多层感知机fashion mnist
作者:互联网
import torch from matplotlib import pyplot as plt from torch import nn from d2l import torch as d2l net=nn.Sequential(nn.Flatten(),nn.Linear(784,256),nn.ReLU(), nn.Linear(256,10)) def init_weights(m): if type(m)==nn.Linear: nn.init.normal_(m.weight,std=0.01) net.apply(init_weights); batch_size,lr,num_epochs=256,0.1,10 loss=nn.CrossEntropyLoss() loss=nn.CrossEntropyLoss() trainer=torch.optim.SGD(net.parameters(),lr=lr) train_iter,test_iter=d2l.load_data_fashion_mnist(batch_size=batch_size) d2l.train_ch3(net,train_iter,test_iter,loss,num_epochs,trainer) plt.figure(figsize=(20,8),dpi=100) plt.show()
标签:fashion,nn,torch,iter,感知机,d2l,import,net,mnist 来源: https://blog.csdn.net/Li12139/article/details/122374003