| | """ |
| | CPU-optimized retrieval for efficient context handling. |
| | """ |
| |
|
| | import logging |
| | import heapq |
| | from typing import List, Dict, Any, Optional, Tuple, Union |
| | import numpy as np |
| |
|
| | from efficient_context.retrieval.base import BaseRetriever |
| | from efficient_context.chunking.base import Chunk |
| |
|
| | |
| | logging.basicConfig(level=logging.INFO) |
| | logger = logging.getLogger(__name__) |
| |
|
| | class CPUOptimizedRetriever(BaseRetriever): |
| | """ |
| | Retriever optimized for CPU performance and low memory usage. |
| | |
| | This retriever uses techniques to minimize computational requirements |
| | while still providing high-quality retrieval results. |
| | """ |
| | |
| | def __init__( |
| | self, |
| | embedding_model: str = "lightweight", |
| | similarity_metric: str = "cosine", |
| | use_batching: bool = True, |
| | batch_size: int = 32, |
| | max_index_size: Optional[int] = None, |
| | ): |
| | """ |
| | Initialize the CPUOptimizedRetriever. |
| | |
| | Args: |
| | embedding_model: Model to use for embeddings |
| | similarity_metric: Metric for comparing embeddings |
| | use_batching: Whether to batch embedding operations |
| | batch_size: Size of batches for embedding |
| | max_index_size: Maximum number of chunks to keep in the index |
| | """ |
| | self.embedding_model = embedding_model |
| | self.similarity_metric = similarity_metric |
| | self.use_batching = use_batching |
| | self.batch_size = batch_size |
| | self.max_index_size = max_index_size |
| | |
| | |
| | self.chunks = [] |
| | self.chunk_embeddings = None |
| | self.chunk_ids_to_index = {} |
| | |
| | |
| | self._init_embedding_model() |
| | |
| | logger.info("CPUOptimizedRetriever initialized with model: %s", embedding_model) |
| | |
| | def _init_embedding_model(self): |
| | """Initialize the embedding model.""" |
| | try: |
| | from sentence_transformers import SentenceTransformer |
| | |
| | |
| | if self.embedding_model == "lightweight": |
| | |
| | self.model = SentenceTransformer('paraphrase-MiniLM-L3-v2') |
| | else: |
| | |
| | self.model = SentenceTransformer(self.embedding_model) |
| | |
| | logger.info("Using embedding model: %s", self.model.get_sentence_embedding_dimension()) |
| | except ImportError: |
| | logger.warning("SentenceTransformer not available, using numpy fallback (less accurate)") |
| | self.model = None |
| | |
| | def _get_embeddings(self, texts: List[str]) -> np.ndarray: |
| | """ |
| | Get embeddings for a list of texts. |
| | |
| | Args: |
| | texts: List of texts to embed |
| | |
| | Returns: |
| | embeddings: Array of text embeddings |
| | """ |
| | if not texts: |
| | return np.array([]) |
| | |
| | if self.model is not None: |
| | |
| | |
| | if self.use_batching and len(texts) > self.batch_size: |
| | embeddings = [] |
| | |
| | for i in range(0, len(texts), self.batch_size): |
| | batch = texts[i:i+self.batch_size] |
| | batch_embeddings = self.model.encode( |
| | batch, |
| | show_progress_bar=False, |
| | convert_to_numpy=True |
| | ) |
| | embeddings.append(batch_embeddings) |
| | |
| | return np.vstack(embeddings) |
| | else: |
| | return self.model.encode(texts, show_progress_bar=False) |
| | else: |
| | |
| | from sklearn.feature_extraction.text import TfidfVectorizer |
| | vectorizer = TfidfVectorizer(max_features=5000) |
| | return vectorizer.fit_transform(texts).toarray() |
| | |
| | def _compute_similarities(self, query_embedding: np.ndarray, chunk_embeddings: np.ndarray) -> np.ndarray: |
| | """ |
| | Compute similarities between query and chunk embeddings. |
| | |
| | Args: |
| | query_embedding: Embedding of the query |
| | chunk_embeddings: Embeddings of the chunks |
| | |
| | Returns: |
| | similarities: Array of similarity scores |
| | """ |
| | if self.similarity_metric == "cosine": |
| | |
| | query_norm = np.linalg.norm(query_embedding) |
| | if query_norm > 0: |
| | query_embedding = query_embedding / query_norm |
| | |
| | |
| | return np.dot(chunk_embeddings, query_embedding) |
| | elif self.similarity_metric == "dot": |
| | |
| | return np.dot(chunk_embeddings, query_embedding) |
| | elif self.similarity_metric == "euclidean": |
| | |
| | return -np.sqrt(np.sum((chunk_embeddings - query_embedding) ** 2, axis=1)) |
| | else: |
| | |
| | return np.dot(chunk_embeddings, query_embedding) |
| | |
| | def index_chunks(self, chunks: List[Chunk]) -> None: |
| | """ |
| | Index chunks for future retrieval. |
| | |
| | Args: |
| | chunks: Chunks to index |
| | """ |
| | if not chunks: |
| | return |
| | |
| | |
| | for chunk in chunks: |
| | |
| | if chunk.chunk_id in self.chunk_ids_to_index: |
| | continue |
| | |
| | self.chunks.append(chunk) |
| | self.chunk_ids_to_index[chunk.chunk_id] = len(self.chunks) - 1 |
| | |
| | |
| | chunk_texts = [chunk.content for chunk in self.chunks] |
| | self.chunk_embeddings = self._get_embeddings(chunk_texts) |
| | |
| | |
| | if (self.max_index_size is not None and |
| | len(self.chunks) > self.max_index_size and |
| | self.model is not None): |
| | |
| | |
| | self.chunks = self.chunks[-self.max_index_size:] |
| | |
| | |
| | self.chunk_ids_to_index = { |
| | chunk.chunk_id: i for i, chunk in enumerate(self.chunks) |
| | } |
| | |
| | |
| | chunk_texts = [chunk.content for chunk in self.chunks] |
| | self.chunk_embeddings = self._get_embeddings(chunk_texts) |
| | |
| | |
| | if self.similarity_metric == "cosine" and self.chunk_embeddings is not None: |
| | |
| | norms = np.linalg.norm(self.chunk_embeddings, axis=1, keepdims=True) |
| | |
| | |
| | non_zero_norms = norms > 0 |
| | if np.any(non_zero_norms): |
| | |
| | self.chunk_embeddings = np.where( |
| | non_zero_norms, |
| | self.chunk_embeddings / norms, |
| | self.chunk_embeddings |
| | ) |
| | |
| | logger.info("Indexed %d chunks (total: %d)", len(chunks), len(self.chunks)) |
| | |
| | def retrieve(self, query: str, top_k: Optional[int] = None) -> List[Chunk]: |
| | """ |
| | Retrieve chunks relevant to a query. |
| | |
| | Args: |
| | query: Query to retrieve chunks for |
| | top_k: Number of chunks to retrieve (default: 5) |
| | |
| | Returns: |
| | chunks: List of retrieved chunks |
| | """ |
| | if not self.chunks: |
| | logger.warning("No chunks indexed for retrieval") |
| | return [] |
| | |
| | if not query: |
| | logger.warning("Empty query provided") |
| | return [] |
| | |
| | |
| | top_k = top_k or 5 |
| | |
| | |
| | query_embedding = self._get_embeddings([query])[0] |
| | |
| | |
| | similarities = self._compute_similarities(query_embedding, self.chunk_embeddings) |
| | |
| | |
| | if top_k >= len(similarities): |
| | top_indices = list(range(len(similarities))) |
| | top_indices.sort(key=lambda i: similarities[i], reverse=True) |
| | else: |
| | |
| | top_indices = heapq.nlargest(top_k, range(len(similarities)), key=lambda i: similarities[i]) |
| | |
| | |
| | retrieved_chunks = [self.chunks[i] for i in top_indices] |
| | |
| | logger.info("Retrieved %d chunks for query", len(retrieved_chunks)) |
| | return retrieved_chunks |
| | |
| | def clear(self) -> None: |
| | """Clear all indexed chunks.""" |
| | self.chunks = [] |
| | self.chunk_embeddings = None |
| | self.chunk_ids_to_index = {} |
| | logger.info("Cleared chunk index") |
| |
|