DreamZero (WAM) 模型架构
1. 整体架构概览
DreamZero 是一个世界动作模型 (World Action Model, WAM),通过联合预测未来视频帧与机器人动作序列,实现对未见任务的零样本泛化。其核心创新在于:基于 Wan2.1 视频扩散模型构建 Causal WAN DiT,以 Flow Matching 框架在视频 latent 空间和动作空间上同步去噪,配合多机器人体 Category-specific MLP 支持 DROID、AgiBot、YAM 等多种机器人平台,在 MolmoSpaces 和 RoboArena 双榜均位列第一(截至 2026 年 2 月)。
[B, 1, C, H, W]
320×176"] TXT["语言指令
(text prompt)"] EID["机器人体 ID
(embodiment_id)"] end subgraph Encoders["感知编码器"] CLIP["CLIP ViT 图像编码器
open-clip-xlm-roberta-large-vit-huge-14
→ 1536-dim 特征"] T5["umt5-xxl 文本编码器
→ 1536-dim 特征"] end subgraph VAE["视频 VAE (Wan2.1)"] ENC["VAE 编码器
观测帧像素 → 观测 latent
(仅压缩,无生成能力)"] DEC["VAE 解码器
预测 latent → 像素帧
(仅解压,可视化用)"] end subgraph Core["Causal WAN DiT (核心,真正的生成模型)"] DIT["40层因果扩散 Transformer
dim=5120, 40 heads
Flow Matching + RoPE + Flash Attn 3"] ACT_MLP["多机器人体 MLP
Category-specific Linear"] NOISE["纯高斯噪声
(未来帧初始状态)"] end subgraph Output["输出"] ACT_OUT["动作序列
[B, 24, action_dim]"] VID_OUT["可视化视频帧
[B, 33, H, W, 3]
(推理可选,不影响控制)"] end IMG --> CLIP TXT --> T5 IMG --> ENC CLIP -->|"视觉特征(条件)"| DIT T5 -->|"语言特征(条件)"| DIT ENC -->|"观测 latent(条件,填KV Cache)"| DIT EID --> ACT_MLP --> DIT NOISE -->|"去噪起点"| DIT DIT -->|"动作预测"| ACT_OUT DIT -->|"预测 latent(去噪生成)"| DEC --> VID_OUT style Input fill:#e8f4fd,stroke:#2196F3 style Encoders fill:#fff3e0,stroke:#FF9800 style VAE fill:#f3e5f5,stroke:#9C27B0 style Core fill:#e8f5e9,stroke:#4CAF50 style Output fill:#fce4ec,stroke:#E91E63
2. 核心组件详解
2.1 感知编码器
DreamZero 使用两个独立的预训练编码器将视觉和语言信息映射到统一的 1536-dim 特征空间:
- 图像编码器:
open-clip-xlm-roberta-large-vit-huge-14(CLIP ViT-H/14),对输入视频帧逐帧提取视觉特征 - 文本编码器:
umt5-xxl(Google 多语言 T5),将自然语言指令编码为序列特征
[B, T_img, C, H, W]"] --> CLIP_ENC["CLIP ViT 编码
wan_video_image_encoder.py"] CLIP_ENC --> CLIP_FEAT["图像特征
[B, T_img, N_patch, 1536]"] end subgraph TextPath["文本编码路径"] PROMPT["语言提示
(string)"] --> T5_TOK["T5 分词器"] T5_TOK --> T5_ENC["umt5-xxl 编码
wan_video_text_encoder.py"] T5_ENC --> TXT_FEAT["文本特征
[B, T_txt, 1536]"] end subgraph Fusion["特征融合 → DiT"] CLIP_FEAT --> CAT["拼接为条件序列"] TXT_FEAT --> CAT CAT --> DIT_IN["输入 Causal WAN DiT"] end style ImagePath fill:#e3f2fd,stroke:#2196F3 style TextPath fill:#fff3e0,stroke:#FF9800 style Fusion fill:#e8f5e9,stroke:#4CAF50
2.2 视频 VAE
沿用 Wan2.1 的视频 VAE,作为压缩/解压工具(无生成能力),让 DiT 在低维 latent 空间工作以降低计算量。VAE 编码器和解码器服务于两个完全不同的目的,输入输出的 latent 含义不同。
[B, T, H, W, 3]"] --> VAE_E["VAE 编码器
wan_video_vae.py"] VAE_E --> LAT["观测 latent(条件)
[B, 16, T, h, w]
用途:填入 KV Cache,作为 DiT 生成的历史条件"] end subgraph Decode["VAE 解码器(输出侧)"] LAT2["预测 latent(生成目标)
[B, 16, T, h, w]
DiT 从纯噪声去噪得到"] --> VAE_D["VAE 解码器"] VAE_D --> V_OUT["预测视频帧像素
[B, T, H, W, 3]
(仅用于可视化,不影响机器人控制)"] end subgraph Params["关键参数"] P1["latent 通道数: 16"] P2["时域压缩: 4×"] P3["空域压缩: 8×"] end style Encode fill:#e3f2fd,stroke:#2196F3 style Decode fill:#e8f5e9,stroke:#4CAF50 style Params fill:#fff9c4,stroke:#FFC107
2.3 Causal WAN DiT(核心扩散模型)
Causal WAN DiT 是 DreamZero 的核心,基于 Wan2.1 的 DiT 架构改造而来:40 层因果 Transformer,通过 Causal Masking 保证时序因果性,同时在 latent 序列中嵌入动作 token,实现视频与动作的联合生成。
[B, T_vid, 16]
(去噪目标,从纯噪声出发)"] AC["含噪动作
[B, 24, action_dim]"] VIS["视觉条件
[B, T_img, 1536]"] LNG["语言条件
[B, T_txt, 1536]"] TS["时间步 t"] VL --> ROPE["RoPE 位置编码"] AC --> ACT_PROJ["动作投影 MLP"] TS --> TIME_EMB["时间步嵌入
(正弦编码)"] end subgraph Block["单层 Causal DiT Block (×40)"] direction TB IN_B["输入 [B, T, 5120]"] --> LN1_B["RMSNorm"] LN1_B --> MOD1["Modulation
(shift, scale by timestep emb)"] MOD1 --> CATTN["因果自注意力
40 heads, Flash Attn 3
Causal Mask"] CATTN --> ADD1["+ 残差"] IN_B --> ADD1 ADD1 --> LN2_B["RMSNorm"] LN2_B --> XATTN["交叉注意力
(视觉 + 语言条件)"] XATTN --> ADD2["+ 残差"] ADD1 --> ADD2 ADD2 --> LN3_B["RMSNorm"] LN3_B --> FFN_B["FFN (SwiGLU)
dim 5120 → 13824 → 5120"] FFN_B --> ADD3["+ 残差"] ADD2 --> ADD3 ADD3 --> OUT_B["输出 [B, T, 5120]"] end subgraph Output["输出解码"] OUT_B --> SPLIT["分离 video / action token"] SPLIT --> VP["视频 latent 预测"] SPLIT --> AP["动作预测
[B, 24, action_dim]"] end ROPE --> Block ACT_PROJ --> Block TIME_EMB --> Block VIS --> Block LNG --> Block style InputSeq fill:#e3f2fd,stroke:#2196F3 style Block fill:#e8f5e9,stroke:#4CAF50 style Output fill:#fce4ec,stroke:#E91E63
因果注意力掩码设计: 视频帧 token 只能 attend 到过去帧和当前帧,动作 token attend 到所有视频 token(全局条件),保证生成的时序一致性。
2.4 多机器人体支持
DreamZero 通过每个机器人体独立的线性层(Category-specific MLP)将不同维度的动作空间映射到统一的 DiT 隐空间,支持在单一模型中处理多种机器人平台。
[B, 24, dim_droid]"] A_ACT["AgiBot 动作
[B, 24, dim_agibot]"] Y_ACT["YAM 动作
[B, 24, dim_yam]"] O_ACT["其他机器人体..."] D_ACT --> D_MLP["DROID Linear
dim_droid → 5120"] A_ACT --> A_MLP["AgiBot Linear
dim_agibot → 5120"] Y_ACT --> Y_MLP["YAM Linear
dim_yam → 5120"] O_ACT --> O_MLP["... Linear"] end subgraph Shared["共享 DiT 处理"] D_MLP --> UNIFIED["统一动作 token
[B, 24, 5120]"] A_MLP --> UNIFIED Y_MLP --> UNIFIED O_MLP --> UNIFIED UNIFIED --> DIT_PROC["Causal WAN DiT
(共享权重)"] end subgraph Decode["多机器人体动作解码"] DIT_PROC --> SPLIT_E["按 embodiment_id 路由"] SPLIT_E --> D_DEC["DROID Decoder Linear
5120 → dim_droid"] SPLIT_E --> A_DEC["AgiBot Decoder Linear
5120 → dim_agibot"] SPLIT_E --> Y_DEC["YAM Decoder Linear
5120 → dim_yam"] end style Embodiments fill:#e3f2fd,stroke:#2196F3 style Shared fill:#e8f5e9,stroke:#4CAF50 style Decode fill:#fce4ec,stroke:#E91E63
相对动作计算: 各机器人体在数据加载时将绝对动作转为相对值(action - reference_state),使模型学到的动作表示更加泛化。
3. 训练流水线
DreamZero 基于 Flow Matching 框架进行训练,同时对视频 latent 和动作 token 施加扩散损失。训练使用 DeepSpeed ZeRO-2 分布式优化,并通过 LoRA 高效微调 14B 参数基础模型。
· 视频 resize → 320×176
· 状态/动作归一化
· 相对动作计算
· 语言编码"] MOD --> BATCH_OUT["训练 Batch"] end subgraph Encode["编码阶段"] BATCH_OUT --> IMG_E["CLIP 图像编码
→ 视觉特征"] BATCH_OUT --> TXT_E["T5 文本编码
→ 语言特征"] BATCH_OUT --> VAE_E2["VAE 编码
视频 → latent"] end subgraph FlowMatch["Flow Matching 加噪"] T_SAMP["时间步采样
t ~ Uniform[0, 1000]"] NOISE2["高斯噪声 ε ~ N(0, I)"] ACT_GT["真实动作 / 视频 latent"] T_SAMP --> INTERP2["线性插值
x_t = ε·t/T + x_0·(1 - t/T)"] NOISE2 --> INTERP2 ACT_GT --> INTERP2 end subgraph Forward["前向传播"] IMG_E --> DIT_FWD["Causal WAN DiT"] TXT_E --> DIT_FWD VAE_E2 --> DIT_FWD INTERP2 --> DIT_FWD DIT_FWD --> PRED_ALL["联合预测
视频 latent + 动作"] end subgraph Loss["损失计算"] PRED_ALL --> L_VID["视频 Flow Matching Loss
(MSE on video latent)"] PRED_ALL --> L_ACT["动作 Flow Matching Loss
(MSE on action)"] L_VID --> L_TOTAL["总损失
L = λ_vid · L_vid + λ_act · L_act"] L_ACT --> L_TOTAL end subgraph Optim["优化"] L_TOTAL --> BACK["反向传播"] BACK --> LORA["LoRA 梯度更新
rank=4, alpha=4"] LORA --> OPT["AdamW
β₁=0.95, β₂=0.999
Cosine LR + Warmup"] OPT --> DS_ZERO["DeepSpeed ZeRO-2
梯度累积"] end DataLoad --> Encode Encode --> FlowMatch FlowMatch --> Forward Forward --> Loss Loss --> Optim style DataLoad fill:#e3f2fd,stroke:#2196F3 style Encode fill:#fff3e0,stroke:#FF9800 style FlowMatch fill:#f3e5f5,stroke:#9C27B0 style Forward fill:#e8f5e9,stroke:#4CAF50 style Loss fill:#fce4ec,stroke:#E91E63 style Optim fill:#e0f7fa,stroke:#00BCD4
LoRA 配置: 基础模型(Wan2.1,约 14B 参数)完全冻结,仅在 DiT 的注意力层插入 LoRA 适配器(rank=4, alpha=4),大幅降低显存占用。
4. 推理流水线
推理时模型从纯随机噪声出发,通过 Flow Matching 迭代去噪恢复动作序列和视频帧。DreamZero 支持通过 WebSocket 的分布式多 GPU 推理服务。
[B, 1, C, H, W]"] LANG_IN["语言指令"] X_NOISE["纯随机噪声
x_T ~ N(0, I)
(视频 latent + 动作)"] end subgraph EncoderOnce["编码(仅执行一次)"] OBS --> CLIP_INF["CLIP 编码"] LANG_IN --> T5_INF["T5 编码"] CLIP_INF --> COND["条件特征缓存"] T5_INF --> COND end subgraph DenoiseLoop["迭代去噪循环 (N 步)"] direction TB S_T["步骤 T: x_T → DiT → x̂₀"] S_T1["步骤 T-1: x_{T-1} = scheduler(x̂₀) → DiT → x̂₀"] SDOTS["..."] S_1["步骤 1: x₁ → DiT → x̂₀"] S_T --> S_T1 --> SDOTS --> S_1 end subgraph PostProcess["后处理"] S_1 --> SPLIT_OUT["分离 video latent / action"] SPLIT_OUT --> VAE_DEC2["VAE 解码 → 视频帧"] SPLIT_OUT --> ACT_POST["动作后处理
(反归一化, 相对→绝对)"] ACT_POST --> ACT_EXEC["执行前 N 步动作
(Action Chunking)"] end Init --> EncoderOnce COND --> DenoiseLoop X_NOISE --> DenoiseLoop DenoiseLoop --> PostProcess style Init fill:#e3f2fd,stroke:#2196F3 style EncoderOnce fill:#fff3e0,stroke:#FF9800 style DenoiseLoop fill:#e8f5e9,stroke:#4CAF50 style PostProcess fill:#fce4ec,stroke:#E91E63
分布式推理服务
发送观测 + 指令"] end subgraph Server["推理服务器 (socket_test_optimized_AR.py)"] WS_SRV["WebSocket 服务器
(Flask-SocketIO)"] REDIS["Redis
会话状态管理"] RAY_W["Ray Worker Pool
多 GPU 并行推理"] MODEL["DreamZero 模型
(GB200 / H100)"] WS_SRV --> REDIS WS_SRV --> RAY_W RAY_W --> MODEL end WS_CLI -->|"观测数据"| WS_SRV MODEL -->|"动作预测"| WS_SRV WS_SRV -->|"动作序列"| WS_CLI WS_CLI -->|"执行动作"| CLI style Client fill:#e3f2fd,stroke:#2196F3 style Server fill:#e8f5e9,stroke:#4CAF50
5. 关键超参数表
| 参数 | 值 | 说明 |
|---|---|---|
| 模型总参数 | ~14B | 基于 Wan2.1,LoRA 微调 |
| DiT 层数 | 40 | Causal WAN DiT |
| DiT hidden dim | 5120 | 每层隐藏层维度 |
| 注意力头数 | 40 | head_dim = 128 |
| Embedding dim | 1536 | 图像/文本编码器输出维度 |
| VAE latent 通道 | 16 | 视频压缩维度 |
| Action horizon | 24 | 单次预测动作步数 |
| 视频帧数 | 33 | 输入/预测帧数 |
| 图像分辨率 | 320×176 | 训练分辨率 |
| LoRA rank / alpha | 4 / 4 | 参数高效微调配置 |
| 批归一化策略 | DeepSpeed ZeRO-2 | 分布式训练 |
| 优化器 | AdamW | β₁=0.95, β₂=0.999 |
| 学习率调度 | Cosine + Warmup | 分布式多卡训练 |
| 精度 | bfloat16 | 混合精度训练 |
6. 关键源文件表
| 组件 | 文件路径(相对 /home/zhuyilong/dreamzero/) |
|---|---|
| 核心 VLA 模型 | groot/vla/model/dreamzero/base_vla.py |
| Action Head (Flow Matching) | groot/vla/model/dreamzero/action_head/wan_flow_matching_action_tf.py |
| Causal WAN DiT | groot/vla/model/dreamzero/modules/wan_video_dit_action_casual_chunk.py |
| 视频 VAE | groot/vla/model/dreamzero/modules/wan_video_vae.py |
| 文本编码器 | groot/vla/model/dreamzero/modules/wan_video_text_encoder.py |
| 图像编码器 | groot/vla/model/dreamzero/modules/wan_video_image_encoder.py |
| 数据变换 | groot/vla/model/dreamzero/transform/ |
| 数据集加载 | groot/vla/data/dataset/lerobot.py |
| 分片数据集 | groot/vla/data/dataset/lerobot_sharded.py |
| 训练器 | groot/vla/experiment/experiment.py |
| 训练基类 | groot/vla/experiment/base.py |
| 分布式推理服务 | socket_test_optimized_AR.py |
| 推理客户端 | test_client_AR.py |
| Hydra 模型配置 | groot/vla/configs/model/ |
| 训练脚本 | scripts/train/ |
7. Attention 机制深入解析
7.1 Q / K / V 的本质分工
Attention 的计算分四步:
1. q = W_q(x) # x 投影成 Q
2. k = W_k(context) # context 投影成 K
3. v = W_v(context) # context 投影成 V(同一原材料,不同矩阵)
4. 权重 = softmax(q · kᵀ / √d)
5. 输出 = 权重 · v
K 和 V 的原材料相同(都来自 context),但投影矩阵不同,作用完全不同:
| 矩阵 | 参与哪一步 | 对输出的影响 | 学到的是 |
|---|---|---|---|
| W_q | 步骤 4(点积) | 间接(通过权重) | 我该去找什么样的 K |
| W_k | 步骤 4(点积) | 间接(通过权重) | 什么样的 Q 应该找到我 |
| W_v | 步骤 5(加权求和) | 直接(构成输出) | 被选中后我该提供什么内容 |
K 是抽象意义上的索引——它不直接出现在输出里,只决定权重分布,告诉模型"哪些位置比较重要"。V 才是真正被读取的内容。
7.2 Self-Attention vs Cross-Attention
区别只有一个:Q、K、V 的原材料从哪来。
| Q 来自 | K、V 来自 | 用途 | |
|---|---|---|---|
| Self-Attention | 序列自身 x | 序列自身 x | 序列内部 token 互相交流 |
| Cross-Attention | 序列自身 x | 另一个序列 context | 从外部条件(文本/图像)读取信息 |
DreamZero 的每个 DiT Block 里两者都有:先 Self-Attention 让视频 token 内部交流,再 Cross-Attention 读取文本和图像条件。
[B, T_img, C, H, W]"] --> CLIP_ENC["CLIP ViT 编码
wan_video_image_encoder.py"] CLIP_ENC --> CLIP_FEAT["图像特征
[B, T_img, N_patch, 1536]"] end subgraph TextPath["文本编码路径"] PROMPT["语言提示
(string)"] --> T5_TOK["T5 分词器"] T5_TOK --> T5_ENC["umt5-xxl 编码
wan_video_text_encoder.py"] T5_ENC --> TXT_FEAT["文本特征
[B, T_txt, 1536]"] end subgraph Fusion["特征融合 → DiT"] CLIP_FEAT --> CAT["拼接为条件序列"] TXT_FEAT --> CAT CAT --> DIT_IN["输入 Causal WAN DiT"] end style ImagePath fill:#e3f2fd,stroke:#2196F3 style TextPath fill:#fff3e0,stroke:#FF9800 style Fusion fill:#e8f5e9,stroke:#4CAF50
7.3 为什么 K 不能等于 Q
K=Q 意味着 W_k = W_q,K 和 Q 是同一个向量。此时:
权重 = softmax(q · qᵀ / √d) # x 和自身的相似度
权重退化成"找和自己像的",模型只能检索与自身相似的 token,无法找到"内容上互补但方向不同"的信息。
类比:图书馆里每本书的索引标签直接写的是"我想找机器人抓取的书"——所有书都在大喊需求,没有一本书在说"我能提供什么",检索系统完全失效。
K 和 Q 必须用不同的投影矩阵,让"被检索"和"去检索"解耦。
7.4 多头注意力(Multi-Head Attention)
多头就是把 dim 切成 h 份,并行跑 h 套独立的 Q/K/V:
head_i 输出 = softmax(Q_i · K_iᵀ / √d) · V_i
最终输出 = Concat(head_0, ..., head_h-1) · W_o
多头的价值来源于 Q 的多样性——每个头问不同的问题,从 context 里提取不同侧面的信息。
[B, T_img, C, H, W]"] --> CLIP_ENC["CLIP ViT 编码
wan_video_image_encoder.py"] CLIP_ENC --> CLIP_FEAT["图像特征
[B, T_img, N_patch, 1536]"] end subgraph TextPath["文本编码路径"] PROMPT["语言提示
(string)"] --> T5_TOK["T5 分词器"] T5_TOK --> T5_ENC["umt5-xxl 编码
wan_video_text_encoder.py"] T5_ENC --> TXT_FEAT["文本特征
[B, T_txt, 1536]"] end subgraph Fusion["特征融合 → DiT"] CLIP_FEAT --> CAT["拼接为条件序列"] TXT_FEAT --> CAT CAT --> DIT_IN["输入 Causal WAN DiT"] end style ImagePath fill:#e3f2fd,stroke:#2196F3 style TextPath fill:#fff3e0,stroke:#FF9800 style Fusion fill:#e8f5e9,stroke:#4CAF50
Q 不能共享,K/V 可以共享(MQA):
| 多K多V | 单K单V | |
|---|---|---|
| 多Q(问题不同) | 标准多头 MHA | MQA(现代大模型常用) |
| 单Q(问题相同) | 无意义(退化) | 普通单头 |
MQA(Multi-Query Attention)中所有 Q 头共享同一组 K/V,节省 KV Cache 显存,但不影响 Q 的多样性——40 个人用不同问题检索同一个书架,仍然能找出不同的重点。
7.5 W_q、W_k、W_v 各自如何被训练
三个矩阵的梯度路径不同,学到的东西也不同。
[B, T_img, C, H, W]"] --> CLIP_ENC["CLIP ViT 编码
wan_video_image_encoder.py"] CLIP_ENC --> CLIP_FEAT["图像特征
[B, T_img, N_patch, 1536]"] end subgraph TextPath["文本编码路径"] PROMPT["语言提示
(string)"] --> T5_TOK["T5 分词器"] T5_TOK --> T5_ENC["umt5-xxl 编码
wan_video_text_encoder.py"] T5_ENC --> TXT_FEAT["文本特征
[B, T_txt, 1536]"] end subgraph Fusion["特征融合 → DiT"] CLIP_FEAT --> CAT["拼接为条件序列"] TXT_FEAT --> CAT CAT --> DIT_IN["输入 Causal WAN DiT"] end style ImagePath fill:#e3f2fd,stroke:#2196F3 style TextPath fill:#fff3e0,stroke:#FF9800 style Fusion fill:#e8f5e9,stroke:#4CAF50
W_q 和 W_k 通过同一条路径(点积)互相塑造,W_v 的梯度路径完全独立,只由输出质量驱动。
7.6 QK Norm:为什么归一化能稳定训练
原始点积:
q · k = |q| × |k| × cos(θ)
模长和方向混在一起,W_q 和 W_k 需要同时控制两件事,容易震荡。
QK Norm 之后:
q_norm = q / |q|, k_norm = k / |k|
q_norm · k_norm = cos(θ) # 值域固定在 [-1, 1]
点积退化为纯余弦相似度,只剩方向这一个自由度,模长的干扰被彻底消除。W_q 和 W_k 只需要各自学"朝哪个方向",优化目标从三个因素(|q|、|k|、θ)变成一个(θ),训练更稳定。
DreamZero 代码中对应实现(wan2_1_submodule.py):
q = self.norm_q(self.q(x)) # QK Norm
k = self.norm_k(self.k(context)) # QK Norm