File size: 15,287 Bytes
f29ea6c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
"""
Prompt Builder for SQL Learning Assistant
Handles: Context Management, User Interaction Flows, Edge Cases
"""

import re
import os
import sys
import json
from datetime import datetime

sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from prompts.system_prompts import (
    get_system_prompt,
    get_prompt_template,
    CLARIFICATION_PROMPT,
    ERROR_RECOVERY_PROMPT
)

# =============================================================================
# OUTPUT DIRECTORIES
# =============================================================================

OUTPUT_DIR = "outputs/prompts"
LOGS_DIR = f"{OUTPUT_DIR}/logs"

def setup_directories():
    os.makedirs(OUTPUT_DIR, exist_ok=True)
    os.makedirs(LOGS_DIR, exist_ok=True)

# =============================================================================
# CONTEXT MANAGEMENT
# =============================================================================

class ConversationContext:
    """
    Manages conversation history and context for multi-turn interactions.
    """
    
    def __init__(self, max_history=5):
        self.history = []
        self.max_history = max_history
        self.current_tables = []
        self.current_schema = {}
        self.user_preferences = {}
    
    def add_turn(self, question, sql_response, success=True):
        """Add a conversation turn to history."""
        self.history.append({
            'question': question,
            'sql': sql_response,
            'success': success,
            'timestamp': datetime.now().isoformat()
        })
        
        # Keep only recent history
        if len(self.history) > self.max_history:
            self.history = self.history[-self.max_history:]
    
    def get_history_context(self):
        """Format history for prompt injection."""
        if not self.history:
            return ""
        
        context = "Previous conversation:\n"
        for turn in self.history[-3:]:  # Last 3 turns
            context += f"Q: {turn['question']}\n"
            context += f"SQL: {turn['sql']}\n\n"
        
        return context
    
    def set_schema(self, schema_dict):
        """Set current database schema context."""
        self.current_schema = schema_dict
    
    def get_schema_context(self):
        """Format schema for prompt injection."""
        if not self.current_schema:
            return ""
        
        context = "Available tables and columns:\n"
        for table, columns in self.current_schema.items():
            context += f"- {table}: {', '.join(columns)}\n"
        
        return context
    
    def clear(self):
        """Clear conversation history."""
        self.history = []
        self.current_tables = []
        self.current_schema = {}

# =============================================================================
# QUERY ANALYSIS (For Specialized Flows)
# =============================================================================

def analyze_query_intent(question):
    """
    Analyze user question to determine query type and intent.
    Returns: dict with query_type, keywords, entities
    """
    question_lower = question.lower()
    
    # Detect query type
    query_type = 'general'
    
    # Aggregation patterns
    agg_patterns = ['count', 'sum', 'average', 'avg', 'total', 'maximum', 'max', 
                    'minimum', 'min', 'how many', 'what is the total']
    if any(p in question_lower for p in agg_patterns):
        query_type = 'aggregation'
    
    # Complex query patterns
    complex_patterns = ['join', 'combine', 'merge', 'from multiple', 'across tables',
                       'subquery', 'nested', 'with the highest', 'with the lowest']
    if any(p in question_lower for p in complex_patterns):
        query_type = 'complex'
    
    # Modification patterns
    mod_patterns = ['insert', 'add new', 'update', 'change', 'modify', 'delete', 'remove']
    if any(p in question_lower for p in mod_patterns):
        query_type = 'modification'
    
    # Simple patterns (if nothing else matched)
    simple_patterns = ['show', 'list', 'get', 'find', 'select', 'display']
    if query_type == 'general' and any(p in question_lower for p in simple_patterns):
        query_type = 'simple'
    
    # Extract potential keywords
    keywords = []
    sql_keywords = ['where', 'group by', 'order by', 'having', 'limit', 'join', 
                   'distinct', 'between', 'like', 'in']
    for kw in sql_keywords:
        if kw in question_lower:
            keywords.append(kw.upper())
    
    return {
        'query_type': query_type,
        'keywords': keywords,
        'question_length': len(question.split())
    }

# =============================================================================
# EDGE CASE HANDLING
# =============================================================================

