PRISM: PARAMETRICALLY RESTRUCTURED INFERENCE FOR SPECULATIVE SAMPLING DRAFT MODELS
Abstract
Large Language Models (LLMs), constrained by their auto-regressive nature, have long suffered from expensive and slow decoding. Speculative sampling methods, capable of alleviating the memory bandwidth bottleneck, have attracted attention from both the system and AI research communities. The demand for high predictive performance has created a growing trend of training parametrically larger and more powerful draft models, which also introduces growing computation overhead. While existing works balance trade-offs to find a sweet spot, in this paper we dive further into this effectiveness and efficiency dilemma, addressing the issue with architectural innovation. By disaggregating the computation of each predictive step across different parameter sets, we restructure the computational paths for the draft models, successfully decoupling the representation capacity from the inference cost, which enables the model scalable and fast at the same time. We conduct extensive experiments showing that our PRISM drafter outperforms SoTA draft architectures on acceptance length and end-to-end throughput when trained with the same dataset. We also show that PRISM scales exceptionally well on large datasets while some other architectures fail. On average, PRISM speculative decoding can achieve more than 2.6x end-to-end speedup when integrated with an already highly optimized inference engine.