Python PaddlePaddle机器学习之求解线性模型

作者:ZacheryZHANG??????? 时间:2023-04-19 08:35:14 

前言

飞桨(PaddlePaddle)是集深度学习核心框架、工具组件和服务平台为一体的技术先进、功能完备的开源深度学习平台

1. 任务描述

  • 乘坐出租车的时候,会有一个10元的起步价,只要上车就需要收取该起步价。

  • 出租车每行驶1公里,需要再支付2元的行驶费用(2元/公里)

  • 当一个乘客做完出租车之后,车上的计价器需要算出来该乘客需要支付的乘车费用。

如果以数学模型的角度可以很容易的解除该题的线性关系,及 Y=2x+10Y=2x+10,其中YY 为最终所需费用,xx 为行驶公里数。

试想,我们用机器学习的方法进行训练是不是也可以解决该问题呢,让机器来给我们推算出 YY 与 xx 的关系。即:知道乘客乘坐公里数和支付费用,但是并不知道每公里行驶费和起步价。

2. 代码演练

首先,我们以数学模型建立关系式,定义计价收费函数。该函数用来生成机器学习的数据集。定义好函数以后,接下来,我们传入6个数据(x),该函数可以计算出对应的Y值(也就是机器学习训练用到的真实值)。

def calculate_fee(distance_travelled):
   return 10+ 2*distance_travelled
for x in [1.0, 3.0, 5.0, 9.0, 10.0, 20.0]:
   print(calculate_fee(x))

接下来开始搭建线性回归。

2.1 数组转张量

将输入数据与输出结果数组转为张量:

import paddle
import numpy
x_data = paddle.to_tensor([[1.0], [3.0], [5.0], [9.0], [10.0], [20.0]])
y_data = paddle.to_tensor([[12.0],[16.0],[20.0],[28.0],[30.0],[50.0]])
linear = paddle.nn.Linear(in_features=1,out_features=1)

# 随机初始化w,b
w_before_opt = linear.weight.numpy().item()
b_before_opt = linear.bias.numpy().item()
# 打印初始w,b
print(w_before_opt,b_before_opt)

mse_loss = paddle.nn.MSELoss()
sgd_optimizer = paddle.optimizer.SGD(learning_rate=0.001, parameters=linear.parameters())

total_epoch = 5000

for i in range(total_epoch):
   y_predict = linear(x_data)
   loss = mse_loss(y_predict,y_data)

# 反向传播(求梯度)
   loss.backward()
   # 优化器往前走一步:求出的梯度给优化器用调参
   sgd_optimizer.step()
   # 优化器把调完参数所用的梯度去清掉,下次再去求
   sgd_optimizer.clear_gradients()

# 打印信息
   if i % 1000 == 0:
       print(i,loss.numpy())
print("finish training, loss = {}".format(loss.numpy()) )

w_after_opt = linear.weight.numpy().item()
b_after_opt = linear.bias.numpy().item()
print(w_after_opt,b_after_opt)

来源:https://juejin.cn/post/7129528216800198664

标签:Python,PaddlePaddle,线性,模型
0
投稿

猜你喜欢

  • python获取一组数据里最大值max函数用法实例

    2022-01-28 00:02:44
  • 网站防止采集方法全攻略

    2007-09-05 19:57:00
  • Python中如何实现MOOC扫码登录

    2021-10-08 14:24:50
  • MYSQL教程:数据列类型与查询效率

    2009-02-27 15:37:00
  • ASP中双引号单引号和&连接符使用技巧

    2007-10-01 18:20:00
  • Python实现网络端口转发和重定向的方法

    2023-09-23 10:19:59
  • django formset实现数据表的批量操作的示例代码

    2023-10-10 15:20:21
  • python数据预处理 :样本分布不均的解决(过采样和欠采样)

    2023-08-10 07:03:14
  • 用SQL语句添加删除修改字段、一些表与字段的基本操作、数据库备份等

    2024-01-26 16:53:22
  • Js 按照MVC模式制作自定义控件

    2008-10-12 12:11:00
  • Python识别二维码的两种方法详解

    2022-08-20 23:44:12
  • Python3实时操作处理日志文件的实现

    2022-09-01 21:21:16
  • SQL Server中的SQL语句优化与效率问题

    2024-01-20 05:26:57
  • 使用keras框架cnn+ctc_loss识别不定长字符图片操作

    2022-05-13 22:15:42
  • Pandas数据结构详细说明及如何创建Series,DataFrame对象方法

    2021-03-14 12:13:35
  • Python中集合的内建函数和内建方法学习教程

    2023-11-03 04:11:27
  • HTTP服务压力测试工具及相关术语讲解

    2024-05-08 10:23:31
  • Python日志:自定义输出字段 json格式输出方式

    2022-08-20 01:27:19
  • 每个程序员都应该学习使用Python或Ruby

    2023-09-05 06:03:52
  • Kettle下载与安装保姆级教程(最新)

    2023-07-29 17:10:41
  • asp之家 网络编程 m.aspxhome.com