使用tensorflow实现AlexNet

作者:triplebee 时间:2023-08-10 08:29:30 

AlexNet是2012年ImageNet比赛的冠军,虽然过去了很长时间,但是作为深度学习中的经典模型,AlexNet不但有助于我们理解其中所使用的很多技巧,而且非常有助于提升我们使用深度学习工具箱的熟练度。尤其是我刚入门深度学习,迫切需要一个能让自己熟悉tensorflow的小练习,于是就有了这个小玩意儿......

先放上我的代码:https://github.com/hjptriplebee/AlexNet_with_tensorflow

如果想运行代码,详细的配置要求都在上面链接的readme文件中了。本文建立在一定的tensorflow基础上,不会对太细的点进行说明。

模型结构

使用tensorflow实现AlexNet

关于模型结构网上的文献很多,我这里不赘述,一会儿都在代码里解释。

有一点需要注意,AlexNet将网络分成了上下两个部分,在论文中两部分结构完全相同,唯一不同的是他们放在不同GPU上训练,因为每一层的feature map之间都是独立的(除了全连接层),所以这相当于是提升训练速度的一种方法。很多AlexNet的复现都将上下两部分合并了,因为他们都是在单个GPU上运行的。虽然我也是在单个GPU上运行,但是我还是很想将最原始的网络结构还原出来,所以我的代码里也是分开的。

模型定义


def maxPoolLayer(x, kHeight, kWidth, strideX, strideY, name, padding = "SAME"):
 """max-pooling"""
 return tf.nn.max_pool(x, ksize = [1, kHeight, kWidth, 1],
            strides = [1, strideX, strideY, 1], padding = padding, name = name)

def dropout(x, keepPro, name = None):
 """dropout"""
 return tf.nn.dropout(x, keepPro, name)

def LRN(x, R, alpha, beta, name = None, bias = 1.0):
 """LRN"""
 return tf.nn.local_response_normalization(x, depth_radius = R, alpha = alpha,
                      beta = beta, bias = bias, name = name)

def fcLayer(x, inputD, outputD, reluFlag, name):
 """fully-connect"""
 with tf.variable_scope(name) as scope:
   w = tf.get_variable("w", shape = [inputD, outputD], dtype = "float")
   b = tf.get_variable("b", [outputD], dtype = "float")
   out = tf.nn.xw_plus_b(x, w, b, name = scope.name)
   if reluFlag:
     return tf.nn.relu(out)
   else:
     return out

def convLayer(x, kHeight, kWidth, strideX, strideY,
      featureNum, name, padding = "SAME", groups = 1):#group为2时等于AlexNet中分上下两部分
 """convlutional"""
 channel = int(x.get_shape()[-1])#获取channel
 conv = lambda a, b: tf.nn.conv2d(a, b, strides = [1, strideY, strideX, 1], padding = padding)#定义卷积的匿名函数
 with tf.variable_scope(name) as scope:
   w = tf.get_variable("w", shape = [kHeight, kWidth, channel/groups, featureNum])
   b = tf.get_variable("b", shape = [featureNum])

xNew = tf.split(value = x, num_or_size_splits = groups, axis = 3)#划分后的输入和权重
   wNew = tf.split(value = w, num_or_size_splits = groups, axis = 3)

featureMap = [conv(t1, t2) for t1, t2 in zip(xNew, wNew)] #分别提取feature map
   mergeFeatureMap = tf.concat(axis = 3, values = featureMap) #feature map整合
   # print mergeFeatureMap.shape
   out = tf.nn.bias_add(mergeFeatureMap, b)
   return tf.nn.relu(tf.reshape(out, mergeFeatureMap.get_shape().as_list()), name = scope.name) #relu后的结果

定义了卷积、pooling、LRN、dropout、全连接五个模块,其中卷积模块因为将网络的上下两部分分开了,所以比较复杂。接下来定义AlexNet。


class alexNet(object):
 """alexNet model"""
 def __init__(self, x, keepPro, classNum, skip, modelPath = "bvlc_alexnet.npy"):
   self.X = x
   self.KEEPPRO = keepPro
   self.CLASSNUM = classNum
   self.SKIP = skip
   self.MODELPATH = modelPath
   #build CNN
   self.buildCNN()

def buildCNN(self):
   """build model"""
   conv1 = convLayer(self.X, 11, 11, 4, 4, 96, "conv1", "VALID")
   pool1 = maxPoolLayer(conv1, 3, 3, 2, 2, "pool1", "VALID")
   lrn1 = LRN(pool1, 2, 2e-05, 0.75, "norm1")

conv2 = convLayer(lrn1, 5, 5, 1, 1, 256, "conv2", groups = 2)
   pool2 = maxPoolLayer(conv2, 3, 3, 2, 2, "pool2", "VALID")
   lrn2 = LRN(pool2, 2, 2e-05, 0.75, "lrn2")

conv3 = convLayer(lrn2, 3, 3, 1, 1, 384, "conv3")

conv4 = convLayer(conv3, 3, 3, 1, 1, 384, "conv4", groups = 2)

