to_tensorflow_dataset#
- scikitplot.corpus.to_tensorflow_dataset(documents, *, text_feature=True, raw_tensor_feature=False, embedding_feature=False, label_field=None, label_map=None, batch_size=32, shuffle=False, shuffle_seed=None, dtype_map=None)[source]#
Convert documents to a
tf.data.Dataset.- Parameters:
- documentslist[CorpusDocument]
Documents to convert.
- text_featurebool, optional
Include
"text"feature (tf.string). Default:True.- raw_tensor_featurebool, optional
Include
"raw_tensor"feature (tf.uint8) when documents carry pixel arrays. Requires all tensors to share the same shape. Default:False.- embedding_featurebool, optional
Include
"embedding"feature (tf.float32). Default:False.- label_fieldstr or None, optional
CorpusDocumentattribute to use as label (e.g."source_type"). Default:None(no label).- label_mapdict[str, int] or None, optional
Map string label values to integer class ids. Required when label_field is set and the field contains strings. Default:
None.- batch_sizeint, optional
Batch size.
Nonedisables batching. Default: 32.- shufflebool, optional
Shuffle the dataset before batching. Default:
False.- shuffle_seedint or None, optional
Seed for deterministic shuffling. Default:
None.- dtype_mapdict or None, optional
Cast feature dtypes, e.g.
{"raw_tensor": tf.float32}.
- Returns:
- tf.data.Dataset
Batched dataset of feature dicts (and optionally labels).
- Raises:
- ImportError
If TensorFlow is not installed.
- ValueError
If raw_tensor_feature is True but raw tensors have different shapes across documents.
- Parameters:
- Return type:
Notes
Fallback: When TensorFlow is not available, returns a dict of NumPy arrays (via
to_numpy_arrays) so pipelines can test the shape of the output without requiring a GPU environment.Examples
Text-only dataset for a Keras text classifier:
>>> ds = to_tensorflow_dataset(docs, text_feature=True, batch_size=16) >>> for batch in ds.take(1): ... print(batch["text"].shape) # (16,)
Image dataset for a CNN:
>>> ds = to_tensorflow_dataset( ... docs, ... text_feature=False, ... raw_tensor_feature=True, ... label_field="source_type", ... label_map={"image": 0, "research": 1}, ... )