Arrcttacsrks commited on
Commit
965d342
·
verified ·
1 Parent(s): 97fab97

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +79 -87
app.py CHANGED
@@ -1,20 +1,23 @@
1
- import os
2
- import sys
 
 
3
  import cv2
4
  import time
5
- import datetime
6
- import pytz
7
  import gradio as gr
8
  import torch
9
  import numpy as np
10
  from torchvision.utils import save_image
11
 
 
12
  # Import files from the local folder
13
  root_path = os.path.abspath('.')
14
  sys.path.append(root_path)
15
  from test_code.inference import super_resolve_img
16
  from test_code.test_utils import load_grl, load_rrdb, load_dat
17
 
 
18
  def auto_download_if_needed(weight_path):
19
  if os.path.exists(weight_path):
20
  return
@@ -22,86 +25,62 @@ def auto_download_if_needed(weight_path):
22
  if not os.path.exists("pretrained"):
23
  os.makedirs("pretrained")
24
 
25
- weight_mappings = {
26
- "pretrained/4x_APISR_RRDB_GAN_generator.pth": "v0.2.0/4x_APISR_RRDB_GAN_generator.pth",
27
- "pretrained/4x_APISR_GRL_GAN_generator.pth": "v0.1.0/4x_APISR_GRL_GAN_generator.pth",
28
- "pretrained/2x_APISR_RRDB_GAN_generator.pth": "v0.1.0/2x_APISR_RRDB_GAN_generator.pth",
29
- "pretrained/4x_APISR_DAT_GAN_generator.pth": "v0.3.0/4x_APISR_DAT_GAN_generator.pth"
30
- }
31
-
32
- if weight_path in weight_mappings:
33
- version_path = weight_mappings[weight_path]
34
- filename = os.path.basename(weight_path)
35
- os.system(f"wget https://github.com/Kiteretsu77/APISR/releases/download/{version_path}")
36
- os.system(f"mv {filename} pretrained")
37
-
38
- def load_model_with_device(loader_func, weight_path, scale, device):
39
- # Check if CUDA is available, otherwise fall back to CPU
40
- if device.type == 'cuda' and not torch.cuda.is_available():
41
- device = torch.device('cpu') # Force CPU if CUDA is not available
42
 
43
- # Load the state dict with map_location to handle CPU if necessary
44
- state_dict = torch.load(weight_path, map_location=device)
45
-
46
- # Initialize the model
47
- generator = loader_func(weight_path, scale=scale)
 
 
48
 
49
- # Load the state dict and move to device
50
- if hasattr(generator, 'load_state_dict'):
51
- generator.load_state_dict(state_dict)
52
- generator = generator.to(device)
53
 
54
- return generator
55
 
56
  def inference(img_path, model_name):
 
57
  try:
58
- # Determine device - use GPU if available, otherwise CPU
59
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
60
  weight_dtype = torch.float32
61
 
62
- # Model configurations
63
- model_configs = {
64
- "4xGRL": ("pretrained/4x_APISR_GRL_GAN_generator.pth", load_grl, 4),
65
- "4xRRDB": ("pretrained/4x_APISR_RRDB_GAN_generator.pth", load_rrdb, 4),
66
- "2xRRDB": ("pretrained/2x_APISR_RRDB_GAN_generator.pth", load_rrdb, 2),
67
- "4xDAT": ("pretrained/4x_APISR_DAT_GAN_generator.pth", load_dat, 4)
68
- }
69
-
70
- if model_name not in model_configs:
71
- raise gr.Error("Unsupported model selected")
 
 
 
 
 
 
 
 
 
 
72
 
73
- weight_path, loader_func, scale = model_configs[model_name]
74
- auto_download_if_needed(weight_path)
75
 
76
- # Load model and move to appropriate device
77
- try:
78
- generator = load_model_with_device(loader_func, weight_path, scale, device)
79
- except RuntimeError as e:
80
- if "out of memory" in str(e):
81
- # If we run out of CUDA memory, try loading on CPU instead
82
- device = torch.device('cpu')
83
- generator = load_model_with_device(loader_func, weight_path, scale, device)
84
- else:
85
- raise e
86
-
87
  generator = generator.to(dtype=weight_dtype)
