import ast import logging import os import pandas as pd from dotenv import load_dotenv from duckdb import DuckDBPyConnection from src.models import PanderaSchemaModel, SQLQueryModel load_dotenv() logger = logging.getLogger(__name__) SQL_GENERATION_RETRIES = int(os.getenv("SQL_GENERATION_RETRIES", "5")) PANDERA_PROMPT = os.getenv("PANDERA_PROMPT") PANDERA_USER_PROMPT = os.getenv("PANDERA_USER_PROMPT") SQL_PROMPT = os.getenv("SQL_PROMPT") USER_PROMPT = os.getenv("USER_PROMPT") class Query2Schema: def __init__( self, duckdb: DuckDBPyConnection, chain, ) -> None: self._duckdb = duckdb self.chain = chain def generate_sql( self, user_question: str, context: str, errors: str | None = None ) -> str | dict[str, str | int | float | None] | list[str] | None: """Generate SQL + description.""" user_prompt_formatted = USER_PROMPT.format( question=user_question, context=context ) if errors: user_prompt_formatted += f"Carefully review the previous error or\ exception and rewrite the SQL so that the error does not occur again.\ Try a different approach or rewrite SQL if needed. Last error: {errors}" sql = self.chain.run( system_prompt=SQL_PROMPT, user_prompt=user_prompt_formatted, format_name="sql_query", response_format=SQLQueryModel, ) logger.info(f"SQL Generated Successfully: {sql}") return sql def run_query(self, sql_query: str) -> pd.DataFrame | None: """Execute SQL and return dataframe.""" logger.info("Query Execution Started.") return self._duckdb.query(sql_query).df() def try_sql_with_retries( self, user_question: str, context: str, max_retries: int = SQL_GENERATION_RETRIES, ) -> tuple[ str | dict[str, str | int | float | None] | list[str] | None, pd.DataFrame | None, ]: """Try SQL generation + execution with retries.""" last_error = None all_errors = "" for attempt in range( 1, max_retries + 2 ): # @ Since the first is normal and not consider in retries try: if attempt > 1 and last_error: logger.info(f"Retrying: {attempt - 1}") # Generate SQL sql = self.generate_sql(user_question, context, errors=all_errors) if not sql: return None, None else: # Generate SQL sql = self.generate_sql(user_question, context) if not sql: return None, None # Try executing query sql_query_str = sql.get("sql_query") if isinstance(sql, dict) else sql if not isinstance(sql_query_str, str): raise ValueError( f"Expected SQL query to be a string, got {type(sql_query_str).__name__}" ) query_df = self.run_query(sql_query_str) # If execution succeeds, stop retrying or if df is not empty if query_df is not None and not query_df.empty: return sql, query_df except Exception as e: last_error = f"\nAttempt {attempt - 1}] {type(e).__name__}: {e}" logger.error(f"Error during SQL generation or execution: {last_error}") all_errors += last_error logger.error(f"Failed after {max_retries} attempts. Last error: {all_errors}") return None, None def generate_pandera_schema(self, sql_query: str, user_instruction: str) -> str: """Generate Pandera schema.""" class_lines = [] schema_str = self.chain.run( system_prompt=PANDERA_PROMPT, user_prompt=PANDERA_USER_PROMPT.format( sql_query=sql_query, instructions=user_instruction ), format_name="pandera_schema", response_format=PanderaSchemaModel, ) parsed = ast.parse(schema_str) original_lines = schema_str.splitlines() for node in parsed.body: if isinstance(node, ast.ClassDef): start, end = node.lineno - 1, node.end_lineno class_lines.extend(original_lines[start:end]) return "\n".join(class_lines)