Training¶
The bertblocks.training package provides PyTorch Lightning modules for pretraining
and finetuning, along with data loading, objectives, optimizers, and schedulers.
Lightning Modules¶
- class bertblocks.training.modules.BertBlocksPretrainingModule(
- learning_rate: float | None = 1e-05,
- weight_decay: float | None = 0.001,
- compile_model: bool | None = True,
- pretrained_tokenizer_name_or_path: str | None = None,
- optimizer_class: str | None = 'adamw',
- optimizer_quantized: bool = False,
- optimizer_kwargs: dict[str, Any] | None = None,
- scheduler_warmup_kind: Literal['constant', 'linear', 'cosine', 'exponential'] | None = 'linear',
- scheduler_warmup_steps: int | None = 1000,
- scheduler_warmup_decay: float = 0.1,
- scheduler_training_kind: Literal['constant', 'linear', 'cosine', 'exponential'] = 'constant',
- scheduler_training_steps: int = -1,
- scheduler_training_decay: float = 1.0,
- scheduler_cooldown_kind: Literal['constant', 'linear', 'inverse-sqrt', 'cosine', 'exponential'] = 'linear',
- scheduler_cooldown_steps: int = 0,
- scheduler_cooldown_decay: float = 0.0,
- objective: Literal['mlm', 'enhanced_mlm', 'diffusion'] = 'mlm',
- gradient_checkpointing: bool = False,
- model_config_kwargs: dict[str, Any] | None = None,
- model_kwargs: dict[str, Any] | None = None,
Bases:
LightningModulePyTorch Lightning module for BertBlocks MLM pretraining.
This module encapsulates the complete training logic for BertBlocks pretraining, including model initialization, optimization setup, and training step implementation. It supports advanced features like model compilation, sophisticated learning rate scheduling, and automatic checkpoint saving.
The module automatically configures the BertBlocks model based on the provided hyperparameters and handles all aspects of the training loop.
- configure_model() None[source]¶
Compile the model core after DDP setup so each rank has its own CUDA context.
Only the inner compute path (embed -> encode -> norm -> scale) is compiled. Unpadding and output assembly remain outside the compiler scope to avoid dynamic shape issues from varying batch dimensions.
- configure_optimizers() tuple[list[Optimizer], list[dict[str, Any]]][source]¶
Configure optimizers and learning rate schedulers.
Sets up optimizer with weight decay only applied to non-bias parameters (excluding norm parameters). Uses a sequential learning rate schedule with linear warmup followed by exponential decay.
- Returns:
- Contains optimizer list and scheduler configuration list.
The scheduler is configured to update every step during training.
- Return type:
- on_save_checkpoint(*args: Any, **kwargs: Any) None[source]¶
Save model checkpoint in HuggingFace format.
This method is called whenever Lightning saves a checkpoint and additionally saves the model in HuggingFace format for easy loading and deployment. Only saves on the main process in distributed training.
- Parameters:
*args – Variable arguments (unused).
**kwargs – Keyword arguments (unused).
- training_step(batch: dict[str, Tensor], batch_idx: int) Tensor[source]¶
Perform a single training step.
- Parameters:
batch – Batch dictionary containing ‘input_ids’, ‘attention_mask’, and ‘labels’ tensors from the MLM collator.
batch_idx – Index of the current batch (unused but required by Lightning).
- Returns:
MLM loss for backpropagation.
- Return type:
- class bertblocks.training.modules.BertBlocksFinetuningModule(
- task: Literal['classification', 'token_classification', 'question_answering'],
- pretrained_model_name_or_path: str,
- num_labels: int | None = None,
- learning_rate: float = 1e-05,
- weight_decay: float = 0.01,
- compile_model: bool = True,
- optimizer_class: str = 'adamw',
- optimizer_kwargs: dict[str, Any] | None = None,
- optimizer_quantized: bool = False,
- scheduler_type: Literal['linear', 'cosine', 'constant', 'polynomial'] | None = None,
- scheduler_kwargs: dict[str, Any] | None = None,
- warmup_steps: int = 0,
- warmup_ratio: float = 0.0,
- gradient_checkpointing: bool = False,
Bases:
LightningModulePyTorch Lightning module for BertBlocks finetuning.
This module handles finetuning of pretrained BertBlocks models on downstream tasks including classification, token classification, and question answering.
Data Modules¶
- class bertblocks.training.modules.BertBlocksPretrainingDataModule(
- train_dataset_name_or_path: str | list[str],
- pretrained_tokenizer_name_or_path: str,
- objective: Literal['mlm', 'enhanced_mlm', 'diffusion'] = 'mlm',
- max_sequence_length: int = 512,
- val_dataset_name_or_path: str | list[str] | None = None,
- train_split: str | None = None,
- val_split: str | list[str] | None = None,
- file_format: str | None = None,
- text_column: str = 'text',
- streaming: bool = False,
- shuffle: bool = False,
- num_shards: int | None = None,
- train_batch_size: int = 32,
- val_batch_size: int = 32,
- num_workers: int = 0,
- collator_kwargs: dict[str, Any] | None = None,
- packing: bool = False,
- packing_pad_to_budget: bool = False,
- cache_dir: str | None = None,
- data_kwargs: dict[str, Any] | None = None,
Bases:
LightningDataModulePyTorch Lightning DataModule for BertBlocks MLM pretraining.
This DataModule handles all aspects of data loading for pretraining, including dataset preparation, tokenization, and batch creation. Currently configured to use the TinyStories dataset but can be easily adapted for other datasets.
The module supports streaming datasets for large-scale pretraining and includes configurable batch sizes and data loading parameters.
- prepare_data() None[source]¶
Download and tokenize data. Called only on LOCAL_RANK=0 per node (Lightning handles this automatically).
- setup(stage: str | None = None) None[source]¶
Load the dataset for training. Called on every process - loads from cache populated by prepare_data().
- train_dataloader() DataLoader[source]¶
Create the training data loader.
- Returns:
DataLoader configured for pretraining, optionally with packed batches.
- val_dataloader() DataLoader | list[DataLoader][source]¶
Create the validation data loader(s).
- Returns:
Single DataLoader, list of DataLoaders for multiple validation sets, or dummy dataloader if no validation dataset is set.
- class bertblocks.training.modules.BertBlocksFinetuningDataModule(
- task: Literal['classification', 'question_answering', 'token_classification'],
- pretrained_tokenizer_name_or_path: str,
- dataset_name_or_path: str | None = None,
- dataset_config_name: str | None = None,
- max_sequence_length: int = 512,
- train_split: str = 'train',
- val_split: str = 'validation',
- test_split: str = 'test',
- text_column: str = 'text',
- label_column: str = 'label',
- train_batch_size: int = 32,
- val_batch_size: int = 32,
- test_batch_size: int = 32,
- num_workers: int = 0,
- shuffle_train: bool = True,
- collator_kwargs: dict[str, Any] | None = None,
Bases:
LightningDataModulePyTorch Lightning DataModule for finetuning tasks.
Supports classification, token classification, and question answering tasks with flexible dataset loading from HuggingFace Hub or local files.
- set_datasets( ) None[source]¶
Manually set datasets for custom data loading.
- Parameters:
train – Training dataset.
val – Validation dataset.
test – Test dataset.
- setup(stage: str | None = None) None[source]¶
Set up datasets for each process. Called on every process.
- Parameters:
stage – Current stage (‘fit’, ‘validate’, ‘test’, ‘predict’).
- test_dataloader() DataLoader[source]¶
Create the test data loader.
- Returns:
DataLoader configured for testing or None if no test dataset.
- train_dataloader() DataLoader[source]¶
Create the training data loader.
- Returns:
DataLoader configured for training or None if no training dataset.
- val_dataloader() DataLoader[source]¶
Create the validation data loader.
- Returns:
DataLoader configured for validation or None if no validation dataset.
Training Objectives (Collators)¶
Collators prepare batches for specific training objectives. Each collator handles tokenization, masking, and label creation for its task.
- class bertblocks.training.objectives.Collator(
- pad_token_id: int,
- mask_token_id: int,
- vocab_size: int,
- text_column: str = 'text',
- label_column: str | None = None,
- max_sequence_length: int = 1024,
Abstract data collator class for pretraining tasks.
A data collator is responsible for processing pretokenized input data into a format suitable for model training. This includes padding and applying any necessary transformations such as masking for language modeling tasks.
Inherited classes must implement the compute_labels method.
- class bertblocks.training.objectives.MaskedLanguageModelingCollator(
- pad_token_id: int,
- mask_token_id: int,
- vocab_size: int,
- text_column: str = 'text',
- max_sequence_length: int = 1024,
- mlm_probability: float = 0.3,
Bases:
CollatorData collator for masked language modeling pretraining.
Applies masking to create MLM training examples.
- Parameters:
pad_token_id – Token ID used for padding sequences.
- compute_labels(tokenized: dict[str, Any]) Any[source]¶
Compute the MLM labels for the given batch of tokenized inputs.
Applies masking following the BERT MLM strategy for tokens sampled with MLM probability: - 80% of the time: replace token with [MASK] - 10% of the time: replace token with random token - 10% of the time: keep token unchanged
- class bertblocks.training.objectives.EnhancedMaskedLanguageModelingCollator(
- pad_token_id: int,
- mask_token_id: int,
- vocab_size: int,
- text_column: str = 'text',
- label_column: str | None = None,
- max_sequence_length: int = 1024,
Bases:
CollatorData collator for enhanced masked language modeling pretraining.
Prepares pretokenized sequences for enhanced MLM. Unlike standard MLM, masking is not applied in the collator but is handled by the model itself.
Expects input to be already tokenized (handled in the data module).
- class bertblocks.training.objectives.TokenClassificationCollator(
- pad_token_id: int,
- mask_token_id: int,
- vocab_size: int,
- text_column: str = 'text',
- label_column: str = 'labels',
- max_sequence_length: int = 1024,
Bases:
CollatorData collator for token classification tasks.
Handles formatting for token classification tasks like NER, POS tagging, etc. Pads pretokenized sequences and aligns token-level labels.
Expects input to be already tokenized (handled in the data module).
- Parameters:
pad_token_id (int) – Token ID used for padding sequences.
mask_token_id (int) – Token ID used for masking sequences.
vocab_size (int) – Size of the tokenizer vocabulary.
text_column (str) – Name of the column containing text data. Defaults to “text”.
label_column (str) – Name of the column containing label data. Defaults to “labels”.
max_sequence_length (int) – Maximum sequence length for padding. Defaults to 1024.
- compute_labels(tokenized: dict[str, Any]) dict[str, Any][source]¶
Compute token classification labels for the given batch.
For token classification, we need to ensure labels align with tokenized input. The labels should be padded to match the sequence length and use -100 for special tokens and padding.
- class bertblocks.training.objectives.SequenceClassificationCollator(
- pad_token_id: int,
- mask_token_id: int,
- vocab_size: int,
- text_column: str = 'text',
- label_column: str = 'label',
- max_sequence_length: int = 1024,
Bases:
CollatorData collator for sequence classification tasks.
Handles formatting for sequence classification tasks like sentiment analysis, text classification, etc. Pads pretokenized sequences and preserves sequence-level labels.
- Parameters:
pad_token_id (int) – Token ID used for padding sequences.
mask_token_id (int) – Token ID used for masking sequences.
vocab_size (int) – Size of the tokenizer vocabulary.
text_column (str) – Name of the column containing text data. Defaults to “text”.
label_column (str) – Name of the column containing label data. Defaults to “label”.
max_sequence_length (int) – Maximum sequence length for padding. Defaults to 1024.
- class bertblocks.training.objectives.QuestionAnsweringCollator(
- pad_token_id: int,
- mask_token_id: int,
- vocab_size: int,
- text_column: str = 'question',
- label_column: str = 'answers',
- context_column: str = 'context',
- max_sequence_length: int = 1024,
- doc_stride: int = 128,
Bases:
CollatorData collator for question answering tasks.
Handles formatting for question answering tasks like SQuAD. Pads pretokenized question-context pairs and preserves start/end positions for answer span prediction.
- Parameters:
pad_token_id (int) – Token ID used for padding sequences.
mask_token_id (int) – Token ID used for masking sequences.
vocab_size (int) – Size of the tokenizer vocabulary.
text_column (str) – Name of the column containing question data. Defaults to “question”.
label_column (str) – Name of the column containing answer data. Defaults to “answers”.
context_column (str) – Name of the column containing context data. Defaults to “context”.
max_sequence_length (int) – Maximum sequence length for padding. Defaults to 1024.
doc_stride (int) – Stride for sliding window when context is too long. Defaults to 128.
- class bertblocks.training.objectives.MaskedDiffusionCollator(
- pad_token_id: int,
- mask_token_id: int,
- vocab_size: int,
- text_column: str = 'text',
- max_sequence_length: int = 1024,
- num_steps: int = 1000,
- sampling_eps: float = 0.1,
- noise_eps: float = 0.001,
- min_masked: int | None = None,
- max_masked: int | None = None,
Bases:
CollatorDiffusion-based MLM collator that samples random masking probabilities per batch.
Randomly masks tokens using a sampled masking probability from a diffusion schedule. Expects input to be already tokenized (handled in the data module).
- Parameters:
pad_token_id (int) – Token ID used for padding sequences.
mask_token_id (int) – Token ID to use for masking.
vocab_size (int) – Size of the tokenizer vocabulary.
text_column (str) – The name of the column containing text data.
max_sequence_length (int) – The maximum length for padding.
num_steps (int) – Number of diffusion steps.
sampling_eps (float) – Minimum timestep for sampling, controls minimum masking probability.
noise_eps (float) – Epsilon for the noise schedule, controls maximum masking probability (1 - noise_eps).
min_masked (int | None) – Minimum number of masked tokens per sequence. If None, not enforced.
max_masked (int | None) – Maximum number of masked tokens per sequence. If None, not enforced.
- bertblocks.training.objectives.get_collator_cls(
- objective: Literal['mlm', 'enhanced_mlm', 'classification', 'token_classification', 'question_answering', 'diffusion'],
Get the appropriate data collator for the given objective.
- Parameters:
objective (str) – The training objective. Available options: “mlm”, “enhanced_mlm”, “classification”, “token_classification”, “question_answering”, “diffusion”.
- Raises:
ValueError – If the objective is unknown.
- Returns:
The corresponding data collator class.
- Return type:
Optimizers¶
- bertblocks.training.optimizer.get_optimizer(
- optimizer_name: Literal['adafactor', 'adagrad', 'adam', 'adamw', 'galore', 'lamb', 'lars', 'lion', 'muon', 'sgd', 'shampoo', 'soap', 'sophiah', 'splus', 'rmsprop'],
- params: list[dict[str, Any]],
- optimizer_kwargs: dict[str, Any],
- quantized: bool = False,
Instantiate a specific optimizer with params and hyperparameters and return it.
- Parameters:
- Returns:
The instantiated optimizer.
- Raises:
ValueError – If the specified optimizer name is not recognized.
ImportError – If the specified optimizer name requires a non-installed optional dependency.
- Supported optimizer types:
adafactor
adagrad
adam
adamw
galore
lamb
lars
lion
muon
sgd
shampoo
soap
sophiah
splus
rmsprop
Schedulers¶
- class bertblocks.training.scheduler.InverseSqrtScheduler(optimizer: Optimizer, cooldown_steps: int, last_epoch: int = -1)[source]¶
Bases:
LambdaLRA scheduler that applies inverse sqrt scaling.
Scaling is a function of the steps (“1-sqrt cooldown”), where learning rate is scaled (down) 1 - sqrt(current_step/total_steps). This scheduler is intended for cooldown phases. It has been reported to be superior to linear cooldown and is presented as an alternative to cosine decay.
- Parameters:
optimizer – The optimizer to use.
cooldown_steps – The number of cooldown steps. If this number is exceeded, the scheduler will transition to a constant learning rate of 0.0.
last_epoch – The index of the last epoch. Defaults to -1.
References
Scaling Laws and Compute-Optimal Training Beyond Fixed Training Durations (https://arxiv.org/abs/2405.18392)
- bertblocks.training.scheduler.get_scheduler(
- optimizer: Optimizer,
- warmup_kind: Literal['linear', 'constant'] = 'linear',
- warmup_steps: int = 0,
- warmup_decay: float = 0.0,
- training_kind: Literal['constant', 'linear', 'cosine', 'exponential'] = 'constant',
- training_steps: int = -1,
- training_decay: float = 1.0,
- cooldown_kind: Literal['linear', 'inverse-sqrt', 'exponential'] = 'linear',
- cooldown_steps: int = 0,
- cooldown_decay: float = 0.0,
Construct a sequential learning rate schedule with three phases warmup, training, and cooldown.
- Parameters:
optimizer – The optimizer to schedule the learning rate for.
warmup_kind – Kind of scheduler for warmup phase. Defaults to ‘linear’.
warmup_steps – Duration in steps for warmup phase. Defaults to 0 (no warmup).
warmup_decay – Decay value for warmup phase; has different effect depending on warmup_kind. Defaults to 0.0.
training_kind – Kind of scheduler for training phase.
training_steps – Duration in steps for training phase.
training_decay – Decay value for training phase; has different effect depending on cooldown_kind. Defaults to 1.0 (constant learning rate).
cooldown_kind – Kind of scheduler for cooldown phase. Defaults to ‘linear’.
cooldown_steps – Duration in steps for cooldown phase. Defaults to 0 (no cooldown).
cooldown_decay – Decay value for warmup phase; has different effect depending on cooldown_kind. Defaults to 0.0.
- Returns:
Learning rate scheduler with sequential phases as defined.
- bertblocks.training.scheduler.get_single_scheduler(
- optimizer: Optimizer,
- kind: Literal['constant', 'linear', 'inverse-sqrt', 'exponential', 'cosine'] = 'constant',
- num_steps: int = 0,
- decay: float = 0.0,
- direction: Literal['increase', 'decrease'] = 'increase',
Return the corresponding instantiated scheduler for configuration provided.
- Parameters:
optimizer – Optimizer to schedule the learning rate for.
kind – Kind of scheduler. One of ‘constant’, ‘linear’, ‘cosine’, ‘exponential’. Defaults to ‘constant’.
num_steps – Number of steps to schedule. Defaults to 0.
decay – Decay value, depending on kind of scheduler. If ‘constant’, is applied as factor; if ‘linear’, is applied as start_factor; if ‘exponential’, is applied as gamma; if ‘cosine’, is applied as eta_min.
direction – Direction of scheduler. Defaults to ‘increase’.
- Returns:
The specified scheduler.