import os
import sys
import logging
import asyncio
import uuid
from typing import List, Optional, Dict, Any
from datetime import datetime

from fastapi import FastAPI, HTTPException, BackgroundTasks, UploadFile, File
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import uvicorn

# Add shared modules to path
sys.path.append('/app')
sys.path.append('/app/shared')

from shared.models.schemas import (
    SearchRequest, SearchResponse, BatchSearchRequest, 
    ProcessingStatus, ErrorResponse, HealthStatus
)
from shared.utils.kafka_utils import KafkaClient, get_search_topic, TOPICS
from shared.utils.redis_utils import RedisClient, get_cache_key, CACHE_KEYS, CACHE_EXPIRATION

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

app = FastAPI(
    title="WebSearch Microservices API",
    description="API Gateway for distributed web search system",
    version="1.0.0"
)

# CORS middleware
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Global clients
kafka_client = None
redis_client = None

@app.on_event("startup")
async def startup_event():
    """Initialize connections on startup"""
    global kafka_client, redis_client
    
    try:
        kafka_client = KafkaClient()
        redis_client = RedisClient()
        logger.info("API Gateway started successfully")
    except Exception as e:
        logger.error(f"Failed to start API Gateway: {e}")
        raise

@app.on_event("shutdown")
async def shutdown_event():
    """Clean up connections on shutdown"""
    global kafka_client, redis_client
    
    if kafka_client:
        kafka_client.close()
    if redis_client:
        redis_client.close()
    logger.info("API Gateway shutdown completed")

class SingleSearchRequest(BaseModel):
    query: str
    search_engine: str = "google"
    max_results: int = 10
    delay: int = 2
    timeout: int = 30
    headless: bool = True

@app.get("/")
async def root():
    """Root endpoint"""
    return {
        "message": "WebSearch Microservices API Gateway",
        "version": "1.0.0",
        "timestamp": datetime.now().isoformat()
    }

@app.get("/health")
async def health_check():
    """Health check endpoint"""
    try:
        # Check Kafka connection
        kafka_healthy = kafka_client is not None
        
        # Check Redis connection
        redis_healthy = False
        if redis_client:
            try:
                redis_client.client.ping()
                redis_healthy = True
            except:
                pass
        
        status = "healthy" if kafka_healthy and redis_healthy else "unhealthy"
        
        return HealthStatus(
            service_name="api-gateway",
            status=status,
            details={
                "kafka": "healthy" if kafka_healthy else "unhealthy",
                "redis": "healthy" if redis_healthy else "unhealthy"
            }
        )
    except Exception as e:
        logger.error(f"Health check failed: {e}")
        return HealthStatus(
            service_name="api-gateway",
            status="unhealthy",
            details={"error": str(e)}
        )

@app.post("/search", response_model=Dict[str, Any])
async def single_search(request: SingleSearchRequest, background_tasks: BackgroundTasks):
    """Perform a single search query"""
    try:
        # Create search request
        search_request = SearchRequest(
            query=request.query,
            search_engine=request.search_engine,
            max_results=request.max_results,
            delay=request.delay,
            timeout=request.timeout,
            headless=request.headless
        )
        
        # Store request status in Redis
        status_key = get_cache_key(CACHE_KEYS['REQUEST_STATUS'], request_id=search_request.request_id)
        initial_status = ProcessingStatus(
            request_id=search_request.request_id,
            status="pending",
            message=f"Search request for '{request.query}' using {request.search_engine}"
        )
        redis_client.set(status_key, initial_status.dict(), expire=CACHE_EXPIRATION['REQUEST_STATUS'])
        
        # Send message to appropriate search engine topic
        topic = get_search_topic(request.search_engine)
        message_sent = kafka_client.send_message(
            topic=topic,
            message=search_request.dict(),
            key=search_request.request_id
        )
        
        if not message_sent:
            raise HTTPException(status_code=500, detail="Failed to send search request")
        
        logger.info(f"Search request {search_request.request_id} sent to {topic}")
        
        return {
            "request_id": search_request.request_id,
            "status": "pending",
            "message": f"Search request submitted for '{request.query}' using {request.search_engine}",
            "check_status_url": f"/status/{search_request.request_id}"
        }
        
    except Exception as e:
        logger.error(f"Error in single search: {e}")
        raise HTTPException(status_code=500, detail=str(e))

