python 还原梯度下降算法实现一维线性回归
作者:Mchael菜鸟 时间:2023-10-09 21:53:42
首先我们看公式:
这个是要拟合的函数
然后我们求出它的损失函数, 注意:这里的n和m均为数据集的长度,写的时候忘了
注意,前面的theta0-theta1x是实际值,后面的y是期望值
接着我们求出损失函数的偏导数:
最终,梯度下降的算法:
学习率一般小于1,当损失函数是0时,我们输出theta0和theta1.
接下来上代码!
class LinearRegression():
def __init__(self, data, theta0, theta1, learning_rate):
self.data = data
self.theta0 = theta0
self.theta1 = theta1
self.learning_rate = learning_rate
self.length = len(data)
# hypothesis
def h_theta(self, x):
return self.theta0 + self.theta1 * x
# cost function
def J(self):
temp = 0
for i in range(self.length):
temp += pow(self.h_theta(self.data[i][0]) - self.data[i][1], 2)
return 1 / (2 * self.m) * temp
# partial derivative
def pd_theta0_J(self):
temp = 0
for i in range(self.length):
temp += self.h_theta(self.data[i][0]) - self.data[i][1]
return 1 / self.m * temp
def pd_theta1_J(self):
temp = 0
for i in range(self.length):
temp += (self.h_theta(data[i][0]) - self.data[i][1]) * self.data[i][0]
return 1 / self.m * temp
# gradient descent
def gd(self):
min_cost = 0.00001
round = 1
max_round = 10000
while min_cost < abs(self.J()) and round <= max_round:
self.theta0 = self.theta0 - self.learning_rate * self.pd_theta0_J()
self.theta1 = self.theta1 - self.learning_rate * self.pd_theta1_J()
print('round', round, ':\t theta0=%.16f' % self.theta0, '\t theta1=%.16f' % self.theta1)
round += 1
return self.theta0, self.theta1
def main():
data = [[1, 2], [2, 5], [4, 8], [5, 9], [8, 15]] # 这里换成你想拟合的数[x, y]
# plot scatter
x = []
y = []
for i in range(len(data)):
x.append(data[i][0])
y.append(data[i][1])
plt.scatter(x, y)
# gradient descent
linear_regression = LinearRegression(data, theta0, theta1, learning_rate)
theta0, theta1 = linear_regression.gd()
# plot returned linear
x = np.arange(0, 10, 0.01)
y = theta0 + theta1 * x
plt.plot(x, y)
plt.show()
来源:https://blog.csdn.net/weixin_46490003/article/details/109184418
标签:python,一维,线性回归
0
投稿
猜你喜欢
Centos6.x服务器配置jdk+tomcat+mysql环境(jsp+mysql)
2023-06-14 12:14:13
js三维正方体(兼容ie/ff)
2008-04-12 14:38:00
PHP依赖注入原理与用法分析
2023-09-04 01:22:54
ExecuteReader(),ExecuteNonQuery(),ExecuteScalar(),ExecuteXmlReader()之间的区别
2023-07-08 23:15:54
安装SQL Server 2005时出现计数器错误
2008-11-28 14:19:00
讲解数据库管理系统必须提供的基本服务
2009-01-04 14:33:00
Oracle数据库由dataguard备库引起的log file sync等待问题
2023-07-17 07:35:25
asp网站生成静态页面攻略
2007-11-04 15:09:00
说说CSS的优先权 考虑CSS的继承与层叠
2008-12-11 13:33:00
小型分页的设计
2011-08-18 18:32:26
网站设计中的面包屑[译]
2009-03-22 15:42:00
让你知道codepage的重要,关于多语言编码
2008-01-31 12:04:00
php中加密解密DES类的简单使用方法示例
2023-09-07 23:28:44
详解Go语言变量作用域
2023-08-05 03:25:43
有效地使用 SQL事件探查器的提示和技巧
2009-01-15 13:39:00
向外扩展SQL Server 实现更高扩展性
2008-12-18 14:45:00
两段不错的JS文字特效
2007-09-27 12:52:00
MYSQL数据库常用命令集合
2009-02-26 16:01:00
php获取通过http协议post提交过来xml数据及解析xml
2023-11-14 15:43:36
ASP从数据库中获取下载文件
2007-10-06 21:17:00