将pytorch的网络等转移到cuda
作者:aleien1 时间:2023-08-10 08:33:46
神经网络一般用GPU来跑,我们的神经网络框架一般也都安装的GPU版本,本文就简单记录一下GPU使用的编写。
GPU的设置不在model,而是在Train的初始化上。
第一步是查看是否可以使用GPU
self.GPU_IN_USE = torch.cuda.is_available()
就是返回这个可不可以用GPU的函数,当你的pytorch是cpu版本的时候,他就会返回False。
然后是:
self.device = torch.device('cuda' if self.GPU_IN_USE else 'cpu')
torch.device是代表将torch.tensor分配到哪个设备的函数
接着是,我看到了一篇文章,原来就是将网络啊、数据啊、随机种子啊、损失函数啊、等等等等直接转移到CUDA上就好了!
于是下面就好理解多了:
转移模型:
self.model = Net(num_channels=1, upscale_factor=self.upscale_factor, base_channel=64, num_residuals=4).to(self.device)
设置cuda的随机种子:
torch.cuda.manual_seed(self.seed)
转移损失函数:
self.criterion.cuda()
转移数据:
data, target = data.to(self.device), target.to(self.device)
pytorch 网络定义参数的后面无法加.cuda()
pytorch定义网络__init__()的时候,参数不能加“cuda()", 不然参数不包含在state_dict()中,比如下面这种写法是错误的
self.W1 = nn.Parameter(torch.FloatTensor(3,3), requires_grad=True).cuda()
应该去掉".cuda()"
self.W1 = nn.Parameter(torch.FloatTensor(3,3), requires_grad=True)
来源:https://blog.csdn.net/weixin_42128941/article/details/103048866
标签:pytorch,网络,cuda
0
投稿
猜你喜欢
CSS控制鼠标样式变换方法
2007-11-17 07:58:00
python中创建一个包并引用使用的操作方法
2023-05-19 03:06:09
python实现去除下载电影和电视剧文件名中的多余字符的方法
2022-08-17 16:17:59
Selenium常见异常解析及解决方案示范
2023-06-27 20:09:48
Oracle、MySQL和SqlServe三种数据库分页查询语句的区别介绍
2024-01-15 08:48:41
教你使用vue-autofit 一行代码搞定自适应可视化大屏
2024-05-09 09:05:53
python爬取51job中hr的邮箱
2022-11-06 14:00:54
解决python2 绘图title,xlabel,ylabel出现中文乱码的问题
2022-05-27 16:06:04
Python面向对象编程关键深度探索类与对象
2021-12-07 03:06:20
Python 使用元类type创建类对象常见应用详解
2023-09-15 23:07:57
浅谈Python数据处理csv的应用小结
2021-06-10 01:40:41
Go1.18新特性使用Generics泛型进行流式处理
2024-02-17 07:12:42
关于Pytorch中模型的保存与迁移问题
2023-08-11 04:05:25
Vue集成lodop插件实现打印功能
2023-07-02 17:01:20
python3反转字符串的3种方法(小结)
2022-05-03 22:59:01
python 利用turtle库绘制笑脸和哭脸的例子
2022-01-16 08:35:30
mysql 无限级分类实现思路
2024-01-19 09:18:35
ASP中数据库调用中常见错误的现象和解决
2007-09-20 13:24:00
Python自动化之定位方法大杀器xpath
2023-11-22 05:08:57
用SQL语句添加删除修改字段、一些表与字段的基本操作、数据库备份等
2011-12-01 07:53:11