# Two tensors can be combined into one Dataset object. features = tf.constant([[1, 3], [2, 1], [3, 3]]) # ==> 3x2 tensor labels = tf.constant(['A', 'B', 'A']) # ==> 3x1 tensor dataset = Dataset.from_tensor_slices((features, labels))
# Both the features and the labels tensors can be converted # to a Dataset object separately and combined after. features_dataset = Dataset.from_tensor_slices(features) labels_dataset = Dataset.from_tensor_slices(labels) dataset = Dataset.zip((features_dataset, labels_dataset))
预处理
Dataset级别的api
.apply()
Args:
transformation func, from Dataset to Dataset
Rets:
a new Dataset
1 2 3
defdataset_fn(ds): return ds.filter(lambda x: x < 5) dataset = dataset.apply(dataset_fn)
.batch()
Args:
batch_size
drop_remainder=False
新Dataset输出的数据新增一个维度,大小为batch_size
1
dataset = dataset.batch(3, drop_remainder=True)
.shard()
Args:
num_shards
index
Rets:
一个新的Dataset,只包含原Dataset的1/num_shards数据,而且数据在原Dataset中的下标% num_shards = index
dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3]) dataset = dataset.enumerate(start=5) for element in dataset.as_numpy_iterator(): print(element)
''' (5, 1) (6, 2) (7, 3) '''
.as_numpy_iterator()
Rets:
a new Dataset
1 2 3 4 5 6 7 8 9
dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3]) for element in dataset.as_numpy_iterator(): print(element)