设置 seed,以使模型结果可重复。
目录
1. 原理
1.1 伪随机数生成器
1.2 深度学习与随机种子
2. 代码
1. 原理
设置随机种子的目的是为了在使用伪随机数生成器(PRNG, Pseudorandom Number Generator)时,使得生成的随机数序列是可重复的。
1.1 伪随机数生成器
伪随机数生成器 (PRNG, Pseudorandom Number Generator),也称为确定性随机位生成器 (DRBG, Deterministic Random Bit Generator),是一种生成数字序列的算法,其属性近似于随机数序列的属性。PRNG 生成的序列不是真正的随机序列,因为它完全由一个初始值决定,称为 PRNG 的种子(可能包括真正的随机值)。尽管可以使用硬件随机数生成器生成更接近真正随机的序列,但伪随机数生成器在实践中因其数字生成速度和可重复性而很重要。
1.2 深度学习与随机种子
这里引用某乎用户丹尼尔小博士的说法:
深度学习网络模型中初始的权值参数通常都是初始化成随机数,而使用梯度下降法最终得到的局部最优解对于初始位置点的选择很敏感。为了能够完全复现作者的开源深度学习代码,随机种子的选择能够减少一定程度上,算法结果的随机性,也就是更接近于原始作者的结果,即产生随机种子意味着每次运行实验,产生的随机数都是相同的。
2. 代码
这里参考 transformers.set_seed。该代码默认兼容 torch。
# coding=utf-8
# @Author: Fulai Cui (cuifulai@mail.hfut.edu.cn)
# @Time: 2024/9/18 21:35
import randomimport numpy as np
import torchdef set_seed(seed: int, deterministic: bool = False):"""Helper function for reproducible behavior to set the seed in `random`, `numpy`, `torch` and/or `tf` (if installed).Args:seed (`int`):The seed to set.deterministic (`bool`, *optional*, defaults to `False`):Whether to use deterministic algorithms where available. Can slow down training."""random.seed(seed)np.random.seed(seed)torch.manual_seed(seed)torch.cuda.manual_seed_all(seed)# ^^ safe to call this function even if cuda is not availableif deterministic:torch.use_deterministic_algorithms(True)try:import tensorflow as tftf.random.set_seed(seed)if deterministic:tf.config.experimental.enable_op_determinism()except ImportError:passdef main():set_seed(42)if __name__ == '__main__':main()