Skip to yearly menu bar Skip to main content



Mission B4 & B11
Wed 15 May 9 a.m. PDT — 10 a.m. PDT


Chat is not available.

Wed 15 May 9:00 - 9:20 PDT

FlashDecoding++: Faster Large Language Model Inference with Asynchronization, Flat GEMM Optimization, and Heuristics

Ke Hong · Guohao Dai · Jiaming Xu · Qiuli Mao · Xiuhong Li · Jun Liu · kangdi chen · Yuhan Dong · Yu Wang

As the Large Language Model (LLM) becomes increasingly important in various domains, the performance of LLM inference is crucial to massive LLM applications. However, the following challenges still remain unsolved in accelerating LLM inference: (1) Synchronized partial softmax update. The softmax operation requires a synchronized update operation among each partial softmax result, leading to ∼20% overheads for the attention computation in LLMs. (2) Under-utilized computation of flat GEMM. The shape of matrices performing GEMM in LLM inference is flat, leading to under-utilized computation and 50% performance loss after padding zeros in previous designs (e.g., cuBLAS, CUTLASS, etc.). (3) Performance loss to static dataflow. Kernel performance in LLM depends on varied input data features, hardware configurations, etc. A single and static dataflow may lead to 50.25% performance loss for GEMMs of different shapes in LLM inference.We present FlashDecoding++, a fast LLM inference engine supporting mainstream LLMs and hardware back- ends. To tackle the above challenges, FlashDecoding++ creatively proposes: (1) Asynchronized softmax with unified max value. FlashDecoding++ introduces a unified max value technique for different partial softmax computations to avoid synchronization. Based on this, the fine-grained pipelining is proposed, leading to 1.05× and 1.14× for the prefill and decoding stage in LLM inference, respectively. (2) Flat GEMM optimization with double buffering. FlashDecoding++ points out that flat GEMMs with different shapes face varied bottlenecks. Then, techniques like double buffering are introduced, leading up to 52% speedup for the flat GEMM operation. (3) Heuristic dataflow with hardware resource adaption. FlashDecoding++ heuristically optimizes dataflow using different hardware resource (e.g., Tensor Core or CUDA core) considering input dynamics. The design leads to up to 29% speedup compared with the static dataflow. Due to the versatility of optimizations in FlashDecoding++, FlashDecoding++ can achieve up to 4.86× and 2.18× speedup on both NVIDIA and AMD GPUs compared with Hugging Face implementations. FlashDecoding++ also achieves an average of 1.37× speedup compared with state-of-the-art LLM inference engines, FlashDecoding, on various LLMs (e.g., Llama2, ChatGLM2, etc.).

Wed 15 May 9:20 - 9:40 PDT

Prompt Cache: Modular Attention Reuse for Low-Latency Inference

In Gim · Guojun Chen · Seung-seob Lee · Nikhil Sarda · Anurag Khandelwal · Lin Zhong

We present Prompt Cache, an approach for accelerating inference for large language models (LLM) by reusing attention states across different LLM prompts. Many input prompts have overlapping text segments, such as system messages, prompt templates, and documents provided for context.Our key insight is that by precomputing and storing the attention states of these frequently occurring text segments on the inference server, we can efficiently reuse them when these segments appear in user prompts. Prompt Cache employs a schema to explicitly define such reusable text segments, called prompt modules. The schema ensures positional accuracy during attention state reuse and provides users with an interface to access cached states in their prompt.Using a prototype implementation, we evaluate Prompt Cache across several LLMs. We show that Prompt Cache significantly reduce latency in time-to-first-token, especially for longer prompts such as document-based question answering and recommendations. The improvements range from 8x for GPU-based inference to 60x for CPU-based inference, all while maintaining output accuracy and without the need for model parameter modifications.

Wed 15 May 9:40 - 10:00 PDT

Keyformer: KV Cache reduction through key tokens selection for Efficient Generative Inference

Muhammad Adnan · Akhil Arunkumar · Gaurav Jain · Prashant Nair · Ilya Soloveychik · Purushotham Kamath

Transformers have emerged as the standard architecture for Large Language Models (LLMs). In generativelanguage models, the inference process involves two main phases: prompt processing and token generation. Tokengeneration, which constitutes most of the computational load, primarily entails vector-matrix multiplicationsand interactions with the Key-Value ($\mathsf{KV}$) Cache. This phase is memory bandwidth-bound due to the overheadof transferring weights and KV cache values from memory to the computing units, which involves relativelylow compute intensity. This memory bottleneck becomes particularly prominent in applications that demandlong-context and extensive text generation, both of which are increasingly crucial for LLMs.This paper introduces an innovative approach to mitigate the challenges associated with KV cache size and memorybandwidth utilization, termed "$\mathsf{Keyformer}$". $\mathsf{Keyformer}$ capitalizes on the observation that during generativeinference, approximately 90% of the attention weight is concentrated on a select subset of tokens, which actas "key" tokens. $\mathsf{Keyformer}$’s key tokens identification takes into account the discarded tokens by utilizing anovel score function. By retaining only these "key" tokens in the $\mathsf{KV cache}$, both the $\mathsf{KV cache}$ size and memorybandwidth usage are significantly reduced while maintaining the model’s accuracy. We evaluate $\mathsf{Keyformer}$’seffectiveness using three foundational models: GPT-J, Cerebras-GPT, and MPT, which employ various positionalembedding algorithms. Our assessment covers a range of tasks, with a primary focus on summarization andconversation tasks that involve extended contexts. $\mathsf{Keyformer}$’s $\mathsf{KV cache}$ reduction enhances inference latencyby 2.1$\times$ and boosts token generation throughput by 2.4$\times$, all while preserving the model’s accuracy.