335 lines
10 KiB
Python
335 lines
10 KiB
Python
"""
|
|
چت RAG برای API چت عمومی — با ارسال کامل داده مزرعه و retrieval تکمیلی از KB.
|
|
"""
|
|
import json
|
|
import logging
|
|
from pathlib import Path
|
|
|
|
from .api_provider import get_chat_client
|
|
from .chunker import chunk_text
|
|
from .config import RAGConfig, ServiceConfig, get_service_config, load_rag_config
|
|
from .retrieve import search_with_texts
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def _load_tone(config: RAGConfig | None) -> str:
|
|
"""بارگذاری فایل لحن پیشفرض (chat KB)."""
|
|
cfg = config or load_rag_config()
|
|
base = Path(__file__).resolve().parent.parent
|
|
chat_kb = cfg.knowledge_bases.get("chat")
|
|
if chat_kb:
|
|
tone_path = base / chat_kb.tone_file
|
|
if tone_path.exists():
|
|
return tone_path.read_text(encoding="utf-8").strip()
|
|
logger.warning("Default tone file not found: %s", tone_path)
|
|
return ""
|
|
|
|
|
|
def _load_service_tone(service: ServiceConfig, config: RAGConfig | None = None) -> str:
|
|
cfg = config or load_rag_config()
|
|
if service.tone_file:
|
|
base = Path(__file__).resolve().parent.parent
|
|
tone_path = base / service.tone_file
|
|
if tone_path.exists():
|
|
return tone_path.read_text(encoding="utf-8").strip()
|
|
logger.warning("Service tone file not found: %s", tone_path)
|
|
return _load_tone(cfg)
|
|
|
|
|
|
def _format_farm_context(farm_uuid: str) -> str:
|
|
from farm_data.services import get_farm_details
|
|
|
|
farm_details = get_farm_details(farm_uuid)
|
|
if not farm_details:
|
|
raise ValueError("farm_uuid نامعتبر است یا اطلاعات مزرعه پیدا نشد.")
|
|
|
|
serialized = json.dumps(
|
|
farm_details,
|
|
ensure_ascii=False,
|
|
indent=2,
|
|
default=str,
|
|
)
|
|
return "[اطلاعات کامل مزرعه]\n" + serialized
|
|
|
|
|
|
def _format_farm_context_from_details(farm_details: dict) -> str:
|
|
serialized = json.dumps(
|
|
farm_details,
|
|
ensure_ascii=False,
|
|
indent=2,
|
|
default=str,
|
|
)
|
|
return "[اطلاعات کامل مزرعه]\n" + serialized
|
|
|
|
|
|
def _load_farm_details_context(
|
|
sensor_uuid: str | None,
|
|
farm_details: dict | None = None,
|
|
) -> str:
|
|
if not sensor_uuid:
|
|
return ""
|
|
if farm_details is not None:
|
|
return _format_farm_context_from_details(farm_details)
|
|
return _format_farm_context(sensor_uuid)
|
|
|
|
|
|
def _build_system_prompt(
|
|
service: ServiceConfig,
|
|
query: str,
|
|
farm_context: str,
|
|
config: RAGConfig | None = None,
|
|
) -> str:
|
|
tone = _load_service_tone(service, config)
|
|
system_parts = [tone] if tone else []
|
|
if service.system_prompt:
|
|
system_parts.append(service.system_prompt)
|
|
system_parts.append(
|
|
"با استفاده از اطلاعات کامل مزرعه و اطلاعات بازیابیشده از پایگاه دانش که در ادامه آمده "
|
|
"به سوال کاربر پاسخ بده. "
|
|
"اگر دادهای در اطلاعات مزرعه وجود دارد، همان را مبنای پاسخ قرار بده و چیزی حدس نزن. "
|
|
"نتایج بازیابیشده از پایگاه دانش را برای تکمیل یا توضیح پاسخ استفاده کن. "
|
|
"اگر داده کافی نبود، این کمبود را شفاف بگو. "
|
|
"پاسخ را به زبان کاربر بنویس."
|
|
)
|
|
system_parts.append(farm_context)
|
|
system_parts.append(f"[سوال کاربر]\n{query}")
|
|
return "\n\n".join(part for part in system_parts if part)
|
|
|
|
|
|
def _create_audit_log(
|
|
farm_uuid: str,
|
|
service_id: str,
|
|
model: str,
|
|
query: str,
|
|
system_prompt: str,
|
|
messages: list[dict],
|
|
) -> "ChatAuditLog":
|
|
from .models import ChatAuditLog
|
|
|
|
log = ChatAuditLog.objects.create(
|
|
farm_uuid=farm_uuid,
|
|
service_id=service_id,
|
|
model=model,
|
|
user_query=query,
|
|
system_prompt=system_prompt,
|
|
messages=messages,
|
|
status=ChatAuditLog.STATUS_STARTED,
|
|
)
|
|
logger.info(
|
|
"Created chat audit log id=%s service_id=%s farm_uuid=%s model=%s",
|
|
log.id,
|
|
service_id,
|
|
farm_uuid,
|
|
model,
|
|
)
|
|
return log
|
|
|
|
|
|
def _complete_audit_log(audit_log: "ChatAuditLog", response_text: str) -> None:
|
|
from .models import ChatAuditLog
|
|
|
|
audit_log.response_text = response_text
|
|
audit_log.status = ChatAuditLog.STATUS_COMPLETED
|
|
audit_log.save(update_fields=["response_text", "status", "updated_at"])
|
|
|
|
|
|
def _fail_audit_log(
|
|
audit_log: "ChatAuditLog",
|
|
error_message: str,
|
|
response_text: str = "",
|
|
) -> None:
|
|
from .models import ChatAuditLog
|
|
|
|
audit_log.response_text = response_text
|
|
audit_log.error_message = error_message
|
|
audit_log.status = ChatAuditLog.STATUS_FAILED
|
|
audit_log.save(
|
|
update_fields=["response_text", "error_message", "status", "updated_at"]
|
|
)
|
|
|
|
|
|
def build_rag_context(
|
|
query: str,
|
|
sensor_uuid: str | None = None,
|
|
config: RAGConfig | None = None,
|
|
limit: int = 8,
|
|
kb_name: str | None = None,
|
|
service_id: str | None = None,
|
|
farm_details: dict | None = None,
|
|
) -> str:
|
|
"""
|
|
ساخت context مشترک برای همه سرویسهای RAG.
|
|
شامل:
|
|
- اطلاعات کامل مزرعه از farm_data/services.py
|
|
- جستجوی KB بر اساس پیام کاربر
|
|
- جستجوی KB بر اساس chunk های کامل داده مزرعه
|
|
"""
|
|
|
|
logger.info(
|
|
"Building RAG context sensor_uuid=%s kb_name=%s limit=%s query_len=%s",
|
|
sensor_uuid,
|
|
kb_name,
|
|
limit,
|
|
len(query or ""),
|
|
)
|
|
parts: list[str] = []
|
|
cfg = config or load_rag_config()
|
|
service = get_service_config(service_id, cfg) if service_id else None
|
|
include_user_embeddings = service.use_user_embeddings if service else True
|
|
resolved_kb_name = kb_name or (service.knowledge_base if service else None)
|
|
farm_context = _load_farm_details_context(
|
|
sensor_uuid=sensor_uuid,
|
|
farm_details=farm_details,
|
|
)
|
|
|
|
if farm_context:
|
|
parts.append(farm_context)
|
|
|
|
search_texts = [query]
|
|
if farm_context:
|
|
search_texts.extend(chunk_text(farm_context, config=cfg))
|
|
|
|
results = search_with_texts(
|
|
search_texts,
|
|
sensor_uuid=sensor_uuid,
|
|
limit=limit,
|
|
per_text_limit=3,
|
|
config=cfg,
|
|
kb_name=resolved_kb_name,
|
|
service_id=service_id,
|
|
use_user_embeddings=include_user_embeddings,
|
|
)
|
|
if results:
|
|
rag_texts = [r.get("text", "").strip() for r in results if r.get("text")]
|
|
if rag_texts:
|
|
parts.append("[متنهای مرجع]\n" + "\n\n---\n\n".join(rag_texts))
|
|
|
|
return "\n\n---\n\n".join(parts) if parts else ""
|
|
|
|
|
|
def chat_rag_stream(
|
|
query: str,
|
|
farm_uuid: str,
|
|
config: RAGConfig | None = None,
|
|
system_override: str | None = None,
|
|
farm_details: dict | None = None,
|
|
):
|
|
"""
|
|
چت استریمی با سرویس ثابت `chat` و context مستقیم مزرعه.
|
|
|
|
Args:
|
|
query: پیام کاربر
|
|
farm_uuid: شناسه مزرعه
|
|
config: تنظیمات RAG
|
|
system_override: جایگزین system prompt (اختیاری)
|
|
|
|
Yields:
|
|
chunk های استریم پاسخ مدل
|
|
"""
|
|
cfg = config or load_rag_config()
|
|
service_id = "chat"
|
|
service = get_service_config(service_id, cfg)
|
|
service_llm_config = service.llm
|
|
service_cfg = RAGConfig(
|
|
embedding=cfg.embedding,
|
|
qdrant=cfg.qdrant,
|
|
chunking=cfg.chunking,
|
|
llm=service_llm_config,
|
|
knowledge_bases=cfg.knowledge_bases,
|
|
services=cfg.services,
|
|
chromadb=cfg.chromadb,
|
|
)
|
|
client = get_chat_client(service_cfg)
|
|
model = service_llm_config.model
|
|
|
|
logger.info(
|
|
"chat_rag_stream started service_id=%s farm_uuid=%s query_len=%s",
|
|
service_id,
|
|
farm_uuid,
|
|
len(query or ""),
|
|
)
|
|
|
|
context = build_rag_context(
|
|
query=query,
|
|
sensor_uuid=farm_uuid,
|
|
config=cfg,
|
|
service_id=service_id,
|
|
farm_details=farm_details,
|
|
)
|
|
logger.info(
|
|
"Loaded augmented context for farm_uuid=%s context_len=%s",
|
|
farm_uuid,
|
|
len(context),
|
|
)
|
|
|
|
if system_override is not None:
|
|
system_prompt = system_override
|
|
else:
|
|
system_prompt = _build_system_prompt(service, query, context, cfg)
|
|
|
|
messages = [
|
|
{"role": "system", "content": system_prompt},
|
|
{"role": "user", "content": query},
|
|
]
|
|
|
|
logger.info(
|
|
"Final prompt prepared service_id=%s farm_uuid=%s model=%s messages_count=%s",
|
|
service_id,
|
|
farm_uuid,
|
|
model,
|
|
len(messages),
|
|
)
|
|
logger.info("Final system prompt for farm_uuid=%s:\n%s", farm_uuid, system_prompt)
|
|
|
|
audit_log = _create_audit_log(
|
|
farm_uuid=farm_uuid,
|
|
service_id=service_id,
|
|
model=model,
|
|
query=query,
|
|
system_prompt=system_prompt,
|
|
messages=messages,
|
|
)
|
|
|
|
response_chunks: list[str] = []
|
|
try:
|
|
stream = client.chat.completions.create(
|
|
model=model,
|
|
messages=messages,
|
|
stream=True,
|
|
)
|
|
logger.info(
|
|
"Started streaming response id=%s service_id=%s farm_uuid=%s",
|
|
audit_log.id,
|
|
service_id,
|
|
farm_uuid,
|
|
)
|
|
|
|
for chunk in stream:
|
|
delta = chunk.choices[0].delta if chunk.choices else None
|
|
content = delta.content if delta else ""
|
|
if content:
|
|
response_chunks.append(content)
|
|
yield content
|
|
|
|
full_response = "".join(response_chunks)
|
|
_complete_audit_log(audit_log, full_response)
|
|
logger.info(
|
|
"Completed chat response id=%s farm_uuid=%s response_len=%s response=\n%s",
|
|
audit_log.id,
|
|
farm_uuid,
|
|
len(full_response),
|
|
full_response,
|
|
)
|
|
except Exception as exc:
|
|
partial_response = "".join(response_chunks)
|
|
_fail_audit_log(audit_log, str(exc), partial_response)
|
|
logger.exception(
|
|
"Chat request failed id=%s service_id=%s farm_uuid=%s partial_response_len=%s",
|
|
audit_log.id,
|
|
service_id,
|
|
farm_uuid,
|
|
len(partial_response),
|
|
)
|
|
raise
|