您的位置:首页 > 游戏 > 游戏 > 自助建站吧_怎么做游戏代理_网站制作河南_网站推广方案范例

自助建站吧_怎么做游戏代理_网站制作河南_网站推广方案范例

2024/9/23 11:03:27 来源:https://blog.csdn.net/qq_44426403/article/details/142345158  浏览:    关键词:自助建站吧_怎么做游戏代理_网站制作河南_网站推广方案范例
自助建站吧_怎么做游戏代理_网站制作河南_网站推广方案范例

文章目录

  • 0、Prefix-Tuning基本原理
  • 1、Prefix-Tuning代码实战
    • 1.1、导包
    • 1.2、加载数据集
    • 1.3、数据集处理
    • 1.4、创建模型
    • 1.5、Prefix-Tuning
      • 1.5.1、配置文件
      • 1.5.2、创建模型
    • 1.6、配置训练参数
    • 1.7、创建训练器
    • 1.8、模型训练
    • 1.9、模型推理

0、Prefix-Tuning基本原理

 Prefix-Tuning的思想相较于Prompt-Tuning和P-Tuning,Prefix-Tuning不再将Prompt加在输入的Embedding层,而是将其作为可学习的前缀放置在transformer模型的中的每一层中,具体表现为past_key_values。
在这里插入图片描述

 past_key_values是transformer模型中历史计算的key和value的结果,最早是用于生成类模型解码加速,解码逻辑是根据历史输入,每次预测一个新的token,然后将新的token加入输入,再预测下一个token。这个过程中,会存在大量的重复计算, 因此可以将key和value的计算结果缓存,作为past_key_values输入到下一次的计算中,称之为kv_cache>
 Prefix-Tuning中,就是通过past_key_values的形式将可学习的部分放到了模型的每一层,这部分的内容称之为前缀。
在这里插入图片描述

1、Prefix-Tuning代码实战

1.1、导包

from datasets import Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, DataCollatorForSeq2Seq, TrainingArguments, Trainer

1.2、加载数据集

ds = Dataset.load_from_disk("../Data/alpaca_data_zh/")
ds

1.3、数据集处理

tokenizer = AutoTokenizer.from_pretrained("../Model/bloom-389m-zh")
tokenizer
def process_func(example):MAX_LENGTH = 256input_ids, attention_mask, labels = [], [], []instruction = tokenizer("\n".join(["Human: " + example["instruction"], example["input"]]).strip() + "\n\nAssistant: ")response = tokenizer(example["output"] + tokenizer.eos_token)input_ids = instruction["input_ids"] + response["input_ids"]attention_mask = instruction["attention_mask"] + response["attention_mask"]labels = [-100] * len(instruction["input_ids"]) + response["input_ids"]if len(input_ids) > MAX_LENGTH:input_ids = input_ids[:MAX_LENGTH]attention_mask = attention_mask[:MAX_LENGTH]labels = labels[:MAX_LENGTH]return {"input_ids": input_ids,"attention_mask": attention_mask,"labels": labels}tokenized_ds = ds.map(process_func, remove_columns=ds.column_names)
tokenized_ds

1.4、创建模型

model = AutoModelForCausalLM.from_pretrained("../Model/bloom-389m-zh", low_cpu_mem_usage=True)

1.5、Prefix-Tuning

1.5.1、配置文件

from peft import PrefixTuningConfig, get_peft_model, TaskTypeconfig = PrefixTuningConfig(task_type=TaskType.CAUSAL_LM, num_virtual_tokens=10, prefix_projection=True,encoder_hidden_size=1024)
config

1.5.2、创建模型

model = get_peft_model(model, config)
model.print_trainable_parameters()

1.6、配置训练参数

args = TrainingArguments(output_dir="./chatbot",per_device_train_batch_size=1,gradient_accumulation_steps=8,logging_steps=10,num_train_epochs=1
)

1.7、创建训练器

trainer = Trainer(model=model,args=args,train_dataset=tokenized_ds,data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer, padding=True),
)

1.8、模型训练

trainer.train()

1.9、模型推理

model = model.cuda()
ipt = tokenizer("Human: {}\n{}".format("考试有哪些技巧?", "").strip() + "\n\nAssistant: ", return_tensors="pt").to(model.device)
tokenizer.decode(model.generate(**ipt, max_length=128, do_sample=True)[0], skip_special_tokens=True)

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com