其他分享
首页 > 其他分享> > PyTorch 剪枝

PyTorch 剪枝

作者:互联网

pytorch 实现剪枝的思路是 生成一个掩码,然后同时保存 原参数、mask、新参数,如下图

 

pytorch 剪枝分为 局部剪枝、全局剪枝、自定义剪枝;

局部剪枝 是对 模型内 的部分模块 的 部分参数 进行剪枝,全局剪枝是对  整个模型进行剪枝;

 

本文旨在记录 pytorch 剪枝模块的用法,首先让我们构建一个模型

import torch
from torch import nn
import torch.nn.utils.prune as prune
import torch.nn.functional as F


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        # 1 input image channel, 6 output channels, 3x3 square conv kernel
        self.conv1 = nn.Conv2d(1, 6, 3)
        self.conv2 = nn.Conv2d(6, 16, 3)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)  # 5x5 image dimension
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(-1, int(x.nelement() / x.shape[0]))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

model = LeNet().to(device=device)

下面对 这个模型进行剪枝

 

局部剪枝

以修剪 第一层卷积  模块 为例

module = model.conv1
print(list(module.named_parameters()))
print(list(module.buffers()))

# 修剪是从 模块 中 删除 参数(如 weight),并用 weight_orig 保存该参数
# random_unstructured 是一种裁剪技术,随机非结构化裁剪
prune.random_unstructured(module, name="weight", amount=0.3)      # weight    bias
print(list(module.named_parameters()))

# 通过修剪技术会创建一个mask命名为 weight_mask 的模块缓冲区
print(list(module.named_buffers()))

# 新的参数保存为模块 的weight属性
print(module.weight)
# print(module.bias)

print(module._forward_pre_hooks)
# OrderedDict([(0, <torch.nn.utils.prune.RandomUnstructured object at 0x000002695EBCEC18>)])

named_parameters() 内 存储的对象 除非手动删除,否则在剪枝过程中对其无影响

 

迭代剪枝

迭代剪枝 是 对 同一模块 进行 多种剪枝,执行逻辑是 顺序执行各剪枝操作

在之前  随机非结构化剪枝 的基础上进行  L1 L2 非结构化剪枝

## 增加一个修剪,看看变化
# l1范数修剪bias中3个最小条目
prune.l1_unstructured(module, name="bias", amount=3)
print(module.bias)
print(module._forward_pre_hooks)
# OrderedDict([(0, <torch.nn.utils.prune.RandomUnstructured object at 0x000002695EBCEC18>),
#              (1, <torch.nn.utils.prune.L1Unstructured object at 0x000002695DE5CEB8>)])

print(list(module.named_parameters()))
print(list(module.named_buffers()))


### 迭代修剪
# 一个模块中的同一参数可以被多次修剪,多次修剪会顺序执行
# 如在之前的基础上,对 weight 参数继续修剪
# l2 结构化裁剪,n=2代表l2,dim=0代表在weight的第0轴进行结构化裁剪
prune.ln_structured(module, name="weight", amount=0.5, n=2, dim=0)

# 查看 weight 参数的 剪枝 操作
for hook in module._forward_pre_hooks.values():
    if hook._tensor_name == "weight":  # select out the correct hook
        break

print(list(hook))
# [<torch.nn.utils.prune.RandomUnstructured object at 0x0000020AE2A6EC18>,
# <torch.nn.utils.prune.LnStructured object at 0x0000020AA872DE80>]

print(module.state_dict().keys())
# odict_keys(['weight_orig', 'bias_orig', 'weight_mask', 'bias_mask'])

 

修剪模型中的多个参数

### 修剪模型中的多个参数
new_model = LeNet()
for name, module in new_model.named_modules():
    # prune 20% of connections in all 2D-conv layers
    if isinstance(module, torch.nn.Conv2d):
        prune.l1_unstructured(module, name='weight', amount=0.2)
    # prune 40% of connections in all linear layers
    elif isinstance(module, torch.nn.Linear):
        prune.l1_unstructured(module, name='weight', amount=0.4)

print(dict(new_model.named_buffers()).keys())  # to verify that all masks exist

 

全局剪枝

以上研究通常被称为“局部”修剪方法,即通过比较每个条目的统计信息(权重,激活度,梯度等)来逐一修剪模型中的张量的做法。

但是,一种常见且可能更强大的技术是通过删除整个模型中最低的 20%的连接,

而不是删除每一层中最低的 20%的连接来修剪模型。

这很可能导致每个层的修剪百分比不同。

让我们看看如何使用torch.nn.utils.prune中的global_unstructured进行操作

model = LeNet()

parameters_to_prune = (
    (model.conv1, 'weight'),
    (model.conv2, 'weight'),
    (model.fc1, 'weight'),
    (model.fc2, 'weight'),
    (model.fc3, 'weight'),
)

prune.global_unstructured(
    parameters_to_prune,
    pruning_method=prune.L1Unstructured,
    amount=0.2,
)
# 检查每个修剪参数的稀疏性,该稀疏性不等于每层中的 20%。 但是,全局稀疏度将(大约)为 20%

 

自定义剪枝

见  参考资料3

 

训练中剪枝实例

见参考资料1

 

 

 

 

参考资料:

https://blog.csdn.net/qq_40268672/article/details/108631518  pytorch剪枝实战     训练时剪枝,类似 dropout 

https://blog.csdn.net/ssunshining/article/details/125121066  PyTorch--模型剪枝案例

https://www.w3cschool.cn/pytorch/pytorch-rnmi3bti.html  PyTorch 修剪教程

标签:剪枝,修剪,prune,weight,module,PyTorch,print
来源: https://www.cnblogs.com/yanshw/p/16592678.html