您的位置:首页 > 财经 > 金融 > 企业网站的购买方式_建站63年来第一次闭站 北京站辟谣_品牌整合营销传播_如何建立自己的网络销售

企业网站的购买方式_建站63年来第一次闭站 北京站辟谣_品牌整合营销传播_如何建立自己的网络销售

2024/11/17 22:36:50 来源:https://blog.csdn.net/m0_53115174/article/details/142855919  浏览:    关键词:企业网站的购买方式_建站63年来第一次闭站 北京站辟谣_品牌整合营销传播_如何建立自己的网络销售
企业网站的购买方式_建站63年来第一次闭站 北京站辟谣_品牌整合营销传播_如何建立自己的网络销售

文章使用Fashion-MNIST数据集,做一次分类识别任务
Fashion-MNIST中包含的10个类别,分别为:
t-shirt(T恤)、trouser(裤子)、pullover(套衫)、dress(连衣裙)、coat(外套)
sandal(凉鞋)、shirt(衬衫)、sneaker(运动鞋)、bag(包)、ankle boot(短靴)

0 图像数据

0.1 读取展示数据

import torch
import torchvision
from torch.utils import data
from torchvision import transforms
import matplotlib.pyplot as plt# 下载数据集 ,60,000 个训练样本和 10,000 个测试样本,每个样本包含一张28*28的灰度图和一个标签trans = transforms.ToTensor()
mnist_train = torchvision.datasets.FashionMNIST(root="D/DL_Data/Fashion-MNIST", train=True, transform=trans, download=True)
mnist_test = torchvision.datasets.FashionMNIST(root="D/DL_Data/Fashion-MNIST", train=False, transform=trans, download=True)print("test:",len(mnist_test))
print("train:",len(mnist_train))# 获取第一个样本的图像和标签
image, label = mnist_train[0]
print("图像的形状:", image.shape)
print("标签:", label)

在这里插入图片描述

0.2 可视化图像

# 可视化
def show_img():class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat','Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']# 可视化前5张图片fig, axes = plt.subplots(1, 10, figsize=(15, 3))for i in range(10):# 获取第 i 个样本的图像和标签image, label = mnist_train[i]# 将图像从 Tensor 转换回 numpy 数组,并移除通道维度image_np = image.squeeze().numpy()# 在子图中显示图像axes[i].imshow(image_np, cmap='gray')axes[i].set_title(f'Label: {class_names[label]}')axes[i].axis('off')  # 关闭坐标轴plt.tight_layout()plt.show()show_img()

在这里插入图片描述

0.3 整合为数据加载模块

def load_data_fashion_mnist(batch_size, resize=None):  #@savetrans = [transforms.ToTensor()]if resize:trans.insert(0, transforms.Resize(resize))trans = transforms.Compose(trans)mnist_train = torchvision.datasets.FashionMNIST(root="../data", train=True, transform=trans, download=True)mnist_test = torchvision.datasets.FashionMNIST(root="../data", train=False, transform=trans, download=True)return (data.DataLoader(mnist_train, batch_size, shuffle=True,num_workers=get_dataloader_workers()),data.DataLoader(mnist_test, batch_size, shuffle=False,num_workers=get_dataloader_workers()))train_iter, test_iter = load_data_fashion_mnist(256, resize=28)

1 初始化参数模型

我们选择把 28 ∗ 28 28*28 2828的图片展开成 1 ∗ 784 1*784 1784的向量,认为每个像素位置都是一个特征,所以输入是784维,输出是10个类别标签,所以输出是10维

因为softmaxhi回归类似于线性回归,所以权重 w w w应该是 784 ∗ 10 784*10 78410 的矩阵,偏置是 1 ∗ 10 1*10 110 的行向量,接下来如同线性回归中一样,使用正太分布初始化权重,偏置初始化为0:

num_inputs = 784
num_outputs = 10W = torch.normal(0, 0.01, size=(num_inputs, num_outputs), requires_grad=True)
b = torch.zeros(num_outputs, requires_grad=True)

2 定义softmax操作

回顾一下softmax的公式:
在这里插入图片描述
由三个步骤组成:

  1. 对每个项目求幂
  2. 将每一行求和(小批量样本中,每个样本是一行),得到每个样本的规范化常数。
  3. 将每一行除以其规范化常数,确保结果的和为1。
