Spaces:
Running
Running
| import logging | |
| import os | |
| from typing import Any, Dict, List | |
| import duckdb | |
| import gradio as gr | |
| import lancedb | |
| import pandas as pd | |
| import pyarrow as pa | |
| from dotenv import load_dotenv | |
| from src.client import LLMChain, embed_client | |
| from src.pipelines import SQLPipeline | |
| load_dotenv() | |
| # ========ENV's======== | |
| MD_TOKEN = os.getenv("MD_TOKEN") | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| conn = duckdb.connect(f"md:my_db?motherduck_token={MD_TOKEN}", read_only=True) | |
| LEVEL = "INFO" if not os.getenv("ENV") == "PROD" else "WARNING" | |
| EMB_URL = os.getenv("EMB_URL") | |
| EMB_MODEL = os.getenv("EMB_MODEL") | |
| TAB_LINES = 8 | |
| # ===================== | |
| logging.basicConfig( | |
| level=getattr(logging, LEVEL, logging.INFO), | |
| format="%(asctime)s %(levelname)s %(name)s: %(message)s", | |
| ) | |
| logger = logging.getLogger(__name__) | |
| pipe = SQLPipeline(duckdb=conn, chain=LLMChain()) | |
| def _setup_lancedb() -> lancedb.table.Table: | |
| lance_db = lancedb.connect( | |
| uri=os.getenv("lancedb_uri"), | |
| api_key=os.getenv("lancedb_api_key"), | |
| region=os.getenv("lancedb_region"), | |
| ) | |
| lance_schema = pa.schema( | |
| [pa.field("vector", pa.list_(pa.float32())), pa.field("sql-query", pa.utf8())] | |
| ) | |
| try: | |
| table = lance_db.create_table(name="SQL-Queries", schema=lance_schema) | |
| except Exception: | |
| table = lance_db.open_table(name="SQL-Queries") | |
| return table | |
| lance_table = _setup_lancedb() | |
| def get_schemas() -> List: | |
| schemas = conn.execute(""" | |
| SELECT DISTINCT schema_name | |
| FROM information_schema.schemata | |
| WHERE schema_name NOT IN ('information_schema', 'pg_catalog') | |
| """).fetchall() | |
| return [item[0] for item in schemas] | |
| def get_tables(schema_name: str) -> List: | |
| tables = conn.execute( | |
| f"SELECT table_name FROM information_schema.tables WHERE table_schema = '{schema_name}'" | |
| ).fetchall() | |
| return [table[0] for table in tables] | |
| def update_tables(schema_name: str): | |
| tables = get_tables(schema_name) | |
| return gr.update(choices=tables) | |
| def get_table_schema(table: str) -> str: | |
| result = conn.sql( | |
| f"SELECT sql, database_name, schema_name FROM duckdb_tables() where table_name ='{table}';" | |
| ).df() | |
| ddl_create = result.iloc[0, 0] | |
| parent_database = result.iloc[0, 1] | |
| schema_name = result.iloc[0, 2] | |
| full_path = f"{parent_database}.{schema_name}.{table}" | |
| if schema_name != "main": | |
| old_path = f"{schema_name}.{table}" | |
| else: | |
| old_path = table | |
| ddl_create = ddl_create.replace(old_path, full_path) | |
| return ddl_create | |
| def run_pipeline(table: str, query_input: str) -> Dict[str, Any]: | |
| if table is None: | |
| return _error_response( | |
| query_input=query_input, message="❌ Please select a table/schema." | |
| ) | |
| schema = "" | |
| try: | |
| schema = get_table_schema(table=table) | |
| sql, df = pipe.try_sql_with_retries( | |
| user_question=query_input, | |
| context=schema, | |
| ) | |
| if not sql or df is None: | |
| return _error_response( | |
| query_input=query_input, | |
| schema=schema, | |
| message="❌ Unable to generate SQL from the input text.", | |
| ) | |
| except Exception as exc: | |
| logger.exception("Pipeline execution failed") | |
| return _error_response( | |
| query_input=query_input, schema=schema, message=f"❌ Pipeline error: {exc}" | |
| ) | |
| try: | |
| sql_str = f"{query_input}\n{sql.get('sql_query', '')}" | |
| embeddings = embed_query(sql_str) | |
| log2lancedb(embeddings, sql_str) | |
| except Exception as exc: | |
| logger.warning("Embedding/logging failed: %s", exc) | |
| return { | |
| table_schema: schema, | |
| input_prompt: query_input, | |
| generated_query: sql.get("sql_query", ""), | |
| result_output: df, | |
| } | |
| def _error_response( | |
| *, | |
| query_input: str, | |
| message: str, | |
| schema: str = "", | |
| ) -> Dict[str, Any]: | |
| return { | |
| table_schema: schema, | |
| input_prompt: query_input, | |
| generated_query: "", | |
| result_output: pd.DataFrame([{"error": message}]), | |
| } | |
| def embed_query(data: str) -> List: | |
| logger.info(f"Creating Emebeddings {data}") | |
| try: | |
| results = embed_client.feature_extraction(text=data, model=EMB_MODEL) | |
| return results.tolist() | |
| except Exception as e: | |
| logger.error(f"Unable to Generate embedding for the given query: {e}") | |
| return [] | |
| def log2lancedb(embeddings: List, sql_query: str) -> None: | |
| data = [{"sql-query": sql_query, "vector": embeddings}] | |
| lance_table.add(data) | |
| logger.info("Added to Lance DB.") | |
| custom_css = """ | |
| .gradio-container { | |
| background-color: #f0f4f8; | |
| } | |
| .logo { | |
| max-width: 200px; | |
| margin: 20px auto; | |
| display: block; | |
| } | |
| .gr-button { | |
| background-color: #4a90e2 !important; | |
| } | |
| .gr-button:hover { | |
| background-color: #3a7bc8 !important; | |
| } | |
| """ | |
| with gr.Blocks( | |
| theme=gr.themes.Soft(primary_hue="purple", secondary_hue="indigo"), css=custom_css | |
| ) as demo: | |
| gr.Image("logo.png", label=None, show_label=False, container=False, height=100) | |
| gr.Markdown(""" | |
| <div style='text-align: center;'> | |
| <strong style='font-size: 36px;'>Datajoi SQL Agent</strong> | |
| <br> | |
| <span style='font-size: 20px;'>Generate and Run SQL queries based on a given text for the dataset.</span> | |
| </div> | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1, variant="panel"): | |
| schema_dropdown = gr.Dropdown( | |
| choices=get_schemas(), label="Select Schema", interactive=True | |
| ) | |
| tables_dropdown = gr.Dropdown( | |
| choices=[], label="Available Tables", value=None | |
| ) | |
| with gr.Column(scale=2): | |
| query_input = gr.Textbox( | |
| lines=5, label="Text Query", placeholder="Enter your text query here..." | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=7): | |
| pass | |
| with gr.Column(scale=1): | |
| generate_query_button = gr.Button("Run Query", variant="primary") | |
| with gr.Tabs(): | |
| with gr.Tab("Result"): | |
| result_output = gr.DataFrame( | |
| label="Query Results", value=[], interactive=False | |
| ) | |
| with gr.Tab("SQL Query"): | |
| generated_query = gr.Textbox( | |
| lines=TAB_LINES, | |
| label="Generated SQL Query", | |
| value="", | |
| interactive=False, | |
| ) | |
| with gr.Tab("Prompt"): | |
| input_prompt = gr.Textbox( | |
| lines=TAB_LINES, | |
| label="Input Prompt", | |
| value="", | |
| interactive=False, | |
| ) | |
| with gr.Tab("Schema"): | |
| table_schema = gr.Textbox( | |
| lines=TAB_LINES, | |
| label="Table Schema", | |
| value="", | |
| interactive=False, | |
| ) | |
| schema_dropdown.change( | |
| update_tables, inputs=schema_dropdown, outputs=tables_dropdown | |
| ) | |
| generate_query_button.click( | |
| run_pipeline, | |
| inputs=[tables_dropdown, query_input], | |
| outputs=[table_schema, input_prompt, generated_query, result_output], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |