Architecture Overview

BertBlocks implements a modular transformer encoder where each component can be independently configured and swapped. This page describes how the pieces fit together.

Model Structure

A BertBlocks model consists of three main stages:

  1. Embedding – Token and (optionally) token-type embeddings, plus optional embedding-level positional encodings

  2. Encoder – A stack of transformer blocks

  3. Head – Task-specific output heads (MLM, classification, QA, etc.)

            graph TD
        Input["Input IDs"]
        Embedding["<b>Embedding</b>"]
        Block["<b>Encoder Block ×N</b>"]
        Head["<b>Task Head</b>"]

        Input --> Embedding
        Embedding --> Block
        Block --> Head
    

Transformer Block

Each Block contains:

  • Multi-head attention (Attention) with configurable head count, dropout, and optional grouped query attention (GQA)

  • Feed-forward network (get_mlp()) – standard MLP or Gated Linear Unit (GLU)

  • Normalization (get_norm()) applied pre-attention, post-attention, or both

The normalization position is controlled by the norm_pos config key and supports "pre", "post", "pre_and_post", or "none".

Attention

The Attention module supports:

  • Positional encodings at the block level: ALiBi, RoPE (applied per-block, not at embedding time)

  • Backends: Flash Attention, SDPA, or eager (plain PyTorch), selected via the attention_backend config key

  • Grouped Query Attention: set num_kv_heads < num_attention_heads

  • QK normalization: via qk_norm

  • Local attention: via local_attention and local_attention_window_size

Backends are implemented as AttentionBackend subclasses and handle both padded and unpadded (variable-length) sequences.

Positional Encodings

BertBlocks supports positional encodings at two levels:

Level

Config key

Options

Embedding

embd_pos_enc_kind

"sinusoidal", "learned", "none"

Block

block_pos_enc_kind

"alibi", "rope", "none"

Embedding-level encodings are added once to the token embeddings. Block-level encodings are applied inside each attention computation.

Normalization

Available normalization functions (set via norm_fn):

Value

Class

"layer"

torch.nn.LayerNorm

"rms"

torch.nn.RMSNorm

"group"

torch.nn.GroupNorm

"deep"

DeepNorm

"dynamic_tanh"

DynamicTanhNorm

Feed-Forward Networks

The mlp_type config key selects between:

  • "mlp" – Standard two-layer MLP with activation

  • "glu" – Gated Linear Unit with configurable activation (SwiGLU, GeGLU, etc.)

Both are accessed through get_mlp().

Task Heads

BertBlocks provides several task-specific model wrappers:

Class

Task

BertBlocksForMaskedLM

Masked language modeling

BertBlocksForSequenceClassification

Sequence classification

BertBlocksForTokenClassification

Token classification (NER)

BertBlocksForQuestionAnswering

Extractive QA

BertBlocksForMaskedDiffusion

Masked diffusion modeling

All inherit from BertBlocksPreTrainedModel and are compatible with HuggingFace’s AutoModel registry.