Files
myWiki/raw/papers/goru-one-pass-to-reason-2025.md

2.7 KiB
Raw Blame History

title, authors, venue, year, arxiv, code, dataset, type, tags
title authors venue year arxiv code dataset type tags
One-Pass to Reason: Token Duplication and Block-Sparse Mask for Efficient Fine-Tuning on Multi-Turn Reasoning Ritesh Goru, Shanay Mehta, Prateek Jain ICML 2025 Workshop: 3rd Workshop on Efficient Systems for Foundational Models 2025 2504.18246 https://github.com/devrev/One-Pass-to-Reason https://huggingface.co/datasets/devrev-research/MathChatSync-reasoning paper
efficient-fine-tuning
multi-turn-reasoning
attention-mask
token-duplication
single-pass-training

Abstract

Fine-tuning Large Language Models (LLMs) on multi-turn reasoning datasets requires N (number of turns) separate forward passes per conversation due to reasoning token visibility constraints, as reasoning tokens for a turn are discarded in subsequent turns. We propose duplicating response tokens along with a custom attention mask to enable single-pass processing of entire conversations. We prove our method produces identical losses to the N-pass approach while reducing time complexity from O(N³) to O(N²) and maintaining the same memory complexity for a transformer based model. Our approach achieves significant training speedup while preserving accuracy. Our implementation is available online.

Core Problem

Reasoning models (e.g., DeepSeek-R1) generate internal reasoning tokens, produce a response, and then discard the reasoning tokens from context in subsequent turns. This creates:

  1. Visibility Constraints: Reasoning tokens must be visible during generation but hidden from subsequent turns — static attention masks cannot satisfy this
  2. Position ID Discrepancy: Response tokens follow reasoning tokens during generation but directly follow human messages in later context

Method

  1. Token Duplication: Duplicate response tokens so ri_in (context copy) does not attend to reasoning, while ri_out (generation copy) does
  2. Custom Block-Sparse Attention Mask: Single mask with visibility rules per token type
  3. Strategic Position ID Assignment: Maintains correct relative positions equivalent to N-pass
  4. Theorem 2.1: Proves loss equivalence L_N-Pass(c) = L_1-Pass(c)

Results

  • 1.05×1.22× faster than FlashAttention-2 N-Pass with packing (Qwen-3 4B, 8B, 32B)
  • 1.44×1.54× faster than FlexAttention N-Pass with packing
  • ~33% more GPU memory
  • Speedups grow with conversation depth (O(N²) vs O(N³) theoretical advantage)
  • K-Pass variant allows speedmemory trade-off

Key Contributions

  1. Theoretical framework for single-pass multi-turn reasoning training
  2. MathChatSync Reasoning dataset (first public multi-turn reasoning dataset with explicit per-turn reasoning)
  3. Comprehensive empirical validation on Qwen-3 models using QLoRA