diff --git a/config/rag_config.yaml b/config/rag_config.yaml index 29a6287..a083764 100644 --- a/config/rag_config.yaml +++ b/config/rag_config.yaml @@ -18,6 +18,12 @@ chunking: max_chunk_tokens: 500 overlap_tokens: 50 +# تنظیمات مدل چت (LLM) — Avalai +llm: + model: "gpt-4o" + base_url: "https://api.avalai.ir/v1" + api_key_env: "AVALAI_API_KEY" + tone_file: "config/tone.txt" knowledge_base_path: "config/knowledge_base" user_info_path: "config/user_info" diff --git a/config/urls.py b/config/urls.py index b7776bd..0894916 100644 --- a/config/urls.py +++ b/config/urls.py @@ -3,6 +3,7 @@ from django.urls import include, path urlpatterns = [ path("admin/", admin.site.urls), + path("api/rag/", include("rag.urls")), path("api/tasks/", include("tasks.urls")), path("api/soil-data/", include("soil_data.urls")), path("api/sensor-data/", include("sensor_data.urls")), diff --git a/rag/__init__.py b/rag/__init__.py index fda6379..72f8864 100644 --- a/rag/__init__.py +++ b/rag/__init__.py @@ -3,6 +3,7 @@ فاز یک: Qdrant به‌عنوان vector store """ +from .chat import chat_rag_stream from .chunker import chunk_text, chunk_texts from .client import get_qdrant_client from .config import load_rag_config @@ -12,6 +13,7 @@ from .retrieve import search_with_query from .vector_store import QdrantVectorStore __all__ = [ + "chat_rag_stream", "chunk_text", "chunk_texts", "embed_single", diff --git a/rag/chat.py b/rag/chat.py new file mode 100644 index 0000000..0dae738 --- /dev/null +++ b/rag/chat.py @@ -0,0 +1,102 @@ +""" +چت RAG با استریم — استفاده از دیتای embed شده کاربر و Avalai API +""" +import os +from pathlib import Path + +from openai import OpenAI + +from .config import load_rag_config, RAGConfig +from .retrieve import search_with_query + + +def _get_chat_client(config: RAGConfig | None) -> OpenAI: + """ساخت کلاینت OpenAI برای Avalai Chat API.""" + cfg = config or load_rag_config() + llm = cfg.llm + env_var = llm.api_key_env or "AVALAI_API_KEY" + api_key = os.environ.get(env_var) + base_url = llm.base_url or os.environ.get( + "AVALAI_BASE_URL", "https://api.avalai.ir/v1" + ) + return OpenAI(api_key=api_key, base_url=base_url) + + +def _load_tone(config: RAGConfig | None) -> str: + """بارگذاری فایل لحن.""" + cfg = config or load_rag_config() + base = Path(__file__).resolve().parent.parent + tone_path = base / cfg.tone_file + if tone_path.exists(): + return tone_path.read_text(encoding="utf-8").strip() + return "" + + +def build_rag_context(query: str, config: RAGConfig | None = None, limit: int = 5) -> str: + """ + بازیابی متن‌های مرتبط از RAG برای کوئری کاربر. + """ + results = search_with_query(query, limit=limit, config=config) + if not results: + return "" + parts = [] + for r in results: + text = r.get("text", "").strip() + if text: + parts.append(text) + return "\n\n---\n\n".join(parts) + + +def chat_rag_stream( + query: str, + config: RAGConfig | None = None, + limit: int = 5, + system_override: str | None = None, +): + """ + چت RAG با استریم: دیتای embed شده را بازیابی می‌کند و با LLM جواب می‌دهد. + + Args: + query: پیام کاربر + config: تنظیمات RAG + limit: تعداد چانک‌های بازیابی‌شده + system_override: جایگزین system prompt (اختیاری) + + Yields: + تک‌تک deltaهای content به‌صورت رشته + """ + cfg = config or load_rag_config() + client = _get_chat_client(cfg) + model = cfg.llm.model + + context = build_rag_context(query, config=cfg, limit=limit) + + if system_override is not None: + system_content = system_override + else: + tone = _load_tone(cfg) + system_parts = [tone] if tone else [] + system_parts.append( + "با استفاده از بخش «متن‌های مرجع» زیر به سوال کاربر پاسخ بده. " + "فقط در حد نیاز از مرجع استفاده کن و پاسخ را به زبان کاربر بنویس." + ) + if context: + system_parts.append("\n\nمتن‌های مرجع:\n" + context) + system_content = "\n".join(system_parts) + + messages = [ + {"role": "system", "content": system_content}, + {"role": "user", "content": query}, + ] + + stream = client.chat.completions.create( + model=model, + messages=messages, + stream=True, + ) + + for chunk in stream: + delta = chunk.choices[0].delta if chunk.choices else None + content = delta.content if delta else "" + if content: + yield content diff --git a/rag/config.py b/rag/config.py index f1aa694..908f718 100644 --- a/rag/config.py +++ b/rag/config.py @@ -32,11 +32,19 @@ class ChunkingConfig: overlap_tokens: int = 50 +@dataclass +class LLMConfig: + model: str = "gpt-4o" + base_url: str | None = None + api_key_env: str | None = None + + @dataclass class RAGConfig: embedding: EmbeddingConfig qdrant: QdrantConfig chunking: ChunkingConfig + llm: LLMConfig = field(default_factory=LLMConfig) tone_file: str = "config/tone.txt" knowledge_base_path: str = "config/knowledge_base" user_info_path: str = "config/user_info" @@ -82,10 +90,18 @@ def load_rag_config(config_path: str | Path | None = None) -> RAGConfig: overlap_tokens=ch.get("overlap_tokens", 50), ) + llm_data = data.get("llm", {}) + llm = LLMConfig( + model=llm_data.get("model", "gpt-4o"), + base_url=llm_data.get("base_url"), + api_key_env=llm_data.get("api_key_env"), + ) + return RAGConfig( embedding=embedding, qdrant=qdrant, chunking=chunking, + llm=llm, tone_file=data.get("tone_file", "config/tone.txt"), knowledge_base_path=data.get("knowledge_base_path", "config/knowledge_base"), user_info_path=data.get("user_info_path", "config/user_info"), diff --git a/rag/urls.py b/rag/urls.py new file mode 100644 index 0000000..0058a15 --- /dev/null +++ b/rag/urls.py @@ -0,0 +1,7 @@ +from django.urls import path + +from .views import ChatView + +urlpatterns = [ + path("chat/", ChatView.as_view()), +] diff --git a/rag/vector_store.py b/rag/vector_store.py index a6deb15..11cbb57 100644 --- a/rag/vector_store.py +++ b/rag/vector_store.py @@ -98,20 +98,24 @@ class QdrantVectorStore: ) -> list[dict]: """ جستجوی شباهت بر اساس query vector. + از query_points استفاده می‌کند (qdrant-client >= 2.0). """ - results = self.client.search( + response = self.client.query_points( collection_name=self.qdrant.collection_name, - query_vector=query_vector, + query=query_vector, limit=limit, score_threshold=score_threshold, ) + points = getattr(response, "points", []) or [] return [ { "id": str(r.id), - "score": r.score, - "text": r.payload.get("text", ""), - "metadata": {k: v for k, v in r.payload.items() if k != "text"}, + "score": float(r.score) if r.score is not None else 0.0, + "text": (r.payload or {}).get("text", ""), + "metadata": { + k: v for k, v in (r.payload or {}).items() if k != "text" + }, } - for r in results + for r in points ] diff --git a/rag/views.py b/rag/views.py new file mode 100644 index 0000000..e33c038 --- /dev/null +++ b/rag/views.py @@ -0,0 +1,43 @@ +""" +ویوهای RAG — چت با استریم +""" +from django.http import StreamingHttpResponse +from rest_framework import status +from rest_framework.request import Request +from rest_framework.response import Response +from rest_framework.views import APIView + +from .chat import chat_rag_stream + + +class ChatView(APIView): + """ + چت RAG با استریم. + POST با {"message": "متن سوال"} یا query param message + """ + + def post(self, request: Request): + message = request.data.get("message") or request.query_params.get("message") + if not message or not isinstance(message, str): + return Response( + {"code": 400, "msg": "پارامتر message الزامی است."}, + status=status.HTTP_400_BAD_REQUEST, + ) + message = str(message).strip() + if not message: + return Response( + {"code": 400, "msg": "پیام نباید خالی باشد."}, + status=status.HTTP_400_BAD_REQUEST, + ) + + def generate(): + try: + for chunk in chat_rag_stream(message): + yield chunk + except Exception as e: + yield f"\n[خطا: {e}]" + + return StreamingHttpResponse( + generate(), + content_type="text/plain; charset=utf-8", + )