bertblocks.training.packing¶
Sequence packing using batch sampling and a wrapper around collators to produce full flat-packed batches.
Classes¶
Distributed batch sampler that packs variable-length sequences into token-budget batches. |
|
Wrapper that packs output from existing collators into flat tensors. |
|
Wrapper for iterable datasets that packs continuous sequences into batches. |
Functions¶
|
Detect if attention mask uses packed sequence index format. |
Module Contents¶
- class bertblocks.training.packing.PackingBatchSampler(
- lengths: list[int],
- token_budget: int,
- world_size: int | None = None,
- rank: int | None = None,
- shuffle: bool = True,
- seed: int = 0,
- drop_last: bool = False,
Bases:
torch.utils.data.Sampler[list[int]]Distributed batch sampler that packs variable-length sequences into token-budget batches.
This sampler follows PyTorch’s DistributedSampler pattern: it packs ALL sequences globally into batches, truncates to ensure even distribution across ranks, then assigns batches round-robin to each rank. All ranks independently compute the same global packing using deterministic shuffling, ensuring no distributed communication is needed.
- Parameters:
lengths – Sequence lengths for all dataset samples.
token_budget – Maximum tokens per packed batch.
world_size – Number of processes participating in distributed training. If None, uses world_size from current distributed group.
rank – Rank of current process within num_replicas. If None, uses rank from current distributed group.
shuffle – Whether to shuffle indices before packing. Default: True.
seed – Random seed for shuffling. Should be identical across all processes. Default: 0.
drop_last – Drop the last batch, which might not be fully packed. Default: False.
Example
>>> from datasets import load_dataset >>> from torch.utils.data import DataLoader >>> dataset = load_dataset("dataset_name", split="train") >>> batch_sampler = PackingBatchSampler(list(datasets["length"]), token_budget=4096) >>> collator = PackingCollatorWrapper(base_collator, token_budget=4096) >>> dataloader = DataLoader(dataset, batch_sampler=batch_sampler, collate_fn=collator)
- __iter__() collections.abc.Iterator[list[int]][source]¶
Return an iterator over the current rank’s batches.
- class bertblocks.training.packing.PackingCollatorWrapper(
- base_collator: collections.abc.Callable[[list[dict[str, Any]]], dict[str, Any]],
- token_budget: int = 0,
Wrapper that packs output from existing collators into flat tensors.
This wrapper takes any existing collator (MLM, classification, token classification, etc.) and wraps its output into the packed format expected by the model: - Flat tensors of shape [1, total_tokens] (if token_budget > 0, padded to [1, token_budget]) - Attention mask with sequence indices (0, 1, 2, …) and -1 for padding
The wrapper automatically detects label type and handles both: - Token-level labels (e.g., MLM): shape [batch_size, seq_len] -> packed to [1, total_tokens] - Sequence-level labels (e.g., classification): shape [batch_size] -> packed to [num_sequences]
- Parameters:
base_collator – The underlying collator to wrap.
token_budget – Maximum total tokens per packed batch. If 0, no padding is applied and output is returned with dynamic shape [1, total_tokens]. Defaults to 0.
Example
>>> from bertblocks.training.objectives import MaskedLanguageModelingCollator >>> from transformers import AutoTokenizer >>> tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") >>> mlm_collator = MaskedLanguageModelingCollator(tokenizer, max_sequence_length=512) >>> # With padding for fixed-size batches >>> packing_collator = PackingCollatorWrapper(mlm_collator, token_budget=512 * 32) >>> # Without padding for dynamic-size batches >>> flat_collator = PackingCollatorWrapper(mlm_collator, token_budget=0)
- __call__(batch: list[dict[str, Any]]) dict[str, torch.Tensor][source]¶
Pack a batch of samples into flat tensors.
- Parameters:
batch – List of samples from the dataset.
- Returns:
input_ids: [1, token_budget] (if token_budget > 0) or [1, total_tokens] (if token_budget == 0)
attention_mask: [1, token_budget] or [1, total_tokens] with 0-based sequence indices and -1 for padding
- labels (if present):
Token-level: [1, token_budget] or [1, total_tokens]
Sequence-level: [num_sequences] or [num_sequences, num_classes]
- Return type:
Dictionary with packed tensors
- class bertblocks.training.packing.PackingIterableDataset(
- dataset: torch.utils.data.IterableDataset,
- token_budget: int,
- length_column: str = 'length',
- drop_last: bool = False,
Bases:
torch.utils.data.IterableDatasetWrapper for iterable datasets that packs continuous sequences into batches.
This wrapper is designed to be used with PyTorch DataLoader for streaming datasets. Each iteration yields a list of samples that have been packed together based on their sequence lengths to maximize GPU utilization.
The underlying dataset must include a length field (specified by length_column) in each sample. This can be added using dataset.map() before wrapping.
In distributed training with num_workers > 0, sharding is handled automatically by PyTorch’s DataLoader worker sharding mechanism for IterableDataset.
- Parameters:
dataset – The underlying iterable dataset to wrap.
token_budget – Maximum total tokens per batch (typically max_length * batch_size).
length_column – Name of the field containing sequence length. Defaults to “length”.
drop_last – Drop the last incomplete batch. Defaults to False.
Example
>>> from datasets import load_dataset >>> from torch.utils.data import DataLoader >>> dataset = load_dataset("dataset_name", streaming=True, split="train") >>> dataset = dataset.map(...) # Add length column if not already present >>> packed_dataset = PackingIterableDataset(dataset, token_budget=4096) >>> collator = PackingCollatorWrapper(base_collator, token_budget=4096) >>> dataloader = DataLoader(packed_dataset, batch_size=None, collate_fn=collator)
- bertblocks.training.packing.is_packed_batch(attention_mask: torch.Tensor | None) bool[source]¶
Detect if attention mask uses packed sequence index format.
Packed format uses sequence indices (0, 1, 2, …) with -1 for padding, while standard format uses binary 0/1 values.
- Parameters:
attention_mask – Attention mask tensor to check, or None.
- Returns:
True if packed format, False if standard format or None.