Attention Mechanism
The attention mechanism allows a model to focus on relevant parts of an input sequence when producing each element of an output sequence. It was introduced in the context of neural machine translation and later became the foundation of the transformer architecture.
Scaled Dot-Product Attention
Given queries $Q$, keys $K$, and values $V$, attention is computed as:
$$ \text{Attention}(Q, K, V) = \text{softmax}!\left(\frac{QK^\top}{\sqrt{d_k}}\right)V $$
The scaling factor $\frac{1}{\sqrt{d_k}}$ prevents the dot products from growing large in magnitude when $d_k$ is large, which would push the [[softmax-function|softmax]] into regions of extremely small gradients.
Implementation
import torch
import torch.nn.functional as F
def scaled_dot_product_attention(Q, K, V, mask=None):
d_k = Q.size(-1)
scores = Q @ K.transpose(-2, -1) / d_k ** 0.5
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
weights = F.softmax(scores, dim=-1)
return weights @ V, weights
Multi-Head Attention
Rather than performing a single attention pass, multi-head attention projects $Q$, $K$, and $V$ into $h$ different subspaces and computes attention in each:
$$ \text{MultiHead}(Q,K,V) = \text{Concat}(\text{head}_1,\ldots,\text{head}_h),W^O $$
where $\text{head}_i = \text{Attention}(QW_i^Q,, KW_i^K,, VW_i^V)$.