您的位置:首页 > 健康 > 养生 > 网络剪枝——network-slimming 项目复现

网络剪枝——network-slimming 项目复现

2024/10/6 20:32:30 来源:https://blog.csdn.net/qq_63512036/article/details/141091990  浏览:    关键词:网络剪枝——network-slimming 项目复现

目录

文章目录

  • 目录
  • 网络剪枝——network-slimming 项目复现
    • clone 存储库
    • Baseline
      • vgg
        • 训练
        • 结果
      • resnet
        • 训练
        • 结果
      • densenet
        • 训练
        • 结果
    • Sparsity
      • vgg
        • 训练
        • 结果
      • resnet
        • 训练
        • 结果
      • densenet
        • 训练
        • 结果
    • Prune
      • vgg
        • 命令
        • 结果
      • resnet
        • 命令
        • 结果
      • densenet
        • 命令
        • 结果
    • Fine-tune
      • vgg
        • 训练
        • 结果
      • resnet
        • 训练
        • 结果
      • densenet
        • 训练
        • 结果
    • 模型大小计算脚本 param_counter.py
    • 结果汇总
      • CIFAR10

网络剪枝——network-slimming 项目复现

  • 【GiHnub】:Eric-mingjie/network-slimming: Network Slimming (Pytorch) (ICCV 2017) (github.com)
  • 【作者复现项目】:
  • 通过百度网盘分享的文件:network-slimming-regin.zip
    链接:https://pan.baidu.com/s/1vTJSLS5ZDjE8R8XaApW96A?pwd=t1z2
    提取码:t1z2
    • 仅以 CIFAR-10 为例,CIFAR-100 同理.
    • 提供中文README_zh-CN.md.
    • 包含 CIFAR-10/100 数据集data.cifar10data.cifar100.
    • 解决了 main.py 运行报错问题.
    • 加入了计算训练后模型的 Parameters 大小脚本param_counter.py.

clone 存储库

注:若 clone 作者复现项目,则忽略这一步,直接进入下一步;若想自行从头复现,则 clone 以下存储库.

  • 链接:https://pan.baidu.com/s/1nppPLKoiPbJPW60HOa2TxQ?pwd=ud89
    提取码:ud89


Baseline

vgg

训练
  • 【命令】:
python main.py --dataset cifar10 --arch vgg --depth 19

  • 这个报错通常出现在使用 Python 的multiprocessing库来创建进程时,尤其是在 Windows 操作系统上. 在 Windows 上,Python 的multiprocessing模块启动新进程的方式与 Linux 或 macOS 不同,它使用 “spawn” 来启动新进程,这意味着每个子进程都会从头开始执行脚本. 因此,如果在脚本顶层级别启动进程(而不是在受保护的if __name__ == '__main__':块中),每个子进程都会尝试再次启动子进程,从而导致无限递归和上述错误.
  • 为了解决这个问题,应 确保多进程代码(即main.py)位于if __name__ == '__main__':保护块内.
# 导入部分
...def main():...if __name__ == '__main__':main()
  • 再次运行命令,又报错:

  • 这个报错通常发生在尝试直接索引一个0维的张量(tensor)时. 在 PyTorch 中,0 维张量是一个单一值的张量,但是不能像普通的数组那样通过索引来访问。要从 0 维张量中获取其 Python 数值,需要使用.item()方法.
  • 为了解决这个问题,应该 使用.item()方法来替换所有.data[0]的用法
# 在 train 函数中
if batch_idx % args.log_interval == 0:print('Train Epoch: {} [{}/{} ({:.1f}%)]\tLoss: {:.6f}'.format(epoch, batch_idx * len(data), len(train_loader.dataset),100. * batch_idx / len(train_loader), loss.item()))# 在 test 函数中
for data, target in test_loader:if args.cuda:data, target = data.cuda(), target.cuda()data, target = Variable(data), Variable(target)output = model(data)test_loss += F.cross_entropy(output, target, reduction='sum').item()  # sum up batch losspred = output.data.max(1, keepdim=True)[1]correct += pred.eq(target.data.view_as(pred)).cpu().sum()test_loss /= len(test_loader.dataset)
  • 再次运行命令就正常运行了:

