from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
import google.generativeai as genai
from PIL import Image
import io
import os
from typing import List
import logging
from pydantic import BaseModel
import traceback
import time
import asyncio
from dotenv import load_dotenv
import base64
from datetime import datetime
import csv

# Load environment variables
load_dotenv()
  
# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

app = FastAPI(
    title="Tree Species and DBH Detection API",
    description="Upload 1-10 tree images for batch analysis",
    version="2.0.0"
)

# CORS - Allow all origins
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Configure Gemini API with multiple keys for rotation
GEMINI_API_KEYS = []
api_key_1 = os.getenv("GEMINI_API_KEY")
api_key_2 = os.getenv("GEMINI_API_KEY_2")
api_key_3 = os.getenv("GEMINI_API_KEY_3")
api_key_4 = os.getenv("GEMINI_API_KEY_4")
api_key_5 = os.getenv("GEMINI_API_KEY_5")

# Add available keys to the list
for key in [api_key_1, api_key_2, api_key_3, api_key_4, api_key_5]:
    if key:
        GEMINI_API_KEYS.append(key)

if not GEMINI_API_KEYS:
    logger.error("❌ No GEMINI_API_KEY found!")
    raise ValueError("At least one GEMINI_API_KEY must be set in .env file!")

# Track current key index
current_key_index = 0
logger.info(f"✅ Loaded {len(GEMINI_API_KEYS)} API key(s) for rotation")

# ============== CONFIGURATION ==============
MAX_IMAGES = 10  # Maximum images allowed per batch
RATE_LIMIT_DELAY = 5.5  # Seconds between API calls
MAX_RETRIES = len(GEMINI_API_KEYS)  # Retry with each available key
# ===========================================

# History storage (use database in production)
analysis_history = []


class TreeAnalysis(BaseModel):
    """Single tree analysis result"""
    success: bool = True
    filename: str
    image_id: int
    species: str = "Unknown"
    species_hindi: str = None
    species_marathi: str = None
    scientific_name: str = None
    confidence: str = "Low"
    height: str = None
    canopy: str = None
    girth: str = None
    dbh_estimate: str = None
    condition: str = None
    characteristics: List[str] = []
    additional_info: str = None
    recommendations: str = None
    remarks: str = None
    timestamp: str = None
    error: str = None


class BatchAnalysisResponse(BaseModel):
    """Response for batch tree analysis"""
    total_images: int
    successful: int
    failed: int
    unhealthy_trees: int
    processing_time: float
    results: List[TreeAnalysis]


def get_next_api_key():
    """Get next API key in rotation"""
    global current_key_index
    key = GEMINI_API_KEYS[current_key_index]
    current_key_index = (current_key_index + 1) % len(GEMINI_API_KEYS)
    return key


def resize_image(image: Image.Image, max_size: int = 2048) -> Image.Image:
    """Resize image if too large"""
    if max(image.size) > max_size:
        image.thumbnail((max_size, max_size), Image.Resampling.LANCZOS)
        logger.info(f"   Resized to {image.size}")
    return image


