| """ |
| GRADIO DEMO UI - LAZY LOADING EDITION |
| NL → SQL → Result Table |
| """ |
|
|
| import gradio as gr |
| import pandas as pd |
| import re |
| import time |
| import os |
| import torch |
| import sys |
| import json |
| import subprocess |
| import base64 |
| import io |
| from pathlib import Path |
| from typing import Iterator |
|
|
| |
| |
| |
| try: |
| PROJECT_ROOT = Path(__file__).resolve().parent |
| except NameError: |
| PROJECT_ROOT = Path(".").resolve() |
|
|
| if (PROJECT_ROOT / "data" / "database").exists(): |
| DB_ROOT = PROJECT_ROOT / "data" / "database" |
| else: |
| DB_ROOT = PROJECT_ROOT / "final_databases" |
|
|
| def get_db_path(db_id: str) -> str: |
| path1 = DB_ROOT / db_id / f"{db_id}.sqlite" |
| path2 = DB_ROOT / f"{db_id}.sqlite" |
| return str(path1) if path1.exists() else str(path2) |
|
|
| |
| |
| |
| if not torch.cuda.is_available(): |
| class MockCUDAEvent: |
| def __init__(self, enable_timing=False, blocking=False, interprocess=False): |
| self.t = 0.0 |
| def record(self, stream=None): |
| self.t = time.perf_counter() |
| def elapsed_time(self, end_event): |
| return (end_event.t - self.t) * 1000.0 |
|
|
| torch.cuda.Event = MockCUDAEvent |
| if not hasattr(torch.cuda, 'synchronize'): |
| torch.cuda.synchronize = lambda: None |
|
|
| |
| |
| |
| from src.quantized_text2sql_engine import QuantizedText2SQLEngine |
| from src.schema_encoder import SchemaEncoder |
|
|
| DEFAULT_QUANT_ARTIFACT = str(PROJECT_ROOT / "int8_dynamic") |
|
|
| _ENGINE_CACHE = {} |
| _QUERY_LOG = [] |
| _PERF_LOG = [] |
| _SUCCESS_LOG = [] |
|
|
| _OP_STATS = { |
| "SELECT": {"ok": 0, "fail": 0}, "WHERE": {"ok": 0, "fail": 0}, "JOIN": {"ok": 0, "fail": 0}, |
| "GROUP_BY": {"ok": 0, "fail": 0}, "ORDER_BY": {"ok": 0, "fail": 0}, "HAVING": {"ok": 0, "fail": 0}, "LIMIT": {"ok": 0, "fail": 0}, |
| } |
|
|
| def get_quant_engine(artifact_dir: str, use_constrained: bool = False, exec_workers: int = 8, use_cache: bool = True): |
| key = (artifact_dir, bool(use_constrained), int(exec_workers), bool(use_cache)) |
| if key not in _ENGINE_CACHE: |
| try: |
| _ENGINE_CACHE[key] = QuantizedText2SQLEngine(artifact_dir, device="cpu", use_constrained=bool(use_constrained), exec_workers=int(exec_workers), use_cache=bool(use_cache)) |
| except TypeError: |
| _ENGINE_CACHE[key] = QuantizedText2SQLEngine(artifact_dir) |
| return _ENGINE_CACHE[key] |
|
|
| |
| quant_engine = None |
| try: |
| schema_encoder = SchemaEncoder(DB_ROOT) |
| except Exception as e: |
| print(f"Warning: SchemaEncoder failed to load: {e}") |
| schema_encoder = None |
|
|
| SAMPLES = [ |
| ("Show 10 distinct employee first names.", "chinook_1"), ("Which artist has the most albums?", "chinook_1"), |
| ("List all the tracks that belong to the 'Rock' genre.", "chinook_1"), ("What are the names of all the cities?", "flight_1"), |
| ("Find the flight number and cost of the cheapest flight.", "flight_1"), ("List the airlines that fly out of New York.", "flight_1"), |
| ("Which campus was opened between 1935 and 1939?", "csu_1"), ("Count the number of students in each department.", "college_2"), |
| ("List the names of all clubs.", "club_1"), ("How many members does each club have?", "club_1"), |
| ("Show the names of all cinemas.", "cinema"), ("Which cinema has the most screens?", "cinema") |
| ] |
| SAMPLE_QUESTIONS = [q[0] for q in SAMPLES] |
|
|
| def explain_sql(sql): |
| if not sql: return "" |
| explanation = "This SQL query retrieves information from the database." |
| sql_lower = sql.lower() |
| if "join" in sql_lower: explanation += "\n• It combines data from multiple tables using JOIN." |
| if "where" in sql_lower: explanation += "\n• It filters rows using a WHERE condition." |
| if "group by" in sql_lower: explanation += "\n• It groups results using GROUP BY." |
| if "order by" in sql_lower: explanation += "\n• It sorts the results using ORDER BY." |
| if "limit" in sql_lower: explanation += "\n• It limits the number of returned rows." |
| return explanation |
|
|
| def sql_ops(sql: str) -> list[str]: |
| s = (sql or "").lower() |
| ops = ["SELECT"] |
| if " where " in f" {s} ": ops.append("WHERE") |
| if " join " in f" {s} ": ops.append("JOIN") |
| if " group by " in f" {s} ": ops.append("GROUP_BY") |
| if " order by " in f" {s} ": ops.append("ORDER_BY") |
| if " having " in f" {s} ": ops.append("HAVING") |
| if " limit " in f" {s} ": ops.append("LIMIT") |
| return ops |
|
|
| def classify_error(sql: str, error_msg: str | None = None, *, timed_out: bool = False): |
| s = (sql or "").lower() |
| m = (error_msg or "").lower() |
| if timed_out or "interrupted" in m or "timeout" in m: return "timeout" |
| if not s.strip().startswith(("select", "with")): return "syntax_error" |
| if " join " in f" {s} " and " on " not in f" {s} ": return "missing_join" |
| if " where " in f" {s} " and not any(op in s for op in ["=", ">", "<", " in ", " like ", " between ", " is null", " is not null"]): return "wrong_where" |
| if ("is null" in s or "is not null" in s) and ("no such column" in m or "misuse" in m): return "null_handling" |
| if "no such table" in m: return "missing_table" |
| if "no such column" in m: return "missing_column" |
| if "ambiguous column name" in m: return "ambiguous_column" |
| if "datatype mismatch" in m or "type mismatch" in m: return "type_mismatch" |
| if "misuse of aggregate" in m or "misuse of aggregate function" in m: return "wrong_aggregation" |
| if "syntax error" in m: return "syntax_error" |
| if "near" in m and "syntax error" in m: return "syntax_error" |
| if "runtime" in m or "constraint failed" in m: return "runtime_error" |
| return "other" |
|
|
| def get_hint(error_type): |
| hints = { |
| "missing_join": "Check JOIN conditions between tables.", "wrong_aggregation": "Use proper aggregation like avg(column).", |
| "wrong_where": "Check WHERE condition syntax.", "syntax_error": "Ensure SQL starts with SELECT.", |
| "missing_table": "Use only tables from the provided schema.", "missing_column": "Use only columns from the provided schema.", |
| "ambiguous_column": "Disambiguate by using table.column.", "timeout": "Query took too long; simplify joins.", "other": "Review SQL logic." |
| } |
| return hints.get(error_type, "Review query.") |
|
|
| def is_relevant_to_schema(question, db_id): |
| if schema_encoder is None: return True |
| try: raw_schema = schema_encoder.structured_schema(db_id).lower() |
| except: return True |
| schema_words = set(re.findall(r'[a-z0-9_]+', raw_schema)) |
| q_words = re.findall(r'[a-z0-9_]+', question.lower()) |
| stop_words = {"show", "list", "all", "what", "is", "the", "how", "many", "count", "find", "get", "me", "a", "an", "of", "in", "for", "from", "with", "which", "are", "there", "give", "tell", "details", "info", "data", "everything"} |
| meaningful_q_words = [w for w in q_words if w not in stop_words and not w.isdigit()] |
| if not meaningful_q_words: return True |
| for word in meaningful_q_words: |
| singular_word = word[:-1] if word.endswith('s') else word |
| if word in schema_words or singular_word in schema_words: return True |
| return False |
|
|
| def run_query(method, sample_q, custom_q, db_id): |
| global quant_engine |
| |
| |
| if quant_engine is None: |
| print(f"First request detected! Loading AI model from {DEFAULT_QUANT_ARTIFACT}...", flush=True) |
| try: |
| quant_engine = get_quant_engine(DEFAULT_QUANT_ARTIFACT, use_constrained=False, exec_workers=8, use_cache=True) |
| if quant_engine is None: |
| return "-- ❌ ENGINE CRASH", pd.DataFrame(columns=["Error"]), "Failed to load model. Did you move the tokenizer files and add config.json to int8_dynamic/?" |
| except Exception as e: |
| return f"-- ❌ ENGINE CRASH\n-- {str(e)}", pd.DataFrame(columns=["Error Status"]), f"Critical failure loading model: {e}" |
|
|
| def _log(error_type: str, *, question: str, db_id_val: str, sql: str = "", error_msg: str = "") -> None: |
| _QUERY_LOG.append({"t": time.time(), "db_id": str(db_id_val), "question": str(question), "sql": str(sql), "error_type": str(error_type), "error_msg": str(error_msg)}) |
|
|
| def _perf_log(payload: dict) -> None: |
| _PERF_LOG.append(payload) |
| if len(_PERF_LOG) > 1000: del _PERF_LOG[:200] |
|
|
| raw_question = sample_q if method == "💡 Pick a Sample" else custom_q |
|
|
| if not raw_question or str(raw_question).strip() == "": |
| return "-- No input provided", pd.DataFrame(columns=["Warning"]), "⚠️ Please enter a question." |
| if not db_id or str(db_id).strip() == "": |
| return "-- No database selected", pd.DataFrame(columns=["Warning"]), "⚠️ Please select a database." |
|
|
| typo_corrections = [(r'\bshaw\b', 'show'), (r'\bshw\b', 'show'), (r'\bsho\b', 'show'), (r'\blsit\b', 'list'), (r'\blis\b', 'list'), (r'\bfidn\b', 'find'), (r'\bfnd\b', 'find'), (r'\bgte\b', 'get')] |
| question = str(raw_question) |
| for bad, good in typo_corrections: question = re.sub(bad, good, question, flags=re.IGNORECASE) |
| q_lower = question.strip().lower() |
|
|
| if len(q_lower.split()) < 2: |
| _log("gibberish", question=question, db_id_val=str(db_id), error_msg="gibberish filtered") |
| return "-- Input Blocked", pd.DataFrame(columns=["Warning"]), "⚠️ Please enter a clear, meaningful natural language question (more than one word)." |
|
|
| if re.search(r'\b(delete|update|insert|drop|alter|truncate)\b', q_lower): |
| _log("blocked_dml", question=question, db_id_val=str(db_id), error_msg="DML blocked") |
| return "-- ❌ BLOCKED: Data Modification", pd.DataFrame(columns=["Security Alert"]), "🛑 Security Alert: Modifying or deleting data is strictly prohibited." |
|
|
| if not is_relevant_to_schema(question, db_id): |
| _log("out_of_domain", question=question, db_id_val=str(db_id), error_msg="out of domain") |
| return "-- ❌ BLOCKED: Out of Domain", pd.DataFrame(columns=["Domain Alert"]), f"🛑 Relevance Alert: I don't see anything related to your question in the '{db_id}' schema." |
|
|
| start_time = time.time() |
| t0 = time.perf_counter() |
| ui_warnings = "" |
|
|
| try: |
| try: |
| result = quant_engine.ask(question, str(db_id), num_beams=4, max_new_tokens=120, timeout_s=2.0) |
| except TypeError: |
| result = quant_engine.ask(question, str(db_id)) |
| except Exception as e: |
| _log("backend_crash", question=question, db_id_val=str(db_id), error_msg=str(e)) |
| return f"-- ❌ BACKEND CRASH\n-- {str(e)}", pd.DataFrame(columns=["Error Status"]), f"❌ CRITICAL BACKEND CRASH:\n{str(e)}" |
|
|
| final_sql = str(result.get("sql", "")) |
| model_sql = final_sql |
| |
| num_match = re.search(r'\b(?:show|list|top|limit|get|first|last|sample|of)\s+(?:[a-zA-Z_]+\s+)?(\d+)\b', q_lower) |
| if not num_match and q_lower.startswith(("show", "list", "get")): |
| num_match = re.search(r'\b(\d+)\b', q_lower) |
|
|
| if num_match and final_sql: |
| limit_val = num_match.group(1) |
| final_sql = re.sub(rf"(?i)\s*(?:where|having|and)?\s*count\s*\(\s*\*\s*\)\s*=\s*{limit_val}", "", final_sql) |
| final_sql = re.sub(rf"(?i)\s*(?:where|and)\s+[a-zA-Z0-9_.]+\s*=\s*['\"]?{limit_val}['\"]?", "", final_sql) |
| final_sql = re.sub(r"(?i)\s*where\s*$", "", final_sql) |
| final_sql = re.sub(r"(?i)\s*where\s+(group by|order by|limit)", r" \1", final_sql) |
| |
| agg_kws = ["most", "top", "highest", "lowest", "count", "many", "group", "frequent", "popular"] |
| if not any(k in q_lower for k in agg_kws): |
| final_sql = re.sub(r"(?i)\s*group by\s+[a-zA-Z0-9_.]+\s*order by\s+count\(\*\)\s*(?:desc|asc)?", "", final_sql) |
| final_sql = re.sub(r"(?i)\s*order by\s+count\(\*\)\s*(?:desc|asc)?", "", final_sql) |
| final_sql = re.sub(r"(?i),\s*count\(\*\)", "", final_sql) |
| final_sql = re.sub(r"(?i)count\(\*\)\s*,", "", final_sql) |
| |
| if "group by" in final_sql.lower() and not re.search(r'(?i)\b(count|sum|avg|max|min)\b\(', final_sql): |
| final_sql = re.sub(r"(?i)\s*group by\s+[a-zA-Z0-9_.]+", "", final_sql) |
| |
| if "limit" not in final_sql.lower(): |
| final_sql = f"{final_sql.strip().rstrip(';')} LIMIT {limit_val}" |
|
|
| |
| from src.sql_validator import validate_sql_schema |
| db_path = get_db_path(str(db_id)) |
| |
| try: strict_valid, _ = validate_sql_schema(final_sql, db_path) |
| except Exception: strict_valid = False |
|
|
| error_msg = None |
| rows, cols = [], [] |
| sqlite_success = False |
|
|
| try: |
| rows, cols = quant_engine._execute_one(final_sql, db_path, timeout_s=2.0) |
| sqlite_success = True |
| except Exception as e: |
| error_msg = str(e) |
| sqlite_success = False |
|
|
| if not sqlite_success and model_sql and model_sql != final_sql: |
| try: |
| alt_rows, alt_cols = quant_engine._execute_one(model_sql, db_path, timeout_s=2.0) |
| final_sql = model_sql |
| rows, cols = alt_rows, alt_cols |
| sqlite_success = True |
| error_msg = None |
| except Exception: pass |
|
|
| valid = sqlite_success |
|
|
| if error_msg or not valid: |
| et = classify_error(final_sql, str(error_msg or ""), timed_out=("interrupted" in str(error_msg or "").lower())) |
| _log(et, question=str(question), db_id_val=str(db_id), sql=str(final_sql), error_msg=str(error_msg or "Execution failed")) |
|
|
| latency = round(time.time() - start_time, 3) |
| t1 = time.perf_counter() |
| |
| engine_stats_after = quant_engine.stats() if hasattr(quant_engine, 'stats') else {} |
|
|
| perf = { |
| "db_id": str(db_id), "use_constrained_decoding": False, "num_beams": 4, |
| "latency_total_ms": round((t1 - t0) * 1000.0, 2), "constraint_ok": bool(strict_valid), "has_error": bool(error_msg), |
| "exec_cache_hit_rate": float(engine_stats_after.get("exec_cache_hit_rate", 0.0) or 0.0), |
| } |
| _perf_log(perf) |
|
|
| window = _PERF_LOG[-50:] |
| avg_ms = sum(float(x.get("latency_total_ms", 0.0) or 0.0) for x in window) / len(window) if window else 0.0 |
| constraint_rate = sum(1 for x in window if x.get("constraint_ok")) / len(window) if window else 0.0 |
|
|
| perf_block = ( |
| "\n\n---\nPerformance (task impact)\n" |
| f"- Total latency (ms): {perf['latency_total_ms']}\n" |
| f"- Strict Python Validator OK (Task 3): {perf['constraint_ok']}\n" |
| f"- Exec cache hit-rate (Task 1/5): {round(perf['exec_cache_hit_rate'], 3)}\n" |
| f"- Rolling avg latency last 50 (ms): {round(avg_ms, 2)}\n" |
| f"- Rolling constraint rate last 50: {round(constraint_rate, 3)}\n" |
| ) |
|
|
| if error_msg or not valid: |
| display_sql = final_sql if final_sql.strip() else "-- ❌ INVALID SQL" |
| explanation = f"{ui_warnings}❌ Error Details:\n\n" |
| if error_msg: explanation += f"{error_msg}\n\n" |
| |
| error_type = classify_error(final_sql, str(error_msg or "")) |
| explanation += f"Error Type: {error_type}\nHint: {get_hint(error_type)}" |
| explanation += perf_block |
| ops = sql_ops(final_sql) |
| for op in ops: |
| if op in _OP_STATS: _OP_STATS[op]["fail"] += 1 |
| return display_sql, pd.DataFrame(columns=["Execution Notice"]), explanation |
|
|
| safe_cols = cols if cols else ["Result"] |
| explanation = f"{ui_warnings}✅ Query executed successfully\n\nRows returned: {len(rows)}\nExecution Time: {latency} sec\n\n{explain_sql(final_sql)}{perf_block}" |
| |
| ops = sql_ops(final_sql) |
| for op in ops: |
| if op in _OP_STATS: _OP_STATS[op]["ok"] += 1 |
| _SUCCESS_LOG.append({"t": time.time(), "db_id": str(db_id), "question": question, "sql": final_sql, "ops": ops}) |
|
|
| limit_match = re.search(r'LIMIT\s+(\d+)', final_sql, re.IGNORECASE) |
| if limit_match and len(rows) < int(limit_match.group(1)): |
| explanation += f"\n\nℹ️ Query allowed up to {int(limit_match.group(1))} rows but only {len(rows)} matched." |
|
|
| return final_sql, pd.DataFrame(rows, columns=safe_cols), explanation |
|
|
| def task1_benchmark(n_rollouts: int, max_workers: int) -> Iterator[tuple[str, str]]: |
| project_root = str(PROJECT_ROOT) |
| env = os.environ.copy() |
| env["PYTHONPATH"] = project_root + (os.pathsep + env["PYTHONPATH"] if env.get("PYTHONPATH") else "") |
| env.setdefault("MPLBACKEND", "Agg") |
| env.setdefault("MPLCONFIGDIR", "/tmp/mplconfig") |
| try: os.makedirs(env["MPLCONFIGDIR"], exist_ok=True) |
| except Exception: pass |
|
|
| cmd = [sys.executable, "-u", "scripts/benchmark_parallel_reward.py", "--n", str(int(n_rollouts)), "--max-workers", str(int(max_workers)), "--skip-profile"] |
| proc = subprocess.Popen(cmd, cwd=project_root, env=env, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, bufsize=1) |
| last_yield = time.perf_counter() |
| lines: list[str] = [] |
| yield "Running Task 1 benchmark...\n", "<i>Running...</i>" |
|
|
| assert proc.stdout is not None |
| for line in proc.stdout: |
| lines.append(line) |
| now = time.perf_counter() |
| if now - last_yield >= 0.5: |
| last_yield = now |
| yield "".join(lines[-200:]).strip(), "<i>Running...</i>" |
|
|
| proc.wait() |
| out = "".join(lines).strip() |
|
|
| plot_path = str(PROJECT_ROOT / "results" / "task1_plot.png") |
| if os.path.exists(plot_path): |
| try: |
| b64 = base64.b64encode(Path(plot_path).read_bytes()).decode("ascii") |
| yield out, f"<img src='data:image/png;base64,{b64}' style='max-width: 100%; border: 1px solid #e2e8f0; border-radius: 8px;' />" |
| return |
| except Exception: |
| yield out, f"<pre>{plot_path}</pre>" |
| return |
|
|
| yield out, "<i>No plot generated</i>" |
|
|
| def task2_dashboard_structured(): |
| if not _QUERY_LOG: |
| empty_counts = pd.DataFrame(columns=["error_type", "count", "hint"]) |
| empty_recent = pd.DataFrame(columns=["time", "db_id", "error_type", "question", "error_msg"]) |
| return empty_counts, empty_recent, gr.update(choices=[], value=None) |
|
|
| counts = {} |
| for r in _QUERY_LOG[-1000:]: |
| k = r.get("error_type") or "other" |
| counts[k] = counts.get(k, 0) + 1 |
| rows = [{"error_type": k, "count": int(v), "hint": get_hint(k)} for k, v in sorted(counts.items(), key=lambda x: (-x[1], x[0]))] |
| counts_df = pd.DataFrame(rows) |
|
|
| recent = [] |
| for r in _QUERY_LOG[-100:]: |
| ts = r.get("t") |
| try: ts_s = time.strftime("%H:%M:%S", time.localtime(float(ts))) if ts else "" |
| except Exception: ts_s = "" |
| recent.append({"time": ts_s, "db_id": r.get("db_id", ""), "error_type": r.get("error_type", ""), "question": r.get("question", ""), "error_msg": r.get("error_msg", "")}) |
| recent_df = pd.DataFrame(recent) |
|
|
| choices = [str(x["error_type"]) for x in rows] |
| default = choices[0] if choices else None |
| return counts_df, recent_df, gr.update(choices=choices, value=default) |
|
|
| def task2_error_examples(error_type: str) -> str: |
| if not error_type: return "" |
| hint = get_hint(error_type) |
| matches = [r for r in reversed(_QUERY_LOG) if (r.get("error_type") or "") == str(error_type)][:3] |
| if not matches: return f"Error type: {error_type}\nHint: {hint}\n\nNo examples yet." |
| out = [f"Error type: {error_type}", f"Hint: {hint}", ""] |
| for i, r in enumerate(matches, 1): |
| out.extend([f"Example {i}", f"DB: {r.get('db_id','')}", f"Q: {r.get('question','')}", f"SQL: {r.get('sql','')}", f"Msg: {r.get('error_msg','')}", ""]) |
| return "\n".join(out).strip() |
|
|
| def _plot_op_stats_html() -> str: |
| try: |
| import matplotlib.pyplot as plt |
| labels = list(_OP_STATS.keys()) |
| oks = [int(_OP_STATS[k]["ok"]) for k in labels] |
| fails = [int(_OP_STATS[k]["fail"]) for k in labels] |
|
|
| fig, ax = plt.subplots(figsize=(9, 3.5)) |
| x = list(range(len(labels))) |
| ax.bar(x, oks, label="ok", color="#16a34a") |
| ax.bar(x, fails, bottom=oks, label="fail", color="#dc2626") |
| ax.set_xticks(x) |
| ax.set_xticklabels(labels, rotation=30, ha="right") |
| ax.set_title("Success/Failure by SQL operation") |
| ax.legend() |
| fig.tight_layout() |
|
|
| buf = io.BytesIO() |
| fig.savefig(buf, format="png", dpi=160) |
| plt.close(fig) |
| b64 = base64.b64encode(buf.getvalue()).decode("ascii") |
| return f"<img src='data:image/png;base64,{b64}' style='max-width: 100%; border: 1px solid #e2e8f0; border-radius: 8px;' />" |
| except Exception as e: return f"<pre>Plot error: {e}</pre>" |
|
|
| def task2_ops_table(): |
| rows = [] |
| for op, d in _OP_STATS.items(): |
| ok = int(d.get("ok", 0)) |
| fail = int(d.get("fail", 0)) |
| total = ok + fail |
| rows.append({"op": op, "ok": ok, "fail": fail, "total": total, "success_rate": (ok / total) if total else 0.0}) |
| return pd.DataFrame(rows), _plot_op_stats_html() |
|
|
| def toggle_input_method(method, current_sample): |
| if method == "💡 Pick a Sample": |
| db = next((db for q, db in SAMPLES if q == current_sample), "chinook_1") |
| return (gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(value=db, interactive=False)) |
| return (gr.update(visible=False), gr.update(visible=True), gr.update(visible=True), gr.update(interactive=True)) |
|
|
| def load_sample(selected_question): |
| if not selected_question: return gr.update() |
| return gr.update(value=next((db for q, db in SAMPLES if q == selected_question), "chinook_1")) |
|
|
| def clear_inputs(): |
| return (gr.update(value="💡 Pick a Sample"), gr.update(value=SAMPLE_QUESTIONS[0], visible=True), gr.update(visible=False), gr.update(value="", visible=False), gr.update(value="chinook_1", interactive=False), "", pd.DataFrame(), "") |
|
|
| def update_schema(db_id): |
| if not db_id or schema_encoder is None: return "" |
| try: |
| raw_schema = schema_encoder.structured_schema(db_id) |
| html_output = "<div style='max-height: 250px; overflow-y: auto; background: #f8fafc; padding: 12px; border-radius: 8px; border: 1px solid #e2e8f0; font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, monospace; font-size: 0.9em; line-height: 1.6;'>" |
| for line in raw_schema.strip().split('\n'): |
| line = line.strip() |
| if not line: continue |
| match = re.search(r'^([a-zA-Z0-9_]+)\s*\((.*)\)', line) |
| if match: html_output += f"<div style='margin-bottom: 8px;'><strong style='color: #0f172a; font-size: 1.05em; font-weight: 800;'>{match.group(1).upper()}</strong> <span style='color: #64748b;'>( {match.group(2).lower()} )</span></div>" |
| else: html_output += f"<div style='color: #475569;'>{line}</div>" |
| html_output += "</div>" |
| return html_output |
| except Exception as e: return f"<div style='color: red;'>Error loading schema: {str(e)}</div>" |
|
|
| |
| |
| |
| with gr.Blocks(title="Text-to-SQL RLHF") as demo: |
| gr.HTML(""" |
| <div style="text-align: center; background-color: #e0e7ff; padding: 20px; border-radius: 10px; margin-bottom: 20px; border: 1px solid #c7d2fe;"> |
| <h1 style="color: #3730a3; margin-top: 0; margin-bottom: 10px; font-size: 2.2em;"> Text-to-SQL using RLHF + Execution Reward</h1> |
| <p style="color: #4f46e5; font-size: 1.1em; margin: 0;">Convert Natural Language to SQL, strictly validated and safely executed on local SQLite databases.</p> |
| </div> |
| """) |
|
|
| DBS = sorted(["flight_1", "student_assessment", "store_1", "bike_1", "book_2", "chinook_1", "academic", "aircraft", "car_1", "cinema", "club_1", "csu_1", "college_1", "college_2", "company_1", "company_employee", "customer_complaints", "department_store", "employee_hire_evaluation", "museum_visit", "products_for_hire", "restaurant_1", "school_finance", "shop_membership", "small_bank_1", "student_1", "tvshow", "voter_1", "world_1"]) |
|
|
| with gr.Tabs(): |
| with gr.Tab("Inference"): |
| with gr.Row(): |
| with gr.Column(scale=1): |
| gr.Markdown("### 1. Configuration & Input") |
| input_method = gr.Radio(choices=["💡 Pick a Sample", "✍️ Type my own"], value="💡 Pick a Sample", label="How do you want to ask?") |
| sample_dropdown = gr.Dropdown(choices=SAMPLE_QUESTIONS, value=SAMPLE_QUESTIONS[0], label="Select a Sample Question", info="The database will be selected automatically.", visible=True) |
| type_own_warning = gr.Markdown("**⚠️ Please select a Database first, then type your custom question below:**", visible=False) |
| gr.Markdown("---") |
| db_id = gr.Dropdown(choices=DBS, value="chinook_1", label="Select Database", interactive=False) |
| custom_question = gr.Textbox(label="Ask your Custom Question", placeholder="Type your own question here...", lines=3, visible=False) |
|
|
| gr.Markdown("#### 📋 Database Structure") |
| gr.HTML("<p style='font-size: 0.85em; color: #64748b; margin-top: -10px; margin-bottom: 5px;'>Use these exact names! Table names are <strong>Dark</strong>, Column names are <span style='color: #94a3b8;'>Light</span>.</p>") |
| schema_display = gr.HTML(value=update_schema("chinook_1")) |
|
|
| with gr.Row(): |
| clear_btn = gr.Button("🗑️ Clear", variant="secondary") |
| run_btn = gr.Button(" Generate & Run SQL", variant="primary") |
|
|
| with gr.Column(scale=2): |
| gr.Markdown("### 2. Execution Results") |
| final_sql = gr.Code(language="sql", label="Final Executed SQL") |
| result_table = gr.Dataframe(label="Query Result Table", interactive=False, wrap=True) |
| explanation = gr.Textbox(label="AI Explanation + Execution Details", lines=8) |
|
|
| with gr.Tab("Diagnostics"): |
| gr.Markdown("## Diagnostics & Telemetry") |
| |
| with gr.Accordion("Task 1: Parallel Reward Benchmark", open=False): |
| gr.Markdown("*(Simulates the heavy RLHF training workload by running hundreds of complex SQL queries concurrently to test SQLite multi-threading performance.)*") |
| t1_n = gr.Number(value=20, precision=0, label="Rollouts (n)") |
| t1_workers = gr.Number(value=10, precision=0, label="Max workers") |
| t1_run = gr.Button("Run Task 1 benchmark") |
| t1_out = gr.Textbox(label="Output", lines=12) |
| t1_plot = gr.HTML(label="Plot (if generated)") |
| t1_run.click(fn=task1_benchmark, inputs=[t1_n, t1_workers], outputs=[t1_out, t1_plot]) |
|
|
| with gr.Accordion("Task 2: Error Dashboard", open=True): |
| gr.Markdown("*(Live telemetry tracking the most common SQL failures. Populates automatically when queries fail in the Inference tab.)*") |
| t2_refresh = gr.Button("Refresh dashboard") |
| t2_counts = gr.Dataframe(label="Error counts", interactive=False, wrap=True) |
| t2_recent = gr.Dataframe(label="Recent errors", interactive=False, wrap=True) |
| t2_type = gr.Dropdown(choices=[], value=None, label="Select error type") |
| t2_examples = gr.Textbox(label="Examples + hint", lines=10) |
|
|
| t2_refresh.click(fn=task2_dashboard_structured, inputs=[], outputs=[t2_counts, t2_recent, t2_type]) |
| t2_type.change(fn=task2_error_examples, inputs=[t2_type], outputs=[t2_examples]) |
|
|
| with gr.Accordion("Task 2: Clause Telemetry", open=False): |
| gr.Markdown("*(Analyzes which specific SQL clauses—SELECT, WHERE, JOIN, etc.—are most prone to errors during natural language generation.)*") |
| t2_ops_refresh = gr.Button("Refresh SQL-op stats") |
| t2_ops_tbl = gr.Dataframe(label="Success/failure by op", interactive=False, wrap=True) |
| t2_ops_plot = gr.HTML(label="Op plot") |
| t2_ops_refresh.click(fn=task2_ops_table, inputs=[], outputs=[t2_ops_tbl, t2_ops_plot]) |
|
|
| |
| input_method.change(fn=toggle_input_method, inputs=[input_method, sample_dropdown], outputs=[sample_dropdown, type_own_warning, custom_question, db_id]) |
| sample_dropdown.change(fn=load_sample, inputs=[sample_dropdown], outputs=[db_id]) |
| db_id.change(fn=update_schema, inputs=[db_id], outputs=[schema_display]) |
| |
| run_btn.click( |
| fn=run_query, |
| inputs=[input_method, sample_dropdown, custom_question, db_id], |
| outputs=[final_sql, result_table, explanation] |
| ).then( |
| fn=task2_dashboard_structured, inputs=[], outputs=[t2_counts, t2_recent, t2_type] |
| ).then( |
| fn=task2_ops_table, inputs=[], outputs=[t2_ops_tbl, t2_ops_plot] |
| ) |
| |
| clear_btn.click(fn=clear_inputs, inputs=[], outputs=[input_method, sample_dropdown, type_own_warning, custom_question, db_id, final_sql, result_table, explanation]) |
|
|
| if __name__ == "__main__": |
| server_name = os.environ.get("GRADIO_SERVER_NAME", "0.0.0.0") |
| base_port = int(os.environ.get("GRADIO_SERVER_PORT", 7860)) |
| max_retries = 10 |
| |
| for port in range(base_port, base_port + max_retries): |
| try: |
| print(f"Attempting to start Gradio UI on {server_name}:{port}...", flush=True) |
| demo.launch(server_name=server_name, server_port=port) |
| break |
| except OSError as e: |
| if "Cannot find empty port" in str(e) or "Address already in use" in str(e): |
| print(f"⚠️ Port {port} is in use, trying next port...") |
| continue |
| else: |
| |
| raise e |
| else: |
| print(f"❌ Could not find an open port between {base_port} and {base_port + max_retries - 1}.") |