python 还原梯度下降算法实现一维线性回归

作者:Mchael菜鸟 时间:2023-10-09 21:53:42 

首先我们看公式:

python 还原梯度下降算法实现一维线性回归

这个是要拟合的函数

然后我们求出它的损失函数, 注意:这里的n和m均为数据集的长度,写的时候忘了

python 还原梯度下降算法实现一维线性回归

注意,前面的theta0-theta1x是实际值,后面的y是期望值
接着我们求出损失函数的偏导数:

python 还原梯度下降算法实现一维线性回归

最终,梯度下降的算法:

python 还原梯度下降算法实现一维线性回归

学习率一般小于1,当损失函数是0时,我们输出theta0和theta1.
接下来上代码!


class LinearRegression():

def __init__(self, data, theta0, theta1, learning_rate):
   self.data = data
   self.theta0 = theta0
   self.theta1 = theta1
   self.learning_rate = learning_rate
   self.length = len(data)

# hypothesis
 def h_theta(self, x):
   return self.theta0 + self.theta1 * x

# cost function
 def J(self):
   temp = 0
   for i in range(self.length):
     temp += pow(self.h_theta(self.data[i][0]) - self.data[i][1], 2)
   return 1 / (2 * self.m) * temp

# partial derivative
 def pd_theta0_J(self):
   temp = 0
   for i in range(self.length):
     temp += self.h_theta(self.data[i][0]) - self.data[i][1]
   return 1 / self.m * temp

def pd_theta1_J(self):
   temp = 0
   for i in range(self.length):
     temp += (self.h_theta(data[i][0]) - self.data[i][1]) * self.data[i][0]
   return 1 / self.m * temp

# gradient descent
 def gd(self):
   min_cost = 0.00001
   round = 1
   max_round = 10000
   while min_cost < abs(self.J()) and round <= max_round:
     self.theta0 = self.theta0 - self.learning_rate * self.pd_theta0_J()
     self.theta1 = self.theta1 - self.learning_rate * self.pd_theta1_J()

print('round', round, ':\t theta0=%.16f' % self.theta0, '\t theta1=%.16f' % self.theta1)
     round += 1
   return self.theta0, self.theta1

def main():
data = [[1, 2], [2, 5], [4, 8], [5, 9], [8, 15]] # 这里换成你想拟合的数[x, y]
# plot scatter
 x = []
 y = []
 for i in range(len(data)):
   x.append(data[i][0])
   y.append(data[i][1])
 plt.scatter(x, y)

# gradient descent
 linear_regression = LinearRegression(data, theta0, theta1, learning_rate)
 theta0, theta1 = linear_regression.gd()

# plot returned linear
 x = np.arange(0, 10, 0.01)
 y = theta0 + theta1 * x
 plt.plot(x, y)
 plt.show()

来源:https://blog.csdn.net/weixin_46490003/article/details/109184418

标签:python,一维,线性回归
0
投稿

猜你喜欢

  • Centos6.x服务器配置jdk+tomcat+mysql环境(jsp+mysql)

    2023-06-14 12:14:13
  • js三维正方体(兼容ie/ff)

    2008-04-12 14:38:00
  • PHP依赖注入原理与用法分析

    2023-09-04 01:22:54
  • ExecuteReader(),ExecuteNonQuery(),ExecuteScalar(),ExecuteXmlReader()之间的区别

    2023-07-08 23:15:54
  • 安装SQL Server 2005时出现计数器错误

    2008-11-28 14:19:00
  • 讲解数据库管理系统必须提供的基本服务

    2009-01-04 14:33:00
  • Oracle数据库由dataguard备库引起的log file sync等待问题

    2023-07-17 07:35:25
  • asp网站生成静态页面攻略

    2007-11-04 15:09:00
  • 说说CSS的优先权 考虑CSS的继承与层叠

    2008-12-11 13:33:00
  • 小型分页的设计

    2011-08-18 18:32:26
  • 网站设计中的面包屑[译]

    2009-03-22 15:42:00
  • 让你知道codepage的重要,关于多语言编码

    2008-01-31 12:04:00
  • php中加密解密DES类的简单使用方法示例

    2023-09-07 23:28:44
  • 详解Go语言变量作用域

    2023-08-05 03:25:43
  • 有效地使用 SQL事件探查器的提示和技巧

    2009-01-15 13:39:00
  • 向外扩展SQL Server 实现更高扩展性

    2008-12-18 14:45:00
  • 两段不错的JS文字特效

    2007-09-27 12:52:00
  • MYSQL数据库常用命令集合

    2009-02-26 16:01:00
  • php获取通过http协议post提交过来xml数据及解析xml

    2023-11-14 15:43:36
  • ASP从数据库中获取下载文件

    2007-10-06 21:17:00
  • asp之家 网络编程 m.aspxhome.com