在pytorch中 保存 和加载神经网络
作者:互联网
import torch
import matplotlib.pyplot as plt
torch.manual_seed(1) # reproducible
fake data
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1) # x data (tensor), shape=(100, 1)
y = x.pow(2) + 0.2*torch.rand(x.size()) # noisy y data (tensor), shape=(100, 1)
保存神经网络和神经网络当前训练后的状态
def save():
# save net1
net1 = torch.nn.Sequential(
torch.nn.Linear(1, 10),
torch.nn.ReLU(),
torch.nn.Linear(10, 1)
)
optimizer = torch.optim.SGD(net1.parameters(), lr=0.5)
loss_func = torch.nn.MSELoss()
for t in range(100): prediction = net1(x) loss = loss_func(prediction, y) optimizer.zero_grad() loss.backward() optimizer.step() # plot result plt.figure(1, figsize=(10, 3)) plt.subplot(131) plt.title('Net1') plt.scatter(
标签:loss,plt,nn,torch,神经网络,pytorch,100,net1,加载 来源: https://blog.51cto.com/u_15177056/2725550