您的位置:首页 > 财经 > 金融 > android编程_附近最好的装修公司_世界足球排名最新_企业网

android编程_附近最好的装修公司_世界足球排名最新_企业网

2024/12/23 12:35:11 来源:https://blog.csdn.net/huanghm88/article/details/144590711  浏览:    关键词:android编程_附近最好的装修公司_世界足球排名最新_企业网
android编程_附近最好的装修公司_世界足球排名最新_企业网
{"cells": [{"cell_type": "code","execution_count": 1,"metadata": {"collapsed": false},"outputs": [],"source": ["import torch\n","import torch.nn as nn\n","import torch.optim as optim\n","from torchvision import datasets, transforms, models\n","from torch.utils.data import DataLoader"]},{"cell_type": "code","execution_count": 2,"metadata": {"collapsed": true},"outputs": [],"source": ["# 数据预处理\n","transform = transforms.Compose([\n","    transforms.RandomResizedCrop(224),# 对图像进行随机的crop以后再resize成固定大小\n","    transforms.RandomRotation(20), # 随机旋转角度\n","    transforms.RandomHorizontalFlip(p=0.5), # 随机水平翻转\n","    transforms.ToTensor() \n","])\n"," \n","# 读取数据\n","root = 'image'\n","train_dataset = datasets.ImageFolder(root + '/train', transform)\n","test_dataset = datasets.ImageFolder(root + '/test', transform)\n"," \n","# 导入数据\n","train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=8, shuffle=True)\n","test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=8, shuffle=True)"]},{"cell_type": "code","execution_count": 3,"metadata": {"collapsed": false},"outputs": [{"name": "stdout","output_type": "stream","text": ["['cat', 'dog']\n","{'cat': 0, 'dog': 1}\n"]}],"source": ["classes = train_dataset.classes\n","classes_index = train_dataset.class_to_idx\n","print(classes)\n","print(classes_index)"]},{"cell_type": "code","execution_count": 4,"metadata": {"collapsed": false},"outputs": [{"name": "stdout","output_type": "stream","text": ["VGG(\n","  (features): Sequential(\n","    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n","    (1): ReLU(inplace=True)\n","    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n","    (3): ReLU(inplace=True)\n","    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n","    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n","    (6): ReLU(inplace=True)\n","    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n","    (8): ReLU(inplace=True)\n","    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n","    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n","    (11): ReLU(inplace=True)\n","    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n","    (13): ReLU(inplace=True)\n","    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n","    (15): ReLU(inplace=True)\n","    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n","    (17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n","    (18): ReLU(inplace=True)\n","    (19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n","    (20): ReLU(inplace=True)\n","    (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n","    (22): ReLU(inplace=True)\n","    (23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n","    (24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n","    (25): ReLU(inplace=True)\n","    (26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n","    (27): ReLU(inplace=True)\n","    (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n","    (29): ReLU(inplace=True)\n","    (30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n","  )\n","  (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))\n","  (classifier): Sequential(\n","    (0): Linear(in_features=25088, out_features=4096, bias=True)\n","    (1): ReLU(inplace=True)\n","    (2): Dropout(p=0.5, inplace=False)\n","    (3): Linear(in_features=4096, out_features=4096, bias=True)\n","    (4): ReLU(inplace=True)\n","    (5): Dropout(p=0.5, inplace=False)\n","    (6): Linear(in_features=4096, out_features=1000, bias=True)\n","  )\n",")\n"]}],"source": ["model = models.vgg16(pretrained = True)\n","print(model)"]},{"cell_type": "code","execution_count": 5,"metadata": {"collapsed": true},"outputs": [],"source": ["# 如果我们想只训练模型的全连接层\n","# for param in model.parameters():\n","#     param.requires_grad = False\n","    \n","# 构建新的全连接层\n","model.classifier = torch.nn.Sequential(torch.nn.Linear(25088, 100),\n","                                       torch.nn.ReLU(),\n","                                       torch.nn.Dropout(p=0.5),\n","                                       torch.nn.Linear(100, 2))"]},{"cell_type": "code","execution_count": 6,"metadata": {"collapsed": true},"outputs": [],"source": ["LR = 0.0001\n","# 定义代价函数\n","entropy_loss = nn.CrossEntropyLoss()\n","# 定义优化器\n","optimizer = optim.SGD(model.parameters(), LR, momentum=0.9)"]},{"cell_type": "code","execution_count": 7,"metadata": {"collapsed": true},"outputs": [],"source": ["def train():\n","    model.train()\n","    for i, data in enumerate(train_loader):\n","        # 获得数据和对应的标签\n","        inputs, labels = data\n","        # 获得模型预测结果,(64,10)\n","        out = model(inputs)\n","        # 交叉熵代价函数out(batch,C),labels(batch)\n","        loss = entropy_loss(out, labels)\n","        # 梯度清0\n","        optimizer.zero_grad()\n","        # 计算梯度\n","        loss.backward()\n","        # 修改权值\n","        optimizer.step()\n","\n","\n","def test():\n","    model.eval()\n","    correct = 0\n","    for i, data in enumerate(test_loader):\n","        # 获得数据和对应的标签\n","        inputs, labels = data\n","        # 获得模型预测结果\n","        out = model(inputs)\n","        # 获得最大值,以及最大值所在的位置\n","        _, predicted = torch.max(out, 1)\n","        # 预测正确的数量\n","        correct += (predicted == labels).sum()\n","    print(\"Test acc: {0}\".format(correct.item() / len(test_dataset)))\n","    \n","    correct = 0\n","    for i, data in enumerate(train_loader):\n","        # 获得数据和对应的标签\n","        inputs, labels = data\n","        # 获得模型预测结果\n","        out = model(inputs)\n","        # 获得最大值,以及最大值所在的位置\n","        _, predicted = torch.max(out, 1)\n","        # 预测正确的数量\n","        correct += (predicted == labels).sum()\n","    print(\"Train acc: {0}\".format(correct.item() / len(train_dataset)))"]},{"cell_type": "code","execution_count": 8,"metadata": {"collapsed": false},"outputs": [{"name": "stdout","output_type": "stream","text": ["epoch: 0\n","Test acc: 0.785\n","Train acc: 0.825\n","epoch: 1\n","Test acc: 0.885\n","Train acc: 0.865\n","epoch: 2\n","Test acc: 0.845\n","Train acc: 0.8675\n","epoch: 3\n","Test acc: 0.945\n","Train acc: 0.885\n","epoch: 4\n","Test acc: 0.89\n","Train acc: 0.8675\n","epoch: 5\n","Test acc: 0.93\n","Train acc: 0.945\n","epoch: 6\n","Test acc: 0.915\n","Train acc: 0.93\n","epoch: 7\n","Test acc: 0.925\n","Train acc: 0.935\n","epoch: 8\n","Test acc: 0.9\n","Train acc: 0.9325\n","epoch: 9\n","Test acc: 0.91\n","Train acc: 0.9425\n"]}],"source": ["for epoch in range(0, 10):\n","    print('epoch:',epoch)\n","    train()\n","    test()\n","    \n","torch.save(model.state_dict(), 'cat_dog_cnn.pth')"]},{"cell_type": "code","execution_count": null,"metadata": {"collapsed": true},"outputs": [],"source": []},{"cell_type": "code","execution_count": null,"metadata": {"collapsed": true},"outputs": [],"source": []},{"cell_type": "code","execution_count": null,"metadata": {"collapsed": true},"outputs": [],"source": []},{"cell_type": "code","execution_count": null,"metadata": {"collapsed": true},"outputs": [],"source": []}],"metadata": {"anaconda-cloud": {},"kernelspec": {"display_name": "Python [default]","language": "python","name": "python3"},"language_info": {"codemirror_mode": {"name": "ipython","version": 3},"file_extension": ".py","mimetype": "text/x-python","name": "python","nbconvert_exporter": "python","pygments_lexer": "ipython3","version": "3.5.2"}},"nbformat": 4,"nbformat_minor": 2
}

