pytorch实现用CNN和LSTM对文本进行分类方式

作者:Alphapeople 时间:2023-07-16 18:05:13 

model.py:


#!/usr/bin/python
# -*- coding: utf-8 -*-

import torch
from torch import nn
import numpy as np
from torch.autograd import Variable
import torch.nn.functional as F

class TextRNN(nn.Module):
 """文本分类,RNN模型"""
 def __init__(self):
   super(TextRNN, self).__init__()
   # 三个待输入的数据
   self.embedding = nn.Embedding(5000, 64) # 进行词嵌入
   # self.rnn = nn.LSTM(input_size=64, hidden_size=128, num_layers=2, bidirectional=True)
   self.rnn = nn.GRU(input_size=64, hidden_size=128, num_layers=2, bidirectional=True)
   self.f1 = nn.Sequential(nn.Linear(256,128),
               nn.Dropout(0.8),
               nn.ReLU())
   self.f2 = nn.Sequential(nn.Linear(128,10),
               nn.Softmax())

def forward(self, x):
   x = self.embedding(x)
   x,_ = self.rnn(x)
   x = F.dropout(x,p=0.8)
   x = self.f1(x[:,-1,:])
   return self.f2(x)

class TextCNN(nn.Module):
 def __init__(self):
   super(TextCNN, self).__init__()
   self.embedding = nn.Embedding(5000,64)
   self.conv = nn.Conv1d(64,256,5)
   self.f1 = nn.Sequential(nn.Linear(256*596, 128),
               nn.ReLU())
   self.f2 = nn.Sequential(nn.Linear(128, 10),
               nn.Softmax())
 def forward(self, x):
   x = self.embedding(x)
   x = x.detach().numpy()
   x = np.transpose(x,[0,2,1])
   x = torch.Tensor(x)
   x = Variable(x)
   x = self.conv(x)
   x = x.view(-1,256*596)
   x = self.f1(x)
   return self.f2(x)

train.py:


# coding: utf-8

from __future__ import print_function
import torch
from torch import nn
from torch import optim
from torch.autograd import Variable
import os

import numpy as np

from model import TextRNN,TextCNN
from cnews_loader import read_vocab, read_category, batch_iter, process_file, build_vocab

base_dir = 'cnews'
train_dir = os.path.join(base_dir, 'cnews.train.txt')
test_dir = os.path.join(base_dir, 'cnews.test.txt')
val_dir = os.path.join(base_dir, 'cnews.val.txt')
vocab_dir = os.path.join(base_dir, 'cnews.vocab.txt')

def train():
 x_train, y_train = process_file(train_dir, word_to_id, cat_to_id,600)#获取训练数据每个字的id和对应标签的oe-hot形式
 x_val, y_val = process_file(val_dir, word_to_id, cat_to_id,600)
 #使用LSTM或者CNN
 model = TextRNN()
 # model = TextCNN()
 #选择损失函数
 Loss = nn.MultiLabelSoftMarginLoss()
 # Loss = nn.BCELoss()
 # Loss = nn.MSELoss()
 optimizer = optim.Adam(model.parameters(),lr=0.001)
 best_val_acc = 0
 for epoch in range(1000):
   batch_train = batch_iter(x_train, y_train,100)
   for x_batch, y_batch in batch_train:
     x = np.array(x_batch)
     y = np.array(y_batch)
     x = torch.LongTensor(x)
     y = torch.Tensor(y)
     # y = torch.LongTensor(y)
     x = Variable(x)
     y = Variable(y)
     out = model(x)
     loss = Loss(out,y)
     optimizer.zero_grad()
     loss.backward()
     optimizer.step()
     accracy = np.mean((torch.argmax(out,1)==torch.argmax(y,1)).numpy())
   #对模型进行验证
   if (epoch+1)%20 == 0:
     batch_val = batch_iter(x_val, y_val, 100)
     for x_batch, y_batch in batch_train:
       x = np.array(x_batch)
       y = np.array(y_batch)
       x = torch.LongTensor(x)
       y = torch.Tensor(y)
       # y = torch.LongTensor(y)
       x = Variable(x)
       y = Variable(y)
       out = model(x)
       loss = Loss(out, y)
       optimizer.zero_grad()
       loss.backward()
       optimizer.step()
       accracy = np.mean((torch.argmax(out, 1) == torch.argmax(y, 1)).numpy())
       if accracy > best_val_acc:
         torch.save(model.state_dict(),'model_params.pkl')
         best_val_acc = accracy
       print(accracy)

