python实现感知器算法(批处理)

作者:CommissarMa 时间:2022-09-06 18:02:25 

本文实例为大家分享了Python感知器算法实现的具体代码,供大家参考,具体内容如下

先创建感知器类:用于二分类


# -*- coding: utf-8 -*-

import numpy as np

class Perceptron(object):
 """
 感知器:用于二分类
 参照改写 https://blog.csdn.net/simple_the_best/article/details/54619495

属性:
 w0:偏差
 w:权向量
 learning_rate:学习率
 threshold:准则阈值
 """

def __init__(self,learning_rate=0.01,threshold=0.001):
   self.learning_rate=learning_rate
   self.threshold=threshold

def train(self,x,y):
   """训练
   参数:
   x:样本,维度为n*m(样本有m个特征,x输入就是m维),样本数量为n
   y:类标,维度为n*1,取值1和-1(正样本和负样本)

返回:
   self:object
   """
   self.w0=0.0
   self.w=np.full(x.shape[1],0.0)

k=0
   while(True):
     k+=1
     dJw0=0.0
     dJw=np.zeros(x.shape[1])
     err=0.0
     for i in range(0,x.shape[0]):
       if not (y[i]==1 or y[i]==-1):
         print("类标只能为1或-1!请核对!")
         break
       update=self.learning_rate*0.5*(y[i]-self.predict(x[i]))
       dJw0+=update
       dJw+=update*x[i]
       err+=np.abs(0.5*(y[i]-self.predict(x[i])))
     self.w0 += dJw0
     self.w += dJw
     if np.abs(np.sum(self.learning_rate*dJw))<self.threshold or k>500:
       print("迭代次数:",k," 错分样本数:",err)
       break
   return self

def predict(self,x):
   """预测类别
   参数:
   x:样本,1*m维,1个样本,m维特征

返回:
   yhat:预测的类标号,1或者-1,1代表正样本,-1代表负样本
   """
   if np.matmul(self.w,x.T)+self.w0>0:
     yhat=1
   else:
     yhat=-1
   return yhat

def predict_value(self,x):
   """预测值
   参数:
   x:样本,1*m维,1个样本,m维特征

返回:
   y:预测值
   """
   y=np.matmul(self.w,x.T)+self.w0
   return y

然后为Iris数据集创建一个Iris类,用于产生5折验证所需要的数据,并且能产生不同样本数量的数据集。


# -*- coding: utf-8 -*-
"""
Author:CommissarMa
2018年5月23日 16点52分
"""
import numpy as np
import scipy.io as sio

class Iris(object):
 """Iris数据集
 参数:
 data:根据size裁剪出来的iris数据集
 size:每种类型的样本数量
 way:one against the rest || one against one

注意:
 此处规定5折交叉验证(5-cv),所以每种类型样本的数量要是5的倍数
 多分类方式:one against the rest
 """

def __init__(self,size=50,way="one against the rest"):
   """
   size:每种类型的样本数量
   """
   data=sio.loadmat("C:\\Users\\CommissarMa\\Desktop\\模式识别\\课件ppt\\PR实验内容\\iris_data.mat")
   iris_data=data['iris_data']#iris_data:原数据集,shape:150*4,1-50个样本为第一类,51-100个样本为第二类,101-150个样本为第三类
   self.size=size
   self.way=way
   self.data=np.zeros((size*3,4))
   for r in range(0,size*3):
     self.data[r]=iris_data[int(r/size)*50+r%size]

def generate_train_data(self,index_fold,index_class,neg_class=None):
   """
   index_fold:5折验证的第几折,范围:0,1,2,3,4
   index_class:第几类作为正类,类别号:负类样本为-1,正类样本为1
   """
   if self.way=="one against the rest":
     fold_size=int(self.size/5)#将每类样本分成5份
     train_data=np.zeros((fold_size*4*3,4))
     label_data=np.full((fold_size*4*3),-1)
     for r in range(0,fold_size*4*3):
       n_class=int(r/(fold_size*4))#第几类
       n_fold=int((r%(fold_size*4))/fold_size)#第几折
       n=(r%(fold_size*4))%fold_size#第几个
       if n_fold<index_fold:
         train_data[r]=self.data[n_class*self.size+n_fold*fold_size+n]
       else:
         train_data[r]=self.data[n_class*self.size+(n_fold+1)*fold_size+n]

label_data[fold_size*4*index_class:fold_size*4*(index_class+1)]=1
   elif self.way=="one against one":
     if neg_class==None:
       print("one against one模式下需要提供负类的序号!")
       return
     else:
       fold_size=int(self.size/5)#将每类样本分成5份
       train_data=np.zeros((fold_size*4*2,4))
       label_data=np.full((fold_size*4*2),-1)
       for r in range(0,fold_size*4*2):
         n_class=int(r/(fold_size*4))#第几类
         n_fold=int((r%(fold_size*4))/fold_size)#第几折
         n=(r%(fold_size*4))%fold_size#第几个
         if n_class==0:#放正类样本
           if n_fold<index_fold:
             train_data[r]=self.data[index_class*self.size+n_fold*fold_size+n]
           else:
             train_data[r]=self.data[index_class*self.size+(n_fold+1)*fold_size+n]
         if n_class==1:#放负类样本
           if n_fold<index_fold:
             train_data[r]=self.data[neg_class*self.size+n_fold*fold_size+n]
           else:
             train_data[r]=self.data[neg_class*self.size+(n_fold+1)*fold_size+n]
       label_data[0:fold_size*4]=1
   else:
     print("多分类方式错误!只能为one against one 或 one against the rest!")
     return

return train_data,label_data

