""" Knowledge Base Builder for RAG System Includes: Chunking Strategies, Vector Storage """ import os import pandas as pd import chromadb import json import re from datetime import datetime import sys sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from rag.embeddings import get_embeddings_batch # ============================================================================= # CONFIGURATION # ============================================================================= CHROMA_DIR = "chromadb_data" COLLECTION_NAME = "sql_knowledge" OUTPUT_DIR = "outputs/rag" STATS_DIR = f"{OUTPUT_DIR}/stats" REPORT_DIR = f"{OUTPUT_DIR}/reports" def setup_directories(): """Create necessary directories.""" for d in [CHROMA_DIR, OUTPUT_DIR, STATS_DIR, REPORT_DIR]: os.makedirs(d, exist_ok=True) # ============================================================================= # CHUNKING STRATEGIES # ============================================================================= def chunk_by_sql_clauses(sql): """ Chunking Strategy 1: Split SQL by clauses. Identifies SELECT, FROM, WHERE, GROUP BY, ORDER BY, etc. """ clauses = [] # Common SQL clause patterns patterns = [ (r'\bSELECT\b.*?(?=\bFROM\b|$)', 'SELECT'), (r'\bFROM\b.*?(?=\bWHERE\b|\bGROUP\b|\bORDER\b|\bLIMIT\b|$)', 'FROM'), (r'\bWHERE\b.*?(?=\bGROUP\b|\bORDER\b|\bLIMIT\b|$)', 'WHERE'), (r'\bGROUP BY\b.*?(?=\bHAVING\b|\bORDER\b|\bLIMIT\b|$)', 'GROUP BY'), (r'\bHAVING\b.*?(?=\bORDER\b|\bLIMIT\b|$)', 'HAVING'), (r'\bORDER BY\b.*?(?=\bLIMIT\b|$)', 'ORDER BY'), (r'\bLIMIT\b.*', 'LIMIT'), ] sql_upper = sql.upper() for pattern, clause_name in patterns: match = re.search(pattern, sql_upper, re.IGNORECASE | re.DOTALL) if match: clauses.append(clause_name) return clauses def chunk_by_complexity(question, sql): """ Chunking Strategy 2: Categorize by query complexity. """ sql_upper = sql.upper() # Determine complexity level complexity_score = 0 # Check for complex features if 'JOIN' in sql_upper: complexity_score += 2 if 'SUBQUERY' in sql_upper or sql_upper.count('SELECT') > 1: complexity_score += 2 if 'GROUP BY' in sql_upper: complexity_score += 1 if 'HAVING' in sql_upper: complexity_score += 1 if 'ORDER BY' in sql_upper: complexity_score += 1 if any(agg in sql_upper for agg in ['COUNT', 'SUM', 'AVG', 'MAX', 'MIN']): complexity_score += 1 if 'UNION' in sql_upper: complexity_score += 2 # Categorize if complexity_score <= 1: return 'simple' elif complexity_score <= 3: return 'intermediate' else: return 'complex' def extract_sql_keywords(sql): """ Chunking Strategy 3: Extract SQL keywords for metadata. """ sql_upper = sql.upper() keywords = [] # Operations if 'SELECT' in sql_upper: keywords.append('SELECT') if 'INSERT' in sql_upper: keywords.append('INSERT') if 'UPDATE' in sql_upper: keywords.append('UPDATE') if 'DELETE' in sql_upper: keywords.append('DELETE') # Joins if 'INNER JOIN' in sql_upper: keywords.append('INNER JOIN') elif 'LEFT JOIN' in sql_upper: keywords.append('LEFT JOIN') elif 'RIGHT JOIN' in sql_upper: keywords.append('RIGHT JOIN') elif 'JOIN' in sql_upper: keywords.append('JOIN') # Clauses for clause in ['WHERE', 'GROUP BY', 'HAVING', 'ORDER BY', 'LIMIT']: if clause in sql_upper: keywords.append(clause) # Aggregations for agg in ['COUNT', 'SUM', 'AVG', 'MAX', 'MIN']: if agg in sql_upper: keywords.append(agg) # Subqueries if sql_upper.count('SELECT') > 1: keywords.append('SUBQUERY') return keywords def calculate_chunk_size(text): """Calculate appropriate chunk size category.""" word_count = len(text.split()) if word_count <= 10: return 'short' elif word_count <= 25: return 'medium' else: return 'long' # ============================================================================= # DOCUMENT PREPARATION WITH CHUNKING # ============================================================================= def prepare_documents_with_chunking(datasets): """ Prepare documents with chunking metadata. Each document gets rich metadata for filtering/ranking. """ documents = [] metadatas = [] ids = [] idx = 0 for source, df in datasets.items(): for _, row in df.iterrows(): question = str(row['question']) sql = str(row['sql']) # Apply chunking strategies sql_clauses = chunk_by_sql_clauses(sql) complexity = chunk_by_complexity(question, sql) keywords = extract_sql_keywords(sql) q_size = calculate_chunk_size(question) sql_size = calculate_chunk_size(sql) # Create rich metadata metadata = { 'sql': sql, 'source': source, 'question': question, # Chunking metadata 'complexity': complexity, 'sql_clauses': ','.join(sql_clauses), 'keywords': ','.join(keywords), 'question_size': q_size, 'sql_size': sql_size, 'keyword_count': len(keywords), 'clause_count': len(sql_clauses), } documents.append(question) metadatas.append(metadata) ids.append(f"doc_{idx}") idx += 1 return documents, metadatas, ids # ============================================================================= # CHROMADB CLIENT # ============================================================================= def get_chroma_client(): """Get ChromaDB persistent client.""" return chromadb.PersistentClient(path=CHROMA_DIR) def get_or_create_collection(client): """Get or create the SQL knowledge collection.""" return client.get_or_create_collection( name=COLLECTION_NAME, metadata={"description": "SQL question-answer pairs with chunking metadata"} ) # ============================================================================= # DATA LOADING # ============================================================================= def load_datasets(data_dir="data"): """Load ALL CSV datasets.""" datasets = {} files = { 'train': 'train.csv', 'validation': 'validation.csv', 'test': 'test.csv' # 'synthetic': 'synthetic.csv' } for name, filename in files.items(): filepath = os.path.join(data_dir, filename) if os.path.exists(filepath): df = pd.read_csv(filepath) datasets[name] = df print(f" Loaded {name}: {len(df):,} rows") else: print(f" Skipped {name}: file not found") return datasets # ============================================================================= # KNOWLEDGE BASE BUILDING # ============================================================================= def build_knowledge_base(data_dir="data", batch_size=500): """Build knowledge base with chunking strategies.""" print("=" * 50) print("BUILDING RAG KNOWLEDGE BASE") print("With Chunking Strategies") print("=" * 50) setup_directories() # Step 1: Load data print(f"\n[1/5] Loading datasets...") datasets = load_datasets(data_dir) if not datasets: print("ERROR: No datasets found!") return None total_rows = sum(len(df) for df in datasets.values()) print(f" Total rows: {total_rows:,}") # Step 2: Prepare documents with chunking print(f"\n[2/5] Applying chunking strategies...") documents, metadatas, ids = prepare_documents_with_chunking(datasets) print(f" Total documents: {len(documents):,}") # Show chunking stats complexities = [m['complexity'] for m in metadatas] print(f" Complexity distribution:") print(f" Simple: {complexities.count('simple'):,}") print(f" Intermediate: {complexities.count('intermediate'):,}") print(f" Complex: {complexities.count('complex'):,}") # Step 3: Initialize ChromaDB print(f"\n[3/5] Initializing ChromaDB...") client = get_chroma_client() try: client.delete_collection(COLLECTION_NAME) print(" Deleted existing collection") except: pass collection = get_or_create_collection(client) print(f" Collection: {COLLECTION_NAME}") # Step 4: Generate embeddings and store print(f"\n[4/5] Generating embeddings...") total_added = 0 for i in range(0, len(documents), batch_size): batch_docs = documents[i:i + batch_size] batch_meta = metadatas[i:i + batch_size] batch_ids = ids[i:i + batch_size] embeddings = get_embeddings_batch(batch_docs) if embeddings and embeddings[0] is not None: collection.add( documents=batch_docs, metadatas=batch_meta, ids=batch_ids, embeddings=embeddings ) total_added += len(batch_docs) progress = min(i + batch_size, len(documents)) pct = (progress / len(documents)) * 100 print(f" Progress: {progress:,}/{len(documents):,} ({pct:.1f}%)") # Step 5: Save statistics print(f"\n[5/5] Saving statistics...") stats = { 'total_documents': total_added, 'sources': {name: len(df) for name, df in datasets.items()}, 'collection_name': COLLECTION_NAME, 'embedding_model': 'all-MiniLM-L6-v2', 'chunking_strategies': [ 'sql_clause_extraction', 'complexity_classification', 'keyword_extraction', 'size_categorization' ], 'complexity_distribution': { 'simple': complexities.count('simple'), 'intermediate': complexities.count('intermediate'), 'complex': complexities.count('complex') }, 'created_at': datetime.now().isoformat() } with open(f'{STATS_DIR}/knowledge_base_stats.json', 'w') as f: json.dump(stats, f, indent=2) generate_report(stats) print("\n" + "=" * 50) print("COMPLETE") print("=" * 50) print(f" Documents indexed: {total_added:,}") print(f" Storage: {CHROMA_DIR}/") return collection # ============================================================================= # REPORT GENERATION # ============================================================================= def generate_report(stats): """Generate knowledge base report.""" report = f"""# RAG Knowledge Base Report **Generated:** {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} ## Overview | Metric | Value | |--------|-------| | Total Documents | {stats['total_documents']:,} | | Collection Name | {stats['collection_name']} | | Embedding Model | {stats['embedding_model']} | ## Data Sources | Source | Documents | |--------|-----------| """ for source, count in stats['sources'].items(): report += f"| {source} | {count:,} |\n" report += f""" ## Chunking Strategies 1. **SQL Clause Extraction**: Identifies SELECT, FROM, WHERE, GROUP BY, etc. 2. **Complexity Classification**: Categorizes as simple/intermediate/complex 3. **Keyword Extraction**: Extracts SQL operations (JOIN, COUNT, etc.) 4. **Size Categorization**: Classifies question/SQL length ## Complexity Distribution | Level | Count | |-------|-------| | Simple | {stats['complexity_distribution']['simple']:,} | | Intermediate | {stats['complexity_distribution']['intermediate']:,} | | Complex | {stats['complexity_distribution']['complex']:,} | ## Document Metadata Structure Each document contains: - `sql`: The SQL query - `source`: Origin dataset - `question`: Original question - `complexity`: simple/intermediate/complex - `sql_clauses`: Comma-separated clauses - `keywords`: SQL keywords found - `question_size`: short/medium/long - `sql_size`: short/medium/long """ with open(f'{REPORT_DIR}/knowledge_base_report.md', 'w') as f: f.write(report) print(f" Report saved to {REPORT_DIR}/") # ============================================================================= # ENTRY POINT # ============================================================================= if __name__ == "__main__": build_knowledge_base(data_dir="data", batch_size=500)