Spaces:
Sleeping
Sleeping
File size: 5,338 Bytes
f29ea6c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 |
"""
Inference Module for Fine-Tuned SQL Model
Loads from: Local checkpoint OR Hugging Face Hub
"""
import os
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from dotenv import load_dotenv
load_dotenv()
# =============================================================================
# CONFIGURATION
# =============================================================================
# Hugging Face Model ID (set in .env or Streamlit secrets)
HF_MODEL_ID = os.getenv("HF_MODEL_ID", None)
# Local paths
LOCAL_MODEL_DIR = "outputs/finetuning/checkpoints/final"
BASE_MODEL = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
# =============================================================================
# SQL GENERATOR CLASS
# =============================================================================
class SQLGenerator:
"""SQL Generation using fine-tuned model."""
def __init__(self):
"""Load the fine-tuned model from local or HuggingFace."""
self.device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {self.device}")
load_path = self._get_model_path()
# Load tokenizer and model with memory optimization
print(f"Loading model from: {load_path}")
self.tokenizer = AutoTokenizer.from_pretrained(load_path)
# Memory-efficient loading for cloud deployment
self.model = AutoModelForCausalLM.from_pretrained(
load_path,
torch_dtype=torch.float32, # Use float32 for CPU
device_map=None, # Don't use device_map on CPU
low_cpu_mem_usage=True, # Reduce memory during loading
trust_remote_code=True
)
# Move to device after loading
self.model = self.model.to(self.device)
self.tokenizer.pad_token = self.tokenizer.eos_token
print("โ Model loaded!")
def _get_model_path(self):
"""Determine where to load model from."""
# Check for required model files (not just folder existence)
required_files = ['config.json', 'tokenizer.json', 'tokenizer_config.json']
# Priority 1: Local checkpoint with actual model files
if os.path.exists(LOCAL_MODEL_DIR):
local_files = os.listdir(LOCAL_MODEL_DIR) if os.path.isdir(LOCAL_MODEL_DIR) else []
has_model_files = any(f in local_files for f in required_files) or any(f.endswith('.safetensors') or f.endswith('.bin') for f in local_files)
if has_model_files:
print(f"๐ Found local model checkpoint: {LOCAL_MODEL_DIR}")
return LOCAL_MODEL_DIR
else:
print(f"โ ๏ธ Local folder exists but no model files found")
# Priority 2: Download from HuggingFace Hub
if HF_MODEL_ID:
print(f"โ๏ธ Downloading model from HuggingFace: {HF_MODEL_ID}")
return HF_MODEL_ID
# Priority 3: Base model fallback
print("โ ๏ธ No fine-tuned model found, using base model")
return BASE_MODEL
def generate(self, question, context="", max_tokens=128):
"""Generate SQL from question."""
# Build prompt
if context:
prompt = f"""{context}
### Question:
{question}
### SQL:"""
else:
prompt = f"""### Question:
{question}
### SQL:"""
# Tokenize
inputs = self.tokenizer(
prompt,
return_tensors="pt",
truncation=True,
max_length=512
).to(self.device)
# Generate
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_new_tokens=max_tokens,
temperature=0.1,
do_sample=True,
top_p=0.95,
pad_token_id=self.tokenizer.eos_token_id
)
# Decode
generated = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract SQL
sql = generated[len(prompt):].strip()
if "###" in sql:
sql = sql.split("###")[0].strip()
return sql
# =============================================================================
# STANDALONE FUNCTION
# =============================================================================
_generator = None
def generate_sql(question, context=""):
"""Standalone SQL generation."""
global _generator
if _generator is None:
_generator = SQLGenerator()
return _generator.generate(question, context)
# =============================================================================
# TEST
# =============================================================================
def test_inference():
"""Test the model."""
print("=" * 60)
print("TESTING SQL GENERATION")
print("=" * 60)
generator = SQLGenerator()
questions = [
"Find all employees with salary greater than 50000",
]
print("\n" + "-" * 60)
for q in questions:
print(f"Q: {q}")
sql = generator.generate(q)
print(f"SQL: {sql}")
print("-" * 60)
print("\nโ Test complete")
if __name__ == "__main__":
test_inference() |