0%

调包侠工具箱:tf.data.Dataset

这是个啥?

  • tf.data.Dataset包含数据集导入、预处理、导出的高级api
  • 适用于数据量可以分批导入内存的场景
  • 可以通过简单的api实现流水线处理
  • 这货本质是个封装程度非常高的IO Adapter

Dataset创建

与导入操作有关的api

  • python list导入
1
dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
  • txt文件导入
1
2
3
4
dataset = tf.data.TextLineDataset([
"file1.txt",
"file2.txt"
])
  • TFRecord导入
1
2
3
4
dataset = tf.data.TFRecordDataset([
"file1.tfrecords",
"file2.tfrecords"
])
  • 从多文件导入(正则表达式)
1
dataset = tf.data.Dataset.list_files("/path/*.txt")
  • 其他导入方式
    • tf.data.FixedLengthRecordDataset
    • tf.data.Dataset.from_generator

一些优雅的导入姿势

  • 导入一个tuple
1
2
3
4
5
6
7
8
# Slicing a tuple of 1D tensors produces tuple elements containing scalar tensors.
dataset = tf.data.Dataset.from_tensor_slices((
[1, 2],
[3, 4],
[5, 6]
))
list(dataset.as_numpy_iterator())
# [(1, 3, 5), (2, 4, 6)]
  • 导入两个tuple
1
2
3
4
5
6
7
8
9
10
# Two tensors can be combined into one Dataset object.
features = tf.constant([[1, 3], [2, 1], [3, 3]]) # ==> 3x2 tensor
labels = tf.constant(['A', 'B', 'A']) # ==> 3x1 tensor
dataset = Dataset.from_tensor_slices((features, labels))

# Both the features and the labels tensors can be converted
# to a Dataset object separately and combined after.
features_dataset = Dataset.from_tensor_slices(features)
labels_dataset = Dataset.from_tensor_slices(labels)
dataset = Dataset.zip((features_dataset, labels_dataset))

预处理

Dataset级别的api

  • .apply()

    • Args:

      • transformation func, from Dataset to Dataset
    • Rets:

      • a new Dataset
1
2
3
def dataset_fn(ds):
return ds.filter(lambda x: x < 5)
dataset = dataset.apply(dataset_fn)
  • .batch()
    • Args:
      • batch_size
      • drop_remainder=False
    • 新Dataset输出的数据新增一个维度,大小为batch_size
1
dataset = dataset.batch(3, drop_remainder=True)
  • .shard()
    • Args:
      • num_shards
      • index
    • Rets:
      • 一个新的Dataset,只包含原Dataset的1/num_shards数据,而且数据在原Dataset中的下标% num_shards = index
1
dataset_B = dataset_A.shard(num_shards=3, index=0)
  • .concatenate()
    • Args:
      • another Dataset
    • Rets:
      • a new Dataset
1
2
3
4
a = tf.data.Dataset.range(1, 4)  # ==> [ 1, 2, 3 ]
b = tf.data.Dataset.range(4, 8) # ==> [ 4, 5, 6, 7 ]
ds = a.concatenate(b)
# ==> [ 1, 2, 3, 4, 5, 6, 7 ]
  • Dataset.zip()
    • Args:
      • (Dataset, Dataset, …)
    • Rets:
      • a new Dataset
1
2
3
4
a = tf.data.Dataset.range(1, 4)  # ==> [ 1, 2, 3 ]
b = tf.data.Dataset.range(4, 7) # ==> [ 4, 5, 6 ]
ds = tf.data.Dataset.zip((a, b))
# ==> [(1, 4), (2, 5), (3, 6)]

字段级别的api

  • .filter()
    • Args:
      • predicate(func)
    • Rets:
      • a new Dataset, element by predicate is True
1
dataset = dataset.filter(lambda x: x < 3)
  • .map()
    • Args:
      • map_func
      • num_parallel_calls(=tf.data.experimental.AUTOTUNE)
      • deterministic
    • Rets:
      • a new Dataset
1
2
dataset = Dataset.range(1, 6)  # ==> [ 1, 2, 3, 4, 5 ]
dataset = dataset.map(lambda x: x + 1)
  • interleave()
    • Args:
      • map_func
      • cycle_length=AUTOTUNE
      • block_length=1
      • num_parallel_calls(=tf.data.experimental.AUTOTUNE)
      • deterministic=None
    • Rets:
      • a new Dataset

如何理解:

  1. 遍历原Dataset,用map_func函数处理每个字段,该过程并行度由cycle_length控制
1
2
3
4
5
6
7
8
9
10
11
12
# 假设cycle_length = 3
[map_func(x1), map_func(x2), map_func(x3)]
=> [res1, res2, res3]
=> cycle1

[map_func(x4), map_func(x5), map_func(x6)]
=> [res4, res5, res6]
=> cycle2

...

==> cycles = [cycle1, cycle2, cycle3, ...]
  1. 遍历cycles,取每个cycle的block_length个元素后,换下一个cycle继续读取,环形遍历直到所有元素都被取出
1
2
3
4
5
6
7
8
9
10
# 假设block_length = 2

cycle1 => [res1, res2]
cycle2 => [res4, res5]

cycle1 => [res3]
cycle2 => [res6]

# 最终结果被flatten
==> [res1, res2, res4, res5, res3, res6]
  1. 最终获得flatten后的新Dataset

Dataset导出

  • .prefetch()
    • Args:
      • buffer_size
    • Rets:
      • a new Dataset

官方建议在所有Dataset的处理结束后加上.prefetch()

1
2
3
4
5
# 预取两个字段
dataset = dataset.prefetch(2)

# 预取两个batch
dataset = dataset.batch(20).prefetch(2)
  • .shuffle()
    • Args:
      • buffer_size
      • seed=None
      • reshuffle_each_iteration=None
    • Rets:
      • a new Dataset

官方建议buffer_size >= Dataset数据量

  • .enumerate()
    • Args:
      • start=0
    • Rets:
      • a new Dataset,每个输出结果多了一个表示索引的维度
1
2
3
4
5
6
7
8
9
10
dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
dataset = dataset.enumerate(start=5)
for element in dataset.as_numpy_iterator():
print(element)

'''
(5, 1)
(6, 2)
(7, 3)
'''
  • .as_numpy_iterator()
    • Rets:
      • a new Dataset
1
2
3
4
5
6
7
8
9
dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
for element in dataset.as_numpy_iterator():
print(element)

'''
1
2
3
'''