130 lines
3.8 KiB
Python
130 lines
3.8 KiB
Python
"""
|
|
بارگذاری تنظیمات RAG از rag_config.yaml — با پشتیبانی از چند provider و چند پایگاه دانش
|
|
"""
|
|
import os
|
|
from dataclasses import dataclass, field
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
import yaml
|
|
|
|
|
|
@dataclass
|
|
class EmbeddingConfig:
|
|
provider: str
|
|
model: str
|
|
batch_size: int = 32
|
|
api_key_env: str | None = None
|
|
base_url: str | None = None
|
|
avalai_base_url: str | None = None
|
|
avalai_api_key_env: str | None = None
|
|
|
|
|
|
@dataclass
|
|
class QdrantConfig:
|
|
host: str = "localhost"
|
|
port: int = 6333
|
|
collection_name: str = "croplogic_kb"
|
|
vector_size: int = 384
|
|
|
|
|
|
@dataclass
|
|
class ChunkingConfig:
|
|
max_chunk_tokens: int = 500
|
|
overlap_tokens: int = 50
|
|
|
|
|
|
@dataclass
|
|
class LLMConfig:
|
|
model: str = "gpt-4o"
|
|
base_url: str | None = None
|
|
api_key_env: str | None = None
|
|
avalai_base_url: str | None = None
|
|
avalai_api_key_env: str | None = None
|
|
|
|
|
|
@dataclass
|
|
class KnowledgeBaseConfig:
|
|
path: str
|
|
tone_file: str
|
|
description: str = ""
|
|
|
|
|
|
@dataclass
|
|
class RAGConfig:
|
|
embedding: EmbeddingConfig
|
|
qdrant: QdrantConfig
|
|
chunking: ChunkingConfig
|
|
llm: LLMConfig = field(default_factory=LLMConfig)
|
|
knowledge_bases: dict[str, KnowledgeBaseConfig] = field(default_factory=dict)
|
|
chromadb: dict[str, Any] = field(default_factory=dict)
|
|
|
|
|
|
def load_rag_config(config_path: str | Path | None = None) -> RAGConfig:
|
|
"""
|
|
بارگذاری تنظیمات از YAML و env.
|
|
QDRANT_HOST و QDRANT_PORT از متغیرهای محیطی override میشوند.
|
|
"""
|
|
if config_path is None:
|
|
base = Path(__file__).resolve().parent.parent
|
|
config_path = base / "config" / "rag_config.yaml"
|
|
|
|
path = Path(config_path)
|
|
if not path.exists():
|
|
raise FileNotFoundError(f"RAG config not found: {path}")
|
|
|
|
with open(path, encoding="utf-8") as f:
|
|
data = yaml.safe_load(f) or {}
|
|
|
|
emb = data.get("embedding", {})
|
|
embedding = EmbeddingConfig(
|
|
provider=emb.get("provider", "sentence_transformers"),
|
|
model=emb.get("model", "text-embedding-3-small"),
|
|
batch_size=emb.get("batch_size", 32),
|
|
api_key_env=emb.get("api_key_env"),
|
|
base_url=emb.get("base_url"),
|
|
avalai_base_url=emb.get("avalai_base_url"),
|
|
avalai_api_key_env=emb.get("avalai_api_key_env"),
|
|
)
|
|
|
|
qd = data.get("qdrant", {})
|
|
qdrant = QdrantConfig(
|
|
host=os.environ.get("QDRANT_HOST", qd.get("host", "localhost")),
|
|
port=int(os.environ.get("QDRANT_PORT", qd.get("port", 6333))),
|
|
collection_name=qd.get("collection_name", "croplogic_kb"),
|
|
vector_size=qd.get("vector_size", 1536),
|
|
)
|
|
|
|
ch = data.get("chunking", {})
|
|
chunking = ChunkingConfig(
|
|
max_chunk_tokens=ch.get("max_chunk_tokens", 500),
|
|
overlap_tokens=ch.get("overlap_tokens", 50),
|
|
)
|
|
|
|
llm_data = data.get("llm", {})
|
|
llm = LLMConfig(
|
|
model=llm_data.get("model", "gpt-4o"),
|
|
base_url=llm_data.get("base_url"),
|
|
api_key_env=llm_data.get("api_key_env"),
|
|
avalai_base_url=llm_data.get("avalai_base_url"),
|
|
avalai_api_key_env=llm_data.get("avalai_api_key_env"),
|
|
)
|
|
|
|
kb_data = data.get("knowledge_bases", {})
|
|
knowledge_bases: dict[str, KnowledgeBaseConfig] = {}
|
|
for kb_name, kb_conf in kb_data.items():
|
|
knowledge_bases[kb_name] = KnowledgeBaseConfig(
|
|
path=kb_conf.get("path", f"config/knowledge_base/{kb_name}"),
|
|
tone_file=kb_conf.get("tone_file", f"config/tones/{kb_name}_tone.txt"),
|
|
description=kb_conf.get("description", ""),
|
|
)
|
|
|
|
return RAGConfig(
|
|
embedding=embedding,
|
|
qdrant=qdrant,
|
|
chunking=chunking,
|
|
llm=llm,
|
|
knowledge_bases=knowledge_bases,
|
|
chromadb=data.get("chromadb", {}),
|
|
)
|