File size: 3,507 Bytes
a360e3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import logging
import os

import pandas as pd
from duckdb import DuckDBPyConnection

from src.models import SQLQueryModel
from src.prompts import SQL_PROMPT, USER_PROMPT

logger = logging.getLogger(__name__)


SQL_GENERATION_RETRIES = int(os.getenv("SQL_GENERATION_RETRIES", "5"))


class SQLPipeline:
    def __init__(
        self,
        duckdb: DuckDBPyConnection,
        chain,
    ) -> None:
        self._duckdb = duckdb
        self.chain = chain

    def generate_sql(
        self, user_question: str, context: str, errors: str | None = None
    ) -> str | dict[str, str | int | float | None] | list[str] | None:
        """Generate SQL + description."""
        user_prompt_formatted = USER_PROMPT.format(
            question=user_question, context=context
        )
        if errors:
            user_prompt_formatted += f"Carefully review the previous error or\
            exception and rewrite the SQL so that the error does not occur again.\
            Try a different approach or rewrite SQL if needed. Last error: {errors}"

        sql = self.chain.run(
            system_prompt=SQL_PROMPT,
            user_prompt=user_prompt_formatted,
            format_name="sql_query",
            response_format=SQLQueryModel,
        )
        logger.info(f"SQL Generated Successfully: {sql}")
        return sql

    def run_query(self, sql_query: str) -> pd.DataFrame | None:
        """Execute SQL and return dataframe."""
        logger.info("Query Execution Started.")
        return self._duckdb.query(sql_query).df()

    def try_sql_with_retries(
        self,
        user_question: str,
        context: str,
        max_retries: int = SQL_GENERATION_RETRIES,
    ) -> tuple[
        str | dict[str, str | int | float | None] | list[str] | None,
        pd.DataFrame | None,
    ]:
        """Try SQL generation + execution with retries."""
        last_error = None
        all_errors = ""

        for attempt in range(
            1, max_retries + 2
        ):  # @ Since the first is normal and not consider in retries
            try:
                if attempt > 1 and last_error:
                    logger.info(f"Retrying: {attempt - 1}")
                    # Generate SQL
                    sql = self.generate_sql(user_question, context, errors=all_errors)
                    if not sql:
                        return None, None
                else:
                    # Generate SQL
                    sql = self.generate_sql(user_question, context)
                    if not sql:
                        return None, None

                # Try executing query
                sql_query_str = sql.get("sql_query") if isinstance(sql, dict) else sql
                if not isinstance(sql_query_str, str):
                    raise ValueError(
                        f"Expected SQL query to be a string, got {type(sql_query_str).__name__}"
                    )
                query_df = self.run_query(sql_query_str)

                # If execution succeeds, stop retrying or if df is not empty
                if query_df is not None and not query_df.empty:
                    return sql, query_df

            except Exception as e:
                last_error = f"\nAttempt {attempt - 1}] {type(e).__name__}: {e}"
                logger.error(f"Error during SQL generation or execution: {last_error}")
                all_errors += last_error

        logger.error(f"Failed after {max_retries} attempts. Last error: {all_errors}")
        return None, None