Training

The bertblocks.training package provides PyTorch Lightning modules for pretraining and finetuning, along with data loading, objectives, optimizers, and schedulers.

Lightning Modules

class bertblocks.training.modules.BertBlocksPretrainingModule(
learning_rate: float | None = 1e-05,
weight_decay: float | None = 0.001,
compile_model: bool | None = True,
pretrained_tokenizer_name_or_path: str | None = None,
optimizer_class: str | None = 'adamw',
optimizer_quantized: bool = False,
optimizer_kwargs: dict[str, Any] | None = None,
scheduler_warmup_kind: Literal['constant', 'linear', 'cosine', 'exponential'] | None = 'linear',
scheduler_warmup_steps: int | None = 1000,
scheduler_warmup_decay: float = 0.1,
scheduler_training_kind: Literal['constant', 'linear', 'cosine', 'exponential'] = 'constant',
scheduler_training_steps: int = -1,
scheduler_training_decay: float = 1.0,
scheduler_cooldown_kind: Literal['constant', 'linear', 'inverse-sqrt', 'cosine', 'exponential'] = 'linear',
scheduler_cooldown_steps: int = 0,
scheduler_cooldown_decay: float = 0.0,
objective: Literal['mlm', 'enhanced_mlm', 'diffusion'] = 'mlm',
gradient_checkpointing: bool = False,
model_config_kwargs: dict[str, Any] | None = None,
model_kwargs: dict[str, Any] | None = None,
)[source]

Bases: LightningModule

PyTorch Lightning module for BertBlocks MLM pretraining.

This module encapsulates the complete training logic for BertBlocks pretraining, including model initialization, optimization setup, and training step implementation. It supports advanced features like model compilation, sophisticated learning rate scheduling, and automatic checkpoint saving.

The module automatically configures the BertBlocks model based on the provided hyperparameters and handles all aspects of the training loop.

configure_model() None[source]

Compile the model core after DDP setup so each rank has its own CUDA context.

Only the inner compute path (embed -> encode -> norm -> scale) is compiled. Unpadding and output assembly remain outside the compiler scope to avoid dynamic shape issues from varying batch dimensions.

configure_optimizers() tuple[list[Optimizer], list[dict[str, Any]]][source]

Configure optimizers and learning rate schedulers.

Sets up optimizer with weight decay only applied to non-bias parameters (excluding norm parameters). Uses a sequential learning rate schedule with linear warmup followed by exponential decay.

Returns:

Contains optimizer list and scheduler configuration list.

The scheduler is configured to update every step during training.

Return type:

tuple

on_before_optimizer_step(optimizer: Optimizer) None[source]

Log grad norms at each optimizer step.

on_save_checkpoint(*args: Any, **kwargs: Any) None[source]

Save model checkpoint in HuggingFace format.

This method is called whenever Lightning saves a checkpoint and additionally saves the model in HuggingFace format for easy loading and deployment. Only saves on the main process in distributed training.

Parameters:
  • *args – Variable arguments (unused).

  • **kwargs – Keyword arguments (unused).

training_step(batch: dict[str, Tensor], batch_idx: int) Tensor[source]

Perform a single training step.

Parameters:
  • batch – Batch dictionary containing ‘input_ids’, ‘attention_mask’, and ‘labels’ tensors from the MLM collator.

  • batch_idx – Index of the current batch (unused but required by Lightning).

Returns:

MLM loss for backpropagation.

Return type:

torch.Tensor

validation_step(
batch: dict[str, Tensor],
batch_idx: int,
dataloader_idx: int = 0,
) Tensor[source]

Perform a single validation step.

class bertblocks.training.modules.BertBlocksFinetuningModule(
task: Literal['classification', 'token_classification', 'question_answering'],
pretrained_model_name_or_path: str,
num_labels: int | None = None,
learning_rate: float = 1e-05,
weight_decay: float = 0.01,
compile_model: bool = True,
optimizer_class: str = 'adamw',
optimizer_kwargs: dict[str, Any] | None = None,
optimizer_quantized: bool = False,
scheduler_type: Literal['linear', 'cosine', 'constant', 'polynomial'] | None = None,
scheduler_kwargs: dict[str, Any] | None = None,
warmup_steps: int = 0,
warmup_ratio: float = 0.0,
gradient_checkpointing: bool = False,
)[source]

Bases: LightningModule

PyTorch Lightning module for BertBlocks finetuning.

This module handles finetuning of pretrained BertBlocks models on downstream tasks including classification, token classification, and question answering.

