其他分享
首页 > 其他分享> > EnsNet: Ensconce Text in the Wild 模型训练

EnsNet: Ensconce Text in the Wild 模型训练

作者:互联网

参考网址:

$ https://github.com/HCIILAB/Scene-Text-Removal

环境配置

$ git clone https://github.com/HCIILAB/Scene-Text-Removal
$ https://files.pythonhosted.org/packages/b0/e3/0a7bf93413623ec5a1fa42eb3c89f88731a62155f22ca6b1abc8c67c28d3/mxnet_cu90-1.5.0-py2.py3-none-win_amd64.whl
$ pip install mxnet_cu90-1.5.0-py2.py3-none-win_amd64.whl
$ pip install mxnet_cu90-1.5.0-py2.py3-none-win_amd64.whl
$ 验证是否安装成功
Python
>> import mxnet
import 成功说明没有问题。

下载数据集

$ 目的是为了参照它的数据集整理我们自己的数据集。数据格式最好和它的一样,方便先跑通代码。

模型训练

$ 目的是为了参照它的数据集整理我们自己的数据集。数据格式最好和它的一样,方便先跑通代码。

python train.py --trainset_path=’dataset’ --checkpoint=’save_model’ --gpu=0 --lr=0.0002 --n_epoch=5000

网络调整

训练这个网络存在的问题是:

给出来的数据和给的数据读取方式,不匹配,或者至少我没有理解。

解压以后的数据格式为:

syn_train下面包含img和label两个文件夹。

实际读图像的代码如下所示:

class MyDataSet(Dataset): def __init__(self, root, split, is_transform=False,is_train=True): self.root = os.path.join(root, split) self.is_transform = is_transform self.img_paths = [] self._img_512 = os.path.join(root, split, 'train_512', '{}.png') self._mask_512 = os.path.join(root, split, 'mask_512', '{}.png') self._lbl_512 = os.path.join(root, split, 'train_512', '{}.png') self._img_256 = os.path.join(root, split, 'train_256', '{}.png') self._lbl_256 = os.path.join(root, split, 'train_256', '{}.png') self._img_128 = os.path.join(root, split, 'train_128', '{}.png') for fn in os.listdir(os.path.join(root, split, 'train_512')): if len(fn) > 3 and fn[-4:] == '.png': self.img_paths.append(fn[:-4])
def __len__(self): return len(self.img_paths)
def __getitem__(self, idx): img_path_512 = self._img_512.format(self.img_paths[idx]) img_path_256 = self._img_256.format(self.img_paths[idx]) img_path_128 = self._img_128.format(self.img_paths[idx]) lbl_path_256 = self._lbl_256.format(self.img_paths[idx]) mask_path_512 = self._mask_512.format(self.img_paths[idx]) lbl_path_512 = self._lbl_512.format(self.img_paths[idx]) img_arr_256 = mx.image.imread(img_path_256).astype(np.float32)/127.5 - 1 img_arr_512 = mx.image.imread(img_path_512).astype(np.float32)/127.5 - 1 img_arr_128 = mx.image.imread(img_path_128).astype(np.float32)/127.5 - 1 img_arr_512 = mx.image.imresize(img_arr_512, img_wd * 2, img_ht) img_arr_in_512, img_arr_out_512 = [mx.image.fixed_crop(img_arr_512, 0, 0, img_wd, img_ht), mx.image.fixed_crop(img_arr_512, img_wd, 0, img_wd, img_ht)] if os.path.exists(mask_path_512): mask_512 = mx.image.imread(mask_path_512) else: mask_512 = mx.image.imread(mask_path_512.replace(".png",'.jpg',1)) tep_mask_512 = nd.slice_axis(mask_512, axis=2, begin=0, end=1)/255 if self.is_transform: imgs = [img_arr_out_512, img_arr_in_512, tep_mask_512,img_arr_256,img_arr_128] imgs = random_horizontal_flip(imgs) imgs = random_rotate(imgs) img_arr_out_512,img_arr_in_512,tep_mask_512,img_arr_256,img_arr_128 = imgs[0], imgs[1], imgs[2], imgs[3],imgs[4] img_arr_in_512, img_arr_out_512 = [nd.transpose(img_arr_in_512, (2,0,1)), nd.transpose(img_arr_out_512, (2,0,1))] img_arr_out_256 = nd.transpose(img_arr_256, (2,0,1)) img_arr_out_128 = nd.transpose(img_arr_128, (2,0,1)) tep_mask_512 = tep_mask_512.reshape(tep_mask_512.shape[0],tep_mask_512.shape[1],1) tep_mask_512 = nd.transpose(tep_mask_512,(2,0,1)) return img_arr_out_512,img_arr_in_512,tep_mask_512,img_arr_out_256,img_arr_out_128 不匹配,实际我们没有这么多文件夹。 排查问题的过程:

标签:arr,img,Text,self,mask,EnsNet,Wild,path,512
来源: https://www.cnblogs.com/wjjcjj/p/12017064.html