File size: 16,006 Bytes
965d342
561c629
8f3d49d
965d342
561c629
 
 
 
8d280db
 
 
 
 
965d342
561c629
 
 
 
9bf54b1
561c629
8d280db
 
 
 
 
 
 
 
 
 
 
 
965d342
561c629
 
 
6c29300
561c629
 
6c29300
965d342
 
 
97fab97
965d342
 
 
 
 
 
 
ab0e436
965d342
 
 
 
ab0e436
8d280db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
965d342
8d280db
 
 
 
 
561c629
7ac2642
 
 
8d280db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7ac2642
8d280db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
965d342
8d280db
 
 
965d342
8d280db
 
 
965d342
8d280db
 
 
 
 
 
 
 
f8f5547
8d280db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6c29300
8d280db
 
 
 
 
 
 
 
 
965d342
8d280db
 
 
 
 
 
 
 
 
 
965d342
 
8d280db
 
 
 
6c29300
8d280db
 
 
 
 
 
 
 
 
 
 
 
965d342
8d280db
965d342
 
8d280db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7ac2642
8d280db
561c629
 
965d342
8d280db
 
 
 
7ac2642
 
6c29300
7ac2642
 
 
561c629
7ac2642
 
 
 
f8f5547
 
7ac2642
8d280db
 
 
 
 
 
 
 
 
7ac2642
8d280db
 
 
 
 
 
 
 
 
 
 
7ac2642
8d280db
 
 
 
 
7ac2642
8d280db
 
 
7ac2642
8d280db
 
561c629
8d280db
 
 
 
 
 
 
 
 
 
 
 
7ac2642
8d280db
 
 
 
 
 
 
 
7ac2642
8d280db
 
 
7ac2642
8d280db
 
 
 
 
7ac2642
8d280db
 
7ac2642
f8f5547
 
 
 
 
 
8d280db
 
 
 
 
 
 
 
 
7ac2642
 
8d280db
 
7ac2642
8d280db
 
 
 
7ac2642
 
8d280db
 
7ac2642
 
 
 
f8f5547
7ac2642
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
import os, sys
import cv2
import time
import datetime, pytz
import gradio as gr
import torch
import numpy as np
from torchvision.utils import save_image
import json
import threading
from queue import Queue
from pathlib import Path
import shutil

# Import files from the local folder
root_path = os.path.abspath('.')
sys.path.append(root_path)
from test_code.inference import super_resolve_img
from test_code.test_utils import load_grl, load_rrdb, load_dat

# Global configuration
OUTPUT_DIR = "outputs"
HISTORY_FILE = "history.json"
VIDEO_QUEUE_FILE = "video_queue.json"
video_queue = Queue()
processing_status = {}

# Initialize directories
os.makedirs(OUTPUT_DIR, exist_ok=True)
os.makedirs(os.path.join(OUTPUT_DIR, "images"), exist_ok=True)
os.makedirs(os.path.join(OUTPUT_DIR, "videos"), exist_ok=True)


def auto_download_if_needed(weight_path):
    if os.path.exists(weight_path):
        return
    
    if not os.path.exists("pretrained"):
        os.makedirs("pretrained")
    
    if weight_path == "pretrained/4x_APISR_RRDB_GAN_generator.pth":
        os.system("wget https://github.com/Kiteretsu77/APISR/releases/download/v0.2.0/4x_APISR_RRDB_GAN_generator.pth")
        os.system("mv 4x_APISR_RRDB_GAN_generator.pth pretrained")
    
    if weight_path == "pretrained/4x_APISR_GRL_GAN_generator.pth":
        os.system("wget https://github.com/Kiteretsu77/APISR/releases/download/v0.1.0/4x_APISR_GRL_GAN_generator.pth")
        os.system("mv 4x_APISR_GRL_GAN_generator.pth pretrained")
        
    if weight_path == "pretrained/2x_APISR_RRDB_GAN_generator.pth":
        os.system("wget https://github.com/Kiteretsu77/APISR/releases/download/v0.1.0/2x_APISR_RRDB_GAN_generator.pth")
        os.system("mv 2x_APISR_RRDB_GAN_generator.pth pretrained")
    
    if weight_path == "pretrained/4x_APISR_DAT_GAN_generator.pth":
        os.system("wget https://github.com/Kiteretsu77/APISR/releases/download/v0.3.0/4x_APISR_DAT_GAN_generator.pth")
        os.system("mv 4x_APISR_DAT_GAN_generator.pth pretrained")


