UPDATE
This commit is contained in:
+44
-11
@@ -1,9 +1,12 @@
|
||||
"""
|
||||
سرویس تعبیهسازی متن — از Adapter Pattern برای سوئیچ بین providers استفاده میکند
|
||||
"""
|
||||
import logging
|
||||
import time
|
||||
|
||||
from .api_provider import get_embedding_client
|
||||
from .config import RAGConfig, load_rag_config
|
||||
import logging
|
||||
from .observability import classify_exception, log_event, observe_operation, record_metric
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -26,12 +29,13 @@ def embed_texts(
|
||||
لیست وکتورها
|
||||
"""
|
||||
if not texts:
|
||||
record_metric("rag.embedding.empty_input", operation="embed_texts")
|
||||
return []
|
||||
|
||||
cfg = config or load_rag_config()
|
||||
client = get_embedding_client(cfg)
|
||||
model_name = model or cfg.embedding.model
|
||||
logger.info(model_name)
|
||||
provider = cfg.embedding.provider or "unknown"
|
||||
batch_size = cfg.embedding.batch_size
|
||||
|
||||
all_embeddings: list[list[float]] = []
|
||||
@@ -39,15 +43,44 @@ def embed_texts(
|
||||
if dimensions is not None:
|
||||
extra["dimensions"] = dimensions
|
||||
|
||||
for i in range(0, len(texts), batch_size):
|
||||
batch = texts[i : i + batch_size]
|
||||
resp = client.embeddings.create(
|
||||
model=model_name,
|
||||
input=batch,
|
||||
**extra,
|
||||
)
|
||||
for item in sorted(resp.data, key=lambda x: x.index):
|
||||
all_embeddings.append(item.embedding)
|
||||
with observe_operation(source="rag.embedding", provider=provider, operation="embed_texts"):
|
||||
for i in range(0, len(texts), batch_size):
|
||||
batch = texts[i : i + batch_size]
|
||||
started_at = time.monotonic()
|
||||
try:
|
||||
resp = client.embeddings.create(
|
||||
model=model_name,
|
||||
input=batch,
|
||||
**extra,
|
||||
)
|
||||
except Exception as exc:
|
||||
failure = classify_exception(exc)
|
||||
log_event(
|
||||
level=logging.ERROR,
|
||||
message="embedding batch request failed",
|
||||
source="rag.embedding",
|
||||
provider=provider,
|
||||
operation="embed_batch",
|
||||
result_status="error",
|
||||
duration_ms=(time.monotonic() - started_at) * 1000,
|
||||
error_code=failure.error_code,
|
||||
batch_size=len(batch),
|
||||
model=model_name,
|
||||
)
|
||||
raise
|
||||
for item in sorted(resp.data, key=lambda x: x.index):
|
||||
all_embeddings.append(item.embedding)
|
||||
log_event(
|
||||
level=logging.INFO,
|
||||
message="embedding batch request completed",
|
||||
source="rag.embedding",
|
||||
provider=provider,
|
||||
operation="embed_batch",
|
||||
result_status="success",
|
||||
duration_ms=(time.monotonic() - started_at) * 1000,
|
||||
batch_size=len(batch),
|
||||
model=model_name,
|
||||
)
|
||||
|
||||
return all_embeddings
|
||||
|
||||
|
||||
@@ -0,0 +1,55 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
|
||||
@dataclass
|
||||
class FailureContract:
|
||||
status: str = "error"
|
||||
error_code: str = "internal_error"
|
||||
message: str = ""
|
||||
source: str = "application"
|
||||
warnings: list[str] = field(default_factory=list)
|
||||
retriable: bool = False
|
||||
details: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
payload = {
|
||||
"status": self.status,
|
||||
"error_code": self.error_code,
|
||||
"message": self.message,
|
||||
"source": self.source,
|
||||
"warnings": list(self.warnings),
|
||||
"retriable": self.retriable,
|
||||
}
|
||||
if self.details:
|
||||
payload["details"] = self.details
|
||||
return payload
|
||||
|
||||
|
||||
class RAGServiceError(Exception):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
error_code: str,
|
||||
message: str,
|
||||
source: str,
|
||||
warnings: list[str] | None = None,
|
||||
retriable: bool = False,
|
||||
details: dict[str, Any] | None = None,
|
||||
http_status: int = 500,
|
||||
) -> None:
|
||||
super().__init__(message)
|
||||
self.http_status = http_status
|
||||
self.contract = FailureContract(
|
||||
error_code=error_code,
|
||||
message=message,
|
||||
source=source,
|
||||
warnings=warnings or [],
|
||||
retriable=retriable,
|
||||
details=details or {},
|
||||
)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return self.contract.to_dict()
|
||||
+41
-23
@@ -12,6 +12,7 @@ from pathlib import Path
|
||||
from .chunker import chunk_text, chunk_texts
|
||||
from .config import load_rag_config, RAGConfig
|
||||
from .embedding import embed_texts
|
||||
from .observability import classify_exception, log_event, observe_operation, record_metric
|
||||
from .user_data import load_user_sources, build_user_weather_text
|
||||
from .vector_store import QdrantVectorStore
|
||||
|
||||
@@ -36,7 +37,19 @@ def _load_file(path: Path) -> str | None:
|
||||
return None
|
||||
try:
|
||||
return path.read_text(encoding="utf-8").strip()
|
||||
except Exception:
|
||||
except Exception as exc:
|
||||
failure = classify_exception(exc)
|
||||
log_event(
|
||||
level=40,
|
||||
message="rag ingest file load failed",
|
||||
source="rag.ingest",
|
||||
provider=None,
|
||||
operation="load_file",
|
||||
result_status="error",
|
||||
error_code=failure.error_code,
|
||||
path=str(path),
|
||||
)
|
||||
record_metric("rag.ingest.file_load_failure", error_code=failure.error_code)
|
||||
return None
|
||||
|
||||
|
||||
@@ -122,12 +135,14 @@ def ingest(
|
||||
"""
|
||||
cfg = config or load_rag_config()
|
||||
store = QdrantVectorStore(config=cfg)
|
||||
if recreate:
|
||||
store.ensure_collection(recreate=True)
|
||||
with observe_operation(source="rag.ingest", provider=cfg.embedding.provider, operation="ingest"):
|
||||
if recreate:
|
||||
store.ensure_collection(recreate=True)
|
||||
|
||||
sources = load_sources(config=cfg, kb_name=kb_name)
|
||||
if not sources:
|
||||
return {"chunks_added": 0, "sources": [], "error": "هیچ منبعی یافت نشد"}
|
||||
sources = load_sources(config=cfg, kb_name=kb_name)
|
||||
if not sources:
|
||||
record_metric("rag.ingest.empty_sources", kb_name=kb_name)
|
||||
return {"chunks_added": 0, "sources": [], "error": "هیچ منبعی یافت نشد"}
|
||||
|
||||
all_chunks: list[str] = []
|
||||
all_metas: list[dict] = []
|
||||
@@ -146,24 +161,27 @@ def ingest(
|
||||
"kb_name": src_kb,
|
||||
})
|
||||
|
||||
if not all_chunks:
|
||||
return {"chunks_added": 0, "sources": [s[0] for s in sources], "error": "هیچ چانکی ساخته نشد"}
|
||||
if not all_chunks:
|
||||
record_metric("rag.ingest.empty_chunks", kb_name=kb_name)
|
||||
return {"chunks_added": 0, "sources": [s[0] for s in sources], "error": "هیچ چانکی ساخته نشد"}
|
||||
|
||||
embeddings = embed_texts(all_chunks, config=cfg)
|
||||
if len(embeddings) != len(all_chunks):
|
||||
embeddings = embed_texts(all_chunks, config=cfg)
|
||||
if len(embeddings) != len(all_chunks):
|
||||
record_metric("rag.ingest.embedding_mismatch", kb_name=kb_name)
|
||||
return {
|
||||
"chunks_added": 0,
|
||||
"sources": [s[0] for s in sources],
|
||||
"error": f"تعداد embed با چانکها مطابقت ندارد: {len(embeddings)} vs {len(all_chunks)}",
|
||||
}
|
||||
|
||||
store.add_documents(
|
||||
ids=all_ids,
|
||||
embeddings=embeddings,
|
||||
documents=all_chunks,
|
||||
metadatas=all_metas,
|
||||
)
|
||||
record_metric("rag.ingest.success", kb_name=kb_name, chunks=len(all_chunks))
|
||||
return {
|
||||
"chunks_added": 0,
|
||||
"chunks_added": len(all_chunks),
|
||||
"sources": [s[0] for s in sources],
|
||||
"error": f"تعداد embed با چانکها مطابقت ندارد: {len(embeddings)} vs {len(all_chunks)}",
|
||||
}
|
||||
|
||||
store.add_documents(
|
||||
ids=all_ids,
|
||||
embeddings=embeddings,
|
||||
documents=all_chunks,
|
||||
metadatas=all_metas,
|
||||
)
|
||||
return {
|
||||
"chunks_added": len(all_chunks),
|
||||
"sources": [s[0] for s in sources],
|
||||
}
|
||||
|
||||
@@ -0,0 +1,129 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
from collections import Counter
|
||||
from contextvars import ContextVar
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
_request_id_ctx: ContextVar[str | None] = ContextVar("rag_request_id", default=None)
|
||||
METRICS: Counter[str] = Counter()
|
||||
|
||||
|
||||
def set_request_id(request_id: str | None) -> None:
|
||||
_request_id_ctx.set(request_id)
|
||||
|
||||
|
||||
def get_request_id() -> str | None:
|
||||
return _request_id_ctx.get()
|
||||
|
||||
|
||||
def record_metric(name: str, value: int = 1, **tags: Any) -> None:
|
||||
suffix = ",".join(f"{key}={tags[key]}" for key in sorted(tags) if tags[key] is not None)
|
||||
metric_key = f"{name}|{suffix}" if suffix else name
|
||||
METRICS[metric_key] += value
|
||||
|
||||
|
||||
@dataclass
|
||||
class ClassifiedFailure:
|
||||
error_code: str
|
||||
failure_type: str
|
||||
retriable: bool
|
||||
|
||||
|
||||
def classify_exception(exc: Exception) -> ClassifiedFailure:
|
||||
exc_name = exc.__class__.__name__.lower()
|
||||
message = str(exc).lower()
|
||||
if "timeout" in exc_name or "timeout" in message:
|
||||
return ClassifiedFailure("timeout", "timeout", True)
|
||||
if "json" in exc_name or "json" in message:
|
||||
return ClassifiedFailure("parse_error", "parse_error", False)
|
||||
if "validation" in exc_name or "invalid" in message:
|
||||
return ClassifiedFailure("validation_failure", "validation_failure", False)
|
||||
if "connection" in exc_name or "unavailable" in message:
|
||||
return ClassifiedFailure("dependency_unavailable", "dependency_unavailable", True)
|
||||
return ClassifiedFailure("provider_error", "provider_error", True)
|
||||
|
||||
|
||||
def log_event(
|
||||
*,
|
||||
level: int,
|
||||
message: str,
|
||||
source: str,
|
||||
provider: str | None,
|
||||
operation: str,
|
||||
result_status: str,
|
||||
duration_ms: float | None = None,
|
||||
error_code: str | None = None,
|
||||
**extra: Any,
|
||||
) -> None:
|
||||
payload = {
|
||||
"source": source,
|
||||
"provider": provider,
|
||||
"operation": operation,
|
||||
"result_status": result_status,
|
||||
"duration_ms": round(duration_ms, 2) if duration_ms is not None else None,
|
||||
"error_code": error_code,
|
||||
"request_id": get_request_id(),
|
||||
}
|
||||
payload.update({key: value for key, value in extra.items() if value is not None})
|
||||
logger.log(level, message, extra={"event": payload})
|
||||
|
||||
|
||||
class observe_operation:
|
||||
def __init__(self, *, source: str, provider: str | None, operation: str):
|
||||
self.source = source
|
||||
self.provider = provider
|
||||
self.operation = operation
|
||||
self.started_at = 0.0
|
||||
|
||||
def __enter__(self):
|
||||
self.started_at = time.monotonic()
|
||||
log_event(
|
||||
level=logging.INFO,
|
||||
message="rag operation started",
|
||||
source=self.source,
|
||||
provider=self.provider,
|
||||
operation=self.operation,
|
||||
result_status="started",
|
||||
)
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, _tb):
|
||||
duration_ms = (time.monotonic() - self.started_at) * 1000
|
||||
if exc is None:
|
||||
log_event(
|
||||
level=logging.INFO,
|
||||
message="rag operation completed",
|
||||
source=self.source,
|
||||
provider=self.provider,
|
||||
operation=self.operation,
|
||||
result_status="success",
|
||||
duration_ms=duration_ms,
|
||||
)
|
||||
record_metric("rag.operation.success", source=self.source, provider=self.provider, operation=self.operation)
|
||||
return False
|
||||
|
||||
failure = classify_exception(exc)
|
||||
log_event(
|
||||
level=logging.ERROR,
|
||||
message="rag operation failed",
|
||||
source=self.source,
|
||||
provider=self.provider,
|
||||
operation=self.operation,
|
||||
result_status="error",
|
||||
duration_ms=duration_ms,
|
||||
error_code=failure.error_code,
|
||||
failure_type=failure.failure_type,
|
||||
)
|
||||
record_metric(
|
||||
"rag.operation.failure",
|
||||
source=self.source,
|
||||
provider=self.provider,
|
||||
operation=self.operation,
|
||||
error_code=failure.error_code,
|
||||
)
|
||||
return False
|
||||
+37
-28
@@ -3,6 +3,7 @@
|
||||
"""
|
||||
from .config import load_rag_config, RAGConfig, get_service_config
|
||||
from .embedding import embed_single, embed_texts
|
||||
from .observability import observe_operation, record_metric
|
||||
from .vector_store import QdrantVectorStore
|
||||
|
||||
|
||||
@@ -63,15 +64,19 @@ def search_with_query(
|
||||
use_user_embeddings=use_user_embeddings,
|
||||
)
|
||||
|
||||
query_vector = embed_single(query, config=cfg)
|
||||
store = QdrantVectorStore(config=cfg)
|
||||
return store.search(
|
||||
query_vector=query_vector,
|
||||
limit=limit,
|
||||
score_threshold=score_threshold,
|
||||
sensor_uuids=sensor_filters,
|
||||
kb_names=kb_filters,
|
||||
)
|
||||
with observe_operation(source="rag.retrieve", provider=cfg.embedding.provider, operation="search_with_query"):
|
||||
query_vector = embed_single(query, config=cfg)
|
||||
store = QdrantVectorStore(config=cfg)
|
||||
results = store.search(
|
||||
query_vector=query_vector,
|
||||
limit=limit,
|
||||
score_threshold=score_threshold,
|
||||
sensor_uuids=sensor_filters,
|
||||
kb_names=kb_filters,
|
||||
)
|
||||
if not results:
|
||||
record_metric("rag.retrieve.empty_result", operation="search_with_query", service_id=service_id)
|
||||
return results
|
||||
|
||||
|
||||
def search_with_texts(
|
||||
@@ -102,24 +107,28 @@ def search_with_texts(
|
||||
)
|
||||
|
||||
store = QdrantVectorStore(config=cfg)
|
||||
vectors = embed_texts(normalized_texts, config=cfg)
|
||||
merged_results: dict[str, dict] = {}
|
||||
with observe_operation(source="rag.retrieve", provider=cfg.embedding.provider, operation="search_with_texts"):
|
||||
vectors = embed_texts(normalized_texts, config=cfg)
|
||||
merged_results: dict[str, dict] = {}
|
||||
|
||||
for vector in vectors:
|
||||
results = store.search(
|
||||
query_vector=vector,
|
||||
limit=per_text_limit,
|
||||
score_threshold=score_threshold,
|
||||
sensor_uuids=sensor_filters,
|
||||
kb_names=kb_filters,
|
||||
)
|
||||
for item in results:
|
||||
current = merged_results.get(item["id"])
|
||||
if current is None or item["score"] > current["score"]:
|
||||
merged_results[item["id"]] = item
|
||||
for vector in vectors:
|
||||
results = store.search(
|
||||
query_vector=vector,
|
||||
limit=per_text_limit,
|
||||
score_threshold=score_threshold,
|
||||
sensor_uuids=sensor_filters,
|
||||
kb_names=kb_filters,
|
||||
)
|
||||
for item in results:
|
||||
current = merged_results.get(item["id"])
|
||||
if current is None or item["score"] > current["score"]:
|
||||
merged_results[item["id"]] = item
|
||||
|
||||
return sorted(
|
||||
merged_results.values(),
|
||||
key=lambda item: item["score"],
|
||||
reverse=True,
|
||||
)[:limit]
|
||||
final_results = sorted(
|
||||
merged_results.values(),
|
||||
key=lambda item: item["score"],
|
||||
reverse=True,
|
||||
)[:limit]
|
||||
if not final_results:
|
||||
record_metric("rag.retrieve.empty_result", operation="search_with_texts", service_id=service_id)
|
||||
return final_results
|
||||
|
||||
@@ -18,6 +18,7 @@ from rag.chat import (
|
||||
build_rag_context,
|
||||
)
|
||||
from rag.config import RAGConfig, get_service_config, load_rag_config
|
||||
from rag.failure_contract import RAGServiceError
|
||||
from rag.user_data import build_plant_text
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -73,18 +74,47 @@ def _clean_json(raw: str) -> dict[str, Any]:
|
||||
cleaned = cleaned[4:]
|
||||
cleaned = cleaned.strip()
|
||||
if not cleaned:
|
||||
return {}
|
||||
raise RAGServiceError(
|
||||
error_code="empty_response",
|
||||
message="Pest disease LLM response was empty.",
|
||||
source="llm",
|
||||
retriable=True,
|
||||
details={"service_id": SERVICE_ID},
|
||||
http_status=502,
|
||||
)
|
||||
try:
|
||||
return json.loads(cleaned)
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
parsed = json.loads(cleaned)
|
||||
except (json.JSONDecodeError, ValueError) as exc:
|
||||
logger.warning("Invalid JSON returned by pest_disease LLM: %s", cleaned[:500])
|
||||
return {}
|
||||
raise RAGServiceError(
|
||||
error_code="invalid_json",
|
||||
message="Pest disease LLM response was not valid JSON.",
|
||||
source="llm",
|
||||
retriable=True,
|
||||
details={"service_id": SERVICE_ID},
|
||||
http_status=502,
|
||||
) from exc
|
||||
if not isinstance(parsed, dict):
|
||||
raise RAGServiceError(
|
||||
error_code="invalid_schema",
|
||||
message="Pest disease LLM response root must be a JSON object.",
|
||||
source="llm",
|
||||
details={"service_id": SERVICE_ID},
|
||||
http_status=502,
|
||||
)
|
||||
return parsed
|
||||
|
||||
|
||||
def _load_farm_or_error(farm_uuid: str) -> dict[str, Any]:
|
||||
farm_details = get_farm_details(farm_uuid)
|
||||
if farm_details is None:
|
||||
raise ValueError("farm_uuid نامعتبر است یا اطلاعات مزرعه پیدا نشد.")
|
||||
raise RAGServiceError(
|
||||
error_code="farm_not_found",
|
||||
message="farm_uuid نامعتبر است یا اطلاعات مزرعه پیدا نشد.",
|
||||
source="farm_data",
|
||||
details={"farm_uuid": farm_uuid},
|
||||
http_status=404,
|
||||
)
|
||||
return farm_details
|
||||
|
||||
|
||||
@@ -213,9 +243,12 @@ def _validate_detection_result(parsed: dict[str, Any]) -> dict[str, Any]:
|
||||
}
|
||||
missing = [key for key in required_keys if key not in parsed]
|
||||
if missing:
|
||||
raise ValueError(
|
||||
"Pest disease detection response is missing required fields: "
|
||||
+ ", ".join(missing)
|
||||
raise RAGServiceError(
|
||||
error_code="invalid_schema",
|
||||
message="Pest disease detection response is missing required fields: " + ", ".join(missing),
|
||||
source="llm",
|
||||
details={"missing_fields": missing, "service_id": SERVICE_ID},
|
||||
http_status=502,
|
||||
)
|
||||
return parsed
|
||||
|
||||
@@ -232,9 +265,12 @@ def _validate_risk_result(parsed: dict[str, Any]) -> dict[str, Any]:
|
||||
}
|
||||
missing = [key for key in required_keys if key not in parsed]
|
||||
if missing:
|
||||
raise ValueError(
|
||||
"Pest disease risk response is missing required fields: "
|
||||
+ ", ".join(missing)
|
||||
raise RAGServiceError(
|
||||
error_code="invalid_schema",
|
||||
message="Pest disease risk response is missing required fields: " + ", ".join(missing),
|
||||
source="llm",
|
||||
details={"missing_fields": missing, "service_id": SERVICE_ID},
|
||||
http_status=502,
|
||||
)
|
||||
return parsed
|
||||
|
||||
@@ -301,7 +337,12 @@ def get_pest_disease_detection(
|
||||
) -> dict[str, Any]:
|
||||
normalized_images = _normalize_images(images)
|
||||
if not normalized_images:
|
||||
raise ValueError("حداقل یک تصویر برای تشخیص لازم است.")
|
||||
raise RAGServiceError(
|
||||
error_code="missing_images",
|
||||
message="حداقل یک تصویر برای تشخیص لازم است.",
|
||||
source="request",
|
||||
http_status=400,
|
||||
)
|
||||
|
||||
cfg = load_rag_config()
|
||||
service, client, model = _build_service_client(cfg)
|
||||
@@ -338,12 +379,25 @@ def get_pest_disease_detection(
|
||||
raw = response.choices[0].message.content.strip()
|
||||
parsed = _clean_json(raw)
|
||||
_complete_audit_log(audit_log, raw)
|
||||
except RAGServiceError as exc:
|
||||
logger.error("Pest disease detection failed for %s: %s", farm_uuid, exc)
|
||||
_fail_audit_log(audit_log, str(exc))
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error("Pest disease detection failed for %s: %s", farm_uuid, exc)
|
||||
_fail_audit_log(audit_log, str(exc))
|
||||
raise RuntimeError(f"Pest disease detection failed for farm {farm_uuid}.") from exc
|
||||
raise RAGServiceError(
|
||||
error_code="upstream_failure",
|
||||
message=f"Pest disease detection failed for farm {farm_uuid}.",
|
||||
source="llm",
|
||||
retriable=True,
|
||||
details={"farm_uuid": farm_uuid, "service_id": SERVICE_ID},
|
||||
http_status=503,
|
||||
) from exc
|
||||
|
||||
parsed = _validate_detection_result(parsed)
|
||||
parsed["status"] = "success"
|
||||
parsed["source"] = "llm"
|
||||
parsed["farm_uuid"] = farm_uuid
|
||||
parsed["raw_response"] = raw
|
||||
return parsed
|
||||
@@ -392,12 +446,25 @@ def get_pest_disease_risk(
|
||||
raw = response.choices[0].message.content.strip()
|
||||
parsed = _clean_json(raw)
|
||||
_complete_audit_log(audit_log, raw)
|
||||
except RAGServiceError as exc:
|
||||
logger.error("Pest disease risk prediction failed for %s: %s", farm_uuid, exc)
|
||||
_fail_audit_log(audit_log, str(exc))
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error("Pest disease risk prediction failed for %s: %s", farm_uuid, exc)
|
||||
_fail_audit_log(audit_log, str(exc))
|
||||
raise RuntimeError(f"Pest disease risk prediction failed for farm {farm_uuid}.") from exc
|
||||
raise RAGServiceError(
|
||||
error_code="upstream_failure",
|
||||
message=f"Pest disease risk prediction failed for farm {farm_uuid}.",
|
||||
source="llm",
|
||||
retriable=True,
|
||||
details={"farm_uuid": farm_uuid, "service_id": SERVICE_ID},
|
||||
http_status=503,
|
||||
) from exc
|
||||
|
||||
parsed = _validate_risk_result(parsed)
|
||||
parsed["status"] = "success"
|
||||
parsed["source"] = "llm"
|
||||
parsed["farm_uuid"] = farm_uuid
|
||||
parsed["raw_response"] = raw
|
||||
return parsed
|
||||
|
||||
@@ -14,6 +14,7 @@ from rag.chat import (
|
||||
build_rag_context,
|
||||
)
|
||||
from rag.config import RAGConfig, get_service_config, load_rag_config
|
||||
from rag.failure_contract import RAGServiceError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -39,18 +40,48 @@ def _clean_json(raw: str) -> dict[str, Any]:
|
||||
cleaned = cleaned[4:]
|
||||
cleaned = cleaned.strip()
|
||||
if not cleaned:
|
||||
return {}
|
||||
raise RAGServiceError(
|
||||
error_code="empty_response",
|
||||
message="Soil anomaly LLM response was empty.",
|
||||
source="llm",
|
||||
retriable=True,
|
||||
details={"service_id": SERVICE_ID},
|
||||
http_status=502,
|
||||
)
|
||||
try:
|
||||
return json.loads(cleaned)
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
parsed = json.loads(cleaned)
|
||||
except (json.JSONDecodeError, ValueError) as exc:
|
||||
logger.warning("Invalid JSON returned by soil_anomaly LLM: %s", cleaned[:500])
|
||||
return {}
|
||||
raise RAGServiceError(
|
||||
error_code="invalid_json",
|
||||
message="Soil anomaly LLM response was not valid JSON.",
|
||||
source="llm",
|
||||
retriable=True,
|
||||
details={"service_id": SERVICE_ID},
|
||||
http_status=502,
|
||||
) from exc
|
||||
if not isinstance(parsed, dict):
|
||||
raise RAGServiceError(
|
||||
error_code="invalid_schema",
|
||||
message="Soil anomaly LLM response root must be a JSON object.",
|
||||
source="llm",
|
||||
retriable=False,
|
||||
details={"service_id": SERVICE_ID},
|
||||
http_status=502,
|
||||
)
|
||||
return parsed
|
||||
|
||||
|
||||
def _load_farm_or_error(farm_uuid: str) -> dict[str, Any]:
|
||||
farm_details = get_farm_details(farm_uuid)
|
||||
if farm_details is None:
|
||||
raise ValueError("farm_uuid نامعتبر است یا اطلاعات مزرعه پیدا نشد.")
|
||||
raise RAGServiceError(
|
||||
error_code="farm_not_found",
|
||||
message="farm_uuid نامعتبر است یا اطلاعات مزرعه پیدا نشد.",
|
||||
source="farm_data",
|
||||
details={"farm_uuid": farm_uuid},
|
||||
http_status=404,
|
||||
)
|
||||
return farm_details
|
||||
|
||||
|
||||
@@ -80,9 +111,12 @@ def _validate_anomaly_insight(parsed: dict[str, Any]) -> dict[str, Any]:
|
||||
}
|
||||
missing = [key for key in required_keys if key not in parsed]
|
||||
if missing:
|
||||
raise ValueError(
|
||||
"Soil anomaly insight response is missing required fields: "
|
||||
+ ", ".join(missing)
|
||||
raise RAGServiceError(
|
||||
error_code="invalid_schema",
|
||||
message="Soil anomaly insight response is missing required fields: " + ", ".join(missing),
|
||||
source="llm",
|
||||
details={"missing_fields": missing, "service_id": SERVICE_ID},
|
||||
http_status=502,
|
||||
)
|
||||
return parsed
|
||||
|
||||
@@ -156,12 +190,25 @@ def get_soil_anomaly_insight(
|
||||
raw = response.choices[0].message.content.strip()
|
||||
parsed = _clean_json(raw)
|
||||
_complete_audit_log(audit_log, raw)
|
||||
except RAGServiceError as exc:
|
||||
logger.error("Soil anomaly insight failed for %s: %s", farm_uuid, exc)
|
||||
_fail_audit_log(audit_log, str(exc))
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error("Soil anomaly insight failed for %s: %s", farm_uuid, exc)
|
||||
_fail_audit_log(audit_log, str(exc))
|
||||
raise RuntimeError(f"Soil anomaly insight failed for farm {farm_uuid}.") from exc
|
||||
raise RAGServiceError(
|
||||
error_code="upstream_failure",
|
||||
message=f"Soil anomaly insight failed for farm {farm_uuid}.",
|
||||
source="llm",
|
||||
retriable=True,
|
||||
details={"farm_uuid": farm_uuid, "service_id": SERVICE_ID},
|
||||
http_status=503,
|
||||
) from exc
|
||||
|
||||
parsed = _validate_anomaly_insight(parsed)
|
||||
parsed["status"] = "success"
|
||||
parsed["source"] = "llm"
|
||||
parsed["farm_uuid"] = farm_uuid
|
||||
parsed["raw_response"] = raw
|
||||
return parsed
|
||||
|
||||
@@ -14,6 +14,7 @@ from rag.chat import (
|
||||
build_rag_context,
|
||||
)
|
||||
from rag.config import RAGConfig, get_service_config, load_rag_config
|
||||
from rag.failure_contract import RAGServiceError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -38,18 +39,47 @@ def _clean_json(raw: str) -> dict[str, Any]:
|
||||
cleaned = cleaned[4:]
|
||||
cleaned = cleaned.strip()
|
||||
if not cleaned:
|
||||
return {}
|
||||
raise RAGServiceError(
|
||||
error_code="empty_response",
|
||||
message="Water need prediction LLM response was empty.",
|
||||
source="llm",
|
||||
retriable=True,
|
||||
details={"service_id": SERVICE_ID},
|
||||
http_status=502,
|
||||
)
|
||||
try:
|
||||
return json.loads(cleaned)
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
parsed = json.loads(cleaned)
|
||||
except (json.JSONDecodeError, ValueError) as exc:
|
||||
logger.warning("Invalid JSON returned by water_need_prediction LLM: %s", cleaned[:500])
|
||||
return {}
|
||||
raise RAGServiceError(
|
||||
error_code="invalid_json",
|
||||
message="Water need prediction LLM response was not valid JSON.",
|
||||
source="llm",
|
||||
retriable=True,
|
||||
details={"service_id": SERVICE_ID},
|
||||
http_status=502,
|
||||
) from exc
|
||||
if not isinstance(parsed, dict):
|
||||
raise RAGServiceError(
|
||||
error_code="invalid_schema",
|
||||
message="Water need prediction LLM response root must be a JSON object.",
|
||||
source="llm",
|
||||
details={"service_id": SERVICE_ID},
|
||||
http_status=502,
|
||||
)
|
||||
return parsed
|
||||
|
||||
|
||||
def _load_farm_or_error(farm_uuid: str) -> dict[str, Any]:
|
||||
farm_details = get_farm_details(farm_uuid)
|
||||
if farm_details is None:
|
||||
raise ValueError("farm_uuid نامعتبر است یا اطلاعات مزرعه پیدا نشد.")
|
||||
raise RAGServiceError(
|
||||
error_code="farm_not_found",
|
||||
message="farm_uuid نامعتبر است یا اطلاعات مزرعه پیدا نشد.",
|
||||
source="farm_data",
|
||||
details={"farm_uuid": farm_uuid},
|
||||
http_status=404,
|
||||
)
|
||||
return farm_details
|
||||
|
||||
|
||||
@@ -78,9 +108,12 @@ def _validate_prediction_insight(parsed: dict[str, Any]) -> dict[str, Any]:
|
||||
}
|
||||
missing = [key for key in required_keys if key not in parsed]
|
||||
if missing:
|
||||
raise ValueError(
|
||||
"Water need prediction insight response is missing required fields: "
|
||||
+ ", ".join(missing)
|
||||
raise RAGServiceError(
|
||||
error_code="invalid_schema",
|
||||
message="Water need prediction insight response is missing required fields: " + ", ".join(missing),
|
||||
source="llm",
|
||||
details={"missing_fields": missing, "service_id": SERVICE_ID},
|
||||
http_status=502,
|
||||
)
|
||||
return parsed
|
||||
|
||||
@@ -154,12 +187,25 @@ def get_water_need_prediction_insight(
|
||||
raw = response.choices[0].message.content.strip()
|
||||
parsed = _clean_json(raw)
|
||||
_complete_audit_log(audit_log, raw)
|
||||
except RAGServiceError as exc:
|
||||
logger.error("Water need prediction insight failed for %s: %s", farm_uuid, exc)
|
||||
_fail_audit_log(audit_log, str(exc))
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error("Water need prediction insight failed for %s: %s", farm_uuid, exc)
|
||||
_fail_audit_log(audit_log, str(exc))
|
||||
raise RuntimeError(f"Water need prediction insight failed for farm {farm_uuid}.") from exc
|
||||
raise RAGServiceError(
|
||||
error_code="upstream_failure",
|
||||
message=f"Water need prediction insight failed for farm {farm_uuid}.",
|
||||
source="llm",
|
||||
retriable=True,
|
||||
details={"farm_uuid": farm_uuid, "service_id": SERVICE_ID},
|
||||
http_status=503,
|
||||
) from exc
|
||||
|
||||
parsed = _validate_prediction_insight(parsed)
|
||||
parsed["status"] = "success"
|
||||
parsed["source"] = "llm"
|
||||
parsed["farm_uuid"] = farm_uuid
|
||||
parsed["raw_response"] = raw
|
||||
return parsed
|
||||
|
||||
@@ -14,6 +14,7 @@ from rag.chat import (
|
||||
_load_service_tone,
|
||||
)
|
||||
from rag.config import RAGConfig, get_service_config, load_rag_config
|
||||
from rag.failure_contract import RAGServiceError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -90,6 +91,8 @@ class YieldHarvestRAGService:
|
||||
if audit_log is not None:
|
||||
_complete_audit_log(audit_log, raw)
|
||||
return {
|
||||
"status": "success",
|
||||
"source": "llm",
|
||||
"season_highlights_subtitle": validated.season_highlights_subtitle,
|
||||
"yield_prediction_explanation": validated.yield_prediction_explanation,
|
||||
"harvest_readiness_summary": validated.harvest_readiness_summary,
|
||||
@@ -99,12 +102,25 @@ class YieldHarvestRAGService:
|
||||
logger.warning("Yield harvest narrative parsing failed for farm_uuid=%s: %s", farm_uuid, exc)
|
||||
if audit_log is not None:
|
||||
_fail_audit_log(audit_log, str(exc))
|
||||
return {}
|
||||
raise RAGServiceError(
|
||||
error_code="invalid_payload",
|
||||
message=f"Yield harvest narrative parsing failed for farm_uuid={farm_uuid or 'unknown'}.",
|
||||
source="llm",
|
||||
details={"farm_uuid": farm_uuid or "unknown", "service_id": SERVICE_ID},
|
||||
http_status=502,
|
||||
) from exc
|
||||
except Exception as exc:
|
||||
logger.error("Yield harvest narrative LLM call failed for farm_uuid=%s: %s", farm_uuid, exc)
|
||||
if audit_log is not None:
|
||||
_fail_audit_log(audit_log, str(exc))
|
||||
return {}
|
||||
raise RAGServiceError(
|
||||
error_code="upstream_failure",
|
||||
message=f"Yield harvest narrative generation failed for farm_uuid={farm_uuid or 'unknown'}.",
|
||||
source="llm",
|
||||
retriable=True,
|
||||
details={"farm_uuid": farm_uuid or "unknown", "service_id": SERVICE_ID},
|
||||
http_status=503,
|
||||
) from exc
|
||||
|
||||
def _build_service_client(self, cfg: RAGConfig):
|
||||
service = get_service_config(SERVICE_ID, cfg)
|
||||
@@ -217,11 +233,31 @@ class YieldHarvestRAGService:
|
||||
cleaned = cleaned[4:]
|
||||
cleaned = cleaned.strip()
|
||||
if not cleaned:
|
||||
raise ValueError("Yield harvest narrative response was empty.")
|
||||
raise RAGServiceError(
|
||||
error_code="empty_response",
|
||||
message="Yield harvest narrative response was empty.",
|
||||
source="llm",
|
||||
retriable=True,
|
||||
details={"service_id": SERVICE_ID},
|
||||
http_status=502,
|
||||
)
|
||||
try:
|
||||
parsed = json.loads(cleaned)
|
||||
except (json.JSONDecodeError, ValueError) as exc:
|
||||
raise ValueError("Yield harvest narrative response was not valid JSON.") from exc
|
||||
raise RAGServiceError(
|
||||
error_code="invalid_json",
|
||||
message="Yield harvest narrative response was not valid JSON.",
|
||||
source="llm",
|
||||
retriable=True,
|
||||
details={"service_id": SERVICE_ID},
|
||||
http_status=502,
|
||||
) from exc
|
||||
if not isinstance(parsed, dict):
|
||||
raise ValueError("Yield harvest narrative response root must be a JSON object.")
|
||||
raise RAGServiceError(
|
||||
error_code="invalid_schema",
|
||||
message="Yield harvest narrative response root must be a JSON object.",
|
||||
source="llm",
|
||||
details={"service_id": SERVICE_ID},
|
||||
http_status=502,
|
||||
)
|
||||
return parsed
|
||||
|
||||
@@ -0,0 +1,88 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from django.test import SimpleTestCase
|
||||
|
||||
from rag.failure_contract import RAGServiceError
|
||||
from rag.services.pest_disease import get_pest_disease_detection
|
||||
from rag.services.soil_anomaly import get_soil_anomaly_insight
|
||||
from rag.services.water_need_prediction import get_water_need_prediction_insight
|
||||
from rag.services.yield_harvest import YieldHarvestRAGService
|
||||
|
||||
|
||||
class RAGFailureContractTests(SimpleTestCase):
|
||||
@patch("rag.services.soil_anomaly._create_audit_log", return_value=object())
|
||||
@patch("rag.services.soil_anomaly._fail_audit_log")
|
||||
@patch("rag.services.soil_anomaly._build_service_client")
|
||||
@patch("rag.services.soil_anomaly.build_rag_context", return_value="")
|
||||
@patch("rag.services.soil_anomaly._load_farm_or_error", return_value={"farm_uuid": "farm-1"})
|
||||
def test_soil_anomaly_invalid_json_raises_structured_error(
|
||||
self,
|
||||
_mock_load_farm,
|
||||
_mock_context,
|
||||
mock_build_client,
|
||||
_mock_fail,
|
||||
_mock_audit,
|
||||
):
|
||||
client = Mock()
|
||||
client.chat.completions.create.return_value = SimpleNamespace(
|
||||
choices=[SimpleNamespace(message=SimpleNamespace(content="not-json"))]
|
||||
)
|
||||
mock_build_client.return_value = (SimpleNamespace(system_prompt=""), client, "gpt-test")
|
||||
|
||||
with self.assertRaises(RAGServiceError) as exc_info:
|
||||
get_soil_anomaly_insight(farm_uuid="farm-1", anomaly_payload={})
|
||||
|
||||
self.assertEqual(exc_info.exception.contract.error_code, "invalid_json")
|
||||
|
||||
@patch("rag.services.water_need_prediction._create_audit_log", return_value=object())
|
||||
@patch("rag.services.water_need_prediction._fail_audit_log")
|
||||
@patch("rag.services.water_need_prediction._build_service_client")
|
||||
@patch("rag.services.water_need_prediction.build_rag_context", return_value="")
|
||||
@patch("rag.services.water_need_prediction._load_farm_or_error", return_value={"farm_uuid": "farm-1"})
|
||||
def test_water_need_invalid_json_raises_structured_error(
|
||||
self,
|
||||
_mock_load_farm,
|
||||
_mock_context,
|
||||
mock_build_client,
|
||||
_mock_fail,
|
||||
_mock_audit,
|
||||
):
|
||||
client = Mock()
|
||||
client.chat.completions.create.return_value = SimpleNamespace(
|
||||
choices=[SimpleNamespace(message=SimpleNamespace(content="not-json"))]
|
||||
)
|
||||
mock_build_client.return_value = (SimpleNamespace(system_prompt=""), client, "gpt-test")
|
||||
|
||||
with self.assertRaises(RAGServiceError) as exc_info:
|
||||
get_water_need_prediction_insight(farm_uuid="farm-1", prediction_payload={})
|
||||
|
||||
self.assertEqual(exc_info.exception.contract.error_code, "invalid_json")
|
||||
|
||||
def test_pest_detection_requires_image_with_structured_error(self):
|
||||
with self.assertRaises(RAGServiceError) as exc_info:
|
||||
get_pest_disease_detection(farm_uuid="farm-1", images=[])
|
||||
|
||||
self.assertEqual(exc_info.exception.contract.error_code, "missing_images")
|
||||
|
||||
@patch("rag.services.yield_harvest._create_audit_log", return_value=object())
|
||||
@patch("rag.services.yield_harvest._fail_audit_log")
|
||||
@patch("rag.services.yield_harvest.YieldHarvestRAGService._build_service_client")
|
||||
def test_yield_harvest_invalid_json_raises_structured_error(
|
||||
self,
|
||||
mock_build_client,
|
||||
_mock_fail,
|
||||
_mock_audit,
|
||||
):
|
||||
client = Mock()
|
||||
client.chat.completions.create.return_value = SimpleNamespace(
|
||||
choices=[SimpleNamespace(message=SimpleNamespace(content="not-json"))]
|
||||
)
|
||||
mock_build_client.return_value = (SimpleNamespace(system_prompt=""), client, "gpt-test")
|
||||
|
||||
with self.assertRaises(RAGServiceError) as exc_info:
|
||||
YieldHarvestRAGService().generate_narrative({"farm_uuid": "farm-1"})
|
||||
|
||||
self.assertEqual(exc_info.exception.contract.error_code, "invalid_json")
|
||||
@@ -0,0 +1,51 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from django.test import SimpleTestCase
|
||||
|
||||
from rag.embedding import embed_texts
|
||||
from rag.ingest import ingest
|
||||
from rag.observability import METRICS
|
||||
from rag.retrieve import search_with_query
|
||||
|
||||
|
||||
class RAGObservabilityTests(SimpleTestCase):
|
||||
def tearDown(self):
|
||||
METRICS.clear()
|
||||
|
||||
def test_embed_texts_records_empty_input_metric(self):
|
||||
result = embed_texts([])
|
||||
|
||||
self.assertEqual(result, [])
|
||||
self.assertEqual(METRICS["rag.embedding.empty_input|operation=embed_texts"], 1)
|
||||
|
||||
@patch("rag.retrieve.QdrantVectorStore")
|
||||
@patch("rag.retrieve.embed_single", return_value=[0.1, 0.2])
|
||||
@patch("rag.retrieve.load_rag_config")
|
||||
def test_search_with_query_records_empty_result_metric(self, mock_load_config, _mock_embed, mock_store_cls):
|
||||
mock_load_config.return_value = SimpleNamespace(
|
||||
embedding=SimpleNamespace(provider="gapgpt"),
|
||||
)
|
||||
mock_store = Mock()
|
||||
mock_store.search.return_value = []
|
||||
mock_store_cls.return_value = mock_store
|
||||
|
||||
result = search_with_query("query")
|
||||
|
||||
self.assertEqual(result, [])
|
||||
self.assertEqual(METRICS["rag.retrieve.empty_result|operation=search_with_query,service_id=None"], 1)
|
||||
|
||||
@patch("rag.ingest.load_sources", return_value=[])
|
||||
@patch("rag.ingest.QdrantVectorStore")
|
||||
@patch("rag.ingest.load_rag_config")
|
||||
def test_ingest_records_empty_sources_metric(self, mock_load_config, _mock_store_cls, _mock_sources):
|
||||
mock_load_config.return_value = SimpleNamespace(
|
||||
embedding=SimpleNamespace(provider="gapgpt"),
|
||||
)
|
||||
|
||||
result = ingest()
|
||||
|
||||
self.assertEqual(result["chunks_added"], 0)
|
||||
self.assertEqual(METRICS["rag.ingest.empty_sources|kb_name=None"], 1)
|
||||
@@ -4,10 +4,10 @@ from unittest.mock import Mock, patch
|
||||
|
||||
from django.test import TestCase
|
||||
|
||||
from farm_data.models import SensorData
|
||||
from farm_data.models import PlantCatalogSnapshot, SensorData
|
||||
from farm_data.services import assign_farm_plants_from_backend_ids
|
||||
from irrigation.models import IrrigationMethod
|
||||
from location_data.models import SoilLocation
|
||||
from plant.models import Plant
|
||||
from rag.services.fertilization import get_fertilization_recommendation
|
||||
from rag.services.irrigation import get_irrigation_recommendation
|
||||
from weather.models import WeatherForecast
|
||||
@@ -27,8 +27,8 @@ class RecommendationServiceDefaultsTests(TestCase):
|
||||
temperature_max=23.0,
|
||||
temperature_mean=18.0,
|
||||
)
|
||||
self.plant = Plant.objects.create(name="گوجهفرنگی")
|
||||
self.onion = Plant.objects.create(name="پیاز")
|
||||
self.plant = PlantCatalogSnapshot.objects.create(backend_plant_id=101, name="گوجهفرنگی")
|
||||
self.onion = PlantCatalogSnapshot.objects.create(backend_plant_id=102, name="پیاز")
|
||||
self.irrigation_method = IrrigationMethod.objects.create(name="آبیاری قطرهای")
|
||||
self.farm_uuid = uuid.uuid4()
|
||||
self.farm = SensorData.objects.create(
|
||||
@@ -45,7 +45,7 @@ class RecommendationServiceDefaultsTests(TestCase):
|
||||
}
|
||||
},
|
||||
)
|
||||
self.farm.plants.set([self.plant])
|
||||
assign_farm_plants_from_backend_ids(self.farm, [self.plant.backend_plant_id])
|
||||
|
||||
def build_irrigation_optimizer_result(self):
|
||||
return {
|
||||
@@ -162,6 +162,39 @@ class RecommendationServiceDefaultsTests(TestCase):
|
||||
self.assertEqual(result["sections"][1]["type"], "tip")
|
||||
self.assertEqual(result["water_balance"]["active_kc"], 0.9)
|
||||
|
||||
@patch("rag.services.irrigation.calculate_forecast_water_needs", return_value=[])
|
||||
@patch("rag.services.irrigation.resolve_kc", return_value=0.9)
|
||||
@patch("rag.services.irrigation.resolve_crop_profile", return_value={})
|
||||
@patch("rag.services.irrigation.build_irrigation_method_text", return_value="method text")
|
||||
@patch("rag.services.irrigation.build_plant_text", return_value="plant text")
|
||||
@patch("rag.services.irrigation.build_rag_context", return_value="")
|
||||
@patch("rag.services.irrigation._get_optimizer")
|
||||
@patch("rag.services.irrigation.get_chat_client")
|
||||
def test_irrigation_recommendation_reads_from_canonical_farm_data_assignments(
|
||||
self,
|
||||
mock_get_chat_client,
|
||||
mock_get_optimizer,
|
||||
_mock_build_rag_context,
|
||||
mock_build_plant_text,
|
||||
_mock_build_irrigation_method_text,
|
||||
_mock_resolve_crop_profile,
|
||||
_mock_resolve_kc,
|
||||
_mock_calculate_forecast_water_needs,
|
||||
):
|
||||
assign_farm_plants_from_backend_ids(self.farm, [self.onion.backend_plant_id, self.plant.backend_plant_id])
|
||||
mock_get_optimizer.return_value.optimize_irrigation.return_value = self.build_irrigation_optimizer_result()
|
||||
mock_response = Mock()
|
||||
mock_response.choices = [Mock(message=Mock(content=self.build_irrigation_llm_result()))]
|
||||
mock_get_chat_client.return_value.chat.completions.create.return_value = mock_response
|
||||
|
||||
result = get_irrigation_recommendation(
|
||||
farm_uuid=str(self.farm_uuid),
|
||||
growth_stage="میوهدهی",
|
||||
)
|
||||
|
||||
self.assertEqual(result["selected_plant"]["name"], "پیاز")
|
||||
mock_build_plant_text.assert_called_once_with("پیاز", "میوهدهی")
|
||||
|
||||
@patch("rag.services.irrigation.calculate_forecast_water_needs", return_value=[])
|
||||
@patch("rag.services.irrigation.resolve_kc", return_value=0.9)
|
||||
@patch("rag.services.irrigation.resolve_crop_profile", return_value={})
|
||||
@@ -299,6 +332,34 @@ class RecommendationServiceDefaultsTests(TestCase):
|
||||
mock_build_plant_text.assert_called_once_with("پیاز", "flowering")
|
||||
self.assertEqual(result["data"]["primary_recommendation"]["npk_ratio"]["label"], "20-20-20")
|
||||
|
||||
@patch("rag.services.fertilization.build_plant_text", return_value="plant text")
|
||||
@patch("rag.services.fertilization.build_rag_context", return_value="")
|
||||
@patch("rag.services.fertilization._get_optimizer")
|
||||
@patch("rag.services.fertilization.get_chat_client")
|
||||
def test_fertilization_recommendation_uses_canonical_assignment_lookup_for_requested_catalog_plant(
|
||||
self,
|
||||
mock_get_chat_client,
|
||||
mock_get_optimizer,
|
||||
_mock_build_rag_context,
|
||||
mock_build_plant_text,
|
||||
):
|
||||
assign_farm_plants_from_backend_ids(self.farm, [self.plant.backend_plant_id, self.onion.backend_plant_id])
|
||||
mock_get_optimizer.return_value.optimize_fertilization.return_value = self.build_fertilization_optimizer_result()
|
||||
mock_response = Mock()
|
||||
mock_response.choices = [Mock(message=Mock(content="not-json"))]
|
||||
mock_get_chat_client.return_value.chat.completions.create.return_value = mock_response
|
||||
|
||||
result = get_fertilization_recommendation(
|
||||
farm_uuid=str(self.farm_uuid),
|
||||
plant_name="پیاز",
|
||||
growth_stage="گلدهی",
|
||||
)
|
||||
|
||||
optimizer_call = mock_get_optimizer.return_value.optimize_fertilization.call_args.kwargs
|
||||
self.assertEqual(getattr(optimizer_call["plant"], "name", None), "پیاز")
|
||||
mock_build_plant_text.assert_called_once_with("پیاز", "flowering")
|
||||
self.assertEqual(result["data"]["primary_recommendation"]["npk_ratio"]["label"], "20-20-20")
|
||||
|
||||
@patch("rag.services.fertilization.build_plant_text", return_value="plant text")
|
||||
@patch("rag.services.fertilization.build_rag_context", return_value="")
|
||||
@patch("rag.services.fertilization._get_optimizer")
|
||||
|
||||
Reference in New Issue
Block a user