您的位置:首页 > 汽车 > 时评 > 【深度学习】--情感分类-附带全套代码(网络结构:编码+全连接)

【深度学习】--情感分类-附带全套代码(网络结构:编码+全连接)

2024/7/3 20:53:12 来源:https://blog.csdn.net/weixin_40293999/article/details/139845399  浏览:    关键词:【深度学习】--情感分类-附带全套代码(网络结构:编码+全连接)

Pytorch 1.12
构建embedding 层:

import torch,torchtext,torchdata
torch.nn.Embedding(num_embeddings = 1000,embedding_dim =2)

Embedding(1000, 2)

train_iter, test_iter = torchtext.datasets.IMDB()
print(next(iter(train_iter)))
('neg', 'I rented I AM CURIOUS-YELLOW from my video store because of all the controversy that surrounded it when it was first released in 1967. I also heard that at first it was seized by U.S. customs if it ever tried to enter this country, therefore being a fan of films considered "controversial" I really had to see this for myself.<br /><br />The plot is centered around a young Swedish drama student named Lena who wants to learn everything she can about life. In particular she wants to focus her attentions to making some sort of documentary on what the average Swede thought about certain political issues such as the Vietnam War and race issues in the United States. In between asking politicians and ordinary denizens of Stockholm about their opinions on politics, she has sex with her drama teacher, classmates, and married men.<br /><br />What kills me about I AM CURIOUS-YELLOW is that 40 years ago, this was considered pornographic. Really, the sex and nudity scenes are few and far between, even then it\'s not shot like some cheaply made porno. While my countrymen mind find it shocking, in reality sex and nudity are a major staple in Swedish cinema. Even Ingmar Bergman, arguably their answer to good old boy John Ford, had sex scenes in his films.<br /><br />I do commend the filmmakers for the fact that any sex shown in the film is shown for artistic purposes rather than just to shock people and make money to be shown in pornographic theaters in America. I AM CURIOUS-YELLOW is a good film for anyone wanting to study the meat and potatoes (no pun intended) of Swedish cinema. But really, this film doesn\'t have much of a plot.')
unique_labels=set([label for (label, text) in train_iter])
print(unique_labels)
num_class = len(unique_labels)
{'neg', 'pos'}
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
tokenizer = get_tokenizer('basic_english')
print(tokenizer('this is a book about PyTorch.'))
['this', 'is', 'a', 'book', 'about', 'pytorch', '.']
def yield_tokens(data):for _,text in data:yield tokenizer(text)vocab = build_vocab_from_iterator(yield_tokens(train_iter),specials=["<pad>","<unk>"],min_freq=5)
print(len(vocab))

30123

