pytorch中如何设置随机种子

作者:大虾飞哥哥 时间:2021-10-24 06:43:51 

pytorch设置随机种子

pytorch设置随机种子 - 保证复现模型所有的训练过程

在使用 PyTorch 时,如果希望通过设置随机数种子,在 GPU 或 CPU 上固定每一次的训练结果,则需要在程序执行的开始处添加以下代码:

def seed_everything():
   '''
   设置整个开发环境的seed
   :param seed:
   :param device:
   :return:
   '''
   import os
   import random
   import numpy as np

random.seed(seed)
   os.environ['PYTHONHASHSEED'] = str(seed)
   np.random.seed(seed)
   torch.manual_seed(seed)
   torch.cuda.manual_seed(seed)
   torch.cuda.manual_seed_all(seed)

# some cudnn methods can be random even after fixing the seed
   # unless you tell it to be deterministic
   torch.backends.cudnn.deterministic = True

pytorch/tensorflow设置随机种子 ,保证结果复现

Pytorch随机种子设置

import numpy as np
import random
import os
import torch
def seed_torch(seed=2021):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.enabled = False
seed_torch()

Tensorflow设置随机种子

第一步 仅导入设置种子和初始化种子值所需的那些库

import tensorflow as tf
import os
import numpy as np
import random

SEED = 0

第二步 为所有可能具有随机行为的库初始化种子的函数

def set_seeds(seed=SEED):
    os.environ['PYTHONHASHSEED'] = str(seed)
    random.seed(seed)
    tf.random.set_seed(seed)
    np.random.seed(seed)

第三步 激活 Tensorflow 确定性功能

def set_global_determinism(seed=SEED):
    set_seeds(seed=seed)

    os.environ['TF_DETERMINISTIC_OPS'] = '1'
    os.environ['TF_CUDNN_DETERMINISTIC'] = '1'
    
    tf.config.threading.set_inter_op_parallelism_threads(1)
    tf.config.threading.set_intra_op_parallelism_threads(1)

# Call the above function with seed value
set_global_determinism(seed=SEED)

来源:https://blog.csdn.net/xu624735206/article/details/124999824

标签:pytorch,随机种子
0
投稿

猜你喜欢

  • Python实现的根据IP地址计算子网掩码位数功能示例

    2021-08-30 21:21:31
  • Python强大的自省机制详解

    2021-06-07 02:07:57
  • MySQL数据表使用的SQL语句整理

    2024-01-20 07:13:03
  • PyTorch中的CUDA的操作方法

    2022-02-24 18:54:41
  • JS字符串累加Array不一定比字符串累加快(根据电脑配置)

    2024-05-02 16:10:18
  • sqlserver 禁用触发器和启用触发器的语句

    2024-01-19 21:38:17
  • vue开发chrome插件,实现获取界面数据和保存到数据库功能

    2024-01-19 03:18:57
  • 注册表单之死

    2008-08-07 13:02:00
  • VB应用程序访问SQL Server的常用方法

    2009-01-21 14:28:00
  • python中urllib.unquote乱码的原因与解决方法

    2023-08-24 14:56:43
  • win10下tensorflow和matplotlib安装教程

    2023-03-23 21:47:27
  • 隐藏你的.php文件的实现方法

    2023-10-20 22:58:01
  • JS实现图片手风琴效果

    2023-08-23 19:28:27
  • python爬虫获取淘宝天猫商品详细参数

    2021-06-08 09:27:29
  • python scipy求解非线性方程的方法(fsolve/root)

    2022-01-06 15:46:00
  • 使用Python解决常见格式图像读取nii,dicom,mhd

    2021-11-14 23:36:59
  • Python使用sorted对字典的key或value排序

    2023-12-12 06:36:53
  • axios 拦截器管理类链式调用手写实现及原理剖析

    2023-07-02 16:38:23
  • go使用makefile脚本编译应用的方法小结

    2024-04-25 15:17:57
  • python画图时linestyle,color和loc参数的设置方式

    2021-07-03 16:15:07
  • asp之家 网络编程 m.aspxhome.com