Arrcttacsrks commited on
Commit
ab0e436
·
verified ·
1 Parent(s): 6c29300

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -8
app.py CHANGED
@@ -33,13 +33,27 @@ def auto_download_if_needed(weight_path):
33
  os.system(f"wget https://github.com/Kiteretsu77/APISR/releases/download/{version_path}")
34
  os.system(f"mv {filename} pretrained")
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  def inference(img_path, model_name):
37
  try:
38
  # Determine device - use GPU if available, otherwise CPU
39
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
40
  weight_dtype = torch.float32
41
 
42
- # Load the model with appropriate device mapping
43
  model_configs = {
44
  "4xGRL": ("pretrained/4x_APISR_GRL_GAN_generator.pth", load_grl, 4),
45
  "4xRRDB": ("pretrained/4x_APISR_RRDB_GAN_generator.pth", load_rrdb, 4),
@@ -53,13 +67,19 @@ def inference(img_path, model_name):
53
  weight_path, loader_func, scale = model_configs[model_name]
54
  auto_download_if_needed(weight_path)
55
 
56
- # Load model with explicit device mapping
57
- generator = loader_func(
58
- weight_path,
59
- scale=scale,
60
- map_location=device
61
- )
62
- generator = generator.to(device=device, dtype=weight_dtype)
 
 
 
 
 
 
63
 
64
  print(f"Processing {img_path} on {device}")
65
  print(f"Current time: {datetime.datetime.now(pytz.timezone('US/Eastern'))}")
 
33
  os.system(f"wget https://github.com/Kiteretsu77/APISR/releases/download/{version_path}")
34
  os.system(f"mv {filename} pretrained")
35
 
36
+ def load_model_with_device(loader_func, weight_path, scale, device):
37
+ # First load the state dict
38
+ state_dict = torch.load(weight_path, map_location=device)
39
+
40
+ # Initialize the model
41
+ generator = loader_func(weight_path, scale=scale)
42
+
43
+ # Load the state dict and move to device
44
+ if hasattr(generator, 'load_state_dict'):
45
+ generator.load_state_dict(state_dict)
46
+ generator = generator.to(device)
47
+
48
+ return generator
49
+
50
  def inference(img_path, model_name):
51
  try:
52
  # Determine device - use GPU if available, otherwise CPU
53
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
54
  weight_dtype = torch.float32
55
 
56
+ # Model configurations
57
  model_configs = {
58
  "4xGRL": ("pretrained/4x_APISR_GRL_GAN_generator.pth", load_grl, 4),
59
  "4xRRDB": ("pretrained/4x_APISR_RRDB_GAN_generator.pth", load_rrdb, 4),
 
67
  weight_path, loader_func, scale = model_configs[model_name]
68
  auto_download_if_needed(weight_path)
69
 
70
+ # Load model and move to appropriate device
71
+ try:
72
+ generator = load_model_with_device(loader_func, weight_path, scale, device)
73
+ except RuntimeError as e:
74
+ if "out of memory" in str(e):
75
+ # If we run out of CUDA memory, try loading on CPU instead
76
+ device = torch.device('cpu')
77
+ generator = load_model_with_device(loader_func, weight_path, scale, device)
78
+ else:
79
+ raise e
80
+
81
+ generator = generator.to(dtype=weight_dtype)
82
+ generator.eval() # Set to evaluation mode
83
 
84
  print(f"Processing {img_path} on {device}")
85
  print(f"Current time: {datetime.datetime.now(pytz.timezone('US/Eastern'))}")