其他分享
首页 > 其他分享> > 1-cascade MRI reconstruction: dataset.py

1-cascade MRI reconstruction: dataset.py

作者:互联网

首先配置文件不能少config.yaml

# Model Parameters
network:
        num_cascades: 6
        num_layers: 5 # Number of layers in the CNN per cascade
        num_filters: 64
        kernel_size: 3
        stride: 1
        padding: 1 #A padding of 1 is needed to keep the image in the same size
        noise: null #Noise in the measurements. To be used in the data consistency step

#Dataset parameters
dataset:
        data_path: 'data1/'
        acceleration_factor: 4.0
        fraction: 0.8 #train set size
        shuffle: 3 #Seed for numpy random generator
        sample_n: 10
        acq_noise: 0 #acquisation noise
        centred: False
        norm: 'ortho'  #norm: 'ortho' or null. if 'ortho', performs unitary transform, otherwise normal dft

# Training parameters
train:
        batch_size: 1
        num_epochs: 5
        early_stop: 100

        # Adam Optimizer Parameters
        learning_rate: 0.001
        b_1: 0.9
        b_2: 0.999
        l2: 0.0000001

        # Miscellaneous
        output_path: 'logs'
        cuda: False

单步执行

import os
import torch
import numpy as np
from math import ceil
from helpers_1 import *
from scipy.io import loadmat
from numpy.lib.stride_tricks import as_strided
import yaml
args = yaml.load(open('config.yaml', 'r'), Loader=yaml.FullLoader)

在这里插入图片描述
dataset = OCMRDataset(fold=‘train’, **args[‘dataset’])

        self.evalset = evalset
        self.data_path = data_path
        self.acc = acceleration_factor
        self.sample_n = sample_n
        self.noise = acq_noise
        self.centred = centred
        self.norm = norm
        self.files = os.listdir(self.data_path)
        if shuffle:
            np.random.seed(shuffle)
            np.random.shuffle(self.files) 
        if fold == 'train':
            self.files = self.files[:int(len(self.files) * fraction)]

在这里插入图片描述

    def __getitem__(self, idx):
        if self.evalset and idx == 0:
            np.random.seed(9001)
        data = loadmat(os.path.join(self.data_path, self.files[idx]))['xn'] * 1e3

在这里插入图片描述

        data = np.expand_dims(data, 0)

因为这里batch_size设的1,所以就有一个256*256
在这里插入图片描述

        mask = self.cartesian_mask(data.shape)
    def cartesian_mask(self, shape):
        N, Nx, Ny = int(np.prod(shape[:-2])), shape[-2], shape[-1]
        pdf_x = normal_pdf(Nx, 0.5/(Nx/10.)**2)

在这里插入图片描述

def normal_pdf(length, sensitivity):
    return np.exp(-sensitivity * (np.arange(length) - length / 2)**2)

