【TensorFlow】用TFRecord方式对数据进行读取(一)

【TensorFlow】用TFRecord方式对数据进行读取(一)在做深度学习项目时,在模型训练前,通常要对训练/验证图像进行读取操作。之前博文《TensorFlow卷积神经网络-猫狗识别》使用的是OpenCV读取的方式。使用OpenCV把图像读成矩阵形式当然可以满足模型训练的要求,此方式在处理小批量图像时还可以,如果处理大批量图像,就显得有点慢了。对于大型项目、大批量的图像,经常用TFRecord的方式对数据进行读取。TFRecord是TensorF…

在做深度学习项目时,在模型训练前,通常要对训练/验证图像进行读取操作。之前博文《TensorFlow 卷积神经网络 – 猫狗识别》使用的是OpenCV读取的方式。使用OpenCV把图像读成矩阵形式当然可以满足模型训练的要求,此方式在处理小批量图像时还可以,如果处理大批量图像,就显得有点慢了。

对于大型项目、大批量的图像,经常用TFRecord的方式对数据进行读取。TFRecord是TensorFlow支持的格式,速度快,1W以上的量建议使用TFRecord。TFRecord文件是以二进制进行存储数据的,适合以串行的方式读取大批量数据。其优势是能更好的利用内存,更方便地复制和移动,这更符合TensorFlow执行引擎的处理方式。通常数据转换成tfrecord格式需要写个小程序将每一个样本组装成protocol buffer定义的Example的对象,序列化成字符串,再由tf.python_io.TFRecordWriter写入文件即可。

在使用TFRecord方式读取数据之前,通常需要把相同类型的数据放在同一个文件夹。例如:

【TensorFlow】用TFRecord方式对数据进行读取(一)

上图中,“flower_photos”为总文件夹,里面放了5个子文件夹,即把所有的玫瑰图片放到“roses”文件夹,所有的向日葵图片放到“sunflowers”文件夹,等等。这样做的目的是方便完成“图片路径”–“图片标签(例:1、2、3)”–“图片名称(例:daisy、dandelion、roses)”之间的映射。

roses文件夹下的图片:

【TensorFlow】用TFRecord方式对数据进行读取(一)

程序实现

目录结构:

【TensorFlow】用TFRecord方式对数据进行读取(一)

flower_label.txt:

此文件的内容存放./flower_photos目录下的5个子文件名称,方便程序读取图片。

daisy
dandelion
roses
sunflowers
tulips

build_image_data.py:

# coding=utf-8
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from datetime import datetime
import os
import random
import sys
# 多线程制作,速度更快。数据预处理、建立数据源写一块
import threading

import numpy as np
import tensorflow as tf

# 定义string和int类型参数
# 没演示验证集,只有训练集,可以在目录里面加上验证集。train_directory为参数名
tf.app.flags.DEFINE_string('train_directory', './flower_photos/', 'Training data directory')
# 验证集,未指定单独的验证集,偷懒
tf.app.flags.DEFINE_string('validation_directory', './flower_photos/', 'Validation data directory')
# TFRecord输出目录
tf.app.flags.DEFINE_string('output_directory', './data/', 'Output data directory')

# 想生成几个TFrecord文件,train_shards / num_threads 要能够整除,这样才好能分配数量
tf.app.flags.DEFINE_integer('train_shards', 2, 'Number of shards in training TFRecord files.')
# 同上,不做验证集,只做训练集
tf.app.flags.DEFINE_integer('validation_shards', 0, 'Number of shards in validation TFRecord files.')
# 启动线程的个数
tf.app.flags.DEFINE_integer('num_threads', 2, 'Number of threads to preprocess the images.')

# The labels file contains a list of valid labels are held in this file .
# Assumes that the file contains entries as such:
# dog
# cat
# flower
# where each line corresponds to a labels. We map each label contained in
# the file to an integer corresponding to the line number starting from 0.
# flower_label.txt和子文件夹的名字一一对应
tf.app.flags.DEFINE_string('labels_file', './flower_label.txt', 'labels file')

# 获得上述定义的参数
FLAGS = tf.app.flags.FLAGS


def _int64_feature(value):
    """Wrapper for inserting int64 feature into Example proto."""
    """isinstance() 函数来判断一个对象是否是一个已知的类型,类似 type()
    isinstance() 与 type() 区别:
        type() 不会认为子类是一种父类类型,不考虑继承关系。
        isinstance() 会认为子类是一种父类类型,考虑继承关系。
    """
    if not isinstance(value, list):
        value = [value]
    return tf.train.Feature(int64_list=tf.train.Int64List(value=value))


