python实现基于信息增益的决策树归纳

作者:conggova 时间:2022-05-20 14:22:47 

本文实例为大家分享了基于信息增益的决策树归纳的Python实现代码,供大家参考,具体内容如下


# -*- coding: utf-8 -*-
import numpy as np
import matplotlib.mlab as mlab
import matplotlib.pyplot as plt
from copy import copy

#加载训练数据
#文件格式:属性标号,是否连续【yes|no】,属性说明
attribute_file_dest = 'F:\\bayes_categorize\\attribute.dat'
attribute_file = open(attribute_file_dest)

#文件格式:rec_id,attr1_value,attr2_value,...,attrn_value,class_id
trainning_data_file_dest = 'F:\\bayes_categorize\\trainning_data.dat'
trainning_data_file = open(trainning_data_file_dest)

#文件格式:class_id,class_desc
class_desc_file_dest = 'F:\\bayes_categorize\\class_desc.dat'
class_desc_file = open(class_desc_file_dest)

root_attr_dict = {}
for line in attribute_file :
 line = line.strip()
 fld_list = line.split(',')
 root_attr_dict[int(fld_list[0])] = tuple(fld_list[1:])

class_dict = {}
for line in class_desc_file :
 line = line.strip()
 fld_list = line.split(',')
 class_dict[int(fld_list[0])] = fld_list[1]

trainning_data_dict = {}
class_member_set_dict = {}
for line in trainning_data_file :
 line = line.strip()
 fld_list = line.split(',')
 rec_id = int(fld_list[0])
 a1 = int(fld_list[1])
 a2 = int(fld_list[2])
 a3 = float(fld_list[3])
 c_id = int(fld_list[4])

if c_id not in class_member_set_dict :
   class_member_set_dict[c_id] = set()
 class_member_set_dict[c_id].add(rec_id)
 trainning_data_dict[rec_id] = (a1 , a2 , a3 , c_id)

attribute_file.close()
class_desc_file.close()
trainning_data_file.close()

class_possibility_dict = {}
for c_id in class_member_set_dict :
 class_possibility_dict[c_id] = (len(class_member_set_dict[c_id]) + 0.0)/len(trainning_data_dict)  

#等待分类的数据
data_to_classify_file_dest = 'F:\\bayes_categorize\\trainning_data_new.dat'
data_to_classify_file = open(data_to_classify_file_dest)
data_to_classify_dict = {}
for line in data_to_classify_file :
 line = line.strip()
 fld_list = line.split(',')
 rec_id = int(fld_list[0])
 a1 = int(fld_list[1])
 a2 = int(fld_list[2])
 a3 = float(fld_list[3])
 c_id = int(fld_list[4])
 data_to_classify_dict[rec_id] = (a1 , a2 , a3 , c_id)
data_to_classify_file.close()

'''
决策树的表达
结点的需求:
1、指示出是哪一种分区 一共3种 一是离散穷举 二是连续有分裂点 三是离散有判别集合 零是叶子结点
2、保存分类所需信息
3、子结点列表
每个结点用Tuple类型表示
元素一是整形,取值123 分别对应两种分裂类型
元素二是集合类型 对于1保存所有的离散值 对于2保存分裂点 对于3保存判别集合 对于0保存分类结果类标号
元素三是dict key对于1来说是某个的离散值 对于23来说只有12两种 对于2来说1代表小于等于分裂点
对于3来说1代表属于判别集合
'''

#对于一个成员列表,计算其熵
#公式为 Info_D = - sum(pi * log2 (pi)) pi为一个元素属于Ci的概率,用|Ci|/|D|计算 ,对所有分类求和
def get_entropy( member_list ) :
 #成员总数
 mem_cnt = len(member_list)
 #首先找出member中所包含的分类
 class_dict = {}
 for mem_id in member_list :
   c_id = trainning_data_dict[mem_id][3]
   if c_id not in class_dict :
     class_dict[c_id] = set()
   class_dict[c_id].add(mem_id)

tmp_sum = 0.0
 for c_id in class_dict :
   pi = ( len(class_dict[c_id]) + 0.0 ) / mem_cnt
   tmp_sum += pi * mlab.log2(pi)
 tmp_sum = -tmp_sum
 return tmp_sum

def attribute_selection_method( member_list , attribute_dict ) :
 #先计算原始的熵
 info_D = get_entropy(member_list)

max_info_Gain = 0.0
 attr_get = 0
 split_point = 0.0
 for attr_id in attribute_dict :
   #对于每一个属性计算划分后的熵
   #信息增益等于原始的熵减去划分后的熵
   info_D_new = 0
   #如果是连续属性
   if attribute_dict[attr_id][0] == 'yes' :
     #先得到memberlist中此属性的取值序列,把序列中每一对相邻项的中值作为划分点计算熵
     #找出其中最小的,作为此连续属性的划分点
     value_list = []
     for mem_id in member_list :
       value_list.append(trainning_data_dict[mem_id][attr_id - 1])

