Attention Mechanism

The attention mechanism allows a model to focus on relevant parts of an input sequence when producing each element of an output sequence. It was introduced in the context of neural machine translation and later became the foundation of the transformer architecture.

Scaled Dot-Product Attention

Given queries $Q$, keys $K$, and values $V$, attention is computed as:

$$ \text{Attention}(Q, K, V) = \text{softmax}!\left(\frac{QK^\top}{\sqrt{d_k}}\right)V $$

The scaling factor $\frac{1}{\sqrt{d_k}}$ prevents the dot products from growing large in magnitude when $d_k$ is large, which would push the [[softmax-function|softmax]] into regions of extremely small gradients.

Implementation

import torch
import torch.nn.functional as F

def scaled_dot_product_attention(Q, K, V, mask=None):
    d_k = Q.size(-1)
    scores = Q @ K.transpose(-2, -1) / d_k ** 0.5
    if mask is not None:
        scores = scores.masked_fill(mask == 0, float('-inf'))
    weights = F.softmax(scores, dim=-1)
    return weights @ V, weights

Multi-Head Attention

Rather than performing a single attention pass, multi-head attention projects $Q$, $K$, and $V$ into $h$ different subspaces and computes attention in each:

$$ \text{MultiHead}(Q,K,V) = \text{Concat}(\text{head}_1,\ldots,\text{head}_h),W^O $$

where $\text{head}_i = \text{Attention}(QW_i^Q,, KW_i^K,, VW_i^V)$.