# 定义softmax操作
def softmax(x):x_exp=torch.exp(x)x_exp_sum=x_exp.sum(1,keepdim=True)return x_exp/x_exp_sum

3 定义模型

# 定义模型
def net(x):x = x.reshape(-1, w.shape[0])  # 将图片重塑为 [batch_size, 784]temp = torch.matmul(x, w)temp = temp + breturn softmax(temp)

4 定义损失函数

使用从0开始深度学习(8)——softmax回归提到的交叉熵损失函数

# 定义损失函数
def cross_entropy(y_hat, y): # 预测值、真实值return - torch.log(y_hat[range(len(y_hat)), y]) # 计算负对数似然cross_entropy(y_hat, y)

5 分类精度

分类精度即正确预测数量与总预测数量之比。

def compute_accuracy(y_hat, y):  # 预测值、真实值if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:y_hat = y_hat.argmax(axis=1)  # 找到一个样本中,对应的最大概率的类别cmp = y_hat.type(y.dtype) == y  # 将预测值 y_hat 与真实标签 y 进行比较,生成一个布尔张量 cmpreturn float(cmp.type(y.dtype).sum())# 计算在指定数据集上模型的准确率
def evaluate_accuracy(net, data_iter):  if isinstance(net, torch.nn.Module):net.eval()  # 将模型设置为评估模式metric = Accumulator(2)  # 累加多个变量的总和。这里初始化了一个包含两个元素的累加器,分别用来存储正确预测的数量和总的预测数量。with torch.no_grad():for X, y in data_iter:metric.add(compute_accuracy(net(X), y), y.numel())return metric[0] / metric[1]class Accumulator:  #@save"""在n个变量上累加"""def __init__(self, n):self.data = [0.0] * ndef add(self, *args):self.data = [a + float(b) for a, b in zip(self.data, args)]def reset(self):self.data = [0.0] * len(self.data)def __getitem__(self, idx):return self.data[idx]# 评估模型
accuracy = evaluate_accuracy(net, test_iter)
print(f"Test Accuracy: {accuracy:.4f}")

在这里插入图片描述

6 定义优化器

# 定义优化器
def sgd(params, lr, batch_size):with torch.no_grad():for param in params:param -= lr * param.grad / batch_sizeparam.grad.zero_()

7 训练

# 训练模型
def train_epoch(net, train_iter, loss, updater):if isinstance(net, torch.nn.Module):net.train()  # 将模型设置为训练模式metric = Accumulator(3)  # 训练损失总和、训练准确度总和、样本数for X, y in train_iter:y_hat = net(X)l = loss(y_hat, y).mean()if isinstance(updater, torch.optim.Optimizer):updater.zero_grad()l.backward()updater.step()else:l.backward()updater([w, b], lr, batch_size)metric.add(float(l) * y.numel(), compute_accuracy(y_hat, y), y.numel())return metric[0] / metric[2], metric[1] / metric[2]def train(net, train_iter, test_iter, loss, num_epochs, updater):for epoch in range(num_epochs):train_metrics = train_epoch(net, train_iter, loss, updater)test_acc = evaluate_accuracy(net, test_iter)print(f'Epoch {epoch + 1}: Train Loss {train_metrics[0]:.3f}, Train Acc {train_metrics[1]:.3f}, Test Acc {test_acc:.3f}')# 训练模型
updater = lambda params, lr, batch_size: sgd(params, lr, batch_size)
train(net, train_iter, test_iter, cross_entropy, num_epochs, updater)

在这里插入图片描述

8 预测

# 定义 Fashion-MNIST 标签的文本描述
def get_fashion_mnist_labels(labels):text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat','sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']return [text_labels[int(i)] for i in labels]# 预测并显示结果
def predict(net, test_iter, n=6):for X, y in test_iter:break  # 只取一个批次的数据trues = get_fashion_mnist_labels(y)preds = get_fashion_mnist_labels(net(X).argmax(axis=1))titles = [true + '\n' + pred for true, pred in zip(trues, preds)]n = min(n, X.shape[0])fig, axs = plt.subplots(1, n, figsize=(12, 3))for i in range(n):axs[i].imshow(X[i].permute(1, 2, 0).squeeze().numpy(), cmap='gray')axs[i].set_title(titles[i])axs[i].axis('off')plt.show()# 调用预测函数
predict(net, test_iter, n=6)

在这里插入图片描述

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com