introvoyz041 commited on
Commit
84fbe5c
·
verified ·
1 Parent(s): bff8e7d

Upload handler.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. handler.py +139 -0
handler.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ import traceback
4
+ from typing import Dict, List, Any
5
+
6
+ from nemo_skills.inference.server.code_execution_model import get_code_execution_model
7
+ from nemo_skills.code_execution.sandbox import get_sandbox
8
+ from nemo_skills.prompt.utils import get_prompt
9
+
10
+ # Configure logging
11
+ logging.basicConfig(level=logging.INFO)
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ class EndpointHandler:
16
+ """Custom endpoint handler for NeMo Skills code execution inference."""
17
+
18
+ def __init__(self):
19
+ """
20
+ Initialize the handler with the model and prompt configurations.
21
+ """
22
+ self.model = None
23
+ self.prompt = None
24
+ self.initialized = False
25
+
26
+ # Configuration
27
+ self.prompt_config_path = os.getenv("PROMPT_CONFIG_PATH", "generic/math")
28
+ self.prompt_template_path = os.getenv("PROMPT_TEMPLATE_PATH", "openmath-instruct")
29
+
30
+ def _initialize_components(self):
31
+ """Initialize the model, sandbox, and prompt components lazily."""
32
+ if self.initialized:
33
+ return
34
+
35
+ try:
36
+ logger.info("Initializing sandbox...")
37
+ sandbox = get_sandbox(sandbox_type="local")
38
+
39
+ logger.info("Initializing code execution model...")
40
+ self.model = get_code_execution_model(
41
+ server_type="vllm",
42
+ sandbox=sandbox,
43
+ host="127.0.0.1",
44
+ port=5000
45
+ )
46
+
47
+ logger.info("Initializing prompt...")
48
+ if self.prompt_config_path:
49
+ self.prompt = get_prompt(
50
+ prompt_config=self.prompt_config_path,
51
+ prompt_template=self.prompt_template_path
52
+ )
53
+
54
+ self.initialized = True
55
+ logger.info("All components initialized successfully")
56
+
57
+ except Exception as e:
58
+ logger.warning(f"Failed to initialize the model")
59
+
60
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
61
+ """
62
+ Process inference requests.
63
+
64
+ Args:
65
+ data: Dictionary containing the request data
66
+ Expected keys:
67
+ - inputs: str or list of str - the input prompts/problems
68
+ - parameters: dict (optional) - generation parameters
69
+
70
+ Returns:
71
+ List of dictionaries containing the generated responses
72
+ """
73
+ try:
74
+ # Initialize components if not already done
75
+ self._initialize_components()
76
+
77
+ # Extract inputs and parameters
78
+ inputs = data.get("inputs", "")
79
+ parameters = data.get("parameters", {})
80
+
81
+ # Handle both single string and list of strings
82
+ if isinstance(inputs, str):
83
+ prompts = [inputs]
84
+ elif isinstance(inputs, list):
85
+ prompts = inputs
86
+ else:
87
+ raise ValueError("inputs must be a string or list of strings")
88
+
89
+ # If we have a prompt template configured, format the inputs
90
+ if self.prompt is not None:
91
+ formatted_prompts = []
92
+ for prompt_text in prompts:
93
+ formatted_prompt = self.prompt.fill({"problem": prompt_text, "total_code_executions": 8})
94
+ formatted_prompts.append(formatted_prompt)
95
+ prompts = formatted_prompts
96
+
97
+ # Get code execution arguments from prompt if available
98
+ extra_generate_params = {}
99
+ if self.prompt is not None:
100
+ extra_generate_params = self.prompt.get_code_execution_args()
101
+
102
+ # Set default generation parameters
103
+ generation_params = {
104
+ "tokens_to_generate": 12000,
105
+ "temperature": 0.0,
106
+ "top_p": 0.95,
107
+ "top_k": 0,
108
+ "repetition_penalty": 1.0,
109
+ "random_seed": 0,
110
+ }
111
+
112
+ # Update with provided parameters
113
+ generation_params.update(parameters)
114
+ generation_params.update(extra_generate_params)
115
+
116
+ logger.info(f"Processing {len(prompts)} prompt(s)")
117
+
118
+ # Generate responses
119
+ outputs = self.model.generate(
120
+ prompts=prompts,
121
+ **generation_params
122
+ )
123
+
124
+ # Format outputs
125
+ results = []
126
+ for output in outputs:
127
+ result = {
128
+ "generated_text": output.get("generation", ""),
129
+ "code_rounds_executed": output.get("code_rounds_executed", 0),
130
+ }
131
+ results.append(result)
132
+
133
+ logger.info(f"Successfully processed {len(results)} request(s)")
134
+ return results
135
+
136
+ except Exception as e:
137
+ logger.error(f"Error processing request: {str(e)}")
138
+ logger.error(traceback.format_exc())
139
+ return [{"error": str(e), "generated_text": ""}]