def detect_edge_cases(question):
    """
    Detect potential edge cases in user question.
    Returns: list of edge case types detected
    """
    edge_cases = []
    question_lower = question.lower()
    
    # Empty or too short
    if len(question.strip()) < 5:
        edge_cases.append('too_short')
    
    # Too vague
    vague_patterns = ['something', 'stuff', 'things', 'data', 'information']
    if any(p in question_lower for p in vague_patterns) and len(question.split()) < 5:
        edge_cases.append('too_vague')
    
    # Multiple questions
    if question.count('?') > 1:
        edge_cases.append('multiple_questions')
    
    # Contains SQL (user pasted SQL instead of question)
    sql_patterns = ['SELECT', 'INSERT', 'UPDATE', 'DELETE', 'FROM', 'WHERE']
    if sum(1 for p in sql_patterns if p in question.upper()) >= 2:
        edge_cases.append('contains_sql')
    
    # Potentially dangerous operations
    dangerous_patterns = ['drop table', 'truncate', 'delete all', 'remove all']
    if any(p in question_lower for p in dangerous_patterns):
        edge_cases.append('dangerous_operation')
    
    # Non-SQL question
    non_sql_patterns = ['weather', 'hello', 'how are you', 'thank', 'bye']
    if any(p in question_lower for p in non_sql_patterns):
        edge_cases.append('not_sql_related')
    
    return edge_cases

def handle_edge_case(edge_case_type, question):
    """
    Generate appropriate response for edge cases.
    Returns: (should_continue, message)
    """
    responses = {
        'too_short': (
            False,
            "Your question is too short. Please provide more details about what data you want to retrieve."
        ),
        'too_vague': (
            False,
            "Your question is a bit vague. Could you specify:\n- Which table(s) to query?\n- What columns to retrieve?\n- Any conditions to filter by?"
        ),
        'multiple_questions': (
            False,
            "I detected multiple questions. Please ask one question at a time for accurate SQL generation."
        ),
        'contains_sql': (
            False,
            "It looks like you've pasted SQL code. Please describe what you want in natural language, and I'll generate the SQL for you."
        ),
        'dangerous_operation': (
            False,
            "⚠️ This appears to be a destructive operation (DROP/TRUNCATE/DELETE ALL). Please confirm you want to proceed or rephrase your question."
        ),
        'not_sql_related': (
            False,
            "I'm an SQL assistant. Please ask me questions about querying databases, and I'll help generate SQL queries."
        )
    }
    
    return responses.get(edge_case_type, (True, ""))

# =============================================================================
# PROMPT BUILDER CLASS
# =============================================================================

class PromptBuilder:
    """
    Main class for building prompts with context management.
    """
    
    def __init__(self):
        self.context = ConversationContext()
        self.log_file = None
        setup_directories()
    
    def build_prompt(self, question, rag_context="", include_history=True):
        """
        Build complete prompt for SQL generation.
        
        Args:
            question: User's natural language question
            rag_context: Retrieved examples from RAG
            include_history: Whether to include conversation history
            
        Returns:
            dict with 'success', 'prompt' or 'error'
        """
        # Check for edge cases
        edge_cases = detect_edge_cases(question)
        
        if edge_cases:
            should_continue, message = handle_edge_case(edge_cases[0], question)
            if not should_continue:
                return {
                    'success': False,
                    'error': message,
                    'edge_case': edge_cases[0]
                }
        
        # Analyze query intent
        intent = analyze_query_intent(question)
        
        # Get appropriate system prompt
        system_prompt = get_system_prompt(intent['query_type'])
        
        # Build context parts
        context_parts = []
        
        # Add schema context if available
        schema_context = self.context.get_schema_context()
        if schema_context:
            context_parts.append(schema_context)
        
        # Add conversation history
        if include_history:
            history_context = self.context.get_history_context()
            if history_context:
                context_parts.append(history_context)
        
        # Add RAG context
        if rag_context:
            context_parts.append(rag_context)
        
        # Build final prompt
        if rag_context:
            template = get_prompt_template('rag')
            prompt = template.format(
                context=rag_context,
                question=question
            )
        else:
            template = get_prompt_template('zero_shot')
            prompt = template.format(question=question)
        
        # Combine everything
        full_prompt = f"{system_prompt}\n\n"
        if context_parts:
            full_prompt += "\n".join(context_parts) + "\n\n"
        full_prompt += prompt
        
        # Log the prompt
        self._log_prompt(question, intent, full_prompt)
        
        return {
            'success': True,
            'prompt': full_prompt,
            'system_prompt': system_prompt,
            'query_type': intent['query_type'],
            'keywords': intent['keywords']
        }
    
    def add_response(self, question, sql_response, success=True):
        """Add a completed interaction to history."""
        self.context.add_turn(question, sql_response, success)
    
    def set_schema(self, schema_dict):
        """Set database schema for context."""
        self.context.set_schema(schema_dict)
    
    def clear_context(self):
        """Clear all context."""
        self.context.clear()
    
    def _log_prompt(self, question, intent, prompt):
        """Log prompt for debugging/analysis."""
        log_entry = {
            'timestamp': datetime.now().isoformat(),
            'question': question,
            'intent': intent,
            'prompt_length': len(prompt)
        }
        
        log_file = f"{LOGS_DIR}/prompt_log.jsonl"
        with open(log_file, 'a') as f:
            f.write(json.dumps(log_entry) + '\n')

