您的位置:首页 > 游戏 > 手游 > 东莞建设工程公司_免费微信小程序模板_学设计什么培训机构好_如何自己做网页

东莞建设工程公司_免费微信小程序模板_学设计什么培训机构好_如何自己做网页

2025/1/8 3:17:49 来源:https://blog.csdn.net/weixin_53704902/article/details/144892143  浏览:    关键词:东莞建设工程公司_免费微信小程序模板_学设计什么培训机构好_如何自己做网页
东莞建设工程公司_免费微信小程序模板_学设计什么培训机构好_如何自己做网页

1. resnet分类器训练
 

import torch
import torchvision
from torchvision import transforms
from torch.utils.data import random_split
import torch.nn as nn
import torch.optim as optim
from torchvision.models import resnet50# Define the transformation
transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])# Load the dataset
data = torchvision.datasets.ImageFolder(root=r"D:\train_model\train_data_set", transform=transform)classes_set = data.classes
# 保存类别信息到 classes.txt
with open('classes.txt', 'w') as f:for class_name in classes_set:f.write(class_name + '\n')
# Split the data into train and test sets
train_size = int(0.8 * len(data))
test_size = len(data) - train_size
train_data, test_data = random_split(data, [train_size, test_size])# Optionally, you can load the train and test data into data loaders
train_loader = torch.utils.data.DataLoader(train_data, batch_size=32, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=32, shuffle=False)# Define the model
model = resnet50(pretrained=True)# Replace the last layer
num_features = model.fc.in_features
model.fc = nn.Linear(num_features, len(classes_set))
# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)# Move the model to the device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)
# Define the number of epochs
num_epochs = 10# Train the model
for epoch in range(num_epochs):# Train the model on the training setmodel.train()train_loss = 0.0for i, (inputs, labels) in enumerate(train_loader):# Move the data to the deviceinputs = inputs.to(device)# inputs = inputs.float()labels = labels.to(device)# labels = labels.long()# Zero the parameter gradientsoptimizer.zero_grad()# Forward + backward + optimizeoutputs = model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()# Update the training losstrain_loss += loss.item() * inputs.size(0)# Evaluate the model on the test setmodel.eval()test_loss = 0.0test_acc = 0.0with torch.no_grad():for i, (inputs, labels) in enumerate(test_loader):# Move the data to the deviceinputs = inputs.to(device)labels = labels.to(device)# Forwardoutputs = model(inputs)loss = criterion(outputs, labels)# Update the test loss and accuracytest_loss += loss.item() * inputs.size(0)_, preds = torch.max(outputs, 1)test_acc += torch.sum(preds == labels.data)# Print the training and test loss and accuracytrain_loss /= len(train_data)test_loss /= len(test_data)test_acc = test_acc.double() / len(test_data)print(f"Epoch [{epoch + 1}/{num_epochs}] Train Loss: {train_loss:.4f} Test Loss: {test_loss:.4f} Test Acc: {test_acc:.4f}")# 保存模型参数
torch.save(model.state_dict(), './model/trained_model.pth')

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com