File size: 9,681 Bytes
a03bf1f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
# Cross Encoders + LLM Judge Evaluator
import os
import json
import torch
import numpy as np
from typing import Dict, Any
import google.generativeai as genai
from dotenv import load_dotenv

from engineering_parser import extract_steps
from sentence_transformers import CrossEncoder
from rouge_score import rouge_scorer
from bert_score import BERTScorer


load_dotenv()

# Initialize the client for the judge LLM.
JUDGE_MODEL = None
try:
    api_key = os.getenv("JUDGE_API_KEY")
    if not api_key:
        print("Warning: JUDGE_API_KEY not found in .env file. LLM Judge will be skipped.")
    else:
        genai.configure(api_key=api_key)
        model_name = os.getenv("JUDGE_MODEL_NAME")
        JUDGE_MODEL = genai.GenerativeModel(model_name)
        print(f"LLM Judge initialized with model: {model_name}")
except Exception as e:
    print(f"Warning: Could not initialize Gemini Judge client. Error: {e}")


#  Model & Scorer Initialization 
print("Initializing evaluation models...")
CROSS_ENCODER = CrossEncoder('cross-encoder/stsb-roberta-large')
ROUGE_SCORER = rouge_scorer.RougeScorer(['rouge2', 'rougeL', 'rougeLsum'], use_stemmer=True)
BERT_SCORER = BERTScorer(model_type='allenai/longformer-base-4096', 
                         device='cuda' if torch.cuda.is_available() else 'cpu')
print("Evaluation models initialized.")


#  LLM Judge Helper Function (Gemini) 
def get_error_analysis_from_llm(gt_step: str, pred_step: str, problem_context: str) -> Dict[str, Any]:
    """

    Uses a Gemini model to categorize the error in a predicted reasoning step.

    """
    if not JUDGE_MODEL:
        return {"error_category": "Analysis Skipped", "explanation": "LLM Judge client not initialized."}
    
    prompt = f"""

    You are an expert engineering professor acting as an automated evaluator. Your task is to analyze the 'MODEL'S STEP' against the 'GROUND-TRUTH STEP' and return a structured JSON object based on the rules and inputs below.



    **CRITICAL RULES:**

    1. Your 'explanation' MUST logically justify your chosen 'error_category'. Do not contradict yourself. For example, do not choose "Calculation Error" and then state that the calculation is correct.

    2. If the model's step is factually correct but takes a different path than the ground-truth, you MUST use the "Alternative Correct" category. Do not classify a correct step as "Other".

    3. The final output must be only a raw JSON object. Do not include any introductory text, concluding remarks, or markdown formatting.



    **Error Categories:**

    - "Conceptual Error": The model applied the wrong scientific principle or formula (e.g., used addition instead of subtraction).

    - "Calculation Error": The model used the correct formula but made a mathematical mistake (e.g., 2 * 3 = 5).

    - "Input Error": The model used the correct formula but pulled the wrong number from the problem context or a previous step.

    - "Alternative Correct": The model's step is valid and logically sound, but follows a different method or phrasing than the ground-truth step.

    - "Other": The model's step is nonsensical, irrelevant, a hallucination, or contains only formatting errors.



    **Input for Analysis:**

    [CONTEXT]: {problem_context}

    [GROUND-TRUTH STEP]: {gt_step}

    [MODEL'S STEP]: {pred_step}



    **OUTPUT FORMAT:**

    You must now provide your analysis. Your entire response will be a single, raw JSON object. Adhere strictly to the following format with exactly two keys:

    {{"error_category": "...", "explanation": "..."}}

    """
    
    try:
        generation_config = genai.types.GenerationConfig(
            response_mime_type="application/json",
            temperature=0.0
        )
        response = JUDGE_MODEL.generate_content(prompt, generation_config=generation_config)
        cleaned_json = response.text.strip().replace("```json", "").replace("```", "").strip()
        analysis = json.loads(cleaned_json)
        
        # Validation and Standardization Step 
        # Ensure the output has the correct key, even if the model makes a mistake.
        if "error_classification" in analysis:
            analysis["error_category"] = analysis.pop("error_classification")
            
        # Ensure the dictionary has the required keys, providing defaults if missing.
        if "error_category" not in analysis:
            analysis["error_category"] = "Formatting Error"
        if "explanation" not in analysis:
            analysis["explanation"] = "Model failed to provide an explanation."
            
        return analysis
        
    except Exception as e:
        return {"error_category": "Analysis Failed", "explanation": str(e)}