configure_optimizers() Optimizer | dict[str, Any][source]

Configure optimizers and learning rate schedulers.

test_step(batch: dict[str, Tensor], batch_idx: int) Tensor[source]

Perform test step.

training_step(batch: dict[str, Tensor], batch_idx: int) Tensor[source]

Perform training step.

validation_step(batch: dict[str, Tensor], batch_idx: int) Tensor[source]

Perform validation step.

Data Modules

class bertblocks.training.modules.BertBlocksPretrainingDataModule(
train_dataset_name_or_path: str | list[str],
pretrained_tokenizer_name_or_path: str,
objective: Literal['mlm', 'enhanced_mlm', 'diffusion'] = 'mlm',
max_sequence_length: int = 512,
val_dataset_name_or_path: str | list[str] | None = None,
train_split: str | None = None,
val_split: str | list[str] | None = None,
file_format: str | None = None,
text_column: str = 'text',
streaming: bool = False,
shuffle: bool = False,
num_shards: int | None = None,
train_batch_size: int = 32,
val_batch_size: int = 32,
num_workers: int = 0,
collator_kwargs: dict[str, Any] | None = None,
packing: bool = False,
packing_pad_to_budget: bool = False,
cache_dir: str | None = None,
data_kwargs: dict[str, Any] | None = None,
)[source]

Bases: LightningDataModule

PyTorch Lightning DataModule for BertBlocks MLM pretraining.

This DataModule handles all aspects of data loading for pretraining, including dataset preparation, tokenization, and batch creation. Currently configured to use the TinyStories dataset but can be easily adapted for other datasets.

The module supports streaming datasets for large-scale pretraining and includes configurable batch sizes and data loading parameters.

prepare_data() None[source]

Download and tokenize data. Called only on LOCAL_RANK=0 per node (Lightning handles this automatically).

setup(stage: str | None = None) None[source]

Load the dataset for training. Called on every process - loads from cache populated by prepare_data().

train_dataloader() DataLoader[source]

Create the training data loader.

Returns:

DataLoader configured for pretraining, optionally with packed batches.

val_dataloader() DataLoader | list[DataLoader][source]

Create the validation data loader(s).

Returns:

Single DataLoader, list of DataLoaders for multiple validation sets, or dummy dataloader if no validation dataset is set.

class bertblocks.training.modules.BertBlocksFinetuningDataModule(
task: Literal['classification', 'question_answering', 'token_classification'],
pretrained_tokenizer_name_or_path: str,
dataset_name_or_path: str | None = None,
dataset_config_name: str | None = None,
max_sequence_length: int = 512,
train_split: str = 'train',
val_split: str = 'validation',
test_split: str = 'test',
text_column: str = 'text',
label_column: str = 'label',
train_batch_size: int = 32,
val_batch_size: int = 32,
test_batch_size: int = 32,
num_workers: int = 0,
shuffle_train: bool = True,
collator_kwargs: dict[str, Any] | None = None,
)[source]

Bases: LightningDataModule

PyTorch Lightning DataModule for finetuning tasks.

Supports classification, token classification, and question answering tasks with flexible dataset loading from HuggingFace Hub or local files.

prepare_data() None[source]

Download datasets if needed. Called once per node.

set_datasets(
train: Dataset | None = None,
val: Dataset | None = None,
test: Dataset | None = None,
) None[source]

Manually set datasets for custom data loading.

Parameters:
  • train – Training dataset.

  • val – Validation dataset.

  • test – Test dataset.

setup(stage: str | None = None) None[source]

Set up datasets for each process. Called on every process.

Parameters:

stage – Current stage (‘fit’, ‘validate’, ‘test’, ‘predict’).

test_dataloader() DataLoader[source]

Create the test data loader.

Returns:

DataLoader configured for testing or None if no test dataset.

train_dataloader() DataLoader[source]

Create the training data loader.

Returns:

DataLoader configured for training or None if no training dataset.

val_dataloader() DataLoader[source]

Create the validation data loader.

Returns:

DataLoader configured for validation or None if no validation dataset.

Training Objectives (Collators)

Collators prepare batches for specific training objectives. Each collator handles tokenization, masking, and label creation for its task.

class bertblocks.training.objectives.Collator(
pad_token_id: int,
mask_token_id: int,
vocab_size: int,
text_column: str = 'text',
label_column: str | None = None,
max_sequence_length: int = 1024,
)[source]

Abstract data collator class for pretraining tasks.

A data collator is responsible for processing pretokenized input data into a format suitable for model training. This includes padding and applying any necessary transformations such as masking for language modeling tasks.

