Spaces:
Sleeping
Sleeping
| import os, sys | |
| import cv2 | |
| import time | |
| import datetime, pytz | |
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| from torchvision.utils import save_image | |
| import json | |
| import threading | |
| from queue import Queue | |
| from pathlib import Path | |
| import shutil | |
| # Import files from the local folder | |
| root_path = os.path.abspath('.') | |
| sys.path.append(root_path) | |
| from test_code.inference import super_resolve_img | |
| from test_code.test_utils import load_grl, load_rrdb, load_dat | |
| # Global configuration | |
| OUTPUT_DIR = "outputs" | |
| HISTORY_FILE = "history.json" | |
| VIDEO_QUEUE_FILE = "video_queue.json" | |
| video_queue = Queue() | |
| processing_status = {} | |
| # Initialize directories | |
| os.makedirs(OUTPUT_DIR, exist_ok=True) | |
| os.makedirs(os.path.join(OUTPUT_DIR, "images"), exist_ok=True) | |
| os.makedirs(os.path.join(OUTPUT_DIR, "videos"), exist_ok=True) | |
| def auto_download_if_needed(weight_path): | |
| if os.path.exists(weight_path): | |
| return | |
| if not os.path.exists("pretrained"): | |
| os.makedirs("pretrained") | |
| if weight_path == "pretrained/4x_APISR_RRDB_GAN_generator.pth": | |
| os.system("wget https://github.com/Kiteretsu77/APISR/releases/download/v0.2.0/4x_APISR_RRDB_GAN_generator.pth") | |
| os.system("mv 4x_APISR_RRDB_GAN_generator.pth pretrained") | |
| if weight_path == "pretrained/4x_APISR_GRL_GAN_generator.pth": | |
| os.system("wget https://github.com/Kiteretsu77/APISR/releases/download/v0.1.0/4x_APISR_GRL_GAN_generator.pth") | |
| os.system("mv 4x_APISR_GRL_GAN_generator.pth pretrained") | |
| if weight_path == "pretrained/2x_APISR_RRDB_GAN_generator.pth": | |
| os.system("wget https://github.com/Kiteretsu77/APISR/releases/download/v0.1.0/2x_APISR_RRDB_GAN_generator.pth") | |
| os.system("mv 2x_APISR_RRDB_GAN_generator.pth pretrained") | |
| if weight_path == "pretrained/4x_APISR_DAT_GAN_generator.pth": | |
| os.system("wget https://github.com/Kiteretsu77/APISR/releases/download/v0.3.0/4x_APISR_DAT_GAN_generator.pth") | |
| os.system("mv 4x_APISR_DAT_GAN_generator.pth pretrained") | |
| def load_history(): | |
| """Load processing history from JSON file""" | |
| if os.path.exists(HISTORY_FILE): | |
| with open(HISTORY_FILE, 'r') as f: | |
| return json.load(f) | |
| return [] | |
| def save_history(history): | |
| """Save processing history to JSON file""" | |
| with open(HISTORY_FILE, 'w') as f: | |
| json.dump(history, f, indent=2) | |
| def add_to_history(input_path, output_path, model_name, process_type, status="completed"): | |
| """Add a record to history""" | |
| history = load_history() | |
| record = { | |
| "timestamp": datetime.datetime.now().isoformat(), | |
| "input_path": input_path, | |
| "output_path": output_path, | |
| "model_name": model_name, | |
| "process_type": process_type, | |
| "status": status | |
| } | |
| history.insert(0, record) # Add to beginning | |
| save_history(history) | |
| def load_generator(model_name): | |
| """Load the appropriate model""" | |
| if model_name == "4xGRL": | |
| weight_path = "pretrained/4x_APISR_GRL_GAN_generator.pth" | |
| auto_download_if_needed(weight_path) | |
| generator = load_grl(weight_path, scale=4) | |
| elif model_name == "4xRRDB": | |
| weight_path = "pretrained/4x_APISR_RRDB_GAN_generator.pth" | |
| auto_download_if_needed(weight_path) | |
| generator = load_rrdb(weight_path, scale=4) | |
| elif model_name == "2xRRDB": | |
| weight_path = "pretrained/2x_APISR_RRDB_GAN_generator.pth" | |
| auto_download_if_needed(weight_path) | |
| generator = load_rrdb(weight_path, scale=2) | |
| elif model_name == "4xDAT": | |
| weight_path = "pretrained/4x_APISR_DAT_GAN_generator.pth" | |
| auto_download_if_needed(weight_path) | |
| generator = load_dat(weight_path, scale=4) | |
| else: | |
| raise ValueError(f"Model {model_name} not supported") | |
| return generator.to(device='cpu') | |
| def inference_image(img_path, model_name): | |
| """Process a single image""" | |
| try: | |
| if img_path is None: | |
| return None, "β Please upload an image first" | |
| generator = load_generator(model_name) | |
| print("Processing image:", img_path) | |
| print("Time:", datetime.datetime.now(pytz.timezone('US/Eastern'))) | |
| # Process image | |
| super_resolved_img = super_resolve_img( | |
| generator, img_path, output_path=None, | |
| downsample_threshold=720, crop_for_4x=True | |
| ) | |
| # Save output | |
| timestamp = int(time.time() * 1000) | |
| output_name = f"image_{timestamp}.png" | |
| output_path = os.path.join(OUTPUT_DIR, "images", output_name) | |
| save_image(super_resolved_img, output_path) | |
| # Load and convert for display | |
| outputs = cv2.imread(output_path) | |
| outputs = cv2.cvtColor(outputs, cv2.COLOR_BGR2RGB) | |
| # Add to history | |
| add_to_history(img_path, output_path, model_name, "image") | |
| return outputs, f"β Saved to: {output_path}" | |
| except Exception as error: | |
| return None, f"β Error: {str(error)}" | |
| def process_video_frame_by_frame(video_path, model_name, task_id): | |
| """Process video frame by frame""" | |
| try: | |
| processing_status[task_id] = {"status": "processing", "progress": 0} | |
| # Load model | |
| generator = load_generator(model_name) | |
| # Open video | |
| cap = cv2.VideoCapture(video_path) | |
| if not cap.isOpened(): | |
| raise ValueError("Cannot open video file") | |
| # Get video properties | |
| fps = cap.get(cv2.CAP_PROP_FPS) | |
| total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
| height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
| # Prepare output | |
| timestamp = int(time.time() * 1000) | |
| output_name = f"video_{timestamp}.mp4" | |
| output_path = os.path.join(OUTPUT_DIR, "videos", output_name) | |
| # Create temporary directory for frames | |
| temp_dir = f"temp_frames_{timestamp}" | |
| os.makedirs(temp_dir, exist_ok=True) | |
| # Process frames | |
| frame_count = 0 | |
| while True: | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| # Save frame temporarily | |
| temp_frame_path = os.path.join(temp_dir, f"frame_{frame_count:06d}.png") | |
| cv2.imwrite(temp_frame_path, frame) | |
| # Super resolve frame | |
| super_resolved_img = super_resolve_img( | |
| generator, temp_frame_path, output_path=None, | |
| downsample_threshold=720, crop_for_4x=True | |
| ) | |
| # Save processed frame | |
| output_frame_path = os.path.join(temp_dir, f"output_{frame_count:06d}.png") | |
| save_image(super_resolved_img, output_frame_path) | |
| frame_count += 1 | |
| progress = int((frame_count / total_frames) * 100) | |
| processing_status[task_id] = {"status": "processing", "progress": progress} | |
| print(f"Task {task_id}: Processed frame {frame_count}/{total_frames} ({progress}%)") | |
| cap.release() | |
| # Combine frames into video using ffmpeg | |
| print(f"Task {task_id}: Combining frames into video...") | |
| processing_status[task_id] = {"status": "encoding", "progress": 100} | |
| os.system(f"ffmpeg -y -framerate {fps} -i {temp_dir}/output_%06d.png -c:v libx264 -pix_fmt yuv420p {output_path}") | |
| # Clean up | |
| shutil.rmtree(temp_dir) | |
| processing_status[task_id] = {"status": "completed", "progress": 100, "output": output_path} | |
| add_to_history(video_path, output_path, model_name, "video") | |
| print(f"Task {task_id}: Completed! Output: {output_path}") | |
| except Exception as error: | |
| processing_status[task_id] = {"status": "error", "error": str(error)} | |
| print(f"Task {task_id}: Error - {error}") | |
| def video_queue_worker(): | |
| """Background worker to process video queue""" | |
| print("Video queue worker started...") | |
| while True: | |
| try: | |
| task = video_queue.get() | |
| if task is None: # Poison pill to stop worker | |
| break | |
| task_id, video_path, model_name = task | |
| print(f"Starting task {task_id}...") | |
| process_video_frame_by_frame(video_path, model_name, task_id) | |
| except Exception as e: | |
| print(f"Worker error: {e}") | |
| finally: | |
| video_queue.task_done() | |
| def submit_video(video_path, model_name): | |
| """Submit video to processing queue""" | |
| if video_path is None: | |
| return None, "β Please upload a video first" | |
| task_id = f"task_{int(time.time() * 1000)}" | |
| video_queue.put((task_id, video_path, model_name)) | |
| processing_status[task_id] = {"status": "queued", "progress": 0} | |
| return None, f"β Video submitted to queue! Task ID: {task_id}\nCheck status in the monitoring section." | |
| def get_queue_status(): | |
| """Get current queue status""" | |
| status_text = "π **Queue Status**\n\n" | |
| status_text += f"Videos in queue: {video_queue.qsize()}\n\n" | |
| if processing_status: | |
| status_text += "**Active Tasks:**\n" | |
| for task_id, status in processing_status.items(): | |
| status_text += f"\n㪠{task_id}:\n" | |
| status_text += f" Status: {status['status']}\n" | |
| status_text += f" Progress: {status.get('progress', 0)}%\n" | |
| if 'output' in status: | |
| status_text += f" Output: {status['output']}\n" | |
| if 'error' in status: | |
| status_text += f" Error: {status['error']}\n" | |
| else: | |
| status_text += "No active tasks" | |
| return status_text | |
| def get_history_display(): | |
| """Get formatted history for display""" | |
| history = load_history() | |
| if not history: | |
| return "No history available" | |
| history_text = "π **Processing History**\n\n" | |
| for idx, record in enumerate(history[:50]): # Show last 50 | |
| history_text += f"**{idx + 1}. {record['process_type'].upper()}** - {record['timestamp']}\n" | |
| history_text += f" Model: {record['model_name']}\n" | |
| history_text += f" Status: {record['status']}\n" | |
| history_text += f" Output: {record['output_path']}\n\n" | |
| return history_text | |
| def clear_history(): | |
| """Clear all history""" | |
| if os.path.exists(HISTORY_FILE): | |
| os.remove(HISTORY_FILE) | |
| return "β History cleared!", get_history_display() | |
| if __name__ == '__main__': | |
| # Start background worker thread | |
| worker_thread = threading.Thread(target=video_queue_worker, daemon=True) | |
| worker_thread.start() | |
| MARKDOWN = """ | |
| # APISR: Anime Production Inspired Real-World Anime Super-Resolution (CVPR 2024) | |
| [GitHub](https://github.com/Kiteretsu77/APISR) | [Paper](https://arxiv.org/abs/2403.01598) | |
| APISR aims at restoring and enhancing low-quality low-resolution **anime** images and video sources with various degradations from real-world scenarios. | |
| ### β οΈ Note: Images with short side > 720px will be downsampled to 720px (e.g., 1920x1080 β 1280x720) | |
| ### πΉ New: Video processing runs in background queue - you can close the browser and it continues! | |
| """ | |
| # Create Gradio interface with Gradio 6.x syntax | |
| with gr.Blocks(title="APISR - Anime Super Resolution") as demo: | |
| gr.Markdown(MARKDOWN) | |
| with gr.Tabs(): | |
| # Tab 1: Image Processing | |
| with gr.Tab("πΌοΈ Image Processing"): | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| input_image = gr.Image(type="filepath", label="Input Image") | |
| image_model = gr.Dropdown( | |
| choices=["2xRRDB", "4xRRDB", "4xGRL", "4xDAT"], | |
| value="4xGRL", | |
| label="Model" | |
| ) | |
| image_btn = gr.Button("π Process Image", variant="primary") | |
| with gr.Column(scale=3): | |
| output_image = gr.Image(type="numpy", label="Output Image") | |
| image_status = gr.Textbox(label="Status", lines=2) | |
| with gr.Row(): | |
| gr.Examples( | |
| examples=[ | |
| ["__assets__/lr_inputs/image-00277.png"], | |
| ["__assets__/lr_inputs/image-00542.png"], | |
| ["__assets__/lr_inputs/41.png"], | |
| ["__assets__/lr_inputs/f91.jpg"], | |
| ], | |
| inputs=[input_image], | |
| ) | |
| image_btn.click( | |
| fn=inference_image, | |
| inputs=[input_image, image_model], | |
| outputs=[output_image, image_status] | |
| ) | |
| # Tab 2: Video Processing | |
| with gr.Tab("π¬ Video Processing"): | |
| gr.Markdown(""" | |
| ### Video Processing Queue | |
| Videos are processed in the background. You can submit multiple videos and close the browser - processing continues! | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_video = gr.Video(label="Input Video") | |
| video_model = gr.Dropdown( | |
| choices=["2xRRDB", "4xRRDB", "4xGRL", "4xDAT"], | |
| value="4xGRL", | |
| label="Model" | |
| ) | |
| video_btn = gr.Button("π€ Submit to Queue", variant="primary") | |
| video_status = gr.Textbox(label="Submission Status", lines=3) | |
| with gr.Column(): | |
| gr.Markdown("### π Queue Monitor") | |
| queue_status = gr.Textbox(label="Queue Status", lines=15, interactive=False) | |
| refresh_btn = gr.Button("π Refresh Status") | |
| video_btn.click( | |
| fn=submit_video, | |
| inputs=[input_video, video_model], | |
| outputs=[input_video, video_status] | |
| ) | |
| refresh_btn.click( | |
| fn=get_queue_status, | |
| outputs=[queue_status] | |
| ) | |
| # Auto-refresh using Timer (Gradio 6.x way) | |
| timer = gr.Timer(value=5, active=True) | |
| timer.tick( | |
| fn=get_queue_status, | |
| outputs=[queue_status] | |
| ) | |
| # Tab 3: History | |
| with gr.Tab("π History"): | |
| gr.Markdown("### Processing History") | |
| with gr.Row(): | |
| refresh_history_btn = gr.Button("π Refresh History") | |
| clear_history_btn = gr.Button("ποΈ Clear History", variant="stop") | |
| history_display = gr.Textbox(label="History", lines=20, interactive=False) | |
| clear_status = gr.Textbox(label="Status", lines=1, visible=True) | |
| refresh_history_btn.click( | |
| fn=get_history_display, | |
| outputs=[history_display] | |
| ) | |
| clear_history_btn.click( | |
| fn=clear_history, | |
| outputs=[clear_status, history_display] | |
| ) | |
| # Auto-load history on page load | |
| demo.load(fn=get_history_display, outputs=[history_display]) | |
| # Launch the app | |
| demo.queue() | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=False | |
| ) |