How Do Transformers Learn to Associate Tokens: Gradient Leading Terms Bring Mechanistic Interpretability
Shawn Im, Changdae Oh, Zhen Fang, Sharon Li
Abstract
Semantic associations such as the link between "bird" and "flew" are foundational for language modeling as they enable models to go beyond memorization and instead generalize and generate coherent text. Understanding how these associations are learned and represented in language models is essential for connecting deep learning with linguistic theory and developing a mechanistic foundation for large language models. In this work, we analyze how these associations emerge from natural language data in attention-based language models through the lens of training dynamics. By leveraging a leading-term approximation of the gradients, we develop closed-form expressions for the weights at early stages of training that explain how semantic associations first take shape. Through our analysis, we reveal that each set of weights of the transformer has closed-form expressions as simple compositions of three basis functions--bigram, token-interchangeability, and context mappings--reflecting the statistics in the text corpus and uncover how each component of the transformer captures the semantic association based on these compositions. Experiments on real-world LLMs demonstrate that our theoretical weight characterizations closely match the learned weights, and qualitative analyses further guide us on how our theorem shines light on interpreting the learned association in transformers.
Gradient leading-term analysis reveals how semantic associations emerge in transformers as compositions of bigram, interchangeability, and context mapping functions.
- Develops closed-form expressions for transformer weights at early training stages using leading-term gradient approximation
- Reveals transformer weights decompose as compositions of three basis functions: bigram mapping, token-interchangeability mapping, and context mapping
- Demonstrates theoretical weight characterizations closely match learned weights in real-world LLMs
- Gradient leading-term analysis
- Closed-form weight expressions
- Training dynamics analysis
Authors did not state explicit limitations.
Discover common factors allowing weight matrices across components to be decomposed into simple functions of shared factors
from the paperLeverage theory to formulate broad hypotheses about how concepts arise in models, extending beyond individual mechanisms
from the paper
Author keywords
- Semantic associations
- Interpretability
- LLM
Related orals
Benchmarking Empirical Privacy Protection for Adaptations of Large Language Models
Benchmarks practical privacy risks in differential privacy-adapted LLMs, revealing distribution shifts and model choice impact effectiveness.
Half-order Fine-Tuning for Diffusion Model: A Recursive Likelihood Ratio Optimizer
Proposes Recursive Likelihood Ratio optimizer for efficient fine-tuning of diffusion models with lower variance gradient estimation.
Invisible Safety Threat: Malicious Finetuning for LLM via Steganography
Demonstrates LLMs can be finetuned to generate harmful steganographically-hidden outputs while appearing benign to safety systems.
Reducing Belief Deviation in Reinforcement Learning for Active Reasoning of LLM Agents
Proposes T3 algorithm to detect belief deviation in LLM agents and truncate trajectories for improved reinforcement learning in active reasoning tasks.
RefineStat: Efficient Exploration for Probabilistic Program Synthesis
RefineStat enforces semantic constraints and applies diagnostic-aware refinement for synthesizing valid probabilistic programs from smaller language models.