Spaces:
Running
Running
File size: 7,897 Bytes
0646b18 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 |
from typing import Any, Callable, Dict, List, Optional, Sequence, Union
import httpx
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models import LanguageModelInput
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
ChatMessage,
HumanMessage,
SystemMessage,
ToolMessage,
)
from langchain_core.outputs import ChatGeneration, ChatResult
from langchain_core.runnables import Runnable
from langchain_core.tools import BaseTool
from langchain_core.utils import pre_init
from langchain_core.utils.function_calling import convert_to_openai_tool
from pydantic import Field
def _convert_message_to_dict(message: BaseMessage) -> dict:
"""Convert a LangChain message to a RITS message dict."""
message_dict = {"role": "", "content": message.content}
if isinstance(message, HumanMessage):
message_dict["role"] = "user"
elif isinstance(message, AIMessage):
message_dict["role"] = "assistant"
elif isinstance(message, SystemMessage):
message_dict["role"] = "system"
elif isinstance(message, ToolMessage):
message_dict["role"] = "tool"
message_dict["tool_call_id"] = message.tool_call_id
else:
raise ValueError(f"Unknown message type: {type(message)}")
if message.name:
message_dict["name"] = message.name
return message_dict
def _convert_dict_to_message(response_dict: Dict[str, Any]) -> BaseMessage:
"""Convert a RITS message dict to a LangChain message."""
role = response_dict["role"]
content = response_dict.get("content", "")
if role == "user":
return HumanMessage(content=content)
elif role == "assistant":
additional_kwargs = {}
if tool_calls := response_dict.get("tool_calls"):
additional_kwargs["tool_calls"] = tool_calls
return AIMessageChunk(content=content, additional_kwargs=additional_kwargs)
elif role == "system":
return SystemMessage(content=content)
elif role == "tool":
return ToolMessage(
content=content,
tool_call_id=response_dict["tool_call_id"],
name=response_dict.get("name"),
)
else:
return ChatMessage(content=content, role=role)
class ChatRITS(BaseChatModel):
model_name: str = Field("meta-llama/Llama-3.2-90B-Vision-Instruct", alias="model")
"""Model name to use."""
rits_api_key: str = Field(default=None, alias="api_key")
"""API key for RITS."""
rits_base_url: str = Field(
default="https://inference-3scale-apicast-production.apps.rits.fmaas.res.ibm.com",
alias="base_url",
)
"""Base URL for RITS."""
best_of: int = 1
"""Generates best_of completions server-side and returns the "best"."""
frequency_penalty: float = 0
"""Penalizes repeated tokens according to frequency."""
logit_bias: Optional[Dict[str, float]] = Field(default_factory=dict)
"""Adjust the probability of specific tokens being generated."""
max_tokens: int = 16
"""The maximum number of tokens to generate in the completion."""
min_tokens: int = 0
"""The minimum number of tokens to generate in the completion."""
n: int = 1
"""How many completions to generate for each prompt."""
presence_penalty: float = 0
"""Penalizes repeated tokens."""
seed: Optional[int] = None
"""Seed for randomness."""
temperature: float = 0.7
"""What sampling temperature to use."""
top_p: float = 1
"""Total probability mass of tokens to consider at each step."""
top_k: int = -1
"""Number of top tokens to consider at each step."""
streaming: bool = False
"""Whether to stream the results or not."""
repetition_penalty: float = 1
"""Penalizes repeated tokens."""
length_penalty: float = 1
"""Penalizes longer completions."""
ignore_eos: bool = False
"""Whether to ignore the eos token."""
stop: Optional[List[str]] = None
"""Stop words to use when generating. Model output is cut off at the first occurrence of the stop substrings."""
@pre_init
def validate_environment(cls, values: Dict) -> Dict:
"""Validate the environment."""
if values["api_key"] is None:
raise ValueError("RITS API key is required.")
# values["base_url"] =
return values
@property
def _default_params(self) -> Dict[str, Any]:
"""Return default params for the model."""
return {
"best_of": self.best_of,
"frequency_penalty": self.frequency_penalty,
"logit_bias": self.logit_bias,
"max_tokens": self.max_tokens,
"min_tokens": self.min_tokens,
"n": self.n,
"presence_penalty": self.presence_penalty,
"seed": self.seed,
"temperature": self.temperature,
"top_p": self.top_p,
"top_k": self.top_k,
"repetition_penalty": self.repetition_penalty,
"length_penalty": self.length_penalty,
"ignore_eos": self.ignore_eos,
}
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "rits-chat"
@staticmethod
def _convert_messages_to_dicts(messages: list[BaseMessage]) -> list[dict]:
"""Convert LangChain messages to RITS messages."""
return [_convert_message_to_dict(message) for message in messages]
def _create_chat_result(self, response: Dict) -> ChatResult:
generations = []
for choice in response["choices"]:
message = _convert_dict_to_message(choice["message"])
gen = ChatGeneration(
message=message,
generation_info=dict(finish_reason=choice.get("finish_reason")),
)
generations.append(gen)
token_usage = response.get("usage", {})
llm_output = {
"token_usage": token_usage,
"model_name": self.model_name,
"system_fingerprint": response.get("system_fingerprint", ""),
}
return ChatResult(generations=generations, llm_output=llm_output)
def bind_tools(
self,
tools: Sequence[Union[Dict[str, Any], type, Callable, BaseTool]], # noqa: UP006
**kwargs: Any,
) -> Runnable[LanguageModelInput, BaseMessage]:
formatted_tools = [convert_to_openai_tool(tool) for tool in tools]
return super().bind(tools=formatted_tools, **kwargs)
def _generate(
self,
messages: list[BaseMessage],
stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
"""Generate text."""
if self.stop is not None and stop is not None:
raise ValueError("`stop` found in both the input and default params.")
elif self.stop is not None:
stop = self.stop
elif stop is None:
stop = []
params = {
**self._default_params,
**kwargs,
}
messages = self._convert_messages_to_dicts(messages)
with httpx.Client(timeout=60.0) as client:
response = client.post(
url=f"{self.rits_base_url}/v1/chat/completions",
headers={"RITS_API_KEY": self.rits_api_key},
json={
"messages": messages,
"stop": stop,
"model": self.model_name,
**params,
},
)
if response.status_code != 200:
raise ValueError(f"Failed to call RITS: {response.text} with status code {response.status_code}")
response_json = response.json()
return self._create_chat_result(response_json)
|