这是一个Jupyter Notebook文件的内容,主要实现了使用预训练的VGG16模型对猫和狗的图像进行分类任务。以下是对每个部分的详细解释:

1. 导入必要的库

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
  • 导入了torch相关的库用于深度学习操作,包括神经网络定义(nn)、优化器(optim)以及预定义的模型(models)和数据处理工具(datasetstransforms)。还导入了DataLoader用于加载数据。

2. 数据预处理和加载

transform = transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomRotation(20),transforms.RandomHorizontalFlip(p=0.5),transforms.ToTensor()
])
root = 'image'
train_dataset = datasets.ImageFolder(root + '/train', transform)
test_dataset = datasets.ImageFolder(root + '/test', transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=8, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=8, shuffle=True)
  • 定义了数据预处理的操作,包括随机裁剪、旋转、水平翻转以及转换为张量。
  • 从指定的image文件夹下的traintest子文件夹中读取图像数据,并应用预处理操作。
  • 使用DataLoader分别创建了训练集和测试集的数据加载器,设置了批量大小为8,并打乱数据顺序。

3. 查看类别信息

classes = train_dataset.classes
classes_index = train_dataset.class_to_idx
print(classes)
print(classes_index)
  • 获取训练集中的类别名称列表和类别到索引的映射字典,并打印出来。这里显示有两个类别:猫和狗,以及它们对应的索引。

4. 加载预训练模型

model = models.vgg16(pretrained=True)
print(model)
  • 加载了预训练的VGG16模型,并打印出模型的结构,包括卷积层、池化层和全连接层等信息。

5. 修改模型结构(可选部分)

# 如果我们想只训练模型的全连接层
# for param in model.parameters():
#     param.requires_grad = False# 构建新的全连接层
model.classifier = torch.nn.Sequential(torch.nn.Linear(25088, 100),torch.nn.ReLU(),torch.nn.Dropout(p=0.5),torch.nn.Linear(100, 2))
  • 这部分代码展示了如何修改模型结构。首先注释掉了冻结所有层的代码,如果需要只训练全连接层,可以取消注释。然后重新定义了模型的全连接层部分,将输出类别改为2(猫和狗)。

6. 定义学习率、损失函数和优化器

LR = 0.0001
entropy_loss = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), LR, momentum=0.9)
  • 定义了学习率为0.0001,使用交叉熵损失函数,并选择随机梯度下降(SGD)作为优化器,设置了动量为0.9。

7. 定义训练和测试函数

def train():...
def test():...
  • train函数实现了模型的训练过程,包括获取数据和标签、计算模型输出、计算损失、梯度清零、反向传播和更新权重等步骤。
  • test函数实现了模型在测试集和训练集上的评估过程,计算预测正确的数量,并打印出准确率。

8. 模型训练和保存

for epoch in range(0, 10):print('epoch:', epoch)train()test()
torch.save(model.state_dict(), 'cat_dog_cnn.pth')
  • 进行10个轮次的训练和测试,每个轮次打印出当前轮次编号,并分别调用traintest函数。
  • 训练完成后,保存模型的权重到cat_dog_cnn.pth文件中。

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com