Spaces:
Sleeping
Sleeping
File size: 5,125 Bytes
a03bf1f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
import os
import json
import time
from tqdm import tqdm
from dotenv import load_dotenv
import google.generativeai as genai
# 1. Configuration
load_dotenv()
# Configure the API client
API_KEY = os.getenv("API_KEY")
if not API_KEY:
raise ValueError("API_KEY not found in .env file or environment variables.")
genai.configure(api_key=API_KEY)
# Define model and file paths
MODEL_NAME = os.getenv("MODEL_NAME")
INPUT_DIR = "../testset"
OUTPUT_FILE = f"inference_results/{MODEL_NAME}_inference_results.jsonl"
# Define the prompt template for the model
PROMPT_TEMPLATE = """
You are an expert engineer. Solve the following problem by providing a detailed, structured solution. Use the exact headings and formatting provided below.
## Given
List all known variables and their values with units.
## Find
State the variable(s) to be calculated.
## Formulae
Write down all necessary governing equations before substituting any values.
## Solution
Provide a step-by-step calculation. Each step must start on a new line and be formatted exactly as '**Step X:**', where X is the step number. Show the substitution of values into the formulae clearly.
## Final Answer
State the final numerical result with its units in the format: **Answer:** [value] [units]
Problem:
{question}
"""
# 2. Data Loading Function
def load_all_problems(directory: str) -> list:
"""
Walks through the nested directory structure, finds all .jsonl files,
and loads all problems into a single list.
"""
all_problems = []
print(f"Loading problems from '{directory}'...")
for root, _, files in os.walk(directory):
for file in files:
if file.endswith('.jsonl'):
file_path = os.path.join(root, file)
with open(file_path, 'r', encoding='utf-8') as f:
for line in f:
all_problems.append(json.loads(line))
print(f"Successfully loaded {len(all_problems)} problems.")
return all_problems
# 3. Main Inference Logic (with Smart Rate Limiter)
if __name__ == "__main__":
# Load all problems from the testset directory
problems = load_all_problems(INPUT_DIR)
if not problems:
print(f"Error: No problems found in '{INPUT_DIR}'. Please check the path.")
else:
# Initialize the generative model
model = genai.GenerativeModel(MODEL_NAME)
print(f"Initialized model: {MODEL_NAME}")
# Open the output file in write mode to start fresh
with open(OUTPUT_FILE, "w", encoding='utf-8') as f_out:
print(f"Starting inference... Results will be saved to '{OUTPUT_FILE}'")
# Initialize variables for rate limiting
request_counter = 0
start_time = time.time()
REQUESTS_PER_MINUTE = 8 # Our safe target
progress_bar = tqdm(problems, desc="Initializing Inference")
for problem in progress_bar:
# Rate Limiting Logic
# Check if we have made 10 requests
if request_counter >= REQUESTS_PER_MINUTE:
elapsed_time = time.time() - start_time
# If 10 requests took less than a minute, wait for the remainder
if elapsed_time < 60:
wait_time = 60 - elapsed_time
tqdm.write(f"Rate limit reached. Pausing for {wait_time:.2f} seconds...")
time.sleep(wait_time)
# Reset the counter and timer for the next batch of 10
request_counter = 0
start_time = time.time()
# Update the progress bar's description
branch = problem.get('branch', 'unknown_branch')
problem_id = problem.get('id', 'unknown_id')
progress_bar.set_description(f"Processing '{problem_id}' from '{branch}'")
prompt = PROMPT_TEMPLATE.format(question=problem['question'])
try:
# Call the Gemini API
response = model.generate_content(prompt)
problem['generation'] = response.text
except Exception as e:
tqdm.write(f"\nAn error occurred for problem ID {problem_id}: {e}")
problem['generation'] = f"ERROR: {e}"
# Wait longer if an error occurs (e.g., server-side issues)
time.sleep(60)
# Write the result to the output file immediately
f_out.write(json.dumps(problem) + '\n')
# Increment the request counter after a successful call
request_counter += 1
print(f"\nInference complete. All {len(problems)} results saved to '{OUTPUT_FILE}'.")
|