pytorch中的自定义反向传播,求导实例

作者:xuxiaoyuxuxiaoyu 时间:2021-08-07 06:57:53 

pytorch中自定义backward()函数。在图像处理过程中,我们有时候会使用自己定义的算法处理图像,这些算法多是基于numpy或者scipy等包。

那么如何将自定义算法的梯度加入到pytorch的计算图中,能使用Loss.backward()操作自动求导并优化呢。下面的代码展示了这个功能`


import torch
import numpy as np
from PIL import Image
from torch.autograd import gradcheck
class Bicubic(torch.autograd.Function):
def basis_function(self, x, a=-1):
 x_abs = np.abs(x)
 if x_abs < 1 and x_abs >= 0:
   y = (a + 2) * np.power(x_abs, 3) - (a + 3) * np.power(x_abs, 2) + 1
 elif x_abs > 1 and x_abs < 2:
   y = a * np.power(x_abs, 3) - 5 * a * np.power(x_abs, 2) + 8 * a * x_abs - 4 * a
 else:
   y = 0
 return y
def bicubic_interpolate(self,data_in, scale=1 / 4, mode='edge'):
 # data_in = data_in.detach().numpy()
 self.grad = np.zeros(data_in.shape,dtype=np.float32)
 obj_shape = (int(data_in.shape[0] * scale), int(data_in.shape[1] * scale), data_in.shape[2])
 data_tmp = data_in.copy()
 data_obj = np.zeros(shape=obj_shape, dtype=np.float32)
 data_in = np.pad(data_in, pad_width=((2, 2), (2, 2), (0, 0)), mode=mode)
 print(data_tmp.shape)
 for axis0 in range(obj_shape[0]):
   f_0 = float(axis0) / scale - np.floor(axis0 / scale)
   int_0 = int(axis0 / scale) + 2
   axis0_weight = np.array(
     [[self.basis_function(1 + f_0), self.basis_function(f_0), self.basis_function(1 - f_0), self.basis_function(2 - f_0)]])
   for axis1 in range(obj_shape[1]):
     f_1 = float(axis1) / scale - np.floor(axis1 / scale)
     int_1 = int(axis1 / scale) + 2
     axis1_weight = np.array(
       [[self.basis_function(1 + f_1), self.basis_function(f_1), self.basis_function(1 - f_1), self.basis_function(2 - f_1)]])
     nbr_pixel = np.zeros(shape=(obj_shape[2], 4, 4), dtype=np.float32)
     grad_point = np.matmul(np.transpose(axis0_weight, (1, 0)), axis1_weight)
     for i in range(4):
       for j in range(4):
         nbr_pixel[:, i, j] = data_in[int_0 + i - 1, int_1 + j - 1, :]
         for ii in range(data_in.shape[2]):
           self.grad[int_0 - 2 + i - 1, int_1 - 2 + j - 1, ii] = grad_point[i,j]
     tmp = np.matmul(axis0_weight, nbr_pixel)
     data_obj[axis0, axis1, :] = np.matmul(tmp, np.transpose(axis1_weight, (1, 0)))[:, 0, 0]
     # img = np.transpose(img[0, :, :, :], [1, 2, 0])
 return data_obj

def forward(self,input):
 print(type(input))
 input_ = input.detach().numpy()
 output = self.bicubic_interpolate(input_)
 # return input.new(output)
 return torch.Tensor(output)

def backward(self,grad_output):
 print(self.grad.shape,grad_output.shape)
 grad_output.detach().numpy()
 grad_output_tmp = np.zeros(self.grad.shape,dtype=np.float32)
 for i in range(self.grad.shape[0]):
   for j in range(self.grad.shape[1]):
     grad_output_tmp[i,j,:] = grad_output[int(i/4),int(j/4),:]
 grad_input = grad_output_tmp*self.grad
 print(type(grad_input))
 # return grad_output.new(grad_input)
 return torch.Tensor(grad_input)

def bicubic(input):
return Bicubic()(input)

def main():
hr = Image.open('./baboon/baboon_hr.png').convert('L')
hr = torch.Tensor(np.expand_dims(np.array(hr), axis=2))
hr.requires_grad = True
lr = bicubic(hr)
print(lr.is_leaf)
loss=torch.mean(lr)
loss.backward()
if __name__ =='__main__':
main()

要想实现自动求导,必须同时实现forward(),backward()两个函数。

1、从代码中可以看出来,forward()函数是针对numpy数据操作,返回值再重新指定为torch.Tensor类型。因此就有这个问题出现了:forward输入input被转换为numpy类型,输出转换为tensor类型,那么输出output的grad_fn参数是如何指定的呢。调试发现,当main()中hr的requires_grad被指定为True,即hr被指定为需要求导的叶子节点。只要Bicubic类继承自torch.autograd.Function,那么output也就是代码中的lr的grad_fn就会被指定为<main.Bicubic object at 0x000001DD5A280D68>,即Bicubic这个类。

2、backward()为求导的函数,gard_output是链式求导法则的上一级的梯度,grad_input即为我们想要得到的梯度。只需要在输入指定grad_output,在调用loss.backward()过程中的某一步会执行到Bicubic的backwward()函数

来源:https://blog.csdn.net/xuxiaoyuxuxiaoyu/article/details/86737492

标签:pytorch,自定义,反向传播,求导
0
投稿

猜你喜欢

  • 解决pycharm导入本地py文件时,模块下方出现红色波浪线的问题

    2023-11-11 10:38:14
  • 深入解析PHP 5.3.x 的strtotime() 时区设定 警告信息修复

    2023-11-06 19:25:27
  • 用gpu训练好的神经网络,用tensorflow-cpu跑出错的原因及解决方案

    2021-02-11 08:06:31
  • asp如何在线查询本地机的文件?

    2010-06-22 21:19:00
  • python字符串查找函数的用法详解

    2022-12-09 11:32:47
  • Python爬虫之教你利用Scrapy爬取图片

    2022-11-02 10:35:02
  • php导出excel格式数据问题

    2023-07-13 22:46:06
  • Python常用断言函数实例汇总

    2023-07-04 18:15:23
  • Case和If哪个更好用?

    2009-10-28 18:25:00
  • numpy:np.newaxis 实现将行向量转换成列向量

    2023-07-13 03:14:39
  • Python玩转加密的技巧【推荐】

    2023-06-26 07:01:31
  • PHP开发实现微信退款功能示例

    2023-06-30 09:10:25
  • ASP日期和时间函数用法详解

    2007-10-13 19:33:00
  • 解决Python2.7中IDLE启动没有反应的问题

    2022-10-17 17:43:57
  • Python变量名详细规则详细变量值介绍

    2021-08-05 07:51:57
  • 用Dreamweaver设计自动关闭的网页

    2010-09-02 12:29:00
  • 简单好用的PHP分页类

    2023-11-22 09:32:39
  • python读取查看npz/npy文件数据以及数据完全显示方法实例

    2022-05-15 15:45:36
  • 在python中将字符串转为json对象并取值的方法

    2022-12-19 02:22:37
  • 运用ASP调用数据库中视图及存储过程

    2008-02-03 15:33:00
  • asp之家 网络编程 m.aspxhome.com