在做深度学习项目时,在模型训练前,通常要对训练/验证图像进行读取操作。之前博文《TensorFlow 卷积神经网络 – 猫狗识别》使用的是OpenCV读取的方式。使用OpenCV把图像读成矩阵形式当然可以满足模型训练的要求,此方式在处理小批量图像时还可以,如果处理大批量图像,就显得有点慢了。
对于大型项目、大批量的图像,经常用TFRecord的方式对数据进行读取。TFRecord是TensorFlow支持的格式,速度快,1W以上的量建议使用TFRecord。TFRecord文件是以二进制进行存储数据的,适合以串行的方式读取大批量数据。其优势是能更好的利用内存,更方便地复制和移动,这更符合TensorFlow执行引擎的处理方式。通常数据转换成tfrecord格式需要写个小程序将每一个样本组装成protocol buffer定义的Example的对象,序列化成字符串,再由tf.python_io.TFRecordWriter写入文件即可。
在使用TFRecord方式读取数据之前,通常需要把相同类型的数据放在同一个文件夹。例如:
上图中,“flower_photos”为总文件夹,里面放了5个子文件夹,即把所有的玫瑰图片放到“roses”文件夹,所有的向日葵图片放到“sunflowers”文件夹,等等。这样做的目的是方便完成“图片路径”–“图片标签(例:1、2、3)”–“图片名称(例:daisy、dandelion、roses)”之间的映射。
roses文件夹下的图片:
程序实现
目录结构:
flower_label.txt:
此文件的内容存放./flower_photos目录下的5个子文件名称,方便程序读取图片。
daisy
dandelion
roses
sunflowers
tulips
build_image_data.py:
# coding=utf-8
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from datetime import datetime
import os
import random
import sys
# 多线程制作,速度更快。数据预处理、建立数据源写一块
import threading
import numpy as np
import tensorflow as tf
# 定义string和int类型参数
# 没演示验证集,只有训练集,可以在目录里面加上验证集。train_directory为参数名
tf.app.flags.DEFINE_string('train_directory', './flower_photos/', 'Training data directory')
# 验证集,未指定单独的验证集,偷懒
tf.app.flags.DEFINE_string('validation_directory', './flower_photos/', 'Validation data directory')
# TFRecord输出目录
tf.app.flags.DEFINE_string('output_directory', './data/', 'Output data directory')
# 想生成几个TFrecord文件,train_shards / num_threads 要能够整除,这样才好能分配数量
tf.app.flags.DEFINE_integer('train_shards', 2, 'Number of shards in training TFRecord files.')
# 同上,不做验证集,只做训练集
tf.app.flags.DEFINE_integer('validation_shards', 0, 'Number of shards in validation TFRecord files.')
# 启动线程的个数
tf.app.flags.DEFINE_integer('num_threads', 2, 'Number of threads to preprocess the images.')
# The labels file contains a list of valid labels are held in this file .
# Assumes that the file contains entries as such:
# dog
# cat
# flower
# where each line corresponds to a labels. We map each label contained in
# the file to an integer corresponding to the line number starting from 0.
# flower_label.txt和子文件夹的名字一一对应
tf.app.flags.DEFINE_string('labels_file', './flower_label.txt', 'labels file')
# 获得上述定义的参数
FLAGS = tf.app.flags.FLAGS
def _int64_feature(value):
"""Wrapper for inserting int64 feature into Example proto."""
"""isinstance() 函数来判断一个对象是否是一个已知的类型,类似 type()
isinstance() 与 type() 区别:
type() 不会认为子类是一种父类类型,不考虑继承关系。
isinstance() 会认为子类是一种父类类型,考虑继承关系。
"""
if not isinstance(value, list):
value = [value]
return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
def _bytes_feature(value):
"""Wrapper for inserting bytes features into Example proto."""
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def _find_image_files(data_dir, labels_file):
"""
Build a list of all images files and labels in the data set.
:param data_dir: string, path to the root directory of images.
:param labels_file: string, path to the labels file.
The list of valid labels are held in this file, Assumes that the file contains entries as such:
dog
cat flower
where each line corresponds to a label. We map each label contained in the file to an integer staring with the
integer 0 corresponding to the label contained in the first line.
:return:
filenames: list of strings; each string is a path to an image file.
texts: list of strings; each string is the class, e.g. 'dog'
labels: list of integer; each integer identifies the ground truth.
"""
print('目标文件夹位置:%s.' % data_dir)
# 读flower_label.txt文件的内容
"""tf.gfile.FastGFile(path, decodestyle)
函数功能:实现对图片的读取。
函数参数:(1)path:图片所在路径 (2)decodestyle:图片的解码方式。(‘r’:UTF-8编码; ‘rb’:非UTF-8编码)
"""
unique_labels = [l.strip() for l in tf.gfile.FastGFile(labels_file, 'r').readlines()]
labels = []
filenames = []
texts = []
# Leave label index 0 empty as a background class.
label_index = 1
# Construct the list of JPEG files and labels.
for text in unique_labels:
jpeg_file_path = '%s/%s/*' % (data_dir, text)
try:
# tf.gfile.Glob()用于返回与给定模式匹配的文件列表
matching_files = tf.gfile.Glob(jpeg_file_path)
except:
print(jpeg_file_path)
continue
# 从“1”开始,扩充每一图片类别的labels
labels.extend([label_index] * len(matching_files))
# 根据flower_label.txt内容,扩充texts
texts.extend([text] * len(matching_files))
filenames.extend(matching_files)
label_index += 1
# shuffle the ordering of all image files in order to guarantee
# random ordering of the images with respect to label in the
# saved TFRecord files. Make the randomization repeatable.
# 洗牌,把当前顺序打乱,标签为1、2、3、4、5、打乱
shuffled_index = list(range(len(filenames)))
# 保证shuffled_index之后每次的随机一样
random.seed(12345)
random.shuffle(shuffled_index)
# 数据重新排列,执行完shuffle之后,数据可以对应上
filenames = [filenames[i] for i in shuffled_index]
texts = [texts[i] for i in shuffled_index]
labels = [labels[i] for i in shuffled_index]
print('Found %d JPEG files across %d labels inside %s.' % (len(filenames), len(unique_labels), data_dir))
return filenames, texts, labels
class ImageCoder(object):
"""Helper class that provides TensorFlow image coding utilities."""
# 把所有图片转换成.jpg的RGB的形式
def __init__(self):
# Create a single Session to run all image coding calls.
self._sess = tf.Session()
# Initializes function that converts PNG to JPEG data.
# 确保所有图像格式都相同
self._png_data = tf.placeholder(dtype=tf.string)
# 解码为3通道
image = tf.image.decode_png(self._png_data, channels=3)
# 编码为RGB
self._png_to_jpeg = tf.image.encode_jpeg(image, format='rgb', quality=100)
# Initializes function that decodes RGB JPEG data.
self._decode_jpeg_data = tf.placeholder(dtype=tf.string)
self._decode_jpeg = tf.image.decode_jpeg(self._decode_jpeg_data, channels=3)
def png_to_jpeg(self, image_data):
return self._sess.run(self._png_to_jpeg, feed_dict={self._png_data: image_data})
def decode_jpeg(self, image_data):
image = self._sess.run(self._decode_jpeg, feed_dict={self._decode_jpeg_data: image_data})
assert len(image.shape) == 3
assert image.shape[2] == 3
return image
def _process_image(filename, coder):
"""
Process a single image file.
:param filename: string, path to an image file e.g., '/path/to/example.JPG'.
:param coder: instance of ImageCoder to provide TensorFlow image coding utils.
:return: image_buffer: string, JPEG encoding of RGB image.
height: integer, image height in pixels.
width: integer, image width in pixels.
"""
# Read the image file.
with tf.gfile.FastGFile(filename, 'rb') as f:
image_data = f.read()
# Convert any PNG to JPEG's for consistency.
if _is_png(filename):
print('Converting PNG to JPEG for %s' % filename)
image_data = coder.png_to_jpeg(image_data)
# Decode the RGB JPEG.
image = coder.decode_jpeg(image_data)
# Check that image converted to RGB. h, w, channel
assert len(image.shape) == 3
height = image.shape[0]
width = image.shape[1]
# 判断是否三通道
assert image.shape[2] == 3
return image_data, height, width
def _is_png(filename):
"""
Determine if a file contains a PNG format image.
:param filename: string, path of the iamge file.
:return: boolean indicating if the image is a PNG.
"""
return '.png' in filename
def _convert_to_example(filename, image_buffer, label, text, height, width):
"""
Build an Example proto for an example.
:param filename: string, path to an image file, e.g., '/path/to/example.JPG'
:param image_buffer: string, JPEG encoding of RGB image
:param label: integer, identifier for the ground truth for the network
:param text: string, unique human-readable, e.g. 'dog'
:param height: integer, image height in pixels
:param width: integer, image width in pixels
:return: Example proto
"""
colorspace = 'RGB'
channels = 3
image_format = 'JPEG'
# tf.compat.as_bytes(),将字节或unicode转换为字节,使用utf-8编码文本
example = tf.train.Example(features=tf.train.Features(feature={
'image/height': _int64_feature(height),
'image/width': _int64_feature(width),
'image/colorspace': _bytes_feature(tf.compat.as_bytes(colorspace)),
'image/channels': _int64_feature(channels),
'image/class/label': _int64_feature(label),
'image/class/text': _bytes_feature(tf.compat.as_bytes(text)),
'image/format': _bytes_feature(tf.compat.as_bytes(image_format)),
'image/filename': _bytes_feature(tf.compat.as_bytes(os.path.basename(filename))),
'image/encoded': _bytes_feature(tf.compat.as_bytes(image_buffer))
# 'image/encoded': _bytes_feature(image_buffer)
}))
return example
def _process_image_files_batch(coder, thread_index, ranges, name, filenames, texts, labels, num_shards):
"""
Processes and saves list of images as TFRecord in 1 thread.
:param coder: instance of ImageCoder to provide TensorFlow image coding utils.
:param thread_index: integer, unique batch to run index is within [0, len(ranges)].
:param ranges: list of pairs of integers specifying ranges of each batches to analyze in parallel.
:param name: string, unique identifier specifying the data set.
:param filenames: list of strings; each string is a path to an image file.
:param texts: list of strings; each string is human readable, e.g. 'dog'.
:param labels: list of integer; each integer identifies the ground truth.
:param num_shards: integer number of shards for this data set.
:return:
"""
# Each thread produces N shards where N=int(num_shards / num_threads).
# For instance, if num_shards=128, and the num_threads=2, then the first thread would produce shards[0, 64].
num_threads = len(ranges)
assert not num_shards % num_threads
num_shards_per_batch = int(num_shards / num_threads)
shard_ranges = np.linspace(ranges[thread_index][0], ranges[thread_index][1], num_shards_per_batch + 1).astype(int)
num_files_in_thread = ranges[thread_index][1] - ranges[thread_index][0]
counter = 0
for s in range(num_shards_per_batch):
# Generate a sharded version of the file name, e.g. 'train-00001-of-00002'
shard = thread_index * num_shards_per_batch + s
output_filename = '%s-%.5d-of-%.5d.tfrecord' % (name, shard, num_shards)
output_file = os.path.join(FLAGS.output_directory, output_filename)
writer = tf.python_io.TFRecordWriter(output_file)
shard_counter = 0
files_in_shard = np.arange(shard_ranges[s], shard_ranges[s + 1], dtype=int)
for i in files_in_shard:
filename = filenames[i] # 全路径
label = labels[i] # 标签
text = texts[i] # 文件夹名称
image_buffer, height, width = _process_image(filename, coder)
example = _convert_to_example(filename, image_buffer, label, text, height, width)
writer.write(example.SerializeToString())
shard_counter += 1
counter += 1
if not counter % 1000:
print('%s [thread %d]: Processed %d of %d image in thread batch.' % (
datetime.now(), thread_index, counter, num_files_in_thread))
sys.stdout.flush()
writer.close()
print('%s [thread %d]: Wrote %d images to %s' % (datetime.now(), thread_index, shard_counter, output_file))
# 关闭多线程
sys.stdout.flush()
shard_counter = 0
print(
'%s [thread %d]: Wrote %d images to %d shards.' % (datetime.now(), thread_index, counter, num_files_in_thread))
sys.stdout.flush()
def _process_image_files(name, filenames, texts, labels, num_shards):
"""
Process and save list of image as TFRecord of Example protos.
:param name: string, unique identifier specifying the data set
:param filenames: list of strings; each string is a path to an image file
:param texts: list of strings; each string is human readable, e.g.'dog
:param labels: list of integer identifies the ground truth
:param num_shards: integer number os shards for this data set.
:return:
"""
# filenames、texts、labels数量相对应
assert len(filenames) == len(texts)
assert len(filenames) == len(labels)
# Break all images into batches with a [ranges[i][0], ranges[i][1]].
# [0, 1835, 3670],从0至1835交给一个线程做;1835至3670交给另一个线程完成。
spacing = np.linspace(0, len(filenames), FLAGS.num_threads + 1).astype(np.int)
# 把spacing分成两部分,得到[0, 1835]和[1835, 3670]
ranges = []
for i in range(len(spacing) - 1):
ranges.append([spacing[i], spacing[i + 1]])
# Launch a thread for each batch.
print('launching %d threads for spacings: %s' % (FLAGS.num_threads, ranges))
sys.stdout.flush()
# Create a mechanism for monitoring when all threads are finished.
# TensorFlow的线程管理器
coord = tf.train.Coordinator()
# Create a generic TensorFlow-based utility for converting all image coding.
coder = ImageCoder()
threads = []
for thread_index in range(len(ranges)):
args = (coder, thread_index, ranges, name, filenames, texts, labels, num_shards)
t = threading.Thread(target=_process_image_files_batch, args=args)
t.start()
threads.append(t)
# Wait for all the threads to terminate.
coord.join(threads)
print('%s: Finished writing all %d images in data set.' % (datetime.now(), len(filenames)))
sys.stdout.flush()
def _process_dataset(name, directory, num_shards, labels_file):
"""Process a complete data set and save it as a TFRecord.
Args:
name: string, unique identifier specifying the data set.
directory: string, root path to the data set.
num_shards: integer number if shards for this data set.
labels_file: string, path to the labels file.
"""
filenames, texts, labels = _find_image_files(directory, labels_file)
_process_image_files(name, filenames, texts, labels, num_shards)
def main(unused_argv):
assert not FLAGS.train_shards % FLAGS.num_threads, ('在测试集中,线程数量应用建立文件个数相对应')
assert not FLAGS.validation_shards % FLAGS.num_threads, ('在验证集中,线程数量应用建立文件个数相对应')
print('生成数据文件夹%s' % FLAGS.output_directory)
# run it!
# 训练集
_process_dataset('train', FLAGS.train_directory, FLAGS.train_shards, FLAGS.labels_file)
# 验证集
# _process_dataset('validation', FLAGS.validation_directory, FLAGS.validation_shards, FLAGS.labels_file)
if __name__ == '__main__':
tf.app.run()
执行结果:
生成数据文件夹./data/
目标文件夹位置:./flower_photos/.
Instructions for updating:
Use tf.gfile.GFile.
Found 3670 JPEG files across 5 labels inside ./flower_photos/.
launching 2 threads for spacings: [[0, 1835], [1835, 3670]]
2019-08-28 12:49:17.142402 [thread 0]: Processed 1000 of 1835 image in thread batch.
2019-08-28 12:49:17.362402 [thread 1]: Processed 1000 of 1835 image in thread batch.
2019-08-28 12:49:25.261402 [thread 0]: Wrote 1835 images to ./data/train-00000-of-00002.tfrecord
2019-08-28 12:49:25.261402 [thread 0]: Wrote 1835 images to 1835 shards.
2019-08-28 12:49:25.810402 [thread 1]: Wrote 1835 images to ./data/train-00001-of-00002.tfrecord
2019-08-28 12:49:25.810402 [thread 1]: Wrote 1835 images to 1835 shards.
2019-08-28 12:49:26.274402: Finished writing all 3670 images in data set.
生成的TFRecord文件:
参考:
https://blog.csdn.net/moyu123456789/article/details/83956366
今天的文章【TensorFlow】用TFRecord方式对数据进行读取(一)分享到此就结束了,感谢您的阅读。
版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 举报,一经查实,本站将立刻删除。
如需转载请保留出处:https://bianchenghao.cn/9939.html