您的位置:首页 > 文旅 > 旅游 > 安陆网站设计_discuz应用中心_网站宣传推广策划_2022新闻热点10条

安陆网站设计_discuz应用中心_网站宣传推广策划_2022新闻热点10条

2024/10/6 1:35:28 来源:https://blog.csdn.net/qq_42875127/article/details/142319609  浏览:    关键词:安陆网站设计_discuz应用中心_网站宣传推广策划_2022新闻热点10条
安陆网站设计_discuz应用中心_网站宣传推广策划_2022新闻热点10条

文章目录

  • Link
  • Task: Few-shot Classification
  • Baseline
    • Simple—transfer learning
    • Medium — FO-MAML
    • Strong — MAML

Link

Kaggle

Task: Few-shot Classification

The Omniglot dataset

  • background set: 30 alphabets
  • evaluation set: 20 alphabets
  • Problem setup: 5-way 1-shot classification

Omniglot数据集

  • 背景集:30个字母
  • 评估集:20个字母
  • 问题设置:5-way 1-shot分类
    Definition of support set and query set

Baseline

Simple—transfer learning

直接把sample code运行即可

  • traing:
    对随机选择的5个任务进行正常分类训练验证/测试
  • validation / testing:
    对五个 Support Images 进行微调,并对Query Images进行推理

Slover首先从训练集中选择5个任务,然后对选择的5个任务进行正常分类训练。在推理中,模型在支持集support set图像上微调inner_train_step步骤,然后在查询集Query Set图像上进行推理。
为了与元学习Slover保持一致,基本Slover具有与元学习Slover完全相同的输入输出格式

def BaseSolver(model,optimizer,x,n_way,k_shot,q_query,loss_fn,inner_train_step=1,inner_lr=0.4,train=True,return_labels=False,
):criterion, task_loss, task_acc = loss_fn, [], []labels = []for meta_batch in x:# Get datasupport_set = meta_batch[: n_way * k_shot]query_set = meta_batch[n_way * k_shot :]if train:""" training loop """# Use the support set to calculate losslabels = create_label(n_way, k_shot).to(device)logits = model.forward(support_set)loss = criterion(logits, labels)task_loss.append(loss)task_acc.append(calculate_accuracy(logits, labels))else:""" validation / testing loop """# First update model with support set images for `inner_train_step` stepsfast_weights = OrderedDict(model.named_parameters())for inner_step in range(inner_train_step):# Simply trainingtrain_label = create_label(n_way, k_shot).to(device)logits = model.functional_forward(support_set, fast_weights)loss = criterion(logits, train_label)grads = torch.autograd.grad(loss, fast_weights.values(), create_graph=True)# Perform SGDfast_weights = OrderedDict((name, param - inner_lr * grad)for ((name, param), grad) in zip(fast_weights.items(), grads))if not return_labels:""" validation """val_label = create_label(n_way, q_query).to(device)logits = model.functional_forward(query_set, fast_weights)loss = criterion(logits, val_label)task_loss.append(loss)task_acc.append(calculate_accuracy(logits, val_label))else:""" testing """logits = model.functional_forward(query_set, fast_weights)labels.extend(torch.argmax(logits, -1).cpu().numpy())if return_labels:return labelsbatch_loss = torch.stack(task_loss).mean()task_acc = np.mean(task_acc)if train:# Update modelmodel.train()optimizer.zero_grad()batch_loss.backward()optimizer.step()return batch_loss, task_acc

Medium — FO-MAML

FOMAML(First-Order MAML)是MAML(Model-Agnostic Meta-Learning)的一种简化版本。MAML是一种元学习算法,旨在通过训练模型使其能够在少量新数据上快速适应新任务。FOMAML通过忽略二阶导数来简化MAML的计算过程,从而提高计算效率。它在许多情况下表现良好,尤其是在计算资源有限的情况下。然而,它也可能在某些任务上表现不如完整的MAML。

MAML的核心思想是通过在多个任务上进行训练,使得模型能够在面对新任务时,只需少量数据就能快速收敛到一个好的参数配置。具体来说,MAML的训练过程包括两个层次的优化:

  • 内层优化(Inner Loop):在每个任务上进行少量的梯度更新,以适应该任务。

  • 外层优化(Outer Loop):在所有任务上进行梯度更新,以优化模型的初始参数,使得模型在面对新任务时能够快速适应。

""" Inner Loop Update """
grads = torch.autograd.grad(loss, fast_weights.values(), create_graph=False) # create_graph=False:这个参数表示在计算梯度时不创建计算图。在FOMAML中,我们只关心一阶导数,因此不需要创建计算图fast_weights = OrderedDict((name, param - inner_lr * grad)for ((name, param), grad) in zip(fast_weights.items(), grads))""" Outer Loop Update """# TODO: Finish the outer loop update# raise NotimplementedErrormeta_batch_loss.backward()optimizer.step()

Strong — MAML

""" Inner Loop Update """
grads = torch.autograd.grad(loss, fast_weights.values(), create_graph=True)fast_weights = OrderedDict((name, param - inner_lr * grad)for ((name, param), grad) in zip(fast_weights.items(), grads))""" Outer Loop Update """# TODO: Finish the outer loop update# raise NotimplementedErrormeta_batch_loss.backward()optimizer.step()

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com