from beanie import Document
from pydantic import Field
from typing import Optional
from datetime import datetime, timedelta
from bson import ObjectId

class RefreshToken(Document):
    user_id: ObjectId
    token: str
    session_id: Optional[ObjectId] = None  # Link to auth session
    
    # Token management
    created_at: datetime = Field(default_factory=datetime.utcnow)
    expires_at: datetime
    used_at: Optional[datetime] = None
    
    # Security
    is_active: bool = True
    revoked_at: Optional[datetime] = None
    revocation_reason: Optional[str] = None
    
    # Device info
    ip_address: Optional[str] = None
    user_agent: Optional[str] = None
    
    class Settings:
        name = "refresh_tokens"
        indexes = [
            "user_id",
            "token",
            "session_id",
            "expires_at",
            "is_active",
            "created_at"
        ]

    class Config:
        arbitrary_types_allowed = True

    @classmethod
    async def create_token(
        cls,
        user_id: ObjectId,
        token: str,
        session_id: Optional[ObjectId] = None,
        ip_address: Optional[str] = None,
        user_agent: Optional[str] = None,
        expires_in_days: int = 7
    ) -> "RefreshToken":
        """Create a new refresh token"""
        expires_at = datetime.utcnow() + timedelta(days=expires_in_days)
        
        refresh_token = cls(
            user_id=user_id,
            token=token,
            session_id=session_id,
            ip_address=ip_address,
            user_agent=user_agent,
            expires_at=expires_at
        )
        
        await refresh_token.save()
        return refresh_token

    def is_expired(self) -> bool:
        """Check if refresh token is expired"""
        return datetime.utcnow() > self.expires_at

    def is_valid(self) -> bool:
        """Check if refresh token is valid"""
        return (
            self.is_active and 
            not self.is_expired() and 
            self.revoked_at is None and
            self.used_at is None
        )

    async def use_token(self):
        """Mark token as used"""
        self.used_at = datetime.utcnow()
        await self.save()

    async def revoke(self, reason: str = "manual_revocation"):
        """Revoke the refresh token"""
        self.is_active = False
        self.revoked_at = datetime.utcnow()
        self.revocation_reason = reason
        await self.save()

    @classmethod
    async def revoke_all_user_tokens(cls, user_id: ObjectId, reason: str = "logout_all"):
        """Revoke all refresh tokens for a user"""
        tokens = await cls.find(cls.user_id == user_id, cls.is_active == True).to_list()
        
        for token in tokens:
            await token.revoke(reason)
