segnet 迁移学习
作者:互联网
本文主要参考博客https://blog.csdn.net/u012426298/article/details/81386817。
首先获取预训练模型,和相应的prototxt文件,连接就不上了,参考博客https://blog.csdn.net/u012426298/article/details/81386817
一、数据集
为了避免自己去标注太多的图片,所以采用了标注好了的nyu数据集。参考博客https://blog.csdn.net/weixin_43915709/article/details/88774325。对labels40.mat操作,得到 label 图。
#-*- coding:UTF-8 -*-
# 从mat文件提取labels
# 需要注意这个文件里面的格式和官方有所不同,长宽需要互换,也就是进行转置
import cv2
import scipy.io as scio
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import os
dataFile = './labels40.mat'
data = scio.loadmat(dataFile)
labels=np.array(data["labels40"])
path_converted='./nyu_labels40'
if not os.path.isdir(path_converted):
os.makedirs(path_converted)
labels_number=[]
for i in range(1449):
labels_number.append(labels[:,:,i].transpose((1, 0))) # 转置
labels_0=np.array(labels_number[i])
#print labels_0.shape
print (type(labels_0))
label_img=Image.fromarray(np.uint8(labels_number[i]))
#label_img = label_img.rotate(270)
label_img = label_img.transpose(Image.ROTATE_270)
iconpath='./nyu_labels40/'+str('%06d'%(i+1))+'.png'
label_img.save(iconpath, optimize=True)
二、 需要确定 label 图中的类别以及classd_weighting ,参考以下代码class_weight.py:
命令
python class_weight.py --dir ./labels_deal # label 图的路径
class_weight.py
import numpy as np
import argparse
import os
from PIL import Image
from os import listdir
import sys
import collections
# Import arguments
parser = argparse.ArgumentParser()
parser.add_argument('--dir', type=str, help='Path to the folder containing the images with annotations')
args = parser.parse_args()
if args.dir:
cwd = args.dir
if not args.dir.endswith('/'): cwd = cwd + '/'
else:
cwd = os.getcwd() + '/'
image_names = listdir(cwd)
# Keep only images and append image_names to directory
image_list = [cwd + s for s in image_names if s.lower().endswith(('.png', '.jpg', '.jpeg'))]
print "Number of images:", len(image_list)
def count_all_pixels(image_list):
dic_class_imgcount = dict()
overall_pixelcount = dict()
result = dict()
for img in image_list:
sys.stdout.write('.')
sys.stdout.flush()
for key, value in get_class_per_image(img).items():
# Sum up the number of classes returned from get_class_per_image function
overall_pixelcount[key] = overall_pixelcount.get(key, 0) + value
# If the class is present in the image, then increase the value by one
# shows in how many images a particular class is present
dic_class_imgcount[key] = dic_class_imgcount.get(key, 0) + 1
print "Done"
# Save above 2 variables in a list
for (k, v), (k2, v2) in zip(overall_pixelcount.items(), dic_class_imgcount.items()):
if k != k2: print ("This was impossible to happen, but somehow it did"); exit()
result[k] = [v, v2]
return result
def get_class_per_image(img):
dic_class_pixelcount = dict()
im = Image.open(img)
pix = im.load()
for x in range(im.size[0]):
for y in range(im.size[1]):
dic_class_pixelcount[pix[x, y]] = dic_class_pixelcount.get(pix[x, y], 0) + 1
#del dic_class_pixelcount[11]
return dic_class_pixelcount
def cal_class_weights(image_list):
freq_images = dict()
weights = collections.OrderedDict()
# calculate freq per class
for k, (v1, v2) in count_all_pixels(image_list).items():
freq_images[k] = v1 / (v2 * 360 * 480 * 1.0)
# calculate median of freqs
median = np.median(freq_images.values())
# calculate weights
for k, v in freq_images.items():
weights[k] = median / v
return weights
results = cal_class_weights(image_list)
# Print the results
for k, v in results.items():
print " class", k, "weight:", round(v, 4)
print "Copy this:"
for k, v in results.items():
print " class_weighting:", round(v, 4)
我一直以为我的 label 图是40类,https://blog.csdn.net/u012455577/article/details/86316996 。之前还凑合用别人的40类class_weighting: 结果一直出错
F0725 17:02:42.888584 17046 math_functions.cu:121] Check failed: status == CUBLAS_STATUS_SUCCESS (11 vs. 0) CUBLAS_STATUS_MAPPING_ERROR
真的是蠢死。。。用了class_weight.py 后才发现有48类。
三、制作 train.txt以及test.txt文件
txtfile.sh
#!/usr/bin/env sh
DATA_train=/home/zml/data/nyu/40_label/images_deal
MASK_train=/home/zml/data/nyu/40_label/labels_deal
DATA_test=/home/zml/data/nyu/40_label/images_deal
MASK_test=/home/zml/data/nyu/40_label/labels_deal
MY=/home/zml/temp/transferlearn/
################################################
rm -rf $MY/train.txt
echo "Create train.txt"
find $DATA_train/ -name "*.png">>$MY/img.txt
find $MASK_train/ -name "*.png">>$MY/mask.txt
paste -d " " $MY/img.txt $MY/mask.txt>$MY/train.txt
rm -rf $MY/img.txt
rm -rf $MY/mask.txt
##################################################
rm -rf $MY/test.txt
echo "Create test.txt"
find $DATA_test/ -name "*.png">>$MY/img.txt
find $MASK_test/ -name "*.png">>$MY/mask.txt
paste -d " " $MY/img.txt $MY/mask.txt>$MY/test.txt
rm -rf $MY/img.txt
rm -rf $MY/mask.txt
用命令 sh txtfile.sh 可得到train.txt和test.txt
四、修改segnet_train.prototxt
最后一个num_output 修改为自己数据集label图的类别总数:
修改ignore_label,修改为自己的类别数。并根据之前的class_weight.py得到的class_weighting修改文件。
五、修改segnet_solver.prototxt,修改里面的学习率等,这个文件比较多,所以就不细细讲了。
六、运行
/home/zml/caffe/caffe-segnet-cudnn5/build/tools/caffe train -solver /home/zml/temp/transferlearn/file/segnet_solver.prototxt -weights -solver /home/zml/temp/transferlearn/file/segnet_pascal.caffemodel -gpu 0
标签:img,image,label,学习,segnet,迁移,txt,MY,class 来源: https://blog.csdn.net/menglanzeng/article/details/97619173