keras用auc做metrics以及早停实例

作者:ssswill 时间:2022-04-19 03:55:12 

我就废话不多说了,大家还是直接看代码吧~


import tensorflow as tf
from sklearn.metrics import roc_auc_score

def auroc(y_true, y_pred):
return tf.py_func(roc_auc_score, (y_true, y_pred), tf.double)
# Build Model...

model.compile(loss='categorical_crossentropy', optimizer='adam',metrics=['accuracy', auroc])

完整例子:


def auc(y_true, y_pred):
auc = tf.metrics.auc(y_true, y_pred)[1]
K.get_session().run(tf.local_variables_initializer())
return auc

def create_model_nn(in_dim,layer_size=200):
model = Sequential()
model.add(Dense(layer_size,input_dim=in_dim, kernel_initializer='normal'))
model.add(BatchNormalization())
model.add(Activation('relu'))
model.add(Dropout(0.3))
for i in range(2):
 model.add(Dense(layer_size))
 model.add(BatchNormalization())
 model.add(Activation('relu'))
 model.add(Dropout(0.3))
model.add(Dense(1, activation='sigmoid'))
adam = optimizers.Adam(lr=0.01)
model.compile(optimizer=adam,loss='binary_crossentropy',metrics = [auc])
return model
####cv train
folds = StratifiedKFold(n_splits=5, shuffle=False, random_state=15)
oof = np.zeros(len(df_train))
predictions = np.zeros(len(df_test))
for fold_, (trn_idx, val_idx) in enumerate(folds.split(df_train.values, target2.values)):
print("fold n°{}".format(fold_))
X_train = df_train.iloc[trn_idx][features]
y_train = target2.iloc[trn_idx]
X_valid = df_train.iloc[val_idx][features]
y_valid = target2.iloc[val_idx]
model_nn = create_model_nn(X_train.shape[1])
callback = EarlyStopping(monitor="val_auc", patience=50, verbose=0, mode='max')
history = model_nn.fit(X_train, y_train, validation_data = (X_valid ,y_valid),epochs=1000,batch_size=64,verbose=0,callbacks=[callback])
print('\n Validation Max score : {}'.format(np.max(history.history['val_auc'])))
predictions += model_nn.predict(df_test[features]).ravel()/folds.n_splits

补充知识:Keras可使用的评价函数

1:binary_accuracy(对二分类问题,计算在所有预测值上的平均正确率)

binary_accuracy(y_true, y_pred)

2:categorical_accuracy(对多分类问题,计算在所有预测值上的平均正确率)

categorical_accuracy(y_true, y_pred)

3:sparse_categorical_accuracy(与categorical_accuracy相同,在对稀疏的目标值预测时有用 )

sparse_categorical_accuracy(y_true, y_pred)

4:top_k_categorical_accuracy(计算top-k正确率,当预测值的前k个值中存在目标类别即认为预测正确 )

top_k_categorical_accuracy(y_true, y_pred, k=5)

5:sparse_top_k_categorical_accuracy(与top_k_categorical_accracy作用相同,但适用于稀疏情况)

sparse_top_k_categorical_accuracy(y_true, y_pred, k=5)

来源:https://blog.csdn.net/ssswill/article/details/95515314

标签:keras,auc,metrics,早停
0
投稿

猜你喜欢

  • JavaScript图片放大镜效果

    2009-10-19 22:15:00
  • Go语言正则表达式示例

    2023-04-13 19:41:34
  • Python序列化模块JSON与Pickle

    2022-11-06 00:24:37
  • GoLang中panic与recover函数以及defer语句超详细讲解

    2024-03-22 09:41:37
  • python中的Json模块dumps、dump、loads、load函数用法详解

    2023-11-09 20:01:30
  • 使用Python开发游戏运行脚本成功调用大漠插件

    2021-03-09 21:05:53
  • 深入分析MSSQL数据库中事务隔离级别和锁机制

    2024-01-22 02:53:35
  • 九个Python列表生成式高频面试题汇总

    2023-06-04 20:09:51
  • python批量生成条形码的示例

    2023-02-22 17:49:03
  • Python压缩模块zipfile实现原理及用法解析

    2023-07-13 03:01:46
  • 如何获取文件的名称和扩展名?

    2009-11-23 20:50:00
  • 用Python将结果保存为xlsx的方法

    2021-10-22 22:59:34
  • Python设计模式中的状态模式你了解吗

    2023-07-14 08:20:28
  • JavaScript开发时的五个小提示

    2007-11-21 19:54:00
  • golang常用库之操作数据库的orm框架-gorm基本使用详解

    2024-01-28 21:22:19
  • MS SQL 查询数据在数据库中所在行

    2009-04-26 19:36:00
  • mysql存储过程之创建(CREATE PROCEDURE)和调用(CALL)及变量创建(DECLARE)和赋值(SET)操作方法

    2024-01-19 22:48:06
  • 在数据库里将毫秒转换成date格式的方法

    2024-01-19 01:27:00
  • python使用pyshp读写shp文件的实现

    2023-10-02 04:07:13
  • python实现Zabbix-API监控

    2022-04-23 17:41:00
  • asp之家 网络编程 m.aspxhome.com