Commit
·
b3e45e6
1
Parent(s):
16f0db6
Update requirements.txt to include llama-cpp-python dependency; change default port in launch.json from 8000 to 8080; add VSCode settings for Python type checking; modify welcome message in main.py; enhance configuration in config.py with new model and file name; implement Message and ChatResponse models for structured messaging; refactor chat_request and chat_service to utilize new message structure; streamline chat response handling; and update client.py for improved OpenAI API integration.
Browse files- .vscode/launch.json +1 -1
- .vscode/settings.json +4 -0
- requirements.txt +2 -1
- src/constants/config.py +3 -0
- src/main.py +1 -1
- src/models/others/message.py +22 -0
- src/models/requests/chat_request.py +14 -5
- src/models/responses/chat_response.py +70 -0
- src/routes/chat_routes.py +1 -2
- src/routes/vector_store_routes.py +0 -6
- src/services/chat_service.py +49 -39
- src/utils/client.py +43 -1
- src/utils/image_pipeline.py +28 -28
.vscode/launch.json
CHANGED
|
@@ -24,7 +24,7 @@
|
|
| 24 |
"src.main:app",
|
| 25 |
"--reload",
|
| 26 |
"--port",
|
| 27 |
-
"
|
| 28 |
"--host",
|
| 29 |
"0.0.0.0",
|
| 30 |
]
|
|
|
|
| 24 |
"src.main:app",
|
| 25 |
"--reload",
|
| 26 |
"--port",
|
| 27 |
+
"8080",
|
| 28 |
"--host",
|
| 29 |
"0.0.0.0",
|
| 30 |
]
|
.vscode/settings.json
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"python.analysis.typeCheckingMode": "basic",
|
| 3 |
+
"python.analysis.autoImportCompletions": true
|
| 4 |
+
}
|
requirements.txt
CHANGED
|
@@ -21,4 +21,5 @@ langchain_chroma>=0.2.2
|
|
| 21 |
chromadb>=0.6.3
|
| 22 |
sentence_transformers>=4.1.0
|
| 23 |
langchain_huggingface>=0.1.2
|
| 24 |
-
huggingface_hub[hf_xet]
|
|
|
|
|
|
| 21 |
chromadb>=0.6.3
|
| 22 |
sentence_transformers>=4.1.0
|
| 23 |
langchain_huggingface>=0.1.2
|
| 24 |
+
huggingface_hub[hf_xet]
|
| 25 |
+
llama-cpp-python==0.3.8
|
src/constants/config.py
CHANGED
|
@@ -7,9 +7,12 @@ TORCH_DEVICE = (
|
|
| 7 |
else "cpu"
|
| 8 |
)
|
| 9 |
IMAGE_MODEL_ID_OR_LINK = "stable-diffusion-v1-5/stable-diffusion-v1-5"
|
|
|
|
| 10 |
CACHE_DIR = "/tmp/cache"
|
| 11 |
DATA_DIR = "/tmp/data"
|
| 12 |
EMBEDDING_MODEL = "intfloat/multilingual-e5-large-instruct"
|
| 13 |
UPLOAD_DIR = "/tmp/uploads"
|
| 14 |
OUTPUT_DIR = "/tmp/outputs"
|
|
|
|
|
|
|
| 15 |
# EMBEDDING_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
|
|
|
|
| 7 |
else "cpu"
|
| 8 |
)
|
| 9 |
IMAGE_MODEL_ID_OR_LINK = "stable-diffusion-v1-5/stable-diffusion-v1-5"
|
| 10 |
+
MODEL_NAME = "facebook/opt-125m"
|
| 11 |
CACHE_DIR = "/tmp/cache"
|
| 12 |
DATA_DIR = "/tmp/data"
|
| 13 |
EMBEDDING_MODEL = "intfloat/multilingual-e5-large-instruct"
|
| 14 |
UPLOAD_DIR = "/tmp/uploads"
|
| 15 |
OUTPUT_DIR = "/tmp/outputs"
|
| 16 |
+
FILE_NAME = "super-lite-model.gguf"
|
| 17 |
+
# FILE_NAME = "llama_3.1_8b_instruct_q4_k_m.gguf"
|
| 18 |
# EMBEDDING_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
|
src/main.py
CHANGED
|
@@ -47,7 +47,7 @@ app.include_router(process_file_routes.router, prefix="/api/v1")
|
|
| 47 |
app.include_router(vector_store_routes.router, prefix="/api/v1")
|
| 48 |
@app.get("/")
|
| 49 |
def read_root():
|
| 50 |
-
return {"message": "Welcome my API"}
|
| 51 |
|
| 52 |
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
| 53 |
app.mount(OUTPUT_DIR, StaticFiles(directory=OUTPUT_DIR), name="outputs")
|
|
|
|
| 47 |
app.include_router(vector_store_routes.router, prefix="/api/v1")
|
| 48 |
@app.get("/")
|
| 49 |
def read_root():
|
| 50 |
+
return {"message": "Welcome to my API"}
|
| 51 |
|
| 52 |
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
| 53 |
app.mount(OUTPUT_DIR, StaticFiles(directory=OUTPUT_DIR), name="outputs")
|
src/models/others/message.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from enum import Enum
|
| 2 |
+
from typing import Optional
|
| 3 |
+
|
| 4 |
+
from pydantic import BaseModel
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class Role(str, Enum):
|
| 8 |
+
assistant = "assistant"
|
| 9 |
+
user = "user"
|
| 10 |
+
system = "system"
|
| 11 |
+
tool = "tool"
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class Message(BaseModel):
|
| 15 |
+
role: Role
|
| 16 |
+
content: Optional[str] = None
|
| 17 |
+
|
| 18 |
+
def to_map(self):
|
| 19 |
+
return {
|
| 20 |
+
"role": self.role.value,
|
| 21 |
+
"content": self.content,
|
| 22 |
+
}
|
src/models/requests/chat_request.py
CHANGED
|
@@ -1,18 +1,27 @@
|
|
|
|
|
| 1 |
from pydantic import BaseModel
|
| 2 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
class ChatRequest(BaseModel):
|
| 4 |
-
|
|
|
|
|
|
|
| 5 |
has_file: bool = False
|
| 6 |
chat_session_id: str | None = None
|
| 7 |
-
|
| 8 |
model_config = {
|
| 9 |
"json_schema_extra": {
|
| 10 |
"examples": [
|
| 11 |
{
|
| 12 |
-
"prompt": [{"role": "user", "content": "Hello, how are you?"}],
|
| 13 |
"has_file": False,
|
| 14 |
-
"chat_session_id": "123"
|
|
|
|
|
|
|
|
|
|
| 15 |
}
|
| 16 |
]
|
| 17 |
}
|
| 18 |
-
}
|
|
|
|
| 1 |
+
from typing import List, Optional
|
| 2 |
from pydantic import BaseModel
|
| 3 |
|
| 4 |
+
from constants.config import MODEL_NAME
|
| 5 |
+
from models.others.message import Role, Message
|
| 6 |
+
|
| 7 |
+
|
| 8 |
class ChatRequest(BaseModel):
|
| 9 |
+
messages: List[Message]
|
| 10 |
+
# temperature: Optional[float] = 0.7
|
| 11 |
+
# max_tokens: Optional[int] = -1
|
| 12 |
has_file: bool = False
|
| 13 |
chat_session_id: str | None = None
|
| 14 |
+
|
| 15 |
model_config = {
|
| 16 |
"json_schema_extra": {
|
| 17 |
"examples": [
|
| 18 |
{
|
|
|
|
| 19 |
"has_file": False,
|
| 20 |
+
"chat_session_id": "123",
|
| 21 |
+
"messages": [{"role": Role.user, "content": "hello"}],
|
| 22 |
+
# "temperature":0.7,
|
| 23 |
+
# "max_tokens": -1,
|
| 24 |
}
|
| 25 |
]
|
| 26 |
}
|
| 27 |
+
}
|
src/models/responses/chat_response.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, List, Optional
|
| 2 |
+
from pydantic import BaseModel
|
| 3 |
+
from models.others.message import Message, Role
|
| 4 |
+
|
| 5 |
+
# class Usage(BaseModel):
|
| 6 |
+
# prompt_token: int
|
| 7 |
+
# completion_token: int
|
| 8 |
+
# total_tokens: int
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class Choice(BaseModel):
|
| 12 |
+
# index: int
|
| 13 |
+
# logprobs: Any
|
| 14 |
+
# finish_reason: Optional[str]
|
| 15 |
+
message: Optional[Message] = None
|
| 16 |
+
delta: Optional[Message] = None
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class ChatResponse(BaseModel):
|
| 20 |
+
id: Optional[str] = None
|
| 21 |
+
# object: Optional[str] = None
|
| 22 |
+
# created: Optional[int] = None
|
| 23 |
+
# model: Optional[str] = None
|
| 24 |
+
# system_fingerprint: Optional[str] = None
|
| 25 |
+
# usage: Optional[Usage] = None
|
| 26 |
+
choices: Optional[List[Choice]] = None
|
| 27 |
+
|
| 28 |
+
@classmethod
|
| 29 |
+
def from_stream_chunk(cls, chunk: dict, last_role: Optional[Role] = None):
|
| 30 |
+
choices = []
|
| 31 |
+
updated_role = last_role # Default to last role
|
| 32 |
+
|
| 33 |
+
for choice in chunk.get("choices", []):
|
| 34 |
+
delta_data = choice.get("delta", {})
|
| 35 |
+
|
| 36 |
+
# Skip chunks that contain neither content nor role
|
| 37 |
+
if not delta_data.get("content") and not delta_data.get("role"):
|
| 38 |
+
continue
|
| 39 |
+
|
| 40 |
+
# Determine role
|
| 41 |
+
if "role" in delta_data and delta_data["role"] is not None:
|
| 42 |
+
try:
|
| 43 |
+
updated_role = Role(delta_data["role"])
|
| 44 |
+
except ValueError:
|
| 45 |
+
# Skip or log invalid role values
|
| 46 |
+
continue
|
| 47 |
+
|
| 48 |
+
if not updated_role:
|
| 49 |
+
# Still no role available, skip
|
| 50 |
+
continue
|
| 51 |
+
|
| 52 |
+
message = Message(
|
| 53 |
+
role=updated_role,
|
| 54 |
+
content=delta_data.get("content"),
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
choices.append(
|
| 58 |
+
Choice(
|
| 59 |
+
message=message,
|
| 60 |
+
delta=message,
|
| 61 |
+
)
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
return (
|
| 65 |
+
cls(
|
| 66 |
+
id=chunk.get("id"),
|
| 67 |
+
choices=choices,
|
| 68 |
+
),
|
| 69 |
+
updated_role,
|
| 70 |
+
)
|
src/routes/chat_routes.py
CHANGED
|
@@ -9,7 +9,6 @@ from models.requests.chat_request import ChatRequest
|
|
| 9 |
from models.responses.base_exception_response import BaseExceptionResponse
|
| 10 |
from models.responses.base_response import BaseResponse
|
| 11 |
from services import chat_service
|
| 12 |
-
from services.process_file_service import get_file_content
|
| 13 |
|
| 14 |
router = APIRouter(tags=["Chat"])
|
| 15 |
|
|
@@ -63,7 +62,7 @@ async def chat(request: ChatRequest):
|
|
| 63 |
|
| 64 |
try:
|
| 65 |
response = chat_service.chat_generate(request=request)
|
| 66 |
-
return BaseResponse(data=
|
| 67 |
except Exception as e:
|
| 68 |
raise BaseExceptionResponse(message=str(e))
|
| 69 |
|
|
|
|
| 9 |
from models.responses.base_exception_response import BaseExceptionResponse
|
| 10 |
from models.responses.base_response import BaseResponse
|
| 11 |
from services import chat_service
|
|
|
|
| 12 |
|
| 13 |
router = APIRouter(tags=["Chat"])
|
| 14 |
|
|
|
|
| 62 |
|
| 63 |
try:
|
| 64 |
response = chat_service.chat_generate(request=request)
|
| 65 |
+
return BaseResponse(data=response)
|
| 66 |
except Exception as e:
|
| 67 |
raise BaseExceptionResponse(message=str(e))
|
| 68 |
|
src/routes/vector_store_routes.py
CHANGED
|
@@ -1,13 +1,7 @@
|
|
| 1 |
-
import json
|
| 2 |
-
import uuid
|
| 3 |
from fastapi import APIRouter
|
| 4 |
-
from models.requests.chat_request import ChatRequest
|
| 5 |
from models.responses.base_exception_response import BaseExceptionResponse
|
| 6 |
from models.responses.base_response import BaseResponse
|
| 7 |
from services import vector_store_service
|
| 8 |
-
from utils.client import openai_client
|
| 9 |
-
import os
|
| 10 |
-
from chromadb import PersistentClient
|
| 11 |
|
| 12 |
|
| 13 |
router = APIRouter(tags=["Vector Store"])
|
|
|
|
|
|
|
|
|
|
| 1 |
from fastapi import APIRouter
|
|
|
|
| 2 |
from models.responses.base_exception_response import BaseExceptionResponse
|
| 3 |
from models.responses.base_response import BaseResponse
|
| 4 |
from services import vector_store_service
|
|
|
|
|
|
|
|
|
|
| 5 |
|
| 6 |
|
| 7 |
router = APIRouter(tags=["Vector Store"])
|
src/services/chat_service.py
CHANGED
|
@@ -1,30 +1,34 @@
|
|
| 1 |
from constants import system_prompts
|
| 2 |
from models.requests.chat_request import ChatRequest
|
| 3 |
from services import vector_store_service
|
|
|
|
| 4 |
from utils.timing import measure_time
|
| 5 |
from utils.tools import tools_helper, tools_define
|
|
|
|
| 6 |
|
| 7 |
-
|
|
|
|
| 8 |
"""Build system prompt with context if file is provided."""
|
| 9 |
-
messages = [
|
| 10 |
-
|
| 11 |
-
if not request.has_file or not vector_store_service.check_if_collection_exists(
|
| 12 |
-
|
| 13 |
-
|
|
|
|
| 14 |
|
| 15 |
with measure_time("Get data from vector store"):
|
| 16 |
vectorstore = vector_store_service.get_vector_store(request.chat_session_id)
|
| 17 |
-
query = request.
|
| 18 |
-
results = vectorstore.similarity_search(query=query, k=10)
|
| 19 |
|
| 20 |
if not results:
|
| 21 |
return messages
|
| 22 |
|
| 23 |
with measure_time("Building context prompt"):
|
| 24 |
-
context =
|
| 25 |
for document in results:
|
| 26 |
# print(f"Document:{document.page_content[:50]}, score:{score}\n\n")
|
| 27 |
-
source = document.metadata.get(
|
| 28 |
context += f"Context from file: {source}\n\n{document.page_content}\n\n"
|
| 29 |
|
| 30 |
embedded_prompt = (
|
|
@@ -35,56 +39,62 @@ def build_context_prompt(request: ChatRequest) -> list:
|
|
| 35 |
f"CONTEXT: {context}\nQUESTION: {query}"
|
| 36 |
)
|
| 37 |
|
| 38 |
-
messages.append(
|
| 39 |
return messages
|
| 40 |
|
| 41 |
-
|
|
|
|
|
|
|
|
|
|
| 42 |
"""Streaming chat generation."""
|
| 43 |
messages = build_context_prompt(request)
|
| 44 |
-
messages.extend(request.
|
| 45 |
|
| 46 |
-
stream = openai_client.chat.completions.create(
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
)
|
|
|
|
|
|
|
| 52 |
|
| 53 |
final_tool_calls = {}
|
| 54 |
|
| 55 |
for chunk in stream:
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
|
|
|
|
|
|
|
|
|
| 60 |
|
| 61 |
if not final_tool_calls:
|
| 62 |
return
|
| 63 |
|
| 64 |
-
|
|
|
|
|
|
|
|
|
|
| 65 |
messages.append(tool_call_message)
|
| 66 |
|
| 67 |
-
new_stream =
|
| 68 |
-
messages=messages,
|
| 69 |
-
model='my-model',
|
| 70 |
-
stream=True
|
| 71 |
-
)
|
| 72 |
|
| 73 |
for chunk in new_stream:
|
| 74 |
yield chunk
|
| 75 |
|
| 76 |
|
| 77 |
-
def chat_generate(request: ChatRequest
|
| 78 |
"""Non-streaming (batched) chat generation."""
|
| 79 |
messages = build_context_prompt(request)
|
| 80 |
-
messages.extend(request.
|
| 81 |
|
| 82 |
with measure_time("Non-streaming chat generation"):
|
| 83 |
-
response = openai_client.chat.completions.create(
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
)
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
|
|
|
| 1 |
from constants import system_prompts
|
| 2 |
from models.requests.chat_request import ChatRequest
|
| 3 |
from services import vector_store_service
|
| 4 |
+
from utils.client import create, create_stream
|
| 5 |
from utils.timing import measure_time
|
| 6 |
from utils.tools import tools_helper, tools_define
|
| 7 |
+
from models.others.message import Message, Role
|
| 8 |
|
| 9 |
+
|
| 10 |
+
def build_context_prompt(request: ChatRequest) -> list[Message]:
|
| 11 |
"""Build system prompt with context if file is provided."""
|
| 12 |
+
messages = [Message(role=Role.system, content=system_prompts.system_prompt)]
|
| 13 |
+
|
| 14 |
+
if not request.has_file or not vector_store_service.check_if_collection_exists(
|
| 15 |
+
request.chat_session_id
|
| 16 |
+
):
|
| 17 |
+
return messages
|
| 18 |
|
| 19 |
with measure_time("Get data from vector store"):
|
| 20 |
vectorstore = vector_store_service.get_vector_store(request.chat_session_id)
|
| 21 |
+
query = request.messages[-1].content
|
| 22 |
+
results = vectorstore.similarity_search(query=query or "", k=10)
|
| 23 |
|
| 24 |
if not results:
|
| 25 |
return messages
|
| 26 |
|
| 27 |
with measure_time("Building context prompt"):
|
| 28 |
+
context = ""
|
| 29 |
for document in results:
|
| 30 |
# print(f"Document:{document.page_content[:50]}, score:{score}\n\n")
|
| 31 |
+
source = document.metadata.get("file_id", "Unknown File")
|
| 32 |
context += f"Context from file: {source}\n\n{document.page_content}\n\n"
|
| 33 |
|
| 34 |
embedded_prompt = (
|
|
|
|
| 39 |
f"CONTEXT: {context}\nQUESTION: {query}"
|
| 40 |
)
|
| 41 |
|
| 42 |
+
messages.append(Message(role=Role.system, content=embedded_prompt))
|
| 43 |
return messages
|
| 44 |
|
| 45 |
+
|
| 46 |
+
def chat_generate_stream(
|
| 47 |
+
request: ChatRequest,
|
| 48 |
+
):
|
| 49 |
"""Streaming chat generation."""
|
| 50 |
messages = build_context_prompt(request)
|
| 51 |
+
messages.extend(request.messages)
|
| 52 |
|
| 53 |
+
# stream = openai_client.chat.completions.create(
|
| 54 |
+
# messages=messages,
|
| 55 |
+
# model='my-model',
|
| 56 |
+
# stream=True,
|
| 57 |
+
# tools=tools_define.tools
|
| 58 |
+
# )
|
| 59 |
+
|
| 60 |
+
stream = create_stream(messages)
|
| 61 |
|
| 62 |
final_tool_calls = {}
|
| 63 |
|
| 64 |
for chunk in stream:
|
| 65 |
+
if chunk.choices and len(chunk.choices) > 0:
|
| 66 |
+
delta = chunk.choices[0].delta
|
| 67 |
+
if getattr(delta, "tool_calls", None):
|
| 68 |
+
final_tool_calls = tools_helper.final_tool_calls_handler(
|
| 69 |
+
final_tool_calls, delta.tool_calls
|
| 70 |
+
)
|
| 71 |
+
yield chunk
|
| 72 |
|
| 73 |
if not final_tool_calls:
|
| 74 |
return
|
| 75 |
|
| 76 |
+
tool_call_result = tools_helper.process_tool_calls(final_tool_calls)
|
| 77 |
+
tool_call_message = Message(
|
| 78 |
+
role=Role.tool, content=tool_call_result.get("content", "")
|
| 79 |
+
)
|
| 80 |
messages.append(tool_call_message)
|
| 81 |
|
| 82 |
+
new_stream = create_stream(messages)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
|
| 84 |
for chunk in new_stream:
|
| 85 |
yield chunk
|
| 86 |
|
| 87 |
|
| 88 |
+
def chat_generate(request: ChatRequest):
|
| 89 |
"""Non-streaming (batched) chat generation."""
|
| 90 |
messages = build_context_prompt(request)
|
| 91 |
+
messages.extend(request.messages)
|
| 92 |
|
| 93 |
with measure_time("Non-streaming chat generation"):
|
| 94 |
+
# response = openai_client.chat.completions.create(
|
| 95 |
+
# messages=messages,
|
| 96 |
+
# model='my-model',
|
| 97 |
+
# tools=tools_define.tools
|
| 98 |
+
# )
|
| 99 |
+
output = create(messages=messages)
|
| 100 |
+
return output
|
|
|
src/utils/client.py
CHANGED
|
@@ -1,8 +1,50 @@
|
|
|
|
|
|
|
|
| 1 |
import openai
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
# Initialize OpenAI API client
|
| 4 |
openai_client = openai.OpenAI(
|
| 5 |
-
base_url="http://localhost:
|
| 6 |
api_key="none",
|
| 7 |
)
|
| 8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from typing import Generator, List
|
| 3 |
import openai
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
from constants.config import FILE_NAME
|
| 7 |
+
from models.others.message import Message
|
| 8 |
+
from models.requests.chat_request import ChatRequest
|
| 9 |
+
from models.responses.chat_response import ChatResponse
|
| 10 |
+
from utils.tools import tools_define
|
| 11 |
|
| 12 |
# Initialize OpenAI API client
|
| 13 |
openai_client = openai.OpenAI(
|
| 14 |
+
base_url="http://localhost:8000/v1",
|
| 15 |
api_key="none",
|
| 16 |
)
|
| 17 |
|
| 18 |
+
from llama_cpp import ChatCompletionTool, Llama
|
| 19 |
+
|
| 20 |
+
# Determine number of CPU threads based on device
|
| 21 |
+
if torch.cuda.is_available() or (
|
| 22 |
+
hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
|
| 23 |
+
):
|
| 24 |
+
n_threads = 4 # Fewer threads if using GPU/MPS, adjust as needed
|
| 25 |
+
n_gpu_layers = 20
|
| 26 |
+
else:
|
| 27 |
+
n_threads = os.cpu_count() or 4
|
| 28 |
+
n_gpu_layers = 0
|
| 29 |
+
|
| 30 |
+
# Khởi tạo mô hình từ GGUF
|
| 31 |
+
llm = Llama(
|
| 32 |
+
model_path=FILE_NAME,
|
| 33 |
+
n_threads=n_threads,
|
| 34 |
+
n_gpu_layers=n_gpu_layers,
|
| 35 |
+
n_ctx=4096,
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
def create(messages: List[Message]):
|
| 39 |
+
prompt = [message.to_map() for message in messages]
|
| 40 |
+
output = llm.create_chat_completion(prompt) # type: ignore
|
| 41 |
+
return output
|
| 42 |
+
|
| 43 |
+
def create_stream(messages: List[Message]) -> Generator[ChatResponse, None, None]:
|
| 44 |
+
prompt = [message.to_map() for message in messages]
|
| 45 |
+
output = llm.create_chat_completion(prompt, stream=True, tools=tools_define.tools) # type: ignore
|
| 46 |
+
last_role = None
|
| 47 |
+
for chunk in output:
|
| 48 |
+
response, last_role = ChatResponse.from_stream_chunk(chunk, last_role) # type: ignore
|
| 49 |
+
if response.choices:
|
| 50 |
+
yield response
|
src/utils/image_pipeline.py
CHANGED
|
@@ -1,32 +1,32 @@
|
|
| 1 |
-
import torch
|
| 2 |
-
from diffusers import StableDiffusionPipeline
|
| 3 |
-
from constants.config import IMAGE_MODEL_ID_OR_LINK, TORCH_DEVICE
|
| 4 |
|
| 5 |
-
torch.backends.cuda.matmul.allow_tf32 = True # Enable TF32 for performance on CUDA
|
| 6 |
|
| 7 |
-
_pipeline = None
|
| 8 |
|
| 9 |
-
def get_pipeline() -> StableDiffusionPipeline:
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
|
| 32 |
-
pipeline = get_pipeline()
|
|
|
|
| 1 |
+
# import torch
|
| 2 |
+
# from diffusers import StableDiffusionPipeline
|
| 3 |
+
# from constants.config import IMAGE_MODEL_ID_OR_LINK, TORCH_DEVICE
|
| 4 |
|
| 5 |
+
# torch.backends.cuda.matmul.allow_tf32 = True # Enable TF32 for performance on CUDA
|
| 6 |
|
| 7 |
+
# _pipeline = None
|
| 8 |
|
| 9 |
+
# def get_pipeline() -> StableDiffusionPipeline:
|
| 10 |
+
# global _pipeline
|
| 11 |
+
# if _pipeline is None:
|
| 12 |
+
# try:
|
| 13 |
+
# _pipeline = StableDiffusionPipeline.from_pretrained(
|
| 14 |
+
# IMAGE_MODEL_ID_OR_LINK,
|
| 15 |
+
# torch_dtype=torch.bfloat16,
|
| 16 |
+
# variant="fp16",
|
| 17 |
+
# # safety_checker=True,
|
| 18 |
+
# use_safetensors=True,
|
| 19 |
+
# )
|
| 20 |
+
# # _pipeline = StableDiffusionPipeline.from_single_file(
|
| 21 |
+
# # IMAGE_MODEL_ID_OR_LINK,
|
| 22 |
+
# # torch_dtype=torch.bfloat16,
|
| 23 |
+
# # variant="fp16",
|
| 24 |
+
# # # safety_checker=True,
|
| 25 |
+
# # use_safetensors=True,
|
| 26 |
+
# # )
|
| 27 |
+
# _pipeline.to(TORCH_DEVICE)
|
| 28 |
+
# except Exception as e:
|
| 29 |
+
# raise RuntimeError(f"Failed to load the model: {e}")
|
| 30 |
+
# return _pipeline
|
| 31 |
|
| 32 |
+
# pipeline = get_pipeline()
|