if __name__ == '__main__':
 #获取文本的类别及其对应id的字典
 categories, cat_to_id = read_category()
 #获取训练文本中所有出现过的字及其所对应的id
 words, word_to_id = read_vocab(vocab_dir)
 #获取字数
 vocab_size = len(words)
 train()

test.py:


# coding: utf-8

from __future__ import print_function

import os
import tensorflow.contrib.keras as kr
import torch
from torch import nn
from cnews_loader import read_category, read_vocab
from model import TextRNN
from torch.autograd import Variable
import numpy as np
try:
 bool(type(unicode))
except NameError:
 unicode = str

base_dir = 'cnews'
vocab_dir = os.path.join(base_dir, 'cnews.vocab.txt')

class TextCNN(nn.Module):
 def __init__(self):
   super(TextCNN, self).__init__()
   self.embedding = nn.Embedding(5000,64)
   self.conv = nn.Conv1d(64,256,5)
   self.f1 = nn.Sequential(nn.Linear(152576, 128),
               nn.ReLU())
   self.f2 = nn.Sequential(nn.Linear(128, 10),
               nn.Softmax())
 def forward(self, x):
   x = self.embedding(x)
   x = x.detach().numpy()
   x = np.transpose(x,[0,2,1])
   x = torch.Tensor(x)
   x = Variable(x)
   x = self.conv(x)
   x = x.view(-1,152576)
   x = self.f1(x)
   return self.f2(x)

class CnnModel:
 def __init__(self):
   self.categories, self.cat_to_id = read_category()
   self.words, self.word_to_id = read_vocab(vocab_dir)
   self.model = TextCNN()
   self.model.load_state_dict(torch.load('model_params.pkl'))

def predict(self, message):
   # 支持不论在python2还是python3下训练的模型都可以在2或者3的环境下运行
   content = unicode(message)
   data = [self.word_to_id[x] for x in content if x in self.word_to_id]
   data = kr.preprocessing.sequence.pad_sequences([data],600)
   data = torch.LongTensor(data)
   y_pred_cls = self.model(data)
   class_index = torch.argmax(y_pred_cls[0]).item()
   return self.categories[class_index]

class RnnModel:
 def __init__(self):
   self.categories, self.cat_to_id = read_category()
   self.words, self.word_to_id = read_vocab(vocab_dir)
   self.model = TextRNN()
   self.model.load_state_dict(torch.load('model_rnn_params.pkl'))

def predict(self, message):
   # 支持不论在python2还是python3下训练的模型都可以在2或者3的环境下运行
   content = unicode(message)
   data = [self.word_to_id[x] for x in content if x in self.word_to_id]
   data = kr.preprocessing.sequence.pad_sequences([data], 600)
   data = torch.LongTensor(data)
   y_pred_cls = self.model(data)
   class_index = torch.argmax(y_pred_cls[0]).item()
   return self.categories[class_index]