结果
  • Terminal

  • 在 ./logs 生成文件checkpoint.pth.tarmodel_best.pth.tar

resnet

训练
  • 【命令】:
python main.py --dataset cifar10 --arch resnet --depth 164
结果

densenet

训练
  • 【命令】:
python main.py --dataset cifar10 --arch densenet --depth 40
结果


Sparsity

vgg

训练
  • 【命令】:
python main.py -sr --s 0.0001 --dataset cifar10 --arch vgg --depth 19
结果

resnet

训练
  • 【命令】:
python main.py -sr --s 0.00001 --dataset cifar10 --arch resnet --depth 164
结果

densenet

训练
  • 【命令】:
python main.py -sr --s 0.00001 --dataset cifar10 --arch densenet --depth 40
结果


Prune

vgg

命令
python vggprune.py --dataset cifar10 --depth 19 --percent 0.7 --model ./results/CIFAR10_results/CIFAR10-Vgg/Sparsity/model_best.pth.tar --save ./prunes

  • main.py同理,为了解决这个问题,应 确保多进程代码位于if __name__ == '__main__':保护块内
# 导入部分
...def main():...if __name__ == '__main__':main()
  • 之后就可以正常运行了.

结果
  • Terminal

  • 在./prunes生成文件prune.txtpruned.pth.tar

  • prune.txt中我们可以看到 Number of parametersTest accuracy

resnet

命令
python resprune.py --dataset cifar10 --depth 164 --percent 0.4 --model ./results/CIFAR10_results/CIFAR10-Resnet-164/Sparsity/model_best.pth.tar --save ./prunes
结果

densenet

命令
python denseprune.py --dataset cifar10 --depth 40 --percent 0.4 --model ./results/CIFAR10_results/CIFAR10-Densenet-40/Sparsity/model_best.pth.tar --save ./prunes
结果


Fine-tune

vgg

训练
  • 【命令】:
python main.py --refine ./results/CIFAR10_results/CIFAR10-Vgg/Prune/pruned.pth.tar --dataset cifar10 --arch vgg --depth 19 --epochs 160
结果

resnet

训练
  • 【命令】:
python main.py --refine ./results/CIFAR10_results/CIFAR10-Resnet-164/Prune/pruned.pth.tar --dataset cifar10 --arch resnet --depth 164 --epochs 160
结果

densenet

训练
  • 【命令】:
python main.py --refine ./results/CIFAR10_results/CIFAR10-Densenet-40/Prune/pruned.pth.tar --dataset cifar10 --arch densenet --depth 40 --epochs 160
结果


模型大小计算脚本 param_counter.py

  • 【路径】:./script/param_counter.py
import torchdef load_model(model_path):model = torch.load(model_path, map_location=torch.device('cpu'))return modeldef count_parameters(model_state_dict):total_params = sum(p.numel() for p in model_state_dict.values())return total_paramsdef get_model_parameters(model_path):# 加载模型状态字典model = load_model(model_path)# 模型状态字典存储在 'state_dict' 键下model_state_dict = model['state_dict'] if 'state_dict' in model else model# 计算参数总数total_params = count_parameters(model_state_dict)return total_params
  • main.py中:
from script.param_counter import get_model_parametersdef main():...# 计算 Parametersmodel_path = 'logs/model_best.pth.tar'total_params = get_model_parameters(model_path)print(f'Total parameters in the model: {total_params}')

结果汇总

注:与原项目结果略有差别.

CIFAR10

CIFAR10-VggBaselineSparsity(1e-4)Prune(70%)Fine-tune-160(70%)
Top1 Accuracy(%)93.7293.6033.9893.75
Parameters20.05M20.05M2.22M2.23M
CIFAR10-Resnet-164BaselineSparsity(1e-5)Prune(40%)Fine-tune-160(40%)
Top1 Accuracy(%)94.9995.0094.5995.27
Parameters1.74M1.74M1.46M1.49M
CIFAR10-Densenet-40BaselineSparsity(1e-5)Prune(40%)Fine-tune-160(40%)
Top1 Accuracy(%)94.1594.3794.1494.48
Parameters1.09M1.09M0.70M0.72M

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com