Multi-Head Attention

[mùlti-hed attènscion]

Versione potenziata del self-attention dove più "teste" guardano la frase in parallelo, ognuna catturando relazioni diverse tra le parole.

Il multi-head attention è il self-attention moltiplicato per N. Invece di un solo meccanismo di attenzione, ne metti 8, 16, 32 in parallelo. Ognuno impara a guardare la frase da un'angolazione diversa.

Perché? Perché una frase ha tanti tipi di relazioni insieme: chi fa cosa, soggetto-verbo, riferimenti pronominali, contesto temporale, sfumature semantiche. Una sola "testa" non ce la fa a catturare tutto. Otto teste fanno otto lavori diversi e poi i risultati si combinano.

Esempio concreto: nella frase "Maria ha detto a Luca che lui doveva tornare", una testa potrebbe seguire il riferimento "lui → Luca", un'altra la struttura "soggetto → verbo", un'altra ancora il contesto temporale.

Numeri tipici nei modelli reali:

  • GPT-3: 96 teste per layer.
  • Llama 3 70B: 64 teste.
  • Modelli piccoli (7B): 32 teste.

Ogni testa ha le sue Query, Key, Value separate. I parametri totali del modello crescono ma il calcolo si parallelizza bene sulle GPU.

Negli ultimi modelli si usano varianti ottimizzate come Grouped-Query Attention (GQA) e Multi-Query Attention (MQA): meno teste per le Key/Value, riducono memoria senza perdere troppo in qualità. Llama 3 e Mistral le usano.