|
|
import os |
|
|
import sys |
|
|
import cv2 |
|
|
import math |
|
|
import json |
|
|
import torch |
|
|
import gradio as gr |
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
from PIL import ImageOps |
|
|
from pathlib import Path |
|
|
import multiprocessing as mp |
|
|
from vitra.utils.data_utils import resize_short_side_to_target, load_normalizer, recon_traj |
|
|
from vitra.utils.config_utils import load_config |
|
|
from scipy.spatial.transform import Rotation as R |
|
|
import spaces |
|
|
|
|
|
repo_root = Path(__file__).parent |
|
|
sys.path.insert(0, str(repo_root)) |
|
|
|
|
|
from visualization.visualize_core import HandVisualizer, normalize_camera_intrinsics, save_to_video, Renderer, process_single_hand_labels |
|
|
from visualization.visualize_core import Config as HandConfig |
|
|
|
|
|
|
|
|
from inference_human_prediction import ( |
|
|
get_state, |
|
|
euler_traj_to_rotmat_traj, |
|
|
) |
|
|
|
|
|
|
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
|
|
|
|
|
|
|
vla_model = None |
|
|
vla_normalizer = None |
|
|
hand_reconstructor = None |
|
|
visualizer = None |
|
|
hand_config = None |
|
|
app_config = None |
|
|
|
|
|
def vla_predict(model, normalizer, image, instruction, state, state_mask, |
|
|
action_mask, fov, num_ddim_steps, cfg_scale, sample_times): |
|
|
""" |
|
|
VLA prediction function that runs on GPU. |
|
|
Model is already loaded and moved to CUDA in main process. |
|
|
""" |
|
|
from vitra.datasets.human_dataset import pad_state_human, pad_action |
|
|
from vitra.datasets.dataset_utils import ActionFeature, StateFeature |
|
|
|
|
|
|
|
|
norm_state = normalizer.normalize_state(state.copy()) |
|
|
|
|
|
|
|
|
unified_action_dim = ActionFeature.ALL_FEATURES[1] |
|
|
unified_state_dim = StateFeature.ALL_FEATURES[1] |
|
|
|
|
|
unified_state, unified_state_mask = pad_state_human( |
|
|
state=norm_state, |
|
|
state_mask=state_mask, |
|
|
action_dim=normalizer.action_mean.shape[0], |
|
|
state_dim=normalizer.state_mean.shape[0], |
|
|
unified_state_dim=unified_state_dim, |
|
|
) |
|
|
_, unified_action_mask = pad_action( |
|
|
actions=None, |
|
|
action_mask=action_mask.copy(), |
|
|
action_dim=normalizer.action_mean.shape[0], |
|
|
unified_action_dim=unified_action_dim |
|
|
) |
|
|
|
|
|
|
|
|
device = torch.device('cuda') |
|
|
fov = torch.from_numpy(fov).unsqueeze(0).to(device) |
|
|
unified_state = unified_state.unsqueeze(0).to(device) |
|
|
unified_state_mask = unified_state_mask.unsqueeze(0).to(device) |
|
|
unified_action_mask = unified_action_mask.unsqueeze(0).to(device) |
|
|
|
|
|
|
|
|
model = model.to(device) |
|
|
|
|
|
|
|
|
norm_action = model.predict_action( |
|
|
image=image, |
|
|
instruction=instruction, |
|
|
current_state=unified_state, |
|
|
current_state_mask=unified_state_mask, |
|
|
action_mask_torch=unified_action_mask, |
|
|
num_ddim_steps=num_ddim_steps, |
|
|
cfg_scale=cfg_scale, |
|
|
fov=fov, |
|
|
sample_times=sample_times, |
|
|
) |
|
|
|
|
|
|
|
|
norm_action = norm_action[:, :, :102] |
|
|
unnorm_action = normalizer.unnormalize_action(norm_action) |
|
|
|
|
|
|
|
|
if isinstance(unnorm_action, torch.Tensor): |
|
|
unnorm_action_np = unnorm_action.cpu().numpy() |
|
|
else: |
|
|
unnorm_action_np = np.array(unnorm_action) |
|
|
|
|
|
return unnorm_action_np |
|
|
|
|
|
class GradioConfig: |
|
|
"""Configuration for Gradio app""" |
|
|
def __init__(self): |
|
|
|
|
|
self.config_path = 'microsoft/VITRA-VLA-3B' |
|
|
self.model_path = None |
|
|
self.statistics_path = None |
|
|
|
|
|
|
|
|
self.hawor_model_path = 'arnoldland/HAWOR' |
|
|
self.detector_path = './weights/hawor/external/detector.pt' |
|
|
self.moge_model_name = 'Ruicheng/moge-2-vitl' |
|
|
self.mano_path = './weights/mano' |
|
|
|
|
|
|
|
|
self.fps = 8 |
|
|
|
|
|
|
|
|
def initialize_services(): |
|
|
"""Initialize all models once at startup""" |
|
|
global vla_model, vla_normalizer, hand_reconstructor, visualizer, hand_config, app_config |
|
|
|
|
|
if vla_model is not None: |
|
|
return "Services already initialized" |
|
|
|
|
|
try: |
|
|
app_config = GradioConfig() |
|
|
|
|
|
|
|
|
hf_token = os.environ.get('HF_TOKEN', None) |
|
|
if hf_token: |
|
|
from huggingface_hub import login |
|
|
login(token=hf_token) |
|
|
print("Logged in to HuggingFace Hub") |
|
|
|
|
|
|
|
|
print("Loading VLA model...") |
|
|
from vitra.models import load_model |
|
|
from vitra.utils.data_utils import load_normalizer |
|
|
|
|
|
configs = load_config(app_config.config_path) |
|
|
if app_config.model_path is not None: |
|
|
configs['model_load_path'] = app_config.model_path |
|
|
if app_config.statistics_path is not None: |
|
|
configs['statistics_path'] = app_config.statistics_path |
|
|
|
|
|
|
|
|
globals()['vla_model'] = load_model(configs).cuda() |
|
|
globals()['vla_model'].eval() |
|
|
globals()['vla_normalizer'] = load_normalizer(configs) |
|
|
print("VLA model loaded") |
|
|
|
|
|
|
|
|
print("Loading Hand Reconstructor...") |
|
|
from data.tools.hand_recon_core import Config, HandReconstructor |
|
|
|
|
|
class ArgsObj: |
|
|
pass |
|
|
args_obj = ArgsObj() |
|
|
args_obj.hawor_model_path = app_config.hawor_model_path |
|
|
args_obj.detector_path = app_config.detector_path |
|
|
args_obj.moge_model_name = app_config.moge_model_name |
|
|
args_obj.mano_path = app_config.mano_path |
|
|
|
|
|
recon_config = Config(args_obj) |
|
|
globals()['hand_reconstructor'] = HandReconstructor(config=recon_config, device='cuda') |
|
|
print("Hand Reconstructor loaded") |
|
|
|
|
|
|
|
|
print("Loading Visualizer...") |
|
|
globals()['hand_config'] = HandConfig(app_config) |
|
|
globals()['hand_config'].FPS = app_config.fps |
|
|
globals()['visualizer'] = HandVisualizer(globals()['hand_config'], render_gradual_traj=False) |
|
|
globals()['visualizer'].mano = globals()['visualizer'].mano.cuda() |
|
|
print("Visualizer loaded") |
|
|
|
|
|
return "✅ All services initialized successfully!" |
|
|
|
|
|
except Exception as e: |
|
|
import traceback |
|
|
return f"❌ Failed to initialize services: {str(e)}\n{traceback.format_exc()}" |
|
|
|
|
|
|
|
|
def validate_image_dimensions(image): |
|
|
"""Validate image dimensions before GPU allocation. |
|
|
Returns (is_valid, message) |
|
|
""" |
|
|
if image is None: |
|
|
return True, "" |
|
|
|
|
|
|
|
|
if isinstance(image, np.ndarray): |
|
|
img_pil = Image.fromarray(image) |
|
|
else: |
|
|
img_pil = image |
|
|
|
|
|
|
|
|
width, height = img_pil.size |
|
|
if width < height: |
|
|
error_msg = f"❌ Please upload a landscape image (width ≥ height).\nCurrent image: {width}x{height} (portrait orientation)" |
|
|
return False, error_msg |
|
|
|
|
|
return True, "" |
|
|
|
|
|
|
|
|
def validate_and_process_wrapper(image, session_state, progress=gr.Progress()): |
|
|
"""Wrapper function to validate image before GPU allocation""" |
|
|
|
|
|
if image is None: |
|
|
return ("Waiting for image upload...", |
|
|
gr.update(interactive=False), |
|
|
None, |
|
|
False, |
|
|
False, |
|
|
session_state) |
|
|
|
|
|
|
|
|
is_valid, error_msg = validate_image_dimensions(image) |
|
|
if not is_valid: |
|
|
return (error_msg, |
|
|
gr.update(interactive=False), |
|
|
None, |
|
|
False, |
|
|
False, |
|
|
session_state) |
|
|
|
|
|
|
|
|
return process_image_upload(image, session_state, progress) |
|
|
|
|
|
|
|
|
@spaces.GPU(duration=120) |
|
|
def process_image_upload(image, session_state, progress=gr.Progress()): |
|
|
"""Process uploaded image and run hand reconstruction""" |
|
|
global hand_reconstructor |
|
|
if torch.cuda.is_available(): |
|
|
print("CUDA is available for image processing") |
|
|
else: |
|
|
print("CUDA is NOT available for image processing") |
|
|
|
|
|
import time |
|
|
start_time = time.time() |
|
|
while time.time() - start_time < 60: |
|
|
try: |
|
|
if torch.cuda.is_available(): |
|
|
torch.zeros(1).cuda() |
|
|
break |
|
|
except: |
|
|
time.sleep(2) |
|
|
|
|
|
if hand_reconstructor is None: |
|
|
return ("Services not initialized. Please wait for initialization to complete.", |
|
|
gr.update(interactive=False), |
|
|
None, |
|
|
False, |
|
|
False, |
|
|
session_state) |
|
|
|
|
|
try: |
|
|
progress(0, desc="Preparing image...") |
|
|
|
|
|
|
|
|
if isinstance(image, np.ndarray): |
|
|
img_pil = Image.fromarray(image) |
|
|
else: |
|
|
img_pil = image |
|
|
|
|
|
|
|
|
session_state['current_image'] = img_pil |
|
|
|
|
|
progress(0.2, desc="Running hand reconstruction...") |
|
|
|
|
|
|
|
|
image_np = np.array(img_pil) |
|
|
image_bgr = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR) |
|
|
|
|
|
|
|
|
image_list = [image_bgr] |
|
|
hand_data = hand_reconstructor.recon(image_list) |
|
|
|
|
|
session_state['current_hand_data'] = hand_data |
|
|
|
|
|
progress(1.0, desc="Hand reconstruction complete!") |
|
|
|
|
|
|
|
|
has_left = 'left' in hand_data and len(hand_data['left']) > 0 |
|
|
has_right = 'right' in hand_data and len(hand_data['right']) > 0 |
|
|
|
|
|
info_msg = "✅ Hand reconstruction complete!\n" |
|
|
info_msg += f"Detected hands: " |
|
|
if has_left and has_right: |
|
|
info_msg += "Left ✓, Right ✓" |
|
|
elif has_left: |
|
|
info_msg += "Left ✓, Right ✗" |
|
|
elif has_right: |
|
|
info_msg += "Left ✗, Right ✓" |
|
|
else: |
|
|
info_msg += "None detected" |
|
|
|
|
|
|
|
|
session_state['detected_left'] = has_left |
|
|
session_state['detected_right'] = has_right |
|
|
|
|
|
|
|
|
|
|
|
return (info_msg, |
|
|
gr.update(interactive=True), |
|
|
hand_data, |
|
|
has_left, |
|
|
has_right, |
|
|
session_state) |
|
|
|
|
|
except Exception as e: |
|
|
import traceback |
|
|
error_msg = f"❌ Hand reconstruction failed: {str(e)}\n{traceback.format_exc()}" |
|
|
|
|
|
session_state['detected_left'] = False |
|
|
session_state['detected_right'] = False |
|
|
|
|
|
return (error_msg, |
|
|
gr.update(interactive=True), |
|
|
None, |
|
|
False, |
|
|
False, |
|
|
session_state) |
|
|
|
|
|
def update_checkboxes(has_left, has_right): |
|
|
"""Update checkbox states based on detected hands (no progress bar)""" |
|
|
|
|
|
|
|
|
left_checkbox_update = gr.update( |
|
|
value=has_left, |
|
|
interactive=True if has_left else False, |
|
|
elem_classes="disabled-checkbox" if not has_left else "" |
|
|
) |
|
|
right_checkbox_update = gr.update( |
|
|
value=has_right, |
|
|
interactive=True if has_right else False, |
|
|
elem_classes="disabled-checkbox" if not has_right else "" |
|
|
) |
|
|
|
|
|
|
|
|
left_instruction_update = gr.update( |
|
|
interactive=has_left, |
|
|
elem_classes="disabled-textbox" if not has_left else "" |
|
|
) |
|
|
right_instruction_update = gr.update( |
|
|
interactive=has_right, |
|
|
elem_classes="disabled-textbox" if not has_right else "" |
|
|
) |
|
|
|
|
|
return left_checkbox_update, right_checkbox_update, left_instruction_update, right_instruction_update |
|
|
|
|
|
|
|
|
def update_instruction_interactivity(use_left, use_right): |
|
|
"""Update instruction textbox interactivity based on checkbox states""" |
|
|
left_update = gr.update( |
|
|
interactive=use_left, |
|
|
elem_classes="disabled-textbox" if not use_left else "" |
|
|
) |
|
|
right_update = gr.update( |
|
|
interactive=use_right, |
|
|
elem_classes="disabled-textbox" if not use_right else "" |
|
|
) |
|
|
return left_update, right_update |
|
|
|
|
|
def update_final_instruction(left_instruction, right_instruction, use_left, use_right): |
|
|
"""Update final instruction based on left/right inputs and checkbox states""" |
|
|
|
|
|
left_text = left_instruction if use_left else "None." |
|
|
right_text = right_instruction if use_right else "None." |
|
|
|
|
|
final = f"Left hand: {left_text} Right hand: {right_text}" |
|
|
|
|
|
|
|
|
styled_output = f"""<div style='padding: 12px; background-color: #f0f7ff; border-left: 4px solid #4A90E2; border-radius: 4px; margin-top: 10px;'> |
|
|
<strong style='color: #2c5282;'>📝 Final Instruction:</strong><br> |
|
|
<span style='color: #1a365d; font-size: 14px;'>{final}</span> |
|
|
</div>""" |
|
|
|
|
|
|
|
|
return gr.update(value=styled_output), final |
|
|
|
|
|
def parse_instruction(instruction_text): |
|
|
"""Parse combined instruction into left and right parts""" |
|
|
import re |
|
|
|
|
|
|
|
|
left_match = re.search(r'Left(?:\s+hand)?:\s*([^.]*(?:\.[^LR]*)*)(?=Right|$)', instruction_text, re.IGNORECASE) |
|
|
right_match = re.search(r'Right(?:\s+hand)?:\s*(.+?)$', instruction_text, re.IGNORECASE) |
|
|
|
|
|
left_text = left_match.group(1).strip() if left_match else "None." |
|
|
right_text = right_match.group(1).strip() if right_match else "None." |
|
|
|
|
|
return left_text, right_text |
|
|
|
|
|
@spaces.GPU(duration=120) |
|
|
def generate_prediction(instruction, use_left, use_right, sample_times, num_ddim_steps, cfg_scale, hand_data, image, progress=gr.Progress()): |
|
|
"""Generate hand motion prediction and visualization""" |
|
|
global vla_model, vla_normalizer, visualizer, hand_config, app_config |
|
|
|
|
|
|
|
|
import time |
|
|
start_time = time.time() |
|
|
while time.time() - start_time < 60: |
|
|
try: |
|
|
if torch.cuda.is_available(): |
|
|
torch.zeros(1).cuda() |
|
|
break |
|
|
except: |
|
|
time.sleep(2) |
|
|
|
|
|
if hand_data is None: |
|
|
return None, "Please upload an image and wait for hand reconstruction first" |
|
|
|
|
|
if not use_left and not use_right: |
|
|
return None, "Please select at least one hand (left or right)" |
|
|
|
|
|
try: |
|
|
progress(0, desc="Preparing data...") |
|
|
|
|
|
|
|
|
if image is None: |
|
|
return None, "Image not found. Please upload an image first." |
|
|
|
|
|
ori_w, ori_h = image.size |
|
|
|
|
|
try: |
|
|
image = ImageOps.exif_transpose(image) |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
image_resized = resize_short_side_to_target(image, target=224) |
|
|
w, h = image_resized.size |
|
|
|
|
|
|
|
|
current_state_left = None |
|
|
current_state_right = None |
|
|
beta_left = None |
|
|
beta_right = None |
|
|
|
|
|
progress(0.1, desc="Extracting hand states...") |
|
|
|
|
|
if use_right: |
|
|
current_state_right, beta_right, fov_x, _ = get_state(hand_data, hand_side='right') |
|
|
if use_left: |
|
|
current_state_left, beta_left, fov_x, _ = get_state(hand_data, hand_side='left') |
|
|
|
|
|
fov_x = fov_x * np.pi / 180 |
|
|
f_ori = ori_w / np.tan(fov_x / 2) / 2 |
|
|
fov_y = 2 * np.arctan(ori_h / (2 * f_ori)) |
|
|
|
|
|
f = w / np.tan(fov_x / 2) / 2 |
|
|
intrinsics = np.array([ |
|
|
[f, 0, w/2], |
|
|
[0, f, h/2], |
|
|
[0, 0, 1] |
|
|
]) |
|
|
|
|
|
|
|
|
if current_state_left is None and current_state_right is None: |
|
|
return None, "No valid hand states found" |
|
|
|
|
|
state_left = current_state_left if use_left else np.zeros_like(current_state_right) |
|
|
beta_left = beta_left if use_left else np.zeros_like(beta_right) |
|
|
state_right = current_state_right if use_right else np.zeros_like(current_state_left) |
|
|
beta_right = beta_right if use_right else np.zeros_like(beta_left) |
|
|
|
|
|
state = np.concatenate([state_left, beta_left, state_right, beta_right], axis=0) |
|
|
state_mask = np.array([use_left, use_right], dtype=bool) |
|
|
|
|
|
|
|
|
configs = load_config(app_config.config_path) |
|
|
chunk_size = configs.get('fwd_pred_next_n', 16) |
|
|
action_mask = np.tile(np.array([[use_left, use_right]], dtype=bool), (chunk_size, 1)) |
|
|
|
|
|
fov = np.array([fov_x, fov_y], dtype=np.float32) |
|
|
image_resized_np = np.array(image_resized) |
|
|
|
|
|
progress(0.3, desc="Running VLA inference...") |
|
|
|
|
|
|
|
|
unnorm_action = vla_predict( |
|
|
model=vla_model, |
|
|
normalizer=vla_normalizer, |
|
|
image=image_resized_np, |
|
|
instruction=instruction, |
|
|
state=state, |
|
|
state_mask=state_mask, |
|
|
action_mask=action_mask, |
|
|
fov=fov, |
|
|
num_ddim_steps=num_ddim_steps, |
|
|
cfg_scale=cfg_scale, |
|
|
sample_times=sample_times, |
|
|
) |
|
|
|
|
|
progress(0.6, desc="Visualizing predictions...") |
|
|
|
|
|
|
|
|
fx_exo = intrinsics[0, 0] |
|
|
fy_exo = intrinsics[1, 1] |
|
|
renderer = Renderer(w, h, (fx_exo, fy_exo), 'cuda') |
|
|
|
|
|
T = chunk_size + 1 |
|
|
traj_right_list = np.zeros((sample_times, T, 51), dtype=np.float32) |
|
|
traj_left_list = np.zeros((sample_times, T, 51), dtype=np.float32) |
|
|
|
|
|
traj_mask = np.tile(np.array([[use_left, use_right]], dtype=bool), (T, 1)) |
|
|
left_hand_mask = traj_mask[:, 0] |
|
|
right_hand_mask = traj_mask[:, 1] |
|
|
hand_mask = (left_hand_mask, right_hand_mask) |
|
|
|
|
|
all_rendered_frames = [] |
|
|
|
|
|
|
|
|
for i in range(sample_times): |
|
|
progress(0.6 + 0.3 * (i / sample_times), desc=f"Rendering sample {i+1}/{sample_times}...") |
|
|
|
|
|
traj_right = traj_right_list[i] |
|
|
traj_left = traj_left_list[i] |
|
|
|
|
|
if use_left: |
|
|
traj_left = recon_traj( |
|
|
state=state_left, |
|
|
rel_action=unnorm_action[i, :, 0:51], |
|
|
) |
|
|
if use_right: |
|
|
traj_right = recon_traj( |
|
|
state=state_right, |
|
|
rel_action=unnorm_action[i, :, 51:102], |
|
|
) |
|
|
|
|
|
left_hand_labels = { |
|
|
'transl_worldspace': traj_left[:, 0:3], |
|
|
'global_orient_worldspace': R.from_euler('xyz', traj_left[:, 3:6]).as_matrix(), |
|
|
'hand_pose': euler_traj_to_rotmat_traj(traj_left[:, 6:51], T), |
|
|
'beta': beta_left, |
|
|
} |
|
|
right_hand_labels = { |
|
|
'transl_worldspace': traj_right[:, 0:3], |
|
|
'global_orient_worldspace': R.from_euler('xyz', traj_right[:, 3:6]).as_matrix(), |
|
|
'hand_pose': euler_traj_to_rotmat_traj(traj_right[:, 6:51], T), |
|
|
'beta': beta_right, |
|
|
} |
|
|
|
|
|
verts_left_worldspace, _ = process_single_hand_labels(left_hand_labels, left_hand_mask, visualizer.mano, is_left=True) |
|
|
verts_right_worldspace, _ = process_single_hand_labels(right_hand_labels, right_hand_mask, visualizer.mano, is_left=False) |
|
|
|
|
|
hand_traj_wordspace = (verts_left_worldspace, verts_right_worldspace) |
|
|
|
|
|
R_w2c = np.broadcast_to(np.eye(3), (T, 3, 3)).copy() |
|
|
t_w2c = np.zeros((T, 3, 1), dtype=np.float32) |
|
|
extrinsics = (R_w2c, t_w2c) |
|
|
|
|
|
image_bgr = image_resized_np[..., ::-1] |
|
|
resize_video_frames = [image_bgr] * T |
|
|
save_frames = visualizer._render_hand_trajectory( |
|
|
resize_video_frames, |
|
|
hand_traj_wordspace, |
|
|
hand_mask, |
|
|
extrinsics, |
|
|
renderer, |
|
|
mode='first' |
|
|
) |
|
|
|
|
|
all_rendered_frames.append(save_frames) |
|
|
|
|
|
progress(0.95, desc="Creating output video...") |
|
|
|
|
|
|
|
|
num_frames = len(all_rendered_frames[0]) |
|
|
grid_cols = math.ceil(math.sqrt(sample_times)) |
|
|
grid_rows = math.ceil(sample_times / grid_cols) |
|
|
|
|
|
combined_frames = [] |
|
|
for frame_idx in range(num_frames): |
|
|
sample_frames = [all_rendered_frames[i][frame_idx] for i in range(sample_times)] |
|
|
|
|
|
while len(sample_frames) < grid_rows * grid_cols: |
|
|
black_frame = np.zeros_like(sample_frames[0]) |
|
|
sample_frames.append(black_frame) |
|
|
|
|
|
rows = [] |
|
|
for row_idx in range(grid_rows): |
|
|
row_frames = sample_frames[row_idx * grid_cols:(row_idx + 1) * grid_cols] |
|
|
row_concat = np.concatenate(row_frames, axis=1) |
|
|
rows.append(row_concat) |
|
|
|
|
|
combined_frame = np.concatenate(rows, axis=0) |
|
|
combined_frames.append(combined_frame) |
|
|
|
|
|
|
|
|
output_dir = Path("./temp_gradio/outputs") |
|
|
output_dir.mkdir(parents=True, exist_ok=True) |
|
|
output_path = output_dir / "prediction.mp4" |
|
|
save_to_video(combined_frames, str(output_path), fps=hand_config.FPS) |
|
|
|
|
|
progress(1.0, desc="Complete!") |
|
|
|
|
|
return str(output_path), f"✅ Generated {sample_times} prediction samples successfully!" |
|
|
|
|
|
except Exception as e: |
|
|
import traceback |
|
|
error_msg = f"❌ Prediction failed: {str(e)}\n{traceback.format_exc()}" |
|
|
return None, error_msg |
|
|
|
|
|
|
|
|
def load_examples(): |
|
|
"""Automatically load all image examples from the examples folder""" |
|
|
examples_dir = Path(__file__).parent / "examples" |
|
|
|
|
|
|
|
|
default_instructions = { |
|
|
"0001.jpg": "Left hand: Put the trash into the garbage. Right hand: None.", |
|
|
"0002.jpg": "Left hand: None. Right hand: Pick up the picture of Michael Jackson.", |
|
|
"0003.png": "Left hand: None. Right hand: Pick up the metal water cup.", |
|
|
"0004.jpg": "Left hand: Squeeze the dish sponge. Right hand: None.", |
|
|
"0005.jpg": "Left hand: None. Right hand: Cut the meat with the knife.", |
|
|
"0006.jpg": "Left hand: Open the closet door. Right hand: None.", |
|
|
"0007.jpg": "Left hand: None. Right hand: Cut the paper with the scissors.", |
|
|
"0008.jpg": "Left hand: Wipe the countertop with the cloth. Right hand: None.", |
|
|
"0009.jpg": "Left hand: None. Right hand: Open the cabinet door.", |
|
|
"0010.png": "Left hand: None. Right hand: Turn on the faucet.", |
|
|
"0011.jpg": "Left hand: Put the drink bottle into the trash can. Right hand: None.", |
|
|
"0012.jpg": "Left hand: None. Right hand: Pick up the gray cup from the cabinet.", |
|
|
"0013.jpg": "Left hand: None. Right hand: Take the milk bottle out of the fridge.", |
|
|
"0014.jpg": "Left hand: None. Right hand: 拿起气球。", |
|
|
"0015.jpg": "Left hand: None. Right hand: Pick up the picture with the smaller red heart.", |
|
|
"0016.jpg": "Left hand: None. Right hand: Pick up the picture with \"Cat\".", |
|
|
"0017.jpg": "Left hand: None. Right hand: Pick up the picture of the Statue of Liberty.", |
|
|
"0018.jpg": "Left hand: None. Right hand: Pick up the picture of the two people.", |
|
|
} |
|
|
|
|
|
examples_images = [] |
|
|
instructions_map = {} |
|
|
|
|
|
if examples_dir.exists(): |
|
|
|
|
|
image_files = sorted([f for f in examples_dir.iterdir() |
|
|
if f.suffix.lower() in ['.jpg', '.jpeg', '.png']]) |
|
|
|
|
|
for img_path in image_files: |
|
|
img_path_str = str(img_path) |
|
|
instruction = default_instructions.get( |
|
|
img_path.name, |
|
|
"Left hand: Perform the action. Right hand: None." |
|
|
) |
|
|
|
|
|
examples_images.append([img_path_str]) |
|
|
|
|
|
instructions_map[img_path_str] = instruction |
|
|
|
|
|
return examples_images, instructions_map |
|
|
|
|
|
|
|
|
def get_instruction_for_image(image_path, instructions_map): |
|
|
"""Get the instruction for a given image path""" |
|
|
if image_path is None: |
|
|
return gr.update() |
|
|
|
|
|
|
|
|
instruction = instructions_map.get(str(image_path), "") |
|
|
return instruction |
|
|
|
|
|
|
|
|
|
|
|
def create_gradio_interface(): |
|
|
"""Create Gradio interface""" |
|
|
|
|
|
with gr.Blocks(delete_cache=(600, 600), title="3D Hand Motion Prediction with VITRA") as demo: |
|
|
|
|
|
|
|
|
gr.HTML(""" |
|
|
<style> |
|
|
.disabled-checkbox { |
|
|
opacity: 0.5 !important; |
|
|
pointer-events: none !important; |
|
|
} |
|
|
.disabled-textbox textarea { |
|
|
background-color: #f5f5f5 !important; |
|
|
color: #9e9e9e !important; |
|
|
cursor: not-allowed !important; |
|
|
} |
|
|
</style> |
|
|
""") |
|
|
|
|
|
gr.HTML(""" |
|
|
<div align="center"> |
|
|
<h1> 🤖 Hand Action Prediction with <a href="https://microsoft.github.io/VITRA/" target="_blank" style="text-decoration: underline; font-weight: bold; color: #4A90E2;">VITRA</a> <a title="Github" href="https://github.com/microsoft/VITRA" target="_blank" rel="noopener noreferrer" style="display: inline-block;"> <img src="https://img.shields.io/github/stars/microsoft/VITRA?label=GitHub%20%E2%98%85&logo=github&color=C8C" alt="badge-github-stars"> </a> </h1> |
|
|
</div> |
|
|
|
|
|
<div style="line-height: 1.8;"> |
|
|
<br> |
|
|
<p style="font-size: 16px;">Upload a <strong style="color: #7C4DFF;">landscape</strong>, <strong style="color: #7C4DFF;">egocentric (first-person)</strong> image containing hand(s) and provide instructions to predict future 3D hand trajectories.</p> |
|
|
|
|
|
<h3>🌟 Steps:</h3> |
|
|
<ol> |
|
|
<li>Upload an landscape view image containing hand(s).</li> |
|
|
<li>Enter text instructions describing the desired task.</li> |
|
|
<li>Configure advanced settings (Optional) and click "Generate 3D Hand Trajectory".</li> |
|
|
</ol> |
|
|
|
|
|
<h3>💡 Tips:</h3> |
|
|
<ul> |
|
|
<li><strong>Use Left/Right Hand</strong>: Select which hand to predict based on what's detected and what you want to predict.</li> |
|
|
<li><strong>Instruction</strong>: Provide clear and specific imperative instructions separately for the left and right hands, and enter them in the corresponding fields. If the results are unsatisfactory, <strong style="color: #7C4DFF;">try providing more detailed instructions</strong> (e.g., color, orientation, etc.).</li> |
|
|
<li>For best inference quality, it is recommended to <strong style="color: #7C4DFF;">capture landscape view images from a camera height close to that of a human head</strong>. Highly unusual or distorted hand poses/positions may cause inference failures.</li> |
|
|
<li>It is worth noting that each generation produces only a single action chunking starting from the current state, which <strong style="color: #7C4DFF;">does not necessarily complete the entire task</strong>. Executing an entire chunking in one step may lead to reduced precision.</li> |
|
|
</ul> |
|
|
|
|
|
</div> |
|
|
|
|
|
<hr style='border: none; border-top: 1px solid #e0e0e0; margin: 20px 0;'> |
|
|
""") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
gr.HTML(""" |
|
|
<div style='background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); padding: 15px; border-radius: 8px; margin-bottom: 15px;'> |
|
|
<h3 style='color: white; margin: 0; text-align: center;'>📄 Input</h3> |
|
|
</div> |
|
|
""") |
|
|
|
|
|
|
|
|
input_image = gr.Image( |
|
|
label="🖼️ Upload Image with Hands", |
|
|
type="pil", |
|
|
height=300, |
|
|
) |
|
|
|
|
|
|
|
|
recon_status = gr.Textbox( |
|
|
label="🔍 Hand Reconstruction Status", |
|
|
value="⏳ Waiting for image upload...", |
|
|
interactive=False, |
|
|
lines=2, |
|
|
container=True |
|
|
) |
|
|
|
|
|
gr.Markdown("---") |
|
|
gr.HTML(""" |
|
|
<div style='background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); padding: 15px; border-radius: 8px; margin-bottom: 15px;'> |
|
|
<h3 style='color: white; margin: 0; text-align: center;'>⚙️ Prediction Settings</h3> |
|
|
</div> |
|
|
""") |
|
|
gr.HTML(""" |
|
|
<div style='padding: 8px; background-color: #e8eaf6; border-left: 4px solid #5c6bc0; border-radius: 4px; margin-bottom: 10px;'> |
|
|
<strong style='color: #3949ab;'>👋 Select Hands:</strong> |
|
|
</div> |
|
|
""") |
|
|
with gr.Row(): |
|
|
use_left = gr.Checkbox(label="Use Left Hand", value=True) |
|
|
use_right = gr.Checkbox(label="Use Right Hand", value=True) |
|
|
|
|
|
|
|
|
gr.HTML(""" |
|
|
<div style='padding: 8px; background-color: #e8eaf6; border-left: 4px solid #5c6bc0; border-radius: 4px; margin: 15px 0 10px 0;'> |
|
|
<strong style='color: #3949ab;'>✍️ Instructions:</strong> |
|
|
</div> |
|
|
""") |
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
with gr.Row(): |
|
|
gr.HTML("<div style='display: flex; align-items: center; min-height: 40px; padding-right: 2px;'><span style='font-weight: 600; color: #5c6bc0; white-space: nowrap;'>Left hand:</span></div>") |
|
|
left_instruction = gr.Textbox( |
|
|
label="", |
|
|
value="Put the trash into the garbage.", |
|
|
lines=1, |
|
|
max_lines=5, |
|
|
placeholder="Describe left hand action...", |
|
|
show_label=False, |
|
|
interactive=True, |
|
|
scale=3 |
|
|
) |
|
|
with gr.Column(): |
|
|
with gr.Row(): |
|
|
gr.HTML("<div style='display: flex; align-items: center; min-height: 40px; padding-right: 2px;'><span style='font-weight: 600; color: #5c6bc0; white-space: nowrap;'>Right hand:</span></div>") |
|
|
right_instruction = gr.Textbox( |
|
|
label="", |
|
|
value="None.", |
|
|
lines=1, |
|
|
max_lines=5, |
|
|
placeholder="Describe right hand action...", |
|
|
show_label=False, |
|
|
interactive=True, |
|
|
scale=3 |
|
|
) |
|
|
|
|
|
|
|
|
final_instruction = gr.HTML( |
|
|
value="""<div style='padding: 12px; background-color: #f0f7ff; border-left: 4px solid #4A90E2; border-radius: 4px; margin-top: 10px;'> |
|
|
<strong style='color: #2c5282;'>📝 Final Instruction:</strong><br> |
|
|
<span style='color: #1a365d; font-size: 14px;'>Left hand: Put the trash into the garbage. Right hand: None.</span> |
|
|
</div>""", |
|
|
show_label=False |
|
|
) |
|
|
final_instruction_text = gr.State(value="Left hand: Put the trash into the garbage. Right hand: None.") |
|
|
|
|
|
|
|
|
with gr.Accordion("🔧 Advanced Settings", open=False): |
|
|
sample_times = gr.Slider( |
|
|
minimum=1, |
|
|
maximum=9, |
|
|
value=4, |
|
|
step=1, |
|
|
label="Number of Samples", |
|
|
info="Multiple samples show different possible trajectories." |
|
|
) |
|
|
num_ddim_steps = gr.Slider( |
|
|
minimum=1, |
|
|
maximum=50, |
|
|
value=10, |
|
|
step=5, |
|
|
label="DDIM Steps", |
|
|
info="DDIM steps of the diffusion model. 10 is usually sufficient." |
|
|
) |
|
|
cfg_scale = gr.Slider( |
|
|
minimum=1.0, |
|
|
maximum=15.0, |
|
|
value=5.0, |
|
|
step=0.5, |
|
|
label="CFG Scale", |
|
|
info="Classifier-free guidance scale of the diffusion model." |
|
|
) |
|
|
|
|
|
|
|
|
generate_btn = gr.Button("🎬 Generate 3D Hand Trajectory", variant="primary", size="lg") |
|
|
|
|
|
|
|
|
hand_data = gr.State(value=None) |
|
|
detected_left = gr.State(value=False) |
|
|
detected_right = gr.State(value=False) |
|
|
|
|
|
|
|
|
session_state = gr.State(value={}) |
|
|
|
|
|
|
|
|
with gr.Column(scale=1): |
|
|
gr.HTML(""" |
|
|
<div style='background: linear-gradient(135deg, #4facfe 0%, #00f2fe 100%); padding: 15px; border-radius: 8px; margin-bottom: 15px;'> |
|
|
<h3 style='color: white; margin: 0; text-align: center;'>🎬 Output</h3> |
|
|
</div> |
|
|
""") |
|
|
|
|
|
|
|
|
output_video = gr.Video( |
|
|
label="🎬 Predicted Hand Motion", |
|
|
height=500, |
|
|
autoplay=True |
|
|
) |
|
|
|
|
|
|
|
|
gen_status = gr.Textbox( |
|
|
label="📊 Generation Status", |
|
|
value="", |
|
|
interactive=False, |
|
|
lines=2 |
|
|
) |
|
|
|
|
|
|
|
|
gr.Markdown("---") |
|
|
gr.HTML(""" |
|
|
<div style='background: linear-gradient(135deg, #89f7fe 0%, #66a6ff 100%); padding: 15px; border-radius: 8px; margin: 20px 0 10px 0;'> |
|
|
<h3 style='color: white; margin: 0; text-align: center;'>📋 Examples</h3> |
|
|
</div> |
|
|
""") |
|
|
gr.HTML(""" |
|
|
<div style='padding: 10px; background-color: #e7f3ff; border-left: 4px solid #2196F3; border-radius: 4px; margin-bottom: 15px;'> |
|
|
<span style='color: #1565c0;'>👆 Click any example below to load the image and instruction</span> |
|
|
</div> |
|
|
""") |
|
|
|
|
|
examples_images, instructions_map = load_examples() |
|
|
|
|
|
|
|
|
example_gallery = gr.Gallery( |
|
|
value=[img[0] for img in examples_images], |
|
|
label="", |
|
|
columns=6, |
|
|
height="450", |
|
|
object_fit="contain", |
|
|
show_label=False |
|
|
) |
|
|
|
|
|
|
|
|
def load_example_from_gallery(evt: gr.SelectData): |
|
|
selected_index = evt.index |
|
|
if selected_index < len(examples_images): |
|
|
img_path = examples_images[selected_index][0] |
|
|
instruction_text = instructions_map.get(img_path, "") |
|
|
|
|
|
left_text, right_text = parse_instruction(instruction_text) |
|
|
|
|
|
return gr.update(value=img_path), gr.update(value=left_text), gr.update(value=right_text), gr.update(interactive=False) |
|
|
return gr.update(), gr.update(), gr.update(), gr.update() |
|
|
|
|
|
example_gallery.select( |
|
|
fn=load_example_from_gallery, |
|
|
inputs=[], |
|
|
outputs=[input_image, left_instruction, right_instruction, generate_btn], |
|
|
show_progress=False |
|
|
).then( |
|
|
fn=update_final_instruction, |
|
|
inputs=[left_instruction, right_instruction, use_left, use_right], |
|
|
outputs=[final_instruction, final_instruction_text], |
|
|
show_progress=False |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
input_image.change( |
|
|
fn=validate_and_process_wrapper, |
|
|
inputs=[input_image, session_state], |
|
|
outputs=[recon_status, generate_btn, hand_data, detected_left, detected_right, session_state], |
|
|
show_progress='full' |
|
|
).then( |
|
|
fn=update_checkboxes, |
|
|
inputs=[detected_left, detected_right], |
|
|
outputs=[use_left, use_right, left_instruction, right_instruction], |
|
|
show_progress=False |
|
|
) |
|
|
|
|
|
|
|
|
use_left.change( |
|
|
fn=update_instruction_interactivity, |
|
|
inputs=[use_left, use_right], |
|
|
outputs=[left_instruction, right_instruction], |
|
|
show_progress=False |
|
|
).then( |
|
|
fn=update_final_instruction, |
|
|
inputs=[left_instruction, right_instruction, use_left, use_right], |
|
|
outputs=[final_instruction, final_instruction_text], |
|
|
show_progress=False |
|
|
) |
|
|
|
|
|
use_right.change( |
|
|
fn=update_instruction_interactivity, |
|
|
inputs=[use_left, use_right], |
|
|
outputs=[left_instruction, right_instruction], |
|
|
show_progress=False |
|
|
).then( |
|
|
fn=update_final_instruction, |
|
|
inputs=[left_instruction, right_instruction, use_left, use_right], |
|
|
outputs=[final_instruction, final_instruction_text], |
|
|
show_progress=False |
|
|
) |
|
|
|
|
|
|
|
|
left_instruction.change( |
|
|
fn=update_final_instruction, |
|
|
inputs=[left_instruction, right_instruction, use_left, use_right], |
|
|
outputs=[final_instruction, final_instruction_text], |
|
|
show_progress=False |
|
|
) |
|
|
|
|
|
right_instruction.change( |
|
|
fn=update_final_instruction, |
|
|
inputs=[left_instruction, right_instruction, use_left, use_right], |
|
|
outputs=[final_instruction, final_instruction_text], |
|
|
show_progress=False |
|
|
) |
|
|
|
|
|
|
|
|
generate_btn.click( |
|
|
fn=generate_prediction, |
|
|
inputs=[final_instruction_text, use_left, use_right, sample_times, num_ddim_steps, cfg_scale, hand_data, input_image], |
|
|
outputs=[output_video, gen_status], |
|
|
show_progress='full' |
|
|
) |
|
|
|
|
|
return demo |
|
|
|
|
|
if __name__ == "__main__": |
|
|
"""launch Gradio app""" |
|
|
|
|
|
print("Initializing services...") |
|
|
init_msg = initialize_services() |
|
|
print(init_msg) |
|
|
|
|
|
if "Failed" in init_msg: |
|
|
print("⚠️ Services failed to initialize. Please check the configuration and try again.") |
|
|
|
|
|
|
|
|
demo = create_gradio_interface() |
|
|
|
|
|
|
|
|
demo.launch() |