keras topN显示,自编写代码案例

作者:姚贤贤 时间:2021-03-19 03:15:13 

对于使用已经训练好的模型,比如VGG,RESNET等,keras都自带了一个keras.applications.imagenet_utils.decode_predictions的方法,有很多限制:


def decode_predictions(preds, top=5):
"""Decodes the prediction of an ImageNet model.

# Arguments
preds: Numpy tensor encoding a batch of predictions.
top: Integer, how many top-guesses to return.

# Returns
A list of lists of top class prediction tuples
`(class_name, class_description, score)`.
One list of tuples per sample in batch input.

# Raises
ValueError: In case of invalid shape of the `pred` array
 (must be 2D).
"""
global CLASS_INDEX
if len(preds.shape) != 2 or preds.shape[1] != 1000:
raise ValueError('`decode_predictions` expects '
   'a batch of predictions '
   '(i.e. a 2D array of shape (samples, 1000)). '
   'Found array with shape: ' + str(preds.shape))
if CLASS_INDEX is None:
fpath = get_file('imagenet_class_index.json',
   CLASS_INDEX_PATH,
   cache_subdir='models',
   file_hash='c2c37ea517e94d9795004a39431a14cb')
with open(fpath) as f:
 CLASS_INDEX = json.load(f)
results = []
for pred in preds:
top_indices = pred.argsort()[-top:][::-1]
result = [tuple(CLASS_INDEX[str(i)]) + (pred[i],) for i in top_indices]
result.sort(key=lambda x: x[2], reverse=True)
results.append(result)
return results

把重要的东西挖出来,然后自己敲,这样就OK了,下例以MNIST数据集为例:


import keras
from keras.models import Sequential
from keras.layers import Dense
import numpy as np
import tflearn
import tflearn.datasets.mnist as mnist

def decode_predictions_custom(preds, top=5):
CLASS_CUSTOM = ["0","1","2","3","4","5","6","7","8","9"]
results = []
for pred in preds:
top_indices = pred.argsort()[-top:][::-1]
result = [tuple(CLASS_CUSTOM[i]) + (pred[i]*100,) for i in top_indices]
results.append(result)
return results

x_train, y_train, x_test, y_test = mnist.load_data(one_hot=True)

model = Sequential()
model.add(Dense(units=64, activation='relu', input_dim=784))
model.add(Dense(units=10, activation='softmax'))
model.compile(loss='categorical_crossentropy',
 optimizer='sgd',
 metrics=['accuracy'])
model.fit(x_train, y_train, epochs=10, batch_size=128)
# score = model.evaluate(x_test, y_test, batch_size=128)
# print(score)
preds = model.predict(x_test[0:1,:])
p = decode_predictions_custom(preds)
for (i,(label,prob)) in enumerate(p[0]):
print("{}. {}: {:.2f}%".format(i+1, label,prob))
# 1. 7: 99.43%
# 2. 9: 0.24%
# 3. 3: 0.23%
# 4. 0: 0.05%
# 5. 2: 0.03%

补充知识:keras简单的去噪自编码器代码和各种类型自编码器代码

我就废话不多说了,大家还是直接看代码吧~


start = time()

from keras.models import Sequential
from keras.layers import Dense, Dropout,Input
from keras.layers import Embedding
from keras.layers import Conv1D, GlobalAveragePooling1D, MaxPooling1D
from keras import layers
from keras.models import Model

# Parameters for denoising autoencoder
nb_visible = 120
nb_hidden = 64
batch_size = 16
# Build autoencoder model
input_img = Input(shape=(nb_visible,))

encoded = Dense(nb_hidden, activation='relu')(input_img)
decoded = Dense(nb_visible, activation='sigmoid')(encoded)

autoencoder = Model(input=input_img, output=decoded)
autoencoder.compile(loss='mean_squared_error',optimizer='adam',metrics=['mae'])
autoencoder.summary()

# Train
### 加一个early_stooping
import keras

early_stopping = keras.callbacks.EarlyStopping(
 monitor='val_loss',
 min_delta=0.0001,
 patience=5,
 verbose=0,
 mode='auto'
)
autoencoder.fit(X_train_np, y_train_np, nb_epoch=50, batch_size=batch_size , shuffle=True,
       callbacks = [early_stopping],verbose = 1,validation_data=(X_test_np, y_test_np))
# Evaluate
evaluation = autoencoder.evaluate(X_test_np, y_test_np, batch_size=batch_size , verbose=1)
print('val_loss: %.6f, val_mean_absolute_error: %.6f' % (evaluation[0], evaluation[1]))

end = time()
print('耗时:'+str((end-start)/60))

keras各种自编码代码

来源:https://blog.csdn.net/u011311291/article/details/79991716

标签:keras,topN显示,自编
0
投稿

猜你喜欢

  • python练习之循环控制语句 break 与 continue

    2022-04-15 12:31:20
  • python的语句结构你真的了解吗

    2022-08-11 23:05:39
  • python基于搜索引擎实现文章查重功能

    2022-01-21 19:25:50
  • python中文件操作与异常的处理图文详解

    2021-09-04 16:04:33
  • WEB移动应用框架构想

    2010-09-28 16:26:00
  • fckeditor编辑器在php中的配置方法

    2023-10-14 14:26:52
  • 判断数据库里存在的BIG5码

    2009-04-09 18:31:00
  • AJAX的jQuery实现入门(二)

    2008-05-01 13:04:00
  • PHP平滑关闭/重启的实现方法

    2023-10-05 08:48:29
  • python判断变量是否为int、字符串、列表、元组、字典的方法详解

    2022-09-28 05:11:57
  • Python的几个高级语法概念浅析(lambda表达式闭包装饰器)

    2021-08-07 14:18:45
  • Python hashlib加密模块常用方法解析

    2022-03-11 05:20:05
  • Python爬虫之网页图片抓取的方法

    2021-12-19 00:47:20
  • perl哈希hash的常见用法介绍

    2023-08-12 18:46:59
  • python实现TF-IDF算法解析

    2021-06-02 03:27:51
  • django中上传图片分页三级联动效果的实现代码

    2022-02-26 18:39:31
  • 如何尽快释放掉Connection对象建立的连接?

    2009-12-16 18:38:00
  • 交互设计实用指南系列(3)—“有效性”之“适时帮助”

    2009-12-25 14:29:00
  • Python 数据类型--集合set

    2021-11-23 21:17:54
  • js版sliderBar(滑动条)控件

    2008-10-18 15:59:00
  • asp之家 网络编程 m.aspxhome.com