Add LLM configuration and update URL routing
- Introduced LLM configuration in rag_config.yaml and corresponding LLMConfig class in config.py. - Updated load_rag_config function to parse LLM settings from the configuration file. - Added new API route for RAG in urls.py to facilitate access to the chat model. - Modified QdrantVectorStore to use query_points method for improved functionality.
This commit is contained in:
@@ -18,6 +18,12 @@ chunking:
|
|||||||
max_chunk_tokens: 500
|
max_chunk_tokens: 500
|
||||||
overlap_tokens: 50
|
overlap_tokens: 50
|
||||||
|
|
||||||
|
# تنظیمات مدل چت (LLM) — Avalai
|
||||||
|
llm:
|
||||||
|
model: "gpt-4o"
|
||||||
|
base_url: "https://api.avalai.ir/v1"
|
||||||
|
api_key_env: "AVALAI_API_KEY"
|
||||||
|
|
||||||
tone_file: "config/tone.txt"
|
tone_file: "config/tone.txt"
|
||||||
knowledge_base_path: "config/knowledge_base"
|
knowledge_base_path: "config/knowledge_base"
|
||||||
user_info_path: "config/user_info"
|
user_info_path: "config/user_info"
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ from django.urls import include, path
|
|||||||
|
|
||||||
urlpatterns = [
|
urlpatterns = [
|
||||||
path("admin/", admin.site.urls),
|
path("admin/", admin.site.urls),
|
||||||
|
path("api/rag/", include("rag.urls")),
|
||||||
path("api/tasks/", include("tasks.urls")),
|
path("api/tasks/", include("tasks.urls")),
|
||||||
path("api/soil-data/", include("soil_data.urls")),
|
path("api/soil-data/", include("soil_data.urls")),
|
||||||
path("api/sensor-data/", include("sensor_data.urls")),
|
path("api/sensor-data/", include("sensor_data.urls")),
|
||||||
|
|||||||
@@ -3,6 +3,7 @@
|
|||||||
فاز یک: Qdrant بهعنوان vector store
|
فاز یک: Qdrant بهعنوان vector store
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from .chat import chat_rag_stream
|
||||||
from .chunker import chunk_text, chunk_texts
|
from .chunker import chunk_text, chunk_texts
|
||||||
from .client import get_qdrant_client
|
from .client import get_qdrant_client
|
||||||
from .config import load_rag_config
|
from .config import load_rag_config
|
||||||
@@ -12,6 +13,7 @@ from .retrieve import search_with_query
|
|||||||
from .vector_store import QdrantVectorStore
|
from .vector_store import QdrantVectorStore
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
"chat_rag_stream",
|
||||||
"chunk_text",
|
"chunk_text",
|
||||||
"chunk_texts",
|
"chunk_texts",
|
||||||
"embed_single",
|
"embed_single",
|
||||||
|
|||||||
+102
@@ -0,0 +1,102 @@
|
|||||||
|
"""
|
||||||
|
چت RAG با استریم — استفاده از دیتای embed شده کاربر و Avalai API
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from openai import OpenAI
|
||||||
|
|
||||||
|
from .config import load_rag_config, RAGConfig
|
||||||
|
from .retrieve import search_with_query
|
||||||
|
|
||||||
|
|
||||||
|
def _get_chat_client(config: RAGConfig | None) -> OpenAI:
|
||||||
|
"""ساخت کلاینت OpenAI برای Avalai Chat API."""
|
||||||
|
cfg = config or load_rag_config()
|
||||||
|
llm = cfg.llm
|
||||||
|
env_var = llm.api_key_env or "AVALAI_API_KEY"
|
||||||
|
api_key = os.environ.get(env_var)
|
||||||
|
base_url = llm.base_url or os.environ.get(
|
||||||
|
"AVALAI_BASE_URL", "https://api.avalai.ir/v1"
|
||||||
|
)
|
||||||
|
return OpenAI(api_key=api_key, base_url=base_url)
|
||||||
|
|
||||||
|
|
||||||
|
def _load_tone(config: RAGConfig | None) -> str:
|
||||||
|
"""بارگذاری فایل لحن."""
|
||||||
|
cfg = config or load_rag_config()
|
||||||
|
base = Path(__file__).resolve().parent.parent
|
||||||
|
tone_path = base / cfg.tone_file
|
||||||
|
if tone_path.exists():
|
||||||
|
return tone_path.read_text(encoding="utf-8").strip()
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
def build_rag_context(query: str, config: RAGConfig | None = None, limit: int = 5) -> str:
|
||||||
|
"""
|
||||||
|
بازیابی متنهای مرتبط از RAG برای کوئری کاربر.
|
||||||
|
"""
|
||||||
|
results = search_with_query(query, limit=limit, config=config)
|
||||||
|
if not results:
|
||||||
|
return ""
|
||||||
|
parts = []
|
||||||
|
for r in results:
|
||||||
|
text = r.get("text", "").strip()
|
||||||
|
if text:
|
||||||
|
parts.append(text)
|
||||||
|
return "\n\n---\n\n".join(parts)
|
||||||
|
|
||||||
|
|
||||||
|
def chat_rag_stream(
|
||||||
|
query: str,
|
||||||
|
config: RAGConfig | None = None,
|
||||||
|
limit: int = 5,
|
||||||
|
system_override: str | None = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
چت RAG با استریم: دیتای embed شده را بازیابی میکند و با LLM جواب میدهد.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: پیام کاربر
|
||||||
|
config: تنظیمات RAG
|
||||||
|
limit: تعداد چانکهای بازیابیشده
|
||||||
|
system_override: جایگزین system prompt (اختیاری)
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
تکتک deltaهای content بهصورت رشته
|
||||||
|
"""
|
||||||
|
cfg = config or load_rag_config()
|
||||||
|
client = _get_chat_client(cfg)
|
||||||
|
model = cfg.llm.model
|
||||||
|
|
||||||
|
context = build_rag_context(query, config=cfg, limit=limit)
|
||||||
|
|
||||||
|
if system_override is not None:
|
||||||
|
system_content = system_override
|
||||||
|
else:
|
||||||
|
tone = _load_tone(cfg)
|
||||||
|
system_parts = [tone] if tone else []
|
||||||
|
system_parts.append(
|
||||||
|
"با استفاده از بخش «متنهای مرجع» زیر به سوال کاربر پاسخ بده. "
|
||||||
|
"فقط در حد نیاز از مرجع استفاده کن و پاسخ را به زبان کاربر بنویس."
|
||||||
|
)
|
||||||
|
if context:
|
||||||
|
system_parts.append("\n\nمتنهای مرجع:\n" + context)
|
||||||
|
system_content = "\n".join(system_parts)
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{"role": "system", "content": system_content},
|
||||||
|
{"role": "user", "content": query},
|
||||||
|
]
|
||||||
|
|
||||||
|
stream = client.chat.completions.create(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
for chunk in stream:
|
||||||
|
delta = chunk.choices[0].delta if chunk.choices else None
|
||||||
|
content = delta.content if delta else ""
|
||||||
|
if content:
|
||||||
|
yield content
|
||||||
@@ -32,11 +32,19 @@ class ChunkingConfig:
|
|||||||
overlap_tokens: int = 50
|
overlap_tokens: int = 50
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LLMConfig:
|
||||||
|
model: str = "gpt-4o"
|
||||||
|
base_url: str | None = None
|
||||||
|
api_key_env: str | None = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class RAGConfig:
|
class RAGConfig:
|
||||||
embedding: EmbeddingConfig
|
embedding: EmbeddingConfig
|
||||||
qdrant: QdrantConfig
|
qdrant: QdrantConfig
|
||||||
chunking: ChunkingConfig
|
chunking: ChunkingConfig
|
||||||
|
llm: LLMConfig = field(default_factory=LLMConfig)
|
||||||
tone_file: str = "config/tone.txt"
|
tone_file: str = "config/tone.txt"
|
||||||
knowledge_base_path: str = "config/knowledge_base"
|
knowledge_base_path: str = "config/knowledge_base"
|
||||||
user_info_path: str = "config/user_info"
|
user_info_path: str = "config/user_info"
|
||||||
@@ -82,10 +90,18 @@ def load_rag_config(config_path: str | Path | None = None) -> RAGConfig:
|
|||||||
overlap_tokens=ch.get("overlap_tokens", 50),
|
overlap_tokens=ch.get("overlap_tokens", 50),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
llm_data = data.get("llm", {})
|
||||||
|
llm = LLMConfig(
|
||||||
|
model=llm_data.get("model", "gpt-4o"),
|
||||||
|
base_url=llm_data.get("base_url"),
|
||||||
|
api_key_env=llm_data.get("api_key_env"),
|
||||||
|
)
|
||||||
|
|
||||||
return RAGConfig(
|
return RAGConfig(
|
||||||
embedding=embedding,
|
embedding=embedding,
|
||||||
qdrant=qdrant,
|
qdrant=qdrant,
|
||||||
chunking=chunking,
|
chunking=chunking,
|
||||||
|
llm=llm,
|
||||||
tone_file=data.get("tone_file", "config/tone.txt"),
|
tone_file=data.get("tone_file", "config/tone.txt"),
|
||||||
knowledge_base_path=data.get("knowledge_base_path", "config/knowledge_base"),
|
knowledge_base_path=data.get("knowledge_base_path", "config/knowledge_base"),
|
||||||
user_info_path=data.get("user_info_path", "config/user_info"),
|
user_info_path=data.get("user_info_path", "config/user_info"),
|
||||||
|
|||||||
@@ -0,0 +1,7 @@
|
|||||||
|
from django.urls import path
|
||||||
|
|
||||||
|
from .views import ChatView
|
||||||
|
|
||||||
|
urlpatterns = [
|
||||||
|
path("chat/", ChatView.as_view()),
|
||||||
|
]
|
||||||
+10
-6
@@ -98,20 +98,24 @@ class QdrantVectorStore:
|
|||||||
) -> list[dict]:
|
) -> list[dict]:
|
||||||
"""
|
"""
|
||||||
جستجوی شباهت بر اساس query vector.
|
جستجوی شباهت بر اساس query vector.
|
||||||
|
از query_points استفاده میکند (qdrant-client >= 2.0).
|
||||||
"""
|
"""
|
||||||
results = self.client.search(
|
response = self.client.query_points(
|
||||||
collection_name=self.qdrant.collection_name,
|
collection_name=self.qdrant.collection_name,
|
||||||
query_vector=query_vector,
|
query=query_vector,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
score_threshold=score_threshold,
|
score_threshold=score_threshold,
|
||||||
)
|
)
|
||||||
|
points = getattr(response, "points", []) or []
|
||||||
|
|
||||||
return [
|
return [
|
||||||
{
|
{
|
||||||
"id": str(r.id),
|
"id": str(r.id),
|
||||||
"score": r.score,
|
"score": float(r.score) if r.score is not None else 0.0,
|
||||||
"text": r.payload.get("text", ""),
|
"text": (r.payload or {}).get("text", ""),
|
||||||
"metadata": {k: v for k, v in r.payload.items() if k != "text"},
|
"metadata": {
|
||||||
|
k: v for k, v in (r.payload or {}).items() if k != "text"
|
||||||
|
},
|
||||||
}
|
}
|
||||||
for r in results
|
for r in points
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -0,0 +1,43 @@
|
|||||||
|
"""
|
||||||
|
ویوهای RAG — چت با استریم
|
||||||
|
"""
|
||||||
|
from django.http import StreamingHttpResponse
|
||||||
|
from rest_framework import status
|
||||||
|
from rest_framework.request import Request
|
||||||
|
from rest_framework.response import Response
|
||||||
|
from rest_framework.views import APIView
|
||||||
|
|
||||||
|
from .chat import chat_rag_stream
|
||||||
|
|
||||||
|
|
||||||
|
class ChatView(APIView):
|
||||||
|
"""
|
||||||
|
چت RAG با استریم.
|
||||||
|
POST با {"message": "متن سوال"} یا query param message
|
||||||
|
"""
|
||||||
|
|
||||||
|
def post(self, request: Request):
|
||||||
|
message = request.data.get("message") or request.query_params.get("message")
|
||||||
|
if not message or not isinstance(message, str):
|
||||||
|
return Response(
|
||||||
|
{"code": 400, "msg": "پارامتر message الزامی است."},
|
||||||
|
status=status.HTTP_400_BAD_REQUEST,
|
||||||
|
)
|
||||||
|
message = str(message).strip()
|
||||||
|
if not message:
|
||||||
|
return Response(
|
||||||
|
{"code": 400, "msg": "پیام نباید خالی باشد."},
|
||||||
|
status=status.HTTP_400_BAD_REQUEST,
|
||||||
|
)
|
||||||
|
|
||||||
|
def generate():
|
||||||
|
try:
|
||||||
|
for chunk in chat_rag_stream(message):
|
||||||
|
yield chunk
|
||||||
|
except Exception as e:
|
||||||
|
yield f"\n[خطا: {e}]"
|
||||||
|
|
||||||
|
return StreamingHttpResponse(
|
||||||
|
generate(),
|
||||||
|
content_type="text/plain; charset=utf-8",
|
||||||
|
)
|
||||||
Reference in New Issue
Block a user