keras用auc做metrics以及早停实例
作者:ssswill 时间:2022-04-19 03:55:12
我就废话不多说了,大家还是直接看代码吧~
import tensorflow as tf
from sklearn.metrics import roc_auc_score
def auroc(y_true, y_pred):
return tf.py_func(roc_auc_score, (y_true, y_pred), tf.double)
# Build Model...
model.compile(loss='categorical_crossentropy', optimizer='adam',metrics=['accuracy', auroc])
完整例子:
def auc(y_true, y_pred):
auc = tf.metrics.auc(y_true, y_pred)[1]
K.get_session().run(tf.local_variables_initializer())
return auc
def create_model_nn(in_dim,layer_size=200):
model = Sequential()
model.add(Dense(layer_size,input_dim=in_dim, kernel_initializer='normal'))
model.add(BatchNormalization())
model.add(Activation('relu'))
model.add(Dropout(0.3))
for i in range(2):
model.add(Dense(layer_size))
model.add(BatchNormalization())
model.add(Activation('relu'))
model.add(Dropout(0.3))
model.add(Dense(1, activation='sigmoid'))
adam = optimizers.Adam(lr=0.01)
model.compile(optimizer=adam,loss='binary_crossentropy',metrics = [auc])
return model
####cv train
folds = StratifiedKFold(n_splits=5, shuffle=False, random_state=15)
oof = np.zeros(len(df_train))
predictions = np.zeros(len(df_test))
for fold_, (trn_idx, val_idx) in enumerate(folds.split(df_train.values, target2.values)):
print("fold n°{}".format(fold_))
X_train = df_train.iloc[trn_idx][features]
y_train = target2.iloc[trn_idx]
X_valid = df_train.iloc[val_idx][features]
y_valid = target2.iloc[val_idx]
model_nn = create_model_nn(X_train.shape[1])
callback = EarlyStopping(monitor="val_auc", patience=50, verbose=0, mode='max')
history = model_nn.fit(X_train, y_train, validation_data = (X_valid ,y_valid),epochs=1000,batch_size=64,verbose=0,callbacks=[callback])
print('\n Validation Max score : {}'.format(np.max(history.history['val_auc'])))
predictions += model_nn.predict(df_test[features]).ravel()/folds.n_splits
补充知识:Keras可使用的评价函数
1:binary_accuracy(对二分类问题,计算在所有预测值上的平均正确率)
binary_accuracy(y_true, y_pred)
2:categorical_accuracy(对多分类问题,计算在所有预测值上的平均正确率)
categorical_accuracy(y_true, y_pred)
3:sparse_categorical_accuracy(与categorical_accuracy相同,在对稀疏的目标值预测时有用 )
sparse_categorical_accuracy(y_true, y_pred)
4:top_k_categorical_accuracy(计算top-k正确率,当预测值的前k个值中存在目标类别即认为预测正确 )
top_k_categorical_accuracy(y_true, y_pred, k=5)
5:sparse_top_k_categorical_accuracy(与top_k_categorical_accracy作用相同,但适用于稀疏情况)
sparse_top_k_categorical_accuracy(y_true, y_pred, k=5)
来源:https://blog.csdn.net/ssswill/article/details/95515314
标签:keras,auc,metrics,早停
![](/images/zang.png)
![](/images/jiucuo.png)
猜你喜欢
JavaScript图片放大镜效果
2009-10-19 22:15:00
![](https://img.aspxhome.com/file/UploadPic/up/2009101922390143.gif)
Go语言正则表达式示例
2023-04-13 19:41:34
Python序列化模块JSON与Pickle
2022-11-06 00:24:37
GoLang中panic与recover函数以及defer语句超详细讲解
2024-03-22 09:41:37
python中的Json模块dumps、dump、loads、load函数用法详解
2023-11-09 20:01:30
![](https://img.aspxhome.com/file/2023/7/105417_0s.png)
使用Python开发游戏运行脚本成功调用大漠插件
2021-03-09 21:05:53
![](https://img.aspxhome.com/file/2023/9/99679_0s.png)
深入分析MSSQL数据库中事务隔离级别和锁机制
2024-01-22 02:53:35
九个Python列表生成式高频面试题汇总
2023-06-04 20:09:51
python批量生成条形码的示例
2023-02-22 17:49:03
Python压缩模块zipfile实现原理及用法解析
2023-07-13 03:01:46
如何获取文件的名称和扩展名?
2009-11-23 20:50:00
用Python将结果保存为xlsx的方法
2021-10-22 22:59:34
Python设计模式中的状态模式你了解吗
2023-07-14 08:20:28
![](https://img.aspxhome.com/file/2023/2/60962_0s.png)
JavaScript开发时的五个小提示
2007-11-21 19:54:00
golang常用库之操作数据库的orm框架-gorm基本使用详解
2024-01-28 21:22:19
MS SQL 查询数据在数据库中所在行
2009-04-26 19:36:00
mysql存储过程之创建(CREATE PROCEDURE)和调用(CALL)及变量创建(DECLARE)和赋值(SET)操作方法
2024-01-19 22:48:06
在数据库里将毫秒转换成date格式的方法
2024-01-19 01:27:00
python使用pyshp读写shp文件的实现
2023-10-02 04:07:13
python实现Zabbix-API监控
2022-04-23 17:41:00
![](https://img.aspxhome.com/file/2023/5/71505_0s.jpg)