点击下载完整代码: 完整代码
import ctranslate2
import sentencepiece as spm
import os
def tokenize(sp, queries):if isinstance(queries, list):return sp.encode(queries, out_type=str)else:return [sp.encode(queries, out_type=str)]def detokenize(sp, queries, lang):translation = sp.decode(queries)prefix_len = len(lang) + 1translation = [''.join(query)[prefix_len:] for query in translation]return translationos.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
sp = spm.SentencePieceProcessor()
sp.load("m2m100_418m\sentencepiece.model")
# 假设'zh'对应的文本表示是'<2zh>'(这只是一个例子,实际情况可能不同)
zh_token_text = '<2zh>'
zh_token_id = "__zh__"
print(f"target token: {zh_token_id} " )
# source = sp.encode("Hello world!", out_type=str)
queries= ["Hello world! ","My name is Jason Hong.","In some cases, you might want to temporarily unload the model and load it back later. ","The Translator object provides the methods unload_model and load_model for this purpose. ","The model can be either fully unloaded or moved to the CPU memory. "]
translator = ctranslate2.Translator("m2m100_418m")
queries_tokenized = tokenize(sp,queries)
print(f"queries_tokenized: {queries_tokenized} ")
results = translator.translate_batch(
source=queries_tokenized,
target_prefix=[[zh_token_id]] * len(queries),
beam_size=5,
max_batch_size=1024,
return_alternatives=False,
disable_unk=True,
replace_unknowns=True,
repetition_penalty=3)# 对翻译结果进行去分词(detokenize)
translations = [sp.decode(ids.hypotheses[0]) for ids in results]# 打印翻译结果
for translation in translations:print(translation)