Spaces:
Sleeping
Sleeping
File size: 15,287 Bytes
f29ea6c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 |
"""
Prompt Builder for SQL Learning Assistant
Handles: Context Management, User Interaction Flows, Edge Cases
"""
import re
import os
import sys
import json
from datetime import datetime
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from prompts.system_prompts import (
get_system_prompt,
get_prompt_template,
CLARIFICATION_PROMPT,
ERROR_RECOVERY_PROMPT
)
# =============================================================================
# OUTPUT DIRECTORIES
# =============================================================================
OUTPUT_DIR = "outputs/prompts"
LOGS_DIR = f"{OUTPUT_DIR}/logs"
def setup_directories():
os.makedirs(OUTPUT_DIR, exist_ok=True)
os.makedirs(LOGS_DIR, exist_ok=True)
# =============================================================================
# CONTEXT MANAGEMENT
# =============================================================================
class ConversationContext:
"""
Manages conversation history and context for multi-turn interactions.
"""
def __init__(self, max_history=5):
self.history = []
self.max_history = max_history
self.current_tables = []
self.current_schema = {}
self.user_preferences = {}
def add_turn(self, question, sql_response, success=True):
"""Add a conversation turn to history."""
self.history.append({
'question': question,
'sql': sql_response,
'success': success,
'timestamp': datetime.now().isoformat()
})
# Keep only recent history
if len(self.history) > self.max_history:
self.history = self.history[-self.max_history:]
def get_history_context(self):
"""Format history for prompt injection."""
if not self.history:
return ""
context = "Previous conversation:\n"
for turn in self.history[-3:]: # Last 3 turns
context += f"Q: {turn['question']}\n"
context += f"SQL: {turn['sql']}\n\n"
return context
def set_schema(self, schema_dict):
"""Set current database schema context."""
self.current_schema = schema_dict
def get_schema_context(self):
"""Format schema for prompt injection."""
if not self.current_schema:
return ""
context = "Available tables and columns:\n"
for table, columns in self.current_schema.items():
context += f"- {table}: {', '.join(columns)}\n"
return context
def clear(self):
"""Clear conversation history."""
self.history = []
self.current_tables = []
self.current_schema = {}
# =============================================================================
# QUERY ANALYSIS (For Specialized Flows)
# =============================================================================
def analyze_query_intent(question):
"""
Analyze user question to determine query type and intent.
Returns: dict with query_type, keywords, entities
"""
question_lower = question.lower()
# Detect query type
query_type = 'general'
# Aggregation patterns
agg_patterns = ['count', 'sum', 'average', 'avg', 'total', 'maximum', 'max',
'minimum', 'min', 'how many', 'what is the total']
if any(p in question_lower for p in agg_patterns):
query_type = 'aggregation'
# Complex query patterns
complex_patterns = ['join', 'combine', 'merge', 'from multiple', 'across tables',
'subquery', 'nested', 'with the highest', 'with the lowest']
if any(p in question_lower for p in complex_patterns):
query_type = 'complex'
# Modification patterns
mod_patterns = ['insert', 'add new', 'update', 'change', 'modify', 'delete', 'remove']
if any(p in question_lower for p in mod_patterns):
query_type = 'modification'
# Simple patterns (if nothing else matched)
simple_patterns = ['show', 'list', 'get', 'find', 'select', 'display']
if query_type == 'general' and any(p in question_lower for p in simple_patterns):
query_type = 'simple'
# Extract potential keywords
keywords = []
sql_keywords = ['where', 'group by', 'order by', 'having', 'limit', 'join',
'distinct', 'between', 'like', 'in']
for kw in sql_keywords:
if kw in question_lower:
keywords.append(kw.upper())
return {
'query_type': query_type,
'keywords': keywords,
'question_length': len(question.split())
}
# =============================================================================
# EDGE CASE HANDLING
# =============================================================================
def detect_edge_cases(question):
"""
Detect potential edge cases in user question.
Returns: list of edge case types detected
"""
edge_cases = []
question_lower = question.lower()
# Empty or too short
if len(question.strip()) < 5:
edge_cases.append('too_short')
# Too vague
vague_patterns = ['something', 'stuff', 'things', 'data', 'information']
if any(p in question_lower for p in vague_patterns) and len(question.split()) < 5:
edge_cases.append('too_vague')
# Multiple questions
if question.count('?') > 1:
edge_cases.append('multiple_questions')
# Contains SQL (user pasted SQL instead of question)
sql_patterns = ['SELECT', 'INSERT', 'UPDATE', 'DELETE', 'FROM', 'WHERE']
if sum(1 for p in sql_patterns if p in question.upper()) >= 2:
edge_cases.append('contains_sql')
# Potentially dangerous operations
dangerous_patterns = ['drop table', 'truncate', 'delete all', 'remove all']
if any(p in question_lower for p in dangerous_patterns):
edge_cases.append('dangerous_operation')
# Non-SQL question
non_sql_patterns = ['weather', 'hello', 'how are you', 'thank', 'bye']
if any(p in question_lower for p in non_sql_patterns):
edge_cases.append('not_sql_related')
return edge_cases
def handle_edge_case(edge_case_type, question):
"""
Generate appropriate response for edge cases.
Returns: (should_continue, message)
"""
responses = {
'too_short': (
False,
"Your question is too short. Please provide more details about what data you want to retrieve."
),
'too_vague': (
False,
"Your question is a bit vague. Could you specify:\n- Which table(s) to query?\n- What columns to retrieve?\n- Any conditions to filter by?"
),
'multiple_questions': (
False,
"I detected multiple questions. Please ask one question at a time for accurate SQL generation."
),
'contains_sql': (
False,
"It looks like you've pasted SQL code. Please describe what you want in natural language, and I'll generate the SQL for you."
),
'dangerous_operation': (
False,
"⚠️ This appears to be a destructive operation (DROP/TRUNCATE/DELETE ALL). Please confirm you want to proceed or rephrase your question."
),
'not_sql_related': (
False,
"I'm an SQL assistant. Please ask me questions about querying databases, and I'll help generate SQL queries."
)
}
return responses.get(edge_case_type, (True, ""))
# =============================================================================
# PROMPT BUILDER CLASS
# =============================================================================
class PromptBuilder:
"""
Main class for building prompts with context management.
"""
def __init__(self):
self.context = ConversationContext()
self.log_file = None
setup_directories()
def build_prompt(self, question, rag_context="", include_history=True):
"""
Build complete prompt for SQL generation.
Args:
question: User's natural language question
rag_context: Retrieved examples from RAG
include_history: Whether to include conversation history
Returns:
dict with 'success', 'prompt' or 'error'
"""
# Check for edge cases
edge_cases = detect_edge_cases(question)
if edge_cases:
should_continue, message = handle_edge_case(edge_cases[0], question)
if not should_continue:
return {
'success': False,
'error': message,
'edge_case': edge_cases[0]
}
# Analyze query intent
intent = analyze_query_intent(question)
# Get appropriate system prompt
system_prompt = get_system_prompt(intent['query_type'])
# Build context parts
context_parts = []
# Add schema context if available
schema_context = self.context.get_schema_context()
if schema_context:
context_parts.append(schema_context)
# Add conversation history
if include_history:
history_context = self.context.get_history_context()
if history_context:
context_parts.append(history_context)
# Add RAG context
if rag_context:
context_parts.append(rag_context)
# Build final prompt
if rag_context:
template = get_prompt_template('rag')
prompt = template.format(
context=rag_context,
question=question
)
else:
template = get_prompt_template('zero_shot')
prompt = template.format(question=question)
# Combine everything
full_prompt = f"{system_prompt}\n\n"
if context_parts:
full_prompt += "\n".join(context_parts) + "\n\n"
full_prompt += prompt
# Log the prompt
self._log_prompt(question, intent, full_prompt)
return {
'success': True,
'prompt': full_prompt,
'system_prompt': system_prompt,
'query_type': intent['query_type'],
'keywords': intent['keywords']
}
def add_response(self, question, sql_response, success=True):
"""Add a completed interaction to history."""
self.context.add_turn(question, sql_response, success)
def set_schema(self, schema_dict):
"""Set database schema for context."""
self.context.set_schema(schema_dict)
def clear_context(self):
"""Clear all context."""
self.context.clear()
def _log_prompt(self, question, intent, prompt):
"""Log prompt for debugging/analysis."""
log_entry = {
'timestamp': datetime.now().isoformat(),
'question': question,
'intent': intent,
'prompt_length': len(prompt)
}
log_file = f"{LOGS_DIR}/prompt_log.jsonl"
with open(log_file, 'a') as f:
f.write(json.dumps(log_entry) + '\n')
# =============================================================================
# USER INTERACTION FLOWS
# =============================================================================
def get_clarification_questions(question, intent):
"""
Generate clarification questions for ambiguous queries.
"""
clarifications = []
# Generic clarifications based on query type
if intent['query_type'] == 'aggregation':
clarifications.append("Which column should be aggregated?")
clarifications.append("Should results be grouped by any column?")
if intent['query_type'] == 'complex':
clarifications.append("Which tables need to be joined?")
clarifications.append("What is the relationship between the tables?")
# Check for missing specifics
if 'table' not in question.lower():
clarifications.append("Which table(s) should be queried?")
if not any(word in question.lower() for word in ['all', 'specific', 'where', 'filter']):
clarifications.append("Do you want all records or filtered results?")
return clarifications
def create_error_recovery_prompt(original_question, error_message):
"""
Create prompt for recovering from errors.
"""
return ERROR_RECOVERY_PROMPT.format(
error=error_message,
question=original_question
)
# =============================================================================
# TEST
# =============================================================================
def test_prompt_builder():
"""Test the prompt builder functionality."""
print("=" * 60)
print("TESTING PROMPT BUILDER")
print("=" * 60)
builder = PromptBuilder()
# Test 1: Normal question
print("\n[TEST 1] Normal Question")
print("-" * 40)
result = builder.build_prompt(
"Find all employees with salary above 50000",
rag_context="Example 1:\nQ: Get workers earning more than 40000\nSQL: SELECT * FROM employees WHERE salary > 40000"
)
print(f"Success: {result['success']}")
print(f"Query Type: {result.get('query_type')}")
print(f"Prompt Length: {len(result.get('prompt', ''))}")
# Test 2: Edge case - too short
print("\n[TEST 2] Edge Case - Too Short")
print("-" * 40)
result = builder.build_prompt("SQL")
print(f"Success: {result['success']}")
print(f"Error: {result.get('error', 'None')}")
# Test 3: Edge case - contains SQL
print("\n[TEST 3] Edge Case - Contains SQL")
print("-" * 40)
result = builder.build_prompt("SELECT * FROM users WHERE id = 1")
print(f"Success: {result['success']}")
print(f"Error: {result.get('error', 'None')}")
# Test 4: Edge case - dangerous operation
print("\n[TEST 4] Edge Case - Dangerous Operation")
print("-" * 40)
result = builder.build_prompt("Drop table users")
print(f"Success: {result['success']}")
print(f"Error: {result.get('error', 'None')}")
# Test 5: Aggregation query
print("\n[TEST 5] Aggregation Query")
print("-" * 40)
result = builder.build_prompt("Count total orders by customer")
print(f"Success: {result['success']}")
print(f"Query Type: {result.get('query_type')}")
# Test 6: Context management
print("\n[TEST 6] Context Management")
print("-" * 40)
builder.set_schema({
'employees': ['id', 'name', 'salary', 'dept_id'],
'departments': ['id', 'name', 'location']
})
builder.add_response("Show all employees", "SELECT * FROM employees", success=True)
result = builder.build_prompt("Now filter by department")
print(f"Success: {result['success']}")
print(f"Has History: {'Previous conversation' in result.get('prompt', '')}")
print("\n" + "=" * 60)
print("✓ All tests complete")
print("=" * 60)
if __name__ == "__main__":
test_prompt_builder() |