此功能处于测试阶段。
剪枝算法通常都用权重掩码来模拟实际的剪枝。 掩码可以用来检查某个剪枝(或稀疏)算法的模型性能,但还没有真正加速。 模型加速才是模型剪枝的最终目标。因此提供了此工具,来帮助基于用户提供的掩码(掩码来自于剪枝算法),将已有模型转换成小模型。
有两种剪枝算法。 一种是细粒度的剪枝,不改变权重形状,和输入输出的张量。 稀疏内核会被用来加速细粒度剪枝的层。 另一类是粗粒度的剪枝(例如,通道),通常,权重形状,输入输出张量会有所改变。 要加速这类剪枝算法,不需要使用系数内核,只需要用更小的层来替换。 由于开源社区中对稀疏内核的支持还比较有限,当前仅支持粗粒度剪枝,会在将来再支持细粒度的剪枝算法。
为了加速模型,被剪枝的层应该被替换掉,要么为粗粒度掩码使用较小的层,要么用稀疏内核来替换细粒度的掩码。 粗粒度掩码通常会改变权重的形状,或输入输出张量,因此,应该通过形状推断,来检查是否其它未被剪枝的层由于形状变化而需要改变形状。 因此,在设计中,主要有两个步骤:第一,做形状推理,找出所有应该替换的模块;第二,替换模块。 第一步需要模型的拓扑(即连接),我们使用了 jit.trace
来获取 PyTorch 的模型图。
对于每个模块,要准备四个函数,三个用于形状推理,一个用于模块替换。 三个形状推理函数是:给定权重形状推断输入/输出形状,给定输入形状推断权重/输出形状,给定输出形状推断权重/输入形状。 模块替换功能返回一个较小的新创建的模块。
from nni.compression.torch import ModelSpeedup
# model: 要加速的模型
# dummy_input: 模型的示例输入,传给 `jit.trace`
# masks_file: 剪枝算法创建的掩码文件
m_speedup = ModelSpeedup(model, dummy_input.to(device), masks_file)
m_speedup.speedup_model()
dummy_input = dummy_input.to(device)
start = time.time()
out = model(dummy_input)
print('elapsed time: ', time.time() - start)
完整示例参考这里
注意:当前支持 PyTorch 1.3.1 或更高版本。
由于每个模块需要 4 个函数用于形状推理和模块替换,因此工作量较大,当前仅实现了示例所需的函数。 如果要加速自己的模型,但当前不支持,欢迎贡献。
对于 PyTorch,仅提供了替换模块,如果是在 forward
中的函数,当前不支持。 一种解决方案是将函数变为 PyTorch 模块。
实验代码可在这里找到。
在一块 V100 GPU 上, 输入张量:torch.randn(64, 3, 32, 32)
次数 | 掩码时延 | 加速后的时延 |
---|---|---|
1 | 0.01197 | 0.005107 |
2 | 0.02019 | 0.008769 |
4 | 0.02733 | 0.014809 |
8 | 0.04310 | 0.027441 |
16 | 0.07731 | 0.05008 |
32 | 0.14464 | 0.10027 |
在 CPU 上, 输入张量:torch.randn(64, 1, 28, 28)
, 方差较大
次数 | 掩码时延 | 加速后的时延 |
---|---|---|
1 | 0.01383 | 0.01839 |
2 | 0.01167 | 0.003558 |
4 | 0.01636 | 0.01088 |
40 | 0.14412 | 0.08268 |
40 | 1.29385 | 0.14408 |
40 | 0.41035 | 0.46162 |
400 | 6.29020 | 5.82143 |
在一块 V100 GPU 上, 输入张量:torch.randn(64, 3, 32, 32)
次数 | 掩码时延 | 加速后的时延 |
---|---|---|
1 | 0.01026 | 0.003677 |
2 | 0.01657 | 0.008161 |
4 | 0.02458 | 0.020018 |
8 | 0.03498 | 0.025504 |
16 | 0.06757 | 0.047523 |
32 | 0.10487 | 0.086442 |
在一块 V100 GPU 上, 输入张量:torch.randn(64, 3, 32, 32)
次数 | 掩码时延 | 加速后的时延 |
---|---|---|
1 | 0.01389 | 0.004208 |
2 | 0.01628 | 0.008310 |
4 | 0.02521 | 0.014008 |
8 | 0.03386 | 0.023923 |
16 | 0.06042 | 0.046183 |
32 | 0.12421 | 0.087113 |