高级 · 4-6周

项目六:模仿学习桌面操作

在PyBullet中采集遥操作数据,训练并对比BC、ACT与Diffusion Policy三种模仿学习方法

高级难度 预计4-6周 9个章节

1. 项目概述

在PyBullet仿真中,用模仿学习训练Franka Panda机械臂完成桌面操作任务:

  • 方块推送到目标位置
  • 方块堆叠
  • 零件插入
  • 对比BC、ACT和Diffusion Policy三种方法
  • 2. 难度:★★★★☆ (4/5)

    3. 预估时间:4-6周

    ---


    4. 1. 项目结构

    imitation_learning/
    ├── envs/
    │   ├── push_env.py        # 推送任务环境
    │   ├── stack_env.py       # 堆叠任务环境
    │   └── insertion_env.py   # 插入任务环境
    ├── data_collection/
    │   ├── keyboard_teleop.py # 键盘遥操作
    │   └── record_episodes.py # 数据记录
    ├── policies/
    │   ├── bc_policy.py       # 行为克隆
    │   ├── act_policy.py      # ACT实现
    │   └── diffusion_policy.py # Diffusion Policy
    ├── training/
    │   ├── train.py           # 训练脚本
    │   └── config.py          # 配置
    ├── evaluation/
    │   ├── evaluate.py         # 评估脚本
    │   └── compare_methods.py  # 方法对比
    └── utils/
        ├── dataset.py          # 数据集加载
        └── visualization.py    # 可视化

    ---


    5. 2. PyBullet 操作环境

    "kw">class="kw">import pybullet "kw">as p
    "kw">class="kw">import pybullet_data
    "kw">class="kw">import numpy "kw">as np
    
    "kw">class PushEnv:
        """方块推送任务环境"""
    
        OBS_DIM = 30  # 状态维度
        ACT_DIM = 3   # 末端增量 (dx, dy, dz)
    
        "kw">def __init__("kw">self, render="kw">False, max_steps=200):
            "kw">self.render_mode = render
            "kw">self.max_steps = max_steps
    
            "kw">if render:
                "kw">self.client = p.connect(p.GUI)
            "kw">else:
                "kw">self.client = p.connect(p.DIRECT)
    
            p.setAdditionalSearchPath(pybullet_data.getDataPath())
            p.setGravity(0, 0, -9.81)
            p.setTimeStep(1/240)
    
            "kw">self._load_scene()
    
        "kw">def _load_scene("kw">self):
            """加载场景"""
            # 地面 + 桌子
            "kw">self.plane = p.loadURDF("plane.urdf")
            "kw">self.table = p.loadURDF("table/table.urdf",
                                     [0.5, 0, 0], [0, 0, 0, 1])
    
            # Franka Panda
            "kw">self.robot = p.loadURDF("franka_panda/panda.urdf",
                                     [0, 0, 0.62], [0, 0, 0, 1],
                                     useFixedBase="kw">True)
    
            # 获取关节索引
            "kw">self.arm_joints = []
            "kw">self.gripper_joints = []
            "kw">for i "kw">in "kw">range(p.getNumJoints("kw">self.robot)):
                info = p.getJointInfo("kw">self.robot, i)
                name = info[1].decode()
                "kw">if 'finger' "kw">in name:
                    "kw">self.gripper_joints.append(i)
                "kw">elif info[2] != p.JOINT_FIXED:
                    "kw">self.arm_joints.append(i)
    
            "kw">self.ee_idx = "kw">self.arm_joints[-1] + 1  # 末端连杆索引
            "kw">self.n_arm = "kw">len("kw">self.arm_joints)
    
        "kw">def reset("kw">self, goal_pos="kw">None):
            """重置环境"""
            # 随机目标位置
            "kw">if goal_pos "kw">is "kw">None:
                "kw">self.goal_pos = np.array([
                    np.random.uniform(0.3, 0.7),
                    np.random.uniform(-0.3, 0.3),
                    0.63,   # 桌子高度
                ])
            "kw">else:
                "kw">self.goal_pos = goal_pos
    
            # 随机方块初始位置(不与目标重合)
            "kw">while "kw">True:
                "kw">self.block_pos = np.array([
                    np.random.uniform(0.3, 0.7),
                    np.random.uniform(-0.3, 0.3),
                    0.63,
                ])
                "kw">if np.linalg.norm("kw">self.block_pos - "kw">self.goal_pos) > 0.15:
                    "kw">break
    
            # 生成方块
            "kw">self.block = p.loadURDF("cube_small.urdf",
                                     "kw">self.block_pos, [0, 0, 0, 1],
                                     globalScaling=0.5)
    
            # 可视化目标
            p.addUserDebugLine("kw">self.goal_pos - [0.05, 0, 0],
                              "kw">self.goal_pos + [0.05, 0, 0],
                              [0, 1, 0], 3, lifeTime=0)
            p.addUserDebugLine("kw">self.goal_pos - [0, 0.05, 0],
                              "kw">self.goal_pos + [0, 0.05, 0],
                              [0, 1, 0], 3, lifeTime=0)
    
            # 重置机器人到默认姿态
            "kw">self._reset_arm()
    
            "kw">self.step_count = 0
            "kw">class="kw">return "kw">self._get_obs()
    
        "kw">def _reset_arm("kw">self):
            """将机械臂恢复到默认姿态"""
            default_q = [0, -0.5, 0, -2.0, 0, 1.5, 0.8]
            "kw">for i, q "kw">in zip("kw">self.arm_joints, default_q):
                p.resetJointState("kw">self.robot, i, q)
            "kw">for _ "kw">in "kw">range(100):
                p.stepSimulation()
    
        "kw">def _get_obs("kw">self):
            """获取观测"""
            # 末端执行器状态
            ee_state = p.getLinkState("kw">self.robot, "kw">self.ee_idx)
    
            # 关节状态
            joint_states = p.getJointStates("kw">self.robot, "kw">self.arm_joints)
            joint_pos = np.array([s[0] "kw">for s "kw">in joint_states])
            joint_vel = np.array([s[1] "kw">for s "kw">in joint_states])
    
            # 方块位置
            block_pos, _ = p.getBasePositionAndOrientation("kw">self.block)
    
            obs = np.concatenate([
                ee_state[0],          # 末端位置 (3)
                ee_state[1],          # 末端朝向四元数 (4)
                joint_pos,             # 关节位置 (7)
                joint_vel,             # 关节速度 (7)
                block_pos,             # 方块位置 (3)
                "kw">self.goal_pos,         # 目标位置 (3)
                "kw">self.goal_pos - np.array(block_pos),  # 差值 (3)
            ])
            "kw">class="kw">return obs.astype(np.float32)
    
        "kw">def step("kw">self, action):
            """
            action: [dx, dy, dz] 末端增量 (世界坐标系)
            """
            "kw">self.step_count += 1
    
            # 获取当前末端位姿
            ee_state = p.getLinkState("kw">self.robot, "kw">self.ee_idx)
            ee_pos = np.array(ee_state[0])
    
            # 目标末端位置 = 当前位置 + 增量
            target_pos = ee_pos + np.array(action) * 0.02  # 步长缩放
            target_pos = np.clip(target_pos, [0.3, -0.3, 0.63],
                                 [0.7, 0.3, 1.0])
    
            # 逆运动学求解
            joint_poses = p.calculateInverseKinematics(
                "kw">self.robot, "kw">self.ee_idx, target_pos
            )
    
            # 位置控制
            "kw">for i, q "kw">in zip("kw">self.arm_joints, joint_poses[:"kw">self.n_arm]):
                p.setJointMotorControl2(
                    "kw">self.robot, i, p.POSITION_CONTROL,
                    targetPosition=q,
                    force=500,
                    maxVelocity=1.0
                )
    
            # 仿真步进
            "kw">for _ "kw">in "kw">range(20):  # 20步 = 约83ms @ 240Hz
                p.stepSimulation()
    
            # 计算奖励
            block_pos, _ = p.getBasePositionAndOrientation("kw">self.block)
            dist = np.linalg.norm(np.array(block_pos) - "kw">self.goal_pos)
    
            reward = -dist
            done = dist < 0.03
            truncated = "kw">self.step_count >= "kw">self.max_steps
    
            "kw">class="kw">return "kw">self._get_obs(), reward, done, truncated, {}
    
        "kw">def close("kw">self):
            p.disconnect("kw">self.client)

    ---


    6. 3. 数据采集(键盘遥操作)

    "kw">class="kw">import pygame
    "kw">class="kw">import numpy "kw">as np
    "kw">class="kw">import pickle
    
    "kw">class KeyboardTeleop:
        """键盘遥操作数据采集"""
    
        KEY_MAP = {
            pygame.K_w: np.array([1, 0, 0]),     # +x (远离)
            pygame.K_s: np.array([-1, 0, 0]),    # -x (靠近)
            pygame.K_a: np.array([0, 1, 0]),     # +y (左)
            pygame.K_d: np.array([0, -1, 0]),    # -y (右)
            pygame.K_q: np.array([0, 0, 1]),     # +z (上)
            pygame.K_e: np.array([0, 0, -1]),    # -z (下)
        }
    
        "kw">def __init__("kw">self, env, save_dir='demos/'):
            "kw">self.env = env
            "kw">self.save_dir = save_dir
            "kw">self.episodes = []
            "kw">self.recording = "kw">False
    
        "kw">def collect_episodes("kw">self, n_episodes=50):
            """采集n条演示数据"""
            pygame.init()
            screen = pygame.display.set_mode((300, 200))
            pygame.display.set_caption("Press SPACE to start episode, ESC to quit")
    
            clock = pygame.time.Clock()
            episode_idx = 0
    
            "kw">while episode_idx < n_episodes:
                "kw">for event "kw">in pygame.event.get():
                    "kw">if event.type == pygame.QUIT:
                        "kw">class="kw">return
    
                    "kw">if event.type == pygame.KEYDOWN:
                        "kw">if event.key == pygame.K_ESCAPE:
                            "kw">class="kw">return
    
                        "kw">if event.key == pygame.K_SPACE:
                            "kw">if "kw">not "kw">self.recording:
                                "kw">self._start_episode(episode_idx)
                                "kw">print(f"采集第 {episode_idx+1} 条演示...")
                            "kw">else:
                                "kw">self._end_episode()
                                episode_idx += 1
                                "kw">print(f"完成第 {episode_idx} 条演示")
    
                "kw">if "kw">self.recording:
                    obs = "kw">self.env._get_obs()
                    action = "kw">self._get_action()
    
                    "kw">self.env.step(action)
                    "kw">self._record_step(obs, action)
    
                    done = "kw">self._check_done()
                    "kw">if done:
                        "kw">self._end_episode()
                        episode_idx += 1
    
                clock.tick(10)  # 10Hz采集
    
            "kw">self._save_all()
            pygame.quit()
    
        "kw">def _get_action("kw">self):
            keys = pygame.key.get_pressed()
            action = np.zeros(3)
            "kw">for key, vec "kw">in "kw">self.KEY_MAP.items():
                "kw">if keys[key]:
                    action += vec
            # 归一化
            norm = np.linalg.norm(action)
            "kw">if norm > 0:
                action = action / norm
            "kw">class="kw">return action

    ---


    7. 4. 训练与对比

    "kw">class="kw">import torch
    "kw">class="kw">import torch.nn "kw">as nn
    "kw">class="kw">import matplotlib.pyplot "kw">as plt
    "kw">from dataloader "kw">class="kw">import DemonstrationDataset
    "kw">from bc_policy "kw">class="kw">import BCPolicy
    "kw">from act_policy "kw">class="kw">import ACTPolicy
    "kw">from diffusion_policy "kw">class="kw">import DiffusionPolicy
    
    "kw">def train_and_compare():
        """训练三种方法并对比"""
    
        # 加载数据
        dataset = DemonstrationDataset('demos/push/')
        train_loader = torch.utils.data.DataLoader(
            dataset, batch_size=32, shuffle="kw">True
        )
    
        results = {}
    
        # ========== 行为克隆 ==========
        "kw">print("=== 训练 Behavior Cloning ===")
        bc = BCPolicy(state_dim=30, action_dim=3)
        bc_optimizer = torch.optim.Adam(bc.parameters(), lr=1e-3)
        bc_losses = []
    
        "kw">for epoch "kw">in "kw">range(100):
            epoch_loss = 0
            "kw">for states, actions "kw">in train_loader:
                bc_optimizer.zero_grad()
                pred = bc(states)
                loss = nn.MSELoss()(pred, actions)
                loss.backward()
                bc_optimizer.step()
                epoch_loss += loss.item()
            bc_losses.append(epoch_loss / "kw">len(train_loader))
    
        # 评估BC
        bc_success_rate = evaluate_policy(bc, env, n_trials=50)
        results['BC'] = {'losses': bc_losses, 'success_rate': bc_success_rate}
    
        # ========== ACT ==========
        "kw">print("=== 训练 ACT ===")
        act = ACTPolicy(state_dim=30, action_dim=3, chunk_size=20)
        act_optimizer = torch.optim.Adam(act.parameters(), lr=1e-4)
        # ... 训练循环
        act_success_rate = evaluate_policy(act, env, n_trials=50)
        results['ACT'] = {'success_rate': act_success_rate}
    
        # ========== Diffusion Policy ==========
        "kw">print("=== 训练 Diffusion Policy ===")
        diffusion = DiffusionPolicy(state_dim=30, action_dim=3)
        diffusion_optimizer = torch.optim.Adam(diffusion.parameters(), lr=1e-4)
        # ... 训练循环
        diff_success_rate = evaluate_policy(diffusion, env, n_trials=50)
        results['Diffusion'] = {'success_rate': diff_success_rate}
    
        # ========== 可视化对比 ==========
        fig, axes = plt.subplots(1, 2, figsize=(12, 4))
    
        axes[0].plot(bc_losses, label='BC')
        axes[0].set_xlabel('Epoch'); axes[0].set_ylabel('Loss')
        axes[0].set_title('训练损失')
        axes[0].legend()
    
        methods = ['BC', 'ACT', 'Diffusion']
        rates = [results[m]['success_rate'] "kw">for m "kw">in methods]
        axes[1].bar(methods, rates, color=['#667eea', '#764ba2', '#f093fb'])
        axes[1].set_ylabel('Success Rate')
        axes[1].set_title('成功率对比 (50 trials)')
        "kw">for i, v "kw">in enumerate(rates):
            axes[1].text(i, v + 0.02, f'{v:.0%}', ha='center')
    
        plt.tight_layout()
        plt.savefig('imitation_learning_comparison.png', dpi=150)
        "kw">print(f"\n结果: BC={bc_success_rate:.1%}, "
              f"ACT={act_success_rate:.1%}, "
              f"Diffusion={diff_success_rate:.1%}")
    
        "kw">class="kw">return results

    ---


    8. 5. 验收标准

  • 数据采集:能通过键盘遥操作采集50+条有效演示
  • BC训练:在推送任务上成功率 > 60%
  • ACT改进:ACT成功率比BC提升 > 10个百分点
  • Diffusion Policy:多模态任务(可选择左右两侧推送)上Diffusion明显优于BC
  • 实验报告:包含训练曲线、成功率、失败案例分析
  • ---


    9. 参考资源

  • robomimic - 模仿学习基准
  • ACT (Action Chunking Transformer)
  • Diffusion Policy