diff --git a/account/backends.py b/account/backends.py index 4ec204f..c9fca01 100644 --- a/account/backends.py +++ b/account/backends.py @@ -6,10 +6,7 @@ User = get_user_model() class MultiFieldBackend(ModelBackend): - """ - Authenticate against username, email, or phone_number. - Used for password-based login where the user can enter any of the three. - """ + """Authenticate with username, email, or phone_number.""" def authenticate(self, request, username=None, password=None, **kwargs): if username is None or password is None: @@ -19,12 +16,10 @@ class MultiFieldBackend(ModelBackend): user = User.objects.get( Q(username=username) | Q(email=username) | Q(phone_number=username) ) - print(user) except (User.DoesNotExist, User.MultipleObjectsReturned): User().set_password(password) return None - print(user.check_password(password) , self.user_can_authenticate(user)) - + if user.check_password(password) and self.user_can_authenticate(user): return user return None diff --git a/auth/views.py b/auth/views.py index 97cb4ef..fee89b2 100644 --- a/auth/views.py +++ b/auth/views.py @@ -1,17 +1,16 @@ import secrets -from django.contrib.auth import authenticate from django.conf import settings +from django.contrib.auth import authenticate from django.core.cache import cache from django.core.signing import BadSignature, SignatureExpired, TimestampSigner from django.db import IntegrityError -from rest_framework import serializers -from rest_framework import status +from rest_framework import serializers, status from rest_framework.permissions import AllowAny from rest_framework.response import Response from rest_framework.views import APIView from drf_spectacular.utils import extend_schema, extend_schema_view -from rest_framework_simplejwt.tokens import RefreshToken +from rest_framework_simplejwt.tokens import AccessToken from account.models import User from config.swagger import code_response @@ -30,7 +29,6 @@ OTP_SIGNER = TimestampSigner(salt="auth.otp") def _auth_user_to_data(user): - """Build AuthUser-shaped dict from Django User.""" if user is None or not getattr(user, "pk", None): return None return { @@ -43,6 +41,10 @@ def _auth_user_to_data(user): } +def _issue_token(user): + return str(AccessToken.for_user(user)) + + @extend_schema_view( post=extend_schema( tags=["Authentication"], @@ -54,13 +56,6 @@ def _auth_user_to_data(user): ), ) class RegisterView(APIView): - """ - POST /api/auth/register/ - Creates a new user with username, email, phone_number, and password. - All fields are required (first_name, last_name optional). - Returns JWT tokens and user data on success. - """ - permission_classes = [AllowAny] def post(self, request): @@ -87,23 +82,14 @@ class RegisterView(APIView): detail = "A user with this phone number already exists." else: detail = "A user with these credentials already exists." - return Response( - {"code": 400, "msg": detail}, - status=status.HTTP_400_BAD_REQUEST, - ) - - refresh = RefreshToken.for_user(user) - user_data = _auth_user_to_data(user) + return Response({"code": 400, "msg": detail}, status=status.HTTP_400_BAD_REQUEST) return Response( { "code": 201, "msg": "success", - "data": user_data, - "token": { - "access": str(refresh.access_token), - "refresh": str(refresh), - }, + "data": _auth_user_to_data(user), + "token": _issue_token(user), }, status=status.HTTP_201_CREATED, ) @@ -120,12 +106,6 @@ class RegisterView(APIView): ), ) class LoginView(APIView): - """ - POST /api/auth/login/ - Accepts identifier (username, email, or phone_number) + password. - Returns JWT tokens and user data on success. - """ - permission_classes = [AllowAny] def post(self, request): @@ -134,27 +114,17 @@ class LoginView(APIView): identifier = serializer.validated_data["identifier"] password = serializer.validated_data["password"] - user = authenticate(request, username=identifier, password=password) if user is None: - return Response( - {"code": 401, "msg": "Invalid credentials."}, - status=status.HTTP_401_UNAUTHORIZED, - ) - - refresh = RefreshToken.for_user(user) - user_data = _auth_user_to_data(user) + return Response({"code": 401, "msg": "Invalid credentials."}, status=status.HTTP_401_UNAUTHORIZED) return Response( { "code": 200, "msg": "success", - "data": user_data, - "token": { - "access": str(refresh.access_token), - "refresh": str(refresh), - }, + "data": _auth_user_to_data(user), + "token": _issue_token(user), }, status=status.HTTP_200_OK, ) @@ -178,12 +148,6 @@ class LoginView(APIView): ), ) class AuthenticationView(APIView): - """ - Single view for auth flows: request-otp and verify-otp. - Dispatches by path: .../request-otp/ -> request_otp, .../verify-otp/ -> verify_otp. - Response format: RequestOTPResponse / VerifyOTPResponse (code, msg, token, data when applicable). - """ - permission_classes = [AllowAny] def post(self, request): @@ -197,10 +161,8 @@ class AuthenticationView(APIView): phone_number = serializer.validated_data["phone_number"].strip() otp_code = f"{secrets.randbelow(1_000_000):06d}" - cache.set(f"otp_code:{phone_number}", otp_code, timeout=OTP_TTL_SECONDS) otp_token = OTP_SIGNER.sign(phone_number) - sms_sent = send_otp_sms(phone_number, otp_code) payload = {"code": 200, "msg": "success", "token": otp_token} @@ -208,7 +170,6 @@ class AuthenticationView(APIView): payload["sms_warning"] = "SMS delivery failed; OTP stored server-side." if settings.DEBUG: payload["debug_otp"] = otp_code - return Response(payload, status=status.HTTP_200_OK) def _verify_otp(self, request): @@ -219,24 +180,15 @@ class AuthenticationView(APIView): otp_code = serializer.validated_data["otp_code"].strip() try: - phone_number = OTP_SIGNER.unsign( - token, max_age=OTP_TTL_SECONDS - ) + phone_number = OTP_SIGNER.unsign(token, max_age=OTP_TTL_SECONDS) except (BadSignature, SignatureExpired): - return Response( - {"code": 400, "msg": "Token is invalid or expired."}, - status=status.HTTP_400_BAD_REQUEST, - ) + return Response({"code": 400, "msg": "Token is invalid or expired."}, status=status.HTTP_400_BAD_REQUEST) cached_otp = cache.get(f"otp_code:{phone_number}") if cached_otp is None or cached_otp != otp_code: - return Response( - {"code": 400, "msg": "OTP code is invalid or expired."}, - status=status.HTTP_400_BAD_REQUEST, - ) + return Response({"code": 400, "msg": "OTP code is invalid or expired."}, status=status.HTTP_400_BAD_REQUEST) cache.delete(f"otp_code:{phone_number}") - user, created = User.objects.get_or_create( phone_number=phone_number, defaults={ @@ -245,19 +197,12 @@ class AuthenticationView(APIView): }, ) - refresh = RefreshToken.for_user(user) - - user_data = _auth_user_to_data(user) - return Response( { "code": 200, "msg": "success", - "data": user_data, - "token": { - "access": str(refresh.access_token), - "refresh": str(refresh), - }, + "data": _auth_user_to_data(user), + "token": _issue_token(user), }, status=status.HTTP_200_OK, ) diff --git a/config/settings.py b/config/settings.py index 5656519..621940e 100644 --- a/config/settings.py +++ b/config/settings.py @@ -1,4 +1,5 @@ import os +from datetime import timedelta from pathlib import Path from dotenv import load_dotenv @@ -150,3 +151,11 @@ EXTERNAL_SERVICES = { "api_key": os.getenv("SENSOR_HUB_SERVICE_API_KEY", ""), }, } + + +SIMPLE_JWT = { + "ACCESS_TOKEN_LIFETIME": timedelta(days=7), + "REFRESH_TOKEN_LIFETIME": timedelta(days=7), + "ROTATE_REFRESH_TOKENS": False, + "BLACKLIST_AFTER_ROTATION": False, +} diff --git a/config/swagger.py b/config/swagger.py index 2f79a12..73b6fe6 100644 --- a/config/swagger.py +++ b/config/swagger.py @@ -3,9 +3,8 @@ from rest_framework import serializers from drf_spectacular.utils import inline_serializer -class TokenPairSerializer(serializers.Serializer): - access = serializers.CharField() - refresh = serializers.CharField() +class AuthTokenSerializer(serializers.Serializer): + token = serializers.CharField() def code_response(name, data=None, token=False, extra_fields=None): @@ -16,7 +15,7 @@ def code_response(name, data=None, token=False, extra_fields=None): if data is not None: fields["data"] = data if token: - fields["token"] = TokenPairSerializer() + fields["token"] = serializers.CharField() if extra_fields: fields.update(extra_fields) return inline_serializer(name=name, fields=fields) diff --git a/external_api_adapter/adapter.py b/external_api_adapter/adapter.py index 6fade40..31aac0c 100644 --- a/external_api_adapter/adapter.py +++ b/external_api_adapter/adapter.py @@ -25,7 +25,7 @@ class ExternalAPIAdapter: request_method = method.upper() self._validate_method(request_method) service = self.service_registry.get(service_name) - + if getattr(settings, "USE_EXTERNAL_API_MOCK", False): mock_response = self.mock_loader.load(service_name=service_name, path=path, method=request_method) return AdapterResponse( diff --git a/farm_ai_assistant/serializers.py b/farm_ai_assistant/serializers.py index 44b701d..00be2b5 100644 --- a/farm_ai_assistant/serializers.py +++ b/farm_ai_assistant/serializers.py @@ -5,29 +5,15 @@ from .models import Conversation, Message class ConversationListSerializer(serializers.ModelSerializer): conversation_id = serializers.UUIDField(source="uuid", read_only=True) - last_message_preview = serializers.SerializerMethodField() - message_count = serializers.IntegerField(read_only=True) class Meta: model = Conversation fields = [ "conversation_id", "title", - "farm_context", - "message_count", - "last_message_preview", - "created_at", "updated_at", ] - def get_last_message_preview(self, obj): - last_message = getattr(obj, "last_message", None) - if last_message is None: - last_message = obj.messages.order_by("-created_at", "-id").first() - if last_message is None: - return "" - return (last_message.content or "")[:120] - class MessageSerializer(serializers.ModelSerializer): message_id = serializers.UUIDField(source="uuid", read_only=True) diff --git a/farm_ai_assistant/views.py b/farm_ai_assistant/views.py index 73d141c..661d71f 100644 --- a/farm_ai_assistant/views.py +++ b/farm_ai_assistant/views.py @@ -1,6 +1,5 @@ """Farm AI Assistant API views.""" -from django.db.models import Count, OuterRef, Subquery from django.http import Http404, HttpResponse from rest_framework import serializers, status from rest_framework.permissions import IsAuthenticated @@ -36,15 +35,7 @@ class ChatListView(APIView): responses={200: status_response("FarmAiAssistantConversationListResponse", data=ConversationListSerializer(many=True))}, ) def get(self, request): - last_message_subquery = Message.objects.filter(conversation=OuterRef("pk")).order_by("-created_at", "-id") - conversations = ( - Conversation.objects.filter(owner=request.user) - .annotate( - message_count=Count("messages"), - last_message=Subquery(last_message_subquery.values("content")[:1]), - ) - .order_by("-updated_at", "-created_at") - ) + conversations = Conversation.objects.filter(owner=request.user).order_by("-updated_at", "-created_at") serializer = ConversationListSerializer(conversations, many=True) return Response({"status": "success", "data": serializer.data}, status=status.HTTP_200_OK) @@ -139,11 +130,29 @@ class ChatView(APIView): ) if isinstance(adapter_response.data, dict) and "body" in adapter_response.data: - conversation.save(update_fields=["updated_at"]) - return HttpResponse( - adapter_response.data["body"], + assistant_content = adapter_response.data.get("body") or "" + assistant_message = Message.objects.create( + conversation=conversation, + role=Message.ROLE_ASSISTANT, + content=assistant_content, + raw_response=adapter_response.data, + ) + + if not conversation.title: + conversation.title = (validated.get("content", "") or assistant_content or "New chat")[:255] + conversation.save(update_fields=["title", "updated_at"]) + else: + conversation.save(update_fields=["updated_at"]) + + return Response( + { + "conversation_id": str(conversation.uuid), + "user_message_id": str(user_message.uuid), + "assistant_message_id": str(assistant_message.uuid), + "content": assistant_content, + "content_type": adapter_response.data.get("content_type", "text/plain; charset=utf-8"), + }, status=adapter_response.status_code, - content_type=adapter_response.data.get("content_type", "text/plain; charset=utf-8"), ) assistant_content = "" @@ -165,23 +174,22 @@ class ChatView(APIView): else: conversation.save(update_fields=["updated_at"]) + conversation_uuid = str(conversation.uuid) response_data = adapter_response.data if isinstance(response_data, dict): + response_data.setdefault("conversation_id", conversation_uuid) + data = response_data.get("data") if isinstance(data, dict): - data.setdefault("conversation_id", str(conversation.uuid)) + data.setdefault("conversation_id", conversation_uuid) data.setdefault("user_message_id", str(user_message.uuid)) data.setdefault("assistant_message_id", str(assistant_message.uuid)) else: - response_data = { - "conversation_id": str(conversation.uuid), - "user_message_id": str(user_message.uuid), - "assistant_message_id": str(assistant_message.uuid), - "response": response_data, - } + response_data.setdefault("user_message_id", str(user_message.uuid)) + response_data.setdefault("assistant_message_id", str(assistant_message.uuid)) else: response_data = { - "conversation_id": str(conversation.uuid), + "conversation_id": conversation_uuid, "user_message_id": str(user_message.uuid), "assistant_message_id": str(assistant_message.uuid), "response": response_data,