UPDATE
This commit is contained in:
@@ -0,0 +1,434 @@
|
||||
"""
|
||||
چت RAG برای API چت عمومی — با ارسال کامل داده مزرعه و retrieval تکمیلی از KB.
|
||||
"""
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
import mimetypes
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from .api_provider import get_chat_client
|
||||
from .chunker import chunk_text
|
||||
from .config import RAGConfig, ServiceConfig, get_service_config, load_rag_config
|
||||
from .retrieve import search_with_texts
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _coerce_text_content(value: Any) -> str:
|
||||
if value is None:
|
||||
return ""
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
if isinstance(value, list):
|
||||
parts: list[str] = []
|
||||
for item in value:
|
||||
if isinstance(item, dict) and item.get("type") == "text":
|
||||
text_value = item.get("text")
|
||||
if isinstance(text_value, str) and text_value.strip():
|
||||
parts.append(text_value.strip())
|
||||
elif isinstance(item, str) and item.strip():
|
||||
parts.append(item.strip())
|
||||
return "\n".join(parts)
|
||||
return str(value)
|
||||
|
||||
|
||||
def _normalize_image_inputs(images: list[Any] | None) -> list[dict[str, str]]:
|
||||
normalized: list[dict[str, str]] = []
|
||||
for item in images or []:
|
||||
if isinstance(item, str):
|
||||
value = item.strip()
|
||||
if value:
|
||||
normalized.append({"url": value})
|
||||
continue
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
url = item.get("url") or item.get("image_url") or item.get("data_url")
|
||||
if not isinstance(url, str) or not url.strip():
|
||||
continue
|
||||
entry = {"url": url.strip()}
|
||||
detail = item.get("detail")
|
||||
if isinstance(detail, str) and detail.strip():
|
||||
entry["detail"] = detail.strip()
|
||||
normalized.append(entry)
|
||||
return normalized
|
||||
|
||||
|
||||
def _build_content_parts(text: str, images: list[dict[str, str]] | None = None) -> str | list[dict[str, Any]]:
|
||||
normalized_text = (text or "").strip()
|
||||
normalized_images = _normalize_image_inputs(images)
|
||||
if not normalized_images:
|
||||
return normalized_text
|
||||
|
||||
parts: list[dict[str, Any]] = []
|
||||
if normalized_text:
|
||||
parts.append({"type": "text", "text": normalized_text})
|
||||
for image in normalized_images:
|
||||
image_payload: dict[str, Any] = {"url": image["url"]}
|
||||
if image.get("detail"):
|
||||
image_payload["detail"] = image["detail"]
|
||||
parts.append({"type": "image_url", "image_url": image_payload})
|
||||
return parts
|
||||
|
||||
|
||||
def _normalize_history_messages(history: list[dict[str, Any]] | None) -> list[dict[str, Any]]:
|
||||
normalized: list[dict[str, Any]] = []
|
||||
for item in history or []:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
role = str(item.get("role") or "").strip().lower()
|
||||
if role not in {"user", "assistant"}:
|
||||
continue
|
||||
text = _coerce_text_content(
|
||||
item.get("content", item.get("message", item.get("text")))
|
||||
).strip()
|
||||
images = _normalize_image_inputs(item.get("images") or item.get("image_urls"))
|
||||
if not text and not images:
|
||||
continue
|
||||
content = _build_content_parts(text, images if role == "user" else None)
|
||||
normalized.append({"role": role, "content": content})
|
||||
return normalized
|
||||
|
||||
|
||||
def encode_uploaded_image(uploaded_file: Any) -> dict[str, str]:
|
||||
content_type = getattr(uploaded_file, "content_type", None) or mimetypes.guess_type(
|
||||
getattr(uploaded_file, "name", "")
|
||||
)[0] or "application/octet-stream"
|
||||
raw = uploaded_file.read()
|
||||
if not isinstance(raw, (bytes, bytearray)):
|
||||
raise ValueError("Uploaded image payload is invalid.")
|
||||
encoded = base64.b64encode(raw).decode("ascii")
|
||||
return {
|
||||
"url": f"data:{content_type};base64,{encoded}",
|
||||
"detail": "auto",
|
||||
}
|
||||
|
||||
|
||||
def _load_tone(config: RAGConfig | None) -> str:
|
||||
"""بارگذاری فایل لحن پیشفرض (chat KB)."""
|
||||
cfg = config or load_rag_config()
|
||||
base = Path(__file__).resolve().parent.parent
|
||||
chat_kb = cfg.knowledge_bases.get("chat")
|
||||
if chat_kb:
|
||||
tone_path = base / chat_kb.tone_file
|
||||
if tone_path.exists():
|
||||
return tone_path.read_text(encoding="utf-8").strip()
|
||||
logger.warning("Default tone file not found: %s", tone_path)
|
||||
return ""
|
||||
|
||||
|
||||
def _load_service_tone(service: ServiceConfig, config: RAGConfig | None = None) -> str:
|
||||
cfg = config or load_rag_config()
|
||||
if service.tone_file:
|
||||
base = Path(__file__).resolve().parent.parent
|
||||
tone_path = base / service.tone_file
|
||||
if tone_path.exists():
|
||||
return tone_path.read_text(encoding="utf-8").strip()
|
||||
logger.warning("Service tone file not found: %s", tone_path)
|
||||
return _load_tone(cfg)
|
||||
|
||||
|
||||
def _format_farm_context(farm_uuid: str) -> str:
|
||||
from farm_data.services import get_farm_details
|
||||
|
||||
farm_details = get_farm_details(farm_uuid)
|
||||
if not farm_details:
|
||||
raise ValueError("farm_uuid نامعتبر است یا اطلاعات مزرعه پیدا نشد.")
|
||||
|
||||
serialized = json.dumps(
|
||||
farm_details,
|
||||
ensure_ascii=False,
|
||||
indent=2,
|
||||
default=str,
|
||||
)
|
||||
return "[اطلاعات کامل مزرعه]\n" + serialized
|
||||
|
||||
|
||||
def _format_farm_context_from_details(farm_details: dict) -> str:
|
||||
serialized = json.dumps(
|
||||
farm_details,
|
||||
ensure_ascii=False,
|
||||
indent=2,
|
||||
default=str,
|
||||
)
|
||||
return "[اطلاعات کامل مزرعه]\n" + serialized
|
||||
|
||||
|
||||
def _load_farm_details_context(
|
||||
sensor_uuid: str | None,
|
||||
farm_details: dict | None = None,
|
||||
) -> str:
|
||||
if not sensor_uuid:
|
||||
return ""
|
||||
if farm_details is not None:
|
||||
return _format_farm_context_from_details(farm_details)
|
||||
return _format_farm_context(sensor_uuid)
|
||||
|
||||
|
||||
def _build_system_prompt(
|
||||
service: ServiceConfig,
|
||||
query: str,
|
||||
farm_context: str,
|
||||
config: RAGConfig | None = None,
|
||||
) -> str:
|
||||
tone = _load_service_tone(service, config)
|
||||
system_parts = [tone] if tone else []
|
||||
if service.system_prompt:
|
||||
system_parts.append(service.system_prompt)
|
||||
system_parts.append(
|
||||
"با استفاده از اطلاعات کامل مزرعه و اطلاعات بازیابیشده از پایگاه دانش که در ادامه آمده "
|
||||
"به سوال کاربر پاسخ بده. "
|
||||
"اگر دادهای در اطلاعات مزرعه وجود دارد، همان را مبنای پاسخ قرار بده و چیزی حدس نزن. "
|
||||
"نتایج بازیابیشده از پایگاه دانش را برای تکمیل یا توضیح پاسخ استفاده کن. "
|
||||
"اگر داده کافی نبود، این کمبود را شفاف بگو. "
|
||||
"پاسخ را به زبان کاربر بنویس."
|
||||
)
|
||||
system_parts.append(farm_context)
|
||||
system_parts.append(f"[سوال کاربر]\n{query}")
|
||||
return "\n\n".join(part for part in system_parts if part)
|
||||
|
||||
|
||||
def _create_audit_log(
|
||||
farm_uuid: str,
|
||||
service_id: str,
|
||||
model: str,
|
||||
query: str,
|
||||
system_prompt: str,
|
||||
messages: list[dict],
|
||||
) -> "ChatAuditLog":
|
||||
from .models import ChatAuditLog
|
||||
|
||||
log = ChatAuditLog.objects.create(
|
||||
farm_uuid=farm_uuid,
|
||||
service_id=service_id,
|
||||
model=model,
|
||||
user_query=query,
|
||||
system_prompt=system_prompt,
|
||||
messages=messages,
|
||||
status=ChatAuditLog.STATUS_STARTED,
|
||||
)
|
||||
logger.info(
|
||||
"Created chat audit log id=%s service_id=%s farm_uuid=%s model=%s",
|
||||
log.id,
|
||||
service_id,
|
||||
farm_uuid,
|
||||
model,
|
||||
)
|
||||
return log
|
||||
|
||||
|
||||
def _complete_audit_log(audit_log: "ChatAuditLog", response_text: str) -> None:
|
||||
from .models import ChatAuditLog
|
||||
|
||||
audit_log.response_text = response_text
|
||||
audit_log.status = ChatAuditLog.STATUS_COMPLETED
|
||||
audit_log.save(update_fields=["response_text", "status", "updated_at"])
|
||||
|
||||
|
||||
def _fail_audit_log(
|
||||
audit_log: "ChatAuditLog",
|
||||
error_message: str,
|
||||
response_text: str = "",
|
||||
) -> None:
|
||||
from .models import ChatAuditLog
|
||||
|
||||
audit_log.response_text = response_text
|
||||
audit_log.error_message = error_message
|
||||
audit_log.status = ChatAuditLog.STATUS_FAILED
|
||||
audit_log.save(
|
||||
update_fields=["response_text", "error_message", "status", "updated_at"]
|
||||
)
|
||||
|
||||
|
||||
def build_rag_context(
|
||||
query: str,
|
||||
sensor_uuid: str | None = None,
|
||||
config: RAGConfig | None = None,
|
||||
limit: int = 8,
|
||||
kb_name: str | None = None,
|
||||
service_id: str | None = None,
|
||||
farm_details: dict | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
ساخت context مشترک برای همه سرویسهای RAG.
|
||||
شامل:
|
||||
- اطلاعات کامل مزرعه از farm_data/services.py
|
||||
- جستجوی KB بر اساس پیام کاربر
|
||||
- جستجوی KB بر اساس chunk های کامل داده مزرعه
|
||||
"""
|
||||
|
||||
logger.info(
|
||||
"Building RAG context sensor_uuid=%s kb_name=%s limit=%s query_len=%s",
|
||||
sensor_uuid,
|
||||
kb_name,
|
||||
limit,
|
||||
len(query or ""),
|
||||
)
|
||||
parts: list[str] = []
|
||||
cfg = config or load_rag_config()
|
||||
service = get_service_config(service_id, cfg) if service_id else None
|
||||
include_user_embeddings = service.use_user_embeddings if service else True
|
||||
resolved_kb_name = kb_name or (service.knowledge_base if service else None)
|
||||
farm_context = _load_farm_details_context(
|
||||
sensor_uuid=sensor_uuid,
|
||||
farm_details=farm_details,
|
||||
)
|
||||
|
||||
if farm_context:
|
||||
parts.append(farm_context)
|
||||
|
||||
search_texts = [query]
|
||||
if farm_context:
|
||||
search_texts.extend(chunk_text(farm_context, config=cfg))
|
||||
|
||||
results = search_with_texts(
|
||||
search_texts,
|
||||
sensor_uuid=sensor_uuid,
|
||||
limit=limit,
|
||||
per_text_limit=3,
|
||||
config=cfg,
|
||||
kb_name=resolved_kb_name,
|
||||
service_id=service_id,
|
||||
use_user_embeddings=include_user_embeddings,
|
||||
)
|
||||
if results:
|
||||
rag_texts = [r.get("text", "").strip() for r in results if r.get("text")]
|
||||
if rag_texts:
|
||||
parts.append("[متنهای مرجع]\n" + "\n\n---\n\n".join(rag_texts))
|
||||
|
||||
return "\n\n---\n\n".join(parts) if parts else ""
|
||||
|
||||
|
||||
def chat_rag_stream(
|
||||
query: str,
|
||||
farm_uuid: str,
|
||||
config: RAGConfig | None = None,
|
||||
system_override: str | None = None,
|
||||
farm_details: dict | None = None,
|
||||
history: list[dict[str, Any]] | None = None,
|
||||
images: list[dict[str, str]] | None = None,
|
||||
):
|
||||
"""
|
||||
چت استریمی با سرویس ثابت `chat` و context مستقیم مزرعه.
|
||||
|
||||
Args:
|
||||
query: پیام کاربر
|
||||
farm_uuid: شناسه مزرعه
|
||||
config: تنظیمات RAG
|
||||
system_override: جایگزین system prompt (اختیاری)
|
||||
history: لیست پیام های قبلی کاربر/هوش مصنوعی
|
||||
images: تصاویر مربوط به پیام فعلی کاربر
|
||||
|
||||
Yields:
|
||||
chunk های استریم پاسخ مدل
|
||||
"""
|
||||
cfg = config or load_rag_config()
|
||||
service_id = "chat"
|
||||
service = get_service_config(service_id, cfg)
|
||||
service_llm_config = service.llm
|
||||
service_cfg = RAGConfig(
|
||||
embedding=cfg.embedding,
|
||||
qdrant=cfg.qdrant,
|
||||
chunking=cfg.chunking,
|
||||
llm=service_llm_config,
|
||||
knowledge_bases=cfg.knowledge_bases,
|
||||
services=cfg.services,
|
||||
chromadb=cfg.chromadb,
|
||||
)
|
||||
client = get_chat_client(service_cfg)
|
||||
model = service_llm_config.model
|
||||
|
||||
logger.info(
|
||||
"chat_rag_stream started service_id=%s farm_uuid=%s query_len=%s",
|
||||
service_id,
|
||||
farm_uuid,
|
||||
len(query or ""),
|
||||
)
|
||||
|
||||
context = build_rag_context(
|
||||
query=query,
|
||||
sensor_uuid=farm_uuid,
|
||||
config=cfg,
|
||||
service_id=service_id,
|
||||
farm_details=farm_details,
|
||||
)
|
||||
logger.info(
|
||||
"Loaded augmented context for farm_uuid=%s context_len=%s",
|
||||
farm_uuid,
|
||||
len(context),
|
||||
)
|
||||
|
||||
if system_override is not None:
|
||||
system_prompt = system_override
|
||||
else:
|
||||
system_prompt = _build_system_prompt(
|
||||
service,
|
||||
query,
|
||||
context,
|
||||
config=cfg,
|
||||
)
|
||||
|
||||
messages = [{"role": "system", "content": system_prompt}]
|
||||
messages.extend(_normalize_history_messages(history))
|
||||
messages.append({"role": "user", "content": _build_content_parts(query, images)})
|
||||
|
||||
logger.info(
|
||||
"Final prompt prepared service_id=%s farm_uuid=%s model=%s messages_count=%s",
|
||||
service_id,
|
||||
farm_uuid,
|
||||
model,
|
||||
len(messages),
|
||||
)
|
||||
logger.info("Final system prompt for farm_uuid=%s:\n%s", farm_uuid, system_prompt)
|
||||
|
||||
audit_log = _create_audit_log(
|
||||
farm_uuid=farm_uuid,
|
||||
service_id=service_id,
|
||||
model=model,
|
||||
query=query,
|
||||
system_prompt=system_prompt,
|
||||
messages=messages,
|
||||
)
|
||||
|
||||
response_chunks: list[str] = []
|
||||
try:
|
||||
stream = client.chat.completions.create(
|
||||
model=model,
|
||||
messages=messages,
|
||||
stream=True,
|
||||
)
|
||||
logger.info(
|
||||
"Started streaming response id=%s service_id=%s farm_uuid=%s",
|
||||
audit_log.id,
|
||||
service_id,
|
||||
farm_uuid,
|
||||
)
|
||||
|
||||
for chunk in stream:
|
||||
delta = chunk.choices[0].delta if chunk.choices else None
|
||||
content = delta.content if delta else ""
|
||||
if content:
|
||||
response_chunks.append(content)
|
||||
yield content
|
||||
|
||||
full_response = "".join(response_chunks)
|
||||
_complete_audit_log(audit_log, full_response)
|
||||
logger.info(
|
||||
"Completed chat response id=%s farm_uuid=%s response_len=%s response=\n%s",
|
||||
audit_log.id,
|
||||
farm_uuid,
|
||||
len(full_response),
|
||||
full_response,
|
||||
)
|
||||
except Exception as exc:
|
||||
partial_response = "".join(response_chunks)
|
||||
_fail_audit_log(audit_log, str(exc), partial_response)
|
||||
logger.exception(
|
||||
"Chat request failed id=%s service_id=%s farm_uuid=%s partial_response_len=%s",
|
||||
audit_log.id,
|
||||
service_id,
|
||||
farm_uuid,
|
||||
len(partial_response),
|
||||
)
|
||||
raise
|
||||
Reference in New Issue
Block a user