您的位置:首页 > 健康 > 美食 > 【Pytorch】RNN for Name Classification

【Pytorch】RNN for Name Classification

2024/12/26 1:25:51 来源:https://blog.csdn.net/bryant_meng/article/details/140344709  浏览:    关键词:【Pytorch】RNN for Name Classification

在这里插入图片描述

参考学习来自:

  • https://pytorch.org/tutorials/intermediate/char_rnn_classification_tutorial.html
  • RNN完成姓名分类
  • https://download.pytorch.org/tutorial/data.zip

导入库

import glob # 用于查找符合规则的文件名
import os
import unicodedata
import string
import torch
import torch.nn as nn
import torch.optim as optim
import random

GPU 配置

"Device configuration"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

导入数据

在这里插入图片描述

"导入数据"
def findFiles(path):return glob.glob(path)for i in findFiles('./data/names/*.txt'):print(i)

output

./data/names/Korean.txt
./data/names/Irish.txt
./data/names/Portuguese.txt
./data/names/Vietnamese.txt
./data/names/Czech.txt
./data/names/Russian.txt
./data/names/Scottish.txt
./data/names/German.txt
./data/names/Polish.txt
./data/names/Spanish.txt
./data/names/English.txt
./data/names/French.txt
./data/names/Japanese.txt
./data/names/Dutch.txt
./data/names/Greek.txt
./data/names/Chinese.txt
./data/names/Italian.txt
./data/names/Arabic.txt

查看名字所有字符

all_letters = string.ascii_letters + " .,;'" #所有的字母和标点
print(all_letters)  # abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ .,;'n_letters = len(all_letters)
print(n_letters)  # 57

字符都转化为 ASCII 形式

# Turn a Unicode string to plain ASCII
# Thanks to https://stackoverflow.com/a/518232/2809427
def unicodeToAscii(s):return ''.join(c for c in unicodedata.normalize('NFD', s)if unicodedata.category(c) != 'Mn'and c in all_letters)
print(unicodeToAscii('Ślusàrski'))  # Slusarskicategory_lines = {}  # 每个语言下的名字
all_categories = []  # 保存所有的语言的名字# 读文件, 返回文件中的每一个
def readLines(filename):# 读文件, 去空格, 按换行进行划分lines = open(filename, encoding='utf-8').read().strip().split('\n')return [unicodeToAscii(line) for line in lines]# 顺序读取每一个文件, 保存其中的姓名
for filename in findFiles('./data/names/*.txt'):# basename返回基础的文件名# splitext将文件名和后缀分开category = os.path.splitext(os.path.basename(filename))[0]all_categories.append(category)lines = readLines(filename)category_lines[category] = linesn_categories = len(all_categories)
# 一共有18种语言的名字
print(len(category_lines))  # 18# 查看某种语言的名字
print(category_lines.keys())
"""
dict_keys(['Korean', 'Irish', 'Portuguese', 'Vietnamese', 'Czech', 'Russian', 'Scottish', 'German', 'Polish', 
'Spanish', 'English', 'French', 'Japanese', 'Dutch', 'Greek', 'Chinese', 'Italian', 'Arabic'])
"""
print(category_lines['English'][:5])  # ['Abbas', 'Abbey', 'Abbott', 'Abdi', 'Abel']

名字转化为 Tensor

"将Name转为Tensor"
# Find letter index from all_letters, e.g. "a"=0
def letterToIndex(letter):# 返回字母在字母表中的位置# abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ .,;'return all_letters.find(letter)# 将整个名字转为<line_length x 1 x n_letters>
def lineToTensor(line):tensor = torch.zeros(len(line), 1, n_letters)for li, letter in enumerate(line):tensor[li][0][letterToIndex(letter)] = 1return tensorprint(lineToTensor('Bryant').shape)  # torch.Size([6, 1, 57])
print(lineToTensor('Mike').shape)  # torch.Size([4, 1, 57])
print(lineToTensor('Mike'))

output

tensor([[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,0., 0., 0., 0., 0., 0.]],[[0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,0., 0., 0., 0., 0., 0.]],[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.,0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,0., 0., 0., 0., 0., 0.]],[[0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,0., 0., 0., 0., 0., 0.]]])

定义 RNN 分类网络

"定义网络"
class RNN_NAME(nn.Module):def __init__(self, input_size, hidden_size, output_size, layers, batch_size):super(RNN_NAME, self).__init__()# 这里的input_size相当于是特征的个数, 这里即是所有字母的个数57self.input_size = input_size# 这里相当于是rnn输出的维数self.hidden_size = hidden_size# 这里是分类的个数, 这里相当于是18, 一共有18类self.output_size = output_size# 这里是rnn的层数self.layers = layers# batchsize的个数self.batch_size = batch_sizeself.gru = nn.GRU(input_size = self.input_size, hidden_size = self.hidden_size, num_layers = self.layers)self.FC = nn.Linear(self.hidden_size, self.output_size)def init_hidden(self):return torch.zeros(self.layers, self.batch_size, self.hidden_size).to(device)def forward(self, x):self.batch_size = x.size(1)self.hidden = self.init_hidden() # 初始化memory的内容rnn_out, self.hidden = self.gru(x,self.hidden)out = self.FC(rnn_out[-1])return out

RNN 网络初始化