if __name__ == '__main__':
 model = CnnModel()
 # model = RnnModel()
 test_demo = ['湖人助教力助科比恢复手感 他也是阿泰的精神导师新浪体育讯记者戴高乐报道 上赛季,科比的右手食指遭遇重创,他的投篮手感也因此大受影响。不过很快科比就调整了自己的投篮手型,并通过这一方式让自己的投篮命中率回升。而在这科比背后,有一位特别助教对科比帮助很大,他就是查克·珀森。珀森上赛季担任湖人的特别助教,除了帮助科比调整投篮手型之外,他的另一个重要任务就是担任阿泰的精神导师。来到湖人队之后,阿泰收敛起了暴躁的脾气,成为湖人夺冠路上不可或缺的一员,珀森的“心灵按摩”功不可没。经历了上赛季的成功之后,珀森本赛季被“升职”成为湖人队的全职助教,每场比赛,他都会坐在球场边,帮助禅师杰克逊一起指挥湖人球员在场上拼杀。对于珀森的工作,禅师非常欣赏,“查克非常善于分析问题,”菲尔·杰克逊说,“他总是在寻找问题的答案,同时也在找造成这一问题的原因,这是我们都非常乐于看到的。我会在平时把防守中出现的一些问题交给他,然后他会通过组织球员练习找到解决的办法。他在球员时代曾是一名很好的外线投手,不过现在他与内线球员的配合也相当不错。',
        '弗老大被裁美国媒体看热闹“特权”在中国像蠢蛋弗老大要走了。虽然他只在首钢男篮效力了13天,而且表现毫无亮点,大大地让球迷和俱乐部失望了,但就像中国人常说的“好聚好散”,队友还是友好地与他告别,俱乐部与他和平分手,球迷还请他留下了在北京的最后一次签名。相比之下,弗老大的同胞美国人却没那么“宽容”。他们嘲讽这位NBA前巨星的英雄迟暮,批评他在CBA的业余表现,还惊讶于中国人的“大方”。今天,北京首钢俱乐部将与弗朗西斯继续商讨解约一事。从昨日的进展来看,双方可以做到“买卖不成人意在”,但回到美国后,恐怕等待弗朗西斯的就没有这么轻松的环境了。进展@北京昨日与队友告别 最后一次为球迷签名弗朗西斯在13天里为首钢队打了4场比赛,3场的得分为0,只有一场得了2分。昨天是他来到北京的第14天,虽然他与首钢还未正式解约,但双方都明白“缘分已尽”。下午,弗朗西斯来到首钢俱乐部与队友们告别。弗朗西斯走到队友身边,依次与他们握手拥抱。“你们都对我很好,安排的条件也很好,我很喜欢这支球队,想融入你们,但我现在真的很不适应。希望你们']
 for i in test_demo:
   print(i,":",model.predict(i))

cnews_loader.py:


# coding: utf-8

import sys
from collections import Counter

import numpy as np
import tensorflow.contrib.keras as kr

if sys.version_info[0] > 2:
 is_py3 = True
else:
 reload(sys)
 sys.setdefaultencoding("utf-8")
 is_py3 = False

def native_word(word, encoding='utf-8'):
 """如果在python2下面使用python3训练的模型,可考虑调用此函数转化一下字符编码"""
 if not is_py3:
   return word.encode(encoding)
 else:
   return word

def native_content(content):
 if not is_py3:
   return content.decode('utf-8')
 else:
   return content

def open_file(filename, mode='r'):
 """
 常用文件操作,可在python2和python3间切换.
 mode: 'r' or 'w' for read or write
 """
 if is_py3:
   return open(filename, mode, encoding='utf-8', errors='ignore')
 else:
   return open(filename, mode)

def read_file(filename):
 """读取文件数据"""
 contents, labels = [], []
 with open_file(filename) as f:
   for line in f:
     try:
       label, content = line.strip().split('\t')
       if content:
         contents.append(list(native_content(content)))
         labels.append(native_content(label))
     except:
       pass
 return contents, labels

def build_vocab(train_dir, vocab_dir, vocab_size=5000):
 """根据训练集构建词汇表,存储"""
 data_train, _ = read_file(train_dir)

all_data = []
 for content in data_train:
   all_data.extend(content)

counter = Counter(all_data)
 count_pairs = counter.most_common(vocab_size - 1)
 words, _ = list(zip(*count_pairs))
 # 添加一个 <PAD> 来将所有文本pad为同一长度
 words = ['<PAD>'] + list(words)
 open_file(vocab_dir, mode='w').write('\n'.join(words) + '\n')