#获取相邻元素的中值序列
     mid_value_list = []
     value_list.sort()
     #print value_list
     last_value = None
     for value in value_list :
       if value == last_value :
         continue
       if last_value is not None :
         mid_value_list.append((last_value+value)/2)
       last_value = value
     #print mid_value_list
     #对于中值序列做循环
     #计算以此值做为划分点的熵
     #总的熵等于两个划分的熵乘以两个划分的比重
     min_info = 1000000000.0
     total_mens = len(member_list) + 0.0
     for mid_value in mid_value_list :
       #小于mid_value的mem
       less_list = []
       #大于
       more_list = []
       for tmp_mem_id in member_list :
         if trainning_data_dict[tmp_mem_id][attr_id - 1] <= mid_value :
           less_list.append(tmp_mem_id)
         else :
           more_list.append(tmp_mem_id)
       sum_info = len(less_list)/total_mens * get_entropy(less_list) \
       + len(more_list)/total_mens * get_entropy(more_list)

if sum_info < min_info :
         min_info = sum_info
         split_point = mid_value

info_D_new = min_info
   #如果是离散属性
   else :
     #计算划分后的熵
     #采用循环累加的方式
     attr_value_member_dict = {} #键为attribute value , 值为memberlist
     for tmp_mem_id in member_list :
       attr_value = trainning_data_dict[tmp_mem_id][attr_id - 1]
       if attr_value not in attr_value_member_dict :
         attr_value_member_dict[attr_value] = []
       attr_value_member_dict[attr_value].append(tmp_mem_id)
     #将每个离散值的熵乘以比重加到这上面
     total_mens = len(member_list) + 0.0
     sum_info = 0.0
     for a_value in attr_value_member_dict :
       sum_info += len(attr_value_member_dict[a_value])/total_mens \
       * get_entropy(attr_value_member_dict[a_value])

info_D_new = sum_info

info_Gain = info_D - info_D_new
   if info_Gain > max_info_Gain :
     max_info_Gain = info_Gain
     attr_get = attr_id

#如果是离散的
 #print 'attr_get ' + str(attr_get)
 if attribute_dict[attr_get][0] == 'no' :
   return (1 , attr_get , split_point)
 else :  
   return (2 , attr_get , split_point)
 #第三类先不考虑

def get_decision_tree(father_node , key , member_list , attr_dict ) :
 #最终的结果是新建一个结点,并且添加到father_node的sub_node_dict,对key为键
 #检查memberlist 如果都是同类的,则生成一个叶子结点,set里面保存类标号
 class_set = set()
 for mem_id in member_list :
   class_set.add(trainning_data_dict[mem_id][3])
 if len(class_set) == 1 :
   father_node[2][key] = (0 , (1 , class_set) , {} )
   return

#检查attribute_list,如果为空,产生叶子结点,类标号为memberlist中多数元素的类标号
 #如果几个类的成员等量,则打印提示,并且全部添加到set里面
 if not attr_dict :
   class_cnt_dict = {}
   for mem_id in member_list :
     c_id = trainning_data_dict[mem_id][3]
     if c_id not in class_cnt_dict :
       class_cnt_dict[c_id] = 1
     else :
       class_cnt_dict[c_id] += 1

class_set = set()
   max_cnt = 0
   for c_id in class_cnt_dict :
     if class_cnt_dict[c_id] > max_cnt :
       max_cnt = class_cnt_dict[c_id]
       class_set.clear()
       class_set.add(c_id)
     elif class_cnt_dict[c_id] == max_cnt :
       class_set.add(c_id)

if len(class_set) > 1 :
     print 'more than one class !'

father_node[2][key] = (0 , (1 , class_set ) , {} )
   return

#找出最好的分区方案 , 暂不考虑第三种划分方法
 #比较所有离散属性和所有连续属性的所有中值点划分的信息增益
 split_criterion = attribute_selection_method(member_list , attr_dict)
 #print split_criterion
 selected_plan_id = split_criterion[0]
 selected_attr_id = split_criterion[1]

#如果采用的是离散属性做为分区方案,删除这个属性
 new_attr_dict = copy(attr_dict)
 if attr_dict[selected_attr_id][0] == 'no' :
   del new_attr_dict[selected_attr_id]

#建立一个结点new_node,father_node[2][key] = new_node
 #然后对new node的每一个key , sub_member_list,
 #调用 get_decision_tree(new_node , new_key , sub_member_list , new_attribute_dict)
 #实现递归
 ele2 = ( selected_attr_id , set() )
 #如果是1 , ele2保存所有离散值
 if selected_plan_id == 1 :
   for mem_id in member_list :
     ele2[1].add(trainning_data_dict[mem_id][selected_attr_id - 1])
 #如果是2,ele2保存分裂点
 elif selected_plan_id == 2 :
   ele2[1].add(split_criterion[2])
 #如果是3则保存判别集合,先不管
 else :
   print 'not completed'
   pass