def _bytes_feature(value):
    """Wrapper for inserting bytes features into Example proto."""
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))


def _find_image_files(data_dir, labels_file):
    """
    Build a list of all images files and labels in the data set.
    :param data_dir: string, path to the root directory of images.
    :param labels_file: string, path to the labels file.
        The list of valid labels are held in this file, Assumes that the file contains entries as such:
            dog
            cat flower
        where each line corresponds to a label. We map each label contained in the file to an integer staring with the
        integer 0 corresponding to the label contained in the first line.

    :return:
        filenames: list of strings; each string is a path to an image file.
        texts: list of strings; each string is the class, e.g. 'dog'
        labels: list of integer; each integer identifies the ground truth.
    """

    print('目标文件夹位置:%s.' % data_dir)
    # 读flower_label.txt文件的内容
    """tf.gfile.FastGFile(path, decodestyle) 
    函数功能:实现对图片的读取。 
    函数参数:(1)path:图片所在路径 (2)decodestyle:图片的解码方式。(‘r’:UTF-8编码; ‘rb’:非UTF-8编码)
    """
    unique_labels = [l.strip() for l in tf.gfile.FastGFile(labels_file, 'r').readlines()]

    labels = []
    filenames = []
    texts = []

    # Leave label index 0 empty as a background class.
    label_index = 1

    # Construct the list of JPEG files and labels.
    for text in unique_labels:
        jpeg_file_path = '%s/%s/*' % (data_dir, text)
        try:
            # tf.gfile.Glob()用于返回与给定模式匹配的文件列表
            matching_files = tf.gfile.Glob(jpeg_file_path)
        except:
            print(jpeg_file_path)
            continue

        # 从“1”开始,扩充每一图片类别的labels
        labels.extend([label_index] * len(matching_files))
        # 根据flower_label.txt内容,扩充texts
        texts.extend([text] * len(matching_files))
        filenames.extend(matching_files)

        label_index += 1

    # shuffle the ordering of all image files in order to guarantee
    # random ordering of the images with respect to label in the
    # saved TFRecord files. Make the randomization repeatable.
    # 洗牌,把当前顺序打乱,标签为1、2、3、4、5、打乱
    shuffled_index = list(range(len(filenames)))
    # 保证shuffled_index之后每次的随机一样
    random.seed(12345)
    random.shuffle(shuffled_index)

    # 数据重新排列,执行完shuffle之后,数据可以对应上
    filenames = [filenames[i] for i in shuffled_index]
    texts = [texts[i] for i in shuffled_index]
    labels = [labels[i] for i in shuffled_index]

    print('Found %d JPEG files across %d labels inside %s.' % (len(filenames), len(unique_labels), data_dir))
    return filenames, texts, labels


class ImageCoder(object):
    """Helper class that provides TensorFlow image coding utilities."""

    # 把所有图片转换成.jpg的RGB的形式
    def __init__(self):
        # Create a single Session to run all image coding calls.
        self._sess = tf.Session()

        # Initializes function that converts PNG to JPEG data.
        # 确保所有图像格式都相同
        self._png_data = tf.placeholder(dtype=tf.string)
        # 解码为3通道
        image = tf.image.decode_png(self._png_data, channels=3)
        # 编码为RGB
        self._png_to_jpeg = tf.image.encode_jpeg(image, format='rgb', quality=100)

        # Initializes function that decodes RGB JPEG data.
        self._decode_jpeg_data = tf.placeholder(dtype=tf.string)
        self._decode_jpeg = tf.image.decode_jpeg(self._decode_jpeg_data, channels=3)

    def png_to_jpeg(self, image_data):
        return self._sess.run(self._png_to_jpeg, feed_dict={self._png_data: image_data})

    def decode_jpeg(self, image_data):
        image = self._sess.run(self._decode_jpeg, feed_dict={self._decode_jpeg_data: image_data})
        assert len(image.shape) == 3
        assert image.shape[2] == 3
        return image


def _process_image(filename, coder):
    """
    Process a single image file.
    :param filename: string, path to an image file e.g., '/path/to/example.JPG'.
    :param coder: instance of ImageCoder to provide TensorFlow image coding utils.
    :return: image_buffer: string, JPEG encoding of RGB image.
        height: integer, image height in pixels.
        width: integer, image width in pixels.
    """
    # Read the image file.
    with tf.gfile.FastGFile(filename, 'rb') as f:
        image_data = f.read()

    # Convert any PNG to JPEG's for consistency.
    if _is_png(filename):
        print('Converting PNG to JPEG for %s' % filename)
        image_data = coder.png_to_jpeg(image_data)

    # Decode the RGB JPEG.
    image = coder.decode_jpeg(image_data)

    # Check that image converted to RGB. h, w, channel
    assert len(image.shape) == 3
    height = image.shape[0]
    width = image.shape[1]
    # 判断是否三通道
    assert image.shape[2] == 3

    return image_data, height, width


