| import argparse |
| import torch |
| import json |
| from config import config |
| from typing import List, Dict |
| from logger import logger |
|
|
| from transformers import AutoTokenizer |
|
|
| import functions |
| from prompter import PromptManager |
| from validator import validate_function_call_schema |
| from langchain_community.chat_models import ChatOllama |
| from langchain_community.llms import Ollama |
| from langchain.prompts import PromptTemplate |
| from langchain_core.output_parsers import StrOutputParser |
|
|
| from utils import ( |
| get_chat_template, |
| validate_and_extract_tool_calls |
| ) |
|
|
| class ModelInference: |
| def __init__(self, chat_template: str): |
| self.prompter = PromptManager() |
| |
| self.model = Ollama(model=config.ollama_model, temperature=0.0, format='json') |
| template = PromptTemplate(template="""<|im_start|>system\nYou are a function calling AI model. You are provided with function signatures within <tools></tools> XML tags. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. Here are the available tools: <tools> {"type": "function", "function": {"name": "get_stock_fundamentals", "description": "get_stock_fundamentals(symbol: str) -> dict - Get fundamental data for a given stock symbol using yfinance API.\\n\\n Args:\\n symbol (str): The stock symbol.\\n\\n Returns:\\n dict: A dictionary containing fundamental data.\\n Keys:\\n - \'symbol\': The stock symbol.\\n - \'company_name\': The long name of the company.\\n - \'sector\': The sector to which the company belongs.\\n - \'industry\': The industry to which the company belongs.\\n - \'market_cap\': The market capitalization of the company.\\n - \'pe_ratio\': The forward price-to-earnings ratio.\\n - \'pb_ratio\': The price-to-book ratio.\\n - \'dividend_yield\': The dividend yield.\\n - \'eps\': The trailing earnings per share.\\n - \'beta\': The beta value of the stock.\\n - \'52_week_high\': The 52-week high price of the stock.\\n - \'52_week_low\': The 52-week low price of the stock.", "parameters": {"type": "object", "properties": {"symbol": {"type": "string"}}, "required": ["symbol"]}}} </tools> Use the following pydantic model json schema for each tool call you will make: {"properties": {"arguments": {"title": "Arguments", "type": "object"}, "name": {"title": "Name", "type": "string"}}, "required": ["arguments", "name"], "title": "FunctionCall", "type": "object"} For each function call return a json object with function name and arguments within <tool_call></tool_call> XML tags as follows:\n<tool_call>\n{"arguments": <args-dict>, "name": <function-name>}\n</tool_call><|im_end|>\n""", input_variables=["question"]) |
| chain = template | self.model | StrOutputParser() |
| |
| self.tokenizer = AutoTokenizer.from_pretrained(config.hf_model, trust_remote_code=True) |
| self.tokenizer.pad_token = self.tokenizer.eos_token |
| self.tokenizer.padding_side = "left" |
|
|
| if self.tokenizer.chat_template is None: |
| print("No chat template defined, getting chat_template...") |
| self.tokenizer.chat_template = get_chat_template(chat_template) |
|
|
| logger.info(f"Model loaded: {self.model}") |
|
|
| def process_completion_and_validate(self, completion, chat_template): |
| if completion: |
| |
| validation, tool_calls, error_message = validate_and_extract_tool_calls(completion) |
|
|
| if validation: |
| logger.info(f"parsed tool calls:\n{json.dumps(tool_calls, indent=2)}") |
| return tool_calls, completion, error_message |
| else: |
| tool_calls = None |
| return tool_calls, completion, error_message |
| else: |
| logger.warning("Assistant message is None") |
| raise ValueError("Assistant message is None") |
| |
| def execute_function_call(self, tool_call): |
| |
| function_name = tool_call.get("name") |
| function_to_call = getattr(functions, function_name, None) |
| function_args = tool_call.get("arguments", {}) |
|
|
| logger.info(f"Invoking function call {function_name} ...") |
| function_response = function_to_call(*function_args.values()) |
| results_dict = f'{{"name": "{function_name}", "content": {function_response}}}' |
| return results_dict |
| |
| def run_inference(self, prompt: List[Dict[str, str]]): |
| inputs = self.tokenizer.apply_chat_template( |
| prompt, |
| add_generation_prompt=True, |
| tokenize=False, |
| ) |
| inputs = inputs.replace("<|begin_of_text|>", "") |
| completion = self.model.invoke(inputs, format='json') |
| return completion.content |
|
|
| def generate_function_call(self, query, chat_template, num_fewshot, max_depth=5): |
| try: |
| depth = 0 |
| user_message = f"{query}\nThis is the first turn and you don't have <tool_results> to analyze yet" |
| chat = [{"role": "user", "content": user_message}] |
| tools = functions.get_openai_tools() |
| prompt = self.prompter.generate_prompt(chat, tools, num_fewshot) |
| |
| completion = self.run_inference(prompt) |
|
|
| def recursive_loop(prompt, completion, depth): |
| nonlocal max_depth |
| tool_calls, assistant_message, error_message = self.process_completion_and_validate(completion, chat_template) |
| prompt.append({"role": "assistant", "content": assistant_message}) |
|
|
| tool_message = f"Agent iteration {depth} to assist with user query: {query}\n" |
| logger.info(f"Found tool calls: {tool_calls}") |
| if tool_calls: |
| logger.info(f"Assistant Message:\n{assistant_message}") |
|
|
| for tool_call in tool_calls: |
| validation, message = validate_function_call_schema(tool_call, tools) |
| if validation: |
| try: |
| function_response = self.execute_function_call(tool_call) |
| tool_message += f"<tool_response>\n{function_response}\n</tool_response>\n" |
| logger.info(f"Here's the response from the function call: {tool_call.get('name')}\n{function_response}") |
| except Exception as e: |
| logger.info(f"Could not execute function: {e}") |
| tool_message += f"<tool_response>\nThere was an error when executing the function: {tool_call.get('name')}\nHere's the error traceback: {e}\nPlease call this function again with correct arguments within XML tags <tool_call></tool_call>\n</tool_response>\n" |
| else: |
| logger.info(message) |
| tool_message += f"<tool_response>\nThere was an error validating function call against function signature: {tool_call.get('name')}\nHere's the error traceback: {message}\nPlease call this function again with correct arguments within XML tags <tool_call></tool_call>\n</tool_response>\n" |
| prompt.append({"role": "tool", "content": tool_message}) |
|
|
| depth += 1 |
| if depth >= max_depth: |
| print(f"Maximum recursion depth reached ({max_depth}). Stopping recursion.") |
| completion = self.run_inference(prompt) |
| return completion |
|
|
| |
| completion = self.run_inference(prompt) |
| return recursive_loop(prompt, completion, depth) |
| elif error_message: |
| logger.info(f"Assistant Message:\n{assistant_message}") |
| tool_message += f"<tool_response>\nThere was an error parsing function calls\n Here's the error stack trace: {error_message}\nPlease call the function again with correct syntax<tool_response>" |
| prompt.append({"role": "tool", "content": tool_message}) |
|
|
| depth += 1 |
| if depth >= max_depth: |
| print(f"Maximum recursion depth reached ({max_depth}). Stopping recursion.") |
| return completion |
|
|
| completion = self.run_inference(prompt) |
| return recursive_loop(prompt, completion, depth) |
| else: |
| logger.info(f"Assistant Message:\n{assistant_message}") |
| return assistant_message |
|
|
| return recursive_loop(prompt, completion, depth) |
|
|
| except Exception as e: |
| logger.error(f"Exception occurred: {e}") |
| raise e |
|
|