Spaces:
Sleeping
Sleeping
File size: 3,507 Bytes
a360e3c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 | import logging
import os
import pandas as pd
from duckdb import DuckDBPyConnection
from src.models import SQLQueryModel
from src.prompts import SQL_PROMPT, USER_PROMPT
logger = logging.getLogger(__name__)
SQL_GENERATION_RETRIES = int(os.getenv("SQL_GENERATION_RETRIES", "5"))
class SQLPipeline:
def __init__(
self,
duckdb: DuckDBPyConnection,
chain,
) -> None:
self._duckdb = duckdb
self.chain = chain
def generate_sql(
self, user_question: str, context: str, errors: str | None = None
) -> str | dict[str, str | int | float | None] | list[str] | None:
"""Generate SQL + description."""
user_prompt_formatted = USER_PROMPT.format(
question=user_question, context=context
)
if errors:
user_prompt_formatted += f"Carefully review the previous error or\
exception and rewrite the SQL so that the error does not occur again.\
Try a different approach or rewrite SQL if needed. Last error: {errors}"
sql = self.chain.run(
system_prompt=SQL_PROMPT,
user_prompt=user_prompt_formatted,
format_name="sql_query",
response_format=SQLQueryModel,
)
logger.info(f"SQL Generated Successfully: {sql}")
return sql
def run_query(self, sql_query: str) -> pd.DataFrame | None:
"""Execute SQL and return dataframe."""
logger.info("Query Execution Started.")
return self._duckdb.query(sql_query).df()
def try_sql_with_retries(
self,
user_question: str,
context: str,
max_retries: int = SQL_GENERATION_RETRIES,
) -> tuple[
str | dict[str, str | int | float | None] | list[str] | None,
pd.DataFrame | None,
]:
"""Try SQL generation + execution with retries."""
last_error = None
all_errors = ""
for attempt in range(
1, max_retries + 2
): # @ Since the first is normal and not consider in retries
try:
if attempt > 1 and last_error:
logger.info(f"Retrying: {attempt - 1}")
# Generate SQL
sql = self.generate_sql(user_question, context, errors=all_errors)
if not sql:
return None, None
else:
# Generate SQL
sql = self.generate_sql(user_question, context)
if not sql:
return None, None
# Try executing query
sql_query_str = sql.get("sql_query") if isinstance(sql, dict) else sql
if not isinstance(sql_query_str, str):
raise ValueError(
f"Expected SQL query to be a string, got {type(sql_query_str).__name__}"
)
query_df = self.run_query(sql_query_str)
# If execution succeeds, stop retrying or if df is not empty
if query_df is not None and not query_df.empty:
return sql, query_df
except Exception as e:
last_error = f"\nAttempt {attempt - 1}] {type(e).__name__}: {e}"
logger.error(f"Error during SQL generation or execution: {last_error}")
all_errors += last_error
logger.error(f"Failed after {max_retries} attempts. Last error: {all_errors}")
return None, None
|