Inherited classes must implement the compute_labels method.

abstractmethod compute_labels(tokenized: dict[str, Any]) dict[str, Any][source]

Compute the labels for the given batch of tokenized inputs.

Parameters:

tokenized (dict[str, Any]) – The tokenized inputs for the batch.

Returns:

The computed labels for the batch.

Return type:

dict[str, Any]

class bertblocks.training.objectives.MaskedLanguageModelingCollator(
pad_token_id: int,
mask_token_id: int,
vocab_size: int,
text_column: str = 'text',
max_sequence_length: int = 1024,
mlm_probability: float = 0.3,
)[source]

Bases: Collator

Data collator for masked language modeling pretraining.

Applies masking to create MLM training examples.

Parameters:

pad_token_id – Token ID used for padding sequences.

compute_labels(tokenized: dict[str, Any]) Any[source]

Compute the MLM labels for the given batch of tokenized inputs.

Applies masking following the BERT MLM strategy for tokens sampled with MLM probability: - 80% of the time: replace token with [MASK] - 10% of the time: replace token with random token - 10% of the time: keep token unchanged

Parameters:

tokenized (dict[str, Any]) – The tokenized inputs for the batch.

Returns:

The computed MLM labels for the batch.

Return type:

dict[str, Any]

class bertblocks.training.objectives.EnhancedMaskedLanguageModelingCollator(
pad_token_id: int,
mask_token_id: int,
vocab_size: int,
text_column: str = 'text',
label_column: str | None = None,
max_sequence_length: int = 1024,
)[source]

Bases: Collator

Data collator for enhanced masked language modeling pretraining.

Prepares pretokenized sequences for enhanced MLM. Unlike standard MLM, masking is not applied in the collator but is handled by the model itself.

Expects input to be already tokenized (handled in the data module).

compute_labels(tokenized: dict[str, Any]) dict[str, Any][source]

Compute labels for enhanced masked language modeling.

Parameters:

tokenized (dict[str, Any]) – The tokenized inputs for the batch.

Returns:

The computed labels for the batch.

Return type:

dict[str, Any]

class bertblocks.training.objectives.TokenClassificationCollator(
pad_token_id: int,
mask_token_id: int,
vocab_size: int,
text_column: str = 'text',
label_column: str = 'labels',
max_sequence_length: int = 1024,
)[source]

Bases: Collator

Data collator for token classification tasks.

Handles formatting for token classification tasks like NER, POS tagging, etc. Pads pretokenized sequences and aligns token-level labels.

Expects input to be already tokenized (handled in the data module).

Parameters:
  • pad_token_id (int) – Token ID used for padding sequences.

  • mask_token_id (int) – Token ID used for masking sequences.

  • vocab_size (int) – Size of the tokenizer vocabulary.

  • text_column (str) – Name of the column containing text data. Defaults to “text”.

  • label_column (str) – Name of the column containing label data. Defaults to “labels”.

  • max_sequence_length (int) – Maximum sequence length for padding. Defaults to 1024.

compute_labels(tokenized: dict[str, Any]) dict[str, Any][source]

Compute token classification labels for the given batch.

For token classification, we need to ensure labels align with tokenized input. The labels should be padded to match the sequence length and use -100 for special tokens and padding.

Parameters:

tokenized (dict[str, Any]) – The tokenized inputs for the batch.

Returns:

The tokenized inputs with properly formatted labels.

Return type:

dict[str, Any]

class bertblocks.training.objectives.SequenceClassificationCollator(
pad_token_id: int,
mask_token_id: int,
vocab_size: int,
text_column: str = 'text',
label_column: str = 'label',
max_sequence_length: int = 1024,
)[source]

Bases: Collator

Data collator for sequence classification tasks.

Handles formatting for sequence classification tasks like sentiment analysis, text classification, etc. Pads pretokenized sequences and preserves sequence-level labels.

Parameters:
  • pad_token_id (int) – Token ID used for padding sequences.

  • mask_token_id (int) – Token ID used for masking sequences.

  • vocab_size (int) – Size of the tokenizer vocabulary.

  • text_column (str) – Name of the column containing text data. Defaults to “text”.

  • label_column (str) – Name of the column containing label data. Defaults to “label”.

  • max_sequence_length (int) – Maximum sequence length for padding. Defaults to 1024.

compute_labels(tokenized: dict[str, Any]) dict[str, Any][source]

Compute sequence classification labels for the given batch.

For sequence classification, we just need to preserve the original labels as they apply to the entire sequence.

