|
|
import os |
|
|
import sys |
|
|
import time |
|
|
import torch |
|
|
import gradio as gr |
|
|
import numpy as np |
|
|
import imageio |
|
|
import spaces |
|
|
from PIL import Image |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from videox_fun.ui.wan_ui import Wan_Controller, css |
|
|
from videox_fun.ui.ui import ( |
|
|
create_model_type, create_model_checkpoints, create_finetune_models_checkpoints, |
|
|
create_teacache_params, create_cfg_skip_params, create_cfg_riflex_k, |
|
|
create_prompts, create_samplers, create_height_width, |
|
|
create_generation_methods_and_video_length, create_generation_method, |
|
|
create_cfg_and_seedbox, create_ui_outputs |
|
|
) |
|
|
from videox_fun.data.dataset_image_video import derive_ground_object_from_instruction |
|
|
from videox_fun.utils.lora_utils import merge_lora, unmerge_lora |
|
|
from videox_fun.utils.utils import save_videos_grid, timer |
|
|
|
|
|
global_controller = None |
|
|
|
|
|
@spaces.GPU(duration=300) |
|
|
@timer |
|
|
def generate_wrapper(*args): |
|
|
global global_controller |
|
|
return global_controller.generate(*args) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_height_width_english(default_height, default_width, maximum_height, maximum_width): |
|
|
resize_method = gr.Radio( |
|
|
["Generate by", "Resize according to Reference"], |
|
|
value="Generate by", |
|
|
show_label=False, |
|
|
visible=False |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
width_slider = gr.Slider(label="Width", value=default_width, minimum=128, maximum=maximum_width, step=16, visible=False) |
|
|
height_slider = gr.Slider(label="Height", value=default_height, minimum=128, maximum=maximum_height, step=16, visible=False) |
|
|
base_resolution = gr.Radio(label="Base Resolution", value=512, choices=[512, 640, 768, 896, 960, 1024], visible=False) |
|
|
|
|
|
return resize_method, width_slider, height_slider, base_resolution |
|
|
|
|
|
def load_video_frames(video_path: str, source_frames: int): |
|
|
assert source_frames is not None, "source_frames is required" |
|
|
|
|
|
reader = imageio.get_reader(video_path) |
|
|
try: |
|
|
total_frames = reader.count_frames() |
|
|
except Exception: |
|
|
total_frames = sum(1 for _ in reader) |
|
|
reader = imageio.get_reader(video_path) |
|
|
|
|
|
stride = max(1, total_frames // source_frames) |
|
|
|
|
|
start_frame = torch.randint(0, max(1, total_frames - stride * source_frames), (1,))[0].item() |
|
|
|
|
|
frames = [] |
|
|
original_height, original_width = None, None |
|
|
|
|
|
for i in range(source_frames): |
|
|
idx = start_frame + i * stride |
|
|
if idx >= total_frames: |
|
|
break |
|
|
try: |
|
|
frame = reader.get_data(idx) |
|
|
pil_frame = Image.fromarray(frame) |
|
|
if original_height is None: |
|
|
original_width, original_height = pil_frame.size |
|
|
print(f"Original video dimensions: {original_width}x{original_height}") |
|
|
frames.append(pil_frame) |
|
|
except IndexError: |
|
|
break |
|
|
|
|
|
reader.close() |
|
|
|
|
|
while len(frames) < source_frames: |
|
|
if frames: |
|
|
frames.append(frames[-1].copy()) |
|
|
else: |
|
|
w, h = (original_width, original_height) if original_width else (832, 480) |
|
|
frames.append(Image.new('RGB', (w, h), (0, 0, 0))) |
|
|
|
|
|
assert len(frames) == source_frames, f"Loaded {len(frames)} frames, expected {source_frames}" |
|
|
print(f"Loaded {source_frames} source frames") |
|
|
|
|
|
input_video = torch.from_numpy(np.array(frames)) |
|
|
input_video = input_video.permute([3, 0, 1, 2]).unsqueeze(0).float() |
|
|
input_video = input_video * (2.0 / 255.0) - 1.0 |
|
|
|
|
|
return input_video, original_height, original_width |
|
|
|
|
|
|
|
|
def preload_models(controller, default_model_path, default_lora_name, acc_lora_path): |
|
|
""" |
|
|
Preload base model and LoRAs before launching the app to avoid first-run latency. |
|
|
""" |
|
|
controller.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
|
|
|
if not hasattr(controller, "_active_lora_path"): |
|
|
controller._active_lora_path = None |
|
|
if not hasattr(controller, "_acc_lora_active"): |
|
|
controller._acc_lora_active = False |
|
|
|
|
|
try: |
|
|
print(f"[preload] Loading base model: {default_model_path}") |
|
|
controller.update_diffusion_transformer(default_model_path) |
|
|
|
|
|
base_candidate = os.path.join(controller.personalized_model_dir, os.path.basename(default_model_path)) |
|
|
if os.path.exists(base_candidate): |
|
|
controller.update_base_model(os.path.basename(base_candidate)) |
|
|
else: |
|
|
print(f"[preload] Skip update_base_model (not found at {base_candidate})") |
|
|
|
|
|
print(f"[preload] Loading VideoCoF LoRA: {default_lora_name}") |
|
|
controller.update_lora_model(default_lora_name) |
|
|
if controller.lora_model_path and controller.lora_model_path != "none": |
|
|
controller.pipeline = merge_lora( |
|
|
controller.pipeline, |
|
|
controller.lora_model_path, |
|
|
multiplier=1.0, |
|
|
device=controller.device, |
|
|
) |
|
|
controller._active_lora_path = controller.lora_model_path |
|
|
|
|
|
if acc_lora_path and os.path.exists(acc_lora_path): |
|
|
print(f"[preload] Loading Acceleration LoRA: {acc_lora_path}") |
|
|
controller.pipeline = merge_lora( |
|
|
controller.pipeline, acc_lora_path, multiplier=1.0, device=controller.device |
|
|
) |
|
|
controller._acc_lora_active = True |
|
|
else: |
|
|
print(f"[preload] Acceleration LoRA not found at {acc_lora_path}") |
|
|
except Exception as e: |
|
|
print(f"[preload] Warning: preload failed: {e}") |
|
|
finally: |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
class VideoCoF_Controller(Wan_Controller): |
|
|
@timer |
|
|
def generate( |
|
|
self, |
|
|
diffusion_transformer_dropdown, |
|
|
base_model_dropdown, |
|
|
lora_model_dropdown, |
|
|
lora_alpha_slider, |
|
|
prompt_textbox, |
|
|
negative_prompt_textbox, |
|
|
sampler_dropdown, |
|
|
sample_step_slider, |
|
|
resize_method, |
|
|
width_slider, |
|
|
height_slider, |
|
|
base_resolution, |
|
|
generation_method, |
|
|
length_slider, |
|
|
overlap_video_length, |
|
|
partial_video_length, |
|
|
cfg_scale_slider, |
|
|
start_image, |
|
|
end_image, |
|
|
validation_video, |
|
|
validation_video_mask, |
|
|
control_video, |
|
|
denoise_strength, |
|
|
seed_textbox, |
|
|
ref_image=None, |
|
|
|
|
|
source_frames_slider=33, |
|
|
reasoning_frames_slider=4, |
|
|
repeat_rope_checkbox=True, |
|
|
|
|
|
enable_acceleration=True, |
|
|
fps=8, |
|
|
is_api=False, |
|
|
): |
|
|
self.clear_cache() |
|
|
print(f"VideoCoF Generation started.") |
|
|
|
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
self.device = torch.device("cuda") |
|
|
else: |
|
|
self.device = torch.device("cpu") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
if hasattr(self, "pipeline") and self.pipeline is not None: |
|
|
self.pipeline.to(self.device) |
|
|
except Exception as move_e: |
|
|
print(f"Warning: failed to move pipeline to {self.device}: {move_e}") |
|
|
|
|
|
if self.diffusion_transformer_dropdown != diffusion_transformer_dropdown: |
|
|
self.update_diffusion_transformer(diffusion_transformer_dropdown) |
|
|
|
|
|
if self.base_model_path != base_model_dropdown: |
|
|
self.update_base_model(base_model_dropdown) |
|
|
|
|
|
if self.lora_model_path != lora_model_dropdown: |
|
|
self.update_lora_model(lora_model_dropdown) |
|
|
|
|
|
|
|
|
if not hasattr(self, "_active_lora_path"): |
|
|
self._active_lora_path = None |
|
|
if not hasattr(self, "_acc_lora_active"): |
|
|
self._acc_lora_active = False |
|
|
|
|
|
|
|
|
scheduler_config = self.pipeline.scheduler.config |
|
|
if sampler_dropdown in ["Flow_Unipc", "Flow_DPM++"]: |
|
|
scheduler_config['shift'] = 1 |
|
|
self.pipeline.scheduler = self.scheduler_dict[sampler_dropdown].from_config(scheduler_config) |
|
|
|
|
|
|
|
|
|
|
|
if self.lora_model_path != "none": |
|
|
|
|
|
if self._active_lora_path and self._active_lora_path != self.lora_model_path: |
|
|
print(f"Unmerging previous VideoCoF LoRA: {self._active_lora_path}") |
|
|
self.pipeline = unmerge_lora(self.pipeline, self._active_lora_path, multiplier=lora_alpha_slider, device=self.device) |
|
|
self._active_lora_path = None |
|
|
|
|
|
if self._active_lora_path != self.lora_model_path: |
|
|
print(f"Merge VideoCoF LoRA: {self.lora_model_path}") |
|
|
self.pipeline = merge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider, device=self.device) |
|
|
self._active_lora_path = self.lora_model_path |
|
|
|
|
|
|
|
|
acc_lora_path = os.path.join(self.personalized_model_dir, "Wan2.1_Text_to_Video_14B_FusionX_LoRA.safetensors") |
|
|
if enable_acceleration: |
|
|
if os.path.exists(acc_lora_path): |
|
|
if not self._acc_lora_active: |
|
|
print(f"Merge Acceleration LoRA: {acc_lora_path}") |
|
|
|
|
|
self.pipeline = merge_lora(self.pipeline, acc_lora_path, multiplier=1.0, device=self.device) |
|
|
self._acc_lora_active = True |
|
|
else: |
|
|
print(f"Warning: Acceleration LoRA not found at {acc_lora_path}") |
|
|
else: |
|
|
|
|
|
if self._acc_lora_active and os.path.exists(acc_lora_path): |
|
|
print("Unmerging Acceleration LoRA (disabled)") |
|
|
self.pipeline = unmerge_lora(self.pipeline, acc_lora_path, multiplier=1.0, device=self.device) |
|
|
self._acc_lora_active = False |
|
|
|
|
|
|
|
|
if int(seed_textbox) != -1 and seed_textbox != "": |
|
|
torch.manual_seed(int(seed_textbox)) |
|
|
else: |
|
|
seed_textbox = np.random.randint(0, 1e10) |
|
|
|
|
|
gen_device = getattr(getattr(self, "pipeline", None), "transformer", None) |
|
|
gen_device = gen_device.device if gen_device is not None else self.device |
|
|
if gen_device.type == 'meta': |
|
|
gen_device = self.device |
|
|
generator = torch.Generator(device=gen_device).manual_seed(int(seed_textbox)) |
|
|
|
|
|
try: |
|
|
|
|
|
|
|
|
input_video_path = validation_video |
|
|
|
|
|
if input_video_path is None: |
|
|
|
|
|
input_video_path = control_video |
|
|
|
|
|
if input_video_path is None: |
|
|
raise ValueError("Please upload a video for VideoCoF generation.") |
|
|
|
|
|
|
|
|
edit_text = prompt_textbox |
|
|
ground_instr = derive_ground_object_from_instruction(edit_text) |
|
|
prompt = ( |
|
|
"A video sequence showing three parts: first the original scene, " |
|
|
f"then grounded {ground_instr}, and finally the same scene but {edit_text}" |
|
|
) |
|
|
print(f"Constructed prompt: {prompt}") |
|
|
|
|
|
|
|
|
input_video_tensor, video_height, video_width = load_video_frames( |
|
|
input_video_path, |
|
|
source_frames=source_frames_slider |
|
|
) |
|
|
|
|
|
|
|
|
h, w = video_height, video_width |
|
|
print(f"Input video dimensions: {w}x{h}") |
|
|
|
|
|
print(f"Running pipeline with frames={length_slider}, source={source_frames_slider}, reasoning={reasoning_frames_slider}") |
|
|
shift = 3 |
|
|
|
|
|
sample = self.pipeline( |
|
|
video=input_video_tensor, |
|
|
prompt=prompt, |
|
|
num_frames=length_slider, |
|
|
source_frames=source_frames_slider, |
|
|
reasoning_frames=reasoning_frames_slider, |
|
|
negative_prompt=negative_prompt_textbox, |
|
|
height=h, |
|
|
width=w, |
|
|
generator=generator, |
|
|
guidance_scale=cfg_scale_slider, |
|
|
num_inference_steps=sample_step_slider, |
|
|
shift=shift, |
|
|
repeat_rope=repeat_rope_checkbox, |
|
|
cot=True, |
|
|
).videos |
|
|
|
|
|
|
|
|
final_video = sample[:, :, -source_frames_slider:, :, :] |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error: {e}") |
|
|
|
|
|
if self._acc_lora_active and os.path.exists(acc_lora_path): |
|
|
print("Unmerging Acceleration LoRA (due to error)") |
|
|
self.pipeline = unmerge_lora(self.pipeline, acc_lora_path, multiplier=1.0, device=self.device) |
|
|
self._acc_lora_active = False |
|
|
|
|
|
if self._active_lora_path: |
|
|
print("Unmerging VideoCoF LoRA (due to error)") |
|
|
self.pipeline = unmerge_lora(self.pipeline, self._active_lora_path, multiplier=lora_alpha_slider, device=self.device) |
|
|
self._active_lora_path = None |
|
|
return gr.update(), gr.update(), f"Error: {str(e)}" |
|
|
|
|
|
|
|
|
save_sample_path = self.save_outputs( |
|
|
False, source_frames_slider, final_video, fps=fps |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return gr.Image(visible=False, value=None), gr.Video(value=save_sample_path, visible=True), "Success" |
|
|
|
|
|
def ui(GPU_memory_mode, scheduler_dict, config_path, compile_dit, weight_dtype): |
|
|
controller = VideoCoF_Controller( |
|
|
GPU_memory_mode, scheduler_dict, model_name=None, model_type="Inpaint", |
|
|
config_path=config_path, compile_dit=compile_dit, |
|
|
weight_dtype=weight_dtype |
|
|
) |
|
|
global global_controller |
|
|
global_controller = controller |
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
gr.Markdown("# VideoCoF Demo") |
|
|
|
|
|
with gr.Column(variant="panel"): |
|
|
|
|
|
local_model_dir = os.path.join("models", "Wan2.1-T2V-14B") |
|
|
diffusion_transformer_dropdown, _ = create_model_checkpoints(controller, visible=False, default_model=local_model_dir) |
|
|
|
|
|
|
|
|
try: |
|
|
from huggingface_hub import snapshot_download, hf_hub_download |
|
|
print("Downloading Wan2.1-T2V-14B weights...") |
|
|
hf_model_id = "Wan-AI/Wan2.1-T2V-14B" |
|
|
snapshot_download(repo_id=hf_model_id, local_dir=local_model_dir, local_dir_use_symlinks=False) |
|
|
|
|
|
os.makedirs("models/Personalized_Model", exist_ok=True) |
|
|
|
|
|
print("Downloading VideoCoF weights...") |
|
|
default_lora_name = "videocof.safetensors" |
|
|
hf_hub_download(repo_id="XiangpengYang/VideoCoF", filename=default_lora_name, local_dir="models/Personalized_Model") |
|
|
|
|
|
print("Downloading FusionX Acceleration LoRA...") |
|
|
acc_lora_filename = "Wan2.1_Text_to_Video_14B_FusionX_LoRA.safetensors" |
|
|
hf_hub_download(repo_id="MonsterMMORPG/Wan_GGUF", filename=acc_lora_filename, local_dir="models/Personalized_Model") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Warning: Failed to pre-download weights: {e}") |
|
|
|
|
|
base_model_dropdown, lora_model_dropdown, lora_alpha_slider, _ = create_finetune_models_checkpoints( |
|
|
controller, visible=False, default_lora="videocof.safetensors" |
|
|
) |
|
|
|
|
|
|
|
|
lora_alpha_slider.value = 1.0 |
|
|
|
|
|
|
|
|
acc_lora_path = os.path.join("models", "Personalized_Model", "Wan2.1_Text_to_Video_14B_FusionX_LoRA.safetensors") |
|
|
preload_models(controller, local_model_dir, "videocof.safetensors", acc_lora_path) |
|
|
|
|
|
with gr.Column(variant="panel"): |
|
|
prompt_textbox, negative_prompt_textbox = create_prompts(prompt="Remove the young man with short black hair wearing black shirt on the left.") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
sampler_dropdown, sample_step_slider = create_samplers(controller) |
|
|
|
|
|
|
|
|
sample_step_slider.value = 4 |
|
|
|
|
|
|
|
|
with gr.Group(): |
|
|
gr.Markdown("### VideoCoF Parameters") |
|
|
source_frames_slider = gr.Slider(label="Source Frames", minimum=1, maximum=100, value=33, step=1) |
|
|
reasoning_frames_slider = gr.Slider(label="Reasoning Frames", minimum=1, maximum=20, value=4, step=1) |
|
|
repeat_rope_checkbox = gr.Checkbox(label="Repeat RoPE", value=True) |
|
|
|
|
|
enable_acceleration = gr.Checkbox(label="Enable 4-step Acceleration (FusionX LoRA)", value=True) |
|
|
|
|
|
|
|
|
resize_method, width_slider, height_slider, base_resolution = create_height_width_english( |
|
|
default_height=480, default_width=832, maximum_height=1344, maximum_width=1344 |
|
|
) |
|
|
|
|
|
|
|
|
generation_method, length_slider, overlap_video_length, partial_video_length = \ |
|
|
create_generation_methods_and_video_length( |
|
|
["Video Generation"], |
|
|
default_video_length=65, |
|
|
maximum_video_length=161 |
|
|
) |
|
|
|
|
|
|
|
|
image_to_video_col, video_to_video_col, control_video_col, source_method, start_image, template_gallery, end_image, validation_video, validation_video_mask, denoise_strength, control_video, ref_image = create_generation_method( |
|
|
["Video to Video"], |
|
|
prompt_textbox, |
|
|
support_end_image=False, |
|
|
default_video="assets/two_man.mp4", |
|
|
video_examples=[ |
|
|
["assets/two_man.mp4", "Remove the young man with short black hair wearing black shirt on the left."], |
|
|
["assets/three_people.mp4", "Remove the man with short dark hair wearing a gray suit on the right"], |
|
|
["assets/office.mp4", "Remove the beige CRT computer setup."], |
|
|
["assets/woman_ballon.mp4", "Add the woman in a floral dress pointing at the balloon on the left."], |
|
|
["assets/greenhouse.mp4", "A white Samoyed is watching the man, who crouches in a greenhouse. The Samoyed is covered in thick, fluffy white fur, giving it a very soft and plush appearance. Its ears are erect and triangular, making it look alert and intelligent. The Samoyed's face features its signature smile, with bright black eyes that convey friendliness and curiosity."], |
|
|
["assets/gameplay.mp4", "Add the woman holding the blue game controller to the left of the man, engaged in gameplay."], |
|
|
["assets/dog.mp4", "Add the brown and white beagle interacting with and drinking from the metallic bowl on the wooden floor."], |
|
|
["assets/sign.mp4", "Replace the yellow \"SCHOOL\" sign with a red hospital sign, featuring a white hospital emblem on the top and the word \"HOSPITAL\" below."], |
|
|
["assets/old_man.mp4", "Swap the old man with long white hair and a blue checkered shirt at the left side of the frame with a woman with curly brown hair and a denim shirt."], |
|
|
["assets/pants.mp4", "swap the white pants worn by the individual the light blue jeans."], |
|
|
["assets/bowl.mp4", "Make the largest cup on the right white and smooth."], |
|
|
["assets/ketchup.mp4", "Make the ketchup bottle to the right of the BBQ sauce bottle violet color."], |
|
|
["assets/fruit.mp4", "Make the pomegranate at the right side of the basket lavender color."] |
|
|
], |
|
|
) |
|
|
|
|
|
|
|
|
validation_video.visible = True |
|
|
validation_video.interactive = True |
|
|
|
|
|
|
|
|
cfg_scale_slider, seed_textbox, seed_button = create_cfg_and_seedbox(True) |
|
|
seed_textbox.value = "0" |
|
|
cfg_scale_slider.value = 1.0 |
|
|
|
|
|
generate_button = gr.Button(value="Generate", variant='primary') |
|
|
|
|
|
result_image, result_video, infer_progress = create_ui_outputs() |
|
|
|
|
|
|
|
|
generate_button.click( |
|
|
fn=generate_wrapper, |
|
|
inputs=[ |
|
|
diffusion_transformer_dropdown, |
|
|
base_model_dropdown, |
|
|
lora_model_dropdown, |
|
|
lora_alpha_slider, |
|
|
prompt_textbox, |
|
|
negative_prompt_textbox, |
|
|
sampler_dropdown, |
|
|
sample_step_slider, |
|
|
resize_method, |
|
|
width_slider, |
|
|
height_slider, |
|
|
base_resolution, |
|
|
generation_method, |
|
|
length_slider, |
|
|
overlap_video_length, |
|
|
partial_video_length, |
|
|
cfg_scale_slider, |
|
|
start_image, |
|
|
end_image, |
|
|
validation_video, |
|
|
validation_video_mask, |
|
|
control_video, |
|
|
denoise_strength, |
|
|
seed_textbox, |
|
|
ref_image, |
|
|
|
|
|
source_frames_slider, |
|
|
reasoning_frames_slider, |
|
|
repeat_rope_checkbox, |
|
|
enable_acceleration |
|
|
], |
|
|
outputs=[result_image, result_video, infer_progress] |
|
|
) |
|
|
|
|
|
return demo, controller |
|
|
|
|
|
if __name__ == "__main__": |
|
|
from videox_fun.ui.controller import flow_scheduler_dict |
|
|
|
|
|
|
|
|
GPU_memory_mode = "sequential_cpu_offload" |
|
|
compile_dit = False |
|
|
weight_dtype = torch.bfloat16 |
|
|
server_name = "0.0.0.0" |
|
|
server_port = 7860 |
|
|
config_path = "config/wan2.1/wan_civitai.yaml" |
|
|
|
|
|
demo, controller = ui(GPU_memory_mode, flow_scheduler_dict, config_path, compile_dit, weight_dtype) |
|
|
|
|
|
demo.queue(status_update_rate=1).launch( |
|
|
server_name=server_name, |
|
|
server_port=server_port, |
|
|
prevent_thread_lock=True, |
|
|
share=False |
|
|
) |
|
|
|
|
|
while True: |
|
|
time.sleep(5) |
|
|
|