44 lines
2.7 KiB
Markdown
44 lines
2.7 KiB
Markdown
---
|
||
title: "One-Pass to Reason: Token Duplication and Block-Sparse Mask for Efficient Fine-Tuning on Multi-Turn Reasoning"
|
||
authors: "Ritesh Goru, Shanay Mehta, Prateek Jain"
|
||
venue: "ICML 2025 Workshop: 3rd Workshop on Efficient Systems for Foundational Models"
|
||
year: 2025
|
||
arxiv: "2504.18246"
|
||
code: "https://github.com/devrev/One-Pass-to-Reason"
|
||
dataset: "https://huggingface.co/datasets/devrev-research/MathChatSync-reasoning"
|
||
type: paper
|
||
tags: [efficient-fine-tuning, multi-turn-reasoning, attention-mask, token-duplication, single-pass-training]
|
||
---
|
||
|
||
## Abstract
|
||
|
||
Fine-tuning Large Language Models (LLMs) on multi-turn reasoning datasets requires N (number of turns) separate forward passes per conversation due to reasoning token visibility constraints, as reasoning tokens for a turn are discarded in subsequent turns. We propose duplicating response tokens along with a custom attention mask to enable single-pass processing of entire conversations. We prove our method produces identical losses to the N-pass approach while reducing time complexity from O(N³) to O(N²) and maintaining the same memory complexity for a transformer based model. Our approach achieves significant training speedup while preserving accuracy. Our implementation is available online.
|
||
|
||
## Core Problem
|
||
|
||
Reasoning models (e.g., DeepSeek-R1) generate internal reasoning tokens, produce a response, and then discard the reasoning tokens from context in subsequent turns. This creates:
|
||
|
||
1. **Visibility Constraints**: Reasoning tokens must be visible during generation but hidden from subsequent turns — static attention masks cannot satisfy this
|
||
2. **Position ID Discrepancy**: Response tokens follow reasoning tokens during generation but directly follow human messages in later context
|
||
|
||
## Method
|
||
|
||
1. **Token Duplication**: Duplicate response tokens so ri_in (context copy) does not attend to reasoning, while ri_out (generation copy) does
|
||
2. **Custom Block-Sparse Attention Mask**: Single mask with visibility rules per token type
|
||
3. **Strategic Position ID Assignment**: Maintains correct relative positions equivalent to N-pass
|
||
4. **Theorem 2.1**: Proves loss equivalence L_N-Pass(c) = L_1-Pass(c)
|
||
|
||
## Results
|
||
|
||
- 1.05×–1.22× faster than FlashAttention-2 N-Pass with packing (Qwen-3 4B, 8B, 32B)
|
||
- 1.44×–1.54× faster than FlexAttention N-Pass with packing
|
||
- ~33% more GPU memory
|
||
- Speedups grow with conversation depth (O(N²) vs O(N³) theoretical advantage)
|
||
- K-Pass variant allows speed–memory trade-off
|
||
|
||
## Key Contributions
|
||
|
||
1. Theoretical framework for single-pass multi-turn reasoning training
|
||
2. MathChatSync Reasoning dataset (first public multi-turn reasoning dataset with explicit per-turn reasoning)
|
||
3. Comprehensive empirical validation on Qwen-3 models using QLoRA
|