UPDATE
This commit is contained in:
@@ -0,0 +1,34 @@
|
||||
import django.db.models.deletion
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
("farm_hub", "0002_seed_default_catalog"),
|
||||
("farm_ai_assistant", "0001_initial"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AddField(
|
||||
model_name="conversation",
|
||||
name="farm",
|
||||
field=models.ForeignKey(
|
||||
blank=True,
|
||||
null=True,
|
||||
on_delete=django.db.models.deletion.CASCADE,
|
||||
related_name="ai_conversations",
|
||||
to="farm_hub.farmhub",
|
||||
),
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name="message",
|
||||
name="farm",
|
||||
field=models.ForeignKey(
|
||||
blank=True,
|
||||
null=True,
|
||||
on_delete=django.db.models.deletion.CASCADE,
|
||||
related_name="ai_messages",
|
||||
to="farm_hub.farmhub",
|
||||
),
|
||||
),
|
||||
]
|
||||
@@ -3,6 +3,8 @@ import uuid
|
||||
from django.conf import settings
|
||||
from django.db import models
|
||||
|
||||
from farm_hub.models import FarmHub
|
||||
|
||||
|
||||
class Conversation(models.Model):
|
||||
uuid = models.UUIDField(default=uuid.uuid4, unique=True, editable=False, db_index=True)
|
||||
@@ -11,6 +13,13 @@ class Conversation(models.Model):
|
||||
on_delete=models.CASCADE,
|
||||
related_name="farm_ai_conversations",
|
||||
)
|
||||
farm = models.ForeignKey(
|
||||
FarmHub,
|
||||
on_delete=models.CASCADE,
|
||||
related_name="ai_conversations",
|
||||
null=True,
|
||||
blank=True,
|
||||
)
|
||||
title = models.CharField(max_length=255, blank=True, default="")
|
||||
farm_context = models.JSONField(default=dict, blank=True)
|
||||
created_at = models.DateTimeField(auto_now_add=True)
|
||||
@@ -38,6 +47,13 @@ class Message(models.Model):
|
||||
on_delete=models.CASCADE,
|
||||
related_name="messages",
|
||||
)
|
||||
farm = models.ForeignKey(
|
||||
FarmHub,
|
||||
on_delete=models.CASCADE,
|
||||
related_name="ai_messages",
|
||||
null=True,
|
||||
blank=True,
|
||||
)
|
||||
role = models.CharField(max_length=32, choices=ROLE_CHOICES)
|
||||
content = models.TextField(blank=True, default="")
|
||||
images = models.JSONField(default=list, blank=True)
|
||||
|
||||
@@ -17,10 +17,12 @@ class ChatSectionSerializer(serializers.Serializer):
|
||||
|
||||
class ConversationSummarySerializer(serializers.Serializer):
|
||||
id = serializers.UUIDField(source="uuid", read_only=True)
|
||||
farm_uuid = serializers.UUIDField(source="farm.farm_uuid", read_only=True)
|
||||
message_count = serializers.IntegerField(read_only=True)
|
||||
|
||||
|
||||
class ConversationCreateSerializer(serializers.Serializer):
|
||||
farm_uuid = serializers.UUIDField(required=True)
|
||||
title = serializers.CharField(required=False, allow_blank=True, max_length=255)
|
||||
farm_context = serializers.JSONField(required=False)
|
||||
|
||||
@@ -28,6 +30,7 @@ class ConversationCreateSerializer(serializers.Serializer):
|
||||
class ChatHistoryMessageSerializer(serializers.Serializer):
|
||||
message_id = serializers.UUIDField(read_only=True)
|
||||
conversation_id = serializers.UUIDField(read_only=True)
|
||||
farm_uuid = serializers.UUIDField(read_only=True)
|
||||
role = serializers.ChoiceField(choices=Message.ROLE_CHOICES, read_only=True)
|
||||
content = serializers.CharField(read_only=True, allow_blank=True)
|
||||
sections = ChatSectionSerializer(many=True, read_only=True)
|
||||
@@ -37,18 +40,21 @@ class ChatHistoryMessageSerializer(serializers.Serializer):
|
||||
|
||||
class ConversationMessagesSerializer(serializers.Serializer):
|
||||
conversation_id = serializers.UUIDField(read_only=True)
|
||||
farm_uuid = serializers.UUIDField(read_only=True)
|
||||
messages = ChatHistoryMessageSerializer(many=True, read_only=True)
|
||||
|
||||
|
||||
class ChatResponseDataSerializer(serializers.Serializer):
|
||||
message_id = serializers.UUIDField(read_only=True)
|
||||
conversation_id = serializers.UUIDField(read_only=True)
|
||||
farm_uuid = serializers.UUIDField(read_only=True)
|
||||
content = serializers.CharField(read_only=True, allow_blank=True)
|
||||
sections = ChatSectionSerializer(many=True, read_only=True)
|
||||
|
||||
|
||||
class ConversationDeleteSerializer(serializers.Serializer):
|
||||
conversation_id = serializers.UUIDField(read_only=True)
|
||||
farm_uuid = serializers.UUIDField(read_only=True)
|
||||
|
||||
|
||||
class ChatTaskSubmitDataSerializer(serializers.Serializer):
|
||||
@@ -57,18 +63,21 @@ class ChatTaskSubmitDataSerializer(serializers.Serializer):
|
||||
status_url = serializers.CharField(required=False, allow_blank=True)
|
||||
conversation_id = serializers.UUIDField(read_only=True)
|
||||
message_id = serializers.UUIDField(read_only=True)
|
||||
farm_uuid = serializers.UUIDField(read_only=True)
|
||||
|
||||
|
||||
class ChatTaskStatusDataSerializer(serializers.Serializer):
|
||||
task_id = serializers.CharField(required=False, allow_blank=True)
|
||||
status = serializers.CharField(required=False, allow_blank=True)
|
||||
conversation_id = serializers.UUIDField(read_only=True)
|
||||
farm_uuid = serializers.UUIDField(read_only=True)
|
||||
progress = serializers.JSONField(required=False)
|
||||
result = serializers.JSONField(required=False)
|
||||
error = serializers.CharField(required=False, allow_blank=True)
|
||||
|
||||
|
||||
class ChatPostSerializer(serializers.Serializer):
|
||||
farm_uuid = serializers.UUIDField(required=True)
|
||||
content = serializers.CharField(required=False, allow_blank=True, default="")
|
||||
images = serializers.ListField(
|
||||
child=serializers.CharField(),
|
||||
|
||||
@@ -2,6 +2,8 @@ from django.contrib.auth import get_user_model
|
||||
from django.test import TestCase, override_settings
|
||||
from rest_framework.test import APIRequestFactory, force_authenticate
|
||||
|
||||
from farm_hub.models import FarmHub, FarmType
|
||||
|
||||
from .models import Conversation, Message
|
||||
from .views import ChatTaskStatusView
|
||||
|
||||
@@ -16,24 +18,35 @@ class ChatTaskStatusViewTests(TestCase):
|
||||
email="farmer@example.com",
|
||||
phone_number="09120000000",
|
||||
)
|
||||
self.farm_type, _ = FarmType.objects.get_or_create(name="زراعی")
|
||||
self.farm = FarmHub.objects.create(
|
||||
owner=self.user,
|
||||
farm_type=self.farm_type,
|
||||
name="Farm 1",
|
||||
)
|
||||
self.conversation = Conversation.objects.create(
|
||||
owner=self.user,
|
||||
farm=self.farm,
|
||||
title="Irrigation chat",
|
||||
farm_context={},
|
||||
)
|
||||
self.user_message = Message.objects.create(
|
||||
conversation=self.conversation,
|
||||
farm=self.farm,
|
||||
role=Message.ROLE_USER,
|
||||
content="What is the best irrigation plan?",
|
||||
raw_response={
|
||||
"task_id": "farm-ai-chat-task-123",
|
||||
"status": "PENDING",
|
||||
"status_url": "/api/tasks/farm-ai-chat-task-123/status/",
|
||||
"farm_uuid": str(self.farm.farm_uuid),
|
||||
},
|
||||
)
|
||||
|
||||
def test_status_success_uses_chat_mock_result_and_persists_assistant_message(self):
|
||||
request = self.factory.get("/api/farm-ai-assistant/chat/task/farm-ai-chat-task-123/status/")
|
||||
request = self.factory.get(
|
||||
f"/api/farm-ai-assistant/chat/task/farm-ai-chat-task-123/status/?farm_uuid={self.farm.farm_uuid}"
|
||||
)
|
||||
force_authenticate(request, user=self.user)
|
||||
|
||||
response = ChatTaskStatusView.as_view()(request, task_id="farm-ai-chat-task-123")
|
||||
@@ -43,6 +56,7 @@ class ChatTaskStatusViewTests(TestCase):
|
||||
self.assertEqual(response.data["data"]["task_id"], "farm-ai-chat-task-123")
|
||||
self.assertEqual(response.data["data"]["status"], "SUCCESS")
|
||||
self.assertEqual(response.data["data"]["conversation_id"], str(self.conversation.uuid))
|
||||
self.assertEqual(response.data["data"]["farm_uuid"], str(self.farm.farm_uuid))
|
||||
self.assertEqual(response.data["data"]["result"]["content"], "Here is the recommended plan.")
|
||||
self.assertEqual(len(response.data["data"]["result"]["sections"]), 3)
|
||||
self.assertEqual(response.data["data"]["result"]["task_id"], "farm-ai-chat-task-123")
|
||||
@@ -53,6 +67,8 @@ class ChatTaskStatusViewTests(TestCase):
|
||||
.first()
|
||||
)
|
||||
self.assertIsNotNone(assistant_message)
|
||||
self.assertEqual(assistant_message.farm_id, self.farm.id)
|
||||
self.assertEqual(assistant_message.content, "Here is the recommended plan.")
|
||||
self.assertEqual(assistant_message.raw_response["task_id"], "farm-ai-chat-task-123")
|
||||
self.assertEqual(assistant_message.raw_response["farm_uuid"], str(self.farm.farm_uuid))
|
||||
self.assertEqual(len(assistant_message.raw_response["sections"]), 3)
|
||||
|
||||
+86
-30
@@ -14,6 +14,7 @@ from drf_spectacular.utils import OpenApiParameter, extend_schema
|
||||
from config.swagger import status_response
|
||||
from external_api_adapter import request as external_api_request
|
||||
from external_api_adapter.exceptions import ExternalAPIRequestError
|
||||
from farm_hub.models import FarmHub
|
||||
from .mock_data import CHAT_RESPONSE_DATA, CONTEXT_RESPONSE_DATA
|
||||
from .models import Conversation, Message
|
||||
from .serializers import (
|
||||
@@ -28,23 +29,45 @@ from .serializers import (
|
||||
)
|
||||
|
||||
|
||||
class ContextView(APIView):
|
||||
class FarmAccessMixin:
|
||||
@staticmethod
|
||||
def _get_farm(request, farm_uuid):
|
||||
if not farm_uuid:
|
||||
raise serializers.ValidationError({"farm_uuid": ["This field is required."]})
|
||||
try:
|
||||
return FarmHub.objects.get(farm_uuid=farm_uuid, owner=request.user)
|
||||
except FarmHub.DoesNotExist as exc:
|
||||
raise Http404("Farm not found") from exc
|
||||
|
||||
|
||||
class ContextView(FarmAccessMixin, APIView):
|
||||
permission_classes = [IsAuthenticated]
|
||||
|
||||
@extend_schema(
|
||||
tags=["Farm AI Assistant"],
|
||||
parameters=[
|
||||
OpenApiParameter(name="farm_uuid", type=OpenApiTypes.UUID, location=OpenApiParameter.QUERY, required=True),
|
||||
],
|
||||
responses={200: status_response("FarmAiAssistantContextResponse", data=serializers.JSONField())},
|
||||
)
|
||||
def get(self, request):
|
||||
farm = self._get_farm(request, request.query_params.get("farm_uuid"))
|
||||
data = deepcopy(CONTEXT_RESPONSE_DATA)
|
||||
data["farm_uuid"] = str(farm.farm_uuid)
|
||||
return Response(
|
||||
{"status": "success", "data": CONTEXT_RESPONSE_DATA},
|
||||
{"status": "success", "data": data},
|
||||
status=status.HTTP_200_OK,
|
||||
)
|
||||
|
||||
|
||||
class ConversationAccessMixin:
|
||||
class ConversationAccessMixin(FarmAccessMixin):
|
||||
@staticmethod
|
||||
def _get_conversation(request, conversation_id):
|
||||
def _get_conversation(request, conversation_id, farm_uuid=None):
|
||||
filters = {"uuid": conversation_id, "owner": request.user}
|
||||
if farm_uuid:
|
||||
filters["farm__farm_uuid"] = farm_uuid
|
||||
try:
|
||||
return Conversation.objects.get(uuid=conversation_id, owner=request.user)
|
||||
return Conversation.objects.select_related("farm").get(**filters)
|
||||
except Conversation.DoesNotExist as exc:
|
||||
raise Http404("Conversation not found") from exc
|
||||
|
||||
@@ -84,18 +107,20 @@ class ConversationAccessMixin:
|
||||
normalized_sections.append(normalized_section)
|
||||
return normalized_sections
|
||||
|
||||
def _build_mock_assistant_payload(self, conversation_id):
|
||||
def _build_mock_assistant_payload(self, conversation):
|
||||
payload = deepcopy(CHAT_RESPONSE_DATA)
|
||||
payload["conversation_id"] = str(conversation_id)
|
||||
payload["conversation_id"] = str(conversation.uuid)
|
||||
payload["farm_uuid"] = str(conversation.farm.farm_uuid)
|
||||
return payload
|
||||
|
||||
def _get_or_create_conversation(self, request, validated):
|
||||
conversation_id = validated.get("conversation_id")
|
||||
farm_context = validated.get("farm_context")
|
||||
title = validated.get("title", "").strip()
|
||||
farm = self._get_farm(request, validated.get("farm_uuid"))
|
||||
|
||||
if conversation_id:
|
||||
conversation = self._get_conversation(request, conversation_id)
|
||||
conversation = self._get_conversation(request, conversation_id, farm.farm_uuid)
|
||||
updated_fields = []
|
||||
if farm_context is not None:
|
||||
conversation.farm_context = farm_context
|
||||
@@ -110,6 +135,7 @@ class ConversationAccessMixin:
|
||||
|
||||
return Conversation.objects.create(
|
||||
owner=request.user,
|
||||
farm=farm,
|
||||
title=title or (validated.get("content", "")[:255]) or "New chat",
|
||||
farm_context=farm_context or {},
|
||||
)
|
||||
@@ -117,6 +143,7 @@ class ConversationAccessMixin:
|
||||
@staticmethod
|
||||
def _build_adapter_payload(request, validated, conversation):
|
||||
payload = {
|
||||
"farm_uuid": str(conversation.farm.farm_uuid),
|
||||
"content": validated.get("content", ""),
|
||||
"query": validated.get("content", ""),
|
||||
"images": validated.get("images", []),
|
||||
@@ -129,7 +156,7 @@ class ConversationAccessMixin:
|
||||
payload["title"] = validated.get("title", "")
|
||||
return payload
|
||||
|
||||
def _extract_assistant_payload(self, adapter_data, conversation_id):
|
||||
def _extract_assistant_payload(self, adapter_data, conversation):
|
||||
payload_source = adapter_data
|
||||
if isinstance(adapter_data, dict) and isinstance(adapter_data.get("data"), dict):
|
||||
payload_source = adapter_data["data"]
|
||||
@@ -149,13 +176,14 @@ class ConversationAccessMixin:
|
||||
|
||||
return {
|
||||
"message_id": "",
|
||||
"conversation_id": str(conversation_id),
|
||||
"conversation_id": str(conversation.uuid),
|
||||
"farm_uuid": str(conversation.farm.farm_uuid),
|
||||
"content": content,
|
||||
"sections": sections,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _extract_task_submit_payload(adapter_data, conversation_id, message_id):
|
||||
def _extract_task_submit_payload(adapter_data, conversation, message_id):
|
||||
payload_source = adapter_data
|
||||
if isinstance(adapter_data, dict) and isinstance(adapter_data.get("data"), dict):
|
||||
payload_source = adapter_data["data"]
|
||||
@@ -167,11 +195,12 @@ class ConversationAccessMixin:
|
||||
"task_id": str(payload_source.get("task_id") or ""),
|
||||
"status": str(payload_source.get("status") or ""),
|
||||
"status_url": str(payload_source.get("status_url") or ""),
|
||||
"conversation_id": str(conversation_id),
|
||||
"conversation_id": str(conversation.uuid),
|
||||
"message_id": str(message_id),
|
||||
"farm_uuid": str(conversation.farm.farm_uuid),
|
||||
}
|
||||
|
||||
def _extract_task_status_payload(self, adapter_data, task_id, conversation_id=None):
|
||||
def _extract_task_status_payload(self, adapter_data, task_id, conversation=None, farm_uuid=None):
|
||||
payload_source = adapter_data
|
||||
if isinstance(adapter_data, dict) and isinstance(adapter_data.get("data"), dict):
|
||||
payload_source = adapter_data["data"]
|
||||
@@ -183,8 +212,11 @@ class ConversationAccessMixin:
|
||||
"task_id": str(payload_source.get("task_id") or task_id),
|
||||
"status": str(payload_source.get("status") or ""),
|
||||
}
|
||||
if conversation_id:
|
||||
task_status_payload["conversation_id"] = str(conversation_id)
|
||||
if conversation:
|
||||
task_status_payload["conversation_id"] = str(conversation.uuid)
|
||||
task_status_payload["farm_uuid"] = str(conversation.farm.farm_uuid)
|
||||
elif farm_uuid:
|
||||
task_status_payload["farm_uuid"] = str(farm_uuid)
|
||||
|
||||
progress = payload_source.get("progress")
|
||||
if progress is not None:
|
||||
@@ -231,6 +263,7 @@ class ConversationAccessMixin:
|
||||
return {
|
||||
"message_id": str(message.uuid),
|
||||
"conversation_id": str(message.conversation.uuid),
|
||||
"farm_uuid": str(message.farm.farm_uuid),
|
||||
"role": message.role,
|
||||
"content": message.content,
|
||||
"sections": ConversationAccessMixin._normalize_sections(sections),
|
||||
@@ -239,11 +272,12 @@ class ConversationAccessMixin:
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _find_user_message_for_task(request, task_id):
|
||||
def _find_user_message_for_task(request, task_id, farm_uuid):
|
||||
return (
|
||||
Message.objects.select_related("conversation")
|
||||
Message.objects.select_related("conversation", "farm")
|
||||
.filter(
|
||||
conversation__owner=request.user,
|
||||
farm__farm_uuid=farm_uuid,
|
||||
role=Message.ROLE_USER,
|
||||
raw_response__task_id=task_id,
|
||||
)
|
||||
@@ -252,7 +286,7 @@ class ConversationAccessMixin:
|
||||
)
|
||||
|
||||
def _persist_task_result(self, user_message, task_id, result):
|
||||
assistant_payload = self._extract_assistant_payload(result, user_message.conversation.uuid)
|
||||
assistant_payload = self._extract_assistant_payload(result, user_message.conversation)
|
||||
assistant_message = (
|
||||
user_message.conversation.messages.filter(
|
||||
role=Message.ROLE_ASSISTANT,
|
||||
@@ -265,6 +299,7 @@ class ConversationAccessMixin:
|
||||
if assistant_message is None:
|
||||
assistant_message = Message.objects.create(
|
||||
conversation=user_message.conversation,
|
||||
farm=user_message.farm,
|
||||
role=Message.ROLE_ASSISTANT,
|
||||
content=assistant_payload.get("content", ""),
|
||||
raw_response={},
|
||||
@@ -293,11 +328,15 @@ class ChatListCreateView(ConversationAccessMixin, APIView):
|
||||
|
||||
@extend_schema(
|
||||
tags=["Farm AI Assistant"],
|
||||
parameters=[
|
||||
OpenApiParameter(name="farm_uuid", type=OpenApiTypes.UUID, location=OpenApiParameter.QUERY, required=True),
|
||||
],
|
||||
responses={200: status_response("FarmAiAssistantConversationListResponse", data=ConversationSummarySerializer(many=True))},
|
||||
)
|
||||
def get(self, request):
|
||||
farm = self._get_farm(request, request.query_params.get("farm_uuid"))
|
||||
conversations = (
|
||||
Conversation.objects.filter(owner=request.user)
|
||||
Conversation.objects.filter(owner=request.user, farm=farm)
|
||||
.annotate(message_count=Count("messages"))
|
||||
.order_by("-updated_at", "-created_at")
|
||||
)
|
||||
@@ -314,8 +353,10 @@ class ChatListCreateView(ConversationAccessMixin, APIView):
|
||||
serializer.is_valid(raise_exception=True)
|
||||
|
||||
validated = serializer.validated_data
|
||||
farm = self._get_farm(request, validated.get("farm_uuid"))
|
||||
conversation = Conversation.objects.create(
|
||||
owner=request.user,
|
||||
farm=farm,
|
||||
title=validated.get("title", "").strip() or "New chat",
|
||||
farm_context=validated.get("farm_context") or {},
|
||||
)
|
||||
@@ -323,6 +364,7 @@ class ChatListCreateView(ConversationAccessMixin, APIView):
|
||||
response_serializer = ConversationSummarySerializer(
|
||||
{
|
||||
"uuid": conversation.uuid,
|
||||
"farm": farm,
|
||||
"message_count": 0,
|
||||
}
|
||||
)
|
||||
@@ -336,18 +378,21 @@ class ChatMessagesView(ConversationAccessMixin, APIView):
|
||||
tags=["Farm AI Assistant"],
|
||||
parameters=[
|
||||
OpenApiParameter(name="conversation_id", type=OpenApiTypes.UUID, location=OpenApiParameter.PATH),
|
||||
OpenApiParameter(name="farm_uuid", type=OpenApiTypes.UUID, location=OpenApiParameter.QUERY, required=True),
|
||||
],
|
||||
responses={200: status_response("FarmAiAssistantMessageListResponse", data=ConversationMessagesSerializer())},
|
||||
)
|
||||
def get(self, request, conversation_id):
|
||||
conversation = self._get_conversation(request, conversation_id)
|
||||
messages = conversation.messages.all()
|
||||
farm = self._get_farm(request, request.query_params.get("farm_uuid"))
|
||||
conversation = self._get_conversation(request, conversation_id, farm.farm_uuid)
|
||||
messages = conversation.messages.select_related("farm").all()
|
||||
serialized_messages = [self._serialize_chat_message(message) for message in messages]
|
||||
return Response(
|
||||
{
|
||||
"status": "success",
|
||||
"data": {
|
||||
"conversation_id": str(conversation.uuid),
|
||||
"farm_uuid": str(farm.farm_uuid),
|
||||
"messages": serialized_messages,
|
||||
},
|
||||
},
|
||||
@@ -362,18 +407,22 @@ class ChatDetailView(ConversationAccessMixin, APIView):
|
||||
tags=["Farm AI Assistant"],
|
||||
parameters=[
|
||||
OpenApiParameter(name="conversation_id", type=OpenApiTypes.UUID, location=OpenApiParameter.PATH),
|
||||
OpenApiParameter(name="farm_uuid", type=OpenApiTypes.UUID, location=OpenApiParameter.QUERY, required=True),
|
||||
],
|
||||
responses={200: status_response("FarmAiAssistantConversationDeleteResponse", data=ConversationDeleteSerializer())},
|
||||
)
|
||||
def delete(self, request, conversation_id):
|
||||
conversation = self._get_conversation(request, conversation_id)
|
||||
farm = self._get_farm(request, request.query_params.get("farm_uuid"))
|
||||
conversation = self._get_conversation(request, conversation_id, farm.farm_uuid)
|
||||
deleted_conversation_id = str(conversation.uuid)
|
||||
deleted_farm_uuid = str(conversation.farm.farm_uuid)
|
||||
conversation.delete()
|
||||
return Response(
|
||||
{
|
||||
"status": "success",
|
||||
"data": {
|
||||
"conversation_id": deleted_conversation_id,
|
||||
"farm_uuid": deleted_farm_uuid,
|
||||
},
|
||||
},
|
||||
status=status.HTTP_200_OK,
|
||||
@@ -397,10 +446,11 @@ class ChatView(ConversationAccessMixin, APIView):
|
||||
|
||||
user_message = Message.objects.create(
|
||||
conversation=conversation,
|
||||
farm=conversation.farm,
|
||||
role=Message.ROLE_USER,
|
||||
content=validated.get("content", ""),
|
||||
images=validated.get("images", []),
|
||||
raw_response={},
|
||||
raw_response={"farm_uuid": str(conversation.farm.farm_uuid)},
|
||||
)
|
||||
|
||||
adapter_payload = self._build_adapter_payload(request, validated, conversation)
|
||||
@@ -420,14 +470,15 @@ class ChatView(ConversationAccessMixin, APIView):
|
||||
},
|
||||
status=adapter_response.status_code,
|
||||
)
|
||||
assistant_payload = self._extract_assistant_payload(adapter_response.data, conversation.uuid)
|
||||
assistant_payload = self._extract_assistant_payload(adapter_response.data, conversation)
|
||||
response_status_code = adapter_response.status_code
|
||||
except ExternalAPIRequestError:
|
||||
assistant_payload = self._build_mock_assistant_payload(conversation.uuid)
|
||||
assistant_payload = self._build_mock_assistant_payload(conversation)
|
||||
response_status_code = status.HTTP_200_OK
|
||||
|
||||
assistant_message = Message.objects.create(
|
||||
conversation=conversation,
|
||||
farm=conversation.farm,
|
||||
role=Message.ROLE_ASSISTANT,
|
||||
content=assistant_payload.get("content", ""),
|
||||
raw_response={},
|
||||
@@ -467,10 +518,11 @@ class ChatTaskCreateView(ConversationAccessMixin, APIView):
|
||||
conversation = self._get_or_create_conversation(request, validated)
|
||||
user_message = Message.objects.create(
|
||||
conversation=conversation,
|
||||
farm=conversation.farm,
|
||||
role=Message.ROLE_USER,
|
||||
content=validated.get("content", ""),
|
||||
images=validated.get("images", []),
|
||||
raw_response={},
|
||||
raw_response={"farm_uuid": str(conversation.farm.farm_uuid)},
|
||||
)
|
||||
|
||||
adapter_payload = self._build_adapter_payload(request, validated, conversation)
|
||||
@@ -503,7 +555,7 @@ class ChatTaskCreateView(ConversationAccessMixin, APIView):
|
||||
|
||||
task_payload = self._extract_task_submit_payload(
|
||||
adapter_response.data,
|
||||
conversation.uuid,
|
||||
conversation,
|
||||
user_message.uuid,
|
||||
)
|
||||
user_message.raw_response = task_payload
|
||||
@@ -526,15 +578,18 @@ class ChatTaskStatusView(ConversationAccessMixin, APIView):
|
||||
tags=["Farm AI Assistant"],
|
||||
parameters=[
|
||||
OpenApiParameter(name="task_id", type=OpenApiTypes.STR, location=OpenApiParameter.PATH),
|
||||
OpenApiParameter(name="farm_uuid", type=OpenApiTypes.UUID, location=OpenApiParameter.QUERY, required=True),
|
||||
],
|
||||
responses={200: status_response("FarmAiAssistantChatTaskStatusResponse", data=ChatTaskStatusDataSerializer())},
|
||||
)
|
||||
def get(self, request, task_id):
|
||||
farm = self._get_farm(request, request.query_params.get("farm_uuid"))
|
||||
try:
|
||||
adapter_response = external_api_request(
|
||||
"ai",
|
||||
f"/tasks/{task_id}/status",
|
||||
method="GET",
|
||||
query={"farm_uuid": str(farm.farm_uuid)},
|
||||
)
|
||||
except ExternalAPIRequestError:
|
||||
return Response(
|
||||
@@ -556,12 +611,13 @@ class ChatTaskStatusView(ConversationAccessMixin, APIView):
|
||||
status=adapter_response.status_code,
|
||||
)
|
||||
|
||||
user_message = self._find_user_message_for_task(request, task_id)
|
||||
conversation_id = user_message.conversation.uuid if user_message else None
|
||||
user_message = self._find_user_message_for_task(request, task_id, farm.farm_uuid)
|
||||
conversation = user_message.conversation if user_message else None
|
||||
task_status_payload = self._extract_task_status_payload(
|
||||
adapter_response.data,
|
||||
task_id,
|
||||
conversation_id=conversation_id,
|
||||
conversation=conversation,
|
||||
farm_uuid=farm.farm_uuid,
|
||||
)
|
||||
|
||||
result = self._extract_structured_task_result(adapter_response.data)
|
||||
|
||||
Reference in New Issue
Block a user