"初始化"
layers = 3
input_size = n_letters
hidden_size = 128
output_size = len(all_categories) # 18类
batch_size = 1  # 因为这里name的长度不相同, 所以将batch_size设为1
# 网络初始化
rnn_name = RNN_NAME(input_size, hidden_size, output_size, layers, batch_size).to(device)
# 测试网络输入输出
input_data = lineToTensor('Mike')
out_data = rnn_name(input_data.to(device))
print(out_data.size())
# torch.Size([1, 18])
print(out_data)
"""
tensor([[ 0.0051, -0.0081, -0.0906,  0.1054,  0.0679, -0.0187, -0.0497, -0.0349,-0.0466,  0.0004, -0.0099, -0.0940, -0.0201, -0.0694, -0.0542, -0.0807,0.0005,  0.1373]], grad_fn=<AddmmBackward0>)
"""

准备训练数据集

"准备训练数据"
# 打印最后显示的结果def categoryFromOutput(output):top_n, top_i = output.topk(1)category_i = top_i[0].item()return all_categories[category_i], category_i# Get a random training example
def randomChoice(l):# 随机返回一个l中的分类return l[random.randint(0, len(l)-1)]def randomTrainingExample():"""随机挑选一种语言的一个名字, 用来产生训练需要的数据"""category = randomChoice(all_categories)  # 随机选一个语言line = randomChoice(category_lines[category])  # 随机从这个语言里选一个名字category_tensor = torch.tensor([all_categories.index(category)]).long() # 将分类转为Tensorline_tensor = lineToTensor(line)  # 将名字转为Tensorreturn category, line, category_tensor, line_tensorfor i in range(10):category, line, category_tensor, line_tensor = randomTrainingExample()print('category = ', category, '/ line = ', line)
"""
category =  Scottish / line =  Miller
category =  Irish / line =  O'Mooney
category =  Dutch / line =  Kools
category =  French / line =  Bouchard
category =  French / line =  Masson
category =  English / line =  May
category =  Spanish / line =  Holguin
category =  Portuguese / line =  Moreno
category =  Greek / line =  Nomikos
category =  Czech / line =  Cernohous
"""

计算预测精度

"计算精度"
def get_accuracy(logit, target):corrects = (torch.max(logit, 1)[1].view(target.size()).data == target.data).sum()accuracy = 100.0 * corrects/batch_sizereturn accuracy.item()

网络训练