def parse_gemini_response(response_text: str, filename: str, image_id: int) -> TreeAnalysis:
    """Parse Gemini response into structured TreeAnalysis"""
    data = {
        "success": True,
        "filename": filename,
        "image_id": image_id,
        "species": "Unknown",
        "confidence": "Low",
        "characteristics": [],
        "timestamp": datetime.now().isoformat()
    }
    
    lines = response_text.split("\n")
    current_section = None
    
    for line in lines:
        line = line.strip()
        
        # Parse each field
        if line.startswith("ENGLISH NAME:"):
            species = line.replace("ENGLISH NAME:", "").strip()
            if species:
                data["species"] = species
                
        elif line.startswith("HINDI NAME:"):
            hindi = line.replace("HINDI NAME:", "").strip()
            if hindi and hindi.lower() not in ["not available", "na", "n/a"]:
                data["species_hindi"] = hindi
                
        elif line.startswith("MARATHI NAME:"):
            marathi = line.replace("MARATHI NAME:", "").strip()
            if marathi and marathi.lower() not in ["not available", "na", "n/a"]:
                data["species_marathi"] = marathi
                
        elif line.startswith("SCIENTIFIC NAME:"):
            sci = line.replace("SCIENTIFIC NAME:", "").strip()
            if sci:
                data["scientific_name"] = sci
                
        elif line.startswith("CONFIDENCE:"):
            conf = line.replace("CONFIDENCE:", "").strip()
            if conf:
                data["confidence"] = conf
                
        elif line.startswith("HEIGHT:"):
            height = line.replace("HEIGHT:", "").strip()
            if height and "not" not in height.lower():
                data["height"] = height
                
        elif line.startswith("CANOPY:"):
            canopy = line.replace("CANOPY:", "").strip()
            if canopy and "not" not in canopy.lower():
                data["canopy"] = canopy
                
        elif line.startswith("GIRTH:"):
            girth = line.replace("GIRTH:", "").strip()
            if girth and "not" not in girth.lower():
                data["girth"] = girth
                
        elif line.startswith("DBH:"):
            dbh = line.replace("DBH:", "").strip()
            if dbh and "not" not in dbh.lower():
                data["dbh_estimate"] = dbh
                
        elif line.startswith("CONDITION:"):
            cond = line.replace("CONDITION:", "").strip()
            if cond:
                data["condition"] = cond
            current_section = "condition"
            
        elif line.startswith("CHARACTERISTICS:"):
            current_section = "characteristics"
            
        elif line.startswith("ADDITIONAL INFO:"):
            current_section = "additional_info"
            
        elif line.startswith("RECOMMENDATIONS:"):
            current_section = "recommendations"
            
        elif line.startswith("REMARKS:"):
            current_section = "remarks"
            
        elif line.startswith("- ") or line.startswith("• "):
            item = line[2:].strip()
            if not item:
                continue
                
            if current_section == "characteristics":
                data["characteristics"].append(item)
            elif current_section == "additional_info":
                if "additional_info" not in data or not data["additional_info"]:
                    data["additional_info"] = item
                else:
                    data["additional_info"] += " " + item
            elif current_section == "recommendations":
                if "recommendations" not in data or not data["recommendations"]:
                    data["recommendations"] = item
                else:
                    data["recommendations"] += " " + item
            elif current_section == "remarks":
                if "remarks" not in data or not data["remarks"]:
                    data["remarks"] = item
                else:
                    data["remarks"] += " " + item
    
    return TreeAnalysis(**data)


