import torch from transformers import AutoTokenizer, AutoModelForSeq2SeqLM import functools # Define a subset of popular languages mapped to FLORES-200 codes for better UX. # NLLB supports 200+, but a dropdown of 200 items can be unwieldy. # Codes reference: https://github.com/facebookresearch/flores/blob/main/flores200/README.md LANGUAGE_CODES = { "English": "eng_Latn", "French": "fra_Latn", "Spanish": "spa_Latn", "German": "deu_Latn", "Chinese (Simplified)": "zho_Hans", "Chinese (Traditional)": "zho_Hant", "Hindi": "hin_Deva", "Arabic": "arb_Arab", "Russian": "rus_Cyrl", "Portuguese": "por_Latn", "Japanese": "jpn_Jpan", "Korean": "kor_Hang", "Italian": "ita_Latn", "Dutch": "nld_Latn", "Turkish": "tur_Latn", "Vietnamese": "vie_Latn", "Indonesian": "ind_Latn", "Persian": "pes_Arab", "Polish": "pol_Latn", "Ukrainian": "ukr_Cyrl", "Swahili": "swh_Latn", "Urdu": "urd_Arab", "Bengali": "ben_Beng", "Tamil": "tam_Taml" } MODEL_NAME = "facebook/nllb-200-distilled-600M" _model = None _tokenizer = None def get_device(): """Determines the best available device.""" if torch.cuda.is_available(): return "cuda" elif torch.backends.mps.is_available(): return "mps" return "cpu" def load_model(): """ Loads the model and tokenizer lazily (singleton pattern). """ global _model, _tokenizer if _model is None: print(f"Loading {MODEL_NAME}...") device = get_device() _tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) _model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME).to(device) print("Model loaded successfully.") return _model, _tokenizer def translate_text(text, src_lang_name, tgt_lang_name): """ Performs the translation using NLLB. """ if not text: return "" try: model, tokenizer = load_model() device = model.device # Get NLLB specific codes src_code = LANGUAGE_CODES.get(src_lang_name, "eng_Latn") tgt_code = LANGUAGE_CODES.get(tgt_lang_name, "fra_Latn") # Prepare inputs tokenizer.src_lang = src_code inputs = tokenizer(text, return_tensors="pt").to(device) # Generate translation # forced_bos_token_id forces the model to start generating in the target language generated_tokens = model.generate( **inputs, forced_bos_token_id=tokenizer.lang_code_to_id[tgt_code], max_length=200 ) # Decode output result = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0] return result except Exception as e: return f"Error during translation: {str(e)}"