| import json |
| import os |
| import logging |
| import sys |
| import html |
| from pathlib import Path |
| from typing import Dict, List, Optional |
| from functools import lru_cache |
|
|
| import gradio as gr |
| import pandas as pd |
| import plotly.graph_objects as go |
| import plotly.io as pio |
| from huggingface_hub import snapshot_download, HfApi |
|
|
| logging.basicConfig( |
| level=logging.INFO, |
| format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", |
| handlers=[logging.StreamHandler(sys.stdout)], |
| ) |
| logger = logging.getLogger(__name__) |
|
|
| DEFAULT_DATASET_ID = os.getenv( |
| "DATASET_ID", "DynamicIntelligence/humanoid-robots-training-dataset" |
| ) |
| LOCAL_DATASET_DIR = Path("dataset_cache") |
| HF_TOKEN = os.getenv("HF_TOKEN") |
|
|
| JOINT_ALIASES = { |
| "wrist": "Wrist", |
| "thumb_tip": "Thumb Tip", |
| "index_mcp": "Index MCP", |
| "index_tip": "Index Tip", |
| } |
|
|
| JOINT_NAME_MAP = { |
| "wrist": "WRIST", |
| "thumb_tip": "THUMB_TIP", |
| "index_mcp": "INDEX_FINGER_MCP", |
| "index_tip": "INDEX_FINGER_TIP", |
| } |
|
|
| METRIC_LABELS = { |
| "x_cm": "X (cm)", |
| "y_cm": "Y (cm)", |
| "z_cm": "Z (cm)", |
| "yaw_deg": "Yaw (°)", |
| "pitch_deg": "Pitch (°)", |
| "roll_deg": "Roll (°)", |
| } |
|
|
| PLOT_GRID = [ |
| ["x_cm", "y_cm", "z_cm"], |
| ["yaw_deg", "pitch_deg", "roll_deg"], |
| ] |
|
|
| PLOT_ORDER = [metric for row in PLOT_GRID for metric in row] |
|
|
| CUSTOM_CSS = """ |
| :root, .gradio-container, body { |
| background-color: #050a18 !important; |
| color: #f8fafc !important; |
| font-family: 'Inter', 'Segoe UI', system-ui, sans-serif; |
| } |
| .side-panel { |
| background: #0f172a; |
| padding: 20px; |
| border-radius: 18px; |
| border: 1px solid #1f2b47; |
| min-height: 100%; |
| } |
| .stats-card ul { |
| list-style: none; |
| padding: 0; |
| margin: 0; |
| font-size: 0.92rem; |
| } |
| .stats-card li { |
| margin-bottom: 10px; |
| color: #e2e8f0; |
| } |
| .stats-card span { |
| display: inline-block; |
| margin-right: 6px; |
| color: #7dd3fc; |
| } |
| .episodes-title { |
| margin: 18px 0 8px; |
| font-size: 0.78rem; |
| text-transform: uppercase; |
| letter-spacing: 0.14em; |
| color: #94a3b8; |
| } |
| .episode-list .gr-form { |
| padding: 0; |
| } |
| .episode-list .gr-form > div { |
| gap: 0; |
| } |
| .episode-list input[type="radio"] { |
| display: none; |
| } |
| .episode-list label { |
| background: transparent !important; |
| border: none !important; |
| color: #cbd5f5 !important; |
| padding: 3px 0 !important; |
| justify-content: flex-start; |
| font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, "Liberation Mono", "Courier New", monospace; |
| font-size: 0.9rem; |
| text-decoration: underline; |
| } |
| .episode-list label:hover { |
| color: #67e8f9 !important; |
| cursor: pointer; |
| } |
| .episode-list input[type="radio"]:checked + label { |
| color: #facc15 !important; |
| font-weight: 700; |
| margin-left: -2px; |
| } |
| .main-panel { |
| padding-top: 8px; |
| } |
| .instruction-card { |
| background: #0f172a; |
| padding: 18px 20px; |
| border-radius: 18px; |
| border: 1px solid #1f2b47; |
| } |
| .instruction-label { |
| font-size: 0.75rem; |
| letter-spacing: 0.12em; |
| text-transform: uppercase; |
| color: #94a3b8; |
| margin-bottom: 10px; |
| } |
| .instruction-text { |
| font-size: 1.1rem; |
| line-height: 1.5; |
| } |
| .video-card { |
| background: #0f172a; |
| border: 1px solid #1f2b47; |
| border-radius: 18px; |
| padding: 18px 20px; |
| margin-top: 18px; |
| } |
| .video-title { |
| font-size: 0.78rem; |
| text-transform: uppercase; |
| letter-spacing: 0.18em; |
| color: #94a3b8; |
| margin-bottom: 8px; |
| } |
| .video-panel video { |
| border-radius: 12px; |
| border: 1px solid #1f2b47; |
| background: #030712; |
| } |
| .download-button button { |
| border-radius: 999px; |
| border: 1px solid #334155; |
| background: #1e293b; |
| color: #f8fafc; |
| font-size: 0.85rem; |
| padding: 8px 24px; |
| } |
| .download-button button:hover { |
| border-color: #67e8f9; |
| color: #67e8f9; |
| } |
| .plots-wrap { |
| margin-top: 18px; |
| } |
| .plots-wrap .gr-row { |
| gap: 16px; |
| } |
| .plot-html { |
| background: #111a2c; |
| border-radius: 12px; |
| padding: 10px; |
| border: 1px solid #1f2b47; |
| min-height: 320px; |
| } |
| .plot-html iframe { |
| width: 100%; |
| height: 300px; |
| border: none; |
| } |
| """ |
|
|
|
|
| @lru_cache(maxsize=1) |
| def get_dataset_revision(repo_id: str) -> Optional[str]: |
| try: |
| info = HfApi(token=HF_TOKEN).repo_info(repo_id=repo_id, repo_type="dataset") |
| return info.sha |
| except Exception as exc: |
| logger.warning(f"Could not fetch dataset revision for {repo_id}: {exc}") |
| return None |
|
|
|
|
| @lru_cache(maxsize=2) |
| def get_dataset_root(repo_id: str, revision: Optional[str]) -> Path: |
| local_path = snapshot_download( |
| repo_id=repo_id, |
| repo_type="dataset", |
| local_dir=LOCAL_DATASET_DIR, |
| local_dir_use_symlinks=False, |
| revision=revision, |
| token=HF_TOKEN, |
| ) |
| return Path(local_path) |
|
|
|
|
| @lru_cache(maxsize=2) |
| def load_info(repo_id: str, revision: Optional[str]) -> Dict: |
| root = get_dataset_root(repo_id, revision) |
| info_path = root / "meta" / "info.json" |
| with open(info_path, "r", encoding="utf-8") as f: |
| return json.load(f) |
|
|
|
|
| def resolve_path(root: Path, template: str, episode_chunk: int, episode_index: int) -> Path: |
| if isinstance(template, dict): |
| rgb_template = template.get("rgb") |
| if rgb_template is None: |
| raise ValueError("RGB template missing from metadata") |
| return root / rgb_template.format(episode_chunk=episode_chunk, episode_index=episode_index) |
| return root / template.format(episode_chunk=episode_chunk, episode_index=episode_index) |
|
|
|
|
| @lru_cache(maxsize=64) |
| def load_episode(repo_id: str, episode_index: int, revision: Optional[str]) -> Dict: |
| info = load_info(repo_id, revision) |
| root = get_dataset_root(repo_id, revision) |
| episode_meta = next((ep for ep in info["episodes"] if ep["episode_index"] == episode_index), None) |
| if not episode_meta: |
| raise ValueError(f"Episode {episode_index} not found in metadata") |
|
|
| chunk = episode_meta["episode_chunk"] |
| parquet_path = resolve_path(root, info["data_path"], chunk, episode_index) |
| if not parquet_path.exists(): |
| raise FileNotFoundError(f"Parquet file not found: {parquet_path}") |
|
|
| df = pd.read_parquet(parquet_path) |
| timestamps, state_df = build_state_dataframe(df) |
|
|
| rgb_path = resolve_path(root, info["video_path"], chunk, episode_index) |
|
|
| instruction = ( |
| episode_meta.get("language_instruction") |
| or ( |
| df["language_instruction"].dropna().iloc[0] |
| if "language_instruction" in df.columns and not df["language_instruction"].isna().all() |
| else info.get("task", "Tape roll to bowl") |
| ) |
| ) |
|
|
| return { |
| "timestamps": timestamps, |
| "state_df": state_df, |
| "rgb_path": rgb_path, |
| "instruction": instruction, |
| } |
|
|
|
|
| def build_state_dataframe(df: pd.DataFrame) -> (List[float], pd.DataFrame): |
| if "frame_idx" not in df.columns or "timestamp_s" not in df.columns: |
| raise ValueError("Episode parquet missing frame timing information.") |
|
|
| frame_times = ( |
| df[["frame_idx", "timestamp_s"]] |
| .drop_duplicates("frame_idx") |
| .set_index("frame_idx") |
| .sort_index() |
| ) |
| frame_indices = frame_times.index.to_list() |
|
|
| state_df = pd.DataFrame(index=frame_indices) |
| for alias, joint_name in JOINT_NAME_MAP.items(): |
| joint_df = ( |
| df[df["joint_name"] == joint_name] |
| .set_index("frame_idx") |
| .sort_index() |
| .reindex(frame_indices) |
| ) |
| for metric in METRIC_LABELS.keys(): |
| if metric in joint_df.columns: |
| state_df[f"{alias}_{metric}"] = joint_df[metric].astype(float) |
|
|
| state_df.reset_index(drop=True, inplace=True) |
| timestamps = frame_times["timestamp_s"].to_list() |
| return timestamps, state_df |
|
|
|
|
| def build_plot_fig(data: Dict, metric: str) -> go.Figure: |
| timestamps = data["timestamps"] |
| state_df = data["state_df"] |
| fig = go.Figure() |
| for alias, label in JOINT_ALIASES.items(): |
| col_name = f"{alias}_{metric}" |
| if col_name not in state_df.columns: |
| continue |
| fig.add_trace( |
| go.Scatter( |
| x=timestamps, |
| y=state_df[col_name], |
| mode="lines", |
| name=label, |
| ) |
| ) |
| fig.update_layout( |
| margin=dict(l=20, r=20, t=30, b=20), |
| height=250, |
| template="plotly_dark", |
| legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1), |
| xaxis_title="Time (s)", |
| yaxis_title=METRIC_LABELS[metric], |
| ) |
| fig.update_xaxes(showgrid=True, gridwidth=0.5, gridcolor="rgba(255,255,255,0.1)") |
| fig.update_yaxes(showgrid=True, gridwidth=0.5, gridcolor="rgba(255,255,255,0.1)") |
| return fig |
|
|
|
|
| def build_plot_html(data: Dict, metric: str) -> str: |
| fig = build_plot_fig(data, metric) |
| return pio.to_html(fig, include_plotlyjs="cdn", full_html=False) |
|
|
|
|
| def format_episode_label(idx: int) -> str: |
| return f"Episode {idx:02d}" |
|
|
|
|
| def parse_episode_label(label: str) -> int: |
| return int(label.replace("Episode", "").strip()) |
|
|
|
|
| def format_instruction_html(text: str) -> str: |
| safe_text = html.escape(text) |
| return ( |
| '<div class="instruction-card">' |
| '<p class="instruction-label">Language Instruction</p>' |
| f'<p class="instruction-text">{safe_text}</p>' |
| "</div>" |
| ) |
|
|
|
|
| def build_interface(): |
| revision = get_dataset_revision(DEFAULT_DATASET_ID) |
| info = load_info(DEFAULT_DATASET_ID, revision) |
| episode_indices = sorted(ep["episode_index"] for ep in info["episodes"]) |
| if not episode_indices: |
| raise RuntimeError("No episodes found in dataset metadata.") |
|
|
| default_idx = episode_indices[0] |
| default_label = format_episode_label(default_idx) |
| default_data = load_episode(DEFAULT_DATASET_ID, default_idx, revision) |
| default_video = str(default_data["rgb_path"]) |
| default_instruction = default_data["instruction"] |
| default_figs = {metric: build_plot_html(default_data, metric) for metric in METRIC_LABELS.keys()} |
|
|
| total_frames = sum(ep.get("num_frames", 0) for ep in info["episodes"]) |
| fps = info.get("fps", 30.0) |
| stats_html = f""" |
| <div class="stats-card"> |
| <ul> |
| <li><span>Number of samples/frames:</span> {total_frames:,}</li> |
| <li><span>Number of episodes:</span> {len(episode_indices)}</li> |
| <li><span>Frames per second:</span> {fps:.1f}</li> |
| </ul> |
| </div> |
| """ |
|
|
| theme = gr.themes.Soft( |
| primary_hue="cyan", secondary_hue="blue", neutral_hue="slate" |
| ).set( |
| body_background_fill="#0c1424", |
| body_text_color="#f8fafc", |
| block_background_fill="#111a2c", |
| block_title_text_color="#f8fafc", |
| input_background_fill="#151f33", |
| border_color_primary="#1f2b47", |
| shadow_drop="none", |
| ) |
|
|
| with gr.Blocks(theme=theme, css=CUSTOM_CSS) as demo: |
| gr.Markdown("# Humanoid Robots Hand Pose Viewer") |
| gr.Markdown( |
| "Visualize RGB + 6DoF hand trajectories for all Moving_Mini tasks " |
| "(humanoid-robots-training-dataset)." |
| ) |
|
|
| with gr.Row(equal_height=True): |
| with gr.Column(scale=1, min_width=260, elem_classes=["side-panel"]): |
| gr.HTML(stats_html) |
| gr.HTML('<div class="episodes-title">Episodes</div>') |
| episode_radio = gr.Radio( |
| choices=[format_episode_label(i) for i in episode_indices], |
| value=default_label, |
| label="Episodes", |
| elem_classes=["episode-list"], |
| ) |
| with gr.Column(scale=2, min_width=640, elem_classes=["main-panel"]): |
| instruction_box = gr.HTML( |
| format_instruction_html(default_instruction), |
| label="Language Instruction", |
| ) |
| with gr.Column(elem_classes=["video-card"]): |
| gr.HTML('<div class="video-title">RGB</div>') |
| video = gr.Video( |
| height=360, |
| value=default_video, |
| elem_classes=["video-panel"], |
| show_label=False, |
| show_download_button=False, |
| ) |
| download_button = gr.DownloadButton( |
| label="Download", |
| value=default_video, |
| elem_classes=["download-button"], |
| ) |
|
|
| plot_outputs = [] |
| gr.Markdown("### Joint trajectories", elem_classes=["plots-title"]) |
| with gr.Column(elem_classes=["plots-wrap"]): |
| for row in PLOT_GRID: |
| with gr.Row(): |
| for metric in row: |
| plot = gr.HTML(value=default_figs[metric], elem_classes=["plot-html"]) |
| plot_outputs.append(plot) |
|
|
| outputs = [instruction_box, video, download_button] + plot_outputs |
|
|
| def load_episode_payload(label: str): |
| idx = parse_episode_label(label) |
| data = load_episode(DEFAULT_DATASET_ID, idx, revision) |
| video_path = str(data["rgb_path"]) |
| figs = [build_plot_html(data, metric) for metric in PLOT_ORDER] |
| return [ |
| format_instruction_html(data["instruction"]), |
| video_path, |
| gr.DownloadButton.update(value=video_path), |
| *figs, |
| ] |
|
|
| episode_radio.change(fn=load_episode_payload, inputs=episode_radio, outputs=outputs) |
|
|
| return demo |
|
|
|
|
|
|
| def main(): |
| demo = build_interface() |
| demo.queue().launch(show_api=False) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|