""" لایه Embedding سازگار با چند provider (sentence_transformers، openai). """ from typing import Protocol from .rag_settings import EmbeddingConfig, RAGConfig class Embedder(Protocol): """پروتکل embedder.""" def encode(self, texts: list[str], batch_size: int | None = None) -> list[list[float]]: ... class SentenceTransformerEmbedder: """Embedder با استفاده از sentence-transformers.""" def __init__(self, model_name: str): from sentence_transformers import SentenceTransformer self._model = SentenceTransformer(model_name) def encode(self, texts: list[str], batch_size: int | None = None) -> list[list[float]]: embeddings = self._model.encode( texts, batch_size=batch_size or 32, show_progress_bar=len(texts) > 50, convert_to_numpy=True, ) return embeddings.tolist() class OpenAIEmbedder: """Embedder با استفاده از OpenAI API.""" def __init__(self, model_name: str, api_key: str | None = None): import os from openai import OpenAI key = api_key or os.environ.get("OPENAI_API_KEY") if not key: raise ValueError( "OpenAI API key required. Set OPENAI_API_KEY env or pass api_key." ) self._client = OpenAI(api_key=key) self._model = model_name def encode(self, texts: list[str], batch_size: int | None = None) -> list[list[float]]: # OpenAI limits batch size (max ~2048 inputs); we use smaller batches batch_size = min(batch_size or 100, 100) all_embeddings = [] for i in range(0, len(texts), batch_size): batch = texts[i : i + batch_size] resp = self._client.embeddings.create( model=self._model, input=batch, ) for e in resp.data: all_embeddings.append(e.embedding) return all_embeddings def get_embedder(config: RAGConfig | EmbeddingConfig) -> Embedder: """ بر اساس config، embedder مناسب را برمی‌گرداند. """ if isinstance(config, RAGConfig): cfg = config.embedding else: cfg = config if cfg.provider == "sentence_transformers": return SentenceTransformerEmbedder(model_name=cfg.model) if cfg.provider == "openai": api_key = None if cfg.api_key_env: import os api_key = os.environ.get(cfg.api_key_env) return OpenAIEmbedder(model_name=cfg.model, api_key=api_key) raise ValueError(f"Unknown embedding provider: {cfg.provider}")