17.200种鸟类图片分类
作者:互联网
这个是一个多分类问题,我们先看一下这个数据集
首先我们有一个文件夹叫birds_train
这个文件夹中有200个子文件夹,每一个文件夹中都是一种鸟类的照片
我们打开一个文件夹
其中有39张该鸟类的图片,这个数据量对于训练来讲是很少的,加之我们有200个分类,所以最终如果从网上单拎出一张图片来测试,效果不会很好
- 当然这个都是鸟也很大程度上会影响结果,就算是给人看所有图片,然后随便拿一张数据集之外的图片令其说出是哪一种的鸟也很困难
目录
1 导入库
2 数据处理
2.1 获取所有图像路径
我们依然先使用glob获取所有图片路径字符串列表
我们打印出来其中一张图片的路径看一下
2.2 提取所有图像标签
我们要提取的是字符串中间的 Black_footed_Albatross ,简单的一种方法是使用两个split,第一个的参数是\\,第二个的参数是点
我们现在要获取所有种类的label就不要一个一个写了,因为一共有200多个,这里我们使用np.unique(),这个方法可以寻找到列表中的唯一值
2.3 将标签与序号对应
我们现在需要给每个标签一个序号,这里我们用到了enumerate()方法,在这里我介绍过 python内建方法_potato123232的博客-CSDN博客
我们显示出来看一下
我们在预测的时候还需要 序号对标签 的字典(因为不是二分类了,我们如果通过if...else...处理会很麻烦)
现在我们使用label_to_index获取所有的序号标签
2.4 随机数据
上面的所有数据我们注意到是没有进行随机的,没有随机的数据是对训练的效果是有影响的,所以我们需要随机数据,在这里我们介绍一种随机数据的新方法
我们首先使用np.random.seed()设置随机种子,随机种子的参数是任意整形数据,设置随机种子的目的是,无论在那一台电脑,随机多少次,随机的结果都是一样的
之后我们创建长度为图片数量的连续随机整形数组,使用的方法为np.random.permutation(),参数为数组的长度
之后我们以random_index作为索引乱序imgs_path与all_labels,直接写出来看起来会很迷惑,我们先举个例子
我们当前顺序索引前四个值
发现返回的结果是按也是按顺序排列的
那么我现在把index换成几个随便写的数
我们发现返回值也是按我们写的数排列的,也就是现在的第0张是原数据集的第99张,现在的第1张,是原来的第149张,由于我们索引是固定的,只要原来的imgs_path与all_labels对应,那么重新排序后的imgs_path与all_labels依然对应
我们现在把后面的索引换为刚刚生成的random_index
2.5 选取训练数据与测试数据
我们定义前80%的数据为训练数据,其余为测试数据
之后定义路径与标签,从0-train_count是训练用的,从train_count-最后一个是测试用的
3 创建数据集
首先创建训练集
然后创建测试集
我们发现现在传入的是路径,在之前的 卫星图片分类 代码中我们是先读取的图片,然后创建的数据集,此处我们先用路径创建数据集,再进行图片读取
- 只是展示出来还有这种方法可以操作,读取图片与创建数据集的顺序对不影响任何东西
我们现在定义读取图片的函数
我们上面的这个函数要直接对train_ds与test_ds进行操作,这两个数据集中是有label的,所以我们的参数要写label,即使我们不对label执行任何操作,我们加载图片的方式与卫星图片分类加载图片的方式相同,每一行是什么意思可以看一下这个 15.卫星图片分类_potato123232的博客-CSDN博客
之后我们使用map对数据集执行load_img(),我们在这里介绍一下多线程的概念,首先我们定义变量AUTOTUNE,令其获取到自适应合适的线程数
- 合适的线程数与cpu有关,与gpu无关
之后使用map,我们在这里多定义一个num_paraller_calls
测试集也用相同的方式处理
4 设置重复,乱序与批次
5 建立模型
我们的模型是这样的
- 我们称没有激活的输出为logits,我们会得到200个概率值,这些概率值加起来为1,其中的最大值就是它的预测结果
我们引入了上一章的批标准化层,在输出层,此次我们不加激活函数,不加激活函数只是为了展示后续的操作,对训练的结果没有影响
我们看一下模型整体
6 编译模型
由于我们输出层没有激活函数,所以from_logits置为True,损失函数为SparseCategoricalCrossentroy,该损失函数应用在多分类问题的顺序编码的情况
7 训练模型
- 训练的过程会比较久,我们如果运行中出现内存问题,我们要减少batch
训练之后我们绘制出acc与loss的曲线看一下
- acc
- loss
通过曲线我们可以发现acc依然有上升的趋势,loss依然有下降的趋势,我们增加epoch有可能会有更好的效果
之后我们将模型保存一下,以便之后进行预测
我们确认一下生成了这个文件夹
8 读取模型
首先我们先读取模型
这个时候我们就再另一个新开的py文件中读取模型了
- 其余两个库下面会用到
由于我们的分类比较多,所以我们要使用之前定义的index_to_label,所以在这个文件中,要把之前的index_to_label复制进来
9 预测模型
首先我们需要把图像加载进来,但是我们不能使用之前的加载函数,因为之前是对整个数据集进行操作,现在是对单张图片进行操作,所以我们要这样写
- 与之前的处理方式相同,只是去掉了label
之后我们定义预测函数,加载图像后提升维度,然后使用model.predict()进行预测,这个时候得到的结果是200个概率值的集合(我们下面会举个例子看一下),之后我们使用argmax()获取最大值的索引,再之后使用index_to_label获取到最大值索引对应的标签名称
我们现在找一张图片测试一下
这张图片的结果正确与否不是很重要,如果要追求准确度,最好的办法还是加大数据量
首先出现的是200个概率值
之后是机器给的预测结果
标签:鸟类,模型,分类,label,17.200,随机,数据,我们,图片 来源: https://blog.csdn.net/potato123232/article/details/120777581