| import os |
| import sys |
|
|
| import torch |
| from lightning import seed_everything |
| from safetensors.torch import load_file as load_safetensors |
|
|
| from ldf_utils.initialize import compare_statedict_and_parameters, instantiate, load_config |
|
|
| |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
|
|
|
| def load_model_from_config(): |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| torch.set_float32_matmul_precision("high") |
| cfg = load_config() |
| seed_everything(cfg.seed) |
| |
| |
| |
| if '--config' in sys.argv: |
| config_idx = sys.argv.index('--config') + 1 |
| config_dir = os.path.dirname(os.path.abspath(sys.argv[config_idx])) |
| else: |
| config_dir = os.getcwd() |
|
|
| vae = instantiate( |
| target=cfg.test_vae.target, |
| cfg=None, |
| hfstyle=False, |
| **cfg.test_vae.params, |
| ) |
| |
| |
| vae_path = cfg.test_vae_ckpt |
| if not os.path.isabs(vae_path): |
| vae_path = os.path.join(config_dir, vae_path) |
| |
| |
| vae_state_dict = load_safetensors(vae_path) |
| vae.load_state_dict(vae_state_dict, strict=True) |
| print(f"Loaded VAE model from {vae_path}") |
|
|
| compare_statedict_and_parameters( |
| state_dict=vae.state_dict(), |
| named_parameters=vae.named_parameters(), |
| named_buffers=vae.named_buffers(), |
| ) |
| vae.to(device) |
| vae.eval() |
|
|
| |
| model_params = dict(cfg.model.params) |
| |
| if 'checkpoint_path' in model_params and model_params['checkpoint_path']: |
| if not os.path.isabs(model_params['checkpoint_path']): |
| model_params['checkpoint_path'] = os.path.join(config_dir, model_params['checkpoint_path']) |
| if 'tokenizer_path' in model_params and model_params['tokenizer_path']: |
| if not os.path.isabs(model_params['tokenizer_path']): |
| model_params['tokenizer_path'] = os.path.join(config_dir, model_params['tokenizer_path']) |
| |
| model = instantiate( |
| target=cfg.model.target, cfg=None, hfstyle=False, **model_params |
| ) |
| |
| |
| model_path = cfg.test_ckpt |
| if not os.path.isabs(model_path): |
| model_path = os.path.join(config_dir, model_path) |
| |
| |
| model_state_dict = load_safetensors(model_path) |
| model.load_state_dict(model_state_dict, strict=True) |
| print(f"Loaded model from {model_path}") |
|
|
| compare_statedict_and_parameters( |
| state_dict=model.state_dict(), |
| named_parameters=model.named_parameters(), |
| named_buffers=model.named_buffers(), |
| ) |
| model.to(device) |
| model.eval() |
|
|
| return vae, model |
|
|
|
|
| @torch.inference_mode() |
| def generate_feature_stream( |
| model, feature_length, text, feature_text_end=None, num_denoise_steps=None |
| ): |
| """ |
| Streaming interface for feature generation |
| Args: |
| model: Loaded model |
| feature_length: List[int], generation length for each sample |
| text: List[str] or List[List[str]], text prompts |
| feature_text_end: List[List[int]], time points where text ends (if text is list of list) |
| num_denoise_steps: Number of denoising steps |
| Yields: |
| dict: Contains "generated" (current generated feature segment) |
| """ |
|
|
| |
| |
| x = {"feature_length": torch.tensor(feature_length), "text": text} |
|
|
| if feature_text_end is not None: |
| x["feature_text_end"] = feature_text_end |
|
|
| |
| |
| generator = model.stream_generate(x, num_denoise_steps=num_denoise_steps) |
|
|
| for step_output in generator: |
| |
| yield step_output |
|
|
|
|
| if __name__ == "__main__": |
| import argparse |
|
|
| parser = argparse.ArgumentParser() |
| parser.add_argument("--config", type=str, required=True, help="Path to config") |
| parser.add_argument( |
| "--text", type=str, default="a person walks forward", help="Text prompt" |
| ) |
| parser.add_argument("--length", type=int, default=120, help="Motion length") |
| parser.add_argument( |
| "--output", type=str, default="output.mp4", help="Output video path" |
| ) |
| parser.add_argument( |
| "--num_denoise_steps", type=int, default=None, help="Number of denoising steps" |
| ) |
| args = parser.parse_args() |
|
|
| print("Loading model...") |
| vae, model = load_model_from_config() |
|
|
|
|