# =============================================================================
# USER INTERACTION FLOWS
# =============================================================================

def get_clarification_questions(question, intent):
    """
    Generate clarification questions for ambiguous queries.
    """
    clarifications = []
    
    # Generic clarifications based on query type
    if intent['query_type'] == 'aggregation':
        clarifications.append("Which column should be aggregated?")
        clarifications.append("Should results be grouped by any column?")
    
    if intent['query_type'] == 'complex':
        clarifications.append("Which tables need to be joined?")
        clarifications.append("What is the relationship between the tables?")
    
    # Check for missing specifics
    if 'table' not in question.lower():
        clarifications.append("Which table(s) should be queried?")
    
    if not any(word in question.lower() for word in ['all', 'specific', 'where', 'filter']):
        clarifications.append("Do you want all records or filtered results?")
    
    return clarifications

def create_error_recovery_prompt(original_question, error_message):
    """
    Create prompt for recovering from errors.
    """
    return ERROR_RECOVERY_PROMPT.format(
        error=error_message,
        question=original_question
    )

# =============================================================================
# TEST
# =============================================================================

def test_prompt_builder():
    """Test the prompt builder functionality."""
    
    print("=" * 60)
    print("TESTING PROMPT BUILDER")
    print("=" * 60)
    
    builder = PromptBuilder()
    
    # Test 1: Normal question
    print("\n[TEST 1] Normal Question")
    print("-" * 40)
    result = builder.build_prompt(
        "Find all employees with salary above 50000",
        rag_context="Example 1:\nQ: Get workers earning more than 40000\nSQL: SELECT * FROM employees WHERE salary > 40000"
    )
    print(f"Success: {result['success']}")
    print(f"Query Type: {result.get('query_type')}")
    print(f"Prompt Length: {len(result.get('prompt', ''))}")
    
    # Test 2: Edge case - too short
    print("\n[TEST 2] Edge Case - Too Short")
    print("-" * 40)
    result = builder.build_prompt("SQL")
    print(f"Success: {result['success']}")
    print(f"Error: {result.get('error', 'None')}")
    
    # Test 3: Edge case - contains SQL
    print("\n[TEST 3] Edge Case - Contains SQL")
    print("-" * 40)
    result = builder.build_prompt("SELECT * FROM users WHERE id = 1")
    print(f"Success: {result['success']}")
    print(f"Error: {result.get('error', 'None')}")
    
    # Test 4: Edge case - dangerous operation
    print("\n[TEST 4] Edge Case - Dangerous Operation")
    print("-" * 40)
    result = builder.build_prompt("Drop table users")
    print(f"Success: {result['success']}")
    print(f"Error: {result.get('error', 'None')}")
    
    # Test 5: Aggregation query
    print("\n[TEST 5] Aggregation Query")
    print("-" * 40)
    result = builder.build_prompt("Count total orders by customer")
    print(f"Success: {result['success']}")
    print(f"Query Type: {result.get('query_type')}")
    
    # Test 6: Context management
    print("\n[TEST 6] Context Management")
    print("-" * 40)
    builder.set_schema({
        'employees': ['id', 'name', 'salary', 'dept_id'],
        'departments': ['id', 'name', 'location']
    })
    builder.add_response("Show all employees", "SELECT * FROM employees", success=True)
    result = builder.build_prompt("Now filter by department")
    print(f"Success: {result['success']}")
    print(f"Has History: {'Previous conversation' in result.get('prompt', '')}")
    
    print("\n" + "=" * 60)
    print("✓ All tests complete")
    print("=" * 60)

if __name__ == "__main__":
    test_prompt_builder()