Architecture Overview¶
BertBlocks implements a modular transformer encoder where each component can be independently configured and swapped. This page describes how the pieces fit together.
Model Structure¶
A BertBlocks model consists of three main stages:
Embedding – Token and (optionally) token-type embeddings, plus optional embedding-level positional encodings
Encoder – A stack of transformer blocks
Head – Task-specific output heads (MLM, classification, QA, etc.)
graph TD
Input["Input IDs"]
Embedding["<b>Embedding</b>"]
Block["<b>Encoder Block ×N</b>"]
Head["<b>Task Head</b>"]
Input --> Embedding
Embedding --> Block
Block --> Head
Transformer Block¶
Each Block contains:
Multi-head attention (
Attention) with configurable head count, dropout, and optional grouped query attention (GQA)Feed-forward network (
get_mlp()) – standard MLP or Gated Linear Unit (GLU)Normalization (
get_norm()) applied pre-attention, post-attention, or both
The normalization position is controlled by the norm_pos config key and supports "pre", "post", "pre_and_post", or "none".
Attention¶
The Attention module supports:
Positional encodings at the block level: ALiBi, RoPE (applied per-block, not at embedding time)
Backends: Flash Attention, SDPA, or eager (plain PyTorch), selected via the
attention_backendconfig keyGrouped Query Attention: set
num_kv_heads<num_attention_headsQK normalization: via
qk_normLocal attention: via
local_attentionandlocal_attention_window_size
Backends are implemented as AttentionBackend subclasses and handle both padded and unpadded (variable-length) sequences.
Positional Encodings¶
BertBlocks supports positional encodings at two levels:
Level |
Config key |
Options |
|---|---|---|
Embedding |
|
|
Block |
|
|
Embedding-level encodings are added once to the token embeddings. Block-level encodings are applied inside each attention computation.
Normalization¶
Available normalization functions (set via norm_fn):
Value |
Class |
|---|---|
|
|
|
|
|
|
|
|
|
Feed-Forward Networks¶
The mlp_type config key selects between:
"mlp"– Standard two-layer MLP with activation"glu"– Gated Linear Unit with configurable activation (SwiGLU, GeGLU, etc.)
Both are accessed through get_mlp().
Task Heads¶
BertBlocks provides several task-specific model wrappers:
Class |
Task |
|---|---|
Masked language modeling |
|
Sequence classification |
|
Token classification (NER) |
|
Extractive QA |
|
Masked diffusion modeling |
All inherit from BertBlocksPreTrainedModel and are compatible with HuggingFace’s AutoModel registry.