Parameters:

tokenized (dict[str, Any]) – The tokenized inputs for the batch.

Returns:

The tokenized inputs with labels preserved.

Return type:

dict[str, Any]

class bertblocks.training.objectives.QuestionAnsweringCollator(
pad_token_id: int,
mask_token_id: int,
vocab_size: int,
text_column: str = 'question',
label_column: str = 'answers',
context_column: str = 'context',
max_sequence_length: int = 1024,
doc_stride: int = 128,
)[source]

Bases: Collator

Data collator for question answering tasks.

Handles formatting for question answering tasks like SQuAD. Pads pretokenized question-context pairs and preserves start/end positions for answer span prediction.

Parameters:
  • pad_token_id (int) – Token ID used for padding sequences.

  • mask_token_id (int) – Token ID used for masking sequences.

  • vocab_size (int) – Size of the tokenizer vocabulary.

  • text_column (str) – Name of the column containing question data. Defaults to “question”.

  • label_column (str) – Name of the column containing answer data. Defaults to “answers”.

  • context_column (str) – Name of the column containing context data. Defaults to “context”.

  • max_sequence_length (int) – Maximum sequence length for padding. Defaults to 1024.

  • doc_stride (int) – Stride for sliding window when context is too long. Defaults to 128.

compute_labels(tokenized: dict[str, Any]) dict[str, Any][source]

Compute question answering labels for the given batch.

For QA, we need to find the start and end positions of the answer spans within the tokenized context.

Parameters:

tokenized (dict[str, Any]) – The tokenized inputs for the batch.

Returns:

The tokenized inputs with start/end positions.

Return type:

dict[str, Any]

class bertblocks.training.objectives.MaskedDiffusionCollator(
pad_token_id: int,
mask_token_id: int,
vocab_size: int,
text_column: str = 'text',
max_sequence_length: int = 1024,
num_steps: int = 1000,
sampling_eps: float = 0.1,
noise_eps: float = 0.001,
min_masked: int | None = None,
max_masked: int | None = None,
)[source]

Bases: Collator

Diffusion-based MLM collator that samples random masking probabilities per batch.

Randomly masks tokens using a sampled masking probability from a diffusion schedule. Expects input to be already tokenized (handled in the data module).

Parameters:
  • pad_token_id (int) – Token ID used for padding sequences.

  • mask_token_id (int) – Token ID to use for masking.

  • vocab_size (int) – Size of the tokenizer vocabulary.

  • text_column (str) – The name of the column containing text data.

  • max_sequence_length (int) – The maximum length for padding.

  • num_steps (int) – Number of diffusion steps.

  • sampling_eps (float) – Minimum timestep for sampling, controls minimum masking probability.

  • noise_eps (float) – Epsilon for the noise schedule, controls maximum masking probability (1 - noise_eps).

  • min_masked (int | None) – Minimum number of masked tokens per sequence. If None, not enforced.

  • max_masked (int | None) – Maximum number of masked tokens per sequence. If None, not enforced.

compute_labels(tokenized: dict[str, Any]) Any[source]

Compute the denoising MLM labels for the given batch of tokenized inputs.

Parameters:

tokenized (dict[str, Any]) – The tokenized inputs for the batch.

Returns:

The computed MLM labels for the batch.

Return type:

dict[str, Any]

bertblocks.training.objectives.get_collator_cls(
objective: Literal['mlm', 'enhanced_mlm', 'classification', 'token_classification', 'question_answering', 'diffusion'],
) type[Collator][source]

Get the appropriate data collator for the given objective.

Parameters:

objective (str) – The training objective. Available options: “mlm”, “enhanced_mlm”, “classification”, “token_classification”, “question_answering”, “diffusion”.

Raises:

ValueError – If the objective is unknown.

Returns:

The corresponding data collator class.

Return type:

type[Collator]

Optimizers

bertblocks.training.optimizer.get_optimizer(
optimizer_name: Literal['adafactor', 'adagrad', 'adam', 'adamw', 'galore', 'lamb', 'lars', 'lion', 'muon', 'sgd', 'shampoo', 'soap', 'sophiah', 'splus', 'rmsprop'],
params: list[dict[str, Any]],
optimizer_kwargs: dict[str, Any],
quantized: bool = False,
) Optimizer[source]

Instantiate a specific optimizer with params and hyperparameters and return it.

Parameters:
  • optimizer_name (str) – Name of the optimizer class to instantiate.

  • params (list[dict[str, Any]]) – Model parameters to be optimized.

  • optimizer_kwargs (dict[str, Any]) – Optional hyperparameters to pass to optimizer.

  • quantized (bool) – Whether to use a 8bit-quantized optimizer variant.

