感知器基础原理及python实现过程详解

作者:沙克的世界 时间:2023-11-07 16:24:35 

简单版本,按照李航的《统计学习方法》的思路编写

感知器基础原理及python实现过程详解

数据采用了著名的sklearn自带的iries数据,最优化求解采用了SGD算法。

预处理增加了标准化操作。


'''
perceptron classifier

created on 2019.9.14
author: vince
'''
import pandas
import numpy
import logging
import matplotlib.pyplot as plt

from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

'''
perceptron classifier

Attributes
w: ld-array = weights after training
l: list = number of misclassification during each iteration
'''
class Perceptron:
 def __init__(self, eta = 0.01, iter_num = 50, batch_size = 1):
   '''
   eta: float = learning rate (between 0.0 and 1.0).
   iter_num: int = iteration over the training dataset.
   batch_size: int = gradient descent batch number,
     if batch_size == 1, used SGD;
     if batch_size == 0, use BGD;
     else MBGD;
   '''

self.eta = eta;
   self.iter_num = iter_num;
   self.batch_size = batch_size;

def train(self, X, Y):
   '''
   train training data.
   X:{array-like}, shape=[n_samples, n_features] = Training vectors,
     where n_samples is the number of training samples and
     n_features is the number of features.
   Y:{array-like}, share=[n_samples] = traget values.
   '''
   self.w = numpy.zeros(1 + X.shape[1]);
   self.l = numpy.zeros(self.iter_num);
   for iter_index in range(self.iter_num):
     for sample_index in range(X.shape[0]):
       if (self.activation(X[sample_index]) != Y[sample_index]):
         logging.debug("%s: pred(%s), label(%s), %s, %s" % (sample_index,
           self.net_input(X[sample_index]) , Y[sample_index],
           X[sample_index, 0], X[sample_index, 1]));
         self.l[iter_index] += 1;
     for sample_index in range(X.shape[0]):
       if (self.activation(X[sample_index]) != Y[sample_index]):
         self.w[0] += self.eta * Y[sample_index];
         self.w[1:] += self.eta * numpy.dot(X[sample_index], Y[sample_index]);
         break;
     logging.info("iter %s: %s, %s, %s, %s" %
         (iter_index, self.w[0], self.w[1], self.w[2], self.l[iter_index]));

def activation(self, x):
   return numpy.where(self.net_input(x) >= 0.0 , 1 , -1);

def net_input(self, x):
   return numpy.dot(x, self.w[1:]) + self.w[0];

def predict(self, x):
   return self.activation(x);

def main():
 logging.basicConfig(level = logging.INFO,
     format = '%(asctime)s %(filename)s[line:%(lineno)d] %(levelname)s %(message)s',
     datefmt = '%a, %d %b %Y %H:%M:%S');

iris = load_iris();

features = iris.data[:99, [0, 2]];
 # normalization
 features_std = numpy.copy(features);
 for i in range(features.shape[1]):
   features_std[:, i] = (features_std[:, i] - features[:, i].mean()) / features[:, i].std();

labels = numpy.where(iris.target[:99] == 0, -1, 1);

# 2/3 data from training, 1/3 data for testing
 train_features, test_features, train_labels, test_labels = train_test_split(
     features_std, labels, test_size = 0.33, random_state = 23323);

logging.info("train set shape:%s" % (str(train_features.shape)));

p = Perceptron();

p.train(train_features, train_labels);

test_predict = numpy.array([]);
 for feature in test_features:
   predict_label = p.predict(feature);
   test_predict = numpy.append(test_predict, predict_label);

score = accuracy_score(test_labels, test_predict);
 logging.info("The accruacy score is: %s "% (str(score)));

#plot
 x_min, x_max = train_features[:, 0].min() - 1, train_features[:, 0].max() + 1;
 y_min, y_max = train_features[:, 1].min() - 1, train_features[:, 1].max() + 1;
 plt.xlim(x_min, x_max);
 plt.ylim(y_min, y_max);
 plt.xlabel("width");
 plt.ylabel("heigt");

plt.scatter(train_features[:, 0], train_features[:, 1], c = train_labels, marker = 'o', s = 10);

k = - p.w[1] / p.w[2];
 d = - p.w[0] / p.w[2];

plt.plot([x_min, x_max], [k * x_min + d, k * x_max + d], "go-");

plt.show();

if __name__ == "__main__":
 main();

感知器基础原理及python实现过程详解

来源:https://www.cnblogs.com/thsss/p/11519846.html

标签:感知器,原理,python,实现
0
投稿

猜你喜欢

  • JavaScript ES6中const、let与var的对比详解

    2024-05-22 10:37:36
  • Go语言单元测试模拟服务请求和接口返回

    2024-04-23 09:41:13
  • python爬取招聘要求等信息实例

    2021-01-27 21:22:36
  • axios拦截器工作方式及原理源码解析

    2023-07-02 16:38:36
  • mysql中取字符串中的数字的语句

    2024-01-15 02:16:15
  • selenium+opencv实现滑块验证码的登陆

    2022-03-28 06:49:04
  • 浅谈Python中函数的定义及其调用方法

    2022-09-01 09:35:35
  • Python画图小案例之小雪人超详细源码注释

    2021-09-21 11:49:44
  • Spring数据库事务的实现机制讲解

    2024-01-19 11:32:10
  • sql模式设置引起的问题解决办法

    2024-01-17 03:38:16
  • Python装饰器基础概念与用法详解

    2021-07-07 12:32:46
  • python+opencv识别图片中的圆形

    2022-02-10 00:04:23
  • openCV实现图像融合的示例代码

    2022-05-20 03:28:16
  • python-序列解包(对可迭代元素的快速取值方法)

    2023-12-28 23:23:57
  • Python range与enumerate函数区别解析

    2022-03-05 21:40:20
  • 网页设计五原则

    2007-11-03 13:50:00
  • 对Python函数设计规范详解

    2023-08-02 15:59:17
  • 如何定义记录集打开的游标类型?

    2009-11-15 20:19:00
  • bootstrap为水平排列的表单和内联表单设置可选的图标

    2024-05-05 09:16:06
  • asp如何让页面背景五彩缤纷?

    2010-05-13 16:38:00
  • asp之家 网络编程 m.aspxhome.com