Pytorch基本变量类型FloatTensor与Variable用法

作者:jingxian 时间:2022-10-10 21:14:45 

pytorch中基本的变量类型当属FloatTensor(以下都用floattensor),而Variable(以下都用variable)是floattensor的封装,除了包含floattensor还包含有梯度信息

pytorch中的dochi给出一些对于floattensor的基本的操作,比如四则运算以及平方等(链接),这些操作对于floattensor是十分的不友好,有时候需要写一个正则化的项需要写很长的一串,比如两个floattensor之间的相加需要用torch.add()来实现

然而正确的打开方式并不是这样

韩国一位大神写了一个pytorch的turorial,其中包含style transfer的一个代码实现


for step in range(config.total_step):

# Extract multiple(5) conv feature vectors
   target_features = vgg(target)  # 每一次输入到网络中的是同样一张图片,反传优化的目标是输入的target
   content_features = vgg(Variable(content))
   style_features = vgg(Variable(style))

style_loss = 0
   content_loss = 0
   for f1, f2, f3 in zip(target_features, content_features, style_features):
     # Compute content loss (target and content image)
     content_loss += torch.mean((f1 - f2)**2) # square 可以进行直接加-操作?可以,并且mean对所有的元素进行均值化造作

# Reshape conv features
     _, c, h, w = f1.size() # channel height width
     f1 = f1.view(c, h * w) # reshape a vector
     f3 = f3.view(c, h * w) # reshape a vector

# Compute gram matrix
     f1 = torch.mm(f1, f1.t())
     f3 = torch.mm(f3, f3.t())

# Compute style loss (target and style image)
     style_loss += torch.mean((f1 - f3)**2) / (c * h * w)  # 总共元素的数目?

其中f1与f2,f3的变量类型是Variable,作者对其直接用四则运算符进行加减,并且用python内置的**进行平方操作,然后


# -*-coding: utf-8 -*-
import torch
from torch.autograd import Variable

# dtype = torch.FloatTensor
dtype = torch.cuda.FloatTensor # Uncomment this to run on GPU

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Randomly initialize weights
w1 = torch.randn(D_in, H).type(dtype) # 两个权重矩阵
w2 = torch.randn(D_in, H).type(dtype)
# operate with +-*/ and **
w3 = w1-2*w2
w4 = w3**2
w5 = w4/w1

# operate the Variable with +-*/ and **
w6 = Variable(torch.randn(N, D_in).type(dtype))
w7 = Variable(torch.randn(N, D_in).type(dtype))
w8 = w6 + w7
w9 = w6*w7
w10 = w9**2
print(1)

基本上调试的结果与预期相符

Pytorch基本变量类型FloatTensor与Variable用法

所以,对于floattensor以及variable进行普通的+-×/以及**没毛病

来源:https://blog.csdn.net/u013517182/article/details/93051322

标签:Pytorch,FloatTensor,Variable
0
投稿

猜你喜欢

  • 详解python opencv、scikit-image和PIL图像处理库比较

    2021-11-10 02:24:13
  • JavaScript设计模式之享元模式实例详解

    2024-04-17 10:08:34
  • python中使用 xlwt 操作excel的常见方法与问题

    2021-09-12 05:10:02
  • 利用Python的turtle库绘制玫瑰教程

    2021-12-01 19:33:52
  • js读取图片的宽和高

    2007-08-04 10:14:00
  • 在ASP.NET 2.0中操作数据之五十二:使用FileUpload上传文件

    2023-07-07 04:19:18
  • Python如何生成树形图案

    2022-11-17 03:39:58
  • 基于python的itchat库实现微信聊天机器人(推荐)

    2021-11-30 13:54:21
  • 解决pycharm最左侧Tool Buttons显示不全的问题

    2022-11-22 13:23:22
  • jenkins配置163邮箱的操作方法

    2023-08-10 22:54:05
  • python检测服务器端口代码实例

    2023-07-07 06:34:14
  • asp Access数据备份,还原,压缩类代码

    2011-03-07 11:16:00
  • MySQL隔离级别和锁机制的深入讲解

    2024-01-14 06:57:53
  • python 双循环遍历list 变量判断代码

    2021-02-10 12:38:12
  • python获取百度热榜链接的实例方法

    2022-10-02 23:10:01
  • centos7环境下二进制安装包安装 mysql5.6的方法详解

    2024-01-26 23:37:33
  • laravel框架中路由设置,路由参数和路由命名实例分析

    2024-06-05 09:43:33
  • Django框架模型简单介绍与使用分析

    2021-04-06 02:59:19
  • MySQL decimal unsigned更新负数转化为0

    2024-01-14 20:59:36
  • Python使用Scrapy爬取妹子图

    2022-06-17 23:47:41
  • asp之家 网络编程 m.aspxhome.com