使用keras实现BiLSTM+CNN+CRF文字标记NER

作者:xinfeng2005 时间:2022-05-01 04:46:22 

我就废话不多说了,大家还是直接看代码吧~


import keras
from sklearn.model_selection import train_test_split
import tensorflow as tf
from keras.callbacks import ModelCheckpoint,Callback
# import keras.backend as K
from keras.layers import *
from keras.models import Model
from keras.optimizers import SGD, RMSprop, Adagrad,Adam
from keras.models import *
from keras.metrics import *
from keras import backend as K
from keras.regularizers import *
from keras.metrics import categorical_accuracy
# from keras.regularizers import activity_l1 #通过L1正则项,使得输出更加稀疏
from keras_contrib.layers import CRF

from visual_callbacks import AccLossPlotter
plotter = AccLossPlotter(graphs=['acc', 'loss'], save_graph=True, save_graph_path=sys.path[0])

# from crf import CRFLayer,create_custom_objects

class LossHistory(Callback):
 def on_train_begin(self, logs={}):
   self.losses = []

def on_batch_end(self, batch, logs={}):
   self.losses.append(logs.get('loss'))
# def on_epoch_end(self, epoch, logs=None):

word_input = Input(shape=(max_len,), dtype='int32', name='word_input')
word_emb = Embedding(len(char_value_dict)+2, output_dim=64, input_length=max_len, dropout=0.2, name='word_emb')(word_input)
bilstm = Bidirectional(LSTM(32, dropout_W=0.1, dropout_U=0.1, return_sequences=True))(word_emb)
bilstm_d = Dropout(0.1)(bilstm)
half_window_size = 2
paddinglayer = ZeroPadding1D(padding=half_window_size)(word_emb)
conv = Conv1D(nb_filter=50, filter_length=(2 * half_window_size + 1), border_mode='valid')(paddinglayer)
conv_d = Dropout(0.1)(conv)
dense_conv = TimeDistributed(Dense(50))(conv_d)
rnn_cnn_merge = merge([bilstm_d, dense_conv], mode='concat', concat_axis=2)
dense = TimeDistributed(Dense(class_label_count))(rnn_cnn_merge)
crf = CRF(class_label_count, sparse_target=False)
crf_output = crf(dense)
model = Model(input=[word_input], output=[crf_output])
model.compile(loss=crf.loss_function, optimizer='adam', metrics=[crf.accuracy])
model.summary()

# serialize model to JSON
model_json = model.to_json()
with open("model.json", "w") as json_file:
 json_file.write(model_json)

#编译模型
# model.compile(loss='categorical_crossentropy', optimizer=adam, metrics=['acc',])

# 用于保存验证集误差最小的参数,当验证集误差减少时,立马保存下来
checkpointer = ModelCheckpoint(filepath="bilstm_1102_k205_tf130.w", verbose=0, save_best_only=True, save_weights_only=True) #save_weights_only=True
history = LossHistory()

history = model.fit(x_train, y_train,
         batch_size=32, epochs=500,#validation_data = ([x_test, seq_lens_test], y_test),
         callbacks=[checkpointer, history, plotter],
         verbose=1,
         validation_split=0.1,
         )

补充知识:keras训练模型使用自定义CTC损失函数,重载模型时报错解决办法

使用keras训练模型,用到了ctc损失函数,需要自定义损失函数如下:

self.ctc_model.compile(loss={'ctc': lambda y_true, output: output}, optimizer=opt)

其中loss为自定义函数,使用字典{‘ctc': lambda y_true, output: output}

训练完模型后需要重载模型,如下:

from keras.models import load_model

model=load_model('final_ctc_model.h5')

报错:

Unknown loss function : <lambda>

由于是自定义的损失函数需要加参数custom_objects,这里需要定义字典{'': lambda y_true, output: output},正确代码如下:

model=load_model('final_ctc_model.h5',custom_objects={'<lambda>': lambda y_true, output: output})

可能是因为要将自己定义的loss函数加入到keras函数里

在这之前试了很多次,如果用lambda y_true, output: output定义loss

函数字典名只能是'<lambda>',不能是别的字符

如果自定义一个函数如loss_func作为loss函数如:

self.ctc_model.compile(loss=loss_func, optimizer=opt)

可以在重载时使用

am=load_model('final_ctc_model.h5',custom_objects={'loss_func': loss_func})

此时注意字典名和函数名要相同

来源:https://blog.csdn.net/xinfeng2005/article/details/78485748

标签:keras,BiLSTM,CNN,CRF,NER
0
投稿

猜你喜欢

  • python pandas dataframe 去重函数的具体使用

    2023-10-15 00:56:36
  • CSS控制字体效果的思考

    2011-06-14 09:44:02
  • 详解python日期时间处理

    2021-08-20 17:07:53
  • 如何获知文件被改动的情况?

    2009-11-24 20:42:00
  • 当视觉设计师遇上产品经理、开发工程师…[译]

    2010-01-17 10:18:00
  • mysql密码过期导致连接不上mysql

    2024-01-22 12:28:06
  • 支持多浏览器(IE、Firefox、Opera)剪切板复制函数_脚本之家修正版

    2024-05-03 15:08:06
  • MySQL主从复制问题总结及排查过程

    2024-01-15 07:05:44
  • Python-GUI wxPython之自动化数据生成器的项目实战

    2021-06-08 14:43:18
  • Python Numpy实现计算矩阵的均值和标准差详解

    2021-12-20 09:51:20
  • 浅谈Python数学建模之固定费用问题

    2022-09-08 10:26:41
  • pandas 获取季度,月度,年度首尾日期的方法

    2022-08-16 06:53:06
  • sqlserver 巧妙的自关联运用

    2012-07-21 14:55:12
  • python http服务flask架构实用代码详解分析

    2023-07-31 13:52:59
  • 基于python的docx模块处理word和WPS的docx格式文件方式

    2021-11-13 12:07:55
  • window.onload使用指南

    2024-04-18 10:58:51
  • Java获取网络文件并插入数据库的代码

    2024-01-23 19:35:10
  • SQL Server 2005如何设置多字段做关键字

    2009-01-08 15:57:00
  • pytorch中Schedule与warmup_steps的用法说明

    2023-07-07 00:18:14
  • 用Python写一个无界面的2048小游戏

    2022-02-12 11:18:23
  • asp之家 网络编程 m.aspxhome.com