import os, sys import cv2 import time import datetime, pytz import gradio as gr import torch import numpy as np from torchvision.utils import save_image import json import threading from queue import Queue from pathlib import Path import shutil # Import files from the local folder root_path = os.path.abspath('.') sys.path.append(root_path) from test_code.inference import super_resolve_img from test_code.test_utils import load_grl, load_rrdb, load_dat # Global configuration OUTPUT_DIR = "outputs" HISTORY_FILE = "history.json" VIDEO_QUEUE_FILE = "video_queue.json" video_queue = Queue() processing_status = {} # Initialize directories os.makedirs(OUTPUT_DIR, exist_ok=True) os.makedirs(os.path.join(OUTPUT_DIR, "images"), exist_ok=True) os.makedirs(os.path.join(OUTPUT_DIR, "videos"), exist_ok=True) def auto_download_if_needed(weight_path): if os.path.exists(weight_path): return if not os.path.exists("pretrained"): os.makedirs("pretrained") if weight_path == "pretrained/4x_APISR_RRDB_GAN_generator.pth": os.system("wget https://github.com/Kiteretsu77/APISR/releases/download/v0.2.0/4x_APISR_RRDB_GAN_generator.pth") os.system("mv 4x_APISR_RRDB_GAN_generator.pth pretrained") if weight_path == "pretrained/4x_APISR_GRL_GAN_generator.pth": os.system("wget https://github.com/Kiteretsu77/APISR/releases/download/v0.1.0/4x_APISR_GRL_GAN_generator.pth") os.system("mv 4x_APISR_GRL_GAN_generator.pth pretrained") if weight_path == "pretrained/2x_APISR_RRDB_GAN_generator.pth": os.system("wget https://github.com/Kiteretsu77/APISR/releases/download/v0.1.0/2x_APISR_RRDB_GAN_generator.pth") os.system("mv 2x_APISR_RRDB_GAN_generator.pth pretrained") if weight_path == "pretrained/4x_APISR_DAT_GAN_generator.pth": os.system("wget https://github.com/Kiteretsu77/APISR/releases/download/v0.3.0/4x_APISR_DAT_GAN_generator.pth") os.system("mv 4x_APISR_DAT_GAN_generator.pth pretrained") def load_history(): """Load processing history from JSON file""" if os.path.exists(HISTORY_FILE): with open(HISTORY_FILE, 'r') as f: return json.load(f) return [] def save_history(history): """Save processing history to JSON file""" with open(HISTORY_FILE, 'w') as f: json.dump(history, f, indent=2) def add_to_history(input_path, output_path, model_name, process_type, status="completed"): """Add a record to history""" history = load_history() record = { "timestamp": datetime.datetime.now().isoformat(), "input_path": input_path, "output_path": output_path, "model_name": model_name, "process_type": process_type, "status": status } history.insert(0, record) # Add to beginning save_history(history) def load_generator(model_name): """Load the appropriate model""" if model_name == "4xGRL": weight_path = "pretrained/4x_APISR_GRL_GAN_generator.pth" auto_download_if_needed(weight_path) generator = load_grl(weight_path, scale=4) elif model_name == "4xRRDB": weight_path = "pretrained/4x_APISR_RRDB_GAN_generator.pth" auto_download_if_needed(weight_path) generator = load_rrdb(weight_path, scale=4) elif model_name == "2xRRDB": weight_path = "pretrained/2x_APISR_RRDB_GAN_generator.pth" auto_download_if_needed(weight_path) generator = load_rrdb(weight_path, scale=2) elif model_name == "4xDAT": weight_path = "pretrained/4x_APISR_DAT_GAN_generator.pth" auto_download_if_needed(weight_path) generator = load_dat(weight_path, scale=4) else: raise ValueError(f"Model {model_name} not supported") return generator.to(device='cpu') def inference_image(img_path, model_name): """Process a single image""" try: if img_path is None: return None, "❌ Please upload an image first" generator = load_generator(model_name) print("Processing image:", img_path) print("Time:", datetime.datetime.now(pytz.timezone('US/Eastern'))) # Process image super_resolved_img = super_resolve_img( generator, img_path, output_path=None, downsample_threshold=720, crop_for_4x=True ) # Save output timestamp = int(time.time() * 1000) output_name = f"image_{timestamp}.png" output_path = os.path.join(OUTPUT_DIR, "images", output_name) save_image(super_resolved_img, output_path) # Load and convert for display outputs = cv2.imread(output_path) outputs = cv2.cvtColor(outputs, cv2.COLOR_BGR2RGB) # Add to history add_to_history(img_path, output_path, model_name, "image") return outputs, f"✅ Saved to: {output_path}" except Exception as error: return None, f"❌ Error: {str(error)}" def process_video_frame_by_frame(video_path, model_name, task_id): """Process video frame by frame""" try: processing_status[task_id] = {"status": "processing", "progress": 0} # Load model generator = load_generator(model_name) # Open video cap = cv2.VideoCapture(video_path) if not cap.isOpened(): raise ValueError("Cannot open video file") # Get video properties fps = cap.get(cv2.CAP_PROP_FPS) total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) # Prepare output timestamp = int(time.time() * 1000) output_name = f"video_{timestamp}.mp4" output_path = os.path.join(OUTPUT_DIR, "videos", output_name) # Create temporary directory for frames temp_dir = f"temp_frames_{timestamp}" os.makedirs(temp_dir, exist_ok=True) # Process frames frame_count = 0 while True: ret, frame = cap.read() if not ret: break # Save frame temporarily temp_frame_path = os.path.join(temp_dir, f"frame_{frame_count:06d}.png") cv2.imwrite(temp_frame_path, frame) # Super resolve frame super_resolved_img = super_resolve_img( generator, temp_frame_path, output_path=None, downsample_threshold=720, crop_for_4x=True ) # Save processed frame output_frame_path = os.path.join(temp_dir, f"output_{frame_count:06d}.png") save_image(super_resolved_img, output_frame_path) frame_count += 1 progress = int((frame_count / total_frames) * 100) processing_status[task_id] = {"status": "processing", "progress": progress} print(f"Task {task_id}: Processed frame {frame_count}/{total_frames} ({progress}%)") cap.release() # Combine frames into video using ffmpeg print(f"Task {task_id}: Combining frames into video...") processing_status[task_id] = {"status": "encoding", "progress": 100} os.system(f"ffmpeg -y -framerate {fps} -i {temp_dir}/output_%06d.png -c:v libx264 -pix_fmt yuv420p {output_path}") # Clean up shutil.rmtree(temp_dir) processing_status[task_id] = {"status": "completed", "progress": 100, "output": output_path} add_to_history(video_path, output_path, model_name, "video") print(f"Task {task_id}: Completed! Output: {output_path}") except Exception as error: processing_status[task_id] = {"status": "error", "error": str(error)} print(f"Task {task_id}: Error - {error}") def video_queue_worker(): """Background worker to process video queue""" print("Video queue worker started...") while True: try: task = video_queue.get() if task is None: # Poison pill to stop worker break task_id, video_path, model_name = task print(f"Starting task {task_id}...") process_video_frame_by_frame(video_path, model_name, task_id) except Exception as e: print(f"Worker error: {e}") finally: video_queue.task_done() def submit_video(video_path, model_name): """Submit video to processing queue""" if video_path is None: return None, "❌ Please upload a video first" task_id = f"task_{int(time.time() * 1000)}" video_queue.put((task_id, video_path, model_name)) processing_status[task_id] = {"status": "queued", "progress": 0} return None, f"✅ Video submitted to queue! Task ID: {task_id}\nCheck status in the monitoring section." def get_queue_status(): """Get current queue status""" status_text = "📊 **Queue Status**\n\n" status_text += f"Videos in queue: {video_queue.qsize()}\n\n" if processing_status: status_text += "**Active Tasks:**\n" for task_id, status in processing_status.items(): status_text += f"\n🎬 {task_id}:\n" status_text += f" Status: {status['status']}\n" status_text += f" Progress: {status.get('progress', 0)}%\n" if 'output' in status: status_text += f" Output: {status['output']}\n" if 'error' in status: status_text += f" Error: {status['error']}\n" else: status_text += "No active tasks" return status_text def get_history_display(): """Get formatted history for display""" history = load_history() if not history: return "No history available" history_text = "📜 **Processing History**\n\n" for idx, record in enumerate(history[:50]): # Show last 50 history_text += f"**{idx + 1}. {record['process_type'].upper()}** - {record['timestamp']}\n" history_text += f" Model: {record['model_name']}\n" history_text += f" Status: {record['status']}\n" history_text += f" Output: {record['output_path']}\n\n" return history_text def clear_history(): """Clear all history""" if os.path.exists(HISTORY_FILE): os.remove(HISTORY_FILE) return "✅ History cleared!", get_history_display() if __name__ == '__main__': # Start background worker thread worker_thread = threading.Thread(target=video_queue_worker, daemon=True) worker_thread.start() MARKDOWN = """ # APISR: Anime Production Inspired Real-World Anime Super-Resolution (CVPR 2024) [GitHub](https://github.com/Kiteretsu77/APISR) | [Paper](https://arxiv.org/abs/2403.01598) APISR aims at restoring and enhancing low-quality low-resolution **anime** images and video sources with various degradations from real-world scenarios. ### ⚠️ Note: Images with short side > 720px will be downsampled to 720px (e.g., 1920x1080 → 1280x720) ### 📹 New: Video processing runs in background queue - you can close the browser and it continues! """ # Create Gradio interface with Gradio 6.x syntax with gr.Blocks(title="APISR - Anime Super Resolution") as demo: gr.Markdown(MARKDOWN) with gr.Tabs(): # Tab 1: Image Processing with gr.Tab("🖼️ Image Processing"): with gr.Row(): with gr.Column(scale=2): input_image = gr.Image(type="filepath", label="Input Image") image_model = gr.Dropdown( choices=["2xRRDB", "4xRRDB", "4xGRL", "4xDAT"], value="4xGRL", label="Model" ) image_btn = gr.Button("🚀 Process Image", variant="primary") with gr.Column(scale=3): output_image = gr.Image(type="numpy", label="Output Image") image_status = gr.Textbox(label="Status", lines=2) with gr.Row(): gr.Examples( examples=[ ["__assets__/lr_inputs/image-00277.png"], ["__assets__/lr_inputs/image-00542.png"], ["__assets__/lr_inputs/41.png"], ["__assets__/lr_inputs/f91.jpg"], ], inputs=[input_image], ) image_btn.click( fn=inference_image, inputs=[input_image, image_model], outputs=[output_image, image_status] ) # Tab 2: Video Processing with gr.Tab("🎬 Video Processing"): gr.Markdown(""" ### Video Processing Queue Videos are processed in the background. You can submit multiple videos and close the browser - processing continues! """) with gr.Row(): with gr.Column(): input_video = gr.Video(label="Input Video") video_model = gr.Dropdown( choices=["2xRRDB", "4xRRDB", "4xGRL", "4xDAT"], value="4xGRL", label="Model" ) video_btn = gr.Button("📤 Submit to Queue", variant="primary") video_status = gr.Textbox(label="Submission Status", lines=3) with gr.Column(): gr.Markdown("### 📊 Queue Monitor") queue_status = gr.Textbox(label="Queue Status", lines=15, interactive=False) refresh_btn = gr.Button("🔄 Refresh Status") video_btn.click( fn=submit_video, inputs=[input_video, video_model], outputs=[input_video, video_status] ) refresh_btn.click( fn=get_queue_status, outputs=[queue_status] ) # Auto-refresh using Timer (Gradio 6.x way) timer = gr.Timer(value=5, active=True) timer.tick( fn=get_queue_status, outputs=[queue_status] ) # Tab 3: History with gr.Tab("📜 History"): gr.Markdown("### Processing History") with gr.Row(): refresh_history_btn = gr.Button("🔄 Refresh History") clear_history_btn = gr.Button("🗑️ Clear History", variant="stop") history_display = gr.Textbox(label="History", lines=20, interactive=False) clear_status = gr.Textbox(label="Status", lines=1, visible=True) refresh_history_btn.click( fn=get_history_display, outputs=[history_display] ) clear_history_btn.click( fn=clear_history, outputs=[clear_status, history_display] ) # Auto-load history on page load demo.load(fn=get_history_display, outputs=[history_display]) # Launch the app demo.queue() demo.launch( server_name="0.0.0.0", server_port=7860, share=False )