from typing import Optional, List
from pathlib import Path
from fastapi import HTTPException, status
from fastapi.concurrency import run_in_threadpool
from sqlmodel import Session, select

from app.core.logging import setup_logger
from app.core.config import get_settings
from app.core.constants import ERR_NOT_FOUND, ERR_FIELD_UPDATE

from app.models.project_master import ProjectMaster
from app.models.product_image import ProductImage
from app.repositories.project_repo import ProjectRepository
from app.repositories.product_image_repo import ProductImageRepository
from app.repositories.product_template_repo import ProductTemplateRepository

from app.utils.image import validate_overlay_file, validate_logo_file, save_upload_file_streamed, safe_delete
from app.utils.generate_final_image import generate_final_image

logger = setup_logger(__name__)
settings = get_settings()

STORAGE_OVERLAY = Path(settings.OVERLAY_IMAGES_DIRECTORY)
STORAGE_OVERLAY.mkdir(parents=True, exist_ok=True)
LOGO_OVERLAY = Path(settings.LOGO_IMAGES_DIRECTORY)
LOGO_OVERLAY.mkdir(parents=True, exist_ok=True)
STORAGE_FINAL = Path(settings.FINAL_IMAGES_DIRECTORY)
STORAGE_FINAL.mkdir(parents=True, exist_ok=True)


