Pytorch 使用tensor特定条件判断索引
作者:judgechen1997 发布时间:2023-01-18 16:30:23
torch.where() 用于将两个broadcastable的tensor组合成新的tensor,类似于c++中的三元操作符“?:”
区别于python numpy中的where()直接可以找到特定条件元素的index
想要实现numpy中where()的功能,可以借助nonzero()
对应numpy中的where()操作效果:
补充:Pytorch torch.Tensor.detach()方法的用法及修改指定模块权重的方法
detach
detach的中文意思是分离,官方解释是返回一个新的Tensor,从当前的计算图中分离出来
需要注意的是,返回的Tensor和原Tensor共享相同的存储空间,但是返回的 Tensor 永远不会需要梯度
import torch as t
a = t.ones(10,)
b = a.detach()
print(b)
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])
那么这个函数有什么作用?
–假如A网络输出了一个Tensor类型的变量a, a要作为输入传入到B网络中,如果我想通过损失函数反向传播修改B网络的参数,但是不想修改A网络的参数,这个时候就可以使用detcah()方法
a = A(input)
a = detach()
b = B(a)
loss = criterion(b, target)
loss.backward()
来看一个实际的例子:
import torch as t
x = t.ones(1, requires_grad=True)
x.requires_grad #True
y = t.ones(1, requires_grad=True)
y.requires_grad #True
x = x.detach() #分离之后
x.requires_grad #False
y = x+y #tensor([2.])
y.requires_grad #我还是True
y.retain_grad() #y不是叶子张量,要加上这一行
z = t.pow(y, 2)
z.backward() #反向传播
y.grad #tensor([4.])
x.grad #None
以上代码就说明了反向传播到y就结束了,没有到达x,所以x的grad属性为None
既然谈到了修改模型的权重问题,那么还有一种情况是:
–假如A网络输出了一个Tensor类型的变量a, a要作为输入传入到B网络中,如果我想通过损失函数反向传播修改A网络的参数,但是不想修改B网络的参数,这个时候又应该怎么办了?
这时可以使用Tensor.requires_grad属性,只需要将requires_grad修改为False即可.
for param in B.parameters():
param.requires_grad = False
a = A(input)
b = B(a)
loss = criterion(b, target)
loss.backward()
以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。如有错误或未考虑完全的地方,望不吝赐教。
来源:https://blog.csdn.net/judgechen1997/article/details/105820709
猜你喜欢
- CLI工程全局安装vue-clinpm install -g @vue/cli通过cli创建uni-app项目 vue creat
- 命令行进入python打开cmd——>直接输入python即可,如下退出python方法一:先按Ctrl+z,再按Enter(回车键)
- 如下所示:函数说明type()返回数据结构类型(list、dict、numpy.ndarray 等)dtype()返回数据元素的数据类型(i
- 版本:MySQL-5.7.32前言:对于业务繁忙的数据库来说,在运行了一定时间后,往往会产生一些数据量较大的表,特别是对于每天新增数据较多的
- 使用Requests测试带签名的接口部分业务为了安全需要,需要对接口请求数据做签名校验,一般制定一下规则1、业务方接入系统,需申请业务ID以
- 我们先来看一下效果(简单的写了一个):原理:将post请求的代码数据写入了服务器的一个文件,然后用服务器的python编译器执行返回结果实现
- MS SQL Server中文版的预设日期datetime格式是yyyy-mm-dd hh:mm:ss.mmm 长短日期格式 --短日期格式
- 重置MySQL中表中自增列的初始值的实现方法1. 问题的提出 在MySQL的数据库设计中,一般都会设计自增的数字列,
- Python 的元组与列表类似,不同之处在于元组的元素不能修改。元组使用小括号,列表使用方括号。元组创建很简单,只需要在括号中添加元素,并使
- 来源:http://stackoverflow.com/questions/3806562/ways-to-move-up-and-down
- 示例from optparse import OptionParser[...]def main():
- 准备篇:1、配置防火墙,开启80端口、3306端口说明:Ubuntu默认安装是没有开启任何防火墙的,为了服务器的安全,建议大家安装启用防火墙
- 一、使用我使用的是python3,可以自行搜索下载二、安装phone模块pip install phone三、测试代码如下:from pho
- 1. 利用resnet18做迁移学习import torchfrom torchvision import models if __name
- python中查找指定的字符串的方法如下:code#查询def selStr(): sStr1 = 'jsjtt.com
- 目录前言filestools库介绍一行代码给图片加水印总结前言版权相当重要,对于某张图片,可能是你精心制作的思维导图,或者你精心设计的某个l
- 在 MySQL下,在进行中文模糊检索时,经常会返回一些与之不相关的记录,如查找 "%a%" 时,返回的可能有中文字符,却
- 现有:班级表(A_CLASS)学生表( STUDENT)注:学生表(STUDENT)的classId关联班级表(A_CLASS)的主键ID代
- 前言:最近正在将一个使用单文件组件的 Options API 的 Vue2 JavaScript 项目升级为 Vue3 typescript
- 一、迭代器迭代器就是iter(可迭代对象函数)返回的对象,说人话.......可迭代对象由一个个迭代器组成可以用next()函数获取可迭代对