async def analyze_tree_image(image_data: bytes, filename: str, image_id: int, add_delay: bool = False) -> TreeAnalysis:
    """Analyze single tree image with AI (with retry logic and API key rotation)"""
    
    # Rate limiting delay
    if add_delay:
        logger.info(f"⏳ Rate limit delay: {RATE_LIMIT_DELAY}s before {filename}...")
        await asyncio.sleep(RATE_LIMIT_DELAY)
    
    logger.info(f"📸 [{image_id}] Processing: {filename}")
    
    # Prepare image
    try:
        image = Image.open(io.BytesIO(image_data))
        logger.info(f"   Image size: {image.size}, Mode: {image.mode}")
        
        # Convert to RGB
        if image.mode == 'RGBA':
            background = Image.new('RGB', image.size, (255, 255, 255))
            background.paste(image, mask=image.split()[3])
            image = background
        elif image.mode != 'RGB':
            image = image.convert('RGB')
        
        # Resize if needed
        image = resize_image(image)
        
        # Convert to JPEG
        img_buffer = io.BytesIO()
        image.save(img_buffer, format='JPEG', quality=85)
        img_bytes = img_buffer.getvalue()
        
    except Exception as e:
        error_msg = f"Image processing error: {str(e)}"
        logger.error(f"❌ [{image_id}] {error_msg}")
        error_result = TreeAnalysis(
            success=False,
            filename=filename,
            image_id=image_id,
            error=error_msg,
            timestamp=datetime.now().isoformat()
        )
        analysis_history.append(error_result.model_dump())
        return error_result
    
    # Detailed prompt for comprehensive analysis
    prompt = """Analyze this tree image thoroughly and provide detailed information in the EXACT format below:

ENGLISH NAME: [Common English name of the tree species]
HINDI NAME: [Hindi name if known, otherwise "Not available"]
MARATHI NAME: [Marathi name if known, otherwise "Not available"]
SCIENTIFIC NAME: [Full scientific name - Genus species]
CONFIDENCE: [High/Medium/Low - based on visible features and image clarity]
HEIGHT: [Estimated height in meters, e.g., "12-15 meters" or "Not clearly visible"]
CANOPY: [Canopy diameter/spread in meters, e.g., "8-10 meters" or "Not clearly visible"]
GIRTH: [Tree circumference at breast height (1.3m), e.g., "2.5 meters" or "Not measurable from image"]
DBH: [Diameter at Breast Height calculated from girth, e.g., "79.6 cm" or "Not calculable"]
CONDITION: [Overall health - Healthy/Good/Fair/Poor/Declining with detailed explanation of visible signs]
CHARACTERISTICS:
- [Bark texture, color, and pattern details]
- [Leaf shape, size, arrangement, and color]
- [Branch structure and growth pattern]
- [Any flowers, fruits, or seeds visible]
- [Unique identifying features]
- [Additional observable characteristics]
ADDITIONAL INFO: [Ecological information - native habitat, typical growth conditions, cultural significance, common uses, interesting facts about this species]
RECOMMENDATIONS: [Specific care recommendations for this species - watering needs, pruning guidelines, fertilization, pest management, best growing conditions, seasonal care tips]
REMARKS: [Specific observations about THIS tree - any visible issues, notable features, age estimation if possible, overall assessment]

IMPORTANT:
1. Be as specific and detailed as possible with all measurements and descriptions
2. For DBH calculation: DBH = Girth / π (if girth is measurable)
3. Provide botanical details to justify species identification
4. Include both common and scientific nomenclature
5. Be honest about confidence level - explain if uncertain
6. Base all observations on what's actually visible in the image"""
    
    # Try with multiple API keys if quota exceeded
    for retry in range(MAX_RETRIES):
        try:
            # Get next API key
            api_key = get_next_api_key()
            genai.configure(api_key=api_key)
            model = genai.GenerativeModel('models/gemini-2.5-flash')
            
            logger.info(f"   🔑 Using API key #{(current_key_index % len(GEMINI_API_KEYS)) + 1}/{len(GEMINI_API_KEYS)} (attempt {retry + 1}/{MAX_RETRIES})...")
            
            response = model.generate_content([
                prompt,
                {"mime_type": "image/jpeg", "data": base64.b64encode(img_bytes).decode('utf-8')}
            ])
            
            # Parse response
            result = parse_gemini_response(response.text.strip(), filename, image_id)
            logger.info(f"✅ [{image_id}] SUCCESS: {result.species} ({result.scientific_name})")
            logger.info(f"   DBH: {result.dbh_estimate}, Condition: {result.condition}")
            
            # Save to history
            analysis_history.append(result.model_dump())
            return result
            
        except Exception as e:
            error_msg = str(e)
            
            # Check if it's a quota error
            if "429" in error_msg or "quota" in error_msg.lower() or "exceeded" in error_msg.lower():
                logger.warning(f"⚠️ [{image_id}] Quota exceeded on attempt {retry + 1}/{MAX_RETRIES}")
                
                if retry < MAX_RETRIES - 1:
                    # Try next API key
                    logger.info(f"🔄 Trying next API key...")
                    await asyncio.sleep(2)  # Small delay before retry
                    continue
                else:
                    # All retries exhausted
                    logger.error(f"❌ [{image_id}] All {len(GEMINI_API_KEYS)} API keys exhausted")
                    error_result = TreeAnalysis(
                        success=False,
                        filename=filename,
                        image_id=image_id,
                        error=f"All {len(GEMINI_API_KEYS)} API keys have exceeded their quota. Please try again later or add more API keys.",
                        timestamp=datetime.now().isoformat()
                    )
                    analysis_history.append(error_result.model_dump())
                    return error_result
            else:
                # Other error - don't retry
                logger.error(f"❌ [{image_id}] FAILED: {filename} - {error_msg}")
                error_result = TreeAnalysis(
                    success=False,
                    filename=filename,
                    image_id=image_id,
                    error=error_msg,
                    timestamp=datetime.now().isoformat()
                )
                analysis_history.append(error_result.model_dump())
                return error_result


