Il multi-head attention è il self-attention moltiplicato per N. Invece di un solo meccanismo di attenzione, ne metti 8, 16, 32 in parallelo. Ognuno impara a guardare la frase da un'angolazione diversa.
Perché? Perché una frase ha tanti tipi di relazioni insieme: chi fa cosa, soggetto-verbo, riferimenti pronominali, contesto temporale, sfumature semantiche. Una sola "testa" non ce la fa a catturare tutto. Otto teste fanno otto lavori diversi e poi i risultati si combinano.
Esempio concreto: nella frase "Maria ha detto a Luca che lui doveva tornare", una testa potrebbe seguire il riferimento "lui → Luca", un'altra la struttura "soggetto → verbo", un'altra ancora il contesto temporale.
Numeri tipici nei modelli reali:
- GPT-3: 96 teste per layer.
- Llama 3 70B: 64 teste.
- Modelli piccoli (7B): 32 teste.
Ogni testa ha le sue Query, Key, Value separate. I parametri totali del modello crescono ma il calcolo si parallelizza bene sulle GPU.
Negli ultimi modelli si usano varianti ottimizzate come Grouped-Query Attention (GQA) e Multi-Query Attention (MQA): meno teste per le Key/Value, riducono memoria senza perdere troppo in qualità. Llama 3 e Mistral le usano.