其他分享
首页 > 其他分享> > 飞桨学习二、本地开发环境搭建与测试(训练手写体识别)

飞桨学习二、本地开发环境搭建与测试(训练手写体识别)

作者:互联网

飞桨学习二、本地开发环境搭建与测试(训练手写体识别)


在这里插入图片描述

一、准备环境

二、安装paddlepaddle

# 这里使用CPU版本,因为我的电脑没英伟达GPU
python -m pip install paddlepaddle==2.0.1 -i https://mirror.baidu.com/pypi/simple

在这里插入图片描述

三、书写程序

1. 引入paddlepaddle

#加载飞桨
import paddle
# 查看版本2.0.1
print(paddle.__version__)

2. 加载数据集

from paddle.vision.transforms import ToTensor

train_dataset = paddle.vision.datasets.MNIST(mode='train', transform=ToTensor())
val_dataset =  paddle.vision.datasets.MNIST(mode='test', transform=ToTensor())

3. 模型搭建

# Sequential形式组网
mnist = paddle.nn.Sequential(
# Flatten 将[1, 28, 28]形状的图片数据改变形状为[1, 784]。
    paddle.nn.Flatten(),
    paddle.nn.Linear(784, 512),
    paddle.nn.ReLU(),
    paddle.nn.Dropout(0.2),
    paddle.nn.Linear(512, 10)
)

4. 模型训练

# 预计模型结构生成模型对象,便于进行后续的配置、训练和验证
model = paddle.Model(mnist)

# 模型训练相关配置,准备损失计算方法,优化器和精度计算方法
model.prepare(paddle.optimizer.Adam(parameters=model.parameters()),
              paddle.nn.CrossEntropyLoss(),
              paddle.metric.Accuracy())

# 开始模型训练
model.fit(train_dataset,
          epochs=5,
          batch_size=64,
          verbose=1)

在这里插入图片描述

5. 模型评估

使用预先定义的验证数据集来评估前一步训练得到的模型的精度。
在这里插入图片描述

model.evaluate(val_dataset, verbose=0)
{'loss': [1.0728842e-06], 'acc': 0.9822}

可以看出,初步训练得到的模型效果在98%附近,在逐渐了解飞桨后,可以通过调整其中的训练参数来提升模型的精度。

标签:训练,nn,模型,paddlepaddle,paddle,飞桨,手写体,model,搭建
来源: https://blog.csdn.net/xundh/article/details/115577743