Arrcttacsrks commited on
Commit
97fab97
·
verified ·
1 Parent(s): ab0e436

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -4
app.py CHANGED
@@ -1,7 +1,9 @@
1
- import os, sys
 
2
  import cv2
3
  import time
4
- import datetime, pytz
 
5
  import gradio as gr
6
  import torch
7
  import numpy as np
@@ -34,7 +36,11 @@ def auto_download_if_needed(weight_path):
34
  os.system(f"mv {filename} pretrained")
35
 
36
  def load_model_with_device(loader_func, weight_path, scale, device):
37
- # First load the state dict
 
 
 
 
38
  state_dict = torch.load(weight_path, map_location=device)
39
 
40
  # Initialize the model
@@ -152,4 +158,4 @@ if __name__ == '__main__':
152
 
153
  run_btn.click(inference, inputs=[input_image, model_name], outputs=[output_image])
154
 
155
- block.launch()
 
1
+ import os
2
+ import sys
3
  import cv2
4
  import time
5
+ import datetime
6
+ import pytz
7
  import gradio as gr
8
  import torch
9
  import numpy as np
 
36
  os.system(f"mv {filename} pretrained")
37
 
38
  def load_model_with_device(loader_func, weight_path, scale, device):
39
+ # Check if CUDA is available, otherwise fall back to CPU
40
+ if device.type == 'cuda' and not torch.cuda.is_available():
41
+ device = torch.device('cpu') # Force CPU if CUDA is not available
42
+
43
+ # Load the state dict with map_location to handle CPU if necessary
44
  state_dict = torch.load(weight_path, map_location=device)
45
 
46
  # Initialize the model
 
158
 
159
  run_btn.click(inference, inputs=[input_image, model_name], outputs=[output_image])
160
 
161
+ block.launch()