Files
myWiki/raw/papers/gu-mamba-2024.md

95 lines
4.6 KiB
Markdown
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

---
title: "Mamba: Linear-Time Sequence Modeling with Selective State Spaces"
authors: ["Albert Gu", "Tri Dao"]
date: 2023-12-01
arxiv_id: "2312.00752v2"
categories: ["cs.LG", "cs.AI"]
affiliations: ["Carnegie Mellon University", "Princeton University"]
paper_type: "conference"
code: "https://github.com/state-spaces/mamba"
---
# Mamba: Linear-Time Sequence Modeling with Selective State Spaces
## 摘要
Foundation model 几乎全部基于 Transformer 架构,但其注意力机制的二次复杂度在处理长序列时效率极低。各种次二次复杂度架构(线性注意力、门控卷积、结构状态空间模型)试图取代注意力,但在语言等核心模态上始终达不到 Transformer 质量。本文识别出这些模型的根本弱点——**缺乏内容感知推理能力content-based reasoning**——并通过两个关键创新解决:(1) 让 SSM 参数成为输入的函数选择机制S6使模型能根据当前 token 选择性传播或遗忘信息;(2) 设计硬件感知的并行算法,在循环模式下高效计算。最终形成极简架构 Mamba——无注意力甚至无 MLP 块。Mamba 推理吞吐量是 Transformer 的 5 倍,序列长度线性扩展,在语言、音频、基因组学等多个模态达到 SOTA。Mamba-3B 性能超过同规模 Transformer 并匹敌两倍规模的 Transformer。
## 核心贡献
1. **选择机制Selection Mechanism / S6**:将 SSM 参数(Δ, B, C变为输入依赖从时间不变LTI升级为时间变化
2. **硬件感知算法**通过并行关联扫描parallel associative scan在 SRAM 中计算,避免 GPU HBM 之间的 IO 瓶颈
3. **极简架构 Mamba**:将 H3 架构中的 SSM 层与 MLP 门控融合为单一同质块
4. **选择复制Selective Copying和归纳头Induction Heads合成任务**Mamba 不仅轻松解决,且能无限外推(>1M tokens
## 方法框架
### 从 S4 到 S6
传统 S4 的关键局限是 **线性时间不变性LTI**:参数 (Δ, A, B, C) 对所有时间步固定。这意味着状态更新规则不随输入内容改变——模型无法"选择性"关注或忽略特定 token。
Mamba 的选择机制S6将 B, C, Δ 变为输入 x 的函数:
```
B_t = s_B(x_t) # 输入 → 输入投影
C_t = s_C(x_t) # 输入 → 输出投影
Δ_t = τ_Δ(Δ + s_Δ(x_t)) # 输入依赖的步长
```
核心差异:
| 特性 | S4 (LTI) | S6 (Selective) |
|------|---------|---------------|
| 参数 | 时间不变 | 时间变化(输入依赖) |
| 计算模式 | 卷积 OR 循环 | 仅循环(需 scan |
| 选择性 | 无 | 有(过滤/保留) |
| 内容感知 | 否 | 是 |
### 硬件感知并行 Scan
选择机制消除了卷积等价性——模型必须是时间变化的无法用卷积并行计算。Mamba 通过**并行关联扫描parallel associative scan / Blelloch scan**解决:
1. 将状态更新展开为前缀和操作
2. 在 GPU SRAM 中做 kernel fusion避免将扩展状态写入 HBM
3. 输入在 HBM → 加载到 SRAM → scan + 离散化 → 写回 HBM
结果:比所有基于卷积的 SSM 快 3×A100 GPU
### Mamba 架构
```
Input → Mamba Block → ... (×L) → Output
Mamba Block:
x → LayerNorm → [Linear(expand) → Conv1d → SiLU → SSM(S6)] → LayerNorm → Linear → + (residual)
```
关键设计:
- **无注意力、无 MLP**:用选择性 SSM 取代二者
- **扩展因子 E=2**Linear 将 d_model 扩展到 2× 再投影回
- **残差连接 + SiLU 激活**
- **H3 简化**:将 H3 的两个门控 SSM 融合为单一选择性 SSM
## 实验结果
- **合成任务**Selective Copying 和 Induction Heads → Mamba 可以泛化到 >1M token 序列
- **语言建模**Mamba-3B 在 pretraining perplexity 和 0-shot 评估上超过 Pythia-3B匹敌 Pythia-7B5× 推理吞吐
- **音频**:在 SC09 语音生成上将 FID 降低一半以上
- **基因组学**:在 DNA 建模上超过 HyenaDNA 和 Transformer
## 关键概念
- [[selective-state-space]] — S6 选择机制,输入依赖的 SSM 参数化
- [[hardware-aware-algorithm]] — GPU 层次优化的并行 scan
- [[structured-state-space-models]] — S4 前身HiPPO 矩阵 + 对角结构
- [[selective-copy]] — 需要内容感知的选择性复制任务
- [[induction-heads]] — 解释 LLM in-context learning 能力的机制
- [[hippo]] — SSM 的数学基础High-order Polynomial Projection Operators
- [[content-based-reasoning]] — Mamba 识别并解决的核心弱点
## 参考
- 代码https://github.com/state-spaces/mamba
- S4 (Gu et al., 2022)
- H3 (Dao et al., 2023)
- 选择复制任务 (Arjovsky et al., 2016)
- 归纳头 (Olsson et al., 2022)