moheesh commited on
Commit
f29ea6c
·
1 Parent(s): e9aa12a

got all my code

Browse files
Files changed (44) hide show
  1. Dockerfile +1 -1
  2. src/.env.example +32 -0
  3. src/.gitignore +100 -0
  4. src/README.md +209 -0
  5. src/app.py +497 -0
  6. src/config.py +98 -0
  7. src/finetuning/__init__.py +0 -0
  8. src/finetuning/evaluate.py +293 -0
  9. src/finetuning/inference.py +168 -0
  10. src/finetuning/prepare_data.py +149 -0
  11. src/finetuning/train.py +218 -0
  12. src/outputs/finetuning/data_stats.json +7 -0
  13. src/outputs/finetuning/results/evaluation_report.md +26 -0
  14. src/outputs/finetuning/results/evaluation_results.json +7 -0
  15. src/outputs/finetuning/test.jsonl +100 -0
  16. src/outputs/finetuning/train.jsonl +100 -0
  17. src/outputs/finetuning/val.jsonl +100 -0
  18. src/outputs/finetuning/visualizations/01_metrics_overview.png +0 -0
  19. src/outputs/finetuning/visualizations/02_token_accuracy_dist.png +0 -0
  20. src/outputs/finetuning/visualizations/03_keyword_accuracy_dist.png +0 -0
  21. src/outputs/finetuning/visualizations/04_training_loss.png +0 -0
  22. src/outputs/rag/reports/knowledge_base_report.md +46 -0
  23. src/outputs/rag/stats/knowledge_base_stats.json +22 -0
  24. src/outputs/synthetic/reports/synthetic_report.md +47 -0
  25. src/outputs/synthetic/stats/statistics.json +24 -0
  26. src/outputs/synthetic/visualizations/01_size_comparison.png +0 -0
  27. src/outputs/synthetic/visualizations/02_length_distribution.png +0 -0
  28. src/outputs/synthetic/visualizations/03_diversity_distribution.png +0 -0
  29. src/pipeline/integrated.py +584 -0
  30. src/prompts/__init__.py +0 -0
  31. src/prompts/prompt_builder.py +440 -0
  32. src/prompts/system_prompts.py +162 -0
  33. src/rag/__init__.py +0 -0
  34. src/rag/embeddings.py +87 -0
  35. src/rag/knowledge_base.py +415 -0
  36. src/rag/retriever.py +234 -0
  37. src/requirements.txt +23 -0
  38. src/streamlit_app.py +0 -40
  39. src/synthetic/__init__.py +0 -0
  40. src/synthetic/generate_data.py +401 -0
  41. src/synthetic/synonyms.py +149 -0
  42. src/tests/test_finetuned.py +0 -0
  43. src/tests/test_rag.py +0 -0
  44. src/tests/test_synthetic.py +0 -0
Dockerfile CHANGED
@@ -17,4 +17,4 @@ EXPOSE 8501
17
 
18
  HEALTHCHECK CMD curl --fail http://localhost:8501/_stcore/health
19
 
20
- ENTRYPOINT ["streamlit", "run", "src/streamlit_app.py", "--server.port=8501", "--server.address=0.0.0.0"]
 
17
 
18
  HEALTHCHECK CMD curl --fail http://localhost:8501/_stcore/health
19
 
