您的位置:首页 > 新闻 > 热点要闻 > 个人申请网站_海南省人民政府_企业推广宣传文案_卖网站链接

个人申请网站_海南省人民政府_企业推广宣传文案_卖网站链接

2025/4/19 14:57:51 来源:https://blog.csdn.net/liuhe2296044/article/details/147168901  浏览:    关键词:个人申请网站_海南省人民政府_企业推广宣传文案_卖网站链接
个人申请网站_海南省人民政府_企业推广宣传文案_卖网站链接

infoNCE

代码1:(样本格式为query_n个positive_n个hardnegative)

  • PairwiseModel并不是模型,而是连接model和loss的一个包装类。
  • PairwiseModel接收两种类型样本 【query + pos pair】or【query + pos + neg triplet】。

  • CrossEntropyLoss还可以传入label_smoothing=0.05,用于对比学习。label_smoothing = 0.3时,label_smoothing 的作用是把硬标签 [0, 0, 1, 0] 平滑成类似 [0.1, 0.1, 0.7, 0.1],从而使得 CrossEntropyLoss 不再只惩罚预测不对的类,还会对非目标类的概率也做约束,使模型更加平滑稳定、泛化更强。
  • AutoModelForEmbedding的pooling_method选择mean还是cls根据模型来定,如果模型训练的时候用cls向量当做句子表征,则用cls。否则则用mean。

代码2:(样本格式为query_positive,只有正样本,负样本为batch内其他样本)

import os
import torch.nn as nn
from datasets import load_dataset
from transformers import AutoTokenizer, AdamW, get_linear_schedule_with_warmup, TrainingArguments
from retrievals import AutoModelForEmbedding, RetrievalTrainer, RetrievalCollator, PairwiseModel
from retrievals.losses import ArcFaceAdaptiveMarginLoss, InfoNCE, SimCSE, TripletLoss
model_name_or_path: str = '../model/m3e-base'
# model_name_or_path: str = '../model/bge-small-zh-v1.5'
batch_size: int = 2
epochs: int = 3
#数据集会按照dev、train、test划分。具体有哪个,得print来看,再用split="dev"获取 dev的部分。
train_dataset = load_dataset("../../dataset/C-MTEB/T2Reranking", split="dev") #这个数据集并不是 query_positive格式,而是query_n个positive,因此需要更改
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=False)
model = AutoModelForEmbedding.from_pretrained(model_name_or_path, pooling_method="mean")
train_model = PairwiseModel(model, loss_fn=InfoNCE(nn.CrossEntropyLoss(label_smoothing=0.05)))
optimizer = AdamW(train_model.parameters(), lr=5e-5)
num_train_steps = int(len(train_dataset) / batch_size * epochs)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0.05 * num_train_steps, num_training_steps=num_train_steps)
training_arguments = TrainingArguments(output_dir='./checkpoints',num_train_epochs=epochs,per_device_train_batch_size=batch_size,remove_unused_columns=False,logging_steps=50,
)
# 处理后会得到一个两个key的dict,每个value是一个包含dict(其中包含input_ids、token_type_ids、attention_mask)
dc=RetrievalCollator(tokenizer, keys=['query', 'positive'], max_lengths=[64, 128]) 
trainer = RetrievalTrainer(model=train_model,args=training_arguments,train_dataset=train_dataset,data_collator=dc, # 相当于 自定义collate_fn 函数
)
trainer.optimizer = optimizer
trainer.scheduler = scheduler
trainer.train()

参考:

动手学习RAG: moka-ai/m3e 模型微调deepspeed与对比学习_m3e模型微调-CSDN博客

https://github.com/LongxingTan/open-retrievals?tab=readme-ov-file

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com