在预测的时候,加with torch.no_grad():
with torch.no_grad():for i,batch in enumerate(test_loader):pass
with torch.no_grad():
的主要作用是在指定的代码块中暂时禁用梯度计算,进行推理时,我们不需要计算梯度,只关心模型的输出。
参考:DEBUG:pytorch训练时候没有问题,预测时候内存迅速增长,爆掉,out of memory的解决办法_模型训练没问题,但测试时报内存不足-CSDN博客