moheesh
got all my code
f29ea6c
"""
Inference Module for Fine-Tuned SQL Model
Loads from: Local checkpoint OR Hugging Face Hub
"""
import os
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from dotenv import load_dotenv
load_dotenv()
# =============================================================================
# CONFIGURATION
# =============================================================================
# Hugging Face Model ID (set in .env or Streamlit secrets)
HF_MODEL_ID = os.getenv("HF_MODEL_ID", None)
# Local paths
LOCAL_MODEL_DIR = "outputs/finetuning/checkpoints/final"
BASE_MODEL = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
# =============================================================================
# SQL GENERATOR CLASS
# =============================================================================
class SQLGenerator:
"""SQL Generation using fine-tuned model."""
def __init__(self):
"""Load the fine-tuned model from local or HuggingFace."""
self.device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {self.device}")
load_path = self._get_model_path()
# Load tokenizer and model with memory optimization
print(f"Loading model from: {load_path}")
self.tokenizer = AutoTokenizer.from_pretrained(load_path)
# Memory-efficient loading for cloud deployment
self.model = AutoModelForCausalLM.from_pretrained(
load_path,
torch_dtype=torch.float32, # Use float32 for CPU
device_map=None, # Don't use device_map on CPU
low_cpu_mem_usage=True, # Reduce memory during loading
trust_remote_code=True
)
# Move to device after loading
self.model = self.model.to(self.device)
self.tokenizer.pad_token = self.tokenizer.eos_token
print("βœ“ Model loaded!")
def _get_model_path(self):
"""Determine where to load model from."""
# Check for required model files (not just folder existence)
required_files = ['config.json', 'tokenizer.json', 'tokenizer_config.json']
# Priority 1: Local checkpoint with actual model files
if os.path.exists(LOCAL_MODEL_DIR):
local_files = os.listdir(LOCAL_MODEL_DIR) if os.path.isdir(LOCAL_MODEL_DIR) else []
has_model_files = any(f in local_files for f in required_files) or any(f.endswith('.safetensors') or f.endswith('.bin') for f in local_files)
if has_model_files:
print(f"πŸ“ Found local model checkpoint: {LOCAL_MODEL_DIR}")
return LOCAL_MODEL_DIR
else:
print(f"⚠️ Local folder exists but no model files found")
# Priority 2: Download from HuggingFace Hub
if HF_MODEL_ID:
print(f"☁️ Downloading model from HuggingFace: {HF_MODEL_ID}")
return HF_MODEL_ID
# Priority 3: Base model fallback
print("⚠️ No fine-tuned model found, using base model")
return BASE_MODEL
def generate(self, question, context="", max_tokens=128):
"""Generate SQL from question."""
# Build prompt
if context:
prompt = f"""{context}
### Question:
{question}
### SQL:"""
else:
prompt = f"""### Question:
{question}
### SQL:"""
# Tokenize
inputs = self.tokenizer(
prompt,
return_tensors="pt",
truncation=True,
max_length=512
).to(self.device)
# Generate
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_new_tokens=max_tokens,
temperature=0.1,
do_sample=True,
top_p=0.95,
pad_token_id=self.tokenizer.eos_token_id
)
# Decode
generated = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract SQL
sql = generated[len(prompt):].strip()
if "###" in sql:
sql = sql.split("###")[0].strip()
return sql
# =============================================================================
# STANDALONE FUNCTION
# =============================================================================
_generator = None
def generate_sql(question, context=""):
"""Standalone SQL generation."""
global _generator
if _generator is None:
_generator = SQLGenerator()
return _generator.generate(question, context)
# =============================================================================
# TEST
# =============================================================================
def test_inference():
"""Test the model."""
print("=" * 60)
print("TESTING SQL GENERATION")
print("=" * 60)
generator = SQLGenerator()
questions = [
"Find all employees with salary greater than 50000",
]
print("\n" + "-" * 60)
for q in questions:
print(f"Q: {q}")
sql = generator.generate(q)
print(f"SQL: {sql}")
print("-" * 60)
print("\nβœ“ Test complete")
if __name__ == "__main__":
test_inference()