Partition Generative Modeling: Masked Modeling Without Masks
Justin Deschenaux, Lan Tran, Caglar Gulcehre
We show that it is possible to train masked generative models without using MASK tokens, resulting in efficiency gains at inference.
Abstract
Masked generative models (MGMs) can generate tokens in parallel and in any order, unlike autoregressive models (ARMs), which decode one token at a time, left-to-right. However, MGMs process the full-length sequence at every sampling step, including \mask tokens that carry no information. In contrast, ARMs process only the previously generated tokens. We introduce ``Partition Generative Models'' (PGMs), which replace masking with partitioning. Tokens are split into two groups that cannot attend to each other, and the model learns to predict each group conditioned on the other, eliminating mask tokens entirely. Because the groups do not interact, PGMs can process only the clean tokens during sampling, like ARMs, while retaining parallel, any-order generation, like MGMs. On OpenWebText, PGMs achieve $5-5.5\times$ higher throughput than MDLM while producing samples with lower Generative Perplexity. On ImageNet, PGMs reach comparable FID to MaskGIT with a $7.5\times$ throughput improvement. With twice as many steps, the FID improves to 4.56 while remaining $3.9\times$ faster than MGMs. Finally, PGMs remain compatible with existing MGM samplers and distillation methods.
Partition Generative Models replace masking with partitioning for efficient parallel generation, achieving higher throughput than masked generative models.
- Introduces PGM approach replacing masking with token partitioning where groups cannot attend to each other
- Eliminates mask tokens entirely while retaining parallel, any-order generation capabilities like masked models
- Achieves 5-5.5x higher throughput than MDLM on OpenWebText and 7.5x improvement over MaskGIT on ImageNet
- Token partitioning
- Parallel generation
- Iterative updating
- OpenWebText
- ImageNet
Models require slight increase in parameters to match validation perplexity of MDLM baseline, attributed to GroupSwap layer
from the paperTraining slightly more computationally expensive than baseline due to torch's default attention implementation
from the paperApplication to multimodal settings remains open direction
from the paper
Explore optimizations to PGM architecture including more efficient GroupSwap mechanisms
from the paperInvestigate distillation techniques specifically designed for PGMs
from the paperExtend approach to multimodal settings
from the paper
Author keywords
- masked generative modeling
- discrete diffusion
- masked diffusion language modeling
- diffusion language modeling
Related orals
Benchmarking Empirical Privacy Protection for Adaptations of Large Language Models
Benchmarks practical privacy risks in differential privacy-adapted LLMs, revealing distribution shifts and model choice impact effectiveness.
Half-order Fine-Tuning for Diffusion Model: A Recursive Likelihood Ratio Optimizer
Proposes Recursive Likelihood Ratio optimizer for efficient fine-tuning of diffusion models with lower variance gradient estimation.
Invisible Safety Threat: Malicious Finetuning for LLM via Steganography
Demonstrates LLMs can be finetuned to generate harmful steganographically-hidden outputs while appearing benign to safety systems.
Reducing Belief Deviation in Reinforcement Learning for Active Reasoning of LLM Agents
Proposes T3 algorithm to detect belief deviation in LLM agents and truncate trajectories for improved reinforcement learning in active reasoning tasks.
RefineStat: Efficient Exploration for Probabilistic Program Synthesis
RefineStat enforces semantic constraints and applies diagnostic-aware refinement for synthesizing valid probabilistic programs from smaller language models.