import redis.asyncio as redis
from decouple import config
import json
from typing import Optional, Any, Dict
import logging

logger = logging.getLogger(__name__)

class RedisClient:
    def __init__(self):
        self.redis_url = config('REDIS_URL', default='redis://localhost:6379')
        self.redis: Optional[redis.Redis] = None

    async def connect(self):
        """Connect to Redis"""
        try:
            self.redis = redis.from_url(self.redis_url, decode_responses=True)
            await self.redis.ping()
            logger.info("Connected to Redis")
            return self.redis
        except Exception as e:
            logger.error(f"Failed to connect to Redis: {e}")
            raise

    async def disconnect(self):
        """Disconnect from Redis"""
        if self.redis:
            await self.redis.close()
            logger.info("Disconnected from Redis")

    async def set(self, key: str, value: Any, expire: Optional[int] = None) -> bool:
        """Set a key-value pair in Redis"""
        try:
            if not self.redis:
                await self.connect()
            
            if isinstance(value, (dict, list)):
                value = json.dumps(value)
            
            result = await self.redis.set(key, value, ex=expire)
            return bool(result)
        except Exception as e:
            logger.error(f"Redis set error: {e}")
            return False

    async def get(self, key: str) -> Optional[str]:
        """Get a value from Redis"""
        try:
            if not self.redis:
                await self.connect()
            return await self.redis.get(key)
        except Exception as e:
            logger.error(f"Redis get error: {e}")
            return None

    async def get_json(self, key: str) -> Optional[Dict[str, Any]]:
        """Get and parse JSON value from Redis"""
        try:
            value = await self.get(key)
            if value:
                return json.loads(value)
            return None
        except json.JSONDecodeError as e:
            logger.error(f"JSON decode error: {e}")
            return None

    async def delete(self, key: str) -> bool:
        """Delete a key from Redis"""
        try:
            if not self.redis:
                await self.connect()
            result = await self.redis.delete(key)
            return bool(result)
        except Exception as e:
            logger.error(f"Redis delete error: {e}")
            return False

    async def exists(self, key: str) -> bool:
        """Check if a key exists in Redis"""
        try:
            if not self.redis:
                await self.connect()
            result = await self.redis.exists(key)
            return bool(result)
        except Exception as e:
            logger.error(f"Redis exists error: {e}")
            return False

    async def increment(self, key: str, amount: int = 1) -> Optional[int]:
        """Increment a key's value"""
        try:
            if not self.redis:
                await self.connect()
            return await self.redis.incr(key, amount)
        except Exception as e:
            logger.error(f"Redis increment error: {e}")
            return None

    async def set_hash(self, key: str, field: str, value: Any) -> bool:
        """Set a field in a Redis hash"""
        try:
            if not self.redis:
                await self.connect()
            if isinstance(value, (dict, list)):
                value = json.dumps(value)
            result = await self.redis.hset(key, field, value)
            return bool(result)
        except Exception as e:
            logger.error(f"Redis hset error: {e}")
            return False

    async def get_hash(self, key: str, field: str) -> Optional[str]:
        """Get a field from a Redis hash"""
        try:
            if not self.redis:
                await self.connect()
            return await self.redis.hget(key, field)
        except Exception as e:
            logger.error(f"Redis hget error: {e}")
            return None

# Global Redis client instance
redis_client = RedisClient()
