其他分享
首页 > 其他分享> > 在pytorch中 保存 和加载神经网络

在pytorch中 保存 和加载神经网络

作者:互联网

import torch
import matplotlib.pyplot as plt

torch.manual_seed(1) # reproducible

fake data

x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1) # x data (tensor), shape=(100, 1)
y = x.pow(2) + 0.2*torch.rand(x.size()) # noisy y data (tensor), shape=(100, 1)

保存神经网络和神经网络当前训练后的状态

def save():
# save net1
net1 = torch.nn.Sequential(
torch.nn.Linear(1, 10),
torch.nn.ReLU(),
torch.nn.Linear(10, 1)
)
optimizer = torch.optim.SGD(net1.parameters(), lr=0.5)
loss_func = torch.nn.MSELoss()

for t in range(100):
    prediction = net1(x)
    loss = loss_func(prediction, y)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

# plot result
plt.figure(1, figsize=(10, 3))
plt.subplot(131)
plt.title('Net1')
plt.scatter(

               

标签:loss,plt,nn,torch,神经网络,pytorch,100,net1,加载
来源: https://blog.51cto.com/u_15177056/2725550