Muhammad Mustehson
Update Old Code
a360e3c
import logging
import os
from typing import Any, Dict, List
import duckdb
import gradio as gr
import lancedb
import pandas as pd
import pyarrow as pa
from dotenv import load_dotenv
from src.client import LLMChain, embed_client
from src.pipelines import SQLPipeline
load_dotenv()
# ========ENV's========
MD_TOKEN = os.getenv("MD_TOKEN")
HF_TOKEN = os.getenv("HF_TOKEN")
conn = duckdb.connect(f"md:my_db?motherduck_token={MD_TOKEN}", read_only=True)
LEVEL = "INFO" if not os.getenv("ENV") == "PROD" else "WARNING"
EMB_URL = os.getenv("EMB_URL")
EMB_MODEL = os.getenv("EMB_MODEL")
TAB_LINES = 8
# =====================
logging.basicConfig(
level=getattr(logging, LEVEL, logging.INFO),
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
)
logger = logging.getLogger(__name__)
pipe = SQLPipeline(duckdb=conn, chain=LLMChain())
def _setup_lancedb() -> lancedb.table.Table:
lance_db = lancedb.connect(
uri=os.getenv("lancedb_uri"),
api_key=os.getenv("lancedb_api_key"),
region=os.getenv("lancedb_region"),
)
lance_schema = pa.schema(
[pa.field("vector", pa.list_(pa.float32())), pa.field("sql-query", pa.utf8())]
)
try:
table = lance_db.create_table(name="SQL-Queries", schema=lance_schema)
except Exception:
table = lance_db.open_table(name="SQL-Queries")
return table
lance_table = _setup_lancedb()
def get_schemas() -> List:
schemas = conn.execute("""
SELECT DISTINCT schema_name
FROM information_schema.schemata
WHERE schema_name NOT IN ('information_schema', 'pg_catalog')
""").fetchall()
return [item[0] for item in schemas]
def get_tables(schema_name: str) -> List:
tables = conn.execute(
f"SELECT table_name FROM information_schema.tables WHERE table_schema = '{schema_name}'"
).fetchall()
return [table[0] for table in tables]
def update_tables(schema_name: str):
tables = get_tables(schema_name)
return gr.update(choices=tables)
def get_table_schema(table: str) -> str:
result = conn.sql(
f"SELECT sql, database_name, schema_name FROM duckdb_tables() where table_name ='{table}';"
).df()
ddl_create = result.iloc[0, 0]
parent_database = result.iloc[0, 1]
schema_name = result.iloc[0, 2]
full_path = f"{parent_database}.{schema_name}.{table}"
if schema_name != "main":
old_path = f"{schema_name}.{table}"
else:
old_path = table
ddl_create = ddl_create.replace(old_path, full_path)
return ddl_create
def run_pipeline(table: str, query_input: str) -> Dict[str, Any]:
if table is None:
return _error_response(
query_input=query_input, message="❌ Please select a table/schema."
)
schema = ""
try:
schema = get_table_schema(table=table)
sql, df = pipe.try_sql_with_retries(
user_question=query_input,
context=schema,
)
if not sql or df is None:
return _error_response(
query_input=query_input,
schema=schema,
message="❌ Unable to generate SQL from the input text.",
)
except Exception as exc:
logger.exception("Pipeline execution failed")
return _error_response(
query_input=query_input, schema=schema, message=f"❌ Pipeline error: {exc}"
)
try:
sql_str = f"{query_input}\n{sql.get('sql_query', '')}"
embeddings = embed_query(sql_str)
log2lancedb(embeddings, sql_str)
except Exception as exc:
logger.warning("Embedding/logging failed: %s", exc)
return {
table_schema: schema,
input_prompt: query_input,
generated_query: sql.get("sql_query", ""),
result_output: df,
}
def _error_response(
*,
query_input: str,
message: str,
schema: str = "",
) -> Dict[str, Any]:
return {
table_schema: schema,
input_prompt: query_input,
generated_query: "",
result_output: pd.DataFrame([{"error": message}]),
}
def embed_query(data: str) -> List:
logger.info(f"Creating Emebeddings {data}")
try:
results = embed_client.feature_extraction(text=data, model=EMB_MODEL)
return results.tolist()
except Exception as e:
logger.error(f"Unable to Generate embedding for the given query: {e}")
return []
def log2lancedb(embeddings: List, sql_query: str) -> None:
data = [{"sql-query": sql_query, "vector": embeddings}]
lance_table.add(data)
logger.info("Added to Lance DB.")
custom_css = """
.gradio-container {
background-color: #f0f4f8;
}
.logo {
max-width: 200px;
margin: 20px auto;
display: block;
}
.gr-button {
background-color: #4a90e2 !important;
}
.gr-button:hover {
background-color: #3a7bc8 !important;
}
"""
with gr.Blocks(
theme=gr.themes.Soft(primary_hue="purple", secondary_hue="indigo"), css=custom_css
) as demo:
gr.Image("logo.png", label=None, show_label=False, container=False, height=100)
gr.Markdown("""
<div style='text-align: center;'>
<strong style='font-size: 36px;'>Datajoi SQL Agent</strong>
<br>
<span style='font-size: 20px;'>Generate and Run SQL queries based on a given text for the dataset.</span>
</div>
""")
with gr.Row():
with gr.Column(scale=1, variant="panel"):
schema_dropdown = gr.Dropdown(
choices=get_schemas(), label="Select Schema", interactive=True
)
tables_dropdown = gr.Dropdown(
choices=[], label="Available Tables", value=None
)
with gr.Column(scale=2):
query_input = gr.Textbox(
lines=5, label="Text Query", placeholder="Enter your text query here..."
)
with gr.Row():
with gr.Column(scale=7):
pass
with gr.Column(scale=1):
generate_query_button = gr.Button("Run Query", variant="primary")
with gr.Tabs():
with gr.Tab("Result"):
result_output = gr.DataFrame(
label="Query Results", value=[], interactive=False
)
with gr.Tab("SQL Query"):
generated_query = gr.Textbox(
lines=TAB_LINES,
label="Generated SQL Query",
value="",
interactive=False,
)
with gr.Tab("Prompt"):
input_prompt = gr.Textbox(
lines=TAB_LINES,
label="Input Prompt",
value="",
interactive=False,
)
with gr.Tab("Schema"):
table_schema = gr.Textbox(
lines=TAB_LINES,
label="Table Schema",
value="",
interactive=False,
)
schema_dropdown.change(
update_tables, inputs=schema_dropdown, outputs=tables_dropdown
)
generate_query_button.click(
run_pipeline,
inputs=[tables_dropdown, query_input],
outputs=[table_schema, input_prompt, generated_query, result_output],
)
if __name__ == "__main__":
demo.launch()