基于深度学习的显著性检测用于遥感影像地物提取(MINet)
作者:互联网
这个跟前两个一样,显著性检测貌似无法解决我的问题,我发誓这是最后一个了,准备换个方向解决我的问题了,虽然我的目的没达到,但是这个的效果确实还行的,有需要的可以好好调整一下。
使用链接:https://github.com/lartpang/MINet
原图
标签
预测结果
评价结果:
acc: 0.9055214352077908
acc_cls: 0.8682510382904347
iou: [0.88870665 0.61525859]
miou: 0.7519826202767053
fwavacc: 0.8376228680494308
class_accuracy: 0.7143424443012731
class_recall: 0.7998325458213021
accuracy: 0.9007926079195722
f1_score: 0.7546741156614227
注意这个是我默认参数跑的,iou上来就是0.6以上了,感觉效果不错,不过这个跑的有点慢。
1.数据准备
数据准备很简单,就是普通的存放方式
一级目录
二级目录
这里面的文件夹名字最好和我一样,代码里是通过这个名字拼凑路径的,另外,图像和标签的名字保持一样就行。
2.数据导入
这里要改的就是测试时的数据导入,训练的数据导入包含了测试和验证,我把训练时候的验证去掉了
# -*- coding: utf-8 -*-
# @Time : 2020/7/22
# @Author : Lart Pang
# @Email : lartpang@163.com
# @File : dataloader.py
# @Project : code
# @GitHub : https://github.com/lartpang
import os
import random
from functools import partial
import torch
from PIL import Image
from prefetch_generator import BackgroundGenerator
from torch.nn.functional import interpolate
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torchvision import transforms
from config import arg_config
from utils.joint_transforms import Compose, JointResize, RandomHorizontallyFlip, RandomRotate
from utils.misc import construct_print
def _get_suffix(path_list):
ext_list = list(set([os.path.splitext(p)[1] for p in path_list]))
if len(ext_list) != 1:
if ".png" in ext_list:
ext = ".png"
elif ".jpg" in ext_list:
ext = ".jpg"
elif ".bmp" in ext_list:
ext = ".bmp"
else:
raise NotImplementedError
construct_print(f"数据文件夹中包含多种扩展名,这里仅使用{ext}")
else:
ext = ext_list[0]
return ext
def _make_dataset(root):
img_path = os.path.join(root, "Image")
mask_path = os.path.join(root, "Mask")
img_list = os.listdir(img_path)
mask_list = os.listdir(mask_path)
img_suffix = _get_suffix(img_list)
mask_suffix = _get_suffix(mask_list)
img_list = [os.path.splitext(f)[0] for f in mask_list if f.endswith(mask_suffix)]
return [
(
os.path.join(img_path, img_name + img_suffix),
os.path.join(mask_path, img_name + mask_suffix),
)
for img_name in img_list
]
def _make_dataset2(root):
img_path = os.path.join(root, "Image")
# mask_path = os.path.join(root, "Mask")
img_list = os.listdir(img_path)
# mask_list = os.listdir(mask_path)
img_suffix = _get_suffix(img_list)
# mask_suffix = _get_suffix(mask_list)
# img_list = [os.path.splitext(f)[0] for f in mask_list if f.endswith(mask_suffix)]
return [
(
os.path.join(img_path, img_name),
# os.path.join(mask_path, img_name + mask_suffix),
)
for img_name in img_list
]
def _read_list_from_file(list_filepath):
img_list = []
with open(list_filepath, mode="r", encoding="utf-8") as openedfile:
line = openedfile.readline()
while line:
img_list.append(line.split()[0])
line = openedfile.readline()
return img_list
def _make_dataset_from_list(list_filepath, prefix=(".png", ".png")):
img_list = _read_list_from_file(list_filepath)
return [
(
os.path.join(
os.path.join(os.path.dirname(img_path), "Image"), #路径拼凑的地方
os.path.basename(img_path) + prefix[0],
),
os.path.join(
os.path.join(os.path.dirname(img_path), "Mask"), #路径拼凑的地方
os.path.basename(img_path) + prefix[1],
),
)
for img_path in img_list
]
def _make_dataset_from_list2(list_filepath, prefix=(".png", ".png")): #用于测试数据导入,不需要标签,测试还要标签是很多时候不遇到的情况
img_list = _read_list_from_file(list_filepath)
return [
(
os.path.join(
os.path.join(os.path.dirname(img_path), "Image"), #路径拼凑的地方
os.path.basename(img_path) + prefix[0],
),
# os.path.join(
# os.path.join(os.path.dirname(img_path), "Mask"),
# os.path.basename(img_path) + prefix[1],
# ),
)
for img_path in img_list
]
class ImageFolder(Dataset):
def __init__(self, root, in_size, training, prefix, use_bigt=False):
self.training = training
self.use_bigt = use_bigt
if os.path.isdir(root):
construct_print(f"{root} is an image folder, we will test on it.")
self.imgs = _make_dataset(root)
elif os.path.isfile(root):
construct_print(
f"{root} is a list of images, we will use these paths to read the "
f"corresponding image"
)
self.imgs = _make_dataset_from_list(root, prefix=prefix)
else:
raise NotImplementedError
if self.training:
self.joint_transform = Compose(
[JointResize(in_size), RandomHorizontallyFlip(), RandomRotate(10)]
)
img_transform = [transforms.ColorJitter(0.1, 0.1, 0.1)]
self.mask_transform = transforms.ToTensor()
else:
# 输入的如果是一个tuple,则按照数据缩放,但是如果是一个数字,则按比例缩放到短边等于该值
img_transform = [
transforms.Resize((in_size, in_size), interpolation=Image.BILINEAR),
]
self.img_transform = transforms.Compose(
[
*img_transform,
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
# transforms.Normalize([0.341414, 0.357437, 0.298912], [0.143317, 0.112520, 0.113972]),
]
)
def __getitem__(self, index):
img_path, mask_path = self.imgs[index]
img_name = os.path.splitext(os.path.basename(img_path))[0]
img = Image.open(img_path).convert("RGB")
if self.training:
mask = Image.open(mask_path).convert("L")
img, mask = self.joint_transform(img, mask)
img = self.img_transform(img)
mask = self.mask_transform(mask)
if self.use_bigt:
mask = mask.ge(0.5).float() # 二值化
return img, mask, img_name
else:
# todo: When evaluating, the mask path may not exist. But our code defaults to its existence, which makes
# it impossible to use dataloader to generate a prediction without a mask path.
img = self.img_transform(img)
# img = img / 255.0
return img, mask_path, img_name
def __len__(self):
return len(self.imgs)
class ImageFolder2(Dataset): #增加的测试数据导入
def __init__(self, root, in_size, training, prefix, use_bigt=False):
self.training = training
self.use_bigt = use_bigt
if os.path.isdir(root):
construct_print(f"{root} is an image folder, we will test on it.")
self.imgs = _make_dataset2(root)
elif os.path.isfile(root):
construct_print(
f"{root} is a list of images, we will use these paths to read the "
f"corresponding image"
)
self.imgs = _make_dataset_from_list2(root, prefix=prefix)
else:
raise NotImplementedError
# 输入的如果是一个tuple,则按照数据缩放,但是如果是一个数字,则按比例缩放到短边等于该值
img_transform = [
transforms.Resize((in_size, in_size), interpolation=Image.BILINEAR),
]
self.img_transform = transforms.Compose(
[
*img_transform,
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
# transforms.Normalize([0.341414, 0.357437, 0.298912], [0.143317, 0.112520, 0.113972]),
]
)
def __getitem__(self, index):
# print(self.imgs[index][0])
img_path = self.imgs[index][0]
img_name = os.path.splitext(os.path.basename(img_path))[0]
img = Image.open(img_path).convert("RGB")
img = self.img_transform(img)
return img, img_name
def __len__(self):
return len(self.imgs)
class DataLoaderX(DataLoader):
def __iter__(self):
return BackgroundGenerator(super(DataLoaderX, self).__iter__())
def _collate_fn(batch, size_list):
size = random.choice(size_list)
img, mask, image_name = [list(item) for item in zip(*batch)]
img = torch.stack(img, dim=0)
img = interpolate(img, size=(size, size), mode="bilinear", align_corners=False)
mask = torch.stack(mask, dim=0)
mask = interpolate(mask, size=(size, size), mode="nearest")
return img, mask, image_name
def _mask_loader(dataset, shuffle, drop_last, size_list):
assert float(torch.__version__[:3]) >= 1.2, (
"If you want to use the pytorch < 1.2, you need to "
"comment out the line `collate_fn=...` when you set the `size_list` to `None`."
)
return DataLoaderX(
dataset=dataset,
collate_fn=partial(_collate_fn, size_list=size_list) if size_list else None,
batch_size=arg_config["batch_size"],
num_workers=arg_config["num_workers"],
shuffle=shuffle,
drop_last=drop_last,
pin_memory=True,
)
def create_loader(data_path, training, size_list=None, prefix=(".jpg", ".png"), get_length=False):
if training:
construct_print(f"Training on: {data_path}")
imageset = ImageFolder(
data_path,
in_size=arg_config["input_size"],
prefix=prefix,
use_bigt=arg_config["use_bigt"],
training=True,
)
loader = _mask_loader(imageset, shuffle=True, drop_last=True, size_list=size_list)
else:
construct_print(f"Testing on: {data_path}")
imageset = ImageFolder2(
data_path, in_size=arg_config["input_size"], prefix=prefix, training=False,
)
loader = _mask_loader(imageset, shuffle=False, drop_last=False, size_list=None)
if get_length:
length_of_dataset = len(imageset)
return loader, length_of_dataset
else:
return loader
if __name__ == "__main__":
loader = create_loader(
data_path=arg_config["rgb_data"]["tr_data_path"],
training=True,
get_length=False,
size_list=arg_config["size_list"],
)
for idx, train_data in enumerate(loader):
train_inputs, train_masks, *train_other_data = train_data
print(f"" f"batch: {idx} ", train_inputs.size(), train_masks.size())
3.训练
这个源码主要是用过配置文件控制的下面先说下配置文件
config.py
import os
__all__ = ["proj_root", "arg_config"]
from collections import OrderedDict
proj_root = os.path.dirname(__file__)
datasets_root = "./Dataset/"
#原作者的路径
# ecssd_path = os.path.join(datasets_root, "Saliency/RGBSOD", "ECSSD")
# dutomron_path = os.path.join(datasets_root, "Saliency/RGBSOD", "DUT-OMRON")
# hkuis_path = os.path.join(datasets_root, "Saliency/RGBSOD", "HKU-IS")
# pascals_path = os.path.join(datasets_root, "Saliency/RGBSOD", "PASCAL-S")
# soc_path = os.path.join(datasets_root, "Saliency/RGBSOD", "SOC/Test")
# dutstr_path = os.path.join(datasets_root, "Saliency/RGBSOD", "DUTS/Train")
# dutste_path = os.path.join(datasets_root, "Saliency/RGBSOD", "DUTS/Test")
#本人测试使用的路径
# dutstr_path = os.path.join(datasets_root, "ECSSD/Train")
ecssdte_path = os.path.join(datasets_root, "ECSSD/Test")
modelte_path = os.path.join(datasets_root, "TEST")
rivertr_path = os.path.join(datasets_root, "RIVER/Train")
riverte_path = os.path.join(datasets_root, "RIVER/Test")
buildtr_path = os.path.join(datasets_root, "BUILD/Train")
buildte_path = os.path.join(datasets_root, "BUILD/Test")
arg_config = {
"model": "MINet_VGG16", # 实际使用的模型,需要在`network/__init__.py`中导入
"info": "", # 关于本次实验的额外信息说明,这个会附加到本次试验的exp_name的结尾,如果为空,则不会附加内容。
"use_amp": False, # 是否使用amp加速训练
"resume_mode": "inference", # the mode for resume parameters: ['train', 'test', 'inference', ''] #这里注意了,由于我改过的缘故,训练选'',测试选inference
"use_aux_loss": False, # 是否使用辅助损失, 这个可以设置多个损失函数,需要在solver.py文件里的self.loss_funcs参数里增加
"save_pre": True, # 是否保留最终的预测结果
"epoch_num": 60, # 训练周期, 0: directly test model
"lr": 0.001, # 微调时缩小100倍
"xlsx_name": "result.xlsx", # the name of the record file
# 数据集设置
"rgb_data": {
"tr_data_path": buildtr_path, #训练路径
"te_data_list": OrderedDict(
{
# "pascal-s": pascals_path,
# "ecssd": ecssdte_path,
# "hku-is": hkuis_path,
# "duts": dutste_path,
# "dut-omron": dutomron_path,
# "soc": soc_path,
# "river": riverte_path,
"modelte": buildte_path, #测试路径
},
),
},
# 训练过程中的监控信息
"tb_update": 50, # >0 则使用tensorboard
"print_freq": 50, # >0, 保存迭代过程中的信息
# img_prefix, gt_prefix,用在使用索引文件的时候的对应的扩展名
"prefix": (".jpg", ".png"),
# if you dont use the multi-scale training, you can set 'size_list': None
# "size_list": [224, 256, 288, 320, 352],
"size_list": None, # 不使用多尺度训练
"reduction": "mean", # 损失处理的方式,可选“mean”和“sum”
# 优化器与学习率衰减
"optim": "adam", # 自定义部分的学习率
"weight_decay": 5e-4, # 微调时设置为0.0001
"momentum": 0.9,
"nesterov": False,
"sche_usebatch": False,
"lr_type": "poly",
"warmup_epoch": 1, # depond on the special lr_type, only lr_type has 'warmup', when set it to 1, it means no warmup.
"lr_decay": 0.9, # poly
"use_bigt": True, # 训练时是否对真值二值化(阈值为0.5)
"batch_size": 4, # 要是继续训练, 最好使用相同的batchsize
"num_workers": 0, # 不要太大, 不然运行多个程序同时训练的时候, 会造成数据读入速度受影响
"input_size": 512, #图像大小,里面会有resize 大小,和原本图像不一致会自动帮你resize
}
main.py
这个文件我加了infercence的选项,这个和配置文件里对应
import shutil
from datetime import datetime
from config import arg_config, proj_root
from utils.misc import construct_exp_name, construct_path, construct_print, pre_mkdir, set_seed
from utils.solver import Solver
construct_print(f"{datetime.now()}: Initializing...")
construct_print(f"Project Root: {proj_root}")
init_start = datetime.now()
exp_name = construct_exp_name(arg_config)
path_config = construct_path(
proj_root=proj_root, exp_name=exp_name, xlsx_name=arg_config["xlsx_name"],
)
pre_mkdir(path_config)
set_seed(seed=0, use_cudnn_benchmark=arg_config["size_list"] != None)
solver = Solver(exp_name, arg_config, path_config)
construct_print(f"Total initialization time:{datetime.now() - init_start}")
shutil.copy(f"{proj_root}/config.py", path_config["cfg_log"])
shutil.copy(f"{proj_root}/utils/solver.py", path_config["trainer_log"])
construct_print(f"{datetime.now()}: Start...")
if arg_config["resume_mode"] == "test":
solver.test()
elif arg_config["resume_mode"] == "inference": #增加了这里
solver.inference()
else:
solver.train()
construct_print(f"{datetime.now()}: End...")
solver.py
import os
from pprint import pprint
import numpy as np
import torch
from PIL import Image
from torchvision import transforms
from tqdm import tqdm
import network as network_lib
from loss.CEL import CEL
from loss.focal_loss import FocalLoss #下面这些loss函数都是我加的后面会打包一起给
from loss.dice_loss import DiceLoss
from loss.iou_loss import IoULoss
from utils.dataloader import create_loader
from utils.metric import cal_maxf, cal_pr_mae_meanf
from utils.misc import (
AvgMeter,
construct_print,
write_data_to_file,
)
from utils.pipeline_ops import (
get_total_loss,
make_optimizer,
make_scheduler,
resume_checkpoint,
save_checkpoint,
)
from utils.recorder import TBRecorder, Timer, XLSXRecoder
class Solver:
def __init__(self, exp_name: str, arg_dict: dict, path_dict: dict):
super(Solver, self).__init__()
self.exp_name = exp_name
self.arg_dict = arg_dict
self.path_dict = path_dict
self.dev = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
self.to_pil = transforms.ToPILImage()
self.tr_data_path = self.arg_dict["rgb_data"]["tr_data_path"]
self.te_data_list = self.arg_dict["rgb_data"]["te_data_list"]
self.save_path = self.path_dict["save"]
self.save_pre = self.arg_dict["save_pre"]
if self.arg_dict["tb_update"] > 0:
self.tb_recorder = TBRecorder(tb_path=self.path_dict["tb"])
if self.arg_dict["xlsx_name"]:
self.xlsx_recorder = XLSXRecoder(xlsx_path=self.path_dict["xlsx"])
# 依赖与前面属性的属性
self.tr_loader = create_loader(
data_path=self.tr_data_path,
training=True,
size_list=self.arg_dict["size_list"],
prefix=self.arg_dict["prefix"],
get_length=False,
)
self.end_epoch = self.arg_dict["epoch_num"]
self.iter_num = self.end_epoch * len(self.tr_loader)
if hasattr(network_lib, self.arg_dict["model"]):
self.net = getattr(network_lib, self.arg_dict["model"])().to(self.dev)
else:
raise AttributeError
pprint(self.arg_dict)
if self.arg_dict["resume_mode"] == "test":
# resume model only to test model.
# self.start_epoch is useless
resume_checkpoint(
model=self.net, load_path=self.path_dict["final_state_net"], mode="onlynet",
)
return
#因为新加了inference,所以这里也对应加了
if self.arg_dict["resume_mode"] == "inference":
# resume model only to test model.
# self.start_epoch is useless
resume_checkpoint(
model=self.net, load_path=self.path_dict["final_state_net"], mode="onlynet",
)
return
#可以多个loss,记得把config.py文件对应位置的设置改为True
self.loss_funcs = [
# torch.nn.BCEWithLogitsLoss(reduction=self.arg_dict["reduction"]).to(self.dev)
# FocalLoss()
IoULoss()
]
if self.arg_dict["use_aux_loss"]:
self.loss_funcs.append(CEL().to(self.dev))
self.opti = make_optimizer(
model=self.net,
optimizer_type=self.arg_dict["optim"],
optimizer_info=dict(
lr=self.arg_dict["lr"],
momentum=self.arg_dict["momentum"],
weight_decay=self.arg_dict["weight_decay"],
nesterov=self.arg_dict["nesterov"],
),
)
self.sche = make_scheduler(
optimizer=self.opti,
total_num=self.iter_num if self.arg_dict["sche_usebatch"] else self.end_epoch,
scheduler_type=self.arg_dict["lr_type"],
scheduler_info=dict(
lr_decay=self.arg_dict["lr_decay"], warmup_epoch=self.arg_dict["warmup_epoch"]
),
)
# AMP
if self.arg_dict["use_amp"]:
construct_print("Now, we will use the amp to accelerate training!")
from apex import amp
self.amp = amp
self.net, self.opti = self.amp.initialize(self.net, self.opti, opt_level="O1")
else:
self.amp = None
if self.arg_dict["resume_mode"] == "train":
# resume model to train the model
self.start_epoch = resume_checkpoint(
model=self.net,
optimizer=self.opti,
scheduler=self.sche,
amp=self.amp,
exp_name=self.exp_name,
load_path=self.path_dict["final_full_net"],
mode="all",
)
else:
# only train a new model.
self.start_epoch = 0
def train(self):
for curr_epoch in range(self.start_epoch, self.end_epoch):
train_loss_record = AvgMeter()
self._train_per_epoch(curr_epoch, train_loss_record)
# 根据周期修改学习率
if not self.arg_dict["sche_usebatch"]:
self.sche.step()
# 每个周期都进行保存测试,保存的是针对第curr_epoch+1周期的参数
save_checkpoint(
model=self.net,
optimizer=self.opti,
scheduler=self.sche,
amp=self.amp,
exp_name=self.exp_name,
current_epoch=curr_epoch + 1,
full_net_path=self.path_dict["final_full_net"],
state_net_path=self.path_dict["final_state_net"],
) # 保存参数
#这里被我注释了,如果要用,需要把dataloader.py 里面的create_loader函数中的ImageFolder2换成ImageFolder
# if self.arg_dict["use_amp"]:
# # https://github.com/NVIDIA/apex/issues/567
# with self.amp.disable_casts():
# construct_print("When evaluating, we wish to evaluate in pure fp32.")
# self.test()
# else:
# self.test()
@Timer
def _train_per_epoch(self, curr_epoch, train_loss_record):
for curr_iter_in_epoch, train_data in enumerate(self.tr_loader):
num_iter_per_epoch = len(self.tr_loader)
curr_iter = curr_epoch * num_iter_per_epoch + curr_iter_in_epoch
self.opti.zero_grad()
train_inputs, train_masks, _ = train_data
train_inputs = train_inputs.to(self.dev, non_blocking=True)
train_masks = train_masks.to(self.dev, non_blocking=True)
train_preds = self.net(train_inputs)
train_loss, loss_item_list = get_total_loss(train_preds, train_masks, self.loss_funcs)
if self.amp:
with self.amp.scale_loss(train_loss, self.opti) as scaled_loss:
scaled_loss.backward()
else:
train_loss.backward()
self.opti.step()
if self.arg_dict["sche_usebatch"]:
self.sche.step()
# 仅在累计的时候使用item()获取数据
train_iter_loss = train_loss.item()
train_batch_size = train_inputs.size(0)
train_loss_record.update(train_iter_loss, train_batch_size)
# 显示tensorboard
if (
self.arg_dict["tb_update"] > 0
and (curr_iter + 1) % self.arg_dict["tb_update"] == 0
):
self.tb_recorder.record_curve("trloss_avg", train_loss_record.avg, curr_iter)
self.tb_recorder.record_curve("trloss_iter", train_iter_loss, curr_iter)
self.tb_recorder.record_curve("lr", self.opti.param_groups, curr_iter)
self.tb_recorder.record_image("trmasks", train_masks, curr_iter)
self.tb_recorder.record_image("trsodout", train_preds.sigmoid(), curr_iter)
self.tb_recorder.record_image("trsodin", train_inputs, curr_iter)
# 记录每一次迭代的数据
if (
self.arg_dict["print_freq"] > 0
and (curr_iter + 1) % self.arg_dict["print_freq"] == 0
):
lr_str = ",".join(
[f"{param_groups['lr']:.7f}" for param_groups in self.opti.param_groups]
)
log = (
f"{curr_iter_in_epoch}:{num_iter_per_epoch}/"
f"{curr_iter}:{self.iter_num}/"
f"{curr_epoch}:{self.end_epoch} "
f"{self.exp_name}\n"
f"Lr:{lr_str} "
f"M:{train_loss_record.avg:.5f} C:{train_iter_loss:.5f} "
f"{loss_item_list}"
)
print(log)
write_data_to_file(log, self.path_dict["tr_log"])
def test(self):
self.net.eval()
total_results = {}
for data_name, data_path in self.te_data_list.items():
construct_print(f"Testing with testset: {data_name}")
self.te_loader = create_loader(
data_path=data_path,
training=False,
prefix=self.arg_dict["prefix"],
get_length=False,
)
self.save_path = os.path.join(self.path_dict["save"], data_name)
if not os.path.exists(self.save_path):
construct_print(f"{self.save_path} do not exist. Let's create it.")
os.makedirs(self.save_path)
results = self._test_process(save_pre=self.save_pre)
msg = f"Results on the testset({data_name}:'{data_path}'): {results}"
construct_print(msg)
write_data_to_file(msg, self.path_dict["te_log"])
total_results[data_name] = results
self.net.train()
if self.arg_dict["xlsx_name"]:
# save result into xlsx file.
self.xlsx_recorder.write_xlsx(self.exp_name, total_results)
def _test_process(self, save_pre):
loader = self.te_loader
pres = [AvgMeter() for _ in range(256)]
recs = [AvgMeter() for _ in range(256)]
meanfs = AvgMeter()
maes = AvgMeter()
tqdm_iter = tqdm(enumerate(loader), total=len(loader), leave=False)
for test_batch_id, test_data in tqdm_iter:
tqdm_iter.set_description(f"{self.exp_name}: te=>{test_batch_id + 1}")
with torch.no_grad():
in_imgs, in_mask_paths, in_names = test_data
in_imgs = in_imgs.to(self.dev, non_blocking=True)
outputs = self.net(in_imgs)
outputs_np = outputs.sigmoid().cpu().detach()
for item_id, out_item in enumerate(outputs_np):
gimg_path = os.path.join(in_mask_paths[item_id])
gt_img = Image.open(gimg_path).convert("L")
out_img = self.to_pil(out_item).resize(gt_img.size, resample=Image.NEAREST)
if save_pre:
oimg_path = os.path.join(self.save_path, in_names[item_id] + ".png")
out_img.save(oimg_path)
gt_img = np.array(gt_img)
out_img = np.array(out_img)
ps, rs, mae, meanf = cal_pr_mae_meanf(out_img, gt_img)
for pidx, pdata in enumerate(zip(ps, rs)):
p, r = pdata
pres[pidx].update(p)
recs[pidx].update(r)
maes.update(mae)
meanfs.update(meanf)
maxf = cal_maxf([pre.avg for pre in pres], [rec.avg for rec in recs])
results = {"MAXF": maxf, "MEANF": meanfs.avg, "MAE": maes.avg}
return results
#这里是我加的
def inference(self):
self.net.eval()
total_results = {}
for data_name, data_path in self.te_data_list.items():
construct_print(f"Testing with testset: {data_name}")
self.te_loader = create_loader(
data_path=data_path,
training=False,
prefix=self.arg_dict["prefix"],
get_length=False,
)
self.save_path = os.path.join(self.path_dict["save"], data_name)
if not os.path.exists(self.save_path):
construct_print(f"{self.save_path} do not exist. Let's create it.")
os.makedirs(self.save_path)
self._inference_process(save_pre=self.save_pre)
# msg = f"Results on the testset({data_name}:'{data_path}'): {results}"
# construct_print(msg)
# write_data_to_file(msg, self.path_dict["te_log"])
# total_results[data_name] = results
# self.net.train()
# if self.arg_dict["xlsx_name"]:
# # save result into xlsx file.
# self.xlsx_recorder.write_xlsx(self.exp_name, total_results)
def _inference_process(self, save_pre):
loader = self.te_loader
tqdm_iter = tqdm(enumerate(loader), total=len(loader), leave=False)
for test_batch_id, test_data in tqdm_iter:
tqdm_iter.set_description(f"{self.exp_name}: te=>{test_batch_id + 1}")
with torch.no_grad():
in_imgs, in_names= test_data
# print(in_imgs.shape)
in_imgs = in_imgs.to(self.dev, non_blocking=True)
outputs = self.net(in_imgs)
outputs_np = outputs.sigmoid().cpu().detach()
for item_id, out_item in enumerate(outputs_np):
out_img = self.to_pil(out_item).resize((256,256), resample=Image.NEAREST)
if save_pre:
oimg_path = os.path.join(self.save_path, in_names[item_id] + ".png")
out_img.save(oimg_path)
pipeline_ops.py
这里改了loss获取的函数get_total_loss,自己的loss会报一个错,这么改了以后能用
import os
import torch
import torch.nn as nn
import torch.optim.optimizer as optim
import torch.optim.lr_scheduler as sche
import numpy as np
from torch.optim import Adam, SGD
from utils.misc import construct_print
def get_total_loss(
train_preds: torch.Tensor, train_masks: torch.Tensor, loss_funcs: list
) -> (float, list):
"""
return the sum of the list of loss functions with train_preds and train_masks
Args:
train_preds (torch.Tensor): predictions
train_masks (torch.Tensor): masks
loss_funcs (list): the list of loss functions
Returns: the sum of all losses and the list of result strings
"""
loss_list = []
loss_item_list = []
assert len(loss_funcs) != 0, "请指定损失函数`loss_funcs`"
for loss in loss_funcs:
loss_out = loss(train_preds, train_masks)
try:
loss_list.append(loss_out)
loss_item_list.append(f"{loss_out.item():.5f}")
except:
loss_list.append(loss_out)
loss_item_list.append(f"{loss_out:.5f}")
train_loss = sum(loss_list)
return train_loss, loss_item_list
def save_checkpoint(
model: nn.Module = None,
optimizer: optim.Optimizer = None,
scheduler: sche._LRScheduler = None,
amp=None,
exp_name: str = "",
current_epoch: int = 1,
full_net_path: str = "",
state_net_path: str = "",
):
"""
保存完整参数模型(大)和状态参数模型(小)
Args:
model (nn.Module): model object
optimizer (optim.Optimizer): optimizer object
scheduler (sche._LRScheduler): scheduler object
amp (): apex.amp
exp_name (str): exp_name
current_epoch (int): in the epoch, model **will** be trained
full_net_path (str): the path for saving the full model parameters
state_net_path (str): the path for saving the state dict.
"""
state_dict = {
"arch": exp_name,
"epoch": current_epoch,
"net_state": model.state_dict(),
"opti_state": optimizer.state_dict(),
"sche_state": scheduler.state_dict(),
"amp_state": amp.state_dict() if amp else None,
}
torch.save(state_dict, full_net_path)
torch.save(model.state_dict(), state_net_path)
def resume_checkpoint(
model: nn.Module = None,
optimizer: optim.Optimizer = None,
scheduler: sche._LRScheduler = None,
amp=None,
exp_name: str = "",
load_path: str = "",
mode: str = "all",
):
"""
从保存节点恢复模型
Args:
model (nn.Module): model object
optimizer (optim.Optimizer): optimizer object
scheduler (sche._LRScheduler): scheduler object
amp (): apex.amp
exp_name (str): exp_name
load_path (str): 模型存放路径
mode (str): 选择哪种模型恢复模式:
- 'all': 回复完整模型,包括训练中的的参数;
- 'onlynet': 仅恢复模型权重参数
Returns mode: 'all' start_epoch; 'onlynet' None
"""
if os.path.exists(load_path) and os.path.isfile(load_path):
construct_print(f"Loading checkpoint '{load_path}'")
checkpoint = torch.load(load_path)
if mode == "all":
if exp_name and exp_name != checkpoint["arch"]:
# 如果给定了exp_name,那么就必须匹配对应的checkpoint["arch"],否则不作要求
raise Exception(f"We can not match {exp_name} with {load_path}.")
start_epoch = checkpoint["epoch"]
if hasattr(model, "module"):
model.module.load_state_dict(checkpoint["net_state"])
else:
model.load_state_dict(checkpoint["net_state"])
optimizer.load_state_dict(checkpoint["opti_state"])
scheduler.load_state_dict(checkpoint["sche_state"])
if checkpoint.get("amp_state", None):
if amp:
amp.load_state_dict(checkpoint["amp_state"])
else:
construct_print("You are not using amp.")
else:
construct_print("The state_dict of amp is None.")
construct_print(
f"Loaded '{load_path}' " f"(will train at epoch" f" {checkpoint['epoch']})"
)
return start_epoch
elif mode == "onlynet":
if hasattr(model, "module"):
model.module.load_state_dict(checkpoint)
else:
model.load_state_dict(checkpoint)
construct_print(
f"Loaded checkpoint '{load_path}' " f"(only has the model's weight params)"
)
else:
raise NotImplementedError
else:
raise Exception(f"{load_path}路径不正常,请检查")
def make_scheduler(
optimizer: optim.Optimizer, total_num: int, scheduler_type: str, scheduler_info: dict
) -> sche._LRScheduler:
def get_lr_coefficient(curr_epoch):
nonlocal total_num
# curr_epoch start from 0
# total_num = iter_num if args["sche_usebatch"] else end_epoch
if scheduler_type == "poly":
coefficient = pow((1 - float(curr_epoch) / total_num), scheduler_info["lr_decay"])
elif scheduler_type == "poly_warmup":
turning_epoch = scheduler_info["warmup_epoch"]
if curr_epoch < turning_epoch:
# 0,1,2,...,turning_epoch-1
coefficient = 1 / turning_epoch * (1 + curr_epoch)
else:
# turning_epoch,...,end_epoch
curr_epoch -= turning_epoch - 1
total_num -= turning_epoch - 1
coefficient = pow((1 - float(curr_epoch) / total_num), scheduler_info["lr_decay"])
elif scheduler_type == "cosine_warmup":
turning_epoch = scheduler_info["warmup_epoch"]
if curr_epoch < turning_epoch:
# 0,1,2,...,turning_epoch-1
coefficient = 1 / turning_epoch * (1 + curr_epoch)
else:
# turning_epoch,...,end_epoch
curr_epoch -= turning_epoch - 1
total_num -= turning_epoch - 1
coefficient = (1 + np.cos(np.pi * curr_epoch / total_num)) / 2
elif scheduler_type == "f3_sche":
coefficient = 1 - abs((curr_epoch + 1) / (total_num + 1) * 2 - 1)
else:
raise NotImplementedError
return coefficient
scheduler = sche.LambdaLR(optimizer, lr_lambda=get_lr_coefficient)
return scheduler
def make_optimizer(model: nn.Module, optimizer_type: str, optimizer_info: dict) -> optim.Optimizer:
if optimizer_type == "sgd_trick":
# https://github.com/implus/PytorchInsight/blob/master/classification/imagenet_tricks.py
params = [
{
"params": [
p for name, p in model.named_parameters() if ("bias" in name or "bn" in name)
],
"weight_decay": 0,
},
{
"params": [
p
for name, p in model.named_parameters()
if ("bias" not in name and "bn" not in name)
]
},
]
optimizer = SGD(
params,
lr=optimizer_info["lr"],
momentum=optimizer_info["momentum"],
weight_decay=optimizer_info["weight_decay"],
nesterov=optimizer_info["nesterov"],
)
elif optimizer_type == "sgd_r3":
params = [
# 不对bias参数执行weight decay操作,weight decay主要的作用就是通过对网络
# 层的参数(包括weight和bias)做约束(L2正则化会使得网络层的参数更加平滑)达
# 到减少模型过拟合的效果。
{
"params": [
param for name, param in model.named_parameters() if name[-4:] == "bias"
],
"lr": 2 * optimizer_info["lr"],
},
{
"params": [
param for name, param in model.named_parameters() if name[-4:] != "bias"
],
"lr": optimizer_info["lr"],
"weight_decay": optimizer_info["weight_decay"],
},
]
optimizer = SGD(params, momentum=optimizer_info["momentum"])
elif optimizer_type == "sgd_all":
optimizer = SGD(
model.parameters(),
lr=optimizer_info["lr"],
weight_decay=optimizer_info["weight_decay"],
momentum=optimizer_info["momentum"],
)
elif optimizer_type == "adam":
optimizer = Adam(
model.parameters(),
lr=optimizer_info["lr"],
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=optimizer_info["weight_decay"],
)
elif optimizer_type == "f3_trick":
backbone, head = [], []
for name, params_tensor in model.named_parameters():
if name.startswith("div_2"):
pass
elif name.startswith("div"):
backbone.append(params_tensor)
else:
head.append(params_tensor)
params = [
{"params": backbone, "lr": 0.1 * optimizer_info["lr"]},
{"params": head, "lr": optimizer_info["lr"]},
]
optimizer = SGD(
params=params,
momentum=optimizer_info["momentum"],
weight_decay=optimizer_info["weight_decay"],
nesterov=optimizer_info["nesterov"],
)
else:
raise NotImplementedError
print("optimizer = ", optimizer)
return optimizer
if __name__ == "__main__":
a = torch.rand((3, 3)).bool()
print(isinstance(a, torch.FloatTensor), a.type())
4.预测
训练完以后自动生成一个ouput文件夹,当你config.py文件都设置好以后这个会自动生成配置很多东西,记得测试要设置"resume_mode": “inference”,结果存储的位置也在output里的pre文件夹中
下面是我传到百度网盘的参考,数据前面的博客提供了,这里面没放数据
链接:https://pan.baidu.com/s/1n1gfGEIm9kibVAwv8ifKNA
提取码:7477
复制这段内容后打开百度网盘手机App,操作更方便哦–来自百度网盘超级会员V5的分享
标签:显著性,name,img,MINet,list,dict,path,地物,self 来源: https://blog.csdn.net/qq_20373723/article/details/112739902