Spaces:
Runtime error
Runtime error
dylanglenister
commited on
Commit
·
6e61aeb
1
Parent(s):
d144879
Updated memory and connected.
Browse filesThe memory module has been updated to work with the new database files, some of those were fixed in this commit. Modules that rely on memory were also updated in suit.
- src/api/routes/chat.py +6 -3
- src/core/history.py +1 -1
- src/core/memory.py +58 -80
- src/core/state.py +15 -12
- src/data/repositories/account.py +19 -7
- src/data/repositories/session.py +15 -12
- src/main.py +2 -2
src/api/routes/chat.py
CHANGED
|
@@ -30,7 +30,10 @@ async def chat_endpoint(
|
|
| 30 |
if not user_profile:
|
| 31 |
state.memory_system.create_user(request.user_id, request.user_role or "Anonymous")
|
| 32 |
if request.user_specialty:
|
| 33 |
-
state.memory_system.
|
|
|
|
|
|
|
|
|
|
| 34 |
|
| 35 |
# Get or create session
|
| 36 |
session = state.memory_system.get_session(request.session_id)
|
|
@@ -43,8 +46,8 @@ async def chat_endpoint(
|
|
| 43 |
# Get medical context from memory
|
| 44 |
medical_context = state.history_manager.get_conversation_context(
|
| 45 |
request.user_id,
|
| 46 |
-
request.session_id,
|
| 47 |
-
request.message
|
| 48 |
)
|
| 49 |
|
| 50 |
# Generate response using Gemini AI
|
|
|
|
| 30 |
if not user_profile:
|
| 31 |
state.memory_system.create_user(request.user_id, request.user_role or "Anonymous")
|
| 32 |
if request.user_specialty:
|
| 33 |
+
state.memory_system.set_user_preferences(
|
| 34 |
+
request.user_id,
|
| 35 |
+
{"specialty": request.user_specialty}
|
| 36 |
+
)
|
| 37 |
|
| 38 |
# Get or create session
|
| 39 |
session = state.memory_system.get_session(request.session_id)
|
|
|
|
| 46 |
# Get medical context from memory
|
| 47 |
medical_context = state.history_manager.get_conversation_context(
|
| 48 |
request.user_id,
|
| 49 |
+
#request.session_id,
|
| 50 |
+
#request.message
|
| 51 |
)
|
| 52 |
|
| 53 |
# Generate response using Gemini AI
|
src/core/history.py
CHANGED
|
@@ -59,7 +59,7 @@ class MedicalHistoryManager:
|
|
| 59 |
user_id: str,
|
| 60 |
) -> str:
|
| 61 |
"""Retrieves relevant conversation context for a new question."""
|
| 62 |
-
return self.memory.get_medical_context(user_id
|
| 63 |
|
| 64 |
def get_user_medical_history(
|
| 65 |
self,
|
|
|
|
| 59 |
user_id: str,
|
| 60 |
) -> str:
|
| 61 |
"""Retrieves relevant conversation context for a new question."""
|
| 62 |
+
return self.memory.get_medical_context(user_id=user_id)
|
| 63 |
|
| 64 |
def get_user_medical_history(
|
| 65 |
self,
|
src/core/memory.py
CHANGED
|
@@ -6,132 +6,110 @@ from typing import Any
|
|
| 6 |
|
| 7 |
from src.core.profile import UserProfile
|
| 8 |
from src.core.session import ChatSession
|
| 9 |
-
from src.data import
|
|
|
|
|
|
|
| 10 |
from src.utils.logger import logger
|
| 11 |
|
| 12 |
|
| 13 |
class MemoryLRU:
|
| 14 |
"""
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
- Multiple chat sessions per user
|
| 18 |
-
- Chat history and continuity
|
| 19 |
-
- Medical context summaries
|
| 20 |
"""
|
|
|
|
| 21 |
def __init__(self, max_sessions_per_user: int = 10):
|
| 22 |
self.max_sessions_per_user = max_sessions_per_user
|
| 23 |
|
| 24 |
def create_user(self, user_id: str, name: str = "Anonymous") -> UserProfile:
|
| 25 |
-
"""
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
"_id": user_id,
|
| 29 |
-
"name": name,
|
| 30 |
-
"created_at": datetime.now(timezone.utc),
|
| 31 |
-
"last_seen": datetime.now(timezone.utc),
|
| 32 |
-
"preferences": {}
|
| 33 |
-
})
|
| 34 |
-
return user
|
| 35 |
|
| 36 |
def get_user(self, user_id: str) -> UserProfile | None:
|
| 37 |
-
"""
|
| 38 |
-
data =
|
| 39 |
return UserProfile.from_dict(data) if data else None
|
| 40 |
|
| 41 |
def create_session(self, user_id: str, title: str = "New Chat") -> str:
|
| 42 |
-
"""
|
| 43 |
-
|
| 44 |
-
mongodb.create_chat_session({
|
| 45 |
-
"_id": session_id,
|
| 46 |
-
"user_id": user_id,
|
| 47 |
-
"title": title,
|
| 48 |
-
"messages": []
|
| 49 |
-
})
|
| 50 |
-
return session_id
|
| 51 |
|
| 52 |
def get_session(self, session_id: str) -> ChatSession | None:
|
| 53 |
-
"""
|
| 54 |
try:
|
| 55 |
-
data =
|
| 56 |
-
if
|
| 57 |
-
logger().info(f"Session not found: {session_id}")
|
| 58 |
-
return None
|
| 59 |
-
|
| 60 |
-
logger().debug(f"Retrieved session data: {data}")
|
| 61 |
-
return ChatSession.from_dict(data)
|
| 62 |
except Exception as e:
|
| 63 |
logger().error(f"Error retrieving session {session_id}: {e}")
|
| 64 |
-
logger().error(f"Stack trace:", exc_info=True)
|
| 65 |
raise
|
| 66 |
|
| 67 |
def get_user_sessions(self, user_id: str) -> list[ChatSession]:
|
| 68 |
-
"""
|
| 69 |
-
sessions_data =
|
| 70 |
return [ChatSession.from_dict(data) for data in sessions_data]
|
| 71 |
|
| 72 |
-
def add_message_to_session(
|
| 73 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
message = {
|
| 75 |
"id": str(uuid.uuid4()),
|
| 76 |
"role": role,
|
| 77 |
"content": content,
|
| 78 |
"timestamp": datetime.now(timezone.utc),
|
| 79 |
-
"metadata": metadata
|
| 80 |
}
|
| 81 |
-
|
| 82 |
|
| 83 |
def update_session_title(self, session_id: str, title: str):
|
| 84 |
-
"""
|
| 85 |
-
|
| 86 |
|
| 87 |
def delete_session(self, session_id: str):
|
| 88 |
-
"""
|
| 89 |
-
|
| 90 |
|
| 91 |
-
def
|
| 92 |
-
|
| 93 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
|
| 95 |
-
# Medical context methods
|
| 96 |
def add(self, user_id: str, summary: str):
|
| 97 |
-
"""
|
| 98 |
-
|
| 99 |
|
| 100 |
def all(self, user_id: str) -> list[str]:
|
| 101 |
-
"""
|
| 102 |
-
contexts =
|
| 103 |
return [ctx["summary"] for ctx in contexts]
|
| 104 |
|
| 105 |
def recent(self, user_id: str, n: int) -> list[str]:
|
| 106 |
-
"""
|
| 107 |
-
contexts =
|
| 108 |
return [ctx["summary"] for ctx in contexts]
|
| 109 |
|
| 110 |
def rest(self, user_id: str, skip: int) -> list[str]:
|
| 111 |
-
"""
|
| 112 |
-
|
| 113 |
-
return [
|
| 114 |
-
|
| 115 |
-
def get_medical_context(
|
| 116 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
try:
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
if not contexts:
|
| 121 |
-
return ""
|
| 122 |
-
|
| 123 |
-
# Format contexts into a string
|
| 124 |
-
context_texts = []
|
| 125 |
-
for ctx in contexts:
|
| 126 |
-
summary = ctx.get("summary")
|
| 127 |
-
if summary:
|
| 128 |
-
context_texts.append(summary)
|
| 129 |
-
|
| 130 |
-
if not context_texts:
|
| 131 |
-
return ""
|
| 132 |
-
|
| 133 |
-
return "\n\n".join(context_texts)
|
| 134 |
except Exception as e:
|
| 135 |
-
logger().error(f"Error getting medical context: {e}")
|
| 136 |
-
logger().error("Stack trace:", exc_info=True)
|
| 137 |
return ""
|
|
|
|
| 6 |
|
| 7 |
from src.core.profile import UserProfile
|
| 8 |
from src.core.session import ChatSession
|
| 9 |
+
from src.data.repositories import account as account_repo
|
| 10 |
+
from src.data.repositories import medical as medical_repo
|
| 11 |
+
from src.data.repositories import session as chat_repo
|
| 12 |
from src.utils.logger import logger
|
| 13 |
|
| 14 |
|
| 15 |
class MemoryLRU:
|
| 16 |
"""
|
| 17 |
+
A memory system that orchestrates data access between the application core
|
| 18 |
+
and the data repositories, managing users, sessions, and medical context.
|
|
|
|
|
|
|
|
|
|
| 19 |
"""
|
| 20 |
+
|
| 21 |
def __init__(self, max_sessions_per_user: int = 10):
|
| 22 |
self.max_sessions_per_user = max_sessions_per_user
|
| 23 |
|
| 24 |
def create_user(self, user_id: str, name: str = "Anonymous") -> UserProfile:
|
| 25 |
+
"""Creates a new user profile."""
|
| 26 |
+
account_repo.create_account(name=name, user_id=user_id)
|
| 27 |
+
return UserProfile(user_id, name)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
|
| 29 |
def get_user(self, user_id: str) -> UserProfile | None:
|
| 30 |
+
"""Retrieves a user profile by its ID."""
|
| 31 |
+
data = account_repo.get_user_profile(user_id)
|
| 32 |
return UserProfile.from_dict(data) if data else None
|
| 33 |
|
| 34 |
def create_session(self, user_id: str, title: str = "New Chat") -> str:
|
| 35 |
+
"""Creates a new chat session for a user."""
|
| 36 |
+
return chat_repo.create_session(user_id, title)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
|
| 38 |
def get_session(self, session_id: str) -> ChatSession | None:
|
| 39 |
+
"""Retrieves a single chat session by its ID."""
|
| 40 |
try:
|
| 41 |
+
data = chat_repo.get_session(session_id)
|
| 42 |
+
return ChatSession.from_dict(data) if data else None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
except Exception as e:
|
| 44 |
logger().error(f"Error retrieving session {session_id}: {e}")
|
|
|
|
| 45 |
raise
|
| 46 |
|
| 47 |
def get_user_sessions(self, user_id: str) -> list[ChatSession]:
|
| 48 |
+
"""Retrieves all sessions for a specific user."""
|
| 49 |
+
sessions_data = chat_repo.get_user_sessions(user_id, limit=self.max_sessions_per_user)
|
| 50 |
return [ChatSession.from_dict(data) for data in sessions_data]
|
| 51 |
|
| 52 |
+
def add_message_to_session(
|
| 53 |
+
self,
|
| 54 |
+
session_id: str,
|
| 55 |
+
role: str,
|
| 56 |
+
content: str,
|
| 57 |
+
metadata: dict = {}
|
| 58 |
+
):
|
| 59 |
+
"""Adds a message to a chat session."""
|
| 60 |
message = {
|
| 61 |
"id": str(uuid.uuid4()),
|
| 62 |
"role": role,
|
| 63 |
"content": content,
|
| 64 |
"timestamp": datetime.now(timezone.utc),
|
| 65 |
+
"metadata": metadata
|
| 66 |
}
|
| 67 |
+
chat_repo.add_message(session_id, message)
|
| 68 |
|
| 69 |
def update_session_title(self, session_id: str, title: str):
|
| 70 |
+
"""Updates the title of a session."""
|
| 71 |
+
chat_repo.update_session_title(session_id, title)
|
| 72 |
|
| 73 |
def delete_session(self, session_id: str):
|
| 74 |
+
"""Deletes a chat session."""
|
| 75 |
+
chat_repo.delete_chat_session(session_id)
|
| 76 |
|
| 77 |
+
def set_user_preferences(
|
| 78 |
+
self,
|
| 79 |
+
user_id: str,
|
| 80 |
+
update_data: dict[str, Any]
|
| 81 |
+
):
|
| 82 |
+
"""Sets a preference for a user."""
|
| 83 |
+
account_repo.set_user_preferences(user_id, update_data)
|
| 84 |
|
|
|
|
| 85 |
def add(self, user_id: str, summary: str):
|
| 86 |
+
"""Adds a medical context summary for a user."""
|
| 87 |
+
medical_repo.add_medical_context(user_id, summary)
|
| 88 |
|
| 89 |
def all(self, user_id: str) -> list[str]:
|
| 90 |
+
"""Retrieves all medical context summaries for a user."""
|
| 91 |
+
contexts = medical_repo.get_medical_context(user_id)
|
| 92 |
return [ctx["summary"] for ctx in contexts]
|
| 93 |
|
| 94 |
def recent(self, user_id: str, n: int) -> list[str]:
|
| 95 |
+
"""Retrieves the N most recent medical context summaries."""
|
| 96 |
+
contexts = medical_repo.get_medical_context(user_id, limit=n)
|
| 97 |
return [ctx["summary"] for ctx in contexts]
|
| 98 |
|
| 99 |
def rest(self, user_id: str, skip: int) -> list[str]:
|
| 100 |
+
"""Retrieves all summaries except for the N most recent ones."""
|
| 101 |
+
all_contexts = self.all(user_id)
|
| 102 |
+
return all_contexts[skip:]
|
| 103 |
+
|
| 104 |
+
def get_medical_context(
|
| 105 |
+
self,
|
| 106 |
+
user_id: str,
|
| 107 |
+
limit: int = 5
|
| 108 |
+
) -> str:
|
| 109 |
+
"""Retrieves and formats recent medical context into a single string."""
|
| 110 |
try:
|
| 111 |
+
contexts = self.recent(user_id, limit)
|
| 112 |
+
return "\n\n".join(contexts)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 113 |
except Exception as e:
|
| 114 |
+
logger().error(f"Error getting medical context for user {user_id}: {e}")
|
|
|
|
| 115 |
return ""
|
src/core/state.py
CHANGED
|
@@ -7,30 +7,33 @@ from src.utils.rotator import APIKeyRotator
|
|
| 7 |
|
| 8 |
|
| 9 |
class MedicalState:
|
| 10 |
-
"""
|
| 11 |
-
_instance
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
def __init__(self):
|
|
|
|
|
|
|
| 14 |
self.memory_system: MemoryLRU
|
| 15 |
self.embedding_client: EmbeddingClient
|
| 16 |
self.history_manager: MedicalHistoryManager
|
| 17 |
self.gemini_rotator: APIKeyRotator
|
| 18 |
self.nvidia_rotator: APIKeyRotator
|
|
|
|
| 19 |
|
| 20 |
def initialize(self):
|
| 21 |
-
"""
|
| 22 |
self.memory_system = MemoryLRU(max_sessions_per_user=20)
|
| 23 |
-
self.embedding_client = EmbeddingClient("all-MiniLM-L6-v2", 384)
|
| 24 |
self.history_manager = MedicalHistoryManager(self.memory_system, self.embedding_client)
|
| 25 |
self.gemini_rotator = APIKeyRotator("GEMINI_API_", max_slots=5)
|
| 26 |
self.nvidia_rotator = APIKeyRotator("NVIDIA_API_", max_slots=5)
|
| 27 |
|
| 28 |
-
@classmethod
|
| 29 |
-
def get_instance(cls) -> 'MedicalState':
|
| 30 |
-
if cls._instance is None:
|
| 31 |
-
cls._instance = MedicalState()
|
| 32 |
-
return cls._instance
|
| 33 |
-
|
| 34 |
def get_state() -> MedicalState:
|
| 35 |
-
"""
|
| 36 |
-
return MedicalState
|
|
|
|
| 7 |
|
| 8 |
|
| 9 |
class MedicalState:
|
| 10 |
+
"""Manages the global state of the application using a Singleton pattern."""
|
| 11 |
+
_instance = None
|
| 12 |
+
|
| 13 |
+
def __new__(cls):
|
| 14 |
+
if cls._instance is None:
|
| 15 |
+
cls._instance = super(MedicalState, cls).__new__(cls)
|
| 16 |
+
cls._instance._initialized = False
|
| 17 |
+
return cls._instance
|
| 18 |
|
| 19 |
def __init__(self):
|
| 20 |
+
if self._initialized:
|
| 21 |
+
return
|
| 22 |
self.memory_system: MemoryLRU
|
| 23 |
self.embedding_client: EmbeddingClient
|
| 24 |
self.history_manager: MedicalHistoryManager
|
| 25 |
self.gemini_rotator: APIKeyRotator
|
| 26 |
self.nvidia_rotator: APIKeyRotator
|
| 27 |
+
self._initialized = True
|
| 28 |
|
| 29 |
def initialize(self):
|
| 30 |
+
"""Initializes all core application components."""
|
| 31 |
self.memory_system = MemoryLRU(max_sessions_per_user=20)
|
| 32 |
+
self.embedding_client = EmbeddingClient(model_name="all-MiniLM-L6-v2", dimension=384)
|
| 33 |
self.history_manager = MedicalHistoryManager(self.memory_system, self.embedding_client)
|
| 34 |
self.gemini_rotator = APIKeyRotator("GEMINI_API_", max_slots=5)
|
| 35 |
self.nvidia_rotator = APIKeyRotator("NVIDIA_API_", max_slots=5)
|
| 36 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
def get_state() -> MedicalState:
|
| 38 |
+
"""Provides access to the application state, for use as a dependency."""
|
| 39 |
+
return MedicalState()
|
src/data/repositories/account.py
CHANGED
|
@@ -16,16 +16,28 @@ class UserNotFound(Exception):
|
|
| 16 |
pass
|
| 17 |
|
| 18 |
def create_account(
|
| 19 |
-
|
|
|
|
| 20 |
*,
|
|
|
|
| 21 |
collection_name: str = ACCOUNTS_COLLECTION
|
| 22 |
) -> str:
|
| 23 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
collection = get_collection(collection_name)
|
| 25 |
now = datetime.now(timezone.utc)
|
| 26 |
-
user_data
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
try:
|
| 28 |
result = collection.insert_one(user_data)
|
|
|
|
| 29 |
logger().info(f"Created new account: {result.inserted_id}")
|
| 30 |
return str(result.inserted_id)
|
| 31 |
except DuplicateKeyError as e:
|
|
@@ -72,19 +84,19 @@ def get_user_profile(
|
|
| 72 |
def set_user_preferences(
|
| 73 |
user_id: str,
|
| 74 |
/,
|
| 75 |
-
|
| 76 |
*,
|
| 77 |
collection_name: str = ACCOUNTS_COLLECTION
|
| 78 |
) -> bool:
|
| 79 |
"""Sets a preference for a user."""
|
| 80 |
try:
|
| 81 |
collection = get_collection(collection_name)
|
| 82 |
-
|
| 83 |
-
|
| 84 |
result = collection.update_one(
|
| 85 |
{"_id": user_id},
|
| 86 |
{
|
| 87 |
-
"$set":
|
| 88 |
}
|
| 89 |
)
|
| 90 |
if result.matched_count == 0:
|
|
|
|
| 16 |
pass
|
| 17 |
|
| 18 |
def create_account(
|
| 19 |
+
name: str,
|
| 20 |
+
preferences: dict[str, Any] = {},
|
| 21 |
*,
|
| 22 |
+
user_id: str,
|
| 23 |
collection_name: str = ACCOUNTS_COLLECTION
|
| 24 |
) -> str:
|
| 25 |
+
"""
|
| 26 |
+
Creates a new user account.
|
| 27 |
+
|
| 28 |
+
@TODO Revise if the user id should be passed in or generated by the database.
|
| 29 |
+
"""
|
| 30 |
collection = get_collection(collection_name)
|
| 31 |
now = datetime.now(timezone.utc)
|
| 32 |
+
user_data: dict[str, Any] = {
|
| 33 |
+
"_id": user_id,
|
| 34 |
+
"name": name,
|
| 35 |
+
"created_at": now,
|
| 36 |
+
"updated_at": now
|
| 37 |
+
}
|
| 38 |
try:
|
| 39 |
result = collection.insert_one(user_data)
|
| 40 |
+
set_user_preferences(user_id, preferences)
|
| 41 |
logger().info(f"Created new account: {result.inserted_id}")
|
| 42 |
return str(result.inserted_id)
|
| 43 |
except DuplicateKeyError as e:
|
|
|
|
| 84 |
def set_user_preferences(
|
| 85 |
user_id: str,
|
| 86 |
/,
|
| 87 |
+
preferences: dict[str, Any],
|
| 88 |
*,
|
| 89 |
collection_name: str = ACCOUNTS_COLLECTION
|
| 90 |
) -> bool:
|
| 91 |
"""Sets a preference for a user."""
|
| 92 |
try:
|
| 93 |
collection = get_collection(collection_name)
|
| 94 |
+
preferences = {f"preferences.{key}": value for key, value in preferences}
|
| 95 |
+
preferences["updated_at"] = datetime.now(timezone.utc)
|
| 96 |
result = collection.update_one(
|
| 97 |
{"_id": user_id},
|
| 98 |
{
|
| 99 |
+
"$set": preferences
|
| 100 |
}
|
| 101 |
)
|
| 102 |
if result.matched_count == 0:
|
src/data/repositories/session.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
# data/repositories/chat.py
|
| 2 |
|
|
|
|
| 3 |
from datetime import datetime, timezone
|
| 4 |
from typing import Any
|
| 5 |
|
| 6 |
-
from bson import ObjectId
|
| 7 |
from pymongo import DESCENDING
|
| 8 |
from pymongo.errors import (ConnectionFailure, DuplicateKeyError,
|
| 9 |
OperationFailure, PyMongoError)
|
|
@@ -13,23 +13,26 @@ from src.utils.logger import logger
|
|
| 13 |
|
| 14 |
CHAT_SESSIONS_COLLECTION = "chat_sessions"
|
| 15 |
|
| 16 |
-
def
|
| 17 |
-
|
|
|
|
| 18 |
*,
|
| 19 |
collection_name: str = CHAT_SESSIONS_COLLECTION
|
| 20 |
) -> str:
|
| 21 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
collection = get_collection(collection_name)
|
| 23 |
now = datetime.now(timezone.utc)
|
| 24 |
-
session_data
|
|
|
|
|
|
|
|
|
|
| 25 |
"created_at": now,
|
| 26 |
-
"updated_at": now
|
| 27 |
-
|
| 28 |
-
"title": str(session_data.get("title", "New Chat")),
|
| 29 |
-
"user_id": str(session_data["user_id"])
|
| 30 |
-
})
|
| 31 |
-
#if "_id" not in session_data:
|
| 32 |
-
# session_data["_id"] = str(ObjectId())
|
| 33 |
try:
|
| 34 |
result = collection.insert_one(session_data)
|
| 35 |
return str(result.inserted_id)
|
|
|
|
| 1 |
# data/repositories/chat.py
|
| 2 |
|
| 3 |
+
import uuid
|
| 4 |
from datetime import datetime, timezone
|
| 5 |
from typing import Any
|
| 6 |
|
|
|
|
| 7 |
from pymongo import DESCENDING
|
| 8 |
from pymongo.errors import (ConnectionFailure, DuplicateKeyError,
|
| 9 |
OperationFailure, PyMongoError)
|
|
|
|
| 13 |
|
| 14 |
CHAT_SESSIONS_COLLECTION = "chat_sessions"
|
| 15 |
|
| 16 |
+
def create_session(
|
| 17 |
+
user_id: str,
|
| 18 |
+
title: str,
|
| 19 |
*,
|
| 20 |
collection_name: str = CHAT_SESSIONS_COLLECTION
|
| 21 |
) -> str:
|
| 22 |
+
"""
|
| 23 |
+
Creates a new chat session.
|
| 24 |
+
|
| 25 |
+
@TODO Revise if the session id should be passed in or generated by the database.
|
| 26 |
+
"""
|
| 27 |
collection = get_collection(collection_name)
|
| 28 |
now = datetime.now(timezone.utc)
|
| 29 |
+
session_data: dict[str, Any] = {
|
| 30 |
+
"_id": str(uuid.uuid4()),
|
| 31 |
+
"user_id": user_id,
|
| 32 |
+
"title": title,
|
| 33 |
"created_at": now,
|
| 34 |
+
"updated_at": now
|
| 35 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
try:
|
| 37 |
result = collection.insert_one(session_data)
|
| 38 |
return str(result.inserted_id)
|
src/main.py
CHANGED
|
@@ -24,7 +24,7 @@ except Exception as e:
|
|
| 24 |
|
| 25 |
# Import project modules after trying to load environment variables
|
| 26 |
from src.api.routes import chat, session, static, system, user
|
| 27 |
-
from src.core.state import MedicalState
|
| 28 |
|
| 29 |
|
| 30 |
def startup_event(state: MedicalState):
|
|
@@ -73,7 +73,7 @@ def shutdown_event():
|
|
| 73 |
@asynccontextmanager
|
| 74 |
async def lifespan(app: FastAPI):
|
| 75 |
# Initialize state
|
| 76 |
-
state =
|
| 77 |
state.initialize()
|
| 78 |
|
| 79 |
# Startup code here
|
|
|
|
| 24 |
|
| 25 |
# Import project modules after trying to load environment variables
|
| 26 |
from src.api.routes import chat, session, static, system, user
|
| 27 |
+
from src.core.state import MedicalState, get_state
|
| 28 |
|
| 29 |
|
| 30 |
def startup_event(state: MedicalState):
|
|
|
|
| 73 |
@asynccontextmanager
|
| 74 |
async def lifespan(app: FastAPI):
|
| 75 |
# Initialize state
|
| 76 |
+
state = get_state()
|
| 77 |
state.initialize()
|
| 78 |
|
| 79 |
# Startup code here
|