Spaces:
Sleeping
Sleeping
| """ | |
| Evaluation Module for Fine-Tuned SQL Model | |
| """ | |
| import os | |
| import json | |
| import matplotlib.pyplot as plt | |
| from datetime import datetime | |
| from collections import Counter | |
| # ============================================================================= | |
| # CONFIGURATION | |
| # ============================================================================= | |
| OUTPUT_DIR = "outputs/finetuning" | |
| RESULTS_DIR = f"{OUTPUT_DIR}/results" | |
| VIZ_DIR = f"{OUTPUT_DIR}/visualizations" | |
| # Number of samples to evaluate | |
| NUM_EVAL_SAMPLES = 50 # Change for more/less evaluation | |
| def setup_directories(): | |
| for d in [RESULTS_DIR, VIZ_DIR]: | |
| os.makedirs(d, exist_ok=True) | |
| # ============================================================================= | |
| # EVALUATION METRICS | |
| # ============================================================================= | |
| def exact_match(pred, expected): | |
| """Check exact match.""" | |
| return pred.lower().strip() == expected.lower().strip() | |
| def token_accuracy(pred, expected): | |
| """Token overlap accuracy.""" | |
| pred_tokens = set(pred.lower().split()) | |
| exp_tokens = set(expected.lower().split()) | |
| if not exp_tokens: | |
| return 0.0 | |
| return len(pred_tokens & exp_tokens) / len(exp_tokens) | |
| def keyword_accuracy(pred, expected): | |
| """SQL keyword match accuracy.""" | |
| keywords = ['SELECT', 'FROM', 'WHERE', 'JOIN', 'GROUP BY', | |
| 'ORDER BY', 'COUNT', 'SUM', 'AVG', 'MAX', 'MIN'] | |
| pred_kw = [k for k in keywords if k in pred.upper()] | |
| exp_kw = [k for k in keywords if k in expected.upper()] | |
| if not exp_kw: | |
| return 1.0 if not pred_kw else 0.0 | |
| matches = sum(1 for k in exp_kw if k in pred_kw) | |
| return matches / len(exp_kw) | |
| def structure_similarity(pred, expected): | |
| """SQL structure similarity.""" | |
| clauses = ['SELECT', 'FROM', 'WHERE', 'JOIN', 'GROUP BY', 'ORDER BY', 'LIMIT'] | |
| pred_struct = set(c for c in clauses if c in pred.upper()) | |
| exp_struct = set(c for c in clauses if c in expected.upper()) | |
| if not exp_struct and not pred_struct: | |
| return 1.0 | |
| if not exp_struct or not pred_struct: | |
| return 0.0 | |
| return len(pred_struct & exp_struct) / len(pred_struct | exp_struct) | |
| # ============================================================================= | |
| # EVALUATION RUNNER | |
| # ============================================================================= | |
| def evaluate_predictions(predictions, ground_truth): | |
| """Calculate all metrics.""" | |
| results = { | |
| 'exact_match': [], | |
| 'token_accuracy': [], | |
| 'keyword_accuracy': [], | |
| 'structure_similarity': [] | |
| } | |
| for pred, exp in zip(predictions, ground_truth): | |
| results['exact_match'].append(1 if exact_match(pred, exp) else 0) | |
| results['token_accuracy'].append(token_accuracy(pred, exp)) | |
| results['keyword_accuracy'].append(keyword_accuracy(pred, exp)) | |
| results['structure_similarity'].append(structure_similarity(pred, exp)) | |
| # Calculate averages | |
| metrics = { | |
| 'total_samples': len(predictions), | |
| 'exact_match_rate': sum(results['exact_match']) / len(results['exact_match']), | |
| 'avg_token_accuracy': sum(results['token_accuracy']) / len(results['token_accuracy']), | |
| 'avg_keyword_accuracy': sum(results['keyword_accuracy']) / len(results['keyword_accuracy']), | |
| 'avg_structure_similarity': sum(results['structure_similarity']) / len(results['structure_similarity']), | |
| 'detailed': results | |
| } | |
| return metrics | |
| # ============================================================================= | |
| # VISUALIZATIONS | |
| # ============================================================================= | |
| def create_visualizations(metrics): | |
| """Create evaluation charts.""" | |
| setup_directories() | |
| plt.style.use('seaborn-v0_8-whitegrid') | |
| # 1. Metrics Overview | |
| fig, ax = plt.subplots(figsize=(10, 6)) | |
| names = ['Exact Match', 'Token Acc', 'Keyword Acc', 'Structure Sim'] | |
| values = [ | |
| metrics['exact_match_rate'] * 100, | |
| metrics['avg_token_accuracy'] * 100, | |
| metrics['avg_keyword_accuracy'] * 100, | |
| metrics['avg_structure_similarity'] * 100 | |
| ] | |
| colors = ['#3498db', '#2ecc71', '#9b59b6', '#e74c3c'] | |
| bars = ax.bar(names, values, color=colors, edgecolor='black') | |
| for bar, val in zip(bars, values): | |
| ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 1, | |
| f'{val:.1f}%', ha='center', fontweight='bold') | |
| ax.set_ylabel('Score (%)') | |
| ax.set_title('Model Evaluation Metrics', fontsize=14, fontweight='bold') | |
| ax.set_ylim(0, 110) | |
| plt.tight_layout() | |
| plt.savefig(f'{VIZ_DIR}/01_metrics_overview.png', dpi=150) | |
| plt.close() | |
| print(f" Saved: {VIZ_DIR}/01_metrics_overview.png") | |
| # 2. Token Accuracy Distribution | |
| fig, ax = plt.subplots(figsize=(10, 6)) | |
| token_acc = metrics['detailed']['token_accuracy'] | |
| ax.hist(token_acc, bins=20, color='#2ecc71', edgecolor='black', alpha=0.7) | |
| ax.axvline(sum(token_acc)/len(token_acc), color='red', linestyle='--', | |
| label=f"Mean: {sum(token_acc)/len(token_acc):.2f}") | |
| ax.set_xlabel('Token Accuracy') | |
| ax.set_ylabel('Frequency') | |
| ax.set_title('Token Accuracy Distribution', fontsize=14, fontweight='bold') | |
| ax.legend() | |
| plt.tight_layout() | |
| plt.savefig(f'{VIZ_DIR}/02_token_accuracy_dist.png', dpi=150) | |
| plt.close() | |
| print(f" Saved: {VIZ_DIR}/02_token_accuracy_dist.png") | |
| # 3. Keyword Accuracy Distribution | |
| fig, ax = plt.subplots(figsize=(10, 6)) | |
| kw_acc = metrics['detailed']['keyword_accuracy'] | |
| ax.hist(kw_acc, bins=20, color='#9b59b6', edgecolor='black', alpha=0.7) | |
| ax.axvline(sum(kw_acc)/len(kw_acc), color='red', linestyle='--', | |
| label=f"Mean: {sum(kw_acc)/len(kw_acc):.2f}") | |
| ax.set_xlabel('Keyword Accuracy') | |
| ax.set_ylabel('Frequency') | |
| ax.set_title('Keyword Accuracy Distribution', fontsize=14, fontweight='bold') | |
| ax.legend() | |
| plt.tight_layout() | |
| plt.savefig(f'{VIZ_DIR}/03_keyword_accuracy_dist.png', dpi=150) | |
| plt.close() | |
| print(f" Saved: {VIZ_DIR}/03_keyword_accuracy_dist.png") | |
| # ============================================================================= | |
| # REPORT GENERATION | |
| # ============================================================================= | |
| def generate_report(metrics): | |
| """Generate evaluation report.""" | |
| report = f"""# Fine-Tuning Evaluation Report | |
| **Generated:** {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} | |
| ## Metrics Summary | |
| | Metric | Score | | |
| |--------|-------| | |
| | Samples Evaluated | {metrics['total_samples']} | | |
| | Exact Match Rate | {metrics['exact_match_rate']*100:.2f}% | | |
| | Token Accuracy | {metrics['avg_token_accuracy']*100:.2f}% | | |
| | Keyword Accuracy | {metrics['avg_keyword_accuracy']*100:.2f}% | | |
| | Structure Similarity | {metrics['avg_structure_similarity']*100:.2f}% | | |
| ## Metrics Explanation | |
| - **Exact Match**: Predictions identical to ground truth | |
| - **Token Accuracy**: Word overlap between prediction and expected | |
| - **Keyword Accuracy**: SQL keywords (SELECT, WHERE, etc.) match | |
| - **Structure Similarity**: Query structure (clauses used) match | |
| ## Visualizations | |
| - `01_metrics_overview.png` - All metrics bar chart | |
| - `02_token_accuracy_dist.png` - Token accuracy histogram | |
| - `03_keyword_accuracy_dist.png` - Keyword accuracy histogram | |
| """ | |
| with open(f'{RESULTS_DIR}/evaluation_report.md', 'w') as f: | |
| f.write(report) | |
| print(f" Saved: {RESULTS_DIR}/evaluation_report.md") | |
| # Save JSON | |
| json_metrics = {k: v for k, v in metrics.items() if k != 'detailed'} | |
| with open(f'{RESULTS_DIR}/evaluation_results.json', 'w') as f: | |
| json.dump(json_metrics, f, indent=2) | |
| print(f" Saved: {RESULTS_DIR}/evaluation_results.json") | |
| # ============================================================================= | |
| # MAIN EVALUATION | |
| # ============================================================================= | |
| def run_evaluation(): | |
| """Run full evaluation.""" | |
| print("=" * 60) | |
| print("EVALUATING FINE-TUNED MODEL") | |
| print("=" * 60) | |
| setup_directories() | |
| # Load test data | |
| print("\n[1/4] Loading test data...") | |
| test_file = f"{OUTPUT_DIR}/test.jsonl" | |
| if not os.path.exists(test_file): | |
| print("ERROR: Run prepare_data.py first!") | |
| return None | |
| test_data = [] | |
| with open(test_file) as f: | |
| for line in f: | |
| test_data.append(json.loads(line)) | |
| test_data = test_data[:NUM_EVAL_SAMPLES] | |
| print(f" Loaded {len(test_data)} samples") | |
| # Generate predictions | |
| print("\n[2/4] Generating predictions...") | |
| try: | |
| import sys | |
| sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
| from finetuning.inference import SQLGenerator | |
| generator = SQLGenerator() | |
| predictions = [] | |
| ground_truth = [] | |
| for i, item in enumerate(test_data): | |
| pred = generator.generate(item['question']) | |
| predictions.append(pred) | |
| ground_truth.append(item['sql']) | |
| if (i + 1) % 10 == 0: | |
| print(f" Progress: {i+1}/{len(test_data)}") | |
| except Exception as e: | |
| print(f" Error loading model: {e}") | |
| print(" Using ground truth as predictions (for testing metrics)") | |
| predictions = [item['sql'] for item in test_data] | |
| ground_truth = [item['sql'] for item in test_data] | |
| # Calculate metrics | |
| print("\n[3/4] Calculating metrics...") | |
| metrics = evaluate_predictions(predictions, ground_truth) | |
| print(f" Exact Match: {metrics['exact_match_rate']*100:.2f}%") | |
| print(f" Token Accuracy: {metrics['avg_token_accuracy']*100:.2f}%") | |
| print(f" Keyword Accuracy: {metrics['avg_keyword_accuracy']*100:.2f}%") | |
| print(f" Structure Sim: {metrics['avg_structure_similarity']*100:.2f}%") | |
| # Generate outputs | |
| print("\n[4/4] Generating outputs...") | |
| create_visualizations(metrics) | |
| generate_report(metrics) | |
| print("\n" + "=" * 60) | |
| print("EVALUATION COMPLETE") | |
| print("=" * 60) | |
| return metrics | |
| if __name__ == "__main__": | |
| run_evaluation() |