sql-learning-assistant / src /prompts /prompt_builder.py
moheesh
got all my code
f29ea6c
"""
Prompt Builder for SQL Learning Assistant
Handles: Context Management, User Interaction Flows, Edge Cases
"""
import re
import os
import sys
import json
from datetime import datetime
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from prompts.system_prompts import (
get_system_prompt,
get_prompt_template,
CLARIFICATION_PROMPT,
ERROR_RECOVERY_PROMPT
)
# =============================================================================
# OUTPUT DIRECTORIES
# =============================================================================
OUTPUT_DIR = "outputs/prompts"
LOGS_DIR = f"{OUTPUT_DIR}/logs"
def setup_directories():
os.makedirs(OUTPUT_DIR, exist_ok=True)
os.makedirs(LOGS_DIR, exist_ok=True)
# =============================================================================
# CONTEXT MANAGEMENT
# =============================================================================
class ConversationContext:
"""
Manages conversation history and context for multi-turn interactions.
"""
def __init__(self, max_history=5):
self.history = []
self.max_history = max_history
self.current_tables = []
self.current_schema = {}
self.user_preferences = {}
def add_turn(self, question, sql_response, success=True):
"""Add a conversation turn to history."""
self.history.append({
'question': question,
'sql': sql_response,
'success': success,
'timestamp': datetime.now().isoformat()
})
# Keep only recent history
if len(self.history) > self.max_history:
self.history = self.history[-self.max_history:]
def get_history_context(self):
"""Format history for prompt injection."""
if not self.history:
return ""
context = "Previous conversation:\n"
for turn in self.history[-3:]: # Last 3 turns
context += f"Q: {turn['question']}\n"
context += f"SQL: {turn['sql']}\n\n"
return context
def set_schema(self, schema_dict):
"""Set current database schema context."""
self.current_schema = schema_dict
def get_schema_context(self):
"""Format schema for prompt injection."""
if not self.current_schema:
return ""
context = "Available tables and columns:\n"
for table, columns in self.current_schema.items():
context += f"- {table}: {', '.join(columns)}\n"
return context
def clear(self):
"""Clear conversation history."""
self.history = []
self.current_tables = []
self.current_schema = {}
# =============================================================================
# QUERY ANALYSIS (For Specialized Flows)
# =============================================================================
def analyze_query_intent(question):
"""
Analyze user question to determine query type and intent.
Returns: dict with query_type, keywords, entities
"""
question_lower = question.lower()
# Detect query type
query_type = 'general'
# Aggregation patterns
agg_patterns = ['count', 'sum', 'average', 'avg', 'total', 'maximum', 'max',
'minimum', 'min', 'how many', 'what is the total']
if any(p in question_lower for p in agg_patterns):
query_type = 'aggregation'
# Complex query patterns
complex_patterns = ['join', 'combine', 'merge', 'from multiple', 'across tables',
'subquery', 'nested', 'with the highest', 'with the lowest']
if any(p in question_lower for p in complex_patterns):
query_type = 'complex'
# Modification patterns
mod_patterns = ['insert', 'add new', 'update', 'change', 'modify', 'delete', 'remove']
if any(p in question_lower for p in mod_patterns):
query_type = 'modification'
# Simple patterns (if nothing else matched)
simple_patterns = ['show', 'list', 'get', 'find', 'select', 'display']
if query_type == 'general' and any(p in question_lower for p in simple_patterns):
query_type = 'simple'
# Extract potential keywords
keywords = []
sql_keywords = ['where', 'group by', 'order by', 'having', 'limit', 'join',
'distinct', 'between', 'like', 'in']
for kw in sql_keywords:
if kw in question_lower:
keywords.append(kw.upper())
return {
'query_type': query_type,
'keywords': keywords,
'question_length': len(question.split())
}
# =============================================================================
# EDGE CASE HANDLING
# =============================================================================
def detect_edge_cases(question):
"""
Detect potential edge cases in user question.
Returns: list of edge case types detected
"""
edge_cases = []
question_lower = question.lower()
# Empty or too short
if len(question.strip()) < 5:
edge_cases.append('too_short')
# Too vague
vague_patterns = ['something', 'stuff', 'things', 'data', 'information']
if any(p in question_lower for p in vague_patterns) and len(question.split()) < 5:
edge_cases.append('too_vague')
# Multiple questions
if question.count('?') > 1:
edge_cases.append('multiple_questions')
# Contains SQL (user pasted SQL instead of question)
sql_patterns = ['SELECT', 'INSERT', 'UPDATE', 'DELETE', 'FROM', 'WHERE']
if sum(1 for p in sql_patterns if p in question.upper()) >= 2:
edge_cases.append('contains_sql')
# Potentially dangerous operations
dangerous_patterns = ['drop table', 'truncate', 'delete all', 'remove all']
if any(p in question_lower for p in dangerous_patterns):
edge_cases.append('dangerous_operation')
# Non-SQL question
non_sql_patterns = ['weather', 'hello', 'how are you', 'thank', 'bye']
if any(p in question_lower for p in non_sql_patterns):
edge_cases.append('not_sql_related')
return edge_cases
def handle_edge_case(edge_case_type, question):
"""
Generate appropriate response for edge cases.
Returns: (should_continue, message)
"""
responses = {
'too_short': (
False,
"Your question is too short. Please provide more details about what data you want to retrieve."
),
'too_vague': (
False,
"Your question is a bit vague. Could you specify:\n- Which table(s) to query?\n- What columns to retrieve?\n- Any conditions to filter by?"
),
'multiple_questions': (
False,
"I detected multiple questions. Please ask one question at a time for accurate SQL generation."
),
'contains_sql': (
False,
"It looks like you've pasted SQL code. Please describe what you want in natural language, and I'll generate the SQL for you."
),
'dangerous_operation': (
False,
"⚠️ This appears to be a destructive operation (DROP/TRUNCATE/DELETE ALL). Please confirm you want to proceed or rephrase your question."
),
'not_sql_related': (
False,
"I'm an SQL assistant. Please ask me questions about querying databases, and I'll help generate SQL queries."
)
}
return responses.get(edge_case_type, (True, ""))
# =============================================================================
# PROMPT BUILDER CLASS
# =============================================================================
class PromptBuilder:
"""
Main class for building prompts with context management.
"""
def __init__(self):
self.context = ConversationContext()
self.log_file = None
setup_directories()
def build_prompt(self, question, rag_context="", include_history=True):
"""
Build complete prompt for SQL generation.
Args:
question: User's natural language question
rag_context: Retrieved examples from RAG
include_history: Whether to include conversation history
Returns:
dict with 'success', 'prompt' or 'error'
"""
# Check for edge cases
edge_cases = detect_edge_cases(question)
if edge_cases:
should_continue, message = handle_edge_case(edge_cases[0], question)
if not should_continue:
return {
'success': False,
'error': message,
'edge_case': edge_cases[0]
}
# Analyze query intent
intent = analyze_query_intent(question)
# Get appropriate system prompt
system_prompt = get_system_prompt(intent['query_type'])
# Build context parts
context_parts = []
# Add schema context if available
schema_context = self.context.get_schema_context()
if schema_context:
context_parts.append(schema_context)
# Add conversation history
if include_history:
history_context = self.context.get_history_context()
if history_context:
context_parts.append(history_context)
# Add RAG context
if rag_context:
context_parts.append(rag_context)
# Build final prompt
if rag_context:
template = get_prompt_template('rag')
prompt = template.format(
context=rag_context,
question=question
)
else:
template = get_prompt_template('zero_shot')
prompt = template.format(question=question)
# Combine everything
full_prompt = f"{system_prompt}\n\n"
if context_parts:
full_prompt += "\n".join(context_parts) + "\n\n"
full_prompt += prompt
# Log the prompt
self._log_prompt(question, intent, full_prompt)
return {
'success': True,
'prompt': full_prompt,
'system_prompt': system_prompt,
'query_type': intent['query_type'],
'keywords': intent['keywords']
}
def add_response(self, question, sql_response, success=True):
"""Add a completed interaction to history."""
self.context.add_turn(question, sql_response, success)
def set_schema(self, schema_dict):
"""Set database schema for context."""
self.context.set_schema(schema_dict)
def clear_context(self):
"""Clear all context."""
self.context.clear()
def _log_prompt(self, question, intent, prompt):
"""Log prompt for debugging/analysis."""
log_entry = {
'timestamp': datetime.now().isoformat(),
'question': question,
'intent': intent,
'prompt_length': len(prompt)
}
log_file = f"{LOGS_DIR}/prompt_log.jsonl"
with open(log_file, 'a') as f:
f.write(json.dumps(log_entry) + '\n')
# =============================================================================
# USER INTERACTION FLOWS
# =============================================================================
def get_clarification_questions(question, intent):
"""
Generate clarification questions for ambiguous queries.
"""
clarifications = []
# Generic clarifications based on query type
if intent['query_type'] == 'aggregation':
clarifications.append("Which column should be aggregated?")
clarifications.append("Should results be grouped by any column?")
if intent['query_type'] == 'complex':
clarifications.append("Which tables need to be joined?")
clarifications.append("What is the relationship between the tables?")
# Check for missing specifics
if 'table' not in question.lower():
clarifications.append("Which table(s) should be queried?")
if not any(word in question.lower() for word in ['all', 'specific', 'where', 'filter']):
clarifications.append("Do you want all records or filtered results?")
return clarifications
def create_error_recovery_prompt(original_question, error_message):
"""
Create prompt for recovering from errors.
"""
return ERROR_RECOVERY_PROMPT.format(
error=error_message,
question=original_question
)
# =============================================================================
# TEST
# =============================================================================
def test_prompt_builder():
"""Test the prompt builder functionality."""
print("=" * 60)
print("TESTING PROMPT BUILDER")
print("=" * 60)
builder = PromptBuilder()
# Test 1: Normal question
print("\n[TEST 1] Normal Question")
print("-" * 40)
result = builder.build_prompt(
"Find all employees with salary above 50000",
rag_context="Example 1:\nQ: Get workers earning more than 40000\nSQL: SELECT * FROM employees WHERE salary > 40000"
)
print(f"Success: {result['success']}")
print(f"Query Type: {result.get('query_type')}")
print(f"Prompt Length: {len(result.get('prompt', ''))}")
# Test 2: Edge case - too short
print("\n[TEST 2] Edge Case - Too Short")
print("-" * 40)
result = builder.build_prompt("SQL")
print(f"Success: {result['success']}")
print(f"Error: {result.get('error', 'None')}")
# Test 3: Edge case - contains SQL
print("\n[TEST 3] Edge Case - Contains SQL")
print("-" * 40)
result = builder.build_prompt("SELECT * FROM users WHERE id = 1")
print(f"Success: {result['success']}")
print(f"Error: {result.get('error', 'None')}")
# Test 4: Edge case - dangerous operation
print("\n[TEST 4] Edge Case - Dangerous Operation")
print("-" * 40)
result = builder.build_prompt("Drop table users")
print(f"Success: {result['success']}")
print(f"Error: {result.get('error', 'None')}")
# Test 5: Aggregation query
print("\n[TEST 5] Aggregation Query")
print("-" * 40)
result = builder.build_prompt("Count total orders by customer")
print(f"Success: {result['success']}")
print(f"Query Type: {result.get('query_type')}")
# Test 6: Context management
print("\n[TEST 6] Context Management")
print("-" * 40)
builder.set_schema({
'employees': ['id', 'name', 'salary', 'dept_id'],
'departments': ['id', 'name', 'location']
})
builder.add_response("Show all employees", "SELECT * FROM employees", success=True)
result = builder.build_prompt("Now filter by department")
print(f"Success: {result['success']}")
print(f"Has History: {'Previous conversation' in result.get('prompt', '')}")
print("\n" + "=" * 60)
print("✓ All tests complete")
print("=" * 60)
if __name__ == "__main__":
test_prompt_builder()