"训练"criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(rnn_name.parameters(),lr=0.005)
N_EPHOCS = 100
# 共有20074个数据
n_iters = 2000  # 训练n_iters个名字
print_every = 200  # 每print_every轮打印一次结果, 并传递一次误差, 更新系数
for epoch in range(N_EPHOCS):train_running_loss = 0.0loss = torch.tensor([0.0]).float().to(device)train_count = 0rnn_name.train()# trainging roundfor iter in range(1, n_iters+1):# 获取样本category, line, category_tensor, line_tensor = randomTrainingExample()category_tensor = category_tensor.to(device)line_tensor = line_tensor.to(device)# 进行训练optimizer.zero_grad()# reset hidden statesrnn_name.hidden = rnn_name.init_hidden()# forward+backward+optimizeoutputs = rnn_name(line_tensor)guess, guess_i = categoryFromOutput(outputs)if guess == category:train_count = train_count + 1# print(iter)loss = loss + criterion(outputs, category_tensor)if iter % print_every == 0:correct = '✓' if guess == category else '✗ ({})'.format(category)print('{} /{} {}'.format(line, guess, correct))# 进行反向传播(print_every次传一次误差)loss.backward()optimizer.step()train_running_loss = train_running_loss + loss.detach().item()loss = torch.tensor([0.0]).float().to(device)# 查看准确率train_acc = train_count / n_iters * 100print('=======')print('Epoch : {:0>2d} | Loss : {:<6.4f} | Train Accuracy : {:<6.2f}%'.format(epoch, train_running_loss/n_iters, train_acc))print('=======')"验证"
test_count = 0
test_num = 500 # 测试500个名字
rnn_name.eval()
for test_iter in range(1, test_num+1):category, line, category_tensor, line_tensor = randomTrainingExample()outputs = rnn_name(line_tensor.to(device))guess, guess_i = categoryFromOutput(outputs)if guess == category:test_count = test_count + 1
print('Test Accuracy : {:<6.4f}%'.format(test_count/test_num*100))
"""
=======
Epoch : 71 | Loss : 0.2698 | Train Accuracy : 92.35 %
=======
Konda /Japanese ✓
Seighin /Irish ✓
Sedmik /Czech ✓
Jestkov /Russian ✓
So /Korean ✓
Laar /Dutch ✓
Mazza /Italian ✓
Abl /Czech ✓
Doyle /Irish ✓
Skwor /Czech ✓
=======
Epoch : 72 | Loss : 0.2583 | Train Accuracy : 91.40 %
=======
Martoyas /Greek ✗ (Russian)
Hallett /English ✓
Bermudez /Spanish ✓
Bonhomme /French ✓
Lam /Chinese ✗ (Vietnamese)
Kourempes /Greek ✓
Matsuki /Japanese ✓
Sokolofsky /Polish ✓
Kuipers /Dutch ✓
Aller /Dutch ✓
=======
Epoch : 73 | Loss : 0.2564 | Train Accuracy : 91.65 %
=======
Xiao /Chinese ✓
Lian /Chinese ✓
Prill /Czech ✓
Glatter /Czech ✓
Bureau /French ✓
Vallance /French ✗ (English)
Johnstone /Scottish ✓
Kassab /Arabic ✓
Atiyeh /Arabic ✓
Luong /Vietnamese ✓
=======
Epoch : 74 | Loss : 0.2610 | Train Accuracy : 91.60 %
=======
Poirier /French ✓
Hahn /German ✓
Botros /Arabic ✓
Adrichem /Dutch ✓
Rios /Portuguese ✗ (Spanish)
Zangari /Italian ✓
Mccallum /Scottish ✓
Ferro /Portuguese ✓
Kim /Vietnamese ✗ (Korean)
Chang /Korean ✓
=======
Epoch : 75 | Loss : 0.2792 | Train Accuracy : 90.85 %
=======
Monte /Italian ✓
Han /Vietnamese ✗ (Korean)
Ying /Chinese ✓
O'Driscoll /Irish ✓
Poirier /French ✓
Tsuruya /Japanese ✓
Ralph /Czech ✗ (English)
Bonheur /French ✓
Romagna /Italian ✓
Krakowski /Polish ✓
=======
Epoch : 76 | Loss : 0.3000 | Train Accuracy : 90.85 %
=======
Fay /French ✓
Plisek /Czech ✓
Travere /French ✓
Moon /Korean ✓
Aalst /Dutch ✓
Sassa /Japanese ✓
Holzmann /German ✓
Kruessel /Czech ✓
Truong /Vietnamese ✓
Kassab /Arabic ✓
=======
Epoch : 77 | Loss : 0.2602 | Train Accuracy : 92.00 %
=======
Janda /Polish ✓
Doan /Vietnamese ✓
O'Reilly /Irish ✓
Torres /Portuguese ✓
Esipovich /Russian ✓
Johnston /Scottish ✓
Viteri /Spanish ✗ (Italian)
Vandroogenbroeck /Dutch ✓
Montero /Spanish ✓
Wan /Chinese ✓
=======
Epoch : 78 | Loss : 0.2460 | Train Accuracy : 92.35 %
=======
Ha /Korean ✗ (Vietnamese)
Christie /Scottish ✓
Pereira /Portuguese ✓
Kuiper /Dutch ✓
Bautista /Spanish ✓
Takewaki /Japanese ✓
Ying /Chinese ✓
Sokoloff /Polish ✓
Raghailligh /Irish ✓
Edgell /English ✓
=======
Epoch : 79 | Loss : 0.2640 | Train Accuracy : 91.60 %
=======
Neal /English ✓
Kafka /Czech ✓
Hay /Scottish ✓
Pelletier /French ✓
Wilchek /Czech ✓
Marchesi /Italian ✓
Michel /Spanish ✓
Ling /Chinese ✓
Zou /Chinese ✓
Freitas /Portuguese ✓
=======
Epoch : 80 | Loss : 0.2579 | Train Accuracy : 92.60 %
=======
Kelly /Scottish ✓
Chou /Korean ✓
Unsworth /English ✓
Larue /French ✓
Tron /Vietnamese ✓
Herodes /Czech ✓
Fremut /Czech ✓
Thai /Vietnamese ✓
Clark /Irish ✗ (Scottish)
Halenkov /Russian ✓
=======
Epoch : 81 | Loss : 0.2453 | Train Accuracy : 92.40 %
=======
Than /Vietnamese ✓
Pfeifer /Czech ✓
Kennedy /Scottish ✓
Klerkx /Dutch ✓
Freitas /Portuguese ✓
Miller /Scottish ✓
Deeb /Arabic ✓
Almetov /Russian ✓
Li /Korean ✓
Perez /Spanish ✓
=======
Epoch : 82 | Loss : 0.2579 | Train Accuracy : 91.85 %
=======
Horiatis /Greek ✓
Clowes /English ✓
Santos /Portuguese ✓
Mochan /Irish ✓
Sokolof /Polish ✓
Ogterop /Dutch ✓
Schwartz /Czech ✓
Abadi /Arabic ✓
Winogrodzki /Polish ✓
Bazzhin /Russian ✓
=======
Epoch : 83 | Loss : 0.2397 | Train Accuracy : 92.25 %
=======
Ybarra /Spanish ✓
Jacobson /English ✓
Shi /Chinese ✓
Berg /Dutch ✓
Araya /Spanish ✓
Banos /Greek ✓
Winograd /Polish ✓
Schneider /German ✗ (Czech)
Calligaris /Italian ✓
Hasbulatov /Russian ✓
=======
Epoch : 84 | Loss : 0.2466 | Train Accuracy : 92.25 %
=======
Felkerzam /Polish ✗ (Russian)
Gomolka /Polish ✓
Belo /Portuguese ✓
Santos /Portuguese ✓
Rim /Korean ✓
Solomon /French ✓
Agnellutti /Italian ✓
Santana /Portuguese ✓
Ramaker /Dutch ✓
Palmisano /Italian ✓
=======
Epoch : 85 | Loss : 0.2390 | Train Accuracy : 92.80 %
=======
Moraitopoulos /Greek ✓
Rosenberger /German ✓
Torres /Portuguese ✓
Sabbagh /Arabic ✓
Aggelen /Dutch ✓
Chieu /Chinese ✓
Ukhtomsky /Russian ✓
Kouros /Greek ✓
Portyansky /Russian ✓
Dinh /Vietnamese ✓
=======
Epoch : 86 | Loss : 0.2173 | Train Accuracy : 93.15 %
=======
Jaskulski /Polish ✓
Mendez /Spanish ✓
Smyth /English ✓
Mo /Korean ✓
Seegers /Dutch ✓
Elford /English ✓
Lau /Chinese ✓
Sargent /French ✓
Zeman /Czech ✓
Ohme /German ✓
=======
Epoch : 87 | Loss : 0.2189 | Train Accuracy : 93.00 %
=======
Gwang  /Korean ✓
Zhernakov /Russian ✓
Quraishi /Arabic ✓
Asker /Arabic ✓
Franco /Portuguese ✓
Bacon /Czech ✓
Weineltk /Czech ✓
Thuy /Vietnamese ✓
Close /Greek ✓
Xiang /Chinese ✓
=======
Epoch : 88 | Loss : 0.2497 | Train Accuracy : 91.95 %
=======
Basara /Arabic ✓
Esteves /Portuguese ✓
Metz /German ✓
Lopez /Spanish ✓
Bukoski /Polish ✓
Rao /Italian ✓
Klerk /Dutch ✓
Espinoza /Spanish ✓
Wehunt /German ✓
Giehl /German ✓
=======
Epoch : 89 | Loss : 0.2454 | Train Accuracy : 92.70 %
=======
Deniau /French ✓
Bover /Spanish ✗ (Italian)
Soho /Japanese ✓
Ron /Korean ✓
Dolezal /Czech ✓
Yamamura /Japanese ✓
Ramsay /Scottish ✓
Fromm /German ✓
Huerta /Spanish ✓
Colman /Irish ✓
=======
Epoch : 90 | Loss : 0.2432 | Train Accuracy : 92.55 %
=======
Fei /Chinese ✓
Dinh /Vietnamese ✓
Krawiec /Czech ✓
Urbanek /Czech ✓
Yan /Chinese ✓
Giehl /German ✓
Riain /Irish ✓
Zhestkov /Russian ✓
Almetiev /Russian ✓
Deniel /French ✓
=======
Epoch : 91 | Loss : 0.2318 | Train Accuracy : 92.45 %
=======
Macleod /Scottish ✓
Savatier /French ✓
Chong /Chinese ✗ (Korean)
Prchal /Czech ✓
Tsoumada /Greek ✓
Agosti /Italian ✓
Mojjis /Czech ✓
Mitsui /Japanese ✓
San nicolas /Spanish ✓
Kokoris /Greek ✓
=======
Epoch : 92 | Loss : 0.2163 | Train Accuracy : 92.25 %
=======
Marubeni /Japanese ✓
Suh /Korean ✓
Klerks /Dutch ✓
Roma /Spanish ✓
Lac /Vietnamese ✓
Khoury /Arabic ✓
Hoefler /German ✓
Gibson /Scottish ✓
Jasso /Spanish ✓
Oirschotten /Dutch ✓
=======
Epoch : 93 | Loss : 0.2329 | Train Accuracy : 92.90 %
=======
Kikuchi /Japanese ✓
Blecher /German ✓
Ying /Chinese ✓
Nahas /Arabic ✓
Araujo /Portuguese ✓
Koemans /Dutch ✓
Durant /French ✓
Martzenyuk /Russian ✓
Petru /Czech ✓
Wood /Scottish ✓
=======
Epoch : 94 | Loss : 0.2080 | Train Accuracy : 93.35 %
=======
Vespa /Italian ✓
Kowalczyk /Polish ✓
Nahas /Arabic ✓
Tresler /German ✓
Caivano /Italian ✓
Pho /Vietnamese ✓
Dael /Dutch ✓
Heidl /Czech ✓
Totolos /Greek ✓
Adlersflugel /German ✓
=======
Epoch : 95 | Loss : 0.2400 | Train Accuracy : 92.45 %
=======
Zwolenksy /Czech ✓
Chermak /Czech ✓
Sgro /Italian ✓
Lestrange /French ✓
Hwang /Korean ✓
Barros /Spanish ✗ (Portuguese)
Boyle /Scottish ✓
Schmeling /German ✓
Poplawski /Polish ✓
Pantelas /Greek ✓
=======
Epoch : 96 | Loss : 0.2150 | Train Accuracy : 93.10 %
=======
Asher /English ✓
Sai /Vietnamese ✓
Hakimi /Arabic ✓
Sheng /Chinese ✓
Lebedevich /Russian ✓
Castellano /Spanish ✓
Battaglia /Italian ✓
Cardozo /Portuguese ✓
Sinagra /Italian ✓
Bumgarner /German ✓
=======
Epoch : 97 | Loss : 0.2359 | Train Accuracy : 92.80 %
=======
Lyes /English ✓
Do /Vietnamese ✓
Busch /German ✓
Shimura /Japanese ✓
O'Connor /Irish ✓
Geiger /German ✓
Reier /German ✓
Paris /French ✓
Sokolof /Polish ✓
Limarev /Russian ✓
=======
Epoch : 98 | Loss : 0.2369 | Train Accuracy : 92.55 %
=======
Osladil /Czech ✓
Roy /French ✓
Ortega /Spanish ✓
Ruzzier /Italian ✓
Chou /Korean ✓
Mateus /Portuguese ✓
Ibanez /Spanish ✓
Cabello /Spanish ✓
Duguay /French ✓
Yi /Korean ✓
=======
Epoch : 99 | Loss : 0.2192 | Train Accuracy : 92.65 %
=======
Test Accuracy : 94.8000%
"""