@app.post("/search/batch", response_model=Dict[str, Any])
async def batch_search(request: BatchSearchRequest, background_tasks: BackgroundTasks):
    """Perform batch search with multiple queries and search engines"""
    try:
        # Store batch request status in Redis
        status_key = get_cache_key(CACHE_KEYS['REQUEST_STATUS'], request_id=request.request_id)
        initial_status = ProcessingStatus(
            request_id=request.request_id,
            status="pending",
            message=f"Batch search with {len(request.queries)} queries across {len(request.search_engines)} engines"
        )
        redis_client.set(status_key, initial_status.dict(), expire=CACHE_EXPIRATION['REQUEST_STATUS'])
        
        # Create individual search requests for each query-engine combination
        search_requests = []
        for query in request.queries:
            for engine in request.search_engines:
                search_req = SearchRequest(
                    query=query,
                    search_engine=engine,
                    max_results=request.max_results,
                    delay=request.delay,
                    user_id=request.user_id,
                    session_id=request.session_id
                )
                search_requests.append(search_req)
        
        # Send all search requests to appropriate topics
        for search_req in search_requests:
            topic = get_search_topic(search_req.search_engine)
            kafka_client.send_message(
                topic=topic,
                message=search_req.dict(),
                key=search_req.request_id
            )
        
        logger.info(f"Batch search request {request.request_id} with {len(search_requests)} individual searches")
        
        return {
            "request_id": request.request_id,
            "status": "pending",
            "total_searches": len(search_requests),
            "message": f"Batch search submitted with {len(request.queries)} queries across {len(request.search_engines)} engines",
            "check_status_url": f"/status/{request.request_id}"
        }
        
    except Exception as e:
        logger.error(f"Error in batch search: {e}")
        raise HTTPException(status_code=500, detail=str(e))

@app.get("/status/{request_id}", response_model=ProcessingStatus)
async def get_status(request_id: str):
    """Get the status of a search request"""
    try:
        status_key = get_cache_key(CACHE_KEYS['REQUEST_STATUS'], request_id=request_id)
        status_data = redis_client.get(status_key)
        
        if not status_data:
            raise HTTPException(status_code=404, detail="Request not found")
        
        return ProcessingStatus(**status_data)
        
    except HTTPException:
        raise
    except Exception as e:
        logger.error(f"Error getting status for {request_id}: {e}")
        raise HTTPException(status_code=500, detail=str(e))

@app.get("/results/{request_id}")
async def get_results(request_id: str):
    """Get the results of a completed search request"""
    try:
        results_key = get_cache_key(CACHE_KEYS['SEARCH_RESULTS'], request_id=request_id)
        results_data = redis_client.get(results_key)
        
        if not results_data:
            raise HTTPException(status_code=404, detail="Results not found")
        
        return results_data
        
    except HTTPException:
        raise
    except Exception as e:
        logger.error(f"Error getting results for {request_id}: {e}")
        raise HTTPException(status_code=500, detail=str(e))

@app.post("/upload/queries")
async def upload_queries(file: UploadFile = File(...)):
    """Upload a file containing queries for batch processing"""
    try:
        if not file.filename.endswith(('.csv', '.xlsx', '.xls')):
            raise HTTPException(status_code=400, detail="Only CSV and Excel files are supported")
        
        # Save uploaded file
        file_path = f"/app/results/uploads/{file.filename}"
        os.makedirs(os.path.dirname(file_path), exist_ok=True)
        
        with open(file_path, "wb") as buffer:
            content = await file.read()
            buffer.write(content)
        
        # TODO: Process the file and extract queries
        # For now, return file info
        return {
            "filename": file.filename,
            "size": len(content),
            "message": "File uploaded successfully. Processing functionality to be implemented."
        }
        
    except Exception as e:
        logger.error(f"Error uploading file: {e}")
        raise HTTPException(status_code=500, detail=str(e))

@app.get("/engines")
async def get_available_engines():
    """Get list of available search engines"""
    return {
        "engines": ["google", "bing", "yandex", "duckduckgo", "yahoo", "baidu"],
        "default": "google"
    }

@app.get("/metrics")
async def get_metrics():
    """Get system metrics and statistics"""
    try:
        # TODO: Implement metrics collection
        return {
            "requests_processed": 0,
            "active_requests": 0,
            "success_rate": 0.0,
            "average_response_time": 0.0,
            "timestamp": datetime.now().isoformat()
        }
    except Exception as e:
        logger.error(f"Error getting metrics: {e}")
        raise HTTPException(status_code=500, detail=str(e))

if __name__ == "__main__":
    uvicorn.run(
        "main:app",
        host="0.0.0.0",
        port=8000,
        log_level="info",
        reload=False
    )
