Add LLM configuration and update URL routing

- Introduced LLM configuration in rag_config.yaml and corresponding LLMConfig class in config.py.
- Updated load_rag_config function to parse LLM settings from the configuration file.
- Added new API route for RAG in urls.py to facilitate access to the chat model.
- Modified QdrantVectorStore to use query_points method for improved functionality.
This commit is contained in:
2026-02-27 19:44:49 +03:30
parent 197f70ee12
commit 94355af62b
8 changed files with 187 additions and 6 deletions
+2
View File
@@ -3,6 +3,7 @@
فاز یک: Qdrant به‌عنوان vector store
"""
from .chat import chat_rag_stream
from .chunker import chunk_text, chunk_texts
from .client import get_qdrant_client
from .config import load_rag_config
@@ -12,6 +13,7 @@ from .retrieve import search_with_query
from .vector_store import QdrantVectorStore
__all__ = [
"chat_rag_stream",
"chunk_text",
"chunk_texts",
"embed_single",
+102
View File
@@ -0,0 +1,102 @@
"""
چت RAG با استریم — استفاده از دیتای embed شده کاربر و Avalai API
"""
import os
from pathlib import Path
from openai import OpenAI
from .config import load_rag_config, RAGConfig
from .retrieve import search_with_query
def _get_chat_client(config: RAGConfig | None) -> OpenAI:
"""ساخت کلاینت OpenAI برای Avalai Chat API."""
cfg = config or load_rag_config()
llm = cfg.llm
env_var = llm.api_key_env or "AVALAI_API_KEY"
api_key = os.environ.get(env_var)
base_url = llm.base_url or os.environ.get(
"AVALAI_BASE_URL", "https://api.avalai.ir/v1"
)
return OpenAI(api_key=api_key, base_url=base_url)
def _load_tone(config: RAGConfig | None) -> str:
"""بارگذاری فایل لحن."""
cfg = config or load_rag_config()
base = Path(__file__).resolve().parent.parent
tone_path = base / cfg.tone_file
if tone_path.exists():
return tone_path.read_text(encoding="utf-8").strip()
return ""
def build_rag_context(query: str, config: RAGConfig | None = None, limit: int = 5) -> str:
"""
بازیابی متن‌های مرتبط از RAG برای کوئری کاربر.
"""
results = search_with_query(query, limit=limit, config=config)
if not results:
return ""
parts = []
for r in results:
text = r.get("text", "").strip()
if text:
parts.append(text)
return "\n\n---\n\n".join(parts)
def chat_rag_stream(
query: str,
config: RAGConfig | None = None,
limit: int = 5,
system_override: str | None = None,
):
"""
چت RAG با استریم: دیتای embed شده را بازیابی می‌کند و با LLM جواب می‌دهد.
Args:
query: پیام کاربر
config: تنظیمات RAG
limit: تعداد چانک‌های بازیابی‌شده
system_override: جایگزین system prompt (اختیاری)
Yields:
تک‌تک deltaهای content به‌صورت رشته
"""
cfg = config or load_rag_config()
client = _get_chat_client(cfg)
model = cfg.llm.model
context = build_rag_context(query, config=cfg, limit=limit)
if system_override is not None:
system_content = system_override
else:
tone = _load_tone(cfg)
system_parts = [tone] if tone else []
system_parts.append(
"با استفاده از بخش «متن‌های مرجع» زیر به سوال کاربر پاسخ بده. "
"فقط در حد نیاز از مرجع استفاده کن و پاسخ را به زبان کاربر بنویس."
)
if context:
system_parts.append("\n\nمتن‌های مرجع:\n" + context)
system_content = "\n".join(system_parts)
messages = [
{"role": "system", "content": system_content},
{"role": "user", "content": query},
]
stream = client.chat.completions.create(
model=model,
messages=messages,
stream=True,
)
for chunk in stream:
delta = chunk.choices[0].delta if chunk.choices else None
content = delta.content if delta else ""
if content:
yield content
+16
View File
@@ -32,11 +32,19 @@ class ChunkingConfig:
overlap_tokens: int = 50
@dataclass
class LLMConfig:
model: str = "gpt-4o"
base_url: str | None = None
api_key_env: str | None = None
@dataclass
class RAGConfig:
embedding: EmbeddingConfig
qdrant: QdrantConfig
chunking: ChunkingConfig
llm: LLMConfig = field(default_factory=LLMConfig)
tone_file: str = "config/tone.txt"
knowledge_base_path: str = "config/knowledge_base"
user_info_path: str = "config/user_info"
@@ -82,10 +90,18 @@ def load_rag_config(config_path: str | Path | None = None) -> RAGConfig:
overlap_tokens=ch.get("overlap_tokens", 50),
)
llm_data = data.get("llm", {})
llm = LLMConfig(
model=llm_data.get("model", "gpt-4o"),
base_url=llm_data.get("base_url"),
api_key_env=llm_data.get("api_key_env"),
)
return RAGConfig(
embedding=embedding,
qdrant=qdrant,
chunking=chunking,
llm=llm,
tone_file=data.get("tone_file", "config/tone.txt"),
knowledge_base_path=data.get("knowledge_base_path", "config/knowledge_base"),
user_info_path=data.get("user_info_path", "config/user_info"),
+7
View File
@@ -0,0 +1,7 @@
from django.urls import path
from .views import ChatView
urlpatterns = [
path("chat/", ChatView.as_view()),
]
+10 -6
View File
@@ -98,20 +98,24 @@ class QdrantVectorStore:
) -> list[dict]:
"""
جستجوی شباهت بر اساس query vector.
از query_points استفاده می‌کند (qdrant-client >= 2.0).
"""
results = self.client.search(
response = self.client.query_points(
collection_name=self.qdrant.collection_name,
query_vector=query_vector,
query=query_vector,
limit=limit,
score_threshold=score_threshold,
)
points = getattr(response, "points", []) or []
return [
{
"id": str(r.id),
"score": r.score,
"text": r.payload.get("text", ""),
"metadata": {k: v for k, v in r.payload.items() if k != "text"},
"score": float(r.score) if r.score is not None else 0.0,
"text": (r.payload or {}).get("text", ""),
"metadata": {
k: v for k, v in (r.payload or {}).items() if k != "text"
},
}
for r in results
for r in points
]
+43
View File
@@ -0,0 +1,43 @@
"""
ویوهای RAG — چت با استریم
"""
from django.http import StreamingHttpResponse
from rest_framework import status
from rest_framework.request import Request
from rest_framework.response import Response
from rest_framework.views import APIView
from .chat import chat_rag_stream
class ChatView(APIView):
"""
چت RAG با استریم.
POST با {"message": "متن سوال"} یا query param message
"""
def post(self, request: Request):
message = request.data.get("message") or request.query_params.get("message")
if not message or not isinstance(message, str):
return Response(
{"code": 400, "msg": "پارامتر message الزامی است."},
status=status.HTTP_400_BAD_REQUEST,
)
message = str(message).strip()
if not message:
return Response(
{"code": 400, "msg": "پیام نباید خالی باشد."},
status=status.HTTP_400_BAD_REQUEST,
)
def generate():
try:
for chunk in chat_rag_stream(message):
yield chunk
except Exception as e:
yield f"\n[خطا: {e}]"
return StreamingHttpResponse(
generate(),
content_type="text/plain; charset=utf-8",
)