sql-learning-assistant / src /rag /knowledge_base.py
moheesh
got all my code
f29ea6c
"""
Knowledge Base Builder for RAG System
Includes: Chunking Strategies, Vector Storage
"""
import os
import pandas as pd
import chromadb
import json
import re
from datetime import datetime
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from rag.embeddings import get_embeddings_batch
# =============================================================================
# CONFIGURATION
# =============================================================================
CHROMA_DIR = "chromadb_data"
COLLECTION_NAME = "sql_knowledge"
OUTPUT_DIR = "outputs/rag"
STATS_DIR = f"{OUTPUT_DIR}/stats"
REPORT_DIR = f"{OUTPUT_DIR}/reports"
def setup_directories():
"""Create necessary directories."""
for d in [CHROMA_DIR, OUTPUT_DIR, STATS_DIR, REPORT_DIR]:
os.makedirs(d, exist_ok=True)
# =============================================================================
# CHUNKING STRATEGIES
# =============================================================================
def chunk_by_sql_clauses(sql):
"""
Chunking Strategy 1: Split SQL by clauses.
Identifies SELECT, FROM, WHERE, GROUP BY, ORDER BY, etc.
"""
clauses = []
# Common SQL clause patterns
patterns = [
(r'\bSELECT\b.*?(?=\bFROM\b|$)', 'SELECT'),
(r'\bFROM\b.*?(?=\bWHERE\b|\bGROUP\b|\bORDER\b|\bLIMIT\b|$)', 'FROM'),
(r'\bWHERE\b.*?(?=\bGROUP\b|\bORDER\b|\bLIMIT\b|$)', 'WHERE'),
(r'\bGROUP BY\b.*?(?=\bHAVING\b|\bORDER\b|\bLIMIT\b|$)', 'GROUP BY'),
(r'\bHAVING\b.*?(?=\bORDER\b|\bLIMIT\b|$)', 'HAVING'),
(r'\bORDER BY\b.*?(?=\bLIMIT\b|$)', 'ORDER BY'),
(r'\bLIMIT\b.*', 'LIMIT'),
]
sql_upper = sql.upper()
for pattern, clause_name in patterns:
match = re.search(pattern, sql_upper, re.IGNORECASE | re.DOTALL)
if match:
clauses.append(clause_name)
return clauses
def chunk_by_complexity(question, sql):
"""
Chunking Strategy 2: Categorize by query complexity.
"""
sql_upper = sql.upper()
# Determine complexity level
complexity_score = 0
# Check for complex features
if 'JOIN' in sql_upper:
complexity_score += 2
if 'SUBQUERY' in sql_upper or sql_upper.count('SELECT') > 1:
complexity_score += 2
if 'GROUP BY' in sql_upper:
complexity_score += 1
if 'HAVING' in sql_upper:
complexity_score += 1
if 'ORDER BY' in sql_upper:
complexity_score += 1
if any(agg in sql_upper for agg in ['COUNT', 'SUM', 'AVG', 'MAX', 'MIN']):
complexity_score += 1
if 'UNION' in sql_upper:
complexity_score += 2
# Categorize
if complexity_score <= 1:
return 'simple'
elif complexity_score <= 3:
return 'intermediate'
else:
return 'complex'
def extract_sql_keywords(sql):
"""
Chunking Strategy 3: Extract SQL keywords for metadata.
"""
sql_upper = sql.upper()
keywords = []
# Operations
if 'SELECT' in sql_upper:
keywords.append('SELECT')
if 'INSERT' in sql_upper:
keywords.append('INSERT')
if 'UPDATE' in sql_upper:
keywords.append('UPDATE')
if 'DELETE' in sql_upper:
keywords.append('DELETE')
# Joins
if 'INNER JOIN' in sql_upper:
keywords.append('INNER JOIN')
elif 'LEFT JOIN' in sql_upper:
keywords.append('LEFT JOIN')
elif 'RIGHT JOIN' in sql_upper:
keywords.append('RIGHT JOIN')
elif 'JOIN' in sql_upper:
keywords.append('JOIN')
# Clauses
for clause in ['WHERE', 'GROUP BY', 'HAVING', 'ORDER BY', 'LIMIT']:
if clause in sql_upper:
keywords.append(clause)
# Aggregations
for agg in ['COUNT', 'SUM', 'AVG', 'MAX', 'MIN']:
if agg in sql_upper:
keywords.append(agg)
# Subqueries
if sql_upper.count('SELECT') > 1:
keywords.append('SUBQUERY')
return keywords
def calculate_chunk_size(text):
"""Calculate appropriate chunk size category."""
word_count = len(text.split())
if word_count <= 10:
return 'short'
elif word_count <= 25:
return 'medium'
else:
return 'long'
# =============================================================================
# DOCUMENT PREPARATION WITH CHUNKING
# =============================================================================
def prepare_documents_with_chunking(datasets):
"""
Prepare documents with chunking metadata.
Each document gets rich metadata for filtering/ranking.
"""
documents = []
metadatas = []
ids = []
idx = 0
for source, df in datasets.items():
for _, row in df.iterrows():
question = str(row['question'])
sql = str(row['sql'])
# Apply chunking strategies
sql_clauses = chunk_by_sql_clauses(sql)
complexity = chunk_by_complexity(question, sql)
keywords = extract_sql_keywords(sql)
q_size = calculate_chunk_size(question)
sql_size = calculate_chunk_size(sql)
# Create rich metadata
metadata = {
'sql': sql,
'source': source,
'question': question,
# Chunking metadata
'complexity': complexity,
'sql_clauses': ','.join(sql_clauses),
'keywords': ','.join(keywords),
'question_size': q_size,
'sql_size': sql_size,
'keyword_count': len(keywords),
'clause_count': len(sql_clauses),
}
documents.append(question)
metadatas.append(metadata)
ids.append(f"doc_{idx}")
idx += 1
return documents, metadatas, ids
# =============================================================================
# CHROMADB CLIENT
# =============================================================================
def get_chroma_client():
"""Get ChromaDB persistent client."""
return chromadb.PersistentClient(path=CHROMA_DIR)
def get_or_create_collection(client):
"""Get or create the SQL knowledge collection."""
return client.get_or_create_collection(
name=COLLECTION_NAME,
metadata={"description": "SQL question-answer pairs with chunking metadata"}
)
# =============================================================================
# DATA LOADING
# =============================================================================
def load_datasets(data_dir="data"):
"""Load ALL CSV datasets."""
datasets = {}
files = {
'train': 'train.csv',
'validation': 'validation.csv',
'test': 'test.csv'
# 'synthetic': 'synthetic.csv'
}
for name, filename in files.items():
filepath = os.path.join(data_dir, filename)
if os.path.exists(filepath):
df = pd.read_csv(filepath)
datasets[name] = df
print(f" Loaded {name}: {len(df):,} rows")
else:
print(f" Skipped {name}: file not found")
return datasets
# =============================================================================
# KNOWLEDGE BASE BUILDING
# =============================================================================
def build_knowledge_base(data_dir="data", batch_size=500):
"""Build knowledge base with chunking strategies."""
print("=" * 50)
print("BUILDING RAG KNOWLEDGE BASE")
print("With Chunking Strategies")
print("=" * 50)
setup_directories()
# Step 1: Load data
print(f"\n[1/5] Loading datasets...")
datasets = load_datasets(data_dir)
if not datasets:
print("ERROR: No datasets found!")
return None
total_rows = sum(len(df) for df in datasets.values())
print(f" Total rows: {total_rows:,}")
# Step 2: Prepare documents with chunking
print(f"\n[2/5] Applying chunking strategies...")
documents, metadatas, ids = prepare_documents_with_chunking(datasets)
print(f" Total documents: {len(documents):,}")
# Show chunking stats
complexities = [m['complexity'] for m in metadatas]
print(f" Complexity distribution:")
print(f" Simple: {complexities.count('simple'):,}")
print(f" Intermediate: {complexities.count('intermediate'):,}")
print(f" Complex: {complexities.count('complex'):,}")
# Step 3: Initialize ChromaDB
print(f"\n[3/5] Initializing ChromaDB...")
client = get_chroma_client()
try:
client.delete_collection(COLLECTION_NAME)
print(" Deleted existing collection")
except:
pass
collection = get_or_create_collection(client)
print(f" Collection: {COLLECTION_NAME}")
# Step 4: Generate embeddings and store
print(f"\n[4/5] Generating embeddings...")
total_added = 0
for i in range(0, len(documents), batch_size):
batch_docs = documents[i:i + batch_size]
batch_meta = metadatas[i:i + batch_size]
batch_ids = ids[i:i + batch_size]
embeddings = get_embeddings_batch(batch_docs)
if embeddings and embeddings[0] is not None:
collection.add(
documents=batch_docs,
metadatas=batch_meta,
ids=batch_ids,
embeddings=embeddings
)
total_added += len(batch_docs)
progress = min(i + batch_size, len(documents))
pct = (progress / len(documents)) * 100
print(f" Progress: {progress:,}/{len(documents):,} ({pct:.1f}%)")
# Step 5: Save statistics
print(f"\n[5/5] Saving statistics...")
stats = {
'total_documents': total_added,
'sources': {name: len(df) for name, df in datasets.items()},
'collection_name': COLLECTION_NAME,
'embedding_model': 'all-MiniLM-L6-v2',
'chunking_strategies': [
'sql_clause_extraction',
'complexity_classification',
'keyword_extraction',
'size_categorization'
],
'complexity_distribution': {
'simple': complexities.count('simple'),
'intermediate': complexities.count('intermediate'),
'complex': complexities.count('complex')
},
'created_at': datetime.now().isoformat()
}
with open(f'{STATS_DIR}/knowledge_base_stats.json', 'w') as f:
json.dump(stats, f, indent=2)
generate_report(stats)
print("\n" + "=" * 50)
print("COMPLETE")
print("=" * 50)
print(f" Documents indexed: {total_added:,}")
print(f" Storage: {CHROMA_DIR}/")
return collection
# =============================================================================
# REPORT GENERATION
# =============================================================================
def generate_report(stats):
"""Generate knowledge base report."""
report = f"""# RAG Knowledge Base Report
**Generated:** {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
## Overview
| Metric | Value |
|--------|-------|
| Total Documents | {stats['total_documents']:,} |
| Collection Name | {stats['collection_name']} |
| Embedding Model | {stats['embedding_model']} |
## Data Sources
| Source | Documents |
|--------|-----------|
"""
for source, count in stats['sources'].items():
report += f"| {source} | {count:,} |\n"
report += f"""
## Chunking Strategies
1. **SQL Clause Extraction**: Identifies SELECT, FROM, WHERE, GROUP BY, etc.
2. **Complexity Classification**: Categorizes as simple/intermediate/complex
3. **Keyword Extraction**: Extracts SQL operations (JOIN, COUNT, etc.)
4. **Size Categorization**: Classifies question/SQL length
## Complexity Distribution
| Level | Count |
|-------|-------|
| Simple | {stats['complexity_distribution']['simple']:,} |
| Intermediate | {stats['complexity_distribution']['intermediate']:,} |
| Complex | {stats['complexity_distribution']['complex']:,} |
## Document Metadata Structure
Each document contains:
- `sql`: The SQL query
- `source`: Origin dataset
- `question`: Original question
- `complexity`: simple/intermediate/complex
- `sql_clauses`: Comma-separated clauses
- `keywords`: SQL keywords found
- `question_size`: short/medium/long
- `sql_size`: short/medium/long
"""
with open(f'{REPORT_DIR}/knowledge_base_report.md', 'w') as f:
f.write(report)
print(f" Report saved to {REPORT_DIR}/")
# =============================================================================
# ENTRY POINT
# =============================================================================
if __name__ == "__main__":
build_knowledge_base(data_dir="data", batch_size=500)