《tensorflow笔记》学习记录第六节

阅读: 评论:0

《tensorflow笔记》学习记录第六节

《tensorflow笔记》学习记录第六节

本节解决的问题

1.如何对输入的真实图片,输出预测结果?
2.如何制作数据集,实现特定应用?

答:

1.可以利用第五节中的网络结构,再添加部分代码(mnist_app.py), 实现输入手写数字图片输出识别结果。

2.利用tf中的 tfrecords文件。

6.1输入手写数字图片输出识别结果

6.1.1断点续训

断点续训可以解决神经网络训练被中断后,恢复神经网络训练时可以按之前的训练结果继续训练。实现方式是在反向传播中加入ckpt,代码如下

# 代码在第五节的反向传播代码基础上更改,省略了部分内容def backward(mnist):# 省略反向传播网络结构# 省略滑动平均等配置# 在Session()中加入ckpt实现断点续训练。# 实例化类saver = tf.train.Saver()with tf.Session() as sess:init_op = tf.global_variables_initializer()sess.run(init_op)ckpt = _checkpoint_state(MODEL_SAVE_PATH)if ckpt del_checkpoint_path:# 恢复保存的训练参数,实现断点续训 store(sess, del_checkpoint_path)for i in range(STEPS):# 省略循环训练和保存训练结果过程。

6.1.1图片预处理

首先,需要将图片预处理为全连接网络输入数据的格式。我们通过下面代码中的 pre_pic()函数实现。mnist数据集(本课所用的数字识别数据集),数据是黑底白字,黑底用0表示,白字用0~1之间浮点数表示,越接近1 颜色越白(这里老师说的应该是视觉直观上。( 《TensorFlow实战Google学习框架》 书中描述,书中是接近1为黑色。和/  网页上是说像素是0到255 255为黑色。应该是说图片数据中的1为有值。反色是因为在我们平时写的数字是白底黑字.总之要知道为什么在代码里反色。)

然后在restore_mode()函数中使用mnist_forward.py中定义的y, 并喂入处理之后的图片,模型预测的概率。

mnist_app.py 代码如下

#mnist_app.py 代码import tensorflow as tf
import numpy as np
from PIL import Image
import mnist_backward
import mnist_forwarddef restore_model(testPicArr):# tf.Graph().as_default() 应该是为了处理多个程序同时调用# mnist_forward.forward() 的情况,# 屏蔽with tf.Graph().as_default() as tg:# 一个窗口运行mnist_app,在另一个窗口同时运行mnist_backward,app会报错# =en# 根据tf官网 # The default graph is a property of the current thread. If you create a new thread, # and wish to use the default graph in that thread, you must explicitly add a with # g.as_default(): in that thread's function.with tf.Graph().as_default() as tg:x = tf.placeholder(tf.float32, [None, mnist_forward.INPUT_NODE])y = mnist_forward.forward(x, None)preValue = tf.argmax(y, 1)variable_averages= tf.train.ExponentialMovingAverage(mnist_backward.MOVING_AVERAGE_DECAY)# variables_to_restore()variable_to_restore = variable_averages.variables_to_restore()saver = tf.train.Saver(variables_to_restore)# 恢复网络结构with tf.Session() as sess:ckpt = _checkpoint_state(mnist_backward.MODEL_SAVE_PATH)if ckpt del_checkpoint_store(sess, del_checkpoint_path)preValue = sess.run(preValue, feed_dict={x:testPicArr})return preValueelse:print 'No checkpoint file found!'return -1# 预处理, 包括resize, 转变灰度图,二值化。
def pre_pic(picName):img = Image.open(picName)reIm = size((28, 28), Image.ANTIALIAS)im_arr = np.vert('L'))threshold = 50# im_arr[i][j] only have 0 or 255???for i in range(28):for j in range(28):# 反色处理,reb 和mnist 对黑的定义相反# reb中 0为黑,mnist 中 1为黑im_arr[i][j] = 255- im_arr[i][j]if im_arr[i][j] < threshold :im_arr[i][j] =0else:im_arr[i][j] =255nm_arr = shape([1, 784])nm_arr = nm_arr.astype(np.float32)img_ready = np.multiply(nm_arr, 1.0/255.0)return img_ready		def application():testNum = int(input('input the number of test pictures:'))for i range(testNum):testPic = raw_input('the path of test picture:')testPicArr = pre_pic(testPic)preValue = restore_model(testPicArr)print 'The prediction number is:', preValuedef main():application()if __name__ == '__main__':main()

输出截图如下:

 需要注意的是程序对mnist数据集外的图片识别率不是很高。

6.2 制作数据集

6.1.1tfrecords文件

tfrecords是一种二进制文件, 可先将图片和标签制作成为该格式的文件,使用tf.records进行数据读取,提高内存利用率。

用tf.train.Example的协议存储训练数据。训练数据的特征用键值对的形式表示。

如: ’img_raw' :值      ’label' :值     值是Byteslist/FloatList/Int64List

用SerializeToString() 把数据序列化成字符串存储。

伪代码如下,注意伪代码中把图片像素值除以255了 和mnist数据集是0到1之间的数对应

