logreader / retrieval.py
PatrickRedStar's picture
add
d76ef9a
import os
from dataclasses import dataclass
from typing import List, Optional
import numpy as np
import torch
from sentence_transformers import CrossEncoder, SentenceTransformer, util
@dataclass
class RunbookDoc:
"""
Представляет один Markdown-ранбук локальной БЗ.
"""
path: str
title: str
content: str
class RunbookRetriever:
"""
Отвечает за загрузку локальной базы знаний и поиск по ней.
"""
def __init__(
self,
kb_dir: str = "kb",
embed_model_name: str = "sentence-transformers/all-MiniLM-L6-v2",
reranker_name: Optional[str] = "cross-encoder/ms-marco-MiniLM-L-6-v2",
):
"""
Загружает все ранбуки и подготавливает модели (эмбеддер + опциональный reranker).
"""
self.kb_dir = kb_dir
# Force CPU to avoid CUDA capability mismatches in WSL/GPUs.
self.device = torch.device("cpu")
self.embed_model = SentenceTransformer(embed_model_name, device=self.device)
self.reranker: Optional[CrossEncoder] = None
if reranker_name:
try:
self.reranker = CrossEncoder(reranker_name, device=self.device)
except Exception:
self.reranker = None
self.docs = self._load_docs()
if self.docs:
self.doc_embeddings = self.embed_model.encode(
[doc.content for doc in self.docs],
convert_to_tensor=True,
device=self.device,
)
else:
self.doc_embeddings = None
def _load_docs(self) -> List[RunbookDoc]:
"""
Читает Markdown-файлы из kb_dir и превращает их в список RunbookDoc.
"""
docs: List[RunbookDoc] = []
if not os.path.isdir(self.kb_dir):
return docs
for fname in os.listdir(self.kb_dir):
if not fname.endswith(".md"):
continue
path = os.path.join(self.kb_dir, fname)
with open(path, "r", encoding="utf-8") as f:
content = f.read()
title = content.splitlines()[0].lstrip("# ").strip() if content else fname
docs.append(RunbookDoc(path=path, title=title, content=content))
return docs
def search(self, query: str, top_k: int = 3):
"""
Находит топ-k релевантных ранбуков по косинусному сходству (и reranker'у, если доступен).
"""
if not self.docs or self.doc_embeddings is None:
return []
query_emb = self.embed_model.encode(query, convert_to_tensor=True, device=self.device)
scores = util.cos_sim(query_emb, self.doc_embeddings)[0]
top_results = np.argsort(-scores.cpu().numpy())[: top_k * 4]
candidates = [
{"doc": self.docs[idx], "score": float(scores[idx])} for idx in top_results
]
if self.reranker:
pairs = [[query, c["doc"].content] for c in candidates]
rerank_scores = self.reranker.predict(pairs)
for cand, rscore in zip(candidates, rerank_scores):
cand["rerank_score"] = float(rscore)
candidates = sorted(candidates, key=lambda x: x.get("rerank_score", x["score"]), reverse=True)
else:
candidates = sorted(candidates, key=lambda x: x["score"], reverse=True)
return [
{
"title": cand["doc"].title,
"score": cand.get("rerank_score", cand["score"]),
"path": cand["doc"].path,
"excerpt": cand["doc"].content[:500],
}
for cand in candidates[:top_k]
]