@app.get("/")
async def root():
    """API Information"""
    return {
        "name": "🌳 Tree Species & DBH Detection API",
        "version": "2.0.0",
        "mode": "Flexible Batch Processing",
        "max_images": MAX_IMAGES,
        "rate_limit_delay": f"{RATE_LIMIT_DELAY} seconds between images",
        "api_keys_available": len(GEMINI_API_KEYS),
        "message": f"Upload 1-{MAX_IMAGES} tree images for analysis",
        "endpoints": {
            "analyze": "POST /api/analyze-multiple (upload 1-{} images)".format(MAX_IMAGES),
            "health": "GET /api/health",
            "stats": "GET /api/stats",
            "export": "GET /api/export/csv",
            "docs": "GET /docs"
        }
    }


@app.get("/api/health")
async def health_check():
    """Health Check"""
    return {
        "status": "healthy",
        "api_configured": len(GEMINI_API_KEYS) > 0,
        "api_keys_count": len(GEMINI_API_KEYS),
        "mode": "FLEXIBLE BATCH PROCESSING",
        "max_images": MAX_IMAGES,
        "rate_limit_delay": RATE_LIMIT_DELAY,
        "total_analyses": len(analysis_history)
    }


@app.post("/api/analyze-multiple", response_model=BatchAnalysisResponse)
async def analyze_multiple_trees(files: List[UploadFile] = File(...)):
    """
    🌳 FLEXIBLE TREE ANALYSIS - Upload 1-10 tree images
    
    Features:
    - Upload anywhere from 1 to 10 images (botanist's choice)
    - Rate limited (5.5s delay between images)
    - Automatic API key rotation on quota exceeded
    - Comprehensive analysis: Species, DBH, Height, Canopy, Girth, Health
    - Regional names (Hindi & Marathi)
    - Detailed characteristics and recommendations
    
    Returns:
    - Complete analysis for all trees
    - Success/failure statistics
    - Unhealthy tree count
    - Processing time
    """
    start_time = time.time()
    
    logger.info("="*80)
    logger.info(f"🌳 BATCH ANALYSIS REQUEST")
    logger.info(f"📊 Images received: {len(files)}")
    logger.info(f"📋 Allowed range: 1-{MAX_IMAGES} images")
    logger.info(f"🔑 API keys available: {len(GEMINI_API_KEYS)}")
    logger.info("="*80)
    
    # Validate: At least 1 image
    if len(files) < 1:
        logger.error(f"❌ No images uploaded")
        raise HTTPException(
            status_code=400,
            detail=f"Please upload at least 1 image."
        )
    
    # Validate: Maximum images
    if len(files) > MAX_IMAGES:
        logger.error(f"❌ Too many images: {len(files)} > {MAX_IMAGES}")
        raise HTTPException(
            status_code=400,
            detail=f"Maximum {MAX_IMAGES} images allowed per batch. You uploaded {len(files)} images. "
                   f"Please select up to {MAX_IMAGES} images only."
        )
    
    # Validate: All files must be images
    for idx, file in enumerate(files):
        if not file.content_type or not file.content_type.startswith("image/"):
            logger.error(f"❌ Invalid file type: {file.filename}")
            raise HTTPException(
                status_code=400,
                detail=f"File '{file.filename}' is not an image. All files must be images (JPG, PNG, WEBP)."
            )
    
    try:
        # Read all files
        logger.info("📂 Reading uploaded files...")
        files_data = []
        for idx, file in enumerate(files):
            contents = await file.read()
            files_data.append((contents, file.filename, idx + 1))
            logger.info(f"   ✓ [{idx+1}] {file.filename} - {len(contents):,} bytes")
        
        # Process images SEQUENTIALLY with rate limiting
        logger.info(f"🔄 Starting sequential processing with {RATE_LIMIT_DELAY}s rate limiting...")
        estimated_time = len(files) * RATE_LIMIT_DELAY
        logger.info(f"⏱️  Estimated time: ~{estimated_time:.0f} seconds")
        
        results = []
        for idx, (data, name, img_id) in enumerate(files_data):
            # Add delay before each image except the first
            add_delay = (idx > 0)
            result = await analyze_tree_image(data, name, img_id, add_delay=add_delay)
            results.append(result)
        
        # Calculate statistics
        successful = sum(1 for r in results if r.success)
        failed = sum(1 for r in results if not r.success)
        
        # Count unhealthy trees
        unhealthy_keywords = ['poor', 'declining', 'unhealthy', 'diseased', 'stressed', 'damaged', 'weak']
        unhealthy = sum(1 for r in results if r.success and r.condition and 
                       any(word in r.condition.lower() for word in unhealthy_keywords))
        
        processing_time = time.time() - start_time
        
        logger.info("="*80)
        logger.info(f"✅ BATCH ANALYSIS COMPLETED")
        logger.info(f"📊 Total: {len(files)} | Success: {successful} | Failed: {failed} | Unhealthy: {unhealthy}")
        logger.info(f"⏱️  Processing time: {processing_time:.1f} seconds")
        logger.info("="*80)
        
        return BatchAnalysisResponse(
            total_images=len(files),
            successful=successful,
            failed=failed,
            unhealthy_trees=unhealthy,
            processing_time=round(processing_time, 2),
            results=results
        )
        
    except HTTPException:
        raise
    except Exception as e:
        logger.error(f"❌ CRITICAL ERROR: {str(e)}")
        logger.error(traceback.format_exc())
        raise HTTPException(
            status_code=500,
            detail=f"Server error during analysis: {str(e)}"
        )