# 生成tfrecords 文件
writer = tf.python_io.TFRecordWriter(tfRecordName) # 新建一个writer
for 循环遍历每张图和标签 :example = tf.train.Example(feature&#ain.Features(feature={'img_raw':tf.train.Feature(bytes_list&#ain.Byteslist(value=[img_raw])),'label':tf.train.Feature(int64_list&#ain.Int64List(value=labels))}))  # 把每张图片和标签封装到example中wirter.write(example.SerializeToString()) # 把example 进行序列化# 解析tfrecords文件
fliename_queue &#ain.string_input_producer([tfRecord_path])
reader = tf.TFRecordReader() # 新建一个reader
_, serialized_example = ad(filename_queue)
features = tf.parse_single_example(serialized_example,features={'img_raw':tf.FixedLenFeature([],tf.string),'label':tf.FixedLenFeature([10],tf.int64)
})
img = tf.decode_raw(features['img_raw'], tf.uint8)
img.set_shape([784])
# 注意这里把图片像素值除以255了 和mnist数据集是0到1之间的数对应
img = tf.cast(img, tf.float32)*(1/255) 
label = tf.cast(features['label'], tf.float32)

本节课代码和第五节课的区别为

数据集生成部分的代码:

# coding:utf-8
import tensorflow as tf
import numpy as np
from PIL import Image
import  os #tf.disable_v2_behavior()
image_train_path = './mnist_data_jpg/mnist_train_jpg_60000/'
label_train_path = './mnist_data_jpg/mnist_train_'
tfRecord_train = './data/mnist_train.tfrecords'
image_test_path = './mnist_data_jpg/mnist_test_jpg_10000/'
label_test_path = './mnist_data_jpg/mnist_test_'
tfRecord_test = './data/mnist_test.tfrecords'
data_path = './data'
resize_height = 28
resize_width =28# 生成tfrecords文件
def write_tfRecord(tfRecordName, image_path, label_path):# 新建一个writerwriter = tf.python_io.TFRecordWriter(tfRecordName)num_pic = 0f = open(label_path, 'r')contents &#adlines()f.close()# 循环遍历每张图和标签for content in contents:value = content.split()img_path =image_path + value[0]img = Image.open(img_path)img_raw = bytes()labels = [0] * 10labels[int(value[1])] = 1# 把图片和标签封装到exampleexample = tf.train.Example(features&#ain.Features(feature={'img_raw': tf.train.Feature(bytes_list&#ain.BytesList(value=[img_raw])),'label':   tf.train.Feature(int64_list&#ain.Int64List(value=labels))}))# example 序列化writer.write(example.SerializeToString())num_pic += 1print("the number of picture:", num_pic)writer.close()print("write tfrecord successful")def generate_tfRecord():isExists = ists(data_path)if not isExists:os.makedirs(data_path)print('the directory was created successfully')else:print('directory already exists')write_tfRecord(tfRecord_train, image_train_path, label_train_path)write_tfRecord(tfRecord_test, image_test_path, label_test_path)# 解析tfrecords文件
def read_tfRecord(tfRecord_path):# 函数生成一个先入先出的队列,文件阅读器会使用它来读取数据filename_queue = tf.train.string_input_producer([tfRecord_path], shuffle=True)# 新建一个readerreader = tf.TFRecordReader()# 解序列化,标签和图片的键名应该和制作tfrecords的键名相同,其中标签给出几分类_, serialized_example &#ad(filename_queue)# 将tf.train.Example协议内存块(protocol buffer)解析为张量features = tf.parse_single_example(serialized_example,features={'label':tf.FixedLenFeature([10], tf.int64),'img_raw':tf.FixedLenFeature([],tf.string)})# 将 img_raw 字符串转化为8位无符号整形img = tf.decode_raw(features['img_raw'],tf.uint8)img.set_shape([784])# 注意这里把图片像素值除以255了 和mnist数据集是0到1之间的数对应img = tf.cast(img, tf.float32) * (1. / 255)label = tf.cast(features['label'], tf.float32)return  img, labeldef get_tfrecord(num,isTrain=True):if isTrain:tfRecord_path = tfRecord_trainelse:tfRecord_path = tfRecord_testimg, label = read_tfRecord(tfRecord_path)#随机读取一个batch的数据img_batch, label_batch = tf.train.shuffle_batch([img, label],batch_size=num,num_threads=2,capacity=1000,min_after_dequeue=700)# 返回的图片和标签为随机抽取的batch_size组return  img_batch, label_batchdef main():generate_tfRecord()if __name__ == "__main__":main()

本文发布于:2024-02-02 09:45:18,感谢您对本站的认可!

本文链接:https://www.4u4v.net/it/170683831942972.html

版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,我们将在24小时内删除。

标签:第六节   笔记   tensorflow
留言与评论(共有 0 条评论)
   
验证码:

Copyright ©2019-2022 Comsenz Inc.Powered by ©

网站地图1 网站地图2 网站地图3 网站地图4 网站地图5 网站地图6 网站地图7 网站地图8 网站地图9 网站地图10 网站地图11 网站地图12 网站地图13 网站地图14 网站地图15 网站地图16 网站地图17 网站地图18 网站地图19 网站地图20 网站地图21 网站地图22/a> 网站地图23