File size: 10,275 Bytes
f29ea6c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
"""
Evaluation Module for Fine-Tuned SQL Model
"""

import os
import json
import matplotlib.pyplot as plt
from datetime import datetime
from collections import Counter

# =============================================================================
# CONFIGURATION
# =============================================================================

OUTPUT_DIR = "outputs/finetuning"
RESULTS_DIR = f"{OUTPUT_DIR}/results"
VIZ_DIR = f"{OUTPUT_DIR}/visualizations"

# Number of samples to evaluate
NUM_EVAL_SAMPLES = 50  # Change for more/less evaluation

def setup_directories():
    for d in [RESULTS_DIR, VIZ_DIR]:
        os.makedirs(d, exist_ok=True)

# =============================================================================
# EVALUATION METRICS
# =============================================================================

def exact_match(pred, expected):
    """Check exact match."""
    return pred.lower().strip() == expected.lower().strip()

def token_accuracy(pred, expected):
    """Token overlap accuracy."""
    pred_tokens = set(pred.lower().split())
    exp_tokens = set(expected.lower().split())
    if not exp_tokens:
        return 0.0
    return len(pred_tokens & exp_tokens) / len(exp_tokens)

def keyword_accuracy(pred, expected):
    """SQL keyword match accuracy."""
    keywords = ['SELECT', 'FROM', 'WHERE', 'JOIN', 'GROUP BY', 
                'ORDER BY', 'COUNT', 'SUM', 'AVG', 'MAX', 'MIN']
    
    pred_kw = [k for k in keywords if k in pred.upper()]
    exp_kw = [k for k in keywords if k in expected.upper()]
    
    if not exp_kw:
        return 1.0 if not pred_kw else 0.0
    
    matches = sum(1 for k in exp_kw if k in pred_kw)
    return matches / len(exp_kw)

def structure_similarity(pred, expected):
    """SQL structure similarity."""
    clauses = ['SELECT', 'FROM', 'WHERE', 'JOIN', 'GROUP BY', 'ORDER BY', 'LIMIT']
    
    pred_struct = set(c for c in clauses if c in pred.upper())
    exp_struct = set(c for c in clauses if c in expected.upper())
    
    if not exp_struct and not pred_struct:
        return 1.0
    if not exp_struct or not pred_struct:
        return 0.0
    
    return len(pred_struct & exp_struct) / len(pred_struct | exp_struct)

# =============================================================================
# EVALUATION RUNNER
# =============================================================================

def evaluate_predictions(predictions, ground_truth):
    """Calculate all metrics."""
    
    results = {
        'exact_match': [],
        'token_accuracy': [],
        'keyword_accuracy': [],
        'structure_similarity': []
    }
    
    for pred, exp in zip(predictions, ground_truth):
        results['exact_match'].append(1 if exact_match(pred, exp) else 0)
        results['token_accuracy'].append(token_accuracy(pred, exp))
        results['keyword_accuracy'].append(keyword_accuracy(pred, exp))
        results['structure_similarity'].append(structure_similarity(pred, exp))
    
    # Calculate averages
    metrics = {
        'total_samples': len(predictions),
        'exact_match_rate': sum(results['exact_match']) / len(results['exact_match']),
        'avg_token_accuracy': sum(results['token_accuracy']) / len(results['token_accuracy']),
        'avg_keyword_accuracy': sum(results['keyword_accuracy']) / len(results['keyword_accuracy']),
        'avg_structure_similarity': sum(results['structure_similarity']) / len(results['structure_similarity']),
        'detailed': results
    }
    
    return metrics

# =============================================================================
# VISUALIZATIONS
# =============================================================================

def create_visualizations(metrics):
    """Create evaluation charts."""
    
    setup_directories()
    plt.style.use('seaborn-v0_8-whitegrid')
    
    # 1. Metrics Overview
    fig, ax = plt.subplots(figsize=(10, 6))
    
    names = ['Exact Match', 'Token Acc', 'Keyword Acc', 'Structure Sim']
    values = [
        metrics['exact_match_rate'] * 100,
        metrics['avg_token_accuracy'] * 100,
        metrics['avg_keyword_accuracy'] * 100,
        metrics['avg_structure_similarity'] * 100
    ]
    colors = ['#3498db', '#2ecc71', '#9b59b6', '#e74c3c']
    
    bars = ax.bar(names, values, color=colors, edgecolor='black')
    for bar, val in zip(bars, values):
        ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 1,
                f'{val:.1f}%', ha='center', fontweight='bold')
    
    ax.set_ylabel('Score (%)')
    ax.set_title('Model Evaluation Metrics', fontsize=14, fontweight='bold')
    ax.set_ylim(0, 110)
    
    plt.tight_layout()
    plt.savefig(f'{VIZ_DIR}/01_metrics_overview.png', dpi=150)
    plt.close()
    print(f"  Saved: {VIZ_DIR}/01_metrics_overview.png")
    
    # 2. Token Accuracy Distribution
    fig, ax = plt.subplots(figsize=(10, 6))
    
    token_acc = metrics['detailed']['token_accuracy']
    ax.hist(token_acc, bins=20, color='#2ecc71', edgecolor='black', alpha=0.7)
    ax.axvline(sum(token_acc)/len(token_acc), color='red', linestyle='--',
               label=f"Mean: {sum(token_acc)/len(token_acc):.2f}")
    ax.set_xlabel('Token Accuracy')
    ax.set_ylabel('Frequency')
    ax.set_title('Token Accuracy Distribution', fontsize=14, fontweight='bold')
    ax.legend()
    
    plt.tight_layout()
    plt.savefig(f'{VIZ_DIR}/02_token_accuracy_dist.png', dpi=150)
    plt.close()
    print(f"  Saved: {VIZ_DIR}/02_token_accuracy_dist.png")
    
    # 3. Keyword Accuracy Distribution
    fig, ax = plt.subplots(figsize=(10, 6))
    
    kw_acc = metrics['detailed']['keyword_accuracy']
    ax.hist(kw_acc, bins=20, color='#9b59b6', edgecolor='black', alpha=0.7)
    ax.axvline(sum(kw_acc)/len(kw_acc), color='red', linestyle='--',
               label=f"Mean: {sum(kw_acc)/len(kw_acc):.2f}")
    ax.set_xlabel('Keyword Accuracy')
    ax.set_ylabel('Frequency')
    ax.set_title('Keyword Accuracy Distribution', fontsize=14, fontweight='bold')
    ax.legend()
    
    plt.tight_layout()
    plt.savefig(f'{VIZ_DIR}/03_keyword_accuracy_dist.png', dpi=150)
    plt.close()
    print(f"  Saved: {VIZ_DIR}/03_keyword_accuracy_dist.png")

