sklearn+python:线性回归案例

作者:yuanlulu 时间:2023-10-19 20:07:01 

使用一阶线性方程预测波士顿房价

载入的数据是随sklearn一起发布的,来自boston 1993年之前收集的506个房屋的数据和价格。load_boston()用于载入数据。


from sklearn.datasets import load_boston
from sklearn.model_selection import train_test_split
import time
from sklearn.linear_model import LinearRegression

boston = load_boston()

X = boston.data
y = boston.target

print("X.shape:{}. y.shape:{}".format(X.shape, y.shape))
print('boston.feature_name:{}'.format(boston.feature_names))

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=3)

model = LinearRegression()

start = time.clock()
model.fit(X_train, y_train)

train_score = model.score(X_train, y_train)
cv_score = model.score(X_test, y_test)

print('time used:{0:.6f}; train_score:{1:.6f}, sv_score:{2:.6f}'.format((time.clock()-start),
                                   train_score, cv_score))

输出内容为:


X.shape:(506, 13). y.shape:(506,)
boston.feature_name:['CRIM' 'ZN' 'INDUS' 'CHAS' 'NOX' 'RM' 'AGE' 'DIS' 'RAD' 'TAX' 'PTRATIO'
'B' 'LSTAT']
time used:0.012403; train_score:0.723941, sv_score:0.794958

可以看到测试集上准确率并不高,应该是欠拟合。

使用多项式做线性回归

上面的例子是欠拟合的,说明模型太简单,无法拟合数据的情况。现在增加模型复杂度,引入多项式。

打个比方,如果原来的特征是[a, b]两个特征,

在degree为2的情况下, 多项式特征变为[1, a, b, a^2, ab, b^2]。degree为其它值的情况依次类推。

多项式特征相当于增加了数据和模型的复杂性,能够更好的拟合。

下面的代码使用Pipeline把多项式特征和线性回归特征连起来,最终测试degree在1、2、3的情况下的得分。


from sklearn.datasets import load_boston
from sklearn.model_selection import train_test_split
import time
from sklearn.linear_model import LinearRegression
from sklearn.preprocessing import PolynomialFeatures
from sklearn.pipeline import Pipeline

def polynomial_model(degree=1):
 polynomial_features = PolynomialFeatures(degree=degree, include_bias=False)

linear_regression = LinearRegression(normalize=True)
 pipeline = Pipeline([('polynomial_features', polynomial_features),
            ('linear_regression', linear_regression)])
 return pipeline

boston = load_boston()
X = boston.data
y = boston.target
print("X.shape:{}. y.shape:{}".format(X.shape, y.shape))
print('boston.feature_name:{}'.format(boston.feature_names))

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=3)

for i in range(1,4):
 print( 'degree:{}'.format( i ) )
 model = polynomial_model(degree=i)

start = time.clock()
 model.fit(X_train, y_train)

train_score = model.score(X_train, y_train)
 cv_score = model.score(X_test, y_test)

print('time used:{0:.6f}; train_score:{1:.6f}, sv_score:{2:.6f}'.format((time.clock()-start),
                                   train_score, cv_score))

输出结果为:


X.shape:(506, 13). y.shape:(506,)
boston.feature_name:['CRIM' 'ZN' 'INDUS' 'CHAS' 'NOX' 'RM' 'AGE' 'DIS' 'RAD' 'TAX' 'PTRATIO'
'B' 'LSTAT']
degree:1
time used:0.003576; train_score:0.723941, sv_score:0.794958
degree:2
time used:0.030123; train_score:0.930547, sv_score:0.860465
degree:3
time used:0.137346; train_score:1.000000, sv_score:-104.429619

可以看到degree为1和上面不使用多项式是一样的。degree为3在训练集上的得分为1,在测试集上得分是负数,明显过拟合了。

所以最终应该选择degree为2的模型。

二阶多项式比一阶多项式好的多,但是测试集和训练集上的得分仍有不少差距,这可能是数据不够的原因,需要更多的讯据才能进一步提高模型的准确度。

正规方程解法和梯度下降的比较

除了梯度下降法来逼近最优解,也可以使用正规的方程解法直接计算出最终的解来。

根据吴恩达的课程,线性回归最优解为:

theta = (X^T * X)^-1 * X^T * y

其实两种方法各有优缺点:

梯度下降法:

缺点:需要选择学习率,需要多次迭代

优点:特征值很多(1万以上)时仍然能以不错的速度工作

正规方程解法:

优点:不需要设置学习率,不需要多次迭代

缺点:需要计算X的转置和逆,复杂度O3;特征值很多(1万以上)时特变慢

在分类等非线性计算中,正规方程解法并不适用,所以梯度下降法适用范围更广。

来源:https://blog.csdn.net/yuanlulu/article/details/81068027

标签:sklearn,python,线性回归
0
投稿

猜你喜欢

  • 用python实现一幅春联实例代码

    2021-07-23 09:25:41
  • ASP数据库编程SQL常用技巧

    2024-01-20 04:53:59
  • 自适应css布局——流动布局新时代[译]

    2009-08-13 12:28:00
  • golang如何通过viper读取config.yaml文件

    2023-07-22 05:46:11
  • 一文搞懂MySQL预编译

    2024-01-25 21:52:21
  • 解决mac使用homebrew安装MySQL无法登陆问题

    2024-01-27 06:22:24
  • JavaScript逆向分析instagram登入过程

    2023-09-08 19:51:52
  • 基于Tensorflow使用CPU而不用GPU问题的解决

    2022-01-01 22:53:08
  • Python使用read_csv读数据遇到分隔符问题的2种解决方式

    2022-01-13 13:30:47
  • 编写安全的SQL Server扩展存储过程

    2008-11-25 11:16:00
  • Javascript Closures (2)

    2009-03-18 12:22:00
  • 表格可读性提升分析

    2010-05-19 13:03:00
  • Python数据可视化详解

    2021-10-02 19:28:55
  • Spring 数据库连接池(JDBC)详解

    2024-01-22 19:00:36
  • python去除扩展名的实例讲解

    2022-05-08 18:10:49
  • 如何让页面在打开时自动刷新一次让图片全部显示

    2024-04-17 10:10:44
  • 深入mysql基础知识的详解

    2024-01-21 06:04:45
  • Python中pygal绘制雷达图代码分享

    2023-09-27 10:03:59
  • Python多线程编程之threading模块详解

    2023-12-28 07:52:59
  • Python用5行代码实现批量抠图的示例代码

    2021-04-16 23:56:05
  • asp之家 网络编程 m.aspxhome.com