conv5 = convLayer(conv4, 3, 3, 1, 1, 256, "conv5", groups = 2)
   pool5 = maxPoolLayer(conv5, 3, 3, 2, 2, "pool5", "VALID")

fcIn = tf.reshape(pool5, [-1, 256 * 6 * 6])
   fc1 = fcLayer(fcIn, 256 * 6 * 6, 4096, True, "fc6")
   dropout1 = dropout(fc1, self.KEEPPRO)

fc2 = fcLayer(dropout1, 4096, 4096, True, "fc7")
   dropout2 = dropout(fc2, self.KEEPPRO)

self.fc3 = fcLayer(dropout2, 4096, self.CLASSNUM, True, "fc8")

def loadModel(self, sess):
   """load model"""
   wDict = np.load(self.MODELPATH, encoding = "bytes").item()
   #for layers in model
   for name in wDict:
     if name not in self.SKIP:
       with tf.variable_scope(name, reuse = True):
         for p in wDict[name]:
           if len(p.shape) == 1:  
             #bias 只有一维
             sess.run(tf.get_variable('b', trainable = False).assign(p))
           else:
             #weights  
             sess.run(tf.get_variable('w', trainable = False).assign(p))

buildCNN函数完全按照alexnet的结构搭建网络。
loadModel函数从模型文件中读取参数,采用的模型文件见github上的readme说明。
至此,我们定义了完整的模型,下面开始测试模型。

模型测试

ImageNet训练的AlexNet有很多类,几乎包含所有常见的物体,因此我们随便从网上找几张图片测试。比如我直接用了之前做项目的渣土车图片:

使用tensorflow实现AlexNet

然后编写测试代码:


#some params
dropoutPro = 1
classNum = 1000
skip = []
#get testImage
testPath = "testModel"
testImg = []
for f in os.listdir(testPath):
 testImg.append(cv2.imread(testPath + "/" + f))

imgMean = np.array([104, 117, 124], np.float)
x = tf.placeholder("float", [1, 227, 227, 3])

model = alexnet.alexNet(x, dropoutPro, classNum, skip)
score = model.fc3
softmax = tf.nn.softmax(score)

with tf.Session() as sess:
 sess.run(tf.global_variables_initializer())
 model.loadModel(sess) #加载模型

for i, img in enumerate(testImg):
   #img preprocess
   test = cv2.resize(img.astype(np.float), (227, 227)) #resize成网络输入大小
   test -= imgMean #去均值
   test = test.reshape((1, 227, 227, 3)) #拉成tensor
   maxx = np.argmax(sess.run(softmax, feed_dict = {x: test}))
   res = caffe_classes.class_names[maxx] #取概率最大类的下标
   #print(res)
   font = cv2.FONT_HERSHEY_SIMPLEX
   cv2.putText(img, res, (int(img.shape[0]/3), int(img.shape[1]/3)), font, 1, (0, 255, 0), 2)#绘制类的名字
   cv2.imshow("demo", img)  
   cv2.waitKey(5000) #显示5秒

如上代码所示,首先需要设置一些参数,然后读取指定路径下的测试图像,再对模型做一个初始化,最后是真正测试代码。测试结果如下:

使用tensorflow实现AlexNet

来源:http://blog.csdn.net/accepthjp/article/details/69999309

标签:tensorflow,AlexNet
0
投稿

猜你喜欢

  • Python 文件与文件对象及文件打开关闭

    2021-06-16 16:08:44
  • Python实现数字的格式化输出

    2021-10-11 18:11:27
  • SQL 重复记录问题的处理方法小结

    2024-01-16 14:56:36
  • python 求1-100之间的奇数或者偶数之和的实例

    2021-05-28 19:48:58
  • python中argparse模块基础及使用步骤

    2023-01-26 19:21:59
  • Pandas透视表(pivot_table)详解

    2022-03-26 00:21:29
  • 详细解读tornado协程(coroutine)原理

    2022-08-21 00:20:18
  • 基于scrapy的redis安装和配置方法

    2022-07-15 17:26:56
  • 好玩的vbs微信小程序之语言播报功能

    2023-04-27 12:54:29
  • 基于JavaScript如何实现私有成员的语法特征及私有成员的实现方式

    2024-04-22 22:37:54
  • Python2与Python3的区别点整理

    2022-02-23 07:44:46
  • Python实现迭代时使用索引的方法示例

    2022-12-15 11:08:48
  • GoLang nil与interface的空指针深入分析

    2024-02-18 01:58:50
  • python 图片验证码代码分享

    2022-02-21 17:10:27
  • python 切片和range()用法说明

    2021-12-12 07:40:52
  • Python实战之手写一个搜索引擎

    2023-07-11 21:16:49
  • 在django中图片上传的格式校验及大小方法

    2023-04-02 23:12:56
  • Python进制转换用法详解

    2021-08-20 15:18:40
  • 基于python实现删除指定文件类型

    2022-02-16 06:19:48
  • PHP自定义函数格式化json数据示例

    2023-07-17 07:17:45
  • asp之家 网络编程 m.aspxhome.com