您的位置:首页 > 健康 > 养生 > 企业年报申报入口官网_淘宝网络推广怎么做_北京seo优化公司_怎么做微信小程序

企业年报申报入口官网_淘宝网络推广怎么做_北京seo优化公司_怎么做微信小程序

2024/12/22 11:16:01 来源:https://blog.csdn.net/hzether/article/details/144533532  浏览:    关键词:企业年报申报入口官网_淘宝网络推广怎么做_北京seo优化公司_怎么做微信小程序
企业年报申报入口官网_淘宝网络推广怎么做_北京seo优化公司_怎么做微信小程序

猫狗图像分类项目

这是一个使用PyTorch实现的猫狗图像分类项目。

项目结构

  • model.py: 定义了CNN模型结构
  • train.py: 训练模型的脚本
  • predict.py: 使用训练好的模型进行预测
  • requirements.txt: 项目依赖

环境配置

  1. 创建虚拟环境(推荐)
python -m venv venv
source venv/bin/activate  # Linux/Mac
venv\Scripts\activate     # Windows
  1. 安装依赖
#requirements.txt
torch>=2.0.0
torchvision>=0.15.0
pillow>=9.0.0
numpy>=1.21.0
tqdm>=4.65.0···```bash
pip install -r requirements.txt

数据集准备

准备数据集目录结构如下:

data/
├── train/
│   ├── cat/
│   │   ├── cat1.jpg
│   │   ├── cat2.jpg
│   │   └── ...
│   └── dog/
│       ├── dog1.jpg
│       ├── dog2.jpg
│       └── ...

训练模型

python train.py

预测

修改 predict.py 中的图片路径和模型路径,然后运行:

python predict.py

模型说明

  • 使用了简单的CNN架构
  • 输入图像大小:224x224
  • 输出类别:猫(0)和狗(1)

下载图片

python download_dataset.py

import os
import urllib.request
import zipfile
from tqdm import tqdmdef download_file(url, filename):"""下载文件并显示进度条"""class DownloadProgressBar(tqdm):def update_to(self, b=1, bsize=1, tsize=None):if tsize is not None:self.total = tsizeself.update(b * bsize - self.n)with DownloadProgressBar(unit='B', unit_scale=True,miniters=1, desc=filename) as t:urllib.request.urlretrieve(url, filename=filename,reporthook=t.update_to)def prepare_dataset():# 创建数据目录os.makedirs('data/train/cat', exist_ok=True)os.makedirs('data/train/dog', exist_ok=True)# 下载示例数据集print("正在下载示例数据集...")dataset_url = "https://download.microsoft.com/download/3/E/1/3E1C3F21-ECDB-4869-8368-6DEBA77B919F/kagglecatsanddogs_5340.zip"zip_path = "dataset.zip"try:download_file(dataset_url, zip_path)print("正在解压数据集...")with zipfile.ZipFile(zip_path, 'r') as zip_ref:zip_ref.extractall("temp_data")# 移动文件到对应目录import shutilcat_source = "temp_data/PetImages/Cat"dog_source = "temp_data/PetImages/Dog"print("正在整理数据...")# 移动一部分猫的图片for i, filename in enumerate(os.listdir(cat_source)):if i >= 1000:  # 只使用1000张图片breaksrc = os.path.join(cat_source, filename)dst = os.path.join('data/train/cat', filename)try:if os.path.getsize(src) > 0:  # 检查文件是否有效shutil.copy2(src, dst)except:continue# 移动一部分狗的图片for i, filename in enumerate(os.listdir(dog_source)):if i >= 1000:  # 只使用1000张图片breaksrc = os.path.join(dog_source, filename)dst = os.path.join('data/train/dog', filename)try:if os.path.getsize(src) > 0:  # 检查文件是否有效shutil.copy2(src, dst)except:continue# 清理临时文件print("清理临时文件...")os.remove(zip_path)shutil.rmtree("temp_data")print("数据集准备完成!")print(f"猫图片数量: {len(os.listdir('data/train/cat'))}")print(f"狗图片数量: {len(os.listdir('data/train/dog'))}")except Exception as e:print(f"下载或处理数据时出错: {str(e)}")if __name__ == "__main__":prepare_dataset()

各取1000张进行训练:

import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
from tqdm import tqdm
from model import CatDogNet# 设置设备 - 如果有GPU就用GPU,没有就用CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')# 数据预处理转换
# 1. 调整图片大小为224x224
# 2. 转换为tensor格式
# 3. 标准化处理(使用ImageNet的均值和标准差)
transform = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])def train_model(data_dir, num_epochs=10, batch_size=32, learning_rate=0.001):"""训练模型的主函数参数:data_dir (str): 数据集目录路径num_epochs (int): 训练轮数,默认10轮batch_size (int): 批次大小,默认32learning_rate (float): 学习率,默认0.001"""# 加载训练数据集# ImageFolder会自动根据子文件夹名称作为类别标签train_dataset = datasets.ImageFolder(os.path.join(data_dir, 'train'),transform=transform)# 创建数据加载器# shuffle=True 确保每个epoch数据顺序随机# num_workers=4 使用4个进程加载数据train_loader = DataLoader(train_dataset,batch_size=batch_size,shuffle=True,num_workers=4)# 初始化模型、损失函数和优化器model = CatDogNet().to(device)  # 将模型移到GPU(如果可用)criterion = nn.CrossEntropyLoss()  # 交叉熵损失函数optimizer = optim.Adam(model.parameters(), lr=learning_rate)  # Adam优化器# 训练循环for epoch in range(num_epochs):model.train()  # 设置为训练模式running_loss = 0.0  # 记录总损失correct = 0  # 记录正确预测数total = 0  # 记录总样本数# 创建进度条progress_bar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}')for inputs, labels in progress_bar:# 将数据移到GPU(如果可用)inputs, labels = inputs.to(device), labels.to(device)# 清零梯度optimizer.zero_grad()# 前向传播outputs = model(inputs)# 计算损失loss = criterion(outputs, labels)# 反向传播loss.backward()# 更新参数optimizer.step()# 统计训练信息running_loss += loss.item()_, predicted = outputs.max(1)total += labels.size(0)correct += predicted.eq(labels).sum().item()# 更新进度条信息progress_bar.set_postfix({'loss': f'{running_loss/len(progress_bar):.3f}',  # 平均损失'acc': f'{100.*correct/total:.2f}%'  # 准确率})# 每5个epoch保存一次模型if (epoch + 1) % 5 == 0:torch.save(model.state_dict(), f'model_epoch_{epoch+1}.pth')print('训练完成!')return modelif __name__ == '__main__':# 设置数据集路径并开始训练data_dir = './data'  # 数据集路径model = train_model(data_dir)  # 开始训练模型

预测

import torch
from PIL import Image
from torchvision import transforms
from model import CatDogNetdef predict_image(image_path, model_path):# 设置设备device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')# 加载模型model = CatDogNet().to(device)model.load_state_dict(torch.load(model_path))model.eval()# 图像预处理transform = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])# 加载并处理图像image = Image.open(image_path).convert('RGB')image = transform(image).unsqueeze(0).to(device)# 预测with torch.no_grad():outputs = model(image)_, predicted = torch.max(outputs, 1)# 返回预测结果class_names = ['cat', 'dog']return class_names[predicted.item()]if __name__ == '__main__':# 使用示例image_path = 'path_to_your_image.jpg'  # 修改为你的图片路径model_path = 'model_epoch_10.pth'      # 修改为你的模型路径result = predict_image(image_path, model_path)print(f'预测结果: {result}')
result, confidence = predict_image('path_to_your_image.jpg') #单张
predict_directory('path_to_your_directory') #批量

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com