new_node = ( selected_plan_id , ele2 , {} )
 father_node[2][key] = new_node

#生成KEY,并递归调用
 if selected_plan_id == 1 :
   #每个attr_value是一个key
   attr_value_member_dict = {}
   for mem_id in member_list :
     attr_value = trainning_data_dict[mem_id][selected_attr_id - 1 ]
     if attr_value not in attr_value_member_dict :
       attr_value_member_dict[attr_value] = []
     attr_value_member_dict[attr_value].append(mem_id)
   for attr_value in attr_value_member_dict :
     get_decision_tree(new_node , attr_value , attr_value_member_dict[attr_value] , new_attr_dict)
   pass
 elif selected_plan_id == 2 :
   #key 只有12 , 小于等于分裂点的是1 , 大于的是2
   less_list = []
   more_list = []
   for mem_id in member_list :
     attr_value = trainning_data_dict[mem_id][selected_attr_id - 1 ]
     if attr_value <= split_criterion[2] :
       less_list.append(mem_id)
     else :
       more_list.append(mem_id)
   #if len(less_list) != 0 :
   get_decision_tree(new_node , 1 , less_list , new_attr_dict)
   #if len(more_list) != 0 :
   get_decision_tree(new_node , 2 , more_list , new_attr_dict)
   pass
 #如果是3则保存判别集合,先不管
 else :
   print 'not completed'
   pass

def get_class_sub(node , tp ) :
 #
 attr_id = node[1][0]
 plan_id = node[0]
 key = 0
 if plan_id == 0 :
   return node[1][1]
 elif plan_id == 1 :
   key = tp[attr_id - 1]
 elif plan_id == 2 :
   split_point = tuple(node[1][1])[0]
   attr_value = tp[attr_id - 1]
   if attr_value <= split_point :
     key = 1
   else :
     key = 2
 else :
   print 'error'
   return set()

return get_class_sub(node[2][key] , tp )

def get_class(r_node , tp) :
 #tp为一组属性值
 if r_node[0] != -1 :
   print 'error'
   return set()

if 1 in r_node[2] :
   return get_class_sub(r_node[2][1] , tp)
 else :
   print 'error'
   return set()

if __name__ == '__main__' :
 root_node = ( -1 , set() , {} )
 mem_list = trainning_data_dict.keys()
 get_decision_tree(root_node , 1 , mem_list , root_attr_dict )

#测试分类器的准确率
 diff_cnt = 0
 for mem_id in data_to_classify_dict :
   c_id = get_class(root_node , data_to_classify_dict[mem_id][0:3])
   if tuple(c_id)[0] != data_to_classify_dict[mem_id][3] :
     print tuple(c_id)[0]
     print data_to_classify_dict[mem_id][3]
     print 'different'
     diff_cnt += 1
 print diff_cnt

来源:https://blog.csdn.net/conggova/article/details/77528966

标签:python,信息增益,决策树
0
投稿

猜你喜欢

  • python 多线程实现多任务的方法示例

    2021-04-12 08:36:05
  • 详解Pytest测试用例的执行方法

    2022-02-15 18:28:14
  • IntelliJ IDEA卡死,如何优化内存

    2023-07-04 12:10:27
  • 实现一个获取元素样式的函数getStyle

    2009-02-10 10:37:00
  • python 协程 gevent原理与用法分析

    2021-10-12 23:36:19
  • 自然语言处理NLP TextRNN实现情感分类

    2022-01-20 11:14:47
  • Vue.js实现文章评论和回复评论功能

    2024-05-29 22:20:31
  • mysql递归函数with recursive的用法举例

    2024-01-16 22:37:22
  • 如何提取python字符串括号中的内容

    2021-01-11 01:29:03
  • 通过VB6将ASP编译封装成DLL组件最简教程 附全部工程源文件

    2012-11-30 20:20:50
  • JavaScript中clientWidth,offsetWidth,scrollWidth的区别

    2024-04-22 22:24:59
  • POST与GET方法的区别简要分析

    2022-06-26 17:27:36
  • 详解Python中的字符串格式化

    2023-09-10 22:38:14
  • MYSQL中binlog优化的一些思考汇总

    2024-01-23 01:58:25
  • Python使用Django实现博客系统完整版

    2021-02-10 14:43:48
  • 如何利用Pandas删除某列指定值所在的行

    2023-10-29 11:49:39
  • conda与jupyter notebook kernel核环境不一致的问题解决

    2021-07-03 15:43:02
  • MySQL优化方案之开启慢查询日志

    2024-01-23 09:30:23
  • PHP常用字符串操作函数实例总结(trim、nl2br、addcslashes、uudecode、md5等)

    2023-10-02 13:10:01
  • vue单向数据流的深入讲解

    2024-04-10 13:48:33
  • asp之家 网络编程 m.aspxhome.com