import logging import os from typing import Any, Dict, List import duckdb import gradio as gr import lancedb import pandas as pd import pyarrow as pa from dotenv import load_dotenv from src.client import LLMChain, embed_client from src.pipelines import SQLPipeline load_dotenv() # ========ENV's======== MD_TOKEN = os.getenv("MD_TOKEN") HF_TOKEN = os.getenv("HF_TOKEN") conn = duckdb.connect(f"md:my_db?motherduck_token={MD_TOKEN}", read_only=True) LEVEL = "INFO" if not os.getenv("ENV") == "PROD" else "WARNING" EMB_URL = os.getenv("EMB_URL") EMB_MODEL = os.getenv("EMB_MODEL") TAB_LINES = 8 # ===================== logging.basicConfig( level=getattr(logging, LEVEL, logging.INFO), format="%(asctime)s %(levelname)s %(name)s: %(message)s", ) logger = logging.getLogger(__name__) pipe = SQLPipeline(duckdb=conn, chain=LLMChain()) def _setup_lancedb() -> lancedb.table.Table: lance_db = lancedb.connect( uri=os.getenv("lancedb_uri"), api_key=os.getenv("lancedb_api_key"), region=os.getenv("lancedb_region"), ) lance_schema = pa.schema( [pa.field("vector", pa.list_(pa.float32())), pa.field("sql-query", pa.utf8())] ) try: table = lance_db.create_table(name="SQL-Queries", schema=lance_schema) except Exception: table = lance_db.open_table(name="SQL-Queries") return table lance_table = _setup_lancedb() def get_schemas() -> List: schemas = conn.execute(""" SELECT DISTINCT schema_name FROM information_schema.schemata WHERE schema_name NOT IN ('information_schema', 'pg_catalog') """).fetchall() return [item[0] for item in schemas] def get_tables(schema_name: str) -> List: tables = conn.execute( f"SELECT table_name FROM information_schema.tables WHERE table_schema = '{schema_name}'" ).fetchall() return [table[0] for table in tables] def update_tables(schema_name: str): tables = get_tables(schema_name) return gr.update(choices=tables) def get_table_schema(table: str) -> str: result = conn.sql( f"SELECT sql, database_name, schema_name FROM duckdb_tables() where table_name ='{table}';" ).df() ddl_create = result.iloc[0, 0] parent_database = result.iloc[0, 1] schema_name = result.iloc[0, 2] full_path = f"{parent_database}.{schema_name}.{table}" if schema_name != "main": old_path = f"{schema_name}.{table}" else: old_path = table ddl_create = ddl_create.replace(old_path, full_path) return ddl_create def run_pipeline(table: str, query_input: str) -> Dict[str, Any]: if table is None: return _error_response( query_input=query_input, message="❌ Please select a table/schema." ) schema = "" try: schema = get_table_schema(table=table) sql, df = pipe.try_sql_with_retries( user_question=query_input, context=schema, ) if not sql or df is None: return _error_response( query_input=query_input, schema=schema, message="❌ Unable to generate SQL from the input text.", ) except Exception as exc: logger.exception("Pipeline execution failed") return _error_response( query_input=query_input, schema=schema, message=f"❌ Pipeline error: {exc}" ) try: sql_str = f"{query_input}\n{sql.get('sql_query', '')}" embeddings = embed_query(sql_str) log2lancedb(embeddings, sql_str) except Exception as exc: logger.warning("Embedding/logging failed: %s", exc) return { table_schema: schema, input_prompt: query_input, generated_query: sql.get("sql_query", ""), result_output: df, } def _error_response( *, query_input: str, message: str, schema: str = "", ) -> Dict[str, Any]: return { table_schema: schema, input_prompt: query_input, generated_query: "", result_output: pd.DataFrame([{"error": message}]), } def embed_query(data: str) -> List: logger.info(f"Creating Emebeddings {data}") try: results = embed_client.feature_extraction(text=data, model=EMB_MODEL) return results.tolist() except Exception as e: logger.error(f"Unable to Generate embedding for the given query: {e}") return [] def log2lancedb(embeddings: List, sql_query: str) -> None: data = [{"sql-query": sql_query, "vector": embeddings}] lance_table.add(data) logger.info("Added to Lance DB.") custom_css = """ .gradio-container { background-color: #f0f4f8; } .logo { max-width: 200px; margin: 20px auto; display: block; } .gr-button { background-color: #4a90e2 !important; } .gr-button:hover { background-color: #3a7bc8 !important; } """ with gr.Blocks( theme=gr.themes.Soft(primary_hue="purple", secondary_hue="indigo"), css=custom_css ) as demo: gr.Image("logo.png", label=None, show_label=False, container=False, height=100) gr.Markdown("""