vocab.set_default_index(vocab["<unk>"])
print(vocab(["this","is","a","book","about","pytorch","/"]))
[14, 10, 6, 276, 50, 1, 2192]
text_pipeline = lambda x: vocab(tokenizer(x))
label_pipeline = lambda x: int(x=="pos")
text_pipeline('This is a book about PyTorch.')
print(label_pipeline('pos'))
from torchtext.data.functional import to_map_style_dataset
train_dataset = to_map_style_dataset(train_iter)
test_dataset = to_map_style_dataset(test_iter)
from torch.utils.data import DataLoader
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 批次处理
def collate_batch(batch):label_list, text_list, offsets = [],[],[0]for (_label, _text) in batch:label_list.append(label_pipeline(_label))process_text = torch.tensor(text_pipeline(_text),dtype=torch.int64)text_list.append(process_text)offsets.append(process_text.size(0))label_list = torch.tensor(label_list)text_list = torch.cat(text_list)offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)return label_list.to(device), text_list.to(device), offsets.to(device)
BATCHSIZE=64
train_dataloader = DataLoader(train_dataset, batch_size=BATCHSIZE,shuffle=True, collate_fn=collate_batch)
test_dataloader = DataLoader(test_dataset, batch_size=BATCHSIZE, collate_fn=collate_batch)label_list, text_list, offsets = next(iter(train_dataloader))
len(label_list), len(text_list), len(offsets)
import torch.nn as nn
vocab_size = len(vocab) # 获取词表大小
embedding_dim = 100 # 定义词嵌入向量大小
class TextClassificationModel(nn.Module):def __init__(self, vocab_size, embed_dim, num_class) -> None:super().__init__()self.embedding = nn.EmbeddingBag(vocab_size, embed_dim,sparse=True)self.fc = nn.Linear(embed_dim, num_class)# 尼玛对上了,在这里呢,初始化权重self.init_weights()def init_weights(self):initrange = 0.5self.embedding.weight.data.uniform_(-initrange,initrange)self.fc.weight.data.uniform_(-initrange, initrange)self.fc.bias.data.zero_()def forward(self, text, offsets):embedded = self.embedding(text, offsets)return self.fc(embedded)
vocab_size = len(vocab)
model=TextClassificationModel(vocab_size, embedding_dim, num_class).to(device)
model
TextClassificationModel((embedding): EmbeddingBag(30123, 100, mode=mean)(fc): Linear(in_features=100, out_features=2, bias=True)
)
loss_fn = nn.CrossEntropyLoss() # 分类问题的损失函数
from torch.optim import lr_scheduler # 用于对学习速率做衰减optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
# 定义学习速率衰减策略
exp_lr_scheduler = lr_scheduler.StepLR(optimizer,step_size=20,gamma=0.1)
def train(dataloader):total_acc,total_count,total_loss = 0,0,0model.train() # 这是啥意思,训练的时候开启BN和Dropout,for label,text, offsets  in dataloader:# @todo为什么返回的是三项predicted_label=model(text,offsets) #模型调用时要输入offsetsloss = loss_fn(predicted_label, label)# 反向传播optimizer.zero_grad()loss.backward()optimizer.step()with torch.no_grad():total_acc += (predicted_label.argmax(1)==label).sum().item()total_count += label.size(0)total_loss += loss.item() * label.size(0)return total_loss/total_count, total_acc/total_countdef test(dataloader):model.eval()total_acc, total_count, total_loss = 0,0,0with torch.no_grad():for idx,(label,text,offsets) in enumerate(dataloader):predicted_label = model(text, offsets)loss = loss_fn(predicted_label,label)total_acc += (predicted_label.argmax(1)==label).sum().item()total_count += label.size(0)total_loss += loss.item() * label.size(0)return total_loss/total_count, total_acc/total_count
def fit(epochs, train_dl, test_dl):train_loss = []train_acc = []test_loss = []test_acc = []for epoch in range(epochs):epoch_loss, epoch_acc = train(train_dl)epoch_test_loss, epoch_test_acc = test(test_dl)train_loss.append(epoch_loss)train_acc.append(epoch_acc)test_loss.append(epoch_test_loss)test_acc.append(epoch_test_acc)exp_lr_scheduler.step()template =("epoch:{:2d}, train_loss:{:.5f},train_acc:{:.1f}%, test_loss:{:.5f}, test_acc:{:.1f}%")print(template.format(epoch,epoch_loss,epoch_acc*100, epoch_test_loss,epoch_test_acc*100))print("Done!")return train_loss, test_loss, train_acc, test_acc
epochs = 20
fit(epochs, train_dataloader, test_dataloader)
epoch: 0, train_loss:0.69063,train_acc:53.0%, test_loss:0.68775, test_acc:54.7%
Done!
epoch: 1, train_loss:0.68182,train_acc:58.4%, test_loss:0.68003, test_acc:59.2%
Done!
epoch: 2, train_loss:0.67389,train_acc:61.6%, test_loss:0.67334, test_acc:60.2%
Done!
epoch: 3, train_loss:0.66643,train_acc:63.2%, test_loss:0.66645, test_acc:62.6%
Done!
epoch: 4, train_loss:0.65892,train_acc:64.4%, test_loss:0.65915, test_acc:63.5%
Done!
epoch: 5, train_loss:0.65146,train_acc:65.2%, test_loss:0.65190, test_acc:65.0%
Done!
epoch: 6, train_loss:0.64365,train_acc:66.1%, test_loss:0.64442, test_acc:65.7%
Done!
epoch: 7, train_loss:0.63593,train_acc:66.8%, test_loss:0.63676, test_acc:66.6%
Done!
epoch: 8, train_loss:0.62797,train_acc:67.7%, test_loss:0.62896, test_acc:67.3%
Done!
epoch: 9, train_loss:0.61967,train_acc:68.7%, test_loss:0.62102, test_acc:68.3%
Done!
epoch:10, train_loss:0.61128,train_acc:69.3%, test_loss:0.61303, test_acc:69.1%
Done!
epoch:11, train_loss:0.60308,train_acc:70.2%, test_loss:0.60502, test_acc:69.7%
Done!
epoch:12, train_loss:0.59478,train_acc:70.9%, test_loss:0.59658, test_acc:70.5%
...
epoch:18, train_loss:0.54764,train_acc:74.2%, test_loss:0.55076, test_acc:73.6%
Done!
epoch:19, train_loss:0.54041,train_acc:74.9%, test_loss:0.54327, test_acc:74.3%
Done!
Output is truncated. View as a scrollable element or open in a text editor. Adjust cell output settings...

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com