Spaces:
Sleeping
Sleeping
moheesh
commited on
Commit
·
f29ea6c
1
Parent(s):
e9aa12a
got all my code
Browse files- Dockerfile +1 -1
- src/.env.example +32 -0
- src/.gitignore +100 -0
- src/README.md +209 -0
- src/app.py +497 -0
- src/config.py +98 -0
- src/finetuning/__init__.py +0 -0
- src/finetuning/evaluate.py +293 -0
- src/finetuning/inference.py +168 -0
- src/finetuning/prepare_data.py +149 -0
- src/finetuning/train.py +218 -0
- src/outputs/finetuning/data_stats.json +7 -0
- src/outputs/finetuning/results/evaluation_report.md +26 -0
- src/outputs/finetuning/results/evaluation_results.json +7 -0
- src/outputs/finetuning/test.jsonl +100 -0
- src/outputs/finetuning/train.jsonl +100 -0
- src/outputs/finetuning/val.jsonl +100 -0
- src/outputs/finetuning/visualizations/01_metrics_overview.png +0 -0
- src/outputs/finetuning/visualizations/02_token_accuracy_dist.png +0 -0
- src/outputs/finetuning/visualizations/03_keyword_accuracy_dist.png +0 -0
- src/outputs/finetuning/visualizations/04_training_loss.png +0 -0
- src/outputs/rag/reports/knowledge_base_report.md +46 -0
- src/outputs/rag/stats/knowledge_base_stats.json +22 -0
- src/outputs/synthetic/reports/synthetic_report.md +47 -0
- src/outputs/synthetic/stats/statistics.json +24 -0
- src/outputs/synthetic/visualizations/01_size_comparison.png +0 -0
- src/outputs/synthetic/visualizations/02_length_distribution.png +0 -0
- src/outputs/synthetic/visualizations/03_diversity_distribution.png +0 -0
- src/pipeline/integrated.py +584 -0
- src/prompts/__init__.py +0 -0
- src/prompts/prompt_builder.py +440 -0
- src/prompts/system_prompts.py +162 -0
- src/rag/__init__.py +0 -0
- src/rag/embeddings.py +87 -0
- src/rag/knowledge_base.py +415 -0
- src/rag/retriever.py +234 -0
- src/requirements.txt +23 -0
- src/streamlit_app.py +0 -40
- src/synthetic/__init__.py +0 -0
- src/synthetic/generate_data.py +401 -0
- src/synthetic/synonyms.py +149 -0
- src/tests/test_finetuned.py +0 -0
- src/tests/test_rag.py +0 -0
- src/tests/test_synthetic.py +0 -0
Dockerfile
CHANGED
|
@@ -17,4 +17,4 @@ EXPOSE 8501
|
|
| 17 |
|
| 18 |
HEALTHCHECK CMD curl --fail http://localhost:8501/_stcore/health
|
| 19 |
|
| 20 |
-
ENTRYPOINT ["streamlit", "run", "src/
|
|
|
|
| 17 |
|
| 18 |
HEALTHCHECK CMD curl --fail http://localhost:8501/_stcore/health
|
| 19 |
|
| 20 |
+
ENTRYPOINT ["streamlit", "run", "src/app.py", "--server.port=8501", "--server.address=0.0.0.0"]
|
src/.env.example
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# =============================================================================
|
| 2 |
+
# GEMINI API KEYS (Required)
|
| 3 |
+
# =============================================================================
|
| 4 |
+
GEMINI_API_KEY=your-primary-gemini-key
|
| 5 |
+
GEMINI_API_KEY_FALLBACK_1=your-fallback-key-1
|
| 6 |
+
GEMINI_API_KEY_FALLBACK_2=your-fallback-key-2
|
| 7 |
+
|
| 8 |
+
# =============================================================================
|
| 9 |
+
# GEMINI MODELS
|
| 10 |
+
# =============================================================================
|
| 11 |
+
GEMINI_MODEL=gemini-2.5-flash
|
| 12 |
+
GEMINI_MODEL_FALLBACK_1=gemini-2.5-flash-lite
|
| 13 |
+
|
| 14 |
+
# =============================================================================
|
| 15 |
+
# HUGGINGFACE (Required for cloud deployment, optional for local)
|
| 16 |
+
# =============================================================================
|
| 17 |
+
HF_TOKEN=your-huggingface-token
|
| 18 |
+
HF_MODEL_ID=your-username/sql-tinyllama-lora
|
| 19 |
+
HF_CHROMADB_ID=your-username/sql-chromadb
|
| 20 |
+
|
| 21 |
+
# =============================================================================
|
| 22 |
+
# HOW IT WORKS:
|
| 23 |
+
# =============================================================================
|
| 24 |
+
# LOCAL RUN:
|
| 25 |
+
# - If outputs/finetuning/checkpoints/final exists → uses local model
|
| 26 |
+
# - If chromadb_data exists → uses local ChromaDB
|
| 27 |
+
#
|
| 28 |
+
# CLOUD RUN (Streamlit):
|
| 29 |
+
# - If HF_MODEL_ID set → downloads model from HuggingFace
|
| 30 |
+
# - If HF_CHROMADB_ID set → downloads ChromaDB from HuggingFace
|
| 31 |
+
# - Falls back to building ChromaDB from data/ folder if needed
|
| 32 |
+
# =============================================================================
|
src/.gitignore
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# =============================================================================
|
| 2 |
+
# ENVIRONMENT & SECRETS
|
| 3 |
+
# =============================================================================
|
| 4 |
+
.env
|
| 5 |
+
.env.local
|
| 6 |
+
.env.production
|
| 7 |
+
|
| 8 |
+
# =============================================================================
|
| 9 |
+
# VIRTUAL ENVIRONMENT
|
| 10 |
+
# =============================================================================
|
| 11 |
+
.venv/
|
| 12 |
+
venv/
|
| 13 |
+
env/
|
| 14 |
+
ENV/
|
| 15 |
+
data/
|
| 16 |
+
data/synthetic.csv
|
| 17 |
+
# =============================================================================
|
| 18 |
+
# PYTHON
|
| 19 |
+
# =============================================================================
|
| 20 |
+
__pycache__/
|
| 21 |
+
*.py[cod]
|
| 22 |
+
*$py.class
|
| 23 |
+
*.so
|
| 24 |
+
.Python
|
| 25 |
+
build/
|
| 26 |
+
develop-eggs/
|
| 27 |
+
dist/
|
| 28 |
+
downloads/
|
| 29 |
+
eggs/
|
| 30 |
+
.eggs/
|
| 31 |
+
lib/
|
| 32 |
+
lib64/
|
| 33 |
+
parts/
|
| 34 |
+
sdist/
|
| 35 |
+
var/
|
| 36 |
+
wheels/
|
| 37 |
+
*.egg-info/
|
| 38 |
+
.installed.cfg
|
| 39 |
+
*.egg
|
| 40 |
+
|
| 41 |
+
# =============================================================================
|
| 42 |
+
# MODEL FILES (Upload to HuggingFace instead)
|
| 43 |
+
# =============================================================================
|
| 44 |
+
outputs/finetuning/checkpoints/
|
| 45 |
+
*.bin
|
| 46 |
+
*.pt
|
| 47 |
+
*.pth
|
| 48 |
+
*.safetensors
|
| 49 |
+
*.ckpt
|
| 50 |
+
|
| 51 |
+
# =============================================================================
|
| 52 |
+
# CHROMADB (Upload to HuggingFace instead)
|
| 53 |
+
# =============================================================================
|
| 54 |
+
chromadb_data/
|
| 55 |
+
|
| 56 |
+
# =============================================================================
|
| 57 |
+
# LOGS & OUTPUTS
|
| 58 |
+
# =============================================================================
|
| 59 |
+
*.log
|
| 60 |
+
outputs/*/logs/
|
| 61 |
+
outputs/pipeline/logs/
|
| 62 |
+
outputs/prompts/logs/
|
| 63 |
+
|
| 64 |
+
# =============================================================================
|
| 65 |
+
# IDE
|
| 66 |
+
# =============================================================================
|
| 67 |
+
.vscode/
|
| 68 |
+
.idea/
|
| 69 |
+
*.swp
|
| 70 |
+
*.swo
|
| 71 |
+
*~
|
| 72 |
+
|
| 73 |
+
# =============================================================================
|
| 74 |
+
# OS
|
| 75 |
+
# =============================================================================
|
| 76 |
+
.DS_Store
|
| 77 |
+
.DS_Store?
|
| 78 |
+
._*
|
| 79 |
+
.Spotlight-V100
|
| 80 |
+
.Trashes
|
| 81 |
+
ehthumbs.db
|
| 82 |
+
Thumbs.db
|
| 83 |
+
|
| 84 |
+
# =============================================================================
|
| 85 |
+
# JUPYTER
|
| 86 |
+
# =============================================================================
|
| 87 |
+
.ipynb_checkpoints/
|
| 88 |
+
*.ipynb
|
| 89 |
+
|
| 90 |
+
# =============================================================================
|
| 91 |
+
# KEEP THESE (don't ignore)
|
| 92 |
+
# =============================================================================
|
| 93 |
+
# !data/
|
| 94 |
+
# !data/*.csv
|
| 95 |
+
# !docs/
|
| 96 |
+
# !docs/index.html
|
| 97 |
+
# !*.py
|
| 98 |
+
# !requirements.txt
|
| 99 |
+
# !README.md
|
| 100 |
+
# !.env.example
|
src/README.md
ADDED
|
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ⚡ Prompt to SQL using RAG + LLM
|
| 2 |
+
|
| 3 |
+
AI-powered Natural Language to SQL conversion using RAG, Fine-tuned LLM, and Gemini Enhancement.
|
| 4 |
+
|
| 5 |
+

|
| 6 |
+

|
| 7 |
+

|
| 8 |
+
|
| 9 |
+
## 🌐 Live Demo
|
| 10 |
+
|
| 11 |
+
- **🚀 Web App:** [Streamlit App](https://your-app.streamlit.app)
|
| 12 |
+
- **📄 Project Page:** [GitHub Pages](https://moheesh.github.io/Prompt_to_SQL_using_RAG_LLM)
|
| 13 |
+
|
| 14 |
+
## ✨ Features
|
| 15 |
+
|
| 16 |
+
| Feature | Description |
|
| 17 |
+
|---------|-------------|
|
| 18 |
+
| 🔍 **RAG Retrieval** | 80,000+ SQL examples in ChromaDB vector store |
|
| 19 |
+
| 🤖 **Fine-tuned LLM** | TinyLlama with LoRA for SQL generation |
|
| 20 |
+
| ✨ **Gemini Enhancement** | Query refinement, validation & explanation |
|
| 21 |
+
| 📝 **Prompt Engineering** | Context management, edge cases, query analysis |
|
| 22 |
+
| 📦 **Synthetic Data** | Data augmentation with 5 techniques |
|
| 23 |
+
| 🔄 **Auto Fallback** | Multiple API keys & models for reliability |
|
| 24 |
+
|
| 25 |
+
## 🔄 Pipeline Architecture
|
| 26 |
+
|
| 27 |
+
```
|
| 28 |
+
┌─────────────────────┐
|
| 29 |
+
│ Synthetic Data │ (Training augmentation)
|
| 30 |
+
└─────────┬───────────┘
|
| 31 |
+
↓
|
| 32 |
+
┌─────────────────────┐
|
| 33 |
+
│ Fine-tuned Model │ (LoRA training on TinyLlama)
|
| 34 |
+
└─────────┬───────────┘
|
| 35 |
+
↓
|
| 36 |
+
┌─────────────────────┐
|
| 37 |
+
│ User Question │ (Natural language input)
|
| 38 |
+
└─────────┬───────────┘
|
| 39 |
+
↓
|
| 40 |
+
┌─────────────────────┐
|
| 41 |
+
│ RAG Retrieval │ (Similar examples from ChromaDB)
|
| 42 |
+
└─────────┬───────────┘
|
| 43 |
+
↓
|
| 44 |
+
┌─────────────────────┐
|
| 45 |
+
│ Prompt Engineering │ (Context + query formatting)
|
| 46 |
+
└─────────┬───────────┘
|
| 47 |
+
↓
|
| 48 |
+
┌─────────────────────┐
|
| 49 |
+
│ Fine-tuned Model │ (SQL generation)
|
| 50 |
+
└─────────┬───────────┘
|
| 51 |
+
↓
|
| 52 |
+
┌─────────────────────┐
|
| 53 |
+
│ Gemini Enhancement │ (Refine + explain)
|
| 54 |
+
└─────────┬───────────┘
|
| 55 |
+
↓
|
| 56 |
+
┌─────────────────────┐
|
| 57 |
+
│ Final SQL │ (Optimized output)
|
| 58 |
+
└─────────────────────┘
|
| 59 |
+
```
|
| 60 |
+
|
| 61 |
+
## 📁 Project Structure
|
| 62 |
+
|
| 63 |
+
```
|
| 64 |
+
Prompt_to_SQL_using_RAG_LLM/
|
| 65 |
+
├── app.py # Streamlit UI
|
| 66 |
+
├── config.py # Central configuration
|
| 67 |
+
├── requirements.txt # Dependencies
|
| 68 |
+
│
|
| 69 |
+
├── pipeline/
|
| 70 |
+
│ └── integrated.py # Main pipeline (RAG + Model + Gemini)
|
| 71 |
+
│
|
| 72 |
+
├── finetuning/
|
| 73 |
+
│ ├── prepare_data.py # Data preparation
|
| 74 |
+
│ ├── train.py # LoRA fine-tuning
|
| 75 |
+
│ ├── evaluate.py # Model evaluation
|
| 76 |
+
│ └── inference.py # SQL generation
|
| 77 |
+
│
|
| 78 |
+
├── rag/
|
| 79 |
+
│ ├── embeddings.py # Sentence transformers
|
| 80 |
+
│ ├── knowledge_base.py # ChromaDB builder
|
| 81 |
+
│ └── retriever.py # LangChain retriever
|
| 82 |
+
│
|
| 83 |
+
├── prompts/
|
| 84 |
+
│ ├── prompt_builder.py # Context management
|
| 85 |
+
│ └── system_prompts.py # Prompt templates
|
| 86 |
+
│
|
| 87 |
+
├── synthetic/
|
| 88 |
+
│ ├── generate_data.py # Data augmentation
|
| 89 |
+
│ └── synonyms.py # Synonym dictionary
|
| 90 |
+
│
|
| 91 |
+
├── data/
|
| 92 |
+
│ ├── train.csv
|
| 93 |
+
│ ├── validation.csv
|
| 94 |
+
│ └── test.csv
|
| 95 |
+
│
|
| 96 |
+
└── docs/
|
| 97 |
+
└── index.html # GitHub Pages
|
| 98 |
+
```
|
| 99 |
+
|
| 100 |
+
## 🛠️ Setup
|
| 101 |
+
|
| 102 |
+
### 1. Clone the Repository
|
| 103 |
+
|
| 104 |
+
```bash
|
| 105 |
+
git clone https://github.com/moheesh/Prompt_to_SQL_using_RAG_LLM.git
|
| 106 |
+
cd Prompt_to_SQL_using_RAG_LLM
|
| 107 |
+
```
|
| 108 |
+
|
| 109 |
+
### 2. Create Virtual Environment
|
| 110 |
+
|
| 111 |
+
```bash
|
| 112 |
+
python -m venv .venv
|
| 113 |
+
|
| 114 |
+
# Windows
|
| 115 |
+
.venv\Scripts\activate
|
| 116 |
+
|
| 117 |
+
# Mac/Linux
|
| 118 |
+
source .venv/bin/activate
|
| 119 |
+
```
|
| 120 |
+
|
| 121 |
+
### 3. Install Dependencies
|
| 122 |
+
|
| 123 |
+
```bash
|
| 124 |
+
pip install -r requirements.txt
|
| 125 |
+
```
|
| 126 |
+
|
| 127 |
+
### 4. Configure Environment
|
| 128 |
+
|
| 129 |
+
Create a `.env` file:
|
| 130 |
+
|
| 131 |
+
```env
|
| 132 |
+
# Gemini API
|
| 133 |
+
GEMINI_API_KEY=your-primary-key
|
| 134 |
+
GEMINI_MODEL=gemini-2.5-flash
|
| 135 |
+
|
| 136 |
+
# HuggingFace (for cloud deployment)
|
| 137 |
+
HF_TOKEN=your-hf-token
|
| 138 |
+
HF_MODEL_ID=your-username/sql-tinyllama-lora
|
| 139 |
+
HF_CHROMADB_ID=your-username/sql-chromadb
|
| 140 |
+
```
|
| 141 |
+
|
| 142 |
+
### 5. Build Knowledge Base (First Time)
|
| 143 |
+
|
| 144 |
+
```bash
|
| 145 |
+
python rag/knowledge_base.py
|
| 146 |
+
```
|
| 147 |
+
|
| 148 |
+
### 6. Run the App
|
| 149 |
+
|
| 150 |
+
```bash
|
| 151 |
+
streamlit run app.py
|
| 152 |
+
```
|
| 153 |
+
|
| 154 |
+
## 🚀 Deployment
|
| 155 |
+
|
| 156 |
+
### Upload to HuggingFace
|
| 157 |
+
|
| 158 |
+
```bash
|
| 159 |
+
# Login
|
| 160 |
+
huggingface-cli login
|
| 161 |
+
|
| 162 |
+
# Upload model
|
| 163 |
+
python -c "from huggingface_hub import HfApi; api = HfApi(); api.upload_folder(folder_path='outputs/finetuning/checkpoints/final', repo_id='moheesh/sql-tinyllama-lora', repo_type='model', create_repo=True)"
|
| 164 |
+
|
| 165 |
+
# Upload ChromaDB
|
| 166 |
+
python -c "from huggingface_hub import HfApi; api = HfApi(); api.upload_folder(folder_path='chromadb_data', repo_id='moheesh/sql-chromadb', repo_type='dataset', create_repo=True)"
|
| 167 |
+
```
|
| 168 |
+
|
| 169 |
+
### Deploy to Streamlit Cloud
|
| 170 |
+
|
| 171 |
+
1. Push code to GitHub
|
| 172 |
+
2. Go to [share.streamlit.io](https://share.streamlit.io)
|
| 173 |
+
3. Connect your repo
|
| 174 |
+
4. Add secrets (same as `.env`)
|
| 175 |
+
5. Deploy!
|
| 176 |
+
|
| 177 |
+
## 🛠️ Tech Stack
|
| 178 |
+
|
| 179 |
+
| Component | Technology |
|
| 180 |
+
|-----------|------------|
|
| 181 |
+
| LLM | TinyLlama + LoRA |
|
| 182 |
+
| Vector DB | ChromaDB |
|
| 183 |
+
| Embeddings | all-MiniLM-L6-v2 |
|
| 184 |
+
| Enhancement | Gemini API |
|
| 185 |
+
| Framework | LangChain |
|
| 186 |
+
| UI | Streamlit |
|
| 187 |
+
|
| 188 |
+
## 📊 Evaluation Metrics
|
| 189 |
+
|
| 190 |
+
| Metric | Score |
|
| 191 |
+
|--------|-------|
|
| 192 |
+
| Exact Match | XX% |
|
| 193 |
+
| Token Accuracy | XX% |
|
| 194 |
+
| Keyword Accuracy | XX% |
|
| 195 |
+
| Structure Similarity | XX% |
|
| 196 |
+
|
| 197 |
+
## 🎓 Course
|
| 198 |
+
|
| 199 |
+
**INFO7375** - Northeastern University
|
| 200 |
+
|
| 201 |
+
## 👤 Author
|
| 202 |
+
|
| 203 |
+
**Your Name**
|
| 204 |
+
- GitHub: [@moheesh](https://github.com/moheesh)
|
| 205 |
+
- LinkedIn: [LinkedIn](https://linkedin.com/in/moheesh-k-a-a95306169)
|
| 206 |
+
|
| 207 |
+
## 📝 License
|
| 208 |
+
|
| 209 |
+
MIT License - see [LICENSE](LICENSE) for details.
|
src/app.py
ADDED
|
@@ -0,0 +1,497 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Streamlit App for SQL Learning Assistant
|
| 3 |
+
Integrates: RAG + Fine-tuned Model + Gemini Enhancement
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import streamlit as st
|
| 7 |
+
import os
|
| 8 |
+
import sys
|
| 9 |
+
from dotenv import load_dotenv
|
| 10 |
+
|
| 11 |
+
# Load environment variables FIRST
|
| 12 |
+
load_dotenv()
|
| 13 |
+
|
| 14 |
+
# Add parent directory
|
| 15 |
+
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
| 16 |
+
|
| 17 |
+
# =============================================================================
|
| 18 |
+
# PAGE CONFIG - MUST BE FIRST STREAMLIT COMMAND
|
| 19 |
+
# =============================================================================
|
| 20 |
+
|
| 21 |
+
st.set_page_config(
|
| 22 |
+
page_title="SQL Learning Assistant",
|
| 23 |
+
page_icon="⚡",
|
| 24 |
+
layout="wide",
|
| 25 |
+
initial_sidebar_state="expanded"
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
# =============================================================================
|
| 29 |
+
# CACHED LOADERS - Load on-demand, cache forever
|
| 30 |
+
# =============================================================================
|
| 31 |
+
|
| 32 |
+
@st.cache_resource(show_spinner=False)
|
| 33 |
+
def load_chromadb():
|
| 34 |
+
"""Download ChromaDB from HuggingFace if needed."""
|
| 35 |
+
chromadb_path = "chromadb_data"
|
| 36 |
+
hf_chromadb_id = os.getenv("HF_CHROMADB_ID", None)
|
| 37 |
+
|
| 38 |
+
has_files = False
|
| 39 |
+
if os.path.exists(chromadb_path):
|
| 40 |
+
local_files = os.listdir(chromadb_path) if os.path.isdir(chromadb_path) else []
|
| 41 |
+
has_files = any('chroma' in f.lower() or 'sqlite' in f.lower() for f in local_files) or len(local_files) > 2
|
| 42 |
+
|
| 43 |
+
if not has_files and hf_chromadb_id:
|
| 44 |
+
from huggingface_hub import snapshot_download
|
| 45 |
+
os.makedirs(chromadb_path, exist_ok=True)
|
| 46 |
+
snapshot_download(repo_id=hf_chromadb_id, repo_type="dataset", local_dir=chromadb_path)
|
| 47 |
+
|
| 48 |
+
return chromadb_path
|
| 49 |
+
|
| 50 |
+
@st.cache_resource(show_spinner=False)
|
| 51 |
+
def load_retriever():
|
| 52 |
+
"""Load the RAG retriever."""
|
| 53 |
+
load_chromadb()
|
| 54 |
+
from rag.retriever import SQLRetriever
|
| 55 |
+
return SQLRetriever()
|
| 56 |
+
|
| 57 |
+
@st.cache_resource(show_spinner=False)
|
| 58 |
+
def load_model():
|
| 59 |
+
"""Load the fine-tuned model."""
|
| 60 |
+
from finetuning.inference import SQLGenerator
|
| 61 |
+
return SQLGenerator()
|
| 62 |
+
|
| 63 |
+
@st.cache_resource(show_spinner=False)
|
| 64 |
+
def load_prompt_builder():
|
| 65 |
+
"""Load prompt builder."""
|
| 66 |
+
from prompts.prompt_builder import PromptBuilder
|
| 67 |
+
return PromptBuilder()
|
| 68 |
+
|
| 69 |
+
@st.cache_resource(show_spinner=False)
|
| 70 |
+
def load_gemini():
|
| 71 |
+
"""Load Gemini client."""
|
| 72 |
+
from pipeline.integrated import GeminiClient, GEMINI_KEYS
|
| 73 |
+
if GEMINI_KEYS:
|
| 74 |
+
return GeminiClient()
|
| 75 |
+
return None
|
| 76 |
+
|
| 77 |
+
# =============================================================================
|
| 78 |
+
# HELPER FUNCTION TO RUN PIPELINE
|
| 79 |
+
# =============================================================================
|
| 80 |
+
|
| 81 |
+
def run_pipeline(question, num_examples=3):
|
| 82 |
+
"""Run the full pipeline - loads components on first use."""
|
| 83 |
+
result = {
|
| 84 |
+
'question': question,
|
| 85 |
+
'success': False,
|
| 86 |
+
'steps': {}
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
# Step 1: RAG
|
| 90 |
+
rag_context = ""
|
| 91 |
+
examples = []
|
| 92 |
+
try:
|
| 93 |
+
with st.spinner("🔍 Loading RAG system..."):
|
| 94 |
+
retriever = load_retriever()
|
| 95 |
+
if retriever:
|
| 96 |
+
examples = retriever.retrieve(question, top_k=num_examples)
|
| 97 |
+
rag_context = "Similar SQL examples:\n\n"
|
| 98 |
+
for i, r in enumerate(examples, 1):
|
| 99 |
+
rag_context += f"Example {i}:\nQuestion: {r['question']}\nSQL: {r['sql']}\n\n"
|
| 100 |
+
except Exception as e:
|
| 101 |
+
st.warning(f"RAG error: {e}")
|
| 102 |
+
|
| 103 |
+
result['steps']['rag'] = {'examples': examples, 'num_examples': len(examples), 'context': rag_context}
|
| 104 |
+
|
| 105 |
+
# Step 2: Prompt
|
| 106 |
+
prompt = ""
|
| 107 |
+
try:
|
| 108 |
+
prompt_builder = load_prompt_builder()
|
| 109 |
+
if prompt_builder:
|
| 110 |
+
prompt_result = prompt_builder.build_prompt(question=question, rag_context=rag_context)
|
| 111 |
+
if prompt_result['success']:
|
| 112 |
+
prompt = prompt_result['prompt']
|
| 113 |
+
except:
|
| 114 |
+
pass
|
| 115 |
+
if not prompt:
|
| 116 |
+
prompt = f"{rag_context}\nQuestion: {question}\n\nSQL:"
|
| 117 |
+
|
| 118 |
+
result['steps']['prompt'] = {'prompt': prompt, 'length': len(prompt)}
|
| 119 |
+
|
| 120 |
+
# Step 3: Fine-tuned Model
|
| 121 |
+
finetuned_sql = None
|
| 122 |
+
try:
|
| 123 |
+
with st.spinner("🤖 Loading AI model..."):
|
| 124 |
+
model = load_model()
|
| 125 |
+
if model:
|
| 126 |
+
finetuned_sql = model.generate(question, rag_context)
|
| 127 |
+
except Exception as e:
|
| 128 |
+
st.warning(f"Model error: {e}")
|
| 129 |
+
|
| 130 |
+
result['steps']['finetuned'] = {'sql': finetuned_sql, 'error': None if finetuned_sql else 'Model not available'}
|
| 131 |
+
|
| 132 |
+
if not finetuned_sql:
|
| 133 |
+
return result
|
| 134 |
+
|
| 135 |
+
# Step 4: Gemini Enhancement
|
| 136 |
+
enhanced_sql = finetuned_sql
|
| 137 |
+
try:
|
| 138 |
+
gemini = load_gemini()
|
| 139 |
+
if gemini:
|
| 140 |
+
enhance_prompt = f"""You are an SQL expert. Review and enhance this SQL query.
|
| 141 |
+
|
| 142 |
+
Original Question: {question}
|
| 143 |
+
|
| 144 |
+
Generated SQL (by a smaller model):
|
| 145 |
+
{finetuned_sql}
|
| 146 |
+
|
| 147 |
+
Rules:
|
| 148 |
+
- If the SQL is correct, return it unchanged
|
| 149 |
+
- If it needs fixes, return the corrected version
|
| 150 |
+
- Return ONLY the SQL query, no explanations
|
| 151 |
+
|
| 152 |
+
Enhanced SQL:"""
|
| 153 |
+
response, error = gemini.generate(enhance_prompt)
|
| 154 |
+
if response and not error:
|
| 155 |
+
enhanced_sql = response.strip()
|
| 156 |
+
if enhanced_sql.startswith("```"):
|
| 157 |
+
lines = enhanced_sql.split("\n")
|
| 158 |
+
enhanced_sql = "\n".join(lines[1:-1] if lines[-1] == "```" else lines[1:])
|
| 159 |
+
if enhanced_sql.lower().startswith("sql"):
|
| 160 |
+
enhanced_sql = enhanced_sql[3:].strip()
|
| 161 |
+
except Exception as e:
|
| 162 |
+
st.warning(f"Gemini enhance error: {e}")
|
| 163 |
+
|
| 164 |
+
result['steps']['gemini_enhance'] = {'sql': enhanced_sql, 'info': {'enhanced': enhanced_sql != finetuned_sql}}
|
| 165 |
+
result['final_sql'] = enhanced_sql
|
| 166 |
+
|
| 167 |
+
# Step 5: Explanation
|
| 168 |
+
explanation = ""
|
| 169 |
+
try:
|
| 170 |
+
gemini = load_gemini()
|
| 171 |
+
if gemini:
|
| 172 |
+
explain_prompt = f"Explain this SQL query in simple terms (2-3 sentences):\n\nSQL: {enhanced_sql}"
|
| 173 |
+
response, error = gemini.generate(explain_prompt)
|
| 174 |
+
if response and not error:
|
| 175 |
+
explanation = response.strip()
|
| 176 |
+
except:
|
| 177 |
+
pass
|
| 178 |
+
|
| 179 |
+
result['explanation'] = explanation
|
| 180 |
+
result['success'] = True
|
| 181 |
+
|
| 182 |
+
return result
|
| 183 |
+
|
| 184 |
+
# =============================================================================
|
| 185 |
+
# CUSTOM CSS
|
| 186 |
+
# =============================================================================
|
| 187 |
+
|
| 188 |
+
st.markdown("""
|
| 189 |
+
<style>
|
| 190 |
+
.stApp {
|
| 191 |
+
background: linear-gradient(135deg, #0f0f23 0%, #1a1a2e 50%, #16213e 100%);
|
| 192 |
+
}
|
| 193 |
+
|
| 194 |
+
.main-header {
|
| 195 |
+
font-size: 3rem;
|
| 196 |
+
font-weight: 800;
|
| 197 |
+
background: linear-gradient(120deg, #00d4ff, #7c3aed, #f472b6);
|
| 198 |
+
-webkit-background-clip: text;
|
| 199 |
+
-webkit-text-fill-color: transparent;
|
| 200 |
+
background-clip: text;
|
| 201 |
+
text-align: center;
|
| 202 |
+
margin-bottom: 0.5rem;
|
| 203 |
+
}
|
| 204 |
+
|
| 205 |
+
.sub-header {
|
| 206 |
+
font-size: 1.1rem;
|
| 207 |
+
color: #94a3b8;
|
| 208 |
+
text-align: center;
|
| 209 |
+
margin-bottom: 2rem;
|
| 210 |
+
}
|
| 211 |
+
|
| 212 |
+
.stButton > button {
|
| 213 |
+
background: linear-gradient(135deg, #1e293b 0%, #334155 100%);
|
| 214 |
+
color: #e2e8f0;
|
| 215 |
+
border: 1px solid #475569;
|
| 216 |
+
border-radius: 10px;
|
| 217 |
+
transition: all 0.3s ease;
|
| 218 |
+
}
|
| 219 |
+
|
| 220 |
+
.stButton > button:hover {
|
| 221 |
+
background: linear-gradient(135deg, #3b82f6 0%, #8b5cf6 100%);
|
| 222 |
+
border-color: #60a5fa;
|
| 223 |
+
transform: translateY(-2px);
|
| 224 |
+
}
|
| 225 |
+
|
| 226 |
+
.stTextInput > div > div > input {
|
| 227 |
+
background: rgba(30, 41, 59, 0.8);
|
| 228 |
+
border: 1px solid #475569;
|
| 229 |
+
border-radius: 12px;
|
| 230 |
+
color: #f1f5f9;
|
| 231 |
+
}
|
| 232 |
+
|
| 233 |
+
[data-testid="stSidebar"] {
|
| 234 |
+
background: linear-gradient(180deg, #0f172a 0%, #1e293b 100%);
|
| 235 |
+
}
|
| 236 |
+
|
| 237 |
+
.pipeline-box {
|
| 238 |
+
background: rgba(30, 41, 59, 0.6);
|
| 239 |
+
border: 1px solid #475569;
|
| 240 |
+
border-radius: 8px;
|
| 241 |
+
padding: 0.5rem 1rem;
|
| 242 |
+
margin: 0.25rem 0;
|
| 243 |
+
font-size: 0.85rem;
|
| 244 |
+
text-align: center;
|
| 245 |
+
}
|
| 246 |
+
|
| 247 |
+
.pipeline-arrow {
|
| 248 |
+
color: #3b82f6;
|
| 249 |
+
text-align: center;
|
| 250 |
+
font-size: 1.2rem;
|
| 251 |
+
}
|
| 252 |
+
</style>
|
| 253 |
+
""", unsafe_allow_html=True)
|
| 254 |
+
|
| 255 |
+
# =============================================================================
|
| 256 |
+
# HEADER
|
| 257 |
+
# =============================================================================
|
| 258 |
+
|
| 259 |
+
st.markdown('<p class="main-header">⚡ SQL Learning Assistant</p>', unsafe_allow_html=True)
|
| 260 |
+
st.markdown('<p class="sub-header">Transform Natural Language into SQL using AI-Powered Pipeline</p>', unsafe_allow_html=True)
|
| 261 |
+
|
| 262 |
+
# =============================================================================
|
| 263 |
+
# SIDEBAR
|
| 264 |
+
# =============================================================================
|
| 265 |
+
|
| 266 |
+
with st.sidebar:
|
| 267 |
+
st.markdown("## ⚙️ Configuration")
|
| 268 |
+
st.markdown("---")
|
| 269 |
+
|
| 270 |
+
st.markdown("### 🎯 RAG Settings")
|
| 271 |
+
num_examples = st.slider("Similar examples to retrieve", min_value=1, max_value=5, value=3)
|
| 272 |
+
|
| 273 |
+
st.markdown("---")
|
| 274 |
+
|
| 275 |
+
st.markdown("### 📊 System Status")
|
| 276 |
+
col1, col2 = st.columns(2)
|
| 277 |
+
with col1:
|
| 278 |
+
st.markdown("✅ **RAG**")
|
| 279 |
+
st.markdown("✅ **Model**")
|
| 280 |
+
with col2:
|
| 281 |
+
st.markdown("✅ **Prompts**")
|
| 282 |
+
if os.getenv("GEMINI_API_KEY"):
|
| 283 |
+
st.markdown("✅ **Gemini**")
|
| 284 |
+
else:
|
| 285 |
+
st.markdown("❌ **Gemini**")
|
| 286 |
+
|
| 287 |
+
st.markdown("---")
|
| 288 |
+
|
| 289 |
+
st.markdown("### 🔄 Pipeline Flow")
|
| 290 |
+
pipeline_steps = [
|
| 291 |
+
("📦", "Synthetic Data"),
|
| 292 |
+
("🎓", "Fine-tuned Model"),
|
| 293 |
+
("❓", "User Question"),
|
| 294 |
+
("🔍", "RAG Retrieval"),
|
| 295 |
+
("📝", "Prompt Engineering"),
|
| 296 |
+
("🤖", "Model Inference"),
|
| 297 |
+
("✨", "Gemini Enhancement"),
|
| 298 |
+
("✅", "Final Output"),
|
| 299 |
+
]
|
| 300 |
+
|
| 301 |
+
for i, (icon, title) in enumerate(pipeline_steps):
|
| 302 |
+
st.markdown(f'<div class="pipeline-box">{icon} <strong>{title}</strong></div>', unsafe_allow_html=True)
|
| 303 |
+
if i < len(pipeline_steps) - 1:
|
| 304 |
+
st.markdown('<p class="pipeline-arrow">↓</p>', unsafe_allow_html=True)
|
| 305 |
+
|
| 306 |
+
st.markdown("---")
|
| 307 |
+
st.markdown("### 📚 About")
|
| 308 |
+
st.markdown("**Course:** INFO7375")
|
| 309 |
+
|
| 310 |
+
# =============================================================================
|
| 311 |
+
# MAIN CONTENT
|
| 312 |
+
# =============================================================================
|
| 313 |
+
|
| 314 |
+
if "messages" not in st.session_state:
|
| 315 |
+
st.session_state.messages = []
|
| 316 |
+
|
| 317 |
+
if "results_history" not in st.session_state:
|
| 318 |
+
st.session_state.results_history = []
|
| 319 |
+
|
| 320 |
+
if "input_text" not in st.session_state:
|
| 321 |
+
st.session_state.input_text = ""
|
| 322 |
+
|
| 323 |
+
# =============================================================================
|
| 324 |
+
# EXAMPLE QUESTIONS
|
| 325 |
+
# =============================================================================
|
| 326 |
+
|
| 327 |
+
st.markdown("### 💡 Try an Example")
|
| 328 |
+
|
| 329 |
+
example_questions = [
|
| 330 |
+
("👥 Employees", "Find all employees with salary above 50000"),
|
| 331 |
+
("📊 Orders", "Count total orders by customer"),
|
| 332 |
+
("🏆 Top Products", "Show top 5 products by revenue"),
|
| 333 |
+
("📅 Recent", "List customers who placed orders in 2024"),
|
| 334 |
+
("💰 Salary", "Calculate average salary by department"),
|
| 335 |
+
]
|
| 336 |
+
|
| 337 |
+
cols = st.columns(5)
|
| 338 |
+
for i, (label, ex_question) in enumerate(example_questions):
|
| 339 |
+
with cols[i]:
|
| 340 |
+
if st.button(label, key=f"ex_{i}", use_container_width=True, help=ex_question):
|
| 341 |
+
st.session_state.input_text = ex_question
|
| 342 |
+
|
| 343 |
+
# =============================================================================
|
| 344 |
+
# INPUT AREA
|
| 345 |
+
# =============================================================================
|
| 346 |
+
|
| 347 |
+
st.markdown("### 🎤 Ask Your Question")
|
| 348 |
+
|
| 349 |
+
col1, col2 = st.columns([6, 1])
|
| 350 |
+
|
| 351 |
+
with col1:
|
| 352 |
+
question = st.text_input(
|
| 353 |
+
"Question",
|
| 354 |
+
placeholder="e.g., Find all employees with salary greater than 50000...",
|
| 355 |
+
label_visibility="collapsed",
|
| 356 |
+
key="input_text"
|
| 357 |
+
)
|
| 358 |
+
|
| 359 |
+
with col2:
|
| 360 |
+
submit_btn = st.button("🚀 Run", type="primary", use_container_width=True)
|
| 361 |
+
|
| 362 |
+
st.markdown("---")
|
| 363 |
+
|
| 364 |
+
# =============================================================================
|
| 365 |
+
# CHAT HISTORY
|
| 366 |
+
# =============================================================================
|
| 367 |
+
|
| 368 |
+
for i, message in enumerate(st.session_state.messages):
|
| 369 |
+
with st.chat_message(message["role"], avatar="🧑💻" if message["role"] == "user" else "🤖"):
|
| 370 |
+
st.markdown(message["content"])
|
| 371 |
+
|
| 372 |
+
if message["role"] == "assistant":
|
| 373 |
+
result_idx = i // 2
|
| 374 |
+
if result_idx < len(st.session_state.results_history):
|
| 375 |
+
result = st.session_state.results_history[result_idx]
|
| 376 |
+
if result and result.get('success'):
|
| 377 |
+
with st.expander("🔍 View Pipeline Details", expanded=False):
|
| 378 |
+
tab1, tab2, tab3, tab4 = st.tabs(["🔍 RAG", "📝 Prompt", "🤖 Fine-tuned", "✨ Gemini"])
|
| 379 |
+
|
| 380 |
+
with tab1:
|
| 381 |
+
examples = result['steps']['rag'].get('examples', [])
|
| 382 |
+
st.markdown(f"**Retrieved {len(examples)} examples**")
|
| 383 |
+
for j, ex in enumerate(examples, 1):
|
| 384 |
+
st.markdown(f"**Example {j}** | Score: `{ex.get('score', 0):.3f}`")
|
| 385 |
+
st.markdown(f"Q: {ex.get('question', 'N/A')}")
|
| 386 |
+
st.code(ex.get('sql', 'N/A'), language="sql")
|
| 387 |
+
|
| 388 |
+
with tab2:
|
| 389 |
+
st.markdown("**Constructed Prompt:**")
|
| 390 |
+
st.code(result['steps']['prompt'].get('prompt', 'N/A'), language="text")
|
| 391 |
+
|
| 392 |
+
with tab3:
|
| 393 |
+
st.markdown("**Fine-tuned Model Output:**")
|
| 394 |
+
st.code(result['steps']['finetuned'].get('sql', 'N/A'), language="sql")
|
| 395 |
+
|
| 396 |
+
with tab4:
|
| 397 |
+
if 'gemini_enhance' in result['steps']:
|
| 398 |
+
st.markdown("**Enhanced SQL:**")
|
| 399 |
+
st.code(result['steps']['gemini_enhance'].get('sql', 'N/A'), language="sql")
|
| 400 |
+
|
| 401 |
+
# =============================================================================
|
| 402 |
+
# PROCESS QUERY
|
| 403 |
+
# =============================================================================
|
| 404 |
+
|
| 405 |
+
if submit_btn and question:
|
| 406 |
+
st.session_state.messages.append({"role": "user", "content": question})
|
| 407 |
+
|
| 408 |
+
with st.chat_message("user", avatar="🧑💻"):
|
| 409 |
+
st.markdown(question)
|
| 410 |
+
|
| 411 |
+
with st.chat_message("assistant", avatar="🤖"):
|
| 412 |
+
with st.status("🔄 Processing your query...", expanded=True) as status:
|
| 413 |
+
st.write("🔍 Retrieving similar examples...")
|
| 414 |
+
st.write("📝 Building prompt...")
|
| 415 |
+
st.write("🤖 Generating SQL...")
|
| 416 |
+
st.write("✨ Enhancing with Gemini...")
|
| 417 |
+
|
| 418 |
+
result = run_pipeline(question=question, num_examples=num_examples)
|
| 419 |
+
|
| 420 |
+
status.update(label="✅ Complete!", state="complete", expanded=False)
|
| 421 |
+
|
| 422 |
+
st.session_state.results_history.append(result)
|
| 423 |
+
|
| 424 |
+
if result['success']:
|
| 425 |
+
st.markdown("### ✅ Generated SQL")
|
| 426 |
+
st.code(result['final_sql'], language="sql")
|
| 427 |
+
|
| 428 |
+
if 'gemini_enhance' in result['steps']:
|
| 429 |
+
original = result['steps']['finetuned'].get('sql', '')
|
| 430 |
+
enhanced = result['steps']['gemini_enhance'].get('sql', '')
|
| 431 |
+
if original != enhanced:
|
| 432 |
+
st.success("✨ Query optimized by Gemini!")
|
| 433 |
+
else:
|
| 434 |
+
st.info("✓ Query was already optimal")
|
| 435 |
+
|
| 436 |
+
if 'explanation' in result and result['explanation']:
|
| 437 |
+
if not result['explanation'].startswith("Explanation error"):
|
| 438 |
+
st.markdown("### 📖 Explanation")
|
| 439 |
+
st.info(result['explanation'])
|
| 440 |
+
|
| 441 |
+
with st.expander("🔍 View Pipeline Details", expanded=False):
|
| 442 |
+
tab1, tab2, tab3, tab4 = st.tabs(["🔍 RAG", "📝 Prompt", "🤖 Fine-tuned", "✨ Gemini"])
|
| 443 |
+
|
| 444 |
+
with tab1:
|
| 445 |
+
examples = result['steps']['rag'].get('examples', [])
|
| 446 |
+
st.markdown(f"**Retrieved {len(examples)} examples**")
|
| 447 |
+
for j, ex in enumerate(examples, 1):
|
| 448 |
+
st.markdown(f"**Example {j}** | Score: `{ex.get('score', 0):.3f}`")
|
| 449 |
+
st.markdown(f"Q: {ex.get('question', 'N/A')}")
|
| 450 |
+
st.code(ex.get('sql', 'N/A'), language="sql")
|
| 451 |
+
|
| 452 |
+
with tab2:
|
| 453 |
+
st.markdown("**Constructed Prompt:**")
|
| 454 |
+
st.code(result['steps']['prompt'].get('prompt', 'N/A'), language="text")
|
| 455 |
+
|
| 456 |
+
with tab3:
|
| 457 |
+
st.markdown("**Fine-tuned Model Output:**")
|
| 458 |
+
st.code(result['steps']['finetuned'].get('sql', 'N/A'), language="sql")
|
| 459 |
+
|
| 460 |
+
with tab4:
|
| 461 |
+
if 'gemini_enhance' in result['steps']:
|
| 462 |
+
st.markdown("**Enhanced SQL:**")
|
| 463 |
+
st.code(result['steps']['gemini_enhance'].get('sql', 'N/A'), language="sql")
|
| 464 |
+
|
| 465 |
+
response_text = f"**Generated SQL:**\n```sql\n{result['final_sql']}\n```"
|
| 466 |
+
if 'explanation' in result and not result['explanation'].startswith("Explanation error"):
|
| 467 |
+
response_text += f"\n\n**Explanation:** {result['explanation']}"
|
| 468 |
+
|
| 469 |
+
st.session_state.messages.append({"role": "assistant", "content": response_text})
|
| 470 |
+
|
| 471 |
+
else:
|
| 472 |
+
st.error("❌ Failed to generate SQL. Please try again.")
|
| 473 |
+
st.session_state.messages.append({"role": "assistant", "content": "❌ Failed to generate SQL."})
|
| 474 |
+
|
| 475 |
+
elif submit_btn and not question:
|
| 476 |
+
st.warning("⚠️ Please enter a question first!")
|
| 477 |
+
|
| 478 |
+
# =============================================================================
|
| 479 |
+
# FOOTER
|
| 480 |
+
# =============================================================================
|
| 481 |
+
|
| 482 |
+
st.markdown("---")
|
| 483 |
+
|
| 484 |
+
col1, col2, col3 = st.columns([1, 2, 1])
|
| 485 |
+
|
| 486 |
+
with col1:
|
| 487 |
+
if st.button("🗑️ Clear Chat", use_container_width=True):
|
| 488 |
+
st.session_state.messages = []
|
| 489 |
+
st.session_state.results_history = []
|
| 490 |
+
st.session_state.input_text = ""
|
| 491 |
+
st.rerun()
|
| 492 |
+
|
| 493 |
+
with col2:
|
| 494 |
+
st.markdown('<p style="text-align: center; color: #64748b;">Built with ❤️ using Streamlit • LangChain • Gemini</p>', unsafe_allow_html=True)
|
| 495 |
+
|
| 496 |
+
with col3:
|
| 497 |
+
st.markdown('<p style="text-align: right; color: #64748b;"><strong>INFO7375</strong></p>', unsafe_allow_html=True)
|
src/config.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Central Configuration for SQL Learning Assistant
|
| 3 |
+
Handles local vs HuggingFace paths automatically
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
from dotenv import load_dotenv
|
| 8 |
+
|
| 9 |
+
load_dotenv()
|
| 10 |
+
|
| 11 |
+
# =============================================================================
|
| 12 |
+
# HUGGINGFACE CONFIGURATION (for cloud deployment)
|
| 13 |
+
# Set these in .env or Streamlit secrets
|
| 14 |
+
# =============================================================================
|
| 15 |
+
|
| 16 |
+
HF_MODEL_ID = os.getenv("HF_MODEL_ID", None) # e.g., "username/sql-tinyllama-lora"
|
| 17 |
+
HF_CHROMADB_ID = os.getenv("HF_CHROMADB_ID", None) # e.g., "username/sql-chromadb"
|
| 18 |
+
HF_TOKEN = os.getenv("HF_TOKEN", None)
|
| 19 |
+
|
| 20 |
+
# =============================================================================
|
| 21 |
+
# LOCAL PATHS
|
| 22 |
+
# =============================================================================
|
| 23 |
+
|
| 24 |
+
LOCAL_MODEL_DIR = "outputs/finetuning/checkpoints/final"
|
| 25 |
+
LOCAL_CHROMADB_DIR = "chromadb_data"
|
| 26 |
+
LOCAL_DATA_DIR = "data"
|
| 27 |
+
|
| 28 |
+
# =============================================================================
|
| 29 |
+
# GEMINI CONFIGURATION
|
| 30 |
+
# =============================================================================
|
| 31 |
+
|
| 32 |
+
GEMINI_KEYS = [
|
| 33 |
+
os.getenv("GEMINI_API_KEY"),
|
| 34 |
+
os.getenv("GEMINI_API_KEY_FALLBACK_1"),
|
| 35 |
+
os.getenv("GEMINI_API_KEY_FALLBACK_2"),
|
| 36 |
+
]
|
| 37 |
+
GEMINI_KEYS = [k for k in GEMINI_KEYS if k] # Remove None values
|
| 38 |
+
|
| 39 |
+
GEMINI_MODELS = [
|
| 40 |
+
os.getenv("GEMINI_MODEL", "gemini-2.5-flash"),
|
| 41 |
+
os.getenv("GEMINI_MODEL_FALLBACK_1"),
|
| 42 |
+
]
|
| 43 |
+
GEMINI_MODELS = [m for m in GEMINI_MODELS if m] # Remove None values
|
| 44 |
+
|
| 45 |
+
# =============================================================================
|
| 46 |
+
# RAG CONFIGURATION
|
| 47 |
+
# =============================================================================
|
| 48 |
+
|
| 49 |
+
EMBEDDING_MODEL = "all-MiniLM-L6-v2"
|
| 50 |
+
COLLECTION_NAME = "sql_knowledge"
|
| 51 |
+
|
| 52 |
+
# =============================================================================
|
| 53 |
+
# HELPER FUNCTIONS
|
| 54 |
+
# =============================================================================
|
| 55 |
+
|
| 56 |
+
def is_local():
|
| 57 |
+
"""Check if running locally (has local model/data)."""
|
| 58 |
+
return os.path.exists(LOCAL_MODEL_DIR) and os.path.exists(LOCAL_CHROMADB_DIR)
|
| 59 |
+
|
| 60 |
+
def is_cloud():
|
| 61 |
+
"""Check if running in cloud (has HF config)."""
|
| 62 |
+
return HF_MODEL_ID is not None or HF_CHROMADB_ID is not None
|
| 63 |
+
|
| 64 |
+
def get_model_source():
|
| 65 |
+
"""Get where model will be loaded from."""
|
| 66 |
+
if os.path.exists(LOCAL_MODEL_DIR) and os.listdir(LOCAL_MODEL_DIR):
|
| 67 |
+
return "local", LOCAL_MODEL_DIR
|
| 68 |
+
elif HF_MODEL_ID:
|
| 69 |
+
return "huggingface", HF_MODEL_ID
|
| 70 |
+
else:
|
| 71 |
+
return "base", "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
|
| 72 |
+
|
| 73 |
+
def get_chromadb_source():
|
| 74 |
+
"""Get where ChromaDB will be loaded from."""
|
| 75 |
+
if os.path.exists(LOCAL_CHROMADB_DIR) and os.listdir(LOCAL_CHROMADB_DIR):
|
| 76 |
+
return "local", LOCAL_CHROMADB_DIR
|
| 77 |
+
elif HF_CHROMADB_ID:
|
| 78 |
+
return "huggingface", HF_CHROMADB_ID
|
| 79 |
+
else:
|
| 80 |
+
return "build", LOCAL_DATA_DIR
|
| 81 |
+
|
| 82 |
+
def print_config():
|
| 83 |
+
"""Print current configuration."""
|
| 84 |
+
print("=" * 50)
|
| 85 |
+
print("CONFIGURATION")
|
| 86 |
+
print("=" * 50)
|
| 87 |
+
|
| 88 |
+
model_src, model_path = get_model_source()
|
| 89 |
+
chromadb_src, chromadb_path = get_chromadb_source()
|
| 90 |
+
|
| 91 |
+
print(f"Model: {model_src} → {model_path}")
|
| 92 |
+
print(f"ChromaDB: {chromadb_src} → {chromadb_path}")
|
| 93 |
+
print(f"Gemini Keys: {len(GEMINI_KEYS)} available")
|
| 94 |
+
print(f"Gemini Models: {GEMINI_MODELS}")
|
| 95 |
+
print("=" * 50)
|
| 96 |
+
|
| 97 |
+
if __name__ == "__main__":
|
| 98 |
+
print_config()
|
src/finetuning/__init__.py
ADDED
|
File without changes
|
src/finetuning/evaluate.py
ADDED
|
@@ -0,0 +1,293 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Evaluation Module for Fine-Tuned SQL Model
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
import json
|
| 7 |
+
import matplotlib.pyplot as plt
|
| 8 |
+
from datetime import datetime
|
| 9 |
+
from collections import Counter
|
| 10 |
+
|
| 11 |
+
# =============================================================================
|
| 12 |
+
# CONFIGURATION
|
| 13 |
+
# =============================================================================
|
| 14 |
+
|
| 15 |
+
OUTPUT_DIR = "outputs/finetuning"
|
| 16 |
+
RESULTS_DIR = f"{OUTPUT_DIR}/results"
|
| 17 |
+
VIZ_DIR = f"{OUTPUT_DIR}/visualizations"
|
| 18 |
+
|
| 19 |
+
# Number of samples to evaluate
|
| 20 |
+
NUM_EVAL_SAMPLES = 50 # Change for more/less evaluation
|
| 21 |
+
|
| 22 |
+
def setup_directories():
|
| 23 |
+
for d in [RESULTS_DIR, VIZ_DIR]:
|
| 24 |
+
os.makedirs(d, exist_ok=True)
|
| 25 |
+
|
| 26 |
+
# =============================================================================
|
| 27 |
+
# EVALUATION METRICS
|
| 28 |
+
# =============================================================================
|
| 29 |
+
|
| 30 |
+
def exact_match(pred, expected):
|
| 31 |
+
"""Check exact match."""
|
| 32 |
+
return pred.lower().strip() == expected.lower().strip()
|
| 33 |
+
|
| 34 |
+
def token_accuracy(pred, expected):
|
| 35 |
+
"""Token overlap accuracy."""
|
| 36 |
+
pred_tokens = set(pred.lower().split())
|
| 37 |
+
exp_tokens = set(expected.lower().split())
|
| 38 |
+
if not exp_tokens:
|
| 39 |
+
return 0.0
|
| 40 |
+
return len(pred_tokens & exp_tokens) / len(exp_tokens)
|
| 41 |
+
|
| 42 |
+
def keyword_accuracy(pred, expected):
|
| 43 |
+
"""SQL keyword match accuracy."""
|
| 44 |
+
keywords = ['SELECT', 'FROM', 'WHERE', 'JOIN', 'GROUP BY',
|
| 45 |
+
'ORDER BY', 'COUNT', 'SUM', 'AVG', 'MAX', 'MIN']
|
| 46 |
+
|
| 47 |
+
pred_kw = [k for k in keywords if k in pred.upper()]
|
| 48 |
+
exp_kw = [k for k in keywords if k in expected.upper()]
|
| 49 |
+
|
| 50 |
+
if not exp_kw:
|
| 51 |
+
return 1.0 if not pred_kw else 0.0
|
| 52 |
+
|
| 53 |
+
matches = sum(1 for k in exp_kw if k in pred_kw)
|
| 54 |
+
return matches / len(exp_kw)
|
| 55 |
+
|
| 56 |
+
def structure_similarity(pred, expected):
|
| 57 |
+
"""SQL structure similarity."""
|
| 58 |
+
clauses = ['SELECT', 'FROM', 'WHERE', 'JOIN', 'GROUP BY', 'ORDER BY', 'LIMIT']
|
| 59 |
+
|
| 60 |
+
pred_struct = set(c for c in clauses if c in pred.upper())
|
| 61 |
+
exp_struct = set(c for c in clauses if c in expected.upper())
|
| 62 |
+
|
| 63 |
+
if not exp_struct and not pred_struct:
|
| 64 |
+
return 1.0
|
| 65 |
+
if not exp_struct or not pred_struct:
|
| 66 |
+
return 0.0
|
| 67 |
+
|
| 68 |
+
return len(pred_struct & exp_struct) / len(pred_struct | exp_struct)
|
| 69 |
+
|
| 70 |
+
# =============================================================================
|
| 71 |
+
# EVALUATION RUNNER
|
| 72 |
+
# =============================================================================
|
| 73 |
+
|
| 74 |
+
def evaluate_predictions(predictions, ground_truth):
|
| 75 |
+
"""Calculate all metrics."""
|
| 76 |
+
|
| 77 |
+
results = {
|
| 78 |
+
'exact_match': [],
|
| 79 |
+
'token_accuracy': [],
|
| 80 |
+
'keyword_accuracy': [],
|
| 81 |
+
'structure_similarity': []
|
| 82 |
+
}
|
| 83 |
+
|
| 84 |
+
for pred, exp in zip(predictions, ground_truth):
|
| 85 |
+
results['exact_match'].append(1 if exact_match(pred, exp) else 0)
|
| 86 |
+
results['token_accuracy'].append(token_accuracy(pred, exp))
|
| 87 |
+
results['keyword_accuracy'].append(keyword_accuracy(pred, exp))
|
| 88 |
+
results['structure_similarity'].append(structure_similarity(pred, exp))
|
| 89 |
+
|
| 90 |
+
# Calculate averages
|
| 91 |
+
metrics = {
|
| 92 |
+
'total_samples': len(predictions),
|
| 93 |
+
'exact_match_rate': sum(results['exact_match']) / len(results['exact_match']),
|
| 94 |
+
'avg_token_accuracy': sum(results['token_accuracy']) / len(results['token_accuracy']),
|
| 95 |
+
'avg_keyword_accuracy': sum(results['keyword_accuracy']) / len(results['keyword_accuracy']),
|
| 96 |
+
'avg_structure_similarity': sum(results['structure_similarity']) / len(results['structure_similarity']),
|
| 97 |
+
'detailed': results
|
| 98 |
+
}
|
| 99 |
+
|
| 100 |
+
return metrics
|
| 101 |
+
|
| 102 |
+
# =============================================================================
|
| 103 |
+
# VISUALIZATIONS
|
| 104 |
+
# =============================================================================
|
| 105 |
+
|
| 106 |
+
def create_visualizations(metrics):
|
| 107 |
+
"""Create evaluation charts."""
|
| 108 |
+
|
| 109 |
+
setup_directories()
|
| 110 |
+
plt.style.use('seaborn-v0_8-whitegrid')
|
| 111 |
+
|
| 112 |
+
# 1. Metrics Overview
|
| 113 |
+
fig, ax = plt.subplots(figsize=(10, 6))
|
| 114 |
+
|
| 115 |
+
names = ['Exact Match', 'Token Acc', 'Keyword Acc', 'Structure Sim']
|
| 116 |
+
values = [
|
| 117 |
+
metrics['exact_match_rate'] * 100,
|
| 118 |
+
metrics['avg_token_accuracy'] * 100,
|
| 119 |
+
metrics['avg_keyword_accuracy'] * 100,
|
| 120 |
+
metrics['avg_structure_similarity'] * 100
|
| 121 |
+
]
|
| 122 |
+
colors = ['#3498db', '#2ecc71', '#9b59b6', '#e74c3c']
|
| 123 |
+
|
| 124 |
+
bars = ax.bar(names, values, color=colors, edgecolor='black')
|
| 125 |
+
for bar, val in zip(bars, values):
|
| 126 |
+
ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 1,
|
| 127 |
+
f'{val:.1f}%', ha='center', fontweight='bold')
|
| 128 |
+
|
| 129 |
+
ax.set_ylabel('Score (%)')
|
| 130 |
+
ax.set_title('Model Evaluation Metrics', fontsize=14, fontweight='bold')
|
| 131 |
+
ax.set_ylim(0, 110)
|
| 132 |
+
|
| 133 |
+
plt.tight_layout()
|
| 134 |
+
plt.savefig(f'{VIZ_DIR}/01_metrics_overview.png', dpi=150)
|
| 135 |
+
plt.close()
|
| 136 |
+
print(f" Saved: {VIZ_DIR}/01_metrics_overview.png")
|
| 137 |
+
|
| 138 |
+
# 2. Token Accuracy Distribution
|
| 139 |
+
fig, ax = plt.subplots(figsize=(10, 6))
|
| 140 |
+
|
| 141 |
+
token_acc = metrics['detailed']['token_accuracy']
|
| 142 |
+
ax.hist(token_acc, bins=20, color='#2ecc71', edgecolor='black', alpha=0.7)
|
| 143 |
+
ax.axvline(sum(token_acc)/len(token_acc), color='red', linestyle='--',
|
| 144 |
+
label=f"Mean: {sum(token_acc)/len(token_acc):.2f}")
|
| 145 |
+
ax.set_xlabel('Token Accuracy')
|
| 146 |
+
ax.set_ylabel('Frequency')
|
| 147 |
+
ax.set_title('Token Accuracy Distribution', fontsize=14, fontweight='bold')
|
| 148 |
+
ax.legend()
|
| 149 |
+
|
| 150 |
+
plt.tight_layout()
|
| 151 |
+
plt.savefig(f'{VIZ_DIR}/02_token_accuracy_dist.png', dpi=150)
|
| 152 |
+
plt.close()
|
| 153 |
+
print(f" Saved: {VIZ_DIR}/02_token_accuracy_dist.png")
|
| 154 |
+
|
| 155 |
+
# 3. Keyword Accuracy Distribution
|
| 156 |
+
fig, ax = plt.subplots(figsize=(10, 6))
|
| 157 |
+
|
| 158 |
+
kw_acc = metrics['detailed']['keyword_accuracy']
|
| 159 |
+
ax.hist(kw_acc, bins=20, color='#9b59b6', edgecolor='black', alpha=0.7)
|
| 160 |
+
ax.axvline(sum(kw_acc)/len(kw_acc), color='red', linestyle='--',
|
| 161 |
+
label=f"Mean: {sum(kw_acc)/len(kw_acc):.2f}")
|
| 162 |
+
ax.set_xlabel('Keyword Accuracy')
|
| 163 |
+
ax.set_ylabel('Frequency')
|
| 164 |
+
ax.set_title('Keyword Accuracy Distribution', fontsize=14, fontweight='bold')
|
| 165 |
+
ax.legend()
|
| 166 |
+
|
| 167 |
+
plt.tight_layout()
|
| 168 |
+
plt.savefig(f'{VIZ_DIR}/03_keyword_accuracy_dist.png', dpi=150)
|
| 169 |
+
plt.close()
|
| 170 |
+
print(f" Saved: {VIZ_DIR}/03_keyword_accuracy_dist.png")
|
| 171 |
+
|
| 172 |
+
# =============================================================================
|
| 173 |
+
# REPORT GENERATION
|
| 174 |
+
# =============================================================================
|
| 175 |
+
|
| 176 |
+
def generate_report(metrics):
|
| 177 |
+
"""Generate evaluation report."""
|
| 178 |
+
|
| 179 |
+
report = f"""# Fine-Tuning Evaluation Report
|
| 180 |
+
|
| 181 |
+
**Generated:** {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
|
| 182 |
+
|
| 183 |
+
## Metrics Summary
|
| 184 |
+
|
| 185 |
+
| Metric | Score |
|
| 186 |
+
|--------|-------|
|
| 187 |
+
| Samples Evaluated | {metrics['total_samples']} |
|
| 188 |
+
| Exact Match Rate | {metrics['exact_match_rate']*100:.2f}% |
|
| 189 |
+
| Token Accuracy | {metrics['avg_token_accuracy']*100:.2f}% |
|
| 190 |
+
| Keyword Accuracy | {metrics['avg_keyword_accuracy']*100:.2f}% |
|
| 191 |
+
| Structure Similarity | {metrics['avg_structure_similarity']*100:.2f}% |
|
| 192 |
+
|
| 193 |
+
## Metrics Explanation
|
| 194 |
+
|
| 195 |
+
- **Exact Match**: Predictions identical to ground truth
|
| 196 |
+
- **Token Accuracy**: Word overlap between prediction and expected
|
| 197 |
+
- **Keyword Accuracy**: SQL keywords (SELECT, WHERE, etc.) match
|
| 198 |
+
- **Structure Similarity**: Query structure (clauses used) match
|
| 199 |
+
|
| 200 |
+
## Visualizations
|
| 201 |
+
|
| 202 |
+
- `01_metrics_overview.png` - All metrics bar chart
|
| 203 |
+
- `02_token_accuracy_dist.png` - Token accuracy histogram
|
| 204 |
+
- `03_keyword_accuracy_dist.png` - Keyword accuracy histogram
|
| 205 |
+
"""
|
| 206 |
+
|
| 207 |
+
with open(f'{RESULTS_DIR}/evaluation_report.md', 'w') as f:
|
| 208 |
+
f.write(report)
|
| 209 |
+
print(f" Saved: {RESULTS_DIR}/evaluation_report.md")
|
| 210 |
+
|
| 211 |
+
# Save JSON
|
| 212 |
+
json_metrics = {k: v for k, v in metrics.items() if k != 'detailed'}
|
| 213 |
+
with open(f'{RESULTS_DIR}/evaluation_results.json', 'w') as f:
|
| 214 |
+
json.dump(json_metrics, f, indent=2)
|
| 215 |
+
print(f" Saved: {RESULTS_DIR}/evaluation_results.json")
|
| 216 |
+
|
| 217 |
+
# =============================================================================
|
| 218 |
+
# MAIN EVALUATION
|
| 219 |
+
# =============================================================================
|
| 220 |
+
|
| 221 |
+
def run_evaluation():
|
| 222 |
+
"""Run full evaluation."""
|
| 223 |
+
|
| 224 |
+
print("=" * 60)
|
| 225 |
+
print("EVALUATING FINE-TUNED MODEL")
|
| 226 |
+
print("=" * 60)
|
| 227 |
+
|
| 228 |
+
setup_directories()
|
| 229 |
+
|
| 230 |
+
# Load test data
|
| 231 |
+
print("\n[1/4] Loading test data...")
|
| 232 |
+
test_file = f"{OUTPUT_DIR}/test.jsonl"
|
| 233 |
+
|
| 234 |
+
if not os.path.exists(test_file):
|
| 235 |
+
print("ERROR: Run prepare_data.py first!")
|
| 236 |
+
return None
|
| 237 |
+
|
| 238 |
+
test_data = []
|
| 239 |
+
with open(test_file) as f:
|
| 240 |
+
for line in f:
|
| 241 |
+
test_data.append(json.loads(line))
|
| 242 |
+
|
| 243 |
+
test_data = test_data[:NUM_EVAL_SAMPLES]
|
| 244 |
+
print(f" Loaded {len(test_data)} samples")
|
| 245 |
+
|
| 246 |
+
# Generate predictions
|
| 247 |
+
print("\n[2/4] Generating predictions...")
|
| 248 |
+
|
| 249 |
+
try:
|
| 250 |
+
import sys
|
| 251 |
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 252 |
+
from finetuning.inference import SQLGenerator
|
| 253 |
+
generator = SQLGenerator()
|
| 254 |
+
|
| 255 |
+
predictions = []
|
| 256 |
+
ground_truth = []
|
| 257 |
+
|
| 258 |
+
for i, item in enumerate(test_data):
|
| 259 |
+
pred = generator.generate(item['question'])
|
| 260 |
+
predictions.append(pred)
|
| 261 |
+
ground_truth.append(item['sql'])
|
| 262 |
+
|
| 263 |
+
if (i + 1) % 10 == 0:
|
| 264 |
+
print(f" Progress: {i+1}/{len(test_data)}")
|
| 265 |
+
|
| 266 |
+
except Exception as e:
|
| 267 |
+
print(f" Error loading model: {e}")
|
| 268 |
+
print(" Using ground truth as predictions (for testing metrics)")
|
| 269 |
+
predictions = [item['sql'] for item in test_data]
|
| 270 |
+
ground_truth = [item['sql'] for item in test_data]
|
| 271 |
+
|
| 272 |
+
# Calculate metrics
|
| 273 |
+
print("\n[3/4] Calculating metrics...")
|
| 274 |
+
metrics = evaluate_predictions(predictions, ground_truth)
|
| 275 |
+
|
| 276 |
+
print(f" Exact Match: {metrics['exact_match_rate']*100:.2f}%")
|
| 277 |
+
print(f" Token Accuracy: {metrics['avg_token_accuracy']*100:.2f}%")
|
| 278 |
+
print(f" Keyword Accuracy: {metrics['avg_keyword_accuracy']*100:.2f}%")
|
| 279 |
+
print(f" Structure Sim: {metrics['avg_structure_similarity']*100:.2f}%")
|
| 280 |
+
|
| 281 |
+
# Generate outputs
|
| 282 |
+
print("\n[4/4] Generating outputs...")
|
| 283 |
+
create_visualizations(metrics)
|
| 284 |
+
generate_report(metrics)
|
| 285 |
+
|
| 286 |
+
print("\n" + "=" * 60)
|
| 287 |
+
print("EVALUATION COMPLETE")
|
| 288 |
+
print("=" * 60)
|
| 289 |
+
|
| 290 |
+
return metrics
|
| 291 |
+
|
| 292 |
+
if __name__ == "__main__":
|
| 293 |
+
run_evaluation()
|
src/finetuning/inference.py
ADDED
|
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Inference Module for Fine-Tuned SQL Model
|
| 3 |
+
Loads from: Local checkpoint OR Hugging Face Hub
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import torch
|
| 8 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 9 |
+
from dotenv import load_dotenv
|
| 10 |
+
|
| 11 |
+
load_dotenv()
|
| 12 |
+
|
| 13 |
+
# =============================================================================
|
| 14 |
+
# CONFIGURATION
|
| 15 |
+
# =============================================================================
|
| 16 |
+
|
| 17 |
+
# Hugging Face Model ID (set in .env or Streamlit secrets)
|
| 18 |
+
HF_MODEL_ID = os.getenv("HF_MODEL_ID", None)
|
| 19 |
+
|
| 20 |
+
# Local paths
|
| 21 |
+
LOCAL_MODEL_DIR = "outputs/finetuning/checkpoints/final"
|
| 22 |
+
BASE_MODEL = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
|
| 23 |
+
|
| 24 |
+
# =============================================================================
|
| 25 |
+
# SQL GENERATOR CLASS
|
| 26 |
+
# =============================================================================
|
| 27 |
+
|
| 28 |
+
class SQLGenerator:
|
| 29 |
+
"""SQL Generation using fine-tuned model."""
|
| 30 |
+
|
| 31 |
+
def __init__(self):
|
| 32 |
+
"""Load the fine-tuned model from local or HuggingFace."""
|
| 33 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 34 |
+
print(f"Device: {self.device}")
|
| 35 |
+
|
| 36 |
+
load_path = self._get_model_path()
|
| 37 |
+
|
| 38 |
+
# Load tokenizer and model with memory optimization
|
| 39 |
+
print(f"Loading model from: {load_path}")
|
| 40 |
+
self.tokenizer = AutoTokenizer.from_pretrained(load_path)
|
| 41 |
+
|
| 42 |
+
# Memory-efficient loading for cloud deployment
|
| 43 |
+
self.model = AutoModelForCausalLM.from_pretrained(
|
| 44 |
+
load_path,
|
| 45 |
+
torch_dtype=torch.float32, # Use float32 for CPU
|
| 46 |
+
device_map=None, # Don't use device_map on CPU
|
| 47 |
+
low_cpu_mem_usage=True, # Reduce memory during loading
|
| 48 |
+
trust_remote_code=True
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
# Move to device after loading
|
| 52 |
+
self.model = self.model.to(self.device)
|
| 53 |
+
|
| 54 |
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
| 55 |
+
print("✓ Model loaded!")
|
| 56 |
+
|
| 57 |
+
def _get_model_path(self):
|
| 58 |
+
"""Determine where to load model from."""
|
| 59 |
+
|
| 60 |
+
# Check for required model files (not just folder existence)
|
| 61 |
+
required_files = ['config.json', 'tokenizer.json', 'tokenizer_config.json']
|
| 62 |
+
|
| 63 |
+
# Priority 1: Local checkpoint with actual model files
|
| 64 |
+
if os.path.exists(LOCAL_MODEL_DIR):
|
| 65 |
+
local_files = os.listdir(LOCAL_MODEL_DIR) if os.path.isdir(LOCAL_MODEL_DIR) else []
|
| 66 |
+
has_model_files = any(f in local_files for f in required_files) or any(f.endswith('.safetensors') or f.endswith('.bin') for f in local_files)
|
| 67 |
+
|
| 68 |
+
if has_model_files:
|
| 69 |
+
print(f"📁 Found local model checkpoint: {LOCAL_MODEL_DIR}")
|
| 70 |
+
return LOCAL_MODEL_DIR
|
| 71 |
+
else:
|
| 72 |
+
print(f"⚠️ Local folder exists but no model files found")
|
| 73 |
+
|
| 74 |
+
# Priority 2: Download from HuggingFace Hub
|
| 75 |
+
if HF_MODEL_ID:
|
| 76 |
+
print(f"☁️ Downloading model from HuggingFace: {HF_MODEL_ID}")
|
| 77 |
+
return HF_MODEL_ID
|
| 78 |
+
|
| 79 |
+
# Priority 3: Base model fallback
|
| 80 |
+
print("⚠️ No fine-tuned model found, using base model")
|
| 81 |
+
return BASE_MODEL
|
| 82 |
+
|
| 83 |
+
def generate(self, question, context="", max_tokens=128):
|
| 84 |
+
"""Generate SQL from question."""
|
| 85 |
+
|
| 86 |
+
# Build prompt
|
| 87 |
+
if context:
|
| 88 |
+
prompt = f"""{context}
|
| 89 |
+
|
| 90 |
+
### Question:
|
| 91 |
+
{question}
|
| 92 |
+
|
| 93 |
+
### SQL:"""
|
| 94 |
+
else:
|
| 95 |
+
prompt = f"""### Question:
|
| 96 |
+
{question}
|
| 97 |
+
|
| 98 |
+
### SQL:"""
|
| 99 |
+
|
| 100 |
+
# Tokenize
|
| 101 |
+
inputs = self.tokenizer(
|
| 102 |
+
prompt,
|
| 103 |
+
return_tensors="pt",
|
| 104 |
+
truncation=True,
|
| 105 |
+
max_length=512
|
| 106 |
+
).to(self.device)
|
| 107 |
+
|
| 108 |
+
# Generate
|
| 109 |
+
with torch.no_grad():
|
| 110 |
+
outputs = self.model.generate(
|
| 111 |
+
**inputs,
|
| 112 |
+
max_new_tokens=max_tokens,
|
| 113 |
+
temperature=0.1,
|
| 114 |
+
do_sample=True,
|
| 115 |
+
top_p=0.95,
|
| 116 |
+
pad_token_id=self.tokenizer.eos_token_id
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
# Decode
|
| 120 |
+
generated = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 121 |
+
|
| 122 |
+
# Extract SQL
|
| 123 |
+
sql = generated[len(prompt):].strip()
|
| 124 |
+
if "###" in sql:
|
| 125 |
+
sql = sql.split("###")[0].strip()
|
| 126 |
+
|
| 127 |
+
return sql
|
| 128 |
+
|
| 129 |
+
# =============================================================================
|
| 130 |
+
# STANDALONE FUNCTION
|
| 131 |
+
# =============================================================================
|
| 132 |
+
|
| 133 |
+
_generator = None
|
| 134 |
+
|
| 135 |
+
def generate_sql(question, context=""):
|
| 136 |
+
"""Standalone SQL generation."""
|
| 137 |
+
global _generator
|
| 138 |
+
if _generator is None:
|
| 139 |
+
_generator = SQLGenerator()
|
| 140 |
+
return _generator.generate(question, context)
|
| 141 |
+
|
| 142 |
+
# =============================================================================
|
| 143 |
+
# TEST
|
| 144 |
+
# =============================================================================
|
| 145 |
+
|
| 146 |
+
def test_inference():
|
| 147 |
+
"""Test the model."""
|
| 148 |
+
print("=" * 60)
|
| 149 |
+
print("TESTING SQL GENERATION")
|
| 150 |
+
print("=" * 60)
|
| 151 |
+
|
| 152 |
+
generator = SQLGenerator()
|
| 153 |
+
|
| 154 |
+
questions = [
|
| 155 |
+
"Find all employees with salary greater than 50000",
|
| 156 |
+
]
|
| 157 |
+
|
| 158 |
+
print("\n" + "-" * 60)
|
| 159 |
+
for q in questions:
|
| 160 |
+
print(f"Q: {q}")
|
| 161 |
+
sql = generator.generate(q)
|
| 162 |
+
print(f"SQL: {sql}")
|
| 163 |
+
print("-" * 60)
|
| 164 |
+
|
| 165 |
+
print("\n✓ Test complete")
|
| 166 |
+
|
| 167 |
+
if __name__ == "__main__":
|
| 168 |
+
test_inference()
|
src/finetuning/prepare_data.py
ADDED
|
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Data Preparation for Fine-Tuning
|
| 3 |
+
Uses train.csv, validation.csv, test.csv correctly.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import pandas as pd
|
| 8 |
+
import json
|
| 9 |
+
from datetime import datetime
|
| 10 |
+
|
| 11 |
+
# =============================================================================
|
| 12 |
+
# CONFIGURATION
|
| 13 |
+
# =============================================================================
|
| 14 |
+
|
| 15 |
+
OUTPUT_DIR = "outputs/finetuning"
|
| 16 |
+
DATA_DIR = "data"
|
| 17 |
+
|
| 18 |
+
# Change this for testing vs full training
|
| 19 |
+
MAX_SAMPLES = 100 # Set to None for full data
|
| 20 |
+
|
| 21 |
+
def setup_directories():
|
| 22 |
+
for d in [OUTPUT_DIR, f"{OUTPUT_DIR}/results", f"{OUTPUT_DIR}/logs"]:
|
| 23 |
+
os.makedirs(d, exist_ok=True)
|
| 24 |
+
|
| 25 |
+
# =============================================================================
|
| 26 |
+
# PROMPT TEMPLATE
|
| 27 |
+
# =============================================================================
|
| 28 |
+
|
| 29 |
+
def format_for_training(question, sql):
|
| 30 |
+
"""Format single example for instruction fine-tuning."""
|
| 31 |
+
text = f"""### Question:
|
| 32 |
+
{question}
|
| 33 |
+
|
| 34 |
+
### SQL:
|
| 35 |
+
{sql}"""
|
| 36 |
+
return text
|
| 37 |
+
|
| 38 |
+
# =============================================================================
|
| 39 |
+
# DATA LOADING
|
| 40 |
+
# =============================================================================
|
| 41 |
+
|
| 42 |
+
def load_csv_file(filepath, max_samples=None):
|
| 43 |
+
"""Load a single CSV file."""
|
| 44 |
+
if not os.path.exists(filepath):
|
| 45 |
+
print(f" File not found: {filepath}")
|
| 46 |
+
return None
|
| 47 |
+
|
| 48 |
+
df = pd.read_csv(filepath)
|
| 49 |
+
|
| 50 |
+
if max_samples and len(df) > max_samples:
|
| 51 |
+
df = df.sample(n=max_samples, random_state=42)
|
| 52 |
+
|
| 53 |
+
return df
|
| 54 |
+
|
| 55 |
+
def format_dataframe(df, source_name):
|
| 56 |
+
"""Convert dataframe to training format."""
|
| 57 |
+
formatted = []
|
| 58 |
+
for _, row in df.iterrows():
|
| 59 |
+
formatted.append({
|
| 60 |
+
'text': format_for_training(row['question'], row['sql']),
|
| 61 |
+
'question': str(row['question']),
|
| 62 |
+
'sql': str(row['sql']),
|
| 63 |
+
'source': source_name
|
| 64 |
+
})
|
| 65 |
+
return formatted
|
| 66 |
+
|
| 67 |
+
def save_jsonl(data, filepath):
|
| 68 |
+
"""Save data as JSONL file."""
|
| 69 |
+
with open(filepath, 'w', encoding='utf-8') as f:
|
| 70 |
+
for item in data:
|
| 71 |
+
f.write(json.dumps(item) + '\n')
|
| 72 |
+
print(f" Saved: {filepath}")
|
| 73 |
+
|
| 74 |
+
# =============================================================================
|
| 75 |
+
# MAIN FUNCTION
|
| 76 |
+
# =============================================================================
|
| 77 |
+
|
| 78 |
+
def prepare_finetuning_data():
|
| 79 |
+
"""Prepare data for fine-tuning."""
|
| 80 |
+
|
| 81 |
+
print("=" * 50)
|
| 82 |
+
print("PREPARING FINE-TUNING DATA")
|
| 83 |
+
print(f"Max samples per file: {MAX_SAMPLES if MAX_SAMPLES else 'ALL'}")
|
| 84 |
+
print("=" * 50)
|
| 85 |
+
|
| 86 |
+
setup_directories()
|
| 87 |
+
|
| 88 |
+
# Load train data
|
| 89 |
+
print("\n[1/5] Loading training data...")
|
| 90 |
+
train_df = load_csv_file(f"{DATA_DIR}/train.csv", MAX_SAMPLES)
|
| 91 |
+
print(f" train.csv: {len(train_df):,} rows")
|
| 92 |
+
|
| 93 |
+
# Load synthetic and combine with train
|
| 94 |
+
# synthetic_df = load_csv_file(f"{DATA_DIR}/synthetic.csv", MAX_SAMPLES)
|
| 95 |
+
# if synthetic_df is not None:
|
| 96 |
+
# print(f" synthetic.csv: {len(synthetic_df):,} rows")
|
| 97 |
+
# train_df = pd.concat([train_df, synthetic_df], ignore_index=True)
|
| 98 |
+
# print(f" Combined training: {len(train_df):,} rows")
|
| 99 |
+
|
| 100 |
+
# Load validation data
|
| 101 |
+
print("\n[2/5] Loading validation data...")
|
| 102 |
+
val_df = load_csv_file(f"{DATA_DIR}/validation.csv", MAX_SAMPLES)
|
| 103 |
+
print(f" validation.csv: {len(val_df):,} rows")
|
| 104 |
+
|
| 105 |
+
# Load test data
|
| 106 |
+
print("\n[3/5] Loading test data...")
|
| 107 |
+
test_df = load_csv_file(f"{DATA_DIR}/test.csv", MAX_SAMPLES)
|
| 108 |
+
print(f" test.csv: {len(test_df):,} rows")
|
| 109 |
+
|
| 110 |
+
# Format data
|
| 111 |
+
print("\n[4/5] Formatting data...")
|
| 112 |
+
train_data = format_dataframe(train_df, 'train')
|
| 113 |
+
val_data = format_dataframe(val_df, 'validation')
|
| 114 |
+
test_data = format_dataframe(test_df, 'test')
|
| 115 |
+
|
| 116 |
+
# Save files
|
| 117 |
+
print("\n[5/5] Saving files...")
|
| 118 |
+
save_jsonl(train_data, f"{OUTPUT_DIR}/train.jsonl")
|
| 119 |
+
save_jsonl(val_data, f"{OUTPUT_DIR}/val.jsonl")
|
| 120 |
+
save_jsonl(test_data, f"{OUTPUT_DIR}/test.jsonl")
|
| 121 |
+
|
| 122 |
+
# Save stats
|
| 123 |
+
stats = {
|
| 124 |
+
'train_samples': len(train_data),
|
| 125 |
+
'val_samples': len(val_data),
|
| 126 |
+
'test_samples': len(test_data),
|
| 127 |
+
'max_samples': MAX_SAMPLES,
|
| 128 |
+
'created_at': datetime.now().isoformat()
|
| 129 |
+
}
|
| 130 |
+
|
| 131 |
+
with open(f"{OUTPUT_DIR}/data_stats.json", 'w') as f:
|
| 132 |
+
json.dump(stats, f, indent=2)
|
| 133 |
+
|
| 134 |
+
# Summary
|
| 135 |
+
print("\n" + "=" * 50)
|
| 136 |
+
print("COMPLETE")
|
| 137 |
+
print("=" * 50)
|
| 138 |
+
print(f" Train: {len(train_data):,}")
|
| 139 |
+
print(f" Val: {len(val_data):,}")
|
| 140 |
+
print(f" Test: {len(test_data):,}")
|
| 141 |
+
|
| 142 |
+
return stats
|
| 143 |
+
|
| 144 |
+
# =============================================================================
|
| 145 |
+
# ENTRY POINT
|
| 146 |
+
# =============================================================================
|
| 147 |
+
|
| 148 |
+
if __name__ == "__main__":
|
| 149 |
+
prepare_finetuning_data()
|
src/finetuning/train.py
ADDED
|
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Fine-Tuning Script for SQL Generation Model
|
| 3 |
+
Uses LoRA for efficient fine-tuning.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import json
|
| 8 |
+
import torch
|
| 9 |
+
from datetime import datetime
|
| 10 |
+
from datasets import load_dataset
|
| 11 |
+
from transformers import (
|
| 12 |
+
AutoModelForCausalLM,
|
| 13 |
+
AutoTokenizer,
|
| 14 |
+
TrainingArguments,
|
| 15 |
+
Trainer,
|
| 16 |
+
DataCollatorForLanguageModeling
|
| 17 |
+
)
|
| 18 |
+
from peft import LoraConfig, get_peft_model, TaskType
|
| 19 |
+
|
| 20 |
+
# =============================================================================
|
| 21 |
+
# CONFIGURATION
|
| 22 |
+
# =============================================================================
|
| 23 |
+
|
| 24 |
+
MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
|
| 25 |
+
OUTPUT_DIR = "outputs/finetuning"
|
| 26 |
+
CHECKPOINT_DIR = f"{OUTPUT_DIR}/checkpoints"
|
| 27 |
+
LOGS_DIR = f"{OUTPUT_DIR}/logs"
|
| 28 |
+
|
| 29 |
+
# Training config (optimized for RTX 4070)
|
| 30 |
+
TRAINING_CONFIG = {
|
| 31 |
+
'num_epochs': 3,
|
| 32 |
+
'batch_size': 8,
|
| 33 |
+
'learning_rate': 2e-4,
|
| 34 |
+
'max_length': 256,
|
| 35 |
+
'warmup_steps': 100,
|
| 36 |
+
'logging_steps': 50,
|
| 37 |
+
'save_steps': 500,
|
| 38 |
+
'gradient_accumulation_steps': 2,
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
# LoRA config
|
| 42 |
+
LORA_CONFIG = {
|
| 43 |
+
'r': 16,
|
| 44 |
+
'lora_alpha': 32,
|
| 45 |
+
'lora_dropout': 0.1,
|
| 46 |
+
'target_modules': ['q_proj', 'v_proj', 'k_proj', 'o_proj']
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
def setup_directories():
|
| 50 |
+
for d in [OUTPUT_DIR, CHECKPOINT_DIR, LOGS_DIR]:
|
| 51 |
+
os.makedirs(d, exist_ok=True)
|
| 52 |
+
|
| 53 |
+
# =============================================================================
|
| 54 |
+
# TRAINING FUNCTIONS
|
| 55 |
+
# =============================================================================
|
| 56 |
+
|
| 57 |
+
def load_data():
|
| 58 |
+
"""Load prepared training data."""
|
| 59 |
+
train_file = f"{OUTPUT_DIR}/train.jsonl"
|
| 60 |
+
val_file = f"{OUTPUT_DIR}/val.jsonl"
|
| 61 |
+
|
| 62 |
+
if not os.path.exists(train_file):
|
| 63 |
+
raise FileNotFoundError("Run prepare_data.py first!")
|
| 64 |
+
|
| 65 |
+
return load_dataset('json', data_files={
|
| 66 |
+
'train': train_file,
|
| 67 |
+
'validation': val_file
|
| 68 |
+
})
|
| 69 |
+
|
| 70 |
+
def setup_model():
|
| 71 |
+
"""Load model and tokenizer with LoRA."""
|
| 72 |
+
print(f"Loading: {MODEL_NAME}")
|
| 73 |
+
|
| 74 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
| 75 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 76 |
+
tokenizer.padding_side = "right"
|
| 77 |
+
|
| 78 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 79 |
+
MODEL_NAME,
|
| 80 |
+
torch_dtype=torch.float16,
|
| 81 |
+
device_map="auto"
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
lora_config = LoraConfig(
|
| 85 |
+
task_type=TaskType.CAUSAL_LM,
|
| 86 |
+
r=LORA_CONFIG['r'],
|
| 87 |
+
lora_alpha=LORA_CONFIG['lora_alpha'],
|
| 88 |
+
lora_dropout=LORA_CONFIG['lora_dropout'],
|
| 89 |
+
target_modules=LORA_CONFIG['target_modules']
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
model = get_peft_model(model, lora_config)
|
| 93 |
+
model.print_trainable_parameters()
|
| 94 |
+
|
| 95 |
+
return model, tokenizer
|
| 96 |
+
|
| 97 |
+
def tokenize(examples, tokenizer):
|
| 98 |
+
"""Tokenize examples."""
|
| 99 |
+
return tokenizer(
|
| 100 |
+
examples['text'],
|
| 101 |
+
truncation=True,
|
| 102 |
+
padding='max_length',
|
| 103 |
+
max_length=TRAINING_CONFIG['max_length']
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
def train(model, tokenizer, dataset):
|
| 107 |
+
"""Train the model."""
|
| 108 |
+
|
| 109 |
+
# Tokenize
|
| 110 |
+
print("Tokenizing...")
|
| 111 |
+
tokenized_train = dataset['train'].map(
|
| 112 |
+
lambda x: tokenize(x, tokenizer),
|
| 113 |
+
batched=True,
|
| 114 |
+
remove_columns=dataset['train'].column_names
|
| 115 |
+
)
|
| 116 |
+
tokenized_val = dataset['validation'].map(
|
| 117 |
+
lambda x: tokenize(x, tokenizer),
|
| 118 |
+
batched=True,
|
| 119 |
+
remove_columns=dataset['validation'].column_names
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
# Training args
|
| 123 |
+
training_args = TrainingArguments(
|
| 124 |
+
output_dir=CHECKPOINT_DIR,
|
| 125 |
+
num_train_epochs=TRAINING_CONFIG['num_epochs'],
|
| 126 |
+
per_device_train_batch_size=TRAINING_CONFIG['batch_size'],
|
| 127 |
+
per_device_eval_batch_size=TRAINING_CONFIG['batch_size'],
|
| 128 |
+
learning_rate=TRAINING_CONFIG['learning_rate'],
|
| 129 |
+
warmup_steps=TRAINING_CONFIG['warmup_steps'],
|
| 130 |
+
logging_steps=TRAINING_CONFIG['logging_steps'],
|
| 131 |
+
save_steps=TRAINING_CONFIG['save_steps'],
|
| 132 |
+
gradient_accumulation_steps=TRAINING_CONFIG['gradient_accumulation_steps'],
|
| 133 |
+
eval_strategy="steps",
|
| 134 |
+
eval_steps=TRAINING_CONFIG['save_steps'],
|
| 135 |
+
save_total_limit=2,
|
| 136 |
+
fp16=True,
|
| 137 |
+
report_to="none",
|
| 138 |
+
logging_dir=LOGS_DIR,
|
| 139 |
+
dataloader_pin_memory=False,
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
# Trainer
|
| 143 |
+
trainer = Trainer(
|
| 144 |
+
model=model,
|
| 145 |
+
args=training_args,
|
| 146 |
+
train_dataset=tokenized_train,
|
| 147 |
+
eval_dataset=tokenized_val,
|
| 148 |
+
data_collator=DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
# Train
|
| 152 |
+
print(f"\nTraining: {len(tokenized_train)} samples, {TRAINING_CONFIG['num_epochs']} epochs")
|
| 153 |
+
result = trainer.train()
|
| 154 |
+
|
| 155 |
+
# Save
|
| 156 |
+
print("\nSaving model...")
|
| 157 |
+
trainer.save_model(f"{CHECKPOINT_DIR}/final")
|
| 158 |
+
tokenizer.save_pretrained(f"{CHECKPOINT_DIR}/final")
|
| 159 |
+
|
| 160 |
+
# Stats
|
| 161 |
+
stats = {
|
| 162 |
+
'train_loss': result.training_loss,
|
| 163 |
+
'runtime_seconds': result.metrics['train_runtime'],
|
| 164 |
+
'samples_per_second': result.metrics['train_samples_per_second'],
|
| 165 |
+
'epochs': TRAINING_CONFIG['num_epochs'],
|
| 166 |
+
'total_steps': result.global_step,
|
| 167 |
+
'gpu': torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU',
|
| 168 |
+
'completed_at': datetime.now().isoformat()
|
| 169 |
+
}
|
| 170 |
+
|
| 171 |
+
with open(f"{CHECKPOINT_DIR}/training_stats.json", 'w') as f:
|
| 172 |
+
json.dump(stats, f, indent=2)
|
| 173 |
+
|
| 174 |
+
return stats
|
| 175 |
+
|
| 176 |
+
# =============================================================================
|
| 177 |
+
# MAIN
|
| 178 |
+
# =============================================================================
|
| 179 |
+
|
| 180 |
+
def run_finetuning():
|
| 181 |
+
"""Main function."""
|
| 182 |
+
|
| 183 |
+
print("=" * 60)
|
| 184 |
+
print("FINE-TUNING SQL MODEL")
|
| 185 |
+
if torch.cuda.is_available():
|
| 186 |
+
print(f"GPU: {torch.cuda.get_device_name(0)}")
|
| 187 |
+
else:
|
| 188 |
+
print("GPU: Not available (using CPU)")
|
| 189 |
+
print("=" * 60)
|
| 190 |
+
|
| 191 |
+
setup_directories()
|
| 192 |
+
|
| 193 |
+
# Load data
|
| 194 |
+
print("\n[1/3] Loading data...")
|
| 195 |
+
dataset = load_data()
|
| 196 |
+
print(f" Train: {len(dataset['train']):,}")
|
| 197 |
+
print(f" Val: {len(dataset['validation']):,}")
|
| 198 |
+
|
| 199 |
+
# Setup model
|
| 200 |
+
print("\n[2/3] Setting up model...")
|
| 201 |
+
model, tokenizer = setup_model()
|
| 202 |
+
|
| 203 |
+
# Train
|
| 204 |
+
print("\n[3/3] Training...")
|
| 205 |
+
stats = train(model, tokenizer, dataset)
|
| 206 |
+
|
| 207 |
+
# Done
|
| 208 |
+
print("\n" + "=" * 60)
|
| 209 |
+
print("TRAINING COMPLETE")
|
| 210 |
+
print("=" * 60)
|
| 211 |
+
print(f" Loss: {stats['train_loss']:.4f}")
|
| 212 |
+
print(f" Time: {stats['runtime_seconds']/60:.1f} min")
|
| 213 |
+
print(f" Model: {CHECKPOINT_DIR}/final")
|
| 214 |
+
|
| 215 |
+
return stats
|
| 216 |
+
|
| 217 |
+
if __name__ == "__main__":
|
| 218 |
+
run_finetuning()
|
src/outputs/finetuning/data_stats.json
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"train_samples": 100,
|
| 3 |
+
"val_samples": 100,
|
| 4 |
+
"test_samples": 100,
|
| 5 |
+
"max_samples": 100,
|
| 6 |
+
"created_at": "2025-12-08T01:22:29.002186"
|
| 7 |
+
}
|
src/outputs/finetuning/results/evaluation_report.md
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Fine-Tuning Evaluation Report
|
| 2 |
+
|
| 3 |
+
**Generated:** 2025-12-08 01:32:37
|
| 4 |
+
|
| 5 |
+
## Metrics Summary
|
| 6 |
+
|
| 7 |
+
| Metric | Score |
|
| 8 |
+
|--------|-------|
|
| 9 |
+
| Samples Evaluated | 50 |
|
| 10 |
+
| Exact Match Rate | 0.00% |
|
| 11 |
+
| Token Accuracy | 47.21% |
|
| 12 |
+
| Keyword Accuracy | 91.33% |
|
| 13 |
+
| Structure Similarity | 91.07% |
|
| 14 |
+
|
| 15 |
+
## Metrics Explanation
|
| 16 |
+
|
| 17 |
+
- **Exact Match**: Predictions identical to ground truth
|
| 18 |
+
- **Token Accuracy**: Word overlap between prediction and expected
|
| 19 |
+
- **Keyword Accuracy**: SQL keywords (SELECT, WHERE, etc.) match
|
| 20 |
+
- **Structure Similarity**: Query structure (clauses used) match
|
| 21 |
+
|
| 22 |
+
## Visualizations
|
| 23 |
+
|
| 24 |
+
- `01_metrics_overview.png` - All metrics bar chart
|
| 25 |
+
- `02_token_accuracy_dist.png` - Token accuracy histogram
|
| 26 |
+
- `03_keyword_accuracy_dist.png` - Keyword accuracy histogram
|
src/outputs/finetuning/results/evaluation_results.json
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"total_samples": 50,
|
| 3 |
+
"exact_match_rate": 0.0,
|
| 4 |
+
"avg_token_accuracy": 0.472115604983252,
|
| 5 |
+
"avg_keyword_accuracy": 0.9133333333333334,
|
| 6 |
+
"avg_structure_similarity": 0.9106666666666667
|
| 7 |
+
}
|
src/outputs/finetuning/test.jsonl
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{"text": "### Question:\nWhat is the whole of Drawn that has a Lost of 4?\n\n### SQL:\nSELECT SUM Drawn FROM table WHERE Lost = 4", "question": "What is the whole of Drawn that has a Lost of 4?", "sql": "SELECT SUM Drawn FROM table WHERE Lost = 4", "source": "test"}
|
| 2 |
+
{"text": "### Question:\nWhat is the rating of the episode with a rating/share of 0.9/4?\n\n### SQL:\nSELECT Rating FROM table WHERE Rating/Share (18\u201349) = 0.9/4", "question": "What is the rating of the episode with a rating/share of 0.9/4?", "sql": "SELECT Rating FROM table WHERE Rating/Share (18\u201349) = 0.9/4", "source": "test"}
|
| 3 |
+
{"text": "### Question:\nWhat is the last year that someone is first elected in this table?\n\n### SQL:\nSELECT MAX First elected FROM table", "question": "What is the last year that someone is first elected in this table?", "sql": "SELECT MAX First elected FROM table", "source": "test"}
|
| 4 |
+
{"text": "### Question:\nHow many poles had 72 points?\n\n### SQL:\nSELECT COUNT Poles FROM table WHERE Points = 72", "question": "How many poles had 72 points?", "sql": "SELECT COUNT Poles FROM table WHERE Points = 72", "source": "test"}
|
| 5 |
+
{"text": "### Question:\nWho are all the Moto2 winners when the grand prix was Shell Advance Malaysian Grand Prix?\n\n### SQL:\nSELECT Moto2 winner FROM table WHERE Grand Prix = Shell Advance Malaysian Grand Prix", "question": "Who are all the Moto2 winners when the grand prix was Shell Advance Malaysian Grand Prix?", "sql": "SELECT Moto2 winner FROM table WHERE Grand Prix = Shell Advance Malaysian Grand Prix", "source": "test"}
|
| 6 |
+
{"text": "### Question:\nHow many incumbents are there in the georgia 8 district when the party is democratic?\n\n### SQL:\nSELECT COUNT Incumbent FROM table WHERE Party = Democratic AND District = Georgia 8", "question": "How many incumbents are there in the georgia 8 district when the party is democratic?", "sql": "SELECT COUNT Incumbent FROM table WHERE Party = Democratic AND District = Georgia 8", "source": "test"}
|
| 7 |
+
{"text": "### Question:\nWhat is the largest Area (msr) that has an Area less than 291.045, is part of the Per family, and has a rank higher than 78?\n\n### SQL:\nSELECT MAX Area (msr) FROM table WHERE Area (sq.deg.) < 291.045 AND Family = per AND Rank > 78", "question": "What is the largest Area (msr) that has an Area less than 291.045, is part of the Per family, and has a rank higher than 78?", "sql": "SELECT MAX Area (msr) FROM table WHERE Area (sq.deg.) < 291.045 AND Family = per AND Rank > 78", "source": "test"}
|
| 8 |
+
{"text": "### Question:\nWho won women's double in 2002, the year that Kenneth Vella won men's singles?\n\n### SQL:\nSELECT Women's doubles FROM table WHERE Men's singles = kenneth vella AND Year = 2002", "question": "Who won women's double in 2002, the year that Kenneth Vella won men's singles?", "sql": "SELECT Women's doubles FROM table WHERE Men's singles = kenneth vella AND Year = 2002", "source": "test"}
|
| 9 |
+
{"text": "### Question:\nWhat's k. j. choi's to par?\n\n### SQL:\nSELECT To par FROM table WHERE Player = k. j. choi", "question": "What's k. j. choi's to par?", "sql": "SELECT To par FROM table WHERE Player = k. j. choi", "source": "test"}
|
| 10 |
+
{"text": "### Question:\nWho was the winner when the time was 1:24.00?\n\n### SQL:\nSELECT Winner/2nd FROM table WHERE Time = 1:24.00", "question": "Who was the winner when the time was 1:24.00?", "sql": "SELECT Winner/2nd FROM table WHERE Time = 1:24.00", "source": "test"}
|
| 11 |
+
{"text": "### Question:\nWhich couple had a week 2 score of exactly 23?\n\n### SQL:\nSELECT Couple FROM table WHERE Week 2 = 23", "question": "Which couple had a week 2 score of exactly 23?", "sql": "SELECT Couple FROM table WHERE Week 2 = 23", "source": "test"}
|
| 12 |
+
{"text": "### Question:\nWhat is the record when the result was w 52\u201343?\n\n### SQL:\nSELECT Record FROM table WHERE Result = w 52\u201343", "question": "What is the record when the result was w 52\u201343?", "sql": "SELECT Record FROM table WHERE Result = w 52\u201343", "source": "test"}
|
| 13 |
+
{"text": "### Question:\nWhat's the L3 cache that has a low power part number?\n\n### SQL:\nSELECT L3 cache FROM table WHERE Part number(s) = low power", "question": "What's the L3 cache that has a low power part number?", "sql": "SELECT L3 cache FROM table WHERE Part number(s) = low power", "source": "test"}
|
| 14 |
+
{"text": "### Question:\nWho won the race on 24 August?\n\n### SQL:\nSELECT Winning driver FROM table WHERE Date = 24 august", "question": "Who won the race on 24 August?", "sql": "SELECT Winning driver FROM table WHERE Date = 24 august", "source": "test"}
|
| 15 |
+
{"text": "### Question:\nWhat is the original artist when the vocal percussionist is Alexei Kalveks?\n\n### SQL:\nSELECT Original Artist FROM table WHERE Vocal Percussionist = Alexei Kalveks", "question": "What is the original artist when the vocal percussionist is Alexei Kalveks?", "sql": "SELECT Original Artist FROM table WHERE Vocal Percussionist = Alexei Kalveks", "source": "test"}
|
| 16 |
+
{"text": "### Question:\nwhat is the winning % for the years 2006-11?\n\n### SQL:\nSELECT Winning % FROM table WHERE Years = 2006-11", "question": "what is the winning % for the years 2006-11?", "sql": "SELECT Winning % FROM table WHERE Years = 2006-11", "source": "test"}
|
| 17 |
+
{"text": "### Question:\nWhat is Method, when Event is \"Reality Submission Fighting 2\"?\n\n### SQL:\nSELECT Method FROM table WHERE Event = reality submission fighting 2", "question": "What is Method, when Event is \"Reality Submission Fighting 2\"?", "sql": "SELECT Method FROM table WHERE Event = reality submission fighting 2", "source": "test"}
|
| 18 |
+
{"text": "### Question:\nWhat was the date of the game that had a loss of lidle (10-8)?\n\n### SQL:\nSELECT Date FROM table WHERE Loss = lidle (10-8)", "question": "What was the date of the game that had a loss of lidle (10-8)?", "sql": "SELECT Date FROM table WHERE Loss = lidle (10-8)", "source": "test"}
|
| 19 |
+
{"text": "### Question:\nWho played mixed doubles when Anna Keir played women's singles?\n\n### SQL:\nSELECT Mixed doubles FROM table WHERE Women's singles = anna keir", "question": "Who played mixed doubles when Anna Keir played women's singles?", "sql": "SELECT Mixed doubles FROM table WHERE Women's singles = anna keir", "source": "test"}
|
| 20 |
+
{"text": "### Question:\nWhich tone has a Standard Thai at \u0e1b\u0e25\u0e32?\n\n### SQL:\nSELECT Tone FROM table WHERE Standard Thai = \u0e1b\u0e25\u0e32", "question": "Which tone has a Standard Thai at \u0e1b\u0e25\u0e32?", "sql": "SELECT Tone FROM table WHERE Standard Thai = \u0e1b\u0e25\u0e32", "source": "test"}
|
| 21 |
+
{"text": "### Question:\nName the payload that has a weight of 12,000\n\n### SQL:\nSELECT Payload (kg) FROM table WHERE Weight (kg) = 12,000", "question": "Name the payload that has a weight of 12,000", "sql": "SELECT Payload (kg) FROM table WHERE Weight (kg) = 12,000", "source": "test"}
|
| 22 |
+
{"text": "### Question:\nWhich school's round was 24?\n\n### SQL:\nSELECT School FROM table WHERE Round = 24", "question": "Which school's round was 24?", "sql": "SELECT School FROM table WHERE Round = 24", "source": "test"}
|
| 23 |
+
{"text": "### Question:\nWhat is the least for Scottish Cup with a Challenge Cup greater than 0, Player Paul Keegan, and League Cup greater than 0?\n\n### SQL:\nSELECT MIN Scottish Cup FROM table WHERE Challenge Cup > 0 AND Player = paul keegan AND League Cup > 0", "question": "What is the least for Scottish Cup with a Challenge Cup greater than 0, Player Paul Keegan, and League Cup greater than 0?", "sql": "SELECT MIN Scottish Cup FROM table WHERE Challenge Cup > 0 AND Player = paul keegan AND League Cup > 0", "source": "test"}
|
| 24 |
+
{"text": "### Question:\nWhat is the only type of university that was founded in 1873?\n\n### SQL:\nSELECT Control FROM table WHERE Founded = 1873", "question": "What is the only type of university that was founded in 1873?", "sql": "SELECT Control FROM table WHERE Founded = 1873", "source": "test"}
|
| 25 |
+
{"text": "### Question:\nHow many games were there in the 1966 season?\n\n### SQL:\nSELECT MAX Game FROM table", "question": "How many games were there in the 1966 season?", "sql": "SELECT MAX Game FROM table", "source": "test"}
|
| 26 |
+
{"text": "### Question:\nWhen luz mcclinton is the name what is the season?\n\n### SQL:\nSELECT Season FROM table WHERE Name = Luz McClinton", "question": "When luz mcclinton is the name what is the season?", "sql": "SELECT Season FROM table WHERE Name = Luz McClinton", "source": "test"}
|
| 27 |
+
{"text": "### Question:\nWho was the original artist for First Solo?\n\n### SQL:\nSELECT Original artist FROM table WHERE Theme = First Solo", "question": "Who was the original artist for First Solo?", "sql": "SELECT Original artist FROM table WHERE Theme = First Solo", "source": "test"}
|
| 28 |
+
{"text": "### Question:\nWhat is the lowest grid for Roberto Rolfo with more than 26 laps?\n\n### SQL:\nSELECT MIN Grid FROM table WHERE Rider = roberto rolfo AND Laps > 26", "question": "What is the lowest grid for Roberto Rolfo with more than 26 laps?", "sql": "SELECT MIN Grid FROM table WHERE Rider = roberto rolfo AND Laps > 26", "source": "test"}
|
| 29 |
+
{"text": "### Question:\nWhat is the average lap for suzuki gsx-r1000 k7 and at grid 6?\n\n### SQL:\nSELECT AVG Laps FROM table WHERE Bike = suzuki gsx-r1000 k7 AND Grid = 6", "question": "What is the average lap for suzuki gsx-r1000 k7 and at grid 6?", "sql": "SELECT AVG Laps FROM table WHERE Bike = suzuki gsx-r1000 k7 AND Grid = 6", "source": "test"}
|
| 30 |
+
{"text": "### Question:\nWhat is the date when the Lakers were the home team?\n\n### SQL:\nSELECT Date FROM table WHERE Home = lakers", "question": "What is the date when the Lakers were the home team?", "sql": "SELECT Date FROM table WHERE Home = lakers", "source": "test"}
|
| 31 |
+
{"text": "### Question:\nWhen was there a score of 7-1?\n\n### SQL:\nSELECT Date FROM table WHERE Score = 7-1", "question": "When was there a score of 7-1?", "sql": "SELECT Date FROM table WHERE Score = 7-1", "source": "test"}
|
| 32 |
+
{"text": "### Question:\nWho received 6,131 televotes?\n\n### SQL:\nSELECT Televote Points FROM table WHERE Televotes = 6,131", "question": "Who received 6,131 televotes?", "sql": "SELECT Televote Points FROM table WHERE Televotes = 6,131", "source": "test"}
|
| 33 |
+
{"text": "### Question:\nHow many episodes aired Saturday, July 11, 2009\n\n### SQL:\nSELECT COUNT Episode # FROM table WHERE US air date = Saturday, July 11, 2009", "question": "How many episodes aired Saturday, July 11, 2009", "sql": "SELECT COUNT Episode # FROM table WHERE US air date = Saturday, July 11, 2009", "source": "test"}
|
| 34 |
+
{"text": "### Question:\nWhich episode aired in the USA on 20 May 2005?\n\n### SQL:\nSELECT Episode FROM table WHERE Airdate (USA) = 20 may 2005", "question": "Which episode aired in the USA on 20 May 2005?", "sql": "SELECT Episode FROM table WHERE Airdate (USA) = 20 may 2005", "source": "test"}
|
| 35 |
+
{"text": "### Question:\nWhat was the best finish for 206 on the money list?\n\n### SQL:\nSELECT Best finish FROM table WHERE Money list rank = 206", "question": "What was the best finish for 206 on the money list?", "sql": "SELECT Best finish FROM table WHERE Money list rank = 206", "source": "test"}
|
| 36 |
+
{"text": "### Question:\nName the sum of pick # for round less than 1\n\n### SQL:\nSELECT SUM Pick # FROM table WHERE Round < 1", "question": "Name the sum of pick # for round less than 1", "sql": "SELECT SUM Pick # FROM table WHERE Round < 1", "source": "test"}
|
| 37 |
+
{"text": "### Question:\nWhat engine did the Team Lotus have after 1965?\n\n### SQL:\nSELECT Engine FROM table WHERE Entrant = team lotus AND Year > 1965", "question": "What engine did the Team Lotus have after 1965?", "sql": "SELECT Engine FROM table WHERE Entrant = team lotus AND Year > 1965", "source": "test"}
|
| 38 |
+
{"text": "### Question:\nOpponent of chicago bulls had what location?\n\n### SQL:\nSELECT Location FROM table WHERE Opponent = chicago bulls", "question": "Opponent of chicago bulls had what location?", "sql": "SELECT Location FROM table WHERE Opponent = chicago bulls", "source": "test"}
|
| 39 |
+
{"text": "### Question:\nWhat religious groups made up 0.72% of the Indian population in 2001?\n\n### SQL:\nSELECT Religious group FROM table WHERE Population % 2001 = 0.72%", "question": "What religious groups made up 0.72% of the Indian population in 2001?", "sql": "SELECT Religious group FROM table WHERE Population % 2001 = 0.72%", "source": "test"}
|
| 40 |
+
{"text": "### Question:\nOn which date was the high assists Delonte West Earl Watson (6)?\n\n### SQL:\nSELECT Date FROM table WHERE High assists = delonte west earl watson (6)", "question": "On which date was the high assists Delonte West Earl Watson (6)?", "sql": "SELECT Date FROM table WHERE High assists = delonte west earl watson (6)", "source": "test"}
|
| 41 |
+
{"text": "### Question:\nWhat is the enrollment for the institution in Westfield, Massachusetts? \n\n### SQL:\nSELECT Enrollment FROM table WHERE Location = Westfield, Massachusetts", "question": "What is the enrollment for the institution in Westfield, Massachusetts? ", "sql": "SELECT Enrollment FROM table WHERE Location = Westfield, Massachusetts", "source": "test"}
|
| 42 |
+
{"text": "### Question:\nWhat is the Studio of the Film with a Gross rental of $7,500,000?\n\n### SQL:\nSELECT Studio FROM table WHERE Gross rental = $7,500,000", "question": "What is the Studio of the Film with a Gross rental of $7,500,000?", "sql": "SELECT Studio FROM table WHERE Gross rental = $7,500,000", "source": "test"}
|
| 43 |
+
{"text": "### Question:\nWhat was the Outcome of the match played on Hard (i) Surface?\n\n### SQL:\nSELECT Outcome FROM table WHERE Surface = hard (i)", "question": "What was the Outcome of the match played on Hard (i) Surface?", "sql": "SELECT Outcome FROM table WHERE Surface = hard (i)", "source": "test"}
|
| 44 |
+
{"text": "### Question:\nWhat is the highest rank of the player who played 30 events and made less than $2,708,005?\n\n### SQL:\nSELECT MAX Rank FROM table WHERE Earnings ( $ ) < 2,708,005 AND Events = 30", "question": "What is the highest rank of the player who played 30 events and made less than $2,708,005?", "sql": "SELECT MAX Rank FROM table WHERE Earnings ( $ ) < 2,708,005 AND Events = 30", "source": "test"}
|
| 45 |
+
{"text": "### Question:\nHow many games had Montreal Canadiens as an opponent?\n\n### SQL:\nSELECT SUM Game FROM table WHERE Opponent = montreal canadiens", "question": "How many games had Montreal Canadiens as an opponent?", "sql": "SELECT SUM Game FROM table WHERE Opponent = montreal canadiens", "source": "test"}
|
| 46 |
+
{"text": "### Question:\nWhat is the length of the highway with the route name sh 2?\n\n### SQL:\nSELECT Length FROM table WHERE Route Name = sh 2", "question": "What is the length of the highway with the route name sh 2?", "sql": "SELECT Length FROM table WHERE Route Name = sh 2", "source": "test"}
|
| 47 |
+
{"text": "### Question:\nCan you tell me total number of Silver that has the Republic of latvian ssr, and the Total larger than 6?\n\n### SQL:\nSELECT COUNT Silver FROM table WHERE Republic = latvian ssr AND Total > 6", "question": "Can you tell me total number of Silver that has the Republic of latvian ssr, and the Total larger than 6?", "sql": "SELECT COUNT Silver FROM table WHERE Republic = latvian ssr AND Total > 6", "source": "test"}
|
| 48 |
+
{"text": "### Question:\nWhat date did they play the Florida Panthers?\n\n### SQL:\nSELECT Date FROM table WHERE Opponent = Florida Panthers", "question": "What date did they play the Florida Panthers?", "sql": "SELECT Date FROM table WHERE Opponent = Florida Panthers", "source": "test"}
|
| 49 |
+
{"text": "### Question:\nwhich college has a player called Riley Clayton?\n\n### SQL:\nSELECT College FROM table WHERE Player = riley clayton", "question": "which college has a player called Riley Clayton?", "sql": "SELECT College FROM table WHERE Player = riley clayton", "source": "test"}
|
| 50 |
+
{"text": "### Question:\nWhat is the most minimal Final year that has a North or east end of covington?\n\n### SQL:\nSELECT MIN Final year FROM table WHERE North or east terminus = covington", "question": "What is the most minimal Final year that has a North or east end of covington?", "sql": "SELECT MIN Final year FROM table WHERE North or east terminus = covington", "source": "test"}
|
| 51 |
+
{"text": "### Question:\nWho directed the episode whose production code is pabf05?\n\n### SQL:\nSELECT Directed by FROM table WHERE Production code = PABF05", "question": "Who directed the episode whose production code is pabf05?", "sql": "SELECT Directed by FROM table WHERE Production code = PABF05", "source": "test"}
|
| 52 |
+
{"text": "### Question:\nWhat is the best fit (all data) when the best fit (WMAP, extra parameter) shows \u2014?\n\n### SQL:\nSELECT Best fit (all data) FROM table WHERE Best fit (WMAP, extra parameter) = \u2014", "question": "What is the best fit (all data) when the best fit (WMAP, extra parameter) shows \u2014?", "sql": "SELECT Best fit (all data) FROM table WHERE Best fit (WMAP, extra parameter) = \u2014", "source": "test"}
|
| 53 |
+
{"text": "### Question:\nCan you tell me the lowest Week that has the Attendance smaller than 34,336?\n\n### SQL:\nSELECT MIN Week FROM table WHERE Attendance < 34,336", "question": "Can you tell me the lowest Week that has the Attendance smaller than 34,336?", "sql": "SELECT MIN Week FROM table WHERE Attendance < 34,336", "source": "test"}
|
| 54 |
+
{"text": "### Question:\nWhat is the Lead in the 2004-05 Season?\n\n### SQL:\nSELECT Lead FROM table WHERE Season = 2004-05", "question": "What is the Lead in the 2004-05 Season?", "sql": "SELECT Lead FROM table WHERE Season = 2004-05", "source": "test"}
|
| 55 |
+
{"text": "### Question:\nWhat is the result for week 12 against the Green Bay Packers?\n\n### SQL:\nSELECT Result FROM table WHERE Week > 12 AND Opponent = green bay packers", "question": "What is the result for week 12 against the Green Bay Packers?", "sql": "SELECT Result FROM table WHERE Week > 12 AND Opponent = green bay packers", "source": "test"}
|
| 56 |
+
{"text": "### Question:\nWhat was the first season for the club that in 2012 was 2nd in Superettan?\n\n### SQL:\nSELECT First season FROM table WHERE Position in 2012 = 2nd in Superettan", "question": "What was the first season for the club that in 2012 was 2nd in Superettan?", "sql": "SELECT First season FROM table WHERE Position in 2012 = 2nd in Superettan", "source": "test"}
|
| 57 |
+
{"text": "### Question:\nWhat was the extra info for the Commonwealth Games?\n\n### SQL:\nSELECT Extra FROM table WHERE Tournament = commonwealth games", "question": "What was the extra info for the Commonwealth Games?", "sql": "SELECT Extra FROM table WHERE Tournament = commonwealth games", "source": "test"}
|
| 58 |
+
{"text": "### Question:\nWhich player played for the Grizzlies from 1997-1998?\n\n### SQL:\nSELECT Player FROM table WHERE Years for Grizzlies = 1997-1998", "question": "Which player played for the Grizzlies from 1997-1998?", "sql": "SELECT Player FROM table WHERE Years for Grizzlies = 1997-1998", "source": "test"}
|
| 59 |
+
{"text": "### Question:\nNone of the communities listed has a percentage smaller than 8.6 in 2006.\n\n### SQL:\nSELECT COUNT Seats 2001 FROM table WHERE % 2006 < 8.6", "question": "None of the communities listed has a percentage smaller than 8.6 in 2006.", "sql": "SELECT COUNT Seats 2001 FROM table WHERE % 2006 < 8.6", "source": "test"}
|
| 60 |
+
{"text": "### Question:\nWhat is the basketball status for Valparaiso who has an indoor track status of yes?\n\n### SQL:\nSELECT Bask FROM table WHERE Indoor track = yes AND School = valparaiso", "question": "What is the basketball status for Valparaiso who has an indoor track status of yes?", "sql": "SELECT Bask FROM table WHERE Indoor track = yes AND School = valparaiso", "source": "test"}
|
| 61 |
+
{"text": "### Question:\nWhat 1979 Hindi film had Ravindra Jain directing music?\n\n### SQL:\nSELECT Film name FROM table WHERE Language = hindi AND Lyricist = ravindra jain AND Music director = ravindra jain AND Year = 1979", "question": "What 1979 Hindi film had Ravindra Jain directing music?", "sql": "SELECT Film name FROM table WHERE Language = hindi AND Lyricist = ravindra jain AND Music director = ravindra jain AND Year = 1979", "source": "test"}
|
| 62 |
+
{"text": "### Question:\nWhat was the winning score in the Alfred Dunhill links championship?\n\n### SQL:\nSELECT Winning score FROM table WHERE Tournament = alfred dunhill links championship", "question": "What was the winning score in the Alfred Dunhill links championship?", "sql": "SELECT Winning score FROM table WHERE Tournament = alfred dunhill links championship", "source": "test"}
|
| 63 |
+
{"text": "### Question:\nTell me the highest Grid for Maurice Trintignant and laps less than 87\n\n### SQL:\nSELECT MAX Grid FROM table WHERE Driver = maurice trintignant AND Laps < 87", "question": "Tell me the highest Grid for Maurice Trintignant and laps less than 87", "sql": "SELECT MAX Grid FROM table WHERE Driver = maurice trintignant AND Laps < 87", "source": "test"}
|
| 64 |
+
{"text": "### Question:\nWhich Play-Off has Events of \u2013, and a Season of a-6?\n\n### SQL:\nSELECT Play-Off FROM table WHERE Events = \u2013 AND Season = a-6", "question": "Which Play-Off has Events of \u2013, and a Season of a-6?", "sql": "SELECT Play-Off FROM table WHERE Events = \u2013 AND Season = a-6", "source": "test"}
|
| 65 |
+
{"text": "### Question:\nWho did team phoenix visit in their home?\n\n### SQL:\nSELECT Home FROM table WHERE Visitor = phoenix", "question": "Who did team phoenix visit in their home?", "sql": "SELECT Home FROM table WHERE Visitor = phoenix", "source": "test"}
|
| 66 |
+
{"text": "### Question:\nName the record with home of bucks on 24 november 2007\n\n### SQL:\nSELECT Record FROM table WHERE Home = bucks AND Date = 24 november 2007", "question": "Name the record with home of bucks on 24 november 2007", "sql": "SELECT Record FROM table WHERE Home = bucks AND Date = 24 november 2007", "source": "test"}
|
| 67 |
+
{"text": "### Question:\nWhat is Myron Walwyn with a Territorial at-large Constiuency's First Elected Date\n\n### SQL:\nSELECT First elected FROM table WHERE Constiuency = territorial at-large AND Name = myron walwyn", "question": "What is Myron Walwyn with a Territorial at-large Constiuency's First Elected Date", "sql": "SELECT First elected FROM table WHERE Constiuency = territorial at-large AND Name = myron walwyn", "source": "test"}
|
| 68 |
+
{"text": "### Question:\nWhat competition had a Rank-Qualifying of 1st and a ball apparatus?\n\n### SQL:\nSELECT Competition Description FROM table WHERE Rank-Qualifying = 1st AND Apparatus = ball", "question": "What competition had a Rank-Qualifying of 1st and a ball apparatus?", "sql": "SELECT Competition Description FROM table WHERE Rank-Qualifying = 1st AND Apparatus = ball", "source": "test"}
|
| 69 |
+
{"text": "### Question:\nWhat is the average area larger than Code 19025 but a smaller region than 12?\n\n### SQL:\nSELECT AVG Area (km 2 ) FROM table WHERE Code > 19025 AND Region < 12", "question": "What is the average area larger than Code 19025 but a smaller region than 12?", "sql": "SELECT AVG Area (km 2 ) FROM table WHERE Code > 19025 AND Region < 12", "source": "test"}
|
| 70 |
+
{"text": "### Question:\nWhat is the lowest number of laps for Marco Simoncelli on a grid higher than 11?\n\n### SQL:\nSELECT MIN Laps FROM table WHERE Rider = marco simoncelli AND Grid > 11", "question": "What is the lowest number of laps for Marco Simoncelli on a grid higher than 11?", "sql": "SELECT MIN Laps FROM table WHERE Rider = marco simoncelli AND Grid > 11", "source": "test"}
|
| 71 |
+
{"text": "### Question:\nWhat is the highest Year, when Apparatus is \"Vault\", and when Rank-Final is less than 9?\n\n### SQL:\nSELECT MAX Year FROM table WHERE Apparatus = vault AND Rank-Final < 9", "question": "What is the highest Year, when Apparatus is \"Vault\", and when Rank-Final is less than 9?", "sql": "SELECT MAX Year FROM table WHERE Apparatus = vault AND Rank-Final < 9", "source": "test"}
|
| 72 |
+
{"text": "### Question:\nHow many poles has a percentage of 22.08%?\n\n### SQL:\nSELECT SUM Poles FROM table WHERE Percentage = 22.08%", "question": "How many poles has a percentage of 22.08%?", "sql": "SELECT SUM Poles FROM table WHERE Percentage = 22.08%", "source": "test"}
|
| 73 |
+
{"text": "### Question:\nName the number of score for sacramento\n\n### SQL:\nSELECT COUNT Score FROM table WHERE Team = Sacramento", "question": "Name the number of score for sacramento", "sql": "SELECT COUNT Score FROM table WHERE Team = Sacramento", "source": "test"}
|
| 74 |
+
{"text": "### Question:\nWhat was the nationality of the player with a score of 72-72-67=211?\n\n### SQL:\nSELECT Country FROM table WHERE Score = 72-72-67=211", "question": "What was the nationality of the player with a score of 72-72-67=211?", "sql": "SELECT Country FROM table WHERE Score = 72-72-67=211", "source": "test"}
|
| 75 |
+
{"text": "### Question:\nWhat was the venue that had 5000 m after 2009?\n\n### SQL:\nSELECT Venue FROM table WHERE Year > 2009 AND Notes = 5000 m", "question": "What was the venue that had 5000 m after 2009?", "sql": "SELECT Venue FROM table WHERE Year > 2009 AND Notes = 5000 m", "source": "test"}
|
| 76 |
+
{"text": "### Question:\nwhat is the sspec number when the part number is cw8064701470802?\n\n### SQL:\nSELECT sSpec number FROM table WHERE Part number(s) = cw8064701470802", "question": "what is the sspec number when the part number is cw8064701470802?", "sql": "SELECT sSpec number FROM table WHERE Part number(s) = cw8064701470802", "source": "test"}
|
| 77 |
+
{"text": "### Question:\nWhich Base has a Name of .44 wcf?\n\n### SQL:\nSELECT Base FROM table WHERE Name = .44 wcf", "question": "Which Base has a Name of .44 wcf?", "sql": "SELECT Base FROM table WHERE Name = .44 wcf", "source": "test"}
|
| 78 |
+
{"text": "### Question:\nWhat is the value of the runner up column for the Alberta province?\n\n### SQL:\nSELECT MAX Runner Up FROM table WHERE Province = Alberta", "question": "What is the value of the runner up column for the Alberta province?", "sql": "SELECT MAX Runner Up FROM table WHERE Province = Alberta", "source": "test"}
|
| 79 |
+
{"text": "### Question:\nWhat is the Athlete of the race with a Time of 9.78?\n\n### SQL:\nSELECT Athlete FROM table WHERE Time = 9.78", "question": "What is the Athlete of the race with a Time of 9.78?", "sql": "SELECT Athlete FROM table WHERE Time = 9.78", "source": "test"}
|
| 80 |
+
{"text": "### Question:\nWhat was the rating of the episode \"After Hours\"?\n\n### SQL:\nSELECT Rating (Millions) FROM table WHERE Title = \"after hours\"", "question": "What was the rating of the episode \"After Hours\"?", "sql": "SELECT Rating (Millions) FROM table WHERE Title = \"after hours\"", "source": "test"}
|
| 81 |
+
{"text": "### Question:\nWhat is the number for the forward position from the school/club team La Salle?\n\n### SQL:\nSELECT Number FROM table WHERE Position = forward AND School/Club Team = la salle", "question": "What is the number for the forward position from the school/club team La Salle?", "sql": "SELECT Number FROM table WHERE Position = forward AND School/Club Team = la salle", "source": "test"}
|
| 82 |
+
{"text": "### Question:\nwhat is the highest wickets when the best bowling is 2/32 and matches is less than 5?\n\n### SQL:\nSELECT MAX Wickets FROM table WHERE Best Bowling = 2/32 AND Matches < 5", "question": "what is the highest wickets when the best bowling is 2/32 and matches is less than 5?", "sql": "SELECT MAX Wickets FROM table WHERE Best Bowling = 2/32 AND Matches < 5", "source": "test"}
|
| 83 |
+
{"text": "### Question:\nWhich show runs on Friday at 05:00 AM?\n\n### SQL:\nSELECT 05:00 AM FROM table WHERE Time = friday", "question": "Which show runs on Friday at 05:00 AM?", "sql": "SELECT 05:00 AM FROM table WHERE Time = friday", "source": "test"}
|
| 84 |
+
{"text": "### Question:\nHow many players have the hometown Pennsauken, NJ?\n\n### SQL:\nSELECT COUNT Player FROM table WHERE Hometown = Pennsauken, NJ", "question": "How many players have the hometown Pennsauken, NJ?", "sql": "SELECT COUNT Player FROM table WHERE Hometown = Pennsauken, NJ", "source": "test"}
|
| 85 |
+
{"text": "### Question:\nHow many losses does Cross Keys RFC have?\n\n### SQL:\nSELECT COUNT Lost FROM table WHERE Club = Cross Keys RFC", "question": "How many losses does Cross Keys RFC have?", "sql": "SELECT COUNT Lost FROM table WHERE Club = Cross Keys RFC", "source": "test"}
|
| 86 |
+
{"text": "### Question:\nWhat was the time when the method was TKO?\n\n### SQL:\nSELECT Time FROM table WHERE Method = tko", "question": "What was the time when the method was TKO?", "sql": "SELECT Time FROM table WHERE Method = tko", "source": "test"}
|
| 87 |
+
{"text": "### Question:\nWhat was the type of sussex?\n\n### SQL:\nSELECT Type FROM table WHERE Name = Sussex", "question": "What was the type of sussex?", "sql": "SELECT Type FROM table WHERE Name = Sussex", "source": "test"}
|
| 88 |
+
{"text": "### Question:\nWhat Title has a Role of Mylene?\n\n### SQL:\nSELECT Title FROM table WHERE Role = mylene", "question": "What Title has a Role of Mylene?", "sql": "SELECT Title FROM table WHERE Role = mylene", "source": "test"}
|
| 89 |
+
{"text": "### Question:\nWho is the captain of Neil Warnock's team?\n\n### SQL:\nSELECT Team captain FROM table WHERE Manager = Neil Warnock", "question": "Who is the captain of Neil Warnock's team?", "sql": "SELECT Team captain FROM table WHERE Manager = Neil Warnock", "source": "test"}
|
| 90 |
+
{"text": "### Question:\nWhat is the winning % for the 2010 QF?\n\n### SQL:\nSELECT Win % FROM table WHERE 2010 = qf", "question": "What is the winning % for the 2010 QF?", "sql": "SELECT Win % FROM table WHERE 2010 = qf", "source": "test"}
|
| 91 |
+
{"text": "### Question:\nWhat's the court ranking of 5th son of tadayori and has revenues of 10,000 koku?\n\n### SQL:\nSELECT Court Rank FROM table WHERE Revenues = 10,000 koku AND Lineage = 5th son of tadayori", "question": "What's the court ranking of 5th son of tadayori and has revenues of 10,000 koku?", "sql": "SELECT Court Rank FROM table WHERE Revenues = 10,000 koku AND Lineage = 5th son of tadayori", "source": "test"}
|
| 92 |
+
{"text": "### Question:\nWhich race was on the Las Vegas Motor Speedway for 2 hours?\n\n### SQL:\nSELECT Race FROM table WHERE Circuit = las vegas motor speedway AND Length = 2 hours", "question": "Which race was on the Las Vegas Motor Speedway for 2 hours?", "sql": "SELECT Race FROM table WHERE Circuit = las vegas motor speedway AND Length = 2 hours", "source": "test"}
|
| 93 |
+
{"text": "### Question:\nWhat venue hsoted the european cross country championships with a notes of junior men individual 6.595km?\n\n### SQL:\nSELECT Venue FROM table WHERE Competition = european cross country championships AND Notes = junior men individual 6.595km", "question": "What venue hsoted the european cross country championships with a notes of junior men individual 6.595km?", "sql": "SELECT Venue FROM table WHERE Competition = european cross country championships AND Notes = junior men individual 6.595km", "source": "test"}
|
| 94 |
+
{"text": "### Question:\nWhat was the Opponent in Week 9?\n\n### SQL:\nSELECT Opponent FROM table WHERE Week = 9", "question": "What was the Opponent in Week 9?", "sql": "SELECT Opponent FROM table WHERE Week = 9", "source": "test"}
|
| 95 |
+
{"text": "### Question:\nBefore round 7, what is the greatest Pick # for a player that plays defensive tackle?\n\n### SQL:\nSELECT MAX Pick # FROM table WHERE Position = defensive tackle AND Round < 7", "question": "Before round 7, what is the greatest Pick # for a player that plays defensive tackle?", "sql": "SELECT MAX Pick # FROM table WHERE Position = defensive tackle AND Round < 7", "source": "test"}
|
| 96 |
+
{"text": "### Question:\nWhat was the highest Pick for Lonnie Brockman before round 9?\n\n### SQL:\nSELECT MAX Pick FROM table WHERE Player = lonnie brockman AND Round < 9", "question": "What was the highest Pick for Lonnie Brockman before round 9?", "sql": "SELECT MAX Pick FROM table WHERE Player = lonnie brockman AND Round < 9", "source": "test"}
|
| 97 |
+
{"text": "### Question:\nWhen was brian lemay born?\n\n### SQL:\nSELECT Date of Birth (Age) FROM table WHERE Player = brian lemay", "question": "When was brian lemay born?", "sql": "SELECT Date of Birth (Age) FROM table WHERE Player = brian lemay", "source": "test"}
|
| 98 |
+
{"text": "### Question:\nWhat was the time of the NJKF Titans Neo X event?\n\n### SQL:\nSELECT Time FROM table WHERE Event = njkf titans neo x", "question": "What was the time of the NJKF Titans Neo X event?", "sql": "SELECT Time FROM table WHERE Event = njkf titans neo x", "source": "test"}
|
| 99 |
+
{"text": "### Question:\nWhich kit maker have Trond Sollied as a manager?\n\n### SQL:\nSELECT Kit maker FROM table WHERE Manager = trond sollied", "question": "Which kit maker have Trond Sollied as a manager?", "sql": "SELECT Kit maker FROM table WHERE Manager = trond sollied", "source": "test"}
|
| 100 |
+
{"text": "### Question:\nWhich Locomotive Entered Service in November 1984 and has an Operator of Southern Shorthaul Railroad?\n\n### SQL:\nSELECT Locomotive FROM table WHERE Operator = southern shorthaul railroad AND Entered service = november 1984", "question": "Which Locomotive Entered Service in November 1984 and has an Operator of Southern Shorthaul Railroad?", "sql": "SELECT Locomotive FROM table WHERE Operator = southern shorthaul railroad AND Entered service = november 1984", "source": "test"}
|
src/outputs/finetuning/train.jsonl
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{"text": "### Question:\nWhat was the year when Ch\u016bnqi\u016b Ch\u00e1sh\u00ec (\u6625\u79cb\u8336\u5ba4) was submitted?\n\n### SQL:\nSELECT Year (Ceremony) FROM table WHERE Original title = Ch\u016bnqi\u016b ch\u00e1sh\u00ec (\u6625\u79cb\u8336\u5ba4)", "question": "What was the year when Ch\u016bnqi\u016b Ch\u00e1sh\u00ec (\u6625\u79cb\u8336\u5ba4) was submitted?", "sql": "SELECT Year (Ceremony) FROM table WHERE Original title = Ch\u016bnqi\u016b ch\u00e1sh\u00ec (\u6625\u79cb\u8336\u5ba4)", "source": "train"}
|
| 2 |
+
{"text": "### Question:\nWhat is Name, when Builder is \"Kerr Stuart\"?\n\n### SQL:\nSELECT Name FROM table WHERE Builder = kerr stuart", "question": "What is Name, when Builder is \"Kerr Stuart\"?", "sql": "SELECT Name FROM table WHERE Builder = kerr stuart", "source": "train"}
|
| 3 |
+
{"text": "### Question:\nWhat series had more than 10 Podiums?\n\n### SQL:\nSELECT Series FROM table WHERE Podiums > 10", "question": "What series had more than 10 Podiums?", "sql": "SELECT Series FROM table WHERE Podiums > 10", "source": "train"}
|
| 4 |
+
{"text": "### Question:\nFor what ceremony was \"Fire Dancer\" not nominated? \n\n### SQL:\nSELECT Year (Ceremony) FROM table WHERE Original title = Fire Dancer", "question": "For what ceremony was \"Fire Dancer\" not nominated? ", "sql": "SELECT Year (Ceremony) FROM table WHERE Original title = Fire Dancer", "source": "train"}
|
| 5 |
+
{"text": "### Question:\nHow many people went to the game with Indiana visiting?\n\n### SQL:\nSELECT Attendance FROM table WHERE Visitor = indiana", "question": "How many people went to the game with Indiana visiting?", "sql": "SELECT Attendance FROM table WHERE Visitor = indiana", "source": "train"}
|
| 6 |
+
{"text": "### Question:\nWith a swim (1.5km) of 18:55 and a run (10km) of 32:37, what is the trans 2?\n\n### SQL:\nSELECT Trans 2 FROM table WHERE Swim (1.5km) = 18:55 AND Run (10km) = 32:37", "question": "With a swim (1.5km) of 18:55 and a run (10km) of 32:37, what is the trans 2?", "sql": "SELECT Trans 2 FROM table WHERE Swim (1.5km) = 18:55 AND Run (10km) = 32:37", "source": "train"}
|
| 7 |
+
{"text": "### Question:\nWhat is the region for Chep\u00e9n with 3 districts?\n\n### SQL:\nSELECT Region FROM table WHERE Districts = 3 AND Capital = chep\u00e9n", "question": "What is the region for Chep\u00e9n with 3 districts?", "sql": "SELECT Region FROM table WHERE Districts = 3 AND Capital = chep\u00e9n", "source": "train"}
|
| 8 |
+
{"text": "### Question:\nWhat was the third place of the performance in 2006 with the host Japan?\n\n### SQL:\nSELECT Third place FROM table WHERE Host = japan AND Season < 2006", "question": "What was the third place of the performance in 2006 with the host Japan?", "sql": "SELECT Third place FROM table WHERE Host = japan AND Season < 2006", "source": "train"}
|
| 9 |
+
{"text": "### Question:\nWhat is the 1999-2000 team, when the Height (cm) is less than 187, and when the Birthplace is Cloquet, Minnesota?\n\n### SQL:\nSELECT 1999-2000 team FROM table WHERE Height (cm) < 187 AND Birthplace = cloquet, minnesota", "question": "What is the 1999-2000 team, when the Height (cm) is less than 187, and when the Birthplace is Cloquet, Minnesota?", "sql": "SELECT 1999-2000 team FROM table WHERE Height (cm) < 187 AND Birthplace = cloquet, minnesota", "source": "train"}
|
| 10 |
+
{"text": "### Question:\nWhat is the highest Tournaments, when Pro Debut is \"July 2002\"?\n\n### SQL:\nSELECT MAX Tournaments FROM table WHERE Pro Debut = july 2002", "question": "What is the highest Tournaments, when Pro Debut is \"July 2002\"?", "sql": "SELECT MAX Tournaments FROM table WHERE Pro Debut = july 2002", "source": "train"}
|
| 11 |
+
{"text": "### Question:\nwhat is the total time in office when the assumed office is 1 november 1856?\n\n### SQL:\nSELECT TOTAL Time in Office: FROM table WHERE Assumed Office: = 1 november 1856", "question": "what is the total time in office when the assumed office is 1 november 1856?", "sql": "SELECT TOTAL Time in Office: FROM table WHERE Assumed Office: = 1 november 1856", "source": "train"}
|
| 12 |
+
{"text": "### Question:\nWhen did the term end for the term that had government 27 and Minister Tzachi Hanegbi?\n\n### SQL:\nSELECT Term end FROM table WHERE Governments = 27 AND Minister = tzachi hanegbi", "question": "When did the term end for the term that had government 27 and Minister Tzachi Hanegbi?", "sql": "SELECT Term end FROM table WHERE Governments = 27 AND Minister = tzachi hanegbi", "source": "train"}
|
| 13 |
+
{"text": "### Question:\nWhat is the Outcome of the Doubles played on Carpet?\n\n### SQL:\nSELECT Outcome FROM table WHERE Surface = carpet", "question": "What is the Outcome of the Doubles played on Carpet?", "sql": "SELECT Outcome FROM table WHERE Surface = carpet", "source": "train"}
|
| 14 |
+
{"text": "### Question:\nName the venue for geelong away team\n\n### SQL:\nSELECT Venue FROM table WHERE Away team = geelong", "question": "Name the venue for geelong away team", "sql": "SELECT Venue FROM table WHERE Away team = geelong", "source": "train"}
|
| 15 |
+
{"text": "### Question:\nWhat was the score for the final played on 2 July 2012?\n\n### SQL:\nSELECT Score FROM table WHERE Date = 2 july 2012", "question": "What was the score for the final played on 2 July 2012?", "sql": "SELECT Score FROM table WHERE Date = 2 july 2012", "source": "train"}
|
| 16 |
+
{"text": "### Question:\nBetween November 25\u201330, 2008 the sellout rate was at 75%, indicating that the ration between shows to sellout was what?\n\n### SQL:\nSELECT Shows / Sellout FROM table WHERE Sellout (%) = 75%", "question": "Between November 25\u201330, 2008 the sellout rate was at 75%, indicating that the ration between shows to sellout was what?", "sql": "SELECT Shows / Sellout FROM table WHERE Sellout (%) = 75%", "source": "train"}
|
| 17 |
+
{"text": "### Question:\nWhat is the total ranking when there are less than 16 draws, less than 1 point, and the English translation is in love with you?\n\n### SQL:\nSELECT SUM Place FROM table WHERE Draw < 16 AND English translation = in love with you AND Points < 1", "question": "What is the total ranking when there are less than 16 draws, less than 1 point, and the English translation is in love with you?", "sql": "SELECT SUM Place FROM table WHERE Draw < 16 AND English translation = in love with you AND Points < 1", "source": "train"}
|
| 18 |
+
{"text": "### Question:\nWhich Ratio as % has a Ratio of 8/9?\n\n### SQL:\nSELECT Ratio as % FROM table WHERE Ratio = 8/9", "question": "Which Ratio as % has a Ratio of 8/9?", "sql": "SELECT Ratio as % FROM table WHERE Ratio = 8/9", "source": "train"}
|
| 19 |
+
{"text": "### Question:\nWhat is the Early Modern English phonology used in the example b\u014dg > \"bough\"; pl\u014dg > pl\u014dh > \"plough\"?\n\n### SQL:\nSELECT Early Modern English FROM table WHERE Example = b\u014dg > \"bough\"; pl\u014dg > pl\u014dh > \"plough\"", "question": "What is the Early Modern English phonology used in the example b\u014dg > \"bough\"; pl\u014dg > pl\u014dh > \"plough\"?", "sql": "SELECT Early Modern English FROM table WHERE Example = b\u014dg > \"bough\"; pl\u014dg > pl\u014dh > \"plough\"", "source": "train"}
|
| 20 |
+
{"text": "### Question:\nWhat is the greatest Wins with Matches smaller than 5, and a Year of 1994?\n\n### SQL:\nSELECT MAX Wins FROM table WHERE Matches < 5 AND Year = 1994", "question": "What is the greatest Wins with Matches smaller than 5, and a Year of 1994?", "sql": "SELECT MAX Wins FROM table WHERE Matches < 5 AND Year = 1994", "source": "train"}
|
| 21 |
+
{"text": "### Question:\nWhat ranking is the Battersea Power Station?\n\n### SQL:\nSELECT Rank FROM table WHERE Name = battersea power station", "question": "What ranking is the Battersea Power Station?", "sql": "SELECT Rank FROM table WHERE Name = battersea power station", "source": "train"}
|
| 22 |
+
{"text": "### Question:\nWhat pick number was the player that was picked by Edmonton?\n\n### SQL:\nSELECT Pick # FROM table WHERE CFL Team = Edmonton", "question": "What pick number was the player that was picked by Edmonton?", "sql": "SELECT Pick # FROM table WHERE CFL Team = Edmonton", "source": "train"}
|
| 23 |
+
{"text": "### Question:\nWhat was the location of the fight when Gassaway fought kevin knabjian?\n\n### SQL:\nSELECT Location FROM table WHERE Opponent = kevin knabjian", "question": "What was the location of the fight when Gassaway fought kevin knabjian?", "sql": "SELECT Location FROM table WHERE Opponent = kevin knabjian", "source": "train"}
|
| 24 |
+
{"text": "### Question:\nHow many points did the song \"Stille f\u00f8r stormen\" get?\n\n### SQL:\nSELECT MIN Total Points FROM table WHERE Song = \"Stille f\u00f8r stormen\"", "question": "How many points did the song \"Stille f\u00f8r stormen\" get?", "sql": "SELECT MIN Total Points FROM table WHERE Song = \"Stille f\u00f8r stormen\"", "source": "train"}
|
| 25 |
+
{"text": "### Question:\nWht years did truck robinson play?\n\n### SQL:\nSELECT Years for Jazz FROM table WHERE Player = Truck Robinson", "question": "Wht years did truck robinson play?", "sql": "SELECT Years for Jazz FROM table WHERE Player = Truck Robinson", "source": "train"}
|
| 26 |
+
{"text": "### Question:\nWhich Score has a Couple of cristi\u00e1n & cheryl, and a Style of cha-cha-cha?\n\n### SQL:\nSELECT Score FROM table WHERE Couple = cristi\u00e1n & cheryl AND Style = cha-cha-cha", "question": "Which Score has a Couple of cristi\u00e1n & cheryl, and a Style of cha-cha-cha?", "sql": "SELECT Score FROM table WHERE Couple = cristi\u00e1n & cheryl AND Style = cha-cha-cha", "source": "train"}
|
| 27 |
+
{"text": "### Question:\nWhat is the IUPAC name for chloroform?\n\n### SQL:\nSELECT IUPAC name FROM table WHERE Common name = chloroform", "question": "What is the IUPAC name for chloroform?", "sql": "SELECT IUPAC name FROM table WHERE Common name = chloroform", "source": "train"}
|
| 28 |
+
{"text": "### Question:\nWhat is the Constituency Number when the Number of Electorates (2003) is more than 156,910, and Reserved for sc?\n\n### SQL:\nSELECT Constituency number FROM table WHERE Number of electorates (2003) > 156,910 AND Reserved for ( SC / ST /None) = sc", "question": "What is the Constituency Number when the Number of Electorates (2003) is more than 156,910, and Reserved for sc?", "sql": "SELECT Constituency number FROM table WHERE Number of electorates (2003) > 156,910 AND Reserved for ( SC / ST /None) = sc", "source": "train"}
|
| 29 |
+
{"text": "### Question:\nWhat was John Jones's pick#?\n\n### SQL:\nSELECT SUM Pick FROM table WHERE Player = john jones", "question": "What was John Jones's pick#?", "sql": "SELECT SUM Pick FROM table WHERE Player = john jones", "source": "train"}
|
| 30 |
+
{"text": "### Question:\nWhich runner(s)-up had a Winning score of \u201313 (68-70-66-71=275) and a Margin of victory of 3 strokes?\n\n### SQL:\nSELECT Runner(s)-up FROM table WHERE Margin of victory = 3 strokes AND Winning score = \u201313 (68-70-66-71=275)", "question": "Which runner(s)-up had a Winning score of \u201313 (68-70-66-71=275) and a Margin of victory of 3 strokes?", "sql": "SELECT Runner(s)-up FROM table WHERE Margin of victory = 3 strokes AND Winning score = \u201313 (68-70-66-71=275)", "source": "train"}
|
| 31 |
+
{"text": "### Question:\nWhat is the number of people in attendance when Tonbridge Angels is the opponent?\n\n### SQL:\nSELECT Attendance FROM table WHERE Opponent = tonbridge angels", "question": "What is the number of people in attendance when Tonbridge Angels is the opponent?", "sql": "SELECT Attendance FROM table WHERE Opponent = tonbridge angels", "source": "train"}
|
| 32 |
+
{"text": "### Question:\nWhich letter has the British a\u026a?\n\n### SQL:\nSELECT Letter FROM table WHERE British = a\u026a", "question": "Which letter has the British a\u026a?", "sql": "SELECT Letter FROM table WHERE British = a\u026a", "source": "train"}
|
| 33 |
+
{"text": "### Question:\nIn which city was the berlin marathon?\n\n### SQL:\nSELECT Location FROM table WHERE Road race = Berlin Marathon", "question": "In which city was the berlin marathon?", "sql": "SELECT Location FROM table WHERE Road race = Berlin Marathon", "source": "train"}
|
| 34 |
+
{"text": "### Question:\nName the year 2007 for 668 2008-q1\n\n### SQL:\nSELECT year 2007 FROM table WHERE 2008 - Q1 = 668", "question": "Name the year 2007 for 668 2008-q1", "sql": "SELECT year 2007 FROM table WHERE 2008 - Q1 = 668", "source": "train"}
|
| 35 |
+
{"text": "### Question:\nWhat was Collingwood's score when they played against North Melbourne at home?\n\n### SQL:\nSELECT Home team score FROM table WHERE Away team = north melbourne", "question": "What was Collingwood's score when they played against North Melbourne at home?", "sql": "SELECT Home team score FROM table WHERE Away team = north melbourne", "source": "train"}
|
| 36 |
+
{"text": "### Question:\nWhich division were the Brewers a part of in the 1987 season?\n\n### SQL:\nSELECT Division FROM table WHERE Team season = 1987", "question": "Which division were the Brewers a part of in the 1987 season?", "sql": "SELECT Division FROM table WHERE Team season = 1987", "source": "train"}
|
| 37 |
+
{"text": "### Question:\nWhich TV Station has a Romaji Title of kegareta shita?\n\n### SQL:\nSELECT TV Station FROM table WHERE Romaji Title = kegareta shita", "question": "Which TV Station has a Romaji Title of kegareta shita?", "sql": "SELECT TV Station FROM table WHERE Romaji Title = kegareta shita", "source": "train"}
|
| 38 |
+
{"text": "### Question:\nTell me the notes with method of points and event of adcc 2001 absolute with result of loss\n\n### SQL:\nSELECT Notes FROM table WHERE Method = points AND Event = adcc 2001 absolute AND Result = loss", "question": "Tell me the notes with method of points and event of adcc 2001 absolute with result of loss", "sql": "SELECT Notes FROM table WHERE Method = points AND Event = adcc 2001 absolute AND Result = loss", "source": "train"}
|
| 39 |
+
{"text": "### Question:\nWhat is the highest value for SF round for the country of England?\n\n### SQL:\nSELECT MAX SF Round FROM table WHERE Country = England", "question": "What is the highest value for SF round for the country of England?", "sql": "SELECT MAX SF Round FROM table WHERE Country = England", "source": "train"}
|
| 40 |
+
{"text": "### Question:\nWhat country id Bob Rosburg from?\n\n### SQL:\nSELECT Country FROM table WHERE Player = bob rosburg", "question": "What country id Bob Rosburg from?", "sql": "SELECT Country FROM table WHERE Player = bob rosburg", "source": "train"}
|
| 41 |
+
{"text": "### Question:\nWho is the Alternate for Sweden?\n\n### SQL:\nSELECT Alternate FROM table WHERE Nation = sweden", "question": "Who is the Alternate for Sweden?", "sql": "SELECT Alternate FROM table WHERE Nation = sweden", "source": "train"}
|
| 42 |
+
{"text": "### Question:\nWhich Format has a Frequency of 100.5 fm?\n\n### SQL:\nSELECT Format FROM table WHERE Frequency = 100.5 fm", "question": "Which Format has a Frequency of 100.5 fm?", "sql": "SELECT Format FROM table WHERE Frequency = 100.5 fm", "source": "train"}
|
| 43 |
+
{"text": "### Question:\nWhat country is Lee Janzen from?\n\n### SQL:\nSELECT Country FROM table WHERE Player = lee janzen", "question": "What country is Lee Janzen from?", "sql": "SELECT Country FROM table WHERE Player = lee janzen", "source": "train"}
|
| 44 |
+
{"text": "### Question:\nHead Coach casemiro mior is at which Club?\n\n### SQL:\nSELECT Club FROM table WHERE Head Coach = casemiro mior", "question": "Head Coach casemiro mior is at which Club?", "sql": "SELECT Club FROM table WHERE Head Coach = casemiro mior", "source": "train"}
|
| 45 |
+
{"text": "### Question:\nOn which date was the opponent the Chicago Bears?\n\n### SQL:\nSELECT Date FROM table WHERE Opponent = chicago bears", "question": "On which date was the opponent the Chicago Bears?", "sql": "SELECT Date FROM table WHERE Opponent = chicago bears", "source": "train"}
|
| 46 |
+
{"text": "### Question:\nWhich languages are offered in the coverage area of klang petaling jaya shah alam?\n\n### SQL:\nSELECT Language FROM table WHERE Coverage Area = Klang Petaling Jaya Shah Alam", "question": "Which languages are offered in the coverage area of klang petaling jaya shah alam?", "sql": "SELECT Language FROM table WHERE Coverage Area = Klang Petaling Jaya Shah Alam", "source": "train"}
|
| 47 |
+
{"text": "### Question:\nWhat is the to par for Jiyai Shin when the place is t1?\n\n### SQL:\nSELECT To par FROM table WHERE Place = t1 AND Player = jiyai shin", "question": "What is the to par for Jiyai Shin when the place is t1?", "sql": "SELECT To par FROM table WHERE Place = t1 AND Player = jiyai shin", "source": "train"}
|
| 48 |
+
{"text": "### Question:\nWhat is the total number of Division(s), when Team is Chongqing Lifan, and when Apps is greater than 9?\n\n### SQL:\nSELECT COUNT Division FROM table WHERE Team = chongqing lifan AND Apps > 9", "question": "What is the total number of Division(s), when Team is Chongqing Lifan, and when Apps is greater than 9?", "sql": "SELECT COUNT Division FROM table WHERE Team = chongqing lifan AND Apps > 9", "source": "train"}
|
| 49 |
+
{"text": "### Question:\nWhat is Mike Weir's To par?\n\n### SQL:\nSELECT To par FROM table WHERE Player = mike weir", "question": "What is Mike Weir's To par?", "sql": "SELECT To par FROM table WHERE Player = mike weir", "source": "train"}
|
| 50 |
+
{"text": "### Question:\nWhat was the away team when the home was st kilda?\n\n### SQL:\nSELECT Away team FROM table WHERE Home team = st kilda", "question": "What was the away team when the home was st kilda?", "sql": "SELECT Away team FROM table WHERE Home team = st kilda", "source": "train"}
|
| 51 |
+
{"text": "### Question:\nWhat is the % of same-sex marriages for the year of 2011?\n\n### SQL:\nSELECT % same-sex marriages FROM table WHERE Year = 2011", "question": "What is the % of same-sex marriages for the year of 2011?", "sql": "SELECT % same-sex marriages FROM table WHERE Year = 2011", "source": "train"}
|
| 52 |
+
{"text": "### Question:\nWhat Place has a To par of \u20134?\n\n### SQL:\nSELECT Place FROM table WHERE To par = \u20134", "question": "What Place has a To par of \u20134?", "sql": "SELECT Place FROM table WHERE To par = \u20134", "source": "train"}
|
| 53 |
+
{"text": "### Question:\nName the district for 1994\n\n### SQL:\nSELECT District FROM table WHERE First elected = 1994", "question": "Name the district for 1994", "sql": "SELECT District FROM table WHERE First elected = 1994", "source": "train"}
|
| 54 |
+
{"text": "### Question:\nWhich Body Width/mm has a Lead Pitch/mm smaller than 0.55, and a Part Number of tsop48?\n\n### SQL:\nSELECT MIN Body Width/mm FROM table WHERE Lead Pitch/mm < 0.55 AND Part Number = tsop48", "question": "Which Body Width/mm has a Lead Pitch/mm smaller than 0.55, and a Part Number of tsop48?", "sql": "SELECT MIN Body Width/mm FROM table WHERE Lead Pitch/mm < 0.55 AND Part Number = tsop48", "source": "train"}
|
| 55 |
+
{"text": "### Question:\nWhat is the lowest number of episodes for anabel barnston?\n\n### SQL:\nSELECT MIN Episodes FROM table WHERE Actor = anabel barnston", "question": "What is the lowest number of episodes for anabel barnston?", "sql": "SELECT MIN Episodes FROM table WHERE Actor = anabel barnston", "source": "train"}
|
| 56 |
+
{"text": "### Question:\nFor what league was the player in G position drafted?\n\n### SQL:\nSELECT League from FROM table WHERE Position = g", "question": "For what league was the player in G position drafted?", "sql": "SELECT League from FROM table WHERE Position = g", "source": "train"}
|
| 57 |
+
{"text": "### Question:\nWhat is the last episode which has segment d as blown glass?\n\n### SQL:\nSELECT MAX Episode FROM table WHERE Segment D = Blown Glass", "question": "What is the last episode which has segment d as blown glass?", "sql": "SELECT MAX Episode FROM table WHERE Segment D = Blown Glass", "source": "train"}
|
| 58 |
+
{"text": "### Question:\nWhat is the date for the 10b serial?\n\n### SQL:\nSELECT Date FROM table WHERE Serial = 10b", "question": "What is the date for the 10b serial?", "sql": "SELECT Date FROM table WHERE Serial = 10b", "source": "train"}
|
| 59 |
+
{"text": "### Question:\nWhen \"we're going to disney world (part 1)\" is the title what is the air date?\n\n### SQL:\nSELECT Original air date FROM table WHERE Title = \"We're Going to Disney World (Part 1)\"", "question": "When \"we're going to disney world (part 1)\" is the title what is the air date?", "sql": "SELECT Original air date FROM table WHERE Title = \"We're Going to Disney World (Part 1)\"", "source": "train"}
|
| 60 |
+
{"text": "### Question:\nHow did the School/Club Team of Manuel Luis Quezon acquire their Forward?\n\n### SQL:\nSELECT Acquisition via FROM table WHERE Position = forward AND School/Club Team = manuel luis quezon", "question": "How did the School/Club Team of Manuel Luis Quezon acquire their Forward?", "sql": "SELECT Acquisition via FROM table WHERE Position = forward AND School/Club Team = manuel luis quezon", "source": "train"}
|
| 61 |
+
{"text": "### Question:\nWhat is the Loss has an Attendance more than 43,095 and a Record of 31\u201329?\n\n### SQL:\nSELECT Loss FROM table WHERE Attendance > 43,095 AND Record = 31\u201329", "question": "What is the Loss has an Attendance more than 43,095 and a Record of 31\u201329?", "sql": "SELECT Loss FROM table WHERE Attendance > 43,095 AND Record = 31\u201329", "source": "train"}
|
| 62 |
+
{"text": "### Question:\nWhat is the ethernet ports of the u10 appliance?\n\n### SQL:\nSELECT Ethernet Ports FROM table WHERE Name = u10", "question": "What is the ethernet ports of the u10 appliance?", "sql": "SELECT Ethernet Ports FROM table WHERE Name = u10", "source": "train"}
|
| 63 |
+
{"text": "### Question:\nName the 2007 for 2005 of a and 003 of a with 2009 of sf\n\n### SQL:\nSELECT 2007 FROM table WHERE 2005 = a AND 2003 = a AND 2009 = sf", "question": "Name the 2007 for 2005 of a and 003 of a with 2009 of sf", "sql": "SELECT 2007 FROM table WHERE 2005 = a AND 2003 = a AND 2009 = sf", "source": "train"}
|
| 64 |
+
{"text": "### Question:\nWhat is the constructor for the race with Nigel Mansell as the fastest lap?\n\n### SQL:\nSELECT Constructor FROM table WHERE Fastest Lap = nigel mansell", "question": "What is the constructor for the race with Nigel Mansell as the fastest lap?", "sql": "SELECT Constructor FROM table WHERE Fastest Lap = nigel mansell", "source": "train"}
|
| 65 |
+
{"text": "### Question:\nWhich institution's nickname is the Polar Bears?\n\n### SQL:\nSELECT Institution FROM table WHERE Nickname = Polar Bears", "question": "Which institution's nickname is the Polar Bears?", "sql": "SELECT Institution FROM table WHERE Nickname = Polar Bears", "source": "train"}
|
| 66 |
+
{"text": "### Question:\nWhich Away team has an Away team score of 11.18 (84)?\n\n### SQL:\nSELECT Away team FROM table WHERE Away team score = 11.18 (84)", "question": "Which Away team has an Away team score of 11.18 (84)?", "sql": "SELECT Away team FROM table WHERE Away team score = 11.18 (84)", "source": "train"}
|
| 67 |
+
{"text": "### Question:\nwhat is the current version with license gpl v3?\n\n### SQL:\nSELECT Current version FROM table WHERE License = gpl v3", "question": "what is the current version with license gpl v3?", "sql": "SELECT Current version FROM table WHERE License = gpl v3", "source": "train"}
|
| 68 |
+
{"text": "### Question:\nWhat was the record on April 1?\n\n### SQL:\nSELECT Record FROM table WHERE Date = april 1", "question": "What was the record on April 1?", "sql": "SELECT Record FROM table WHERE Date = april 1", "source": "train"}
|
| 69 |
+
{"text": "### Question:\nWhich Winning score has a Margin of victory of 1 stroke, and a Date of 21 jun 1981?\n\n### SQL:\nSELECT Winning score FROM table WHERE Margin of victory = 1 stroke AND Date = 21 jun 1981", "question": "Which Winning score has a Margin of victory of 1 stroke, and a Date of 21 jun 1981?", "sql": "SELECT Winning score FROM table WHERE Margin of victory = 1 stroke AND Date = 21 jun 1981", "source": "train"}
|
| 70 |
+
{"text": "### Question:\nWhich Avoirdupois value is translated to grain?\n\n### SQL:\nSELECT Avoirdupois value FROM table WHERE Translation = grain", "question": "Which Avoirdupois value is translated to grain?", "sql": "SELECT Avoirdupois value FROM table WHERE Translation = grain", "source": "train"}
|
| 71 |
+
{"text": "### Question:\nName the most margin for nco party and p. ramachandran won\n\n### SQL:\nSELECT MAX Margin FROM table WHERE Party = NCO AND Winner = P. Ramachandran", "question": "Name the most margin for nco party and p. ramachandran won", "sql": "SELECT MAX Margin FROM table WHERE Party = NCO AND Winner = P. Ramachandran", "source": "train"}
|
| 72 |
+
{"text": "### Question:\nWhat is the Catalog with a Date that is february 20, 2002?\n\n### SQL:\nSELECT Catalog FROM table WHERE Date = february 20, 2002", "question": "What is the Catalog with a Date that is february 20, 2002?", "sql": "SELECT Catalog FROM table WHERE Date = february 20, 2002", "source": "train"}
|
| 73 |
+
{"text": "### Question:\nWho is the driver of the chassis-engine porsche 956 gti?\n\n### SQL:\nSELECT Driver FROM table WHERE Chassis \u2013 Engine = porsche 956 gti", "question": "Who is the driver of the chassis-engine porsche 956 gti?", "sql": "SELECT Driver FROM table WHERE Chassis \u2013 Engine = porsche 956 gti", "source": "train"}
|
| 74 |
+
{"text": "### Question:\nWhat is the 1st party with Charles Isaac Elton as the 2nd member?\n\n### SQL:\nSELECT 1st Party FROM table WHERE 2nd Member = charles isaac elton", "question": "What is the 1st party with Charles Isaac Elton as the 2nd member?", "sql": "SELECT 1st Party FROM table WHERE 2nd Member = charles isaac elton", "source": "train"}
|
| 75 |
+
{"text": "### Question:\nWhat is the earliest date of the game with a score of 2-2?\n\n### SQL:\nSELECT MIN Date FROM table WHERE Score = 2-2", "question": "What is the earliest date of the game with a score of 2-2?", "sql": "SELECT MIN Date FROM table WHERE Score = 2-2", "source": "train"}
|
| 76 |
+
{"text": "### Question:\nWhat was the original nfl team that the player was in from the midwestern conference?\n\n### SQL:\nSELECT Original NFL team FROM table WHERE Conf. = midwestern", "question": "What was the original nfl team that the player was in from the midwestern conference?", "sql": "SELECT Original NFL team FROM table WHERE Conf. = midwestern", "source": "train"}
|
| 77 |
+
{"text": "### Question:\nName the total number of domestic mail for 7853 for total frieght and mail\n\n### SQL:\nSELECT COUNT Domestic mail FROM table WHERE Total freight and mail = 7853", "question": "Name the total number of domestic mail for 7853 for total frieght and mail", "sql": "SELECT COUNT Domestic mail FROM table WHERE Total freight and mail = 7853", "source": "train"}
|
| 78 |
+
{"text": "### Question:\nWhat time is listed against the Wrestler Jimmy Rave?\n\n### SQL:\nSELECT Time FROM table WHERE Wrestler = jimmy rave", "question": "What time is listed against the Wrestler Jimmy Rave?", "sql": "SELECT Time FROM table WHERE Wrestler = jimmy rave", "source": "train"}
|
| 79 |
+
{"text": "### Question:\nWhat's the listed average of Cuts made that has a Top-5 of 3, and a Top-10 that's smaller than 5?\n\n### SQL:\nSELECT AVG Cuts made FROM table WHERE Top-5 = 3 AND Top-10 < 5", "question": "What's the listed average of Cuts made that has a Top-5 of 3, and a Top-10 that's smaller than 5?", "sql": "SELECT AVG Cuts made FROM table WHERE Top-5 = 3 AND Top-10 < 5", "source": "train"}
|
| 80 |
+
{"text": "### Question:\nHow large was the crowd at Carlton's home game?\n\n### SQL:\nSELECT COUNT Crowd FROM table WHERE Home team = carlton", "question": "How large was the crowd at Carlton's home game?", "sql": "SELECT COUNT Crowd FROM table WHERE Home team = carlton", "source": "train"}
|
| 81 |
+
{"text": "### Question:\nWhat date did \"The runner\" originally air on?\n\n### SQL:\nSELECT Original air date FROM table WHERE Title = \"The Runner\"", "question": "What date did \"The runner\" originally air on?", "sql": "SELECT Original air date FROM table WHERE Title = \"The Runner\"", "source": "train"}
|
| 82 |
+
{"text": "### Question:\nWhen 10th, south west district 1 is the mens 2nd xi what is the ladies 1st xi?\n\n### SQL:\nSELECT Ladies 1st XI FROM table WHERE Mens 2nd XI = 10th, South West District 1", "question": "When 10th, south west district 1 is the mens 2nd xi what is the ladies 1st xi?", "sql": "SELECT Ladies 1st XI FROM table WHERE Mens 2nd XI = 10th, South West District 1", "source": "train"}
|
| 83 |
+
{"text": "### Question:\nWhat is Show, when Episode Number is 1, when Year is less than 2010, and when Original Airdate is January 20, 2008?\n\n### SQL:\nSELECT Show FROM table WHERE Episode number = 1 AND Year < 2010 AND Original airdate = january 20, 2008", "question": "What is Show, when Episode Number is 1, when Year is less than 2010, and when Original Airdate is January 20, 2008?", "sql": "SELECT Show FROM table WHERE Episode number = 1 AND Year < 2010 AND Original airdate = january 20, 2008", "source": "train"}
|
| 84 |
+
{"text": "### Question:\nWhat is the verb for the Proto-Austronesian word *diri?\n\n### SQL:\nSELECT Verb FROM table WHERE Proto-Austronesian = *diri", "question": "What is the verb for the Proto-Austronesian word *diri?", "sql": "SELECT Verb FROM table WHERE Proto-Austronesian = *diri", "source": "train"}
|
| 85 |
+
{"text": "### Question:\nWhat did winner Gary Player par?\n\n### SQL:\nSELECT To par FROM table WHERE Winner = gary player", "question": "What did winner Gary Player par?", "sql": "SELECT To par FROM table WHERE Winner = gary player", "source": "train"}
|
| 86 |
+
{"text": "### Question:\nWhat notes have 2 as the rank?\n\n### SQL:\nSELECT Notes FROM table WHERE Rank = 2", "question": "What notes have 2 as the rank?", "sql": "SELECT Notes FROM table WHERE Rank = 2", "source": "train"}
|
| 87 |
+
{"text": "### Question:\nWho was the Away Captain when the Home Captain was Joe Darling at Melbourne Cricket Ground?\n\n### SQL:\nSELECT Away captain FROM table WHERE Venue = melbourne cricket ground AND Home captain = joe darling", "question": "Who was the Away Captain when the Home Captain was Joe Darling at Melbourne Cricket Ground?", "sql": "SELECT Away captain FROM table WHERE Venue = melbourne cricket ground AND Home captain = joe darling", "source": "train"}
|
| 88 |
+
{"text": "### Question:\nWhich spacecraft were launched by the Titan II?\n\n### SQL:\nSELECT Spacecraft FROM table WHERE Launcher = titan ii", "question": "Which spacecraft were launched by the Titan II?", "sql": "SELECT Spacecraft FROM table WHERE Launcher = titan ii", "source": "train"}
|
| 89 |
+
{"text": "### Question:\nWhat shows for 3:30 pm when 12:30 pm is the young and the restless?\n\n### SQL:\nSELECT noon FROM table WHERE 12:30 pm = the young and the restless", "question": "What shows for 3:30 pm when 12:30 pm is the young and the restless?", "sql": "SELECT noon FROM table WHERE 12:30 pm = the young and the restless", "source": "train"}
|
| 90 |
+
{"text": "### Question:\nWhat's the fed tax that has a total tax greater than 33.2, a minimum sales tax less than 41.01 and in Vancouver, BC?\n\n### SQL:\nSELECT AVG Federal excise tax ( CAD\u00a2 / L ) FROM table WHERE Total excise tax (CAD\u00a2/L) > 33.2 AND Government = vancouver, bc AND Minimum tax incl. sales taxes (CAD\u00a2/L) < 41.01", "question": "What's the fed tax that has a total tax greater than 33.2, a minimum sales tax less than 41.01 and in Vancouver, BC?", "sql": "SELECT AVG Federal excise tax ( CAD\u00a2 / L ) FROM table WHERE Total excise tax (CAD\u00a2/L) > 33.2 AND Government = vancouver, bc AND Minimum tax incl. sales taxes (CAD\u00a2/L) < 41.01", "source": "train"}
|
| 91 |
+
{"text": "### Question:\nWhat is the highest total number?\n\n### SQL:\nSELECT MAX Total# FROM table", "question": "What is the highest total number?", "sql": "SELECT MAX Total# FROM table", "source": "train"}
|
| 92 |
+
{"text": "### Question:\nWhat is the score of the match that was against alberto berasategui?\n\n### SQL:\nSELECT Score in the final FROM table WHERE Opponent in the final = alberto berasategui", "question": "What is the score of the match that was against alberto berasategui?", "sql": "SELECT Score in the final FROM table WHERE Opponent in the final = alberto berasategui", "source": "train"}
|
| 93 |
+
{"text": "### Question:\nWhat is the newest Cap with a Goals stat larger than 17 and which was done by Brian Turner?\n\n### SQL:\nSELECT Most Recent Cap FROM table WHERE Goals > 17 AND Name = brian turner", "question": "What is the newest Cap with a Goals stat larger than 17 and which was done by Brian Turner?", "sql": "SELECT Most Recent Cap FROM table WHERE Goals > 17 AND Name = brian turner", "source": "train"}
|
| 94 |
+
{"text": "### Question:\nWhat was the losing score on September 1?\n\n### SQL:\nSELECT Loss FROM table WHERE Date = september 1", "question": "What was the losing score on September 1?", "sql": "SELECT Loss FROM table WHERE Date = september 1", "source": "train"}
|
| 95 |
+
{"text": "### Question:\nWhat is the score of the away team whose opponent scored 14.8 (92)?\n\n### SQL:\nSELECT Away team score FROM table WHERE Home team score = 14.8 (92)", "question": "What is the score of the away team whose opponent scored 14.8 (92)?", "sql": "SELECT Away team score FROM table WHERE Home team score = 14.8 (92)", "source": "train"}
|
| 96 |
+
{"text": "### Question:\nWhat are donor payments in the country where there are 12 children to 6 families (2 per family)?\n\n### SQL:\nSELECT Donor payment FROM table WHERE Children per donor = 12 children to 6 families (2 per family)", "question": "What are donor payments in the country where there are 12 children to 6 families (2 per family)?", "sql": "SELECT Donor payment FROM table WHERE Children per donor = 12 children to 6 families (2 per family)", "source": "train"}
|
| 97 |
+
{"text": "### Question:\nWhat was the original air date for the episode with production code 1wab06?\n\n### SQL:\nSELECT Originalairdate FROM table WHERE Production code = 1WAB06", "question": "What was the original air date for the episode with production code 1wab06?", "sql": "SELECT Originalairdate FROM table WHERE Production code = 1WAB06", "source": "train"}
|
| 98 |
+
{"text": "### Question:\nWhat was the total for David O'Callaghan, and a Tally of 1-9?\n\n### SQL:\nSELECT SUM Total FROM table WHERE Player = david o'callaghan AND Tally = 1-9", "question": "What was the total for David O'Callaghan, and a Tally of 1-9?", "sql": "SELECT SUM Total FROM table WHERE Player = david o'callaghan AND Tally = 1-9", "source": "train"}
|
| 99 |
+
{"text": "### Question:\nAttendance larger than 17,001, and a Date of june 15 had what decision?\n\n### SQL:\nSELECT Decision FROM table WHERE Attendance > 17,001 AND Date = june 15", "question": "Attendance larger than 17,001, and a Date of june 15 had what decision?", "sql": "SELECT Decision FROM table WHERE Attendance > 17,001 AND Date = june 15", "source": "train"}
|
| 100 |
+
{"text": "### Question:\nWhat is Album Artist, when Song is \"\"Something New\" (with Mint Royale and Class A)\"?\n\n### SQL:\nSELECT Album artist FROM table WHERE Song = \"something new\" (with mint royale and class a)", "question": "What is Album Artist, when Song is \"\"Something New\" (with Mint Royale and Class A)\"?", "sql": "SELECT Album artist FROM table WHERE Song = \"something new\" (with mint royale and class a)", "source": "train"}
|
src/outputs/finetuning/val.jsonl
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{"text": "### Question:\nName the play for 1976\n\n### SQL:\nSELECT Play FROM table WHERE Year = 1976", "question": "Name the play for 1976", "sql": "SELECT Play FROM table WHERE Year = 1976", "source": "validation"}
|
| 2 |
+
{"text": "### Question:\nwhat are all the playoffs for u.s. open cup in 1st round\n\n### SQL:\nSELECT Playoffs FROM table WHERE U.S. Open Cup = 1st Round", "question": "what are all the playoffs for u.s. open cup in 1st round", "sql": "SELECT Playoffs FROM table WHERE U.S. Open Cup = 1st Round", "source": "validation"}
|
| 3 |
+
{"text": "### Question:\nWhat is the location of the game that has a number smaller than 2?\n\n### SQL:\nSELECT Location FROM table WHERE Game < 2", "question": "What is the location of the game that has a number smaller than 2?", "sql": "SELECT Location FROM table WHERE Game < 2", "source": "validation"}
|
| 4 |
+
{"text": "### Question:\nWhat is 2004, when 2005 is \"Not Tier I\"?\n\n### SQL:\nSELECT 2004 FROM table WHERE 2005 = not tier i", "question": "What is 2004, when 2005 is \"Not Tier I\"?", "sql": "SELECT 2004 FROM table WHERE 2005 = not tier i", "source": "validation"}
|
| 5 |
+
{"text": "### Question:\nWhich venue led to a result of 13th and had an extra of Long Race?\n\n### SQL:\nSELECT Venue FROM table WHERE Extra = long race AND Result = 13th", "question": "Which venue led to a result of 13th and had an extra of Long Race?", "sql": "SELECT Venue FROM table WHERE Extra = long race AND Result = 13th", "source": "validation"}
|
| 6 |
+
{"text": "### Question:\nWhat was Anders Forsbrand's score when the TO par is +4?\n\n### SQL:\nSELECT Score FROM table WHERE To par = +4 AND Player = anders forsbrand", "question": "What was Anders Forsbrand's score when the TO par is +4?", "sql": "SELECT Score FROM table WHERE To par = +4 AND Player = anders forsbrand", "source": "validation"}
|
| 7 |
+
{"text": "### Question:\nWhat was the attendance of the game that had an away team of FK Mogren?\n\n### SQL:\nSELECT Attendance FROM table WHERE Guest = fk mogren", "question": "What was the attendance of the game that had an away team of FK Mogren?", "sql": "SELECT Attendance FROM table WHERE Guest = fk mogren", "source": "validation"}
|
| 8 |
+
{"text": "### Question:\nWhat is the Air Date that has a 18\u201349 larger than 1.9, less than 7.54 viewers and a rating less than 4.9?\n\n### SQL:\nSELECT Air Date FROM table WHERE 18\u201349 > 1.9 AND Viewers < 7.54 AND Rating < 4.9", "question": "What is the Air Date that has a 18\u201349 larger than 1.9, less than 7.54 viewers and a rating less than 4.9?", "sql": "SELECT Air Date FROM table WHERE 18\u201349 > 1.9 AND Viewers < 7.54 AND Rating < 4.9", "source": "validation"}
|
| 9 |
+
{"text": "### Question:\nWhich Avg/G is the lowest one that has a Long smaller than 47, and a Name of frank murphy, and a Gain smaller than 569?\n\n### SQL:\nSELECT MIN Avg/G FROM table WHERE Long < 47 AND Name = frank murphy AND Gain < 569", "question": "Which Avg/G is the lowest one that has a Long smaller than 47, and a Name of frank murphy, and a Gain smaller than 569?", "sql": "SELECT MIN Avg/G FROM table WHERE Long < 47 AND Name = frank murphy AND Gain < 569", "source": "validation"}
|
| 10 |
+
{"text": "### Question:\nWhich rank has 1 silver medal and more than 1 gold medal?\n\n### SQL:\nSELECT Rank FROM table WHERE Silver = 1 AND Gold > 1", "question": "Which rank has 1 silver medal and more than 1 gold medal?", "sql": "SELECT Rank FROM table WHERE Silver = 1 AND Gold > 1", "source": "validation"}
|
| 11 |
+
{"text": "### Question:\nName the number of candidates for # of seats won being 43\n\n### SQL:\nSELECT # of candidates FROM table WHERE # of seats won = 43", "question": "Name the number of candidates for # of seats won being 43", "sql": "SELECT # of candidates FROM table WHERE # of seats won = 43", "source": "validation"}
|
| 12 |
+
{"text": "### Question:\nWhat is the home team score at lake oval?\n\n### SQL:\nSELECT Home team score FROM table WHERE Venue = lake oval", "question": "What is the home team score at lake oval?", "sql": "SELECT Home team score FROM table WHERE Venue = lake oval", "source": "validation"}
|
| 13 |
+
{"text": "### Question:\nWhich loss has an attendance greater than 49,688 and 11-8 as the record?\n\n### SQL:\nSELECT Loss FROM table WHERE Attendance > 49,688 AND Record = 11-8", "question": "Which loss has an attendance greater than 49,688 and 11-8 as the record?", "sql": "SELECT Loss FROM table WHERE Attendance > 49,688 AND Record = 11-8", "source": "validation"}
|
| 14 |
+
{"text": "### Question:\nWhat is the sum of pick# for Don Majkowski?3\n\n### SQL:\nSELECT SUM Pick # FROM table WHERE Player = don majkowski", "question": "What is the sum of pick# for Don Majkowski?3", "sql": "SELECT SUM Pick # FROM table WHERE Player = don majkowski", "source": "validation"}
|
| 15 |
+
{"text": "### Question:\nWhat is the total number of wins for riders with fewer than 56 races and more than 0 titles?\n\n### SQL:\nSELECT COUNT Wins FROM table WHERE Races < 56 AND Titles > 0", "question": "What is the total number of wins for riders with fewer than 56 races and more than 0 titles?", "sql": "SELECT COUNT Wins FROM table WHERE Races < 56 AND Titles > 0", "source": "validation"}
|
| 16 |
+
{"text": "### Question:\nHow much did the girl, nicknamed Chidi, weigh at birth?\n\n### SQL:\nSELECT Weight at birth FROM table WHERE Gender = girl AND Nickname = chidi", "question": "How much did the girl, nicknamed Chidi, weigh at birth?", "sql": "SELECT Weight at birth FROM table WHERE Gender = girl AND Nickname = chidi", "source": "validation"}
|
| 17 |
+
{"text": "### Question:\nOn which apparatus did Kanayeva have a final score smaller than 75.5 and a qualifying score smaller than 18.7?\n\n### SQL:\nSELECT Apparatus FROM table WHERE Score-Final < 75.5 AND Score-Qualifying < 18.7", "question": "On which apparatus did Kanayeva have a final score smaller than 75.5 and a qualifying score smaller than 18.7?", "sql": "SELECT Apparatus FROM table WHERE Score-Final < 75.5 AND Score-Qualifying < 18.7", "source": "validation"}
|
| 18 |
+
{"text": "### Question:\nWhat are the rounds for the B tyres and Ferrari 053 engine +?\n\n### SQL:\nSELECT Rounds FROM table WHERE Tyre = b AND Engine \u2020 = ferrari 053", "question": "What are the rounds for the B tyres and Ferrari 053 engine +?", "sql": "SELECT Rounds FROM table WHERE Tyre = b AND Engine \u2020 = ferrari 053", "source": "validation"}
|
| 19 |
+
{"text": "### Question:\nWhat is the sexual abuse rate where the conflict is the Burundi Civil War?\n\n### SQL:\nSELECT MAX Sexual abuse 1 FROM table WHERE Conflict = Burundi Civil War", "question": "What is the sexual abuse rate where the conflict is the Burundi Civil War?", "sql": "SELECT MAX Sexual abuse 1 FROM table WHERE Conflict = Burundi Civil War", "source": "validation"}
|
| 20 |
+
{"text": "### Question:\nWhen 0-1 is the series who has the highest amount of assists?\n\n### SQL:\nSELECT High assists FROM table WHERE Series = 0-1", "question": "When 0-1 is the series who has the highest amount of assists?", "sql": "SELECT High assists FROM table WHERE Series = 0-1", "source": "validation"}
|
| 21 |
+
{"text": "### Question:\nWhich MLS team has the #41 pick?\n\n### SQL:\nSELECT MLS team FROM table WHERE Pick # = 41", "question": "Which MLS team has the #41 pick?", "sql": "SELECT MLS team FROM table WHERE Pick # = 41", "source": "validation"}
|
| 22 |
+
{"text": "### Question:\nWhat is the most bronze can be when silver is larger than 2, and the nation is germany, and gold is more than 8?\n\n### SQL:\nSELECT MAX Bronze FROM table WHERE Silver > 2 AND Nation = germany AND Gold > 8", "question": "What is the most bronze can be when silver is larger than 2, and the nation is germany, and gold is more than 8?", "sql": "SELECT MAX Bronze FROM table WHERE Silver > 2 AND Nation = germany AND Gold > 8", "source": "validation"}
|
| 23 |
+
{"text": "### Question:\nWhat dates contained matches at the venue Bourda?\n\n### SQL:\nSELECT Date FROM table WHERE Venue = bourda", "question": "What dates contained matches at the venue Bourda?", "sql": "SELECT Date FROM table WHERE Venue = bourda", "source": "validation"}
|
| 24 |
+
{"text": "### Question:\nWhat is the minimum amount of poles?\n\n### SQL:\nSELECT MIN Poles FROM table", "question": "What is the minimum amount of poles?", "sql": "SELECT MIN Poles FROM table", "source": "validation"}
|
| 25 |
+
{"text": "### Question:\nWhat is the average Episode # with a 7 share and 18\u201349 is less than 2 and the Air Date of may 21, 2009?\n\n### SQL:\nSELECT AVG Episode # FROM table WHERE Share = 7 AND 18\u201349 < 2 AND Air Date = may 21, 2009", "question": "What is the average Episode # with a 7 share and 18\u201349 is less than 2 and the Air Date of may 21, 2009?", "sql": "SELECT AVG Episode # FROM table WHERE Share = 7 AND 18\u201349 < 2 AND Air Date = may 21, 2009", "source": "validation"}
|
| 26 |
+
{"text": "### Question:\nWhat is the most lost games for the team with a difference smaller than 86 and points of 32?\n\n### SQL:\nSELECT MAX Lost FROM table WHERE Points = 32 AND Difference < 86", "question": "What is the most lost games for the team with a difference smaller than 86 and points of 32?", "sql": "SELECT MAX Lost FROM table WHERE Points = 32 AND Difference < 86", "source": "validation"}
|
| 27 |
+
{"text": "### Question:\nwhat's the\u00a0first elected\u00a0with\u00a0district\u00a0being florida 7\n\n### SQL:\nSELECT First elected FROM table WHERE District = Florida 7", "question": "what's the\u00a0first elected\u00a0with\u00a0district\u00a0being florida 7", "sql": "SELECT First elected FROM table WHERE District = Florida 7", "source": "validation"}
|
| 28 |
+
{"text": "### Question:\nWhat is the average number of points for a song ranked 2nd with a draw greater than 3?\n\n### SQL:\nSELECT AVG Points FROM table WHERE Rank = 2nd AND Draw > 3", "question": "What is the average number of points for a song ranked 2nd with a draw greater than 3?", "sql": "SELECT AVG Points FROM table WHERE Rank = 2nd AND Draw > 3", "source": "validation"}
|
| 29 |
+
{"text": "### Question:\nWhat is the destination when the train number is 16526?\n\n### SQL:\nSELECT Destination FROM table WHERE Train number = 16526", "question": "What is the destination when the train number is 16526?", "sql": "SELECT Destination FROM table WHERE Train number = 16526", "source": "validation"}
|
| 30 |
+
{"text": "### Question:\nWhat is the highest game that has 32 points and a team rank larger than 4 named montepaschi siena\n\n### SQL:\nSELECT MAX Games FROM table WHERE Points = 32 AND Team = montepaschi siena AND Rank > 4", "question": "What is the highest game that has 32 points and a team rank larger than 4 named montepaschi siena", "sql": "SELECT MAX Games FROM table WHERE Points = 32 AND Team = montepaschi siena AND Rank > 4", "source": "validation"}
|
| 31 |
+
{"text": "### Question:\nWhat is Episode, when Jeremy's Guest is \"Pauline McLynn\"?\n\n### SQL:\nSELECT Episode FROM table WHERE Jeremy's guest = pauline mclynn", "question": "What is Episode, when Jeremy's Guest is \"Pauline McLynn\"?", "sql": "SELECT Episode FROM table WHERE Jeremy's guest = pauline mclynn", "source": "validation"}
|
| 32 |
+
{"text": "### Question:\nWhat is the poor law union of the Kilmaloda townland?\n\n### SQL:\nSELECT Poor law union FROM table WHERE Townland = Kilmaloda", "question": "What is the poor law union of the Kilmaloda townland?", "sql": "SELECT Poor law union FROM table WHERE Townland = Kilmaloda", "source": "validation"}
|
| 33 |
+
{"text": "### Question:\nWhat is the largest pick in round 8?\n\n### SQL:\nSELECT MAX Pick FROM table WHERE Round = 8", "question": "What is the largest pick in round 8?", "sql": "SELECT MAX Pick FROM table WHERE Round = 8", "source": "validation"}
|
| 34 |
+
{"text": "### Question:\nOn what date was the attendance at TD Garden 18,624?\n\n### SQL:\nSELECT Date FROM table WHERE Location Attendance = TD Garden 18,624", "question": "On what date was the attendance at TD Garden 18,624?", "sql": "SELECT Date FROM table WHERE Location Attendance = TD Garden 18,624", "source": "validation"}
|
| 35 |
+
{"text": "### Question:\nWhat is canada's margin?\n\n### SQL:\nSELECT SUM Margin FROM table WHERE Country = canada", "question": "What is canada's margin?", "sql": "SELECT SUM Margin FROM table WHERE Country = canada", "source": "validation"}
|
| 36 |
+
{"text": "### Question:\nWhat Sweet Sixteen team is in the Colonial conference?\n\n### SQL:\nSELECT Sweet Sixteen FROM table WHERE Conference = colonial", "question": "What Sweet Sixteen team is in the Colonial conference?", "sql": "SELECT Sweet Sixteen FROM table WHERE Conference = colonial", "source": "validation"}
|
| 37 |
+
{"text": "### Question:\nHow many resorts have 118 runs?\n\n### SQL:\nSELECT COUNT Name FROM table WHERE Runs = 118", "question": "How many resorts have 118 runs?", "sql": "SELECT COUNT Name FROM table WHERE Runs = 118", "source": "validation"}
|
| 38 |
+
{"text": "### Question:\nWho was the winner against Lindsay Davenport?\n\n### SQL:\nSELECT Winner FROM table WHERE Finalist = lindsay davenport", "question": "Who was the winner against Lindsay Davenport?", "sql": "SELECT Winner FROM table WHERE Finalist = lindsay davenport", "source": "validation"}
|
| 39 |
+
{"text": "### Question:\nHow many laps for a grid larger than 1 with a Time/Retired of halfshaft?\n\n### SQL:\nSELECT Laps FROM table WHERE Grid > 1 AND Time/Retired = halfshaft", "question": "How many laps for a grid larger than 1 with a Time/Retired of halfshaft?", "sql": "SELECT Laps FROM table WHERE Grid > 1 AND Time/Retired = halfshaft", "source": "validation"}
|
| 40 |
+
{"text": "### Question:\nIn what year was the feature at a 33.3S latitude named? \n\n### SQL:\nSELECT MAX Year named FROM table WHERE Latitude = 33.3S", "question": "In what year was the feature at a 33.3S latitude named? ", "sql": "SELECT MAX Year named FROM table WHERE Latitude = 33.3S", "source": "validation"}
|
| 41 |
+
{"text": "### Question:\nWhich Thirds (Under 17's) have a Reserve of barnawartha?\n\n### SQL:\nSELECT Thirds (Under 17's) FROM table WHERE Reserves = barnawartha", "question": "Which Thirds (Under 17's) have a Reserve of barnawartha?", "sql": "SELECT Thirds (Under 17's) FROM table WHERE Reserves = barnawartha", "source": "validation"}
|
| 42 |
+
{"text": "### Question:\nWhat was the outcome of the match against Stacy Margolin?\n\n### SQL:\nSELECT Outcome FROM table WHERE Opponent = stacy margolin", "question": "What was the outcome of the match against Stacy Margolin?", "sql": "SELECT Outcome FROM table WHERE Opponent = stacy margolin", "source": "validation"}
|
| 43 |
+
{"text": "### Question:\nIf the working force of HK is 10.4%, what is the salary range?\n\n### SQL:\nSELECT Salary range FROM table WHERE Working force of HK = 10.4%", "question": "If the working force of HK is 10.4%, what is the salary range?", "sql": "SELECT Salary range FROM table WHERE Working force of HK = 10.4%", "source": "validation"}
|
| 44 |
+
{"text": "### Question:\nWhat is the sum of the pick from texas a&i college with a round greater than 1?\n\n### SQL:\nSELECT SUM Pick FROM table WHERE College = texas a&i AND Round > 1", "question": "What is the sum of the pick from texas a&i college with a round greater than 1?", "sql": "SELECT SUM Pick FROM table WHERE College = texas a&i AND Round > 1", "source": "validation"}
|
| 45 |
+
{"text": "### Question:\nWhich Second has a Lead of ben hebert?\n\n### SQL:\nSELECT Second FROM table WHERE Lead = ben hebert", "question": "Which Second has a Lead of ben hebert?", "sql": "SELECT Second FROM table WHERE Lead = ben hebert", "source": "validation"}
|
| 46 |
+
{"text": "### Question:\nWhich Genre has a Game of donkey kong country?\n\n### SQL:\nSELECT Genre FROM table WHERE Game = donkey kong country", "question": "Which Genre has a Game of donkey kong country?", "sql": "SELECT Genre FROM table WHERE Game = donkey kong country", "source": "validation"}
|
| 47 |
+
{"text": "### Question:\nWhat is the location of the Carousel toll plaza?\n\n### SQL:\nSELECT Location FROM table WHERE Name = Carousel Toll Plaza", "question": "What is the location of the Carousel toll plaza?", "sql": "SELECT Location FROM table WHERE Name = Carousel Toll Plaza", "source": "validation"}
|
| 48 |
+
{"text": "### Question:\nWhat is Turkey's average Gold entry that also has a Bronze entry that is smaller than 2 and the Total is greater than 1?\n\n### SQL:\nSELECT AVG Gold FROM table WHERE Bronze < 2 AND Nation = turkey AND Total > 1", "question": "What is Turkey's average Gold entry that also has a Bronze entry that is smaller than 2 and the Total is greater than 1?", "sql": "SELECT AVG Gold FROM table WHERE Bronze < 2 AND Nation = turkey AND Total > 1", "source": "validation"}
|
| 49 |
+
{"text": "### Question:\nWhich Class has a Quantity made of 29?\n\n### SQL:\nSELECT Class FROM table WHERE Quantity made = 29", "question": "Which Class has a Quantity made of 29?", "sql": "SELECT Class FROM table WHERE Quantity made = 29", "source": "validation"}
|
| 50 |
+
{"text": "### Question:\nWhich Oberliga Bayern has a Season of 1981-82?\n\n### SQL:\nSELECT Oberliga Bayern FROM table WHERE Season = 1981-82", "question": "Which Oberliga Bayern has a Season of 1981-82?", "sql": "SELECT Oberliga Bayern FROM table WHERE Season = 1981-82", "source": "validation"}
|
| 51 |
+
{"text": "### Question:\nWhat is the number of podiums with 0 wins, 0 F.L. and 35 points?\n\n### SQL:\nSELECT Podiums FROM table WHERE Wins = 0 AND F.L. = 0 AND Points = 35", "question": "What is the number of podiums with 0 wins, 0 F.L. and 35 points?", "sql": "SELECT Podiums FROM table WHERE Wins = 0 AND F.L. = 0 AND Points = 35", "source": "validation"}
|
| 52 |
+
{"text": "### Question:\nWho was the publisher of Martial Law: Dead Ringers?\n\n### SQL:\nSELECT Publisher FROM table WHERE Release title = martial law: dead ringers", "question": "Who was the publisher of Martial Law: Dead Ringers?", "sql": "SELECT Publisher FROM table WHERE Release title = martial law: dead ringers", "source": "validation"}
|
| 53 |
+
{"text": "### Question:\nWhat is the Almali village with the S\u00fcsk\u0259n village z\u0259rn\u0259?\n\n### SQL:\nSELECT Almal\u0131 (Qax) FROM table WHERE S\u00fcsk\u0259n = z\u0259rn\u0259", "question": "What is the Almali village with the S\u00fcsk\u0259n village z\u0259rn\u0259?", "sql": "SELECT Almal\u0131 (Qax) FROM table WHERE S\u00fcsk\u0259n = z\u0259rn\u0259", "source": "validation"}
|
| 54 |
+
{"text": "### Question:\nName the typed for formed from 6-pul trailer third in res unit\n\n### SQL:\nSELECT Type FROM table WHERE Formed from = 6-pul trailer third in res unit", "question": "Name the typed for formed from 6-pul trailer third in res unit", "sql": "SELECT Type FROM table WHERE Formed from = 6-pul trailer third in res unit", "source": "validation"}
|
| 55 |
+
{"text": "### Question:\nName the 2009/10 with 2011/12 of lq and 2008/09 of not held\n\n### SQL:\nSELECT 2009/ 10 FROM table WHERE 2011/ 12 = lq AND 2008/ 09 = not held", "question": "Name the 2009/10 with 2011/12 of lq and 2008/09 of not held", "sql": "SELECT 2009/ 10 FROM table WHERE 2011/ 12 = lq AND 2008/ 09 = not held", "source": "validation"}
|
| 56 |
+
{"text": "### Question:\nWhat is the tyres for the JBW type 2 chassis?\n\n### SQL:\nSELECT Tyres FROM table WHERE Chassis = jbw type 2", "question": "What is the tyres for the JBW type 2 chassis?", "sql": "SELECT Tyres FROM table WHERE Chassis = jbw type 2", "source": "validation"}
|
| 57 |
+
{"text": "### Question:\nHow many total appearances (league only) have a name of gavin dykes?\n\n### SQL:\nSELECT Total Appearances(league only) FROM table WHERE Name = gavin dykes", "question": "How many total appearances (league only) have a name of gavin dykes?", "sql": "SELECT Total Appearances(league only) FROM table WHERE Name = gavin dykes", "source": "validation"}
|
| 58 |
+
{"text": "### Question:\nWhat is the sum of laps that has a car number of larger than 1, is a ford, and has 155 points?\n\n### SQL:\nSELECT SUM Laps FROM table WHERE Car # > 1 AND Make = ford AND Points = 155", "question": "What is the sum of laps that has a car number of larger than 1, is a ford, and has 155 points?", "sql": "SELECT SUM Laps FROM table WHERE Car # > 1 AND Make = ford AND Points = 155", "source": "validation"}
|
| 59 |
+
{"text": "### Question:\nWhat was the average crowd size of games held at Glenferrie Oval?\n\n### SQL:\nSELECT AVG Crowd FROM table WHERE Venue = glenferrie oval", "question": "What was the average crowd size of games held at Glenferrie Oval?", "sql": "SELECT AVG Crowd FROM table WHERE Venue = glenferrie oval", "source": "validation"}
|
| 60 |
+
{"text": "### Question:\nWhat version of iWork was released on October 22, 2013 with a pages version greater than 2?\n\n### SQL:\nSELECT iWork version FROM table WHERE Release date = october 22, 2013 AND Pages version > 2", "question": "What version of iWork was released on October 22, 2013 with a pages version greater than 2?", "sql": "SELECT iWork version FROM table WHERE Release date = october 22, 2013 AND Pages version > 2", "source": "validation"}
|
| 61 |
+
{"text": "### Question:\nName the player for chicago black hawks\n\n### SQL:\nSELECT Player FROM table WHERE NHL team = Chicago Black Hawks", "question": "Name the player for chicago black hawks", "sql": "SELECT Player FROM table WHERE NHL team = Chicago Black Hawks", "source": "validation"}
|
| 62 |
+
{"text": "### Question:\nWhat is the streak for game 2?\n\n### SQL:\nSELECT Streak FROM table WHERE Game = 2", "question": "What is the streak for game 2?", "sql": "SELECT Streak FROM table WHERE Game = 2", "source": "validation"}
|
| 63 |
+
{"text": "### Question:\nI want the D 45 and D 42 of r 22\n\n### SQL:\nSELECT D 45 FROM table WHERE D 42 = r 22", "question": "I want the D 45 and D 42 of r 22", "sql": "SELECT D 45 FROM table WHERE D 42 = r 22", "source": "validation"}
|
| 64 |
+
{"text": "### Question:\nWho was the away team when Queensland Roar was the home team in the round less than 3?\n\n### SQL:\nSELECT Away Team FROM table WHERE Round < 3 AND Home Team = queensland roar", "question": "Who was the away team when Queensland Roar was the home team in the round less than 3?", "sql": "SELECT Away Team FROM table WHERE Round < 3 AND Home Team = queensland roar", "source": "validation"}
|
| 65 |
+
{"text": "### Question:\nHow many artists were there for the show thoroughly modern millie?\n\n### SQL:\nSELECT COUNT Artist FROM table WHERE Show = Thoroughly Modern Millie", "question": "How many artists were there for the show thoroughly modern millie?", "sql": "SELECT COUNT Artist FROM table WHERE Show = Thoroughly Modern Millie", "source": "validation"}
|
| 66 |
+
{"text": "### Question:\nWhich wrestling event was at the 2008 Beijing games?\n\n### SQL:\nSELECT Event FROM table WHERE Sport = wrestling AND Games = 2008 beijing", "question": "Which wrestling event was at the 2008 Beijing games?", "sql": "SELECT Event FROM table WHERE Sport = wrestling AND Games = 2008 beijing", "source": "validation"}
|
| 67 |
+
{"text": "### Question:\nWho was the opponent in London, England in a round less than 2?\n\n### SQL:\nSELECT Opponent FROM table WHERE Location = london, england AND Round < 2", "question": "Who was the opponent in London, England in a round less than 2?", "sql": "SELECT Opponent FROM table WHERE Location = london, england AND Round < 2", "source": "validation"}
|
| 68 |
+
{"text": "### Question:\nWhat is the ceremony year when Ganito Kami Noon, Paano Kayo Ngayon was the original title?\n\n### SQL:\nSELECT Year (Ceremony) FROM table WHERE Original title = ganito kami noon, paano kayo ngayon", "question": "What is the ceremony year when Ganito Kami Noon, Paano Kayo Ngayon was the original title?", "sql": "SELECT Year (Ceremony) FROM table WHERE Original title = ganito kami noon, paano kayo ngayon", "source": "validation"}
|
| 69 |
+
{"text": "### Question:\nWhen the total score is 740, what is tromso?\n\n### SQL:\nSELECT MIN Troms\u00f8 FROM table WHERE Total = 740", "question": "When the total score is 740, what is tromso?", "sql": "SELECT MIN Troms\u00f8 FROM table WHERE Total = 740", "source": "validation"}
|
| 70 |
+
{"text": "### Question:\nWhat is the result for director Said Elmarouk before 2008?\n\n### SQL:\nSELECT Result FROM table WHERE Director = said elmarouk AND Year < 2008", "question": "What is the result for director Said Elmarouk before 2008?", "sql": "SELECT Result FROM table WHERE Director = said elmarouk AND Year < 2008", "source": "validation"}
|
| 71 |
+
{"text": "### Question:\nWhen was the score 56-26?\n\n### SQL:\nSELECT Date FROM table WHERE Record = 56-26", "question": "When was the score 56-26?", "sql": "SELECT Date FROM table WHERE Record = 56-26", "source": "validation"}
|
| 72 |
+
{"text": "### Question:\nName the D 44 when it has a D 46 of d 31\n\n### SQL:\nSELECT D 44 FROM table WHERE D 46 = d 31", "question": "Name the D 44 when it has a D 46 of d 31", "sql": "SELECT D 44 FROM table WHERE D 46 = d 31", "source": "validation"}
|
| 73 |
+
{"text": "### Question:\nWhat was the date of the race that lasted 6 hours?\n\n### SQL:\nSELECT Date FROM table WHERE Length/Duration = 6 hours", "question": "What was the date of the race that lasted 6 hours?", "sql": "SELECT Date FROM table WHERE Length/Duration = 6 hours", "source": "validation"}
|
| 74 |
+
{"text": "### Question:\nWhich event is in the 1952 summer olympics?\n\n### SQL:\nSELECT Event FROM table WHERE Olympics = 1952 summer olympics", "question": "Which event is in the 1952 summer olympics?", "sql": "SELECT Event FROM table WHERE Olympics = 1952 summer olympics", "source": "validation"}
|
| 75 |
+
{"text": "### Question:\n the 2010 clausura tournament?\n\n### SQL:\nSELECT Coefficient FROM table WHERE Tournament = 2010 Clausura", "question": " the 2010 clausura tournament?", "sql": "SELECT Coefficient FROM table WHERE Tournament = 2010 Clausura", "source": "validation"}
|
| 76 |
+
{"text": "### Question:\nWhat was the score of the BCS National Championship game?\n\n### SQL:\nSELECT Score FROM table WHERE Bowl Game = bcs national championship", "question": "What was the score of the BCS National Championship game?", "sql": "SELECT Score FROM table WHERE Bowl Game = bcs national championship", "source": "validation"}
|
| 77 |
+
{"text": "### Question:\nWhat was the attendance when their record stood at 0-2-2?\n\n### SQL:\nSELECT SUM Attendance FROM table WHERE Record = 0-2-2", "question": "What was the attendance when their record stood at 0-2-2?", "sql": "SELECT SUM Attendance FROM table WHERE Record = 0-2-2", "source": "validation"}
|
| 78 |
+
{"text": "### Question:\nWhat were the results before the year 2000?\n\n### SQL:\nSELECT Result FROM table WHERE Year < 2000", "question": "What were the results before the year 2000?", "sql": "SELECT Result FROM table WHERE Year < 2000", "source": "validation"}
|
| 79 |
+
{"text": "### Question:\nHow much time is required for less than 35 laps and less than 10 grids?\n\n### SQL:\nSELECT Time/Retired FROM table WHERE Laps < 35 AND Grid < 10", "question": "How much time is required for less than 35 laps and less than 10 grids?", "sql": "SELECT Time/Retired FROM table WHERE Laps < 35 AND Grid < 10", "source": "validation"}
|
| 80 |
+
{"text": "### Question:\nWhen oslo is 48, what is stavanger?\n\n### SQL:\nSELECT MIN Stavanger FROM table WHERE Oslo = 48", "question": "When oslo is 48, what is stavanger?", "sql": "SELECT MIN Stavanger FROM table WHERE Oslo = 48", "source": "validation"}
|
| 81 |
+
{"text": "### Question:\nWhen was the year that had an average attendance of 5,445?\n\n### SQL:\nSELECT Year FROM table WHERE Avg. attendance = 5,445", "question": "When was the year that had an average attendance of 5,445?", "sql": "SELECT Year FROM table WHERE Avg. attendance = 5,445", "source": "validation"}
|
| 82 |
+
{"text": "### Question:\nFor the game with 528 attendance, what was the result?\n\n### SQL:\nSELECT Result FROM table WHERE Attendance = 528", "question": "For the game with 528 attendance, what was the result?", "sql": "SELECT Result FROM table WHERE Attendance = 528", "source": "validation"}
|
| 83 |
+
{"text": "### Question:\nWhat dated was the game played at the location delta center 19,911?\n\n### SQL:\nSELECT Date FROM table WHERE Location Attendance = Delta Center 19,911", "question": "What dated was the game played at the location delta center 19,911?", "sql": "SELECT Date FROM table WHERE Location Attendance = Delta Center 19,911", "source": "validation"}
|
| 84 |
+
{"text": "### Question:\nWhat is the ISBN of \"Dead as a Doornail?\n\n### SQL:\nSELECT Paperback FROM table WHERE Title = Dead as a Doornail", "question": "What is the ISBN of \"Dead as a Doornail?", "sql": "SELECT Paperback FROM table WHERE Title = Dead as a Doornail", "source": "validation"}
|
| 85 |
+
{"text": "### Question:\nWhat scored is recorded on April 24?\n\n### SQL:\nSELECT Score FROM table WHERE Date = april 24", "question": "What scored is recorded on April 24?", "sql": "SELECT Score FROM table WHERE Date = april 24", "source": "validation"}
|
| 86 |
+
{"text": "### Question:\nWho acquired tom norton?\n\n### SQL:\nSELECT Acquired FROM table WHERE Player = tom norton", "question": "Who acquired tom norton?", "sql": "SELECT Acquired FROM table WHERE Player = tom norton", "source": "validation"}
|
| 87 |
+
{"text": "### Question:\nWHAT IS THE HIGHEST VIEWERS WITH AN EPISODE LESS THAN 15 AND SHARE LAGER THAN 7?\n\n### SQL:\nSELECT MAX Viewers (millions) FROM table WHERE Episode number < 15 AND Share > 7", "question": "WHAT IS THE HIGHEST VIEWERS WITH AN EPISODE LESS THAN 15 AND SHARE LAGER THAN 7?", "sql": "SELECT MAX Viewers (millions) FROM table WHERE Episode number < 15 AND Share > 7", "source": "validation"}
|
| 88 |
+
{"text": "### Question:\nWhat is the English name given to the city of St. John's?\n\n### SQL:\nSELECT Capital ( exonym ) FROM table WHERE Capital ( endonym ) = St. John's", "question": "What is the English name given to the city of St. John's?", "sql": "SELECT Capital ( exonym ) FROM table WHERE Capital ( endonym ) = St. John's", "source": "validation"}
|
| 89 |
+
{"text": "### Question:\nWhat was the result of the game played on November 23, 2003?\n\n### SQL:\nSELECT Result FROM table WHERE Date = november 23, 2003", "question": "What was the result of the game played on November 23, 2003?", "sql": "SELECT Result FROM table WHERE Date = november 23, 2003", "source": "validation"}
|
| 90 |
+
{"text": "### Question:\nWho directed An Egg Scramble?\n\n### SQL:\nSELECT Director FROM table WHERE Title = an egg scramble", "question": "Who directed An Egg Scramble?", "sql": "SELECT Director FROM table WHERE Title = an egg scramble", "source": "validation"}
|
| 91 |
+
{"text": "### Question:\nWhat is the Bulgarian Commander of the Battle of Rusion?\n\n### SQL:\nSELECT Bulgarian Commander FROM table WHERE Battle = battle of rusion", "question": "What is the Bulgarian Commander of the Battle of Rusion?", "sql": "SELECT Bulgarian Commander FROM table WHERE Battle = battle of rusion", "source": "validation"}
|
| 92 |
+
{"text": "### Question:\nWhat was the location of the game when the record was 12-4?\n\n### SQL:\nSELECT Location FROM table WHERE Record = 12-4", "question": "What was the location of the game when the record was 12-4?", "sql": "SELECT Location FROM table WHERE Record = 12-4", "source": "validation"}
|
| 93 |
+
{"text": "### Question:\nWhat Service Name has UTV as the owner?\n\n### SQL:\nSELECT Service name FROM table WHERE Owner = utv", "question": "What Service Name has UTV as the owner?", "sql": "SELECT Service name FROM table WHERE Owner = utv", "source": "validation"}
|
| 94 |
+
{"text": "### Question:\nWhich Rebuilt has a Name as rebuilt of binevanagh?\n\n### SQL:\nSELECT Rebuilt FROM table WHERE Name as rebuilt = binevanagh", "question": "Which Rebuilt has a Name as rebuilt of binevanagh?", "sql": "SELECT Rebuilt FROM table WHERE Name as rebuilt = binevanagh", "source": "validation"}
|
| 95 |
+
{"text": "### Question:\nwhat is the margin of victory when the runner-up is amy alcott and the winning score is \u20139 (72-68-67=207)?\n\n### SQL:\nSELECT Margin of victory FROM table WHERE Runner(s)-up = amy alcott AND Winning score = \u20139 (72-68-67=207)", "question": "what is the margin of victory when the runner-up is amy alcott and the winning score is \u20139 (72-68-67=207)?", "sql": "SELECT Margin of victory FROM table WHERE Runner(s)-up = amy alcott AND Winning score = \u20139 (72-68-67=207)", "source": "validation"}
|
| 96 |
+
{"text": "### Question:\nWhat is the height of the building with 40 floors?\n\n### SQL:\nSELECT Height ft / m FROM table WHERE Floors = 40", "question": "What is the height of the building with 40 floors?", "sql": "SELECT Height ft / m FROM table WHERE Floors = 40", "source": "validation"}
|
| 97 |
+
{"text": "### Question:\nWhat is the total number drawn with goals against less than 55, and a total of 14 losses?\n\n### SQL:\nSELECT COUNT Drawn FROM table WHERE Goals Against < 55 AND Lost = 14", "question": "What is the total number drawn with goals against less than 55, and a total of 14 losses?", "sql": "SELECT COUNT Drawn FROM table WHERE Goals Against < 55 AND Lost = 14", "source": "validation"}
|
| 98 |
+
{"text": "### Question:\nWhich engine from 1973 has a Brabham bt37 chassis?\n\n### SQL:\nSELECT Engine FROM table WHERE Year = 1973 AND Chassis = brabham bt37", "question": "Which engine from 1973 has a Brabham bt37 chassis?", "sql": "SELECT Engine FROM table WHERE Year = 1973 AND Chassis = brabham bt37", "source": "validation"}
|
| 99 |
+
{"text": "### Question:\nTell me the final score for january 9 for cincinnati bengals\n\n### SQL:\nSELECT Final Score FROM table WHERE Date = january 9 AND Host Team = cincinnati bengals", "question": "Tell me the final score for january 9 for cincinnati bengals", "sql": "SELECT Final Score FROM table WHERE Date = january 9 AND Host Team = cincinnati bengals", "source": "validation"}
|
| 100 |
+
{"text": "### Question:\nWhat player was place of t1 in To Par and had a score of 70-73-69=212?\n\n### SQL:\nSELECT To par FROM table WHERE Place = t1 AND Score = 70-73-69=212", "question": "What player was place of t1 in To Par and had a score of 70-73-69=212?", "sql": "SELECT To par FROM table WHERE Place = t1 AND Score = 70-73-69=212", "source": "validation"}
|
src/outputs/finetuning/visualizations/01_metrics_overview.png
ADDED
|
src/outputs/finetuning/visualizations/02_token_accuracy_dist.png
ADDED
|
src/outputs/finetuning/visualizations/03_keyword_accuracy_dist.png
ADDED
|
src/outputs/finetuning/visualizations/04_training_loss.png
ADDED
|
src/outputs/rag/reports/knowledge_base_report.md
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# RAG Knowledge Base Report
|
| 2 |
+
|
| 3 |
+
**Generated:** 2025-12-07 23:58:56
|
| 4 |
+
|
| 5 |
+
## Overview
|
| 6 |
+
|
| 7 |
+
| Metric | Value |
|
| 8 |
+
|--------|-------|
|
| 9 |
+
| Total Documents | 80,654 |
|
| 10 |
+
| Collection Name | sql_knowledge |
|
| 11 |
+
| Embedding Model | all-MiniLM-L6-v2 |
|
| 12 |
+
|
| 13 |
+
## Data Sources
|
| 14 |
+
|
| 15 |
+
| Source | Documents |
|
| 16 |
+
|--------|-----------|
|
| 17 |
+
| train | 56,355 |
|
| 18 |
+
| validation | 8,421 |
|
| 19 |
+
| test | 15,878 |
|
| 20 |
+
|
| 21 |
+
## Chunking Strategies
|
| 22 |
+
|
| 23 |
+
1. **SQL Clause Extraction**: Identifies SELECT, FROM, WHERE, GROUP BY, etc.
|
| 24 |
+
2. **Complexity Classification**: Categorizes as simple/intermediate/complex
|
| 25 |
+
3. **Keyword Extraction**: Extracts SQL operations (JOIN, COUNT, etc.)
|
| 26 |
+
4. **Size Categorization**: Classifies question/SQL length
|
| 27 |
+
|
| 28 |
+
## Complexity Distribution
|
| 29 |
+
|
| 30 |
+
| Level | Count |
|
| 31 |
+
|-------|-------|
|
| 32 |
+
| Simple | 80,396 |
|
| 33 |
+
| Intermediate | 258 |
|
| 34 |
+
| Complex | 0 |
|
| 35 |
+
|
| 36 |
+
## Document Metadata Structure
|
| 37 |
+
|
| 38 |
+
Each document contains:
|
| 39 |
+
- `sql`: The SQL query
|
| 40 |
+
- `source`: Origin dataset
|
| 41 |
+
- `question`: Original question
|
| 42 |
+
- `complexity`: simple/intermediate/complex
|
| 43 |
+
- `sql_clauses`: Comma-separated clauses
|
| 44 |
+
- `keywords`: SQL keywords found
|
| 45 |
+
- `question_size`: short/medium/long
|
| 46 |
+
- `sql_size`: short/medium/long
|
src/outputs/rag/stats/knowledge_base_stats.json
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"total_documents": 80654,
|
| 3 |
+
"sources": {
|
| 4 |
+
"train": 56355,
|
| 5 |
+
"validation": 8421,
|
| 6 |
+
"test": 15878
|
| 7 |
+
},
|
| 8 |
+
"collection_name": "sql_knowledge",
|
| 9 |
+
"embedding_model": "all-MiniLM-L6-v2",
|
| 10 |
+
"chunking_strategies": [
|
| 11 |
+
"sql_clause_extraction",
|
| 12 |
+
"complexity_classification",
|
| 13 |
+
"keyword_extraction",
|
| 14 |
+
"size_categorization"
|
| 15 |
+
],
|
| 16 |
+
"complexity_distribution": {
|
| 17 |
+
"simple": 80396,
|
| 18 |
+
"intermediate": 258,
|
| 19 |
+
"complex": 0
|
| 20 |
+
},
|
| 21 |
+
"created_at": "2025-12-07T23:58:56.309706"
|
| 22 |
+
}
|
src/outputs/synthetic/reports/synthetic_report.md
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Synthetic Data Generation Report
|
| 2 |
+
|
| 3 |
+
**Generated:** 2025-12-07 23:24:17
|
| 4 |
+
|
| 5 |
+
## Dataset Statistics
|
| 6 |
+
|
| 7 |
+
| Metric | Original | Synthetic |
|
| 8 |
+
|--------|----------|-----------|
|
| 9 |
+
| Samples | 52,527 | 142,639 |
|
| 10 |
+
| Avg Length | 11.64 | 14.75 |
|
| 11 |
+
| Min Length | 3 | 3 |
|
| 12 |
+
| Max Length | 44 | 49 |
|
| 13 |
+
| Unique Words | 50,846 | 60,734 |
|
| 14 |
+
|
| 15 |
+
## Augmentation Results
|
| 16 |
+
|
| 17 |
+
- **Augmentation Factor:** 2.72x
|
| 18 |
+
- **Avg Diversity Score:** 0.2832
|
| 19 |
+
- **Min Diversity Score:** 0.103
|
| 20 |
+
- **Max Diversity Score:** 0.8
|
| 21 |
+
|
| 22 |
+
## Techniques Used
|
| 23 |
+
|
| 24 |
+
1. Synonym Replacement (40% probability)
|
| 25 |
+
2. Random Insertion (15% probability)
|
| 26 |
+
3. Random Swap (10% probability)
|
| 27 |
+
4. Structure Variation (prefix/suffix)
|
| 28 |
+
5. Case Variation
|
| 29 |
+
|
| 30 |
+
## Quality Controls
|
| 31 |
+
|
| 32 |
+
- Minimum question length: 10 characters
|
| 33 |
+
- Maximum question length: 500 characters
|
| 34 |
+
- Minimum diversity score: 0.1
|
| 35 |
+
- Duplicate removal via MD5 hashing
|
| 36 |
+
|
| 37 |
+
## Privacy Measures
|
| 38 |
+
|
| 39 |
+
- Email anonymization
|
| 40 |
+
- Phone number anonymization
|
| 41 |
+
- SSN anonymization
|
| 42 |
+
|
| 43 |
+
## Visualizations
|
| 44 |
+
|
| 45 |
+
- `01_size_comparison.png` - Dataset size comparison
|
| 46 |
+
- `02_length_distribution.png` - Question length distribution
|
| 47 |
+
- `03_diversity_distribution.png` - Diversity score distribution
|
src/outputs/synthetic/stats/statistics.json
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"original": {
|
| 3 |
+
"name": "Original",
|
| 4 |
+
"samples": 52527,
|
| 5 |
+
"avg_length": 11.64,
|
| 6 |
+
"min_length": 3,
|
| 7 |
+
"max_length": 44,
|
| 8 |
+
"unique_words": 50846
|
| 9 |
+
},
|
| 10 |
+
"synthetic": {
|
| 11 |
+
"name": "Synthetic",
|
| 12 |
+
"samples": 142639,
|
| 13 |
+
"avg_length": 14.75,
|
| 14 |
+
"min_length": 3,
|
| 15 |
+
"max_length": 49,
|
| 16 |
+
"unique_words": 60734
|
| 17 |
+
},
|
| 18 |
+
"diversity": {
|
| 19 |
+
"avg": 0.2832,
|
| 20 |
+
"min": 0.103,
|
| 21 |
+
"max": 0.8
|
| 22 |
+
},
|
| 23 |
+
"augmentation_factor": 2.72
|
| 24 |
+
}
|
src/outputs/synthetic/visualizations/01_size_comparison.png
ADDED
|
src/outputs/synthetic/visualizations/02_length_distribution.png
ADDED
|
src/outputs/synthetic/visualizations/03_diversity_distribution.png
ADDED
|
src/pipeline/integrated.py
ADDED
|
@@ -0,0 +1,584 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Integrated Pipeline: RAG + Fine-tuned Model + Gemini Enhancement
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
import sys
|
| 7 |
+
import json
|
| 8 |
+
from datetime import datetime
|
| 9 |
+
from dotenv import load_dotenv
|
| 10 |
+
|
| 11 |
+
# Load environment variables from .env file
|
| 12 |
+
load_dotenv()
|
| 13 |
+
|
| 14 |
+
# Add parent directory
|
| 15 |
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 16 |
+
|
| 17 |
+
# =============================================================================
|
| 18 |
+
# CONFIGURATION
|
| 19 |
+
# =============================================================================
|
| 20 |
+
|
| 21 |
+
OUTPUT_DIR = "outputs/pipeline"
|
| 22 |
+
LOGS_DIR = f"{OUTPUT_DIR}/logs"
|
| 23 |
+
|
| 24 |
+
# Gemini config - loaded from .env with fallbacks
|
| 25 |
+
GEMINI_KEYS = [
|
| 26 |
+
os.getenv("GEMINI_API_KEY"),
|
| 27 |
+
os.getenv("GEMINI_API_KEY_FALLBACK_1"),
|
| 28 |
+
os.getenv("GEMINI_API_KEY_FALLBACK_2"),
|
| 29 |
+
]
|
| 30 |
+
# Remove None values
|
| 31 |
+
GEMINI_KEYS = [k for k in GEMINI_KEYS if k]
|
| 32 |
+
|
| 33 |
+
GEMINI_MODELS = [
|
| 34 |
+
os.getenv("GEMINI_MODEL", "gemini-2.5-flash"),
|
| 35 |
+
os.getenv("GEMINI_MODEL_FALLBACK_1"),
|
| 36 |
+
]
|
| 37 |
+
# Remove None values
|
| 38 |
+
GEMINI_MODELS = [m for m in GEMINI_MODELS if m]
|
| 39 |
+
|
| 40 |
+
if not GEMINI_KEYS:
|
| 41 |
+
print("⚠️ Warning: No GEMINI_API_KEY found in .env file")
|
| 42 |
+
else:
|
| 43 |
+
print(f"✓ Found {len(GEMINI_KEYS)} Gemini API key(s)")
|
| 44 |
+
print(f"✓ Found {len(GEMINI_MODELS)} Gemini model(s)")
|
| 45 |
+
|
| 46 |
+
def setup_directories():
|
| 47 |
+
for d in [OUTPUT_DIR, LOGS_DIR]:
|
| 48 |
+
os.makedirs(d, exist_ok=True)
|
| 49 |
+
|
| 50 |
+
# =============================================================================
|
| 51 |
+
# GEMINI CLIENT WITH FALLBACK
|
| 52 |
+
# =============================================================================
|
| 53 |
+
|
| 54 |
+
class GeminiClient:
|
| 55 |
+
"""Gemini client with automatic fallback for rate limits."""
|
| 56 |
+
|
| 57 |
+
def __init__(self):
|
| 58 |
+
self.genai = None
|
| 59 |
+
self.current_key_idx = 0
|
| 60 |
+
self.current_model_idx = 0
|
| 61 |
+
self.model = None
|
| 62 |
+
self.initialized = False
|
| 63 |
+
|
| 64 |
+
try:
|
| 65 |
+
import google.generativeai as genai
|
| 66 |
+
self.genai = genai
|
| 67 |
+
self._init_model()
|
| 68 |
+
except ImportError:
|
| 69 |
+
print("✗ google-generativeai not installed")
|
| 70 |
+
|
| 71 |
+
def _init_model(self):
|
| 72 |
+
"""Initialize model with current key and model."""
|
| 73 |
+
if not GEMINI_KEYS:
|
| 74 |
+
return False
|
| 75 |
+
|
| 76 |
+
key = GEMINI_KEYS[self.current_key_idx]
|
| 77 |
+
model_name = GEMINI_MODELS[self.current_model_idx]
|
| 78 |
+
|
| 79 |
+
try:
|
| 80 |
+
self.genai.configure(api_key=key)
|
| 81 |
+
self.model = self.genai.GenerativeModel(model_name)
|
| 82 |
+
self.initialized = True
|
| 83 |
+
print(f" Using API key #{self.current_key_idx + 1}, model: {model_name}")
|
| 84 |
+
return True
|
| 85 |
+
except Exception as e:
|
| 86 |
+
print(f" Failed to init Gemini: {e}")
|
| 87 |
+
return False
|
| 88 |
+
|
| 89 |
+
def _switch_to_next(self):
|
| 90 |
+
"""Switch to next model or key combination."""
|
| 91 |
+
# Try next model with same key
|
| 92 |
+
if self.current_model_idx < len(GEMINI_MODELS) - 1:
|
| 93 |
+
self.current_model_idx += 1
|
| 94 |
+
print(f" ⟳ Switching to fallback model: {GEMINI_MODELS[self.current_model_idx]}")
|
| 95 |
+
return self._init_model()
|
| 96 |
+
|
| 97 |
+
# Try next key with first model
|
| 98 |
+
if self.current_key_idx < len(GEMINI_KEYS) - 1:
|
| 99 |
+
self.current_key_idx += 1
|
| 100 |
+
self.current_model_idx = 0
|
| 101 |
+
print(f" ⟳ Switching to fallback API key #{self.current_key_idx + 1}")
|
| 102 |
+
return self._init_model()
|
| 103 |
+
|
| 104 |
+
# No more fallbacks
|
| 105 |
+
print(" ✗ All Gemini keys/models exhausted")
|
| 106 |
+
return False
|
| 107 |
+
|
| 108 |
+
def generate(self, prompt, max_retries=None):
|
| 109 |
+
"""Generate content with automatic fallback."""
|
| 110 |
+
if not self.initialized or not self.model:
|
| 111 |
+
return None, "Gemini not initialized"
|
| 112 |
+
|
| 113 |
+
# Calculate max retries based on available combinations
|
| 114 |
+
if max_retries is None:
|
| 115 |
+
max_retries = len(GEMINI_KEYS) * len(GEMINI_MODELS)
|
| 116 |
+
|
| 117 |
+
attempts = 0
|
| 118 |
+
while attempts < max_retries:
|
| 119 |
+
try:
|
| 120 |
+
response = self.model.generate_content(prompt)
|
| 121 |
+
return response.text.strip(), None
|
| 122 |
+
except Exception as e:
|
| 123 |
+
error_str = str(e)
|
| 124 |
+
|
| 125 |
+
# Check if rate limit error
|
| 126 |
+
if "429" in error_str or "quota" in error_str.lower() or "rate" in error_str.lower():
|
| 127 |
+
print(f" ⚠️ Rate limit hit")
|
| 128 |
+
if not self._switch_to_next():
|
| 129 |
+
return None, "All API keys exhausted"
|
| 130 |
+
attempts += 1
|
| 131 |
+
else:
|
| 132 |
+
# Other error, don't retry
|
| 133 |
+
return None, error_str
|
| 134 |
+
|
| 135 |
+
return None, "Max retries exceeded"
|
| 136 |
+
|
| 137 |
+
def is_available(self):
|
| 138 |
+
"""Check if Gemini is available."""
|
| 139 |
+
return self.initialized and self.model is not None
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
# =============================================================================
|
| 143 |
+
# COMPONENT IMPORTS
|
| 144 |
+
# =============================================================================
|
| 145 |
+
|
| 146 |
+
def load_components():
|
| 147 |
+
"""Load all pipeline components."""
|
| 148 |
+
components = {}
|
| 149 |
+
|
| 150 |
+
# 1. RAG Retriever (using SQLRetriever class)
|
| 151 |
+
try:
|
| 152 |
+
from rag.retriever import SQLRetriever
|
| 153 |
+
components['rag'] = SQLRetriever()
|
| 154 |
+
print("✓ RAG Retriever loaded")
|
| 155 |
+
except Exception as e:
|
| 156 |
+
components['rag'] = None
|
| 157 |
+
print(f"✗ RAG not available: {e}")
|
| 158 |
+
|
| 159 |
+
# 2. Prompt Builder
|
| 160 |
+
try:
|
| 161 |
+
from prompts.prompt_builder import PromptBuilder
|
| 162 |
+
components['prompt_builder'] = PromptBuilder()
|
| 163 |
+
print("✓ Prompt Builder loaded")
|
| 164 |
+
except Exception as e:
|
| 165 |
+
components['prompt_builder'] = None
|
| 166 |
+
print(f"✗ Prompt Builder not available: {e}")
|
| 167 |
+
|
| 168 |
+
# 3. Fine-tuned Model
|
| 169 |
+
try:
|
| 170 |
+
from finetuning.inference import SQLGenerator
|
| 171 |
+
components['finetuned_model'] = SQLGenerator()
|
| 172 |
+
print("✓ Fine-tuned model loaded")
|
| 173 |
+
except Exception as e:
|
| 174 |
+
components['finetuned_model'] = None
|
| 175 |
+
print(f"✗ Fine-tuned model not available: {e}")
|
| 176 |
+
|
| 177 |
+
# 4. Gemini with fallback support
|
| 178 |
+
try:
|
| 179 |
+
if GEMINI_KEYS:
|
| 180 |
+
components['gemini'] = GeminiClient()
|
| 181 |
+
if components['gemini'].is_available():
|
| 182 |
+
print("✓ Gemini loaded")
|
| 183 |
+
else:
|
| 184 |
+
components['gemini'] = None
|
| 185 |
+
print("✗ Gemini failed to initialize")
|
| 186 |
+
else:
|
| 187 |
+
components['gemini'] = None
|
| 188 |
+
print("✗ Gemini not available (no API keys)")
|
| 189 |
+
except Exception as e:
|
| 190 |
+
components['gemini'] = None
|
| 191 |
+
print(f"✗ Gemini not available: {e}")
|
| 192 |
+
|
| 193 |
+
return components
|
| 194 |
+
|
| 195 |
+
# =============================================================================
|
| 196 |
+
# GEMINI ENHANCEMENT PROMPTS
|
| 197 |
+
# =============================================================================
|
| 198 |
+
|
| 199 |
+
GEMINI_REFINE_PROMPT = """You are an SQL expert. Review and enhance this SQL query.
|
| 200 |
+
|
| 201 |
+
Original Question: {question}
|
| 202 |
+
|
| 203 |
+
Generated SQL (by a smaller model):
|
| 204 |
+
{sql}
|
| 205 |
+
|
| 206 |
+
Your tasks:
|
| 207 |
+
1. Check for syntax errors
|
| 208 |
+
2. Check for logical errors
|
| 209 |
+
3. Optimize if possible
|
| 210 |
+
4. Fix any issues
|
| 211 |
+
|
| 212 |
+
Rules:
|
| 213 |
+
- If the SQL is correct, return it unchanged
|
| 214 |
+
- If it needs fixes, return the corrected version
|
| 215 |
+
- Return ONLY the SQL query, no explanations
|
| 216 |
+
|
| 217 |
+
Enhanced SQL:"""
|
| 218 |
+
|
| 219 |
+
GEMINI_VALIDATE_PROMPT = """You are an SQL validator. Check this SQL query.
|
| 220 |
+
|
| 221 |
+
Question: {question}
|
| 222 |
+
SQL: {sql}
|
| 223 |
+
|
| 224 |
+
Respond in JSON format:
|
| 225 |
+
{{
|
| 226 |
+
"is_valid": true/false,
|
| 227 |
+
"errors": ["list of errors if any"],
|
| 228 |
+
"suggestions": ["list of suggestions if any"],
|
| 229 |
+
"confidence": 0.0-1.0
|
| 230 |
+
}}
|
| 231 |
+
|
| 232 |
+
JSON Response:"""
|
| 233 |
+
|
| 234 |
+
GEMINI_EXPLAIN_PROMPT = """Explain this SQL query in simple terms.
|
| 235 |
+
|
| 236 |
+
SQL: {sql}
|
| 237 |
+
|
| 238 |
+
Provide a brief, beginner-friendly explanation (2-3 sentences):"""
|
| 239 |
+
|
| 240 |
+
# =============================================================================
|
| 241 |
+
# PIPELINE CLASS
|
| 242 |
+
# =============================================================================
|
| 243 |
+
|
| 244 |
+
class IntegratedPipeline:
|
| 245 |
+
"""
|
| 246 |
+
Complete pipeline: RAG → Prompt → Fine-tuned Model → Gemini Enhancement
|
| 247 |
+
"""
|
| 248 |
+
|
| 249 |
+
def __init__(self):
|
| 250 |
+
setup_directories()
|
| 251 |
+
print("\n" + "=" * 50)
|
| 252 |
+
print("LOADING PIPELINE COMPONENTS")
|
| 253 |
+
print("=" * 50)
|
| 254 |
+
self.components = load_components()
|
| 255 |
+
print("=" * 50 + "\n")
|
| 256 |
+
|
| 257 |
+
# -------------------------------------------------------------------------
|
| 258 |
+
# STEP 1: RAG Retrieval
|
| 259 |
+
# -------------------------------------------------------------------------
|
| 260 |
+
def retrieve_context(self, question, top_k=3):
|
| 261 |
+
"""Retrieve similar examples using RAG."""
|
| 262 |
+
if not self.components['rag']:
|
| 263 |
+
return "", []
|
| 264 |
+
|
| 265 |
+
try:
|
| 266 |
+
# Use SQLRetriever's retrieve method
|
| 267 |
+
results = self.components['rag'].retrieve(question, top_k=top_k)
|
| 268 |
+
|
| 269 |
+
# Format as context string
|
| 270 |
+
context = "Similar SQL examples:\n\n"
|
| 271 |
+
examples = []
|
| 272 |
+
for i, r in enumerate(results, 1):
|
| 273 |
+
context += f"Example {i}:\n"
|
| 274 |
+
context += f"Question: {r['question']}\n"
|
| 275 |
+
context += f"SQL: {r['sql']}\n\n"
|
| 276 |
+
examples.append(r)
|
| 277 |
+
|
| 278 |
+
return context, examples
|
| 279 |
+
except Exception as e:
|
| 280 |
+
print(f"RAG error: {e}")
|
| 281 |
+
return "", []
|
| 282 |
+
|
| 283 |
+
def retrieve_context_formatted(self, question, top_k=3):
|
| 284 |
+
"""Use SQLRetriever's built-in context formatting."""
|
| 285 |
+
if not self.components['rag']:
|
| 286 |
+
return ""
|
| 287 |
+
|
| 288 |
+
try:
|
| 289 |
+
return self.components['rag'].retrieve_as_context(question, top_k=top_k)
|
| 290 |
+
except Exception as e:
|
| 291 |
+
print(f"RAG error: {e}")
|
| 292 |
+
return ""
|
| 293 |
+
|
| 294 |
+
# -------------------------------------------------------------------------
|
| 295 |
+
# STEP 2: Build Prompt
|
| 296 |
+
# -------------------------------------------------------------------------
|
| 297 |
+
def build_prompt(self, question, rag_context):
|
| 298 |
+
"""Build prompt with context."""
|
| 299 |
+
if self.components['prompt_builder']:
|
| 300 |
+
result = self.components['prompt_builder'].build_prompt(
|
| 301 |
+
question=question,
|
| 302 |
+
rag_context=rag_context
|
| 303 |
+
)
|
| 304 |
+
if result['success']:
|
| 305 |
+
return result['prompt']
|
| 306 |
+
|
| 307 |
+
# Fallback: simple prompt
|
| 308 |
+
if rag_context:
|
| 309 |
+
return f"{rag_context}\nQuestion: {question}\n\nSQL:"
|
| 310 |
+
return f"Generate SQL for: {question}\n\nSQL:"
|
| 311 |
+
|
| 312 |
+
# -------------------------------------------------------------------------
|
| 313 |
+
# STEP 3: Fine-tuned Model Generation
|
| 314 |
+
# -------------------------------------------------------------------------
|
| 315 |
+
def generate_with_finetuned(self, question, context=""):
|
| 316 |
+
"""Generate SQL using fine-tuned model."""
|
| 317 |
+
if not self.components['finetuned_model']:
|
| 318 |
+
return None, "Fine-tuned model not available"
|
| 319 |
+
|
| 320 |
+
try:
|
| 321 |
+
sql = self.components['finetuned_model'].generate(question, context)
|
| 322 |
+
return sql, None
|
| 323 |
+
except Exception as e:
|
| 324 |
+
return None, str(e)
|
| 325 |
+
|
| 326 |
+
# -------------------------------------------------------------------------
|
| 327 |
+
# STEP 4: Gemini Enhancement
|
| 328 |
+
# -------------------------------------------------------------------------
|
| 329 |
+
def enhance_with_gemini(self, question, sql):
|
| 330 |
+
"""Use Gemini to refine/validate the SQL."""
|
| 331 |
+
if not self.components['gemini']:
|
| 332 |
+
return sql, {"enhanced": False, "reason": "Gemini not available"}
|
| 333 |
+
|
| 334 |
+
try:
|
| 335 |
+
prompt = GEMINI_REFINE_PROMPT.format(question=question, sql=sql)
|
| 336 |
+
enhanced_sql, error = self.components['gemini'].generate(prompt)
|
| 337 |
+
|
| 338 |
+
if error:
|
| 339 |
+
return sql, {"enhanced": False, "reason": error}
|
| 340 |
+
|
| 341 |
+
# Clean up response
|
| 342 |
+
enhanced_sql = self._clean_sql(enhanced_sql)
|
| 343 |
+
|
| 344 |
+
return enhanced_sql, {"enhanced": True, "original": sql}
|
| 345 |
+
except Exception as e:
|
| 346 |
+
return sql, {"enhanced": False, "reason": str(e)}
|
| 347 |
+
|
| 348 |
+
def validate_with_gemini(self, question, sql):
|
| 349 |
+
"""Use Gemini to validate SQL."""
|
| 350 |
+
if not self.components['gemini']:
|
| 351 |
+
return {"is_valid": True, "confidence": 0.5}
|
| 352 |
+
|
| 353 |
+
try:
|
| 354 |
+
prompt = GEMINI_VALIDATE_PROMPT.format(question=question, sql=sql)
|
| 355 |
+
text, error = self.components['gemini'].generate(prompt)
|
| 356 |
+
|
| 357 |
+
if error:
|
| 358 |
+
return {"is_valid": True, "confidence": 0.5, "error": error}
|
| 359 |
+
|
| 360 |
+
# Remove markdown code blocks if present
|
| 361 |
+
if text.startswith("```"):
|
| 362 |
+
text = text.split("```")[1]
|
| 363 |
+
if text.startswith("json"):
|
| 364 |
+
text = text[4:]
|
| 365 |
+
|
| 366 |
+
return json.loads(text)
|
| 367 |
+
except:
|
| 368 |
+
return {"is_valid": True, "confidence": 0.5}
|
| 369 |
+
|
| 370 |
+
def explain_with_gemini(self, sql):
|
| 371 |
+
"""Use Gemini to explain the SQL."""
|
| 372 |
+
if not self.components['gemini']:
|
| 373 |
+
return "Explanation not available"
|
| 374 |
+
|
| 375 |
+
try:
|
| 376 |
+
prompt = GEMINI_EXPLAIN_PROMPT.format(sql=sql)
|
| 377 |
+
explanation, error = self.components['gemini'].generate(prompt)
|
| 378 |
+
|
| 379 |
+
if error:
|
| 380 |
+
return f"Explanation error: {error}"
|
| 381 |
+
|
| 382 |
+
return explanation
|
| 383 |
+
except Exception as e:
|
| 384 |
+
return f"Explanation error: {e}"
|
| 385 |
+
|
| 386 |
+
# -------------------------------------------------------------------------
|
| 387 |
+
# MAIN PIPELINE
|
| 388 |
+
# -------------------------------------------------------------------------
|
| 389 |
+
def run(self, question, enhance=True, validate=False, explain=False, top_k=3):
|
| 390 |
+
"""
|
| 391 |
+
Run the complete pipeline.
|
| 392 |
+
|
| 393 |
+
Args:
|
| 394 |
+
question: Natural language question
|
| 395 |
+
enhance: Use Gemini to enhance SQL
|
| 396 |
+
validate: Use Gemini to validate SQL
|
| 397 |
+
explain: Use Gemini to explain SQL
|
| 398 |
+
top_k: Number of RAG examples to retrieve
|
| 399 |
+
|
| 400 |
+
Returns:
|
| 401 |
+
dict with all results
|
| 402 |
+
"""
|
| 403 |
+
result = {
|
| 404 |
+
'question': question,
|
| 405 |
+
'timestamp': datetime.now().isoformat(),
|
| 406 |
+
'steps': {}
|
| 407 |
+
}
|
| 408 |
+
|
| 409 |
+
# Step 1: RAG Retrieval
|
| 410 |
+
rag_context, examples = self.retrieve_context(question, top_k=top_k)
|
| 411 |
+
result['steps']['rag'] = {
|
| 412 |
+
'context': rag_context,
|
| 413 |
+
'examples': examples,
|
| 414 |
+
'num_examples': len(examples)
|
| 415 |
+
}
|
| 416 |
+
|
| 417 |
+
# Step 2: Build Prompt
|
| 418 |
+
prompt = self.build_prompt(question, rag_context)
|
| 419 |
+
result['steps']['prompt'] = {
|
| 420 |
+
'prompt': prompt,
|
| 421 |
+
'length': len(prompt)
|
| 422 |
+
}
|
| 423 |
+
|
| 424 |
+
# Step 3: Fine-tuned Model
|
| 425 |
+
finetuned_sql, error = self.generate_with_finetuned(question, rag_context)
|
| 426 |
+
result['steps']['finetuned'] = {
|
| 427 |
+
'sql': finetuned_sql,
|
| 428 |
+
'error': error
|
| 429 |
+
}
|
| 430 |
+
|
| 431 |
+
if not finetuned_sql:
|
| 432 |
+
result['success'] = False
|
| 433 |
+
result['final_sql'] = None
|
| 434 |
+
return result
|
| 435 |
+
|
| 436 |
+
# Step 4: Gemini Enhancement
|
| 437 |
+
if enhance:
|
| 438 |
+
enhanced_sql, enhance_info = self.enhance_with_gemini(question, finetuned_sql)
|
| 439 |
+
result['steps']['gemini_enhance'] = {
|
| 440 |
+
'sql': enhanced_sql,
|
| 441 |
+
'info': enhance_info
|
| 442 |
+
}
|
| 443 |
+
result['final_sql'] = enhanced_sql
|
| 444 |
+
else:
|
| 445 |
+
result['final_sql'] = finetuned_sql
|
| 446 |
+
|
| 447 |
+
# Optional: Validation
|
| 448 |
+
if validate:
|
| 449 |
+
validation = self.validate_with_gemini(question, result['final_sql'])
|
| 450 |
+
result['steps']['validation'] = validation
|
| 451 |
+
|
| 452 |
+
# Optional: Explanation
|
| 453 |
+
if explain:
|
| 454 |
+
explanation = self.explain_with_gemini(result['final_sql'])
|
| 455 |
+
result['explanation'] = explanation
|
| 456 |
+
|
| 457 |
+
result['success'] = True
|
| 458 |
+
|
| 459 |
+
# Log result
|
| 460 |
+
self._log_result(result)
|
| 461 |
+
|
| 462 |
+
return result
|
| 463 |
+
|
| 464 |
+
# -------------------------------------------------------------------------
|
| 465 |
+
# UTILITIES
|
| 466 |
+
# -------------------------------------------------------------------------
|
| 467 |
+
def _clean_sql(self, sql):
|
| 468 |
+
"""Clean SQL output."""
|
| 469 |
+
sql = sql.strip()
|
| 470 |
+
# Remove markdown code blocks
|
| 471 |
+
if sql.startswith("```"):
|
| 472 |
+
lines = sql.split("\n")
|
| 473 |
+
sql = "\n".join(lines[1:-1] if lines[-1] == "```" else lines[1:])
|
| 474 |
+
# Remove leading 'sql' keyword
|
| 475 |
+
if sql.lower().startswith("sql"):
|
| 476 |
+
sql = sql[3:].strip()
|
| 477 |
+
return sql
|
| 478 |
+
|
| 479 |
+
def _log_result(self, result):
|
| 480 |
+
"""Log pipeline result."""
|
| 481 |
+
log_file = f"{LOGS_DIR}/pipeline_log.jsonl"
|
| 482 |
+
# Remove examples from log to save space
|
| 483 |
+
log_result = {k: v for k, v in result.items()}
|
| 484 |
+
if 'steps' in log_result and 'rag' in log_result['steps']:
|
| 485 |
+
log_result['steps']['rag'] = {
|
| 486 |
+
'num_examples': log_result['steps']['rag'].get('num_examples', 0)
|
| 487 |
+
}
|
| 488 |
+
with open(log_file, 'a') as f:
|
| 489 |
+
f.write(json.dumps(log_result, default=str) + '\n')
|
| 490 |
+
|
| 491 |
+
def get_component_status(self):
|
| 492 |
+
"""Get status of all components."""
|
| 493 |
+
return {
|
| 494 |
+
'rag': self.components['rag'] is not None,
|
| 495 |
+
'prompt_builder': self.components['prompt_builder'] is not None,
|
| 496 |
+
'finetuned_model': self.components['finetuned_model'] is not None,
|
| 497 |
+
'gemini': self.components['gemini'] is not None
|
| 498 |
+
}
|
| 499 |
+
|
| 500 |
+
# =============================================================================
|
| 501 |
+
# SIMPLE INTERFACE
|
| 502 |
+
# =============================================================================
|
| 503 |
+
|
| 504 |
+
_pipeline = None
|
| 505 |
+
|
| 506 |
+
def get_pipeline():
|
| 507 |
+
"""Get or create pipeline instance."""
|
| 508 |
+
global _pipeline
|
| 509 |
+
if _pipeline is None:
|
| 510 |
+
_pipeline = IntegratedPipeline()
|
| 511 |
+
return _pipeline
|
| 512 |
+
|
| 513 |
+
def generate_sql(question, enhance=True, explain=False):
|
| 514 |
+
"""Simple function to generate SQL."""
|
| 515 |
+
pipeline = get_pipeline()
|
| 516 |
+
result = pipeline.run(question, enhance=enhance, explain=explain)
|
| 517 |
+
|
| 518 |
+
if result['success']:
|
| 519 |
+
return result['final_sql']
|
| 520 |
+
return None
|
| 521 |
+
|
| 522 |
+
# =============================================================================
|
| 523 |
+
# TEST
|
| 524 |
+
# =============================================================================
|
| 525 |
+
|
| 526 |
+
def test_pipeline():
|
| 527 |
+
"""Test the integrated pipeline."""
|
| 528 |
+
|
| 529 |
+
print("=" * 60)
|
| 530 |
+
print("TESTING INTEGRATED PIPELINE")
|
| 531 |
+
print("=" * 60)
|
| 532 |
+
|
| 533 |
+
pipeline = IntegratedPipeline()
|
| 534 |
+
|
| 535 |
+
# Show component status
|
| 536 |
+
print("\nComponent Status:")
|
| 537 |
+
status = pipeline.get_component_status()
|
| 538 |
+
for comp, loaded in status.items():
|
| 539 |
+
icon = "✓" if loaded else "✗"
|
| 540 |
+
print(f" {icon} {comp}")
|
| 541 |
+
|
| 542 |
+
questions = [
|
| 543 |
+
"Find all employees with salary above 50000",
|
| 544 |
+
]
|
| 545 |
+
|
| 546 |
+
for q in questions:
|
| 547 |
+
print(f"\n{'='*60}")
|
| 548 |
+
print(f"Question: {q}")
|
| 549 |
+
print("-" * 60)
|
| 550 |
+
|
| 551 |
+
result = pipeline.run(q, enhance=True, explain=True, top_k=3)
|
| 552 |
+
|
| 553 |
+
# Show RAG results
|
| 554 |
+
print(f"\n[RAG] Retrieved {result['steps']['rag']['num_examples']} examples")
|
| 555 |
+
|
| 556 |
+
# Show fine-tuned output
|
| 557 |
+
print(f"\n[Fine-tuned Model]")
|
| 558 |
+
if result['steps']['finetuned']['sql']:
|
| 559 |
+
print(f" SQL: {result['steps']['finetuned']['sql']}")
|
| 560 |
+
else:
|
| 561 |
+
print(f" Error: {result['steps']['finetuned']['error']}")
|
| 562 |
+
|
| 563 |
+
# Show Gemini enhancement
|
| 564 |
+
if 'gemini_enhance' in result['steps']:
|
| 565 |
+
print(f"\n[Gemini Enhanced]")
|
| 566 |
+
print(f" SQL: {result['steps']['gemini_enhance']['sql']}")
|
| 567 |
+
if result['steps']['finetuned']['sql'] != result['steps']['gemini_enhance']['sql']:
|
| 568 |
+
print(f" ✨ Query was improved!")
|
| 569 |
+
|
| 570 |
+
# Show final
|
| 571 |
+
print(f"\n[Final SQL]")
|
| 572 |
+
print(f" {result['final_sql']}")
|
| 573 |
+
|
| 574 |
+
# Show explanation
|
| 575 |
+
if 'explanation' in result:
|
| 576 |
+
print(f"\n[Explanation]")
|
| 577 |
+
print(f" {result['explanation']}")
|
| 578 |
+
|
| 579 |
+
print("\n" + "=" * 60)
|
| 580 |
+
print("✓ Pipeline test complete")
|
| 581 |
+
print("=" * 60)
|
| 582 |
+
|
| 583 |
+
if __name__ == "__main__":
|
| 584 |
+
test_pipeline()
|
src/prompts/__init__.py
ADDED
|
File without changes
|
src/prompts/prompt_builder.py
ADDED
|
@@ -0,0 +1,440 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Prompt Builder for SQL Learning Assistant
|
| 3 |
+
Handles: Context Management, User Interaction Flows, Edge Cases
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import re
|
| 7 |
+
import os
|
| 8 |
+
import sys
|
| 9 |
+
import json
|
| 10 |
+
from datetime import datetime
|
| 11 |
+
|
| 12 |
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 13 |
+
from prompts.system_prompts import (
|
| 14 |
+
get_system_prompt,
|
| 15 |
+
get_prompt_template,
|
| 16 |
+
CLARIFICATION_PROMPT,
|
| 17 |
+
ERROR_RECOVERY_PROMPT
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
# =============================================================================
|
| 21 |
+
# OUTPUT DIRECTORIES
|
| 22 |
+
# =============================================================================
|
| 23 |
+
|
| 24 |
+
OUTPUT_DIR = "outputs/prompts"
|
| 25 |
+
LOGS_DIR = f"{OUTPUT_DIR}/logs"
|
| 26 |
+
|
| 27 |
+
def setup_directories():
|
| 28 |
+
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
| 29 |
+
os.makedirs(LOGS_DIR, exist_ok=True)
|
| 30 |
+
|
| 31 |
+
# =============================================================================
|
| 32 |
+
# CONTEXT MANAGEMENT
|
| 33 |
+
# =============================================================================
|
| 34 |
+
|
| 35 |
+
class ConversationContext:
|
| 36 |
+
"""
|
| 37 |
+
Manages conversation history and context for multi-turn interactions.
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
def __init__(self, max_history=5):
|
| 41 |
+
self.history = []
|
| 42 |
+
self.max_history = max_history
|
| 43 |
+
self.current_tables = []
|
| 44 |
+
self.current_schema = {}
|
| 45 |
+
self.user_preferences = {}
|
| 46 |
+
|
| 47 |
+
def add_turn(self, question, sql_response, success=True):
|
| 48 |
+
"""Add a conversation turn to history."""
|
| 49 |
+
self.history.append({
|
| 50 |
+
'question': question,
|
| 51 |
+
'sql': sql_response,
|
| 52 |
+
'success': success,
|
| 53 |
+
'timestamp': datetime.now().isoformat()
|
| 54 |
+
})
|
| 55 |
+
|
| 56 |
+
# Keep only recent history
|
| 57 |
+
if len(self.history) > self.max_history:
|
| 58 |
+
self.history = self.history[-self.max_history:]
|
| 59 |
+
|
| 60 |
+
def get_history_context(self):
|
| 61 |
+
"""Format history for prompt injection."""
|
| 62 |
+
if not self.history:
|
| 63 |
+
return ""
|
| 64 |
+
|
| 65 |
+
context = "Previous conversation:\n"
|
| 66 |
+
for turn in self.history[-3:]: # Last 3 turns
|
| 67 |
+
context += f"Q: {turn['question']}\n"
|
| 68 |
+
context += f"SQL: {turn['sql']}\n\n"
|
| 69 |
+
|
| 70 |
+
return context
|
| 71 |
+
|
| 72 |
+
def set_schema(self, schema_dict):
|
| 73 |
+
"""Set current database schema context."""
|
| 74 |
+
self.current_schema = schema_dict
|
| 75 |
+
|
| 76 |
+
def get_schema_context(self):
|
| 77 |
+
"""Format schema for prompt injection."""
|
| 78 |
+
if not self.current_schema:
|
| 79 |
+
return ""
|
| 80 |
+
|
| 81 |
+
context = "Available tables and columns:\n"
|
| 82 |
+
for table, columns in self.current_schema.items():
|
| 83 |
+
context += f"- {table}: {', '.join(columns)}\n"
|
| 84 |
+
|
| 85 |
+
return context
|
| 86 |
+
|
| 87 |
+
def clear(self):
|
| 88 |
+
"""Clear conversation history."""
|
| 89 |
+
self.history = []
|
| 90 |
+
self.current_tables = []
|
| 91 |
+
self.current_schema = {}
|
| 92 |
+
|
| 93 |
+
# =============================================================================
|
| 94 |
+
# QUERY ANALYSIS (For Specialized Flows)
|
| 95 |
+
# =============================================================================
|
| 96 |
+
|
| 97 |
+
def analyze_query_intent(question):
|
| 98 |
+
"""
|
| 99 |
+
Analyze user question to determine query type and intent.
|
| 100 |
+
Returns: dict with query_type, keywords, entities
|
| 101 |
+
"""
|
| 102 |
+
question_lower = question.lower()
|
| 103 |
+
|
| 104 |
+
# Detect query type
|
| 105 |
+
query_type = 'general'
|
| 106 |
+
|
| 107 |
+
# Aggregation patterns
|
| 108 |
+
agg_patterns = ['count', 'sum', 'average', 'avg', 'total', 'maximum', 'max',
|
| 109 |
+
'minimum', 'min', 'how many', 'what is the total']
|
| 110 |
+
if any(p in question_lower for p in agg_patterns):
|
| 111 |
+
query_type = 'aggregation'
|
| 112 |
+
|
| 113 |
+
# Complex query patterns
|
| 114 |
+
complex_patterns = ['join', 'combine', 'merge', 'from multiple', 'across tables',
|
| 115 |
+
'subquery', 'nested', 'with the highest', 'with the lowest']
|
| 116 |
+
if any(p in question_lower for p in complex_patterns):
|
| 117 |
+
query_type = 'complex'
|
| 118 |
+
|
| 119 |
+
# Modification patterns
|
| 120 |
+
mod_patterns = ['insert', 'add new', 'update', 'change', 'modify', 'delete', 'remove']
|
| 121 |
+
if any(p in question_lower for p in mod_patterns):
|
| 122 |
+
query_type = 'modification'
|
| 123 |
+
|
| 124 |
+
# Simple patterns (if nothing else matched)
|
| 125 |
+
simple_patterns = ['show', 'list', 'get', 'find', 'select', 'display']
|
| 126 |
+
if query_type == 'general' and any(p in question_lower for p in simple_patterns):
|
| 127 |
+
query_type = 'simple'
|
| 128 |
+
|
| 129 |
+
# Extract potential keywords
|
| 130 |
+
keywords = []
|
| 131 |
+
sql_keywords = ['where', 'group by', 'order by', 'having', 'limit', 'join',
|
| 132 |
+
'distinct', 'between', 'like', 'in']
|
| 133 |
+
for kw in sql_keywords:
|
| 134 |
+
if kw in question_lower:
|
| 135 |
+
keywords.append(kw.upper())
|
| 136 |
+
|
| 137 |
+
return {
|
| 138 |
+
'query_type': query_type,
|
| 139 |
+
'keywords': keywords,
|
| 140 |
+
'question_length': len(question.split())
|
| 141 |
+
}
|
| 142 |
+
|
| 143 |
+
# =============================================================================
|
| 144 |
+
# EDGE CASE HANDLING
|
| 145 |
+
# =============================================================================
|
| 146 |
+
|
| 147 |
+
def detect_edge_cases(question):
|
| 148 |
+
"""
|
| 149 |
+
Detect potential edge cases in user question.
|
| 150 |
+
Returns: list of edge case types detected
|
| 151 |
+
"""
|
| 152 |
+
edge_cases = []
|
| 153 |
+
question_lower = question.lower()
|
| 154 |
+
|
| 155 |
+
# Empty or too short
|
| 156 |
+
if len(question.strip()) < 5:
|
| 157 |
+
edge_cases.append('too_short')
|
| 158 |
+
|
| 159 |
+
# Too vague
|
| 160 |
+
vague_patterns = ['something', 'stuff', 'things', 'data', 'information']
|
| 161 |
+
if any(p in question_lower for p in vague_patterns) and len(question.split()) < 5:
|
| 162 |
+
edge_cases.append('too_vague')
|
| 163 |
+
|
| 164 |
+
# Multiple questions
|
| 165 |
+
if question.count('?') > 1:
|
| 166 |
+
edge_cases.append('multiple_questions')
|
| 167 |
+
|
| 168 |
+
# Contains SQL (user pasted SQL instead of question)
|
| 169 |
+
sql_patterns = ['SELECT', 'INSERT', 'UPDATE', 'DELETE', 'FROM', 'WHERE']
|
| 170 |
+
if sum(1 for p in sql_patterns if p in question.upper()) >= 2:
|
| 171 |
+
edge_cases.append('contains_sql')
|
| 172 |
+
|
| 173 |
+
# Potentially dangerous operations
|
| 174 |
+
dangerous_patterns = ['drop table', 'truncate', 'delete all', 'remove all']
|
| 175 |
+
if any(p in question_lower for p in dangerous_patterns):
|
| 176 |
+
edge_cases.append('dangerous_operation')
|
| 177 |
+
|
| 178 |
+
# Non-SQL question
|
| 179 |
+
non_sql_patterns = ['weather', 'hello', 'how are you', 'thank', 'bye']
|
| 180 |
+
if any(p in question_lower for p in non_sql_patterns):
|
| 181 |
+
edge_cases.append('not_sql_related')
|
| 182 |
+
|
| 183 |
+
return edge_cases
|
| 184 |
+
|
| 185 |
+
def handle_edge_case(edge_case_type, question):
|
| 186 |
+
"""
|
| 187 |
+
Generate appropriate response for edge cases.
|
| 188 |
+
Returns: (should_continue, message)
|
| 189 |
+
"""
|
| 190 |
+
responses = {
|
| 191 |
+
'too_short': (
|
| 192 |
+
False,
|
| 193 |
+
"Your question is too short. Please provide more details about what data you want to retrieve."
|
| 194 |
+
),
|
| 195 |
+
'too_vague': (
|
| 196 |
+
False,
|
| 197 |
+
"Your question is a bit vague. Could you specify:\n- Which table(s) to query?\n- What columns to retrieve?\n- Any conditions to filter by?"
|
| 198 |
+
),
|
| 199 |
+
'multiple_questions': (
|
| 200 |
+
False,
|
| 201 |
+
"I detected multiple questions. Please ask one question at a time for accurate SQL generation."
|
| 202 |
+
),
|
| 203 |
+
'contains_sql': (
|
| 204 |
+
False,
|
| 205 |
+
"It looks like you've pasted SQL code. Please describe what you want in natural language, and I'll generate the SQL for you."
|
| 206 |
+
),
|
| 207 |
+
'dangerous_operation': (
|
| 208 |
+
False,
|
| 209 |
+
"⚠️ This appears to be a destructive operation (DROP/TRUNCATE/DELETE ALL). Please confirm you want to proceed or rephrase your question."
|
| 210 |
+
),
|
| 211 |
+
'not_sql_related': (
|
| 212 |
+
False,
|
| 213 |
+
"I'm an SQL assistant. Please ask me questions about querying databases, and I'll help generate SQL queries."
|
| 214 |
+
)
|
| 215 |
+
}
|
| 216 |
+
|
| 217 |
+
return responses.get(edge_case_type, (True, ""))
|
| 218 |
+
|
| 219 |
+
# =============================================================================
|
| 220 |
+
# PROMPT BUILDER CLASS
|
| 221 |
+
# =============================================================================
|
| 222 |
+
|
| 223 |
+
class PromptBuilder:
|
| 224 |
+
"""
|
| 225 |
+
Main class for building prompts with context management.
|
| 226 |
+
"""
|
| 227 |
+
|
| 228 |
+
def __init__(self):
|
| 229 |
+
self.context = ConversationContext()
|
| 230 |
+
self.log_file = None
|
| 231 |
+
setup_directories()
|
| 232 |
+
|
| 233 |
+
def build_prompt(self, question, rag_context="", include_history=True):
|
| 234 |
+
"""
|
| 235 |
+
Build complete prompt for SQL generation.
|
| 236 |
+
|
| 237 |
+
Args:
|
| 238 |
+
question: User's natural language question
|
| 239 |
+
rag_context: Retrieved examples from RAG
|
| 240 |
+
include_history: Whether to include conversation history
|
| 241 |
+
|
| 242 |
+
Returns:
|
| 243 |
+
dict with 'success', 'prompt' or 'error'
|
| 244 |
+
"""
|
| 245 |
+
# Check for edge cases
|
| 246 |
+
edge_cases = detect_edge_cases(question)
|
| 247 |
+
|
| 248 |
+
if edge_cases:
|
| 249 |
+
should_continue, message = handle_edge_case(edge_cases[0], question)
|
| 250 |
+
if not should_continue:
|
| 251 |
+
return {
|
| 252 |
+
'success': False,
|
| 253 |
+
'error': message,
|
| 254 |
+
'edge_case': edge_cases[0]
|
| 255 |
+
}
|
| 256 |
+
|
| 257 |
+
# Analyze query intent
|
| 258 |
+
intent = analyze_query_intent(question)
|
| 259 |
+
|
| 260 |
+
# Get appropriate system prompt
|
| 261 |
+
system_prompt = get_system_prompt(intent['query_type'])
|
| 262 |
+
|
| 263 |
+
# Build context parts
|
| 264 |
+
context_parts = []
|
| 265 |
+
|
| 266 |
+
# Add schema context if available
|
| 267 |
+
schema_context = self.context.get_schema_context()
|
| 268 |
+
if schema_context:
|
| 269 |
+
context_parts.append(schema_context)
|
| 270 |
+
|
| 271 |
+
# Add conversation history
|
| 272 |
+
if include_history:
|
| 273 |
+
history_context = self.context.get_history_context()
|
| 274 |
+
if history_context:
|
| 275 |
+
context_parts.append(history_context)
|
| 276 |
+
|
| 277 |
+
# Add RAG context
|
| 278 |
+
if rag_context:
|
| 279 |
+
context_parts.append(rag_context)
|
| 280 |
+
|
| 281 |
+
# Build final prompt
|
| 282 |
+
if rag_context:
|
| 283 |
+
template = get_prompt_template('rag')
|
| 284 |
+
prompt = template.format(
|
| 285 |
+
context=rag_context,
|
| 286 |
+
question=question
|
| 287 |
+
)
|
| 288 |
+
else:
|
| 289 |
+
template = get_prompt_template('zero_shot')
|
| 290 |
+
prompt = template.format(question=question)
|
| 291 |
+
|
| 292 |
+
# Combine everything
|
| 293 |
+
full_prompt = f"{system_prompt}\n\n"
|
| 294 |
+
if context_parts:
|
| 295 |
+
full_prompt += "\n".join(context_parts) + "\n\n"
|
| 296 |
+
full_prompt += prompt
|
| 297 |
+
|
| 298 |
+
# Log the prompt
|
| 299 |
+
self._log_prompt(question, intent, full_prompt)
|
| 300 |
+
|
| 301 |
+
return {
|
| 302 |
+
'success': True,
|
| 303 |
+
'prompt': full_prompt,
|
| 304 |
+
'system_prompt': system_prompt,
|
| 305 |
+
'query_type': intent['query_type'],
|
| 306 |
+
'keywords': intent['keywords']
|
| 307 |
+
}
|
| 308 |
+
|
| 309 |
+
def add_response(self, question, sql_response, success=True):
|
| 310 |
+
"""Add a completed interaction to history."""
|
| 311 |
+
self.context.add_turn(question, sql_response, success)
|
| 312 |
+
|
| 313 |
+
def set_schema(self, schema_dict):
|
| 314 |
+
"""Set database schema for context."""
|
| 315 |
+
self.context.set_schema(schema_dict)
|
| 316 |
+
|
| 317 |
+
def clear_context(self):
|
| 318 |
+
"""Clear all context."""
|
| 319 |
+
self.context.clear()
|
| 320 |
+
|
| 321 |
+
def _log_prompt(self, question, intent, prompt):
|
| 322 |
+
"""Log prompt for debugging/analysis."""
|
| 323 |
+
log_entry = {
|
| 324 |
+
'timestamp': datetime.now().isoformat(),
|
| 325 |
+
'question': question,
|
| 326 |
+
'intent': intent,
|
| 327 |
+
'prompt_length': len(prompt)
|
| 328 |
+
}
|
| 329 |
+
|
| 330 |
+
log_file = f"{LOGS_DIR}/prompt_log.jsonl"
|
| 331 |
+
with open(log_file, 'a') as f:
|
| 332 |
+
f.write(json.dumps(log_entry) + '\n')
|
| 333 |
+
|
| 334 |
+
# =============================================================================
|
| 335 |
+
# USER INTERACTION FLOWS
|
| 336 |
+
# =============================================================================
|
| 337 |
+
|
| 338 |
+
def get_clarification_questions(question, intent):
|
| 339 |
+
"""
|
| 340 |
+
Generate clarification questions for ambiguous queries.
|
| 341 |
+
"""
|
| 342 |
+
clarifications = []
|
| 343 |
+
|
| 344 |
+
# Generic clarifications based on query type
|
| 345 |
+
if intent['query_type'] == 'aggregation':
|
| 346 |
+
clarifications.append("Which column should be aggregated?")
|
| 347 |
+
clarifications.append("Should results be grouped by any column?")
|
| 348 |
+
|
| 349 |
+
if intent['query_type'] == 'complex':
|
| 350 |
+
clarifications.append("Which tables need to be joined?")
|
| 351 |
+
clarifications.append("What is the relationship between the tables?")
|
| 352 |
+
|
| 353 |
+
# Check for missing specifics
|
| 354 |
+
if 'table' not in question.lower():
|
| 355 |
+
clarifications.append("Which table(s) should be queried?")
|
| 356 |
+
|
| 357 |
+
if not any(word in question.lower() for word in ['all', 'specific', 'where', 'filter']):
|
| 358 |
+
clarifications.append("Do you want all records or filtered results?")
|
| 359 |
+
|
| 360 |
+
return clarifications
|
| 361 |
+
|
| 362 |
+
def create_error_recovery_prompt(original_question, error_message):
|
| 363 |
+
"""
|
| 364 |
+
Create prompt for recovering from errors.
|
| 365 |
+
"""
|
| 366 |
+
return ERROR_RECOVERY_PROMPT.format(
|
| 367 |
+
error=error_message,
|
| 368 |
+
question=original_question
|
| 369 |
+
)
|
| 370 |
+
|
| 371 |
+
# =============================================================================
|
| 372 |
+
# TEST
|
| 373 |
+
# =============================================================================
|
| 374 |
+
|
| 375 |
+
def test_prompt_builder():
|
| 376 |
+
"""Test the prompt builder functionality."""
|
| 377 |
+
|
| 378 |
+
print("=" * 60)
|
| 379 |
+
print("TESTING PROMPT BUILDER")
|
| 380 |
+
print("=" * 60)
|
| 381 |
+
|
| 382 |
+
builder = PromptBuilder()
|
| 383 |
+
|
| 384 |
+
# Test 1: Normal question
|
| 385 |
+
print("\n[TEST 1] Normal Question")
|
| 386 |
+
print("-" * 40)
|
| 387 |
+
result = builder.build_prompt(
|
| 388 |
+
"Find all employees with salary above 50000",
|
| 389 |
+
rag_context="Example 1:\nQ: Get workers earning more than 40000\nSQL: SELECT * FROM employees WHERE salary > 40000"
|
| 390 |
+
)
|
| 391 |
+
print(f"Success: {result['success']}")
|
| 392 |
+
print(f"Query Type: {result.get('query_type')}")
|
| 393 |
+
print(f"Prompt Length: {len(result.get('prompt', ''))}")
|
| 394 |
+
|
| 395 |
+
# Test 2: Edge case - too short
|
| 396 |
+
print("\n[TEST 2] Edge Case - Too Short")
|
| 397 |
+
print("-" * 40)
|
| 398 |
+
result = builder.build_prompt("SQL")
|
| 399 |
+
print(f"Success: {result['success']}")
|
| 400 |
+
print(f"Error: {result.get('error', 'None')}")
|
| 401 |
+
|
| 402 |
+
# Test 3: Edge case - contains SQL
|
| 403 |
+
print("\n[TEST 3] Edge Case - Contains SQL")
|
| 404 |
+
print("-" * 40)
|
| 405 |
+
result = builder.build_prompt("SELECT * FROM users WHERE id = 1")
|
| 406 |
+
print(f"Success: {result['success']}")
|
| 407 |
+
print(f"Error: {result.get('error', 'None')}")
|
| 408 |
+
|
| 409 |
+
# Test 4: Edge case - dangerous operation
|
| 410 |
+
print("\n[TEST 4] Edge Case - Dangerous Operation")
|
| 411 |
+
print("-" * 40)
|
| 412 |
+
result = builder.build_prompt("Drop table users")
|
| 413 |
+
print(f"Success: {result['success']}")
|
| 414 |
+
print(f"Error: {result.get('error', 'None')}")
|
| 415 |
+
|
| 416 |
+
# Test 5: Aggregation query
|
| 417 |
+
print("\n[TEST 5] Aggregation Query")
|
| 418 |
+
print("-" * 40)
|
| 419 |
+
result = builder.build_prompt("Count total orders by customer")
|
| 420 |
+
print(f"Success: {result['success']}")
|
| 421 |
+
print(f"Query Type: {result.get('query_type')}")
|
| 422 |
+
|
| 423 |
+
# Test 6: Context management
|
| 424 |
+
print("\n[TEST 6] Context Management")
|
| 425 |
+
print("-" * 40)
|
| 426 |
+
builder.set_schema({
|
| 427 |
+
'employees': ['id', 'name', 'salary', 'dept_id'],
|
| 428 |
+
'departments': ['id', 'name', 'location']
|
| 429 |
+
})
|
| 430 |
+
builder.add_response("Show all employees", "SELECT * FROM employees", success=True)
|
| 431 |
+
result = builder.build_prompt("Now filter by department")
|
| 432 |
+
print(f"Success: {result['success']}")
|
| 433 |
+
print(f"Has History: {'Previous conversation' in result.get('prompt', '')}")
|
| 434 |
+
|
| 435 |
+
print("\n" + "=" * 60)
|
| 436 |
+
print("✓ All tests complete")
|
| 437 |
+
print("=" * 60)
|
| 438 |
+
|
| 439 |
+
if __name__ == "__main__":
|
| 440 |
+
test_prompt_builder()
|
src/prompts/system_prompts.py
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
System Prompts for SQL Learning Assistant
|
| 3 |
+
Systematic prompting strategies for different use cases.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
# =============================================================================
|
| 7 |
+
# BASE SYSTEM PROMPT
|
| 8 |
+
# =============================================================================
|
| 9 |
+
|
| 10 |
+
BASE_SYSTEM_PROMPT = """You are an expert SQL assistant. Your task is to generate accurate SQL queries based on natural language questions.
|
| 11 |
+
|
| 12 |
+
Rules:
|
| 13 |
+
1. Generate ONLY the SQL query, no explanations unless asked
|
| 14 |
+
2. Use standard SQL syntax
|
| 15 |
+
3. Be precise and efficient in your queries
|
| 16 |
+
4. If the question is ambiguous, make reasonable assumptions
|
| 17 |
+
5. Always use proper SQL formatting
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
# =============================================================================
|
| 21 |
+
# SPECIALIZED PROMPTS BY USE CASE
|
| 22 |
+
# =============================================================================
|
| 23 |
+
|
| 24 |
+
# For simple SELECT queries
|
| 25 |
+
SIMPLE_QUERY_PROMPT = """You are an SQL assistant specializing in simple queries.
|
| 26 |
+
|
| 27 |
+
Your task: Convert the natural language question into a basic SQL SELECT query.
|
| 28 |
+
|
| 29 |
+
Guidelines:
|
| 30 |
+
- Use simple SELECT, FROM, WHERE clauses
|
| 31 |
+
- Avoid complex joins unless necessary
|
| 32 |
+
- Keep queries straightforward and readable
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
# For complex queries with JOINs
|
| 36 |
+
COMPLEX_QUERY_PROMPT = """You are an SQL assistant specializing in complex queries.
|
| 37 |
+
|
| 38 |
+
Your task: Convert the natural language question into an SQL query that may involve:
|
| 39 |
+
- Multiple JOINs (INNER, LEFT, RIGHT)
|
| 40 |
+
- Subqueries
|
| 41 |
+
- Multiple conditions
|
| 42 |
+
- Aggregations with GROUP BY
|
| 43 |
+
|
| 44 |
+
Guidelines:
|
| 45 |
+
- Use appropriate JOIN types
|
| 46 |
+
- Structure subqueries clearly
|
| 47 |
+
- Use aliases for readability
|
| 48 |
+
"""
|
| 49 |
+
|
| 50 |
+
# For aggregation queries
|
| 51 |
+
AGGREGATION_PROMPT = """You are an SQL assistant specializing in aggregation queries.
|
| 52 |
+
|
| 53 |
+
Your task: Convert the natural language question into an SQL query using aggregate functions.
|
| 54 |
+
|
| 55 |
+
Guidelines:
|
| 56 |
+
- Use COUNT, SUM, AVG, MAX, MIN appropriately
|
| 57 |
+
- Include GROUP BY when aggregating
|
| 58 |
+
- Use HAVING for aggregate conditions
|
| 59 |
+
- Consider ORDER BY for ranked results
|
| 60 |
+
"""
|
| 61 |
+
|
| 62 |
+
# For data modification (if needed)
|
| 63 |
+
MODIFICATION_PROMPT = """You are an SQL assistant for data modification queries.
|
| 64 |
+
|
| 65 |
+
Your task: Convert the natural language request into INSERT, UPDATE, or DELETE statements.
|
| 66 |
+
|
| 67 |
+
Guidelines:
|
| 68 |
+
- Be cautious with DELETE and UPDATE
|
| 69 |
+
- Always include WHERE clause for UPDATE/DELETE
|
| 70 |
+
- Validate data types for INSERT
|
| 71 |
+
"""
|
| 72 |
+
|
| 73 |
+
# =============================================================================
|
| 74 |
+
# PROMPT TEMPLATES WITH CONTEXT
|
| 75 |
+
# =============================================================================
|
| 76 |
+
|
| 77 |
+
RAG_CONTEXT_TEMPLATE = """You are an expert SQL assistant.
|
| 78 |
+
|
| 79 |
+
Here are similar examples to help you:
|
| 80 |
+
|
| 81 |
+
{context}
|
| 82 |
+
|
| 83 |
+
Based on these examples, generate the SQL query for:
|
| 84 |
+
Question: {question}
|
| 85 |
+
|
| 86 |
+
SQL:"""
|
| 87 |
+
|
| 88 |
+
FEW_SHOT_TEMPLATE = """You are an expert SQL assistant. Learn from these examples:
|
| 89 |
+
|
| 90 |
+
{examples}
|
| 91 |
+
|
| 92 |
+
Now generate SQL for this question:
|
| 93 |
+
Question: {question}
|
| 94 |
+
|
| 95 |
+
SQL:"""
|
| 96 |
+
|
| 97 |
+
ZERO_SHOT_TEMPLATE = """You are an expert SQL assistant.
|
| 98 |
+
|
| 99 |
+
Generate the SQL query for:
|
| 100 |
+
Question: {question}
|
| 101 |
+
|
| 102 |
+
SQL:"""
|
| 103 |
+
|
| 104 |
+
# =============================================================================
|
| 105 |
+
# ERROR HANDLING PROMPTS
|
| 106 |
+
# =============================================================================
|
| 107 |
+
|
| 108 |
+
CLARIFICATION_PROMPT = """I need more information to generate the SQL query.
|
| 109 |
+
|
| 110 |
+
Original question: {question}
|
| 111 |
+
|
| 112 |
+
Please clarify:
|
| 113 |
+
{clarification_points}
|
| 114 |
+
"""
|
| 115 |
+
|
| 116 |
+
ERROR_RECOVERY_PROMPT = """I encountered an issue with the previous query.
|
| 117 |
+
|
| 118 |
+
Error: {error}
|
| 119 |
+
Original question: {question}
|
| 120 |
+
|
| 121 |
+
Let me try a different approach:
|
| 122 |
+
"""
|
| 123 |
+
|
| 124 |
+
# =============================================================================
|
| 125 |
+
# PROMPT SELECTOR
|
| 126 |
+
# =============================================================================
|
| 127 |
+
|
| 128 |
+
def get_system_prompt(query_type='general'):
|
| 129 |
+
"""
|
| 130 |
+
Get appropriate system prompt based on query type.
|
| 131 |
+
|
| 132 |
+
Args:
|
| 133 |
+
query_type: 'simple', 'complex', 'aggregation', 'modification', 'general'
|
| 134 |
+
|
| 135 |
+
Returns:
|
| 136 |
+
System prompt string
|
| 137 |
+
"""
|
| 138 |
+
prompts = {
|
| 139 |
+
'simple': SIMPLE_QUERY_PROMPT,
|
| 140 |
+
'complex': COMPLEX_QUERY_PROMPT,
|
| 141 |
+
'aggregation': AGGREGATION_PROMPT,
|
| 142 |
+
'modification': MODIFICATION_PROMPT,
|
| 143 |
+
'general': BASE_SYSTEM_PROMPT
|
| 144 |
+
}
|
| 145 |
+
return prompts.get(query_type, BASE_SYSTEM_PROMPT)
|
| 146 |
+
|
| 147 |
+
def get_prompt_template(template_type='rag'):
|
| 148 |
+
"""
|
| 149 |
+
Get prompt template by type.
|
| 150 |
+
|
| 151 |
+
Args:
|
| 152 |
+
template_type: 'rag', 'few_shot', 'zero_shot'
|
| 153 |
+
|
| 154 |
+
Returns:
|
| 155 |
+
Template string
|
| 156 |
+
"""
|
| 157 |
+
templates = {
|
| 158 |
+
'rag': RAG_CONTEXT_TEMPLATE,
|
| 159 |
+
'few_shot': FEW_SHOT_TEMPLATE,
|
| 160 |
+
'zero_shot': ZERO_SHOT_TEMPLATE
|
| 161 |
+
}
|
| 162 |
+
return templates.get(template_type, RAG_CONTEXT_TEMPLATE)
|
src/rag/__init__.py
ADDED
|
File without changes
|
src/rag/embeddings.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Embedding Module for RAG System
|
| 3 |
+
Uses FREE sentence-transformers (no API costs).
|
| 4 |
+
Gemini is ONLY used for final SQL generation.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from sentence_transformers import SentenceTransformer
|
| 8 |
+
import os
|
| 9 |
+
|
| 10 |
+
# =============================================================================
|
| 11 |
+
# FREE LOCAL EMBEDDING MODEL
|
| 12 |
+
# =============================================================================
|
| 13 |
+
|
| 14 |
+
# Using all-MiniLM-L6-v2: fast, good quality, 384 dimensions
|
| 15 |
+
MODEL_NAME = "all-MiniLM-L6-v2"
|
| 16 |
+
|
| 17 |
+
# Global model instance (loaded once)
|
| 18 |
+
_model = None
|
| 19 |
+
|
| 20 |
+
def get_model():
|
| 21 |
+
"""Get or load the embedding model."""
|
| 22 |
+
global _model
|
| 23 |
+
if _model is None:
|
| 24 |
+
print(f" Loading embedding model: {MODEL_NAME}")
|
| 25 |
+
_model = SentenceTransformer(MODEL_NAME)
|
| 26 |
+
return _model
|
| 27 |
+
|
| 28 |
+
# =============================================================================
|
| 29 |
+
# EMBEDDING FUNCTIONS
|
| 30 |
+
# =============================================================================
|
| 31 |
+
|
| 32 |
+
def get_embedding(text):
|
| 33 |
+
"""Get embedding for a single text."""
|
| 34 |
+
try:
|
| 35 |
+
model = get_model()
|
| 36 |
+
embedding = model.encode(text, convert_to_numpy=True)
|
| 37 |
+
return embedding.tolist()
|
| 38 |
+
except Exception as e:
|
| 39 |
+
print(f"Error getting embedding: {e}")
|
| 40 |
+
return None
|
| 41 |
+
|
| 42 |
+
def get_embeddings_batch(texts):
|
| 43 |
+
"""Get embeddings for multiple texts at once (efficient)."""
|
| 44 |
+
try:
|
| 45 |
+
model = get_model()
|
| 46 |
+
embeddings = model.encode(texts, convert_to_numpy=True, show_progress_bar=False)
|
| 47 |
+
return [emb.tolist() for emb in embeddings]
|
| 48 |
+
except Exception as e:
|
| 49 |
+
print(f"Error in batch embedding: {e}")
|
| 50 |
+
return [None] * len(texts)
|
| 51 |
+
|
| 52 |
+
# =============================================================================
|
| 53 |
+
# TEST
|
| 54 |
+
# =============================================================================
|
| 55 |
+
|
| 56 |
+
def test_embedding():
|
| 57 |
+
"""Test embedding functionality."""
|
| 58 |
+
print("=" * 50)
|
| 59 |
+
print("TESTING EMBEDDINGS (FREE - No API)")
|
| 60 |
+
print("=" * 50)
|
| 61 |
+
|
| 62 |
+
test_texts = [
|
| 63 |
+
"Find all employees with salary greater than 50000",
|
| 64 |
+
"Show customers who ordered last month",
|
| 65 |
+
"Count products by category"
|
| 66 |
+
]
|
| 67 |
+
|
| 68 |
+
print(f"\nModel: {MODEL_NAME}")
|
| 69 |
+
print(f"Testing with {len(test_texts)} texts...\n")
|
| 70 |
+
|
| 71 |
+
# Single embedding
|
| 72 |
+
emb = get_embedding(test_texts[0])
|
| 73 |
+
if emb:
|
| 74 |
+
print(f"✓ Single embedding works")
|
| 75 |
+
print(f" Dimension: {len(emb)}")
|
| 76 |
+
|
| 77 |
+
# Batch embedding
|
| 78 |
+
embs = get_embeddings_batch(test_texts)
|
| 79 |
+
if embs and embs[0]:
|
| 80 |
+
print(f"✓ Batch embedding works")
|
| 81 |
+
print(f" Got {len(embs)} embeddings")
|
| 82 |
+
|
| 83 |
+
print("\n✓ All tests passed (FREE - No Gemini used)")
|
| 84 |
+
return True
|
| 85 |
+
|
| 86 |
+
if __name__ == "__main__":
|
| 87 |
+
test_embedding()
|
src/rag/knowledge_base.py
ADDED
|
@@ -0,0 +1,415 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Knowledge Base Builder for RAG System
|
| 3 |
+
Includes: Chunking Strategies, Vector Storage
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import pandas as pd
|
| 8 |
+
import chromadb
|
| 9 |
+
import json
|
| 10 |
+
import re
|
| 11 |
+
from datetime import datetime
|
| 12 |
+
import sys
|
| 13 |
+
|
| 14 |
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 15 |
+
from rag.embeddings import get_embeddings_batch
|
| 16 |
+
|
| 17 |
+
# =============================================================================
|
| 18 |
+
# CONFIGURATION
|
| 19 |
+
# =============================================================================
|
| 20 |
+
|
| 21 |
+
CHROMA_DIR = "chromadb_data"
|
| 22 |
+
COLLECTION_NAME = "sql_knowledge"
|
| 23 |
+
OUTPUT_DIR = "outputs/rag"
|
| 24 |
+
STATS_DIR = f"{OUTPUT_DIR}/stats"
|
| 25 |
+
REPORT_DIR = f"{OUTPUT_DIR}/reports"
|
| 26 |
+
|
| 27 |
+
def setup_directories():
|
| 28 |
+
"""Create necessary directories."""
|
| 29 |
+
for d in [CHROMA_DIR, OUTPUT_DIR, STATS_DIR, REPORT_DIR]:
|
| 30 |
+
os.makedirs(d, exist_ok=True)
|
| 31 |
+
|
| 32 |
+
# =============================================================================
|
| 33 |
+
# CHUNKING STRATEGIES
|
| 34 |
+
# =============================================================================
|
| 35 |
+
|
| 36 |
+
def chunk_by_sql_clauses(sql):
|
| 37 |
+
"""
|
| 38 |
+
Chunking Strategy 1: Split SQL by clauses.
|
| 39 |
+
Identifies SELECT, FROM, WHERE, GROUP BY, ORDER BY, etc.
|
| 40 |
+
"""
|
| 41 |
+
clauses = []
|
| 42 |
+
|
| 43 |
+
# Common SQL clause patterns
|
| 44 |
+
patterns = [
|
| 45 |
+
(r'\bSELECT\b.*?(?=\bFROM\b|$)', 'SELECT'),
|
| 46 |
+
(r'\bFROM\b.*?(?=\bWHERE\b|\bGROUP\b|\bORDER\b|\bLIMIT\b|$)', 'FROM'),
|
| 47 |
+
(r'\bWHERE\b.*?(?=\bGROUP\b|\bORDER\b|\bLIMIT\b|$)', 'WHERE'),
|
| 48 |
+
(r'\bGROUP BY\b.*?(?=\bHAVING\b|\bORDER\b|\bLIMIT\b|$)', 'GROUP BY'),
|
| 49 |
+
(r'\bHAVING\b.*?(?=\bORDER\b|\bLIMIT\b|$)', 'HAVING'),
|
| 50 |
+
(r'\bORDER BY\b.*?(?=\bLIMIT\b|$)', 'ORDER BY'),
|
| 51 |
+
(r'\bLIMIT\b.*', 'LIMIT'),
|
| 52 |
+
]
|
| 53 |
+
|
| 54 |
+
sql_upper = sql.upper()
|
| 55 |
+
for pattern, clause_name in patterns:
|
| 56 |
+
match = re.search(pattern, sql_upper, re.IGNORECASE | re.DOTALL)
|
| 57 |
+
if match:
|
| 58 |
+
clauses.append(clause_name)
|
| 59 |
+
|
| 60 |
+
return clauses
|
| 61 |
+
|
| 62 |
+
def chunk_by_complexity(question, sql):
|
| 63 |
+
"""
|
| 64 |
+
Chunking Strategy 2: Categorize by query complexity.
|
| 65 |
+
"""
|
| 66 |
+
sql_upper = sql.upper()
|
| 67 |
+
|
| 68 |
+
# Determine complexity level
|
| 69 |
+
complexity_score = 0
|
| 70 |
+
|
| 71 |
+
# Check for complex features
|
| 72 |
+
if 'JOIN' in sql_upper:
|
| 73 |
+
complexity_score += 2
|
| 74 |
+
if 'SUBQUERY' in sql_upper or sql_upper.count('SELECT') > 1:
|
| 75 |
+
complexity_score += 2
|
| 76 |
+
if 'GROUP BY' in sql_upper:
|
| 77 |
+
complexity_score += 1
|
| 78 |
+
if 'HAVING' in sql_upper:
|
| 79 |
+
complexity_score += 1
|
| 80 |
+
if 'ORDER BY' in sql_upper:
|
| 81 |
+
complexity_score += 1
|
| 82 |
+
if any(agg in sql_upper for agg in ['COUNT', 'SUM', 'AVG', 'MAX', 'MIN']):
|
| 83 |
+
complexity_score += 1
|
| 84 |
+
if 'UNION' in sql_upper:
|
| 85 |
+
complexity_score += 2
|
| 86 |
+
|
| 87 |
+
# Categorize
|
| 88 |
+
if complexity_score <= 1:
|
| 89 |
+
return 'simple'
|
| 90 |
+
elif complexity_score <= 3:
|
| 91 |
+
return 'intermediate'
|
| 92 |
+
else:
|
| 93 |
+
return 'complex'
|
| 94 |
+
|
| 95 |
+
def extract_sql_keywords(sql):
|
| 96 |
+
"""
|
| 97 |
+
Chunking Strategy 3: Extract SQL keywords for metadata.
|
| 98 |
+
"""
|
| 99 |
+
sql_upper = sql.upper()
|
| 100 |
+
|
| 101 |
+
keywords = []
|
| 102 |
+
|
| 103 |
+
# Operations
|
| 104 |
+
if 'SELECT' in sql_upper:
|
| 105 |
+
keywords.append('SELECT')
|
| 106 |
+
if 'INSERT' in sql_upper:
|
| 107 |
+
keywords.append('INSERT')
|
| 108 |
+
if 'UPDATE' in sql_upper:
|
| 109 |
+
keywords.append('UPDATE')
|
| 110 |
+
if 'DELETE' in sql_upper:
|
| 111 |
+
keywords.append('DELETE')
|
| 112 |
+
|
| 113 |
+
# Joins
|
| 114 |
+
if 'INNER JOIN' in sql_upper:
|
| 115 |
+
keywords.append('INNER JOIN')
|
| 116 |
+
elif 'LEFT JOIN' in sql_upper:
|
| 117 |
+
keywords.append('LEFT JOIN')
|
| 118 |
+
elif 'RIGHT JOIN' in sql_upper:
|
| 119 |
+
keywords.append('RIGHT JOIN')
|
| 120 |
+
elif 'JOIN' in sql_upper:
|
| 121 |
+
keywords.append('JOIN')
|
| 122 |
+
|
| 123 |
+
# Clauses
|
| 124 |
+
for clause in ['WHERE', 'GROUP BY', 'HAVING', 'ORDER BY', 'LIMIT']:
|
| 125 |
+
if clause in sql_upper:
|
| 126 |
+
keywords.append(clause)
|
| 127 |
+
|
| 128 |
+
# Aggregations
|
| 129 |
+
for agg in ['COUNT', 'SUM', 'AVG', 'MAX', 'MIN']:
|
| 130 |
+
if agg in sql_upper:
|
| 131 |
+
keywords.append(agg)
|
| 132 |
+
|
| 133 |
+
# Subqueries
|
| 134 |
+
if sql_upper.count('SELECT') > 1:
|
| 135 |
+
keywords.append('SUBQUERY')
|
| 136 |
+
|
| 137 |
+
return keywords
|
| 138 |
+
|
| 139 |
+
def calculate_chunk_size(text):
|
| 140 |
+
"""Calculate appropriate chunk size category."""
|
| 141 |
+
word_count = len(text.split())
|
| 142 |
+
|
| 143 |
+
if word_count <= 10:
|
| 144 |
+
return 'short'
|
| 145 |
+
elif word_count <= 25:
|
| 146 |
+
return 'medium'
|
| 147 |
+
else:
|
| 148 |
+
return 'long'
|
| 149 |
+
|
| 150 |
+
# =============================================================================
|
| 151 |
+
# DOCUMENT PREPARATION WITH CHUNKING
|
| 152 |
+
# =============================================================================
|
| 153 |
+
|
| 154 |
+
def prepare_documents_with_chunking(datasets):
|
| 155 |
+
"""
|
| 156 |
+
Prepare documents with chunking metadata.
|
| 157 |
+
Each document gets rich metadata for filtering/ranking.
|
| 158 |
+
"""
|
| 159 |
+
documents = []
|
| 160 |
+
metadatas = []
|
| 161 |
+
ids = []
|
| 162 |
+
|
| 163 |
+
idx = 0
|
| 164 |
+
for source, df in datasets.items():
|
| 165 |
+
for _, row in df.iterrows():
|
| 166 |
+
question = str(row['question'])
|
| 167 |
+
sql = str(row['sql'])
|
| 168 |
+
|
| 169 |
+
# Apply chunking strategies
|
| 170 |
+
sql_clauses = chunk_by_sql_clauses(sql)
|
| 171 |
+
complexity = chunk_by_complexity(question, sql)
|
| 172 |
+
keywords = extract_sql_keywords(sql)
|
| 173 |
+
q_size = calculate_chunk_size(question)
|
| 174 |
+
sql_size = calculate_chunk_size(sql)
|
| 175 |
+
|
| 176 |
+
# Create rich metadata
|
| 177 |
+
metadata = {
|
| 178 |
+
'sql': sql,
|
| 179 |
+
'source': source,
|
| 180 |
+
'question': question,
|
| 181 |
+
# Chunking metadata
|
| 182 |
+
'complexity': complexity,
|
| 183 |
+
'sql_clauses': ','.join(sql_clauses),
|
| 184 |
+
'keywords': ','.join(keywords),
|
| 185 |
+
'question_size': q_size,
|
| 186 |
+
'sql_size': sql_size,
|
| 187 |
+
'keyword_count': len(keywords),
|
| 188 |
+
'clause_count': len(sql_clauses),
|
| 189 |
+
}
|
| 190 |
+
|
| 191 |
+
documents.append(question)
|
| 192 |
+
metadatas.append(metadata)
|
| 193 |
+
ids.append(f"doc_{idx}")
|
| 194 |
+
idx += 1
|
| 195 |
+
|
| 196 |
+
return documents, metadatas, ids
|
| 197 |
+
|
| 198 |
+
# =============================================================================
|
| 199 |
+
# CHROMADB CLIENT
|
| 200 |
+
# =============================================================================
|
| 201 |
+
|
| 202 |
+
def get_chroma_client():
|
| 203 |
+
"""Get ChromaDB persistent client."""
|
| 204 |
+
return chromadb.PersistentClient(path=CHROMA_DIR)
|
| 205 |
+
|
| 206 |
+
def get_or_create_collection(client):
|
| 207 |
+
"""Get or create the SQL knowledge collection."""
|
| 208 |
+
return client.get_or_create_collection(
|
| 209 |
+
name=COLLECTION_NAME,
|
| 210 |
+
metadata={"description": "SQL question-answer pairs with chunking metadata"}
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
# =============================================================================
|
| 214 |
+
# DATA LOADING
|
| 215 |
+
# =============================================================================
|
| 216 |
+
|
| 217 |
+
def load_datasets(data_dir="data"):
|
| 218 |
+
"""Load ALL CSV datasets."""
|
| 219 |
+
datasets = {}
|
| 220 |
+
|
| 221 |
+
files = {
|
| 222 |
+
'train': 'train.csv',
|
| 223 |
+
'validation': 'validation.csv',
|
| 224 |
+
'test': 'test.csv'
|
| 225 |
+
# 'synthetic': 'synthetic.csv'
|
| 226 |
+
}
|
| 227 |
+
|
| 228 |
+
for name, filename in files.items():
|
| 229 |
+
filepath = os.path.join(data_dir, filename)
|
| 230 |
+
if os.path.exists(filepath):
|
| 231 |
+
df = pd.read_csv(filepath)
|
| 232 |
+
datasets[name] = df
|
| 233 |
+
print(f" Loaded {name}: {len(df):,} rows")
|
| 234 |
+
else:
|
| 235 |
+
print(f" Skipped {name}: file not found")
|
| 236 |
+
|
| 237 |
+
return datasets
|
| 238 |
+
|
| 239 |
+
# =============================================================================
|
| 240 |
+
# KNOWLEDGE BASE BUILDING
|
| 241 |
+
# =============================================================================
|
| 242 |
+
|
| 243 |
+
def build_knowledge_base(data_dir="data", batch_size=500):
|
| 244 |
+
"""Build knowledge base with chunking strategies."""
|
| 245 |
+
|
| 246 |
+
print("=" * 50)
|
| 247 |
+
print("BUILDING RAG KNOWLEDGE BASE")
|
| 248 |
+
print("With Chunking Strategies")
|
| 249 |
+
print("=" * 50)
|
| 250 |
+
|
| 251 |
+
setup_directories()
|
| 252 |
+
|
| 253 |
+
# Step 1: Load data
|
| 254 |
+
print(f"\n[1/5] Loading datasets...")
|
| 255 |
+
datasets = load_datasets(data_dir)
|
| 256 |
+
|
| 257 |
+
if not datasets:
|
| 258 |
+
print("ERROR: No datasets found!")
|
| 259 |
+
return None
|
| 260 |
+
|
| 261 |
+
total_rows = sum(len(df) for df in datasets.values())
|
| 262 |
+
print(f" Total rows: {total_rows:,}")
|
| 263 |
+
|
| 264 |
+
# Step 2: Prepare documents with chunking
|
| 265 |
+
print(f"\n[2/5] Applying chunking strategies...")
|
| 266 |
+
documents, metadatas, ids = prepare_documents_with_chunking(datasets)
|
| 267 |
+
print(f" Total documents: {len(documents):,}")
|
| 268 |
+
|
| 269 |
+
# Show chunking stats
|
| 270 |
+
complexities = [m['complexity'] for m in metadatas]
|
| 271 |
+
print(f" Complexity distribution:")
|
| 272 |
+
print(f" Simple: {complexities.count('simple'):,}")
|
| 273 |
+
print(f" Intermediate: {complexities.count('intermediate'):,}")
|
| 274 |
+
print(f" Complex: {complexities.count('complex'):,}")
|
| 275 |
+
|
| 276 |
+
# Step 3: Initialize ChromaDB
|
| 277 |
+
print(f"\n[3/5] Initializing ChromaDB...")
|
| 278 |
+
client = get_chroma_client()
|
| 279 |
+
|
| 280 |
+
try:
|
| 281 |
+
client.delete_collection(COLLECTION_NAME)
|
| 282 |
+
print(" Deleted existing collection")
|
| 283 |
+
except:
|
| 284 |
+
pass
|
| 285 |
+
|
| 286 |
+
collection = get_or_create_collection(client)
|
| 287 |
+
print(f" Collection: {COLLECTION_NAME}")
|
| 288 |
+
|
| 289 |
+
# Step 4: Generate embeddings and store
|
| 290 |
+
print(f"\n[4/5] Generating embeddings...")
|
| 291 |
+
|
| 292 |
+
total_added = 0
|
| 293 |
+
|
| 294 |
+
for i in range(0, len(documents), batch_size):
|
| 295 |
+
batch_docs = documents[i:i + batch_size]
|
| 296 |
+
batch_meta = metadatas[i:i + batch_size]
|
| 297 |
+
batch_ids = ids[i:i + batch_size]
|
| 298 |
+
|
| 299 |
+
embeddings = get_embeddings_batch(batch_docs)
|
| 300 |
+
|
| 301 |
+
if embeddings and embeddings[0] is not None:
|
| 302 |
+
collection.add(
|
| 303 |
+
documents=batch_docs,
|
| 304 |
+
metadatas=batch_meta,
|
| 305 |
+
ids=batch_ids,
|
| 306 |
+
embeddings=embeddings
|
| 307 |
+
)
|
| 308 |
+
total_added += len(batch_docs)
|
| 309 |
+
|
| 310 |
+
progress = min(i + batch_size, len(documents))
|
| 311 |
+
pct = (progress / len(documents)) * 100
|
| 312 |
+
print(f" Progress: {progress:,}/{len(documents):,} ({pct:.1f}%)")
|
| 313 |
+
|
| 314 |
+
# Step 5: Save statistics
|
| 315 |
+
print(f"\n[5/5] Saving statistics...")
|
| 316 |
+
stats = {
|
| 317 |
+
'total_documents': total_added,
|
| 318 |
+
'sources': {name: len(df) for name, df in datasets.items()},
|
| 319 |
+
'collection_name': COLLECTION_NAME,
|
| 320 |
+
'embedding_model': 'all-MiniLM-L6-v2',
|
| 321 |
+
'chunking_strategies': [
|
| 322 |
+
'sql_clause_extraction',
|
| 323 |
+
'complexity_classification',
|
| 324 |
+
'keyword_extraction',
|
| 325 |
+
'size_categorization'
|
| 326 |
+
],
|
| 327 |
+
'complexity_distribution': {
|
| 328 |
+
'simple': complexities.count('simple'),
|
| 329 |
+
'intermediate': complexities.count('intermediate'),
|
| 330 |
+
'complex': complexities.count('complex')
|
| 331 |
+
},
|
| 332 |
+
'created_at': datetime.now().isoformat()
|
| 333 |
+
}
|
| 334 |
+
|
| 335 |
+
with open(f'{STATS_DIR}/knowledge_base_stats.json', 'w') as f:
|
| 336 |
+
json.dump(stats, f, indent=2)
|
| 337 |
+
|
| 338 |
+
generate_report(stats)
|
| 339 |
+
|
| 340 |
+
print("\n" + "=" * 50)
|
| 341 |
+
print("COMPLETE")
|
| 342 |
+
print("=" * 50)
|
| 343 |
+
print(f" Documents indexed: {total_added:,}")
|
| 344 |
+
print(f" Storage: {CHROMA_DIR}/")
|
| 345 |
+
|
| 346 |
+
return collection
|
| 347 |
+
|
| 348 |
+
# =============================================================================
|
| 349 |
+
# REPORT GENERATION
|
| 350 |
+
# =============================================================================
|
| 351 |
+
|
| 352 |
+
def generate_report(stats):
|
| 353 |
+
"""Generate knowledge base report."""
|
| 354 |
+
|
| 355 |
+
report = f"""# RAG Knowledge Base Report
|
| 356 |
+
|
| 357 |
+
**Generated:** {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
|
| 358 |
+
|
| 359 |
+
## Overview
|
| 360 |
+
|
| 361 |
+
| Metric | Value |
|
| 362 |
+
|--------|-------|
|
| 363 |
+
| Total Documents | {stats['total_documents']:,} |
|
| 364 |
+
| Collection Name | {stats['collection_name']} |
|
| 365 |
+
| Embedding Model | {stats['embedding_model']} |
|
| 366 |
+
|
| 367 |
+
## Data Sources
|
| 368 |
+
|
| 369 |
+
| Source | Documents |
|
| 370 |
+
|--------|-----------|
|
| 371 |
+
"""
|
| 372 |
+
|
| 373 |
+
for source, count in stats['sources'].items():
|
| 374 |
+
report += f"| {source} | {count:,} |\n"
|
| 375 |
+
|
| 376 |
+
report += f"""
|
| 377 |
+
## Chunking Strategies
|
| 378 |
+
|
| 379 |
+
1. **SQL Clause Extraction**: Identifies SELECT, FROM, WHERE, GROUP BY, etc.
|
| 380 |
+
2. **Complexity Classification**: Categorizes as simple/intermediate/complex
|
| 381 |
+
3. **Keyword Extraction**: Extracts SQL operations (JOIN, COUNT, etc.)
|
| 382 |
+
4. **Size Categorization**: Classifies question/SQL length
|
| 383 |
+
|
| 384 |
+
## Complexity Distribution
|
| 385 |
+
|
| 386 |
+
| Level | Count |
|
| 387 |
+
|-------|-------|
|
| 388 |
+
| Simple | {stats['complexity_distribution']['simple']:,} |
|
| 389 |
+
| Intermediate | {stats['complexity_distribution']['intermediate']:,} |
|
| 390 |
+
| Complex | {stats['complexity_distribution']['complex']:,} |
|
| 391 |
+
|
| 392 |
+
## Document Metadata Structure
|
| 393 |
+
|
| 394 |
+
Each document contains:
|
| 395 |
+
- `sql`: The SQL query
|
| 396 |
+
- `source`: Origin dataset
|
| 397 |
+
- `question`: Original question
|
| 398 |
+
- `complexity`: simple/intermediate/complex
|
| 399 |
+
- `sql_clauses`: Comma-separated clauses
|
| 400 |
+
- `keywords`: SQL keywords found
|
| 401 |
+
- `question_size`: short/medium/long
|
| 402 |
+
- `sql_size`: short/medium/long
|
| 403 |
+
"""
|
| 404 |
+
|
| 405 |
+
with open(f'{REPORT_DIR}/knowledge_base_report.md', 'w') as f:
|
| 406 |
+
f.write(report)
|
| 407 |
+
|
| 408 |
+
print(f" Report saved to {REPORT_DIR}/")
|
| 409 |
+
|
| 410 |
+
# =============================================================================
|
| 411 |
+
# ENTRY POINT
|
| 412 |
+
# =============================================================================
|
| 413 |
+
|
| 414 |
+
if __name__ == "__main__":
|
| 415 |
+
build_knowledge_base(data_dir="data", batch_size=500)
|
src/rag/retriever.py
ADDED
|
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Retriever Module for RAG System
|
| 3 |
+
Loads from: Local ChromaDB OR HuggingFace Hub
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import sys
|
| 8 |
+
from dotenv import load_dotenv
|
| 9 |
+
|
| 10 |
+
load_dotenv()
|
| 11 |
+
|
| 12 |
+
# Try new imports first, fall back to old
|
| 13 |
+
try:
|
| 14 |
+
from langchain_huggingface import HuggingFaceEmbeddings
|
| 15 |
+
from langchain_chroma import Chroma
|
| 16 |
+
except ImportError:
|
| 17 |
+
from langchain_community.vectorstores import Chroma
|
| 18 |
+
from langchain_community.embeddings import HuggingFaceEmbeddings
|
| 19 |
+
|
| 20 |
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 21 |
+
|
| 22 |
+
# =============================================================================
|
| 23 |
+
# CONFIGURATION
|
| 24 |
+
# =============================================================================
|
| 25 |
+
|
| 26 |
+
LOCAL_CHROMADB_DIR = "chromadb_data"
|
| 27 |
+
HF_CHROMADB_ID = os.getenv("HF_CHROMADB_ID", None)
|
| 28 |
+
COLLECTION_NAME = "sql_knowledge"
|
| 29 |
+
EMBEDDING_MODEL = "all-MiniLM-L6-v2"
|
| 30 |
+
|
| 31 |
+
# =============================================================================
|
| 32 |
+
# CHROMADB LOADER
|
| 33 |
+
# =============================================================================
|
| 34 |
+
|
| 35 |
+
def ensure_chromadb_exists():
|
| 36 |
+
"""Ensure ChromaDB data exists - download from HF if needed."""
|
| 37 |
+
|
| 38 |
+
# Check if local has actual ChromaDB files (not just empty folder)
|
| 39 |
+
if os.path.exists(LOCAL_CHROMADB_DIR):
|
| 40 |
+
local_files = os.listdir(LOCAL_CHROMADB_DIR) if os.path.isdir(LOCAL_CHROMADB_DIR) else []
|
| 41 |
+
# ChromaDB creates files like chroma.sqlite3 or folders
|
| 42 |
+
has_chroma_files = any('chroma' in f.lower() or 'sqlite' in f.lower() for f in local_files) or len(local_files) > 2
|
| 43 |
+
|
| 44 |
+
if has_chroma_files:
|
| 45 |
+
print(f"📁 Using local ChromaDB: {LOCAL_CHROMADB_DIR}")
|
| 46 |
+
return LOCAL_CHROMADB_DIR
|
| 47 |
+
else:
|
| 48 |
+
print(f"⚠️ ChromaDB folder exists but is empty or incomplete")
|
| 49 |
+
|
| 50 |
+
# Download from HuggingFace
|
| 51 |
+
if HF_CHROMADB_ID:
|
| 52 |
+
print(f"☁️ Downloading ChromaDB from HuggingFace: {HF_CHROMADB_ID}")
|
| 53 |
+
from huggingface_hub import snapshot_download
|
| 54 |
+
|
| 55 |
+
# Create folder if not exists
|
| 56 |
+
os.makedirs(LOCAL_CHROMADB_DIR, exist_ok=True)
|
| 57 |
+
|
| 58 |
+
snapshot_download(
|
| 59 |
+
repo_id=HF_CHROMADB_ID,
|
| 60 |
+
repo_type="dataset",
|
| 61 |
+
local_dir=LOCAL_CHROMADB_DIR
|
| 62 |
+
)
|
| 63 |
+
print("✓ ChromaDB downloaded!")
|
| 64 |
+
return LOCAL_CHROMADB_DIR
|
| 65 |
+
|
| 66 |
+
# Need to build it from data
|
| 67 |
+
print("⚠️ ChromaDB not found and no HF_CHROMADB_ID set. Building from data...")
|
| 68 |
+
from rag.knowledge_base import build_knowledge_base
|
| 69 |
+
build_knowledge_base(data_dir="data", batch_size=500)
|
| 70 |
+
return LOCAL_CHROMADB_DIR
|
| 71 |
+
|
| 72 |
+
# =============================================================================
|
| 73 |
+
# LANGCHAIN EMBEDDINGS
|
| 74 |
+
# =============================================================================
|
| 75 |
+
|
| 76 |
+
def get_embeddings():
|
| 77 |
+
"""Get HuggingFace embeddings for LangChain."""
|
| 78 |
+
return HuggingFaceEmbeddings(
|
| 79 |
+
model_name=EMBEDDING_MODEL,
|
| 80 |
+
model_kwargs={'device': 'cpu'},
|
| 81 |
+
encode_kwargs={'normalize_embeddings': True}
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
# =============================================================================
|
| 85 |
+
# RANKING FUNCTIONS
|
| 86 |
+
# =============================================================================
|
| 87 |
+
|
| 88 |
+
def calculate_relevance_score(result, query):
|
| 89 |
+
"""Calculate enhanced relevance score."""
|
| 90 |
+
base_score = result.get('score', 0.5)
|
| 91 |
+
boost = 0.0
|
| 92 |
+
|
| 93 |
+
query_words = set(query.lower().split())
|
| 94 |
+
question_words = set(result.get('question', '').lower().split())
|
| 95 |
+
overlap = len(query_words & question_words)
|
| 96 |
+
if overlap > 0:
|
| 97 |
+
boost += 0.05 * min(overlap, 5)
|
| 98 |
+
|
| 99 |
+
query_length = len(query.split())
|
| 100 |
+
if query_length <= 8 and result.get('complexity') == 'simple':
|
| 101 |
+
boost += 0.1
|
| 102 |
+
elif query_length > 15 and result.get('complexity') == 'complex':
|
| 103 |
+
boost += 0.1
|
| 104 |
+
|
| 105 |
+
return base_score + boost
|
| 106 |
+
|
| 107 |
+
def rerank_results(results, query):
|
| 108 |
+
"""Re-rank results using enhanced relevance scoring."""
|
| 109 |
+
for r in results:
|
| 110 |
+
r['relevance_score'] = calculate_relevance_score(r, query)
|
| 111 |
+
results.sort(key=lambda x: x['relevance_score'], reverse=True)
|
| 112 |
+
return results
|
| 113 |
+
|
| 114 |
+
# =============================================================================
|
| 115 |
+
# FILTERING FUNCTIONS
|
| 116 |
+
# =============================================================================
|
| 117 |
+
|
| 118 |
+
def filter_by_threshold(results, min_score=0.0):
|
| 119 |
+
return [r for r in results if r.get('score', 0) >= min_score]
|
| 120 |
+
|
| 121 |
+
def filter_by_complexity(results, complexity=None):
|
| 122 |
+
if complexity is None:
|
| 123 |
+
return results
|
| 124 |
+
return [r for r in results if r.get('complexity') == complexity]
|
| 125 |
+
|
| 126 |
+
# =============================================================================
|
| 127 |
+
# SQL RETRIEVER CLASS
|
| 128 |
+
# =============================================================================
|
| 129 |
+
|
| 130 |
+
class SQLRetriever:
|
| 131 |
+
"""LangChain-based retriever with local/HuggingFace support."""
|
| 132 |
+
|
| 133 |
+
def __init__(self):
|
| 134 |
+
"""Initialize the retriever."""
|
| 135 |
+
print("Initializing SQL Retriever...")
|
| 136 |
+
|
| 137 |
+
# Ensure ChromaDB exists
|
| 138 |
+
chromadb_path = ensure_chromadb_exists()
|
| 139 |
+
|
| 140 |
+
# Load embeddings
|
| 141 |
+
self.embeddings = get_embeddings()
|
| 142 |
+
|
| 143 |
+
# Load ChromaDB
|
| 144 |
+
self.vectorstore = Chroma(
|
| 145 |
+
collection_name=COLLECTION_NAME,
|
| 146 |
+
persist_directory=chromadb_path,
|
| 147 |
+
embedding_function=self.embeddings
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
self.doc_count = self.vectorstore._collection.count()
|
| 151 |
+
print(f"✓ Loaded {self.doc_count:,} documents from {chromadb_path}")
|
| 152 |
+
|
| 153 |
+
def retrieve(self, query, top_k=5, min_score=None, complexity=None, rerank=True):
|
| 154 |
+
"""Retrieve similar questions with filtering and ranking."""
|
| 155 |
+
|
| 156 |
+
fetch_k = min(top_k * 3, 50)
|
| 157 |
+
docs_with_scores = self.vectorstore.similarity_search_with_score(query, k=fetch_k)
|
| 158 |
+
|
| 159 |
+
# Format results
|
| 160 |
+
formatted = []
|
| 161 |
+
for doc, score in docs_with_scores:
|
| 162 |
+
formatted.append({
|
| 163 |
+
'question': doc.page_content,
|
| 164 |
+
'sql': doc.metadata.get('sql', ''),
|
| 165 |
+
'source': doc.metadata.get('source', 'unknown'),
|
| 166 |
+
'complexity': doc.metadata.get('complexity', 'unknown'),
|
| 167 |
+
'keywords': doc.metadata.get('keywords', ''),
|
| 168 |
+
'sql_clauses': doc.metadata.get('sql_clauses', ''),
|
| 169 |
+
'distance': score,
|
| 170 |
+
'score': 1 - score if score <= 1 else 1 / (1 + score)
|
| 171 |
+
})
|
| 172 |
+
|
| 173 |
+
# Apply filters
|
| 174 |
+
if min_score is not None:
|
| 175 |
+
formatted = filter_by_threshold(formatted, min_score)
|
| 176 |
+
|
| 177 |
+
if complexity is not None:
|
| 178 |
+
formatted = filter_by_complexity(formatted, complexity)
|
| 179 |
+
|
| 180 |
+
# Apply re-ranking
|
| 181 |
+
if rerank:
|
| 182 |
+
formatted = rerank_results(formatted, query)
|
| 183 |
+
|
| 184 |
+
return formatted[:top_k]
|
| 185 |
+
|
| 186 |
+
def retrieve_as_context(self, query, top_k=5):
|
| 187 |
+
"""Retrieve and format as context for LLM prompt."""
|
| 188 |
+
results = self.retrieve(query, top_k=top_k)
|
| 189 |
+
|
| 190 |
+
if not results:
|
| 191 |
+
return ""
|
| 192 |
+
|
| 193 |
+
context = "Similar SQL examples:\n\n"
|
| 194 |
+
for i, r in enumerate(results, 1):
|
| 195 |
+
context += f"Example {i}:\n"
|
| 196 |
+
context += f"Question: {r['question']}\n"
|
| 197 |
+
context += f"SQL: {r['sql']}\n\n"
|
| 198 |
+
|
| 199 |
+
return context
|
| 200 |
+
|
| 201 |
+
def get_stats(self):
|
| 202 |
+
"""Get retriever statistics."""
|
| 203 |
+
return {
|
| 204 |
+
'total_documents': self.doc_count,
|
| 205 |
+
'collection_name': COLLECTION_NAME,
|
| 206 |
+
'embedding_model': EMBEDDING_MODEL,
|
| 207 |
+
}
|
| 208 |
+
|
| 209 |
+
# =============================================================================
|
| 210 |
+
# TEST
|
| 211 |
+
# =============================================================================
|
| 212 |
+
|
| 213 |
+
def test_retriever():
|
| 214 |
+
"""Test retriever."""
|
| 215 |
+
print("=" * 60)
|
| 216 |
+
print("TESTING SQL RETRIEVER")
|
| 217 |
+
print("=" * 60)
|
| 218 |
+
|
| 219 |
+
retriever = SQLRetriever()
|
| 220 |
+
|
| 221 |
+
query = "Find all employees with salary above 50000"
|
| 222 |
+
results = retriever.retrieve(query, top_k=3)
|
| 223 |
+
|
| 224 |
+
print(f"\nQuery: {query}\n")
|
| 225 |
+
for i, r in enumerate(results, 1):
|
| 226 |
+
print(f"Result {i}: (score: {r['score']:.3f})")
|
| 227 |
+
print(f" Q: {r['question'][:60]}...")
|
| 228 |
+
print(f" SQL: {r['sql'][:60]}...")
|
| 229 |
+
print()
|
| 230 |
+
|
| 231 |
+
print("✓ Test complete")
|
| 232 |
+
|
| 233 |
+
if __name__ == "__main__":
|
| 234 |
+
test_retriever()
|
src/requirements.txt
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
--extra-index-url https://download.pytorch.org/whl/cu130
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
streamlit
|
| 5 |
+
chromadb
|
| 6 |
+
google-generativeai
|
| 7 |
+
python-dotenv
|
| 8 |
+
pandas
|
| 9 |
+
datasets
|
| 10 |
+
transformers
|
| 11 |
+
peft
|
| 12 |
+
accelerate
|
| 13 |
+
bitsandbytes
|
| 14 |
+
torch
|
| 15 |
+
torchvision
|
| 16 |
+
torchaudio
|
| 17 |
+
sentencepiece
|
| 18 |
+
huggingface_hub
|
| 19 |
+
matplotlib
|
| 20 |
+
sentence-transformers
|
| 21 |
+
langchain
|
| 22 |
+
langchain-community
|
| 23 |
+
langchain-chroma
|
src/streamlit_app.py
DELETED
|
@@ -1,40 +0,0 @@
|
|
| 1 |
-
import altair as alt
|
| 2 |
-
import numpy as np
|
| 3 |
-
import pandas as pd
|
| 4 |
-
import streamlit as st
|
| 5 |
-
|
| 6 |
-
"""
|
| 7 |
-
# Welcome to Streamlit!
|
| 8 |
-
|
| 9 |
-
Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
|
| 10 |
-
If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
|
| 11 |
-
forums](https://discuss.streamlit.io).
|
| 12 |
-
|
| 13 |
-
In the meantime, below is an example of what you can do with just a few lines of code:
|
| 14 |
-
"""
|
| 15 |
-
|
| 16 |
-
num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
|
| 17 |
-
num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
|
| 18 |
-
|
| 19 |
-
indices = np.linspace(0, 1, num_points)
|
| 20 |
-
theta = 2 * np.pi * num_turns * indices
|
| 21 |
-
radius = indices
|
| 22 |
-
|
| 23 |
-
x = radius * np.cos(theta)
|
| 24 |
-
y = radius * np.sin(theta)
|
| 25 |
-
|
| 26 |
-
df = pd.DataFrame({
|
| 27 |
-
"x": x,
|
| 28 |
-
"y": y,
|
| 29 |
-
"idx": indices,
|
| 30 |
-
"rand": np.random.randn(num_points),
|
| 31 |
-
})
|
| 32 |
-
|
| 33 |
-
st.altair_chart(alt.Chart(df, height=700, width=700)
|
| 34 |
-
.mark_point(filled=True)
|
| 35 |
-
.encode(
|
| 36 |
-
x=alt.X("x", axis=None),
|
| 37 |
-
y=alt.Y("y", axis=None),
|
| 38 |
-
color=alt.Color("idx", legend=None, scale=alt.Scale()),
|
| 39 |
-
size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
|
| 40 |
-
))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/synthetic/__init__.py
ADDED
|
File without changes
|
src/synthetic/generate_data.py
ADDED
|
@@ -0,0 +1,401 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Synthetic Data Generation for SQL Learning Assistant
|
| 3 |
+
|
| 4 |
+
Covers:
|
| 5 |
+
1. Create synthetic datasets for training/testing
|
| 6 |
+
2. Implement data augmentation techniques
|
| 7 |
+
3. Ensure diversity and quality of generated data
|
| 8 |
+
4. Address privacy and ethical considerations
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import pandas as pd
|
| 12 |
+
import random
|
| 13 |
+
import re
|
| 14 |
+
import hashlib
|
| 15 |
+
import json
|
| 16 |
+
from collections import Counter
|
| 17 |
+
from datetime import datetime
|
| 18 |
+
import matplotlib.pyplot as plt
|
| 19 |
+
import os
|
| 20 |
+
import sys
|
| 21 |
+
|
| 22 |
+
# Add parent directory to path for imports
|
| 23 |
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 24 |
+
from synthetic.synonyms import SYNONYMS, get_synonym, has_synonym
|
| 25 |
+
|
| 26 |
+
# =============================================================================
|
| 27 |
+
# OUTPUT DIRECTORIES
|
| 28 |
+
# =============================================================================
|
| 29 |
+
|
| 30 |
+
OUTPUT_DIR = "outputs/synthetic"
|
| 31 |
+
VIZ_DIR = f"{OUTPUT_DIR}/visualizations"
|
| 32 |
+
REPORT_DIR = f"{OUTPUT_DIR}/reports"
|
| 33 |
+
STATS_DIR = f"{OUTPUT_DIR}/stats"
|
| 34 |
+
|
| 35 |
+
def setup_directories():
|
| 36 |
+
"""Create output directories."""
|
| 37 |
+
for d in [OUTPUT_DIR, VIZ_DIR, REPORT_DIR, STATS_DIR]:
|
| 38 |
+
os.makedirs(d, exist_ok=True)
|
| 39 |
+
|
| 40 |
+
# =============================================================================
|
| 41 |
+
# SENTENCE VARIATIONS
|
| 42 |
+
# =============================================================================
|
| 43 |
+
|
| 44 |
+
PREFIXES = ["", "Can you ", "Please ", "I want to ", "I need to ",
|
| 45 |
+
"Could you ", "Help me ", "Show me how to "]
|
| 46 |
+
|
| 47 |
+
SUFFIXES = ["", "?", " please", " for me", " please?"]
|
| 48 |
+
|
| 49 |
+
# =============================================================================
|
| 50 |
+
# AUGMENTATION TECHNIQUES
|
| 51 |
+
# =============================================================================
|
| 52 |
+
|
| 53 |
+
def replace_synonyms(text, prob=0.4):
|
| 54 |
+
"""Technique 1: Replace words with synonyms."""
|
| 55 |
+
words = text.split()
|
| 56 |
+
result = []
|
| 57 |
+
for word in words:
|
| 58 |
+
clean = re.sub(r'[^\w]', '', word).lower()
|
| 59 |
+
if has_synonym(clean) and random.random() < prob:
|
| 60 |
+
syn = get_synonym(clean)
|
| 61 |
+
result.append(syn if word[-1] not in '.,?!' else syn + word[-1])
|
| 62 |
+
else:
|
| 63 |
+
result.append(word)
|
| 64 |
+
return ' '.join(result)
|
| 65 |
+
|
| 66 |
+
def random_insertion(text, prob=0.15):
|
| 67 |
+
"""Technique 2: Insert contextual words."""
|
| 68 |
+
inserts = ["also", "specifically", "exactly", "just", "only"]
|
| 69 |
+
words = text.split()
|
| 70 |
+
if len(words) > 3 and random.random() < prob:
|
| 71 |
+
pos = random.randint(1, len(words) - 1)
|
| 72 |
+
words.insert(pos, random.choice(inserts))
|
| 73 |
+
return ' '.join(words)
|
| 74 |
+
|
| 75 |
+
def random_swap(text, prob=0.1):
|
| 76 |
+
"""Technique 3: Swap adjacent words."""
|
| 77 |
+
words = text.split()
|
| 78 |
+
if len(words) > 4 and random.random() < prob:
|
| 79 |
+
pos = random.randint(1, len(words) - 3)
|
| 80 |
+
words[pos], words[pos + 1] = words[pos + 1], words[pos]
|
| 81 |
+
return ' '.join(words)
|
| 82 |
+
|
| 83 |
+
def structure_variation(text):
|
| 84 |
+
"""Technique 4: Add prefixes and suffixes."""
|
| 85 |
+
prefix = random.choice(PREFIXES)
|
| 86 |
+
suffix = random.choice(SUFFIXES)
|
| 87 |
+
if prefix:
|
| 88 |
+
text = text[0].lower() + text[1:] if text else text
|
| 89 |
+
result = prefix + text + suffix
|
| 90 |
+
return result[0].upper() + result[1:] if result else result
|
| 91 |
+
|
| 92 |
+
def case_variation(text):
|
| 93 |
+
"""Technique 5: Vary capitalization."""
|
| 94 |
+
r = random.random()
|
| 95 |
+
if r < 0.6:
|
| 96 |
+
return text[0].upper() + text[1:].lower() if text else text
|
| 97 |
+
elif r < 0.85:
|
| 98 |
+
return text.lower()
|
| 99 |
+
return text
|
| 100 |
+
|
| 101 |
+
def generate_variation(question):
|
| 102 |
+
"""Apply all augmentation techniques."""
|
| 103 |
+
variation = question
|
| 104 |
+
variation = replace_synonyms(variation)
|
| 105 |
+
variation = random_insertion(variation)
|
| 106 |
+
variation = random_swap(variation)
|
| 107 |
+
variation = structure_variation(variation)
|
| 108 |
+
variation = case_variation(variation)
|
| 109 |
+
return variation
|
| 110 |
+
|
| 111 |
+
# =============================================================================
|
| 112 |
+
# QUALITY AND DIVERSITY
|
| 113 |
+
# =============================================================================
|
| 114 |
+
|
| 115 |
+
def diversity_score(original, variation):
|
| 116 |
+
"""Calculate diversity between original and variation."""
|
| 117 |
+
orig_words = set(original.lower().split())
|
| 118 |
+
var_words = set(variation.lower().split())
|
| 119 |
+
if not orig_words or not var_words:
|
| 120 |
+
return 0
|
| 121 |
+
intersection = orig_words & var_words
|
| 122 |
+
union = orig_words | var_words
|
| 123 |
+
return 1 - (len(intersection) / len(union))
|
| 124 |
+
|
| 125 |
+
def quality_check(question, sql):
|
| 126 |
+
"""Check if generated data passes quality standards."""
|
| 127 |
+
if not question or len(question.strip()) < 10:
|
| 128 |
+
return False
|
| 129 |
+
if not sql or len(sql.strip()) < 5:
|
| 130 |
+
return False
|
| 131 |
+
if not re.search(r'[a-zA-Z]', question):
|
| 132 |
+
return False
|
| 133 |
+
if len(question) > 500:
|
| 134 |
+
return False
|
| 135 |
+
return True
|
| 136 |
+
|
| 137 |
+
def remove_duplicates(data):
|
| 138 |
+
"""Remove duplicate entries."""
|
| 139 |
+
seen = set()
|
| 140 |
+
unique = []
|
| 141 |
+
for item in data:
|
| 142 |
+
normalized = re.sub(r'[^\w\s]', '', item['question'].lower())
|
| 143 |
+
normalized = ' '.join(normalized.split())
|
| 144 |
+
h = hashlib.md5(normalized.encode()).hexdigest()
|
| 145 |
+
if h not in seen:
|
| 146 |
+
seen.add(h)
|
| 147 |
+
unique.append(item)
|
| 148 |
+
return unique
|
| 149 |
+
|
| 150 |
+
# =============================================================================
|
| 151 |
+
# PRIVACY (ETHICAL CONSIDERATIONS)
|
| 152 |
+
# =============================================================================
|
| 153 |
+
|
| 154 |
+
def anonymize(text):
|
| 155 |
+
"""Remove sensitive information."""
|
| 156 |
+
text = re.sub(r'\b[\w.-]+@[\w.-]+\.\w+\b', '[EMAIL]', text)
|
| 157 |
+
text = re.sub(r'\b\d{3}[-.]?\d{3}[-.]?\d{4}\b', '[PHONE]', text)
|
| 158 |
+
text = re.sub(r'\b\d{3}-\d{2}-\d{4}\b', '[SSN]', text)
|
| 159 |
+
return text
|
| 160 |
+
|
| 161 |
+
# =============================================================================
|
| 162 |
+
# STATISTICS
|
| 163 |
+
# =============================================================================
|
| 164 |
+
|
| 165 |
+
def calculate_stats(original_df, synthetic_df):
|
| 166 |
+
"""Calculate dataset statistics."""
|
| 167 |
+
def get_stats(df, name):
|
| 168 |
+
questions = df['question'].tolist()
|
| 169 |
+
lengths = [len(q.split()) for q in questions]
|
| 170 |
+
return {
|
| 171 |
+
'name': name,
|
| 172 |
+
'samples': len(df),
|
| 173 |
+
'avg_length': round(sum(lengths) / len(lengths), 2),
|
| 174 |
+
'min_length': min(lengths),
|
| 175 |
+
'max_length': max(lengths),
|
| 176 |
+
'unique_words': len(set(' '.join(questions).lower().split()))
|
| 177 |
+
}
|
| 178 |
+
|
| 179 |
+
orig_stats = get_stats(original_df, 'Original')
|
| 180 |
+
synth_stats = get_stats(synthetic_df, 'Synthetic')
|
| 181 |
+
|
| 182 |
+
diversity_scores = synthetic_df['diversity_score'].tolist()
|
| 183 |
+
diversity_stats = {
|
| 184 |
+
'avg': round(sum(diversity_scores) / len(diversity_scores), 4),
|
| 185 |
+
'min': round(min(diversity_scores), 4),
|
| 186 |
+
'max': round(max(diversity_scores), 4)
|
| 187 |
+
}
|
| 188 |
+
|
| 189 |
+
return {
|
| 190 |
+
'original': orig_stats,
|
| 191 |
+
'synthetic': synth_stats,
|
| 192 |
+
'diversity': diversity_stats,
|
| 193 |
+
'augmentation_factor': round(len(synthetic_df) / len(original_df), 2)
|
| 194 |
+
}
|
| 195 |
+
|
| 196 |
+
# =============================================================================
|
| 197 |
+
# VISUALIZATIONS
|
| 198 |
+
# =============================================================================
|
| 199 |
+
|
| 200 |
+
def create_visualizations(original_df, synthetic_df):
|
| 201 |
+
"""Create and save visualizations."""
|
| 202 |
+
plt.style.use('seaborn-v0_8-whitegrid')
|
| 203 |
+
|
| 204 |
+
# 1. Dataset Size Comparison
|
| 205 |
+
fig, ax = plt.subplots(figsize=(8, 5))
|
| 206 |
+
sizes = [len(original_df), len(synthetic_df)]
|
| 207 |
+
bars = ax.bar(['Original', 'Synthetic'], sizes, color=['#3498db', '#2ecc71'])
|
| 208 |
+
for bar, size in zip(bars, sizes):
|
| 209 |
+
ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 20,
|
| 210 |
+
f'{size:,}', ha='center', fontweight='bold')
|
| 211 |
+
ax.set_ylabel('Samples')
|
| 212 |
+
ax.set_title('Dataset Size Comparison')
|
| 213 |
+
plt.savefig(f'{VIZ_DIR}/01_size_comparison.png', dpi=150, bbox_inches='tight')
|
| 214 |
+
plt.close()
|
| 215 |
+
|
| 216 |
+
# 2. Question Length Distribution
|
| 217 |
+
fig, axes = plt.subplots(1, 2, figsize=(12, 4))
|
| 218 |
+
orig_len = [len(q.split()) for q in original_df['question']]
|
| 219 |
+
synth_len = [len(q.split()) for q in synthetic_df['question']]
|
| 220 |
+
|
| 221 |
+
axes[0].hist(orig_len, bins=25, color='#3498db', alpha=0.7)
|
| 222 |
+
axes[0].set_title('Original - Question Length')
|
| 223 |
+
axes[0].set_xlabel('Words')
|
| 224 |
+
|
| 225 |
+
axes[1].hist(synth_len, bins=25, color='#2ecc71', alpha=0.7)
|
| 226 |
+
axes[1].set_title('Synthetic - Question Length')
|
| 227 |
+
axes[1].set_xlabel('Words')
|
| 228 |
+
|
| 229 |
+
plt.tight_layout()
|
| 230 |
+
plt.savefig(f'{VIZ_DIR}/02_length_distribution.png', dpi=150, bbox_inches='tight')
|
| 231 |
+
plt.close()
|
| 232 |
+
|
| 233 |
+
# 3. Diversity Score Distribution
|
| 234 |
+
fig, ax = plt.subplots(figsize=(8, 5))
|
| 235 |
+
ax.hist(synthetic_df['diversity_score'], bins=20, color='#9b59b6', alpha=0.7)
|
| 236 |
+
ax.axvline(synthetic_df['diversity_score'].mean(), color='red', linestyle='--',
|
| 237 |
+
label=f"Mean: {synthetic_df['diversity_score'].mean():.3f}")
|
| 238 |
+
ax.set_xlabel('Diversity Score')
|
| 239 |
+
ax.set_ylabel('Frequency')
|
| 240 |
+
ax.set_title('Diversity Score Distribution')
|
| 241 |
+
ax.legend()
|
| 242 |
+
plt.savefig(f'{VIZ_DIR}/03_diversity_distribution.png', dpi=150, bbox_inches='tight')
|
| 243 |
+
plt.close()
|
| 244 |
+
|
| 245 |
+
print(f" Visualizations saved to {VIZ_DIR}/")
|
| 246 |
+
|
| 247 |
+
# =============================================================================
|
| 248 |
+
# REPORT GENERATION
|
| 249 |
+
# =============================================================================
|
| 250 |
+
|
| 251 |
+
def generate_report(stats):
|
| 252 |
+
"""Generate markdown report."""
|
| 253 |
+
report = f"""# Synthetic Data Generation Report
|
| 254 |
+
|
| 255 |
+
**Generated:** {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
|
| 256 |
+
|
| 257 |
+
## Dataset Statistics
|
| 258 |
+
|
| 259 |
+
| Metric | Original | Synthetic |
|
| 260 |
+
|--------|----------|-----------|
|
| 261 |
+
| Samples | {stats['original']['samples']:,} | {stats['synthetic']['samples']:,} |
|
| 262 |
+
| Avg Length | {stats['original']['avg_length']} | {stats['synthetic']['avg_length']} |
|
| 263 |
+
| Min Length | {stats['original']['min_length']} | {stats['synthetic']['min_length']} |
|
| 264 |
+
| Max Length | {stats['original']['max_length']} | {stats['synthetic']['max_length']} |
|
| 265 |
+
| Unique Words | {stats['original']['unique_words']:,} | {stats['synthetic']['unique_words']:,} |
|
| 266 |
+
|
| 267 |
+
## Augmentation Results
|
| 268 |
+
|
| 269 |
+
- **Augmentation Factor:** {stats['augmentation_factor']}x
|
| 270 |
+
- **Avg Diversity Score:** {stats['diversity']['avg']}
|
| 271 |
+
- **Min Diversity Score:** {stats['diversity']['min']}
|
| 272 |
+
- **Max Diversity Score:** {stats['diversity']['max']}
|
| 273 |
+
|
| 274 |
+
## Techniques Used
|
| 275 |
+
|
| 276 |
+
1. Synonym Replacement (40% probability)
|
| 277 |
+
2. Random Insertion (15% probability)
|
| 278 |
+
3. Random Swap (10% probability)
|
| 279 |
+
4. Structure Variation (prefix/suffix)
|
| 280 |
+
5. Case Variation
|
| 281 |
+
|
| 282 |
+
## Quality Controls
|
| 283 |
+
|
| 284 |
+
- Minimum question length: 10 characters
|
| 285 |
+
- Maximum question length: 500 characters
|
| 286 |
+
- Minimum diversity score: 0.1
|
| 287 |
+
- Duplicate removal via MD5 hashing
|
| 288 |
+
|
| 289 |
+
## Privacy Measures
|
| 290 |
+
|
| 291 |
+
- Email anonymization
|
| 292 |
+
- Phone number anonymization
|
| 293 |
+
- SSN anonymization
|
| 294 |
+
|
| 295 |
+
## Visualizations
|
| 296 |
+
|
| 297 |
+
- `01_size_comparison.png` - Dataset size comparison
|
| 298 |
+
- `02_length_distribution.png` - Question length distribution
|
| 299 |
+
- `03_diversity_distribution.png` - Diversity score distribution
|
| 300 |
+
"""
|
| 301 |
+
|
| 302 |
+
with open(f'{REPORT_DIR}/synthetic_report.md', 'w') as f:
|
| 303 |
+
f.write(report)
|
| 304 |
+
print(f" Report saved to {REPORT_DIR}/synthetic_report.md")
|
| 305 |
+
|
| 306 |
+
# =============================================================================
|
| 307 |
+
# MAIN PIPELINE
|
| 308 |
+
# =============================================================================
|
| 309 |
+
|
| 310 |
+
def generate_synthetic_data(input_csv, output_csv, sample_size=500, variations=3, min_diversity=0.1):
|
| 311 |
+
"""Main synthetic data generation pipeline."""
|
| 312 |
+
|
| 313 |
+
print("=" * 50)
|
| 314 |
+
print("SYNTHETIC DATA GENERATION")
|
| 315 |
+
print("=" * 50)
|
| 316 |
+
|
| 317 |
+
# Setup
|
| 318 |
+
setup_directories()
|
| 319 |
+
|
| 320 |
+
# Load data
|
| 321 |
+
print(f"\n[1/6] Loading {input_csv}...")
|
| 322 |
+
df = pd.read_csv(input_csv)
|
| 323 |
+
sample_df = df.sample(n=min(sample_size, len(df)), random_state=42)
|
| 324 |
+
print(f" Sampled {len(sample_df)} rows")
|
| 325 |
+
|
| 326 |
+
# Generate variations
|
| 327 |
+
print(f"\n[2/6] Generating variations...")
|
| 328 |
+
synthetic_data = []
|
| 329 |
+
skipped = 0
|
| 330 |
+
|
| 331 |
+
for _, row in sample_df.iterrows():
|
| 332 |
+
question = anonymize(str(row['question']))
|
| 333 |
+
sql = anonymize(str(row['sql']))
|
| 334 |
+
|
| 335 |
+
for _ in range(variations):
|
| 336 |
+
variation = generate_variation(question)
|
| 337 |
+
div_score = diversity_score(question, variation)
|
| 338 |
+
|
| 339 |
+
if div_score < min_diversity or not quality_check(variation, sql):
|
| 340 |
+
skipped += 1
|
| 341 |
+
continue
|
| 342 |
+
|
| 343 |
+
synthetic_data.append({
|
| 344 |
+
'question': variation,
|
| 345 |
+
'sql': sql,
|
| 346 |
+
'original_question': question,
|
| 347 |
+
'diversity_score': round(div_score, 3),
|
| 348 |
+
'is_synthetic': True
|
| 349 |
+
})
|
| 350 |
+
|
| 351 |
+
print(f" Generated: {len(synthetic_data)}, Skipped: {skipped}")
|
| 352 |
+
|
| 353 |
+
# Remove duplicates
|
| 354 |
+
print(f"\n[3/6] Removing duplicates...")
|
| 355 |
+
before = len(synthetic_data)
|
| 356 |
+
synthetic_data = remove_duplicates(synthetic_data)
|
| 357 |
+
print(f" Removed {before - len(synthetic_data)} duplicates")
|
| 358 |
+
|
| 359 |
+
# Save data
|
| 360 |
+
print(f"\n[4/6] Saving data...")
|
| 361 |
+
synthetic_df = pd.DataFrame(synthetic_data)
|
| 362 |
+
synthetic_df.to_csv(output_csv, index=False)
|
| 363 |
+
print(f" Saved to {output_csv}")
|
| 364 |
+
|
| 365 |
+
# Calculate stats
|
| 366 |
+
print(f"\n[5/6] Calculating statistics...")
|
| 367 |
+
stats = calculate_stats(sample_df, synthetic_df)
|
| 368 |
+
|
| 369 |
+
# Save stats as JSON
|
| 370 |
+
with open(f'{STATS_DIR}/statistics.json', 'w') as f:
|
| 371 |
+
json.dump(stats, f, indent=2)
|
| 372 |
+
print(f" Stats saved to {STATS_DIR}/statistics.json")
|
| 373 |
+
|
| 374 |
+
# Generate visualizations and report
|
| 375 |
+
print(f"\n[6/6] Creating outputs...")
|
| 376 |
+
create_visualizations(sample_df, synthetic_df)
|
| 377 |
+
generate_report(stats)
|
| 378 |
+
|
| 379 |
+
# Summary
|
| 380 |
+
print("\n" + "=" * 50)
|
| 381 |
+
print("COMPLETE")
|
| 382 |
+
print("=" * 50)
|
| 383 |
+
print(f" Original: {stats['original']['samples']:,} samples")
|
| 384 |
+
print(f" Synthetic: {stats['synthetic']['samples']:,} samples")
|
| 385 |
+
print(f" Augmentation: {stats['augmentation_factor']}x")
|
| 386 |
+
print(f" Avg Diversity: {stats['diversity']['avg']}")
|
| 387 |
+
|
| 388 |
+
return synthetic_df
|
| 389 |
+
|
| 390 |
+
# =============================================================================
|
| 391 |
+
# ENTRY POINT
|
| 392 |
+
# =============================================================================
|
| 393 |
+
|
| 394 |
+
if __name__ == "__main__":
|
| 395 |
+
generate_synthetic_data(
|
| 396 |
+
input_csv="data/train.csv",
|
| 397 |
+
output_csv="data/synthetic.csv",
|
| 398 |
+
sample_size=52527,
|
| 399 |
+
variations=3,
|
| 400 |
+
min_diversity=0.1
|
| 401 |
+
)
|
src/synthetic/synonyms.py
ADDED
|
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Synonym Dictionary for SQL Question Augmentation
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import random
|
| 6 |
+
|
| 7 |
+
# =============================================================================
|
| 8 |
+
# SYNONYMS BY CATEGORY
|
| 9 |
+
# =============================================================================
|
| 10 |
+
|
| 11 |
+
# Action Verbs
|
| 12 |
+
QUERY_VERBS = {
|
| 13 |
+
"find": ["get", "show", "display", "list", "retrieve", "fetch", "return", "select"],
|
| 14 |
+
"show": ["display", "list", "get", "find", "retrieve", "present"],
|
| 15 |
+
"get": ["find", "show", "display", "list", "retrieve", "fetch", "obtain"],
|
| 16 |
+
"list": ["show", "display", "get", "find", "enumerate"],
|
| 17 |
+
"retrieve": ["get", "fetch", "find", "obtain", "extract"],
|
| 18 |
+
"select": ["choose", "pick", "get", "retrieve", "find"],
|
| 19 |
+
"search": ["find", "look for", "look up", "query"],
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
CALCULATION_VERBS = {
|
| 23 |
+
"calculate": ["compute", "determine", "find", "figure out", "work out"],
|
| 24 |
+
"compute": ["calculate", "determine", "figure out"],
|
| 25 |
+
"count": ["tally", "enumerate", "number", "total up"],
|
| 26 |
+
"sum": ["total", "add up", "aggregate"],
|
| 27 |
+
"average": ["mean", "find the average of"],
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
MANIPULATION_VERBS = {
|
| 31 |
+
"sort": ["order", "arrange", "rank", "organize"],
|
| 32 |
+
"filter": ["narrow down", "limit", "restrict", "select"],
|
| 33 |
+
"group": ["categorize", "organize", "cluster", "aggregate"],
|
| 34 |
+
"join": ["combine", "merge", "connect", "link"],
|
| 35 |
+
"update": ["modify", "change", "edit", "alter"],
|
| 36 |
+
"delete": ["remove", "erase", "drop", "eliminate"],
|
| 37 |
+
"insert": ["add", "create", "put", "include"],
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
# Comparison Terms
|
| 41 |
+
COMPARISONS = {
|
| 42 |
+
"greater than": ["more than", "above", "exceeding", "over", "higher than"],
|
| 43 |
+
"less than": ["below", "under", "fewer than", "smaller than", "lower than"],
|
| 44 |
+
"equal to": ["equals", "is", "matching", "same as"],
|
| 45 |
+
"between": ["in the range of", "ranging from", "within"],
|
| 46 |
+
"contains": ["includes", "has", "with"],
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
# Aggregation Terms
|
| 50 |
+
AGGREGATIONS = {
|
| 51 |
+
"maximum": ["highest", "largest", "greatest", "max", "top"],
|
| 52 |
+
"minimum": ["lowest", "smallest", "least", "min", "bottom"],
|
| 53 |
+
"average": ["mean", "avg"],
|
| 54 |
+
"total": ["sum", "combined", "overall", "aggregate"],
|
| 55 |
+
"count": ["number of", "how many", "total number of"],
|
| 56 |
+
"distinct": ["unique", "different", "separate"],
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
# Business Entities
|
| 60 |
+
ENTITIES = {
|
| 61 |
+
"employees": ["workers", "staff", "personnel", "team members"],
|
| 62 |
+
"customers": ["clients", "users", "buyers", "patrons"],
|
| 63 |
+
"products": ["items", "goods", "merchandise"],
|
| 64 |
+
"orders": ["purchases", "transactions", "sales"],
|
| 65 |
+
"suppliers": ["vendors", "providers", "distributors"],
|
| 66 |
+
"company": ["firm", "organization", "business"],
|
| 67 |
+
"department": ["dept", "division", "section", "unit"],
|
| 68 |
+
"manager": ["supervisor", "boss", "lead", "head"],
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
# Financial Terms
|
| 72 |
+
FINANCIAL = {
|
| 73 |
+
"price": ["cost", "amount", "value", "rate"],
|
| 74 |
+
"salary": ["pay", "wage", "income", "earnings"],
|
| 75 |
+
"revenue": ["income", "earnings", "sales"],
|
| 76 |
+
"profit": ["earnings", "gain", "margin"],
|
| 77 |
+
"cost": ["price", "expense", "charge"],
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
# Time Terms
|
| 81 |
+
TIME_TERMS = {
|
| 82 |
+
"date": ["day", "time", "period"],
|
| 83 |
+
"year": ["annum", "calendar year"],
|
| 84 |
+
"month": ["period", "calendar month"],
|
| 85 |
+
"recent": ["latest", "newest", "most recent"],
|
| 86 |
+
"current": ["present", "existing", "active"],
|
| 87 |
+
"previous": ["prior", "former", "past", "earlier"],
|
| 88 |
+
"last": ["final", "most recent", "latest"],
|
| 89 |
+
"first": ["initial", "earliest", "beginning"],
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
# Quantifiers
|
| 93 |
+
QUANTIFIERS = {
|
| 94 |
+
"all": ["every", "each", "the entire", "complete"],
|
| 95 |
+
"some": ["a few", "certain", "several"],
|
| 96 |
+
"many": ["numerous", "multiple", "several"],
|
| 97 |
+
"few": ["some", "a small number of", "limited"],
|
| 98 |
+
"only": ["just", "solely", "exclusively"],
|
| 99 |
+
}
|
| 100 |
+
|
| 101 |
+
# Adjectives
|
| 102 |
+
ADJECTIVES = {
|
| 103 |
+
"highest": ["greatest", "maximum", "largest", "top"],
|
| 104 |
+
"lowest": ["smallest", "minimum", "least", "bottom"],
|
| 105 |
+
"active": ["current", "live", "enabled"],
|
| 106 |
+
"inactive": ["disabled", "dormant", "idle"],
|
| 107 |
+
"new": ["recent", "latest", "fresh"],
|
| 108 |
+
"old": ["previous", "former", "past"],
|
| 109 |
+
}
|
| 110 |
+
|
| 111 |
+
# =============================================================================
|
| 112 |
+
# COMBINED DICTIONARY
|
| 113 |
+
# =============================================================================
|
| 114 |
+
|
| 115 |
+
def get_all_synonyms():
|
| 116 |
+
"""Combine all synonym dictionaries."""
|
| 117 |
+
all_synonyms = {}
|
| 118 |
+
for d in [QUERY_VERBS, CALCULATION_VERBS, MANIPULATION_VERBS,
|
| 119 |
+
COMPARISONS, AGGREGATIONS, ENTITIES, FINANCIAL,
|
| 120 |
+
TIME_TERMS, QUANTIFIERS, ADJECTIVES]:
|
| 121 |
+
all_synonyms.update(d)
|
| 122 |
+
return all_synonyms
|
| 123 |
+
|
| 124 |
+
SYNONYMS = get_all_synonyms()
|
| 125 |
+
|
| 126 |
+
# =============================================================================
|
| 127 |
+
# UTILITY FUNCTIONS
|
| 128 |
+
# =============================================================================
|
| 129 |
+
|
| 130 |
+
def get_synonym(word):
|
| 131 |
+
"""Get a random synonym for a word."""
|
| 132 |
+
word_lower = word.lower()
|
| 133 |
+
if word_lower in SYNONYMS:
|
| 134 |
+
return random.choice(SYNONYMS[word_lower])
|
| 135 |
+
return word
|
| 136 |
+
|
| 137 |
+
def has_synonym(word):
|
| 138 |
+
"""Check if a word has synonyms."""
|
| 139 |
+
return word.lower() in SYNONYMS
|
| 140 |
+
|
| 141 |
+
def print_stats():
|
| 142 |
+
"""Print synonym statistics."""
|
| 143 |
+
total_words = len(SYNONYMS)
|
| 144 |
+
total_synonyms = sum(len(v) for v in SYNONYMS.values())
|
| 145 |
+
print(f"Total words: {total_words}")
|
| 146 |
+
print(f"Total synonyms: {total_synonyms}")
|
| 147 |
+
|
| 148 |
+
if __name__ == "__main__":
|
| 149 |
+
print_stats()
|
src/tests/test_finetuned.py
ADDED
|
File without changes
|
src/tests/test_rag.py
ADDED
|
File without changes
|
src/tests/test_synthetic.py
ADDED
|
File without changes
|