python的numpy模块实现逻辑回归模型

作者:上进的小菜鸟 时间:2022-10-01 07:05:59 

使用python的numpy模块实现逻辑回归模型的代码,供大家参考,具体内容如下

使用了numpy模块,pandas模块,matplotlib模块

1.初始化参数

def initial_para(nums_feature):
    """initial the weights and bias which is zero"""
    #nums_feature是输入数据的属性数目,因此权重w是[1, nums_feature]维
    #且w和b均初始化为0
    w = np.zeros((1, nums_feature))
    b = 0
    return w, b

2.逻辑回归方程

def activation(x, w , b):
    """a linear function and then sigmoid activation function: 
    x_ = w*x +b,y = 1/(1+exp(-x_))"""
    #线性方程,输入的x是[batch, 2]维,输出是[1, batch]维,batch是模型优化迭代一次输入数据的数目
    #[1, 2] * [2, batch] = [1, batch], 所以是w * x.T(x的转置)
    #np.dot是矩阵乘法
    x_ = np.dot(w, x.T) + b
    #np.exp是实现e的x次幂
    sigmoid = 1 / (1 + np.exp(-x_))
    return sigmoid

3.梯度下降

def gradient_descent_batch(x, w, b, label, learning_rate):
    #获取输入数据的数目,即batch大小
    n = len(label)
    #进行逻辑回归预测
    sigmoid = activation(x, w, b)
    #损失函数,np.sum是将矩阵求和
    cost = -np.sum(label.T * np.log(sigmoid) + (1-label).T * np.log(1-sigmoid)) / n
    #求对w和b的偏导(即梯度值)
    g_w = np.dot(x.T, (sigmoid - label.T).T) / n
    g_b = np.sum((sigmoid - label.T)) / n
    #根据梯度更新参数
    w = w - learning_rate * g_w.T
    b = b - learning_rate * g_b
    return w, b, cost

4.模型优化

def optimal_model_batch(x, label, nums_feature, step=10000, batch_size=1):
    """train the model with batch"""
    length = len(x)
    w, b = initial_para(nums_feature)
    for i in range(step):
        #随机获取一个batch数目的数据
        num = randint(0, length - 1 - batch_size)
        x_batch = x[num:(num+batch_size), :]
        label_batch = label[num:num+batch_size]
        #进行一次梯度更新(优化)
        w, b, cost = gradient_descent_batch(x_batch, w, b, label_batch, 0.0001)
        #每1000次打印一下损失值
        if i%1000 == 0:
            print('step is : ', i, ', cost is: ', cost)
    return w, b

5.读取数据,数据预处理,训练模型,评估精度

import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from random import randint
from sklearn.preprocessing import StandardScaler
 
def _main():
    #读取csv格式的数据data_path是数据的路径
    data = pd.read_csv('data_path')
    #获取样本属性和标签
    x = data.iloc[:, 2:4].values
    y = data.iloc[:, 4].values
    #将数据集分为测试集和训练集
    x_train, x_test, y_train, y_test = train_test_split(x, y, test_size = 0.2, random_state=0)
    #数据预处理,去均值化
    standardscaler = StandardScaler()
    x_train = standardscaler.fit_transform(x_train)
    x_test = standardscaler.transform(x_test)
    #w, b = optimal_model(x_train, y_train, 2, 50000)
    #训练模型
    w, b = optimal_model_batch(x_train, y_train, 2, 50000, 64)
    print('trian is over')
    #对测试集进行预测,并计算精度
    predict = activation(x_test, w, b).T
    n = 0
    for i, p in enumerate(predict):
        if p >=0.5:
            if y_test[i] == 1:
                n += 1
        else:
            if y_test[i] == 0:
                n += 1
    print('accuracy is : ', n / len(y_test))

6.结果可视化

predict = np.reshape(np.int32(predict), [len(predict)])
    #将预测结果以散点图的形式可视化
    for i, j in enumerate(np.unique(predict)):
        plt.scatter(x_test[predict == j, 0], x_test[predict == j, 1], 
        c = ListedColormap(('red', 'blue'))(i), label=j)
    plt.show()

python的numpy模块实现逻辑回归模型

来源:https://blog.csdn.net/qq_35153620/article/details/95763896

标签:python,numpy,逻辑回归
0
投稿

猜你喜欢

  • Softmax函数原理及Python实现过程解析

    2022-12-15 02:18:24
  • python matplotlib折线图样式实现过程

    2022-10-28 12:18:08
  • MSSQL 基本语法及实例操作语句

    2012-07-11 15:40:09
  • 在cmd命令行里进入和退出Python程序的方法

    2023-07-18 04:21:14
  • python 移动图片到另外一个文件夹的实例

    2022-09-17 07:56:14
  • Sql Server查询性能优化之不可小觑的书签查找介绍

    2012-05-22 18:24:53
  • PHP5.6读写excel表格文件操作示例

    2023-11-21 15:03:21
  • 用python实现将数组元素按从小到大的顺序排列方法

    2022-01-07 22:03:25
  • SQL Server 2000安装图解教程

    2009-09-09 19:59:00
  • Python CSV模块使用实例

    2022-02-04 18:56:36
  • 安装MySQL的步骤和方法

    2009-07-30 08:38:00
  • Python如何优雅获取本机IP方法

    2021-03-07 15:46:16
  • Python抓包程序mitmproxy安装和使用过程图解

    2023-12-09 19:45:12
  • 常用的匹配正则表达式和实例

    2008-06-07 09:19:00
  • python3字符串操作总结

    2023-08-23 22:31:23
  • 分享216色网页拾色器(调色板)

    2007-09-27 12:33:00
  • javascript下兼容firefox选取textarea文本的代码

    2013-08-30 02:05:16
  • Python中import语句用法案例讲解

    2023-08-07 05:33:47
  • django 使用全局搜索功能的实例详解

    2023-01-26 05:56:56
  • thinkphp学习笔记之多表查询

    2023-11-15 02:57:15
  • asp之家 网络编程 m.aspxhome.com