精度不足 95%

下面写写测试流程

"单独测试"
input_data = lineToTensor('WMN').to(device)
out_data = rnn_name(input_data)
print(out_data)
print(categoryFromOutput(out_data))
"""
tensor([[ 5.7034,  1.0659, -8.1210,  0.5218, -2.5818,  0.1362,  7.4420,  3.9547,1.0824, -7.2981,  4.0036, -2.2944,  1.6835, -3.5377, -7.6178,  9.8920,-0.5789, -1.0545]], device='cuda:0', grad_fn=<AddmmBackward0>)
('Chinese', 15)
"""input_data = lineToTensor('Huang').to(device)
out_data = rnn_name(input_data)
print(out_data)
print(categoryFromOutput(out_data))
"""
tensor([[ 4.1755, -1.2160, -3.8454,  5.4587, -4.6412,  0.8137,  1.2896,  1.7429,-1.3145, -2.1634,  1.0224, -2.7052,  2.1592,  1.1909, -8.0874, 10.2988,-1.9252, -0.0729]], device='cuda:0', grad_fn=<AddmmBackward0>)
('Chinese', 15)
"""

完整代码

import glob # 用于查找符合规则的文件名
import os
import unicodedata
import string
import torch
import torch.nn as nn
import torch.optim as optim
import random"Device configuration"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")"导入数据"
def findFiles(path):return glob.glob(path)for i in findFiles('./data/names/*.txt'):print(i)"""
./data/names/Korean.txt
./data/names/Irish.txt
./data/names/Portuguese.txt
./data/names/Vietnamese.txt
./data/names/Czech.txt
./data/names/Russian.txt
./data/names/Scottish.txt
./data/names/German.txt
./data/names/Polish.txt
./data/names/Spanish.txt
./data/names/English.txt
./data/names/French.txt
./data/names/Japanese.txt
./data/names/Dutch.txt
./data/names/Greek.txt
./data/names/Chinese.txt
./data/names/Italian.txt
./data/names/Arabic.txt
"""all_letters = string.ascii_letters + " .,;'" #所有的字母和标点
print(all_letters)  # abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ .,;'n_letters = len(all_letters)
print(n_letters)  # 57# Turn a Unicode string to plain ASCII
# Thanks to https://stackoverflow.com/a/518232/2809427
def unicodeToAscii(s):return ''.join(c for c in unicodedata.normalize('NFD', s)if unicodedata.category(c) != 'Mn'and c in all_letters)
print(unicodeToAscii('Ślusàrski'))  # Slusarskicategory_lines = {}  # 每个语言下的名字
all_categories = []  # 保存所有的语言的名字# 读文件, 返回文件中的每一个
def readLines(filename):# 读文件, 去空格, 按换行进行划分lines = open(filename, encoding='utf-8').read().strip().split('\n')return [unicodeToAscii(line) for line in lines]# 顺序读取每一个文件, 保存其中的姓名
for filename in findFiles('./data/names/*.txt'):# basename返回基础的文件名# splitext将文件名和后缀分开category = os.path.splitext(os.path.basename(filename))[0]all_categories.append(category)lines = readLines(filename)category_lines[category] = linesn_categories = len(all_categories)
# 一共有18种语言的名字
print(len(category_lines))  # 18# 查看某种语言的名字
print(category_lines.keys())
"""
dict_keys(['Korean', 'Irish', 'Portuguese', 'Vietnamese', 'Czech', 'Russian', 'Scottish', 'German', 'Polish', 
'Spanish', 'English', 'French', 'Japanese', 'Dutch', 'Greek', 'Chinese', 'Italian', 'Arabic'])
"""
print(category_lines['English'][:5])  # ['Abbas', 'Abbey', 'Abbott', 'Abdi', 'Abel']"将Name转为Tensor"
# Find letter index from all_letters, e.g. "a"=0
def letterToIndex(letter):# 返回字母在字母表中的位置# abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ .,;'return all_letters.find(letter)# 将整个名字转为<line_length x 1 x n_letters>
def lineToTensor(line):tensor = torch.zeros(len(line), 1, n_letters)for li, letter in enumerate(line):tensor[li][0][letterToIndex(letter)] = 1return tensorprint(lineToTensor('Bryant').shape)  # torch.Size([6, 1, 57])
print(lineToTensor('Mike').shape)  # torch.Size([4, 1, 57])
print(lineToTensor('Mike'))"""
tensor([[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,0., 0., 0., 0., 0., 0.]],[[0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,0., 0., 0., 0., 0., 0.]],[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.,0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,0., 0., 0., 0., 0., 0.]],[[0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,0., 0., 0., 0., 0., 0.]]])
""""定义网络"
class RNN_NAME(nn.Module):def __init__(self, input_size, hidden_size, output_size, layers, batch_size):super(RNN_NAME, self).__init__()# 这里的input_size相当于是特征的个数, 这里即是所有字母的个数57self.input_size = input_size# 这里相当于是rnn输出的维数self.hidden_size = hidden_size# 这里是分类的个数, 这里相当于是18, 一共有18类self.output_size = output_size# 这里是rnn的层数self.layers = layers# batchsize的个数self.batch_size = batch_sizeself.gru = nn.GRU(input_size = self.input_size, hidden_size = self.hidden_size, num_layers = self.layers)self.FC = nn.Linear(self.hidden_size, self.output_size)def init_hidden(self):return torch.zeros(self.layers, self.batch_size, self.hidden_size).to(device)def forward(self, x):self.batch_size = x.size(1)self.hidden = self.init_hidden() # 初始化memory的内容rnn_out, self.hidden = self.gru(x,self.hidden)out = self.FC(rnn_out[-1])return out"初始化"
layers = 3
input_size = n_letters
hidden_size = 128
output_size = len(all_categories) # 18类
batch_size = 1  # 因为这里name的长度不相同, 所以将batch_size设为1
# 网络初始化
rnn_name = RNN_NAME(input_size, hidden_size, output_size, layers, batch_size).to(device)
# 测试网络输入输出
input_data = lineToTensor('Mike')
out_data = rnn_name(input_data.to(device))
print(out_data.size())
# torch.Size([1, 18])
print(out_data)
"""
tensor([[ 0.0051, -0.0081, -0.0906,  0.1054,  0.0679, -0.0187, -0.0497, -0.0349,-0.0466,  0.0004, -0.0099, -0.0940, -0.0201, -0.0694, -0.0542, -0.0807,0.0005,  0.1373]], grad_fn=<AddmmBackward0>)
""""准备训练数据"
# 打印最后显示的结果def categoryFromOutput(output):top_n, top_i = output.topk(1)category_i = top_i[0].item()return all_categories[category_i], category_i# Get a random training example
def randomChoice(l):# 随机返回一个l中的分类return l[random.randint(0, len(l)-1)]def randomTrainingExample():"""随机挑选一种语言的一个名字, 用来产生训练需要的数据"""category = randomChoice(all_categories)  # 随机选一个语言line = randomChoice(category_lines[category])  # 随机从这个语言里选一个名字category_tensor = torch.tensor([all_categories.index(category)]).long() # 将分类转为Tensorline_tensor = lineToTensor(line)  # 将名字转为Tensorreturn category, line, category_tensor, line_tensorfor i in range(10):category, line, category_tensor, line_tensor = randomTrainingExample()print('category = ', category, '/ line = ', line)
"""
category =  Scottish / line =  Miller
category =  Irish / line =  O'Mooney
category =  Dutch / line =  Kools
category =  French / line =  Bouchard
category =  French / line =  Masson
category =  English / line =  May
category =  Spanish / line =  Holguin
category =  Portuguese / line =  Moreno
category =  Greek / line =  Nomikos
category =  Czech / line =  Cernohous
""""计算精度"
def get_accuracy(logit, target):corrects = (torch.max(logit, 1)[1].view(target.size()).data == target.data).sum()accuracy = 100.0 * corrects/batch_sizereturn accuracy.item()"训练"criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(rnn_name.parameters(),lr=0.005)
N_EPHOCS = 100
# 共有20074个数据
n_iters = 2000  # 训练n_iters个名字
print_every = 200  # 每print_every轮打印一次结果, 并传递一次误差, 更新系数
for epoch in range(N_EPHOCS):train_running_loss = 0.0loss = torch.tensor([0.0]).float().to(device)train_count = 0rnn_name.train()# trainging roundfor iter in range(1, n_iters+1):# 获取样本category, line, category_tensor, line_tensor = randomTrainingExample()category_tensor = category_tensor.to(device)line_tensor = line_tensor.to(device)# 进行训练optimizer.zero_grad()# reset hidden statesrnn_name.hidden = rnn_name.init_hidden()# forward+backward+optimizeoutputs = rnn_name(line_tensor)guess, guess_i = categoryFromOutput(outputs)if guess == category:train_count = train_count + 1# print(iter)loss = loss + criterion(outputs, category_tensor)if iter % print_every == 0:correct = '✓' if guess == category else '✗ ({})'.format(category)print('{} /{} {}'.format(line, guess, correct))# 进行反向传播(print_every次传一次误差)loss.backward()optimizer.step()train_running_loss = train_running_loss + loss.detach().item()loss = torch.tensor([0.0]).float().to(device)# 查看准确率train_acc = train_count / n_iters * 100print('=======')print('Epoch : {:0>2d} | Loss : {:<6.4f} | Train Accuracy : {:<6.2f}%'.format(epoch, train_running_loss/n_iters, train_acc))print('=======')"验证"
test_count = 0
test_num = 500 # 测试500个名字
rnn_name.eval()
for test_iter in range(1, test_num+1):category, line, category_tensor, line_tensor = randomTrainingExample()outputs = rnn_name(line_tensor.to(device))guess, guess_i = categoryFromOutput(outputs)if guess == category:test_count = test_count + 1
print('Test Accuracy : {:<6.4f}%'.format(test_count/test_num*100))
"""
=======
Epoch : 71 | Loss : 0.2698 | Train Accuracy : 92.35 %
=======
Konda /Japanese ✓
Seighin /Irish ✓
Sedmik /Czech ✓
Jestkov /Russian ✓
So /Korean ✓
Laar /Dutch ✓
Mazza /Italian ✓
Abl /Czech ✓
Doyle /Irish ✓
Skwor /Czech ✓
=======
Epoch : 72 | Loss : 0.2583 | Train Accuracy : 91.40 %
=======
Martoyas /Greek ✗ (Russian)
Hallett /English ✓
Bermudez /Spanish ✓
Bonhomme /French ✓
Lam /Chinese ✗ (Vietnamese)
Kourempes /Greek ✓
Matsuki /Japanese ✓
Sokolofsky /Polish ✓
Kuipers /Dutch ✓
Aller /Dutch ✓
=======
Epoch : 73 | Loss : 0.2564 | Train Accuracy : 91.65 %
=======
Xiao /Chinese ✓
Lian /Chinese ✓
Prill /Czech ✓
Glatter /Czech ✓
Bureau /French ✓
Vallance /French ✗ (English)
Johnstone /Scottish ✓
Kassab /Arabic ✓
Atiyeh /Arabic ✓
Luong /Vietnamese ✓
=======
Epoch : 74 | Loss : 0.2610 | Train Accuracy : 91.60 %
=======
Poirier /French ✓
Hahn /German ✓
Botros /Arabic ✓
Adrichem /Dutch ✓
Rios /Portuguese ✗ (Spanish)
Zangari /Italian ✓
Mccallum /Scottish ✓
Ferro /Portuguese ✓
Kim /Vietnamese ✗ (Korean)
Chang /Korean ✓
=======
Epoch : 75 | Loss : 0.2792 | Train Accuracy : 90.85 %
=======
Monte /Italian ✓
Han /Vietnamese ✗ (Korean)
Ying /Chinese ✓
O'Driscoll /Irish ✓
Poirier /French ✓
Tsuruya /Japanese ✓
Ralph /Czech ✗ (English)
Bonheur /French ✓
Romagna /Italian ✓
Krakowski /Polish ✓
=======
Epoch : 76 | Loss : 0.3000 | Train Accuracy : 90.85 %
=======
Fay /French ✓
Plisek /Czech ✓
Travere /French ✓
Moon /Korean ✓
Aalst /Dutch ✓
Sassa /Japanese ✓
Holzmann /German ✓
Kruessel /Czech ✓
Truong /Vietnamese ✓
Kassab /Arabic ✓
=======
Epoch : 77 | Loss : 0.2602 | Train Accuracy : 92.00 %
=======
Janda /Polish ✓
Doan /Vietnamese ✓
O'Reilly /Irish ✓
Torres /Portuguese ✓
Esipovich /Russian ✓
Johnston /Scottish ✓
Viteri /Spanish ✗ (Italian)
Vandroogenbroeck /Dutch ✓
Montero /Spanish ✓
Wan /Chinese ✓
=======
Epoch : 78 | Loss : 0.2460 | Train Accuracy : 92.35 %
=======
Ha /Korean ✗ (Vietnamese)
Christie /Scottish ✓
Pereira /Portuguese ✓
Kuiper /Dutch ✓
Bautista /Spanish ✓
Takewaki /Japanese ✓
Ying /Chinese ✓
Sokoloff /Polish ✓
Raghailligh /Irish ✓
Edgell /English ✓
=======
Epoch : 79 | Loss : 0.2640 | Train Accuracy : 91.60 %
=======
Neal /English ✓
Kafka /Czech ✓
Hay /Scottish ✓
Pelletier /French ✓
Wilchek /Czech ✓
Marchesi /Italian ✓
Michel /Spanish ✓
Ling /Chinese ✓
Zou /Chinese ✓
Freitas /Portuguese ✓
=======
Epoch : 80 | Loss : 0.2579 | Train Accuracy : 92.60 %
=======
Kelly /Scottish ✓
Chou /Korean ✓
Unsworth /English ✓
Larue /French ✓
Tron /Vietnamese ✓
Herodes /Czech ✓
Fremut /Czech ✓
Thai /Vietnamese ✓
Clark /Irish ✗ (Scottish)
Halenkov /Russian ✓
=======
Epoch : 81 | Loss : 0.2453 | Train Accuracy : 92.40 %
=======
Than /Vietnamese ✓
Pfeifer /Czech ✓
Kennedy /Scottish ✓
Klerkx /Dutch ✓
Freitas /Portuguese ✓
Miller /Scottish ✓
Deeb /Arabic ✓
Almetov /Russian ✓
Li /Korean ✓
Perez /Spanish ✓
=======
Epoch : 82 | Loss : 0.2579 | Train Accuracy : 91.85 %
=======
Horiatis /Greek ✓
Clowes /English ✓
Santos /Portuguese ✓
Mochan /Irish ✓
Sokolof /Polish ✓
Ogterop /Dutch ✓
Schwartz /Czech ✓
Abadi /Arabic ✓
Winogrodzki /Polish ✓
Bazzhin /Russian ✓
=======
Epoch : 83 | Loss : 0.2397 | Train Accuracy : 92.25 %
=======
Ybarra /Spanish ✓
Jacobson /English ✓
Shi /Chinese ✓
Berg /Dutch ✓
Araya /Spanish ✓
Banos /Greek ✓
Winograd /Polish ✓
Schneider /German ✗ (Czech)
Calligaris /Italian ✓
Hasbulatov /Russian ✓
=======
Epoch : 84 | Loss : 0.2466 | Train Accuracy : 92.25 %
=======
Felkerzam /Polish ✗ (Russian)
Gomolka /Polish ✓
Belo /Portuguese ✓
Santos /Portuguese ✓
Rim /Korean ✓
Solomon /French ✓
Agnellutti /Italian ✓
Santana /Portuguese ✓
Ramaker /Dutch ✓
Palmisano /Italian ✓
=======
Epoch : 85 | Loss : 0.2390 | Train Accuracy : 92.80 %
=======
Moraitopoulos /Greek ✓
Rosenberger /German ✓
Torres /Portuguese ✓
Sabbagh /Arabic ✓
Aggelen /Dutch ✓
Chieu /Chinese ✓
Ukhtomsky /Russian ✓
Kouros /Greek ✓
Portyansky /Russian ✓
Dinh /Vietnamese ✓
=======
Epoch : 86 | Loss : 0.2173 | Train Accuracy : 93.15 %
=======
Jaskulski /Polish ✓
Mendez /Spanish ✓
Smyth /English ✓
Mo /Korean ✓
Seegers /Dutch ✓
Elford /English ✓
Lau /Chinese ✓
Sargent /French ✓
Zeman /Czech ✓
Ohme /German ✓
=======
Epoch : 87 | Loss : 0.2189 | Train Accuracy : 93.00 %
=======
Gwang  /Korean ✓
Zhernakov /Russian ✓
Quraishi /Arabic ✓
Asker /Arabic ✓
Franco /Portuguese ✓
Bacon /Czech ✓
Weineltk /Czech ✓
Thuy /Vietnamese ✓
Close /Greek ✓
Xiang /Chinese ✓
=======
Epoch : 88 | Loss : 0.2497 | Train Accuracy : 91.95 %
=======
Basara /Arabic ✓
Esteves /Portuguese ✓
Metz /German ✓
Lopez /Spanish ✓
Bukoski /Polish ✓
Rao /Italian ✓
Klerk /Dutch ✓
Espinoza /Spanish ✓
Wehunt /German ✓
Giehl /German ✓
=======
Epoch : 89 | Loss : 0.2454 | Train Accuracy : 92.70 %
=======
Deniau /French ✓
Bover /Spanish ✗ (Italian)
Soho /Japanese ✓
Ron /Korean ✓
Dolezal /Czech ✓
Yamamura /Japanese ✓
Ramsay /Scottish ✓
Fromm /German ✓
Huerta /Spanish ✓
Colman /Irish ✓
=======
Epoch : 90 | Loss : 0.2432 | Train Accuracy : 92.55 %
=======
Fei /Chinese ✓
Dinh /Vietnamese ✓
Krawiec /Czech ✓
Urbanek /Czech ✓
Yan /Chinese ✓
Giehl /German ✓
Riain /Irish ✓
Zhestkov /Russian ✓
Almetiev /Russian ✓
Deniel /French ✓
=======
Epoch : 91 | Loss : 0.2318 | Train Accuracy : 92.45 %
=======
Macleod /Scottish ✓
Savatier /French ✓
Chong /Chinese ✗ (Korean)
Prchal /Czech ✓
Tsoumada /Greek ✓
Agosti /Italian ✓
Mojjis /Czech ✓
Mitsui /Japanese ✓
San nicolas /Spanish ✓
Kokoris /Greek ✓
=======
Epoch : 92 | Loss : 0.2163 | Train Accuracy : 92.25 %
=======
Marubeni /Japanese ✓
Suh /Korean ✓
Klerks /Dutch ✓
Roma /Spanish ✓
Lac /Vietnamese ✓
Khoury /Arabic ✓
Hoefler /German ✓
Gibson /Scottish ✓
Jasso /Spanish ✓
Oirschotten /Dutch ✓
=======
Epoch : 93 | Loss : 0.2329 | Train Accuracy : 92.90 %
=======
Kikuchi /Japanese ✓
Blecher /German ✓
Ying /Chinese ✓
Nahas /Arabic ✓
Araujo /Portuguese ✓
Koemans /Dutch ✓
Durant /French ✓
Martzenyuk /Russian ✓
Petru /Czech ✓
Wood /Scottish ✓
=======
Epoch : 94 | Loss : 0.2080 | Train Accuracy : 93.35 %
=======
Vespa /Italian ✓
Kowalczyk /Polish ✓
Nahas /Arabic ✓
Tresler /German ✓
Caivano /Italian ✓
Pho /Vietnamese ✓
Dael /Dutch ✓
Heidl /Czech ✓
Totolos /Greek ✓
Adlersflugel /German ✓
=======
Epoch : 95 | Loss : 0.2400 | Train Accuracy : 92.45 %
=======
Zwolenksy /Czech ✓
Chermak /Czech ✓
Sgro /Italian ✓
Lestrange /French ✓
Hwang /Korean ✓
Barros /Spanish ✗ (Portuguese)
Boyle /Scottish ✓
Schmeling /German ✓
Poplawski /Polish ✓
Pantelas /Greek ✓
=======
Epoch : 96 | Loss : 0.2150 | Train Accuracy : 93.10 %
=======
Asher /English ✓
Sai /Vietnamese ✓
Hakimi /Arabic ✓
Sheng /Chinese ✓
Lebedevich /Russian ✓
Castellano /Spanish ✓
Battaglia /Italian ✓
Cardozo /Portuguese ✓
Sinagra /Italian ✓
Bumgarner /German ✓
=======
Epoch : 97 | Loss : 0.2359 | Train Accuracy : 92.80 %
=======
Lyes /English ✓
Do /Vietnamese ✓
Busch /German ✓
Shimura /Japanese ✓
O'Connor /Irish ✓
Geiger /German ✓
Reier /German ✓
Paris /French ✓
Sokolof /Polish ✓
Limarev /Russian ✓
=======
Epoch : 98 | Loss : 0.2369 | Train Accuracy : 92.55 %
=======
Osladil /Czech ✓
Roy /French ✓
Ortega /Spanish ✓
Ruzzier /Italian ✓
Chou /Korean ✓
Mateus /Portuguese ✓
Ibanez /Spanish ✓
Cabello /Spanish ✓
Duguay /French ✓
Yi /Korean ✓
=======
Epoch : 99 | Loss : 0.2192 | Train Accuracy : 92.65 %
=======
Test Accuracy : 94.8000%
""""单独测试"
input_data = lineToTensor('WMN').to(device)
out_data = rnn_name(input_data)
print(out_data)
print(categoryFromOutput(out_data))
"""
tensor([[ 5.7034,  1.0659, -8.1210,  0.5218, -2.5818,  0.1362,  7.4420,  3.9547,1.0824, -7.2981,  4.0036, -2.2944,  1.6835, -3.5377, -7.6178,  9.8920,-0.5789, -1.0545]], device='cuda:0', grad_fn=<AddmmBackward0>)
('Chinese', 15)
"""input_data = lineToTensor('Huang').to(device)
out_data = rnn_name(input_data)
print(out_data)
print(categoryFromOutput(out_data))
"""
tensor([[ 4.1755, -1.2160, -3.8454,  5.4587, -4.6412,  0.8137,  1.2896,  1.7429,-1.3145, -2.1634,  1.0224, -2.7052,  2.1592,  1.1909, -8.0874, 10.2988,-1.9252, -0.0729]], device='cuda:0', grad_fn=<AddmmBackward0>)
('Chinese', 15)
"""

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com