Pi0-FAST 模型架构

1. 整体架构概览

Pi0-FAST (Fast Action Sequence Tokenizer) 是 Physical Intelligence 提出的一种高效视觉-语言-动作 (VLA) 模型。与 Pi0 使用 Flow Matching 扩散生成连续动作不同,Pi0-FAST 将连续动作离散化为 token,然后通过 PaliGemma 视觉语言模型进行 自回归 next-token 预测。这种设计避免了多步去噪推理,实现了更快的动作生成。

核心思路:连续动作 → DCT 变换 → BPE 分词 → 离散 token → 自回归预测

graph TB subgraph Input["输入"] IMG["相机图像
(pixel_values)"] TXT["语言指令
(input_ids)"] ACT_TRAIN["动作序列
[B, chunk_size, action_dim]
(仅训练)"] end subgraph Tokenizer["FAST 动作分词器"] DCT["DCT 变换
连续动作 → 频域系数"] BPE["BPE 分词
频域系数 → 离散 token ID"] end subgraph VLM["PaliGemma 视觉语言模型"] VE["SigLIP 视觉编码器"] EMB["词元嵌入层
(共享: 语言 + 动作 token)"] GEMMA["Gemma 2B
18 层 Transformer"] LM_HEAD["LM Head
(Linear → vocab logits)"] end subgraph Output["输出"] LOGITS["下一个 token 预测
(训练: 交叉熵损失)"] ACTIONS["预测动作序列
(推理: 自回归解码 → IDCT)"] end IMG --> VE --> GEMMA TXT --> EMB --> GEMMA ACT_TRAIN --> DCT --> BPE --> EMB GEMMA --> LM_HEAD LM_HEAD --> LOGITS LM_HEAD --> ACTIONS style Input fill:#e8f4fd,stroke:#2196F3 style Tokenizer fill:#fff3e0,stroke:#FF9800 style VLM fill:#e8f5e9,stroke:#4CAF50 style Output fill:#fce4ec,stroke:#E91E63

2. 核心组件详解

2.1 PaliGemma 视觉语言模型

Pi0-FAST 使用 PaliGemma(SigLIP + Gemma)作为统一的视觉语言主干。与 Pi0 不同,Pi0-FAST 没有独立的 Action Expert——动作 token 直接复用语言模型的词元嵌入和 LM Head 进行预测。

graph LR subgraph Vision["SigLIP 视觉编码器"] PIX["pixel_values
[B, C, H, W]"] --> SIG["SigLIP
视觉 Transformer"] SIG --> IMG_EMB["图像嵌入
[B, N_patches, 2048]"] end subgraph Language["词元嵌入(共享)"] LANG_TOK["语言 token
[B, seq_len]"] --> EMBED["embed_tokens()"] ACT_TOK["动作 token
[B, num_tokens]"] --> EMBED EMBED --> LANG_EMB["嵌入 × √dim
[B, *, 2048]"] end subgraph LM["Gemma 2B (18 层)"] direction TB L0["Layer 0"] --> L1["Layer 1~16"] --> L17["Layer 17"] end IMG_EMB --> LM LANG_EMB --> LM LM --> HEAD["LM Head
2048 → 257152 (vocab)"] style Vision fill:#e3f2fd,stroke:#2196F3 style Language fill:#f3e5f5,stroke:#9C27B0 style LM fill:#e8f5e9,stroke:#4CAF50

关键设计: - 动作 token 和语言 token 共享同一个嵌入层 (embed_tokens) - 动作 token 和语言 token 共享同一个 LM Head 进行预测 - 嵌入后乘以 √hidden_dim 进行缩放(与 Gemma 标准做法一致) - 无需额外的状态投影或动作投影层

2.2 FAST 动作分词器 (Action Tokenizer)

FAST 分词器将连续动作序列转换为离散 token,是 Pi0-FAST 的核心创新。

graph LR subgraph Encode["编码(训练时)"] A["连续动作
[B, T, action_dim]"] --> DCT["DCT 变换
(离散余弦变换)"] DCT --> COEFF["频域系数
(压缩表示)"] COEFF --> BPE["BPE 分词
(字节对编码)"] BPE --> TOKENS["离散 token ID
[B, num_tokens]"] end subgraph Decode["解码(推理时)"] PRED_TOK["预测 token ID"] --> DE_BPE["BPE 解码"] DE_BPE --> IDCT["IDCT 逆变换
(scipy.fftpack.idct)"] IDCT --> PRED_A["连续动作
[B, T, action_dim]"] end style Encode fill:#e8f5e9,stroke:#4CAF50 style Decode fill:#fff3e0,stroke:#FF9800

DCT 变换的优势: - 将时间域动作序列转换为频域系数,实现 有损压缩 - 低频成分(平滑运动)用少量系数即可表示 - 天然规避了连续值域的归一化和 padding 不稳定问题 - BPE 分词进一��将系数编码为可预测的离散 token

Token 空间: - 总词表大小: 257,152(Gemma 原始词表) - fast_skip_tokens: 128(跳过前 128 个 token) - 最大动作 token 数: 256(max_action_tokens

2.3 自定义注意力掩码

Pi0-FAST 使用精心设计的注意力掩码,确保图像/语言和动作 token 之间正确的信息流动。

graph TB subgraph Mask["注意力掩码结构"] direction LR subgraph QueryAxis["Query (行)"] QI["图像 token"] QL["语言 token"] QA["动作 token"] end subgraph Rules["注意力规则"] R1["图像 ↔ 图像: ✅ 双向"] R2["图像 ↔ 语言: ✅ 双向"] R3["语言 ↔ 语言: ✅ 双向"] R4["动作 → 图像: ✅ 允许"] R5["动作 → 语言: ✅ 允许"] R6["图像 → 动作: ❌ 禁止"] R7["语言 → 动作: ❌ 禁止"] R8["动作 ↔ 动作: ⬇️ 因果"] end end style Mask fill:#e0f2f1,stroke:#009688
graph LR subgraph AttentionMatrix["注意力矩阵可视化"] M[" 图像 语言 动作
图像 [✅✅✅ ✅✅ ❌❌❌]
语言 [✅✅✅ ✅✅ ❌❌❌]
动作₁ [✅✅✅ ✅✅ ✅❌❌]
动作₂ [✅✅✅ ✅✅ ✅✅❌]
动作₃ [✅✅✅ ✅✅ ✅✅✅]"] end style AttentionMatrix fill:#f5f5f5,stroke:#9E9E9E

实现细节 (_create_custom_attention_mask_fast): 1. 初始化全零布尔掩码 [B, total_len, total_len] 2. 图像/语言段之间:设为 True(双向注意力) 3. 动作段 → 图像/语言段:设为 True(动作可看到观测) 4. 动作段内部:设下三角 True(因果注意力,自回归生成) 5. 最后与 padding mask 相交:att_masks &= pad_2d_masks


3. 训练流水线

graph TB subgraph DataPrep["数据准备"] RAW_A["原始连续动作
[B, chunk_size, action_dim]"] RAW_A --> NORM["归一化
(MEAN_STD)"] NORM --> PAD["零填充至 max_action_dim=32"] PAD --> FAST_ENC["FAST Tokenizer 编码
DCT → BPE"] FAST_ENC --> ACT_TOKENS["动作 token ID
[B, max_action_tokens]"] RAW_IMG["原始图像
[B, n_cameras, C, H, W]"] RAW_IMG --> RESIZE["resize_with_pad
→ 224×224"] RAW_LANG["语言指令
(文本)"] RAW_LANG --> TOK["PaliGemma Tokenizer
max_length=200"] TOK --> LANG_TOKENS["语言 token ID
[B, seq_len]"] end subgraph Embedding["嵌入构建"] RESIZE --> VIS_EMB["SigLIP 编码
→ 图像嵌入"] LANG_TOKENS --> LANG_EMB["embed_tokens × √dim
→ 语言嵌入"] ACT_TOKENS --> ACT_EMB["embed_tokens × √dim
→ 动作嵌入"] VIS_EMB --> CONCAT["拼接
[图像, 语言, 动作]"] LANG_EMB --> CONCAT ACT_EMB --> CONCAT end subgraph Forward["前向传播"] CONCAT --> ATT_MASK["构建注意力掩码
(双向 + 因果)"] ATT_MASK --> GEMMA["Gemma 2B
18 层 Transformer"] GEMMA --> HIDDEN["隐藏状态
[B, total_len, 2048]"] end subgraph Loss["损失计算"] HIDDEN --> EXTRACT["提取动作位置
hidden[:, -num_fast:]"] EXTRACT --> LM_HEAD["LM Head
→ vocab logits"] LM_HEAD --> SHIFT["左移 logits / 右移 targets
(next-token prediction)"] SHIFT --> CE["CrossEntropyLoss
(带 padding mask)"] CE --> LOSS["fast_loss = masked_sum / mask_count"] end style DataPrep fill:#e3f2fd,stroke:#2196F3 style Embedding fill:#f3e5f5,stroke:#9C27B0 style Forward fill:#e8f5e9,stroke:#4CAF50 style Loss fill:#fce4ec,stroke:#E91E63

训练目标详解:

输入 token:  [img_1, ..., img_N, lang_1, ..., lang_M, act_1, act_2, ..., act_{T-1}]
目标 token:  [                                         act_2, act_3, ..., act_T    ]

4. 推理流水线

Pi0-FAST 支持两种推理模式:无缓存KV 缓存 加速。

4.1 KV 缓存推理(默认,高效)

graph TB subgraph Prefill["阶段一: Prefill(一次性)"] IMG["图像"] --> VIS["SigLIP 编码"] LANG["语言 + BOS"] --> EMB["词元嵌入"] VIS --> CAT["拼接 [图像, 语言, BOS]"] EMB --> CAT CAT --> FWD1["Gemma 前向
use_cache=True"] FWD1 --> KV["KV Cache
(18 层缓存)"] FWD1 --> FIRST["LM Head → 采样
→ 第一个动作 token"] end subgraph Decode["阶段二: 自回归解码(逐 token)"] FIRST --> LOOP["循环 t = 1 → max_decoding_steps"] LOOP --> EMB_NEW["嵌入上一个 token
[B, 1, 2048]"] EMB_NEW --> FWD2["Gemma 前向
+ KV Cache
只处理 1 个 token"] KV --> FWD2 FWD2 --> SAMPLE["LM Head → 采样"] SAMPLE -->|"temperature=0: argmax
temperature>0: multinomial"| NEXT["下一个 token"] NEXT --> LOOP end subgraph PostProcess["后处理"] LOOP -->|"生成完毕"| ALL_TOKENS["所有动作 token
[B, max_decoding_steps]"] ALL_TOKENS --> DEBPE["BPE 解码"] DEBPE --> IDCT["IDCT 逆变换"] IDCT --> FINAL["连续动作
[B, chunk_size, action_dim]"] end style Prefill fill:#e3f2fd,stroke:#2196F3 style Decode fill:#e8f5e9,stroke:#4CAF50 style PostProcess fill:#fff3e0,stroke:#FF9800

KV 缓存优化要点: - Prefill 阶段:处理完整前缀序列,生成 KV Cache - 解码阶段:每步只处理 1 个新 token,重用缓存的 K/V - 注意力掩码:新 token 可注意到所有历���有效 token - 复杂度从 O(T²) 降低到 O(T)(T = 总序列长度)

4.2 无缓存推理(简单但慢)

graph LR subgraph NaiveLoop["逐步完整前向传播"] T0["t=0: [img, lang] → Gemma → token₁"] T1["t=1: [img, lang, token₁] → Gemma → token₂"] T2["t=2: [img, lang, token₁, token₂] → Gemma → token₃"] DOTS["..."] TN["t=N: [img, lang, token₁...N] → Gemma → token_{N+1}"] T0 --> T1 --> T2 --> DOTS --> TN end style NaiveLoop fill:#ffebee,stroke:#f44336

5. 与 Pi0 的关键区别

特性 Pi0 Pi0-FAST
动作表示 连续值 (Flow Matching) 离散 token (FAST Tokenizer)
动作生成方式 迭代去噪 (10 步) 自回归 next-token 预测
Action Expert 独立 Gemma 300M 无(复用 PaliGemma)
模型结构 双流 (PaliGemma + Expert) 单流 (仅 PaliGemma)
训练损失 MSE (flow matching) CrossEntropy (next-token)
推理速度 10 次完整前向传播 ~256 次轻量解码步 (KV cache)
状态输入 投影到隐藏维 → suffix 不直接使用(通过语言描述)
归一化敏感性 高(MEAN_STD + padding 不稳定) 低(token 化天然规避)
Padding 问题 严重(30 维零值影响 flow matching) 无(token 化后无连续 padding)
graph LR subgraph Pi0Path["Pi0: Flow Matching 路径"] P_IN["连续动作"] --> P_NOISE["加噪 x_t"] --> P_DIT["双流 Transformer
× 10 步去噪"] --> P_OUT["连续动作"] end subgraph FastPath["Pi0-FAST: Token 路径"] F_IN["连续动作"] --> F_DCT["DCT"] --> F_BPE["BPE"] --> F_AR["单流 Transformer
自回归解码"] --> F_IDCT["IDCT"] --> F_OUT["连续动作"] end style Pi0Path fill:#e3f2fd,stroke:#2196F3 style FastPath fill:#e8f5e9,stroke:#4CAF50

6. 关键超参数

模型结构

参数 说明
paligemma_variant gemma_2b PaliGemma 使用 Gemma 2B
Gemma 2B width 2048 隐藏层维度
Gemma 2B depth 18 Transformer 层数
Gemma 2B mlp_dim 16,384 FFN 中间层维度
Gemma 2B num_heads 8 注意力头数
Gemma 2B head_dim 256 每头维度
Gemma 2B num_kv_heads 1 GQA KV 头数
词表大小 257,152 Gemma 原始词表

动作空间

参数 说明
chunk_size 50 预测动作步数
n_action_steps 50 执行动作步数
max_state_dim 32 状态向量填充维度
max_action_dim 32 动作向量填充维度
max_action_tokens 256 最大离散动作 token 数
fast_skip_tokens 128 跳过词表前 128 个 token

推理参数

参数 说明
temperature 0.0 采样温度 (0=argmax)
max_decoding_steps 256 最大自回归步数
use_kv_cache True 启用 KV 缓存加速
validate_action_token_prefix True 验证 "Action: " 前缀

训练参数

参数 说明
optimizer_lr 2.5e-5 AdamW 学习率
optimizer_betas (0.9, 0.95) Adam β 参数
optimizer_weight_decay 0.01 权重衰减
optimizer_grad_clip_norm 1.0 梯度裁剪阈值
scheduler_warmup_steps 1,000 预热步数
scheduler_decay_steps 30,000 余弦衰减总步数
scheduler_decay_lr 2.5e-6 最终学习率
tokenizer_max_length 200 语言 token 最大长度
image_resolution (224, 224) 输入图像分辨率

7. 关键源文件

组件 类名 文件
策略封装 PI0FastPolicy lerobot/policies/pi0_fast/modeling_pi0_fast.py:804
核心模型 PI0FastPytorch lerobot/policies/pi0_fast/modeling_pi0_fast.py:278
PaliGemma 封装 PI0FastPaliGemma lerobot/policies/pi0_fast/modeling_pi0_fast.py:186
Gemma 配置 GemmaConfig / get_gemma_config lerobot/policies/pi0_fast/modeling_pi0_fast.py:150
配置类 PI0FastConfig lerobot/policies/pi0_fast/configuration_pi0_fast.py:31
嵌入构建 embed_prefix_fast() lerobot/policies/pi0_fast/modeling_pi0_fast.py:355
注意力掩码 _create_custom_attention_mask_fast() lerobot/policies/pi0_fast/modeling_pi0_fast.py:443
训练前向 forward() lerobot/policies/pi0_fast/modeling_pi0_fast.py:484
推理 (无缓存) sample_actions_fast() lerobot/policies/pi0_fast/modeling_pi0_fast.py:583
推理 (KV 缓存) sample_actions_fast_kv_cache() lerobot/policies/pi0_fast/modeling_pi0_fast.py:680
图像预处理 resize_with_pad_torch() lerobot/policies/pi0_fast/modeling_pi0_fast.py:76
FAST 分词器 physical-intelligence/fast HuggingFace Hub (外部)
文本分词器 google/paligemma-3b-pt-224 HuggingFace Hub (外部)