20
+ ENTRYPOINT ["streamlit", "run", "src/app.py", "--server.port=8501", "--server.address=0.0.0.0"]
src/.env.example ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # =============================================================================
2
+ # GEMINI API KEYS (Required)
3
+ # =============================================================================
4
+ GEMINI_API_KEY=your-primary-gemini-key
5
+ GEMINI_API_KEY_FALLBACK_1=your-fallback-key-1
6
+ GEMINI_API_KEY_FALLBACK_2=your-fallback-key-2
7
+
8
+ # =============================================================================
9
+ # GEMINI MODELS
10
+ # =============================================================================
11
+ GEMINI_MODEL=gemini-2.5-flash
12
+ GEMINI_MODEL_FALLBACK_1=gemini-2.5-flash-lite
13
+
14
+ # =============================================================================
15
+ # HUGGINGFACE (Required for cloud deployment, optional for local)
16
+ # =============================================================================
17
+ HF_TOKEN=your-huggingface-token
18
+ HF_MODEL_ID=your-username/sql-tinyllama-lora
19
+ HF_CHROMADB_ID=your-username/sql-chromadb
20
+
21
+ # =============================================================================
22
+ # HOW IT WORKS:
23
+ # =============================================================================
24
+ # LOCAL RUN:
25
+ # - If outputs/finetuning/checkpoints/final exists → uses local model
26
+ # - If chromadb_data exists → uses local ChromaDB
27
+ #
28
+ # CLOUD RUN (Streamlit):
29
+ # - If HF_MODEL_ID set → downloads model from HuggingFace
30
+ # - If HF_CHROMADB_ID set → downloads ChromaDB from HuggingFace
31
+ # - Falls back to building ChromaDB from data/ folder if needed
32
+ # =============================================================================
src/.gitignore ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # =============================================================================
2
+ # ENVIRONMENT & SECRETS
3
+ # =============================================================================
4
+ .env
5
+ .env.local
6
+ .env.production
7
+
8
+ # =============================================================================
9
+ # VIRTUAL ENVIRONMENT
10
+ # =============================================================================
11
+ .venv/
12
+ venv/
13
+ env/
14
+ ENV/
15
+ data/
16
+ data/synthetic.csv
17
+ # =============================================================================
18
+ # PYTHON
19
+ # =============================================================================
20
+ __pycache__/
21
+ *.py[cod]
22
+ *$py.class
23
+ *.so
24
+ .Python
25
+ build/
26
+ develop-eggs/
27
+ dist/
28
+ downloads/
29
+ eggs/
30
+ .eggs/
31
+ lib/
32
+ lib64/
33
+ parts/
34
+ sdist/
35
+ var/
36
+ wheels/
37
+ *.egg-info/
38
+ .installed.cfg
39
+ *.egg
40
+
41
+ # =============================================================================
42
+ # MODEL FILES (Upload to HuggingFace instead)
43
+ # =============================================================================
44
+ outputs/finetuning/checkpoints/
45
+ *.bin
46
+ *.pt
47
+ *.pth
48
+ *.safetensors
49
+ *.ckpt
50
+
51
+ # =============================================================================
52
+ # CHROMADB (Upload to HuggingFace instead)
53
+ # =============================================================================
54
+ chromadb_data/
55
+
56
+ # =============================================================================
57
+ # LOGS & OUTPUTS
58
+ # =============================================================================
59
+ *.log
60
+ outputs/*/logs/
61
+ outputs/pipeline/logs/
62
+ outputs/prompts/logs/
63
+
64
+ # =============================================================================
65
+ # IDE
66
+ # =============================================================================
67
+ .vscode/
68
+ .idea/
69
+ *.swp
70
+ *.swo
71
+ *~
72
+
73
+ # =============================================================================
74
+ # OS
75
+ # =============================================================================
76
+ .DS_Store
77
+ .DS_Store?
78
+ ._*
79
+ .Spotlight-V100
80
+ .Trashes
81
+ ehthumbs.db
82
+ Thumbs.db
83
+
84
+ # =============================================================================
85
+ # JUPYTER
86
+ # =============================================================================
87
+ .ipynb_checkpoints/
88
+ *.ipynb
89
+
90
+ # =============================================================================
91
+ # KEEP THESE (don't ignore)
92
+ # =============================================================================
93
+ # !data/
94
+ # !data/*.csv
95
+ # !docs/
96
+ # !docs/index.html
97
+ # !*.py
98
+ # !requirements.txt
99
+ # !README.md
100
+ # !.env.example
src/README.md ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ⚡ Prompt to SQL using RAG + LLM
2
+
3
+ AI-powered Natural Language to SQL conversion using RAG, Fine-tuned LLM, and Gemini Enhancement.
4
+
5
+ ![Python](https://img.shields.io/badge/Python-3.10+-blue.svg)
6
+ ![Streamlit](https://img.shields.io/badge/Streamlit-1.28+-red.svg)
7
+ ![License](https://img.shields.io/badge/License-MIT-green.svg)
8
+
9
+ ## 🌐 Live Demo
10
+
11
+ - **🚀 Web App:** [Streamlit App](https://your-app.streamlit.app)
12
+ - **📄 Project Page:** [GitHub Pages](https://moheesh.github.io/Prompt_to_SQL_using_RAG_LLM)
13
+
14
+ ## ✨ Features
15
+
16
+ | Feature | Description |
17
+ |---------|-------------|
18
+ | 🔍 **RAG Retrieval** | 80,000+ SQL examples in ChromaDB vector store |
19
+ | 🤖 **Fine-tuned LLM** | TinyLlama with LoRA for SQL generation |
20
+ | ✨ **Gemini Enhancement** | Query refinement, validation & explanation |
21
+ | 📝 **Prompt Engineering** | Context management, edge cases, query analysis |
22
+ | 📦 **Synthetic Data** | Data augmentation with 5 techniques |
23
+ | 🔄 **Auto Fallback** | Multiple API keys & models for reliability |
24
+
25
+ ## 🔄 Pipeline Architecture
26
+
27
+ ```
28
+ ┌─────────────────────┐
29
+ │ Synthetic Data │ (Training augmentation)
30
+ └─────────┬───────────┘
31
+
32
+ ┌─────────────────────┐
33
+ │ Fine-tuned Model │ (LoRA training on TinyLlama)
34
+ └─────────┬───────────┘
35
+
36
+ ┌─────────────────────┐
37
+ │ User Question │ (Natural language input)
38
+ └─────────┬───────────┘
39
+
40
+ ┌─────────────────────┐
41
+ │ RAG Retrieval │ (Similar examples from ChromaDB)
42
+ └─────────┬───────────┘
43
+
44
+ ┌─────────────────────┐
45
+ │ Prompt Engineering │ (Context + query formatting)
46
+ └─────────┬───────────┘
47
+
48
+ ┌─────────────────────┐
49
+ │ Fine-tuned Model │ (SQL generation)
50
+ └─────────┬───────────┘
51
+
52
+ ┌─────────────────────┐
53
+ │ Gemini Enhancement │ (Refine + explain)
54
+ └─────────┬───────────┘
55
+
56
+ ┌─────────────────────┐
57
+ │ Final SQL │ (Optimized output)
58
+ └─────────────────────┘
59
+ ```
60
+
61
+ ## 📁 Project Structure
62
+
63
+ ```
64
+ Prompt_to_SQL_using_RAG_LLM/
65
+ ├── app.py # Streamlit UI
66
+ ├── config.py # Central configuration
67
+ ├── requirements.txt # Dependencies
68
+
69
+ ├── pipeline/
70
+ │ └── integrated.py # Main pipeline (RAG + Model + Gemini)
71
+
72
+ ├── finetuning/
73
+ │ ├── prepare_data.py # Data preparation
74
+ │ ├── train.py # LoRA fine-tuning
75
+ │ ├── evaluate.py # Model evaluation
76
+ │ └── inference.py # SQL generation
77
+
78
+ ├── rag/
79
+ │ ├── embeddings.py # Sentence transformers
80
+ │ ├── knowledge_base.py # ChromaDB builder
81
+ │ └── retriever.py # LangChain retriever
82
+
83
+ ├── prompts/
84
+ │ ├── prompt_builder.py # Context management
85
+ │ └── system_prompts.py # Prompt templates
86
+
87
+ ├── synthetic/
88
+ │ ├── generate_data.py # Data augmentation
89
+ │ └── synonyms.py # Synonym dictionary
90
+
91
+ ├── data/
92
+ │ ├── train.csv
93
+ │ ├── validation.csv
94
+ │ └── test.csv
95
+
96
+ └── docs/
97
+ └── index.html # GitHub Pages
98
+ ```
99
+
100
+ ## 🛠️ Setup
101
+
102
+ ### 1. Clone the Repository
103
+
104
+ ```bash
105
+ git clone https://github.com/moheesh/Prompt_to_SQL_using_RAG_LLM.git
106
+ cd Prompt_to_SQL_using_RAG_LLM
107
+ ```
108
+
109
+ ### 2. Create Virtual Environment
110
+
111
+ ```bash
112
+ python -m venv .venv
113
+
114
+ # Windows
115
+ .venv\Scripts\activate
116
+
117
+ # Mac/Linux
118
+ source .venv/bin/activate
119
+ ```
120
+
121
+ ### 3. Install Dependencies
122
+
123
+ ```bash
124
+ pip install -r requirements.txt
125
+ ```
126
+
127
+ ### 4. Configure Environment
128
+
129
+ Create a `.env` file:
130
+
131
+ ```env
132
+ # Gemini API
133
+ GEMINI_API_KEY=your-primary-key
134
+ GEMINI_MODEL=gemini-2.5-flash
135
+
136
+ # HuggingFace (for cloud deployment)
137
+ HF_TOKEN=your-hf-token
138
+ HF_MODEL_ID=your-username/sql-tinyllama-lora
139
+ HF_CHROMADB_ID=your-username/sql-chromadb
140
+ ```
141
+
142
+ ### 5. Build Knowledge Base (First Time)
143
+
144
+ ```bash
145
+ python rag/knowledge_base.py
146
+ ```
147
+
148
+ ### 6. Run the App
149
+
150
+ ```bash
151
+ streamlit run app.py
152
+ ```
153
+
154
+ ## 🚀 Deployment
155
+
156
+ ### Upload to HuggingFace
157
+
158
+ ```bash
159
+ # Login
160
+ huggingface-cli login
161
+
162
+ # Upload model
163
+ python -c "from huggingface_hub import HfApi; api = HfApi(); api.upload_folder(folder_path='outputs/finetuning/checkpoints/final', repo_id='moheesh/sql-tinyllama-lora', repo_type='model', create_repo=True)"
164
+
165
+ # Upload ChromaDB
166
+ python -c "from huggingface_hub import HfApi; api = HfApi(); api.upload_folder(folder_path='chromadb_data', repo_id='moheesh/sql-chromadb', repo_type='dataset', create_repo=True)"
167
+ ```
168
+
169
+ ### Deploy to Streamlit Cloud
170
+
171
+ 1. Push code to GitHub
172
+ 2. Go to [share.streamlit.io](https://share.streamlit.io)
173
+ 3. Connect your repo
174
+ 4. Add secrets (same as `.env`)
175
+ 5. Deploy!
176
+
177
+ ## 🛠️ Tech Stack
178
+
179
+ | Component | Technology |
180
+ |-----------|------------|
181
+ | LLM | TinyLlama + LoRA |
182
+ | Vector DB | ChromaDB |
183
+ | Embeddings | all-MiniLM-L6-v2 |
184
+ | Enhancement | Gemini API |
185
+ | Framework | LangChain |
186
+ | UI | Streamlit |
187
+
188
+ ## 📊 Evaluation Metrics
189
+
190
+ | Metric | Score |
191
+ |--------|-------|
192
+ | Exact Match | XX% |
193
+ | Token Accuracy | XX% |
194
+ | Keyword Accuracy | XX% |
195
+ | Structure Similarity | XX% |
196
+
197
+ ## 🎓 Course
198
+
199
+ **INFO7375** - Northeastern University
200
+
201
+ ## 👤 Author
202
+
203
+ **Your Name**
204
+ - GitHub: [@moheesh](https://github.com/moheesh)
205
+ - LinkedIn: [LinkedIn](https://linkedin.com/in/moheesh-k-a-a95306169)
206
+
207
+ ## 📝 License
208
+
209
+ MIT License - see [LICENSE](LICENSE) for details.
src/app.py ADDED
@@ -0,0 +1,497 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Streamlit App for SQL Learning Assistant
3
+ Integrates: RAG + Fine-tuned Model + Gemini Enhancement
4
+ """
5
+
6
+ import streamlit as st
7
+ import os
8
+ import sys
9
+ from dotenv import load_dotenv
10
+
11
+ # Load environment variables FIRST
12
+ load_dotenv()
13
+
14
+ # Add parent directory
15
+ sys.path.append(os.path.dirname(os.path.abspath(__file__)))
16
+
17
+ # =============================================================================
18
+ # PAGE CONFIG - MUST BE FIRST STREAMLIT COMMAND
19
+ # =============================================================================
20
+
21
+ st.set_page_config(
22
+ page_title="SQL Learning Assistant",
23
+ page_icon="⚡",
24
+ layout="wide",
25
+ initial_sidebar_state="expanded"
26
+ )
27
+
28
+ # =============================================================================
29
+ # CACHED LOADERS - Load on-demand, cache forever
30
+ # =============================================================================
31
+
32
+ @st.cache_resource(show_spinner=False)
33
+ def load_chromadb():
34
+ """Download ChromaDB from HuggingFace if needed."""
35
+ chromadb_path = "chromadb_data"
36
+ hf_chromadb_id = os.getenv("HF_CHROMADB_ID", None)
37
+
38
+ has_files = False
39
+ if os.path.exists(chromadb_path):
40
+ local_files = os.listdir(chromadb_path) if os.path.isdir(chromadb_path) else []
41
+ has_files = any('chroma' in f.lower() or 'sqlite' in f.lower() for f in local_files) or len(local_files) > 2
42
+
43
+ if not has_files and hf_chromadb_id:
44
+ from huggingface_hub import snapshot_download
45
+ os.makedirs(chromadb_path, exist_ok=True)
46
+ snapshot_download(repo_id=hf_chromadb_id, repo_type="dataset", local_dir=chromadb_path)
47
+
48
+ return chromadb_path
49
+
50
+ @st.cache_resource(show_spinner=False)
51
+ def load_retriever():
52
+ """Load the RAG retriever."""
53
+ load_chromadb()
54
+ from rag.retriever import SQLRetriever
55
+ return SQLRetriever()
56
+
57
+ @st.cache_resource(show_spinner=False)
58
+ def load_model():
59
+ """Load the fine-tuned model."""
60
+ from finetuning.inference import SQLGenerator
61
+ return SQLGenerator()
62
+
63
+ @st.cache_resource(show_spinner=False)
64
+ def load_prompt_builder():
65
+ """Load prompt builder."""
66
+ from prompts.prompt_builder import PromptBuilder
67
+ return PromptBuilder()
68
+
69
+ @st.cache_resource(show_spinner=False)
70
+ def load_gemini():
71
+ """Load Gemini client."""
72
+ from pipeline.integrated import GeminiClient, GEMINI_KEYS
73
+ if GEMINI_KEYS:
74
+ return GeminiClient()
75
+ return None
76
+
77
+ # =============================================================================
78
+ # HELPER FUNCTION TO RUN PIPELINE
79
+ # =============================================================================
80
+
81
+ def run_pipeline(question, num_examples=3):
82
+ """Run the full pipeline - loads components on first use."""
83
+ result = {
84
+ 'question': question,
85
+ 'success': False,
86
+ 'steps': {}
87
+ }
88
+
89
+ # Step 1: RAG
90
+ rag_context = ""
91
+ examples = []
92
+ try:
93
+ with st.spinner("🔍 Loading RAG system..."):
94
+ retriever = load_retriever()
95
+ if retriever:
96
+ examples = retriever.retrieve(question, top_k=num_examples)
97
+ rag_context = "Similar SQL examples:\n\n"
98
+ for i, r in enumerate(examples, 1):
99
+ rag_context += f"Example {i}:\nQuestion: {r['question']}\nSQL: {r['sql']}\n\n"
100
+ except Exception as e:
101
+ st.warning(f"RAG error: {e}")
102
+
103
+ result['steps']['rag'] = {'examples': examples, 'num_examples': len(examples), 'context': rag_context}
104
+
105
+ # Step 2: Prompt
106
+ prompt = ""
107
+ try:
108
+ prompt_builder = load_prompt_builder()
109
+ if prompt_builder:
110
+ prompt_result = prompt_builder.build_prompt(question=question, rag_context=rag_context)
111
+ if prompt_result['success']:
112
+ prompt = prompt_result['prompt']
113
+ except:
114
+ pass
115
+ if not prompt:
116
+ prompt = f"{rag_context}\nQuestion: {question}\n\nSQL:"
117
+
118
+ result['steps']['prompt'] = {'prompt': prompt, 'length': len(prompt)}
119
+
120
+ # Step 3: Fine-tuned Model
121
+ finetuned_sql = None
122
+ try:
123
+ with st.spinner("🤖 Loading AI model..."):
124
+ model = load_model()
125
+ if model:
126
+ finetuned_sql = model.generate(question, rag_context)
127
+ except Exception as e:
128
+ st.warning(f"Model error: {e}")
129
+
130
+ result['steps']['finetuned'] = {'sql': finetuned_sql, 'error': None if finetuned_sql else 'Model not available'}
131
+
132
+ if not finetuned_sql:
133
+ return result
134
+
135
+ # Step 4: Gemini Enhancement
136
+ enhanced_sql = finetuned_sql
137
+ try:
138
+ gemini = load_gemini()
139
+ if gemini:
140
+ enhance_prompt = f"""You are an SQL expert. Review and enhance this SQL query.
141
+
142
+ Original Question: {question}
143
+
144
+ Generated SQL (by a smaller model):
145
+ {finetuned_sql}
146
+
147
+ Rules:
148
+ - If the SQL is correct, return it unchanged
149
+ - If it needs fixes, return the corrected version
150
+ - Return ONLY the SQL query, no explanations
151
+
152
+ Enhanced SQL:"""
153
+ response, error = gemini.generate(enhance_prompt)
154
+ if response and not error:
155
+ enhanced_sql = response.strip()
156
+ if enhanced_sql.startswith("```"):
157
+ lines = enhanced_sql.split("\n")
158
+ enhanced_sql = "\n".join(lines[1:-1] if lines[-1] == "```" else lines[1:])
159
+ if enhanced_sql.lower().startswith("sql"):
160
+ enhanced_sql = enhanced_sql[3:].strip()
161
+ except Exception as e:
162
+ st.warning(f"Gemini enhance error: {e}")
163
+
164
+ result['steps']['gemini_enhance'] = {'sql': enhanced_sql, 'info': {'enhanced': enhanced_sql != finetuned_sql}}
165
+ result['final_sql'] = enhanced_sql
166
+
167
+ # Step 5: Explanation
168
+ explanation = ""
169
+ try:
170
+ gemini = load_gemini()
171
+ if gemini:
172
+ explain_prompt = f"Explain this SQL query in simple terms (2-3 sentences):\n\nSQL: {enhanced_sql}"
173
+ response, error = gemini.generate(explain_prompt)
174
+ if response and not error:
175
+ explanation = response.strip()
176
+ except:
177
+ pass
178
+
179
+ result['explanation'] = explanation
180
+ result['success'] = True
181
+
182
+ return result
183
+
184
+ # =============================================================================
185
+ # CUSTOM CSS
186
+ # =============================================================================
187
+
188
+ st.markdown("""
189
+ <style>
190
+ .stApp {
191
+ background: linear-gradient(135deg, #0f0f23 0%, #1a1a2e 50%, #16213e 100%);
192
+ }
193
+
194
+ .main-header {
195
+ font-size: 3rem;
196
+ font-weight: 800;
197
+ background: linear-gradient(120deg, #00d4ff, #7c3aed, #f472b6);
198
+ -webkit-background-clip: text;
199
+ -webkit-text-fill-color: transparent;
200
+ background-clip: text;
201
+ text-align: center;
202
+ margin-bottom: 0.5rem;
203
+ }
204
+
205
+ .sub-header {
206
+ font-size: 1.1rem;
207
+ color: #94a3b8;
208
+ text-align: center;
209
+ margin-bottom: 2rem;
210
+ }
211
+
212
+ .stButton > button {
213
+ background: linear-gradient(135deg, #1e293b 0%, #334155 100%);
214
+ color: #e2e8f0;
215
+ border: 1px solid #475569;
216
+ border-radius: 10px;
217
+ transition: all 0.3s ease;
218
+ }
219
+
220
+ .stButton > button:hover {
221
+ background: linear-gradient(135deg, #3b82f6 0%, #8b5cf6 100%);
222
+ border-color: #60a5fa;
223
+ transform: translateY(-2px);
224
+ }
225
+
226
+ .stTextInput > div > div > input {
227
+ background: rgba(30, 41, 59, 0.8);
228
+ border: 1px solid #475569;
229
+ border-radius: 12px;
230
+ color: #f1f5f9;
231
+ }
232
+
233
+ [data-testid="stSidebar"] {
234
+ background: linear-gradient(180deg, #0f172a 0%, #1e293b 100%);
235
+ }
236
+
237
+ .pipeline-box {
238
+ background: rgba(30, 41, 59, 0.6);
239
+ border: 1px solid #475569;
240
+ border-radius: 8px;
241
+ padding: 0.5rem 1rem;
242
+ margin: 0.25rem 0;
243
+ font-size: 0.85rem;
244
+ text-align: center;
245
+ }
246
+
247
+ .pipeline-arrow {
248
+ color: #3b82f6;
249
+ text-align: center;
250
+ font-size: 1.2rem;
251
+ }
252
+ </style>
253
+ """, unsafe_allow_html=True)
254
+
255
+ # =============================================================================
256
+ # HEADER
257
+ # =============================================================================
258
+
259
+ st.markdown('<p class="main-header">⚡ SQL Learning Assistant</p>', unsafe_allow_html=True)
260
+ st.markdown('<p class="sub-header">Transform Natural Language into SQL using AI-Powered Pipeline</p>', unsafe_allow_html=True)
261
+
262
+ # =============================================================================
263
+ # SIDEBAR
264
+ # =============================================================================
265
+
266
+ with st.sidebar:
267
+ st.markdown("## ⚙️ Configuration")
268
+ st.markdown("---")
269
+
270
+ st.markdown("### 🎯 RAG Settings")
271
+ num_examples = st.slider("Similar examples to retrieve", min_value=1, max_value=5, value=3)
272
+
273
+ st.markdown("---")
274
+
275
+ st.markdown("### 📊 System Status")
276
+ col1, col2 = st.columns(2)
277
+ with col1:
278
+ st.markdown("✅ **RAG**")
279
+ st.markdown("✅ **Model**")
280
+ with col2:
281
+ st.markdown("✅ **Prompts**")
282
+ if os.getenv("GEMINI_API_KEY"):
283
+ st.markdown("✅ **Gemini**")
284
+ else:
285
+ st.markdown("❌ **Gemini**")
286
+
287
+ st.markdown("---")
288
+
289
+ st.markdown("### 🔄 Pipeline Flow")
290
+ pipeline_steps = [
291
+ ("📦", "Synthetic Data"),
292
+ ("🎓", "Fine-tuned Model"),
293
+ ("❓", "User Question"),
294
+ ("🔍", "RAG Retrieval"),
295
+ ("📝", "Prompt Engineering"),
296
+ ("🤖", "Model Inference"),
297
+ ("✨", "Gemini Enhancement"),
298
+ ("✅", "Final Output"),
299
+ ]
300
+
301
+ for i, (icon, title) in enumerate(pipeline_steps):
302
+ st.markdown(f'<div class="pipeline-box">{icon} <strong>{title}</strong></div>', unsafe_allow_html=True)
303
+ if i < len(pipeline_steps) - 1:
304
+ st.markdown('<p class="pipeline-arrow">↓</p>', unsafe_allow_html=True)
305
+
306
+ st.markdown("---")
307
+ st.markdown("### 📚 About")
308
+ st.markdown("**Course:** INFO7375")
309
+
310
+ # =============================================================================
311
+ # MAIN CONTENT
312
+ # =============================================================================
313
+
314
+ if "messages" not in st.session_state:
315
+ st.session_state.messages = []
316
+
317
+ if "results_history" not in st.session_state:
318
+ st.session_state.results_history = []
319
+
320
+ if "input_text" not in st.session_state:
321
+ st.session_state.input_text = ""
322
+
323
+ # =============================================================================
324
+ # EXAMPLE QUESTIONS
325
+ # =============================================================================
326
+
327
+ st.markdown("### 💡 Try an Example")
328
+
329
+ example_questions = [
330
+ ("👥 Employees", "Find all employees with salary above 50000"),
331
+ ("📊 Orders", "Count total orders by customer"),
332
+ ("🏆 Top Products", "Show top 5 products by revenue"),
333
+ ("📅 Recent", "List customers who placed orders in 2024"),
334
+ ("💰 Salary", "Calculate average salary by department"),
335
+ ]
336
+
337
+ cols = st.columns(5)
338
+ for i, (label, ex_question) in enumerate(example_questions):
339
+ with cols[i]:
340
+ if st.button(label, key=f"ex_{i}", use_container_width=True, help=ex_question):
341
+ st.session_state.input_text = ex_question
342
+
343
+ # =============================================================================
344
+ # INPUT AREA
345
+ # =============================================================================
346
+
347
+ st.markdown("### 🎤 Ask Your Question")
348
+
349
+ col1, col2 = st.columns([6, 1])
350
+
351
+ with col1:
352
+ question = st.text_input(
353
+ "Question",
354
+ placeholder="e.g., Find all employees with salary greater than 50000...",
355
+ label_visibility="collapsed",
356
+ key="input_text"
357
+ )
358
+
359
+ with col2:
360
+ submit_btn = st.button("🚀 Run", type="primary", use_container_width=True)
361
+
362
+ st.markdown("---")
363
+
364
+ # =============================================================================
365
+ # CHAT HISTORY
366
+ # =============================================================================
367
+
368
+ for i, message in enumerate(st.session_state.messages):
369
+ with st.chat_message(message["role"], avatar="🧑‍💻" if message["role"] == "user" else "🤖"):
370
+ st.markdown(message["content"])
371
+
372
+ if message["role"] == "assistant":
373
+ result_idx = i // 2
374
+ if result_idx < len(st.session_state.results_history):
375
+ result = st.session_state.results_history[result_idx]
376
+ if result and result.get('success'):
377
+ with st.expander("🔍 View Pipeline Details", expanded=False):
378
+ tab1, tab2, tab3, tab4 = st.tabs(["🔍 RAG", "📝 Prompt", "🤖 Fine-tuned", "✨ Gemini"])
379
+
380
+ with tab1:
381
+ examples = result['steps']['rag'].get('examples', [])
382
+ st.markdown(f"**Retrieved {len(examples)} examples**")
383
+ for j, ex in enumerate(examples, 1):
384
+ st.markdown(f"**Example {j}** | Score: `{ex.get('score', 0):.3f}`")
385
+ st.markdown(f"Q: {ex.get('question', 'N/A')}")
386
+ st.code(ex.get('sql', 'N/A'), language="sql")
387
+
388
+ with tab2:
389
+ st.markdown("**Constructed Prompt:**")
390
+ st.code(result['steps']['prompt'].get('prompt', 'N/A'), language="text")
391
+
392
+ with tab3:
393
+ st.markdown("**Fine-tuned Model Output:**")
394
+ st.code(result['steps']['finetuned'].get('sql', 'N/A'), language="sql")
395
+
396
+ with tab4:
397
+ if 'gemini_enhance' in result['steps']:
398
+ st.markdown("**Enhanced SQL:**")
399
+ st.code(result['steps']['gemini_enhance'].get('sql', 'N/A'), language="sql")
400
+
401
+ # =============================================================================
402
+ # PROCESS QUERY
403
+ # =============================================================================
404
+
405
+ if submit_btn and question:
406
+ st.session_state.messages.append({"role": "user", "content": question})
407
+
408
+ with st.chat_message("user", avatar="🧑‍💻"):
409
+ st.markdown(question)
410
+
411
+ with st.chat_message("assistant", avatar="🤖"):
412
+ with st.status("🔄 Processing your query...", expanded=True) as status:
413
+ st.write("🔍 Retrieving similar examples...")
414
+ st.write("📝 Building prompt...")
415
+ st.write("🤖 Generating SQL...")
416
+ st.write("✨ Enhancing with Gemini...")
417
+
418
+ result = run_pipeline(question=question, num_examples=num_examples)
419
+
420
+ status.update(label="✅ Complete!", state="complete", expanded=False)
421
+
422
+ st.session_state.results_history.append(result)
423
+
424
+ if result['success']:
425
+ st.markdown("### ✅ Generated SQL")
426
+ st.code(result['final_sql'], language="sql")
427
+
428
+ if 'gemini_enhance' in result['steps']:
429
+ original = result['steps']['finetuned'].get('sql', '')
430
+ enhanced = result['steps']['gemini_enhance'].get('sql', '')
431
+ if original != enhanced:
432
+ st.success("✨ Query optimized by Gemini!")
433
+ else:
434
+ st.info("✓ Query was already optimal")
435
+
436
+ if 'explanation' in result and result['explanation']:
437
+ if not result['explanation'].startswith("Explanation error"):
438
+ st.markdown("### 📖 Explanation")
439
+ st.info(result['explanation'])
440
+
441
+ with st.expander("🔍 View Pipeline Details", expanded=False):
442
+ tab1, tab2, tab3, tab4 = st.tabs(["🔍 RAG", "📝 Prompt", "🤖 Fine-tuned", "✨ Gemini"])
443
+
444
+ with tab1:
445
+ examples = result['steps']['rag'].get('examples', [])
446
+ st.markdown(f"**Retrieved {len(examples)} examples**")
447
+ for j, ex in enumerate(examples, 1):
448
+ st.markdown(f"**Example {j}** | Score: `{ex.get('score', 0):.3f}`")
449
+ st.markdown(f"Q: {ex.get('question', 'N/A')}")
450
+ st.code(ex.get('sql', 'N/A'), language="sql")
451
+
452
+ with tab2:
453
+ st.markdown("**Constructed Prompt:**")
454
+ st.code(result['steps']['prompt'].get('prompt', 'N/A'), language="text")
455
+
456
+ with tab3:
457
+ st.markdown("**Fine-tuned Model Output:**")
458
+ st.code(result['steps']['finetuned'].get('sql', 'N/A'), language="sql")
459
+
460
+ with tab4:
461
+ if 'gemini_enhance' in result['steps']:
462
+ st.markdown("**Enhanced SQL:**")
463
+ st.code(result['steps']['gemini_enhance'].get('sql', 'N/A'), language="sql")
464
+
465
+ response_text = f"**Generated SQL:**\n```sql\n{result['final_sql']}\n```"
466
+ if 'explanation' in result and not result['explanation'].startswith("Explanation error"):
467
+ response_text += f"\n\n**Explanation:** {result['explanation']}"
468
+
469
+ st.session_state.messages.append({"role": "assistant", "content": response_text})
470
+
471
+ else:
472
+ st.error("❌ Failed to generate SQL. Please try again.")
473
+ st.session_state.messages.append({"role": "assistant", "content": "❌ Failed to generate SQL."})
474
+
475
+ elif submit_btn and not question:
476
+ st.warning("⚠️ Please enter a question first!")
477
+
478
+ # =============================================================================
479
+ # FOOTER
480
+ # =============================================================================
481
+
482
+ st.markdown("---")
483
+
484
+ col1, col2, col3 = st.columns([1, 2, 1])
485
+
486
+ with col1:
487
+ if st.button("🗑️ Clear Chat", use_container_width=True):
488
+ st.session_state.messages = []
489
+ st.session_state.results_history = []
490
+ st.session_state.input_text = ""
491
+ st.rerun()
492
+
493
+ with col2:
494
+ st.markdown('<p style="text-align: center; color: #64748b;">Built with ❤️ using Streamlit • LangChain • Gemini</p>', unsafe_allow_html=True)
495
+
496
+ with col3:
497
+ st.markdown('<p style="text-align: right; color: #64748b;"><strong>INFO7375</strong></p>', unsafe_allow_html=True)
src/config.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Central Configuration for SQL Learning Assistant
3
+ Handles local vs HuggingFace paths automatically
4
+ """
5
+
6
+ import os
7
+ from dotenv import load_dotenv
8
+
9
+ load_dotenv()
10
+
11
+ # =============================================================================
12
+ # HUGGINGFACE CONFIGURATION (for cloud deployment)
13
+ # Set these in .env or Streamlit secrets
14
+ # =============================================================================
15
+
16
+ HF_MODEL_ID = os.getenv("HF_MODEL_ID", None) # e.g., "username/sql-tinyllama-lora"
17
+ HF_CHROMADB_ID = os.getenv("HF_CHROMADB_ID", None) # e.g., "username/sql-chromadb"
18
+ HF_TOKEN = os.getenv("HF_TOKEN", None)
19
+
20
+ # =============================================================================
21
+ # LOCAL PATHS
22
+ # =============================================================================
23
+
24
+ LOCAL_MODEL_DIR = "outputs/finetuning/checkpoints/final"
25
+ LOCAL_CHROMADB_DIR = "chromadb_data"
26
+ LOCAL_DATA_DIR = "data"
27
+
28
+ # =============================================================================
29
+ # GEMINI CONFIGURATION
30
+ # =============================================================================
31
+
32
+ GEMINI_KEYS = [
33
+ os.getenv("GEMINI_API_KEY"),
34
+ os.getenv("GEMINI_API_KEY_FALLBACK_1"),
35
+ os.getenv("GEMINI_API_KEY_FALLBACK_2"),
36
+ ]
37
+ GEMINI_KEYS = [k for k in GEMINI_KEYS if k] # Remove None values
38
+
39
+ GEMINI_MODELS = [
40
+ os.getenv("GEMINI_MODEL", "gemini-2.5-flash"),
41
+ os.getenv("GEMINI_MODEL_FALLBACK_1"),
42
+ ]
43
+ GEMINI_MODELS = [m for m in GEMINI_MODELS if m] # Remove None values
44
+
45
+ # =============================================================================
46
+ # RAG CONFIGURATION
47
+ # =============================================================================
48
+
49
+ EMBEDDING_MODEL = "all-MiniLM-L6-v2"
50
+ COLLECTION_NAME = "sql_knowledge"
51
+
52
+ # =============================================================================
53
+ # HELPER FUNCTIONS
54
+ # =============================================================================
55
+
56
+ def is_local():
57
+ """Check if running locally (has local model/data)."""
58
+ return os.path.exists(LOCAL_MODEL_DIR) and os.path.exists(LOCAL_CHROMADB_DIR)
59
+
60
+ def is_cloud():
61
+ """Check if running in cloud (has HF config)."""
62
+ return HF_MODEL_ID is not None or HF_CHROMADB_ID is not None
63
+
64
+ def get_model_source():
65
+ """Get where model will be loaded from."""
66
+ if os.path.exists(LOCAL_MODEL_DIR) and os.listdir(LOCAL_MODEL_DIR):
67
+ return "local", LOCAL_MODEL_DIR
68
+ elif HF_MODEL_ID:
69
+ return "huggingface", HF_MODEL_ID
70
+ else:
71
+ return "base", "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
72
+
73
+ def get_chromadb_source():
74
+ """Get where ChromaDB will be loaded from."""
75
+ if os.path.exists(LOCAL_CHROMADB_DIR) and os.listdir(LOCAL_CHROMADB_DIR):
76
+ return "local", LOCAL_CHROMADB_DIR
77
+ elif HF_CHROMADB_ID:
78
+ return "huggingface", HF_CHROMADB_ID
79
+ else:
80
+ return "build", LOCAL_DATA_DIR
81
+
82
+ def print_config():
83
+ """Print current configuration."""
84
+ print("=" * 50)
85
+ print("CONFIGURATION")
86
+ print("=" * 50)
87
+
88
+ model_src, model_path = get_model_source()
89
+ chromadb_src, chromadb_path = get_chromadb_source()
90
+
91
+ print(f"Model: {model_src} → {model_path}")
92
+ print(f"ChromaDB: {chromadb_src} → {chromadb_path}")
93
+ print(f"Gemini Keys: {len(GEMINI_KEYS)} available")
94
+ print(f"Gemini Models: {GEMINI_MODELS}")
95
+ print("=" * 50)
96
+
97
+ if __name__ == "__main__":
98
+ print_config()
src/finetuning/__init__.py ADDED
File without changes
src/finetuning/evaluate.py ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Evaluation Module for Fine-Tuned SQL Model
3
+ """
4
+
5
+ import os
6
+ import json
7
+ import matplotlib.pyplot as plt
8
+ from datetime import datetime
9
+ from collections import Counter
10
+
11
+ # =============================================================================
12
+ # CONFIGURATION
13
+ # =============================================================================
14
+
15
+ OUTPUT_DIR = "outputs/finetuning"
16
+ RESULTS_DIR = f"{OUTPUT_DIR}/results"
17
+ VIZ_DIR = f"{OUTPUT_DIR}/visualizations"
18
+
19
+ # Number of samples to evaluate
20
+ NUM_EVAL_SAMPLES = 50 # Change for more/less evaluation
21
+
22
+ def setup_directories():
23
+ for d in [RESULTS_DIR, VIZ_DIR]:
24
+ os.makedirs(d, exist_ok=True)
25
+
26
+ # =============================================================================
27
+ # EVALUATION METRICS
28
+ # =============================================================================
29
+
30
+ def exact_match(pred, expected):
31
+ """Check exact match."""
32
+ return pred.lower().strip() == expected.lower().strip()
33
+
34
+ def token_accuracy(pred, expected):
35
+ """Token overlap accuracy."""
36
+ pred_tokens = set(pred.lower().split())
37
+ exp_tokens = set(expected.lower().split())
38
+ if not exp_tokens:
39
+ return 0.0
40
+ return len(pred_tokens & exp_tokens) / len(exp_tokens)
41
+
42
+ def keyword_accuracy(pred, expected):
43
+ """SQL keyword match accuracy."""
44
+ keywords = ['SELECT', 'FROM', 'WHERE', 'JOIN', 'GROUP BY',
45
+ 'ORDER BY', 'COUNT', 'SUM', 'AVG', 'MAX', 'MIN']
46
+
47
+ pred_kw = [k for k in keywords if k in pred.upper()]
48
+ exp_kw = [k for k in keywords if k in expected.upper()]
49
+
50
+ if not exp_kw:
51
+ return 1.0 if not pred_kw else 0.0
52
+
53
+ matches = sum(1 for k in exp_kw if k in pred_kw)
54
+ return matches / len(exp_kw)
55
+
56
+ def structure_similarity(pred, expected):
57
+ """SQL structure similarity."""
58
+ clauses = ['SELECT', 'FROM', 'WHERE', 'JOIN', 'GROUP BY', 'ORDER BY', 'LIMIT']
59
+
60
+ pred_struct = set(c for c in clauses if c in pred.upper())
61
+ exp_struct = set(c for c in clauses if c in expected.upper())
62
+
63
+ if not exp_struct and not pred_struct:
64
+ return 1.0
65
+ if not exp_struct or not pred_struct:
66
+ return 0.0
67
+
68
+ return len(pred_struct & exp_struct) / len(pred_struct | exp_struct)
69
+
70
+ # =============================================================================
71
+ # EVALUATION RUNNER
72
+ # =============================================================================
73
+
74
+ def evaluate_predictions(predictions, ground_truth):
75
+ """Calculate all metrics."""
76
+
77
+ results = {
78
+ 'exact_match': [],
79
+ 'token_accuracy': [],
80
+ 'keyword_accuracy': [],
81
+ 'structure_similarity': []
82
+ }
83
+
84
+ for pred, exp in zip(predictions, ground_truth):
85
+ results['exact_match'].append(1 if exact_match(pred, exp) else 0)
86
+ results['token_accuracy'].append(token_accuracy(pred, exp))
87
+ results['keyword_accuracy'].append(keyword_accuracy(pred, exp))
88
+ results['structure_similarity'].append(structure_similarity(pred, exp))
89
+
90
+ # Calculate averages
91
+ metrics = {
92
+ 'total_samples': len(predictions),
93
+ 'exact_match_rate': sum(results['exact_match']) / len(results['exact_match']),
94
+ 'avg_token_accuracy': sum(results['token_accuracy']) / len(results['token_accuracy']),
95
+ 'avg_keyword_accuracy': sum(results['keyword_accuracy']) / len(results['keyword_accuracy']),
96
+ 'avg_structure_similarity': sum(results['structure_similarity']) / len(results['structure_similarity']),
97
+ 'detailed': results
98
+ }
99
+
100
+ return metrics
101
+
102
+ # =============================================================================
103
+ # VISUALIZATIONS
104
+ # =============================================================================
105
+
106
+ def create_visualizations(metrics):
107
+ """Create evaluation charts."""
108
+
109
+ setup_directories()
110
+ plt.style.use('seaborn-v0_8-whitegrid')
111
+
112
+ # 1. Metrics Overview
113
+ fig, ax = plt.subplots(figsize=(10, 6))
114
+
115
+ names = ['Exact Match', 'Token Acc', 'Keyword Acc', 'Structure Sim']
116
+ values = [
117
+ metrics['exact_match_rate'] * 100,
118
+ metrics['avg_token_accuracy'] * 100,
119
+ metrics['avg_keyword_accuracy'] * 100,
120
+ metrics['avg_structure_similarity'] * 100
121
+ ]
122
+ colors = ['#3498db', '#2ecc71', '#9b59b6', '#e74c3c']
123
+
124
+ bars = ax.bar(names, values, color=colors, edgecolor='black')
125
+ for bar, val in zip(bars, values):
126
+ ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 1,
127
+ f'{val:.1f}%', ha='center', fontweight='bold')
128
+
129
+ ax.set_ylabel('Score (%)')
130
+ ax.set_title('Model Evaluation Metrics', fontsize=14, fontweight='bold')
131
+ ax.set_ylim(0, 110)
132
+
133
+ plt.tight_layout()
134
+ plt.savefig(f'{VIZ_DIR}/01_metrics_overview.png', dpi=150)
135
+ plt.close()
136
+ print(f" Saved: {VIZ_DIR}/01_metrics_overview.png")
137
+
138
+ # 2. Token Accuracy Distribution
139
+ fig, ax = plt.subplots(figsize=(10, 6))
140
+
141
+ token_acc = metrics['detailed']['token_accuracy']
142
+ ax.hist(token_acc, bins=20, color='#2ecc71', edgecolor='black', alpha=0.7)
143
+ ax.axvline(sum(token_acc)/len(token_acc), color='red', linestyle='--',
144
+ label=f"Mean: {sum(token_acc)/len(token_acc):.2f}")
145
+ ax.set_xlabel('Token Accuracy')
146
+ ax.set_ylabel('Frequency')
147
+ ax.set_title('Token Accuracy Distribution', fontsize=14, fontweight='bold')
148
+ ax.legend()
149
+
150
+ plt.tight_layout()
151
+ plt.savefig(f'{VIZ_DIR}/02_token_accuracy_dist.png', dpi=150)
152
+ plt.close()
153
+ print(f" Saved: {VIZ_DIR}/02_token_accuracy_dist.png")
154
+
155
+ # 3. Keyword Accuracy Distribution
156
+ fig, ax = plt.subplots(figsize=(10, 6))
157
+
158
+ kw_acc = metrics['detailed']['keyword_accuracy']
159
+ ax.hist(kw_acc, bins=20, color='#9b59b6', edgecolor='black', alpha=0.7)
160
+ ax.axvline(sum(kw_acc)/len(kw_acc), color='red', linestyle='--',
161
+ label=f"Mean: {sum(kw_acc)/len(kw_acc):.2f}")
162
+ ax.set_xlabel('Keyword Accuracy')
163
+ ax.set_ylabel('Frequency')
164
+ ax.set_title('Keyword Accuracy Distribution', fontsize=14, fontweight='bold')
165
+ ax.legend()
166
+
167
+ plt.tight_layout()
168
+ plt.savefig(f'{VIZ_DIR}/03_keyword_accuracy_dist.png', dpi=150)
169
+ plt.close()
170
+ print(f" Saved: {VIZ_DIR}/03_keyword_accuracy_dist.png")
171
+
172
+ # =============================================================================
173
+ # REPORT GENERATION
174
+ # =============================================================================
175
+
176
+ def generate_report(metrics):
177
+ """Generate evaluation report."""
178
+
179
+ report = f"""# Fine-Tuning Evaluation Report
180
+
181
+ **Generated:** {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
182
+
183
+ ## Metrics Summary
184
+
185
+ | Metric | Score |
186
+ |--------|-------|
187
+ | Samples Evaluated | {metrics['total_samples']} |
188
+ | Exact Match Rate | {metrics['exact_match_rate']*100:.2f}% |
189
+ | Token Accuracy | {metrics['avg_token_accuracy']*100:.2f}% |
190
+ | Keyword Accuracy | {metrics['avg_keyword_accuracy']*100:.2f}% |
191
+ | Structure Similarity | {metrics['avg_structure_similarity']*100:.2f}% |
192
+
193
+ ## Metrics Explanation
194
+
195
+ - **Exact Match**: Predictions identical to ground truth
196
+ - **Token Accuracy**: Word overlap between prediction and expected
197
+ - **Keyword Accuracy**: SQL keywords (SELECT, WHERE, etc.) match
198
+ - **Structure Similarity**: Query structure (clauses used) match
199
+
200
+ ## Visualizations
201
+
202
+ - `01_metrics_overview.png` - All metrics bar chart
203
+ - `02_token_accuracy_dist.png` - Token accuracy histogram
204
+ - `03_keyword_accuracy_dist.png` - Keyword accuracy histogram
205
+ """
206
+
207
+ with open(f'{RESULTS_DIR}/evaluation_report.md', 'w') as f:
208
+ f.write(report)
209
+ print(f" Saved: {RESULTS_DIR}/evaluation_report.md")
210
+
211
+ # Save JSON
212
+ json_metrics = {k: v for k, v in metrics.items() if k != 'detailed'}
213
+ with open(f'{RESULTS_DIR}/evaluation_results.json', 'w') as f:
214
+ json.dump(json_metrics, f, indent=2)
215
+ print(f" Saved: {RESULTS_DIR}/evaluation_results.json")
216
+
217
+ # =============================================================================
218
+ # MAIN EVALUATION
219
+ # =============================================================================
220
+
221
+ def run_evaluation():
222
+ """Run full evaluation."""
223
+
224
+ print("=" * 60)
225
+ print("EVALUATING FINE-TUNED MODEL")
226
+ print("=" * 60)
227
+
228
+ setup_directories()
229
+
230
+ # Load test data
231
+ print("\n[1/4] Loading test data...")
232
+ test_file = f"{OUTPUT_DIR}/test.jsonl"
233
+
234
+ if not os.path.exists(test_file):
235
+ print("ERROR: Run prepare_data.py first!")
236
+ return None
237
+
238
+ test_data = []
239
+ with open(test_file) as f:
240
+ for line in f:
241
+ test_data.append(json.loads(line))
242
+
243
+ test_data = test_data[:NUM_EVAL_SAMPLES]
244
+ print(f" Loaded {len(test_data)} samples")
245
+
246
+ # Generate predictions
247
+ print("\n[2/4] Generating predictions...")
248
+
249
+ try:
250
+ import sys
251
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
252
+ from finetuning.inference import SQLGenerator
253
+ generator = SQLGenerator()
254
+
255
+ predictions = []
256
+ ground_truth = []
257
+
258
+ for i, item in enumerate(test_data):
259
+ pred = generator.generate(item['question'])
260
+ predictions.append(pred)
261
+ ground_truth.append(item['sql'])
262
+
263
+ if (i + 1) % 10 == 0:
264
+ print(f" Progress: {i+1}/{len(test_data)}")
265
+
266
+ except Exception as e:
267
+ print(f" Error loading model: {e}")
268
+ print(" Using ground truth as predictions (for testing metrics)")
269
+ predictions = [item['sql'] for item in test_data]
270
+ ground_truth = [item['sql'] for item in test_data]
271
+
272
+ # Calculate metrics
273
+ print("\n[3/4] Calculating metrics...")
274
+ metrics = evaluate_predictions(predictions, ground_truth)
275
+
276
+ print(f" Exact Match: {metrics['exact_match_rate']*100:.2f}%")
277
+ print(f" Token Accuracy: {metrics['avg_token_accuracy']*100:.2f}%")
278
+ print(f" Keyword Accuracy: {metrics['avg_keyword_accuracy']*100:.2f}%")
279
+ print(f" Structure Sim: {metrics['avg_structure_similarity']*100:.2f}%")
280
+
281
+ # Generate outputs
282
+ print("\n[4/4] Generating outputs...")
283
+ create_visualizations(metrics)
284
+ generate_report(metrics)
285
+
286
+ print("\n" + "=" * 60)
287
+ print("EVALUATION COMPLETE")
288
+ print("=" * 60)
289
+
290
+ return metrics
291
+
292
+ if __name__ == "__main__":
293
+ run_evaluation()
src/finetuning/inference.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Inference Module for Fine-Tuned SQL Model
3
+ Loads from: Local checkpoint OR Hugging Face Hub
4
+ """
5
+
6
+ import os
7
+ import torch
8
+ from transformers import AutoModelForCausalLM, AutoTokenizer
9
+ from dotenv import load_dotenv
10
+
11
+ load_dotenv()
12
+
13
+ # =============================================================================
14
+ # CONFIGURATION
15
+ # =============================================================================
16
+
17
+ # Hugging Face Model ID (set in .env or Streamlit secrets)
18
+ HF_MODEL_ID = os.getenv("HF_MODEL_ID", None)
19
+
20
+ # Local paths
21
+ LOCAL_MODEL_DIR = "outputs/finetuning/checkpoints/final"
22
+ BASE_MODEL = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
23
+
24
+ # =============================================================================
25
+ # SQL GENERATOR CLASS
26
+ # =============================================================================
27
+
28
+ class SQLGenerator:
29
+ """SQL Generation using fine-tuned model."""
30
+
31
+ def __init__(self):
32
+ """Load the fine-tuned model from local or HuggingFace."""
33
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
34
+ print(f"Device: {self.device}")
35
+
36
+ load_path = self._get_model_path()
37
+
38
+ # Load tokenizer and model with memory optimization
39
+ print(f"Loading model from: {load_path}")
40
+ self.tokenizer = AutoTokenizer.from_pretrained(load_path)
41
+
42
+ # Memory-efficient loading for cloud deployment
43
+ self.model = AutoModelForCausalLM.from_pretrained(
44
+ load_path,
45
+ torch_dtype=torch.float32, # Use float32 for CPU
46
+ device_map=None, # Don't use device_map on CPU
47
+ low_cpu_mem_usage=True, # Reduce memory during loading
48
+ trust_remote_code=True
49
+ )
50
+
51
+ # Move to device after loading
52
+ self.model = self.model.to(self.device)
53
+
54
+ self.tokenizer.pad_token = self.tokenizer.eos_token
55
+ print("✓ Model loaded!")
56
+
57
+ def _get_model_path(self):
58
+ """Determine where to load model from."""
59
+
60
+ # Check for required model files (not just folder existence)
61
+ required_files = ['config.json', 'tokenizer.json', 'tokenizer_config.json']
62
+
63
+ # Priority 1: Local checkpoint with actual model files
64
+ if os.path.exists(LOCAL_MODEL_DIR):
65
+ local_files = os.listdir(LOCAL_MODEL_DIR) if os.path.isdir(LOCAL_MODEL_DIR) else []
66
+ has_model_files = any(f in local_files for f in required_files) or any(f.endswith('.safetensors') or f.endswith('.bin') for f in local_files)
67
+
68
+ if has_model_files:
69
+ print(f"📁 Found local model checkpoint: {LOCAL_MODEL_DIR}")
70
+ return LOCAL_MODEL_DIR
71
+ else:
72
+ print(f"⚠️ Local folder exists but no model files found")
73
+
74
+ # Priority 2: Download from HuggingFace Hub
75
+ if HF_MODEL_ID:
76
+ print(f"☁️ Downloading model from HuggingFace: {HF_MODEL_ID}")
77
+ return HF_MODEL_ID
78
+
79
+ # Priority 3: Base model fallback
80
+ print("⚠️ No fine-tuned model found, using base model")
81
+ return BASE_MODEL
82
+
83
+ def generate(self, question, context="", max_tokens=128):
84
+ """Generate SQL from question."""
85
+
86
+ # Build prompt
87
+ if context:
88
+ prompt = f"""{context}
89
+
90
+ ### Question:
91
+ {question}
92
+
93
+ ### SQL:"""
94
+ else:
95
+ prompt = f"""### Question:
96
+ {question}
97
+
98
+ ### SQL:"""
99
+
100
+ # Tokenize
101
+ inputs = self.tokenizer(
102
+ prompt,
103
+ return_tensors="pt",
104
+ truncation=True,
105
+ max_length=512
106
+ ).to(self.device)
107
+
108
+ # Generate
109
+ with torch.no_grad():
110
+ outputs = self.model.generate(
111
+ **inputs,
112
+ max_new_tokens=max_tokens,
113
+ temperature=0.1,
114
+ do_sample=True,
115
+ top_p=0.95,
116
+ pad_token_id=self.tokenizer.eos_token_id
117
+ )
118
+
119
+ # Decode
120
+ generated = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
121
+
122
+ # Extract SQL
123
+ sql = generated[len(prompt):].strip()
124
+ if "###" in sql:
125
+ sql = sql.split("###")[0].strip()
126
+
127
+ return sql
128
+
129
+ # =============================================================================
130
+ # STANDALONE FUNCTION
131
+ # =============================================================================
132
+
133
+ _generator = None
134
+
135
+ def generate_sql(question, context=""):
136
+ """Standalone SQL generation."""
137
+ global _generator
138
+ if _generator is None:
139
+ _generator = SQLGenerator()
140
+ return _generator.generate(question, context)
141
+
142
+ # =============================================================================
143
+ # TEST
144
+ # =============================================================================
145
+
146
+ def test_inference():
147
+ """Test the model."""
148
+ print("=" * 60)
149
+ print("TESTING SQL GENERATION")
150
+ print("=" * 60)
151
+
152
+ generator = SQLGenerator()
153
+
154
+ questions = [
155
+ "Find all employees with salary greater than 50000",
156
+ ]
157
+
158
+ print("\n" + "-" * 60)
159
+ for q in questions:
160
+ print(f"Q: {q}")
161
+ sql = generator.generate(q)
162
+ print(f"SQL: {sql}")
163
+ print("-" * 60)
164
+
165
+ print("\n✓ Test complete")
166
+
167
+ if __name__ == "__main__":
168
+ test_inference()
src/finetuning/prepare_data.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Data Preparation for Fine-Tuning
3
+ Uses train.csv, validation.csv, test.csv correctly.
4
+ """
5
+
6
+ import os
7
+ import pandas as pd
8
+ import json
9
+ from datetime import datetime
10
+
11
+ # =============================================================================
12
+ # CONFIGURATION
13
+ # =============================================================================
14
+
15
+ OUTPUT_DIR = "outputs/finetuning"
16
+ DATA_DIR = "data"
17
+
18
+ # Change this for testing vs full training
19
+ MAX_SAMPLES = 100 # Set to None for full data
20
+
21
+ def setup_directories():
22
+ for d in [OUTPUT_DIR, f"{OUTPUT_DIR}/results", f"{OUTPUT_DIR}/logs"]:
23
+ os.makedirs(d, exist_ok=True)
24
+
25
+ # =============================================================================
26
+ # PROMPT TEMPLATE
27
+ # =============================================================================
28
+
29
+ def format_for_training(question, sql):
30
+ """Format single example for instruction fine-tuning."""
31
+ text = f"""### Question:
32
+ {question}
33
+
34
+ ### SQL:
35
+ {sql}"""
36
+ return text
37
+
38
+ # =============================================================================
39
+ # DATA LOADING
40
+ # =============================================================================
41
+
42
+ def load_csv_file(filepath, max_samples=None):
43
+ """Load a single CSV file."""
44
+ if not os.path.exists(filepath):
45
+ print(f" File not found: {filepath}")
46
+ return None
47
+
48
+ df = pd.read_csv(filepath)
49
+
50
+ if max_samples and len(df) > max_samples:
51
+ df = df.sample(n=max_samples, random_state=42)
52
+
53
+ return df
54
+
55
+ def format_dataframe(df, source_name):
56
+ """Convert dataframe to training format."""
57
+ formatted = []
58
+ for _, row in df.iterrows():
59
+ formatted.append({
60
+ 'text': format_for_training(row['question'], row['sql']),
61
+ 'question': str(row['question']),
62
+ 'sql': str(row['sql']),
63
+ 'source': source_name
64
+ })
65
+ return formatted
66
+
67
+ def save_jsonl(data, filepath):
68
+ """Save data as JSONL file."""
69
+ with open(filepath, 'w', encoding='utf-8') as f:
70
+ for item in data:
71
+ f.write(json.dumps(item) + '\n')
72
+ print(f" Saved: {filepath}")
73
+
74
+ # =============================================================================
75
+ # MAIN FUNCTION
76
+ # =============================================================================
77
+
78
+ def prepare_finetuning_data():
79
+ """Prepare data for fine-tuning."""
80
+
81
+ print("=" * 50)
82
+ print("PREPARING FINE-TUNING DATA")
83
+ print(f"Max samples per file: {MAX_SAMPLES if MAX_SAMPLES else 'ALL'}")
84
+ print("=" * 50)
85
+
86
+ setup_directories()
87
+
88
+ # Load train data
89
+ print("\n[1/5] Loading training data...")
90
+ train_df = load_csv_file(f"{DATA_DIR}/train.csv", MAX_SAMPLES)
91
+ print(f" train.csv: {len(train_df):,} rows")
92
+
93
+ # Load synthetic and combine with train
94
+ # synthetic_df = load_csv_file(f"{DATA_DIR}/synthetic.csv", MAX_SAMPLES)
95
+ # if synthetic_df is not None:
96
+ # print(f" synthetic.csv: {len(synthetic_df):,} rows")
97
+ # train_df = pd.concat([train_df, synthetic_df], ignore_index=True)
98
+ # print(f" Combined training: {len(train_df):,} rows")
99
+
100
+ # Load validation data
101
+ print("\n[2/5] Loading validation data...")
102
+ val_df = load_csv_file(f"{DATA_DIR}/validation.csv", MAX_SAMPLES)
103
+ print(f" validation.csv: {len(val_df):,} rows")
104
+
105
+ # Load test data
106
+ print("\n[3/5] Loading test data...")
107
+ test_df = load_csv_file(f"{DATA_DIR}/test.csv", MAX_SAMPLES)
108
+ print(f" test.csv: {len(test_df):,} rows")
109
+
110
+ # Format data
111
+ print("\n[4/5] Formatting data...")
112
+ train_data = format_dataframe(train_df, 'train')
113
+ val_data = format_dataframe(val_df, 'validation')
114
+ test_data = format_dataframe(test_df, 'test')
115
+
116
+ # Save files
117
+ print("\n[5/5] Saving files...")
118
+ save_jsonl(train_data, f"{OUTPUT_DIR}/train.jsonl")
119
+ save_jsonl(val_data, f"{OUTPUT_DIR}/val.jsonl")
120
+ save_jsonl(test_data, f"{OUTPUT_DIR}/test.jsonl")
121
+
122
+ # Save stats
123
+ stats = {
124
+ 'train_samples': len(train_data),
125
+ 'val_samples': len(val_data),
126
+ 'test_samples': len(test_data),
127
+ 'max_samples': MAX_SAMPLES,
128
+ 'created_at': datetime.now().isoformat()
129
+ }
130
+
131
+ with open(f"{OUTPUT_DIR}/data_stats.json", 'w') as f:
132
+ json.dump(stats, f, indent=2)
133
+
134
+ # Summary
135
+ print("\n" + "=" * 50)
136
+ print("COMPLETE")
137
+ print("=" * 50)
138
+ print(f" Train: {len(train_data):,}")
139
+ print(f" Val: {len(val_data):,}")
140
+ print(f" Test: {len(test_data):,}")
141
+
142
+ return stats
143
+
144
+ # =============================================================================
145
+ # ENTRY POINT
146
+ # =============================================================================
147
+
148
+ if __name__ == "__main__":
149
+ prepare_finetuning_data()
src/finetuning/train.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Fine-Tuning Script for SQL Generation Model
3
+ Uses LoRA for efficient fine-tuning.
4
+ """
5
+
6
+ import os
7
+ import json
8
+ import torch
9
+ from datetime import datetime
10
+ from datasets import load_dataset
11
+ from transformers import (
12
+ AutoModelForCausalLM,
13
+ AutoTokenizer,
14
+ TrainingArguments,
15
+ Trainer,
16
+ DataCollatorForLanguageModeling
17
+ )
18
+ from peft import LoraConfig, get_peft_model, TaskType
19
+
20
+ # =============================================================================
21
+ # CONFIGURATION
22
+ # =============================================================================
23
+
24
+ MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
25
+ OUTPUT_DIR = "outputs/finetuning"
26
+ CHECKPOINT_DIR = f"{OUTPUT_DIR}/checkpoints"
27
+ LOGS_DIR = f"{OUTPUT_DIR}/logs"
28
+
29
+ # Training config (optimized for RTX 4070)
30
+ TRAINING_CONFIG = {
31
+ 'num_epochs': 3,
32
+ 'batch_size': 8,
33
+ 'learning_rate': 2e-4,
34
+ 'max_length': 256,
35
+ 'warmup_steps': 100,
36
+ 'logging_steps': 50,
37
+ 'save_steps': 500,
38
+ 'gradient_accumulation_steps': 2,
39
+ }
40
+
41
+ # LoRA config
42
+ LORA_CONFIG = {
43
+ 'r': 16,
44
+ 'lora_alpha': 32,
45
+ 'lora_dropout': 0.1,
46
+ 'target_modules': ['q_proj', 'v_proj', 'k_proj', 'o_proj']
47
+ }
48
+
49
+ def setup_directories():
50
+ for d in [OUTPUT_DIR, CHECKPOINT_DIR, LOGS_DIR]:
51
+ os.makedirs(d, exist_ok=True)
52
+
53
+ # =============================================================================
54
+ # TRAINING FUNCTIONS
55
+ # =============================================================================
56
+
57
+ def load_data():
58
+ """Load prepared training data."""
59
+ train_file = f"{OUTPUT_DIR}/train.jsonl"
60
+ val_file = f"{OUTPUT_DIR}/val.jsonl"
61
+
62
+ if not os.path.exists(train_file):
63
+ raise FileNotFoundError("Run prepare_data.py first!")
64
+
65
+ return load_dataset('json', data_files={
66
+ 'train': train_file,
67
+ 'validation': val_file
68
+ })
69
+
70
+ def setup_model():
71
+ """Load model and tokenizer with LoRA."""
72
+ print(f"Loading: {MODEL_NAME}")
73
+
74
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
75
+ tokenizer.pad_token = tokenizer.eos_token
76
+ tokenizer.padding_side = "right"
77
+
78
+ model = AutoModelForCausalLM.from_pretrained(
79
+ MODEL_NAME,
80
+ torch_dtype=torch.float16,
81
+ device_map="auto"
82
+ )
83
+
84
+ lora_config = LoraConfig(
85
+ task_type=TaskType.CAUSAL_LM,
86
+ r=LORA_CONFIG['r'],
87
+ lora_alpha=LORA_CONFIG['lora_alpha'],
88
+ lora_dropout=LORA_CONFIG['lora_dropout'],
89
+ target_modules=LORA_CONFIG['target_modules']
90
+ )
91
+
92
+ model = get_peft_model(model, lora_config)
93
+ model.print_trainable_parameters()
94
+
95
+ return model, tokenizer
96
+
97
+ def tokenize(examples, tokenizer):
98
+ """Tokenize examples."""
99
+ return tokenizer(
100
+ examples['text'],
101
+ truncation=True,
102
+ padding='max_length',
103
+ max_length=TRAINING_CONFIG['max_length']
104
+ )
105
+
106
+ def train(model, tokenizer, dataset):
107
+ """Train the model."""
108
+
109
+ # Tokenize
110
+ print("Tokenizing...")
111
+ tokenized_train = dataset['train'].map(
112
+ lambda x: tokenize(x, tokenizer),
113
+ batched=True,
114
+ remove_columns=dataset['train'].column_names
115
+ )
116
+ tokenized_val = dataset['validation'].map(
117
+ lambda x: tokenize(x, tokenizer),
118
+ batched=True,
119
+ remove_columns=dataset['validation'].column_names
120
+ )
121
+
122
+ # Training args
123
+ training_args = TrainingArguments(
124
+ output_dir=CHECKPOINT_DIR,
125
+ num_train_epochs=TRAINING_CONFIG['num_epochs'],
126
+ per_device_train_batch_size=TRAINING_CONFIG['batch_size'],
127
+ per_device_eval_batch_size=TRAINING_CONFIG['batch_size'],
128
+ learning_rate=TRAINING_CONFIG['learning_rate'],
129
+ warmup_steps=TRAINING_CONFIG['warmup_steps'],
130
+ logging_steps=TRAINING_CONFIG['logging_steps'],
131
+ save_steps=TRAINING_CONFIG['save_steps'],
132
+ gradient_accumulation_steps=TRAINING_CONFIG['gradient_accumulation_steps'],
133
+ eval_strategy="steps",
134
+ eval_steps=TRAINING_CONFIG['save_steps'],
135
+ save_total_limit=2,
136
+ fp16=True,
137
+ report_to="none",
138
+ logging_dir=LOGS_DIR,
139
+ dataloader_pin_memory=False,
140
+ )
141
+
142
+ # Trainer
143
+ trainer = Trainer(
144
+ model=model,
145
+ args=training_args,
146
+ train_dataset=tokenized_train,
147
+ eval_dataset=tokenized_val,
148
+ data_collator=DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
149
+ )
150
+
151
+ # Train
152
+ print(f"\nTraining: {len(tokenized_train)} samples, {TRAINING_CONFIG['num_epochs']} epochs")
153
+ result = trainer.train()
154
+
155
+ # Save
156
+ print("\nSaving model...")
157
+ trainer.save_model(f"{CHECKPOINT_DIR}/final")
158
+ tokenizer.save_pretrained(f"{CHECKPOINT_DIR}/final")
159
+
160
+ # Stats
161
+ stats = {
162
+ 'train_loss': result.training_loss,
163
+ 'runtime_seconds': result.metrics['train_runtime'],
164
+ 'samples_per_second': result.metrics['train_samples_per_second'],
165
+ 'epochs': TRAINING_CONFIG['num_epochs'],
166
+ 'total_steps': result.global_step,
167
+ 'gpu': torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU',
168
+ 'completed_at': datetime.now().isoformat()
169
+ }
170
+
171
+ with open(f"{CHECKPOINT_DIR}/training_stats.json", 'w') as f:
172
+ json.dump(stats, f, indent=2)
173
+
174
+ return stats
175
+
176
+ # =============================================================================
177
+ # MAIN
178
+ # =============================================================================
179
+
180
+ def run_finetuning():
181
+ """Main function."""
182
+
183
+ print("=" * 60)
184
+ print("FINE-TUNING SQL MODEL")
185
+ if torch.cuda.is_available():
186
+ print(f"GPU: {torch.cuda.get_device_name(0)}")
187
+ else:
188
+ print("GPU: Not available (using CPU)")
189
+ print("=" * 60)
190
+
191
+ setup_directories()
192
+
193
+ # Load data
194
+ print("\n[1/3] Loading data...")
195
+ dataset = load_data()
196
+ print(f" Train: {len(dataset['train']):,}")
197
+ print(f" Val: {len(dataset['validation']):,}")
198
+
199
+ # Setup model
200
+ print("\n[2/3] Setting up model...")
201
+ model, tokenizer = setup_model()
202
+
203
+ # Train
204
+ print("\n[3/3] Training...")
205
+ stats = train(model, tokenizer, dataset)
206
+
207
+ # Done
208
+ print("\n" + "=" * 60)
209
+ print("TRAINING COMPLETE")
210
+ print("=" * 60)
211
+ print(f" Loss: {stats['train_loss']:.4f}")
212
+ print(f" Time: {stats['runtime_seconds']/60:.1f} min")
213
+ print(f" Model: {CHECKPOINT_DIR}/final")
214
+
215
+ return stats
216
+
217
+ if __name__ == "__main__":
218
+ run_finetuning()
src/outputs/finetuning/data_stats.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "train_samples": 100,
3
+ "val_samples": 100,
4
+ "test_samples": 100,
5
+ "max_samples": 100,
6
+ "created_at": "2025-12-08T01:22:29.002186"
7
+ }
src/outputs/finetuning/results/evaluation_report.md ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Fine-Tuning Evaluation Report
2
+
3
+ **Generated:** 2025-12-08 01:32:37
4
+
5
+ ## Metrics Summary
6
+
7
+ | Metric | Score |
8
+ |--------|-------|
9
+ | Samples Evaluated | 50 |
10
+ | Exact Match Rate | 0.00% |
11
+ | Token Accuracy | 47.21% |
12
+ | Keyword Accuracy | 91.33% |
13
+ | Structure Similarity | 91.07% |
14
+
15
+ ## Metrics Explanation
16
+
17
+ - **Exact Match**: Predictions identical to ground truth
18
+ - **Token Accuracy**: Word overlap between prediction and expected
19
+ - **Keyword Accuracy**: SQL keywords (SELECT, WHERE, etc.) match
20
+ - **Structure Similarity**: Query structure (clauses used) match
21
+
22
+ ## Visualizations
23
+
24
+ - `01_metrics_overview.png` - All metrics bar chart
25
+ - `02_token_accuracy_dist.png` - Token accuracy histogram
26
+ - `03_keyword_accuracy_dist.png` - Keyword accuracy histogram
src/outputs/finetuning/results/evaluation_results.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "total_samples": 50,
3
+ "exact_match_rate": 0.0,
4
+ "avg_token_accuracy": 0.472115604983252,
5
+ "avg_keyword_accuracy": 0.9133333333333334,
6
+ "avg_structure_similarity": 0.9106666666666667
7
+ }
src/outputs/finetuning/test.jsonl ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {"text": "### Question:\nWhat is the whole of Drawn that has a Lost of 4?\n\n### SQL:\nSELECT SUM Drawn FROM table WHERE Lost = 4", "question": "What is the whole of Drawn that has a Lost of 4?", "sql": "SELECT SUM Drawn FROM table WHERE Lost = 4", "source": "test"}
2
+ {"text": "### Question:\nWhat is the rating of the episode with a rating/share of 0.9/4?\n\n### SQL:\nSELECT Rating FROM table WHERE Rating/Share (18\u201349) = 0.9/4", "question": "What is the rating of the episode with a rating/share of 0.9/4?", "sql": "SELECT Rating FROM table WHERE Rating/Share (18\u201349) = 0.9/4", "source": "test"}
3
+ {"text": "### Question:\nWhat is the last year that someone is first elected in this table?\n\n### SQL:\nSELECT MAX First elected FROM table", "question": "What is the last year that someone is first elected in this table?", "sql": "SELECT MAX First elected FROM table", "source": "test"}
4
+ {"text": "### Question:\nHow many poles had 72 points?\n\n### SQL:\nSELECT COUNT Poles FROM table WHERE Points = 72", "question": "How many poles had 72 points?", "sql": "SELECT COUNT Poles FROM table WHERE Points = 72", "source": "test"}
5
+ {"text": "### Question:\nWho are all the Moto2 winners when the grand prix was Shell Advance Malaysian Grand Prix?\n\n### SQL:\nSELECT Moto2 winner FROM table WHERE Grand Prix = Shell Advance Malaysian Grand Prix", "question": "Who are all the Moto2 winners when the grand prix was Shell Advance Malaysian Grand Prix?", "sql": "SELECT Moto2 winner FROM table WHERE Grand Prix = Shell Advance Malaysian Grand Prix", "source": "test"}
6
+ {"text": "### Question:\nHow many incumbents are there in the georgia 8 district when the party is democratic?\n\n### SQL:\nSELECT COUNT Incumbent FROM table WHERE Party = Democratic AND District = Georgia 8", "question": "How many incumbents are there in the georgia 8 district when the party is democratic?", "sql": "SELECT COUNT Incumbent FROM table WHERE Party = Democratic AND District = Georgia 8", "source": "test"}
7
+ {"text": "### Question:\nWhat is the largest Area (msr) that has an Area less than 291.045, is part of the Per family, and has a rank higher than 78?\n\n### SQL:\nSELECT MAX Area (msr) FROM table WHERE Area (sq.deg.) < 291.045 AND Family = per AND Rank > 78", "question": "What is the largest Area (msr) that has an Area less than 291.045, is part of the Per family, and has a rank higher than 78?", "sql": "SELECT MAX Area (msr) FROM table WHERE Area (sq.deg.) < 291.045 AND Family = per AND Rank > 78", "source": "test"}
8
+ {"text": "### Question:\nWho won women's double in 2002, the year that Kenneth Vella won men's singles?\n\n### SQL:\nSELECT Women's doubles FROM table WHERE Men's singles = kenneth vella AND Year = 2002", "question": "Who won women's double in 2002, the year that Kenneth Vella won men's singles?", "sql": "SELECT Women's doubles FROM table WHERE Men's singles = kenneth vella AND Year = 2002", "source": "test"}
9
+ {"text": "### Question:\nWhat's k. j. choi's to par?\n\n### SQL:\nSELECT To par FROM table WHERE Player = k. j. choi", "question": "What's k. j. choi's to par?", "sql": "SELECT To par FROM table WHERE Player = k. j. choi", "source": "test"}
10
+ {"text": "### Question:\nWho was the winner when the time was 1:24.00?\n\n### SQL:\nSELECT Winner/2nd FROM table WHERE Time = 1:24.00", "question": "Who was the winner when the time was 1:24.00?", "sql": "SELECT Winner/2nd FROM table WHERE Time = 1:24.00", "source": "test"}
11
+ {"text": "### Question:\nWhich couple had a week 2 score of exactly 23?\n\n### SQL:\nSELECT Couple FROM table WHERE Week 2 = 23", "question": "Which couple had a week 2 score of exactly 23?", "sql": "SELECT Couple FROM table WHERE Week 2 = 23", "source": "test"}
12
+ {"text": "### Question:\nWhat is the record when the result was w 52\u201343?\n\n### SQL:\nSELECT Record FROM table WHERE Result = w 52\u201343", "question": "What is the record when the result was w 52\u201343?", "sql": "SELECT Record FROM table WHERE Result = w 52\u201343", "source": "test"}
13
+ {"text": "### Question:\nWhat's the L3 cache that has a low power part number?\n\n### SQL:\nSELECT L3 cache FROM table WHERE Part number(s) = low power", "question": "What's the L3 cache that has a low power part number?", "sql": "SELECT L3 cache FROM table WHERE Part number(s) = low power", "source": "test"}
14
+ {"text": "### Question:\nWho won the race on 24 August?\n\n### SQL:\nSELECT Winning driver FROM table WHERE Date = 24 august", "question": "Who won the race on 24 August?", "sql": "SELECT Winning driver FROM table WHERE Date = 24 august", "source": "test"}
15
+ {"text": "### Question:\nWhat is the original artist when the vocal percussionist is Alexei Kalveks?\n\n### SQL:\nSELECT Original Artist FROM table WHERE Vocal Percussionist = Alexei Kalveks", "question": "What is the original artist when the vocal percussionist is Alexei Kalveks?", "sql": "SELECT Original Artist FROM table WHERE Vocal Percussionist = Alexei Kalveks", "source": "test"}
16
+ {"text": "### Question:\nwhat is the winning % for the years 2006-11?\n\n### SQL:\nSELECT Winning % FROM table WHERE Years = 2006-11", "question": "what is the winning % for the years 2006-11?", "sql": "SELECT Winning % FROM table WHERE Years = 2006-11", "source": "test"}
17
+ {"text": "### Question:\nWhat is Method, when Event is \"Reality Submission Fighting 2\"?\n\n### SQL:\nSELECT Method FROM table WHERE Event = reality submission fighting 2", "question": "What is Method, when Event is \"Reality Submission Fighting 2\"?", "sql": "SELECT Method FROM table WHERE Event = reality submission fighting 2", "source": "test"}
18
+ {"text": "### Question:\nWhat was the date of the game that had a loss of lidle (10-8)?\n\n### SQL:\nSELECT Date FROM table WHERE Loss = lidle (10-8)", "question": "What was the date of the game that had a loss of lidle (10-8)?", "sql": "SELECT Date FROM table WHERE Loss = lidle (10-8)", "source": "test"}
19
+ {"text": "### Question:\nWho played mixed doubles when Anna Keir played women's singles?\n\n### SQL:\nSELECT Mixed doubles FROM table WHERE Women's singles = anna keir", "question": "Who played mixed doubles when Anna Keir played women's singles?", "sql": "SELECT Mixed doubles FROM table WHERE Women's singles = anna keir", "source": "test"}
20
+ {"text": "### Question:\nWhich tone has a Standard Thai at \u0e1b\u0e25\u0e32?\n\n### SQL:\nSELECT Tone FROM table WHERE Standard Thai = \u0e1b\u0e25\u0e32", "question": "Which tone has a Standard Thai at \u0e1b\u0e25\u0e32?", "sql": "SELECT Tone FROM table WHERE Standard Thai = \u0e1b\u0e25\u0e32", "source": "test"}
21
+ {"text": "### Question:\nName the payload that has a weight of 12,000\n\n### SQL:\nSELECT Payload (kg) FROM table WHERE Weight (kg) = 12,000", "question": "Name the payload that has a weight of 12,000", "sql": "SELECT Payload (kg) FROM table WHERE Weight (kg) = 12,000", "source": "test"}
22
+ {"text": "### Question:\nWhich school's round was 24?\n\n### SQL:\nSELECT School FROM table WHERE Round = 24", "question": "Which school's round was 24?", "sql": "SELECT School FROM table WHERE Round = 24", "source": "test"}
23
+ {"text": "### Question:\nWhat is the least for Scottish Cup with a Challenge Cup greater than 0, Player Paul Keegan, and League Cup greater than 0?\n\n### SQL:\nSELECT MIN Scottish Cup FROM table WHERE Challenge Cup > 0 AND Player = paul keegan AND League Cup > 0", "question": "What is the least for Scottish Cup with a Challenge Cup greater than 0, Player Paul Keegan, and League Cup greater than 0?", "sql": "SELECT MIN Scottish Cup FROM table WHERE Challenge Cup > 0 AND Player = paul keegan AND League Cup > 0", "source": "test"}
24
+ {"text": "### Question:\nWhat is the only type of university that was founded in 1873?\n\n### SQL:\nSELECT Control FROM table WHERE Founded = 1873", "question": "What is the only type of university that was founded in 1873?", "sql": "SELECT Control FROM table WHERE Founded = 1873", "source": "test"}
25
+ {"text": "### Question:\nHow many games were there in the 1966 season?\n\n### SQL:\nSELECT MAX Game FROM table", "question": "How many games were there in the 1966 season?", "sql": "SELECT MAX Game FROM table", "source": "test"}
26
+ {"text": "### Question:\nWhen luz mcclinton is the name what is the season?\n\n### SQL:\nSELECT Season FROM table WHERE Name = Luz McClinton", "question": "When luz mcclinton is the name what is the season?", "sql": "SELECT Season FROM table WHERE Name = Luz McClinton", "source": "test"}
27
+ {"text": "### Question:\nWho was the original artist for First Solo?\n\n### SQL:\nSELECT Original artist FROM table WHERE Theme = First Solo", "question": "Who was the original artist for First Solo?", "sql": "SELECT Original artist FROM table WHERE Theme = First Solo", "source": "test"}
28
+ {"text": "### Question:\nWhat is the lowest grid for Roberto Rolfo with more than 26 laps?\n\n### SQL:\nSELECT MIN Grid FROM table WHERE Rider = roberto rolfo AND Laps > 26", "question": "What is the lowest grid for Roberto Rolfo with more than 26 laps?", "sql": "SELECT MIN Grid FROM table WHERE Rider = roberto rolfo AND Laps > 26", "source": "test"}
29
+ {"text": "### Question:\nWhat is the average lap for suzuki gsx-r1000 k7 and at grid 6?\n\n### SQL:\nSELECT AVG Laps FROM table WHERE Bike = suzuki gsx-r1000 k7 AND Grid = 6", "question": "What is the average lap for suzuki gsx-r1000 k7 and at grid 6?", "sql": "SELECT AVG Laps FROM table WHERE Bike = suzuki gsx-r1000 k7 AND Grid = 6", "source": "test"}
30
+ {"text": "### Question:\nWhat is the date when the Lakers were the home team?\n\n### SQL:\nSELECT Date FROM table WHERE Home = lakers", "question": "What is the date when the Lakers were the home team?", "sql": "SELECT Date FROM table WHERE Home = lakers", "source": "test"}
31
+ {"text": "### Question:\nWhen was there a score of 7-1?\n\n### SQL:\nSELECT Date FROM table WHERE Score = 7-1", "question": "When was there a score of 7-1?", "sql": "SELECT Date FROM table WHERE Score = 7-1", "source": "test"}
32
+ {"text": "### Question:\nWho received 6,131 televotes?\n\n### SQL:\nSELECT Televote Points FROM table WHERE Televotes = 6,131", "question": "Who received 6,131 televotes?", "sql": "SELECT Televote Points FROM table WHERE Televotes = 6,131", "source": "test"}
33
+ {"text": "### Question:\nHow many episodes aired Saturday, July 11, 2009\n\n### SQL:\nSELECT COUNT Episode # FROM table WHERE US air date = Saturday, July 11, 2009", "question": "How many episodes aired Saturday, July 11, 2009", "sql": "SELECT COUNT Episode # FROM table WHERE US air date = Saturday, July 11, 2009", "source": "test"}
34
+ {"text": "### Question:\nWhich episode aired in the USA on 20 May 2005?\n\n### SQL:\nSELECT Episode FROM table WHERE Airdate (USA) = 20 may 2005", "question": "Which episode aired in the USA on 20 May 2005?", "sql": "SELECT Episode FROM table WHERE Airdate (USA) = 20 may 2005", "source": "test"}
35
+ {"text": "### Question:\nWhat was the best finish for 206 on the money list?\n\n### SQL:\nSELECT Best finish FROM table WHERE Money list rank = 206", "question": "What was the best finish for 206 on the money list?", "sql": "SELECT Best finish FROM table WHERE Money list rank = 206", "source": "test"}
36
+ {"text": "### Question:\nName the sum of pick # for round less than 1\n\n### SQL:\nSELECT SUM Pick # FROM table WHERE Round < 1", "question": "Name the sum of pick # for round less than 1", "sql": "SELECT SUM Pick # FROM table WHERE Round < 1", "source": "test"}
37
+ {"text": "### Question:\nWhat engine did the Team Lotus have after 1965?\n\n### SQL:\nSELECT Engine FROM table WHERE Entrant = team lotus AND Year > 1965", "question": "What engine did the Team Lotus have after 1965?", "sql": "SELECT Engine FROM table WHERE Entrant = team lotus AND Year > 1965", "source": "test"}
38
+ {"text": "### Question:\nOpponent of chicago bulls had what location?\n\n### SQL:\nSELECT Location FROM table WHERE Opponent = chicago bulls", "question": "Opponent of chicago bulls had what location?", "sql": "SELECT Location FROM table WHERE Opponent = chicago bulls", "source": "test"}
39
+ {"text": "### Question:\nWhat religious groups made up 0.72% of the Indian population in 2001?\n\n### SQL:\nSELECT Religious group FROM table WHERE Population % 2001 = 0.72%", "question": "What religious groups made up 0.72% of the Indian population in 2001?", "sql": "SELECT Religious group FROM table WHERE Population % 2001 = 0.72%", "source": "test"}
40
+ {"text": "### Question:\nOn which date was the high assists Delonte West Earl Watson (6)?\n\n### SQL:\nSELECT Date FROM table WHERE High assists = delonte west earl watson (6)", "question": "On which date was the high assists Delonte West Earl Watson (6)?", "sql": "SELECT Date FROM table WHERE High assists = delonte west earl watson (6)", "source": "test"}
41
+ {"text": "### Question:\nWhat is the enrollment for the institution in Westfield, Massachusetts? \n\n### SQL:\nSELECT Enrollment FROM table WHERE Location = Westfield, Massachusetts", "question": "What is the enrollment for the institution in Westfield, Massachusetts? ", "sql": "SELECT Enrollment FROM table WHERE Location = Westfield, Massachusetts", "source": "test"}
42
+ {"text": "### Question:\nWhat is the Studio of the Film with a Gross rental of $7,500,000?\n\n### SQL:\nSELECT Studio FROM table WHERE Gross rental = $7,500,000", "question": "What is the Studio of the Film with a Gross rental of $7,500,000?", "sql": "SELECT Studio FROM table WHERE Gross rental = $7,500,000", "source": "test"}
43
+ {"text": "### Question:\nWhat was the Outcome of the match played on Hard (i) Surface?\n\n### SQL:\nSELECT Outcome FROM table WHERE Surface = hard (i)", "question": "What was the Outcome of the match played on Hard (i) Surface?", "sql": "SELECT Outcome FROM table WHERE Surface = hard (i)", "source": "test"}
44
+ {"text": "### Question:\nWhat is the highest rank of the player who played 30 events and made less than $2,708,005?\n\n### SQL:\nSELECT MAX Rank FROM table WHERE Earnings ( $ ) < 2,708,005 AND Events = 30", "question": "What is the highest rank of the player who played 30 events and made less than $2,708,005?", "sql": "SELECT MAX Rank FROM table WHERE Earnings ( $ ) < 2,708,005 AND Events = 30", "source": "test"}
45
+ {"text": "### Question:\nHow many games had Montreal Canadiens as an opponent?\n\n### SQL:\nSELECT SUM Game FROM table WHERE Opponent = montreal canadiens", "question": "How many games had Montreal Canadiens as an opponent?", "sql": "SELECT SUM Game FROM table WHERE Opponent = montreal canadiens", "source": "test"}
46
+ {"text": "### Question:\nWhat is the length of the highway with the route name sh 2?\n\n### SQL:\nSELECT Length FROM table WHERE Route Name = sh 2", "question": "What is the length of the highway with the route name sh 2?", "sql": "SELECT Length FROM table WHERE Route Name = sh 2", "source": "test"}
47
+ {"text": "### Question:\nCan you tell me total number of Silver that has the Republic of latvian ssr, and the Total larger than 6?\n\n### SQL:\nSELECT COUNT Silver FROM table WHERE Republic = latvian ssr AND Total > 6", "question": "Can you tell me total number of Silver that has the Republic of latvian ssr, and the Total larger than 6?", "sql": "SELECT COUNT Silver FROM table WHERE Republic = latvian ssr AND Total > 6", "source": "test"}
48
+ {"text": "### Question:\nWhat date did they play the Florida Panthers?\n\n### SQL:\nSELECT Date FROM table WHERE Opponent = Florida Panthers", "question": "What date did they play the Florida Panthers?", "sql": "SELECT Date FROM table WHERE Opponent = Florida Panthers", "source": "test"}
49
+ {"text": "### Question:\nwhich college has a player called Riley Clayton?\n\n### SQL:\nSELECT College FROM table WHERE Player = riley clayton", "question": "which college has a player called Riley Clayton?", "sql": "SELECT College FROM table WHERE Player = riley clayton", "source": "test"}
50
+ {"text": "### Question:\nWhat is the most minimal Final year that has a North or east end of covington?\n\n### SQL:\nSELECT MIN Final year FROM table WHERE North or east terminus = covington", "question": "What is the most minimal Final year that has a North or east end of covington?", "sql": "SELECT MIN Final year FROM table WHERE North or east terminus = covington", "source": "test"}
51
+ {"text": "### Question:\nWho directed the episode whose production code is pabf05?\n\n### SQL:\nSELECT Directed by FROM table WHERE Production code = PABF05", "question": "Who directed the episode whose production code is pabf05?", "sql": "SELECT Directed by FROM table WHERE Production code = PABF05", "source": "test"}
52
+ {"text": "### Question:\nWhat is the best fit (all data) when the best fit (WMAP, extra parameter) shows \u2014?\n\n### SQL:\nSELECT Best fit (all data) FROM table WHERE Best fit (WMAP, extra parameter) = \u2014", "question": "What is the best fit (all data) when the best fit (WMAP, extra parameter) shows \u2014?", "sql": "SELECT Best fit (all data) FROM table WHERE Best fit (WMAP, extra parameter) = \u2014", "source": "test"}
53
+ {"text": "### Question:\nCan you tell me the lowest Week that has the Attendance smaller than 34,336?\n\n### SQL:\nSELECT MIN Week FROM table WHERE Attendance < 34,336", "question": "Can you tell me the lowest Week that has the Attendance smaller than 34,336?", "sql": "SELECT MIN Week FROM table WHERE Attendance < 34,336", "source": "test"}
54
+ {"text": "### Question:\nWhat is the Lead in the 2004-05 Season?\n\n### SQL:\nSELECT Lead FROM table WHERE Season = 2004-05", "question": "What is the Lead in the 2004-05 Season?", "sql": "SELECT Lead FROM table WHERE Season = 2004-05", "source": "test"}
55
+ {"text": "### Question:\nWhat is the result for week 12 against the Green Bay Packers?\n\n### SQL:\nSELECT Result FROM table WHERE Week > 12 AND Opponent = green bay packers", "question": "What is the result for week 12 against the Green Bay Packers?", "sql": "SELECT Result FROM table WHERE Week > 12 AND Opponent = green bay packers", "source": "test"}
56
+ {"text": "### Question:\nWhat was the first season for the club that in 2012 was 2nd in Superettan?\n\n### SQL:\nSELECT First season FROM table WHERE Position in 2012 = 2nd in Superettan", "question": "What was the first season for the club that in 2012 was 2nd in Superettan?", "sql": "SELECT First season FROM table WHERE Position in 2012 = 2nd in Superettan", "source": "test"}
57
+ {"text": "### Question:\nWhat was the extra info for the Commonwealth Games?\n\n### SQL:\nSELECT Extra FROM table WHERE Tournament = commonwealth games", "question": "What was the extra info for the Commonwealth Games?", "sql": "SELECT Extra FROM table WHERE Tournament = commonwealth games", "source": "test"}
58
+ {"text": "### Question:\nWhich player played for the Grizzlies from 1997-1998?\n\n### SQL:\nSELECT Player FROM table WHERE Years for Grizzlies = 1997-1998", "question": "Which player played for the Grizzlies from 1997-1998?", "sql": "SELECT Player FROM table WHERE Years for Grizzlies = 1997-1998", "source": "test"}
59
+ {"text": "### Question:\nNone of the communities listed has a percentage smaller than 8.6 in 2006.\n\n### SQL:\nSELECT COUNT Seats 2001 FROM table WHERE % 2006 < 8.6", "question": "None of the communities listed has a percentage smaller than 8.6 in 2006.", "sql": "SELECT COUNT Seats 2001 FROM table WHERE % 2006 < 8.6", "source": "test"}
60
+ {"text": "### Question:\nWhat is the basketball status for Valparaiso who has an indoor track status of yes?\n\n### SQL:\nSELECT Bask FROM table WHERE Indoor track = yes AND School = valparaiso", "question": "What is the basketball status for Valparaiso who has an indoor track status of yes?", "sql": "SELECT Bask FROM table WHERE Indoor track = yes AND School = valparaiso", "source": "test"}
61
+ {"text": "### Question:\nWhat 1979 Hindi film had Ravindra Jain directing music?\n\n### SQL:\nSELECT Film name FROM table WHERE Language = hindi AND Lyricist = ravindra jain AND Music director = ravindra jain AND Year = 1979", "question": "What 1979 Hindi film had Ravindra Jain directing music?", "sql": "SELECT Film name FROM table WHERE Language = hindi AND Lyricist = ravindra jain AND Music director = ravindra jain AND Year = 1979", "source": "test"}
62
+ {"text": "### Question:\nWhat was the winning score in the Alfred Dunhill links championship?\n\n### SQL:\nSELECT Winning score FROM table WHERE Tournament = alfred dunhill links championship", "question": "What was the winning score in the Alfred Dunhill links championship?", "sql": "SELECT Winning score FROM table WHERE Tournament = alfred dunhill links championship", "source": "test"}
63
+ {"text": "### Question:\nTell me the highest Grid for Maurice Trintignant and laps less than 87\n\n### SQL:\nSELECT MAX Grid FROM table WHERE Driver = maurice trintignant AND Laps < 87", "question": "Tell me the highest Grid for Maurice Trintignant and laps less than 87", "sql": "SELECT MAX Grid FROM table WHERE Driver = maurice trintignant AND Laps < 87", "source": "test"}
64
+ {"text": "### Question:\nWhich Play-Off has Events of \u2013, and a Season of a-6?\n\n### SQL:\nSELECT Play-Off FROM table WHERE Events = \u2013 AND Season = a-6", "question": "Which Play-Off has Events of \u2013, and a Season of a-6?", "sql": "SELECT Play-Off FROM table WHERE Events = \u2013 AND Season = a-6", "source": "test"}
65
+ {"text": "### Question:\nWho did team phoenix visit in their home?\n\n### SQL:\nSELECT Home FROM table WHERE Visitor = phoenix", "question": "Who did team phoenix visit in their home?", "sql": "SELECT Home FROM table WHERE Visitor = phoenix", "source": "test"}
66
+ {"text": "### Question:\nName the record with home of bucks on 24 november 2007\n\n### SQL:\nSELECT Record FROM table WHERE Home = bucks AND Date = 24 november 2007", "question": "Name the record with home of bucks on 24 november 2007", "sql": "SELECT Record FROM table WHERE Home = bucks AND Date = 24 november 2007", "source": "test"}
67
+ {"text": "### Question:\nWhat is Myron Walwyn with a Territorial at-large Constiuency's First Elected Date\n\n### SQL:\nSELECT First elected FROM table WHERE Constiuency = territorial at-large AND Name = myron walwyn", "question": "What is Myron Walwyn with a Territorial at-large Constiuency's First Elected Date", "sql": "SELECT First elected FROM table WHERE Constiuency = territorial at-large AND Name = myron walwyn", "source": "test"}
68
+ {"text": "### Question:\nWhat competition had a Rank-Qualifying of 1st and a ball apparatus?\n\n### SQL:\nSELECT Competition Description FROM table WHERE Rank-Qualifying = 1st AND Apparatus = ball", "question": "What competition had a Rank-Qualifying of 1st and a ball apparatus?", "sql": "SELECT Competition Description FROM table WHERE Rank-Qualifying = 1st AND Apparatus = ball", "source": "test"}
69
+ {"text": "### Question:\nWhat is the average area larger than Code 19025 but a smaller region than 12?\n\n### SQL:\nSELECT AVG Area (km 2 ) FROM table WHERE Code > 19025 AND Region < 12", "question": "What is the average area larger than Code 19025 but a smaller region than 12?", "sql": "SELECT AVG Area (km 2 ) FROM table WHERE Code > 19025 AND Region < 12", "source": "test"}
70
+ {"text": "### Question:\nWhat is the lowest number of laps for Marco Simoncelli on a grid higher than 11?\n\n### SQL:\nSELECT MIN Laps FROM table WHERE Rider = marco simoncelli AND Grid > 11", "question": "What is the lowest number of laps for Marco Simoncelli on a grid higher than 11?", "sql": "SELECT MIN Laps FROM table WHERE Rider = marco simoncelli AND Grid > 11", "source": "test"}
71
+ {"text": "### Question:\nWhat is the highest Year, when Apparatus is \"Vault\", and when Rank-Final is less than 9?\n\n### SQL:\nSELECT MAX Year FROM table WHERE Apparatus = vault AND Rank-Final < 9", "question": "What is the highest Year, when Apparatus is \"Vault\", and when Rank-Final is less than 9?", "sql": "SELECT MAX Year FROM table WHERE Apparatus = vault AND Rank-Final < 9", "source": "test"}
72
+ {"text": "### Question:\nHow many poles has a percentage of 22.08%?\n\n### SQL:\nSELECT SUM Poles FROM table WHERE Percentage = 22.08%", "question": "How many poles has a percentage of 22.08%?", "sql": "SELECT SUM Poles FROM table WHERE Percentage = 22.08%", "source": "test"}
73
+ {"text": "### Question:\nName the number of score for sacramento\n\n### SQL:\nSELECT COUNT Score FROM table WHERE Team = Sacramento", "question": "Name the number of score for sacramento", "sql": "SELECT COUNT Score FROM table WHERE Team = Sacramento", "source": "test"}
74
+ {"text": "### Question:\nWhat was the nationality of the player with a score of 72-72-67=211?\n\n### SQL:\nSELECT Country FROM table WHERE Score = 72-72-67=211", "question": "What was the nationality of the player with a score of 72-72-67=211?", "sql": "SELECT Country FROM table WHERE Score = 72-72-67=211", "source": "test"}
75
+ {"text": "### Question:\nWhat was the venue that had 5000 m after 2009?\n\n### SQL:\nSELECT Venue FROM table WHERE Year > 2009 AND Notes = 5000 m", "question": "What was the venue that had 5000 m after 2009?", "sql": "SELECT Venue FROM table WHERE Year > 2009 AND Notes = 5000 m", "source": "test"}
76
+ {"text": "### Question:\nwhat is the sspec number when the part number is cw8064701470802?\n\n### SQL:\nSELECT sSpec number FROM table WHERE Part number(s) = cw8064701470802", "question": "what is the sspec number when the part number is cw8064701470802?", "sql": "SELECT sSpec number FROM table WHERE Part number(s) = cw8064701470802", "source": "test"}
77
+ {"text": "### Question:\nWhich Base has a Name of .44 wcf?\n\n### SQL:\nSELECT Base FROM table WHERE Name = .44 wcf", "question": "Which Base has a Name of .44 wcf?", "sql": "SELECT Base FROM table WHERE Name = .44 wcf", "source": "test"}
78
+ {"text": "### Question:\nWhat is the value of the runner up column for the Alberta province?\n\n### SQL:\nSELECT MAX Runner Up FROM table WHERE Province = Alberta", "question": "What is the value of the runner up column for the Alberta province?", "sql": "SELECT MAX Runner Up FROM table WHERE Province = Alberta", "source": "test"}
79
+ {"text": "### Question:\nWhat is the Athlete of the race with a Time of 9.78?\n\n### SQL:\nSELECT Athlete FROM table WHERE Time = 9.78", "question": "What is the Athlete of the race with a Time of 9.78?", "sql": "SELECT Athlete FROM table WHERE Time = 9.78", "source": "test"}
80
+ {"text": "### Question:\nWhat was the rating of the episode \"After Hours\"?\n\n### SQL:\nSELECT Rating (Millions) FROM table WHERE Title = \"after hours\"", "question": "What was the rating of the episode \"After Hours\"?", "sql": "SELECT Rating (Millions) FROM table WHERE Title = \"after hours\"", "source": "test"}
81
+ {"text": "### Question:\nWhat is the number for the forward position from the school/club team La Salle?\n\n### SQL:\nSELECT Number FROM table WHERE Position = forward AND School/Club Team = la salle", "question": "What is the number for the forward position from the school/club team La Salle?", "sql": "SELECT Number FROM table WHERE Position = forward AND School/Club Team = la salle", "source": "test"}
82
+ {"text": "### Question:\nwhat is the highest wickets when the best bowling is 2/32 and matches is less than 5?\n\n### SQL:\nSELECT MAX Wickets FROM table WHERE Best Bowling = 2/32 AND Matches < 5", "question": "what is the highest wickets when the best bowling is 2/32 and matches is less than 5?", "sql": "SELECT MAX Wickets FROM table WHERE Best Bowling = 2/32 AND Matches < 5", "source": "test"}
83
+ {"text": "### Question:\nWhich show runs on Friday at 05:00 AM?\n\n### SQL:\nSELECT 05:00 AM FROM table WHERE Time = friday", "question": "Which show runs on Friday at 05:00 AM?", "sql": "SELECT 05:00 AM FROM table WHERE Time = friday", "source": "test"}
84
+ {"text": "### Question:\nHow many players have the hometown Pennsauken, NJ?\n\n### SQL:\nSELECT COUNT Player FROM table WHERE Hometown = Pennsauken, NJ", "question": "How many players have the hometown Pennsauken, NJ?", "sql": "SELECT COUNT Player FROM table WHERE Hometown = Pennsauken, NJ", "source": "test"}
85
+ {"text": "### Question:\nHow many losses does Cross Keys RFC have?\n\n### SQL:\nSELECT COUNT Lost FROM table WHERE Club = Cross Keys RFC", "question": "How many losses does Cross Keys RFC have?", "sql": "SELECT COUNT Lost FROM table WHERE Club = Cross Keys RFC", "source": "test"}
86
+ {"text": "### Question:\nWhat was the time when the method was TKO?\n\n### SQL:\nSELECT Time FROM table WHERE Method = tko", "question": "What was the time when the method was TKO?", "sql": "SELECT Time FROM table WHERE Method = tko", "source": "test"}
87
+ {"text": "### Question:\nWhat was the type of sussex?\n\n### SQL:\nSELECT Type FROM table WHERE Name = Sussex", "question": "What was the type of sussex?", "sql": "SELECT Type FROM table WHERE Name = Sussex", "source": "test"}
88
+ {"text": "### Question:\nWhat Title has a Role of Mylene?\n\n### SQL:\nSELECT Title FROM table WHERE Role = mylene", "question": "What Title has a Role of Mylene?", "sql": "SELECT Title FROM table WHERE Role = mylene", "source": "test"}
89
+ {"text": "### Question:\nWho is the captain of Neil Warnock's team?\n\n### SQL:\nSELECT Team captain FROM table WHERE Manager = Neil Warnock", "question": "Who is the captain of Neil Warnock's team?", "sql": "SELECT Team captain FROM table WHERE Manager = Neil Warnock", "source": "test"}
90
+ {"text": "### Question:\nWhat is the winning % for the 2010 QF?\n\n### SQL:\nSELECT Win % FROM table WHERE 2010 = qf", "question": "What is the winning % for the 2010 QF?", "sql": "SELECT Win % FROM table WHERE 2010 = qf", "source": "test"}
91
+ {"text": "### Question:\nWhat's the court ranking of 5th son of tadayori and has revenues of 10,000 koku?\n\n### SQL:\nSELECT Court Rank FROM table WHERE Revenues = 10,000 koku AND Lineage = 5th son of tadayori", "question": "What's the court ranking of 5th son of tadayori and has revenues of 10,000 koku?", "sql": "SELECT Court Rank FROM table WHERE Revenues = 10,000 koku AND Lineage = 5th son of tadayori", "source": "test"}
92
+ {"text": "### Question:\nWhich race was on the Las Vegas Motor Speedway for 2 hours?\n\n### SQL:\nSELECT Race FROM table WHERE Circuit = las vegas motor speedway AND Length = 2 hours", "question": "Which race was on the Las Vegas Motor Speedway for 2 hours?", "sql": "SELECT Race FROM table WHERE Circuit = las vegas motor speedway AND Length = 2 hours", "source": "test"}
93
+ {"text": "### Question:\nWhat venue hsoted the european cross country championships with a notes of junior men individual 6.595km?\n\n### SQL:\nSELECT Venue FROM table WHERE Competition = european cross country championships AND Notes = junior men individual 6.595km", "question": "What venue hsoted the european cross country championships with a notes of junior men individual 6.595km?", "sql": "SELECT Venue FROM table WHERE Competition = european cross country championships AND Notes = junior men individual 6.595km", "source": "test"}
94
+ {"text": "### Question:\nWhat was the Opponent in Week 9?\n\n### SQL:\nSELECT Opponent FROM table WHERE Week = 9", "question": "What was the Opponent in Week 9?", "sql": "SELECT Opponent FROM table WHERE Week = 9", "source": "test"}
95
+ {"text": "### Question:\nBefore round 7, what is the greatest Pick # for a player that plays defensive tackle?\n\n### SQL:\nSELECT MAX Pick # FROM table WHERE Position = defensive tackle AND Round < 7", "question": "Before round 7, what is the greatest Pick # for a player that plays defensive tackle?", "sql": "SELECT MAX Pick # FROM table WHERE Position = defensive tackle AND Round < 7", "source": "test"}
96
+ {"text": "### Question:\nWhat was the highest Pick for Lonnie Brockman before round 9?\n\n### SQL:\nSELECT MAX Pick FROM table WHERE Player = lonnie brockman AND Round < 9", "question": "What was the highest Pick for Lonnie Brockman before round 9?", "sql": "SELECT MAX Pick FROM table WHERE Player = lonnie brockman AND Round < 9", "source": "test"}
97
+ {"text": "### Question:\nWhen was brian lemay born?\n\n### SQL:\nSELECT Date of Birth (Age) FROM table WHERE Player = brian lemay", "question": "When was brian lemay born?", "sql": "SELECT Date of Birth (Age) FROM table WHERE Player = brian lemay", "source": "test"}
98
+ {"text": "### Question:\nWhat was the time of the NJKF Titans Neo X event?\n\n### SQL:\nSELECT Time FROM table WHERE Event = njkf titans neo x", "question": "What was the time of the NJKF Titans Neo X event?", "sql": "SELECT Time FROM table WHERE Event = njkf titans neo x", "source": "test"}
99
+ {"text": "### Question:\nWhich kit maker have Trond Sollied as a manager?\n\n### SQL:\nSELECT Kit maker FROM table WHERE Manager = trond sollied", "question": "Which kit maker have Trond Sollied as a manager?", "sql": "SELECT Kit maker FROM table WHERE Manager = trond sollied", "source": "test"}
100
+ {"text": "### Question:\nWhich Locomotive Entered Service in November 1984 and has an Operator of Southern Shorthaul Railroad?\n\n### SQL:\nSELECT Locomotive FROM table WHERE Operator = southern shorthaul railroad AND Entered service = november 1984", "question": "Which Locomotive Entered Service in November 1984 and has an Operator of Southern Shorthaul Railroad?", "sql": "SELECT Locomotive FROM table WHERE Operator = southern shorthaul railroad AND Entered service = november 1984", "source": "test"}
src/outputs/finetuning/train.jsonl ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {"text": "### Question:\nWhat was the year when Ch\u016bnqi\u016b Ch\u00e1sh\u00ec (\u6625\u79cb\u8336\u5ba4) was submitted?\n\n### SQL:\nSELECT Year (Ceremony) FROM table WHERE Original title = Ch\u016bnqi\u016b ch\u00e1sh\u00ec (\u6625\u79cb\u8336\u5ba4)", "question": "What was the year when Ch\u016bnqi\u016b Ch\u00e1sh\u00ec (\u6625\u79cb\u8336\u5ba4) was submitted?", "sql": "SELECT Year (Ceremony) FROM table WHERE Original title = Ch\u016bnqi\u016b ch\u00e1sh\u00ec (\u6625\u79cb\u8336\u5ba4)", "source": "train"}
2
+ {"text": "### Question:\nWhat is Name, when Builder is \"Kerr Stuart\"?\n\n### SQL:\nSELECT Name FROM table WHERE Builder = kerr stuart", "question": "What is Name, when Builder is \"Kerr Stuart\"?", "sql": "SELECT Name FROM table WHERE Builder = kerr stuart", "source": "train"}
3
+ {"text": "### Question:\nWhat series had more than 10 Podiums?\n\n### SQL:\nSELECT Series FROM table WHERE Podiums > 10", "question": "What series had more than 10 Podiums?", "sql": "SELECT Series FROM table WHERE Podiums > 10", "source": "train"}
4
+ {"text": "### Question:\nFor what ceremony was \"Fire Dancer\" not nominated? \n\n### SQL:\nSELECT Year (Ceremony) FROM table WHERE Original title = Fire Dancer", "question": "For what ceremony was \"Fire Dancer\" not nominated? ", "sql": "SELECT Year (Ceremony) FROM table WHERE Original title = Fire Dancer", "source": "train"}
5
+ {"text": "### Question:\nHow many people went to the game with Indiana visiting?\n\n### SQL:\nSELECT Attendance FROM table WHERE Visitor = indiana", "question": "How many people went to the game with Indiana visiting?", "sql": "SELECT Attendance FROM table WHERE Visitor = indiana", "source": "train"}
6
+ {"text": "### Question:\nWith a swim (1.5km) of 18:55 and a run (10km) of 32:37, what is the trans 2?\n\n### SQL:\nSELECT Trans 2 FROM table WHERE Swim (1.5km) = 18:55 AND Run (10km) = 32:37", "question": "With a swim (1.5km) of 18:55 and a run (10km) of 32:37, what is the trans 2?", "sql": "SELECT Trans 2 FROM table WHERE Swim (1.5km) = 18:55 AND Run (10km) = 32:37", "source": "train"}
7
+ {"text": "### Question:\nWhat is the region for Chep\u00e9n with 3 districts?\n\n### SQL:\nSELECT Region FROM table WHERE Districts = 3 AND Capital = chep\u00e9n", "question": "What is the region for Chep\u00e9n with 3 districts?", "sql": "SELECT Region FROM table WHERE Districts = 3 AND Capital = chep\u00e9n", "source": "train"}
8
+ {"text": "### Question:\nWhat was the third place of the performance in 2006 with the host Japan?\n\n### SQL:\nSELECT Third place FROM table WHERE Host = japan AND Season < 2006", "question": "What was the third place of the performance in 2006 with the host Japan?", "sql": "SELECT Third place FROM table WHERE Host = japan AND Season < 2006", "source": "train"}
9
+ {"text": "### Question:\nWhat is the 1999-2000 team, when the Height (cm) is less than 187, and when the Birthplace is Cloquet, Minnesota?\n\n### SQL:\nSELECT 1999-2000 team FROM table WHERE Height (cm) < 187 AND Birthplace = cloquet, minnesota", "question": "What is the 1999-2000 team, when the Height (cm) is less than 187, and when the Birthplace is Cloquet, Minnesota?", "sql": "SELECT 1999-2000 team FROM table WHERE Height (cm) < 187 AND Birthplace = cloquet, minnesota", "source": "train"}
10
+ {"text": "### Question:\nWhat is the highest Tournaments, when Pro Debut is \"July 2002\"?\n\n### SQL:\nSELECT MAX Tournaments FROM table WHERE Pro Debut = july 2002", "question": "What is the highest Tournaments, when Pro Debut is \"July 2002\"?", "sql": "SELECT MAX Tournaments FROM table WHERE Pro Debut = july 2002", "source": "train"}
11
+ {"text": "### Question:\nwhat is the total time in office when the assumed office is 1 november 1856?\n\n### SQL:\nSELECT TOTAL Time in Office: FROM table WHERE Assumed Office: = 1 november 1856", "question": "what is the total time in office when the assumed office is 1 november 1856?", "sql": "SELECT TOTAL Time in Office: FROM table WHERE Assumed Office: = 1 november 1856", "source": "train"}
12
+ {"text": "### Question:\nWhen did the term end for the term that had government 27 and Minister Tzachi Hanegbi?\n\n### SQL:\nSELECT Term end FROM table WHERE Governments = 27 AND Minister = tzachi hanegbi", "question": "When did the term end for the term that had government 27 and Minister Tzachi Hanegbi?", "sql": "SELECT Term end FROM table WHERE Governments = 27 AND Minister = tzachi hanegbi", "source": "train"}
13
+ {"text": "### Question:\nWhat is the Outcome of the Doubles played on Carpet?\n\n### SQL:\nSELECT Outcome FROM table WHERE Surface = carpet", "question": "What is the Outcome of the Doubles played on Carpet?", "sql": "SELECT Outcome FROM table WHERE Surface = carpet", "source": "train"}
14
+ {"text": "### Question:\nName the venue for geelong away team\n\n### SQL:\nSELECT Venue FROM table WHERE Away team = geelong", "question": "Name the venue for geelong away team", "sql": "SELECT Venue FROM table WHERE Away team = geelong", "source": "train"}
15
+ {"text": "### Question:\nWhat was the score for the final played on 2 July 2012?\n\n### SQL:\nSELECT Score FROM table WHERE Date = 2 july 2012", "question": "What was the score for the final played on 2 July 2012?", "sql": "SELECT Score FROM table WHERE Date = 2 july 2012", "source": "train"}
16
+ {"text": "### Question:\nBetween November 25\u201330, 2008 the sellout rate was at 75%, indicating that the ration between shows to sellout was what?\n\n### SQL:\nSELECT Shows / Sellout FROM table WHERE Sellout (%) = 75%", "question": "Between November 25\u201330, 2008 the sellout rate was at 75%, indicating that the ration between shows to sellout was what?", "sql": "SELECT Shows / Sellout FROM table WHERE Sellout (%) = 75%", "source": "train"}
17
+ {"text": "### Question:\nWhat is the total ranking when there are less than 16 draws, less than 1 point, and the English translation is in love with you?\n\n### SQL:\nSELECT SUM Place FROM table WHERE Draw < 16 AND English translation = in love with you AND Points < 1", "question": "What is the total ranking when there are less than 16 draws, less than 1 point, and the English translation is in love with you?", "sql": "SELECT SUM Place FROM table WHERE Draw < 16 AND English translation = in love with you AND Points < 1", "source": "train"}
18
+ {"text": "### Question:\nWhich Ratio as % has a Ratio of 8/9?\n\n### SQL:\nSELECT Ratio as % FROM table WHERE Ratio = 8/9", "question": "Which Ratio as % has a Ratio of 8/9?", "sql": "SELECT Ratio as % FROM table WHERE Ratio = 8/9", "source": "train"}
19
+ {"text": "### Question:\nWhat is the Early Modern English phonology used in the example b\u014dg > \"bough\"; pl\u014dg > pl\u014dh > \"plough\"?\n\n### SQL:\nSELECT Early Modern English FROM table WHERE Example = b\u014dg > \"bough\"; pl\u014dg > pl\u014dh > \"plough\"", "question": "What is the Early Modern English phonology used in the example b\u014dg > \"bough\"; pl\u014dg > pl\u014dh > \"plough\"?", "sql": "SELECT Early Modern English FROM table WHERE Example = b\u014dg > \"bough\"; pl\u014dg > pl\u014dh > \"plough\"", "source": "train"}
20
+ {"text": "### Question:\nWhat is the greatest Wins with Matches smaller than 5, and a Year of 1994?\n\n### SQL:\nSELECT MAX Wins FROM table WHERE Matches < 5 AND Year = 1994", "question": "What is the greatest Wins with Matches smaller than 5, and a Year of 1994?", "sql": "SELECT MAX Wins FROM table WHERE Matches < 5 AND Year = 1994", "source": "train"}
21
+ {"text": "### Question:\nWhat ranking is the Battersea Power Station?\n\n### SQL:\nSELECT Rank FROM table WHERE Name = battersea power station", "question": "What ranking is the Battersea Power Station?", "sql": "SELECT Rank FROM table WHERE Name = battersea power station", "source": "train"}
22
+ {"text": "### Question:\nWhat pick number was the player that was picked by Edmonton?\n\n### SQL:\nSELECT Pick # FROM table WHERE CFL Team = Edmonton", "question": "What pick number was the player that was picked by Edmonton?", "sql": "SELECT Pick # FROM table WHERE CFL Team = Edmonton", "source": "train"}
23
+ {"text": "### Question:\nWhat was the location of the fight when Gassaway fought kevin knabjian?\n\n### SQL:\nSELECT Location FROM table WHERE Opponent = kevin knabjian", "question": "What was the location of the fight when Gassaway fought kevin knabjian?", "sql": "SELECT Location FROM table WHERE Opponent = kevin knabjian", "source": "train"}
24
+ {"text": "### Question:\nHow many points did the song \"Stille f\u00f8r stormen\" get?\n\n### SQL:\nSELECT MIN Total Points FROM table WHERE Song = \"Stille f\u00f8r stormen\"", "question": "How many points did the song \"Stille f\u00f8r stormen\" get?", "sql": "SELECT MIN Total Points FROM table WHERE Song = \"Stille f\u00f8r stormen\"", "source": "train"}
25
+ {"text": "### Question:\nWht years did truck robinson play?\n\n### SQL:\nSELECT Years for Jazz FROM table WHERE Player = Truck Robinson", "question": "Wht years did truck robinson play?", "sql": "SELECT Years for Jazz FROM table WHERE Player = Truck Robinson", "source": "train"}
26
+ {"text": "### Question:\nWhich Score has a Couple of cristi\u00e1n & cheryl, and a Style of cha-cha-cha?\n\n### SQL:\nSELECT Score FROM table WHERE Couple = cristi\u00e1n & cheryl AND Style = cha-cha-cha", "question": "Which Score has a Couple of cristi\u00e1n & cheryl, and a Style of cha-cha-cha?", "sql": "SELECT Score FROM table WHERE Couple = cristi\u00e1n & cheryl AND Style = cha-cha-cha", "source": "train"}
27
+ {"text": "### Question:\nWhat is the IUPAC name for chloroform?\n\n### SQL:\nSELECT IUPAC name FROM table WHERE Common name = chloroform", "question": "What is the IUPAC name for chloroform?", "sql": "SELECT IUPAC name FROM table WHERE Common name = chloroform", "source": "train"}
28
+ {"text": "### Question:\nWhat is the Constituency Number when the Number of Electorates (2003) is more than 156,910, and Reserved for sc?\n\n### SQL:\nSELECT Constituency number FROM table WHERE Number of electorates (2003) > 156,910 AND Reserved for ( SC / ST /None) = sc", "question": "What is the Constituency Number when the Number of Electorates (2003) is more than 156,910, and Reserved for sc?", "sql": "SELECT Constituency number FROM table WHERE Number of electorates (2003) > 156,910 AND Reserved for ( SC / ST /None) = sc", "source": "train"}
29
+ {"text": "### Question:\nWhat was John Jones's pick#?\n\n### SQL:\nSELECT SUM Pick FROM table WHERE Player = john jones", "question": "What was John Jones's pick#?", "sql": "SELECT SUM Pick FROM table WHERE Player = john jones", "source": "train"}
30
+ {"text": "### Question:\nWhich runner(s)-up had a Winning score of \u201313 (68-70-66-71=275) and a Margin of victory of 3 strokes?\n\n### SQL:\nSELECT Runner(s)-up FROM table WHERE Margin of victory = 3 strokes AND Winning score = \u201313 (68-70-66-71=275)", "question": "Which runner(s)-up had a Winning score of \u201313 (68-70-66-71=275) and a Margin of victory of 3 strokes?", "sql": "SELECT Runner(s)-up FROM table WHERE Margin of victory = 3 strokes AND Winning score = \u201313 (68-70-66-71=275)", "source": "train"}
31
+ {"text": "### Question:\nWhat is the number of people in attendance when Tonbridge Angels is the opponent?\n\n### SQL:\nSELECT Attendance FROM table WHERE Opponent = tonbridge angels", "question": "What is the number of people in attendance when Tonbridge Angels is the opponent?", "sql": "SELECT Attendance FROM table WHERE Opponent = tonbridge angels", "source": "train"}
32
+ {"text": "### Question:\nWhich letter has the British a\u026a?\n\n### SQL:\nSELECT Letter FROM table WHERE British = a\u026a", "question": "Which letter has the British a\u026a?", "sql": "SELECT Letter FROM table WHERE British = a\u026a", "source": "train"}
33
+ {"text": "### Question:\nIn which city was the berlin marathon?\n\n### SQL:\nSELECT Location FROM table WHERE Road race = Berlin Marathon", "question": "In which city was the berlin marathon?", "sql": "SELECT Location FROM table WHERE Road race = Berlin Marathon", "source": "train"}
34
+ {"text": "### Question:\nName the year 2007 for 668 2008-q1\n\n### SQL:\nSELECT year 2007 FROM table WHERE 2008 - Q1 = 668", "question": "Name the year 2007 for 668 2008-q1", "sql": "SELECT year 2007 FROM table WHERE 2008 - Q1 = 668", "source": "train"}
35
+ {"text": "### Question:\nWhat was Collingwood's score when they played against North Melbourne at home?\n\n### SQL:\nSELECT Home team score FROM table WHERE Away team = north melbourne", "question": "What was Collingwood's score when they played against North Melbourne at home?", "sql": "SELECT Home team score FROM table WHERE Away team = north melbourne", "source": "train"}
36
+ {"text": "### Question:\nWhich division were the Brewers a part of in the 1987 season?\n\n### SQL:\nSELECT Division FROM table WHERE Team season = 1987", "question": "Which division were the Brewers a part of in the 1987 season?", "sql": "SELECT Division FROM table WHERE Team season = 1987", "source": "train"}
37
+ {"text": "### Question:\nWhich TV Station has a Romaji Title of kegareta shita?\n\n### SQL:\nSELECT TV Station FROM table WHERE Romaji Title = kegareta shita", "question": "Which TV Station has a Romaji Title of kegareta shita?", "sql": "SELECT TV Station FROM table WHERE Romaji Title = kegareta shita", "source": "train"}
38
+ {"text": "### Question:\nTell me the notes with method of points and event of adcc 2001 absolute with result of loss\n\n### SQL:\nSELECT Notes FROM table WHERE Method = points AND Event = adcc 2001 absolute AND Result = loss", "question": "Tell me the notes with method of points and event of adcc 2001 absolute with result of loss", "sql": "SELECT Notes FROM table WHERE Method = points AND Event = adcc 2001 absolute AND Result = loss", "source": "train"}
39
+ {"text": "### Question:\nWhat is the highest value for SF round for the country of England?\n\n### SQL:\nSELECT MAX SF Round FROM table WHERE Country = England", "question": "What is the highest value for SF round for the country of England?", "sql": "SELECT MAX SF Round FROM table WHERE Country = England", "source": "train"}
40
+ {"text": "### Question:\nWhat country id Bob Rosburg from?\n\n### SQL:\nSELECT Country FROM table WHERE Player = bob rosburg", "question": "What country id Bob Rosburg from?", "sql": "SELECT Country FROM table WHERE Player = bob rosburg", "source": "train"}
41
+ {"text": "### Question:\nWho is the Alternate for Sweden?\n\n### SQL:\nSELECT Alternate FROM table WHERE Nation = sweden", "question": "Who is the Alternate for Sweden?", "sql": "SELECT Alternate FROM table WHERE Nation = sweden", "source": "train"}
42
+ {"text": "### Question:\nWhich Format has a Frequency of 100.5 fm?\n\n### SQL:\nSELECT Format FROM table WHERE Frequency = 100.5 fm", "question": "Which Format has a Frequency of 100.5 fm?", "sql": "SELECT Format FROM table WHERE Frequency = 100.5 fm", "source": "train"}
43
+ {"text": "### Question:\nWhat country is Lee Janzen from?\n\n### SQL:\nSELECT Country FROM table WHERE Player = lee janzen", "question": "What country is Lee Janzen from?", "sql": "SELECT Country FROM table WHERE Player = lee janzen", "source": "train"}
44
+ {"text": "### Question:\nHead Coach casemiro mior is at which Club?\n\n### SQL:\nSELECT Club FROM table WHERE Head Coach = casemiro mior", "question": "Head Coach casemiro mior is at which Club?", "sql": "SELECT Club FROM table WHERE Head Coach = casemiro mior", "source": "train"}
45
+ {"text": "### Question:\nOn which date was the opponent the Chicago Bears?\n\n### SQL:\nSELECT Date FROM table WHERE Opponent = chicago bears", "question": "On which date was the opponent the Chicago Bears?", "sql": "SELECT Date FROM table WHERE Opponent = chicago bears", "source": "train"}
46
+ {"text": "### Question:\nWhich languages are offered in the coverage area of klang petaling jaya shah alam?\n\n### SQL:\nSELECT Language FROM table WHERE Coverage Area = Klang Petaling Jaya Shah Alam", "question": "Which languages are offered in the coverage area of klang petaling jaya shah alam?", "sql": "SELECT Language FROM table WHERE Coverage Area = Klang Petaling Jaya Shah Alam", "source": "train"}
47
+ {"text": "### Question:\nWhat is the to par for Jiyai Shin when the place is t1?\n\n### SQL:\nSELECT To par FROM table WHERE Place = t1 AND Player = jiyai shin", "question": "What is the to par for Jiyai Shin when the place is t1?", "sql": "SELECT To par FROM table WHERE Place = t1 AND Player = jiyai shin", "source": "train"}
48
+ {"text": "### Question:\nWhat is the total number of Division(s), when Team is Chongqing Lifan, and when Apps is greater than 9?\n\n### SQL:\nSELECT COUNT Division FROM table WHERE Team = chongqing lifan AND Apps > 9", "question": "What is the total number of Division(s), when Team is Chongqing Lifan, and when Apps is greater than 9?", "sql": "SELECT COUNT Division FROM table WHERE Team = chongqing lifan AND Apps > 9", "source": "train"}
49
+ {"text": "### Question:\nWhat is Mike Weir's To par?\n\n### SQL:\nSELECT To par FROM table WHERE Player = mike weir", "question": "What is Mike Weir's To par?", "sql": "SELECT To par FROM table WHERE Player = mike weir", "source": "train"}
50
+ {"text": "### Question:\nWhat was the away team when the home was st kilda?\n\n### SQL:\nSELECT Away team FROM table WHERE Home team = st kilda", "question": "What was the away team when the home was st kilda?", "sql": "SELECT Away team FROM table WHERE Home team = st kilda", "source": "train"}
51
+ {"text": "### Question:\nWhat is the % of same-sex marriages for the year of 2011?\n\n### SQL:\nSELECT % same-sex marriages FROM table WHERE Year = 2011", "question": "What is the % of same-sex marriages for the year of 2011?", "sql": "SELECT % same-sex marriages FROM table WHERE Year = 2011", "source": "train"}
52
+ {"text": "### Question:\nWhat Place has a To par of \u20134?\n\n### SQL:\nSELECT Place FROM table WHERE To par = \u20134", "question": "What Place has a To par of \u20134?", "sql": "SELECT Place FROM table WHERE To par = \u20134", "source": "train"}
53
+ {"text": "### Question:\nName the district for 1994\n\n### SQL:\nSELECT District FROM table WHERE First elected = 1994", "question": "Name the district for 1994", "sql": "SELECT District FROM table WHERE First elected = 1994", "source": "train"}
54
+ {"text": "### Question:\nWhich Body Width/mm has a Lead Pitch/mm smaller than 0.55, and a Part Number of tsop48?\n\n### SQL:\nSELECT MIN Body Width/mm FROM table WHERE Lead Pitch/mm < 0.55 AND Part Number = tsop48", "question": "Which Body Width/mm has a Lead Pitch/mm smaller than 0.55, and a Part Number of tsop48?", "sql": "SELECT MIN Body Width/mm FROM table WHERE Lead Pitch/mm < 0.55 AND Part Number = tsop48", "source": "train"}
55
+ {"text": "### Question:\nWhat is the lowest number of episodes for anabel barnston?\n\n### SQL:\nSELECT MIN Episodes FROM table WHERE Actor = anabel barnston", "question": "What is the lowest number of episodes for anabel barnston?", "sql": "SELECT MIN Episodes FROM table WHERE Actor = anabel barnston", "source": "train"}
56
+ {"text": "### Question:\nFor what league was the player in G position drafted?\n\n### SQL:\nSELECT League from FROM table WHERE Position = g", "question": "For what league was the player in G position drafted?", "sql": "SELECT League from FROM table WHERE Position = g", "source": "train"}
57
+ {"text": "### Question:\nWhat is the last episode which has segment d as blown glass?\n\n### SQL:\nSELECT MAX Episode FROM table WHERE Segment D = Blown Glass", "question": "What is the last episode which has segment d as blown glass?", "sql": "SELECT MAX Episode FROM table WHERE Segment D = Blown Glass", "source": "train"}
58
+ {"text": "### Question:\nWhat is the date for the 10b serial?\n\n### SQL:\nSELECT Date FROM table WHERE Serial = 10b", "question": "What is the date for the 10b serial?", "sql": "SELECT Date FROM table WHERE Serial = 10b", "source": "train"}
59
+ {"text": "### Question:\nWhen \"we're going to disney world (part 1)\" is the title what is the air date?\n\n### SQL:\nSELECT Original air date FROM table WHERE Title = \"We're Going to Disney World (Part 1)\"", "question": "When \"we're going to disney world (part 1)\" is the title what is the air date?", "sql": "SELECT Original air date FROM table WHERE Title = \"We're Going to Disney World (Part 1)\"", "source": "train"}
60
+ {"text": "### Question:\nHow did the School/Club Team of Manuel Luis Quezon acquire their Forward?\n\n### SQL:\nSELECT Acquisition via FROM table WHERE Position = forward AND School/Club Team = manuel luis quezon", "question": "How did the School/Club Team of Manuel Luis Quezon acquire their Forward?", "sql": "SELECT Acquisition via FROM table WHERE Position = forward AND School/Club Team = manuel luis quezon", "source": "train"}
61
+ {"text": "### Question:\nWhat is the Loss has an Attendance more than 43,095 and a Record of 31\u201329?\n\n### SQL:\nSELECT Loss FROM table WHERE Attendance > 43,095 AND Record = 31\u201329", "question": "What is the Loss has an Attendance more than 43,095 and a Record of 31\u201329?", "sql": "SELECT Loss FROM table WHERE Attendance > 43,095 AND Record = 31\u201329", "source": "train"}
62
+ {"text": "### Question:\nWhat is the ethernet ports of the u10 appliance?\n\n### SQL:\nSELECT Ethernet Ports FROM table WHERE Name = u10", "question": "What is the ethernet ports of the u10 appliance?", "sql": "SELECT Ethernet Ports FROM table WHERE Name = u10", "source": "train"}
63
+ {"text": "### Question:\nName the 2007 for 2005 of a and 003 of a with 2009 of sf\n\n### SQL:\nSELECT 2007 FROM table WHERE 2005 = a AND 2003 = a AND 2009 = sf", "question": "Name the 2007 for 2005 of a and 003 of a with 2009 of sf", "sql": "SELECT 2007 FROM table WHERE 2005 = a AND 2003 = a AND 2009 = sf", "source": "train"}
64
+ {"text": "### Question:\nWhat is the constructor for the race with Nigel Mansell as the fastest lap?\n\n### SQL:\nSELECT Constructor FROM table WHERE Fastest Lap = nigel mansell", "question": "What is the constructor for the race with Nigel Mansell as the fastest lap?", "sql": "SELECT Constructor FROM table WHERE Fastest Lap = nigel mansell", "source": "train"}
65
+ {"text": "### Question:\nWhich institution's nickname is the Polar Bears?\n\n### SQL:\nSELECT Institution FROM table WHERE Nickname = Polar Bears", "question": "Which institution's nickname is the Polar Bears?", "sql": "SELECT Institution FROM table WHERE Nickname = Polar Bears", "source": "train"}
66
+ {"text": "### Question:\nWhich Away team has an Away team score of 11.18 (84)?\n\n### SQL:\nSELECT Away team FROM table WHERE Away team score = 11.18 (84)", "question": "Which Away team has an Away team score of 11.18 (84)?", "sql": "SELECT Away team FROM table WHERE Away team score = 11.18 (84)", "source": "train"}
67
+ {"text": "### Question:\nwhat is the current version with license gpl v3?\n\n### SQL:\nSELECT Current version FROM table WHERE License = gpl v3", "question": "what is the current version with license gpl v3?", "sql": "SELECT Current version FROM table WHERE License = gpl v3", "source": "train"}
68
+ {"text": "### Question:\nWhat was the record on April 1?\n\n### SQL:\nSELECT Record FROM table WHERE Date = april 1", "question": "What was the record on April 1?", "sql": "SELECT Record FROM table WHERE Date = april 1", "source": "train"}
69
+ {"text": "### Question:\nWhich Winning score has a Margin of victory of 1 stroke, and a Date of 21 jun 1981?\n\n### SQL:\nSELECT Winning score FROM table WHERE Margin of victory = 1 stroke AND Date = 21 jun 1981", "question": "Which Winning score has a Margin of victory of 1 stroke, and a Date of 21 jun 1981?", "sql": "SELECT Winning score FROM table WHERE Margin of victory = 1 stroke AND Date = 21 jun 1981", "source": "train"}
70
+ {"text": "### Question:\nWhich Avoirdupois value is translated to grain?\n\n### SQL:\nSELECT Avoirdupois value FROM table WHERE Translation = grain", "question": "Which Avoirdupois value is translated to grain?", "sql": "SELECT Avoirdupois value FROM table WHERE Translation = grain", "source": "train"}
71
+ {"text": "### Question:\nName the most margin for nco party and p. ramachandran won\n\n### SQL:\nSELECT MAX Margin FROM table WHERE Party = NCO AND Winner = P. Ramachandran", "question": "Name the most margin for nco party and p. ramachandran won", "sql": "SELECT MAX Margin FROM table WHERE Party = NCO AND Winner = P. Ramachandran", "source": "train"}
72
+ {"text": "### Question:\nWhat is the Catalog with a Date that is february 20, 2002?\n\n### SQL:\nSELECT Catalog FROM table WHERE Date = february 20, 2002", "question": "What is the Catalog with a Date that is february 20, 2002?", "sql": "SELECT Catalog FROM table WHERE Date = february 20, 2002", "source": "train"}
73
+ {"text": "### Question:\nWho is the driver of the chassis-engine porsche 956 gti?\n\n### SQL:\nSELECT Driver FROM table WHERE Chassis \u2013 Engine = porsche 956 gti", "question": "Who is the driver of the chassis-engine porsche 956 gti?", "sql": "SELECT Driver FROM table WHERE Chassis \u2013 Engine = porsche 956 gti", "source": "train"}
74
+ {"text": "### Question:\nWhat is the 1st party with Charles Isaac Elton as the 2nd member?\n\n### SQL:\nSELECT 1st Party FROM table WHERE 2nd Member = charles isaac elton", "question": "What is the 1st party with Charles Isaac Elton as the 2nd member?", "sql": "SELECT 1st Party FROM table WHERE 2nd Member = charles isaac elton", "source": "train"}
75
+ {"text": "### Question:\nWhat is the earliest date of the game with a score of 2-2?\n\n### SQL:\nSELECT MIN Date FROM table WHERE Score = 2-2", "question": "What is the earliest date of the game with a score of 2-2?", "sql": "SELECT MIN Date FROM table WHERE Score = 2-2", "source": "train"}
76
+ {"text": "### Question:\nWhat was the original nfl team that the player was in from the midwestern conference?\n\n### SQL:\nSELECT Original NFL team FROM table WHERE Conf. = midwestern", "question": "What was the original nfl team that the player was in from the midwestern conference?", "sql": "SELECT Original NFL team FROM table WHERE Conf. = midwestern", "source": "train"}
77
+ {"text": "### Question:\nName the total number of domestic mail for 7853 for total frieght and mail\n\n### SQL:\nSELECT COUNT Domestic mail FROM table WHERE Total freight and mail = 7853", "question": "Name the total number of domestic mail for 7853 for total frieght and mail", "sql": "SELECT COUNT Domestic mail FROM table WHERE Total freight and mail = 7853", "source": "train"}
78
+ {"text": "### Question:\nWhat time is listed against the Wrestler Jimmy Rave?\n\n### SQL:\nSELECT Time FROM table WHERE Wrestler = jimmy rave", "question": "What time is listed against the Wrestler Jimmy Rave?", "sql": "SELECT Time FROM table WHERE Wrestler = jimmy rave", "source": "train"}
79
+ {"text": "### Question:\nWhat's the listed average of Cuts made that has a Top-5 of 3, and a Top-10 that's smaller than 5?\n\n### SQL:\nSELECT AVG Cuts made FROM table WHERE Top-5 = 3 AND Top-10 < 5", "question": "What's the listed average of Cuts made that has a Top-5 of 3, and a Top-10 that's smaller than 5?", "sql": "SELECT AVG Cuts made FROM table WHERE Top-5 = 3 AND Top-10 < 5", "source": "train"}
80
+ {"text": "### Question:\nHow large was the crowd at Carlton's home game?\n\n### SQL:\nSELECT COUNT Crowd FROM table WHERE Home team = carlton", "question": "How large was the crowd at Carlton's home game?", "sql": "SELECT COUNT Crowd FROM table WHERE Home team = carlton", "source": "train"}
81
+ {"text": "### Question:\nWhat date did \"The runner\" originally air on?\n\n### SQL:\nSELECT Original air date FROM table WHERE Title = \"The Runner\"", "question": "What date did \"The runner\" originally air on?", "sql": "SELECT Original air date FROM table WHERE Title = \"The Runner\"", "source": "train"}
82
+ {"text": "### Question:\nWhen 10th, south west district 1 is the mens 2nd xi what is the ladies 1st xi?\n\n### SQL:\nSELECT Ladies 1st XI FROM table WHERE Mens 2nd XI = 10th, South West District 1", "question": "When 10th, south west district 1 is the mens 2nd xi what is the ladies 1st xi?", "sql": "SELECT Ladies 1st XI FROM table WHERE Mens 2nd XI = 10th, South West District 1", "source": "train"}
83
+ {"text": "### Question:\nWhat is Show, when Episode Number is 1, when Year is less than 2010, and when Original Airdate is January 20, 2008?\n\n### SQL:\nSELECT Show FROM table WHERE Episode number = 1 AND Year < 2010 AND Original airdate = january 20, 2008", "question": "What is Show, when Episode Number is 1, when Year is less than 2010, and when Original Airdate is January 20, 2008?", "sql": "SELECT Show FROM table WHERE Episode number = 1 AND Year < 2010 AND Original airdate = january 20, 2008", "source": "train"}
84
+ {"text": "### Question:\nWhat is the verb for the Proto-Austronesian word *diri?\n\n### SQL:\nSELECT Verb FROM table WHERE Proto-Austronesian = *diri", "question": "What is the verb for the Proto-Austronesian word *diri?", "sql": "SELECT Verb FROM table WHERE Proto-Austronesian = *diri", "source": "train"}
85
+ {"text": "### Question:\nWhat did winner Gary Player par?\n\n### SQL:\nSELECT To par FROM table WHERE Winner = gary player", "question": "What did winner Gary Player par?", "sql": "SELECT To par FROM table WHERE Winner = gary player", "source": "train"}
86
+ {"text": "### Question:\nWhat notes have 2 as the rank?\n\n### SQL:\nSELECT Notes FROM table WHERE Rank = 2", "question": "What notes have 2 as the rank?", "sql": "SELECT Notes FROM table WHERE Rank = 2", "source": "train"}
87
+ {"text": "### Question:\nWho was the Away Captain when the Home Captain was Joe Darling at Melbourne Cricket Ground?\n\n### SQL:\nSELECT Away captain FROM table WHERE Venue = melbourne cricket ground AND Home captain = joe darling", "question": "Who was the Away Captain when the Home Captain was Joe Darling at Melbourne Cricket Ground?", "sql": "SELECT Away captain FROM table WHERE Venue = melbourne cricket ground AND Home captain = joe darling", "source": "train"}
88
+ {"text": "### Question:\nWhich spacecraft were launched by the Titan II?\n\n### SQL:\nSELECT Spacecraft FROM table WHERE Launcher = titan ii", "question": "Which spacecraft were launched by the Titan II?", "sql": "SELECT Spacecraft FROM table WHERE Launcher = titan ii", "source": "train"}
89
+ {"text": "### Question:\nWhat shows for 3:30 pm when 12:30 pm is the young and the restless?\n\n### SQL:\nSELECT noon FROM table WHERE 12:30 pm = the young and the restless", "question": "What shows for 3:30 pm when 12:30 pm is the young and the restless?", "sql": "SELECT noon FROM table WHERE 12:30 pm = the young and the restless", "source": "train"}
90
+ {"text": "### Question:\nWhat's the fed tax that has a total tax greater than 33.2, a minimum sales tax less than 41.01 and in Vancouver, BC?\n\n### SQL:\nSELECT AVG Federal excise tax ( CAD\u00a2 / L ) FROM table WHERE Total excise tax (CAD\u00a2/L) > 33.2 AND Government = vancouver, bc AND Minimum tax incl. sales taxes (CAD\u00a2/L) < 41.01", "question": "What's the fed tax that has a total tax greater than 33.2, a minimum sales tax less than 41.01 and in Vancouver, BC?", "sql": "SELECT AVG Federal excise tax ( CAD\u00a2 / L ) FROM table WHERE Total excise tax (CAD\u00a2/L) > 33.2 AND Government = vancouver, bc AND Minimum tax incl. sales taxes (CAD\u00a2/L) < 41.01", "source": "train"}
91
+ {"text": "### Question:\nWhat is the highest total number?\n\n### SQL:\nSELECT MAX Total# FROM table", "question": "What is the highest total number?", "sql": "SELECT MAX Total# FROM table", "source": "train"}
92
+ {"text": "### Question:\nWhat is the score of the match that was against alberto berasategui?\n\n### SQL:\nSELECT Score in the final FROM table WHERE Opponent in the final = alberto berasategui", "question": "What is the score of the match that was against alberto berasategui?", "sql": "SELECT Score in the final FROM table WHERE Opponent in the final = alberto berasategui", "source": "train"}
93
+ {"text": "### Question:\nWhat is the newest Cap with a Goals stat larger than 17 and which was done by Brian Turner?\n\n### SQL:\nSELECT Most Recent Cap FROM table WHERE Goals > 17 AND Name = brian turner", "question": "What is the newest Cap with a Goals stat larger than 17 and which was done by Brian Turner?", "sql": "SELECT Most Recent Cap FROM table WHERE Goals > 17 AND Name = brian turner", "source": "train"}
94
+ {"text": "### Question:\nWhat was the losing score on September 1?\n\n### SQL:\nSELECT Loss FROM table WHERE Date = september 1", "question": "What was the losing score on September 1?", "sql": "SELECT Loss FROM table WHERE Date = september 1", "source": "train"}
95
+ {"text": "### Question:\nWhat is the score of the away team whose opponent scored 14.8 (92)?\n\n### SQL:\nSELECT Away team score FROM table WHERE Home team score = 14.8 (92)", "question": "What is the score of the away team whose opponent scored 14.8 (92)?", "sql": "SELECT Away team score FROM table WHERE Home team score = 14.8 (92)", "source": "train"}
96
+ {"text": "### Question:\nWhat are donor payments in the country where there are 12 children to 6 families (2 per family)?\n\n### SQL:\nSELECT Donor payment FROM table WHERE Children per donor = 12 children to 6 families (2 per family)", "question": "What are donor payments in the country where there are 12 children to 6 families (2 per family)?", "sql": "SELECT Donor payment FROM table WHERE Children per donor = 12 children to 6 families (2 per family)", "source": "train"}
97
+ {"text": "### Question:\nWhat was the original air date for the episode with production code 1wab06?\n\n### SQL:\nSELECT Originalairdate FROM table WHERE Production code = 1WAB06", "question": "What was the original air date for the episode with production code 1wab06?", "sql": "SELECT Originalairdate FROM table WHERE Production code = 1WAB06", "source": "train"}
98
+ {"text": "### Question:\nWhat was the total for David O'Callaghan, and a Tally of 1-9?\n\n### SQL:\nSELECT SUM Total FROM table WHERE Player = david o'callaghan AND Tally = 1-9", "question": "What was the total for David O'Callaghan, and a Tally of 1-9?", "sql": "SELECT SUM Total FROM table WHERE Player = david o'callaghan AND Tally = 1-9", "source": "train"}
99
+ {"text": "### Question:\nAttendance larger than 17,001, and a Date of june 15 had what decision?\n\n### SQL:\nSELECT Decision FROM table WHERE Attendance > 17,001 AND Date = june 15", "question": "Attendance larger than 17,001, and a Date of june 15 had what decision?", "sql": "SELECT Decision FROM table WHERE Attendance > 17,001 AND Date = june 15", "source": "train"}
100
+ {"text": "### Question:\nWhat is Album Artist, when Song is \"\"Something New\" (with Mint Royale and Class A)\"?\n\n### SQL:\nSELECT Album artist FROM table WHERE Song = \"something new\" (with mint royale and class a)", "question": "What is Album Artist, when Song is \"\"Something New\" (with Mint Royale and Class A)\"?", "sql": "SELECT Album artist FROM table WHERE Song = \"something new\" (with mint royale and class a)", "source": "train"}
src/outputs/finetuning/val.jsonl ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {"text": "### Question:\nName the play for 1976\n\n### SQL:\nSELECT Play FROM table WHERE Year = 1976", "question": "Name the play for 1976", "sql": "SELECT Play FROM table WHERE Year = 1976", "source": "validation"}
2
+ {"text": "### Question:\nwhat are all the playoffs for u.s. open cup in 1st round\n\n### SQL:\nSELECT Playoffs FROM table WHERE U.S. Open Cup = 1st Round", "question": "what are all the playoffs for u.s. open cup in 1st round", "sql": "SELECT Playoffs FROM table WHERE U.S. Open Cup = 1st Round", "source": "validation"}
3
+ {"text": "### Question:\nWhat is the location of the game that has a number smaller than 2?\n\n### SQL:\nSELECT Location FROM table WHERE Game < 2", "question": "What is the location of the game that has a number smaller than 2?", "sql": "SELECT Location FROM table WHERE Game < 2", "source": "validation"}
4
+ {"text": "### Question:\nWhat is 2004, when 2005 is \"Not Tier I\"?\n\n### SQL:\nSELECT 2004 FROM table WHERE 2005 = not tier i", "question": "What is 2004, when 2005 is \"Not Tier I\"?", "sql": "SELECT 2004 FROM table WHERE 2005 = not tier i", "source": "validation"}
5
+ {"text": "### Question:\nWhich venue led to a result of 13th and had an extra of Long Race?\n\n### SQL:\nSELECT Venue FROM table WHERE Extra = long race AND Result = 13th", "question": "Which venue led to a result of 13th and had an extra of Long Race?", "sql": "SELECT Venue FROM table WHERE Extra = long race AND Result = 13th", "source": "validation"}
6
+ {"text": "### Question:\nWhat was Anders Forsbrand's score when the TO par is +4?\n\n### SQL:\nSELECT Score FROM table WHERE To par = +4 AND Player = anders forsbrand", "question": "What was Anders Forsbrand's score when the TO par is +4?", "sql": "SELECT Score FROM table WHERE To par = +4 AND Player = anders forsbrand", "source": "validation"}
7
+ {"text": "### Question:\nWhat was the attendance of the game that had an away team of FK Mogren?\n\n### SQL:\nSELECT Attendance FROM table WHERE Guest = fk mogren", "question": "What was the attendance of the game that had an away team of FK Mogren?", "sql": "SELECT Attendance FROM table WHERE Guest = fk mogren", "source": "validation"}
8
+ {"text": "### Question:\nWhat is the Air Date that has a 18\u201349 larger than 1.9, less than 7.54 viewers and a rating less than 4.9?\n\n### SQL:\nSELECT Air Date FROM table WHERE 18\u201349 > 1.9 AND Viewers < 7.54 AND Rating < 4.9", "question": "What is the Air Date that has a 18\u201349 larger than 1.9, less than 7.54 viewers and a rating less than 4.9?", "sql": "SELECT Air Date FROM table WHERE 18\u201349 > 1.9 AND Viewers < 7.54 AND Rating < 4.9", "source": "validation"}
9
+ {"text": "### Question:\nWhich Avg/G is the lowest one that has a Long smaller than 47, and a Name of frank murphy, and a Gain smaller than 569?\n\n### SQL:\nSELECT MIN Avg/G FROM table WHERE Long < 47 AND Name = frank murphy AND Gain < 569", "question": "Which Avg/G is the lowest one that has a Long smaller than 47, and a Name of frank murphy, and a Gain smaller than 569?", "sql": "SELECT MIN Avg/G FROM table WHERE Long < 47 AND Name = frank murphy AND Gain < 569", "source": "validation"}
10
+ {"text": "### Question:\nWhich rank has 1 silver medal and more than 1 gold medal?\n\n### SQL:\nSELECT Rank FROM table WHERE Silver = 1 AND Gold > 1", "question": "Which rank has 1 silver medal and more than 1 gold medal?", "sql": "SELECT Rank FROM table WHERE Silver = 1 AND Gold > 1", "source": "validation"}
11
+ {"text": "### Question:\nName the number of candidates for # of seats won being 43\n\n### SQL:\nSELECT # of candidates FROM table WHERE # of seats won = 43", "question": "Name the number of candidates for # of seats won being 43", "sql": "SELECT # of candidates FROM table WHERE # of seats won = 43", "source": "validation"}
12
+ {"text": "### Question:\nWhat is the home team score at lake oval?\n\n### SQL:\nSELECT Home team score FROM table WHERE Venue = lake oval", "question": "What is the home team score at lake oval?", "sql": "SELECT Home team score FROM table WHERE Venue = lake oval", "source": "validation"}
13
+ {"text": "### Question:\nWhich loss has an attendance greater than 49,688 and 11-8 as the record?\n\n### SQL:\nSELECT Loss FROM table WHERE Attendance > 49,688 AND Record = 11-8", "question": "Which loss has an attendance greater than 49,688 and 11-8 as the record?", "sql": "SELECT Loss FROM table WHERE Attendance > 49,688 AND Record = 11-8", "source": "validation"}
14
+ {"text": "### Question:\nWhat is the sum of pick# for Don Majkowski?3\n\n### SQL:\nSELECT SUM Pick # FROM table WHERE Player = don majkowski", "question": "What is the sum of pick# for Don Majkowski?3", "sql": "SELECT SUM Pick # FROM table WHERE Player = don majkowski", "source": "validation"}
15
+ {"text": "### Question:\nWhat is the total number of wins for riders with fewer than 56 races and more than 0 titles?\n\n### SQL:\nSELECT COUNT Wins FROM table WHERE Races < 56 AND Titles > 0", "question": "What is the total number of wins for riders with fewer than 56 races and more than 0 titles?", "sql": "SELECT COUNT Wins FROM table WHERE Races < 56 AND Titles > 0", "source": "validation"}
16
+ {"text": "### Question:\nHow much did the girl, nicknamed Chidi, weigh at birth?\n\n### SQL:\nSELECT Weight at birth FROM table WHERE Gender = girl AND Nickname = chidi", "question": "How much did the girl, nicknamed Chidi, weigh at birth?", "sql": "SELECT Weight at birth FROM table WHERE Gender = girl AND Nickname = chidi", "source": "validation"}
17
+ {"text": "### Question:\nOn which apparatus did Kanayeva have a final score smaller than 75.5 and a qualifying score smaller than 18.7?\n\n### SQL:\nSELECT Apparatus FROM table WHERE Score-Final < 75.5 AND Score-Qualifying < 18.7", "question": "On which apparatus did Kanayeva have a final score smaller than 75.5 and a qualifying score smaller than 18.7?", "sql": "SELECT Apparatus FROM table WHERE Score-Final < 75.5 AND Score-Qualifying < 18.7", "source": "validation"}
18
+ {"text": "### Question:\nWhat are the rounds for the B tyres and Ferrari 053 engine +?\n\n### SQL:\nSELECT Rounds FROM table WHERE Tyre = b AND Engine \u2020 = ferrari 053", "question": "What are the rounds for the B tyres and Ferrari 053 engine +?", "sql": "SELECT Rounds FROM table WHERE Tyre = b AND Engine \u2020 = ferrari 053", "source": "validation"}
19
+ {"text": "### Question:\nWhat is the sexual abuse rate where the conflict is the Burundi Civil War?\n\n### SQL:\nSELECT MAX Sexual abuse 1 FROM table WHERE Conflict = Burundi Civil War", "question": "What is the sexual abuse rate where the conflict is the Burundi Civil War?", "sql": "SELECT MAX Sexual abuse 1 FROM table WHERE Conflict = Burundi Civil War", "source": "validation"}
20
+ {"text": "### Question:\nWhen 0-1 is the series who has the highest amount of assists?\n\n### SQL:\nSELECT High assists FROM table WHERE Series = 0-1", "question": "When 0-1 is the series who has the highest amount of assists?", "sql": "SELECT High assists FROM table WHERE Series = 0-1", "source": "validation"}
21
+ {"text": "### Question:\nWhich MLS team has the #41 pick?\n\n### SQL:\nSELECT MLS team FROM table WHERE Pick # = 41", "question": "Which MLS team has the #41 pick?", "sql": "SELECT MLS team FROM table WHERE Pick # = 41", "source": "validation"}
22
+ {"text": "### Question:\nWhat is the most bronze can be when silver is larger than 2, and the nation is germany, and gold is more than 8?\n\n### SQL:\nSELECT MAX Bronze FROM table WHERE Silver > 2 AND Nation = germany AND Gold > 8", "question": "What is the most bronze can be when silver is larger than 2, and the nation is germany, and gold is more than 8?", "sql": "SELECT MAX Bronze FROM table WHERE Silver > 2 AND Nation = germany AND Gold > 8", "source": "validation"}
23
+ {"text": "### Question:\nWhat dates contained matches at the venue Bourda?\n\n### SQL:\nSELECT Date FROM table WHERE Venue = bourda", "question": "What dates contained matches at the venue Bourda?", "sql": "SELECT Date FROM table WHERE Venue = bourda", "source": "validation"}
24
+ {"text": "### Question:\nWhat is the minimum amount of poles?\n\n### SQL:\nSELECT MIN Poles FROM table", "question": "What is the minimum amount of poles?", "sql": "SELECT MIN Poles FROM table", "source": "validation"}
25
+ {"text": "### Question:\nWhat is the average Episode # with a 7 share and 18\u201349 is less than 2 and the Air Date of may 21, 2009?\n\n### SQL:\nSELECT AVG Episode # FROM table WHERE Share = 7 AND 18\u201349 < 2 AND Air Date = may 21, 2009", "question": "What is the average Episode # with a 7 share and 18\u201349 is less than 2 and the Air Date of may 21, 2009?", "sql": "SELECT AVG Episode # FROM table WHERE Share = 7 AND 18\u201349 < 2 AND Air Date = may 21, 2009", "source": "validation"}
26
+ {"text": "### Question:\nWhat is the most lost games for the team with a difference smaller than 86 and points of 32?\n\n### SQL:\nSELECT MAX Lost FROM table WHERE Points = 32 AND Difference < 86", "question": "What is the most lost games for the team with a difference smaller than 86 and points of 32?", "sql": "SELECT MAX Lost FROM table WHERE Points = 32 AND Difference < 86", "source": "validation"}
27
+ {"text": "### Question:\nwhat's the\u00a0first elected\u00a0with\u00a0district\u00a0being florida 7\n\n### SQL:\nSELECT First elected FROM table WHERE District = Florida 7", "question": "what's the\u00a0first elected\u00a0with\u00a0district\u00a0being florida 7", "sql": "SELECT First elected FROM table WHERE District = Florida 7", "source": "validation"}
28
+ {"text": "### Question:\nWhat is the average number of points for a song ranked 2nd with a draw greater than 3?\n\n### SQL:\nSELECT AVG Points FROM table WHERE Rank = 2nd AND Draw > 3", "question": "What is the average number of points for a song ranked 2nd with a draw greater than 3?", "sql": "SELECT AVG Points FROM table WHERE Rank = 2nd AND Draw > 3", "source": "validation"}
29
+ {"text": "### Question:\nWhat is the destination when the train number is 16526?\n\n### SQL:\nSELECT Destination FROM table WHERE Train number = 16526", "question": "What is the destination when the train number is 16526?", "sql": "SELECT Destination FROM table WHERE Train number = 16526", "source": "validation"}
30
+ {"text": "### Question:\nWhat is the highest game that has 32 points and a team rank larger than 4 named montepaschi siena\n\n### SQL:\nSELECT MAX Games FROM table WHERE Points = 32 AND Team = montepaschi siena AND Rank > 4", "question": "What is the highest game that has 32 points and a team rank larger than 4 named montepaschi siena", "sql": "SELECT MAX Games FROM table WHERE Points = 32 AND Team = montepaschi siena AND Rank > 4", "source": "validation"}
31
+ {"text": "### Question:\nWhat is Episode, when Jeremy's Guest is \"Pauline McLynn\"?\n\n### SQL:\nSELECT Episode FROM table WHERE Jeremy's guest = pauline mclynn", "question": "What is Episode, when Jeremy's Guest is \"Pauline McLynn\"?", "sql": "SELECT Episode FROM table WHERE Jeremy's guest = pauline mclynn", "source": "validation"}
32
+ {"text": "### Question:\nWhat is the poor law union of the Kilmaloda townland?\n\n### SQL:\nSELECT Poor law union FROM table WHERE Townland = Kilmaloda", "question": "What is the poor law union of the Kilmaloda townland?", "sql": "SELECT Poor law union FROM table WHERE Townland = Kilmaloda", "source": "validation"}
33
+ {"text": "### Question:\nWhat is the largest pick in round 8?\n\n### SQL:\nSELECT MAX Pick FROM table WHERE Round = 8", "question": "What is the largest pick in round 8?", "sql": "SELECT MAX Pick FROM table WHERE Round = 8", "source": "validation"}
34
+ {"text": "### Question:\nOn what date was the attendance at TD Garden 18,624?\n\n### SQL:\nSELECT Date FROM table WHERE Location Attendance = TD Garden 18,624", "question": "On what date was the attendance at TD Garden 18,624?", "sql": "SELECT Date FROM table WHERE Location Attendance = TD Garden 18,624", "source": "validation"}
35
+ {"text": "### Question:\nWhat is canada's margin?\n\n### SQL:\nSELECT SUM Margin FROM table WHERE Country = canada", "question": "What is canada's margin?", "sql": "SELECT SUM Margin FROM table WHERE Country = canada", "source": "validation"}
36
+ {"text": "### Question:\nWhat Sweet Sixteen team is in the Colonial conference?\n\n### SQL:\nSELECT Sweet Sixteen FROM table WHERE Conference = colonial", "question": "What Sweet Sixteen team is in the Colonial conference?", "sql": "SELECT Sweet Sixteen FROM table WHERE Conference = colonial", "source": "validation"}
37
+ {"text": "### Question:\nHow many resorts have 118 runs?\n\n### SQL:\nSELECT COUNT Name FROM table WHERE Runs = 118", "question": "How many resorts have 118 runs?", "sql": "SELECT COUNT Name FROM table WHERE Runs = 118", "source": "validation"}
38
+ {"text": "### Question:\nWho was the winner against Lindsay Davenport?\n\n### SQL:\nSELECT Winner FROM table WHERE Finalist = lindsay davenport", "question": "Who was the winner against Lindsay Davenport?", "sql": "SELECT Winner FROM table WHERE Finalist = lindsay davenport", "source": "validation"}
39
+ {"text": "### Question:\nHow many laps for a grid larger than 1 with a Time/Retired of halfshaft?\n\n### SQL:\nSELECT Laps FROM table WHERE Grid > 1 AND Time/Retired = halfshaft", "question": "How many laps for a grid larger than 1 with a Time/Retired of halfshaft?", "sql": "SELECT Laps FROM table WHERE Grid > 1 AND Time/Retired = halfshaft", "source": "validation"}
40
+ {"text": "### Question:\nIn what year was the feature at a 33.3S latitude named? \n\n### SQL:\nSELECT MAX Year named FROM table WHERE Latitude = 33.3S", "question": "In what year was the feature at a 33.3S latitude named? ", "sql": "SELECT MAX Year named FROM table WHERE Latitude = 33.3S", "source": "validation"}
41
+ {"text": "### Question:\nWhich Thirds (Under 17's) have a Reserve of barnawartha?\n\n### SQL:\nSELECT Thirds (Under 17's) FROM table WHERE Reserves = barnawartha", "question": "Which Thirds (Under 17's) have a Reserve of barnawartha?", "sql": "SELECT Thirds (Under 17's) FROM table WHERE Reserves = barnawartha", "source": "validation"}
42
+ {"text": "### Question:\nWhat was the outcome of the match against Stacy Margolin?\n\n### SQL:\nSELECT Outcome FROM table WHERE Opponent = stacy margolin", "question": "What was the outcome of the match against Stacy Margolin?", "sql": "SELECT Outcome FROM table WHERE Opponent = stacy margolin", "source": "validation"}
43
+ {"text": "### Question:\nIf the working force of HK is 10.4%, what is the salary range?\n\n### SQL:\nSELECT Salary range FROM table WHERE Working force of HK = 10.4%", "question": "If the working force of HK is 10.4%, what is the salary range?", "sql": "SELECT Salary range FROM table WHERE Working force of HK = 10.4%", "source": "validation"}
44
+ {"text": "### Question:\nWhat is the sum of the pick from texas a&i college with a round greater than 1?\n\n### SQL:\nSELECT SUM Pick FROM table WHERE College = texas a&i AND Round > 1", "question": "What is the sum of the pick from texas a&i college with a round greater than 1?", "sql": "SELECT SUM Pick FROM table WHERE College = texas a&i AND Round > 1", "source": "validation"}
45
+ {"text": "### Question:\nWhich Second has a Lead of ben hebert?\n\n### SQL:\nSELECT Second FROM table WHERE Lead = ben hebert", "question": "Which Second has a Lead of ben hebert?", "sql": "SELECT Second FROM table WHERE Lead = ben hebert", "source": "validation"}
46
+ {"text": "### Question:\nWhich Genre has a Game of donkey kong country?\n\n### SQL:\nSELECT Genre FROM table WHERE Game = donkey kong country", "question": "Which Genre has a Game of donkey kong country?", "sql": "SELECT Genre FROM table WHERE Game = donkey kong country", "source": "validation"}
47
+ {"text": "### Question:\nWhat is the location of the Carousel toll plaza?\n\n### SQL:\nSELECT Location FROM table WHERE Name = Carousel Toll Plaza", "question": "What is the location of the Carousel toll plaza?", "sql": "SELECT Location FROM table WHERE Name = Carousel Toll Plaza", "source": "validation"}
48
+ {"text": "### Question:\nWhat is Turkey's average Gold entry that also has a Bronze entry that is smaller than 2 and the Total is greater than 1?\n\n### SQL:\nSELECT AVG Gold FROM table WHERE Bronze < 2 AND Nation = turkey AND Total > 1", "question": "What is Turkey's average Gold entry that also has a Bronze entry that is smaller than 2 and the Total is greater than 1?", "sql": "SELECT AVG Gold FROM table WHERE Bronze < 2 AND Nation = turkey AND Total > 1", "source": "validation"}
49
+ {"text": "### Question:\nWhich Class has a Quantity made of 29?\n\n### SQL:\nSELECT Class FROM table WHERE Quantity made = 29", "question": "Which Class has a Quantity made of 29?", "sql": "SELECT Class FROM table WHERE Quantity made = 29", "source": "validation"}
50
+ {"text": "### Question:\nWhich Oberliga Bayern has a Season of 1981-82?\n\n### SQL:\nSELECT Oberliga Bayern FROM table WHERE Season = 1981-82", "question": "Which Oberliga Bayern has a Season of 1981-82?", "sql": "SELECT Oberliga Bayern FROM table WHERE Season = 1981-82", "source": "validation"}
51
+ {"text": "### Question:\nWhat is the number of podiums with 0 wins, 0 F.L. and 35 points?\n\n### SQL:\nSELECT Podiums FROM table WHERE Wins = 0 AND F.L. = 0 AND Points = 35", "question": "What is the number of podiums with 0 wins, 0 F.L. and 35 points?", "sql": "SELECT Podiums FROM table WHERE Wins = 0 AND F.L. = 0 AND Points = 35", "source": "validation"}
52
+ {"text": "### Question:\nWho was the publisher of Martial Law: Dead Ringers?\n\n### SQL:\nSELECT Publisher FROM table WHERE Release title = martial law: dead ringers", "question": "Who was the publisher of Martial Law: Dead Ringers?", "sql": "SELECT Publisher FROM table WHERE Release title = martial law: dead ringers", "source": "validation"}
53
+ {"text": "### Question:\nWhat is the Almali village with the S\u00fcsk\u0259n village z\u0259rn\u0259?\n\n### SQL:\nSELECT Almal\u0131 (Qax) FROM table WHERE S\u00fcsk\u0259n = z\u0259rn\u0259", "question": "What is the Almali village with the S\u00fcsk\u0259n village z\u0259rn\u0259?", "sql": "SELECT Almal\u0131 (Qax) FROM table WHERE S\u00fcsk\u0259n = z\u0259rn\u0259", "source": "validation"}
54
+ {"text": "### Question:\nName the typed for formed from 6-pul trailer third in res unit\n\n### SQL:\nSELECT Type FROM table WHERE Formed from = 6-pul trailer third in res unit", "question": "Name the typed for formed from 6-pul trailer third in res unit", "sql": "SELECT Type FROM table WHERE Formed from = 6-pul trailer third in res unit", "source": "validation"}
55
+ {"text": "### Question:\nName the 2009/10 with 2011/12 of lq and 2008/09 of not held\n\n### SQL:\nSELECT 2009/ 10 FROM table WHERE 2011/ 12 = lq AND 2008/ 09 = not held", "question": "Name the 2009/10 with 2011/12 of lq and 2008/09 of not held", "sql": "SELECT 2009/ 10 FROM table WHERE 2011/ 12 = lq AND 2008/ 09 = not held", "source": "validation"}
56
+ {"text": "### Question:\nWhat is the tyres for the JBW type 2 chassis?\n\n### SQL:\nSELECT Tyres FROM table WHERE Chassis = jbw type 2", "question": "What is the tyres for the JBW type 2 chassis?", "sql": "SELECT Tyres FROM table WHERE Chassis = jbw type 2", "source": "validation"}
57
+ {"text": "### Question:\nHow many total appearances (league only) have a name of gavin dykes?\n\n### SQL:\nSELECT Total Appearances(league only) FROM table WHERE Name = gavin dykes", "question": "How many total appearances (league only) have a name of gavin dykes?", "sql": "SELECT Total Appearances(league only) FROM table WHERE Name = gavin dykes", "source": "validation"}
58
+ {"text": "### Question:\nWhat is the sum of laps that has a car number of larger than 1, is a ford, and has 155 points?\n\n### SQL:\nSELECT SUM Laps FROM table WHERE Car # > 1 AND Make = ford AND Points = 155", "question": "What is the sum of laps that has a car number of larger than 1, is a ford, and has 155 points?", "sql": "SELECT SUM Laps FROM table WHERE Car # > 1 AND Make = ford AND Points = 155", "source": "validation"}
59
+ {"text": "### Question:\nWhat was the average crowd size of games held at Glenferrie Oval?\n\n### SQL:\nSELECT AVG Crowd FROM table WHERE Venue = glenferrie oval", "question": "What was the average crowd size of games held at Glenferrie Oval?", "sql": "SELECT AVG Crowd FROM table WHERE Venue = glenferrie oval", "source": "validation"}
60
+ {"text": "### Question:\nWhat version of iWork was released on October 22, 2013 with a pages version greater than 2?\n\n### SQL:\nSELECT iWork version FROM table WHERE Release date = october 22, 2013 AND Pages version > 2", "question": "What version of iWork was released on October 22, 2013 with a pages version greater than 2?", "sql": "SELECT iWork version FROM table WHERE Release date = october 22, 2013 AND Pages version > 2", "source": "validation"}
61
+ {"text": "### Question:\nName the player for chicago black hawks\n\n### SQL:\nSELECT Player FROM table WHERE NHL team = Chicago Black Hawks", "question": "Name the player for chicago black hawks", "sql": "SELECT Player FROM table WHERE NHL team = Chicago Black Hawks", "source": "validation"}
62
+ {"text": "### Question:\nWhat is the streak for game 2?\n\n### SQL:\nSELECT Streak FROM table WHERE Game = 2", "question": "What is the streak for game 2?", "sql": "SELECT Streak FROM table WHERE Game = 2", "source": "validation"}
63
+ {"text": "### Question:\nI want the D 45 and D 42 of r 22\n\n### SQL:\nSELECT D 45 FROM table WHERE D 42 = r 22", "question": "I want the D 45 and D 42 of r 22", "sql": "SELECT D 45 FROM table WHERE D 42 = r 22", "source": "validation"}
64
+ {"text": "### Question:\nWho was the away team when Queensland Roar was the home team in the round less than 3?\n\n### SQL:\nSELECT Away Team FROM table WHERE Round < 3 AND Home Team = queensland roar", "question": "Who was the away team when Queensland Roar was the home team in the round less than 3?", "sql": "SELECT Away Team FROM table WHERE Round < 3 AND Home Team = queensland roar", "source": "validation"}
65
+ {"text": "### Question:\nHow many artists were there for the show thoroughly modern millie?\n\n### SQL:\nSELECT COUNT Artist FROM table WHERE Show = Thoroughly Modern Millie", "question": "How many artists were there for the show thoroughly modern millie?", "sql": "SELECT COUNT Artist FROM table WHERE Show = Thoroughly Modern Millie", "source": "validation"}
66
+ {"text": "### Question:\nWhich wrestling event was at the 2008 Beijing games?\n\n### SQL:\nSELECT Event FROM table WHERE Sport = wrestling AND Games = 2008 beijing", "question": "Which wrestling event was at the 2008 Beijing games?", "sql": "SELECT Event FROM table WHERE Sport = wrestling AND Games = 2008 beijing", "source": "validation"}
67
+ {"text": "### Question:\nWho was the opponent in London, England in a round less than 2?\n\n### SQL:\nSELECT Opponent FROM table WHERE Location = london, england AND Round < 2", "question": "Who was the opponent in London, England in a round less than 2?", "sql": "SELECT Opponent FROM table WHERE Location = london, england AND Round < 2", "source": "validation"}
68
+ {"text": "### Question:\nWhat is the ceremony year when Ganito Kami Noon, Paano Kayo Ngayon was the original title?\n\n### SQL:\nSELECT Year (Ceremony) FROM table WHERE Original title = ganito kami noon, paano kayo ngayon", "question": "What is the ceremony year when Ganito Kami Noon, Paano Kayo Ngayon was the original title?", "sql": "SELECT Year (Ceremony) FROM table WHERE Original title = ganito kami noon, paano kayo ngayon", "source": "validation"}
69
+ {"text": "### Question:\nWhen the total score is 740, what is tromso?\n\n### SQL:\nSELECT MIN Troms\u00f8 FROM table WHERE Total = 740", "question": "When the total score is 740, what is tromso?", "sql": "SELECT MIN Troms\u00f8 FROM table WHERE Total = 740", "source": "validation"}
70
+ {"text": "### Question:\nWhat is the result for director Said Elmarouk before 2008?\n\n### SQL:\nSELECT Result FROM table WHERE Director = said elmarouk AND Year < 2008", "question": "What is the result for director Said Elmarouk before 2008?", "sql": "SELECT Result FROM table WHERE Director = said elmarouk AND Year < 2008", "source": "validation"}
71
+ {"text": "### Question:\nWhen was the score 56-26?\n\n### SQL:\nSELECT Date FROM table WHERE Record = 56-26", "question": "When was the score 56-26?", "sql": "SELECT Date FROM table WHERE Record = 56-26", "source": "validation"}
72
+ {"text": "### Question:\nName the D 44 when it has a D 46 of d 31\n\n### SQL:\nSELECT D 44 FROM table WHERE D 46 = d 31", "question": "Name the D 44 when it has a D 46 of d 31", "sql": "SELECT D 44 FROM table WHERE D 46 = d 31", "source": "validation"}
73
+ {"text": "### Question:\nWhat was the date of the race that lasted 6 hours?\n\n### SQL:\nSELECT Date FROM table WHERE Length/Duration = 6 hours", "question": "What was the date of the race that lasted 6 hours?", "sql": "SELECT Date FROM table WHERE Length/Duration = 6 hours", "source": "validation"}
74
+ {"text": "### Question:\nWhich event is in the 1952 summer olympics?\n\n### SQL:\nSELECT Event FROM table WHERE Olympics = 1952 summer olympics", "question": "Which event is in the 1952 summer olympics?", "sql": "SELECT Event FROM table WHERE Olympics = 1952 summer olympics", "source": "validation"}
75
+ {"text": "### Question:\n the 2010 clausura tournament?\n\n### SQL:\nSELECT Coefficient FROM table WHERE Tournament = 2010 Clausura", "question": " the 2010 clausura tournament?", "sql": "SELECT Coefficient FROM table WHERE Tournament = 2010 Clausura", "source": "validation"}
76
+ {"text": "### Question:\nWhat was the score of the BCS National Championship game?\n\n### SQL:\nSELECT Score FROM table WHERE Bowl Game = bcs national championship", "question": "What was the score of the BCS National Championship game?", "sql": "SELECT Score FROM table WHERE Bowl Game = bcs national championship", "source": "validation"}
77
+ {"text": "### Question:\nWhat was the attendance when their record stood at 0-2-2?\n\n### SQL:\nSELECT SUM Attendance FROM table WHERE Record = 0-2-2", "question": "What was the attendance when their record stood at 0-2-2?", "sql": "SELECT SUM Attendance FROM table WHERE Record = 0-2-2", "source": "validation"}
78
+ {"text": "### Question:\nWhat were the results before the year 2000?\n\n### SQL:\nSELECT Result FROM table WHERE Year < 2000", "question": "What were the results before the year 2000?", "sql": "SELECT Result FROM table WHERE Year < 2000", "source": "validation"}
79
+ {"text": "### Question:\nHow much time is required for less than 35 laps and less than 10 grids?\n\n### SQL:\nSELECT Time/Retired FROM table WHERE Laps < 35 AND Grid < 10", "question": "How much time is required for less than 35 laps and less than 10 grids?", "sql": "SELECT Time/Retired FROM table WHERE Laps < 35 AND Grid < 10", "source": "validation"}
80
+ {"text": "### Question:\nWhen oslo is 48, what is stavanger?\n\n### SQL:\nSELECT MIN Stavanger FROM table WHERE Oslo = 48", "question": "When oslo is 48, what is stavanger?", "sql": "SELECT MIN Stavanger FROM table WHERE Oslo = 48", "source": "validation"}
81
+ {"text": "### Question:\nWhen was the year that had an average attendance of 5,445?\n\n### SQL:\nSELECT Year FROM table WHERE Avg. attendance = 5,445", "question": "When was the year that had an average attendance of 5,445?", "sql": "SELECT Year FROM table WHERE Avg. attendance = 5,445", "source": "validation"}
82
+ {"text": "### Question:\nFor the game with 528 attendance, what was the result?\n\n### SQL:\nSELECT Result FROM table WHERE Attendance = 528", "question": "For the game with 528 attendance, what was the result?", "sql": "SELECT Result FROM table WHERE Attendance = 528", "source": "validation"}
83
+ {"text": "### Question:\nWhat dated was the game played at the location delta center 19,911?\n\n### SQL:\nSELECT Date FROM table WHERE Location Attendance = Delta Center 19,911", "question": "What dated was the game played at the location delta center 19,911?", "sql": "SELECT Date FROM table WHERE Location Attendance = Delta Center 19,911", "source": "validation"}
84
+ {"text": "### Question:\nWhat is the ISBN of \"Dead as a Doornail?\n\n### SQL:\nSELECT Paperback FROM table WHERE Title = Dead as a Doornail", "question": "What is the ISBN of \"Dead as a Doornail?", "sql": "SELECT Paperback FROM table WHERE Title = Dead as a Doornail", "source": "validation"}
85
+ {"text": "### Question:\nWhat scored is recorded on April 24?\n\n### SQL:\nSELECT Score FROM table WHERE Date = april 24", "question": "What scored is recorded on April 24?", "sql": "SELECT Score FROM table WHERE Date = april 24", "source": "validation"}
86
+ {"text": "### Question:\nWho acquired tom norton?\n\n### SQL:\nSELECT Acquired FROM table WHERE Player = tom norton", "question": "Who acquired tom norton?", "sql": "SELECT Acquired FROM table WHERE Player = tom norton", "source": "validation"}
87
+ {"text": "### Question:\nWHAT IS THE HIGHEST VIEWERS WITH AN EPISODE LESS THAN 15 AND SHARE LAGER THAN 7?\n\n### SQL:\nSELECT MAX Viewers (millions) FROM table WHERE Episode number < 15 AND Share > 7", "question": "WHAT IS THE HIGHEST VIEWERS WITH AN EPISODE LESS THAN 15 AND SHARE LAGER THAN 7?", "sql": "SELECT MAX Viewers (millions) FROM table WHERE Episode number < 15 AND Share > 7", "source": "validation"}
88
+ {"text": "### Question:\nWhat is the English name given to the city of St. John's?\n\n### SQL:\nSELECT Capital ( exonym ) FROM table WHERE Capital ( endonym ) = St. John's", "question": "What is the English name given to the city of St. John's?", "sql": "SELECT Capital ( exonym ) FROM table WHERE Capital ( endonym ) = St. John's", "source": "validation"}
89
+ {"text": "### Question:\nWhat was the result of the game played on November 23, 2003?\n\n### SQL:\nSELECT Result FROM table WHERE Date = november 23, 2003", "question": "What was the result of the game played on November 23, 2003?", "sql": "SELECT Result FROM table WHERE Date = november 23, 2003", "source": "validation"}
90
+ {"text": "### Question:\nWho directed An Egg Scramble?\n\n### SQL:\nSELECT Director FROM table WHERE Title = an egg scramble", "question": "Who directed An Egg Scramble?", "sql": "SELECT Director FROM table WHERE Title = an egg scramble", "source": "validation"}
91
+ {"text": "### Question:\nWhat is the Bulgarian Commander of the Battle of Rusion?\n\n### SQL:\nSELECT Bulgarian Commander FROM table WHERE Battle = battle of rusion", "question": "What is the Bulgarian Commander of the Battle of Rusion?", "sql": "SELECT Bulgarian Commander FROM table WHERE Battle = battle of rusion", "source": "validation"}
92
+ {"text": "### Question:\nWhat was the location of the game when the record was 12-4?\n\n### SQL:\nSELECT Location FROM table WHERE Record = 12-4", "question": "What was the location of the game when the record was 12-4?", "sql": "SELECT Location FROM table WHERE Record = 12-4", "source": "validation"}
93
+ {"text": "### Question:\nWhat Service Name has UTV as the owner?\n\n### SQL:\nSELECT Service name FROM table WHERE Owner = utv", "question": "What Service Name has UTV as the owner?", "sql": "SELECT Service name FROM table WHERE Owner = utv", "source": "validation"}
94
+ {"text": "### Question:\nWhich Rebuilt has a Name as rebuilt of binevanagh?\n\n### SQL:\nSELECT Rebuilt FROM table WHERE Name as rebuilt = binevanagh", "question": "Which Rebuilt has a Name as rebuilt of binevanagh?", "sql": "SELECT Rebuilt FROM table WHERE Name as rebuilt = binevanagh", "source": "validation"}
95
+ {"text": "### Question:\nwhat is the margin of victory when the runner-up is amy alcott and the winning score is \u20139 (72-68-67=207)?\n\n### SQL:\nSELECT Margin of victory FROM table WHERE Runner(s)-up = amy alcott AND Winning score = \u20139 (72-68-67=207)", "question": "what is the margin of victory when the runner-up is amy alcott and the winning score is \u20139 (72-68-67=207)?", "sql": "SELECT Margin of victory FROM table WHERE Runner(s)-up = amy alcott AND Winning score = \u20139 (72-68-67=207)", "source": "validation"}
96
+ {"text": "### Question:\nWhat is the height of the building with 40 floors?\n\n### SQL:\nSELECT Height ft / m FROM table WHERE Floors = 40", "question": "What is the height of the building with 40 floors?", "sql": "SELECT Height ft / m FROM table WHERE Floors = 40", "source": "validation"}
97
+ {"text": "### Question:\nWhat is the total number drawn with goals against less than 55, and a total of 14 losses?\n\n### SQL:\nSELECT COUNT Drawn FROM table WHERE Goals Against < 55 AND Lost = 14", "question": "What is the total number drawn with goals against less than 55, and a total of 14 losses?", "sql": "SELECT COUNT Drawn FROM table WHERE Goals Against < 55 AND Lost = 14", "source": "validation"}
98
+ {"text": "### Question:\nWhich engine from 1973 has a Brabham bt37 chassis?\n\n### SQL:\nSELECT Engine FROM table WHERE Year = 1973 AND Chassis = brabham bt37", "question": "Which engine from 1973 has a Brabham bt37 chassis?", "sql": "SELECT Engine FROM table WHERE Year = 1973 AND Chassis = brabham bt37", "source": "validation"}
99
+ {"text": "### Question:\nTell me the final score for january 9 for cincinnati bengals\n\n### SQL:\nSELECT Final Score FROM table WHERE Date = january 9 AND Host Team = cincinnati bengals", "question": "Tell me the final score for january 9 for cincinnati bengals", "sql": "SELECT Final Score FROM table WHERE Date = january 9 AND Host Team = cincinnati bengals", "source": "validation"}
100
+ {"text": "### Question:\nWhat player was place of t1 in To Par and had a score of 70-73-69=212?\n\n### SQL:\nSELECT To par FROM table WHERE Place = t1 AND Score = 70-73-69=212", "question": "What player was place of t1 in To Par and had a score of 70-73-69=212?", "sql": "SELECT To par FROM table WHERE Place = t1 AND Score = 70-73-69=212", "source": "validation"}
src/outputs/finetuning/visualizations/01_metrics_overview.png ADDED
src/outputs/finetuning/visualizations/02_token_accuracy_dist.png ADDED
src/outputs/finetuning/visualizations/03_keyword_accuracy_dist.png ADDED
src/outputs/finetuning/visualizations/04_training_loss.png ADDED
src/outputs/rag/reports/knowledge_base_report.md ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # RAG Knowledge Base Report
2
+
3
+ **Generated:** 2025-12-07 23:58:56
4
+
5
+ ## Overview
6
+
7
+ | Metric | Value |
8
+ |--------|-------|
9
+ | Total Documents | 80,654 |
10
+ | Collection Name | sql_knowledge |
11
+ | Embedding Model | all-MiniLM-L6-v2 |
12
+
13
+ ## Data Sources
14
+
15
+ | Source | Documents |
16
+ |--------|-----------|
17
+ | train | 56,355 |
18
+ | validation | 8,421 |
19
+ | test | 15,878 |
20
+
21
+ ## Chunking Strategies
22
+
23
+ 1. **SQL Clause Extraction**: Identifies SELECT, FROM, WHERE, GROUP BY, etc.
24
+ 2. **Complexity Classification**: Categorizes as simple/intermediate/complex
25
+ 3. **Keyword Extraction**: Extracts SQL operations (JOIN, COUNT, etc.)
26
+ 4. **Size Categorization**: Classifies question/SQL length
27
+
28
+ ## Complexity Distribution
29
+
30
+ | Level | Count |
31
+ |-------|-------|
32
+ | Simple | 80,396 |
33
+ | Intermediate | 258 |
34
+ | Complex | 0 |
35
+
36
+ ## Document Metadata Structure
37
+
38
+ Each document contains:
39
+ - `sql`: The SQL query
40
+ - `source`: Origin dataset
41
+ - `question`: Original question
42
+ - `complexity`: simple/intermediate/complex
43
+ - `sql_clauses`: Comma-separated clauses
44
+ - `keywords`: SQL keywords found
45
+ - `question_size`: short/medium/long
46
+ - `sql_size`: short/medium/long
src/outputs/rag/stats/knowledge_base_stats.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "total_documents": 80654,
3
+ "sources": {
4
+ "train": 56355,
5
+ "validation": 8421,
6
+ "test": 15878
7
+ },
8
+ "collection_name": "sql_knowledge",
9
+ "embedding_model": "all-MiniLM-L6-v2",
10
+ "chunking_strategies": [
11
+ "sql_clause_extraction",
12
+ "complexity_classification",
13
+ "keyword_extraction",
14
+ "size_categorization"
15
+ ],
16
+ "complexity_distribution": {
17
+ "simple": 80396,
18
+ "intermediate": 258,
19
+ "complex": 0
20
+ },
21
+ "created_at": "2025-12-07T23:58:56.309706"
22
+ }
src/outputs/synthetic/reports/synthetic_report.md ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Synthetic Data Generation Report
2
+
3
+ **Generated:** 2025-12-07 23:24:17
4
+
5
+ ## Dataset Statistics
6
+
7
+ | Metric | Original | Synthetic |
8
+ |--------|----------|-----------|
9
+ | Samples | 52,527 | 142,639 |
10
+ | Avg Length | 11.64 | 14.75 |
11
+ | Min Length | 3 | 3 |
12
+ | Max Length | 44 | 49 |
13
+ | Unique Words | 50,846 | 60,734 |
14
+
15
+ ## Augmentation Results
16
+
17
+ - **Augmentation Factor:** 2.72x
18
+ - **Avg Diversity Score:** 0.2832
19
+ - **Min Diversity Score:** 0.103
20
+ - **Max Diversity Score:** 0.8
21
+
22
+ ## Techniques Used
23
+
24
+ 1. Synonym Replacement (40% probability)
25
+ 2. Random Insertion (15% probability)
26
+ 3. Random Swap (10% probability)
27
+ 4. Structure Variation (prefix/suffix)
28
+ 5. Case Variation
29
+
30
+ ## Quality Controls
31
+
32
+ - Minimum question length: 10 characters
33
+ - Maximum question length: 500 characters
34
+ - Minimum diversity score: 0.1
35
+ - Duplicate removal via MD5 hashing
36
+
37
+ ## Privacy Measures
38
+
39
+ - Email anonymization
40
+ - Phone number anonymization
41
+ - SSN anonymization
42
+
43
+ ## Visualizations
44
+
45
+ - `01_size_comparison.png` - Dataset size comparison
46
+ - `02_length_distribution.png` - Question length distribution
47
+ - `03_diversity_distribution.png` - Diversity score distribution
src/outputs/synthetic/stats/statistics.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "original": {
3
+ "name": "Original",
4
+ "samples": 52527,
5
+ "avg_length": 11.64,
6
+ "min_length": 3,
7
+ "max_length": 44,
8
+ "unique_words": 50846
9
+ },
10
+ "synthetic": {
11
+ "name": "Synthetic",
12
+ "samples": 142639,
13
+ "avg_length": 14.75,
14
+ "min_length": 3,
15
+ "max_length": 49,
16
+ "unique_words": 60734
17
+ },
18
+ "diversity": {
19
+ "avg": 0.2832,
20
+ "min": 0.103,
21
+ "max": 0.8
22
+ },
23
+ "augmentation_factor": 2.72
24
+ }
src/outputs/synthetic/visualizations/01_size_comparison.png ADDED
src/outputs/synthetic/visualizations/02_length_distribution.png ADDED
src/outputs/synthetic/visualizations/03_diversity_distribution.png ADDED
src/pipeline/integrated.py ADDED
@@ -0,0 +1,584 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Integrated Pipeline: RAG + Fine-tuned Model + Gemini Enhancement
3
+ """
4
+
5
+ import os
6
+ import sys
7
+ import json
8
+ from datetime import datetime
9
+ from dotenv import load_dotenv
10
+
11
+ # Load environment variables from .env file
12
+ load_dotenv()
13
+
14
+ # Add parent directory
15
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
16
+
17
+ # =============================================================================
18
+ # CONFIGURATION
19
+ # =============================================================================
20
+
21
+ OUTPUT_DIR = "outputs/pipeline"
22
+ LOGS_DIR = f"{OUTPUT_DIR}/logs"
23
+
24
+ # Gemini config - loaded from .env with fallbacks
25
+ GEMINI_KEYS = [
26
+ os.getenv("GEMINI_API_KEY"),
27
+ os.getenv("GEMINI_API_KEY_FALLBACK_1"),
28
+ os.getenv("GEMINI_API_KEY_FALLBACK_2"),
29
+ ]
30
+ # Remove None values
31
+ GEMINI_KEYS = [k for k in GEMINI_KEYS if k]
32
+
33
+ GEMINI_MODELS = [
34
+ os.getenv("GEMINI_MODEL", "gemini-2.5-flash"),
35
+ os.getenv("GEMINI_MODEL_FALLBACK_1"),
36
+ ]
37
+ # Remove None values
38
+ GEMINI_MODELS = [m for m in GEMINI_MODELS if m]
39
+
40
+ if not GEMINI_KEYS:
41
+ print("⚠️ Warning: No GEMINI_API_KEY found in .env file")
42
+ else:
43
+ print(f"✓ Found {len(GEMINI_KEYS)} Gemini API key(s)")
44
+ print(f"✓ Found {len(GEMINI_MODELS)} Gemini model(s)")
45
+
46
+ def setup_directories():
47
+ for d in [OUTPUT_DIR, LOGS_DIR]:
48
+ os.makedirs(d, exist_ok=True)
49
+
50
+ # =============================================================================
51
+ # GEMINI CLIENT WITH FALLBACK
52
+ # =============================================================================
53
+
54
+ class GeminiClient:
55
+ """Gemini client with automatic fallback for rate limits."""
56
+
57
+ def __init__(self):
58
+ self.genai = None
59
+ self.current_key_idx = 0
60
+ self.current_model_idx = 0
61
+ self.model = None
62
+ self.initialized = False
63
+
64
+ try:
65
+ import google.generativeai as genai
66
+ self.genai = genai
67
+ self._init_model()
68
+ except ImportError:
69
+ print("✗ google-generativeai not installed")
70
+
71
+ def _init_model(self):
72
+ """Initialize model with current key and model."""
73
+ if not GEMINI_KEYS:
74
+ return False
75
+
76
+ key = GEMINI_KEYS[self.current_key_idx]
77
+ model_name = GEMINI_MODELS[self.current_model_idx]
78
+
79
+ try:
80
+ self.genai.configure(api_key=key)
81
+ self.model = self.genai.GenerativeModel(model_name)
82
+ self.initialized = True
83
+ print(f" Using API key #{self.current_key_idx + 1}, model: {model_name}")
84
+ return True
85
+ except Exception as e:
86
+ print(f" Failed to init Gemini: {e}")
87
+ return False
88
+
89
+ def _switch_to_next(self):
90
+ """Switch to next model or key combination."""
91
+ # Try next model with same key
92
+ if self.current_model_idx < len(GEMINI_MODELS) - 1:
93
+ self.current_model_idx += 1
94
+ print(f" ⟳ Switching to fallback model: {GEMINI_MODELS[self.current_model_idx]}")
95
+ return self._init_model()
96
+
97
+ # Try next key with first model
98
+ if self.current_key_idx < len(GEMINI_KEYS) - 1:
99
+ self.current_key_idx += 1
100
+ self.current_model_idx = 0
101
+ print(f" ⟳ Switching to fallback API key #{self.current_key_idx + 1}")
102
+ return self._init_model()
103
+
104
+ # No more fallbacks
105
+ print(" ✗ All Gemini keys/models exhausted")
106
+ return False
107
+
108
+ def generate(self, prompt, max_retries=None):
109
+ """Generate content with automatic fallback."""
110
+ if not self.initialized or not self.model:
111
+ return None, "Gemini not initialized"
112
+
113
+ # Calculate max retries based on available combinations
114
+ if max_retries is None:
115
+ max_retries = len(GEMINI_KEYS) * len(GEMINI_MODELS)
116
+
117
+ attempts = 0
118
+ while attempts < max_retries:
119
+ try:
120
+ response = self.model.generate_content(prompt)
121
+ return response.text.strip(), None
122
+ except Exception as e:
123
+ error_str = str(e)
124
+
125
+ # Check if rate limit error
126
+ if "429" in error_str or "quota" in error_str.lower() or "rate" in error_str.lower():
127
+ print(f" ⚠️ Rate limit hit")
128
+ if not self._switch_to_next():
129
+ return None, "All API keys exhausted"
130
+ attempts += 1
131
+ else:
132
+ # Other error, don't retry
133
+ return None, error_str
134
+
135
+ return None, "Max retries exceeded"
136
+
137
+ def is_available(self):
138
+ """Check if Gemini is available."""
139
+ return self.initialized and self.model is not None
140
+
141
+
142
+ # =============================================================================
143
+ # COMPONENT IMPORTS
144
+ # =============================================================================
145
+
146
+ def load_components():
147
+ """Load all pipeline components."""
148
+ components = {}
149
+
150
+ # 1. RAG Retriever (using SQLRetriever class)
151
+ try:
152
+ from rag.retriever import SQLRetriever
153
+ components['rag'] = SQLRetriever()
154
+ print("✓ RAG Retriever loaded")
155
+ except Exception as e:
156
+ components['rag'] = None
157
+ print(f"✗ RAG not available: {e}")
158
+
159
+ # 2. Prompt Builder
160
+ try:
161
+ from prompts.prompt_builder import PromptBuilder
162
+ components['prompt_builder'] = PromptBuilder()
163
+ print("✓ Prompt Builder loaded")
164
+ except Exception as e:
165
+ components['prompt_builder'] = None
166
+ print(f"✗ Prompt Builder not available: {e}")
167
+
168
+ # 3. Fine-tuned Model
169
+ try:
170
+ from finetuning.inference import SQLGenerator
171
+ components['finetuned_model'] = SQLGenerator()
172
+ print("✓ Fine-tuned model loaded")
173
+ except Exception as e:
174
+ components['finetuned_model'] = None
175
+ print(f"✗ Fine-tuned model not available: {e}")
176
+
177
+ # 4. Gemini with fallback support
178
+ try:
179
+ if GEMINI_KEYS:
180
+ components['gemini'] = GeminiClient()
181
+ if components['gemini'].is_available():
182
+ print("✓ Gemini loaded")
183
+ else:
184
+ components['gemini'] = None
185
+ print("✗ Gemini failed to initialize")
186
+ else:
187
+ components['gemini'] = None
188
+ print("✗ Gemini not available (no API keys)")
189
+ except Exception as e:
190
+ components['gemini'] = None
191
+ print(f"✗ Gemini not available: {e}")
192
+
193
+ return components
194
+
195
+ # =============================================================================
196
+ # GEMINI ENHANCEMENT PROMPTS
197
+ # =============================================================================
198
+
199
+ GEMINI_REFINE_PROMPT = """You are an SQL expert. Review and enhance this SQL query.
200
+
201
+ Original Question: {question}
202
+
203
+ Generated SQL (by a smaller model):
204
+ {sql}
205
+
206
+ Your tasks:
207
+ 1. Check for syntax errors
208
+ 2. Check for logical errors
209
+ 3. Optimize if possible
210
+ 4. Fix any issues
211
+
212
+ Rules:
213
+ - If the SQL is correct, return it unchanged
214
+ - If it needs fixes, return the corrected version
215
+ - Return ONLY the SQL query, no explanations
216
+
217
+ Enhanced SQL:"""
218
+
219
+ GEMINI_VALIDATE_PROMPT = """You are an SQL validator. Check this SQL query.
220
+
221
+ Question: {question}
222
+ SQL: {sql}
223
+
224
+ Respond in JSON format:
225
+ {{
226
+ "is_valid": true/false,
227
+ "errors": ["list of errors if any"],
228
+ "suggestions": ["list of suggestions if any"],
229
+ "confidence": 0.0-1.0
230
+ }}
231
+
232
+ JSON Response:"""
233
+
234
+ GEMINI_EXPLAIN_PROMPT = """Explain this SQL query in simple terms.
235
+
236
+ SQL: {sql}
237
+
238
+ Provide a brief, beginner-friendly explanation (2-3 sentences):"""
239
+
240
+ # =============================================================================
241
+ # PIPELINE CLASS
242
+ # =============================================================================
243
+
244
+ class IntegratedPipeline:
245
+ """
246
+ Complete pipeline: RAG → Prompt → Fine-tuned Model → Gemini Enhancement
247
+ """
248
+
249
+ def __init__(self):
250
+ setup_directories()
251
+ print("\n" + "=" * 50)
252
+ print("LOADING PIPELINE COMPONENTS")
253
+ print("=" * 50)
254
+ self.components = load_components()
255
+ print("=" * 50 + "\n")
256
+
257
+ # -------------------------------------------------------------------------
258
+ # STEP 1: RAG Retrieval
259
+ # -------------------------------------------------------------------------
260
+ def retrieve_context(self, question, top_k=3):
261
+ """Retrieve similar examples using RAG."""
262
+ if not self.components['rag']:
263
+ return "", []
264
+
265
+ try:
266
+ # Use SQLRetriever's retrieve method
267
+ results = self.components['rag'].retrieve(question, top_k=top_k)
268
+
269
+ # Format as context string
270
+ context = "Similar SQL examples:\n\n"
271
+ examples = []
272
+ for i, r in enumerate(results, 1):
273
+ context += f"Example {i}:\n"
274
+ context += f"Question: {r['question']}\n"
275
+ context += f"SQL: {r['sql']}\n\n"
276
+ examples.append(r)
277
+
278
+ return context, examples
279
+ except Exception as e:
280
+ print(f"RAG error: {e}")
281
+ return "", []
282
+
283
+ def retrieve_context_formatted(self, question, top_k=3):
284
+ """Use SQLRetriever's built-in context formatting."""
285
+ if not self.components['rag']:
286
+ return ""
287
+
288
+ try:
289
+ return self.components['rag'].retrieve_as_context(question, top_k=top_k)
290
+ except Exception as e:
291
+ print(f"RAG error: {e}")
292
+ return ""
293
+
294
+ # -------------------------------------------------------------------------
295
+ # STEP 2: Build Prompt
296
+ # -------------------------------------------------------------------------
297
+ def build_prompt(self, question, rag_context):
298
+ """Build prompt with context."""
299
+ if self.components['prompt_builder']:
300
+ result = self.components['prompt_builder'].build_prompt(
301
+ question=question,
302
+ rag_context=rag_context
303
+ )
304
+ if result['success']:
305
+ return result['prompt']
306
+
307
+ # Fallback: simple prompt
308
+ if rag_context:
309
+ return f"{rag_context}\nQuestion: {question}\n\nSQL:"
310
+ return f"Generate SQL for: {question}\n\nSQL:"
311
+
312
+ # -------------------------------------------------------------------------
313
+ # STEP 3: Fine-tuned Model Generation
314
+ # -------------------------------------------------------------------------
315
+ def generate_with_finetuned(self, question, context=""):
316
+ """Generate SQL using fine-tuned model."""
317
+ if not self.components['finetuned_model']:
318
+ return None, "Fine-tuned model not available"
319
+
320
+ try:
321
+ sql = self.components['finetuned_model'].generate(question, context)
322
+ return sql, None
323
+ except Exception as e:
324
+ return None, str(e)
325
+
326
+ # -------------------------------------------------------------------------
327
+ # STEP 4: Gemini Enhancement
328
+ # -------------------------------------------------------------------------
329
+ def enhance_with_gemini(self, question, sql):
330
+ """Use Gemini to refine/validate the SQL."""
331
+ if not self.components['gemini']:
332
+ return sql, {"enhanced": False, "reason": "Gemini not available"}
333
+
334
+ try:
335
+ prompt = GEMINI_REFINE_PROMPT.format(question=question, sql=sql)
336
+ enhanced_sql, error = self.components['gemini'].generate(prompt)
337
+
338
+ if error:
339
+ return sql, {"enhanced": False, "reason": error}
340
+
341
+ # Clean up response
342
+ enhanced_sql = self._clean_sql(enhanced_sql)
343
+
344
+ return enhanced_sql, {"enhanced": True, "original": sql}
345
+ except Exception as e:
346
+ return sql, {"enhanced": False, "reason": str(e)}
347
+
348
+ def validate_with_gemini(self, question, sql):
349
+ """Use Gemini to validate SQL."""
350
+ if not self.components['gemini']:
351
+ return {"is_valid": True, "confidence": 0.5}
352
+
353
+ try:
354
+ prompt = GEMINI_VALIDATE_PROMPT.format(question=question, sql=sql)
355
+ text, error = self.components['gemini'].generate(prompt)
356
+
357
+ if error:
358
+ return {"is_valid": True, "confidence": 0.5, "error": error}
359
+
360
+ # Remove markdown code blocks if present
361
+ if text.startswith("```"):
362
+ text = text.split("```")[1]
363
+ if text.startswith("json"):
364
+ text = text[4:]
365
+
366
+ return json.loads(text)
367
+ except:
368
+ return {"is_valid": True, "confidence": 0.5}
369
+
370
+ def explain_with_gemini(self, sql):
371
+ """Use Gemini to explain the SQL."""
372
+ if not self.components['gemini']:
373
+ return "Explanation not available"
374
+
375
+ try:
376
+ prompt = GEMINI_EXPLAIN_PROMPT.format(sql=sql)
377
+ explanation, error = self.components['gemini'].generate(prompt)
378
+
379
+ if error:
380
+ return f"Explanation error: {error}"
381
+
382
+ return explanation
383
+ except Exception as e:
384
+ return f"Explanation error: {e}"
385
+
386
+ # -------------------------------------------------------------------------
387
+ # MAIN PIPELINE
388
+ # -------------------------------------------------------------------------
389
+ def run(self, question, enhance=True, validate=False, explain=False, top_k=3):
390
+ """
391
+ Run the complete pipeline.
392
+
393
+ Args:
394
+ question: Natural language question
395
+ enhance: Use Gemini to enhance SQL
396
+ validate: Use Gemini to validate SQL
397
+ explain: Use Gemini to explain SQL
398
+ top_k: Number of RAG examples to retrieve
399
+
400
+ Returns:
401
+ dict with all results
402
+ """
403
+ result = {
404
+ 'question': question,
405
+ 'timestamp': datetime.now().isoformat(),
406
+ 'steps': {}
407
+ }
408
+
409
+ # Step 1: RAG Retrieval
410
+ rag_context, examples = self.retrieve_context(question, top_k=top_k)
411
+ result['steps']['rag'] = {
412
+ 'context': rag_context,
413
+ 'examples': examples,
414
+ 'num_examples': len(examples)
415
+ }
416
+
417
+ # Step 2: Build Prompt
418
+ prompt = self.build_prompt(question, rag_context)
419
+ result['steps']['prompt'] = {
420
+ 'prompt': prompt,
421
+ 'length': len(prompt)
422
+ }
423
+
424
+ # Step 3: Fine-tuned Model
425
+ finetuned_sql, error = self.generate_with_finetuned(question, rag_context)
426
+ result['steps']['finetuned'] = {
427
+ 'sql': finetuned_sql,
428
+ 'error': error
429
+ }
430
+
431
+ if not finetuned_sql:
432
+ result['success'] = False
433
+ result['final_sql'] = None
434
+ return result
435
+
436
+ # Step 4: Gemini Enhancement
437
+ if enhance:
438
+ enhanced_sql, enhance_info = self.enhance_with_gemini(question, finetuned_sql)
439
+ result['steps']['gemini_enhance'] = {
440
+ 'sql': enhanced_sql,
441
+ 'info': enhance_info
442
+ }
443
+ result['final_sql'] = enhanced_sql
444
+ else:
445
+ result['final_sql'] = finetuned_sql
446
+
447
+ # Optional: Validation
448
+ if validate:
449
+ validation = self.validate_with_gemini(question, result['final_sql'])
450
+ result['steps']['validation'] = validation
451
+
452
+ # Optional: Explanation
453
+ if explain:
454
+ explanation = self.explain_with_gemini(result['final_sql'])
455
+ result['explanation'] = explanation
456
+
457
+ result['success'] = True
458
+
459
+ # Log result
460
+ self._log_result(result)
461
+
462
+ return result
463
+
464
+ # -------------------------------------------------------------------------
465
+ # UTILITIES
466
+ # -------------------------------------------------------------------------
467
+ def _clean_sql(self, sql):
468
+ """Clean SQL output."""
469
+ sql = sql.strip()
470
+ # Remove markdown code blocks
471
+ if sql.startswith("```"):
472
+ lines = sql.split("\n")
473
+ sql = "\n".join(lines[1:-1] if lines[-1] == "```" else lines[1:])
474
+ # Remove leading 'sql' keyword
475
+ if sql.lower().startswith("sql"):
476
+ sql = sql[3:].strip()
477
+ return sql
478
+
479
+ def _log_result(self, result):
480
+ """Log pipeline result."""
481
+ log_file = f"{LOGS_DIR}/pipeline_log.jsonl"
482
+ # Remove examples from log to save space
483
+ log_result = {k: v for k, v in result.items()}
484
+ if 'steps' in log_result and 'rag' in log_result['steps']:
485
+ log_result['steps']['rag'] = {
486
+ 'num_examples': log_result['steps']['rag'].get('num_examples', 0)
487
+ }
488
+ with open(log_file, 'a') as f:
489
+ f.write(json.dumps(log_result, default=str) + '\n')
490
+
491
+ def get_component_status(self):
492
+ """Get status of all components."""
493
+ return {
494
+ 'rag': self.components['rag'] is not None,
495
+ 'prompt_builder': self.components['prompt_builder'] is not None,
496
+ 'finetuned_model': self.components['finetuned_model'] is not None,
497
+ 'gemini': self.components['gemini'] is not None
498
+ }
499
+
500
+ # =============================================================================
501
+ # SIMPLE INTERFACE
502
+ # =============================================================================
503
+
504
+ _pipeline = None
505
+
506
+ def get_pipeline():
507
+ """Get or create pipeline instance."""
508
+ global _pipeline
509
+ if _pipeline is None:
510
+ _pipeline = IntegratedPipeline()
511
+ return _pipeline
512
+
513
+ def generate_sql(question, enhance=True, explain=False):
514
+ """Simple function to generate SQL."""
515
+ pipeline = get_pipeline()
516
+ result = pipeline.run(question, enhance=enhance, explain=explain)
517
+
518
+ if result['success']:
519
+ return result['final_sql']
520
+ return None
521
+
522
+ # =============================================================================
523
+ # TEST
524
+ # =============================================================================
525
+
526
+ def test_pipeline():
527
+ """Test the integrated pipeline."""
528
+
529
+ print("=" * 60)
530
+ print("TESTING INTEGRATED PIPELINE")
531
+ print("=" * 60)
532
+
533
+ pipeline = IntegratedPipeline()
534
+
535
+ # Show component status
536
+ print("\nComponent Status:")
537
+ status = pipeline.get_component_status()
538
+ for comp, loaded in status.items():
539
+ icon = "✓" if loaded else "✗"
540
+ print(f" {icon} {comp}")
541
+
542
+ questions = [
543
+ "Find all employees with salary above 50000",
544
+ ]
545
+
546
+ for q in questions:
547
+ print(f"\n{'='*60}")
548
+ print(f"Question: {q}")
549
+ print("-" * 60)
550
+
551
+ result = pipeline.run(q, enhance=True, explain=True, top_k=3)
552
+
553
+ # Show RAG results
554
+ print(f"\n[RAG] Retrieved {result['steps']['rag']['num_examples']} examples")
555
+
556
+ # Show fine-tuned output
557
+ print(f"\n[Fine-tuned Model]")
558
+ if result['steps']['finetuned']['sql']:
559
+ print(f" SQL: {result['steps']['finetuned']['sql']}")
560
+ else:
561
+ print(f" Error: {result['steps']['finetuned']['error']}")
562
+
563
+ # Show Gemini enhancement
564
+ if 'gemini_enhance' in result['steps']:
565
+ print(f"\n[Gemini Enhanced]")
566
+ print(f" SQL: {result['steps']['gemini_enhance']['sql']}")
567
+ if result['steps']['finetuned']['sql'] != result['steps']['gemini_enhance']['sql']:
568
+ print(f" ✨ Query was improved!")
569
+
570
+ # Show final
571
+ print(f"\n[Final SQL]")
572
+ print(f" {result['final_sql']}")
573
+
574
+ # Show explanation
575
+ if 'explanation' in result:
576
+ print(f"\n[Explanation]")
577
+ print(f" {result['explanation']}")
578
+
579
+ print("\n" + "=" * 60)
580
+ print("✓ Pipeline test complete")
581
+ print("=" * 60)
582
+
583
+ if __name__ == "__main__":
584
+ test_pipeline()
src/prompts/__init__.py ADDED
File without changes
src/prompts/prompt_builder.py ADDED
@@ -0,0 +1,440 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Prompt Builder for SQL Learning Assistant
3
+ Handles: Context Management, User Interaction Flows, Edge Cases
4
+ """
5
+
6
+ import re
7
+ import os
8
+ import sys
9
+ import json
10
+ from datetime import datetime
11
+
12
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
13
+ from prompts.system_prompts import (
14
+ get_system_prompt,
15
+ get_prompt_template,
16
+ CLARIFICATION_PROMPT,
17
+ ERROR_RECOVERY_PROMPT
18
+ )
19
+
20
+ # =============================================================================
21
+ # OUTPUT DIRECTORIES
22
+ # =============================================================================
23
+
24
+ OUTPUT_DIR = "outputs/prompts"
25
+ LOGS_DIR = f"{OUTPUT_DIR}/logs"
26
+
27
+ def setup_directories():
28
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
29
+ os.makedirs(LOGS_DIR, exist_ok=True)
30
+
31
+ # =============================================================================
32
+ # CONTEXT MANAGEMENT
33
+ # =============================================================================
34
+
35
+ class ConversationContext:
36
+ """
37
+ Manages conversation history and context for multi-turn interactions.
38
+ """
39
+
40
+ def __init__(self, max_history=5):
41
+ self.history = []
42
+ self.max_history = max_history
43
+ self.current_tables = []
44
+ self.current_schema = {}
45
+ self.user_preferences = {}
46
+
47
+ def add_turn(self, question, sql_response, success=True):
48
+ """Add a conversation turn to history."""
49
+ self.history.append({
50
+ 'question': question,
51
+ 'sql': sql_response,
52
+ 'success': success,
53
+ 'timestamp': datetime.now().isoformat()
54
+ })
55
+
56
+ # Keep only recent history
57
+ if len(self.history) > self.max_history:
58
+ self.history = self.history[-self.max_history:]
59
+
60
+ def get_history_context(self):
61
+ """Format history for prompt injection."""
62
+ if not self.history:
63
+ return ""
64
+
65
+ context = "Previous conversation:\n"
66
+ for turn in self.history[-3:]: # Last 3 turns
67
+ context += f"Q: {turn['question']}\n"
68
+ context += f"SQL: {turn['sql']}\n\n"
69
+
70
+ return context
71
+
72
+ def set_schema(self, schema_dict):
73
+ """Set current database schema context."""
74
+ self.current_schema = schema_dict
75
+
76
+ def get_schema_context(self):
77
+ """Format schema for prompt injection."""
78
+ if not self.current_schema:
79
+ return ""
80
+
81
+ context = "Available tables and columns:\n"
82
+ for table, columns in self.current_schema.items():
83
+ context += f"- {table}: {', '.join(columns)}\n"
84
+
85
+ return context
86
+
87
+ def clear(self):
88
+ """Clear conversation history."""
89
+ self.history = []
90
+ self.current_tables = []
91
+ self.current_schema = {}
92
+
93
+ # =============================================================================
94
+ # QUERY ANALYSIS (For Specialized Flows)
95
+ # =============================================================================
96
+
97
+ def analyze_query_intent(question):
98
+ """
99
+ Analyze user question to determine query type and intent.
100
+ Returns: dict with query_type, keywords, entities
101
+ """
102
+ question_lower = question.lower()
103
+
104
+ # Detect query type
105
+ query_type = 'general'
106
+
107
+ # Aggregation patterns
108
+ agg_patterns = ['count', 'sum', 'average', 'avg', 'total', 'maximum', 'max',
109
+ 'minimum', 'min', 'how many', 'what is the total']
110
+ if any(p in question_lower for p in agg_patterns):
111
+ query_type = 'aggregation'
112
+
113
+ # Complex query patterns
114
+ complex_patterns = ['join', 'combine', 'merge', 'from multiple', 'across tables',
115
+ 'subquery', 'nested', 'with the highest', 'with the lowest']
116
+ if any(p in question_lower for p in complex_patterns):
117
+ query_type = 'complex'
118
+
119
+ # Modification patterns
120
+ mod_patterns = ['insert', 'add new', 'update', 'change', 'modify', 'delete', 'remove']
121
+ if any(p in question_lower for p in mod_patterns):
122
+ query_type = 'modification'
123
+
124
+ # Simple patterns (if nothing else matched)
125
+ simple_patterns = ['show', 'list', 'get', 'find', 'select', 'display']
126
+ if query_type == 'general' and any(p in question_lower for p in simple_patterns):
127
+ query_type = 'simple'
128
+
129
+ # Extract potential keywords
130
+ keywords = []
131
+ sql_keywords = ['where', 'group by', 'order by', 'having', 'limit', 'join',
132
+ 'distinct', 'between', 'like', 'in']
133
+ for kw in sql_keywords:
134
+ if kw in question_lower:
135
+ keywords.append(kw.upper())
136
+
137
+ return {
138
+ 'query_type': query_type,
139
+ 'keywords': keywords,
140
+ 'question_length': len(question.split())
141
+ }
142
+
143
+ # =============================================================================
144
+ # EDGE CASE HANDLING
145
+ # =============================================================================
146
+
147
+ def detect_edge_cases(question):
148
+ """
149
+ Detect potential edge cases in user question.
150
+ Returns: list of edge case types detected
151
+ """
152
+ edge_cases = []
153
+ question_lower = question.lower()
154
+
155
+ # Empty or too short
156
+ if len(question.strip()) < 5:
157
+ edge_cases.append('too_short')
158
+
159
+ # Too vague
160
+ vague_patterns = ['something', 'stuff', 'things', 'data', 'information']
161
+ if any(p in question_lower for p in vague_patterns) and len(question.split()) < 5:
162
+ edge_cases.append('too_vague')
163
+
164
+ # Multiple questions
165
+ if question.count('?') > 1:
166
+ edge_cases.append('multiple_questions')
167
+
168
+ # Contains SQL (user pasted SQL instead of question)
169
+ sql_patterns = ['SELECT', 'INSERT', 'UPDATE', 'DELETE', 'FROM', 'WHERE']
170
+ if sum(1 for p in sql_patterns if p in question.upper()) >= 2:
171
+ edge_cases.append('contains_sql')
172
+
173
+ # Potentially dangerous operations
174
+ dangerous_patterns = ['drop table', 'truncate', 'delete all', 'remove all']
175
+ if any(p in question_lower for p in dangerous_patterns):
176
+ edge_cases.append('dangerous_operation')
177
+
178
+ # Non-SQL question
179
+ non_sql_patterns = ['weather', 'hello', 'how are you', 'thank', 'bye']
180
+ if any(p in question_lower for p in non_sql_patterns):
181
+ edge_cases.append('not_sql_related')
182
+
183
+ return edge_cases
184
+
185
+ def handle_edge_case(edge_case_type, question):
186
+ """
187
+ Generate appropriate response for edge cases.
188
+ Returns: (should_continue, message)
189
+ """
190
+ responses = {
191
+ 'too_short': (
192
+ False,
193
+ "Your question is too short. Please provide more details about what data you want to retrieve."
194
+ ),
195
+ 'too_vague': (
196
+ False,
197
+ "Your question is a bit vague. Could you specify:\n- Which table(s) to query?\n- What columns to retrieve?\n- Any conditions to filter by?"
198
+ ),
199
+ 'multiple_questions': (
200
+ False,
201
+ "I detected multiple questions. Please ask one question at a time for accurate SQL generation."
202
+ ),
203
+ 'contains_sql': (
204
+ False,
205
+ "It looks like you've pasted SQL code. Please describe what you want in natural language, and I'll generate the SQL for you."
206
+ ),
207
+ 'dangerous_operation': (
208
+ False,
209
+ "⚠️ This appears to be a destructive operation (DROP/TRUNCATE/DELETE ALL). Please confirm you want to proceed or rephrase your question."
210
+ ),
211
+ 'not_sql_related': (
212
+ False,
213
+ "I'm an SQL assistant. Please ask me questions about querying databases, and I'll help generate SQL queries."
214
+ )
215
+ }
216
+
217
+ return responses.get(edge_case_type, (True, ""))
218
+
219
+ # =============================================================================
220
+ # PROMPT BUILDER CLASS
221
+ # =============================================================================
222
+
223
+ class PromptBuilder:
224
+ """
225
+ Main class for building prompts with context management.
226
+ """
227
+
228
+ def __init__(self):
229
+ self.context = ConversationContext()
230
+ self.log_file = None
231
+ setup_directories()
232
+
233
+ def build_prompt(self, question, rag_context="", include_history=True):
234
+ """
235
+ Build complete prompt for SQL generation.
236
+
237
+ Args:
238
+ question: User's natural language question
239
+ rag_context: Retrieved examples from RAG
240
+ include_history: Whether to include conversation history
241
+
242
+ Returns:
243
+ dict with 'success', 'prompt' or 'error'
244
+ """
245
+ # Check for edge cases
246
+ edge_cases = detect_edge_cases(question)
247
+
248
+ if edge_cases:
249
+ should_continue, message = handle_edge_case(edge_cases[0], question)
250
+ if not should_continue:
251
+ return {
252
+ 'success': False,
253
+ 'error': message,
254
+ 'edge_case': edge_cases[0]
255
+ }
256
+
257
+ # Analyze query intent
258
+ intent = analyze_query_intent(question)
259
+
260
+ # Get appropriate system prompt
261
+ system_prompt = get_system_prompt(intent['query_type'])
262
+
263
+ # Build context parts
264
+ context_parts = []
265
+
266
+ # Add schema context if available
267
+ schema_context = self.context.get_schema_context()
268
+ if schema_context:
269
+ context_parts.append(schema_context)
270
+
271
+ # Add conversation history
272
+ if include_history:
273
+ history_context = self.context.get_history_context()
274
+ if history_context:
275
+ context_parts.append(history_context)
276
+
277
+ # Add RAG context
278
+ if rag_context:
279
+ context_parts.append(rag_context)
280
+
281
+ # Build final prompt
282
+ if rag_context:
283
+ template = get_prompt_template('rag')
284
+ prompt = template.format(
285
+ context=rag_context,
286
+ question=question
287
+ )
288
+ else:
289
+ template = get_prompt_template('zero_shot')
290
+ prompt = template.format(question=question)
291
+
292
+ # Combine everything
293
+ full_prompt = f"{system_prompt}\n\n"
294
+ if context_parts:
295
+ full_prompt += "\n".join(context_parts) + "\n\n"
296
+ full_prompt += prompt
297
+
298
+ # Log the prompt
299
+ self._log_prompt(question, intent, full_prompt)
300
+
301
+ return {
302
+ 'success': True,
303
+ 'prompt': full_prompt,
304
+ 'system_prompt': system_prompt,
305
+ 'query_type': intent['query_type'],
306
+ 'keywords': intent['keywords']
307
+ }
308
+
309
+ def add_response(self, question, sql_response, success=True):
310
+ """Add a completed interaction to history."""
311
+ self.context.add_turn(question, sql_response, success)
312
+
313
+ def set_schema(self, schema_dict):
314
+ """Set database schema for context."""
315
+ self.context.set_schema(schema_dict)
316
+
317
+ def clear_context(self):
318
+ """Clear all context."""
319
+ self.context.clear()
320
+
321
+ def _log_prompt(self, question, intent, prompt):
322
+ """Log prompt for debugging/analysis."""
323
+ log_entry = {
324
+ 'timestamp': datetime.now().isoformat(),
325
+ 'question': question,
326
+ 'intent': intent,
327
+ 'prompt_length': len(prompt)
328
+ }
329
+
330
+ log_file = f"{LOGS_DIR}/prompt_log.jsonl"
331
+ with open(log_file, 'a') as f:
332
+ f.write(json.dumps(log_entry) + '\n')
333
+
334
+ # =============================================================================
335
+ # USER INTERACTION FLOWS
336
+ # =============================================================================
337
+
338
+ def get_clarification_questions(question, intent):
339
+ """
340
+ Generate clarification questions for ambiguous queries.
341
+ """
342
+ clarifications = []
343
+
344
+ # Generic clarifications based on query type
345
+ if intent['query_type'] == 'aggregation':
346
+ clarifications.append("Which column should be aggregated?")
347
+ clarifications.append("Should results be grouped by any column?")
348
+
349
+ if intent['query_type'] == 'complex':
350
+ clarifications.append("Which tables need to be joined?")
351
+ clarifications.append("What is the relationship between the tables?")
352
+
353
+ # Check for missing specifics
354
+ if 'table' not in question.lower():
355
+ clarifications.append("Which table(s) should be queried?")
356
+
357
+ if not any(word in question.lower() for word in ['all', 'specific', 'where', 'filter']):
358
+ clarifications.append("Do you want all records or filtered results?")
359
+
360
+ return clarifications
361
+
362
+ def create_error_recovery_prompt(original_question, error_message):
363
+ """
364
+ Create prompt for recovering from errors.
365
+ """
366
+ return ERROR_RECOVERY_PROMPT.format(
367
+ error=error_message,
368
+ question=original_question
369
+ )
370
+
371
+ # =============================================================================
372
+ # TEST
373
+ # =============================================================================
374
+
375
+ def test_prompt_builder():
376
+ """Test the prompt builder functionality."""
377
+
378
+ print("=" * 60)
379
+ print("TESTING PROMPT BUILDER")
380
+ print("=" * 60)
381
+
382
+ builder = PromptBuilder()
383
+
384
+ # Test 1: Normal question
385
+ print("\n[TEST 1] Normal Question")
386
+ print("-" * 40)
387
+ result = builder.build_prompt(
388
+ "Find all employees with salary above 50000",
389
+ rag_context="Example 1:\nQ: Get workers earning more than 40000\nSQL: SELECT * FROM employees WHERE salary > 40000"
390
+ )
391
+ print(f"Success: {result['success']}")
392
+ print(f"Query Type: {result.get('query_type')}")
393
+ print(f"Prompt Length: {len(result.get('prompt', ''))}")
394
+
395
+ # Test 2: Edge case - too short
396
+ print("\n[TEST 2] Edge Case - Too Short")
397
+ print("-" * 40)
398
+ result = builder.build_prompt("SQL")
399
+ print(f"Success: {result['success']}")
400
+ print(f"Error: {result.get('error', 'None')}")
401
+
402
+ # Test 3: Edge case - contains SQL
403
+ print("\n[TEST 3] Edge Case - Contains SQL")
404
+ print("-" * 40)
405
+ result = builder.build_prompt("SELECT * FROM users WHERE id = 1")
406
+ print(f"Success: {result['success']}")
407
+ print(f"Error: {result.get('error', 'None')}")
408
+
409
+ # Test 4: Edge case - dangerous operation
410
+ print("\n[TEST 4] Edge Case - Dangerous Operation")
411
+ print("-" * 40)
412
+ result = builder.build_prompt("Drop table users")
413
+ print(f"Success: {result['success']}")
414
+ print(f"Error: {result.get('error', 'None')}")
415
+
416
+ # Test 5: Aggregation query
417
+ print("\n[TEST 5] Aggregation Query")
418
+ print("-" * 40)
419
+ result = builder.build_prompt("Count total orders by customer")
420
+ print(f"Success: {result['success']}")
421
+ print(f"Query Type: {result.get('query_type')}")
422
+
423
+ # Test 6: Context management
424
+ print("\n[TEST 6] Context Management")
425
+ print("-" * 40)
426
+ builder.set_schema({
427
+ 'employees': ['id', 'name', 'salary', 'dept_id'],
428
+ 'departments': ['id', 'name', 'location']
429
+ })
430
+ builder.add_response("Show all employees", "SELECT * FROM employees", success=True)
431
+ result = builder.build_prompt("Now filter by department")
432
+ print(f"Success: {result['success']}")
433
+ print(f"Has History: {'Previous conversation' in result.get('prompt', '')}")
434
+
435
+ print("\n" + "=" * 60)
436
+ print("✓ All tests complete")
437
+ print("=" * 60)
438
+
439
+ if __name__ == "__main__":
440
+ test_prompt_builder()
src/prompts/system_prompts.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ System Prompts for SQL Learning Assistant
3
+ Systematic prompting strategies for different use cases.
4
+ """
5
+
6
+ # =============================================================================
7
+ # BASE SYSTEM PROMPT
8
+ # =============================================================================
9
+
10
+ BASE_SYSTEM_PROMPT = """You are an expert SQL assistant. Your task is to generate accurate SQL queries based on natural language questions.
11
+
12
+ Rules:
13
+ 1. Generate ONLY the SQL query, no explanations unless asked
14
+ 2. Use standard SQL syntax
15
+ 3. Be precise and efficient in your queries
16
+ 4. If the question is ambiguous, make reasonable assumptions
17
+ 5. Always use proper SQL formatting
18
+ """
19
+
20
+ # =============================================================================
21
+ # SPECIALIZED PROMPTS BY USE CASE
22
+ # =============================================================================
23
+
24
+ # For simple SELECT queries
25
+ SIMPLE_QUERY_PROMPT = """You are an SQL assistant specializing in simple queries.
26
+
27
+ Your task: Convert the natural language question into a basic SQL SELECT query.
28
+
29
+ Guidelines:
30
+ - Use simple SELECT, FROM, WHERE clauses
31
+ - Avoid complex joins unless necessary
32
+ - Keep queries straightforward and readable
33
+ """
34
+
35
+ # For complex queries with JOINs
36
+ COMPLEX_QUERY_PROMPT = """You are an SQL assistant specializing in complex queries.
37
+
38
+ Your task: Convert the natural language question into an SQL query that may involve:
39
+ - Multiple JOINs (INNER, LEFT, RIGHT)
40
+ - Subqueries
41
+ - Multiple conditions
42
+ - Aggregations with GROUP BY
43
+
44
+ Guidelines:
45
+ - Use appropriate JOIN types
46
+ - Structure subqueries clearly
47
+ - Use aliases for readability
48
+ """
49
+
50
+ # For aggregation queries
51
+ AGGREGATION_PROMPT = """You are an SQL assistant specializing in aggregation queries.
52
+
53
+ Your task: Convert the natural language question into an SQL query using aggregate functions.
54
+
55
+ Guidelines:
56
+ - Use COUNT, SUM, AVG, MAX, MIN appropriately
57
+ - Include GROUP BY when aggregating
58
+ - Use HAVING for aggregate conditions
59
+ - Consider ORDER BY for ranked results
60
+ """
61
+
62
+ # For data modification (if needed)
63
+ MODIFICATION_PROMPT = """You are an SQL assistant for data modification queries.
64
+
65
+ Your task: Convert the natural language request into INSERT, UPDATE, or DELETE statements.
66
+
67
+ Guidelines:
68
+ - Be cautious with DELETE and UPDATE
69
+ - Always include WHERE clause for UPDATE/DELETE
70
+ - Validate data types for INSERT
71
+ """
72
+
73
+ # =============================================================================
74
+ # PROMPT TEMPLATES WITH CONTEXT
75
+ # =============================================================================
76
+
77
+ RAG_CONTEXT_TEMPLATE = """You are an expert SQL assistant.
78
+
79
+ Here are similar examples to help you:
80
+
81
+ {context}
82
+
83
+ Based on these examples, generate the SQL query for:
84
+ Question: {question}
85
+
86
+ SQL:"""
87
+
88
+ FEW_SHOT_TEMPLATE = """You are an expert SQL assistant. Learn from these examples:
89
+
90
+ {examples}
91
+
92
+ Now generate SQL for this question:
93
+ Question: {question}
94
+
95
+ SQL:"""
96
+
97
+ ZERO_SHOT_TEMPLATE = """You are an expert SQL assistant.
98
+
99
+ Generate the SQL query for:
100
+ Question: {question}
101
+
102
+ SQL:"""
103
+
104
+ # =============================================================================
105
+ # ERROR HANDLING PROMPTS
106
+ # =============================================================================
107
+
108
+ CLARIFICATION_PROMPT = """I need more information to generate the SQL query.
109
+
110
+ Original question: {question}
111
+
112
+ Please clarify:
113
+ {clarification_points}
114
+ """
115
+
116
+ ERROR_RECOVERY_PROMPT = """I encountered an issue with the previous query.
117
+
118
+ Error: {error}
119
+ Original question: {question}
120
+
121
+ Let me try a different approach:
122
+ """
123
+
124
+ # =============================================================================
125
+ # PROMPT SELECTOR
126
+ # =============================================================================
127
+
128
+ def get_system_prompt(query_type='general'):
129
+ """
130
+ Get appropriate system prompt based on query type.
131
+
132
+ Args:
133
+ query_type: 'simple', 'complex', 'aggregation', 'modification', 'general'
134
+
135
+ Returns:
136
+ System prompt string
137
+ """
138
+ prompts = {
139
+ 'simple': SIMPLE_QUERY_PROMPT,
140
+ 'complex': COMPLEX_QUERY_PROMPT,
141
+ 'aggregation': AGGREGATION_PROMPT,
142
+ 'modification': MODIFICATION_PROMPT,
143
+ 'general': BASE_SYSTEM_PROMPT
144
+ }
145
+ return prompts.get(query_type, BASE_SYSTEM_PROMPT)
146
+
147
+ def get_prompt_template(template_type='rag'):
148
+ """
149
+ Get prompt template by type.
150
+
151
+ Args:
152
+ template_type: 'rag', 'few_shot', 'zero_shot'
153
+
154
+ Returns:
155
+ Template string
156
+ """
157
+ templates = {
158
+ 'rag': RAG_CONTEXT_TEMPLATE,
159
+ 'few_shot': FEW_SHOT_TEMPLATE,
160
+ 'zero_shot': ZERO_SHOT_TEMPLATE
161
+ }
162
+ return templates.get(template_type, RAG_CONTEXT_TEMPLATE)
src/rag/__init__.py ADDED
File without changes
src/rag/embeddings.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Embedding Module for RAG System
3
+ Uses FREE sentence-transformers (no API costs).
4
+ Gemini is ONLY used for final SQL generation.
5
+ """
6
+
7
+ from sentence_transformers import SentenceTransformer
8
+ import os
9
+
10
+ # =============================================================================
11
+ # FREE LOCAL EMBEDDING MODEL
12
+ # =============================================================================
13
+
14
+ # Using all-MiniLM-L6-v2: fast, good quality, 384 dimensions
15
+ MODEL_NAME = "all-MiniLM-L6-v2"
16
+
17
+ # Global model instance (loaded once)
18
+ _model = None
19
+
20
+ def get_model():
21
+ """Get or load the embedding model."""
22
+ global _model
23
+ if _model is None:
24
+ print(f" Loading embedding model: {MODEL_NAME}")
25
+ _model = SentenceTransformer(MODEL_NAME)
26
+ return _model
27
+
28
+ # =============================================================================
29
+ # EMBEDDING FUNCTIONS
30
+ # =============================================================================
31
+
32
+ def get_embedding(text):
33
+ """Get embedding for a single text."""
34
+ try:
35
+ model = get_model()
36
+ embedding = model.encode(text, convert_to_numpy=True)
37
+ return embedding.tolist()
38
+ except Exception as e:
39
+ print(f"Error getting embedding: {e}")
40
+ return None
41
+
42
+ def get_embeddings_batch(texts):
43
+ """Get embeddings for multiple texts at once (efficient)."""
44
+ try:
45
+ model = get_model()
46
+ embeddings = model.encode(texts, convert_to_numpy=True, show_progress_bar=False)
47
+ return [emb.tolist() for emb in embeddings]
48
+ except Exception as e:
49
+ print(f"Error in batch embedding: {e}")
50
+ return [None] * len(texts)
51
+
52
+ # =============================================================================
53
+ # TEST
54
+ # =============================================================================
55
+
56
+ def test_embedding():
57
+ """Test embedding functionality."""
58
+ print("=" * 50)
59
+ print("TESTING EMBEDDINGS (FREE - No API)")
60
+ print("=" * 50)
61
+
62
+ test_texts = [
63
+ "Find all employees with salary greater than 50000",
64
+ "Show customers who ordered last month",
65
+ "Count products by category"
66
+ ]
67
+
68
+ print(f"\nModel: {MODEL_NAME}")
69
+ print(f"Testing with {len(test_texts)} texts...\n")
70
+
71
+ # Single embedding
72
+ emb = get_embedding(test_texts[0])
73
+ if emb:
74
+ print(f"✓ Single embedding works")
75
+ print(f" Dimension: {len(emb)}")
76
+
77
+ # Batch embedding
78
+ embs = get_embeddings_batch(test_texts)
79
+ if embs and embs[0]:
80
+ print(f"✓ Batch embedding works")
81
+ print(f" Got {len(embs)} embeddings")
82
+
83
+ print("\n✓ All tests passed (FREE - No Gemini used)")
84
+ return True
85
+
86
+ if __name__ == "__main__":
87
+ test_embedding()
src/rag/knowledge_base.py ADDED
@@ -0,0 +1,415 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Knowledge Base Builder for RAG System
3
+ Includes: Chunking Strategies, Vector Storage
4
+ """
5
+
6
+ import os
7
+ import pandas as pd
8
+ import chromadb
9
+ import json
10
+ import re
11
+ from datetime import datetime
12
+ import sys
13
+
14
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
15
+ from rag.embeddings import get_embeddings_batch
16
+
17
+ # =============================================================================
18
+ # CONFIGURATION
19
+ # =============================================================================
20
+
21
+ CHROMA_DIR = "chromadb_data"
22
+ COLLECTION_NAME = "sql_knowledge"
23
+ OUTPUT_DIR = "outputs/rag"
24
+ STATS_DIR = f"{OUTPUT_DIR}/stats"
25
+ REPORT_DIR = f"{OUTPUT_DIR}/reports"
26
+
27
+ def setup_directories():
28
+ """Create necessary directories."""
29
+ for d in [CHROMA_DIR, OUTPUT_DIR, STATS_DIR, REPORT_DIR]:
30
+ os.makedirs(d, exist_ok=True)
31
+
32
+ # =============================================================================
33
+ # CHUNKING STRATEGIES
34
+ # =============================================================================
35
+
36
+ def chunk_by_sql_clauses(sql):
37
+ """
38
+ Chunking Strategy 1: Split SQL by clauses.
39
+ Identifies SELECT, FROM, WHERE, GROUP BY, ORDER BY, etc.
40
+ """
41
+ clauses = []
42
+
43
+ # Common SQL clause patterns
44
+ patterns = [
45
+ (r'\bSELECT\b.*?(?=\bFROM\b|$)', 'SELECT'),
46
+ (r'\bFROM\b.*?(?=\bWHERE\b|\bGROUP\b|\bORDER\b|\bLIMIT\b|$)', 'FROM'),
47
+ (r'\bWHERE\b.*?(?=\bGROUP\b|\bORDER\b|\bLIMIT\b|$)', 'WHERE'),
48
+ (r'\bGROUP BY\b.*?(?=\bHAVING\b|\bORDER\b|\bLIMIT\b|$)', 'GROUP BY'),
49
+ (r'\bHAVING\b.*?(?=\bORDER\b|\bLIMIT\b|$)', 'HAVING'),
50
+ (r'\bORDER BY\b.*?(?=\bLIMIT\b|$)', 'ORDER BY'),
51
+ (r'\bLIMIT\b.*', 'LIMIT'),
52
+ ]
53
+
54
+ sql_upper = sql.upper()
55
+ for pattern, clause_name in patterns:
56
+ match = re.search(pattern, sql_upper, re.IGNORECASE | re.DOTALL)
57
+ if match:
58
+ clauses.append(clause_name)
59
+
60
+ return clauses
61
+
62
+ def chunk_by_complexity(question, sql):
63
+ """
64
+ Chunking Strategy 2: Categorize by query complexity.
65
+ """
66
+ sql_upper = sql.upper()
67
+
68
+ # Determine complexity level
69
+ complexity_score = 0
70
+
71
+ # Check for complex features
72
+ if 'JOIN' in sql_upper:
73
+ complexity_score += 2
74
+ if 'SUBQUERY' in sql_upper or sql_upper.count('SELECT') > 1:
75
+ complexity_score += 2
76
+ if 'GROUP BY' in sql_upper:
77
+ complexity_score += 1
78
+ if 'HAVING' in sql_upper:
79
+ complexity_score += 1
80
+ if 'ORDER BY' in sql_upper:
81
+ complexity_score += 1
82
+ if any(agg in sql_upper for agg in ['COUNT', 'SUM', 'AVG', 'MAX', 'MIN']):
83
+ complexity_score += 1
84
+ if 'UNION' in sql_upper:
85
+ complexity_score += 2
86
+
87
+ # Categorize
88
+ if complexity_score <= 1:
89
+ return 'simple'
90
+ elif complexity_score <= 3:
91
+ return 'intermediate'
92
+ else:
93
+ return 'complex'
94
+
95
+ def extract_sql_keywords(sql):
96
+ """
97
+ Chunking Strategy 3: Extract SQL keywords for metadata.
98
+ """
99
+ sql_upper = sql.upper()
100
+
101
+ keywords = []
102
+
103
+ # Operations
104
+ if 'SELECT' in sql_upper:
105
+ keywords.append('SELECT')
106
+ if 'INSERT' in sql_upper:
107
+ keywords.append('INSERT')
108
+ if 'UPDATE' in sql_upper:
109
+ keywords.append('UPDATE')
110
+ if 'DELETE' in sql_upper:
111
+ keywords.append('DELETE')
112
+
113
+ # Joins
114
+ if 'INNER JOIN' in sql_upper:
115
+ keywords.append('INNER JOIN')
116
+ elif 'LEFT JOIN' in sql_upper:
117
+ keywords.append('LEFT JOIN')
118
+ elif 'RIGHT JOIN' in sql_upper:
119
+ keywords.append('RIGHT JOIN')
120
+ elif 'JOIN' in sql_upper:
121
+ keywords.append('JOIN')
122
+
123
+ # Clauses
124
+ for clause in ['WHERE', 'GROUP BY', 'HAVING', 'ORDER BY', 'LIMIT']:
125
+ if clause in sql_upper:
126
+ keywords.append(clause)
127
+
128
+ # Aggregations
129
+ for agg in ['COUNT', 'SUM', 'AVG', 'MAX', 'MIN']:
130
+ if agg in sql_upper:
131
+ keywords.append(agg)
132
+
133
+ # Subqueries
134
+ if sql_upper.count('SELECT') > 1:
135
+ keywords.append('SUBQUERY')
136
+
137
+ return keywords
138
+
139
+ def calculate_chunk_size(text):
140
+ """Calculate appropriate chunk size category."""
141
+ word_count = len(text.split())
142
+
143
+ if word_count <= 10:
144
+ return 'short'
145
+ elif word_count <= 25:
146
+ return 'medium'
147
+ else:
148
+ return 'long'
149
+
150
+ # =============================================================================
151
+ # DOCUMENT PREPARATION WITH CHUNKING
152
+ # =============================================================================
153
+
154
+ def prepare_documents_with_chunking(datasets):
155
+ """
156
+ Prepare documents with chunking metadata.
157
+ Each document gets rich metadata for filtering/ranking.
158
+ """
159
+ documents = []
160
+ metadatas = []
161
+ ids = []
162
+
163
+ idx = 0
164
+ for source, df in datasets.items():
165
+ for _, row in df.iterrows():
166
+ question = str(row['question'])
167
+ sql = str(row['sql'])
168
+
169
+ # Apply chunking strategies
170
+ sql_clauses = chunk_by_sql_clauses(sql)
171
+ complexity = chunk_by_complexity(question, sql)
172
+ keywords = extract_sql_keywords(sql)
173
+ q_size = calculate_chunk_size(question)
174
+ sql_size = calculate_chunk_size(sql)
175
+
176
+ # Create rich metadata
177
+ metadata = {
178
+ 'sql': sql,
179
+ 'source': source,
180
+ 'question': question,
181
+ # Chunking metadata
182
+ 'complexity': complexity,
183
+ 'sql_clauses': ','.join(sql_clauses),
184
+ 'keywords': ','.join(keywords),
185
+ 'question_size': q_size,
186
+ 'sql_size': sql_size,
187
+ 'keyword_count': len(keywords),
188
+ 'clause_count': len(sql_clauses),
189
+ }
190
+
191
+ documents.append(question)
192
+ metadatas.append(metadata)
193
+ ids.append(f"doc_{idx}")
194
+ idx += 1
195
+
196
+ return documents, metadatas, ids
197
+
198
+ # =============================================================================
199
+ # CHROMADB CLIENT
200
+ # =============================================================================
201
+
202
+ def get_chroma_client():
203
+ """Get ChromaDB persistent client."""
204
+ return chromadb.PersistentClient(path=CHROMA_DIR)
205
+
206
+ def get_or_create_collection(client):
207
+ """Get or create the SQL knowledge collection."""
208
+ return client.get_or_create_collection(
209
+ name=COLLECTION_NAME,
210
+ metadata={"description": "SQL question-answer pairs with chunking metadata"}
211
+ )
212
+
213
+ # =============================================================================
214
+ # DATA LOADING
215
+ # =============================================================================
216
+
217
+ def load_datasets(data_dir="data"):
218
+ """Load ALL CSV datasets."""
219
+ datasets = {}
220
+
221
+ files = {
222
+ 'train': 'train.csv',
223
+ 'validation': 'validation.csv',
224
+ 'test': 'test.csv'
225
+ # 'synthetic': 'synthetic.csv'
226
+ }
227
+
228
+ for name, filename in files.items():
229
+ filepath = os.path.join(data_dir, filename)
230
+ if os.path.exists(filepath):
231
+ df = pd.read_csv(filepath)
232
+ datasets[name] = df
233
+ print(f" Loaded {name}: {len(df):,} rows")
234
+ else:
235
+ print(f" Skipped {name}: file not found")
236
+
237
+ return datasets
238
+
239
+ # =============================================================================
240
+ # KNOWLEDGE BASE BUILDING
241
+ # =============================================================================
242
+
243
+ def build_knowledge_base(data_dir="data", batch_size=500):
244
+ """Build knowledge base with chunking strategies."""
245
+
246
+ print("=" * 50)
247
+ print("BUILDING RAG KNOWLEDGE BASE")
248
+ print("With Chunking Strategies")
249
+ print("=" * 50)
250
+
251
+ setup_directories()
252
+
253
+ # Step 1: Load data
254
+ print(f"\n[1/5] Loading datasets...")
255
+ datasets = load_datasets(data_dir)
256
+
257
+ if not datasets:
258
+ print("ERROR: No datasets found!")
259
+ return None
260
+
261
+ total_rows = sum(len(df) for df in datasets.values())
262
+ print(f" Total rows: {total_rows:,}")
263
+
264
+ # Step 2: Prepare documents with chunking
265
+ print(f"\n[2/5] Applying chunking strategies...")
266
+ documents, metadatas, ids = prepare_documents_with_chunking(datasets)
267
+ print(f" Total documents: {len(documents):,}")
268
+
269
+ # Show chunking stats
270
+ complexities = [m['complexity'] for m in metadatas]
271
+ print(f" Complexity distribution:")
272
+ print(f" Simple: {complexities.count('simple'):,}")
273
+ print(f" Intermediate: {complexities.count('intermediate'):,}")
274
+ print(f" Complex: {complexities.count('complex'):,}")
275
+
276
+ # Step 3: Initialize ChromaDB
277
+ print(f"\n[3/5] Initializing ChromaDB...")
278
+ client = get_chroma_client()
279
+
280
+ try:
281
+ client.delete_collection(COLLECTION_NAME)
282
+ print(" Deleted existing collection")
283
+ except:
284
+ pass
285
+
286
+ collection = get_or_create_collection(client)
287
+ print(f" Collection: {COLLECTION_NAME}")
288
+
289
+ # Step 4: Generate embeddings and store
290
+ print(f"\n[4/5] Generating embeddings...")
291
+
292
+ total_added = 0
293
+
294
+ for i in range(0, len(documents), batch_size):
295
+ batch_docs = documents[i:i + batch_size]
296
+ batch_meta = metadatas[i:i + batch_size]
297
+ batch_ids = ids[i:i + batch_size]
298
+
299
+ embeddings = get_embeddings_batch(batch_docs)
300
+
301
+ if embeddings and embeddings[0] is not None:
302
+ collection.add(
303
+ documents=batch_docs,
304
+ metadatas=batch_meta,
305
+ ids=batch_ids,
306
+ embeddings=embeddings
307
+ )
308
+ total_added += len(batch_docs)
309
+
310
+ progress = min(i + batch_size, len(documents))
311
+ pct = (progress / len(documents)) * 100
312
+ print(f" Progress: {progress:,}/{len(documents):,} ({pct:.1f}%)")
313
+
314
+ # Step 5: Save statistics
315
+ print(f"\n[5/5] Saving statistics...")
316
+ stats = {
317
+ 'total_documents': total_added,
318
+ 'sources': {name: len(df) for name, df in datasets.items()},
319
+ 'collection_name': COLLECTION_NAME,
320
+ 'embedding_model': 'all-MiniLM-L6-v2',
321
+ 'chunking_strategies': [
322
+ 'sql_clause_extraction',
323
+ 'complexity_classification',
324
+ 'keyword_extraction',
325
+ 'size_categorization'
326
+ ],
327
+ 'complexity_distribution': {
328
+ 'simple': complexities.count('simple'),
329
+ 'intermediate': complexities.count('intermediate'),
330
+ 'complex': complexities.count('complex')
331
+ },
332
+ 'created_at': datetime.now().isoformat()
333
+ }
334
+
335
+ with open(f'{STATS_DIR}/knowledge_base_stats.json', 'w') as f:
336
+ json.dump(stats, f, indent=2)
337
+
338
+ generate_report(stats)
339
+
340
+ print("\n" + "=" * 50)
341
+ print("COMPLETE")
342
+ print("=" * 50)
343
+ print(f" Documents indexed: {total_added:,}")
344
+ print(f" Storage: {CHROMA_DIR}/")
345
+
346
+ return collection
347
+
348
+ # =============================================================================
349
+ # REPORT GENERATION
350
+ # =============================================================================
351
+
352
+ def generate_report(stats):
353
+ """Generate knowledge base report."""
354
+
355
+ report = f"""# RAG Knowledge Base Report
356
+
357
+ **Generated:** {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
358
+
359
+ ## Overview
360
+
361
+ | Metric | Value |
362
+ |--------|-------|
363
+ | Total Documents | {stats['total_documents']:,} |
364
+ | Collection Name | {stats['collection_name']} |
365
+ | Embedding Model | {stats['embedding_model']} |
366
+
367
+ ## Data Sources
368
+
369
+ | Source | Documents |
370
+ |--------|-----------|
371
+ """
372
+
373
+ for source, count in stats['sources'].items():
374
+ report += f"| {source} | {count:,} |\n"
375
+
376
+ report += f"""
377
+ ## Chunking Strategies
378
+
379
+ 1. **SQL Clause Extraction**: Identifies SELECT, FROM, WHERE, GROUP BY, etc.
380
+ 2. **Complexity Classification**: Categorizes as simple/intermediate/complex
381
+ 3. **Keyword Extraction**: Extracts SQL operations (JOIN, COUNT, etc.)
382
+ 4. **Size Categorization**: Classifies question/SQL length
383
+
384
+ ## Complexity Distribution
385
+
386
+ | Level | Count |
387
+ |-------|-------|
388
+ | Simple | {stats['complexity_distribution']['simple']:,} |
389
+ | Intermediate | {stats['complexity_distribution']['intermediate']:,} |
390
+ | Complex | {stats['complexity_distribution']['complex']:,} |
391
+
392
+ ## Document Metadata Structure
393
+
394
+ Each document contains:
395
+ - `sql`: The SQL query
396
+ - `source`: Origin dataset
397
+ - `question`: Original question
398
+ - `complexity`: simple/intermediate/complex
399
+ - `sql_clauses`: Comma-separated clauses
400
+ - `keywords`: SQL keywords found
401
+ - `question_size`: short/medium/long
402
+ - `sql_size`: short/medium/long
403
+ """
404
+
405
+ with open(f'{REPORT_DIR}/knowledge_base_report.md', 'w') as f:
406
+ f.write(report)
407
+
408
+ print(f" Report saved to {REPORT_DIR}/")
409
+
410
+ # =============================================================================
411
+ # ENTRY POINT
412
+ # =============================================================================
413
+
414
+ if __name__ == "__main__":
415
+ build_knowledge_base(data_dir="data", batch_size=500)
src/rag/retriever.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Retriever Module for RAG System
3
+ Loads from: Local ChromaDB OR HuggingFace Hub
4
+ """
5
+
6
+ import os
7
+ import sys
8
+ from dotenv import load_dotenv
9
+
10
+ load_dotenv()
11
+
12
+ # Try new imports first, fall back to old
13
+ try:
14
+ from langchain_huggingface import HuggingFaceEmbeddings
15
+ from langchain_chroma import Chroma
16
+ except ImportError:
17
+ from langchain_community.vectorstores import Chroma
18
+ from langchain_community.embeddings import HuggingFaceEmbeddings
19
+
20
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
21
+
22
+ # =============================================================================
23
+ # CONFIGURATION
24
+ # =============================================================================
25
+
26
+ LOCAL_CHROMADB_DIR = "chromadb_data"
27
+ HF_CHROMADB_ID = os.getenv("HF_CHROMADB_ID", None)
28
+ COLLECTION_NAME = "sql_knowledge"
29
+ EMBEDDING_MODEL = "all-MiniLM-L6-v2"
30
+
31
+ # =============================================================================
32
+ # CHROMADB LOADER
33
+ # =============================================================================
34
+
35
+ def ensure_chromadb_exists():
36
+ """Ensure ChromaDB data exists - download from HF if needed."""
37
+
38
+ # Check if local has actual ChromaDB files (not just empty folder)
39
+ if os.path.exists(LOCAL_CHROMADB_DIR):
40
+ local_files = os.listdir(LOCAL_CHROMADB_DIR) if os.path.isdir(LOCAL_CHROMADB_DIR) else []
41
+ # ChromaDB creates files like chroma.sqlite3 or folders
42
+ has_chroma_files = any('chroma' in f.lower() or 'sqlite' in f.lower() for f in local_files) or len(local_files) > 2
43
+
44
+ if has_chroma_files:
45
+ print(f"📁 Using local ChromaDB: {LOCAL_CHROMADB_DIR}")
46
+ return LOCAL_CHROMADB_DIR
47
+ else:
48
+ print(f"⚠️ ChromaDB folder exists but is empty or incomplete")
49
+
50
+ # Download from HuggingFace
51
+ if HF_CHROMADB_ID:
52
+ print(f"☁️ Downloading ChromaDB from HuggingFace: {HF_CHROMADB_ID}")
53
+ from huggingface_hub import snapshot_download
54
+
55
+ # Create folder if not exists
56
+ os.makedirs(LOCAL_CHROMADB_DIR, exist_ok=True)
57
+
58
+ snapshot_download(
59
+ repo_id=HF_CHROMADB_ID,
60
+ repo_type="dataset",
61
+ local_dir=LOCAL_CHROMADB_DIR
62
+ )
63
+ print("✓ ChromaDB downloaded!")
64
+ return LOCAL_CHROMADB_DIR
65
+
66
+ # Need to build it from data
67
+ print("⚠️ ChromaDB not found and no HF_CHROMADB_ID set. Building from data...")
68
+ from rag.knowledge_base import build_knowledge_base
69
+ build_knowledge_base(data_dir="data", batch_size=500)
70
+ return LOCAL_CHROMADB_DIR
71
+
72
+ # =============================================================================
73
+ # LANGCHAIN EMBEDDINGS
74
+ # =============================================================================
75
+
76
+ def get_embeddings():
77
+ """Get HuggingFace embeddings for LangChain."""
78
+ return HuggingFaceEmbeddings(
79
+ model_name=EMBEDDING_MODEL,
80
+ model_kwargs={'device': 'cpu'},
81
+ encode_kwargs={'normalize_embeddings': True}
82
+ )
83
+
84
+ # =============================================================================
85
+ # RANKING FUNCTIONS
86
+ # =============================================================================
87
+
88
+ def calculate_relevance_score(result, query):
89
+ """Calculate enhanced relevance score."""
90
+ base_score = result.get('score', 0.5)
91
+ boost = 0.0
92
+
93
+ query_words = set(query.lower().split())
94
+ question_words = set(result.get('question', '').lower().split())
95
+ overlap = len(query_words & question_words)
96
+ if overlap > 0:
97
+ boost += 0.05 * min(overlap, 5)
98
+
99
+ query_length = len(query.split())
100
+ if query_length <= 8 and result.get('complexity') == 'simple':
101
+ boost += 0.1
102
+ elif query_length > 15 and result.get('complexity') == 'complex':
103
+ boost += 0.1
104
+
105
+ return base_score + boost
106
+
107
+ def rerank_results(results, query):
108
+ """Re-rank results using enhanced relevance scoring."""
109
+ for r in results:
110
+ r['relevance_score'] = calculate_relevance_score(r, query)
111
+ results.sort(key=lambda x: x['relevance_score'], reverse=True)
112
+ return results
113
+
114
+ # =============================================================================
115
+ # FILTERING FUNCTIONS
116
+ # =============================================================================
117
+
118
+ def filter_by_threshold(results, min_score=0.0):
119
+ return [r for r in results if r.get('score', 0) >= min_score]
120
+
121
+ def filter_by_complexity(results, complexity=None):
122
+ if complexity is None:
123
+ return results
124
+ return [r for r in results if r.get('complexity') == complexity]
125
+
126
+ # =============================================================================
127
+ # SQL RETRIEVER CLASS
128
+ # =============================================================================
129
+
130
+ class SQLRetriever:
131
+ """LangChain-based retriever with local/HuggingFace support."""
132
+
133
+ def __init__(self):
134
+ """Initialize the retriever."""
135
+ print("Initializing SQL Retriever...")
136
+
137
+ # Ensure ChromaDB exists
138
+ chromadb_path = ensure_chromadb_exists()
139
+
140
+ # Load embeddings
141
+ self.embeddings = get_embeddings()
142
+
143
+ # Load ChromaDB
144
+ self.vectorstore = Chroma(
145
+ collection_name=COLLECTION_NAME,
146
+ persist_directory=chromadb_path,
147
+ embedding_function=self.embeddings
148
+ )
149
+
150
+ self.doc_count = self.vectorstore._collection.count()
151
+ print(f"✓ Loaded {self.doc_count:,} documents from {chromadb_path}")
152
+
153
+ def retrieve(self, query, top_k=5, min_score=None, complexity=None, rerank=True):
154
+ """Retrieve similar questions with filtering and ranking."""
155
+
156
+ fetch_k = min(top_k * 3, 50)
157
+ docs_with_scores = self.vectorstore.similarity_search_with_score(query, k=fetch_k)
158
+
159
+ # Format results
160
+ formatted = []
161
+ for doc, score in docs_with_scores:
162
+ formatted.append({
163
+ 'question': doc.page_content,
164
+ 'sql': doc.metadata.get('sql', ''),
165
+ 'source': doc.metadata.get('source', 'unknown'),
166
+ 'complexity': doc.metadata.get('complexity', 'unknown'),
167
+ 'keywords': doc.metadata.get('keywords', ''),
168
+ 'sql_clauses': doc.metadata.get('sql_clauses', ''),
169
+ 'distance': score,
170
+ 'score': 1 - score if score <= 1 else 1 / (1 + score)
171
+ })
172
+
173
+ # Apply filters
174
+ if min_score is not None:
175
+ formatted = filter_by_threshold(formatted, min_score)
176
+
177
+ if complexity is not None:
178
+ formatted = filter_by_complexity(formatted, complexity)
179
+
180
+ # Apply re-ranking
181
+ if rerank:
182
+ formatted = rerank_results(formatted, query)
183
+
184
+ return formatted[:top_k]
185
+
186
+ def retrieve_as_context(self, query, top_k=5):
187
+ """Retrieve and format as context for LLM prompt."""
188
+ results = self.retrieve(query, top_k=top_k)
189
+
190
+ if not results:
191
+ return ""
192
+
193
+ context = "Similar SQL examples:\n\n"
194
+ for i, r in enumerate(results, 1):
195
+ context += f"Example {i}:\n"
196
+ context += f"Question: {r['question']}\n"
197
+ context += f"SQL: {r['sql']}\n\n"
198
+
199
+ return context
200
+
201
+ def get_stats(self):
202
+ """Get retriever statistics."""
203
+ return {
204
+ 'total_documents': self.doc_count,
205
+ 'collection_name': COLLECTION_NAME,
206
+ 'embedding_model': EMBEDDING_MODEL,
207
+ }
208
+
209
+ # =============================================================================
210
+ # TEST
211
+ # =============================================================================
212
+
213
+ def test_retriever():
214
+ """Test retriever."""
215
+ print("=" * 60)
216
+ print("TESTING SQL RETRIEVER")
217
+ print("=" * 60)
218
+
219
+ retriever = SQLRetriever()
220
+
221
+ query = "Find all employees with salary above 50000"
222
+ results = retriever.retrieve(query, top_k=3)
223
+
224
+ print(f"\nQuery: {query}\n")
225
+ for i, r in enumerate(results, 1):
226
+ print(f"Result {i}: (score: {r['score']:.3f})")
227
+ print(f" Q: {r['question'][:60]}...")
228
+ print(f" SQL: {r['sql'][:60]}...")
229
+ print()
230
+
231
+ print("✓ Test complete")
232
+
233
+ if __name__ == "__main__":
234
+ test_retriever()
src/requirements.txt ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ --extra-index-url https://download.pytorch.org/whl/cu130
2
+
3
+
4
+ streamlit
5
+ chromadb
6
+ google-generativeai
7
+ python-dotenv
8
+ pandas
9
+ datasets
10
+ transformers
11
+ peft
12
+ accelerate
13
+ bitsandbytes
14
+ torch
15
+ torchvision
16
+ torchaudio
17
+ sentencepiece
18
+ huggingface_hub
19
+ matplotlib
20
+ sentence-transformers
21
+ langchain
22
+ langchain-community
23
+ langchain-chroma
src/streamlit_app.py DELETED
@@ -1,40 +0,0 @@
1
- import altair as alt
2
- import numpy as np
3
- import pandas as pd
4
- import streamlit as st
5
-
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/synthetic/__init__.py ADDED
File without changes
src/synthetic/generate_data.py ADDED
@@ -0,0 +1,401 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Synthetic Data Generation for SQL Learning Assistant
3
+
4
+ Covers:
5
+ 1. Create synthetic datasets for training/testing
6
+ 2. Implement data augmentation techniques
7
+ 3. Ensure diversity and quality of generated data
8
+ 4. Address privacy and ethical considerations
9
+ """
10
+
11
+ import pandas as pd
12
+ import random
13
+ import re
14
+ import hashlib
15
+ import json
16
+ from collections import Counter
17
+ from datetime import datetime
18
+ import matplotlib.pyplot as plt
19
+ import os
20
+ import sys
21
+
22
+ # Add parent directory to path for imports
23
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
24
+ from synthetic.synonyms import SYNONYMS, get_synonym, has_synonym
25
+
26
+ # =============================================================================
27
+ # OUTPUT DIRECTORIES
28
+ # =============================================================================
29
+
30
+ OUTPUT_DIR = "outputs/synthetic"
31
+ VIZ_DIR = f"{OUTPUT_DIR}/visualizations"
32
+ REPORT_DIR = f"{OUTPUT_DIR}/reports"
33
+ STATS_DIR = f"{OUTPUT_DIR}/stats"
34
+
35
+ def setup_directories():
36
+ """Create output directories."""
37
+ for d in [OUTPUT_DIR, VIZ_DIR, REPORT_DIR, STATS_DIR]:
38
+ os.makedirs(d, exist_ok=True)
39
+
40
+ # =============================================================================
41
+ # SENTENCE VARIATIONS
42
+ # =============================================================================
43
+
44
+ PREFIXES = ["", "Can you ", "Please ", "I want to ", "I need to ",
45
+ "Could you ", "Help me ", "Show me how to "]
46
+
47
+ SUFFIXES = ["", "?", " please", " for me", " please?"]
48
+
49
+ # =============================================================================
50
+ # AUGMENTATION TECHNIQUES
51
+ # =============================================================================
52
+
53
+ def replace_synonyms(text, prob=0.4):
54
+ """Technique 1: Replace words with synonyms."""
55
+ words = text.split()
56
+ result = []
57
+ for word in words:
58
+ clean = re.sub(r'[^\w]', '', word).lower()
59
+ if has_synonym(clean) and random.random() < prob:
60
+ syn = get_synonym(clean)
61
+ result.append(syn if word[-1] not in '.,?!' else syn + word[-1])
62
+ else:
63
+ result.append(word)
64
+ return ' '.join(result)
65
+
66
+ def random_insertion(text, prob=0.15):
67
+ """Technique 2: Insert contextual words."""
68
+ inserts = ["also", "specifically", "exactly", "just", "only"]
69
+ words = text.split()
70
+ if len(words) > 3 and random.random() < prob:
71
+ pos = random.randint(1, len(words) - 1)
72
+ words.insert(pos, random.choice(inserts))
73
+ return ' '.join(words)
74
+
75
+ def random_swap(text, prob=0.1):
76
+ """Technique 3: Swap adjacent words."""
77
+ words = text.split()
78
+ if len(words) > 4 and random.random() < prob:
79
+ pos = random.randint(1, len(words) - 3)
80
+ words[pos], words[pos + 1] = words[pos + 1], words[pos]
81
+ return ' '.join(words)
82
+
83
+ def structure_variation(text):
84
+ """Technique 4: Add prefixes and suffixes."""
85
+ prefix = random.choice(PREFIXES)
86
+ suffix = random.choice(SUFFIXES)
87
+ if prefix:
88
+ text = text[0].lower() + text[1:] if text else text
89
+ result = prefix + text + suffix
90
+ return result[0].upper() + result[1:] if result else result
91
+
92
+ def case_variation(text):
93
+ """Technique 5: Vary capitalization."""
94
+ r = random.random()
95
+ if r < 0.6:
96
+ return text[0].upper() + text[1:].lower() if text else text
97
+ elif r < 0.85:
98
+ return text.lower()
99
+ return text
100
+
101
+ def generate_variation(question):
102
+ """Apply all augmentation techniques."""
103
+ variation = question
104
+ variation = replace_synonyms(variation)
105
+ variation = random_insertion(variation)
106
+ variation = random_swap(variation)
107
+ variation = structure_variation(variation)
108
+ variation = case_variation(variation)
109
+ return variation
110
+
111
+ # =============================================================================
112
+ # QUALITY AND DIVERSITY
113
+ # =============================================================================
114
+
115
+ def diversity_score(original, variation):
116
+ """Calculate diversity between original and variation."""
117
+ orig_words = set(original.lower().split())
118
+ var_words = set(variation.lower().split())
119
+ if not orig_words or not var_words:
120
+ return 0
121
+ intersection = orig_words & var_words
122
+ union = orig_words | var_words
123
+ return 1 - (len(intersection) / len(union))
124
+
125
+ def quality_check(question, sql):
126
+ """Check if generated data passes quality standards."""
127
+ if not question or len(question.strip()) < 10:
128
+ return False
129
+ if not sql or len(sql.strip()) < 5:
130
+ return False
131
+ if not re.search(r'[a-zA-Z]', question):
132
+ return False
133
+ if len(question) > 500:
134
+ return False
135
+ return True
136
+
137
+ def remove_duplicates(data):
138
+ """Remove duplicate entries."""
139
+ seen = set()
140
+ unique = []
141
+ for item in data:
142
+ normalized = re.sub(r'[^\w\s]', '', item['question'].lower())
143
+ normalized = ' '.join(normalized.split())
144
+ h = hashlib.md5(normalized.encode()).hexdigest()
145
+ if h not in seen:
146
+ seen.add(h)
147
+ unique.append(item)
148
+ return unique
149
+
150
+ # =============================================================================
151
+ # PRIVACY (ETHICAL CONSIDERATIONS)
152
+ # =============================================================================
153
+
154
+ def anonymize(text):
155
+ """Remove sensitive information."""
156
+ text = re.sub(r'\b[\w.-]+@[\w.-]+\.\w+\b', '[EMAIL]', text)
157
+ text = re.sub(r'\b\d{3}[-.]?\d{3}[-.]?\d{4}\b', '[PHONE]', text)
158
+ text = re.sub(r'\b\d{3}-\d{2}-\d{4}\b', '[SSN]', text)
159
+ return text
160
+
161
+ # =============================================================================
162
+ # STATISTICS
163
+ # =============================================================================
164
+
165
+ def calculate_stats(original_df, synthetic_df):
166
+ """Calculate dataset statistics."""
167
+ def get_stats(df, name):
168
+ questions = df['question'].tolist()
169
+ lengths = [len(q.split()) for q in questions]
170
+ return {
171
+ 'name': name,
172
+ 'samples': len(df),
173
+ 'avg_length': round(sum(lengths) / len(lengths), 2),
174
+ 'min_length': min(lengths),
175
+ 'max_length': max(lengths),
176
+ 'unique_words': len(set(' '.join(questions).lower().split()))
177
+ }
178
+
179
+ orig_stats = get_stats(original_df, 'Original')
180
+ synth_stats = get_stats(synthetic_df, 'Synthetic')
181
+
182
+ diversity_scores = synthetic_df['diversity_score'].tolist()
183
+ diversity_stats = {
184
+ 'avg': round(sum(diversity_scores) / len(diversity_scores), 4),
185
+ 'min': round(min(diversity_scores), 4),
186
+ 'max': round(max(diversity_scores), 4)
187
+ }
188
+
189
+ return {
190
+ 'original': orig_stats,
191
+ 'synthetic': synth_stats,
192
+ 'diversity': diversity_stats,
193
+ 'augmentation_factor': round(len(synthetic_df) / len(original_df), 2)
194
+ }
195
+
196
+ # =============================================================================
197
+ # VISUALIZATIONS
198
+ # =============================================================================
199
+
200
+ def create_visualizations(original_df, synthetic_df):
201
+ """Create and save visualizations."""
202
+ plt.style.use('seaborn-v0_8-whitegrid')
203
+
204
+ # 1. Dataset Size Comparison
205
+ fig, ax = plt.subplots(figsize=(8, 5))
206
+ sizes = [len(original_df), len(synthetic_df)]
207
+ bars = ax.bar(['Original', 'Synthetic'], sizes, color=['#3498db', '#2ecc71'])
208
+ for bar, size in zip(bars, sizes):
209
+ ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 20,
210
+ f'{size:,}', ha='center', fontweight='bold')
211
+ ax.set_ylabel('Samples')
212
+ ax.set_title('Dataset Size Comparison')
213
+ plt.savefig(f'{VIZ_DIR}/01_size_comparison.png', dpi=150, bbox_inches='tight')
214
+ plt.close()
215
+
216
+ # 2. Question Length Distribution
217
+ fig, axes = plt.subplots(1, 2, figsize=(12, 4))
218
+ orig_len = [len(q.split()) for q in original_df['question']]
219
+ synth_len = [len(q.split()) for q in synthetic_df['question']]
220
+
221
+ axes[0].hist(orig_len, bins=25, color='#3498db', alpha=0.7)
222
+ axes[0].set_title('Original - Question Length')
223
+ axes[0].set_xlabel('Words')
224
+
225
+ axes[1].hist(synth_len, bins=25, color='#2ecc71', alpha=0.7)
226
+ axes[1].set_title('Synthetic - Question Length')
227
+ axes[1].set_xlabel('Words')
228
+
229
+ plt.tight_layout()
230
+ plt.savefig(f'{VIZ_DIR}/02_length_distribution.png', dpi=150, bbox_inches='tight')
231
+ plt.close()
232
+
233
+ # 3. Diversity Score Distribution
234
+ fig, ax = plt.subplots(figsize=(8, 5))
235
+ ax.hist(synthetic_df['diversity_score'], bins=20, color='#9b59b6', alpha=0.7)
236
+ ax.axvline(synthetic_df['diversity_score'].mean(), color='red', linestyle='--',
237
+ label=f"Mean: {synthetic_df['diversity_score'].mean():.3f}")
238
+ ax.set_xlabel('Diversity Score')
239
+ ax.set_ylabel('Frequency')
240
+ ax.set_title('Diversity Score Distribution')
241
+ ax.legend()
242
+ plt.savefig(f'{VIZ_DIR}/03_diversity_distribution.png', dpi=150, bbox_inches='tight')
243
+ plt.close()
244
+
245
+ print(f" Visualizations saved to {VIZ_DIR}/")
246
+
247
+ # =============================================================================
248
+ # REPORT GENERATION
249
+ # =============================================================================
250
+
251
+ def generate_report(stats):
252
+ """Generate markdown report."""
253
+ report = f"""# Synthetic Data Generation Report
254
+
255
+ **Generated:** {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
256
+
257
+ ## Dataset Statistics
258
+
259
+ | Metric | Original | Synthetic |
260
+ |--------|----------|-----------|
261
+ | Samples | {stats['original']['samples']:,} | {stats['synthetic']['samples']:,} |
262
+ | Avg Length | {stats['original']['avg_length']} | {stats['synthetic']['avg_length']} |
263
+ | Min Length | {stats['original']['min_length']} | {stats['synthetic']['min_length']} |
264
+ | Max Length | {stats['original']['max_length']} | {stats['synthetic']['max_length']} |
265
+ | Unique Words | {stats['original']['unique_words']:,} | {stats['synthetic']['unique_words']:,} |
266
+
267
+ ## Augmentation Results
268
+
269
+ - **Augmentation Factor:** {stats['augmentation_factor']}x
270
+ - **Avg Diversity Score:** {stats['diversity']['avg']}
271
+ - **Min Diversity Score:** {stats['diversity']['min']}
272
+ - **Max Diversity Score:** {stats['diversity']['max']}
273
+
274
+ ## Techniques Used
275
+
276
+ 1. Synonym Replacement (40% probability)
277
+ 2. Random Insertion (15% probability)
278
+ 3. Random Swap (10% probability)
279
+ 4. Structure Variation (prefix/suffix)
280
+ 5. Case Variation
281
+
282
+ ## Quality Controls
283
+
284
+ - Minimum question length: 10 characters
285
+ - Maximum question length: 500 characters
286
+ - Minimum diversity score: 0.1
287
+ - Duplicate removal via MD5 hashing
288
+
289
+ ## Privacy Measures
290
+
291
+ - Email anonymization
292
+ - Phone number anonymization
293
+ - SSN anonymization
294
+
295
+ ## Visualizations
296
+
297
+ - `01_size_comparison.png` - Dataset size comparison
298
+ - `02_length_distribution.png` - Question length distribution
299
+ - `03_diversity_distribution.png` - Diversity score distribution
300
+ """
301
+
302
+ with open(f'{REPORT_DIR}/synthetic_report.md', 'w') as f:
303
+ f.write(report)
304
+ print(f" Report saved to {REPORT_DIR}/synthetic_report.md")
305
+
306
+ # =============================================================================
307
+ # MAIN PIPELINE
308
+ # =============================================================================
309
+
310
+ def generate_synthetic_data(input_csv, output_csv, sample_size=500, variations=3, min_diversity=0.1):
311
+ """Main synthetic data generation pipeline."""
312
+
313
+ print("=" * 50)
314
+ print("SYNTHETIC DATA GENERATION")
315
+ print("=" * 50)
316
+
317
+ # Setup
318
+ setup_directories()
319
+
320
+ # Load data
321
+ print(f"\n[1/6] Loading {input_csv}...")
322
+ df = pd.read_csv(input_csv)
323
+ sample_df = df.sample(n=min(sample_size, len(df)), random_state=42)
324
+ print(f" Sampled {len(sample_df)} rows")
325
+
326
+ # Generate variations
327
+ print(f"\n[2/6] Generating variations...")
328
+ synthetic_data = []
329
+ skipped = 0
330
+
331
+ for _, row in sample_df.iterrows():
332
+ question = anonymize(str(row['question']))
333
+ sql = anonymize(str(row['sql']))
334
+
335
+ for _ in range(variations):
336
+ variation = generate_variation(question)
337
+ div_score = diversity_score(question, variation)
338
+
339
+ if div_score < min_diversity or not quality_check(variation, sql):
340
+ skipped += 1
341
+ continue
342
+
343
+ synthetic_data.append({
344
+ 'question': variation,
345
+ 'sql': sql,
346
+ 'original_question': question,
347
+ 'diversity_score': round(div_score, 3),
348
+ 'is_synthetic': True
349
+ })
350
+
351
+ print(f" Generated: {len(synthetic_data)}, Skipped: {skipped}")
352
+
353
+ # Remove duplicates
354
+ print(f"\n[3/6] Removing duplicates...")
355
+ before = len(synthetic_data)
356
+ synthetic_data = remove_duplicates(synthetic_data)
357
+ print(f" Removed {before - len(synthetic_data)} duplicates")
358
+
359
+ # Save data
360
+ print(f"\n[4/6] Saving data...")
361
+ synthetic_df = pd.DataFrame(synthetic_data)
362
+ synthetic_df.to_csv(output_csv, index=False)
363
+ print(f" Saved to {output_csv}")
364
+
365
+ # Calculate stats
366
+ print(f"\n[5/6] Calculating statistics...")
367
+ stats = calculate_stats(sample_df, synthetic_df)
368
+
369
+ # Save stats as JSON
370
+ with open(f'{STATS_DIR}/statistics.json', 'w') as f:
371
+ json.dump(stats, f, indent=2)
372
+ print(f" Stats saved to {STATS_DIR}/statistics.json")
373
+
374
+ # Generate visualizations and report
375
+ print(f"\n[6/6] Creating outputs...")
376
+ create_visualizations(sample_df, synthetic_df)
377
+ generate_report(stats)
378
+
379
+ # Summary
380
+ print("\n" + "=" * 50)
381
+ print("COMPLETE")
382
+ print("=" * 50)
383
+ print(f" Original: {stats['original']['samples']:,} samples")
384
+ print(f" Synthetic: {stats['synthetic']['samples']:,} samples")
385
+ print(f" Augmentation: {stats['augmentation_factor']}x")
386
+ print(f" Avg Diversity: {stats['diversity']['avg']}")
387
+
388
+ return synthetic_df
389
+
390
+ # =============================================================================
391
+ # ENTRY POINT
392
+ # =============================================================================
393
+
394
+ if __name__ == "__main__":
395
+ generate_synthetic_data(
396
+ input_csv="data/train.csv",
397
+ output_csv="data/synthetic.csv",
398
+ sample_size=52527,
399
+ variations=3,
400
+ min_diversity=0.1
401
+ )
src/synthetic/synonyms.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Synonym Dictionary for SQL Question Augmentation
3
+ """
4
+
5
+ import random
6
+
7
+ # =============================================================================
8
+ # SYNONYMS BY CATEGORY
9
+ # =============================================================================
10
+
11
+ # Action Verbs
12
+ QUERY_VERBS = {
13
+ "find": ["get", "show", "display", "list", "retrieve", "fetch", "return", "select"],
14
+ "show": ["display", "list", "get", "find", "retrieve", "present"],
15
+ "get": ["find", "show", "display", "list", "retrieve", "fetch", "obtain"],
16
+ "list": ["show", "display", "get", "find", "enumerate"],
17
+ "retrieve": ["get", "fetch", "find", "obtain", "extract"],
18
+ "select": ["choose", "pick", "get", "retrieve", "find"],
19
+ "search": ["find", "look for", "look up", "query"],
20
+ }
21
+
22
+ CALCULATION_VERBS = {
23
+ "calculate": ["compute", "determine", "find", "figure out", "work out"],
24
+ "compute": ["calculate", "determine", "figure out"],
25
+ "count": ["tally", "enumerate", "number", "total up"],
26
+ "sum": ["total", "add up", "aggregate"],
27
+ "average": ["mean", "find the average of"],
28
+ }
29
+
30
+ MANIPULATION_VERBS = {
31
+ "sort": ["order", "arrange", "rank", "organize"],
32
+ "filter": ["narrow down", "limit", "restrict", "select"],
33
+ "group": ["categorize", "organize", "cluster", "aggregate"],
34
+ "join": ["combine", "merge", "connect", "link"],
35
+ "update": ["modify", "change", "edit", "alter"],
36
+ "delete": ["remove", "erase", "drop", "eliminate"],
37
+ "insert": ["add", "create", "put", "include"],
38
+ }
39
+
40
+ # Comparison Terms
41
+ COMPARISONS = {
42
+ "greater than": ["more than", "above", "exceeding", "over", "higher than"],
43
+ "less than": ["below", "under", "fewer than", "smaller than", "lower than"],
44
+ "equal to": ["equals", "is", "matching", "same as"],
45
+ "between": ["in the range of", "ranging from", "within"],
46
+ "contains": ["includes", "has", "with"],
47
+ }
48
+
49
+ # Aggregation Terms
50
+ AGGREGATIONS = {
51
+ "maximum": ["highest", "largest", "greatest", "max", "top"],
52
+ "minimum": ["lowest", "smallest", "least", "min", "bottom"],
53
+ "average": ["mean", "avg"],
54
+ "total": ["sum", "combined", "overall", "aggregate"],
55
+ "count": ["number of", "how many", "total number of"],
56
+ "distinct": ["unique", "different", "separate"],
57
+ }
58
+
59
+ # Business Entities
60
+ ENTITIES = {
61
+ "employees": ["workers", "staff", "personnel", "team members"],
62
+ "customers": ["clients", "users", "buyers", "patrons"],
63
+ "products": ["items", "goods", "merchandise"],
64
+ "orders": ["purchases", "transactions", "sales"],
65
+ "suppliers": ["vendors", "providers", "distributors"],
66
+ "company": ["firm", "organization", "business"],
67
+ "department": ["dept", "division", "section", "unit"],
68
+ "manager": ["supervisor", "boss", "lead", "head"],
69
+ }
70
+
71
+ # Financial Terms
72
+ FINANCIAL = {
73
+ "price": ["cost", "amount", "value", "rate"],
74
+ "salary": ["pay", "wage", "income", "earnings"],
75
+ "revenue": ["income", "earnings", "sales"],
76
+ "profit": ["earnings", "gain", "margin"],
77
+ "cost": ["price", "expense", "charge"],
78
+ }
79
+
80
+ # Time Terms
81
+ TIME_TERMS = {
82
+ "date": ["day", "time", "period"],
83
+ "year": ["annum", "calendar year"],
84
+ "month": ["period", "calendar month"],
85
+ "recent": ["latest", "newest", "most recent"],
86
+ "current": ["present", "existing", "active"],
87
+ "previous": ["prior", "former", "past", "earlier"],
88
+ "last": ["final", "most recent", "latest"],
89
+ "first": ["initial", "earliest", "beginning"],
90
+ }
91
+
92
+ # Quantifiers
93
+ QUANTIFIERS = {
94
+ "all": ["every", "each", "the entire", "complete"],
95
+ "some": ["a few", "certain", "several"],
96
+ "many": ["numerous", "multiple", "several"],
97
+ "few": ["some", "a small number of", "limited"],
98
+ "only": ["just", "solely", "exclusively"],
99
+ }
100
+
101
+ # Adjectives
102
+ ADJECTIVES = {
103
+ "highest": ["greatest", "maximum", "largest", "top"],
104
+ "lowest": ["smallest", "minimum", "least", "bottom"],
105
+ "active": ["current", "live", "enabled"],
106
+ "inactive": ["disabled", "dormant", "idle"],
107
+ "new": ["recent", "latest", "fresh"],
108
+ "old": ["previous", "former", "past"],
109
+ }
110
+
111
+ # =============================================================================
112
+ # COMBINED DICTIONARY
113
+ # =============================================================================
114
+
115
+ def get_all_synonyms():
116
+ """Combine all synonym dictionaries."""
117
+ all_synonyms = {}
118
+ for d in [QUERY_VERBS, CALCULATION_VERBS, MANIPULATION_VERBS,
119
+ COMPARISONS, AGGREGATIONS, ENTITIES, FINANCIAL,
120
+ TIME_TERMS, QUANTIFIERS, ADJECTIVES]:
121
+ all_synonyms.update(d)
122
+ return all_synonyms
123
+
124
+ SYNONYMS = get_all_synonyms()
125
+
126
+ # =============================================================================
127
+ # UTILITY FUNCTIONS
128
+ # =============================================================================
129
+
130
+ def get_synonym(word):
131
+ """Get a random synonym for a word."""
132
+ word_lower = word.lower()
133
+ if word_lower in SYNONYMS:
134
+ return random.choice(SYNONYMS[word_lower])
135
+ return word
136
+
137
+ def has_synonym(word):
138
+ """Check if a word has synonyms."""
139
+ return word.lower() in SYNONYMS
140
+
141
+ def print_stats():
142
+ """Print synonym statistics."""
143
+ total_words = len(SYNONYMS)
144
+ total_synonyms = sum(len(v) for v in SYNONYMS.values())
145
+ print(f"Total words: {total_words}")
146
+ print(f"Total synonyms: {total_synonyms}")
147
+
148
+ if __name__ == "__main__":
149
+ print_stats()
src/tests/test_finetuned.py ADDED
File without changes
src/tests/test_rag.py ADDED
File without changes
src/tests/test_synthetic.py ADDED
File without changes