Files
Ai/rag/config.py
T
2026-03-22 03:08:27 +03:30

185 lines
5.9 KiB
Python

"""
بارگذاری تنظیمات RAG از rag_config.yaml — با پشتیبانی از چند provider،
چند پایگاه دانش و چند سرویس.
"""
import os
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
import yaml
@dataclass
class EmbeddingConfig:
provider: str
model: str
batch_size: int = 32
api_key_env: str | None = None
base_url: str | None = None
avalai_base_url: str | None = None
avalai_api_key_env: str | None = None
@dataclass
class QdrantConfig:
host: str = "localhost"
port: int = 6333
collection_name: str = "croplogic_kb"
vector_size: int = 384
@dataclass
class ChunkingConfig:
max_chunk_tokens: int = 500
overlap_tokens: int = 50
@dataclass
class LLMConfig:
provider: str = "gapgpt"
model: str = "gpt-4o"
base_url: str | None = None
api_key_env: str | None = None
avalai_base_url: str | None = None
avalai_api_key_env: str | None = None
@dataclass
class KnowledgeBaseConfig:
path: str
tone_file: str
description: str = ""
@dataclass
class ServiceConfig:
service_id: str
knowledge_base: str
llm: LLMConfig = field(default_factory=LLMConfig)
tone_file: str | None = None
system_prompt: str | None = None
use_user_embeddings: bool = True
description: str = ""
@dataclass
class RAGConfig:
embedding: EmbeddingConfig
qdrant: QdrantConfig
chunking: ChunkingConfig
llm: LLMConfig = field(default_factory=LLMConfig)
knowledge_bases: dict[str, KnowledgeBaseConfig] = field(default_factory=dict)
services: dict[str, ServiceConfig] = field(default_factory=dict)
chromadb: dict[str, Any] = field(default_factory=dict)
def _build_llm_config(data: dict[str, Any] | None, default: LLMConfig | None = None) -> LLMConfig:
llm_data = data or {}
fallback = default or LLMConfig()
return LLMConfig(
provider=llm_data.get("provider", fallback.provider),
model=llm_data.get("model", fallback.model),
base_url=llm_data.get("base_url", fallback.base_url),
api_key_env=llm_data.get("api_key_env", fallback.api_key_env),
avalai_base_url=llm_data.get("avalai_base_url", fallback.avalai_base_url),
avalai_api_key_env=llm_data.get("avalai_api_key_env", fallback.avalai_api_key_env),
)
def get_service_config(service_id: str, config: RAGConfig | None = None) -> ServiceConfig:
cfg = config or load_rag_config()
service = cfg.services.get(service_id)
if service is None:
raise KeyError(f"Unknown service_id: {service_id}")
return service
def load_rag_config(config_path: str | Path | None = None) -> RAGConfig:
"""
بارگذاری تنظیمات از YAML و env.
QDRANT_HOST و QDRANT_PORT از متغیرهای محیطی override می‌شوند.
"""
if config_path is None:
base = Path(__file__).resolve().parent.parent
config_path = base / "config" / "rag_config.yaml"
path = Path(config_path)
if not path.exists():
raise FileNotFoundError(f"RAG config not found: {path}")
with open(path, encoding="utf-8") as f:
data = yaml.safe_load(f) or {}
emb = data.get("embedding", {})
embedding = EmbeddingConfig(
provider=emb.get("provider", "sentence_transformers"),
model=emb.get("model", "text-embedding-3-small"),
batch_size=emb.get("batch_size", 32),
api_key_env=emb.get("api_key_env"),
base_url=emb.get("base_url"),
avalai_base_url=emb.get("avalai_base_url"),
avalai_api_key_env=emb.get("avalai_api_key_env"),
)
qd = data.get("qdrant", {})
qdrant = QdrantConfig(
host=os.environ.get("QDRANT_HOST", qd.get("host", "localhost")),
port=int(os.environ.get("QDRANT_PORT", qd.get("port", 6333))),
collection_name=qd.get("collection_name", "croplogic_kb"),
vector_size=qd.get("vector_size", 1536),
)
ch = data.get("chunking", {})
chunking = ChunkingConfig(
max_chunk_tokens=ch.get("max_chunk_tokens", 500),
overlap_tokens=ch.get("overlap_tokens", 50),
)
llm = _build_llm_config(data.get("llm", {}))
kb_data = data.get("knowledge_bases", {})
knowledge_bases: dict[str, KnowledgeBaseConfig] = {}
for kb_name, kb_conf in kb_data.items():
knowledge_bases[kb_name] = KnowledgeBaseConfig(
path=kb_conf.get("path", f"config/knowledge_base/{kb_name}"),
tone_file=kb_conf.get("tone_file", f"config/tones/{kb_name}_tone.txt"),
description=kb_conf.get("description", ""),
)
services_data = data.get("services", {})
services: dict[str, ServiceConfig] = {}
for service_id, service_conf in services_data.items():
kb_name = service_conf.get("knowledge_base", service_id)
kb_conf = knowledge_bases.get(kb_name)
services[service_id] = ServiceConfig(
service_id=service_id,
knowledge_base=kb_name,
llm=_build_llm_config(service_conf.get("llm"), default=llm),
tone_file=service_conf.get("tone_file") or (kb_conf.tone_file if kb_conf else None),
system_prompt=service_conf.get("system_prompt"),
use_user_embeddings=service_conf.get("use_user_embeddings", True),
description=service_conf.get("description", ""),
)
if not services:
for kb_name, kb_conf in knowledge_bases.items():
services[kb_name] = ServiceConfig(
service_id=kb_name,
knowledge_base=kb_name,
llm=llm,
tone_file=kb_conf.tone_file,
use_user_embeddings=True,
description=kb_conf.description,
)
return RAGConfig(
embedding=embedding,
qdrant=qdrant,
chunking=chunking,
llm=llm,
knowledge_bases=knowledge_bases,
services=services,
chromadb=data.get("chromadb", {}),
)