def _is_png(filename):
    """
    Determine if a file contains a PNG format image.
    :param filename: string, path of the iamge file.
    :return: boolean indicating if the image is a PNG.
    """
    return '.png' in filename


def _convert_to_example(filename, image_buffer, label, text, height, width):
    """
    Build an Example proto for an example.
    :param filename: string, path to an image file, e.g., '/path/to/example.JPG'
    :param image_buffer: string, JPEG encoding of RGB image
    :param label: integer, identifier for the ground truth for the network
    :param text: string, unique human-readable, e.g. 'dog'
    :param height: integer, image height in pixels
    :param width: integer, image width in pixels
    :return: Example proto
    """
    colorspace = 'RGB'
    channels = 3
    image_format = 'JPEG'

    # tf.compat.as_bytes(),将字节或unicode转换为字节,使用utf-8编码文本
    example = tf.train.Example(features=tf.train.Features(feature={
        'image/height': _int64_feature(height),
        'image/width': _int64_feature(width),
        'image/colorspace': _bytes_feature(tf.compat.as_bytes(colorspace)),
        'image/channels': _int64_feature(channels),
        'image/class/label': _int64_feature(label),
        'image/class/text': _bytes_feature(tf.compat.as_bytes(text)),
        'image/format': _bytes_feature(tf.compat.as_bytes(image_format)),
        'image/filename': _bytes_feature(tf.compat.as_bytes(os.path.basename(filename))),
        'image/encoded': _bytes_feature(tf.compat.as_bytes(image_buffer))
        # 'image/encoded': _bytes_feature(image_buffer)
    }))
    return example


def _process_image_files_batch(coder, thread_index, ranges, name, filenames, texts, labels, num_shards):
    """
    Processes and saves list of images as TFRecord in 1 thread.
    :param coder: instance of ImageCoder to provide TensorFlow image coding utils.
    :param thread_index: integer, unique batch to run index is within [0, len(ranges)].
    :param ranges: list of pairs of integers specifying ranges of each batches to analyze in parallel.
    :param name: string, unique identifier specifying the data set.
    :param filenames: list of strings; each string is a path to an image file.
    :param texts: list of strings; each string is human readable, e.g. 'dog'.
    :param labels: list of integer; each integer identifies the ground truth.
    :param num_shards: integer number of shards for this data set.
    :return:
    """
    # Each thread produces N shards where N=int(num_shards / num_threads).
    # For instance, if num_shards=128, and the num_threads=2, then the first thread would produce shards[0, 64].
    num_threads = len(ranges)
    assert not num_shards % num_threads
    num_shards_per_batch = int(num_shards / num_threads)

    shard_ranges = np.linspace(ranges[thread_index][0], ranges[thread_index][1], num_shards_per_batch + 1).astype(int)
    num_files_in_thread = ranges[thread_index][1] - ranges[thread_index][0]

    counter = 0
    for s in range(num_shards_per_batch):
        # Generate a sharded version of the file name, e.g. 'train-00001-of-00002'
        shard = thread_index * num_shards_per_batch + s
        output_filename = '%s-%.5d-of-%.5d.tfrecord' % (name, shard, num_shards)
        output_file = os.path.join(FLAGS.output_directory, output_filename)
        writer = tf.python_io.TFRecordWriter(output_file)

        shard_counter = 0
        files_in_shard = np.arange(shard_ranges[s], shard_ranges[s + 1], dtype=int)
        for i in files_in_shard:
            filename = filenames[i]  # 全路径
            label = labels[i]  # 标签
            text = texts[i]  # 文件夹名称

            image_buffer, height, width = _process_image(filename, coder)

            example = _convert_to_example(filename, image_buffer, label, text, height, width)

            writer.write(example.SerializeToString())
            shard_counter += 1
            counter += 1

            if not counter % 1000:
                print('%s [thread %d]: Processed %d of %d image in thread batch.' % (
                    datetime.now(), thread_index, counter, num_files_in_thread))
                sys.stdout.flush()

        writer.close()
        print('%s [thread %d]: Wrote %d images to %s' % (datetime.now(), thread_index, shard_counter, output_file))
        # 关闭多线程
        sys.stdout.flush()
        shard_counter = 0
    print(
        '%s [thread %d]: Wrote %d images to %d shards.' % (datetime.now(), thread_index, counter, num_files_in_thread))
    sys.stdout.flush()


