Python线性网络实现分类糖尿病病例

作者:Henry_zs 时间:2022-03-13 11:23:25 

1. 加载数据集

这次我们搭建一个小小的多层线性网络对糖尿病的病例进行分类

首先先导入需要的库文件

Python线性网络实现分类糖尿病病例

先来看看我们的数据集

Python线性网络实现分类糖尿病病例

观察可以发现,前八列是我们的feature ,根据这八个特征可以判断出病人是否得了糖尿病。所以最后一列是1,0 的一个二分类问题

我们使用numpy 去导入数据集,delimiter 是定义分隔符,这里我们用逗号(,)分割

Python线性网络实现分类糖尿病病例

将前八列的特征放到我们的x_data里面,作为特征输入,最后一列放到y_data作为label

Tip :这里y_data 里面的 [-1] 中括号不可以省略,否则y_data会变成向量的形式

如果不习惯这种写法,可以用view改变一下形状就行

y_data = torch.from_numpy(xy[:,-1]).view(-1,1) #将y_data 的代码改成这样就可以了

下面是xy , x_data , y_data 打印出前两行的结果

Python线性网络实现分类糖尿病病例

Python线性网络实现分类糖尿病病例

2. 搭建网络+优化器

搭建网络的时候,要保证两层网络之间的维数能对应上

首先第一层的时候,因为前八列作为我们的x_data ,也就是说我们输入的特征是 8 维度的,那么由于 y = x * wT + b ,因为输入数据的x是(n * 8) 的,而我们定义的y维度是(n * 6) ,所以wT的维度应该是(8,6)

这里不需要知道啥时候转置,啥时候不转置之类的,只要满足线性的方程y = w*x+b,并且维度一致就行了。因为不管是转置,或者w和x谁在前,只是为了保证满足矩阵相乘而已

一个小的技巧就是:只需要看输入特征是多少,然后保证第一层第一个参数对应就行了,然后第一层第二个参数是想输出的维度。其次是第二层的第一个参数对应第一层第二个参数,以此类推....

Python线性网络实现分类糖尿病病例

我们采用的激活函数是ReLU , 由于是二元分类,最后一个网络的输出我们采用sigmoid输出

接下来,搭建实例化我们的网络,然后建立优化器

这里我们选择SGD随机梯度下降算法,学习率设置为0.01

Python线性网络实现分类糖尿病病例

3. 训练网络

Python线性网络实现分类糖尿病病例

训练网络的过程较为简单,大概的过程为

1. 计算预测值

2. 计算损失函数

3. 反向传播,之前要进行梯度清零

4. 梯度更新

5. 重复这个过程,epoch 为所有样本计算一次的周期,这次让epoch 迭代1000次

4. 代码

import torch.nn as nn    # 神经网络库
import matplotlib.pyplot as plt  # 绘图
import torch        # 张量
from torch import optim  # 优化器库
import numpy as np          # 数据处理
xy = np.loadtxt('./diabetes.csv.gz',delimiter=',',dtype=np.float32)    # 加载数据集
x_data = torch.from_numpy(xy[:,:-1])  # 所有行,除了最后一列的元素
y_data = torch.from_numpy(xy[:,-1]).view(-1,1) # -1也能拿出来是向量,但是[-1]会保证拿出来的是个矩阵
epoch_list =[]
loss_list = []
class Model(nn.Module):
   def __init__(self):
       super(Model,self).__init__()
       self.linear1 = nn.Linear(8,6)
       self.linear2 = nn.Linear(6,3)
       self.linear3 = nn.Linear(3,1)
       self.sigmoid = nn.Sigmoid()
       self.relu = nn.ReLU()
   def forward(self,x):
       x = self.relu(self.linear1(x))
       x = self.relu(self.linear2(x))
       x = self.sigmoid(self.linear3(x))
       return x
model = Model()
criterion = nn.BCELoss()
optimizer = optim.SGD(model.parameters(),lr =0.01)
for epoch in range(1000):
   y_pred = model(x_data)
   loss = criterion(y_pred,y_data)   # 计算损失
   if epoch % 100 ==0:   # 每隔100次打印一下
       print(epoch,loss.item())
   #back propagation
   optimizer.zero_grad()    # 梯度清零
   loss.backward()          # 反向传播
   optimizer.step()         # 梯度更新
   epoch_list.append(epoch)
   loss_list.append(loss.item())
plt.plot(epoch_list,loss_list)
plt.show()

输出结果为:

Python线性网络实现分类糖尿病病例

Python线性网络实现分类糖尿病病例

来源:https://blog.csdn.net/qq_44886601/article/details/127347389

标签:Python,线性网络,分类
0
投稿

猜你喜欢

  • Microsoft Enterprise Library 5.0 如何集成MyS

    2011-03-16 15:19:00
  • javaScript合并对象的几个常见方式

    2024-04-16 08:58:26
  • mysql数据库mysql: [ERROR] unknown option '--skip-grant-tables'

    2024-01-18 08:41:27
  • asp.net mvc4 mysql制作简单分页组件(部分视图)

    2024-01-27 17:56:36
  • Go语言基础切片的创建及初始化示例详解

    2024-04-26 17:33:44
  • python虚拟环境virtualenv的使用教程

    2021-03-08 00:50:17
  • Python astype(np.float)函数使用方法解析

    2021-02-23 17:28:16
  • 一篇文章彻底搞懂Python中可迭代(Iterable)、迭代器(Iterator)与生成器(Generator)的概念

    2023-11-03 23:52:38
  • 深入学习python的yield和generator

    2022-01-15 05:00:28
  • php实现网站留言板功能

    2023-11-23 21:06:36
  • Python批量裁剪图片的思路详解

    2023-05-11 03:55:59
  • Git撤销已经推送(push)至远端仓库的提交(commit)信息操作

    2022-05-31 04:33:28
  • 如何将Yolov5的detect.py修改为可以直接调用的函数详解

    2021-12-12 22:21:28
  • python实现通讯录系统

    2023-06-12 20:57:50
  • 利用 Python 实现多任务进程

    2023-12-19 02:53:52
  • 详解如何在Linux(CentOS)下重置MySQL根(Root)密码

    2024-01-14 08:35:27
  • MySQL插入emoji表情失败问题的解决方法

    2024-01-24 05:26:05
  • 跟老齐学Python之私有函数和专有方法

    2021-04-13 20:38:18
  • 在python中利用dict转json按输入顺序输出内容方式

    2021-10-26 15:17:23
  • SQL语句练习实例之二——找出销售冠军

    2011-10-24 19:52:45
  • asp之家 网络编程 m.aspxhome.com