Spaces:
Running
Running
| from typing import Any, Callable, Dict, List, Optional, Sequence, Union | |
| import httpx | |
| from langchain_core.callbacks import CallbackManagerForLLMRun | |
| from langchain_core.language_models import LanguageModelInput | |
| from langchain_core.language_models.chat_models import BaseChatModel | |
| from langchain_core.messages import ( | |
| AIMessage, | |
| AIMessageChunk, | |
| BaseMessage, | |
| ChatMessage, | |
| HumanMessage, | |
| SystemMessage, | |
| ToolMessage, | |
| ) | |
| from langchain_core.outputs import ChatGeneration, ChatResult | |
| from langchain_core.runnables import Runnable | |
| from langchain_core.tools import BaseTool | |
| from langchain_core.utils import pre_init | |
| from langchain_core.utils.function_calling import convert_to_openai_tool | |
| from pydantic import Field | |
| def _convert_message_to_dict(message: BaseMessage) -> dict: | |
| """Convert a LangChain message to a RITS message dict.""" | |
| message_dict = {"role": "", "content": message.content} | |
| if isinstance(message, HumanMessage): | |
| message_dict["role"] = "user" | |
| elif isinstance(message, AIMessage): | |
| message_dict["role"] = "assistant" | |
| elif isinstance(message, SystemMessage): | |
| message_dict["role"] = "system" | |
| elif isinstance(message, ToolMessage): | |
| message_dict["role"] = "tool" | |
| message_dict["tool_call_id"] = message.tool_call_id | |
| else: | |
| raise ValueError(f"Unknown message type: {type(message)}") | |
| if message.name: | |
| message_dict["name"] = message.name | |
| return message_dict | |
| def _convert_dict_to_message(response_dict: Dict[str, Any]) -> BaseMessage: | |
| """Convert a RITS message dict to a LangChain message.""" | |
| role = response_dict["role"] | |
| content = response_dict.get("content", "") | |
| if role == "user": | |
| return HumanMessage(content=content) | |
| elif role == "assistant": | |
| additional_kwargs = {} | |
| if tool_calls := response_dict.get("tool_calls"): | |
| additional_kwargs["tool_calls"] = tool_calls | |
| return AIMessageChunk(content=content, additional_kwargs=additional_kwargs) | |
| elif role == "system": | |
| return SystemMessage(content=content) | |
| elif role == "tool": | |
| return ToolMessage( | |
| content=content, | |
| tool_call_id=response_dict["tool_call_id"], | |
| name=response_dict.get("name"), | |
| ) | |
| else: | |
| return ChatMessage(content=content, role=role) | |
| class ChatRITS(BaseChatModel): | |
| model_name: str = Field("meta-llama/Llama-3.2-90B-Vision-Instruct", alias="model") | |
| """Model name to use.""" | |
| rits_api_key: str = Field(default=None, alias="api_key") | |
| """API key for RITS.""" | |
| rits_base_url: str = Field( | |
| default="https://inference-3scale-apicast-production.apps.rits.fmaas.res.ibm.com", | |
| alias="base_url", | |
| ) | |
| """Base URL for RITS.""" | |
| best_of: int = 1 | |
| """Generates best_of completions server-side and returns the "best".""" | |
| frequency_penalty: float = 0 | |
| """Penalizes repeated tokens according to frequency.""" | |
| logit_bias: Optional[Dict[str, float]] = Field(default_factory=dict) | |
| """Adjust the probability of specific tokens being generated.""" | |
| max_tokens: int = 16 | |
| """The maximum number of tokens to generate in the completion.""" | |
| min_tokens: int = 0 | |
| """The minimum number of tokens to generate in the completion.""" | |
| n: int = 1 | |
| """How many completions to generate for each prompt.""" | |
| presence_penalty: float = 0 | |
| """Penalizes repeated tokens.""" | |
| seed: Optional[int] = None | |
| """Seed for randomness.""" | |
| temperature: float = 0.7 | |
| """What sampling temperature to use.""" | |
| top_p: float = 1 | |
| """Total probability mass of tokens to consider at each step.""" | |
| top_k: int = -1 | |
| """Number of top tokens to consider at each step.""" | |
| streaming: bool = False | |
| """Whether to stream the results or not.""" | |
| repetition_penalty: float = 1 | |
| """Penalizes repeated tokens.""" | |
| length_penalty: float = 1 | |
| """Penalizes longer completions.""" | |
| ignore_eos: bool = False | |
| """Whether to ignore the eos token.""" | |
| stop: Optional[List[str]] = None | |
| """Stop words to use when generating. Model output is cut off at the first occurrence of the stop substrings.""" | |
| def validate_environment(cls, values: Dict) -> Dict: | |
| """Validate the environment.""" | |
| if values["api_key"] is None: | |
| raise ValueError("RITS API key is required.") | |
| # values["base_url"] = | |
| return values | |
| def _default_params(self) -> Dict[str, Any]: | |
| """Return default params for the model.""" | |
| return { | |
| "best_of": self.best_of, | |
| "frequency_penalty": self.frequency_penalty, | |
| "logit_bias": self.logit_bias, | |
| "max_tokens": self.max_tokens, | |
| "min_tokens": self.min_tokens, | |
| "n": self.n, | |
| "presence_penalty": self.presence_penalty, | |
| "seed": self.seed, | |
| "temperature": self.temperature, | |
| "top_p": self.top_p, | |
| "top_k": self.top_k, | |
| "repetition_penalty": self.repetition_penalty, | |
| "length_penalty": self.length_penalty, | |
| "ignore_eos": self.ignore_eos, | |
| } | |
| def _llm_type(self) -> str: | |
| """Return type of llm.""" | |
| return "rits-chat" | |
| def _convert_messages_to_dicts(messages: list[BaseMessage]) -> list[dict]: | |
| """Convert LangChain messages to RITS messages.""" | |
| return [_convert_message_to_dict(message) for message in messages] | |
| def _create_chat_result(self, response: Dict) -> ChatResult: | |
| generations = [] | |
| for choice in response["choices"]: | |
| message = _convert_dict_to_message(choice["message"]) | |
| gen = ChatGeneration( | |
| message=message, | |
| generation_info=dict(finish_reason=choice.get("finish_reason")), | |
| ) | |
| generations.append(gen) | |
| token_usage = response.get("usage", {}) | |
| llm_output = { | |
| "token_usage": token_usage, | |
| "model_name": self.model_name, | |
| "system_fingerprint": response.get("system_fingerprint", ""), | |
| } | |
| return ChatResult(generations=generations, llm_output=llm_output) | |
| def bind_tools( | |
| self, | |
| tools: Sequence[Union[Dict[str, Any], type, Callable, BaseTool]], # noqa: UP006 | |
| **kwargs: Any, | |
| ) -> Runnable[LanguageModelInput, BaseMessage]: | |
| formatted_tools = [convert_to_openai_tool(tool) for tool in tools] | |
| return super().bind(tools=formatted_tools, **kwargs) | |
| def _generate( | |
| self, | |
| messages: list[BaseMessage], | |
| stop: Optional[list[str]] = None, | |
| run_manager: Optional[CallbackManagerForLLMRun] = None, | |
| **kwargs: Any, | |
| ) -> ChatResult: | |
| """Generate text.""" | |
| if self.stop is not None and stop is not None: | |
| raise ValueError("`stop` found in both the input and default params.") | |
| elif self.stop is not None: | |
| stop = self.stop | |
| elif stop is None: | |
| stop = [] | |
| params = { | |
| **self._default_params, | |
| **kwargs, | |
| } | |
| messages = self._convert_messages_to_dicts(messages) | |
| with httpx.Client(timeout=60.0) as client: | |
| response = client.post( | |
| url=f"{self.rits_base_url}/v1/chat/completions", | |
| headers={"RITS_API_KEY": self.rits_api_key}, | |
| json={ | |
| "messages": messages, | |
| "stop": stop, | |
| "model": self.model_name, | |
| **params, | |
| }, | |
| ) | |
| if response.status_code != 200: | |
| raise ValueError(f"Failed to call RITS: {response.text} with status code {response.status_code}") | |
| response_json = response.json() | |
| return self._create_chat_result(response_json) | |