Python PaddlePaddle机器学习之求解线性模型
作者:ZacheryZHANG??????? 时间:2023-04-19 08:35:14
前言
飞桨(PaddlePaddle)是集深度学习核心框架、工具组件和服务平台为一体的技术先进、功能完备的开源深度学习平台
1. 任务描述
乘坐出租车的时候,会有一个10元的起步价,只要上车就需要收取该起步价。
出租车每行驶1公里,需要再支付2元的行驶费用(2元/公里)
当一个乘客做完出租车之后,车上的计价器需要算出来该乘客需要支付的乘车费用。
如果以数学模型的角度可以很容易的解除该题的线性关系,及 Y=2x+10Y=2x+10,其中YY 为最终所需费用,xx 为行驶公里数。
试想,我们用机器学习的方法进行训练是不是也可以解决该问题呢,让机器来给我们推算出 YY 与 xx 的关系。即:知道乘客乘坐公里数和支付费用,但是并不知道每公里行驶费和起步价。
2. 代码演练
首先,我们以数学模型建立关系式,定义计价收费函数。该函数用来生成机器学习的数据集。定义好函数以后,接下来,我们传入6个数据(x),该函数可以计算出对应的Y值(也就是机器学习训练用到的真实值)。
def calculate_fee(distance_travelled):
return 10+ 2*distance_travelled
for x in [1.0, 3.0, 5.0, 9.0, 10.0, 20.0]:
print(calculate_fee(x))
接下来开始搭建线性回归。
2.1 数组转张量
将输入数据与输出结果数组转为张量:
import paddle
import numpy
x_data = paddle.to_tensor([[1.0], [3.0], [5.0], [9.0], [10.0], [20.0]])
y_data = paddle.to_tensor([[12.0],[16.0],[20.0],[28.0],[30.0],[50.0]])
linear = paddle.nn.Linear(in_features=1,out_features=1)
# 随机初始化w,b
w_before_opt = linear.weight.numpy().item()
b_before_opt = linear.bias.numpy().item()
# 打印初始w,b
print(w_before_opt,b_before_opt)
mse_loss = paddle.nn.MSELoss()
sgd_optimizer = paddle.optimizer.SGD(learning_rate=0.001, parameters=linear.parameters())
total_epoch = 5000
for i in range(total_epoch):
y_predict = linear(x_data)
loss = mse_loss(y_predict,y_data)
# 反向传播(求梯度)
loss.backward()
# 优化器往前走一步:求出的梯度给优化器用调参
sgd_optimizer.step()
# 优化器把调完参数所用的梯度去清掉,下次再去求
sgd_optimizer.clear_gradients()
# 打印信息
if i % 1000 == 0:
print(i,loss.numpy())
print("finish training, loss = {}".format(loss.numpy()) )
w_after_opt = linear.weight.numpy().item()
b_after_opt = linear.bias.numpy().item()
print(w_after_opt,b_after_opt)
来源:https://juejin.cn/post/7129528216800198664
标签:Python,PaddlePaddle,线性,模型
0
投稿
猜你喜欢
python获取一组数据里最大值max函数用法实例
2022-01-28 00:02:44
网站防止采集方法全攻略
2007-09-05 19:57:00
Python中如何实现MOOC扫码登录
2021-10-08 14:24:50
MYSQL教程:数据列类型与查询效率
2009-02-27 15:37:00
ASP中双引号单引号和&连接符使用技巧
2007-10-01 18:20:00
Python实现网络端口转发和重定向的方法
2023-09-23 10:19:59
django formset实现数据表的批量操作的示例代码
2023-10-10 15:20:21
python数据预处理 :样本分布不均的解决(过采样和欠采样)
2023-08-10 07:03:14
用SQL语句添加删除修改字段、一些表与字段的基本操作、数据库备份等
2024-01-26 16:53:22
Js 按照MVC模式制作自定义控件
2008-10-12 12:11:00
Python识别二维码的两种方法详解
2022-08-20 23:44:12
Python3实时操作处理日志文件的实现
2022-09-01 21:21:16
SQL Server中的SQL语句优化与效率问题
2024-01-20 05:26:57
使用keras框架cnn+ctc_loss识别不定长字符图片操作
2022-05-13 22:15:42
Pandas数据结构详细说明及如何创建Series,DataFrame对象方法
2021-03-14 12:13:35
Python中集合的内建函数和内建方法学习教程
2023-11-03 04:11:27
HTTP服务压力测试工具及相关术语讲解
2024-05-08 10:23:31
Python日志:自定义输出字段 json格式输出方式
2022-08-20 01:27:19
每个程序员都应该学习使用Python或Ruby
2023-09-05 06:03:52
Kettle下载与安装保姆级教程(最新)
2023-07-29 17:10:41