# =============================================================================
# REPORT GENERATION
# =============================================================================

def generate_report(metrics):
    """Generate evaluation report."""
    
    report = f"""# Fine-Tuning Evaluation Report

**Generated:** {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}

## Metrics Summary

| Metric | Score |
|--------|-------|
| Samples Evaluated | {metrics['total_samples']} |
| Exact Match Rate | {metrics['exact_match_rate']*100:.2f}% |
| Token Accuracy | {metrics['avg_token_accuracy']*100:.2f}% |
| Keyword Accuracy | {metrics['avg_keyword_accuracy']*100:.2f}% |
| Structure Similarity | {metrics['avg_structure_similarity']*100:.2f}% |

## Metrics Explanation

- **Exact Match**: Predictions identical to ground truth
- **Token Accuracy**: Word overlap between prediction and expected
- **Keyword Accuracy**: SQL keywords (SELECT, WHERE, etc.) match
- **Structure Similarity**: Query structure (clauses used) match

## Visualizations

- `01_metrics_overview.png` - All metrics bar chart
- `02_token_accuracy_dist.png` - Token accuracy histogram
- `03_keyword_accuracy_dist.png` - Keyword accuracy histogram
"""
    
    with open(f'{RESULTS_DIR}/evaluation_report.md', 'w') as f:
        f.write(report)
    print(f"  Saved: {RESULTS_DIR}/evaluation_report.md")
    
    # Save JSON
    json_metrics = {k: v for k, v in metrics.items() if k != 'detailed'}
    with open(f'{RESULTS_DIR}/evaluation_results.json', 'w') as f:
        json.dump(json_metrics, f, indent=2)
    print(f"  Saved: {RESULTS_DIR}/evaluation_results.json")

# =============================================================================
# MAIN EVALUATION
# =============================================================================

def run_evaluation():
    """Run full evaluation."""
    
    print("=" * 60)
    print("EVALUATING FINE-TUNED MODEL")
    print("=" * 60)
    
    setup_directories()
    
    # Load test data
    print("\n[1/4] Loading test data...")
    test_file = f"{OUTPUT_DIR}/test.jsonl"
    
    if not os.path.exists(test_file):
        print("ERROR: Run prepare_data.py first!")
        return None
    
    test_data = []
    with open(test_file) as f:
        for line in f:
            test_data.append(json.loads(line))
    
    test_data = test_data[:NUM_EVAL_SAMPLES]
    print(f"  Loaded {len(test_data)} samples")
    
    # Generate predictions
    print("\n[2/4] Generating predictions...")
    
    try:
        import sys
        sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
        from finetuning.inference import SQLGenerator
        generator = SQLGenerator()
        
        predictions = []
        ground_truth = []
        
        for i, item in enumerate(test_data):
            pred = generator.generate(item['question'])
            predictions.append(pred)
            ground_truth.append(item['sql'])
            
            if (i + 1) % 10 == 0:
                print(f"  Progress: {i+1}/{len(test_data)}")
        
    except Exception as e:
        print(f"  Error loading model: {e}")
        print("  Using ground truth as predictions (for testing metrics)")
        predictions = [item['sql'] for item in test_data]
        ground_truth = [item['sql'] for item in test_data]
    
    # Calculate metrics
    print("\n[3/4] Calculating metrics...")
    metrics = evaluate_predictions(predictions, ground_truth)
    
    print(f"  Exact Match: {metrics['exact_match_rate']*100:.2f}%")
    print(f"  Token Accuracy: {metrics['avg_token_accuracy']*100:.2f}%")
    print(f"  Keyword Accuracy: {metrics['avg_keyword_accuracy']*100:.2f}%")
    print(f"  Structure Sim: {metrics['avg_structure_similarity']*100:.2f}%")
    
    # Generate outputs
    print("\n[4/4] Generating outputs...")
    create_visualizations(metrics)
    generate_report(metrics)
    
    print("\n" + "=" * 60)
    print("EVALUATION COMPLETE")
    print("=" * 60)
    
    return metrics

if __name__ == "__main__":
    run_evaluation()