其他分享
首页 > 其他分享> > HybridSN代码修改

HybridSN代码修改

作者:互联网

研究小组里初学习深度学习的同学我都布置写过 HybridSN 的代码:https://github.com/OUCTheoryGroup/colab_demo/blob/master/202003_models/HybridSN_GRSL2020.ipynb

最近做SRDP的同学反映,跑 Pavia 数据集的时候内存会爆,主要原因是 createImageCubes 这个函数有个地方:

patchesData = np.zeros([ width*height, windowSize, windowsSize, spectral_num])

因为 Pavia 数据集尺寸较大,width*height 就比较大了,内存会爆掉。

其实图像中大部分为是0,没有label,我们要取的,只是有label的部分。现在我改了改,先做个循环,看看有多少个像素有 label,然后记录在 count 里,分配内存时:

patchesData = np.zeros([count, windowSize, windowSize, spectral_num])

这样 count 比以前的 width*height 要小很多,内存就不会爆了

修改后的代码如下,供感兴趣的同学参考(Github上的我就不改了,留给以后新同学排雷):

# 在每个像素周围提取 patch 
def createImageCubes(X, y, windowSize=5, removeZeroLabels = True):
    # 给 X 做 padding
    margin = int((windowSize - 1) / 2)
    zeroPaddedX = padWithZeros(X, margin=margin)
    # 获得 y 中的标记样本数
    count = 0
    for r in range(0, y.shape[0]):
        for c in range(0, y.shape[1]):
            if y[r, c] != 0:
                count = count+1

    # split patches
    patchesData = np.zeros([count, windowSize, windowSize, X.shape[2]])
    patchesLabels = np.zeros(count)

    count = 0
    for r in range(margin, zeroPaddedX.shape[0] - margin):
        for c in range(margin, zeroPaddedX.shape[1] - margin):
            if y[r-margin, c-margin] != 0:
                patch = zeroPaddedX[r - margin:r + margin + 1, c - margin:c + margin + 1]   
                patchesData[count, :, :, :] = patch
                patchesLabels[count] = y[r-margin, c-margin]
                count = count + 1

    return patchesData, patchesLabels

标签:count,HybridSN,patchesData,代码,windowSize,修改,shape,range,margin
来源: https://www.cnblogs.com/gaopursuit/p/15852357.html