Python深度学习pyTorch权重衰减与L2范数正则化解析

作者:算法菜鸟飞高高 时间:2021-03-18 11:39:01 

Python深度学习pyTorch权重衰减与L2范数正则化解析

下面进行一个高维线性实验

假设我们的真实方程是:

Python深度学习pyTorch权重衰减与L2范数正则化解析

假设feature数200,训练样本和测试样本各20个

模拟数据集


num_train,num_test = 10,10
num_features = 200
true_w = torch.ones((num_features,1),dtype=torch.float32) * 0.01
true_b = torch.tensor(0.5)
samples = torch.normal(0,1,(num_train+num_test,num_features))
noise = torch.normal(0,0.01,(num_train+num_test,1))
labels = samples.matmul(true_w) + true_b + noise
train_samples, train_labels= samples[:num_train],labels[:num_train]
test_samples, test_labels = samples[num_train:],labels[num_train:]

定义带正则项的loss function


def loss_function(predict,label,w,lambd):
   loss = (predict - label) ** 2
   loss = loss.mean() + lambd * (w**2).mean()
   return loss

画图的方法


def semilogy(x_val,y_val,x_label,y_label,x2_val,y2_val,legend):
   plt.figure(figsize=(3,3))
   plt.xlabel(x_label)
   plt.ylabel(y_label)
   plt.semilogy(x_val,y_val)
   if x2_val and y2_val:
       plt.semilogy(x2_val,y2_val)
       plt.legend(legend)
   plt.show()

拟合和画图


def fit_and_plot(train_samples,train_labels,test_samples,test_labels,num_epoch,lambd):
   w = torch.normal(0,1,(train_samples.shape[-1],1),requires_grad=True)
   b = torch.tensor(0.,requires_grad=True)
   optimizer = torch.optim.Adam([w,b],lr=0.05)
   train_loss = []
   test_loss = []
   for epoch in range(num_epoch):
       predict = train_samples.matmul(w) + b
       epoch_train_loss = loss_function(predict,train_labels,w,lambd)
       optimizer.zero_grad()
       epoch_train_loss.backward()
       optimizer.step()
       test_predict = test_sapmles.matmul(w) + b
       epoch_test_loss = loss_function(test_predict,test_labels,w,lambd)
       train_loss.append(epoch_train_loss.item())
       test_loss.append(epoch_test_loss.item())
   semilogy(range(1,num_epoch+1),train_loss,'epoch','loss',range(1,num_epoch+1),test_loss,['train','test'])

Python深度学习pyTorch权重衰减与L2范数正则化解析
可以发现加了正则项的模型,在测试集上的loss确实下降了

以上就是Python深度学习pyTorch权重衰减与L2范数正则化解析的详细内容,更多关于Python pyTorch权重与L2范数正则化的资料请关注脚本之家其它相关文章!

来源:https://blog.csdn.net/qq_43152622/article/details/116937183

标签:深度学习,pyTorch,权重,范数正则化
0
投稿

猜你喜欢

  • SQL Server 在分页获取数据的同时获取到总记录数

    2024-01-24 09:04:13
  • MySQL 基本概念

    2011-09-10 16:22:34
  • php使用pthreads v3多线程实现抓取新浪新闻信息操作示例

    2023-10-12 19:21:46
  • javascript中不易分清的slice,splice和split三个函数

    2024-04-28 09:37:29
  • 简述MySQL主键和外键使用及说明

    2024-01-13 19:29:28
  • Python获取协程返回值的四种方式详解

    2023-10-03 15:13:21
  • Python+decimal完成精度计算的示例详解

    2022-09-18 11:07:35
  • mysql 8.0.12 winx64详细安装教程

    2024-01-26 12:37:19
  • 浅谈用Python实现一个大数据搜索引擎

    2022-05-11 19:15:52
  • pycharm console 打印中文为乱码问题及解决

    2023-06-15 22:30:02
  • 用Python将IP地址在整型和字符串之间轻松转换

    2021-03-31 16:37:16
  • 自定义Django Form中choicefield下拉菜单选取数据库内容实例

    2024-01-25 09:02:02
  • 对Python之gzip文件读写的方法详解

    2021-03-24 17:54:01
  • 使用Spring Boot实现操作数据库的接口的过程

    2024-01-25 02:02:49
  • [JS效果]动画效果打开/关闭/移动层

    2008-04-10 11:42:00
  • position:relative/absolute无法冲破的等级

    2007-05-11 17:03:00
  • 获取MSSQL数据字典的SQL语句

    2024-01-20 11:35:16
  • python免杀技术shellcode的加载与执行

    2021-10-27 16:25:06
  • 企业级使用LAMP源码安装教程

    2024-01-17 19:41:29
  • Go语言接口用法实例

    2024-02-04 22:27:30
  • asp之家 网络编程 m.aspxhome.com