Spaces:
Running
on
Zero
Running
on
Zero
Tonic
commited on
fix directory arguments
Browse files
main.py
CHANGED
|
@@ -333,19 +333,27 @@ def chords_string_to_list(chords: str):
|
|
| 333 |
chrd_times = [x.split(',') for x in chords[1:-1].split('),(')]
|
| 334 |
return [(x[0], float(x[1])) for x in chrd_times]
|
| 335 |
|
| 336 |
-
# Add this before model loading
|
| 337 |
def patch_jasco_cache():
|
| 338 |
"""Monkey patch JASCO cache initialization"""
|
| 339 |
from audiocraft.modules import jasco_conditioners
|
| 340 |
|
| 341 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 342 |
|
| 343 |
def new_init(self, *args, **kwargs):
|
| 344 |
if 'cache_path' in kwargs:
|
| 345 |
kwargs['cache_path'] = os.path.join(CACHE_DIR, 'drum_cache')
|
| 346 |
return original_init(self, *args, **kwargs)
|
| 347 |
|
| 348 |
-
jasco_conditioners
|
|
|
|
|
|
|
|
|
|
| 349 |
|
| 350 |
# Apply the patch
|
| 351 |
patch_jasco_cache()
|
|
@@ -364,32 +372,24 @@ def load_model(version='facebook/jasco-chords-drums-melody-400M'):
|
|
| 364 |
cache_path = os.path.join(CACHE_DIR, version.replace('/', '_'))
|
| 365 |
os.makedirs(cache_path, exist_ok=True)
|
| 366 |
|
| 367 |
-
# Set
|
| 368 |
os.environ['AUDIOCRAFT_CACHE_DIR'] = cache_path
|
| 369 |
os.environ['TRANSFORMERS_CACHE'] = os.path.join(cache_path, 'transformers')
|
| 370 |
|
| 371 |
-
# Initialize model with custom cache configuration
|
| 372 |
-
model_kwargs = {
|
| 373 |
-
'device': 'cuda',
|
| 374 |
-
'cache_dir': cache_path,
|
| 375 |
-
'model_cache_dir': cache_path
|
| 376 |
-
}
|
| 377 |
-
|
| 378 |
# Initialize chord mapping
|
| 379 |
mapping_file = initialize_chord_mapping()
|
| 380 |
os.environ['AUDIOCRAFT_CHORD_MAPPING'] = mapping_file
|
| 381 |
|
| 382 |
-
# Load the model with
|
| 383 |
MODEL = JASCO.get_pretrained(
|
| 384 |
version,
|
| 385 |
-
device='cuda'
|
| 386 |
-
cache_dir=cache_path,
|
| 387 |
-
local_files_only=False
|
| 388 |
)
|
| 389 |
MODEL.name = version
|
| 390 |
|
| 391 |
-
# Configure model paths
|
| 392 |
-
MODEL
|
|
|
|
| 393 |
|
| 394 |
# Load the chord mapping
|
| 395 |
with open(mapping_file, 'rb') as f:
|
|
|
|
| 333 |
chrd_times = [x.split(',') for x in chords[1:-1].split('),(')]
|
| 334 |
return [(x[0], float(x[1])) for x in chrd_times]
|
| 335 |
|
|
|
|
| 336 |
def patch_jasco_cache():
|
| 337 |
"""Monkey patch JASCO cache initialization"""
|
| 338 |
from audiocraft.modules import jasco_conditioners
|
| 339 |
|
| 340 |
+
if hasattr(jasco_conditioners, 'DrumConditioner'):
|
| 341 |
+
original_init = jasco_conditioners.DrumConditioner.__init__
|
| 342 |
+
elif hasattr(jasco_conditioners, 'DrumsConditioner'):
|
| 343 |
+
original_init = jasco_conditioners.DrumsConditioner.__init__
|
| 344 |
+
else:
|
| 345 |
+
print("Warning: Could not find DrumConditioner class")
|
| 346 |
+
return
|
| 347 |
|
| 348 |
def new_init(self, *args, **kwargs):
|
| 349 |
if 'cache_path' in kwargs:
|
| 350 |
kwargs['cache_path'] = os.path.join(CACHE_DIR, 'drum_cache')
|
| 351 |
return original_init(self, *args, **kwargs)
|
| 352 |
|
| 353 |
+
if hasattr(jasco_conditioners, 'DrumConditioner'):
|
| 354 |
+
jasco_conditioners.DrumConditioner.__init__ = new_init
|
| 355 |
+
elif hasattr(jasco_conditioners, 'DrumsConditioner'):
|
| 356 |
+
jasco_conditioners.DrumsConditioner.__init__ = new_init
|
| 357 |
|
| 358 |
# Apply the patch
|
| 359 |
patch_jasco_cache()
|
|
|
|
| 372 |
cache_path = os.path.join(CACHE_DIR, version.replace('/', '_'))
|
| 373 |
os.makedirs(cache_path, exist_ok=True)
|
| 374 |
|
| 375 |
+
# Set environment variables for caching
|
| 376 |
os.environ['AUDIOCRAFT_CACHE_DIR'] = cache_path
|
| 377 |
os.environ['TRANSFORMERS_CACHE'] = os.path.join(cache_path, 'transformers')
|
| 378 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 379 |
# Initialize chord mapping
|
| 380 |
mapping_file = initialize_chord_mapping()
|
| 381 |
os.environ['AUDIOCRAFT_CHORD_MAPPING'] = mapping_file
|
| 382 |
|
| 383 |
+
# Load the model with only supported parameters
|
| 384 |
MODEL = JASCO.get_pretrained(
|
| 385 |
version,
|
| 386 |
+
device='cuda'
|
|
|
|
|
|
|
| 387 |
)
|
| 388 |
MODEL.name = version
|
| 389 |
|
| 390 |
+
# Configure model paths after loading
|
| 391 |
+
if hasattr(MODEL, '_cache_dir'):
|
| 392 |
+
MODEL._cache_dir = cache_path
|
| 393 |
|
| 394 |
# Load the chord mapping
|
| 395 |
with open(mapping_file, 'rb') as f:
|