| | import os |
| | import dataclasses |
| | import base64 |
| | import copy |
| | import hashlib |
| | import datetime |
| | from io import BytesIO |
| | from PIL import Image |
| | from typing import Any, List, Dict, Union |
| | from dataclasses import field |
| |
|
| | from utils import LOGDIR |
| |
|
| |
|
| | def pil2base64(img: Image.Image) -> str: |
| | buffered = BytesIO() |
| | img.save(buffered, format="PNG") |
| | return base64.b64encode(buffered.getvalue()).decode() |
| |
|
| |
|
| | def resize_img(img: Image.Image, max_len: int, min_len: int) -> Image.Image: |
| | max_hw, min_hw = max(img.size), min(img.size) |
| | aspect_ratio = max_hw / min_hw |
| | |
| | shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw)) |
| | longest_edge = int(shortest_edge * aspect_ratio) |
| | W, H = img.size |
| | if H > W: |
| | H, W = longest_edge, shortest_edge |
| | else: |
| | H, W = shortest_edge, longest_edge |
| | return img.resize((W, H)) |
| |
|
| |
|
| | @dataclasses.dataclass |
| | class Conversation: |
| | """A class that keeps all conversation history.""" |
| |
|
| | SYSTEM = "system" |
| | USER = "user" |
| | ASSISTANT = "assistant" |
| |
|
| | roles: List[str] = field( |
| | default_factory=lambda: [ |
| | Conversation.SYSTEM, |
| | Conversation.USER, |
| | Conversation.ASSISTANT, |
| | ] |
| | ) |
| | mandatory_system_message = "我是书生·万象,英文名是InternVL,是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。" |
| | system_message: str = "请尽可能详细地回答用户的问题。" |
| | messages: List[Dict[str, Any]] = field(default_factory=lambda: []) |
| | max_image_limit: int = 4 |
| | skip_next: bool = False |
| | streaming_placeholder: str = "▌" |
| |
|
| | def get_system_message(self): |
| | return self.mandatory_system_message + "\n\n" + self.system_message |
| |
|
| | def set_system_message(self, system_message: str): |
| | self.system_message = system_message |
| | return self |
| |
|
| | def get_prompt(self, inlude_image=False): |
| | send_messages = [{"role": "system", "content": self.get_system_message()}] |
| | |
| | for message in self.messages: |
| | if message["role"] == self.USER: |
| | user_message = { |
| | "role": self.USER, |
| | "content": message["content"], |
| | } |
| | if inlude_image and "image" in message: |
| | user_message["image"] = [] |
| | for image in message["image"]: |
| | user_message["image"].append(pil2base64(image)) |
| | send_messages.append(user_message) |
| | elif message["role"] == self.ASSISTANT: |
| | send_messages.append( |
| | {"role": self.ASSISTANT, "content": message["content"]} |
| | ) |
| | elif message["role"] == self.SYSTEM: |
| | send_messages.append( |
| | { |
| | "role": self.SYSTEM, |
| | "content": message["content"], |
| | } |
| | ) |
| | else: |
| | raise ValueError(f"Invalid role: {message['role']}") |
| | return send_messages |
| | |
| | def get_prompt_v2(self, inlude_image=False, max_dynamic_patch=12): |
| | send_messages = [ |
| | { |
| | "role": "system", |
| | "content": self.get_system_message(), |
| | } |
| | ] |
| | for message in self.messages: |
| | if message["role"] == self.USER: |
| | user_message = { |
| | "role": self.USER, |
| | "content": message["content"], |
| | } |
| | if inlude_image and "image" in message: |
| | user_message["image"] = [] |
| | for image in message["image"]: |
| | user_message["image"].append(pil2base64(image)) |
| | |
| | content = [{"type": "text", "text": message["content"]}] |
| | for image_base64 in user_message["image"]: |
| | content.append({ |
| | "type": "image_url", |
| | "image_url": { |
| | "url": f"data:image/jpeg;base64,{image_base64}", |
| | "max_dynamic_patch": max_dynamic_patch |
| | } |
| | }) |
| | send_messages.append({'role': self.USER, 'content': content}) |
| | else: |
| | send_messages.append(user_message) |
| | elif message["role"] == self.ASSISTANT: |
| | send_messages.append( |
| | {"role": self.ASSISTANT, "content": message["content"]} |
| | ) |
| | elif message["role"] == self.SYSTEM: |
| | send_messages.append( |
| | { |
| | "role": self.SYSTEM, |
| | "content": message["content"], |
| | } |
| | ) |
| | else: |
| | raise ValueError(f"Invalid role: {message['role']}") |
| | return send_messages |
| |
|
| | def append_message( |
| | self, |
| | role, |
| | content, |
| | image_list=None, |
| | ): |
| | self.messages.append( |
| | { |
| | "role": role, |
| | "content": content, |
| | "image": [] if image_list is None else image_list, |
| | |
| | } |
| | ) |
| |
|
| | def get_images( |
| | self, |
| | return_copy=False, |
| | return_base64=False, |
| | source: Union[str, None] = None, |
| | ): |
| | assert source in [self.USER, self.ASSISTANT, None], f"Invalid source: {soure}" |
| | images = [] |
| | for i, msg in enumerate(self.messages): |
| | if source and msg["role"] != source: |
| | continue |
| |
|
| | for image in msg.get("image", []): |
| | |
| | if return_copy: |
| | image = image.copy() |
| |
|
| | if return_base64: |
| | image = pil2base64(image) |
| |
|
| | images.append(image) |
| |
|
| | return images |
| |
|
| | def to_gradio_chatbot(self): |
| | ret = [] |
| | for i, msg in enumerate(self.messages): |
| | if msg["role"] == self.SYSTEM: |
| | continue |
| |
|
| | alt_str = ( |
| | "user upload image" if msg["role"] == self.USER else "output image" |
| | ) |
| | image = msg.get("image", []) |
| | if not isinstance(image, list): |
| | images = [image] |
| | else: |
| | images = image |
| |
|
| | img_str_list = [] |
| | for i in range(len(images)): |
| | image = resize_img( |
| | images[i], |
| | 400, |
| | 200, |
| | ) |
| | img_b64_str = pil2base64(image) |
| | W, H = image.size |
| | img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="{alt_str}" style="width: {W}px; max-width:none; max-height:none"></img>' |
| | |
| | |
| | |
| | img_str_list.append(img_str) |
| |
|
| | if msg["role"] == self.USER: |
| | msg_str = " ".join(img_str_list) + msg["content"] |
| | ret.append([msg_str, None]) |
| | else: |
| | msg_str = msg["content"] + " ".join(img_str_list) |
| | ret[-1][-1] = msg_str |
| | return ret |
| |
|
| | def update_message(self, role, content, image=None, idx=-1): |
| | assert len(self.messages) > 0, "No message in the conversation." |
| |
|
| | idx = (idx + len(self.messages)) % len(self.messages) |
| |
|
| | assert ( |
| | self.messages[idx]["role"] == role |
| | ), f"Role mismatch: {role} vs {self.messages[idx]['role']}" |
| |
|
| | self.messages[idx]["content"] = content |
| | if image is not None: |
| | if image not in self.messages[idx]["image"]: |
| | self.messages[idx]["image"] = [] |
| | if not isinstance(image, list): |
| | image = [image] |
| | self.messages[idx]["image"].extend(image) |
| |
|
| | def return_last_message(self): |
| | return self.messages[-1]["content"] |
| |
|
| | def end_of_current_turn(self): |
| | assert len(self.messages) > 0, "No message in the conversation." |
| | assert ( |
| | self.messages[-1]["role"] == self.ASSISTANT |
| | ), f"It should end with the message from assistant instead of {self.messages[-1]['role']}." |
| |
|
| | if self.messages[-1]["content"][-1] != self.streaming_placeholder: |
| | return |
| |
|
| | self.update_message(self.ASSISTANT, self.messages[-1]["content"][:-1], None) |
| |
|
| | def copy(self): |
| | return Conversation( |
| | mandatory_system_message=self.mandatory_system_message, |
| | system_message=self.system_message, |
| | roles=copy.deepcopy(self.roles), |
| | messages=copy.deepcopy(self.messages), |
| | ) |
| |
|
| | def dict(self): |
| | """ |
| | all_images = state.get_images() |
| | all_image_hash = [hashlib.md5(image.tobytes()).hexdigest() for image in all_images] |
| | t = datetime.datetime.now() |
| | for image, hash in zip(all_images, all_image_hash): |
| | filename = os.path.join( |
| | LOGDIR, "serve_images", f"{t.year}-{t.month:02d}-{t.day:02d}", f"{hash}.jpg" |
| | ) |
| | if not os.path.isfile(filename): |
| | os.makedirs(os.path.dirname(filename), exist_ok=True) |
| | image.save(filename) |
| | """ |
| | messages = [] |
| | for message in self.messages: |
| | images = [] |
| | for image in message.get("image", []): |
| | filename = self.save_image(image) |
| | images.append(filename) |
| |
|
| | messages.append( |
| | { |
| | "role": message["role"], |
| | "content": message["content"], |
| | "image": images, |
| | } |
| | ) |
| | if len(images) == 0: |
| | messages[-1].pop("image") |
| |
|
| | return { |
| | "mandatory_system_message": self.mandatory_system_message, |
| | "system_message": self.system_message, |
| | "roles": self.roles, |
| | "messages": messages, |
| | } |
| |
|
| | def save_image(self, image: Image.Image) -> str: |
| | t = datetime.datetime.now() |
| | image_hash = hashlib.md5(image.tobytes()).hexdigest() |
| | filename = os.path.join( |
| | LOGDIR, |
| | "serve_images", |
| | f"{t.year}-{t.month:02d}-{t.day:02d}", |
| | f"{image_hash}.jpg", |
| | ) |
| | if not os.path.isfile(filename): |
| | os.makedirs(os.path.dirname(filename), exist_ok=True) |
| | image.save(filename) |
| |
|
| | return filename |
| |
|
| |
|
| |
|