pyTorch深入学习梯度和Linear Regression实现
作者:算法菜鸟飞高高 发布时间:2023-05-12 01:00:45
梯度
PyTorch的数据结构是tensor,它有个属性叫做requires_grad,设置为True以后,就开始track在其上的所有操作,前向计算完成后,可以通过backward来进行梯度回传。
评估模型的时候我们并不需要梯度回传,使用with torch.no_grad() 将不需要梯度的代码段包裹起来。每个Tensor都有一个.grad_fn属性,该属性即创建该Tensor的Function,直接用构造的tensor返回None,否则是生成该tensor的操作。
tensor(data, *, dtype=None, device=None, requires_grad=False, pin_memory=False) -> Tensor
#require_grad默认是false,下面我们将显式的开启
x = torch.tensor([1,2,3],requires_grad=True,dtype=torch.float)
注意只有数据类型是浮点型和complex类型才能require梯度,所以这里显示指定dtype为torch.float32
x = torch.tensor([1,2,3],requires_grad=True,dtype=torch.float32)
> tensor([1.,2.,3.],grad_fn=None)
y = x + 2
> tensor([3.,4.,5.],grad_fn=<AddBackward0>)
z = y * y * 3
> tensor([3.,4.,5.],grad_fn=<MulBackward0>)
像x这种直接创建的,没有grad_fn,被称为叶子结点。grad_fn记录了一个个基本操作用来进行梯度计算的。
关于梯度回传计算看下面一个例子
x = torch.ones((2,2),requires_grad=True)
> tensor([[1.,1.],
> [1.,1.]],requires_grad=True)
y = x + 2
z = y * y * 3
out = z.mean()
#out是一个标量,无需指定求偏导的变量
out.backward()
x.grad
> tensor([[4.500,4.500],
> [4.500,4.500]])
#每次计算梯度前,需要将梯度清零,否则会累加
x.grad.data.zero_()
值得注意的是只有叶子节点的梯度在回传时才会被计算,也就是说,上面的例子中拿不到y和z的grad。
来看一个中断求导的例子
x = torch.tensor(1.,requires_grad=True)
y1 = x ** 2
with torch.no_grad()
y2 = x ** 3
y3 = y1 + y2
y3.backward()
print(x.grad)
> 2
本来梯度应该为5的,但是由于y2被with torch.no_grad()包裹,在梯度计算的时候不会被追踪。
如果我们想要修改某个tensor的数值但是又不想被autograd记录,那么需要使用对x.data进行操作就行这也是一个张量。
线性回归(linear regression)
利用线性回归来预测一栋房屋的价格,价格取决于很多feature,这里简化问题,假设价格只取决于两个因素,面积(平方米)和房龄(年)
x1代表面积,x2代表房龄,售出价格为y
模拟数据集
假设我们的样本数量为1000个,每个数据包括两个features,则数据为1000 * 2的2-d张量,用正太分布来随机取值。
labels是房屋的价格,长度为1000的一维张量。
真实w和b提前把值定好,然后再取一个干扰量 δ \delta δ(也用高斯分布取值,用来模拟真实数据集中的偏差)
num_features = 2#两个特征
num_examples = 1000 #样本个数
w = torch.normal(0,1,(num_features,1))
b = torch.tensor(4.2)
samples = torch.normal(0,1,(num_examples,num_features))
labels = samples.matmul(w) + b
noise = torch.normal(0,.01,labels.shape)
labels += noise
加载数据集
import random
def data_iter(samples,labels,batch_size):
num_samples = samples.shape[0] #获得batch轴的长度
indices = [i for i in range(num_samples)]
random.shuffle(indices)#将索引数组原地打乱
for i in range(0,num_samples,batch_size):
j = torch.LongTensor(indices[i:min(i+batch_size,num_samples)])
yield samples.index_select(0,j),labels(0,j)
torch.index_select(dim,index)
dim表示tensor的轴,index是一个tensor,里面包含的是索引。
定义loss_function
def loss_function(predict,labels):
loss = (predict,labels)** 2 / 2
return loss.mean()
定义优化器
def loss_function(predict,labels):
loss = (predict,labels)** 2 / 2
return loss.mean()
开始训练
w = torch.normal(0.,1.,(num_features,1),requires_grad=True)
b = torch.zero(0.,dtype=torch.float32,requires_grad=True)
batch_size = 100
for epoch in range(10):
for data, label in data_iter(samples,labels,batch_size):
predict = data.matmul(w) + b
loss = loss_function(predict,label)
loss.backward()
optimizer([w,b],0.05)
w.grad.data.zero_()
b.grad.data.zero_()
以上就是pyTorch深入学习梯度和Linear Regression实现的详细内容,更多关于pyTorch实现梯度和Linear Regression的资料请关注脚本之家其它相关文章!
来源:https://blog.csdn.net/qq_43152622/article/details/116792624


