dylanglenister commited on
Commit
833527f
·
1 Parent(s): 3d81965

FEAT: Reranker file.

Browse files

Uses an nvidia reranker model to find the most relevant information.

Files changed (2) hide show
  1. src/config/settings.py +3 -0
  2. src/services/reranker.py +70 -0
src/config/settings.py CHANGED
@@ -1,6 +1,7 @@
1
  # src/config/settings.py
2
  import os
3
 
 
4
  class Settings:
5
  """Application-wide settings."""
6
  # Memory settings
@@ -9,6 +10,8 @@ class Settings:
9
  SEMANTIC_CONTEXT_SIZE: int = 17
10
  SIMILARITY_THRESHOLD: float = 0.15
11
  EMBEDDING_MODEL_NAME: str = "MedEmbed-large-v0.1"
 
 
12
 
13
  # Safety Guard settings
14
  SAFETY_GUARD_ENABLED: bool = os.getenv("SAFETY_GUARD_ENABLED", "true").lower() == "true"
 
1
  # src/config/settings.py
2
  import os
3
 
4
+
5
  class Settings:
6
  """Application-wide settings."""
7
  # Memory settings
 
10
  SEMANTIC_CONTEXT_SIZE: int = 17
11
  SIMILARITY_THRESHOLD: float = 0.15
12
  EMBEDDING_MODEL_NAME: str = "MedEmbed-large-v0.1"
13
+ NVIDIA_RERANKER_MODEL: str = "rerank-qa-mistral-4b"
14
+ NVIDIA_RERANKER_ENDPOINT: str = "" # TODO
15
 
16
  # Safety Guard settings
17
  SAFETY_GUARD_ENABLED: bool = os.getenv("SAFETY_GUARD_ENABLED", "true").lower() == "true"
src/services/reranker.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # src/services/reranker.py
2
+
3
+ from src.config.settings import settings
4
+ from src.models.information import InfoChunk
5
+ from src.utils.logger import logger
6
+ from src.utils.rotator import APIKeyRotator, robust_post_json
7
+
8
+
9
+ async def rerank_documents(
10
+ query: str,
11
+ documents: list[InfoChunk],
12
+ rotator: APIKeyRotator,
13
+ top_k: int = 3,
14
+ ) -> list[InfoChunk]:
15
+ """
16
+ Reranks a list of documents based on a query using the NVIDIA Rerank API.
17
+
18
+ Args:
19
+ query: The user's query string.
20
+ documents: A list of InfoChunk objects retrieved from the initial search.
21
+ rotator: The API key rotator for NVIDIA services.
22
+ top_k: The final number of documents to return after reranking.
23
+
24
+ Returns:
25
+ A sorted list of the top_k most relevant InfoChunk objects.
26
+ Returns the original list sliced to top_k if reranking fails.
27
+ """
28
+ if not documents:
29
+ return []
30
+
31
+ headers = {
32
+ "Authorization": f"Bearer {rotator.get_key() or ''}",
33
+ "Accept": "application/json",
34
+ "Content-Type": "application/json",
35
+ }
36
+
37
+ passages = [doc.content for doc in documents]
38
+
39
+ payload = {
40
+ "model": settings.NVIDIA_RERANKER_MODEL,
41
+ "query": query,
42
+ "passages": passages,
43
+ "top_n": top_k,
44
+ }
45
+
46
+ try:
47
+ # Use the existing robust helper for consistency
48
+ data = await robust_post_json(settings.NVIDIA_RERANKER_ENDPOINT, headers, payload, rotator)
49
+ results = data.get("results", [])
50
+
51
+ if not results:
52
+ logger().warning("Reranking returned no results, falling back to original order.")
53
+ return documents[:top_k]
54
+
55
+ # Create a mapping of original document content to the document object
56
+ doc_map = {doc.content: doc for doc in documents}
57
+
58
+ # Reconstruct the sorted list of documents based on rerank results
59
+ reranked_docs = []
60
+ for result in sorted(results, key=lambda x: x["rank"]):
61
+ if result["passage"] in doc_map:
62
+ reranked_docs.append(doc_map[result["passage"]])
63
+
64
+ return reranked_docs
65
+
66
+ except Exception as e:
67
+ logger().error(f"An unexpected error occurred during reranking: {e}")
68
+
69
+ # Fallback: return the top_k documents from the original list
70
+ return documents[:top_k]