其他分享
首页 > 其他分享> > segnet 迁移学习

segnet 迁移学习

作者:互联网

本文主要参考博客https://blog.csdn.net/u012426298/article/details/81386817

首先获取预训练模型,和相应的prototxt文件,连接就不上了,参考博客https://blog.csdn.net/u012426298/article/details/81386817

一、数据集

为了避免自己去标注太多的图片,所以采用了标注好了的nyu数据集。参考博客https://blog.csdn.net/weixin_43915709/article/details/88774325。对labels40.mat操作,得到 label 图。

#-*- coding:UTF-8 -*-
# 从mat文件提取labels
# 需要注意这个文件里面的格式和官方有所不同,长宽需要互换,也就是进行转置
import cv2
import scipy.io as scio
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import os

dataFile = './labels40.mat'
data = scio.loadmat(dataFile)
labels=np.array(data["labels40"])

path_converted='./nyu_labels40'
if not os.path.isdir(path_converted):
    os.makedirs(path_converted)

labels_number=[]
for i in range(1449):
    labels_number.append(labels[:,:,i].transpose((1, 0))) # 转置
    labels_0=np.array(labels_number[i])
    #print labels_0.shape
    print (type(labels_0))
    label_img=Image.fromarray(np.uint8(labels_number[i]))
    #label_img = label_img.rotate(270)
    label_img = label_img.transpose(Image.ROTATE_270)
 
    iconpath='./nyu_labels40/'+str('%06d'%(i+1))+'.png'
    label_img.save(iconpath, optimize=True)

二、 需要确定 label 图中的类别以及classd_weighting ,参考以下代码class_weight.py:

命令

python class_weight.py --dir  ./labels_deal  # label 图的路径

class_weight.py 

import numpy as np
import argparse
import os
from PIL import Image
from os import listdir
import sys
import collections

# Import arguments
parser = argparse.ArgumentParser()
parser.add_argument('--dir', type=str, help='Path to the folder containing the images with annotations')
args = parser.parse_args()

if args.dir:
    cwd = args.dir
    if not args.dir.endswith('/'): cwd = cwd + '/'
else:
    cwd = os.getcwd() + '/'

image_names = listdir(cwd)
# Keep only images and append image_names to directory
image_list = [cwd + s for s in image_names if s.lower().endswith(('.png', '.jpg', '.jpeg'))]

print "Number of images:", len(image_list)

def count_all_pixels(image_list):
    dic_class_imgcount = dict()
    overall_pixelcount = dict()
    result = dict()
    for img in image_list:
        sys.stdout.write('.')
        sys.stdout.flush()
        for key, value in get_class_per_image(img).items():
            # Sum up the number of classes returned from get_class_per_image function
            overall_pixelcount[key] = overall_pixelcount.get(key, 0) + value
            # If the class is present in the image, then increase the value by one
            # shows in how many images a particular class is present
            dic_class_imgcount[key] = dic_class_imgcount.get(key, 0) + 1
    print "Done"
    # Save above 2 variables in a list
    for (k, v), (k2, v2) in zip(overall_pixelcount.items(), dic_class_imgcount.items()):
        if k != k2: print ("This was impossible to happen, but somehow it did"); exit()
        result[k] = [v, v2]
    return result


def get_class_per_image(img):
    dic_class_pixelcount = dict()
    im = Image.open(img)
    pix = im.load()
    for x in range(im.size[0]):
        for y in range(im.size[1]):
            dic_class_pixelcount[pix[x, y]] = dic_class_pixelcount.get(pix[x, y], 0) + 1
    #del dic_class_pixelcount[11]
    return dic_class_pixelcount


def cal_class_weights(image_list):
    freq_images = dict()
    weights = collections.OrderedDict()
    # calculate freq per class
    for k, (v1, v2) in count_all_pixels(image_list).items():
        freq_images[k] = v1 / (v2 * 360 * 480 * 1.0)
    # calculate median of freqs
    median = np.median(freq_images.values())
    # calculate weights
    for k, v in freq_images.items():
        weights[k] = median / v
    return weights

results = cal_class_weights(image_list)

# Print the results
for k, v in results.items():
    print "    class", k, "weight:", round(v, 4)

print "Copy this:"
for k, v in results.items():
    print "    class_weighting:", round(v, 4)

我一直以为我的 label 图是40类,https://blog.csdn.net/u012455577/article/details/86316996 。之前还凑合用别人的40类class_weighting: 结果一直出错

F0725 17:02:42.888584 17046 math_functions.cu:121] Check failed: status == CUBLAS_STATUS_SUCCESS (11 vs. 0)  CUBLAS_STATUS_MAPPING_ERROR

真的是蠢死。。。用了class_weight.py 后才发现有48类。 

三、制作 train.txt以及test.txt文件

txtfile.sh

#!/usr/bin/env sh
DATA_train=/home/zml/data/nyu/40_label/images_deal
MASK_train=/home/zml/data/nyu/40_label/labels_deal
DATA_test=/home/zml/data/nyu/40_label/images_deal
MASK_test=/home/zml/data/nyu/40_label/labels_deal

MY=/home/zml/temp/transferlearn/
 
################################################
rm -rf $MY/train.txt
 
echo "Create train.txt"
find $DATA_train/ -name "*.png">>$MY/img.txt
find $MASK_train/ -name "*.png">>$MY/mask.txt
paste -d " " $MY/img.txt $MY/mask.txt>$MY/train.txt
 
rm -rf $MY/img.txt
rm -rf $MY/mask.txt
 
##################################################
rm -rf $MY/test.txt
 
echo "Create test.txt"
find $DATA_test/ -name "*.png">>$MY/img.txt
find $MASK_test/ -name "*.png">>$MY/mask.txt
paste -d " " $MY/img.txt $MY/mask.txt>$MY/test.txt
 
rm -rf $MY/img.txt
rm -rf $MY/mask.txt

 用命令 sh  txtfile.sh 可得到train.txt和test.txt

 四、修改segnet_train.prototxt 

最后一个num_output 修改为自己数据集label图的类别总数: 

修改ignore_label,修改为自己的类别数。并根据之前的class_weight.py得到的class_weighting修改文件。 

五、修改segnet_solver.prototxt,修改里面的学习率等,这个文件比较多,所以就不细细讲了。

六、运行

/home/zml/caffe/caffe-segnet-cudnn5/build/tools/caffe  train -solver /home/zml/temp/transferlearn/file/segnet_solver.prototxt -weights  -solver /home/zml/temp/transferlearn/file/segnet_pascal.caffemodel -gpu 0

 

标签:img,image,label,学习,segnet,迁移,txt,MY,class
来源: https://blog.csdn.net/menglanzeng/article/details/97619173