def _process_image_files(name, filenames, texts, labels, num_shards):
    """
    Process and save list of image as TFRecord of Example protos.
    :param name: string, unique identifier specifying the data set
    :param filenames: list of strings; each string is a path to an image file
    :param texts: list of strings; each string is human readable, e.g.'dog
    :param labels: list of integer identifies the ground truth
    :param num_shards: integer number os shards for this data set.
    :return:
    """

    # filenames、texts、labels数量相对应
    assert len(filenames) == len(texts)
    assert len(filenames) == len(labels)

    # Break all images into batches with a [ranges[i][0], ranges[i][1]].
    # [0, 1835, 3670],从0至1835交给一个线程做;1835至3670交给另一个线程完成。
    spacing = np.linspace(0, len(filenames), FLAGS.num_threads + 1).astype(np.int)
    # 把spacing分成两部分,得到[0, 1835]和[1835, 3670]
    ranges = []
    for i in range(len(spacing) - 1):
        ranges.append([spacing[i], spacing[i + 1]])

    # Launch a thread for each batch.
    print('launching %d threads for spacings: %s' % (FLAGS.num_threads, ranges))
    sys.stdout.flush()

    # Create a mechanism for monitoring when all threads are finished.
    # TensorFlow的线程管理器
    coord = tf.train.Coordinator()

    # Create a generic TensorFlow-based utility for converting all image coding.
    coder = ImageCoder()

    threads = []
    for thread_index in range(len(ranges)):
        args = (coder, thread_index, ranges, name, filenames, texts, labels, num_shards)
        t = threading.Thread(target=_process_image_files_batch, args=args)
        t.start()
        threads.append(t)

    # Wait for all the threads to terminate.
    coord.join(threads)
    print('%s: Finished writing all %d images in data set.' % (datetime.now(), len(filenames)))
    sys.stdout.flush()


def _process_dataset(name, directory, num_shards, labels_file):
    """Process a complete data set and save it as a TFRecord.

    Args:
        name: string, unique identifier specifying the data set.
        directory: string, root path to the data set.
        num_shards: integer number if shards for this data set.
        labels_file: string, path to the labels file.
    """
    filenames, texts, labels = _find_image_files(directory, labels_file)
    _process_image_files(name, filenames, texts, labels, num_shards)


def main(unused_argv):
    assert not FLAGS.train_shards % FLAGS.num_threads, ('在测试集中,线程数量应用建立文件个数相对应')
    assert not FLAGS.validation_shards % FLAGS.num_threads, ('在验证集中,线程数量应用建立文件个数相对应')
    print('生成数据文件夹%s' % FLAGS.output_directory)

    # run it!
    # 训练集
    _process_dataset('train', FLAGS.train_directory, FLAGS.train_shards, FLAGS.labels_file)
    # 验证集
    # _process_dataset('validation', FLAGS.validation_directory, FLAGS.validation_shards, FLAGS.labels_file)


if __name__ == '__main__':
    tf.app.run()

执行结果:

生成数据文件夹./data/
目标文件夹位置:./flower_photos/.
Instructions for updating:
Use tf.gfile.GFile.
Found 3670 JPEG files across 5 labels inside ./flower_photos/.
launching 2 threads for spacings: [[0, 1835], [1835, 3670]]
2019-08-28 12:49:17.142402 [thread 0]: Processed 1000 of 1835 image in thread batch.
2019-08-28 12:49:17.362402 [thread 1]: Processed 1000 of 1835 image in thread batch.
2019-08-28 12:49:25.261402 [thread 0]: Wrote 1835 images to ./data/train-00000-of-00002.tfrecord
2019-08-28 12:49:25.261402 [thread 0]: Wrote 1835 images to 1835 shards.
2019-08-28 12:49:25.810402 [thread 1]: Wrote 1835 images to ./data/train-00001-of-00002.tfrecord
2019-08-28 12:49:25.810402 [thread 1]: Wrote 1835 images to 1835 shards.
2019-08-28 12:49:26.274402: Finished writing all 3670 images in data set.

生成的TFRecord文件:

【TensorFlow】用TFRecord方式对数据进行读取(一)

参考:

https://blog.csdn.net/moyu123456789/article/details/83956366

今天的文章【TensorFlow】用TFRecord方式对数据进行读取(一)分享到此就结束了,感谢您的阅读。

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 举报,一经查实,本站将立刻删除。
如需转载请保留出处:https://bianchenghao.cn/9939.html

(0)
编程小号编程小号

相关推荐

发表回复

您的电子邮箱地址不会被公开。 必填项已用*标注