APISRVideo / app.py
Arrcttacsrks's picture
Update app.py
f8f5547 verified
import os, sys
import cv2
import time
import datetime, pytz
import gradio as gr
import torch
import numpy as np
from torchvision.utils import save_image
import json
import threading
from queue import Queue
from pathlib import Path
import shutil
# Import files from the local folder
root_path = os.path.abspath('.')
sys.path.append(root_path)
from test_code.inference import super_resolve_img
from test_code.test_utils import load_grl, load_rrdb, load_dat
# Global configuration
OUTPUT_DIR = "outputs"
HISTORY_FILE = "history.json"
VIDEO_QUEUE_FILE = "video_queue.json"
video_queue = Queue()
processing_status = {}
# Initialize directories
os.makedirs(OUTPUT_DIR, exist_ok=True)
os.makedirs(os.path.join(OUTPUT_DIR, "images"), exist_ok=True)
os.makedirs(os.path.join(OUTPUT_DIR, "videos"), exist_ok=True)
def auto_download_if_needed(weight_path):
if os.path.exists(weight_path):
return
if not os.path.exists("pretrained"):
os.makedirs("pretrained")
if weight_path == "pretrained/4x_APISR_RRDB_GAN_generator.pth":
os.system("wget https://github.com/Kiteretsu77/APISR/releases/download/v0.2.0/4x_APISR_RRDB_GAN_generator.pth")
os.system("mv 4x_APISR_RRDB_GAN_generator.pth pretrained")
if weight_path == "pretrained/4x_APISR_GRL_GAN_generator.pth":
os.system("wget https://github.com/Kiteretsu77/APISR/releases/download/v0.1.0/4x_APISR_GRL_GAN_generator.pth")
os.system("mv 4x_APISR_GRL_GAN_generator.pth pretrained")
if weight_path == "pretrained/2x_APISR_RRDB_GAN_generator.pth":
os.system("wget https://github.com/Kiteretsu77/APISR/releases/download/v0.1.0/2x_APISR_RRDB_GAN_generator.pth")
os.system("mv 2x_APISR_RRDB_GAN_generator.pth pretrained")
if weight_path == "pretrained/4x_APISR_DAT_GAN_generator.pth":
os.system("wget https://github.com/Kiteretsu77/APISR/releases/download/v0.3.0/4x_APISR_DAT_GAN_generator.pth")
os.system("mv 4x_APISR_DAT_GAN_generator.pth pretrained")
def load_history():
"""Load processing history from JSON file"""
if os.path.exists(HISTORY_FILE):
with open(HISTORY_FILE, 'r') as f:
return json.load(f)
return []
def save_history(history):
"""Save processing history to JSON file"""
with open(HISTORY_FILE, 'w') as f:
json.dump(history, f, indent=2)
def add_to_history(input_path, output_path, model_name, process_type, status="completed"):
"""Add a record to history"""
history = load_history()
record = {
"timestamp": datetime.datetime.now().isoformat(),
"input_path": input_path,
"output_path": output_path,
"model_name": model_name,
"process_type": process_type,
"status": status
}
history.insert(0, record) # Add to beginning
save_history(history)
def load_generator(model_name):
"""Load the appropriate model"""
if model_name == "4xGRL":
weight_path = "pretrained/4x_APISR_GRL_GAN_generator.pth"
auto_download_if_needed(weight_path)
generator = load_grl(weight_path, scale=4)
elif model_name == "4xRRDB":
weight_path = "pretrained/4x_APISR_RRDB_GAN_generator.pth"
auto_download_if_needed(weight_path)
generator = load_rrdb(weight_path, scale=4)
elif model_name == "2xRRDB":
weight_path = "pretrained/2x_APISR_RRDB_GAN_generator.pth"
auto_download_if_needed(weight_path)
generator = load_rrdb(weight_path, scale=2)
elif model_name == "4xDAT":
weight_path = "pretrained/4x_APISR_DAT_GAN_generator.pth"
auto_download_if_needed(weight_path)
generator = load_dat(weight_path, scale=4)
else:
raise ValueError(f"Model {model_name} not supported")
return generator.to(device='cpu')
def inference_image(img_path, model_name):
"""Process a single image"""
try:
if img_path is None:
return None, "❌ Please upload an image first"
generator = load_generator(model_name)
print("Processing image:", img_path)
print("Time:", datetime.datetime.now(pytz.timezone('US/Eastern')))
# Process image
super_resolved_img = super_resolve_img(
generator, img_path, output_path=None,
downsample_threshold=720, crop_for_4x=True
)
# Save output
timestamp = int(time.time() * 1000)
output_name = f"image_{timestamp}.png"
output_path = os.path.join(OUTPUT_DIR, "images", output_name)
save_image(super_resolved_img, output_path)
# Load and convert for display
outputs = cv2.imread(output_path)
outputs = cv2.cvtColor(outputs, cv2.COLOR_BGR2RGB)
# Add to history
add_to_history(img_path, output_path, model_name, "image")
return outputs, f"βœ… Saved to: {output_path}"
except Exception as error:
return None, f"❌ Error: {str(error)}"
def process_video_frame_by_frame(video_path, model_name, task_id):
"""Process video frame by frame"""
try:
processing_status[task_id] = {"status": "processing", "progress": 0}
# Load model
generator = load_generator(model_name)
# Open video
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
raise ValueError("Cannot open video file")
# Get video properties
fps = cap.get(cv2.CAP_PROP_FPS)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
# Prepare output
timestamp = int(time.time() * 1000)
output_name = f"video_{timestamp}.mp4"
output_path = os.path.join(OUTPUT_DIR, "videos", output_name)
# Create temporary directory for frames
temp_dir = f"temp_frames_{timestamp}"
os.makedirs(temp_dir, exist_ok=True)
# Process frames
frame_count = 0
while True:
ret, frame = cap.read()
if not ret:
break
# Save frame temporarily
temp_frame_path = os.path.join(temp_dir, f"frame_{frame_count:06d}.png")
cv2.imwrite(temp_frame_path, frame)
# Super resolve frame
super_resolved_img = super_resolve_img(
generator, temp_frame_path, output_path=None,
downsample_threshold=720, crop_for_4x=True
)
# Save processed frame
output_frame_path = os.path.join(temp_dir, f"output_{frame_count:06d}.png")
save_image(super_resolved_img, output_frame_path)
frame_count += 1
progress = int((frame_count / total_frames) * 100)
processing_status[task_id] = {"status": "processing", "progress": progress}
print(f"Task {task_id}: Processed frame {frame_count}/{total_frames} ({progress}%)")
cap.release()
# Combine frames into video using ffmpeg
print(f"Task {task_id}: Combining frames into video...")
processing_status[task_id] = {"status": "encoding", "progress": 100}
os.system(f"ffmpeg -y -framerate {fps} -i {temp_dir}/output_%06d.png -c:v libx264 -pix_fmt yuv420p {output_path}")
# Clean up
shutil.rmtree(temp_dir)
processing_status[task_id] = {"status": "completed", "progress": 100, "output": output_path}
add_to_history(video_path, output_path, model_name, "video")
print(f"Task {task_id}: Completed! Output: {output_path}")
except Exception as error:
processing_status[task_id] = {"status": "error", "error": str(error)}
print(f"Task {task_id}: Error - {error}")
def video_queue_worker():
"""Background worker to process video queue"""
print("Video queue worker started...")
while True:
try:
task = video_queue.get()
if task is None: # Poison pill to stop worker
break
task_id, video_path, model_name = task
print(f"Starting task {task_id}...")
process_video_frame_by_frame(video_path, model_name, task_id)
except Exception as e:
print(f"Worker error: {e}")
finally:
video_queue.task_done()
def submit_video(video_path, model_name):
"""Submit video to processing queue"""
if video_path is None:
return None, "❌ Please upload a video first"
task_id = f"task_{int(time.time() * 1000)}"
video_queue.put((task_id, video_path, model_name))
processing_status[task_id] = {"status": "queued", "progress": 0}
return None, f"βœ… Video submitted to queue! Task ID: {task_id}\nCheck status in the monitoring section."
def get_queue_status():
"""Get current queue status"""
status_text = "πŸ“Š **Queue Status**\n\n"
status_text += f"Videos in queue: {video_queue.qsize()}\n\n"
if processing_status:
status_text += "**Active Tasks:**\n"
for task_id, status in processing_status.items():
status_text += f"\n🎬 {task_id}:\n"
status_text += f" Status: {status['status']}\n"
status_text += f" Progress: {status.get('progress', 0)}%\n"
if 'output' in status:
status_text += f" Output: {status['output']}\n"
if 'error' in status:
status_text += f" Error: {status['error']}\n"
else:
status_text += "No active tasks"
return status_text
def get_history_display():
"""Get formatted history for display"""
history = load_history()
if not history:
return "No history available"
history_text = "πŸ“œ **Processing History**\n\n"
for idx, record in enumerate(history[:50]): # Show last 50
history_text += f"**{idx + 1}. {record['process_type'].upper()}** - {record['timestamp']}\n"
history_text += f" Model: {record['model_name']}\n"
history_text += f" Status: {record['status']}\n"
history_text += f" Output: {record['output_path']}\n\n"
return history_text
def clear_history():
"""Clear all history"""
if os.path.exists(HISTORY_FILE):
os.remove(HISTORY_FILE)
return "βœ… History cleared!", get_history_display()
if __name__ == '__main__':
# Start background worker thread
worker_thread = threading.Thread(target=video_queue_worker, daemon=True)
worker_thread.start()
MARKDOWN = """
# APISR: Anime Production Inspired Real-World Anime Super-Resolution (CVPR 2024)
[GitHub](https://github.com/Kiteretsu77/APISR) | [Paper](https://arxiv.org/abs/2403.01598)
APISR aims at restoring and enhancing low-quality low-resolution **anime** images and video sources with various degradations from real-world scenarios.
### ⚠️ Note: Images with short side > 720px will be downsampled to 720px (e.g., 1920x1080 β†’ 1280x720)
### πŸ“Ή New: Video processing runs in background queue - you can close the browser and it continues!
"""
# Create Gradio interface with Gradio 6.x syntax
with gr.Blocks(title="APISR - Anime Super Resolution") as demo:
gr.Markdown(MARKDOWN)
with gr.Tabs():
# Tab 1: Image Processing
with gr.Tab("πŸ–ΌοΈ Image Processing"):
with gr.Row():
with gr.Column(scale=2):
input_image = gr.Image(type="filepath", label="Input Image")
image_model = gr.Dropdown(
choices=["2xRRDB", "4xRRDB", "4xGRL", "4xDAT"],
value="4xGRL",
label="Model"
)
image_btn = gr.Button("πŸš€ Process Image", variant="primary")
with gr.Column(scale=3):
output_image = gr.Image(type="numpy", label="Output Image")
image_status = gr.Textbox(label="Status", lines=2)
with gr.Row():
gr.Examples(
examples=[
["__assets__/lr_inputs/image-00277.png"],
["__assets__/lr_inputs/image-00542.png"],
["__assets__/lr_inputs/41.png"],
["__assets__/lr_inputs/f91.jpg"],
],
inputs=[input_image],
)
image_btn.click(
fn=inference_image,
inputs=[input_image, image_model],
outputs=[output_image, image_status]
)
# Tab 2: Video Processing
with gr.Tab("🎬 Video Processing"):
gr.Markdown("""
### Video Processing Queue
Videos are processed in the background. You can submit multiple videos and close the browser - processing continues!
""")
with gr.Row():
with gr.Column():
input_video = gr.Video(label="Input Video")
video_model = gr.Dropdown(
choices=["2xRRDB", "4xRRDB", "4xGRL", "4xDAT"],
value="4xGRL",
label="Model"
)
video_btn = gr.Button("πŸ“€ Submit to Queue", variant="primary")
video_status = gr.Textbox(label="Submission Status", lines=3)
with gr.Column():
gr.Markdown("### πŸ“Š Queue Monitor")
queue_status = gr.Textbox(label="Queue Status", lines=15, interactive=False)
refresh_btn = gr.Button("πŸ”„ Refresh Status")
video_btn.click(
fn=submit_video,
inputs=[input_video, video_model],
outputs=[input_video, video_status]
)
refresh_btn.click(
fn=get_queue_status,
outputs=[queue_status]
)
# Auto-refresh using Timer (Gradio 6.x way)
timer = gr.Timer(value=5, active=True)
timer.tick(
fn=get_queue_status,
outputs=[queue_status]
)
# Tab 3: History
with gr.Tab("πŸ“œ History"):
gr.Markdown("### Processing History")
with gr.Row():
refresh_history_btn = gr.Button("πŸ”„ Refresh History")
clear_history_btn = gr.Button("πŸ—‘οΈ Clear History", variant="stop")
history_display = gr.Textbox(label="History", lines=20, interactive=False)
clear_status = gr.Textbox(label="Status", lines=1, visible=True)
refresh_history_btn.click(
fn=get_history_display,
outputs=[history_display]
)
clear_history_btn.click(
fn=clear_history,
outputs=[clear_status, history_display]
)
# Auto-load history on page load
demo.load(fn=get_history_display, outputs=[history_display])
# Launch the app
demo.queue()
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False
)