Files
myWiki/concepts/shared-weight-discretization.md

2.2 KiB
Raw Permalink Blame History

title, created, updated, type, tags, sources
title created updated type tags sources
Shared-Weight Discretization 2026-05-13 2026-05-13 concept
network-architecture
diffusion-language-model
multi-task-learning
https://arxiv.org/abs/2605.10938

Shared-Weight Discretization

共享权重离散化是 embedded-language-flows 的核心设计:同一个网络既执行去噪又执行解码,区别仅在于输入条件和输出处理

机制

网络签名为 net_θ(z, t, mode),其中:

  • z:当前(带噪)嵌入
  • t:时间步 ∈ [0,1]
  • mode:二进制 tokendenoisedecode

Denoise Mode (t < 1)

x̂ = net_θ(z_t, t, "denoise")
v̂ = (x̂ - z_t) / (1-t)       # 转换 x-prediction 为速度
L = MSE(v̂, v_true)

Decode Mode (t = 1)

# 先对 z 加 token 级 corruption 构造非平凡输入
z̃ = corrupt(z_1)
x̂ = net_θ(z̃, t=1, "decode")
logits = W · x̂               # unembedding 层
L = CrossEntropy(logits, s)  # s 是真实 token

为什么共享权重有效

x-prediction-parameterization 是关键:网络始终预测干净嵌入 x̂。在 denoise mode 中它转换为速度;在 decode mode 中它直接经 unembedding 转为 logits。两种模式共享网络权重因为它们在语义上一致——都试图恢复干净的 token 表示。

v-prediction 无法做到这一点:预测速度 v 与预测离散 token 之间的语义鸿沟使得权重共享不可行ELF 论文中实验证实)。

优势

  1. 零额外参数:不需要单独训练的 decoder与 LD4LG 等潜在扩散方法对比)
  2. 训练效率:两种模式在一个 batch 中通过 masking 同时训练,无额外计算开销
  3. 语义对齐:去噪目标(恢复干净嵌入)和解码目标(恢复干净 token共享底层表示

实现细节

训练时两分支按比例混合ELF 默认 80% denoise + 20% decode。推理时

  1. t < 1使用 denoise mode迭代更新嵌入
  2. t = 1使用 decode modeargmax 输出离散 token

相关概念