88
- generator.eval() # Set to evaluation mode
89
-
90
- print(f"Processing {img_path} on {device}")
91
- print(f"Current time: {datetime.datetime.now(pytz.timezone('US/Eastern'))}")
92
-
93
- # Process image
94
- super_resolved_img = super_resolve_img(
95
- generator,
96
- img_path,
97
- output_path=None,
98
- weight_dtype=weight_dtype,
99
- downsample_threshold=720,
100
- crop_for_4x=True
101
- )
102
-
103
- # Save and convert output
104
- store_name = f"output_{time.time()}.png"
105
  save_image(super_resolved_img, store_name)
106
  outputs = cv2.imread(store_name)
107
  outputs = cv2.cvtColor(outputs, cv2.COLOR_RGB2BGR)
@@ -109,11 +88,16 @@ def inference(img_path, model_name):
109
 
110
  return outputs
111
 
 
112
  except Exception as error:
113
- raise gr.Error(f"Error during processing: {str(error)}")
 
 
114
 
115
  if __name__ == '__main__':
116
- MARKDOWN = """
 
 
117
  ## <p style='text-align: center'> APISR: Anime Production Inspired Real-World Anime Super-Resolution (CVPR 2024) </p>
118
 
119
  [GitHub](https://github.com/Kiteretsu77/APISR) | [Paper](https://arxiv.org/abs/2403.01598)
@@ -134,10 +118,15 @@ if __name__ == '__main__':
134
  with gr.Column(scale=2):
135
  input_image = gr.Image(type="filepath", label="Input")
136
  model_name = gr.Dropdown(
137
- ["2xRRDB", "4xRRDB", "4xGRL", "4xDAT"],
 
 
 
 
 
138
  type="value",
139
  value="4xGRL",
140
- label="Model"
141
  )
142
  run_btn = gr.Button(value="Submit")
143
 
@@ -145,16 +134,19 @@ if __name__ == '__main__':
145
  output_image = gr.Image(type="numpy", label="Output image")
146
 
147
  with gr.Row(elem_classes=["container"]):
148
- gr.Examples([
149
- ["__assets__/lr_inputs/image-00277.png"],
150
- ["__assets__/lr_inputs/image-00542.png"],
151
- ["__assets__/lr_inputs/41.png"],
152
- ["__assets__/lr_inputs/f91.jpg"],
153
- ["__assets__/lr_inputs/image-00440.png"],
154
- ["__assets__/lr_inputs/image-00164.jpg"],
155
- ["__assets__/lr_inputs/img_eva.jpeg"],
156
- ["__assets__/lr_inputs/naruto.jpg"],
157
- ], [input_image])
 
 
 
158
 
159
  run_btn.click(inference, inputs=[input_image, model_name], outputs=[output_image])
160
 
 
1
+ '''
2
+ Gradio demo (almost the same code as the one used in Huggingface space)
3
+ '''
4
+ import os, sys
5
  import cv2
6
  import time
7
+ import datetime, pytz
 
8
  import gradio as gr
9
  import torch
10
  import numpy as np
11
  from torchvision.utils import save_image
12
 
13
+
14
  # Import files from the local folder
15
  root_path = os.path.abspath('.')
16
  sys.path.append(root_path)
17
  from test_code.inference import super_resolve_img
18
  from test_code.test_utils import load_grl, load_rrdb, load_dat
19
 
20
+
21
  def auto_download_if_needed(weight_path):
22
  if os.path.exists(weight_path):
23
  return
 
25
  if not os.path.exists("pretrained"):
26
  os.makedirs("pretrained")
27
 
28
+ if weight_path == "pretrained/4x_APISR_RRDB_GAN_generator.pth":
29
+ os.system("wget https://github.com/Kiteretsu77/APISR/releases/download/v0.2.0/4x_APISR_RRDB_GAN_generator.pth")
30
+ os.system("mv 4x_APISR_RRDB_GAN_generator.pth pretrained")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
+ if weight_path == "pretrained/4x_APISR_GRL_GAN_generator.pth":
33
+ os.system("wget https://github.com/Kiteretsu77/APISR/releases/download/v0.1.0/4x_APISR_GRL_GAN_generator.pth")
34
+ os.system("mv 4x_APISR_GRL_GAN_generator.pth pretrained")
35
+
36
+ if weight_path == "pretrained/2x_APISR_RRDB_GAN_generator.pth":
37
+ os.system("wget https://github.com/Kiteretsu77/APISR/releases/download/v0.1.0/2x_APISR_RRDB_GAN_generator.pth")
38
+ os.system("mv 2x_APISR_RRDB_GAN_generator.pth pretrained")
39
 
40
+ if weight_path == "pretrained/4x_APISR_DAT_GAN_generator.pth":
41
+ os.system("wget https://github.com/Kiteretsu77/APISR/releases/download/v0.3.0/4x_APISR_DAT_GAN_generator.pth")
42
+ os.system("mv 4x_APISR_DAT_GAN_generator.pth pretrained")
 
43
 
44
+
45
 
46
  def inference(img_path, model_name):
47
+
48
  try:
 
 
49
  weight_dtype = torch.float32
50
 
51
+ # Load the model
52
+ if model_name == "4xGRL":
53
+ weight_path = "pretrained/4x_APISR_GRL_GAN_generator.pth"
54
+ auto_download_if_needed(weight_path)
55
+ generator = load_grl(weight_path, scale=4) # Directly use default way now
56
+
57
+ elif model_name == "4xRRDB":
58
+ weight_path = "pretrained/4x_APISR_RRDB_GAN_generator.pth"
59
+ auto_download_if_needed(weight_path)
60
+ generator = load_rrdb(weight_path, scale=4) # Directly use default way now
61
+
62
+ elif model_name == "2xRRDB":
63
+ weight_path = "pretrained/2x_APISR_RRDB_GAN_generator.pth"
64
+ auto_download_if_needed(weight_path)
65
+ generator = load_rrdb(weight_path, scale=2) # Directly use default way now
66
+
67
+ elif model_name == "4xDAT":
68
+ weight_path = "pretrained/4x_APISR_DAT_GAN_generator.pth"
69
+ auto_download_if_needed(weight_path)
70
+ generator = load_dat(weight_path, scale=4) # Directly use default way now
71
 
72
+ else:
73
+ raise gr.Error("We don't support such Model")
74
 
 
 
 
 
 
 
 
 
 
 
 
75
  generator = generator.to(dtype=weight_dtype)
76
+
77
+
78
+ print("We are processing ", img_path)
79
+ print("The time now is ", datetime.datetime.now(pytz.timezone('US/Eastern')))
80
+
81
+ # In default, we will automatically use crop to match 4x size
82
+ super_resolved_img = super_resolve_img(generator, img_path, output_path=None, weight_dtype=weight_dtype, downsample_threshold=720, crop_for_4x=True)
83
+ store_name = str(time.time()) + ".png"
 
 
 
 
 
 
 
 
 
84
  save_image(super_resolved_img, store_name)
85
  outputs = cv2.imread(store_name)
86
  outputs = cv2.cvtColor(outputs, cv2.COLOR_RGB2BGR)
 
88
 
89
  return outputs
90
 
91
+
92
  except Exception as error:
93
+ raise gr.Error(f"global exception: {error}")
94
+
95
+
96
 
97
  if __name__ == '__main__':
98
+
99
+ MARKDOWN = \
100
+ """
101
  ## <p style='text-align: center'> APISR: Anime Production Inspired Real-World Anime Super-Resolution (CVPR 2024) </p>