def generate_test_data(self,index_fold):
   """生成测试数据
   index_fold:5折验证的第几折,范围:0,1,2,3,4

返回值:
   test_data:对应于第index_fold折的测试数据
   label_data:类别号为0,1,2
   """
   fold_size=int(self.size/5)#将每类样本分成5份
   test_data=np.zeros((fold_size*3,4))
   label_data=np.zeros(fold_size*3)
   for r in range(0,fold_size*3):
     test_data[r]=self.data[int(int(r/fold_size)*self.size)+int(index_fold*fold_size)+r%fold_size]
   label_data[0:fold_size]=0
   label_data[fold_size:fold_size*2]=1
   label_data[fold_size*2:fold_size*3]=2

return test_data,label_data

然后我们进行训练测试,先使用one against the rest策略:


# -*- coding: utf-8 -*-

from perceptron import Perceptron
from iris_data import Iris
import numpy as np

if __name__=="__main__":
  iris=Iris(size=50,way="one against the rest")

correct_all=0
  for n_fold in range(0,5):
    p=[Perceptron(),Perceptron(),Perceptron()]
    for c in range(0,3):
      x,y=iris.generate_train_data(index_fold=n_fold,index_class=c)
      p[c].train(x,y)
    #训练完毕,开始测试
    correct=0
    x_test,y_test=iris.generate_test_data(index_fold=n_fold)
    num=len(x_test)
    for i in range(0,num):
      maxvalue=max(p[0].predict_value(x_test[i]),p[1].predict_value(x_test[i]),
         p[2].predict_value(x_test[i]))
      if maxvalue==p[int(y_test[i])].predict_value(x_test[i]):
        correct+=1
    print("错分数量:",num-correct,"错误率:",(num-correct)/num)
    correct_all+=correct
  print("平均错误率:",(num*5-correct_all)/(num*5))

然后使用one against one 策略去训练测试:


# -*- coding: utf-8 -*-

from perceptron import Perceptron
from iris_data import Iris
import numpy as np

if __name__=="__main__":
  iris=Iris(size=10,way="one against one")

correct_all=0
  for n_fold in range(0,5):
    #训练
    p01=Perceptron()#0类和1类比较的判别器
    p02=Perceptron()
    p12=Perceptron()
    x,y=iris.generate_train_data(index_fold=n_fold,index_class=0,neg_class=1)
    p01.train(x,y)
    x,y=iris.generate_train_data(index_fold=n_fold,index_class=0,neg_class=2)
    p02.train(x,y)
    x,y=iris.generate_train_data(index_fold=n_fold,index_class=1,neg_class=2)
    p12.train(x,y)
    #测试
    correct=0
    x_test,y_test=iris.generate_test_data(index_fold=n_fold)
    num=len(x_test)
    for i in range(0,num):
      vote0=0
      vote1=0
      vote2=0
      if p01.predict_value(x_test[i])>0:
        vote0+=1
      else:
        vote1+=1
      if p02.predict_value(x_test[i])>0:
        vote0+=1
      else:
        vote2+=1
      if p12.predict_value(x_test[i])>0:
        vote1+=1
      else:
        vote2+=1

if vote0==max(vote0,vote1,vote2) and int(vote0)==int(y_test[i]):
        correct+=1
      elif vote1==max(vote0,vote1,vote2) and int(vote1)==int(y_test[i]):
        correct+=1
      elif vote2==max(vote0,vote1,vote2) and int(vote2)==int(y_test[i]):
        correct+=1
    print("错分数量:",num-correct,"错误率:",(num-correct)/num)
    correct_all+=correct
  print("平均错误率:",(num*5-correct_all)/(num*5))

实验结果如图所示:

python实现感知器算法(批处理)

来源:https://blog.csdn.net/u012343179/article/details/80433068

标签:python,感知器
0
投稿

猜你喜欢

  • T-SQL查询为何慎用IN和NOT IN详解

    2024-01-21 08:26:01
  • vue中自定义指令(directive)的基本使用方法

    2024-05-28 15:46:32
  • Python列表去重复项的N种方法(实例代码)

    2023-06-27 16:00:20
  • 在python2.7中用numpy.reshape 对图像进行切割的方法

    2021-12-23 19:11:02
  • 改变 Python 中线程执行顺序的方法

    2022-01-14 16:11:10
  • 简单谈谈Python的pycurl模块

    2023-07-14 01:42:03
  • Oracle中pivot函数图文实例详解

    2023-07-12 22:13:49
  • Python3实现捕获Ctrl+C终止信号

    2021-10-17 14:03:55
  • Python办公自动化PPT批量转换操作

    2023-11-07 16:54:20
  • 使用Pandas将inf, nan转化成特定的值

    2023-04-15 23:36:33
  • Ceph集群CephFS文件存储核心概念及部署使用详解

    2023-04-18 02:23:31
  • 比较详细PHP生成静态页面教程

    2023-10-14 18:54:31
  • Python+Qt身体特征识别人数统计源码窗体程序(使用步骤)

    2021-06-03 10:40:54
  • 对python插入数据库和生成插入sql的示例讲解

    2022-03-10 05:46:40
  • numpy中hstack vstack stack concatenate函数示例详解

    2023-02-22 19:39:06
  • 浅谈品牌的视觉识别

    2009-07-03 12:28:00
  • 18个超棒的Web和移动应用开发框架

    2011-03-31 17:04:00
  • python爬虫爬取笔趣网小说网站过程图解

    2022-10-06 10:56:50
  • 情感的容器 被寄托了的QQ2010视觉设计

    2010-02-03 14:51:00
  • 详解Go多协程并发环境下的错误处理

    2024-02-15 21:11:59
  • asp之家 网络编程 m.aspxhome.com