to_torch_dataloader#

scikitplot.corpus.to_torch_dataloader(documents, *, text_feature=True, raw_tensor_feature=False, embedding_feature=False, label_field=None, label_map=None, batch_size=32, shuffle=False, num_workers=0, dtype_map=None)[source]#

Convert documents to a torch.utils.data.DataLoader.

Parameters:
documentslist[CorpusDocument]

Documents to convert.

text_featurebool, optional

Include "text" key (list of str per batch). Default: True.

raw_tensor_featurebool, optional

Include "raw_tensor" key (torch.Tensor, NCHW float32). Requires all tensors to have the same shape. Default: False.

embedding_featurebool, optional

Include "embedding" key (torch.Tensor, shape (N, D)). Default: False.

label_fieldstr or None, optional

Attribute to use as label. Default: None.

label_mapdict[str, int] or None, optional

Map string labels to class indices. Default: None.

batch_sizeint, optional

Batch size. Default: 32.

shufflebool, optional

Shuffle data each epoch. Default: False.

num_workersint, optional

DataLoader worker processes. Default: 0 (main process only).

dtype_mapdict or None, optional

Cast tensors, e.g. {"raw_tensor": torch.float32}.

Returns:
torch.utils.data.DataLoader

DataLoader over a CorpusDataset.

Raises:
ImportError

If PyTorch is not installed.

Parameters:
Return type:

Any

Notes

Fallback: When PyTorch is not available, returns a dict of NumPy arrays so pipelines can test without GPU hardware.

Channel order: Raw tensors from ImageReader are (H, W, C) uint8 (channels-last). This function converts them to (C, H, W) float32 in [0, 1] (channels-first, PyTorch convention) when dtype_map is not set.

Examples

Image classification loader:

>>> loader = to_torch_dataloader(
...     docs,
...     raw_tensor_feature=True,
...     label_field="source_type",
...     label_map={"image": 0, "research": 1},
...     batch_size=16,
... )
>>> for batch in loader:
...     imgs = batch["raw_tensor"]  # (16, C, H, W) float32
...     labels = batch["label"]  # (16,) int64