DreamZero (WAM) 模型架构

1. 整体架构概览

DreamZero 是一个世界动作模型 (World Action Model, WAM),通过联合预测未来视频帧与机器人动作序列,实现对未见任务的零样本泛化。其核心创新在于:基于 Wan2.1 视频扩散模型构建 Causal WAN DiT,以 Flow Matching 框架在视频 latent 空间和动作空间上同步去噪,配合多机器人体 Category-specific MLP 支持 DROID、AgiBot、YAM 等多种机器人平台,在 MolmoSpaces 和 RoboArena 双榜均位列第一(截至 2026 年 2 月)。

graph TB subgraph Input["输入"] IMG["历史视频帧
[B, 1, C, H, W]
320×176"] TXT["语言指令
(text prompt)"] EID["机器人体 ID
(embodiment_id)"] end subgraph Encoders["感知编码器"] CLIP["CLIP ViT 图像编码器
open-clip-xlm-roberta-large-vit-huge-14
→ 1536-dim 特征"] T5["umt5-xxl 文本编码器
→ 1536-dim 特征"] end subgraph VAE["视频 VAE (Wan2.1)"] ENC["VAE 编码器
观测帧像素 → 观测 latent
(仅压缩,无生成能力)"] DEC["VAE 解码器
预测 latent → 像素帧
(仅解压,可视化用)"] end subgraph Core["Causal WAN DiT (核心,真正的生成模型)"] DIT["40层因果扩散 Transformer
dim=5120, 40 heads
Flow Matching + RoPE + Flash Attn 3"] ACT_MLP["多机器人体 MLP
Category-specific Linear"] NOISE["纯高斯噪声
(未来帧初始状态)"] end subgraph Output["输出"] ACT_OUT["动作序列
[B, 24, action_dim]"] VID_OUT["可视化视频帧
[B, 33, H, W, 3]
(推理可选,不影响控制)"] end IMG --> CLIP TXT --> T5 IMG --> ENC CLIP -->|"视觉特征(条件)"| DIT T5 -->|"语言特征(条件)"| DIT ENC -->|"观测 latent(条件,填KV Cache)"| DIT EID --> ACT_MLP --> DIT NOISE -->|"去噪起点"| DIT DIT -->|"动作预测"| ACT_OUT DIT -->|"预测 latent(去噪生成)"| DEC --> VID_OUT style Input fill:#e8f4fd,stroke:#2196F3 style Encoders fill:#fff3e0,stroke:#FF9800 style VAE fill:#f3e5f5,stroke:#9C27B0 style Core fill:#e8f5e9,stroke:#4CAF50 style Output fill:#fce4ec,stroke:#E91E63

2. 核心组件详解

2.1 感知编码器

DreamZero 使用两个独立的预训练编码器将视觉和语言信息映射到统一的 1536-dim 特征空间:

graph LR subgraph ImagePath["图像编码路径"] VF["视频帧
[B, T_img, C, H, W]"] --> CLIP_ENC["CLIP ViT 编码
wan_video_image_encoder.py"] CLIP_ENC --> CLIP_FEAT["图像特征
[B, T_img, N_patch, 1536]"] end subgraph TextPath["文本编码路径"] PROMPT["语言提示
(string)"] --> T5_TOK["T5 分词器"] T5_TOK --> T5_ENC["umt5-xxl 编码
wan_video_text_encoder.py"] T5_ENC --> TXT_FEAT["文本特征
[B, T_txt, 1536]"] end subgraph Fusion["特征融合 → DiT"] CLIP_FEAT --> CAT["拼接为条件序列"] TXT_FEAT --> CAT CAT --> DIT_IN["输入 Causal WAN DiT"] end style ImagePath fill:#e3f2fd,stroke:#2196F3 style TextPath fill:#fff3e0,stroke:#FF9800 style Fusion fill:#e8f5e9,stroke:#4CAF50

2.2 视频 VAE

沿用 Wan2.1 的视频 VAE,作为压缩/解压工具(无生成能力),让 DiT 在低维 latent 空间工作以降低计算量。VAE 编码器和解码器服务于两个完全不同的目的,输入输出的 latent 含义不同。

graph LR subgraph Encode["VAE 编码器(输入侧)"] V_IN["观测帧像素
[B, T, H, W, 3]"] --> VAE_E["VAE 编码器
wan_video_vae.py"] VAE_E --> LAT["观测 latent(条件)
[B, 16, T, h, w]
用途:填入 KV Cache,作为 DiT 生成的历史条件"] end subgraph Decode["VAE 解码器(输出侧)"] LAT2["预测 latent(生成目标)
[B, 16, T, h, w]
DiT 从纯噪声去噪得到"] --> VAE_D["VAE 解码器"] VAE_D --> V_OUT["预测视频帧像素
[B, T, H, W, 3]
(仅用于可视化,不影响机器人控制)"] end subgraph Params["关键参数"] P1["latent 通道数: 16"] P2["时域压缩: 4×"] P3["空域压缩: 8×"] end style Encode fill:#e3f2fd,stroke:#2196F3 style Decode fill:#e8f5e9,stroke:#4CAF50 style Params fill:#fff9c4,stroke:#FFC107

2.3 Causal WAN DiT(核心扩散模型)

Causal WAN DiT 是 DreamZero 的核心,基于 Wan2.1 的 DiT 架构改造而来:40 层因果 Transformer,通过 Causal Masking 保证时序因果性,同时在 latent 序列中嵌入动作 token,实现视频与动作的联合生成。

graph TB subgraph InputSeq["输入序列构建"] direction LR VL["含噪预测 latent
[B, T_vid, 16]
(去噪目标,从纯噪声出发)"] AC["含噪动作
[B, 24, action_dim]"] VIS["视觉条件
[B, T_img, 1536]"] LNG["语言条件
[B, T_txt, 1536]"] TS["时间步 t"] VL --> ROPE["RoPE 位置编码"] AC --> ACT_PROJ["动作投影 MLP"] TS --> TIME_EMB["时间步嵌入
(正弦编码)"] end subgraph Block["单层 Causal DiT Block (×40)"] direction TB IN_B["输入 [B, T, 5120]"] --> LN1_B["RMSNorm"] LN1_B --> MOD1["Modulation
(shift, scale by timestep emb)"] MOD1 --> CATTN["因果自注意力
40 heads, Flash Attn 3
Causal Mask"] CATTN --> ADD1["+ 残差"] IN_B --> ADD1 ADD1 --> LN2_B["RMSNorm"] LN2_B --> XATTN["交叉注意力
(视觉 + 语言条件)"] XATTN --> ADD2["+ 残差"] ADD1 --> ADD2 ADD2 --> LN3_B["RMSNorm"] LN3_B --> FFN_B["FFN (SwiGLU)
dim 5120 → 13824 → 5120"] FFN_B --> ADD3["+ 残差"] ADD2 --> ADD3 ADD3 --> OUT_B["输出 [B, T, 5120]"] end subgraph Output["输出解码"] OUT_B --> SPLIT["分离 video / action token"] SPLIT --> VP["视频 latent 预测"] SPLIT --> AP["动作预测
[B, 24, action_dim]"] end ROPE --> Block ACT_PROJ --> Block TIME_EMB --> Block VIS --> Block LNG --> Block style InputSeq fill:#e3f2fd,stroke:#2196F3 style Block fill:#e8f5e9,stroke:#4CAF50 style Output fill:#fce4ec,stroke:#E91E63

因果注意力掩码设计: 视频帧 token 只能 attend 到过去帧和当前帧,动作 token attend 到所有视频 token(全局条件),保证生成的时序一致性。

graph LR subgraph Mask["Causal Attention Mask 示意"] direction LR F0["帧 0 token"] -->|"✓ attend"| F0S["自身"] F1["帧 1 token"] -->|"✓ attend"| F0_1["帧 0"] F1 -->|"✓ attend"| F1S["自身"] F1 -->|"✗ 不 attend"| F2_BLK["帧 2+ (未来)"] AT["动作 token"] -->|"✓ attend"| ALL["所有帧 token"] end style Mask fill:#fff9c4,stroke:#FFC107

2.4 多机器人体支持

DreamZero 通过每个机器人体独立的线性层(Category-specific MLP)将不同维度的动作空间映射到统一的 DiT 隐空间,支持在单一模型中处理多种机器人平台。

graph TB subgraph Embodiments["多机器人体动作编码"] direction TB D_ACT["DROID 动作
[B, 24, dim_droid]"] A_ACT["AgiBot 动作
[B, 24, dim_agibot]"] Y_ACT["YAM 动作
[B, 24, dim_yam]"] O_ACT["其他机器人体..."] D_ACT --> D_MLP["DROID Linear
dim_droid → 5120"] A_ACT --> A_MLP["AgiBot Linear
dim_agibot → 5120"] Y_ACT --> Y_MLP["YAM Linear
dim_yam → 5120"] O_ACT --> O_MLP["... Linear"] end subgraph Shared["共享 DiT 处理"] D_MLP --> UNIFIED["统一动作 token
[B, 24, 5120]"] A_MLP --> UNIFIED Y_MLP --> UNIFIED O_MLP --> UNIFIED UNIFIED --> DIT_PROC["Causal WAN DiT
(共享权重)"] end subgraph Decode["多机器人体动作解码"] DIT_PROC --> SPLIT_E["按 embodiment_id 路由"] SPLIT_E --> D_DEC["DROID Decoder Linear
5120 → dim_droid"] SPLIT_E --> A_DEC["AgiBot Decoder Linear
5120 → dim_agibot"] SPLIT_E --> Y_DEC["YAM Decoder Linear
5120 → dim_yam"] end style Embodiments fill:#e3f2fd,stroke:#2196F3 style Shared fill:#e8f5e9,stroke:#4CAF50 style Decode fill:#fce4ec,stroke:#E91E63

相对动作计算: 各机器人体在数据加载时将绝对动作转为相对值(action - reference_state),使模型学到的动作表示更加泛化。


3. 训练流水线

DreamZero 基于 Flow Matching 框架进行训练,同时对视频 latent 和动作 token 施加扩散损失。训练使用 DeepSpeed ZeRO-2 分布式优化,并通过 LoRA 高效微调 14B 参数基础模型。

graph TB subgraph DataLoad["数据加载 (LeRobot 格式)"] direction LR PARQ["LeRobot parquet 文件"] --> DS["lerobot.py / lerobot_sharded.py"] DS --> MOD["模态变换
· 视频 resize → 320×176
· 状态/动作归一化
· 相对动作计算
· 语言编码"] MOD --> BATCH_OUT["训练 Batch"] end subgraph Encode["编码阶段"] BATCH_OUT --> IMG_E["CLIP 图像编码
→ 视觉特征"] BATCH_OUT --> TXT_E["T5 文本编码
→ 语言特征"] BATCH_OUT --> VAE_E2["VAE 编码
视频 → latent"] end subgraph FlowMatch["Flow Matching 加噪"] T_SAMP["时间步采样
t ~ Uniform[0, 1000]"] NOISE2["高斯噪声 ε ~ N(0, I)"] ACT_GT["真实动作 / 视频 latent"] T_SAMP --> INTERP2["线性插值
x_t = ε·t/T + x_0·(1 - t/T)"] NOISE2 --> INTERP2 ACT_GT --> INTERP2 end subgraph Forward["前向传播"] IMG_E --> DIT_FWD["Causal WAN DiT"] TXT_E --> DIT_FWD VAE_E2 --> DIT_FWD INTERP2 --> DIT_FWD DIT_FWD --> PRED_ALL["联合预测
视频 latent + 动作"] end subgraph Loss["损失计算"] PRED_ALL --> L_VID["视频 Flow Matching Loss
(MSE on video latent)"] PRED_ALL --> L_ACT["动作 Flow Matching Loss
(MSE on action)"] L_VID --> L_TOTAL["总损失
L = λ_vid · L_vid + λ_act · L_act"] L_ACT --> L_TOTAL end subgraph Optim["优化"] L_TOTAL --> BACK["反向传播"] BACK --> LORA["LoRA 梯度更新
rank=4, alpha=4"] LORA --> OPT["AdamW
β₁=0.95, β₂=0.999
Cosine LR + Warmup"] OPT --> DS_ZERO["DeepSpeed ZeRO-2
梯度累积"] end DataLoad --> Encode Encode --> FlowMatch FlowMatch --> Forward Forward --> Loss Loss --> Optim style DataLoad fill:#e3f2fd,stroke:#2196F3 style Encode fill:#fff3e0,stroke:#FF9800 style FlowMatch fill:#f3e5f5,stroke:#9C27B0 style Forward fill:#e8f5e9,stroke:#4CAF50 style Loss fill:#fce4ec,stroke:#E91E63 style Optim fill:#e0f7fa,stroke:#00BCD4

LoRA 配置: 基础模型(Wan2.1,约 14B 参数)完全冻结,仅在 DiT 的注意力层插入 LoRA 适配器(rank=4, alpha=4),大幅降低显存占用。


4. 推理流水线

推理时模型从纯随机噪声出发,通过 Flow Matching 迭代去噪恢复动作序列和视频帧。DreamZero 支持通过 WebSocket 的分布式多 GPU 推理服务。

graph TB subgraph Init["初始化"] OBS["当前观测图像
[B, 1, C, H, W]"] LANG_IN["语言指令"] X_NOISE["纯随机噪声
x_T ~ N(0, I)
(视频 latent + 动作)"] end subgraph EncoderOnce["编码(仅执行一次)"] OBS --> CLIP_INF["CLIP 编码"] LANG_IN --> T5_INF["T5 编码"] CLIP_INF --> COND["条件特征缓存"] T5_INF --> COND end subgraph DenoiseLoop["迭代去噪循环 (N 步)"] direction TB S_T["步骤 T: x_T → DiT → x̂₀"] S_T1["步骤 T-1: x_{T-1} = scheduler(x̂₀) → DiT → x̂₀"] SDOTS["..."] S_1["步骤 1: x₁ → DiT → x̂₀"] S_T --> S_T1 --> SDOTS --> S_1 end subgraph PostProcess["后处理"] S_1 --> SPLIT_OUT["分离 video latent / action"] SPLIT_OUT --> VAE_DEC2["VAE 解码 → 视频帧"] SPLIT_OUT --> ACT_POST["动作后处理
(反归一化, 相对→绝对)"] ACT_POST --> ACT_EXEC["执行前 N 步动作
(Action Chunking)"] end Init --> EncoderOnce COND --> DenoiseLoop X_NOISE --> DenoiseLoop DenoiseLoop --> PostProcess style Init fill:#e3f2fd,stroke:#2196F3 style EncoderOnce fill:#fff3e0,stroke:#FF9800 style DenoiseLoop fill:#e8f5e9,stroke:#4CAF50 style PostProcess fill:#fce4ec,stroke:#E91E63

分布式推理服务

graph LR subgraph Client["客户端 (test_client_AR.py)"] CLI["机器人控制程序"] --> WS_CLI["WebSocket 客户端
发送观测 + 指令"] end subgraph Server["推理服务器 (socket_test_optimized_AR.py)"] WS_SRV["WebSocket 服务器
(Flask-SocketIO)"] REDIS["Redis
会话状态管理"] RAY_W["Ray Worker Pool
多 GPU 并行推理"] MODEL["DreamZero 模型
(GB200 / H100)"] WS_SRV --> REDIS WS_SRV --> RAY_W RAY_W --> MODEL end WS_CLI -->|"观测数据"| WS_SRV MODEL -->|"动作预测"| WS_SRV WS_SRV -->|"动作序列"| WS_CLI WS_CLI -->|"执行动作"| CLI style Client fill:#e3f2fd,stroke:#2196F3 style Server fill:#e8f5e9,stroke:#4CAF50

5. 关键超参数表

参数 说明
模型总参数 ~14B 基于 Wan2.1,LoRA 微调
DiT 层数 40 Causal WAN DiT
DiT hidden dim 5120 每层隐藏层维度
注意力头数 40 head_dim = 128
Embedding dim 1536 图像/文本编码器输出维度
VAE latent 通道 16 视频压缩维度
Action horizon 24 单次预测动作步数
视频帧数 33 输入/预测帧数
图像分辨率 320×176 训练分辨率
LoRA rank / alpha 4 / 4 参数高效微调配置
批归一化策略 DeepSpeed ZeRO-2 分布式训练
优化器 AdamW β₁=0.95, β₂=0.999
学习率调度 Cosine + Warmup 分布式多卡训练
精度 bfloat16 混合精度训练

6. 关键源文件表

组件 文件路径(相对 /home/zhuyilong/dreamzero/
核心 VLA 模型 groot/vla/model/dreamzero/base_vla.py
Action Head (Flow Matching) groot/vla/model/dreamzero/action_head/wan_flow_matching_action_tf.py
Causal WAN DiT groot/vla/model/dreamzero/modules/wan_video_dit_action_casual_chunk.py
视频 VAE groot/vla/model/dreamzero/modules/wan_video_vae.py
文本编码器 groot/vla/model/dreamzero/modules/wan_video_text_encoder.py
图像编码器 groot/vla/model/dreamzero/modules/wan_video_image_encoder.py
数据变换 groot/vla/model/dreamzero/transform/
数据集加载 groot/vla/data/dataset/lerobot.py
分片数据集 groot/vla/data/dataset/lerobot_sharded.py
训练器 groot/vla/experiment/experiment.py
训练基类 groot/vla/experiment/base.py
分布式推理服务 socket_test_optimized_AR.py
推理客户端 test_client_AR.py
Hydra 模型配置 groot/vla/configs/model/
训练脚本 scripts/train/

7. Attention 机制深入解析

7.1 Q / K / V 的本质分工

Attention 的计算分四步:

1. q = W_q(x)              # x 投影成 Q
2. k = W_k(context)        # context 投影成 K
3. v = W_v(context)        # context 投影成 V(同一原材料,不同矩阵)
4. 权重 = softmax(q · kᵀ / √d)
5. 输出 = 权重 · v

K 和 V 的原材料相同(都来自 context),但投影矩阵不同,作用完全不同:

矩阵 参与哪一步 对输出的影响 学到的是
W_q 步骤 4(点积) 间接(通过权重) 我该去找什么样的 K
W_k 步骤 4(点积) 间接(通过权重) 什么样的 Q 应该找到我
W_v 步骤 5(加权求和) 直接(构成输出) 被选中后我该提供什么内容

K 是抽象意义上的索引——它不直接出现在输出里,只决定权重分布,告诉模型"哪些位置比较重要"。V 才是真正被读取的内容。

graph LR subgraph Input["输入"] X["x(视频 token)"] CTX["context(文本/图像特征)"] end subgraph Proj["投影"] X --> WQ["W_q"] --> Q["Q\n寻址用"] CTX --> WK["W_k"] --> K["K\n索引(被动等待匹配)"] CTX --> WV["W_v"] --> V["V\n内容(直接构成输出)"] end subgraph Attn["Attention 计算"] Q --> DOT["q · kᵀ / √d\n点积"] K --> DOT DOT --> SM["softmax\n权重分布"] SM --> WV2["权重 · V\n加权求和"] V --> WV2 WV2 --> OUT["输出\n[B, T_x, dim]"] end style Input fill:#e3f2fd,stroke:#2196F3 style Proj fill:#fff3e0,stroke:#FF9800 style Attn fill:#e8f5e9,stroke:#4CAF50

7.2 Self-Attention vs Cross-Attention

区别只有一个:Q、K、V 的原材料从哪来。

Q 来自 K、V 来自 用途
Self-Attention 序列自身 x 序列自身 x 序列内部 token 互相交流
Cross-Attention 序列自身 x 另一个序列 context 从外部条件(文本/图像)读取信息

DreamZero 的每个 DiT Block 里两者都有:先 Self-Attention 让视频 token 内部交流,再 Cross-Attention 读取文本和图像条件。

graph LR subgraph ImagePath["图像编码路径"] VF["视频帧
[B, T_img, C, H, W]"] --> CLIP_ENC["CLIP ViT 编码
wan_video_image_encoder.py"] CLIP_ENC --> CLIP_FEAT["图像特征
[B, T_img, N_patch, 1536]"] end subgraph TextPath["文本编码路径"] PROMPT["语言提示
(string)"] --> T5_TOK["T5 分词器"] T5_TOK --> T5_ENC["umt5-xxl 编码
wan_video_text_encoder.py"] T5_ENC --> TXT_FEAT["文本特征
[B, T_txt, 1536]"] end subgraph Fusion["特征融合 → DiT"] CLIP_FEAT --> CAT["拼接为条件序列"] TXT_FEAT --> CAT CAT --> DIT_IN["输入 Causal WAN DiT"] end style ImagePath fill:#e3f2fd,stroke:#2196F3 style TextPath fill:#fff3e0,stroke:#FF9800 style Fusion fill:#e8f5e9,stroke:#4CAF50
0


7.3 为什么 K 不能等于 Q

K=Q 意味着 W_k = W_q,K 和 Q 是同一个向量。此时:

权重 = softmax(q · qᵀ / √d)   # x 和自身的相似度

权重退化成"找和自己像的",模型只能检索与自身相似的 token,无法找到"内容上互补但方向不同"的信息。

类比:图书馆里每本书的索引标签直接写的是"我想找机器人抓取的书"——所有书都在大喊需求,没有一本书在说"我能提供什么",检索系统完全失效。

K 和 Q 必须用不同的投影矩阵,让"被检索"和"去检索"解耦。


7.4 多头注意力(Multi-Head Attention)

多头就是把 dim 切成 h 份,并行跑 h 套独立的 Q/K/V:

head_i 输出 = softmax(Q_i · K_iᵀ / √d) · V_i
最终输出 = Concat(head_0, ..., head_h-1) · W_o

多头的价值来源于 Q 的多样性——每个头问不同的问题,从 context 里提取不同侧面的信息。

graph LR subgraph ImagePath["图像编码路径"] VF["视频帧
[B, T_img, C, H, W]"] --> CLIP_ENC["CLIP ViT 编码
wan_video_image_encoder.py"] CLIP_ENC --> CLIP_FEAT["图像特征
[B, T_img, N_patch, 1536]"] end subgraph TextPath["文本编码路径"] PROMPT["语言提示
(string)"] --> T5_TOK["T5 分词器"] T5_TOK --> T5_ENC["umt5-xxl 编码
wan_video_text_encoder.py"] T5_ENC --> TXT_FEAT["文本特征
[B, T_txt, 1536]"] end subgraph Fusion["特征融合 → DiT"] CLIP_FEAT --> CAT["拼接为条件序列"] TXT_FEAT --> CAT CAT --> DIT_IN["输入 Causal WAN DiT"] end style ImagePath fill:#e3f2fd,stroke:#2196F3 style TextPath fill:#fff3e0,stroke:#FF9800 style Fusion fill:#e8f5e9,stroke:#4CAF50
1

Q 不能共享,K/V 可以共享(MQA):

多K多V 单K单V
多Q(问题不同) 标准多头 MHA MQA(现代大模型常用)
单Q(问题相同) 无意义(退化) 普通单头

MQA(Multi-Query Attention)中所有 Q 头共享同一组 K/V,节省 KV Cache 显存,但不影响 Q 的多样性——40 个人用不同问题检索同一个书架,仍然能找出不同的重点。


7.5 W_q、W_k、W_v 各自如何被训练

三个矩阵的梯度路径不同,学到的东西也不同。

graph LR subgraph ImagePath["图像编码路径"] VF["视频帧
[B, T_img, C, H, W]"] --> CLIP_ENC["CLIP ViT 编码
wan_video_image_encoder.py"] CLIP_ENC --> CLIP_FEAT["图像特征
[B, T_img, N_patch, 1536]"] end subgraph TextPath["文本编码路径"] PROMPT["语言提示
(string)"] --> T5_TOK["T5 分词器"] T5_TOK --> T5_ENC["umt5-xxl 编码
wan_video_text_encoder.py"] T5_ENC --> TXT_FEAT["文本特征
[B, T_txt, 1536]"] end subgraph Fusion["特征融合 → DiT"] CLIP_FEAT --> CAT["拼接为条件序列"] TXT_FEAT --> CAT CAT --> DIT_IN["输入 Causal WAN DiT"] end style ImagePath fill:#e3f2fd,stroke:#2196F3 style TextPath fill:#fff3e0,stroke:#FF9800 style Fusion fill:#e8f5e9,stroke:#4CAF50
2

W_q 和 W_k 通过同一条路径(点积)互相塑造,W_v 的梯度路径完全独立,只由输出质量驱动。


7.6 QK Norm:为什么归一化能稳定训练

原始点积:

q · k = |q| × |k| × cos(θ)

模长和方向混在一起,W_q 和 W_k 需要同时控制两件事,容易震荡。

QK Norm 之后:

q_norm = q / |q|,  k_norm = k / |k|
q_norm · k_norm = cos(θ)     # 值域固定在 [-1, 1]

点积退化为纯余弦相似度,只剩方向这一个自由度,模长的干扰被彻底消除。W_q 和 W_k 只需要各自学"朝哪个方向",优化目标从三个因素(|q|、|k|、θ)变成一个(θ),训练更稳定。

DreamZero 代码中对应实现(wan2_1_submodule.py):

q = self.norm_q(self.q(x))        # QK Norm
k = self.norm_k(self.k(context))  # QK Norm