def load_history():
    """Load processing history from JSON file"""
    if os.path.exists(HISTORY_FILE):
        with open(HISTORY_FILE, 'r') as f:
            return json.load(f)
    return []


def save_history(history):
    """Save processing history to JSON file"""
    with open(HISTORY_FILE, 'w') as f:
        json.dump(history, f, indent=2)


def add_to_history(input_path, output_path, model_name, process_type, status="completed"):
    """Add a record to history"""
    history = load_history()
    record = {
        "timestamp": datetime.datetime.now().isoformat(),
        "input_path": input_path,
        "output_path": output_path,
        "model_name": model_name,
        "process_type": process_type,
        "status": status
    }
    history.insert(0, record)  # Add to beginning
    save_history(history)


def load_generator(model_name):
    """Load the appropriate model"""
    if model_name == "4xGRL":
        weight_path = "pretrained/4x_APISR_GRL_GAN_generator.pth"
        auto_download_if_needed(weight_path)
        generator = load_grl(weight_path, scale=4)
        
    elif model_name == "4xRRDB":
        weight_path = "pretrained/4x_APISR_RRDB_GAN_generator.pth"
        auto_download_if_needed(weight_path)
        generator = load_rrdb(weight_path, scale=4)
        
    elif model_name == "2xRRDB":
        weight_path = "pretrained/2x_APISR_RRDB_GAN_generator.pth"
        auto_download_if_needed(weight_path)
        generator = load_rrdb(weight_path, scale=2)
        
    elif model_name == "4xDAT":
        weight_path = "pretrained/4x_APISR_DAT_GAN_generator.pth"
        auto_download_if_needed(weight_path)
        generator = load_dat(weight_path, scale=4)
    else:
        raise ValueError(f"Model {model_name} not supported")
    
    return generator.to(device='cpu')


def inference_image(img_path, model_name):
    """Process a single image"""
    try:
        if img_path is None:
            return None, "❌ Please upload an image first"
            
        generator = load_generator(model_name)
        
        print("Processing image:", img_path)
        print("Time:", datetime.datetime.now(pytz.timezone('US/Eastern')))

        # Process image
        super_resolved_img = super_resolve_img(
            generator, img_path, output_path=None, 
            downsample_threshold=720, crop_for_4x=True
        )
        
        # Save output
        timestamp = int(time.time() * 1000)
        output_name = f"image_{timestamp}.png"
        output_path = os.path.join(OUTPUT_DIR, "images", output_name)
        save_image(super_resolved_img, output_path)
        
        # Load and convert for display
        outputs = cv2.imread(output_path)
        outputs = cv2.cvtColor(outputs, cv2.COLOR_BGR2RGB)
        
        # Add to history
        add_to_history(img_path, output_path, model_name, "image")
        
        return outputs, f"βœ… Saved to: {output_path}"
    
    except Exception as error:
        return None, f"❌ Error: {str(error)}"