def safe_bert_score(gt: str, pred: str) -> float:
    """ A wrapper for BERTScore to handle potential errors and empty strings. """
    if not all(isinstance(s, str) for s in [gt, pred]) or not gt.strip() or not pred.strip():
        return 0.0
    try:
        _, _, f1 = BERT_SCORER.score([pred], [gt])
        return f1.item()
    except Exception as e:
        print(f"Warning: BERTScore failed with error: {e}")
        return 0.0


def evaluate_trace_eng(gt_solution: str, pred_generation: str, problem_context: str) -> Dict[str, Any]:
    """

    Compares a ground-truth engineering solution with a model's generation.

    """

    # Handle empty inputs gracefully
    if not gt_solution or not pred_generation:
        return {
            'error': 'Input solution or generation is empty.', 
            'recall': 0, 
            'precision': 0, 
            'step_f1': 0, 
            'final_answer_match': 0, 
            'rouge2': 0, 
            'rougeL': 0, 
            'rougeLsum': 0, 
            'bertscore': 0, 
            'error_analysis': []
        }
    
    # Compute textual similarity metrics
    rouge_scores = ROUGE_SCORER.score(gt_solution, pred_generation)
    bertscore = safe_bert_score(gt_solution, pred_generation)
    error_analyses = []

    # Extract structured reasoning steps and final answers
    gt_steps, gt_step_answers, gt_final_answer = extract_steps(gt_solution)
    pred_steps, pred_step_answers, pred_final_answer = extract_steps(pred_generation)
    final_answer_match = 0
    FINAL_ANSWER_TOLERANCE = 0.01

    # Final answer comparison with tolerance
    if gt_final_answer is not None and pred_final_answer is not None:
        if abs(gt_final_answer) > 1e-9:
            if abs(gt_final_answer - pred_final_answer) / abs(gt_final_answer) < FINAL_ANSWER_TOLERANCE:
                final_answer_match = 1
        elif abs(gt_final_answer - pred_final_answer) < 1e-9:
            final_answer_match = 1

    # Step-level similarity and recall/precision computation
    recall, precision = 0, 0
    if gt_steps and pred_steps:
        sentence_pairs = [[gt_step, pred_step] for gt_step in gt_steps for pred_step in pred_steps]
        scores = CROSS_ENCODER.predict(sentence_pairs, show_progress_bar=False)
        semantic_similarity = np.array(scores).reshape(len(gt_steps), len(pred_steps))

        # Numeric correctness check
        numeric_correctness = np.zeros((len(gt_steps), len(pred_steps)))
        STEP_ANSWER_TOLERANCE = 0.02

        for i, gt_ans in enumerate(gt_step_answers):
            if gt_ans is None: continue
            for j, pred_ans in enumerate(pred_step_answers):
                if pred_ans is None: continue
                if abs(gt_ans) > 1e-9:
                    if (abs(gt_ans - pred_ans) / abs(gt_ans)) < STEP_ANSWER_TOLERANCE:
                        numeric_correctness[i, j] = 1
                elif abs(gt_ans - pred_ans) < 1e-9:
                    numeric_correctness[i, j] = 1
        
        # Combine semantic and numeric correctness matrices
        combined_matrix = np.multiply(semantic_similarity, numeric_correctness)
        SIMILARITY_THRESHOLD = 0.7
        best_matches_scores = np.max(combined_matrix, axis=1)
        
        # Identify and analyze mismatched steps
        for i, score in enumerate(best_matches_scores):
            if score < SIMILARITY_THRESHOLD:
                gt_mismatched_step = gt_steps[i]
                best_pred_index = np.argmax(combined_matrix[i, :])
                pred_mismatched_step = pred_steps[best_pred_index]
                analysis = get_error_analysis_from_llm(gt_mismatched_step, pred_mismatched_step, problem_context)
                error_analyses.append({
                    "mismatched_gt_step_index": i,
                    "ground_truth_step": gt_mismatched_step,
                    "closest_predicted_step": pred_mismatched_step,
                    "analysis": analysis
                })
        
        # Compute recall and precision based on matched steps
        recall = float(np.sum(best_matches_scores > SIMILARITY_THRESHOLD) / len(gt_steps))
        precision = float(np.sum(np.max(combined_matrix, axis=0) > SIMILARITY_THRESHOLD) / len(pred_steps))

    # Compute step-level F1 score
    step_f1 = 0
    if recall + precision > 0:
        step_f1 = 2 * (recall * precision) / (recall + precision)
    
    # Final result dictionary
    return {
        'recall': recall, 
        'precision': precision, 
        'step_f1': step_f1,
        'final_answer_match': final_answer_match, 
        'rouge2': rouge_scores['rouge2'].fmeasure,
        'rougeL': rouge_scores['rougeL'].fmeasure, 
        'rougeLsum': rouge_scores['rougeLsum'].fmeasure,
        'bertscore': bertscore, 'error_analysis': error_analyses
    }