Pytorch 使用tensor特定条件判断索引

作者:judgechen1997 时间:2023-01-18 16:30:23 

torch.where() 用于将两个broadcastable的tensor组合成新的tensor,类似于c++中的三元操作符“?:”

区别于python numpy中的where()直接可以找到特定条件元素的index

Pytorch 使用tensor特定条件判断索引

想要实现numpy中where()的功能,可以借助nonzero()

Pytorch 使用tensor特定条件判断索引

对应numpy中的where()操作效果:

Pytorch 使用tensor特定条件判断索引

补充:Pytorch torch.Tensor.detach()方法的用法及修改指定模块权重的方法

detach

detach的中文意思是分离,官方解释是返回一个新的Tensor,从当前的计算图中分离出来

Pytorch 使用tensor特定条件判断索引

需要注意的是,返回的Tensor和原Tensor共享相同的存储空间,但是返回的 Tensor 永远不会需要梯度

Pytorch 使用tensor特定条件判断索引


import torch as t
a = t.ones(10,)
b = a.detach()
print(b)
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])

那么这个函数有什么作用?

–假如A网络输出了一个Tensor类型的变量a, a要作为输入传入到B网络中,如果我想通过损失函数反向传播修改B网络的参数,但是不想修改A网络的参数,这个时候就可以使用detcah()方法


a = A(input)
a = detach()
b = B(a)
loss = criterion(b, target)
loss.backward()

来看一个实际的例子:


import torch as t
x = t.ones(1, requires_grad=True)
x.requires_grad   #True
y = t.ones(1, requires_grad=True)
y.requires_grad   #True
x = x.detach() #分离之后
x.requires_grad   #False
y = x+y         #tensor([2.])
y.requires_grad   #我还是True
y.retain_grad()   #y不是叶子张量,要加上这一行
z = t.pow(y, 2)
z.backward()  #反向传播
y.grad  #tensor([4.])
x.grad  #None

以上代码就说明了反向传播到y就结束了,没有到达x,所以x的grad属性为None

既然谈到了修改模型的权重问题,那么还有一种情况是:

–假如A网络输出了一个Tensor类型的变量a, a要作为输入传入到B网络中,如果我想通过损失函数反向传播修改A网络的参数,但是不想修改B网络的参数,这个时候又应该怎么办了?

这时可以使用Tensor.requires_grad属性,只需要将requires_grad修改为False即可.


for param in B.parameters():
param.requires_grad = False
a = A(input)
b = B(a)
loss = criterion(b, target)
loss.backward()

以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。如有错误或未考虑完全的地方,望不吝赐教。

来源:https://blog.csdn.net/judgechen1997/article/details/105820709

标签:Pytorch,tensor,索引
0
投稿

猜你喜欢

  • MySQL数据库卸载的完整步骤

    2024-01-13 13:12:52
  • Python爬虫HTPP请求方法有哪些

    2023-07-25 16:55:06
  • MySQL字段类型详解

    2009-01-05 09:23:00
  • 自动更新程序的设计框架

    2009-08-12 13:00:00
  • Python入门教程(四十一)Python的NumPy数组索引

    2023-07-17 01:38:55
  • Python必备基础之闭包和装饰器知识总结

    2022-05-21 22:34:30
  • Python实战之疫苗研发情况可视化

    2023-08-19 15:29:35
  • SQL附加数据库失败问题的解决方法

    2024-01-25 19:22:10
  • MySQL大量脏数据如何只保留最新的一条(最新推荐)

    2024-01-25 22:41:04
  • Python 'takes exactly 1 argument (2 given)' Python error

    2022-04-19 00:26:05
  • Scrapy爬虫实例讲解_校花网

    2023-03-02 14:46:39
  • Go语言切片前或中间插入项与内置copy()函数详解

    2024-05-22 10:16:19
  • python随机模块random的22种函数(小结)

    2022-08-11 18:09:03
  • Django Form设置文本框为readonly操作

    2023-11-11 03:03:43
  • python 协程中的迭代器,生成器原理及应用实例详解

    2022-09-01 07:19:56
  • Python模拟登录和登录跳转的参考示例

    2023-07-29 07:09:47
  • Sql Server 2005 默认端口修改方法

    2024-01-27 08:44:53
  • 网页设计三剑客

    2010-08-31 17:05:00
  • python 星号(*)的多种用途

    2021-08-13 06:16:27
  • python 给图像添加透明度(alpha通道)

    2021-05-04 04:57:22
  • asp之家 网络编程 m.aspxhome.com