GitHub开源项目Hyperspectral-Classification的解析
作者:互联网
GitHub链接:Hyperspectral-Classification Pytorch。
项目简介
项目的作者是Xidian university,是基于PyTorch的高光谱图像地物目标的分类程序。该项目兼容Python 2.7和Python 3.5+,基于PyTorch深度学习和GPU计算框架,并使用Visdom可视化服务器。
预定义的公开的数据集有:
- 帕维亚大学
- 帕维亚中心
- 肯尼迪航天中心
- 印度松树
- 博茨瓦纳
用户也可添加自定义的数据集,示例是“数据融合大赛2018的高光谱数据集”DFC2018_HSI。开发人员应该为CUSTOM_DATASETS_CONFIG变量添加一个新条目,并为其用例定义特定的数据加载器。
该工具实现了scikit-learn库中的几个SVM变体以及PyTorch中实现的许多最先进的深度网络:
- SVM(带网格搜索的线性,RBF和多核)
- SGD(使用随机梯度下降的线性SVM进行快速优化)
基线神经网络(4个完全连接的层,有丢失) - 1D CNN(用于高光谱图像分类的深度卷积神经网络,Hu等人,Journal of Sensors 2015)
- 半监督的1D CNN(Autoencodeurs pour la visualization d’images hyperspectrales,Boulch et al。,GRETSI 2017)
- 2D CNN(用于图像分类和频带选择的高光谱CNN,应用于人脸识别,Sharma等,技术报告2018)
- 半监督2D CNN(用于高光谱图像分类的半监督卷积神经网络,Liu等,遥感信函2017)
- 3D CNN(用于遥感图像分类的三维深度学习方法,Hamida等,TGRS 2018)
- 3D FCN(基于上下文深度CNN的高光谱分类,Lee和Kwon,IGARSS 2016)
- 3D CNN(基于卷积神经网络的深度特征提取和高光谱图像分类,Chen等,TGRS 2016)
- 3D CNN(三维卷积神经网络的高光谱图像的光谱 - 空间分类,Li等,遥感2017)
- 3D CNN(HSI-CNN:用于高光谱图像的新型卷积神经网络,Luo等,ICPR 2018)
- 多尺度3D CNN(用于高光谱图像分类的多尺度3D深度卷积神经网络,He等,ICIP 2017)
用户也可以通过修改models.py
文件来添加自定义深层网络。这意味着为自定义深层网络创建一个新类并更改该get_model
功能。
项目各模块和函数的解析
utils.py
get_device(ordinal)
功能:
根据输入参数,判断device
为CPU或GPU。
输入和输出:
输入:
ordinal
:一个int
类型的数,表示用哪个GPU
输出:
device
:一个超参数,表示运算的位置(CPU or GPU)
代码:
def get_device(ordinal):
# Use GPU ?
if ordinal < 0:
print("Computation on CPU")
device = torch.device('cpu')
elif torch.cuda.is_available():
print("Computation on CUDA GPU device {}".format(ordinal))
device = torch.device('cuda:{}'.format(ordinal))
else:
print("/!\\ CUDA was requested but is not available! Computation will go on CPU. /!\\")
device = torch.device('cpu')
return device
解析:
其实就是一个简单的分支结构:
ordinal < 0
:CPUordinal < 0
且orch.cuda.is_available() == True
:GPUordinal < 0
且orch.cuda.is_available() == False
:CPU
open_file(dataset)
功能:
打开指定的数据集的文件。
输入和输出:
输入:
dataset
:数据集文件的完整路径,比如C:\Datasets\OwnData\OwnData.mat
。
输出
(以读取.mat
为例,因为读取的以.mat
文件居多):
- 一个以变量名为键,以数据为值的字典
dictionary
。
代码:
def open_file(dataset):
_, ext = os.path.splitext(dataset)
ext = ext.lower()
if ext == '.mat':
# Load Matlab array
return io.loadmat(dataset)
elif ext == '.tif' or ext == '.tiff':
# Load TIFF file
return misc.imread(dataset)
elif ext == '.hdr':
img = spectral.open_image(dataset)
return img.load()
else:
raise ValueError("Unknown file format: {}".format(ext))
解析:
最重要的是 _, ext = os.path.splitext(dataset)
中的os.path.splitext(path)
函数。
该函数将输入的路径path
拆分为文件名 + 扩展名,并依次作为返回值。_, ext
表示只获取扩展名,存入变量ext
。之后就是根据不同的扩展名选择不同的打开方式。
需要注意的是,打开.mat
文件,返回值是一个以变量名为键,以数据为值的字典dictionary
。要取出其中的数据,需要通过字典操作,通过访问键来获取值,比如img = open_file(folder + 'OwnData.mat')['Data']
。
convert_to_color_()
功能:
将标签数组转换为RGB颜色编码图像。
输入和输出:
输入:
arr_2d
:int
类型的二维的标签数组(int 2D array of labels)palette
: 每个标签对应的RGB元组,三个值(dict of colors used (label number -> RGB tuple) )
输出:
int
RGB格式的彩色编码标签的2D图像(int 2D images of color-encoded labels in RGB format)
代码:
def convert_to_color_(arr_2d, palette=None):
"""Convert an array of labels to RGB color-encoded image.
Args:
arr_2d: int 2D array of labels
palette: dict of colors used (label number -> RGB tuple) # 哪个标签对应什么样的颜色(RGB三个值)
Returns:
arr_3d: int 2D images of color-encoded labels in RGB format # RGB三通道图像
"""
arr_3d = np.zeros((arr_2d.shape[0], arr_2d.shape[1], 3), dtype=np.uint8) # 确定维度和编码方式,行列为arr_2d的行列,编码方式为uint8
# 异常报错
if palette is None:
raise Exception("Unknown color palette")
for c, i in palette.items():
m = arr_2d == c
arr_3d[m] = i
return arr_3d
解析:
(暂略)
convert_from_color_()
功能:
将RGB编码图像转换为灰度标签。
输入和输出:
输入:
arr_3d
: int 2D image of color-coded labels on 3 channelspalette
:dict of colors used (RGB tuple -> label number)
输出:
arr_2d
: int 2D array of labels
代码:
def convert_from_color_(arr_3d, palette=None):
"""Convert an RGB-encoded image to grayscale labels.
Args:
arr_3d: int 2D image of color-coded labels on 3 channels
palette: dict of colors used (RGB tuple -> label number)
Returns:
arr_2d: int 2D array of labels
"""
if palette is None:
raise Exception("Unknown color palette")
arr_2d = np.zeros((arr_3d.shape[0], arr_3d.shape[1]), dtype=np.uint8)
for c, i in palette.items():
m = np.all(arr_3d == np.array(c).reshape(1, 1, 3), axis=2)
arr_2d[m] = i
return arr_2d
解析:
(暂略)
display_predictions()
功能:
使用visdom
可视化服务来可视化预测结果。
输入和输出:
输入:
pred
:预测结果,二维vis
:vis服务gt
:ground truthcaption
:图表名称
输出:
visdom
服务器网址显示图表
代码:
def display_predictions(pred, vis, gt=None, caption=""): # caption 字幕
if gt is None:
vis.images([np.transpose(pred, (2, 0, 1))],
opts={'caption': caption})
else:
vis.images([np.transpose(pred, (2, 0, 1)),
np.transpose(gt, (2, 0, 1))],
nrow=2,
opts={'caption': caption})
解析:
函数整体是一个简单的分支结构,分为gt is None
和gt is not None
两种情况。
当gt is None
时:
vis.images()
函数绘制一个列表images。它需要一个输入B x C x H x W
(C
->channel;H
->height;W
->width)张量或list of images全部相同的大小。它使大小的图像(B / Nrow,Nrow)的网格。
vis.images()
的可调参数如下:
nrow
:连续的图像数量padding
:在图像周围填充,四边均匀填充opts.jpgquality
:JPG质量(number0-100;默认= 100)opts.caption
:图像的标题
所以vis.images([np.transpose(pred, (2, 0, 1))], opts={'caption': caption})
中主要有两部分
[np.transpose(pred, (2, 0, 1))]
表示要可视化的图。opts={'caption': caption}
表示可选的操作。
np.transpose()
是交换矩阵维度的数组(详见另一篇博客维度交换函数——a.transpose(m,n,r)),因为原始的图像的维度排序默认是H×W×C
,而vis.images()
要求的是C×H×W
。所以张量维度的维度排序就要从(0,1,2)
变为(2,0,1)
,而这通过np.transpose()
函数实现。
opts={'caption': caption}
就是给图加标题。
当gt is not None
时:
对pred
和gt
都要通过np.transpose()
函数来进行维度交换,另外参数nrow
需要指定为2
。
display_dataset()
功能:
选择3个波段作为RGB波段,显示RGB合成图像。
输入和输出:
输入:
img
: 3D hyperspectral imagegt
: 2D array labelsbands
: tuple of RGB bands to selectlabels
: list of label class namespalette
: dict of colorsdisplay
(optional): type of display, if any
但是,只有img
和bands
这两个变量被用到了,其他4个变量都没有用到。
输出:
visdom
服务器网址显示图表
代码:
def display_dataset(img, gt, bands, labels, palette, vis):
"""Display the specified dataset.
Args:
img: 3D hyperspectral image
gt: 2D array labels
bands: tuple of RGB bands to select
labels: list of label class names
palette: dict of colors
display (optional): type of display, if any
"""
print("Image has dimensions {}x{} and {} channels".format(*img.shape))
rgb = spectral.get_rgb(img, bands) # 从SpyFile对象或numpy数组中提取RGB数据以供显示。
rgb /= np.max(rgb) # 最大值化处理
rgb = np.asarray(255 * rgb, dtype='uint8') # 转为ndarray类型
# Display the RGB composite image 显示RGB合成图像
caption = "RGB (bands {}, {}, {})".format(*bands) # *来拆解变量
# send to visdom server
vis.images([np.transpose(rgb, (2, 0, 1))],
opts={'caption': caption})
解析:
首先通过rgb = spectral.get_rgb(img, bands)
来获取img
中的指定的波段bands
,来作为RGB波段。然后最大值化处理rgb /= np.max(rgb)
,将数值放缩到[0,1]之间。然后通过rgb = np.asarray(255 * rgb, dtype='uint8')
来将rgb
放缩到[0,255]之间,同时设定dtype=‘uint8’,即uint8编码的RGB图像。
之后就是visdom server
的操作,先设置标题 caption = "RGB (bands {}, {}, {})".format(*bands)
,其中format(*bands)
通过*
将列表类型(我猜的)的bands
拆解,分别输出。然后调用vis.images()
来将rgb
可视化,参数解析见上面的display_predictions()
函数部分。
explore_spectrums()
(暂略)
plot_spectrums()
(暂略)
build_dataset()
功能:
根据图像和蒙版创建训练样本列表。
输入和输出:
输入:
mat
: 3D hyperspectral matrix to extract the spectrums from # 用来提取光谱的高光谱矩阵gt
: 2D ground truthignored_labels
(optional): list of classes to ignore, e.g. 0 to remove
输出:
- 根据图像和蒙版创建训练样本列表(Create a list of training samples based on an image and a mask.)
代码:
ef build_dataset(mat, gt, ignored_labels=None):
"""Create a list of training samples based on an image and a mask.
Args:
mat: 3D hyperspectral matrix to extract the spectrums from # 用来提取光谱的高光谱矩阵
gt: 2D ground truth
ignored_labels (optional): list of classes to ignore, e.g. 0 to remove
unlabeled pixels
return_indices (optional): bool set to True to return the indices of
the chosen samples
"""
samples = []
labels = []
# Check that image and ground truth have the same 2D dimensions
assert mat.shape[:2] == gt.shape[:2] # 检查维度是否相符,比如PaviaU的mat和gt都是(610, 340)
for label in np.unique(gt):
if label in ignored_labels:
continue
else:
indices = np.nonzero(gt == label) # 返回同一类标签的全部索引。(对gt每个元素判断是否为label,是的话为1否则为0,然后提取全部的非零元素的索引
samples += list(mat[indices])
labels += len(indices[0]) * [label]
return np.asarray(samples), np.asarray(labels)
解析:
首先检查数组维度是否相同,通过assert
关键字实现,其中assert condition
等于 if not condition: raise AssertionError()
。
assert mat.shape[:2] == gt.shape[:2]
来检查mat
和gt
的前两个维度是否相同。
mat.shape[:2]
为提取mat
数组的前两个维度。
np.unique()
函数返回值为The sorted unique values,类型为ndarray。
之后遍历np.unique()
的返回值,通过np.nonzero(gt == label)
获取每次遍历的gt
中gt == label
的元素的索引,返回为indices
。
之后将mat
中对应索引的元素通过samples += list(mat[indices])
来扩充到samples
中。
下面通过一个简单的实例来说明:
import random
import numpy as np
mat = np.array([[0,0,0,0,0],[0,100,200,300,0],[0,200,300,200,0],[0,300,200,100,0],[0,0,0,0,0]])
gt = np.array([[0,0,0,0,0],[0,1,2,3,0],[0,2,3,2,0],[0,3,2,1,0],[0,0,0,0,0]])
ignored_labels = [0]
samples = []
labels = []
# Check that image and ground truth have the same 2D dimensions
assert mat.shape[:2] == gt.shape[:2] # 检查维度是否相符,比如PaviaU的mat和gt都是(610, 340)
for label in np.unique(gt):
if label in ignored_labels:
continue
else:
indices = np.nonzero(gt == label) # 返回同一类标签的全部索引。(对gt每个元素判断是否为label,是的话为1否则为0,然后提取全部的非零元素的索引
samples += list(mat[indices])
labels += len(indices[0]) * [label]
print(mat)
# [[ 0 0 0 0 0]
# [ 0 100 200 300 0]
# [ 0 200 300 200 0]
# [ 0 300 200 100 0]
# [ 0 0 0 0 0]]
print(gt)
# [[0 0 0 0 0]
# [0 1 2 3 0]
# [0 2 3 2 0]
# [0 3 2 1 0]
# [0 0 0 0 0]]
print(samples)
# [100, 100, 200, 200, 200, 200, 300, 300, 300]
print(labels)
# [1, 1, 2, 2, 2, 2, 3, 3, 3]
get_random_pos()
功能:
随机返回输入图像的一个corner(Return the corners of a random window in the input image)
输入和输出:
输入:
img
: 2D (or more) image, e.g. RGB or grayscale imagewindow_shape
: (width, height) tuple of the window
输出:
-
xmin
,xmax
,ymin
,ymax
: tuple of the corners of the window代表corner位置的两个点(左下角和右上角),表现为两个参数。
代码:
def get_random_pos(img, window_shape):
""" Return the corners of a random window in the input image
Args:
img: 2D (or more) image, e.g. RGB or grayscale image
window_shape: (width, height) tuple of the window
Returns:
xmin, xmax, ymin, ymax: tuple of the corners of the window
"""
# 思路:先随机找到一个点,然后在此基础上加上网格的w和h两个点构成一个网格
w, h = window_shape
W, H = img.shape[:2] # 获取img的前两个维度
x1 = random.randint(0, W - w - 1) # 前闭后闭区间内产生随机数
x2 = x1 + w
y1 = random.randint(0, H - h - 1)
y2 = y1 + h
return x1, x2, y1, y2
解析
先通过 w, h = window_shape
从输入元组中提取width
和height
,再通过 W, H = img.shape[:2]
获取img
的前两个维度W
和H
。
然后生成左下角点的位置,所用函数是random.randint()
,在前闭后闭区间产生随机数。W
维度的随机数的范围是(0, W - w - 1)
,H
维度同理。
将左下角(x1,y1)
分别加上w
和h
,则得到右下角(x2,y2)
。这样就表示了一个corner。
最后将xmin
, xmax
, ymin
, ymax
这4个作为函数的返回值返回。
sliding_window()
功能:
生成在输入图像上滑动窗口生成器(Sliding window generator over an input image)
输入和输出:
输入:
image
: 2D+ image to slide the window on, e.g. RGB or hyperspectralstep
: int stride of the sliding windowwindow_size
: int tuple, width and height of the windowwith_data
(optional): bool set to True to return both the data and the corner indices
输出:
当with_data
为真时,返回image[x:x + w, y:y + h], x, y, w, h
, 即窗口的数据和窗口的位置参数。当with_data
为假时,返回x, y, w, h
,即仅仅返回窗口的位置参数。
代码:
def sliding_window(image, step=10, window_size=(20, 20), with_data=True):
"""Sliding window generator over an input image. # 在输入图像上滑动窗口生成器
Args:
image: 2D+ image to slide the window on, e.g. RGB or hyperspectral
step: int stride of the sliding window
window_size: int tuple, width and height of the window
with_data (optional): bool set to True to return both the data and the
corner indices
Yields:
([data], x, y, w, h) where x and y are the top-left corner of the
window, (w,h) the window size
"""
# slide a window across the image
w, h = window_size
W, H = image.shape[:2]
offset_w = (W - w) % step
offset_h = (H - h) % step
for x in range(0, W - w + offset_w, step):
if x + w > W:
x = W - w
for y in range(0, H - h + offset_h, step):
if y + h > H:
y = H - h
if with_data:
yield image[x:x + w, y:y + h], x, y, w, h
else:
yield x, y, w, h
解析:
通过 w, h = window_size
和W, H = image.shape[:2]
分别获得窗口大小的参数和图像的尺寸。然后定义offset_w
和offset_h
使得窗口在合适的范围滑动。
关键字yield
来创建一个生成器。
带yield
的函数是一个生成器,而不是一个函数了,这个生成器有一个函数就是next
函数,next
就相当于“下一步”生成哪个数,这一次的next
开始的地方是接着上一次的next
停止的地方执行的。所以调用next
的时候,生成器并不会从该函数的开始执行,只是接着上一步停止的地方开始,然后遇到yield
后,return
出要生成的数,此步就结束。
对于yield
的详细解释见python中yield的用法详解——最简单,最清晰的解释。
对于这个函数来说,每调用一次这个函数,窗口就会从上一次的位置滑动一个步长。
count_sliding_window()
功能:
计算图像中的窗口数(Count the number of windows in an image.)
输入和输出:
输入:
image
: 2D+ image to slide the window on, e.g. RGB or hyperspectral, …step
: int stride of the sliding windowwindow_size
: int tuple, width and height of the window
输出:
- int number of windows
代码:
def count_sliding_window(top, step=10, window_size=(20, 20)):
“”" Count the number of windows in an image. # 计算图像中的窗口数
def count_sliding_window(top, step=10, window_size=(20, 20)):
""" Count the number of windows in an image. # 计算图像中的窗口数
Args:
image: 2D+ image to slide the window on, e.g. RGB or hyperspectral, ...
step: int stride of the sliding window
window_size: int tuple, width and height of the window
Returns:
int number of windows
"""
sw = sliding_window(top, step, window_size, with_data=False)
return sum(1 for _ in sw)
解析:
先通过调用sliding_window()
函数得到window
的集合sw
,然后遍历sw
每遍历一次返回值+1
。
grouper()
功能:
Browse an iterable by grouping n elements by n elements.
输入和输出:
输入:
n
: int, size of the groupsiterable
: the iterable to Browse
输出:
- chunk of n elements from the iterable
代码:
def grouper(n, iterable): # 分组器?
""" Browse an iterable by grouping n elements by n elements. # 通过n个元素对n个元素进行分组来浏览iterable
Args:
n: int, size of the groups
iterable: the iterable to Browse 迭代
Yields:
chunk of n elements from the iterable 可迭代的n个元素块
"""
it = iter(iterable)
while True:
chunk = tuple(itertools.islice(it, n))
if not chunk:
return
yield chunk
解析:
(暂略)
metrics()
功能:
计算并打印指标,包括准确率,混淆矩阵和F1分数(Compute and print metrics (accuracy
, confusion matrix
and F1 scores
).)
输入和输出:
输入:
prediction
: list of predicted labelstarget
: list of target labelsignored_labels
(optional): list of labels to ignore, e.g. 0 for undefn_classes
(optional): number of classes, max(target) by default
输出:
accuracy
F1 score
by classconfusion matrix
代码:
def metrics(prediction, target, ignored_labels=[], n_classes=None): # 输出指标
"""Compute and print metrics (accuracy, confusion matrix and F1 scores).
Args:
prediction: list of predicted labels
target: list of target labels
ignored_labels (optional): list of labels to ignore, e.g. 0 for undef
n_classes (optional): number of classes, max(target) by default
Returns:
accuracy, F1 score by class, confusion matrix
"""
ignored_mask = np.zeros(target.shape[:2], dtype=np.bool)
for l in ignored_labels:
ignored_mask[target == l] = True
ignored_mask = ~ignored_mask
target = target[ignored_mask]
prediction = prediction[ignored_mask]
results = {}
n_classes = np.max(target) + 1 if n_classes is None else n_classes
cm = confusion_matrix(
target,
prediction,
labels=range(n_classes))
results["Confusion matrix"] = cm
# Compute global accuracy
total = np.sum(cm)
accuracy = sum([cm[x][x] for x in range(len(cm))])
accuracy *= 100 / float(total)
results["Accuracy"] = accuracy
# Compute F1 score
F1scores = np.zeros(len(cm))
for i in range(len(cm)):
try:
F1 = 2. * cm[i, i] / (np.sum(cm[i, :]) + np.sum(cm[:, i]))
except ZeroDivisionError:
F1 = 0.
F1scores[i] = F1
results["F1 scores"] = F1scores
# Compute kappa coefficient
pa = np.trace(cm) / float(total)
pe = np.sum(np.sum(cm, axis=0) * np.sum(cm, axis=1)) / \
float(total * total)
kappa = (pa - pe) / (1 - pe)
results["Kappa"] = kappa
return results
解析:
(这里要说一下,我对prediction
和target
的数据的组织形式不明确,主要是有np.max(target)
这个代码存在)
ignored_mask
部分是创造了一个蒙版(mask
),目的是不再考虑标签为ignored_labels
的部分。
results = {}
将输出结果定义为字典dictionary
类型,作为函数的返回值。一开始的时候只是将results
作为空字典,然后逐步增加键值对。
计算混淆矩阵
cm = confusion_matrix(target, prediction, labels=range(n_classes))
调用confusion_matrix()
函数(该函数详见 python sklearn 计算混淆矩阵 confusion_matrix()函数)。简单来说,这句代码通过target
, prediction
, labels
这3个参数,来计算得到array
类型的混淆矩阵,并将结果返回给cm
。
之后results["Confusion matrix"] = cm
这句代码,在字典类型的result
中加入"Confusion matrix": cm
的键值对。
计算分类准确率
先通过total = np.sum(cm)
来计算总样本数total
。通过对混淆矩阵cm
的每一个元素求和,总和即为总样本数total
。
再计算分类准确的样本数accuracy
。混淆矩阵cm
的对角线的元素值的总和,即为分类准确的样本数accuracy
。
二者比值即为最后的准确率accuracy
。 accuracy *= 100 / float(total)
import numpy as np
cm = np.array([[1,1,1],[2,2,2],[3,3,3]])
print(cm)
# [[1 1 1]
# [2 2 2]
# [3 3 3]]
print(len(cm))
# 3
print(range(len(cm)))
# range(0, 3)
print('\n')
for i in range(len(cm)):
print(i)
# 0
# 1
# 2
之后results["Accuracy"] = accuracy
这句代码,在字典类型的result
中加入"Accuracy": accuracy
的键值对。
计算F1 score
F1分数(F1 Score),是统计学中用来衡量分类模型精确度的一种指标。它同时兼顾了分类模型的精确率Accuracy和召回率Recall Rate 。F1分数可以看作是模型精确率和召回率的一种调和平均。
这部分代码直接调用就行,套公式而已。
计算kappa系数
也是一种指标,套公式。(暂略)
返回结果
将字典dictionary
类型的result
作为函数的返回值。
show_results()
功能:
在visdom界面以文本形式输出结果。
输入和输出:
输入:
results
:字典dictionary
类型,包含Confusion matrix
、Accuracy
、F1 scores
和Kappa
四个key
。vis
:visdom
可视化服务。label_values
:默认为None
。agregated
:默认为False
。
输出:
text
:visdom
输出的文本。
代码:
def show_results(results, vis, label_values=None, agregated=False): # 可视化模块
text = ""
# if agregated部分没看懂要干啥
if agregated:
accuracies = [r["Accuracy"] for r in results]
kappas = [r["Kappa"] for r in results]
F1_scores = [r["F1 scores"] for r in results]
F1_scores_mean = np.mean(F1_scores, axis=0)
F1_scores_std = np.std(F1_scores, axis=0)
cm = np.mean([r["Confusion matrix"] for r in results], axis=0)
text += "Agregated results :\n"
else:
cm = results["Confusion matrix"]
accuracy = results["Accuracy"]
F1scores = results["F1 scores"]
kappa = results["Kappa"]
vis.heatmap(cm, opts={'title': "Confusion matrix",
'marginbottom': 150,
'marginleft': 150,
'width': 500,
'height': 500,
'rownames': label_values, 'columnnames': label_values})
text += "Confusion matrix :\n"
text += str(cm)
text += "---\n"
if agregated:
text += ("Accuracy: {:.03f} +- {:.03f}\n".format(np.mean(accuracies),
np.std(accuracies)))
else:
text += "Accuracy : {:.03f}%\n".format(accuracy)
text += "---\n"
text += "F1 scores :\n"
if agregated:
for label, score, std in zip(label_values, F1_scores_mean,
F1_scores_std):
text += "\t{}: {:.03f} +- {:.03f}\n".format(label, score, std)
else:
for label, score in zip(label_values, F1scores):
text += "\t{}: {:.03f}\n".format(label, score)
text += "---\n"
if agregated:
text += ("Kappa: {:.03f} +- {:.03f}\n".format(np.mean(kappas),
np.std(kappas)))
else:
text += "Kappa: {:.03f}\n".format(kappa)
vis.text(text.replace('\n', '<br/>'))
print(text)
解析:
整体思路是遍历result
的键值对,然后通过text += XXX
来扩充text
,最后在visdom
上打印result
。
由于不知道agregated
是在干啥,而且默认是False
,所以只考虑agregated = false
的情况。
首先,通过访问字典result
的键,来获取对应键的值。
然后通过vis.heatmap()
函数来绘制一个热图,它需要输入NxM张量X来指定热图中每个位置的值,此处为cm
。设置title
:'title': "Confusion matrix"
,尺寸:'marginbottom': 150, 'marginleft': 150, 'width': 500, 'height': 500
,行列标签:'rownames': label_values, 'columnnames': label_values
。
对于cm
,先通过str(cm)
将cm转为字符串类型,然后通过+=
扩充到text
中。
对于Accuracy
也是一样扩充到text
中,text += "Accuracy : {:.03f}%\n".format(accuracy)
。
对于F1scores
,需要对应的label_values
。通过zip(label_values, F1scores)
来将可迭代对象label_values
和F1scores
的对应元素组成元组,并以对象的形式返回(zip()函数详见:Python zip() 函数)。之后通过for
循环遍历zip()
返回的对象,同时扩充text
。for label, score in zip(label_values, F1scores): text += "\t{}: {:.03f}\n".format(label, score)
。
对于Kappa
就是简单的扩充到text
中,text += "Kappa: {:.03f}\n".format(kappa)
。
vis.text()
函数的功能是在一个盒子里打印文本。可以使用它来嵌入任意的HTML
。它需要输入一个text
字符串。opts
目前没有具体的支持。text.replace('\n', '<br/>)
用来实现换行符的替换。vis.text(text.replace('\n', '<br/>'))
sample_gt()
功能:
从标签数组gt
中提取固定百分比的样本(Extract a fixed percentage of samples from an array of labels)。
需要强调的是,被分割为训练集和测试集的样本,不包括类别为ignored_labels
的sample。被分割的只是有效的sample。
输入和输出:
输入:
gt
: a 2D array of int labelspercentage
: [0, 1] float
输出:
train_gt
:2D arrays of int labelstest_gt
:2D arrays of int labels
代码:
def sample_gt(gt, train_size, mode='random'):
"""Extract a fixed percentage of samples from an array of labels. 从标签数组中提取固定百分比的样本。
Args:
gt: a 2D array of int labels
percentage: [0, 1] float
Returns:
train_gt, test_gt: 2D arrays of int labels
"""
indices = np.nonzero(gt)
X = list(zip(*indices)) # x,y features
y = gt[indices].ravel() # classes
train_gt = np.zeros_like(gt)
test_gt = np.zeros_like(gt)
if train_size > 1:
train_size = int(train_size)
if mode == 'random':
train_indices, test_indices = sklearn.model_selection.train_test_split(X, train_size=train_size, stratify=y)
train_indices = [list(t) for t in zip(*train_indices)]
test_indices = [list(t) for t in zip(*test_indices)]
train_gt[train_indices] = gt[train_indices]
test_gt[test_indices] = gt[test_indices]
elif mode == 'fixed':
print("Sampling {} with train size = {}".format(mode, train_size))
train_indices, test_indices = [], []
for c in np.unique(gt):
if c == 0:
continue
indices = np.nonzero(gt == c)
X = list(zip(*indices)) # x,y features
train, test = sklearn.model_selection.train_test_split(X, train_size=train_size)
train_indices += train
test_indices += test
train_indices = [list(t) for t in zip(*train_indices)]
test_indices = [list(t) for t in zip(*test_indices)]
train_gt[train_indices] = gt[train_indices]
test_gt[test_indices] = gt[test_indices]
elif mode == 'disjoint':
train_gt = np.copy(gt)
test_gt = np.copy(gt)
for c in np.unique(gt):
mask = gt == c
for x in range(gt.shape[0]):
first_half_count = np.count_nonzero(mask[:x, :])
second_half_count = np.count_nonzero(mask[x:, :])
try:
ratio = first_half_count / second_half_count
if ratio > 0.9 * train_size and ratio < 1.1 * train_size:
break
except ZeroDivisionError:
continue
mask[:x, :] = 0
train_gt[mask] = 0
test_gt[train_gt > 0] = 0
else:
raise ValueError("{} sampling is not implemented yet.".format(mode))
return train_gt, test_gt
解析:
indices = np.nonzero(gt)
获取gt
中非零的元素的索引,返回值为两个array
数组构成的元组,分别表示x
和y
方向的索引。
X = list(zip(*indices))
中,首先通过*
来将np.nonzero(gt)
的返回值拆解为两个array
,然后通过zip()
函数来将*indices
拆解得到的两个array
的对应元素组成元组,并以对象的形式返回,然后通过list()
将类型转为列表list
类型。y = gt[indices].ravel()
这一句先通过gt[indices]
根据索引indices取得
相应的元素,然后通过ravel()
展开成一维数组。
这一部分代码的功能可以通过这个demo表示 :
gt = np.array([[0,0,0,0],[0,1,2,0],[0,3,4,0],[0,0,0,0]])
print(gt)
# [[0 0 0 0]
# [0 1 2 0]
# [0 3 4 0]
# [0 0 0 0]]
indices = np.nonzero(gt)
X = list(zip(*indices)) # x,y features (x,y)形式的索引
y = gt[indices].ravel() # classes
print(X)
# [(1, 1), (1, 2), (2, 1), (2, 2)]
print(type(X))
# <class 'list'>
print(y)
# [1 2 3 4]
print(type(y))
# <class 'numpy.ndarray'>
train_gt = np.zeros_like(gt)
和test_gt = np.zeros_like(gt)
将train_gt
和test_gt初始化为全0的,与gt
维度相同的数组。
由于默认的mode
为random
,所以这里只解析mode == random
的情况。
train_indices, test_indices = sklearn.model_selection.train_test_split(X, train_size=train_size, stratify=y)
这一句代码主要是调用sklearn.model_selection.train_test_split()
函数,用来将数据集划分成训练集和测试集。train_indices
和test_indices
是划分的结果,为元素的索引,形式是这种:[(4, 2), (3, 3), (2, 2), (3, 2), (1, 1), (4, 3)]
。
train_indices = [list(t) for t in zip(*train_indices)]
这一句是转换train_indices
的表示形式,变为这种形式:[[4, 3, 2, 3, 1, 4], [2, 3, 2, 2, 1, 3]]
。
train_gt[train_indices] = gt[train_indices]
这一句是从gt
中提取训练集的样本。其他非训练集样本保持初始化的0不变,作为ignored_labels
。
下面附上一个小demo,来帮助更好地理解:
import numpy as np
import sklearn.model_selection
gt = np.array([[0,0,0,0],[0,1,2,0],[0,3,4,0],[1,2,3,4],[4,3,3,2],[0,0,0,0]])
print(gt)
# [[0 0 0 0]
# [0 1 2 0]
# [0 3 4 0]
# [1 2 3 4]
# [4 3 3 2]
# [0 0 0 0]]
indices = np.nonzero(gt)
X = list(zip(*indices)) # x,y features (x,y)形式的索引
y = gt[indices].ravel() # classes
print(X)
# [(1, 1), (1, 2), (2, 1), (2, 2), (3, 0), (3, 1), (3, 2), (3, 3), (4, 0), (4, 1), (4, 2), (4, 3)]
print(type(X))
# <class 'list'>
print(y)
# [1 2 3 4 1 2 3 4 4 3 3 2]
print(type(y))
# <class 'numpy.ndarray'>
train_gt = np.zeros_like(gt)
test_gt = np.zeros_like(gt)
train_indices, test_indices = sklearn.model_selection.train_test_split(X, train_size=0.5, stratify=y)
print(train_indices)
# [(4, 2), (3, 3), (2, 2), (3, 2), (1, 1), (4, 3)]
print(test_indices)
# [(1, 2), (3, 1), (2, 1), (4, 0), (3, 0), (4, 1)]
print('________________________')
y_train = []
for i, j in train_indices:
y_train.append(gt[i][j])
y_test = []
for i, j in test_indices:
y_test.append(gt[i][j])
print(y_train)
# [3, 4, 2, 1, 2, 3]
print(y_test)
# [2, 3, 3, 4, 4, 1]
print('________________________')
train_indices = [list(t) for t in zip(*train_indices)]
print(train_indices)
# [[4, 3, 2, 3, 1, 4], [2, 3, 2, 2, 1, 3]]
train_gt[train_indices] = gt[train_indices]
print(train_gt)
# [[0 0 0 0]
# [0 1 0 0]
# [0 0 4 0]
# [0 0 3 4]
# [0 0 3 2]
# [0 0 0 0]]
compute_imf_weights()
(暂略)
camel_to_snake()
(暂略)
module.py
_addindent()
(暂略)
class Module(object)
这部分是所有神经网络模块的基类。
File is read-only。也就是最好不要修改。
所以(暂略)
model.py
class Baseline(nn.Module)
定义一个class,继承nn.Module
。
属性:
无
方法:
weight_init()
def weight_init(m):
if isinstance(m, nn.Linear): # 判断类型是否相同
init.kaiming_normal_(m.weight) # 一种权重初始化方法
init.zeros_(m.bias)
用来初始化权重weight
和偏置bias
。
首先通过if isinstance(m, nn.Linear)
来看输入m是不是和nn.Linear
是一类(这里是继承关系)(鲁棒性检验)。
通过init.kaiming_normal_(m.weight)
来初始化权重weight
,其中kaiming_normal_()
是一种初始化权重的方法。
init.zeros_(m.bias)
将偏置bias
初始化为0。
__ init__()
def __init__(self, input_channels, n_classes, dropout=False): # 类的属性的初始化
super(Baseline, self).__init__()
self.use_dropout = dropout
if dropout:
self.dropout = nn.Dropout(p=0.5)
self.fc1 = nn.Linear(input_channels, 2048)
self.fc2 = nn.Linear(2048, 4096)
self.fc3 = nn.Linear(4096, 2048)
self.fc4 = nn.Linear(2048, n_classes)
self.apply(self.weight_init)
这一部分是对类的初始化,包括是否使用dropout(True
or Flase
)、网络的层数和in_channel
和out_channel
。同时对网络的参数(weight
和bias
)进行初始化,通过self.apply(self.weight_init)
来实现。
forward(self, x)
def forward(self, x):
x = F.relu(self.fc1(x))
if self.use_dropout:
x = self.dropout(x)
x = F.relu(self.fc2(x))
if self.use_dropout:
x = self.dropout(x)
x = F.relu(self.fc3(x))
if self.use_dropout:
x = self.dropout(x)
x = self.fc4(x)
return x
这个方法定义了前向传播过程,流程如下:
输入 |
---|
nn.Linear(input_channels, 2048) |
relu() |
dropout() |
self.fc2 = nn.Linear(2048, 4096) |
relu() |
dropout() |
nn.Linear(4096, 2048) |
relu() |
dropout() |
self.fc4 = nn.Linear(2048, n_classes) |
输出 |
class HuEtAl(nn.Module)
属性:
无
方法:
weight_init()
def weight_init(m):
# [All the trainable parameters in our CNN should be initialized to
# be a random value between −0.05 and 0.05.]
# 我们CNN中的所有可训练参数应初始化为介于-0.05和0.05之间的随机值。
if isinstance(m, nn.Linear) or isinstance(m, nn.Conv1d):
init.uniform_(m.weight, -0.05, 0.05)
init.zeros_(m.bias)
模型权重weight
初始化为介于-0.05和0.05之间的随机值,偏置bias
初始化为0。
_get_final_flattened_size()
def _get_final_flattened_size(self): # 得到最终的扁平尺寸
with torch.no_grad():
x = torch.zeros(1, 1, self.input_channels) # 生成一个1×1×input_channels的全0的tensor
x = self.pool(self.conv(x)) # 先卷积,再池化
return x.numel() # numel()返回张量中的元素个数
首先设置with torch.no_grad()
。当确定不会调用Tensor.backward()
时,禁用 gradient calculation 对于 inference 非常有用。 它将减少计算的内存消耗,否则会有requires_grad = True
。
x = torch.zeros(1, 1, self.input_channels)
生成一个1×1×input_channels
的全0的tensor。
x = self.pool(self.conv(x))
,先卷积,再池化。
return x.numel()
返回张量中的元素个数。
至于为什么这么做,我现在没看很明白。
__ init__()
def __init__(self, input_channels, n_classes, kernel_size=None, pool_size=None):
super(HuEtAl, self).__init__()
if kernel_size is None:
# [In our experiments, k1 is better to be [ceil](n1/9)]
kernel_size = math.ceil(input_channels / 9)
if pool_size is None:
# The authors recommand that k2's value is chosen so that the pooled features have 30~40 values
# ceil(kernel_size/5) gives the same values as in the paper so let's assume it's okay
pool_size = math.ceil(kernel_size / 5)
self.input_channels = input_channels
# [The first hidden convolution layer C1 filters the n1 x 1 input data with 20 kernels of size k1 x 1]
# 第一隐藏卷积层C1用大小为k1×1的20个内核对n1×1输入数据进行滤波
self.conv = nn.Conv1d(1, 20, kernel_size)
self.pool = nn.MaxPool1d(pool_size)
self.features_size = self._get_final_flattened_size()
# [n4 is set to be 100]
self.fc1 = nn.Linear(self.features_size, 100)
self.fc2 = nn.Linear(100, n_classes)
self.apply(self.weight_init)
这部分是初始化kernel_size
、pool_size
,同时读取input_channels
,以及网络层的结构。
kernel_size
和pool_size
的选择以论文为依据。kernel_size
为 math.ceil(input_channels / 9)
,即input_channels / 9
的向上取整。pool_size
= math.ceil(kernel_size / 5)
,即kernel_size / 5
的向上取整。
self.input_channels = input_channels
。
对于网络层的结构:
第一隐藏卷积层C1用大小为k1×1的20个内核对n1×1输入数据进行滤波,self.conv = nn.Conv1d(1, 20, kernel_size)
。
池化层统一是self.pool = nn.MaxPool1d(pool_size)
。
self.features_size = self._get_final_flattened_size()
获取展平的size
,作为第一个线性层(全连接层)的输入维度。
然后定义两个线性层(全连接层),第一个是self.fc1 = nn.Linear(self.features_size, 100)
,输入维度self.features_size
,输出维度100
。第二个是self.fc2 = nn.Linear(100, n_classes)
,输入维度100
,输出维度n_classes
。
forward()
def forward(self, x):
# [In our design architecture, we choose the hyperbolic tangent function tanh(u)]
# 在我们的设计架构中,我们选择双曲正切函数
x = x.squeeze(dim=-1).squeeze(dim=-1)
x = x.unsqueeze(1)
x = self.conv(x)
x = torch.tanh(self.pool(x))
x = x.view(-1, self.features_size)
x = torch.tanh(self.fc1(x))
x = self.fc2(x)
return x
x = x.squeeze(dim=-1).squeeze(dim=-1)
是数据预处理,将x的倒数第一个维度和倒数第二个维度抹掉。
x = x.unsqueeze(1)
再在第2个维度的位置增加一个维度(维度的索引从0开始)。
用一个小demo演示(注意:这里x是tensor类型):
import numpy as np
x = np.array([1,2,3])
print(x.shape)
# (3,)
x = x.reshape(3,1,1)
print(x.shape)
# (3, 1, 1)
x = x.squeeze(-1).squeeze(-1)
print(x.shape)
# (3,)
之后是网络的前向传播过程。
输入 |
---|
nn.Conv1d(1, 20, kernel_size) |
nn.MaxPool1d(pool_size) |
tanh() |
view(-1, self.features_size) |
nn.Linear(self.features_size, 100) |
tanh() |
nn.Linear(100, n_classes) |
输出 |
(其他网络模型,暂略……)
get_model()
功能:
获取模型模型名称和相应地超参数(实例化并获得具有足够超参数的模型,Instantiate and obtain a model with adequate hyperparameters)
输入和输出:
输入:
name
:模型的名称,string类型(string of the model name)kwargs
:超参数,dictionary类型,**kwargs
表示数目不定
输出:
model
: PyTorch networkoptimizer
: PyTorch optimizercriterion
: PyTorch loss Functionkwargs
: 具有理智默认值的超参数(hyperparameters with sane defaults)
代码和解析:
def get_model(name, **kwargs):
"""
Instantiate and obtain a model with adequate hyperparameters
Args:
name: string of the model name 网络名,string类型
kwargs: hyperparameters 超参数,dictionary类型,**kwargs表示数目不定
Returns:
model: PyTorch network
optimizer: PyTorch optimizer
criterion: PyTorch loss Function
kwargs: hyperparameters with sane defaults 具有理智默认值的超参数
"""
device = kwargs.setdefault('device', torch.device('cpu')) # 获取字典kwargs中键device的值,否则返回默认值为cpu。
n_classes = kwargs['n_classes'] # 获取字典kwargs中键n_classes的值。
n_bands = kwargs['n_bands'] # 获取字典kwargs中键n_bands的值。
weights = torch.ones(n_classes)
weights[torch.LongTensor(kwargs['ignored_labels'])] = 0.
weights = weights.to(device) # 放到cpu或gpu
weights = kwargs.setdefault('weights', weights)
首先要强调的是,超参数以键值对的形式存储在kwargs
中。通过访问字典kwargs
来获得常用的超参数device
、n_classes
和n_bands
,同时通过setdefault()
函数来设定默认值。
if name == 'nn':
……
elif name == 'hamida':
……
这一部分是,根据模型的选择,用分支语句来设定自己模型的合适的超参数,比如learning_rate
、optimizer
、criterion
、epoch
和batch_size
等。
model = model.to(device)
epoch = kwargs.setdefault('epoch', 100)
kwargs.setdefault('scheduler', optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.1, patience=epoch//4, verbose=True))
#kwargs.setdefault('scheduler', None)
kwargs.setdefault('batch_size', 100)
kwargs.setdefault('supervision', 'full')
kwargs.setdefault('flip_augmentation', False)
kwargs.setdefault('radiation_augmentation', False)
kwargs.setdefault('mixture_augmentation', False)
kwargs['center_pixel'] = center_pixel
这部分是用到的函数主要都是setdefault()
,当模型参数不全时,将参数设定为默认值。
但其实模型的参数一般都在前面的if-esle
设定好了,所以这些其实就是查漏补缺的作用。
val()
功能:
计算 val set 的准确率。
输入和输出:
输入:
net
data_loader
device
supervision
输出:
accuracy / total
:其实就是准确率,accuracy
统计的是准确的个数。
代码和解析:
函数定义
def val(net, data_loader, device='cpu', supervision='full'):
# TODO : fix me using metrics()
初始化和预操作
accuracy, total = 0., 0.
ignored_labels = data_loader.dataset.ignored_labels
- 将
accuracy
和total
初始化为float类型的0
- 获取
ignored_labels
开始检测
for batch_idx, (data, target) in enumerate(data_loader):
需要注意的是,再检测过程中,只遍历一遍 val set ,即epoch
为1
。
保持梯度
with torch.no_grad():
因为这部分只是作为检测,并不进行网络的训练,,所以要设置with torch.no_grad()
。
选择device
# Load the data into the GPU if required
data, target = data.to(device), target.to(device)
把数据放到device
上,没什么好说的。
根据不同方式获取预测值
if supervision == 'full':
output = net(data)
elif supervision == 'semi':
outs = net(data)
output, rec = outs
_, output = torch.max(output, dim=1)
一般都是全监督,所以只看全监督的情况。output = net(data)
,没啥可说的。
_, output = torch.max(output, dim=1)
获取预测值,torch.max(output, dim=1)
是按行找到最大的元素,并返回最大的元素和索引(values, indices)
。_, output
表示获取返回值里的索引indices
,索引是几就代表是第几类。
统计accuracy和total
for out, pred in zip(output.view(-1), target.view(-1)):
if out.item() in ignored_labels:
continue
else:
accuracy += out.item() == pred.item()
total += 1
没啥好说的,很容易看懂。
返回 accuracy / total
return accuracy / total
save_model()
def save_model(model, model_name, dataset_name, **kwargs):
model_dir = './checkpoints/' + model_name + "/" + dataset_name + "/"
if not os.path.isdir(model_dir):
os.makedirs(model_dir, exist_ok=True)
if isinstance(model, torch.nn.Module):
filename = str('wk') + "_epoch{epoch}_{metric:.2f}".format(**kwargs)
tqdm.write("Saving neural network weights in {}".format(filename))
torch.save(model.state_dict(), model_dir + filename + '.pth')
else:
filename = str('wk')
tqdm.write("Saving model params in {}".format(filename))
joblib.dump(model, model_dir + filename + '.pkl')
比较容易看明白,现在也不太需要细究,暂略。
train()
功能:
封装好的网络训练的函数。(Training loop to optimize a network for several epochs and a specified loss)
输入和输出:
输入:
net
: a PyTorch modeloptimizer
: a PyTorch optimizerdata_loader
: a PyTorch dataset loaderepoch
: int specifying the number of training epochscriterion
: a PyTorch-compatible loss function, e.g. nn.CrossEntropyLossdevice
(optional): torch device to use (defaults to CPU)display_iter
(optional): number of iterations before refreshing the display (False/None to switch off).scheduler
(optional): PyTorch scheduler,基于epoch
调整学习率lr
val_loader
(optional): validation datasetsupervision
(optional): ‘full’ or ‘semi’
输出:
- 无
代码和解析:
函数的定义和信息
def train(net, optimizer, criterion, data_loader, epoch, scheduler=None,
display_iter=100, device=torch.device('cpu'), display=None,
val_loader=None, supervision='full'):
"""
Training loop to optimize a network for several epochs and a specified loss
Args:
net: a PyTorch model
optimizer: a PyTorch optimizer
data_loader: a PyTorch dataset loader
epoch: int specifying the number of training epochs
criterion: a PyTorch-compatible loss function, e.g. nn.CrossEntropyLoss
device (optional): torch device to use (defaults to CPU)
display_iter (optional): number of iterations before refreshing the display (False/None to switch off).
scheduler (optional): PyTorch scheduler
val_loader (optional): validation dataset
supervision (optional): 'full' or 'semi'
"""
损失函数的鲁棒性检测
if criterion is None:
raise Exception("Missing criterion. You must specify a loss function.")
这一部分就是,如果损失函数criterion
不存在,则报错并打印报错信息。
这是为了增加程序的鲁棒性,不影响程序的主要功能。
初始化部分变量
net.to(device)
save_epoch = epoch // 20 if epoch > 20 else 1
losses = np.zeros(1000000)
mean_losses = np.zeros(100000000)
iter_ = 1
loss_win, val_win = None, None
val_accuracies = []
训练网络
开始epoch循环
for e in tqdm(range(1, epoch + 1), desc="Training the network"):
range(1, epoch + 1)
表示循环次数为epoch
次。
tqdm()
创建一个进度条,描述信息为desc="Training the network"
。
设置模型为训练模式
# Set the network to training mode
net.train()
avg_loss = 0.
关于train()
函数,PyTorch官方文档如下:
train
(mode=True)[SOURCE]Sets the module in training mode.
This has any effect only on certain modules. See documentations of particular modules for details of their behaviors in training/evaluation mode, if they are affected, e.g.
Dropout
,BatchNorm
, etc.
在训练模式下设置模块。
这仅对某些模块有任何影响。如果它们受到影响(例如Dropout,BatchNorm等),有关其在training / evaluation 模式中的行为的详细信息,请参阅特定模块的文档。
在这里net.train()
就是将net
转为训练training模式。
avg_loss = 0.
将avg_loss
初始化为float类型的0
。
epoch
中按batch
训练
for batch_idx, (data, target) in tqdm(enumerate(data_loader), total=len(data_loader)):
enumerate(data_loader)
将data_loader
构成一个索引序列。enumerate(train_loader): <enumerate object at 0x000001D86C258750>
。
len(data_loader)
的返回值为一个epoch
内的batch
的值(样本被分为多少个batch
)。在这里训练集样本数为np.count_nonzero(train_gt): 4063
,对应的len(train_loader): 41。
由于设定的batch_size
为100
,所以训练集的4063
个样本被分为41
个batch
(最后一个batch
的样本数不够batch_size
)。
通过遍历enumerate(data_loader)
构成的索引序列,索引被赋值给batch_idx
,表示这是第几个batch
;数据和真实值被赋值给(data, target)
,表示输入数据和真值。
data
target
to device
# Load the data into the GPU if required
data, target = data.to(device), target.to(device)
将data
和target
放到对应的device
上,默认为cpu,一般为gpu。
正向传播
optimizer.zero_grad()
if supervision == 'full':
output = net(data)
loss = criterion(output, target)
elif supervision == 'semi':
outs = net(data)
output, rec = outs
loss = criterion[0](output, target) + net.aux_loss_weight * criterion[1](rec, data)
else:
raise ValueError("supervision mode \"{}\" is unknown.".format(supervision))
首先将optimizer
的梯度置零:optimizer.zero_grad()
。
然后根据监督方式的不同选择不同的训练方法,因为一般都是全监督,所以只分析全监督的情况:
if supervision == 'full':
output = net(data)
loss = criterion(output, target)
- 计算预测值
output
:output = net(data)
- 计算损失函数
loss
:criterion(output, target)
反向传播
loss.backward()
optimizer.step()
- 反向传播:
loss.backward()
- 优化:
optimizer.step()
计算损失
avg_loss += loss.item()
losses[iter_] = loss.item()
mean_losses[iter_] = np.mean(losses[max(0, iter_ - 100):iter_ + 1])
avg_loss
:初始化为0
,累加每个batch
计算得到的loss
的值。losses
:ndarray类型,每个位置依次存储每次迭代(每个batch
)的损失的值。比如索引为1
的位置存放的是第一次迭代(第一个batch
)的loss
的值。mean_losses
:ndarray类型,将索引为iter_
的位置置为losses
的[max(0, iter_ - 100):iter_ + 1]
的均值。至于为什么这么干,不知道。
绘制Training loss和Validation accuracy曲线
if display_iter and iter_ % display_iter == 0:
这里说一下为什么要有这样一个if
判断。
首先要解决一下变量display_iter
的含义:
display_iter (optional): number of iterations before refreshing the display (False/None to switch off).
简单来说,display_iter
的意义是,一旦迭代次数iter_
是display_iter
整数倍(比如100,200,……),就刷新显示(refreshing the display)。
所以display_iter
不为0
,且迭代次数iter_
是display_iter
整数倍(迭代次数iter_
对display_iter
取余为0
)的时候,才更新 Training loss 和 Validation accuracy 曲线。
string = 'Train (epoch {}/{}) [{}/{} ({:.0f}%)]\tLoss: {:.6f}'
string = string.format(
e, epoch, batch_idx *
len(data), len(data) * len(data_loader),
100. * batch_idx / len(data_loader), mean_losses[iter_])
这一段打印的示例(我自己的数据)为,有6个值:
Train (epoch 55/100) [31100/32200 (97%)] Loss: 0.024220
e
:当前的epoch
数,第几次遍历整个训练集。epoch
:总的epoch
数,一共遍历几次训练集。batch_idx *len(data)
:这个epoch
已经训练结束的样本数。batch_idx
是现在的batch
数,len(data)
是一个batch
的训练样本数。len(data) * len(data_loader)
:这个epoch
要训练的样本数。len(data_loader)
是总的的batch
数,len(data)
是一个batch
的训练样本数。100. * batch_idx / len(data_loader)
:这个epoch
已经训练结束的样本数 / 这个epoch
要训练的样本数,的值,一个百分数。mean_losses[iter_]
:第iter_
次迭代(第iter_
个batch
)的损失的值。
update = None if loss_win is None else 'append'
loss_win = display.line(
X=np.arange(iter_ - display_iter, iter_),
Y=mean_losses[iter_ - display_iter:iter_],
win=loss_win,
update=update,
opts={'title': "Training loss",
'xlabel': "Iterations",
'ylabel': "Loss"
}
)
tqdm.write(string)
第一句update = None if loss_win is None else 'append'
的意思是,如果loss_win
是None
则 update
是None
,如果loss_win
不是None
则 update
是'append'
。
而loss_win
的定义在第一句的下面,这就导致第一次运行的时候update
是None
,之后运行的时候都是'append'
。
update: None
iter_: 100
display_iter: 100
--------------------------------------------------------
Train (epoch 3/100) [1700/4100 (41%)] Loss: 0.916481
update: append
iter_: 200
display_iter: 100
--------------------------------------------------------
Train (epoch 5/100) [3500/4100 (85%)] Loss: 0.695566
update: append
iter_: 300
display_iter: 100
--------------------------------------------------------
Train (epoch 8/100) [1200/4100 (29%)] Loss: 0.599200
loss_win = display.line(
X=np.arange(iter_ - display_iter, iter_),
Y=mean_losses[iter_ - display_iter:iter_],
win=loss_win,
update=update,
opts={'title': "Training loss",
'xlabel': "Iterations",
'ylabel': "Loss"
}
)
tqdm.write(string)
这里的display
是vis
,采用visdom
可视化,用到的函数是vis.line()
。关于vis.line()
:
vis.line
这个函数绘制一个线条图。它需要输入一个N或NxM张量 Y来指定要绘制的M线(连接N点)的值。它还采用可选的X张量来指定相应的x轴值; X可以是一个N张量(在这种情况下,所有的线将共享相同的x轴值)或具有相同的大小Y。
以下opts是支持的:
- opts.fillarea :填充行(boolean)以下的区域
- opts.colormap :colormap(string; default = ‘Viridis’)
- opts.markers :show markers(boolean; default = false)
- opts.markersymbol:标志符号(string;默认= ‘dot’)
- opts.markersize :标记大小(number;默认= ‘10’)
- opts.legend :table包含图例名称
win=loss_win
据我猜测,应该是将操作的window
设置为loss_win
,否则这么多window
不知道操作哪一个。
前3次的结果图如图:
[外链图片转存失败(img-6scjDfp7-1568206440201)(file:///E:\734167802\Image\Group\BC~H{8N6WZ9FQWXTFVFRJ]6.png)]
简单来说就是第一次运行的时候,恰好有iter_
等于display_iter
这时候创建新窗口并绘制一次,之后iter_
等于display_iter
的整数倍的时候再绘制一次,并将后绘制的部分更新到第一次绘制的图像(窗口)上。
最后由tqdm.write(string)
打印进度条。
if len(val_accuracies) > 0:
val_win = display.line(Y=np.array(val_accuracies),
X=np.arange(len(val_accuracies)),
win=val_win,
opts={'title': "Validation accuracy",
'xlabel': "Epochs",
'ylabel': "Accuracy"
})
这部分是打印 Validation accuracy
,还有指定窗口这点东西,上面都有,不再赘述。
迭代变量加一
iter_ += 1
回收无用变量
del(data, target, loss, output)
对于del
方法:
与 init() 方法对应的是 del() 方法,init() 方法用于初始化 Python 对象,而 del() 则用于销毁 Python 对象,即在任何 Python 对象将要被系统回收之时,系统都会自动调用该对象的 del() 方法。 当程序不再需要一个 Python 对象时,系统必须把该对象所占用的内存空间释放出来,这个过程被称为垃圾回收(GC,Garbage Collector),Python 会自动回收所有对象所占用的内存空间,因此开发者无须关心对象垃圾回收的过程。
简单来说,在运行完这一个batch
,这一个batch
的data
、target
、loss
和output
都不被需要了,可以被回收来释放内存。
因为在下一个batch
又会又新的data
、target
,产生新的loss
和output
。
到此为止,一个epoch完毕
计算 avg_loss,val_accuracies,metric
avg_loss /= len(data_loader)
if val_loader is not None:
val_acc = val(net, val_loader, device=device, supervision=supervision)
val_accuracies.append(val_acc)
metric = -val_acc
else:
metric = avg_loss
avg_loss
是对一个epoch
内的所有的batch
的损失loss
取均值。
下面的一个判断语句是,如果val_loader
是None
(第一次执行到这里),metric = avg_loss
;之后执行到这里,就调用val()
函数计算 val set 的准确率,并加入到val_accuracies
中,再将指标metric
设置为-val_acc
(不知道为什么,暂略)。
Save the weights
# Save the weights
if e % save_epoch == 0:
save_model(net, camel_to_snake(str(net.__class__.__name__)), data_loader.dataset.name, epoch=e, metric=abs(metric))
存储的文件名的示例为:
wk_epoch60_0.99.pth
test()
功能:
Test a model on a specific image(在特定图像上测试模型)
输入和输出:
输入:
net
img
:用来做test
的图像hyperparams
:超参数的字典
输出:
probs
:W × H × n_classes
。
代码和解析:
函数定义
def test(net, img, hyperparams):
"""
Test a model on a specific image
"""
模型设置为test模式
net.eval()
提取超参数
patch_size = hyperparams['patch_size']
center_pixel = hyperparams['center_pixel']
batch_size, device = hyperparams['batch_size'], hyperparams['device']
n_classes = hyperparams['n_classes']
kwargs = {'step': hyperparams['test_stride'], 'window_size': (patch_size, patch_size)}
patch_size
:窗口的大小。窗口可以包含上下文信息。center_pixel
:为True
的时候,只看中间的样本,不考虑上下文信息。batch_size
device
kwargs
:一个字典,step
为步长,window_size
为窗口大小(元组类型)
初始化返回结果 probs
probs = np.zeros(img.shape[:2] + (n_classes,))
img
的维度是W × H × channel
,img.shape[:2] + (n_classes,)
获得img
的前两个维度W × H
,并将n_classes
作为第三个维度,即W × H × n_classes
。
用一个小demo演示:
shape = (340,680,103)
print(shape[:2] + (10,))
# (340, 680, 10)
probs
的初始化结果为W × H × n_classes
的全0数组。
计算迭代总数 iterations
iterations = count_sliding_window(img, **kwargs) // batch_size
count_sliding_window()
计算整个图像可以产生多少个window
,每batch_size
最为一批(batch
),二者相除就是最大迭代次数iterations
。
开始迭代
for batch in tqdm(grouper(batch_size, sliding_window(img, **kwargs)),
total=(iterations),
desc="Inference on the image"
):
tqdm()
是生成进度条,total
表示进度条的上限,desc
为描述信息的字符串。
grouper()
是分组器,返回 chunk of n elements from the iterable
,应该是从迭代器sliding_window()
获取n个elements
。
提取数据:
with torch.no_grad():
if patch_size == 1:
data = [b[0][0, 0] for b in batch]
data = np.copy(data)
data = torch.from_numpy(data)
else:
data = [b[0] for b in batch]
data = np.copy(data)
data = data.transpose(0, 3, 1, 2)
data = torch.from_numpy(data)
data = data.unsqueeze(1)
通过这部分自加的代码,来查看batch
和b
的相关信息:
# ----------------------------自加-------------------------
print('batch',batch)
for b in batch:
print('b:',b)
print('b[0]:',b[0])
print('b[0][0, 0]:',b[0][0, 0])
os.system('pause')
# ----------------------------自加-------------------------
这里只详细解释patch_size
的情况。
首先是 b
的形式,b
的类型是tuple(因为元素类型不同):
b: (array([[[0.077625, 0.09325 , 0.0695 , 0.045 , 0.035625, 0.0375 ,
0.03425 , 0.0345 , 0.0415 , 0.039875, 0.03475 , 0.031875,
0.029 , 0.025875, 0.02625 , 0.026125, 0.021 , 0.017375,
0.017125, 0.01925 , 0.021 , 0.02525 , 0.028125, 0.028875,
0.0305 , 0.032125, 0.032875, 0.03275 , 0.03325 , 0.0345 ,
0.035625, 0.036375, 0.035625, 0.034 , 0.033875, 0.030125,
0.026 , 0.02425 , 0.022375, 0.019625, 0.02025 , 0.0215 ,
0.02125 , 0.0195 , 0.018875, 0.017625, 0.016875, 0.015875,
0.017125, 0.017875, 0.018 , 0.016125, 0.012875, 0.013 ,
0.012375, 0.011625, 0.013375, 0.01575 , 0.014625, 0.01225 ,
0.01175 , 0.01175 , 0.012375, 0.01275 , 0.0145 , 0.019125,
0.0235 , 0.030375, 0.04025 , 0.051625, 0.0615 , 0.073875,
0.092125, 0.116625, 0.140625, 0.165875, 0.189875, 0.20825 ,
0.22375 , 0.24175 , 0.253625, 0.25425 , 0.25125 , 0.258625,
0.273875, 0.279125, 0.280625, 0.281125, 0.281875, 0.28125 ,
0.281125, 0.279875, 0.279875, 0.28525 , 0.286 , 0.28025 ,
0.274125, 0.27525 , 0.278125, 0.28325 , 0.2885 , 0.293125,
0.295125]]], dtype=float32), 0, 2, 1, 1)
那么batch
应该就是由好多b
组成的列表或元组(我更倾向是列表)。
那么b[0]
就是数据的部分,即:
array([[[0.077625, 0.09325 , 0.0695 , 0.045 , 0.035625, 0.0375 ,
0.03425 , 0.0345 , 0.0415 , 0.039875, 0.03475 , 0.031875,
0.029 , 0.025875, 0.02625 , 0.026125, 0.021 , 0.017375,
0.017125, 0.01925 , 0.021 , 0.02525 , 0.028125, 0.028875,
0.0305 , 0.032125, 0.032875, 0.03275 , 0.03325 , 0.0345 ,
0.035625, 0.036375, 0.035625, 0.034 , 0.033875, 0.030125,
0.026 , 0.02425 , 0.022375, 0.019625, 0.02025 , 0.0215 ,
0.02125 , 0.0195 , 0.018875, 0.017625, 0.016875, 0.015875,
0.017125, 0.017875, 0.018 , 0.016125, 0.012875, 0.013 ,
0.012375, 0.011625, 0.013375, 0.01575 , 0.014625, 0.01225 ,
0.01175 , 0.01175 , 0.012375, 0.01275 , 0.0145 , 0.019125,
0.0235 , 0.030375, 0.04025 , 0.051625, 0.0615 , 0.073875,
0.092125, 0.116625, 0.140625, 0.165875, 0.189875, 0.20825 ,
0.22375 , 0.24175 , 0.253625, 0.25425 , 0.25125 , 0.258625,
0.273875, 0.279125, 0.280625, 0.281125, 0.281875, 0.28125 ,
0.281125, 0.279875, 0.279875, 0.28525 , 0.286 , 0.28025 ,
0.274125, 0.27525 , 0.278125, 0.28325 , 0.2885 , 0.293125,
0.295125]]], dtype=float32)
b[0]
的shape
为:
(1, 1, 103)
那么b[0][0,0]
和它的shape
为:
[0.077625, 0.09325 , 0.0695 , 0.045 , 0.035625, 0.0375 ,
0.03425 , 0.0345 , 0.0415 , 0.039875, 0.03475 , 0.031875,
0.029 , 0.025875, 0.02625 , 0.026125, 0.021 , 0.017375,
0.017125, 0.01925 , 0.021 , 0.02525 , 0.028125, 0.028875,
0.0305 , 0.032125, 0.032875, 0.03275 , 0.03325 , 0.0345 ,
0.035625, 0.036375, 0.035625, 0.034 , 0.033875, 0.030125,
0.026 , 0.02425 , 0.022375, 0.019625, 0.02025 , 0.0215 ,
0.02125 , 0.0195 , 0.018875, 0.017625, 0.016875, 0.015875,
0.017125, 0.017875, 0.018 , 0.016125, 0.012875, 0.013 ,
0.012375, 0.011625, 0.013375, 0.01575 , 0.014625, 0.01225 ,
0.01175 , 0.01175 , 0.012375, 0.01275 , 0.0145 , 0.019125,
0.0235 , 0.030375, 0.04025 , 0.051625, 0.0615 , 0.073875,
0.092125, 0.116625, 0.140625, 0.165875, 0.189875, 0.20825 ,
0.22375 , 0.24175 , 0.253625, 0.25425 , 0.25125 , 0.258625,
0.273875, 0.279125, 0.280625, 0.281125, 0.281875, 0.28125 ,
0.281125, 0.279875, 0.279875, 0.28525 , 0.286 , 0.28025 ,
0.274125, 0.27525 , 0.278125, 0.28325 , 0.2885 , 0.293125,
0.295125]
(103,)
了解了b
,b[0]
和b[0][0,1]
的组成形式和类型后,再回来看这一句:
data = [b[0][0, 0] for b in batch]
等号右边先是一个中括号[]
,表示是一个列表,列表里面是一个for循环,将每次循环得到的b
进行b[0][0, 0]
的操作,放到列表里作为列表的一个元素。
用以和小demo演示一下功能:
a1 = np.array([[[0.079625, 0.074 , 0.06025 , 0.0695 , 0.0635 , 0.0355 ,
0.02225 , 0.02475 , 0.024125, 0.028 , 0.027125, 0.026875,
0.023375, 0.020125, 0.019 , 0.017 , 0.0155 , 0.01525 ,
0.015875, 0.01575 , 0.015625, 0.015375, 0.018375, 0.0235 ,
0.026 , 0.025375, 0.02525 , 0.02575 , 0.027375, 0.029375,
0.02975 , 0.028375, 0.027125, 0.026875, 0.027 , 0.025125,
0.02375 , 0.020875, 0.018625, 0.02025 , 0.02175 , 0.0225 ,
0.022125, 0.02125 , 0.0205 , 0.021625, 0.0235 , 0.02275 ,
0.020125, 0.018375, 0.017625, 0.01925 , 0.021875, 0.02075 ,
0.0175 , 0.01725 , 0.01825 , 0.017375, 0.016125, 0.018125,
0.01925 , 0.017125, 0.016625, 0.016 , 0.016125, 0.02175 ,
0.030625, 0.04225 , 0.056875, 0.073125, 0.09 , 0.10625 ,
0.126625, 0.153125, 0.1825 , 0.21275 , 0.24225 , 0.269625,
0.289625, 0.304125, 0.315625, 0.319 , 0.311625, 0.31925 ,
0.341625, 0.347625, 0.3435 , 0.3435 , 0.342125, 0.33875 ,
0.335125, 0.33025 , 0.330625, 0.3355 , 0.334375, 0.326125,
0.317625, 0.318875, 0.321375, 0.321125, 0.321625, 0.3275 ,
0.3305 ]]])
a2 = np.array([[[0.115 , 0.0825 , 0.058125, 0.038875, 0.04375 , 0.046875,
0.047 , 0.041375, 0.032625, 0.024875, 0.022375, 0.021125,
0.02075 , 0.022375, 0.021625, 0.019125, 0.017375, 0.015 ,
0.013375, 0.016625, 0.019875, 0.0235 , 0.028375, 0.0305 ,
0.030875, 0.032375, 0.035 , 0.03725 , 0.04 , 0.042375,
0.042625, 0.043375, 0.042375, 0.039875, 0.0385 , 0.036375,
0.033875, 0.032375, 0.03075 , 0.029 , 0.0295 , 0.029625,
0.030625, 0.027625, 0.024875, 0.02525 , 0.025375, 0.02625 ,
0.026375, 0.026 , 0.0265 , 0.026875, 0.026875, 0.027125,
0.02675 , 0.024125, 0.022875, 0.021625, 0.01975 , 0.019125,
0.0195 , 0.019875, 0.019625, 0.02025 , 0.023625, 0.0295 ,
0.03625 , 0.046625, 0.06175 , 0.076875, 0.09225 , 0.108625,
0.130875, 0.157375, 0.182625, 0.209 , 0.2285 , 0.246 ,
0.262 , 0.272625, 0.27975 , 0.276875, 0.269125, 0.279 ,
0.299875, 0.3005 , 0.294 , 0.291625, 0.293 , 0.290375,
0.289 , 0.292875, 0.2935 , 0.290875, 0.287375, 0.278625,
0.27325 , 0.278375, 0.281 , 0.27725 , 0.282375, 0.294 ,
0.298125]]])
batch = [a1, a2]
data = [b[0, 0] for b in batch]
data = np.copy(data)
print('data:',data)
# data: [[0.079625 0.074 0.06025 0.0695 0.0635 0.0355 0.02225 0.02475
# 0.024125 0.028 0.027125 0.026875 0.023375 0.020125 0.019 0.017
# 0.0155 0.01525 0.015875 0.01575 0.015625 0.015375 0.018375 0.0235
# 0.026 0.025375 0.02525 0.02575 0.027375 0.029375 0.02975 0.028375
# 0.027125 0.026875 0.027 0.025125 0.02375 0.020875 0.018625 0.02025
# 0.02175 0.0225 0.022125 0.02125 0.0205 0.021625 0.0235 0.02275
# 0.020125 0.018375 0.017625 0.01925 0.021875 0.02075 0.0175 0.01725
# 0.01825 0.017375 0.016125 0.018125 0.01925 0.017125 0.016625 0.016
# 0.016125 0.02175 0.030625 0.04225 0.056875 0.073125 0.09 0.10625
# 0.126625 0.153125 0.1825 0.21275 0.24225 0.269625 0.289625 0.304125
# 0.315625 0.319 0.311625 0.31925 0.341625 0.347625 0.3435 0.3435
# 0.342125 0.33875 0.335125 0.33025 0.330625 0.3355 0.334375 0.326125
# 0.317625 0.318875 0.321375 0.321125 0.321625 0.3275 0.3305 ]
# [0.115 0.0825 0.058125 0.038875 0.04375 0.046875 0.047 0.041375
# 0.032625 0.024875 0.022375 0.021125 0.02075 0.022375 0.021625 0.019125
# 0.017375 0.015 0.013375 0.016625 0.019875 0.0235 0.028375 0.0305
# 0.030875 0.032375 0.035 0.03725 0.04 0.042375 0.042625 0.043375
# 0.042375 0.039875 0.0385 0.036375 0.033875 0.032375 0.03075 0.029
# 0.0295 0.029625 0.030625 0.027625 0.024875 0.02525 0.025375 0.02625
# 0.026375 0.026 0.0265 0.026875 0.026875 0.027125 0.02675 0.024125
# 0.022875 0.021625 0.01975 0.019125 0.0195 0.019875 0.019625 0.02025
# 0.023625 0.0295 0.03625 0.046625 0.06175 0.076875 0.09225 0.108625
# 0.130875 0.157375 0.182625 0.209 0.2285 0.246 0.262 0.272625
# 0.27975 0.276875 0.269125 0.279 0.299875 0.3005 0.294 0.291625
# 0.293 0.290375 0.289 0.292875 0.2935 0.290875 0.287375 0.278625
# 0.27325 0.278375 0.281 0.27725 0.282375 0.294 0.298125]]
print('data.shape:',data.shape)
# data.shape: (2, 103)
这里的batch
只设定了两个元素,可以看到最终的返回值的shape
为(2, 103)
,应该是每个sample
作为一行。
如果元组batch
中有100
个元素,data
的shape
就是 (100,103)
,为(batch_size, channel)
。
所以再看if patch_size == 1
的这部分代码:
with torch.no_grad():
if patch_size == 1:
data = [b[0][0, 0] for b in batch]
data = np.copy(data)
data = torch.from_numpy(data)
首先是test,所以设定在with torch.no_grad():
下执行。
data = [b[0][0, 0] for b in batch]
将元组batch
中的每个元素的第一个元素(每一个sample
的数据)提取出来,组成一个叫data
的列表。
然后通过np.copy()
将data
从list转成array。
data = torch.from_numpy(data)
将data
从array转成tensor。
获取预测值 output
indices = [b[1:] for b in batch]
data = data.to(device)
output = net(data)
if isinstance(output, tuple):
output = output[0]
output = output.to('cpu')
if patch_size == 1 or center_pixel:
output = output.numpy()
else:
output = np.transpose(output.numpy(), (0, 2, 3, 1))
indices = [b[1:] for b in batch]
获取索引信息。
data = data.to(device)
将data
放到相应的device
上。
output = net(data)
获取data
的预测值。
if isinstance(output, tuple): output = output[0]
这一句不知道,暂略。
output = output.to('cpu')
将output
转到cpu。
然后在patch_size == 1 or center_pixel
的情况下,将output
转成array类型output = output.numpy()
。
统计结果
for (x, y, w, h), out in zip(indices, output):
if center_pixel:
probs[x + w // 2, y + h // 2] += out
else:
probs[x:x + w, y:y + h] += out
首先强调一下,返回结果probs
初始化的时候是初始化为全0
的(对应的一般只设置一个ignored_label
并将其对应的label
作为0
)。
整个test
的大致意思是,每次通过grouper()
获取一个batch
的sample(个数为batch_size
),然后将它们的预测结果output
更新到probs
中。这样一个batch
一个batch
地进行完,就得到了全部训练集样本的预测结果(训练样本以外的样本,对应位置的值为全零)。
其它暂略。
inference.py
这一部分的代码,跟main.py有重复。
(暂略)
datasets.py
此文件包含用于高光谱图像和相关助手的PyTorch数据集。
DATASETS_CONFIG + 更新
数据集配置
DATASETS_CONFIG
是数据集配置字典,是字典dictionary类型。键值对的键是dataset_name
,值是数据集的urls
、img
和gt
。
DATASETS_CONFIG = {
'PaviaC': {
'urls': ['http://www.ehu.eus/ccwintco/uploads/e/e3/Pavia.mat', # urls是链接
'http://www.ehu.eus/ccwintco/uploads/5/53/Pavia_gt.mat'],
'img': 'Pavia.mat',
'gt': 'Pavia_gt.mat'
},
'PaviaU': {
'urls': ['http://www.ehu.eus/ccwintco/uploads/e/ee/PaviaU.mat',
'http://www.ehu.eus/ccwintco/uploads/5/50/PaviaU_gt.mat'],
'img': 'PaviaU.mat',
'gt': 'PaviaU_gt.mat'
},
'KSC': {
'urls': ['http://www.ehu.es/ccwintco/uploads/2/26/KSC.mat',
'http://www.ehu.es/ccwintco/uploads/a/a6/KSC_gt.mat'],
'img': 'KSC.mat',
'gt': 'KSC_gt.mat'
},
'IndianPines': {
'urls': ['http://www.ehu.eus/ccwintco/uploads/6/67/Indian_pines_corrected.mat',
'http://www.ehu.eus/ccwintco/uploads/c/c4/Indian_pines_gt.mat'],
'img': 'Indian_pines_corrected.mat',
'gt': 'Indian_pines_gt.mat'
},
'Botswana': {
'urls': ['http://www.ehu.es/ccwintco/uploads/7/72/Botswana.mat',
'http://www.ehu.es/ccwintco/uploads/5/58/Botswana_gt.mat'],
'img': 'Botswana.mat',
'gt': 'Botswana_gt.mat',
}
}
更新数据集配置
try:
from custom_datasets import CUSTOM_DATASETS_CONFIG
DATASETS_CONFIG.update(CUSTOM_DATASETS_CONFIG)
except ImportError:
pass
本质上是将字典CUSTOM_DATASETS_CONFIG
更新到字典DATASETS_CONFIG
中,用到的是字典操作的update()
函数。
class TqdmUpTo(tqdm)
一个class 进度条功能。
(暂略)
get_dataset()
####功能:
下载并读取数据集。
####输入和输出:
输入:
dataset_name
: string with the name of the datasettarget_folder
(optional): folder to store the datasets, defaults to./
。当然我一般是指定位置的。datasets
(optional): dataset configuration dictionary, defaults to prebuilt one。一般设定为DATASETS_CONFIG
。
输出:
img
: 3D hyperspectral image (WxHxB),B为波段。gt
: 2D int array of labelslabel_values
: list of class namesignored_labels
: list of int classes to ignorergb_bands
: int元组,对应红色、绿色和蓝色波段(int tuple that correspond to red, green and blue bands)
代码和解析:
初始化参数:
def get_dataset(dataset_name, target_folder="./", datasets=DATASETS_CONFIG):
# def get_dataset(dataset_name, target_folder="C:\\Users\\73416\\PycharmProjects\\HSIproject\\Datasets\\", datasets=DATASETS_CONFIG):
""" Gets the dataset specified by name and return the related components.
Args:
dataset_name: string with the name of the dataset
target_folder (optional): folder to store the datasets, defaults to ./
datasets (optional): dataset configuration dictionary, defaults to prebuilt one
Returns:
img: 3D hyperspectral image (WxHxB)
gt: 2D int array of labels # 标签array
label_values: list of class names # 类的名单
ignored_labels: list of int classes to ignore
rgb_bands: int tuple that correspond to red, green and blue bands # int元组,对应红色、绿色和蓝色波段
"""
target_folder = "C:\\Datasets\\" # 自加,修改数据集的路径
# print(target_folder) # 自加
palette = None
# 当输入的数据集的名字没有在数据集字典datasets=DATASETS_CONFIG中,则报错dataset is unknown
if dataset_name not in datasets.keys():
raise ValueError("{} dataset is unknown.".format(dataset_name))
# 字典操作,取得数据集字典datasets中,键(key)为dataset_name的值(urls、img和gt)
dataset = datasets[dataset_name]
folder = target_folder + datasets[dataset_name].get('folder', dataset_name + '/')
# folder为:C:\Datasets\PaviaU/
这部分是初始的一些参数:
target_folder
:数据集文件夹Datasets
的存放路径。比如target_folder = "C:\\Datasets\\"
palette
:调色板,初始化为None
。dataset
:取得数据集字典datasets
中,键(key)为dataset_name
的值(urls
、img
和gt
)folder
:特定数据集文件夹的存放路径。比如C:\Datasets\PaviaU/
下载数据集:
# Download the dataset if is not present
if dataset.get('download', True):
# 如果没有folder(C:\Datasets\PaviaU/)文件夹,则创建该文件夹
if not os.path.isdir(folder):
os.mkdir(folder)
# 下载数据集(暂pass)
for url in datasets[dataset_name]['urls']:
# download the files
filename = url.split('/')[-1]
if not os.path.exists(folder + filename):
with TqdmUpTo(unit='B', unit_scale=True, miniters=1,
desc="Downloading {}".format(filename)) as t:
urlretrieve(url, filename=folder + filename,
reporthook=t.update_to)
elif not os.path.isdir(folder):
print("WARNING: {} is not downloadable.".format(dataset_name))
if dataset.get('download', True):
,则下载指定数据集。
首先检查指定路径folder
下,文件夹是否存在 ,os.path.isdir(folder)
。如果不存在则在指定路径folder
下,创建文件夹。
之后是下载数据集的代码(包括获取urls
、创建进度条等),暂略。
当然还有对于dataset_name
的鲁棒性检查,也是暂略。
读取数据集+预处理:
数据集读取:
# 读取数据集
if dataset_name == 'PaviaC':
# Load the image
# 通过自己写的open_file()函数打开C:\Datasets\PaviaU/Pavia.mat文件,返回值为字典类型,通过['pavia']来提取键值对的中的值
img = open_file(folder + 'Pavia.mat')['pavia']
# 取RGB波段,为什么这么取不知道
rgb_bands = (55, 41, 12)
# 通过自己写的open_file()函数打开C:\Datasets\PaviaU/Pavia_gt.mat文件,返回值为字典类型,通过['pavia_gt']来提取键值对的中的值
gt = open_file(folder + 'Pavia_gt.mat')['pavia_gt']
# ???label_values有什么用,如何和gt链接
label_values = ["Undefined", "Water", "Trees", "Asphalt",
"Self-Blocking Bricks", "Bitumen", "Tiles", "Shadows",
"Meadows", "Bare Soil"]
ignored_labels = [0]
elif dataset_name == 'PaviaU':
# Load the image
img = open_file(folder + 'PaviaU.mat')['paviaU']
rgb_bands = (55, 41, 12)
gt = open_file(folder + 'PaviaU_gt.mat')['paviaU_gt']
label_values = ['Undefined', 'Asphalt', 'Meadows', 'Gravel', 'Trees',
'Painted metal sheets', 'Bare Soil', 'Bitumen',
'Self-Blocking Bricks', 'Shadows']
ignored_labels = [0]
elif dataset_name == 'IndianPines':
# Load the image
img = open_file(folder + 'Indian_pines_corrected.mat')
img = img['indian_pines_corrected']
rgb_bands = (43, 21, 11) # AVIRIS sensor
gt = open_file(folder + 'Indian_pines_gt.mat')['indian_pines_gt']
label_values = ["Undefined", "Alfalfa", "Corn-notill", "Corn-mintill",
"Corn", "Grass-pasture", "Grass-trees",
"Grass-pasture-mowed", "Hay-windrowed", "Oats",
"Soybean-notill", "Soybean-mintill", "Soybean-clean",
"Wheat", "Woods", "Buildings-Grass-Trees-Drives",
"Stone-Steel-Towers"]
ignored_labels = [0]
elif dataset_name == 'Botswana':
# Load the image
img = open_file(folder + 'Botswana.mat')['Botswana']
rgb_bands = (75, 33, 15)
gt = open_file(folder + 'Botswana_gt.mat')['Botswana_gt']
label_values = ["Undefined", "Water", "Hippo grass",
"Floodplain grasses 1", "Floodplain grasses 2",
"Reeds", "Riparian", "Firescar", "Island interior",
"Acacia woodlands", "Acacia shrublands",
"Acacia grasslands", "Short mopane", "Mixed mopane",
"Exposed soils"]
ignored_labels = [0]
elif dataset_name == 'KSC':
# Load the image
img = open_file(folder + 'KSC.mat')['KSC']
rgb_bands = (43, 21, 11) # AVIRIS sensor
gt = open_file(folder + 'KSC_gt.mat')['KSC_gt']
label_values = ["Undefined", "Scrub", "Willow swamp",
"Cabbage palm hammock", "Cabbage palm/oak hammock",
"Slash pine", "Oak/broadleaf hammock",
"Hardwood swamp", "Graminoid marsh", "Spartina marsh",
"Cattail marsh", "Salt marsh", "Mud flats", "Wate"]
ignored_labels = [0]
else:
# 详细见自定义数据集模块
# Custom dataset
img, gt, rgb_bands, ignored_labels, label_values, palette = CUSTOM_DATASETS_CONFIG[dataset_name]['loader'](folder)
这部分是读取下载好的数据集文件,包括3D image
和2D label
。读取不同的数据集文件的操作也有很大的重复性。
每次读取数据集文件的时候,都是干了这几件事情(以dataset_name == 'PaviaC'
为例):
- 读取数据。通过自己写的
open_file()
函数打开C:\Datasets\PaviaU/Pavia.mat
文件,返回值为字典类型,通过['pavia']
来提取键值对的中的值。代码:img = open_file(folder + 'Pavia.mat')['pavia']
- 取RGB波段,但怎么取不知道。代码:
rgb_bands = (55, 41, 12)
- 读取gt。通过自己写的
open_file()
函数打开C:\Datasets\PaviaU/Pavia_gt.mat
文件,返回值为字典类型,通过['pavia_gt']
来提取键值对的中的值。代码:gt = open_file(folder + 'Pavia_gt.mat')['pavia_gt']
- 确定
label_values
。代码:label_values = ["Undefined", "Water", "Trees", "Asphalt", "Self-Blocking Bricks", "Bitumen", "Tiles", "Shadows", "Meadows", "Bare Soil"]
- 确定
ignored_labels
,一般为0。代码:ignored_labels = [0]
**需要注意的是:**当要处理的数据集不是项目预先定义的数据集的时候(即处理的是用户自己的数据集),会在最后的else
来返回值。详见CUSTOM_DATASETS_CONFIG
。
else:
# 详细见自定义数据集模块
# Custom dataset
img, gt, rgb_bands, ignored_labels, label_values, palette = CUSTOM_DATASETS_CONFIG[dataset_name]['loader'](folder)
处理NaN的情况:
# 处理NaN的情况
# Filter NaN out
nan_mask = np.isnan(img.sum(axis=-1))
if np.count_nonzero(nan_mask) > 0:
print("Warning: NaN have been found in the data. It is preferable to remove them beforehand. Learning on NaN data is disabled.")
img[nan_mask] = 0
gt[nan_mask] = 0
ignored_labels.append(0)
ignored_labels = list(set(ignored_labels))
这个情况不常见,暂略。
Normalization 归一化:
# Normalization 归一化
img = np.asarray(img, dtype='float32')
img = (img - np.min(img)) / (np.max(img) - np.min(img))
首先把img
的每个元素的类型变为float32
(img = np.asarray(img, dtype='float32')
),然后归一化操作(img = (img - np.min(img)) / (np.max(img) - np.min(img))
)。
返回值:
return img, gt, label_values, ignored_labels, rgb_bands, palette
img
: 3D hyperspectral image (WxHxB),B为波段。gt
: 2D int array of labelslabel_values
: list of class namesignored_labels
: list of int classes to ignorergb_bands
: int元组,对应红色、绿色和蓝色波段(int tuple that correspond to red, green and blue bands)palette
:默认返回为None。
class HyperX(torch.utils.data.Dataset)
这是高光谱场景的通用的类。
class HyperX(torch.utils.data.Dataset):
类名为HyperX
,继承的父类是torch.utils.data.Dataset
。
__ init__(self, data, gt, **hyperparams):
功能:
对类的属性进行初始化。
输入和输出:
输入:
data
: 3D hyperspectral image 图形gt
: 2D array of labels 标签**hyperparams
:hyperparams
是包含超参数的字典dictionary。**
表示这个位置接收任意多个关键字参数(比如a=1,b=2,c=3,d=4,e=5
等类似于键值对)。**
将多输入的变量,存储为字典dictionary类型 。
输出:
无。
代码和解析:
读取img、gt和超参数
class HyperX(torch.utils.data.Dataset):
""" Generic class for a hyperspectral scene """
def __init__(self, data, gt, **hyperparams): #??? **hyperparams表示接受不定数量的超参数?
"""
Args:
data: 3D hyperspectral image 图形
gt: 2D array of labels 标签
patch_size: int, size of the spatial neighbourhood (int,空间邻域的大小)
center_pixel: bool, set to True to consider only the label of the
center pixel (bool类型,设置为True仅考虑中心像素的标签)
data_augmentation: bool, set to True to perform random flips (数据增强)
supervision: 'full' or 'semi' supervised algorithms (监督方式:监督或半监督)
"""
super(HyperX, self).__init__()
# 读取img
self.data = data
# 读取gt
self.label = gt
# 读取超参数
self.name = hyperparams['dataset']
self.patch_size = hyperparams['patch_size']
self.ignored_labels = set(hyperparams['ignored_labels'])
self.flip_augmentation = hyperparams['flip_augmentation']
self.radiation_augmentation = hyperparams['radiation_augmentation']
self.mixture_augmentation = hyperparams['mixture_augmentation']
self.center_pixel = hyperparams['center_pixel']
supervision = hyperparams['supervision']
读取ignored_labels
的这行代码(self.ignored_labels = set(hyperparams['ignored_labels'])
)有一个set()
函数,是把一个可迭代对象的元素变为集合类型。
一个set()
函数的小demo:
x = set('runoob')
print(x)
# {'r', 'b', 'o', 'u', 'n'}
print(type(x))
# <class 'set'>
其他没什么好说的。
监督方式:
# 监督方式
# Fully supervised : use all pixels with label not ignored. 全监督:使用标签未被忽略的所有像素
if supervision == 'full':
mask = np.ones_like(gt)
for l in self.ignored_labels:
mask[gt == l] = 0
# Semi-supervised : use all pixels, except padding. 半监督:使用除填充之外的所有像素
elif supervision == 'semi':
mask = np.ones_like(gt)
全监督是使用标签未被忽略的所有像素,mask
是将gt
中类别为ignored_labels
的对应位置置零,其他位置置一。
半监督是使用除padding之外的所有像素,mask
是全部置一。
获取索引:
x_pos, y_pos = np.nonzero(mask)
p = self.patch_size // 2
self.indices = np.array([(x,y) for x,y in zip(x_pos, y_pos) if x > p and x < data.shape[0] - p and y > p and y < data.shape[1] - p])
self.labels = [self.label[x,y] for x,y in self.indices]
np.random.shuffle(self.indices)
由于mask
中ignored_labels
为0,所以通过np.nonzero()
获取mask
中的非零元素的索引,返回值为两个array组成的元组,一个array是 x 轴的索引,另一个array是 y 轴的索引。
取p = self.patch_size // 2
(" / “就表示浮点数除法,返回浮点结果;” // "表示整数除法),但为什么这样取并不知道。但我猜测是将图片分块成不同的block。而且我觉得是把整个图片分成了4个block,类似“田”的形状。
下一步是获取指定范围的元素的索引,范围是 x ∈ (p, data.shape[0] - p)
, y ∈ (p, data.shape[1] - p)
。方法是通过zip()
函数来使索引变为(x, y)
的形式,然后遍历这个索引筛选出在x ∈ (p, data.shape[0] - p)
, y ∈ (p, data.shape[1] - p)
中的索引,把值赋给indices
。
然后由self.labels = [self.label[x,y] for x,y in self.indices]
获取相应索引indices
下的标签。
np.random.shuffle(self.indices)
将self.indices
打乱。(但是标签并没有跟着被打乱呀?)
flip(*arrays)
功能:
对输入的多个数组arrays
进行水平或垂直翻转。(我怀疑是数据增强的一种方法)
输入和输出:
输入:
arrays
:多个数组。类型不明。
*
表示这个位置接收任意多个非关键字参数(比如1,2,3,4,5
等值) ,*
将多输入的变量,存储为元组tuple类型。
输出:
arrays
:多个数组。类型不明。
代码和解析:
def flip(*arrays):
horizontal = np.random.random() > 0.5
vertical = np.random.random() > 0.5
if horizontal:
arrays = [np.fliplr(arr) for arr in arrays]
if vertical:
arrays = [np.flipud(arr) for arr in arrays]
return arrays
水平(左右)翻转和垂直(上下)翻转,都是随机,通过生成随机数来实现。代码中默认的概率为0.5。对于p = 0.5如何实现,就是先生成一个随机数,若比0.5大则为True
,反之则为False
。
当horizontal
为True
时,通过遍历arrays
使得每一个arr
都水平(左右)翻转。
当vertical
为True
时,通过遍历arrays
使得每一个arr
都垂直(上下)翻转。
radiation_noise()
功能:
给数据data加上噪音。
输入和输出:
输入:
data
:带处理的数据。alpha_range
:data
的保留范围,默认为(0.9, 1.1)beta
:噪音noise
的保留比例,默认为1/25
输出:
alpha * data + beta * noise
:data
和noise
的组合
代码和解析:
@staticmethod
def radiation_noise(data, alpha_range=(0.9, 1.1), beta=1/25):
alpha = np.random.uniform(*alpha_range)
noise = np.random.normal(loc=0., scale=1.0, size=data.shape)
return alpha * data + beta * noise
首先通过alpha_range
确定alpha
的值,用np.random.uniform(a, b)
函数随机生成(a, b)之间的随机数,*alpha_range
表示接收的是非关键字类型的变量,并将变量拆解。alpha = np.random.uniform(*alpha_range)
。
随机数noise
通过正态分布产生,noise = np.random.normal(loc=0., scale=1.0, size=data.shape)
表示产生均值为0,标准差为1.0,与data
同size的服从正态分布的随机数。
最后返回data
和noise
的组合:alpha * data + beta * noise
。
mixture_noise()
看不懂,暂略。
__ len__()
返回对象的indices
属性的长度。
def __len__(self):
return len(self.indices)
__ getitem__()
功能:
得到以指定位置i
为中心,size为patch_size × patch_size
的图像块block。
同时附加数据增强效果。
这里的item指的是”图像块“。
输入和输出:
输入:
i
:所取位置的索引
输出:
data
:(Batch x) Planes x Channels x Width x Heightlabel
:标签
代码和解析:
获取图像块:
def __getitem__(self, i):
x, y = self.indices[i]
x1, y1 = x - self.patch_size // 2, y - self.patch_size // 2
x2, y2 = x1 + self.patch_size, y1 + self.patch_size
data = self.data[x1:x2, y1:y2]
label = self.label[x1:x2, y1:y2]
这部分是得到以指定位置i
为中心,size为patch_size × patch_size
的图像块block。
大概是这样的原理:
x2, y2 | ||
---|---|---|
x, y | ||
x1, y1 |
由x, y分别减去self.patch_size // 2
得到x1, y1,然后由x1, y1加上self.patch_size
得到x2, y2。
然后通过[x1:x2, y1:y2]
来获取对应的data
和label
的图像块。
数据增强:
if self.flip_augmentation and self.patch_size > 1:
# Perform data augmentation (only on 2D patches)
data, label = self.flip(data, label)
if self.radiation_augmentation and np.random.random() < 0.1:
data = self.radiation_noise(data)
if self.mixture_augmentation and np.random.random() < 0.2:
data = self.mixture_noise(data, label)
这里数据增强并不是默认执行,而是需要self.flip_augmentation == True
或者self.radiation_augmentation== True
或者self.mixture_augmentation== True
。
对于self.flip_augmentation == True
的情况,需要self.patch_size > 1
,而且仅仅对2D的data
执行。
对于self.radiation_augmentation== True
的情况,需要np.random.random() < 0.1
,即只有10%的概率执行操作。
对于self.mixture_augmentation== True
的情况,需要np.random.random() < 0.2
,即只有20%的概率执行操作。
data
和label
转为ndarray类型:
# Copy the data into numpy arrays (PyTorch doesn't like numpy views)
data = np.asarray(np.copy(data).transpose((2, 0, 1)), dtype='float32')
label = np.asarray(np.copy(label), dtype='int64')
这部分将data
和label
转为ndarray类型。来源类型暂不明确。
PyTorch doesn’t like numpy views。numpy的view是channel × row × column,即C × W × H
。所以要通过转置将第三个维度提到第一个维度的位置。
numpy的view是C × W × H
,而PyTorch是W × H × C
。
data = np.asarray(np.copy(data).transpose((2, 0, 1)), dtype='float32')
这一句先拷贝data
,作为ndarray类型,再调用转置函数transpose()
将维度顺序从W × H × C
变为C × W × H
。最后由dtype='float32'
将每个元素的类型设置为float32
。
对于label
,就直接拷贝变ndarray,再设置元素类型为int64
,就Ok了。
ndarray转tensor:
# Load the data into PyTorch tensors
data = torch.from_numpy(data)
label = torch.from_numpy(label)
简单的应用torch.from_numpy()
,没什么好讲的。
Extract the center label if needed:
# Extract the center label if needed
if self.center_pixel and self.patch_size > 1:
label = label[self.patch_size // 2, self.patch_size // 2]
# Remove unused dimensions when we work with invidual spectrums
elif self.patch_size == 1:
data = data[:, 0, 0]
label = label[0, 0]
对于self.center_pixel == True
且self.patch_size > 1
,对于size为self.patch_size × self.patch_size
的label
只取[self.patch_size // 2, self.patch_size // 2]
的位置。
好叭这部分暂略,看不懂。
# Add a fourth dimension for 3D CNN
if self.patch_size > 1:
# Make 4D data ((Batch x) Planes x Channels x Width x Height)
data = data.unsqueeze(0)
这部分是给data
在第一个维度的位置增加一个维度,作为Batch
,即最后的维度顺序是Batch x Channels x Width x Height
。
增加维度的是unsqueeze()
函数,axis = 0
表示在第一个维度的位置增加维度。
返回值:
return data, label
返回值为data, label。类型为tensor。
| | x2, y2 |
| :----: | :–: | :----: |
| | x, y | |
| x1, y1 | | |
由x, y分别减去self.patch_size // 2
得到x1, y1,然后由x1, y1加上self.patch_size
得到x2, y2。
然后通过[x1:x2, y1:y2]
来获取对应的data
和label
的图像块。
数据增强:
if self.flip_augmentation and self.patch_size > 1:
# Perform data augmentation (only on 2D patches)
data, label = self.flip(data, label)
if self.radiation_augmentation and np.random.random() < 0.1:
data = self.radiation_noise(data)
if self.mixture_augmentation and np.random.random() < 0.2:
data = self.mixture_noise(data, label)
这里数据增强并不是默认执行,而是需要self.flip_augmentation == True
或者self.radiation_augmentation== True
或者self.mixture_augmentation== True
。
对于self.flip_augmentation == True
的情况,需要self.patch_size > 1
,而且仅仅对2D的data
执行。
对于self.radiation_augmentation== True
的情况,需要np.random.random() < 0.1
,即只有10%的概率执行操作。
对于self.mixture_augmentation== True
的情况,需要np.random.random() < 0.2
,即只有20%的概率执行操作。
data
和label
转为ndarray类型:
# Copy the data into numpy arrays (PyTorch doesn't like numpy views)
data = np.asarray(np.copy(data).transpose((2, 0, 1)), dtype='float32')
label = np.asarray(np.copy(label), dtype='int64')
这部分将data
和label
转为ndarray类型。来源类型暂不明确。
PyTorch doesn’t like numpy views。numpy的view是channel × row × column,即C × W × H
。所以要通过转置将第三个维度提到第一个维度的位置。
numpy的view是C × W × H
,而PyTorch是W × H × C
。
data = np.asarray(np.copy(data).transpose((2, 0, 1)), dtype='float32')
这一句先拷贝data
,作为ndarray类型,再调用转置函数transpose()
将维度顺序从W × H × C
变为C × W × H
。最后由dtype='float32'
将每个元素的类型设置为float32
。
对于label
,就直接拷贝变ndarray,再设置元素类型为int64
,就Ok了。
ndarray转tensor:
# Load the data into PyTorch tensors
data = torch.from_numpy(data)
label = torch.from_numpy(label)
简单的应用torch.from_numpy()
,没什么好讲的。
Extract the center label if needed:
# Extract the center label if needed
if self.center_pixel and self.patch_size > 1:
label = label[self.patch_size // 2, self.patch_size // 2]
# Remove unused dimensions when we work with invidual spectrums
elif self.patch_size == 1:
data = data[:, 0, 0]
label = label[0, 0]
对于self.center_pixel == True
且self.patch_size > 1
,对于size为self.patch_size × self.patch_size
的label
只取[self.patch_size // 2, self.patch_size // 2]
的位置。
好叭这部分暂略,看不懂。
# Add a fourth dimension for 3D CNN
if self.patch_size > 1:
# Make 4D data ((Batch x) Planes x Channels x Width x Height)
data = data.unsqueeze(0)
这部分是给data
在第一个维度的位置增加一个维度,作为Batch
,即最后的维度顺序是Batch x Channels x Width x Height
。
增加维度的是unsqueeze()
函数,axis = 0
表示在第一个维度的位置增加维度。
返回值:
return data, label
返回值为data, label。类型为tensor。
标签:gt,Classification,Hyperspectral,self,label,GitHub,np,data,size 来源: https://blog.csdn.net/qq_41683065/article/details/100748883