class ProjectService:
    """High-level project orchestration."""

    @staticmethod
    async def create_project(
        project_name: str,
        client_name: str,
        overlay_file,
        logo_file,
        current_user,
        session: Session,
    ) -> ProjectMaster:
        """
        Create a project, process overlay against all active product templates,
        and persist ProductImage rows for each generated image.
        """

        # validate overlay
        try:
            validate_overlay_file(overlay_file)
        except Exception as exc:
            raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(exc))
        
        # Validate Stamp/Logo
        try:
            logo_path = None
            if logo_file:
                try:
                    validate_logo_file(logo_file)
                except Exception as exc:
                    raise HTTPException(status_code=400, detail=str(exc))

                logo_path = await save_upload_file_streamed(logo_file, LOGO_OVERLAY)
        except Exception as exc:
            raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(exc))

        # save overlay file (streamed)
        overlay_path = await save_upload_file_streamed(overlay_file, STORAGE_OVERLAY)

        # create project master record (persist early so we have id)
        project = ProjectMaster(
            project_name=project_name,
            client_name=client_name,
            overlay_image_path=str(overlay_path),
            logo_image_path=str(logo_path) if logo_path else None
        )
        project = ProjectRepository.create_project(session, project)

        # get templates
        templates = ProductTemplateRepository.list_all(session)
        if not templates:
            # cleanup overlay file and project
            safe_delete(Path(overlay_path))
            raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="No product templates available.")

        created_images = []
        tmp_paths = [Path(overlay_path)]

        # process templates in threadpool to avoid blocking event loop
        try:
            for tpl in templates:
                try:
                    # generate_final_image returns (final_path_str)
                    final_path_str = await run_in_threadpool(
                        generate_final_image,
                        tpl.file_path,
                        str(overlay_path),
                        str(STORAGE_FINAL),
                        logo_path=str(logo_path) if logo_path else None,
                    )
                except Exception as gen_exc:
                    logger.exception("Failed to generate image for template %s", tpl.id)
                    raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to process template {tpl.id}") from gen_exc

                tmp_paths.append(Path(final_path_str))

                # persist ProductImage
                entry = ProductImage(
                    project_master_id=project.id,
                    product_template_id=tpl.id,
                    final_image_path=str(final_path_str)
                )
                entry = ProductImageRepository.create(session, entry)
                created_images.append(entry)

        except Exception:
            # cleanup generated files and overlay, and delete project & created rows
            for p in tmp_paths:
                safe_delete(p)
            try:
                # mark project deleted
                ProjectRepository.soft_delete_project(session, project)
            except Exception:
                logger.exception("Failed to rollback project after generation failure")
            raise

        # return created project (fresh)
        session.refresh(project)
        return project

    @staticmethod
    def list_projects(session: Session, page: int, limit: int, search: Optional[str], sort_by: str, order: str):
        return ProjectRepository.list_advanced(
            session=session,
            page=page,
            limit=limit,
            search=search,
            sort_by=sort_by,
            order=order,
        )

    @staticmethod
    def get_project_detail(session: Session, project_id: int):
        project = ProjectRepository.get_project(session, project_id)
        if not project or project.is_deleted:
            raise HTTPException(status.HTTP_404_NOT_FOUND, ERR_NOT_FOUND)

        rows = ProductImageRepository.list_with_template(session, project_id)

        # rows = [(ProductImage, ProductTemplate)]
        images = []
        for img, tpl in rows:
            images.append({
                "id": img.id,
                "final_image_path": img.final_image_path,
                "product_template": tpl
            })

        return project, images

    @staticmethod
    async def update_project(
        project_id: int,
        project_name: Optional[str],
        client_name: Optional[str],
        overlay_file,
        logo_file,
        current_user,
        db: Session,
    ):
        """
        Update project metadata and optionally regenerate final images when overlay changes.
        Handles template-count changes:
        - If new templates appear -> INSERT new rows
        - Else -> UPDATE existing rows only
        Ensures full rollback and file cleanup on failure.
        """

        # Determine storage root automatically (parent of overlay/final directories)
        STORAGE_ROOT = STORAGE_OVERLAY.parent.parent.resolve()

        # ---------- 1. Fetch project ----------
        project = ProjectRepository.get_project(db, project_id)
        if not project or project.is_deleted:
            raise HTTPException(status.HTTP_404_NOT_FOUND, ERR_NOT_FOUND)

        if not any([project_name, client_name, overlay_file]):
            raise HTTPException(status.HTTP_400_BAD_REQUEST, ERR_FIELD_UPDATE)

        # ---------- 2. Update metadata ----------
        if project_name:
            project.project_name = project_name
        if client_name:
            project.client_name = client_name

        # CASE: metadata only
        if not overlay_file:
            db.add(project)
            db.commit()
            db.refresh(project)
            return {"success": True, "message": "Project updated successfully.", "project_id": project.id}

        # ---------- 3. Validate & save new overlay ----------
        try:
            validate_overlay_file(overlay_file)
        except Exception as exc:
            raise HTTPException(status.HTTP_400_BAD_REQUEST, str(exc))
        
        new_logo_path = None
        try:
            if logo_file:
                validate_logo_file(logo_file)
                new_logo_path = await save_upload_file_streamed(logo_file, LOGO_OVERLAY)
        except Exception as exc:
            raise HTTPException(status.HTTP_400_BAD_REQUEST, str(exc))

        new_overlay_abs = Path(await save_upload_file_streamed(overlay_file, STORAGE_OVERLAY)).resolve()
        new_overlay_rel = str(new_overlay_abs.relative_to(STORAGE_ROOT))

        # ---------- 4. Get existing rows ----------
        existing_rows = ProductImageRepository.list_by_project(db, project_id)
        if not existing_rows:
            safe_delete(new_overlay_abs)
            raise HTTPException(404, "No image mapping found for this project")

        # Capture old overlay
        old_overlay_abs = (
            Path(STORAGE_ROOT, project.overlay_image_path).resolve()
            if project.overlay_image_path else None
        )

        # Capture old final images (absolute)
        old_final_abs_map = {}
        for row in existing_rows:
            try:
                abs_path = Path(STORAGE_ROOT, row.final_image_path).resolve()
                old_final_abs_map[row.product_template_id] = abs_path
            except Exception:
                pass

        # Template relations
        existing_map = {row.product_template_id: row for row in existing_rows}
        template_ids = list(existing_map.keys())
        templates = ProductTemplateRepository.list_by_ids(db, template_ids)

        new_final_abs_list: List[Path] = []
        created_rows: List[ProductImage] = []

        # ---------- 5. Generate final images ----------
        try:
            for tpl in templates:
                final_path_str = await run_in_threadpool(
                    generate_final_image,
                    tpl.file_path,
                    str(new_overlay_abs),
                    str(STORAGE_FINAL),
                    logo_path=str(new_logo_path if logo_file else project.logo_image_path),
                )

                final_abs = Path(final_path_str).resolve()
                final_rel = str(final_abs.relative_to(STORAGE_ROOT))

                new_final_abs_list.append(final_abs)

                if tpl.id in existing_map:
                    row = existing_map[tpl.id]
                    row.final_image_path = final_rel
                    db.add(row)
                else:
                    new_row = ProductImage(
                        project_master_id=project.id,
                        product_template_id=tpl.id,
                        final_image_path=final_rel,
                    )
                    created_rows.append(new_row)
                    db.add(new_row)

            db.flush()

        except Exception as exc:
            logger.exception("Regeneration failed; rolling back")
            db.rollback()

            for f in new_final_abs_list:
                safe_delete(f)

            safe_delete(new_overlay_abs)

            raise HTTPException(status.HTTP_500_INTERNAL_SERVER_ERROR, "Project update failed during regeneration")

        # ---------- 6. Commit DB changes ----------
        try:
            if new_logo_path:
                project.logo_image_path = str(Path(new_logo_path).resolve())
            project.overlay_image_path = new_overlay_rel
            db.add(project)
            db.commit()
            db.refresh(project)

            for row in existing_map.values():
                db.refresh(row)
            for row in created_rows:
                db.refresh(row)

        except Exception as exc:
            logger.exception("Commit failed; rolling back")
            db.rollback()

            for f in new_final_abs_list:
                safe_delete(f)

            safe_delete(new_overlay_abs)

            raise HTTPException(500, "Failed to commit regenerated images")

        # ---------- 7. Delete old final images ----------
        try:
            new_final_set = {p.resolve() for p in new_final_abs_list}

            for tpl_id, old_abs in old_final_abs_map.items():
                try:
                    if old_abs.exists() and old_abs not in new_final_set:
                        safe_delete(old_abs)
                except Exception:
                    logger.exception("Failed to delete old final image: %s", old_abs)

        except Exception:
            logger.exception("Unexpected error during old-image cleanup")

        # ---------- 8. Delete old overlay ----------
        if old_overlay_abs:
            try:
                safe_delete(old_overlay_abs)
            except Exception:
                logger.exception("Failed to delete old overlay: %s", old_overlay_abs)

        return {
            "success": True,
            "message": "Project updated successfully.",
            "project_id": project.id,
        }

    @staticmethod
    async def regenerate_single_image(
        db: Session,
        product_image_id: int,
        new_overlay_file,
        current_user
    ):
        """
        Regenerate one final image using a temporary overlay.
        Old image is replaced only if the whole process succeeds.
        """

        # 1. Validate overlay
        try:
            validate_overlay_file(new_overlay_file)
        except Exception as exc:
            raise HTTPException(status.HTTP_400_BAD_REQUEST, str(exc))

        # 2. Fetch the ProductImage record
        img_row = db.get(ProductImage, product_image_id)
        if not img_row or img_row.is_deleted:
            raise HTTPException(status.HTTP_404_NOT_FOUND, "Image record not found")

        # 3. Fetch the template for this image
        template = ProductTemplateRepository.get_by_id(db, img_row.product_template_id)
        if not template:
            raise HTTPException(404, "Associated template not found")

        # 4. Create temp overlay (not stored permanently)
        temp_overlay_path = await save_upload_file_streamed(
            new_overlay_file,
            STORAGE_OVERLAY
        )

        temp_overlay_path = Path(temp_overlay_path)

        # 5. Generate new final image
        try:
            new_final_path_str = await run_in_threadpool(
                generate_final_image,
                template.file_path,
                str(temp_overlay_path),
                str(STORAGE_FINAL),
            )
        except Exception as exc:
            safe_delete(temp_overlay_path)
            raise HTTPException(500, "Failed to regenerate final image") from exc

        new_final_path = Path(new_final_path_str)

        # 6. Prepare rollback protection
        old_final_path = Path(img_row.final_image_path)

        try:
            # Update DB entry
            img_row.final_image_path = str(new_final_path)
            db.add(img_row)
            db.commit()
            db.refresh(img_row)

        except Exception as exc:
            db.rollback()
            # cleanup newly generated file
            safe_delete(new_final_path)
            safe_delete(temp_overlay_path)
            raise HTTPException(500, "Failed to update database") from exc

        # 7. Delete old final image
        try:
            safe_delete(old_final_path)
        except Exception:
            logger.exception("Failed to delete old final image")

        # Cleanup temp overlay
        safe_delete(temp_overlay_path)

        return {
            "success": True,
            "message": "Final image regenerated successfully",
            "product_image_id": img_row.id,
            "project_master_id": img_row.project_master_id,
            "new_final_image_path": str(new_final_path)
        }

    @staticmethod
    def delete_project(session: Session, project_id: int):
        project = ProjectRepository.get_project(session, project_id)
        if not project:
            raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=ERR_NOT_FOUND)

        # mark project soft-deleted, mark images deleted, attempt to remove files (best-effort)
        ProjectRepository.soft_delete_project(session, project)

        images = ProductImageRepository.list_by_project(session, project_id)
        for img in images:
            # remove final file
            try:
                safe_delete(Path(img.final_image_path))
            except Exception:
                logger.exception("Failed to delete final image file: %s", img.final_image_path)
        ProductImageRepository.delete_by_project(session, project_id)