def read_vocab(vocab_dir):
 """读取词汇表"""
 # words = open_file(vocab_dir).read().strip().split('\n')
 with open_file(vocab_dir) as fp:
   # 如果是py2 则每个值都转化为unicode
   words = [native_content(_.strip()) for _ in fp.readlines()]
 word_to_id = dict(zip(words, range(len(words))))
 return words, word_to_id

def read_category():
 """读取分类目录,固定"""
 categories = ['体育', '财经', '房产', '家居', '教育', '科技', '时尚', '时政', '游戏', '娱乐']

categories = [native_content(x) for x in categories]

cat_to_id = dict(zip(categories, range(len(categories))))

return categories, cat_to_id

def to_words(content, words):
 """将id表示的内容转换为文字"""
 return ''.join(words[x] for x in content)

def process_file(filename, word_to_id, cat_to_id, max_length=600):
 """将文件转换为id表示"""
 contents, labels = read_file(filename)#读取训练数据的每一句话及其所对应的类别
 data_id, label_id = [], []
 for i in range(len(contents)):
   data_id.append([word_to_id[x] for x in contents[i] if x in word_to_id])#将每句话id化
   label_id.append(cat_to_id[labels[i]])#每句话对应的类别的id
 #
 # # 使用keras提供的pad_sequences来将文本pad为固定长度
 x_pad = kr.preprocessing.sequence.pad_sequences(data_id, max_length)
 y_pad = kr.utils.to_categorical(label_id, num_classes=len(cat_to_id)) # 将标签转换为one-hot表示
 #
 return x_pad, y_pad

def batch_iter(x, y, batch_size=64):
 """生成批次数据"""
 data_len = len(x)
 num_batch = int((data_len - 1) / batch_size) + 1

indices = np.random.permutation(np.arange(data_len))
 x_shuffle = x[indices]
 y_shuffle = y[indices]

for i in range(num_batch):
   start_id = i * batch_size
   end_id = min((i + 1) * batch_size, data_len)
   yield x_shuffle[start_id:end_id], y_shuffle[start_id:end_id]

来源:https://blog.csdn.net/weixin_38241876/article/details/90606639

标签:pytorch,CNN,LSTM,分类
0
投稿

猜你喜欢

  • CSS Sprites对CSS布局的意义、优点和缺点介绍

    2008-07-14 07:22:00
  • ASP 获取文件扩展名函数getFileExt()

    2011-03-11 11:18:00
  • 页面制作的重要性

    2007-10-30 13:14:00
  • Python学习之魔法函数(filter,map,reduce)详解

    2023-03-25 05:32:21
  • Javascript Math对象

    2024-05-03 15:59:39
  • 分享:在存储过程中使用另一个存储过程返回的查询结果集的方法

    2024-01-16 13:03:57
  • Linux下编译安装MySQL-Python教程

    2021-05-03 05:05:40
  • pandas DataFrame运算的实现

    2021-06-02 21:08:22
  • Pycharm设置界面全黑的方法

    2021-09-15 11:13:51
  • Python实现查看系统启动项功能示例

    2022-12-27 17:03:14
  • Windows安装MySQL8.0.28.0.msi方式(图文详解)

    2024-01-24 14:55:24
  • 详解Golang如何实现支持随机删除元素的堆

    2024-02-22 20:04:53
  • 如何优化网站图片以快速显示

    2008-04-05 10:09:00
  • mysql insert语句操作实例讲解

    2024-01-15 12:12:24
  • 使用Python读取大文件的方法

    2022-02-18 00:43:35
  • 利用python做数据拟合详情

    2023-04-22 15:32:17
  • mysql二进制日志文件恢复数据库

    2024-01-16 10:55:05
  • Python 多进程并发操作中进程池Pool的实例

    2022-06-28 16:31:37
  • Python文字截图识别OCR工具实例解析

    2021-07-02 12:58:20
  • python将二维数组升为一维数组或二维降为一维方法实例

    2023-07-25 07:51:59
  • asp之家 网络编程 m.aspxhome.com