def process_video_frame_by_frame(video_path, model_name, task_id):
    """Process video frame by frame"""
    try:
        processing_status[task_id] = {"status": "processing", "progress": 0}
        
        # Load model
        generator = load_generator(model_name)
        
        # Open video
        cap = cv2.VideoCapture(video_path)
        if not cap.isOpened():
            raise ValueError("Cannot open video file")
        
        # Get video properties
        fps = cap.get(cv2.CAP_PROP_FPS)
        total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
        height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
        
        # Prepare output
        timestamp = int(time.time() * 1000)
        output_name = f"video_{timestamp}.mp4"
        output_path = os.path.join(OUTPUT_DIR, "videos", output_name)
        
        # Create temporary directory for frames
        temp_dir = f"temp_frames_{timestamp}"
        os.makedirs(temp_dir, exist_ok=True)
        
        # Process frames
        frame_count = 0
        while True:
            ret, frame = cap.read()
            if not ret:
                break
            
            # Save frame temporarily
            temp_frame_path = os.path.join(temp_dir, f"frame_{frame_count:06d}.png")
            cv2.imwrite(temp_frame_path, frame)
            
            # Super resolve frame
            super_resolved_img = super_resolve_img(
                generator, temp_frame_path, output_path=None,
                downsample_threshold=720, crop_for_4x=True
            )
            
            # Save processed frame
            output_frame_path = os.path.join(temp_dir, f"output_{frame_count:06d}.png")
            save_image(super_resolved_img, output_frame_path)
            
            frame_count += 1
            progress = int((frame_count / total_frames) * 100)
            processing_status[task_id] = {"status": "processing", "progress": progress}
            
            print(f"Task {task_id}: Processed frame {frame_count}/{total_frames} ({progress}%)")
        
        cap.release()
        
        # Combine frames into video using ffmpeg
        print(f"Task {task_id}: Combining frames into video...")
        processing_status[task_id] = {"status": "encoding", "progress": 100}
        
        os.system(f"ffmpeg -y -framerate {fps} -i {temp_dir}/output_%06d.png -c:v libx264 -pix_fmt yuv420p {output_path}")
        
        # Clean up
        shutil.rmtree(temp_dir)
        
        processing_status[task_id] = {"status": "completed", "progress": 100, "output": output_path}
        add_to_history(video_path, output_path, model_name, "video")
        
        print(f"Task {task_id}: Completed! Output: {output_path}")
        
    except Exception as error:
        processing_status[task_id] = {"status": "error", "error": str(error)}
        print(f"Task {task_id}: Error - {error}")


def video_queue_worker():
    """Background worker to process video queue"""
    print("Video queue worker started...")
    while True:
        try:
            task = video_queue.get()
            if task is None:  # Poison pill to stop worker
                break
            
            task_id, video_path, model_name = task
            print(f"Starting task {task_id}...")
            process_video_frame_by_frame(video_path, model_name, task_id)
            
        except Exception as e:
            print(f"Worker error: {e}")
        finally:
            video_queue.task_done()


def submit_video(video_path, model_name):
    """Submit video to processing queue"""
    if video_path is None:
        return None, "❌ Please upload a video first"
    
    task_id = f"task_{int(time.time() * 1000)}"
    video_queue.put((task_id, video_path, model_name))
    processing_status[task_id] = {"status": "queued", "progress": 0}
    
    return None, f"βœ… Video submitted to queue! Task ID: {task_id}\nCheck status in the monitoring section."


def get_queue_status():
    """Get current queue status"""
    status_text = "πŸ“Š **Queue Status**\n\n"
    status_text += f"Videos in queue: {video_queue.qsize()}\n\n"
    
    if processing_status:
        status_text += "**Active Tasks:**\n"
        for task_id, status in processing_status.items():
            status_text += f"\n🎬 {task_id}:\n"
            status_text += f"  Status: {status['status']}\n"
            status_text += f"  Progress: {status.get('progress', 0)}%\n"
            if 'output' in status:
                status_text += f"  Output: {status['output']}\n"
            if 'error' in status:
                status_text += f"  Error: {status['error']}\n"
    else:
        status_text += "No active tasks"
    
    return status_text


def get_history_display():
    """Get formatted history for display"""
    history = load_history()
    if not history:
        return "No history available"
    
    history_text = "πŸ“œ **Processing History**\n\n"
    for idx, record in enumerate(history[:50]):  # Show last 50
        history_text += f"**{idx + 1}. {record['process_type'].upper()}** - {record['timestamp']}\n"
        history_text += f"   Model: {record['model_name']}\n"
        history_text += f"   Status: {record['status']}\n"
        history_text += f"   Output: {record['output_path']}\n\n"
    
    return history_text


def clear_history():
    """Clear all history"""
    if os.path.exists(HISTORY_FILE):
        os.remove(HISTORY_FILE)
    return "βœ… History cleared!", get_history_display()


