85 lines
2.6 KiB
Python
85 lines
2.6 KiB
Python
|
|
"""
|
||
|
|
لایه Embedding سازگار با چند provider (sentence_transformers، openai).
|
||
|
|
"""
|
||
|
|
from typing import Protocol
|
||
|
|
|
||
|
|
from .rag_settings import EmbeddingConfig, RAGConfig
|
||
|
|
|
||
|
|
|
||
|
|
class Embedder(Protocol):
|
||
|
|
"""پروتکل embedder."""
|
||
|
|
|
||
|
|
def encode(self, texts: list[str], batch_size: int | None = None) -> list[list[float]]:
|
||
|
|
...
|
||
|
|
|
||
|
|
|
||
|
|
class SentenceTransformerEmbedder:
|
||
|
|
"""Embedder با استفاده از sentence-transformers."""
|
||
|
|
|
||
|
|
def __init__(self, model_name: str):
|
||
|
|
from sentence_transformers import SentenceTransformer
|
||
|
|
|
||
|
|
self._model = SentenceTransformer(model_name)
|
||
|
|
|
||
|
|
def encode(self, texts: list[str], batch_size: int | None = None) -> list[list[float]]:
|
||
|
|
embeddings = self._model.encode(
|
||
|
|
texts,
|
||
|
|
batch_size=batch_size or 32,
|
||
|
|
show_progress_bar=len(texts) > 50,
|
||
|
|
convert_to_numpy=True,
|
||
|
|
)
|
||
|
|
return embeddings.tolist()
|
||
|
|
|
||
|
|
|
||
|
|
class OpenAIEmbedder:
|
||
|
|
"""Embedder با استفاده از OpenAI API."""
|
||
|
|
|
||
|
|
def __init__(self, model_name: str, api_key: str | None = None):
|
||
|
|
import os
|
||
|
|
|
||
|
|
from openai import OpenAI
|
||
|
|
|
||
|
|
key = api_key or os.environ.get("OPENAI_API_KEY")
|
||
|
|
if not key:
|
||
|
|
raise ValueError(
|
||
|
|
"OpenAI API key required. Set OPENAI_API_KEY env or pass api_key."
|
||
|
|
)
|
||
|
|
self._client = OpenAI(api_key=key)
|
||
|
|
self._model = model_name
|
||
|
|
|
||
|
|
def encode(self, texts: list[str], batch_size: int | None = None) -> list[list[float]]:
|
||
|
|
# OpenAI limits batch size (max ~2048 inputs); we use smaller batches
|
||
|
|
batch_size = min(batch_size or 100, 100)
|
||
|
|
all_embeddings = []
|
||
|
|
for i in range(0, len(texts), batch_size):
|
||
|
|
batch = texts[i : i + batch_size]
|
||
|
|
resp = self._client.embeddings.create(
|
||
|
|
model=self._model,
|
||
|
|
input=batch,
|
||
|
|
)
|
||
|
|
for e in resp.data:
|
||
|
|
all_embeddings.append(e.embedding)
|
||
|
|
return all_embeddings
|
||
|
|
|
||
|
|
|
||
|
|
def get_embedder(config: RAGConfig | EmbeddingConfig) -> Embedder:
|
||
|
|
"""
|
||
|
|
بر اساس config، embedder مناسب را برمیگرداند.
|
||
|
|
"""
|
||
|
|
if isinstance(config, RAGConfig):
|
||
|
|
cfg = config.embedding
|
||
|
|
else:
|
||
|
|
cfg = config
|
||
|
|
|
||
|
|
if cfg.provider == "sentence_transformers":
|
||
|
|
return SentenceTransformerEmbedder(model_name=cfg.model)
|
||
|
|
if cfg.provider == "openai":
|
||
|
|
api_key = None
|
||
|
|
if cfg.api_key_env:
|
||
|
|
import os
|
||
|
|
|
||
|
|
api_key = os.environ.get(cfg.api_key_env)
|
||
|
|
return OpenAIEmbedder(model_name=cfg.model, api_key=api_key)
|
||
|
|
|
||
|
|
raise ValueError(f"Unknown embedding provider: {cfg.provider}")
|