Pytorch高阶API示范——线性回归模型
生活随笔
收集整理的這篇文章主要介紹了
Pytorch高阶API示范——线性回归模型
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
本文與《20天吃透Pytorch》有所不同,《20天吃透Pytorch》中是繼承之前的模型進行擬合,本文是單獨建立網絡進行擬合。
代碼實現:
import torch import numpy as np import matplotlib.pyplot as plt import pandas as pd from torch import nn import torch.nn.functional as F from torch.utils.data import Dataset,DataLoader,TensorDataset""" 1.準備數據 """ n=800 #樣本數量#生成測試用的數據集 X = 10*torch.rand([n,2])-5.0 #torch.rand是均勻分布 w0 = torch.tensor([[2.0],[-3.0]]) b0 = torch.tensor([10.0]) Y = X@w0 + b0 + torch.normal(0.0,2.0,size=[n,1]) ## @表示矩陣乘法,增加正態擾動#數據可視化 plt.figure(figsize= (12,5)) ax1 = plt.subplot(121) ax1.scatter(X[:,0],Y[:,0],c = 'b',label = 'samples') ax1.legend() #圖例 plt.xlabel("x1") plt.ylabel("y",rotation = 0) ax2 = plt.subplot(122) ax2.scatter(X[:,1],Y[:,0],c = 'g',label = 'samples') ax2.legend() plt.xlabel('x2') plt.ylabel('y',rotation = 0) plt.show()""" 構建通道 """ds = TensorDataset(X,Y) ds_train,ds_valid = torch.utils.data.random_split(ds,[int (n*0.7),n-int(n*0.7)]) #選取總樣本的70%為訓練數據 dl_train = DataLoader(ds_train,batch_size=10,shuffle=True) dl_valid = DataLoader(ds_valid,batch_size=10,shuffle=True)""" 2.定義模型 """class LinearRegression(torch.nn.Module):def __init__(self):super(LinearRegression, self).__init__()self.fc = nn.Linear(2,1)def forward(self,x):x = self.fc(x)return xnet = LinearRegression() """ 3.訓練模型 """ loss_func = torch.nn.MSELoss() optimizer= torch.optim.Adam(net.parameters(),lr = 0.01)eporchs = 10 log_step_freq = 20for eporch in range(1,eporchs+1):net.train()loss_sum = 0.0metric_sum = 0.0step = 1for step,(features,labels) in enumerate(dl_train,1):predictions = net(features)loss = loss_func(predictions,labels)optimizer.zero_grad()loss.backward()optimizer.step()w = net.state_dict()["fc.weight"]b = net.state_dict()["fc.bias"]print("step =", step, "loss = ", loss)print("w =", w)print("b =", b)loss_sum += loss.item()""" 結果可視化 """ w,b = net.state_dict()["fc.weight"],net.state_dict()["fc.bias"]plt.figure(figsize = (12,5)) ax1 = plt.subplot(121) ax1.scatter(X[:,0],Y[:,0], c = "b",label = "samples") ax1.plot(X[:,0],w[0,0]*X[:,0]+b[0],"-r",linewidth = 5.0,label = "model") ax1.legend() plt.xlabel("x1") plt.ylabel("y",rotation = 0)ax2 = plt.subplot(122) ax2.scatter(X[:,1],Y[:,0], c = "g",label = "samples") ax2.plot(X[:,1],w[0,1]*X[:,1]+b[0],"-r",linewidth = 5.0,label = "model") ax2.legend() plt.xlabel("x2") plt.ylabel("y",rotation = 0)plt.show()結果展示:
數據部分:
線性回歸結果:
創作挑戰賽新人創作獎勵來咯,堅持創作打卡瓜分現金大獎總結
以上是生活随笔為你收集整理的Pytorch高阶API示范——线性回归模型的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 做梦梦到瓜子壳什么意思
- 下一篇: 梦到二胎生了个男孩是什么意思