Returns:

The instantiated optimizer.

Raises:
  • ValueError – If the specified optimizer name is not recognized.

  • ImportError – If the specified optimizer name requires a non-installed optional dependency.

Supported optimizer types:
  • adafactor

  • adagrad

  • adam

  • adamw

  • galore

  • lamb

  • lars

  • lion

  • muon

  • sgd

  • shampoo

  • soap

  • sophiah

  • splus

  • rmsprop

Schedulers

class bertblocks.training.scheduler.InverseSqrtScheduler(optimizer: Optimizer, cooldown_steps: int, last_epoch: int = -1)[source]

Bases: LambdaLR

A scheduler that applies inverse sqrt scaling.

Scaling is a function of the steps (“1-sqrt cooldown”), where learning rate is scaled (down) 1 - sqrt(current_step/total_steps). This scheduler is intended for cooldown phases. It has been reported to be superior to linear cooldown and is presented as an alternative to cosine decay.

Parameters:
  • optimizer – The optimizer to use.

  • cooldown_steps – The number of cooldown steps. If this number is exceeded, the scheduler will transition to a constant learning rate of 0.0.

  • last_epoch – The index of the last epoch. Defaults to -1.

References

lr_lambda(current_step: int) float[source]

Adjust the learning rate based on cooldown steps.

bertblocks.training.scheduler.get_scheduler(
optimizer: Optimizer,
warmup_kind: Literal['linear', 'constant'] = 'linear',
warmup_steps: int = 0,
warmup_decay: float = 0.0,
training_kind: Literal['constant', 'linear', 'cosine', 'exponential'] = 'constant',
training_steps: int = -1,
training_decay: float = 1.0,
cooldown_kind: Literal['linear', 'inverse-sqrt', 'exponential'] = 'linear',
cooldown_steps: int = 0,
cooldown_decay: float = 0.0,
) LRScheduler[source]

Construct a sequential learning rate schedule with three phases warmup, training, and cooldown.

Parameters:
  • optimizer – The optimizer to schedule the learning rate for.

  • warmup_kind – Kind of scheduler for warmup phase. Defaults to ‘linear’.

  • warmup_steps – Duration in steps for warmup phase. Defaults to 0 (no warmup).

  • warmup_decay – Decay value for warmup phase; has different effect depending on warmup_kind. Defaults to 0.0.

  • training_kind – Kind of scheduler for training phase.

  • training_steps – Duration in steps for training phase.

  • training_decay – Decay value for training phase; has different effect depending on cooldown_kind. Defaults to 1.0 (constant learning rate).

  • cooldown_kind – Kind of scheduler for cooldown phase. Defaults to ‘linear’.

  • cooldown_steps – Duration in steps for cooldown phase. Defaults to 0 (no cooldown).

  • cooldown_decay – Decay value for warmup phase; has different effect depending on cooldown_kind. Defaults to 0.0.

Returns:

Learning rate scheduler with sequential phases as defined.

bertblocks.training.scheduler.get_single_scheduler(
optimizer: Optimizer,
kind: Literal['constant', 'linear', 'inverse-sqrt', 'exponential', 'cosine'] = 'constant',
num_steps: int = 0,
decay: float = 0.0,
direction: Literal['increase', 'decrease'] = 'increase',
) LRScheduler | None[source]

Return the corresponding instantiated scheduler for configuration provided.

Parameters:
  • optimizer – Optimizer to schedule the learning rate for.

  • kind – Kind of scheduler. One of ‘constant’, ‘linear’, ‘cosine’, ‘exponential’. Defaults to ‘constant’.

  • num_steps – Number of steps to schedule. Defaults to 0.

  • decay – Decay value, depending on kind of scheduler. If ‘constant’, is applied as factor; if ‘linear’, is applied as start_factor; if ‘exponential’, is applied as gamma; if ‘cosine’, is applied as eta_min.

  • direction – Direction of scheduler. Defaults to ‘increase’.

Returns:

The specified scheduler.

Metrics

bertblocks.training.metrics.get_metrics_for_task(
task: Literal['classification', 'token_classification', 'question_answering'],
num_labels: int = 2,
) dict[str, Metric][source]

Get default metrics for a given task.

Parameters:
  • task – The task type to get metrics for.

  • num_labels – The number of labels to use for the task.

Returns:

Dictionary mapping metric names to torchmetrics.Metric instances.

Sequence Packing