102
 
103
  [GitHub](https://github.com/Kiteretsu77/APISR) | [Paper](https://arxiv.org/abs/2403.01598)
 
118
  with gr.Column(scale=2):
119
  input_image = gr.Image(type="filepath", label="Input")
120
  model_name = gr.Dropdown(
121
+ [
122
+ "2xRRDB",
123
+ "4xRRDB",
124
+ "4xGRL",
125
+ "4xDAT",
126
+ ],
127
  type="value",
128
  value="4xGRL",
129
+ label="model",
130
  )
131
  run_btn = gr.Button(value="Submit")
132
 
 
134
  output_image = gr.Image(type="numpy", label="Output image")
135
 
136
  with gr.Row(elem_classes=["container"]):
137
+ gr.Examples(
138
+ [
139
+ ["__assets__/lr_inputs/image-00277.png"],
140
+ ["__assets__/lr_inputs/image-00542.png"],
141
+ ["__assets__/lr_inputs/41.png"],
142
+ ["__assets__/lr_inputs/f91.jpg"],
143
+ ["__assets__/lr_inputs/image-00440.png"],
144
+ ["__assets__/lr_inputs/image-00164.jpg"],
145
+ ["__assets__/lr_inputs/img_eva.jpeg"],
146
+ ["__assets__/lr_inputs/naruto.jpg"],
147
+ ],
148
+ [input_image],
149
+ )
150
 
151
  run_btn.click(inference, inputs=[input_image, model_name], outputs=[output_image])
152