2.7 KiB
title, authors, venue, year, arxiv, code, dataset, type, tags
| title | authors | venue | year | arxiv | code | dataset | type | tags | |||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| One-Pass to Reason: Token Duplication and Block-Sparse Mask for Efficient Fine-Tuning on Multi-Turn Reasoning | Ritesh Goru, Shanay Mehta, Prateek Jain | ICML 2025 Workshop: 3rd Workshop on Efficient Systems for Foundational Models | 2025 | 2504.18246 | https://github.com/devrev/One-Pass-to-Reason | https://huggingface.co/datasets/devrev-research/MathChatSync-reasoning | paper |
|
Abstract
Fine-tuning Large Language Models (LLMs) on multi-turn reasoning datasets requires N (number of turns) separate forward passes per conversation due to reasoning token visibility constraints, as reasoning tokens for a turn are discarded in subsequent turns. We propose duplicating response tokens along with a custom attention mask to enable single-pass processing of entire conversations. We prove our method produces identical losses to the N-pass approach while reducing time complexity from O(N³) to O(N²) and maintaining the same memory complexity for a transformer based model. Our approach achieves significant training speedup while preserving accuracy. Our implementation is available online.
Core Problem
Reasoning models (e.g., DeepSeek-R1) generate internal reasoning tokens, produce a response, and then discard the reasoning tokens from context in subsequent turns. This creates:
- Visibility Constraints: Reasoning tokens must be visible during generation but hidden from subsequent turns — static attention masks cannot satisfy this
- Position ID Discrepancy: Response tokens follow reasoning tokens during generation but directly follow human messages in later context
Method
- Token Duplication: Duplicate response tokens so ri_in (context copy) does not attend to reasoning, while ri_out (generation copy) does
- Custom Block-Sparse Attention Mask: Single mask with visibility rules per token type
- Strategic Position ID Assignment: Maintains correct relative positions equivalent to N-pass
- Theorem 2.1: Proves loss equivalence L_N-Pass(c) = L_1-Pass(c)
Results
- 1.05×–1.22× faster than FlashAttention-2 N-Pass with packing (Qwen-3 4B, 8B, 32B)
- 1.44×–1.54× faster than FlexAttention N-Pass with packing
- ~33% more GPU memory
- Speedups grow with conversation depth (O(N²) vs O(N³) theoretical advantage)
- K-Pass variant allows speed–memory trade-off
Key Contributions
- Theoretical framework for single-pass multi-turn reasoning training
- MathChatSync Reasoning dataset (first public multi-turn reasoning dataset with explicit per-turn reasoning)
- Comprehensive empirical validation on Qwen-3 models using QLoRA