猜你喜欢
- 本文介绍了tf.truncated_normal与tf.random_normal的详细用法,分享给大家,具体如下:tf.truncated
- 微信小程序终于可以支持npm导入第三方库了(https://developers.weixin.qq....),但是这种导入模式和使用模式有
- 代码如下dat=['1', '2', '3', '0', '0
- 简介pycurl类似于Python的urllib,但是pycurl是对libcurl的封装,速度更快。本文使用的是pycurl 7.43.0
- 本文实例为大家分享了vue实现购物车功能的具体代码,供大家参考,具体内容如下new Vue({ el: "#app",
- 一、 Axios 的封装在 Vue 项目中,和后台进行数据交互是频繁且不可或缺的,刚开始没进行 Axios 封装的时候,每次请求后台数据都是
- 亲测可用学习vee-validate,首先可以去阅读官方文档,更为详细可以阅读官网中的规则。一、安装您可以通过npm或通过CDN安装此插件。
- Python 输出 "Hello, World!",英文没有问题,但是如果你输出中文字符"你好,世界"
- 有些时候我们不得已要利用values来反向查询key,有没有简单的方法呢?下面我给大家列举一些方法,方便大家使用python3>>
- ES在之前的博客已有介绍,提供很多接口,本文介绍如何使用python批量导入。ES官网上有较多说明文档,仔细研究并结合搜索引擎应该不难使用。
- 看lifesinger的《由Kimi找茬想到的》,我想到的:1、 我不同意将“合并付款”定调在“很多卖家都需要”。这个“很多”在卖家里面大概
- 我使用anaconda安装的python3.6.3,并且
- 要点说明在绘制散点图的时候,通常使用变量作为输入数据的载体。其实,也可以使用字符串作为输入数据的存储载体。下面代码的data = {“a”:
- Mac安装python3环境首先我先给说明一下:我也是初次接触python,有一定的Java基础,对编程语法有一定基础,当然小菜在这里全当小
- 1. txt文件(1) 单位矩阵即主对角线上的元素均为1,其余元素均为0的正方形矩阵。在NumPy中可以用eye函数创建一个这样的二维数组,
- 爬虫具有域名切换、信息收集以及信息存储功能。这里讲述如何构建基础的爬虫架构。1、urllib库:包含从网络请求数据、处理cookie、改变请
- 前言项目开发中,产品经理提了这样一个需求:将系统中的附件实现批量打包下载功能。本来系统中是有单个下载及批量下载功能,现在应业务方的需求,需要
- 这里首先要介绍官方文档,对python有了进一步深度的学习的大家们应该会发现,网上不管csdn或者简书上还是什么地方,教程来源基本就是官方文
- ChineseCalendar 是一个 Python 包,用于获取中国传统日历信息。这个包提供了中国农历、二十四节气、传统节日、黄历等信息。
- 我就废话不多说了,大家还是直接看代码吧~import numpy as np#从scipy库中导入插值需要的方法 interpolatefr