if __name__ == '__main__':
    
    # Start background worker thread
    worker_thread = threading.Thread(target=video_queue_worker, daemon=True)
    worker_thread.start()
    
    MARKDOWN = """
# APISR: Anime Production Inspired Real-World Anime Super-Resolution (CVPR 2024)

[GitHub](https://github.com/Kiteretsu77/APISR) | [Paper](https://arxiv.org/abs/2403.01598)

APISR aims at restoring and enhancing low-quality low-resolution **anime** images and video sources with various degradations from real-world scenarios.

### ⚠️ Note: Images with short side > 720px will be downsampled to 720px (e.g., 1920x1080 β†’ 1280x720)
### πŸ“Ή New: Video processing runs in background queue - you can close the browser and it continues!
"""

    # Create Gradio interface with Gradio 6.x syntax
    with gr.Blocks(title="APISR - Anime Super Resolution") as demo:
        
        gr.Markdown(MARKDOWN)
        
        with gr.Tabs():
            # Tab 1: Image Processing
            with gr.Tab("πŸ–ΌοΈ Image Processing"):
                with gr.Row():
                    with gr.Column(scale=2):
                        input_image = gr.Image(type="filepath", label="Input Image")
                        image_model = gr.Dropdown(
                            choices=["2xRRDB", "4xRRDB", "4xGRL", "4xDAT"],
                            value="4xGRL",
                            label="Model"
                        )
                        image_btn = gr.Button("πŸš€ Process Image", variant="primary")
                    
                    with gr.Column(scale=3):
                        output_image = gr.Image(type="numpy", label="Output Image")
                        image_status = gr.Textbox(label="Status", lines=2)
                
                with gr.Row():
                    gr.Examples(
                        examples=[
                            ["__assets__/lr_inputs/image-00277.png"],
                            ["__assets__/lr_inputs/image-00542.png"],
                            ["__assets__/lr_inputs/41.png"],
                            ["__assets__/lr_inputs/f91.jpg"],
                        ],
                        inputs=[input_image],
                    )
                
                image_btn.click(
                    fn=inference_image, 
                    inputs=[input_image, image_model], 
                    outputs=[output_image, image_status]
                )
            
            # Tab 2: Video Processing
            with gr.Tab("🎬 Video Processing"):
                gr.Markdown("""
                ### Video Processing Queue
                Videos are processed in the background. You can submit multiple videos and close the browser - processing continues!
                """)
                
                with gr.Row():
                    with gr.Column():
                        input_video = gr.Video(label="Input Video")
                        video_model = gr.Dropdown(
                            choices=["2xRRDB", "4xRRDB", "4xGRL", "4xDAT"],
                            value="4xGRL",
                            label="Model"
                        )
                        video_btn = gr.Button("πŸ“€ Submit to Queue", variant="primary")
                        video_status = gr.Textbox(label="Submission Status", lines=3)
                    
                    with gr.Column():
                        gr.Markdown("### πŸ“Š Queue Monitor")
                        queue_status = gr.Textbox(label="Queue Status", lines=15, interactive=False)
                        refresh_btn = gr.Button("πŸ”„ Refresh Status")
                
                video_btn.click(
                    fn=submit_video,
                    inputs=[input_video, video_model],
                    outputs=[input_video, video_status]
                )
                
                refresh_btn.click(
                    fn=get_queue_status,
                    outputs=[queue_status]
                )
                
                # Auto-refresh using Timer (Gradio 6.x way)
                timer = gr.Timer(value=5, active=True)
                timer.tick(
                    fn=get_queue_status,
                    outputs=[queue_status]
                )
            
            # Tab 3: History
            with gr.Tab("πŸ“œ History"):
                gr.Markdown("### Processing History")
                
                with gr.Row():
                    refresh_history_btn = gr.Button("πŸ”„ Refresh History")
                    clear_history_btn = gr.Button("πŸ—‘οΈ Clear History", variant="stop")
                
                history_display = gr.Textbox(label="History", lines=20, interactive=False)
                clear_status = gr.Textbox(label="Status", lines=1, visible=True)
                
                refresh_history_btn.click(
                    fn=get_history_display,
                    outputs=[history_display]
                )
                
                clear_history_btn.click(
                    fn=clear_history,
                    outputs=[clear_status, history_display]
                )
                
                # Auto-load history on page load
                demo.load(fn=get_history_display, outputs=[history_display])

    # Launch the app
    demo.queue()
    demo.launch(
        server_name="0.0.0.0", 
        server_port=7860,
        share=False
    )