TensorFlow实现简单线性回归

作者:kylinxjd 时间:2023-09-18 13:23:45 

本文实例为大家分享了TensorFlow实现简单线性回归的具体代码,供大家参考,具体内容如下

简单的一元线性回归

一元线性回归公式:

TensorFlow实现简单线性回归

其中x是特征:[x1,x2,x3,…,xn,]T
w是权重,b是偏置值

代码实现

导入必须的包

import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
import os

# 屏蔽warning以下的日志信息
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

产生模拟数据

def generate_data():
    x = tf.constant(np.array([i for i in range(0, 100, 5)]).reshape(-1, 1), tf.float32)
    y = tf.add(tf.matmul(x, [[1.3]]) + 1, tf.random_normal([20, 1], stddev=30))
    return x, y

x是100行1列的数据,tf.matmul是矩阵相乘,所以权值设置成二维的。
设置的w是1.3, b是1

实现回归

def myregression():
    """
    自实现线性回归
    :return:
    """
    x, y = generate_data()
    #     建立模型  y = x * w + b
    # w 1x1的二维数据
    w = tf.Variable(tf.random_normal([1, 1], mean=0.0, stddev=1.0), name='weight_a')
    b = tf.Variable(0.0, name='bias_b')

    y_predict = tf.matmul(x, a) + b

    # 建立损失函数
    loss = tf.reduce_mean(tf.square(y_predict - y))
    
    # 训练
    train_op = tf.train.GradientDescentOptimizer(0.1).minimize(loss=loss)

    # 初始化全局变量
    init_op = tf.global_variables_initializer()

  
    with tf.Session() as sess:
        sess.run(init_op)
        print('初始的权重:%f偏置值:%f' % (a.eval(), b.eval()))
    
        # 训练优化
        for i in range(1, 100):
            sess.run(train_op)
            print('第%d次优化的权重:%f偏置值:%f' % (i, a.eval(), b.eval()))
        # 显示回归效果
        show_img(x.eval(), y.eval(), y_predict.eval())

使用matplotlib查看回归效果

def show_img(x, y, y_pre):
    plt.scatter(x, y)
    plt.plot(x, y_pre)
    plt.show()

完整代码

import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
import os

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

def generate_data():
    x = tf.constant(np.array([i for i in range(0, 100, 5)]).reshape(-1, 1), tf.float32)
    y = tf.add(tf.matmul(x, [[1.3]]) + 1, tf.random_normal([20, 1], stddev=30))
    return x, y

def myregression():
    """
    自实现线性回归
    :return:
    """
    x, y = generate_data()
    # 建立模型  y = x * w + b
    w = tf.Variable(tf.random_normal([1, 1], mean=0.0, stddev=1.0), name='weight_a')
    b = tf.Variable(0.0, name='bias_b')

    y_predict = tf.matmul(x, w) + b

    # 建立损失函数
    loss = tf.reduce_mean(tf.square(y_predict - y))
    # 训练
    train_op = tf.train.GradientDescentOptimizer(0.0001).minimize(loss=loss)

    init_op = tf.global_variables_initializer()

    with tf.Session() as sess:
        sess.run(init_op)
        print('初始的权重:%f偏置值:%f' % (w.eval(), b.eval()))
        # 训练优化
        for i in range(1, 35000):
            sess.run(train_op)
            print('第%d次优化的权重:%f偏置值:%f' % (i, w.eval(), b.eval()))
        show_img(x.eval(), y.eval(), y_predict.eval())

def show_img(x, y, y_pre):
    plt.scatter(x, y)
    plt.plot(x, y_pre)
    plt.show()

if __name__ == '__main__':
    myregression()

看看训练的结果(因为数据是随机产生的,每次的训练结果都会不同,可适当调节梯度下降的学习率和训练步数)

TensorFlow实现简单线性回归

35000次的训练结果

TensorFlow实现简单线性回归

来源:https://blog.csdn.net/kylinxjd/article/details/105557304

标签:TensorFlow,线性回归
0
投稿

猜你喜欢

  • 利用Python函数实现一个万历表完整示例

    2022-06-30 18:06:22
  • python去掉字符串中重复字符的方法

    2022-11-23 09:17:35
  • Python 作为小程序后端的三种实现方法(推荐)

    2023-03-30 09:26:05
  • Excel VBA连接并操作Oracle

    2009-08-08 22:58:00
  • Python 类,property属性(简化属性的操作),@property,property()用法示例

    2022-01-04 19:21:53
  • python爬虫开发之使用python爬虫库requests,urllib与今日头条搜索功能爬取搜索内容实例

    2022-01-05 19:39:44
  • vue 过滤、模糊查询及计算属性 computed详解

    2024-05-09 09:53:30
  • MSSQL数据库的定期自动备份计划。

    2024-01-27 11:01:04
  • Go语言转换所有字符串为大写或者小写的方法

    2023-06-21 19:48:07
  • Python调用Matplotlib绘制振动图、箱型图和提琴图

    2022-02-08 05:56:09
  • oracle合并列的函数wm_concat的使用详解

    2024-01-25 20:54:19
  • 一文搞懂 parseInt()函数异常行为

    2024-04-30 08:57:11
  • Windows下安装Django框架的方法简明教程

    2021-06-26 20:26:13
  • 显示ASP页面源码的代码

    2008-10-12 13:05:00
  • 利用Python读取Excel表内容的详细过程

    2022-10-24 05:43:33
  • MySQL外键约束的禁用与启用命令

    2024-01-27 00:45:04
  • ASPJPEG组件使用详解(缩略图+水印)

    2007-09-19 17:31:00
  • Django如何实现防止XSS攻击

    2022-04-13 10:52:39
  • 百度UEditor编辑器使用教程与使用方法(图文)

    2023-03-31 14:07:53
  • PHP入门教程之会话控制技巧(cookie与session)

    2023-11-16 00:13:39
  • asp之家 网络编程 m.aspxhome.com