@app.get("/api/stats")
async def get_statistics():
    """Get analysis statistics"""
    if not analysis_history:
        return {"message": "No analysis history available yet. Upload trees to begin!"}
    
    total = len(analysis_history)
    successful = sum(1 for a in analysis_history if a.get('success', False))
    failed = total - successful
    
    # Species frequency
    species_freq = {}
    for a in analysis_history:
        if a.get('success') and a.get('species'):
            species = a['species']
            species_freq[species] = species_freq.get(species, 0) + 1
    
    return {
        "total_analyses": total,
        "successful": successful,
        "failed": failed,
        "success_rate": f"{(successful/total*100):.1f}%" if total > 0 else "0%",
        "species_frequency": dict(sorted(species_freq.items(), key=lambda x: x[1], reverse=True)[:10])
    }


@app.get("/api/export/csv")
async def export_all_history():
    """Export all analysis history as CSV"""
    if not analysis_history:
        raise HTTPException(status_code=404, detail="No analysis history available")
    
    output = io.StringIO()
    fieldnames = [
        'timestamp', 'image_id', 'filename', 'species', 'scientific_name', 
        'species_hindi', 'species_marathi', 'confidence', 'height', 'canopy', 
        'girth', 'dbh_estimate', 'condition', 'characteristics', 
        'additional_info', 'recommendations', 'remarks', 'success', 'error'
    ]
    writer = csv.DictWriter(output, fieldnames=fieldnames)
    writer.writeheader()
    
    for record in analysis_history:
        writer.writerow({
            'timestamp': record.get('timestamp', ''),
            'image_id': record.get('image_id', ''),
            'filename': record.get('filename', ''),
            'species': record.get('species', ''),
            'scientific_name': record.get('scientific_name', ''),
            'species_hindi': record.get('species_hindi', ''),
            'species_marathi': record.get('species_marathi', ''),
            'confidence': record.get('confidence', ''),
            'height': record.get('height', ''),
            'canopy': record.get('canopy', ''),
            'girth': record.get('girth', ''),
            'dbh_estimate': record.get('dbh_estimate', ''),
            'condition': record.get('condition', ''),
            'characteristics': '; '.join(record.get('characteristics', [])),
            'additional_info': record.get('additional_info', ''),
            'recommendations': record.get('recommendations', ''),
            'remarks': record.get('remarks', ''),
            'success': record.get('success', False),
            'error': record.get('error', '')
        })
    
    output.seek(0)
    return StreamingResponse(
        iter([output.getvalue()]),
        media_type="text/csv",
        headers={"Content-Disposition": f"attachment; filename=tree_analysis_history.csv"}
    )