在这里插入图片描述
在这里插入图片描述

        lmda = Nx/(2.*self.acc)
        n_lines = int(Nx / self.acc)

        # add uniform distribution
        pdf_x += lmda * 1./Nx

        if self.sample_n:
            pdf_x[Nx//2-self.sample_n//2:Nx//2+self.sample_n//2] = 0
            pdf_x /= np.sum(pdf_x)
            n_lines -= self.sample_n

在这里插入图片描述
在这里插入图片描述

        mask = np.zeros((N, Nx))
        for i in range(N):
            idx = np.random.choice(Nx, n_lines, False, pdf_x)
            mask[i, idx] = 1

在这里插入图片描述
在这里插入图片描述

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

        if self.sample_n:
            mask[:, Nx//2-self.sample_n//2:Nx//2+self.sample_n//2] = 1

        size = mask.itemsize
        mask = as_strided(mask, (N, Nx, Ny), (size * Nx, size, 0))
        mask = mask.reshape(shape)

在这里插入图片描述
mask上下很多行都是0,中间位置1越来越密集
在这里插入图片描述

        if not self.centred:
            mask = ifftshift(mask, axes=(-1, -2))

        return mask

反过来了1变0,0变1
在这里插入图片描述
在这里插入图片描述

        data_und, k_und = self.undersample(data, mask)

在这里插入图片描述

        assert x.shape == mask.shape
        # zero mean complex Gaussian noise
        noise_power = self.noise
        nz = np.sqrt(.5)*(np.random.normal(0, 1, x.shape) + 1j * np.random.normal(0, 1, x.shape))
        nz = nz * np.sqrt(noise_power)

在这里插入图片描述

        if self.norm == 'ortho':
            # multiplicative factor
            nz = nz * np.sqrt(np.prod(mask.shape[-2:]))

在这里插入图片描述

        if self.centred:
            x_f = fft2c(x, norm=self.norm)
            x_fu = mask * (x_f + nz)
            x_u = ifft2c(x_fu, norm=self.norm)
            return x_u, x_fu
        else:
            x_f = fft2(x, norm=self.norm)
            x_fu = mask * (x_f + nz)
            x_u = ifft2(x_fu, norm=self.norm)
            return x_u, x_fu

在这里插入图片描述
在这里插入图片描述
data_und, k_und = x_u, x_fu
在这里插入图片描述
在这里插入图片描述

        data_gnd = format_data(data)
def format_data(data, mask=False):
    if mask: 
        data = data * (1+1j)
    data = complex2real(data)
def complex2real(x):
	x_real = np.real(x)

在这里插入图片描述

x_imag = np.imag(x)

在这里插入图片描述

y = np.array([x_real, x_imag]).astype(np.float)

在这里插入图片描述

    if x.ndim >= 3:
        y = y.swapaxes(0, 1)
    return y

在这里插入图片描述

def format_data(data, mask=False):

    data = complex2real(data)
    return data.squeeze(0)

在这里插入图片描述
data_gnd = data.squeeze(0)
在这里插入图片描述

        data_und = format_data(data_und)
def format_data(data, mask=False):

    data = complex2real(data)
    return data.squeeze(0)
def complex2real(x):
    x_real = np.real(x)
    x_imag = np.imag(x)
    y = np.array([x_real, x_imag]).astype(np.float)
    # re-order in convenient order
    if x.ndim >= 3:
        y = y.swapaxes(0, 1)
    return y

data = complex2real(data)即等于y
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
data_und = data.squeeze(0)
在这里插入图片描述

        k_und = format_data(k_und)
def format_data(data, mask=False):

    data = complex2real(data)
    return data.squeeze(0)
def complex2real(x):
    x_real = np.real(x)
    x_imag = np.imag(x)
    y = np.array([x_real, x_imag]).astype(np.float)
    # re-order in convenient order
    if x.ndim >= 3:
        y = y.swapaxes(0, 1)
    return y

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
k_und = data.squeeze(0)
在这里插入图片描述

        mask = format_data(mask, mask=True)
def format_data(data, mask=False):

    data = complex2real(data)
    return data.squeeze(0)
def complex2real(x):
    x_real = np.real(x)
    x_imag = np.imag(x)
    y = np.array([x_real, x_imag]).astype(np.float)
    # re-order in convenient order
    if x.ndim >= 3:
        y = y.swapaxes(0, 1)
    return y

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
data = complex2real(data)即等于y
在这里插入图片描述
mask = data.squeeze(0)
在这里插入图片描述

        return {
            'image': data_und,
            'k': k_und.transpose(1,2,0),
            'mask': mask.transpose(1,2,0),
            'full': data_gnd
        }

sample = dataset[0]即4个tensor
第一个tensor。‘image’: data_und
第二个tensor。‘k’: k_und.transpose(1,2,0)
在这里插入图片描述
第三个tensor。‘mask’: mask.transpose(1,2,0)
在这里插入图片描述
第四个tensor。‘full’: data_gnd

输出:
Sample image shape: (2, 256, 256)
Sample full shape: (2, 256, 256)
在这里插入图片描述

标签:self,py,mask,Nx,cascade,und,MRI,np,data
来源: https://blog.csdn.net/xuru_0927/article/details/118654663