pytorch建立线性回归神经网络
- 模型建立,简单回归问题
- 可视化
- 模型保存,保存为pth结构
- 模型的调用,注意要有有原来的网络结构
- 应用调用的模型, 可视化2
模型建立,简单回归问题
import torch.nn as nn
x_data =torch.tensor([[1.0],[2.0],[3.0]])
y_data=torch.tensor([[2.0],[4.0],[6.0]])
#重点在于构造计算图 pytorch会自动计算梯度
#Z=wx+b 就是一个线性单元
class LinearModel(nn.Module):#Module的对象会自动实现backword()的过程#构造函数def __init__(self) :super(LinearModel, self).__init__()#Linear()构建y=wx+b,且继承于Module自动完成backword()的过程self.layer=nn.Sequential(nn.Linear(1,20),nn.Linear(20,20),nn.Linear(20,20),nn.Linear(20,1))#前馈计算的函数 必须有def forward(self,x):#调用linear的__call__(),在此函数中会调用forward()y_pred=self.layer(x)return y_preddef train(model, optimizer, criterion, num_epochs):losses = []for epoch in range(num_epochs):optimizer.zero_grad()y_pred=model(x_data)loss=criterion(y_pred,y_data)loss.backward()optimizer.step()losses.append(loss.item())if epoch % 300 == 0:print(f'Epoch {epoch}, Loss: {loss.item()}')return lossesmodel = LinearModel()
#调用损失函数
criterion=nn.MSELoss(size_average=False)
#优化器,lr学习率
optimizer=torch.optim.Adam(model.parameters(),lr=0.01)
losses = train(model, optimizer, criterion, num_epochs=1000)#迭代步数可增大,文章中用的是10000
# print(losses)
可视化
import matplotlib.pyplot as plt
import numpy as np
plt.figure(figsize=(3,2))
import matplotlib.pyplot as plt
plt.plot(losses)
plt.show()
plt.figure(2,figsize=(3,2))
plt.scatter(x_data,y_data)x_new=torch.Tensor(np.arange(0,4,0.01)).reshape(-1,1)
y_new=model(x_new)
# print(y_new)
plt.plot(x_new.detach().numpy(),y_new.detach().numpy(),color='r')
plt.show()
模型保存,保存为pth结构
保存模型,仅保存网络参数
#保存模型,仅保存网络参数
torch.save(model.state_dict(), 'model_params.pth')
模型的调用,注意要有有原来的网络结构
import torch.nn as nn
class LinearModel(torch.nn.Module):#Module的对象会自动实现backword()的过程#构造函数def __init__(self) :super(LinearModel, self).__init__()#Linear()构建y=wx+b,且继承于Module自动完成backword()的过程self.layer=nn.Sequential(nn.Linear(1,20),nn.Linear(20,20),nn.Linear(20,20),nn.Linear(20,1))#前馈计算的函数 必须有def forward(self,x):#调用linear的__call__(),在此函数中会调用forward()y_pred=self.layer(x)return y_pred
net=LinearModel()
net.load_state_dict(torch.load('model_params.pth'))
应用调用的模型, 可视化2
import matplotlib.pyplot as plt
import numpy as np
plt.figure(2,figsize=(3,2))
# plt.scatter(x_data,y_data)
x_new=torch.Tensor(np.arange(0,4,0.01)).reshape(-1,1)
y_new=net(x_new)
# print(y_new)
plt.plot(x_new.detach().numpy(),y_new.detach().numpy(),color='r')
plt.show()