@app.post("/api/export/csv/results")
async def export_current_results(results: List[dict]):
    """Export provided results as CSV"""
    if not results:
        raise HTTPException(status_code=400, detail="No results to export")
    
    output = io.StringIO()
    fieldnames = [
        'timestamp', 'image_id', 'filename', 'species', 'scientific_name',
        'species_hindi', 'species_marathi', 'confidence', 'height', 'canopy',
        'girth', 'dbh_estimate', 'condition', 'characteristics',
        'additional_info', 'recommendations', 'remarks', 'success'
    ]
    writer = csv.DictWriter(output, fieldnames=fieldnames)
    writer.writeheader()
    
    for record in results:
        writer.writerow({
            'timestamp': record.get('timestamp', ''),
            'image_id': record.get('image_id', ''),
            'filename': record.get('filename', ''),
            'species': record.get('species', ''),
            'scientific_name': record.get('scientific_name', ''),
            'species_hindi': record.get('species_hindi', ''),
            'species_marathi': record.get('species_marathi', ''),
            'confidence': record.get('confidence', ''),
            'height': record.get('height', ''),
            'canopy': record.get('canopy', ''),
            'girth': record.get('girth', ''),
            'dbh_estimate': record.get('dbh_estimate', ''),
            'condition': record.get('condition', ''),
            'characteristics': '; '.join(record.get('characteristics', [])),
            'additional_info': record.get('additional_info', ''),
            'recommendations': record.get('recommendations', ''),
            'remarks': record.get('remarks', ''),
            'success': record.get('success', False)
        })
    
    output.seek(0)
    return StreamingResponse(
        iter([output.getvalue()]),
        media_type="text/csv",
        headers={"Content-Disposition": f"attachment; filename=tree_analysis_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv"}
    )


if __name__ == "__main__":
    import uvicorn
    print("\n" + "="*80)
    print("🌳 TREE SPECIES & DBH DETECTION API")
    print("="*80)
    print(f"🔧 AI System:             Advanced Vision Analysis")
    print(f"🔑 API Keys Loaded:       {len(GEMINI_API_KEYS)}")
    print(f"📊 Upload Mode:           FLEXIBLE (1-{MAX_IMAGES} images)")
    print(f"⏱️  Rate Limit:            {RATE_LIMIT_DELAY} seconds between images")
    print(f"🔄 Auto Retry:            Enabled (up to {len(GEMINI_API_KEYS)} keys)")
    print(f"🌐 Server:                http://localhost:8000")
    print(f"📚 API Documentation:     http://localhost:8000/docs")
    print(f"")
    print(f"📋 Analysis Includes:")
    print(f"   ✓ Species identification (English, Hindi, Marathi)")
    print(f"   ✓ DBH (Diameter at Breast Height) estimation")
    print(f"   ✓ Height and Canopy measurements")
    print(f"   ✓ Girth measurement")
    print(f"   ✓ Health condition assessment")
    print(f"   ✓ Detailed characteristics")
    print(f"   ✓ Care recommendations")
    print(f"   ✓ CSV export functionality")
    print(f"")
    print(f"💡 Botanist-Friendly: Upload any number from 1 to {MAX_IMAGES} trees!")
    print(f"🎯 Daily Capacity: ~{len(GEMINI_API_KEYS) * 20} requests with {len(GEMINI_API_KEYS)} API keys")
    print("="*80 + "\n")
    uvicorn.run(app, host="0.0.0.0", port=8000)