Files
Ai/knowledge_base/embeddings.py
T

85 lines
2.6 KiB
Python
Raw Normal View History

"""
لایه Embedding سازگار با چند provider (sentence_transformers، openai).
"""
from typing import Protocol
from .rag_settings import EmbeddingConfig, RAGConfig
class Embedder(Protocol):
"""پروتکل embedder."""
def encode(self, texts: list[str], batch_size: int | None = None) -> list[list[float]]:
...
class SentenceTransformerEmbedder:
"""Embedder با استفاده از sentence-transformers."""
def __init__(self, model_name: str):
from sentence_transformers import SentenceTransformer
self._model = SentenceTransformer(model_name)
def encode(self, texts: list[str], batch_size: int | None = None) -> list[list[float]]:
embeddings = self._model.encode(
texts,
batch_size=batch_size or 32,
show_progress_bar=len(texts) > 50,
convert_to_numpy=True,
)
return embeddings.tolist()
class OpenAIEmbedder:
"""Embedder با استفاده از OpenAI API."""
def __init__(self, model_name: str, api_key: str | None = None):
import os
from openai import OpenAI
key = api_key or os.environ.get("OPENAI_API_KEY")
if not key:
raise ValueError(
"OpenAI API key required. Set OPENAI_API_KEY env or pass api_key."
)
self._client = OpenAI(api_key=key)
self._model = model_name
def encode(self, texts: list[str], batch_size: int | None = None) -> list[list[float]]:
# OpenAI limits batch size (max ~2048 inputs); we use smaller batches
batch_size = min(batch_size or 100, 100)
all_embeddings = []
for i in range(0, len(texts), batch_size):
batch = texts[i : i + batch_size]
resp = self._client.embeddings.create(
model=self._model,
input=batch,
)
for e in resp.data:
all_embeddings.append(e.embedding)
return all_embeddings
def get_embedder(config: RAGConfig | EmbeddingConfig) -> Embedder:
"""
بر اساس config، embedder مناسب را برمی‌گرداند.
"""
if isinstance(config, RAGConfig):
cfg = config.embedding
else:
cfg = config
if cfg.provider == "sentence_transformers":
return SentenceTransformerEmbedder(model_name=cfg.model)
if cfg.provider == "openai":
api_key = None
if cfg.api_key_env:
import os
api_key = os.environ.get(cfg.api_key_env)
return OpenAIEmbedder(model_name=cfg.model, api_key=api_key)
raise ValueError(f"Unknown embedding provider: {cfg.provider}")