Files
Ai/rag/chat.py
T
2026-04-25 17:22:41 +03:30

430 lines
14 KiB
Python

"""
چت RAG برای API چت عمومی — با ارسال کامل داده مزرعه و retrieval تکمیلی از KB.
"""
import base64
import json
import logging
import mimetypes
from pathlib import Path
from typing import Any
from .api_provider import get_chat_client
from .chunker import chunk_text
from .config import RAGConfig, ServiceConfig, get_service_config, load_rag_config
from .retrieve import search_with_texts
logger = logging.getLogger(__name__)
def _coerce_text_content(value: Any) -> str:
if value is None:
return ""
if isinstance(value, str):
return value
if isinstance(value, list):
parts: list[str] = []
for item in value:
if isinstance(item, dict) and item.get("type") == "text":
text_value = item.get("text")
if isinstance(text_value, str) and text_value.strip():
parts.append(text_value.strip())
elif isinstance(item, str) and item.strip():
parts.append(item.strip())
return "\n".join(parts)
return str(value)
def _normalize_image_inputs(images: list[Any] | None) -> list[dict[str, str]]:
normalized: list[dict[str, str]] = []
for item in images or []:
if isinstance(item, str):
value = item.strip()
if value:
normalized.append({"url": value})
continue
if not isinstance(item, dict):
continue
url = item.get("url") or item.get("image_url") or item.get("data_url")
if not isinstance(url, str) or not url.strip():
continue
entry = {"url": url.strip()}
detail = item.get("detail")
if isinstance(detail, str) and detail.strip():
entry["detail"] = detail.strip()
normalized.append(entry)
return normalized
def _build_content_parts(text: str, images: list[dict[str, str]] | None = None) -> str | list[dict[str, Any]]:
normalized_text = (text or "").strip()
normalized_images = _normalize_image_inputs(images)
if not normalized_images:
return normalized_text
parts: list[dict[str, Any]] = []
if normalized_text:
parts.append({"type": "text", "text": normalized_text})
for image in normalized_images:
image_payload: dict[str, Any] = {"url": image["url"]}
if image.get("detail"):
image_payload["detail"] = image["detail"]
parts.append({"type": "image_url", "image_url": image_payload})
return parts
def _normalize_history_messages(history: list[dict[str, Any]] | None) -> list[dict[str, Any]]:
normalized: list[dict[str, Any]] = []
for item in history or []:
if not isinstance(item, dict):
continue
role = str(item.get("role") or "").strip().lower()
if role not in {"user", "assistant"}:
continue
text = _coerce_text_content(
item.get("content", item.get("message", item.get("text")))
).strip()
images = _normalize_image_inputs(item.get("images") or item.get("image_urls"))
if not text and not images:
continue
content = _build_content_parts(text, images if role == "user" else None)
normalized.append({"role": role, "content": content})
return normalized
def encode_uploaded_image(uploaded_file: Any) -> dict[str, str]:
content_type = getattr(uploaded_file, "content_type", None) or mimetypes.guess_type(
getattr(uploaded_file, "name", "")
)[0] or "application/octet-stream"
raw = uploaded_file.read()
if not isinstance(raw, (bytes, bytearray)):
raise ValueError("Uploaded image payload is invalid.")
encoded = base64.b64encode(raw).decode("ascii")
return {
"url": f"data:{content_type};base64,{encoded}",
"detail": "auto",
}
def _load_tone(config: RAGConfig | None) -> str:
"""بارگذاری فایل لحن پیش‌فرض (chat KB)."""
cfg = config or load_rag_config()
base = Path(__file__).resolve().parent.parent
chat_kb = cfg.knowledge_bases.get("chat")
if chat_kb:
tone_path = base / chat_kb.tone_file
if tone_path.exists():
return tone_path.read_text(encoding="utf-8").strip()
logger.warning("Default tone file not found: %s", tone_path)
return ""
def _load_service_tone(service: ServiceConfig, config: RAGConfig | None = None) -> str:
cfg = config or load_rag_config()
if service.tone_file:
base = Path(__file__).resolve().parent.parent
tone_path = base / service.tone_file
if tone_path.exists():
return tone_path.read_text(encoding="utf-8").strip()
logger.warning("Service tone file not found: %s", tone_path)
return _load_tone(cfg)
def _format_farm_context(farm_uuid: str) -> str:
from farm_data.services import get_farm_details
farm_details = get_farm_details(farm_uuid)
if not farm_details:
raise ValueError("farm_uuid نامعتبر است یا اطلاعات مزرعه پیدا نشد.")
serialized = json.dumps(
farm_details,
ensure_ascii=False,
indent=2,
default=str,
)
return "[اطلاعات کامل مزرعه]\n" + serialized
def _format_farm_context_from_details(farm_details: dict) -> str:
serialized = json.dumps(
farm_details,
ensure_ascii=False,
indent=2,
default=str,
)
return "[اطلاعات کامل مزرعه]\n" + serialized
def _load_farm_details_context(
sensor_uuid: str | None,
farm_details: dict | None = None,
) -> str:
if not sensor_uuid:
return ""
if farm_details is not None:
return _format_farm_context_from_details(farm_details)
return _format_farm_context(sensor_uuid)
def _build_system_prompt(
service: ServiceConfig,
query: str,
farm_context: str,
config: RAGConfig | None = None,
) -> str:
tone = _load_service_tone(service, config)
system_parts = [tone] if tone else []
if service.system_prompt:
system_parts.append(service.system_prompt)
system_parts.append(
"با استفاده از اطلاعات کامل مزرعه و اطلاعات بازیابی‌شده از پایگاه دانش که در ادامه آمده "
"به سوال کاربر پاسخ بده. "
"اگر داده‌ای در اطلاعات مزرعه وجود دارد، همان را مبنای پاسخ قرار بده و چیزی حدس نزن. "
"نتایج بازیابی‌شده از پایگاه دانش را برای تکمیل یا توضیح پاسخ استفاده کن. "
"اگر داده کافی نبود، این کمبود را شفاف بگو. "
"پاسخ را به زبان کاربر بنویس."
)
system_parts.append(farm_context)
system_parts.append(f"[سوال کاربر]\n{query}")
return "\n\n".join(part for part in system_parts if part)
def _create_audit_log(
farm_uuid: str,
service_id: str,
model: str,
query: str,
system_prompt: str,
messages: list[dict],
) -> "ChatAuditLog":
from .models import ChatAuditLog
log = ChatAuditLog.objects.create(
farm_uuid=farm_uuid,
service_id=service_id,
model=model,
user_query=query,
system_prompt=system_prompt,
messages=messages,
status=ChatAuditLog.STATUS_STARTED,
)
logger.info(
"Created chat audit log id=%s service_id=%s farm_uuid=%s model=%s",
log.id,
service_id,
farm_uuid,
model,
)
return log
def _complete_audit_log(audit_log: "ChatAuditLog", response_text: str) -> None:
from .models import ChatAuditLog
audit_log.response_text = response_text
audit_log.status = ChatAuditLog.STATUS_COMPLETED
audit_log.save(update_fields=["response_text", "status", "updated_at"])
def _fail_audit_log(
audit_log: "ChatAuditLog",
error_message: str,
response_text: str = "",
) -> None:
from .models import ChatAuditLog
audit_log.response_text = response_text
audit_log.error_message = error_message
audit_log.status = ChatAuditLog.STATUS_FAILED
audit_log.save(
update_fields=["response_text", "error_message", "status", "updated_at"]
)
def build_rag_context(
query: str,
sensor_uuid: str | None = None,
config: RAGConfig | None = None,
limit: int = 8,
kb_name: str | None = None,
service_id: str | None = None,
farm_details: dict | None = None,
) -> str:
"""
ساخت context مشترک برای همه سرویس‌های RAG.
شامل:
- اطلاعات کامل مزرعه از farm_data/services.py
- جستجوی KB بر اساس پیام کاربر
- جستجوی KB بر اساس chunk های کامل داده مزرعه
"""
logger.info(
"Building RAG context sensor_uuid=%s kb_name=%s limit=%s query_len=%s",
sensor_uuid,
kb_name,
limit,
len(query or ""),
)
parts: list[str] = []
cfg = config or load_rag_config()
service = get_service_config(service_id, cfg) if service_id else None
include_user_embeddings = service.use_user_embeddings if service else True
resolved_kb_name = kb_name or (service.knowledge_base if service else None)
farm_context = _load_farm_details_context(
sensor_uuid=sensor_uuid,
farm_details=farm_details,
)
if farm_context:
parts.append(farm_context)
search_texts = [query]
if farm_context:
search_texts.extend(chunk_text(farm_context, config=cfg))
results = search_with_texts(
search_texts,
sensor_uuid=sensor_uuid,
limit=limit,
per_text_limit=3,
config=cfg,
kb_name=resolved_kb_name,
service_id=service_id,
use_user_embeddings=include_user_embeddings,
)
if results:
rag_texts = [r.get("text", "").strip() for r in results if r.get("text")]
if rag_texts:
parts.append("[متن‌های مرجع]\n" + "\n\n---\n\n".join(rag_texts))
return "\n\n---\n\n".join(parts) if parts else ""
def chat_rag_stream(
query: str,
farm_uuid: str,
config: RAGConfig | None = None,
system_override: str | None = None,
farm_details: dict | None = None,
history: list[dict[str, Any]] | None = None,
images: list[dict[str, str]] | None = None,
):
"""
چت استریمی با سرویس ثابت `chat` و context مستقیم مزرعه.
Args:
query: پیام کاربر
farm_uuid: شناسه مزرعه
config: تنظیمات RAG
system_override: جایگزین system prompt (اختیاری)
history: لیست پیام های قبلی کاربر/هوش مصنوعی
images: تصاویر مربوط به پیام فعلی کاربر
Yields:
chunk های استریم پاسخ مدل
"""
cfg = config or load_rag_config()
service_id = "chat"
service = get_service_config(service_id, cfg)
service_llm_config = service.llm
service_cfg = RAGConfig(
embedding=cfg.embedding,
qdrant=cfg.qdrant,
chunking=cfg.chunking,
llm=service_llm_config,
knowledge_bases=cfg.knowledge_bases,
services=cfg.services,
chromadb=cfg.chromadb,
)
client = get_chat_client(service_cfg)
model = service_llm_config.model
logger.info(
"chat_rag_stream started service_id=%s farm_uuid=%s query_len=%s",
service_id,
farm_uuid,
len(query or ""),
)
context = build_rag_context(
query=query,
sensor_uuid=farm_uuid,
config=cfg,
service_id=service_id,
farm_details=farm_details,
)
logger.info(
"Loaded augmented context for farm_uuid=%s context_len=%s",
farm_uuid,
len(context),
)
if system_override is not None:
system_prompt = system_override
else:
system_prompt = _build_system_prompt(service, query, context, cfg)
messages = [{"role": "system", "content": system_prompt}]
messages.extend(_normalize_history_messages(history))
messages.append({"role": "user", "content": _build_content_parts(query, images)})
logger.info(
"Final prompt prepared service_id=%s farm_uuid=%s model=%s messages_count=%s",
service_id,
farm_uuid,
model,
len(messages),
)
logger.info("Final system prompt for farm_uuid=%s:\n%s", farm_uuid, system_prompt)
audit_log = _create_audit_log(
farm_uuid=farm_uuid,
service_id=service_id,
model=model,
query=query,
system_prompt=system_prompt,
messages=messages,
)
response_chunks: list[str] = []
try:
stream = client.chat.completions.create(
model=model,
messages=messages,
stream=True,
)
logger.info(
"Started streaming response id=%s service_id=%s farm_uuid=%s",
audit_log.id,
service_id,
farm_uuid,
)
for chunk in stream:
delta = chunk.choices[0].delta if chunk.choices else None
content = delta.content if delta else ""
if content:
response_chunks.append(content)
yield content
full_response = "".join(response_chunks)
_complete_audit_log(audit_log, full_response)
logger.info(
"Completed chat response id=%s farm_uuid=%s response_len=%s response=\n%s",
audit_log.id,
farm_uuid,
len(full_response),
full_response,
)
except Exception as exc:
partial_response = "".join(response_chunks)
_fail_audit_log(audit_log, str(exc), partial_response)
logger.exception(
"Chat request failed id=%s service_id=%s farm_uuid=%s partial_response_len=%s",
audit_log.id,
service_id,
farm_uuid,
len(partial_response),
)
raise