diff --git a/rag/chat.py b/rag/chat.py index b56644f..4c8762d 100644 --- a/rag/chat.py +++ b/rag/chat.py @@ -1,12 +1,14 @@ """ -چت RAG برای API چت عمومی — استفاده مستقیم از داده مزرعه بدون retrieval/embedding. +چت RAG برای API چت عمومی — با ارسال کامل داده مزرعه و retrieval تکمیلی از KB. """ import json import logging from pathlib import Path from .api_provider import get_chat_client +from .chunker import chunk_text from .config import RAGConfig, ServiceConfig, get_service_config, load_rag_config +from .retrieve import search_with_texts logger = logging.getLogger(__name__) @@ -61,6 +63,17 @@ def _format_farm_context_from_details(farm_details: dict) -> str: return "[اطلاعات کامل مزرعه]\n" + serialized +def _load_farm_details_context( + sensor_uuid: str | None, + farm_details: dict | None = None, +) -> str: + if not sensor_uuid: + return "" + if farm_details is not None: + return _format_farm_context_from_details(farm_details) + return _format_farm_context(sensor_uuid) + + def _build_system_prompt( service: ServiceConfig, query: str, @@ -72,8 +85,10 @@ def _build_system_prompt( if service.system_prompt: system_parts.append(service.system_prompt) system_parts.append( - "با استفاده از اطلاعات کامل مزرعه که در ادامه آمده به سوال کاربر پاسخ بده. " + "با استفاده از اطلاعات کامل مزرعه و اطلاعات بازیابی‌شده از پایگاه دانش که در ادامه آمده " + "به سوال کاربر پاسخ بده. " "اگر داده‌ای در اطلاعات مزرعه وجود دارد، همان را مبنای پاسخ قرار بده و چیزی حدس نزن. " + "نتایج بازیابی‌شده از پایگاه دانش را برای تکمیل یا توضیح پاسخ استفاده کن. " "اگر داده کافی نبود، این کمبود را شفاف بگو. " "پاسخ را به زبان کاربر بنویس." ) @@ -141,13 +156,15 @@ def build_rag_context( limit: int = 8, kb_name: str | None = None, service_id: str | None = None, + farm_details: dict | None = None, ) -> str: """ - ساخت context برای سرویس‌های توصیه با استفاده از RAG قدیمی. - این تابع برای سازگاری با irrigation/fertilization حفظ شده است. + ساخت context مشترک برای همه سرویس‌های RAG. + شامل: + - اطلاعات کامل مزرعه از farm_data/services.py + - جستجوی KB بر اساس پیام کاربر + - جستجوی KB بر اساس chunk های کامل داده مزرعه """ - from .retrieve import search_with_query - from .user_data import build_user_soil_text, build_user_weather_text logger.info( "Building RAG context sensor_uuid=%s kb_name=%s limit=%s query_len=%s", @@ -161,20 +178,23 @@ def build_rag_context( service = get_service_config(service_id, cfg) if service_id else None include_user_embeddings = service.use_user_embeddings if service else True resolved_kb_name = kb_name or (service.knowledge_base if service else None) + farm_context = _load_farm_details_context( + sensor_uuid=sensor_uuid, + farm_details=farm_details, + ) - if include_user_embeddings and sensor_uuid: - user_soil = build_user_soil_text(sensor_uuid) - if user_soil and user_soil.strip(): - parts.append("[داده‌های فعلی خاک شما]\n" + user_soil.strip()) + if farm_context: + parts.append(farm_context) - weather_text = build_user_weather_text(sensor_uuid) - if weather_text and weather_text.strip(): - parts.append("[پیش‌بینی هواشناسی]\n" + weather_text.strip()) + search_texts = [query] + if farm_context: + search_texts.extend(chunk_text(farm_context, config=cfg)) - results = search_with_query( - query, + results = search_with_texts( + search_texts, sensor_uuid=sensor_uuid, limit=limit, + per_text_limit=3, config=cfg, kb_name=resolved_kb_name, service_id=service_id, @@ -230,20 +250,23 @@ def chat_rag_stream( len(query or ""), ) - if farm_details is None: - farm_context = _format_farm_context(farm_uuid) - else: - farm_context = _format_farm_context_from_details(farm_details) + context = build_rag_context( + query=query, + sensor_uuid=farm_uuid, + config=cfg, + service_id=service_id, + farm_details=farm_details, + ) logger.info( - "Loaded farm context for farm_uuid=%s context_len=%s", + "Loaded augmented context for farm_uuid=%s context_len=%s", farm_uuid, - len(farm_context), + len(context), ) if system_override is not None: system_prompt = system_override else: - system_prompt = _build_system_prompt(service, query, farm_context, cfg) + system_prompt = _build_system_prompt(service, query, context, cfg) messages = [ {"role": "system", "content": system_prompt}, diff --git a/rag/retrieve.py b/rag/retrieve.py index 4eb20a2..6372fbe 100644 --- a/rag/retrieve.py +++ b/rag/retrieve.py @@ -2,10 +2,37 @@ بازیابی RAG: embed کوئری و جستجو در vector store """ from .config import load_rag_config, RAGConfig, get_service_config -from .embedding import embed_single +from .embedding import embed_single, embed_texts from .vector_store import QdrantVectorStore +def _resolve_search_options( + sensor_uuid: str | None = None, + config: RAGConfig | None = None, + kb_name: str | None = None, + service_id: str | None = None, + use_user_embeddings: bool | None = None, +) -> tuple[RAGConfig, list[str], list[str]]: + cfg = config or load_rag_config() + service = get_service_config(service_id, cfg) if service_id else None + resolved_kb_name = kb_name or (service.knowledge_base if service else None) + include_user_embeddings = ( + use_user_embeddings + if use_user_embeddings is not None + else (service.use_user_embeddings if service else True) + ) + + sensor_filters = ["__global__"] + if include_user_embeddings and sensor_uuid: + sensor_filters.insert(0, sensor_uuid) + + kb_filters = [resolved_kb_name] if resolved_kb_name else [] + if include_user_embeddings: + kb_filters.append("__all__") + + return cfg, sensor_filters, kb_filters + + def search_with_query( query: str, sensor_uuid: str | None = None, @@ -28,23 +55,14 @@ def search_with_query( Returns: لیست نتایج با id, score, text, metadata """ - cfg = config or load_rag_config() - service = get_service_config(service_id, cfg) if service_id else None - resolved_kb_name = kb_name or (service.knowledge_base if service else None) - include_user_embeddings = ( - use_user_embeddings - if use_user_embeddings is not None - else (service.use_user_embeddings if service else True) + cfg, sensor_filters, kb_filters = _resolve_search_options( + sensor_uuid=sensor_uuid, + config=config, + kb_name=kb_name, + service_id=service_id, + use_user_embeddings=use_user_embeddings, ) - sensor_filters = ["__global__"] - if include_user_embeddings and sensor_uuid: - sensor_filters.insert(0, sensor_uuid) - - kb_filters = [resolved_kb_name] if resolved_kb_name else [] - if include_user_embeddings: - kb_filters.append("__all__") - query_vector = embed_single(query, config=cfg) store = QdrantVectorStore(config=cfg) return store.search( @@ -54,3 +72,54 @@ def search_with_query( sensor_uuids=sensor_filters, kb_names=kb_filters, ) + + +def search_with_texts( + texts: list[str], + sensor_uuid: str | None = None, + limit: int = 8, + per_text_limit: int = 3, + score_threshold: float | None = None, + config: RAGConfig | None = None, + kb_name: str | None = None, + service_id: str | None = None, + use_user_embeddings: bool | None = None, +) -> list[dict]: + """ + چند متن را embed می‌کند و نتیجه جستجوها را به صورت dedupe شده برمی‌گرداند. + برای حالتی مناسب است که هم پیام کاربر و هم داده‌های مزرعه را علیه KB جستجو کنیم. + """ + normalized_texts = [text.strip() for text in texts if text and text.strip()] + if not normalized_texts: + return [] + + cfg, sensor_filters, kb_filters = _resolve_search_options( + sensor_uuid=sensor_uuid, + config=config, + kb_name=kb_name, + service_id=service_id, + use_user_embeddings=use_user_embeddings, + ) + + store = QdrantVectorStore(config=cfg) + vectors = embed_texts(normalized_texts, config=cfg) + merged_results: dict[str, dict] = {} + + for vector in vectors: + results = store.search( + query_vector=vector, + limit=per_text_limit, + score_threshold=score_threshold, + sensor_uuids=sensor_filters, + kb_names=kb_filters, + ) + for item in results: + current = merged_results.get(item["id"]) + if current is None or item["score"] > current["score"]: + merged_results[item["id"]] = item + + return sorted( + merged_results.values(), + key=lambda item: item["score"], + reverse=True, + )[:limit] diff --git a/rag/services/irrigation.py b/rag/services/irrigation.py index 46f6456..4b66e58 100644 --- a/rag/services/irrigation.py +++ b/rag/services/irrigation.py @@ -5,6 +5,8 @@ import json import logging +from django.db import transaction +from irrigation.models import IrrigationMethod from irrigation.evapotranspiration import calculate_forecast_water_needs, resolve_crop_profile, resolve_kc from farm_data.models import SensorData from rag.api_provider import get_chat_client @@ -42,6 +44,31 @@ DEFAULT_IRRIGATION_PROMPT = ( ) +def _resolve_irrigation_method( + sensor: SensorData | None, + irrigation_method_name: str | None, +) -> IrrigationMethod | None: + if irrigation_method_name: + return IrrigationMethod.objects.filter(name=irrigation_method_name).first() + if sensor is not None: + return sensor.irrigation_method + return None + + +def _persist_irrigation_method_on_farm( + sensor: SensorData | None, + irrigation_method: IrrigationMethod | None, +) -> None: + if sensor is None or irrigation_method is None: + return + if sensor.irrigation_method_id == irrigation_method.id: + return + + with transaction.atomic(): + sensor.irrigation_method = irrigation_method + sensor.save(update_fields=["irrigation_method", "updated_at"]) + + def get_irrigation_recommendation( sensor_uuid: str, plant_name: str | None = None, @@ -89,6 +116,9 @@ def get_irrigation_recommendation( .filter(farm_uuid=sensor_uuid) .first() ) + irrigation_method = _resolve_irrigation_method(sensor, irrigation_method_name) + _persist_irrigation_method_on_farm(sensor, irrigation_method) + plant = None resolved_plant_name = plant_name if sensor is not None and plant_name: @@ -106,19 +136,11 @@ def get_irrigation_recommendation( WeatherForecast.objects.filter(location=sensor.center_location, forecast_date__isnull=False) .order_by("forecast_date")[:7] ) - efficiency_percent = None - resolved_irrigation_method_name = irrigation_method_name - method = None - if irrigation_method_name: - from irrigation.models import IrrigationMethod - - method = IrrigationMethod.objects.filter(name=irrigation_method_name).first() - elif sensor is not None: - method = sensor.irrigation_method - if method is not None: - resolved_irrigation_method_name = method.name - - efficiency_percent = getattr(method, "water_efficiency_percent", None) if method else None + efficiency_percent = ( + getattr(irrigation_method, "water_efficiency_percent", None) + if irrigation_method + else None + ) daily_water_needs = calculate_forecast_water_needs( forecasts=forecasts, latitude_deg=float(sensor.center_location.latitude), @@ -132,8 +154,8 @@ def get_irrigation_recommendation( ) extra_parts: list[str] = [] - resolved_irrigation_method_name = irrigation_method_name or ( - sensor.irrigation_method.name if sensor is not None and sensor.irrigation_method else None + resolved_irrigation_method_name = ( + irrigation_method.name if irrigation_method is not None else None ) if resolved_plant_name and growth_stage: plant_text = build_plant_text(resolved_plant_name, growth_stage) @@ -222,6 +244,16 @@ def get_irrigation_recommendation( "crop_profile": crop_profile, "active_kc": active_kc, } + result["selected_irrigation_method"] = ( + { + "id": irrigation_method.id, + "name": irrigation_method.name, + "category": irrigation_method.category, + "water_efficiency_percent": irrigation_method.water_efficiency_percent, + } + if irrigation_method is not None + else None + ) _complete_audit_log( audit_log, json.dumps(result, ensure_ascii=False, default=str), diff --git a/rag/tests/test_chat_context.py b/rag/tests/test_chat_context.py index f2209d0..fd9d14e 100644 --- a/rag/tests/test_chat_context.py +++ b/rag/tests/test_chat_context.py @@ -2,51 +2,54 @@ from unittest.mock import patch from django.test import SimpleTestCase -from rag.chat import build_chat_context +from rag.chat import build_rag_context class ChatContextTests(SimpleTestCase): - @patch("rag.chat.search_with_query") - @patch("rag.chat._rank_text_chunks_by_query") + @patch("rag.chat.search_with_texts") @patch("rag.chat.chunk_text") - def test_build_chat_context_combines_farm_and_kb_context( + def test_build_rag_context_includes_full_farm_and_kb_results( self, mock_chunk_text, - mock_rank_text_chunks_by_query, - mock_search_with_query, + mock_search_with_texts, ): - mock_chunk_text.return_value = ["chunk-a", "chunk-b"] - mock_rank_text_chunks_by_query.return_value = ["chunk-b"] - mock_search_with_query.return_value = [ - {"text": "kb text 1"}, - {"text": "kb text 2"}, + mock_chunk_text.return_value = ["farm chunk 1", "farm chunk 2"] + mock_search_with_texts.return_value = [ + {"id": "kb-1", "score": 0.8, "text": "kb text 1", "metadata": {}}, + {"id": "kb-2", "score": 0.7, "text": "kb text 2", "metadata": {}}, ] - context = build_chat_context( + context = build_rag_context( query="وضعیت مزرعه چطور است؟", - farm_uuid="farm-123", + sensor_uuid="farm-123", + service_id="chat", farm_details={"sensor_payload": {"sensor-7-1": {"soil_moisture": 30}}}, ) - self.assertIn("[بخش‌های مرتبط بازیابی‌شده از اطلاعات مزرعه]", context) - self.assertIn("chunk-b", context) - self.assertIn("[اطلاعات بازیابی‌شده از پایگاه دانش]", context) + self.assertIn("[اطلاعات کامل مزرعه]", context) + self.assertIn("soil_moisture", context) + self.assertIn("[متن‌های مرجع]", context) self.assertIn("kb text 1", context) self.assertIn("kb text 2", context) + mock_search_with_texts.assert_called_once() + sent_texts = mock_search_with_texts.call_args.kwargs["texts"] + self.assertEqual(sent_texts[0], "وضعیت مزرعه چطور است؟") + self.assertIn("farm chunk 1", sent_texts) + self.assertIn("farm chunk 2", sent_texts) - @patch("rag.chat.search_with_query", return_value=[]) - @patch("rag.chat._rank_text_chunks_by_query", return_value=[]) + @patch("rag.chat.search_with_texts", return_value=[]) @patch("rag.chat.chunk_text", return_value=["farm chunk"]) - def test_build_chat_context_falls_back_to_full_farm_context( + def test_build_rag_context_returns_full_farm_when_kb_empty( self, _mock_chunk_text, - _mock_rank_text_chunks_by_query, - _mock_search_with_query, + _mock_search_with_texts, ): - context = build_chat_context( + context = build_rag_context( query="رطوبت چقدر است؟", - farm_uuid="farm-123", + sensor_uuid="farm-123", + service_id="chat", farm_details={"sensor_payload": {"sensor-7-1": {"soil_moisture": 30}}}, ) - self.assertEqual(context, "") + self.assertIn("[اطلاعات کامل مزرعه]", context) + self.assertIn("soil_moisture", context) diff --git a/rag/tests/test_recommendation_services.py b/rag/tests/test_recommendation_services.py index 8f8a143..001261b 100644 --- a/rag/tests/test_recommendation_services.py +++ b/rag/tests/test_recommendation_services.py @@ -68,6 +68,46 @@ class RecommendationServiceDefaultsTests(TestCase): mock_build_rag_context.assert_called_once() mock_build_plant_text.assert_called_once_with("گوجه‌فرنگی", "میوه‌دهی") mock_build_irrigation_method_text.assert_called_once_with("آبیاری قطره‌ای") + self.assertEqual( + result["selected_irrigation_method"]["name"], + "آبیاری قطره‌ای", + ) + + @patch("rag.services.irrigation.calculate_forecast_water_needs", return_value=[]) + @patch("rag.services.irrigation.resolve_kc", return_value=0.9) + @patch("rag.services.irrigation.resolve_crop_profile", return_value={}) + @patch("rag.services.irrigation.build_irrigation_method_text", return_value="method text") + @patch("rag.services.irrigation.build_plant_text", return_value="plant text") + @patch("rag.services.irrigation.build_rag_context", return_value="") + @patch("rag.services.irrigation.get_chat_client") + def test_irrigation_recommendation_persists_selected_method_on_farm( + self, + mock_get_chat_client, + _mock_build_rag_context, + _mock_build_plant_text, + mock_build_irrigation_method_text, + _mock_resolve_crop_profile, + _mock_resolve_kc, + _mock_calculate_forecast_water_needs, + ): + sprinkler = IrrigationMethod.objects.create(name="بارانی") + self.farm.irrigation_method = None + self.farm.save(update_fields=["irrigation_method", "updated_at"]) + + mock_response = Mock() + mock_response.choices = [Mock(message=Mock(content='{"plan": {"frequencyPerWeek": 4}}'))] + mock_get_chat_client.return_value.chat.completions.create.return_value = mock_response + + result = get_irrigation_recommendation( + sensor_uuid=str(self.farm_uuid), + growth_stage="میوه‌دهی", + irrigation_method_name="بارانی", + ) + + self.farm.refresh_from_db() + self.assertEqual(self.farm.irrigation_method_id, sprinkler.id) + self.assertEqual(result["selected_irrigation_method"]["id"], sprinkler.id) + mock_build_irrigation_method_text.assert_called_once_with("بارانی") @patch("rag.services.fertilization.build_plant_text", return_value="plant text") @patch("rag.services.fertilization.build_rag_context", return_value="") diff --git a/rag/urls.py b/rag/urls.py index 3b64bcb..63d75d3 100644 --- a/rag/urls.py +++ b/rag/urls.py @@ -8,6 +8,6 @@ from .views import ( urlpatterns = [ path("chat/", ChatView.as_view()), - path("recommend/irrigation/", IrrigationRecommendationView.as_view(), name="recommend-irrigation"), - path("recommend/fertilization/", FertilizationRecommendationView.as_view(), name="recommend-fertilization"), + # path("recommend/irrigation/", IrrigationRecommendationView.as_view(), name="recommend-irrigation"), + # path("recommend/fertilization/", FertilizationRecommendationView.as_view(), name="recommend-fertilization"), ]