to_torch_dataloader#
- scikitplot.corpus.to_torch_dataloader(documents, *, text_feature=True, raw_tensor_feature=False, embedding_feature=False, label_field=None, label_map=None, batch_size=32, shuffle=False, num_workers=0, dtype_map=None)[source]#
Convert documents to a
torch.utils.data.DataLoader.- Parameters:
- documentslist[CorpusDocument]
Documents to convert.
- text_featurebool, optional
Include
"text"key (list of str per batch). Default:True.- raw_tensor_featurebool, optional
Include
"raw_tensor"key (torch.Tensor, NCHW float32). Requires all tensors to have the same shape. Default:False.- embedding_featurebool, optional
Include
"embedding"key (torch.Tensor, shape(N, D)). Default:False.- label_fieldstr or None, optional
Attribute to use as label. Default:
None.- label_mapdict[str, int] or None, optional
Map string labels to class indices. Default:
None.- batch_sizeint, optional
Batch size. Default: 32.
- shufflebool, optional
Shuffle data each epoch. Default:
False.- num_workersint, optional
DataLoader worker processes. Default: 0 (main process only).
- dtype_mapdict or None, optional
Cast tensors, e.g.
{"raw_tensor": torch.float32}.
- Returns:
- torch.utils.data.DataLoader
DataLoader over a
CorpusDataset.
- Raises:
- ImportError
If PyTorch is not installed.
- Parameters:
- Return type:
Notes
Fallback: When PyTorch is not available, returns a dict of NumPy arrays so pipelines can test without GPU hardware.
Channel order: Raw tensors from
ImageReaderare(H, W, C)uint8 (channels-last). This function converts them to(C, H, W)float32 in[0, 1](channels-first, PyTorch convention) whendtype_mapis not set.Examples
Image classification loader:
>>> loader = to_torch_dataloader( ... docs, ... raw_tensor_feature=True, ... label_field="source_type", ... label_map={"image": 0, "research": 1}, ... batch_size=16, ... ) >>> for batch in loader: ... imgs = batch["raw_tensor"] # (16, C, H, W) float32 ... labels = batch["label"] # (16,) int64