dylanglenister commited on
Commit
6e61aeb
·
1 Parent(s): d144879

Updated memory and connected.

Browse files

The memory module has been updated to work with the new database files, some of those were fixed in this commit. Modules that rely on memory were also updated in suit.

src/api/routes/chat.py CHANGED
@@ -30,7 +30,10 @@ async def chat_endpoint(
30
  if not user_profile:
31
  state.memory_system.create_user(request.user_id, request.user_role or "Anonymous")
32
  if request.user_specialty:
33
- state.memory_system.set_user_preference(request.user_id, "specialty", request.user_specialty)
 
 
 
34
 
35
  # Get or create session
36
  session = state.memory_system.get_session(request.session_id)
@@ -43,8 +46,8 @@ async def chat_endpoint(
43
  # Get medical context from memory
44
  medical_context = state.history_manager.get_conversation_context(
45
  request.user_id,
46
- request.session_id,
47
- request.message
48
  )
49
 
50
  # Generate response using Gemini AI
 
30
  if not user_profile:
31
  state.memory_system.create_user(request.user_id, request.user_role or "Anonymous")
32
  if request.user_specialty:
33
+ state.memory_system.set_user_preferences(
34
+ request.user_id,
35
+ {"specialty": request.user_specialty}
36
+ )
37
 
38
  # Get or create session
39
  session = state.memory_system.get_session(request.session_id)
 
46
  # Get medical context from memory
47
  medical_context = state.history_manager.get_conversation_context(
48
  request.user_id,
49
+ #request.session_id,
50
+ #request.message
51
  )
52
 
53
  # Generate response using Gemini AI
src/core/history.py CHANGED
@@ -59,7 +59,7 @@ class MedicalHistoryManager:
59
  user_id: str,
60
  ) -> str:
61
  """Retrieves relevant conversation context for a new question."""
62
- return self.memory.get_medical_context(user_id, session_id, question)
63
 
64
  def get_user_medical_history(
65
  self,
 
59
  user_id: str,
60
  ) -> str:
61
  """Retrieves relevant conversation context for a new question."""
62
+ return self.memory.get_medical_context(user_id=user_id)
63
 
64
  def get_user_medical_history(
65
  self,
src/core/memory.py CHANGED
@@ -6,132 +6,110 @@ from typing import Any
6
 
7
  from src.core.profile import UserProfile
8
  from src.core.session import ChatSession
9
- from src.data import mongodb
 
 
10
  from src.utils.logger import logger
11
 
12
 
13
  class MemoryLRU:
14
  """
15
- Memory system using MongoDB for persistence, supporting:
16
- - Multiple users with profiles
17
- - Multiple chat sessions per user
18
- - Chat history and continuity
19
- - Medical context summaries
20
  """
 
21
  def __init__(self, max_sessions_per_user: int = 10):
22
  self.max_sessions_per_user = max_sessions_per_user
23
 
24
  def create_user(self, user_id: str, name: str = "Anonymous") -> UserProfile:
25
- """Create a new user profile"""
26
- user = UserProfile(user_id, name)
27
- mongodb.create_account({
28
- "_id": user_id,
29
- "name": name,
30
- "created_at": datetime.now(timezone.utc),
31
- "last_seen": datetime.now(timezone.utc),
32
- "preferences": {}
33
- })
34
- return user
35
 
36
  def get_user(self, user_id: str) -> UserProfile | None:
37
- """Get user profile by ID"""
38
- data = mongodb.get_user_profile(user_id)
39
  return UserProfile.from_dict(data) if data else None
40
 
41
  def create_session(self, user_id: str, title: str = "New Chat") -> str:
42
- """Create a new chat session"""
43
- session_id = str(uuid.uuid4())
44
- mongodb.create_chat_session({
45
- "_id": session_id,
46
- "user_id": user_id,
47
- "title": title,
48
- "messages": []
49
- })
50
- return session_id
51
 
52
  def get_session(self, session_id: str) -> ChatSession | None:
53
- """Get chat session by ID"""
54
  try:
55
- data = mongodb.get_session(session_id)
56
- if not data:
57
- logger().info(f"Session not found: {session_id}")
58
- return None
59
-
60
- logger().debug(f"Retrieved session data: {data}")
61
- return ChatSession.from_dict(data)
62
  except Exception as e:
63
  logger().error(f"Error retrieving session {session_id}: {e}")
64
- logger().error(f"Stack trace:", exc_info=True)
65
  raise
66
 
67
  def get_user_sessions(self, user_id: str) -> list[ChatSession]:
68
- """Get all sessions for a user"""
69
- sessions_data = mongodb.get_user_sessions(user_id, limit=self.max_sessions_per_user)
70
  return [ChatSession.from_dict(data) for data in sessions_data]
71
 
72
- def add_message_to_session(self, session_id: str, role: str, content: str, metadata: dict | None = None):
73
- """Add a message to a session"""
 
 
 
 
 
 
74
  message = {
75
  "id": str(uuid.uuid4()),
76
  "role": role,
77
  "content": content,
78
  "timestamp": datetime.now(timezone.utc),
79
- "metadata": metadata or {}
80
  }
81
- mongodb.add_message(session_id, message)
82
 
83
  def update_session_title(self, session_id: str, title: str):
84
- """Update session title"""
85
- mongodb.update_session_title(session_id, title)
86
 
87
  def delete_session(self, session_id: str):
88
- """Delete a chat session"""
89
- mongodb.delete_chat_session(session_id)
90
 
91
- def set_user_preference(self, user_id: str, key: str, value: Any):
92
- """Set user preference"""
93
- mongodb.set_user_preference(user_id, key, value)
 
 
 
 
94
 
95
- # Medical context methods
96
  def add(self, user_id: str, summary: str):
97
- """Add a medical context summary"""
98
- mongodb.add_medical_context(user_id, summary)
99
 
100
  def all(self, user_id: str) -> list[str]:
101
- """Get all medical context summaries for a user"""
102
- contexts = mongodb.get_medical_context(user_id)
103
  return [ctx["summary"] for ctx in contexts]
104
 
105
  def recent(self, user_id: str, n: int) -> list[str]:
106
- """Get n most recent medical context summaries"""
107
- contexts = mongodb.get_medical_context(user_id, limit=n)
108
  return [ctx["summary"] for ctx in contexts]
109
 
110
  def rest(self, user_id: str, skip: int) -> list[str]:
111
- """Get all summaries except the most recent n"""
112
- contexts = mongodb.get_medical_context(user_id)
113
- return [ctx["summary"] for ctx in contexts[skip:]]
114
-
115
- def get_medical_context(self, user_id: str, session_id: str, question: str) -> str:
116
- """Get relevant medical context for a question"""
 
 
 
 
117
  try:
118
- # Get recent contexts
119
- contexts = mongodb.get_medical_context(user_id, limit=5)
120
- if not contexts:
121
- return ""
122
-
123
- # Format contexts into a string
124
- context_texts = []
125
- for ctx in contexts:
126
- summary = ctx.get("summary")
127
- if summary:
128
- context_texts.append(summary)
129
-
130
- if not context_texts:
131
- return ""
132
-
133
- return "\n\n".join(context_texts)
134
  except Exception as e:
135
- logger().error(f"Error getting medical context: {e}")
136
- logger().error("Stack trace:", exc_info=True)
137
  return ""
 
6
 
7
  from src.core.profile import UserProfile
8
  from src.core.session import ChatSession
9
+ from src.data.repositories import account as account_repo
10
+ from src.data.repositories import medical as medical_repo
11
+ from src.data.repositories import session as chat_repo
12
  from src.utils.logger import logger
13
 
14
 
15
  class MemoryLRU:
16
  """
17
+ A memory system that orchestrates data access between the application core
18
+ and the data repositories, managing users, sessions, and medical context.
 
 
 
19
  """
20
+
21
  def __init__(self, max_sessions_per_user: int = 10):
22
  self.max_sessions_per_user = max_sessions_per_user
23
 
24
  def create_user(self, user_id: str, name: str = "Anonymous") -> UserProfile:
25
+ """Creates a new user profile."""
26
+ account_repo.create_account(name=name, user_id=user_id)
27
+ return UserProfile(user_id, name)
 
 
 
 
 
 
 
28
 
29
  def get_user(self, user_id: str) -> UserProfile | None:
30
+ """Retrieves a user profile by its ID."""
31
+ data = account_repo.get_user_profile(user_id)
32
  return UserProfile.from_dict(data) if data else None
33
 
34
  def create_session(self, user_id: str, title: str = "New Chat") -> str:
35
+ """Creates a new chat session for a user."""
36
+ return chat_repo.create_session(user_id, title)
 
 
 
 
 
 
 
37
 
38
  def get_session(self, session_id: str) -> ChatSession | None:
39
+ """Retrieves a single chat session by its ID."""
40
  try:
41
+ data = chat_repo.get_session(session_id)
42
+ return ChatSession.from_dict(data) if data else None
 
 
 
 
 
43
  except Exception as e:
44
  logger().error(f"Error retrieving session {session_id}: {e}")
 
45
  raise
46
 
47
  def get_user_sessions(self, user_id: str) -> list[ChatSession]:
48
+ """Retrieves all sessions for a specific user."""
49
+ sessions_data = chat_repo.get_user_sessions(user_id, limit=self.max_sessions_per_user)
50
  return [ChatSession.from_dict(data) for data in sessions_data]
51
 
52
+ def add_message_to_session(
53
+ self,
54
+ session_id: str,
55
+ role: str,
56
+ content: str,
57
+ metadata: dict = {}
58
+ ):
59
+ """Adds a message to a chat session."""
60
  message = {
61
  "id": str(uuid.uuid4()),
62
  "role": role,
63
  "content": content,
64
  "timestamp": datetime.now(timezone.utc),
65
+ "metadata": metadata
66
  }
67
+ chat_repo.add_message(session_id, message)
68
 
69
  def update_session_title(self, session_id: str, title: str):
70
+ """Updates the title of a session."""
71
+ chat_repo.update_session_title(session_id, title)
72
 
73
  def delete_session(self, session_id: str):
74
+ """Deletes a chat session."""
75
+ chat_repo.delete_chat_session(session_id)
76
 
77
+ def set_user_preferences(
78
+ self,
79
+ user_id: str,
80
+ update_data: dict[str, Any]
81
+ ):
82
+ """Sets a preference for a user."""
83
+ account_repo.set_user_preferences(user_id, update_data)
84
 
 
85
  def add(self, user_id: str, summary: str):
86
+ """Adds a medical context summary for a user."""
87
+ medical_repo.add_medical_context(user_id, summary)
88
 
89
  def all(self, user_id: str) -> list[str]:
90
+ """Retrieves all medical context summaries for a user."""
91
+ contexts = medical_repo.get_medical_context(user_id)
92
  return [ctx["summary"] for ctx in contexts]
93
 
94
  def recent(self, user_id: str, n: int) -> list[str]:
95
+ """Retrieves the N most recent medical context summaries."""
96
+ contexts = medical_repo.get_medical_context(user_id, limit=n)
97
  return [ctx["summary"] for ctx in contexts]
98
 
99
  def rest(self, user_id: str, skip: int) -> list[str]:
100
+ """Retrieves all summaries except for the N most recent ones."""
101
+ all_contexts = self.all(user_id)
102
+ return all_contexts[skip:]
103
+
104
+ def get_medical_context(
105
+ self,
106
+ user_id: str,
107
+ limit: int = 5
108
+ ) -> str:
109
+ """Retrieves and formats recent medical context into a single string."""
110
  try:
111
+ contexts = self.recent(user_id, limit)
112
+ return "\n\n".join(contexts)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
  except Exception as e:
114
+ logger().error(f"Error getting medical context for user {user_id}: {e}")
 
115
  return ""
src/core/state.py CHANGED
@@ -7,30 +7,33 @@ from src.utils.rotator import APIKeyRotator
7
 
8
 
9
  class MedicalState:
10
- """Global state management for Medical AI system"""
11
- _instance: 'MedicalState | None' = None
 
 
 
 
 
 
12
 
13
  def __init__(self):
 
 
14
  self.memory_system: MemoryLRU
15
  self.embedding_client: EmbeddingClient
16
  self.history_manager: MedicalHistoryManager
17
  self.gemini_rotator: APIKeyRotator
18
  self.nvidia_rotator: APIKeyRotator
 
19
 
20
  def initialize(self):
21
- """Initialize all core components"""
22
  self.memory_system = MemoryLRU(max_sessions_per_user=20)
23
- self.embedding_client = EmbeddingClient("all-MiniLM-L6-v2", 384)
24
  self.history_manager = MedicalHistoryManager(self.memory_system, self.embedding_client)
25
  self.gemini_rotator = APIKeyRotator("GEMINI_API_", max_slots=5)
26
  self.nvidia_rotator = APIKeyRotator("NVIDIA_API_", max_slots=5)
27
 
28
- @classmethod
29
- def get_instance(cls) -> 'MedicalState':
30
- if cls._instance is None:
31
- cls._instance = MedicalState()
32
- return cls._instance
33
-
34
  def get_state() -> MedicalState:
35
- """FastAPI dependency for getting application state"""
36
- return MedicalState.get_instance()
 
7
 
8
 
9
  class MedicalState:
10
+ """Manages the global state of the application using a Singleton pattern."""
11
+ _instance = None
12
+
13
+ def __new__(cls):
14
+ if cls._instance is None:
15
+ cls._instance = super(MedicalState, cls).__new__(cls)
16
+ cls._instance._initialized = False
17
+ return cls._instance
18
 
19
  def __init__(self):
20
+ if self._initialized:
21
+ return
22
  self.memory_system: MemoryLRU
23
  self.embedding_client: EmbeddingClient
24
  self.history_manager: MedicalHistoryManager
25
  self.gemini_rotator: APIKeyRotator
26
  self.nvidia_rotator: APIKeyRotator
27
+ self._initialized = True
28
 
29
  def initialize(self):
30
+ """Initializes all core application components."""
31
  self.memory_system = MemoryLRU(max_sessions_per_user=20)
32
+ self.embedding_client = EmbeddingClient(model_name="all-MiniLM-L6-v2", dimension=384)
33
  self.history_manager = MedicalHistoryManager(self.memory_system, self.embedding_client)
34
  self.gemini_rotator = APIKeyRotator("GEMINI_API_", max_slots=5)
35
  self.nvidia_rotator = APIKeyRotator("NVIDIA_API_", max_slots=5)
36
 
 
 
 
 
 
 
37
  def get_state() -> MedicalState:
38
+ """Provides access to the application state, for use as a dependency."""
39
+ return MedicalState()
src/data/repositories/account.py CHANGED
@@ -16,16 +16,28 @@ class UserNotFound(Exception):
16
  pass
17
 
18
  def create_account(
19
- user_data: dict[str, Any],
 
20
  *,
 
21
  collection_name: str = ACCOUNTS_COLLECTION
22
  ) -> str:
23
- """Creates a new user account."""
 
 
 
 
24
  collection = get_collection(collection_name)
25
  now = datetime.now(timezone.utc)
26
- user_data.update({"created_at": now, "updated_at": now})
 
 
 
 
 
27
  try:
28
  result = collection.insert_one(user_data)
 
29
  logger().info(f"Created new account: {result.inserted_id}")
30
  return str(result.inserted_id)
31
  except DuplicateKeyError as e:
@@ -72,19 +84,19 @@ def get_user_profile(
72
  def set_user_preferences(
73
  user_id: str,
74
  /,
75
- update_data: dict[str, Any],
76
  *,
77
  collection_name: str = ACCOUNTS_COLLECTION
78
  ) -> bool:
79
  """Sets a preference for a user."""
80
  try:
81
  collection = get_collection(collection_name)
82
- update_data = {f"preferences.{key}": value for key, value in update_data}
83
- update_data["updated_at"] = datetime.now(timezone.utc)
84
  result = collection.update_one(
85
  {"_id": user_id},
86
  {
87
- "$set": update_data
88
  }
89
  )
90
  if result.matched_count == 0:
 
16
  pass
17
 
18
  def create_account(
19
+ name: str,
20
+ preferences: dict[str, Any] = {},
21
  *,
22
+ user_id: str,
23
  collection_name: str = ACCOUNTS_COLLECTION
24
  ) -> str:
25
+ """
26
+ Creates a new user account.
27
+
28
+ @TODO Revise if the user id should be passed in or generated by the database.
29
+ """
30
  collection = get_collection(collection_name)
31
  now = datetime.now(timezone.utc)
32
+ user_data: dict[str, Any] = {
33
+ "_id": user_id,
34
+ "name": name,
35
+ "created_at": now,
36
+ "updated_at": now
37
+ }
38
  try:
39
  result = collection.insert_one(user_data)
40
+ set_user_preferences(user_id, preferences)
41
  logger().info(f"Created new account: {result.inserted_id}")
42
  return str(result.inserted_id)
43
  except DuplicateKeyError as e:
 
84
  def set_user_preferences(
85
  user_id: str,
86
  /,
87
+ preferences: dict[str, Any],
88
  *,
89
  collection_name: str = ACCOUNTS_COLLECTION
90
  ) -> bool:
91
  """Sets a preference for a user."""
92
  try:
93
  collection = get_collection(collection_name)
94
+ preferences = {f"preferences.{key}": value for key, value in preferences}
95
+ preferences["updated_at"] = datetime.now(timezone.utc)
96
  result = collection.update_one(
97
  {"_id": user_id},
98
  {
99
+ "$set": preferences
100
  }
101
  )
102
  if result.matched_count == 0:
src/data/repositories/session.py CHANGED
@@ -1,9 +1,9 @@
1
  # data/repositories/chat.py
2
 
 
3
  from datetime import datetime, timezone
4
  from typing import Any
5
 
6
- from bson import ObjectId
7
  from pymongo import DESCENDING
8
  from pymongo.errors import (ConnectionFailure, DuplicateKeyError,
9
  OperationFailure, PyMongoError)
@@ -13,23 +13,26 @@ from src.utils.logger import logger
13
 
14
  CHAT_SESSIONS_COLLECTION = "chat_sessions"
15
 
16
- def create_chat_session(
17
- session_data: dict[str, Any],
 
18
  *,
19
  collection_name: str = CHAT_SESSIONS_COLLECTION
20
  ) -> str:
21
- """Creates a new chat session."""
 
 
 
 
22
  collection = get_collection(collection_name)
23
  now = datetime.now(timezone.utc)
24
- session_data.update({
 
 
 
25
  "created_at": now,
26
- "updated_at": now,
27
- "messages": session_data.get("messages", []),
28
- "title": str(session_data.get("title", "New Chat")),
29
- "user_id": str(session_data["user_id"])
30
- })
31
- #if "_id" not in session_data:
32
- # session_data["_id"] = str(ObjectId())
33
  try:
34
  result = collection.insert_one(session_data)
35
  return str(result.inserted_id)
 
1
  # data/repositories/chat.py
2
 
3
+ import uuid
4
  from datetime import datetime, timezone
5
  from typing import Any
6
 
 
7
  from pymongo import DESCENDING
8
  from pymongo.errors import (ConnectionFailure, DuplicateKeyError,
9
  OperationFailure, PyMongoError)
 
13
 
14
  CHAT_SESSIONS_COLLECTION = "chat_sessions"
15
 
16
+ def create_session(
17
+ user_id: str,
18
+ title: str,
19
  *,
20
  collection_name: str = CHAT_SESSIONS_COLLECTION
21
  ) -> str:
22
+ """
23
+ Creates a new chat session.
24
+
25
+ @TODO Revise if the session id should be passed in or generated by the database.
26
+ """
27
  collection = get_collection(collection_name)
28
  now = datetime.now(timezone.utc)
29
+ session_data: dict[str, Any] = {
30
+ "_id": str(uuid.uuid4()),
31
+ "user_id": user_id,
32
+ "title": title,
33
  "created_at": now,
34
+ "updated_at": now
35
+ }
 
 
 
 
 
36
  try:
37
  result = collection.insert_one(session_data)
38
  return str(result.inserted_id)
src/main.py CHANGED
@@ -24,7 +24,7 @@ except Exception as e:
24
 
25
  # Import project modules after trying to load environment variables
26
  from src.api.routes import chat, session, static, system, user
27
- from src.core.state import MedicalState
28
 
29
 
30
  def startup_event(state: MedicalState):
@@ -73,7 +73,7 @@ def shutdown_event():
73
  @asynccontextmanager
74
  async def lifespan(app: FastAPI):
75
  # Initialize state
76
- state = MedicalState.get_instance()
77
  state.initialize()
78
 
79
  # Startup code here
 
24
 
25
  # Import project modules after trying to load environment variables
26
  from src.api.routes import chat, session, static, system, user
27
+ from src.core.state import MedicalState, get_state
28
 
29
 
30
  def startup_event(state: MedicalState):
 
73
  @asynccontextmanager
74
  async def lifespan(app: FastAPI):
75
  # Initialize state
76
+ state = get_state()
77
  state.initialize()
78
 
79
  # Startup code here