""" Retriever Module for RAG System Loads from: Local ChromaDB OR HuggingFace Hub """ import os import sys from dotenv import load_dotenv load_dotenv() # Try new imports first, fall back to old try: from langchain_huggingface import HuggingFaceEmbeddings from langchain_chroma import Chroma except ImportError: from langchain_community.vectorstores import Chroma from langchain_community.embeddings import HuggingFaceEmbeddings sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) # ============================================================================= # CONFIGURATION # ============================================================================= LOCAL_CHROMADB_DIR = "chromadb_data" HF_CHROMADB_ID = os.getenv("HF_CHROMADB_ID", None) COLLECTION_NAME = "sql_knowledge" EMBEDDING_MODEL = "all-MiniLM-L6-v2" # ============================================================================= # CHROMADB LOADER # ============================================================================= def ensure_chromadb_exists(): """Ensure ChromaDB data exists - download from HF if needed.""" # Check if local has actual ChromaDB files (not just empty folder) if os.path.exists(LOCAL_CHROMADB_DIR): local_files = os.listdir(LOCAL_CHROMADB_DIR) if os.path.isdir(LOCAL_CHROMADB_DIR) else [] # ChromaDB creates files like chroma.sqlite3 or folders has_chroma_files = any('chroma' in f.lower() or 'sqlite' in f.lower() for f in local_files) or len(local_files) > 2 if has_chroma_files: print(f"📁 Using local ChromaDB: {LOCAL_CHROMADB_DIR}") return LOCAL_CHROMADB_DIR else: print(f"⚠️ ChromaDB folder exists but is empty or incomplete") # Download from HuggingFace if HF_CHROMADB_ID: print(f"☁️ Downloading ChromaDB from HuggingFace: {HF_CHROMADB_ID}") from huggingface_hub import snapshot_download # Create folder if not exists os.makedirs(LOCAL_CHROMADB_DIR, exist_ok=True) snapshot_download( repo_id=HF_CHROMADB_ID, repo_type="dataset", local_dir=LOCAL_CHROMADB_DIR ) print("✓ ChromaDB downloaded!") return LOCAL_CHROMADB_DIR # Need to build it from data print("⚠️ ChromaDB not found and no HF_CHROMADB_ID set. Building from data...") from rag.knowledge_base import build_knowledge_base build_knowledge_base(data_dir="data", batch_size=500) return LOCAL_CHROMADB_DIR # ============================================================================= # LANGCHAIN EMBEDDINGS # ============================================================================= def get_embeddings(): """Get HuggingFace embeddings for LangChain.""" return HuggingFaceEmbeddings( model_name=EMBEDDING_MODEL, model_kwargs={'device': 'cpu'}, encode_kwargs={'normalize_embeddings': True} ) # ============================================================================= # RANKING FUNCTIONS # ============================================================================= def calculate_relevance_score(result, query): """Calculate enhanced relevance score.""" base_score = result.get('score', 0.5) boost = 0.0 query_words = set(query.lower().split()) question_words = set(result.get('question', '').lower().split()) overlap = len(query_words & question_words) if overlap > 0: boost += 0.05 * min(overlap, 5) query_length = len(query.split()) if query_length <= 8 and result.get('complexity') == 'simple': boost += 0.1 elif query_length > 15 and result.get('complexity') == 'complex': boost += 0.1 return base_score + boost def rerank_results(results, query): """Re-rank results using enhanced relevance scoring.""" for r in results: r['relevance_score'] = calculate_relevance_score(r, query) results.sort(key=lambda x: x['relevance_score'], reverse=True) return results # ============================================================================= # FILTERING FUNCTIONS # ============================================================================= def filter_by_threshold(results, min_score=0.0): return [r for r in results if r.get('score', 0) >= min_score] def filter_by_complexity(results, complexity=None): if complexity is None: return results return [r for r in results if r.get('complexity') == complexity] # ============================================================================= # SQL RETRIEVER CLASS # ============================================================================= class SQLRetriever: """LangChain-based retriever with local/HuggingFace support.""" def __init__(self): """Initialize the retriever.""" print("Initializing SQL Retriever...") # Ensure ChromaDB exists chromadb_path = ensure_chromadb_exists() # Load embeddings self.embeddings = get_embeddings() # Load ChromaDB self.vectorstore = Chroma( collection_name=COLLECTION_NAME, persist_directory=chromadb_path, embedding_function=self.embeddings ) self.doc_count = self.vectorstore._collection.count() print(f"✓ Loaded {self.doc_count:,} documents from {chromadb_path}") def retrieve(self, query, top_k=5, min_score=None, complexity=None, rerank=True): """Retrieve similar questions with filtering and ranking.""" fetch_k = min(top_k * 3, 50) docs_with_scores = self.vectorstore.similarity_search_with_score(query, k=fetch_k) # Format results formatted = [] for doc, score in docs_with_scores: formatted.append({ 'question': doc.page_content, 'sql': doc.metadata.get('sql', ''), 'source': doc.metadata.get('source', 'unknown'), 'complexity': doc.metadata.get('complexity', 'unknown'), 'keywords': doc.metadata.get('keywords', ''), 'sql_clauses': doc.metadata.get('sql_clauses', ''), 'distance': score, 'score': 1 - score if score <= 1 else 1 / (1 + score) }) # Apply filters if min_score is not None: formatted = filter_by_threshold(formatted, min_score) if complexity is not None: formatted = filter_by_complexity(formatted, complexity) # Apply re-ranking if rerank: formatted = rerank_results(formatted, query) return formatted[:top_k] def retrieve_as_context(self, query, top_k=5): """Retrieve and format as context for LLM prompt.""" results = self.retrieve(query, top_k=top_k) if not results: return "" context = "Similar SQL examples:\n\n" for i, r in enumerate(results, 1): context += f"Example {i}:\n" context += f"Question: {r['question']}\n" context += f"SQL: {r['sql']}\n\n" return context def get_stats(self): """Get retriever statistics.""" return { 'total_documents': self.doc_count, 'collection_name': COLLECTION_NAME, 'embedding_model': EMBEDDING_MODEL, } # ============================================================================= # TEST # ============================================================================= def test_retriever(): """Test retriever.""" print("=" * 60) print("TESTING SQL RETRIEVER") print("=" * 60) retriever = SQLRetriever() query = "Find all employees with salary above 50000" results = retriever.retrieve(query, top_k=3) print(f"\nQuery: {query}\n") for i, r in enumerate(results, 1): print(f"Result {i}: (score: {r['score']:.3f})") print(f" Q: {r['question'][:60]}...") print(f" SQL: {r['sql'][:60]}...") print() print("✓ Test complete") if __name__ == "__main__": test_retriever()