其他分享
首页 > 其他分享> > DenseNet

DenseNet

作者:互联网

DenseNet

代码

from typing import Callable
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from torchvision.datasets import FashionMNIST
from torchvision.transforms import ToTensor,Compose,Resize
import torch.optim as optim
from torch.nn import functional as F
def conv_block(in_channels,out_channels):
    blk=nn.Sequential(
        nn.BatchNorm2d(in_channels),
        nn.ReLU(),
        nn.Conv2d(in_channels,out_channels,kernel_size=3,padding=1)
    )
    return blk
class DenseBlock(nn.Module):
    def __init__(self,num_conv,in_channels,out_channels):
        super(DenseBlock,self).__init__()
        net=[]
        for i in range(num_conv):
            in_c=in_channels+i*out_channels
            net.append(conv_block(in_c,out_channels))
        self.net=nn.ModuleList(net)
        self.out_channels=in_channels+num_conv*out_channels #计算输出通道数
    def forward(self,x):
        for blk in self.net:
            y=blk(x)
            x=torch.cat((x,y),dim=1) #合并通道维
        return x
#过渡层 通道数增加 得用1*1降低通道数
def transition_block(in_channels,out_channels):
    blk=nn.Sequential(
        nn.BatchNorm2d(in_channels),
        nn.ReLU(),
        nn.Conv2d(in_channels,out_channels,kernel_size=1),
        nn.AvgPool2d(kernel_size=2,stride=2)
    )
    return blk
net=nn.Sequential(
    nn.Conv2d(1,64,kernel_size=7,stride=2,padding=3),
    nn.BatchNorm2d(64),
    nn.ReLU(),
    nn.MaxPool2d(kernel_size=3,stride=2,padding=1)
)

num_channels,growth_rate=64,32         #num_channels为当前通道数 growth_rate为增长率->输出通道数
num_convs_in_dense_block=[4,4,4,4]     #卷积块数

for i,num_convs in enumerate(num_convs_in_dense_block):
    DB=DenseBlock(num_convs,num_channels,growth_rate)
    net.add_module(f'DengseBlock_{i}',DB)
    #上一个稠密块的输出通道
    num_channels=DB.out_channels
    #在稠密块之间加入通道数减半的过渡层
    if i !=len(num_convs_in_dense_block)-1:
        net.add_module(f'transition_block_{i}',transition_block(num_channels,num_channels//2))
        num_channels=num_channels//2
        
net.add_module('BN',nn.BatchNorm2d(num_channels))
net.add_module('relu',nn.ReLU())
net.add_module('global_avg_pool',nn.AdaptiveAvgPool2d((1,1)))
net.add_module('fc',nn.Sequential(nn.Flatten(),
                    nn.Linear(num_channels,10)
))

查看网络输出情况

X=torch.rand(1,1,96,96)
for name,layer in net.named_children():
    X=layer(X)
    print(name,'output size',X.shape)

输出:

0 output size torch.Size([1, 64, 48, 48])
1 output size torch.Size([1, 64, 48, 48])
2 output size torch.Size([1, 64, 48, 48])
3 output size torch.Size([1, 64, 24, 24])
DengseBlock_0 output size torch.Size([1, 192, 24, 24])
transition_block_0 output size torch.Size([1, 96, 12, 12])
DengseBlock_1 output size torch.Size([1, 224, 12, 12])
transition_block_1 output size torch.Size([1, 112, 6, 6])
DengseBlock_2 output size torch.Size([1, 240, 6, 6])
transition_block_2 output size torch.Size([1, 120, 3, 3])
DengseBlock_3 output size torch.Size([1, 248, 3, 3])
BN output size torch.Size([1, 248, 3, 3])
relu output size torch.Size([1, 248, 3, 3])
global_avg_pool output size torch.Size([1, 248, 1, 1])
fc output size torch.Size([1, 10])

标签:nn,torch,channels,num,output,DenseNet,size
来源: https://www.cnblogs.com/libraxionghao/p/16230772.html