from typing import List, Optional
from pathlib import Path
import uuid

from PIL import Image
from psd_tools import PSDImage

from app.utils.image import (
    render_layer_alpha_fullsize,
    resize_and_crop_to_fill
)
from app.core.constants import KEYWORDS, LOGO_LAYER


def generate_final_image(
    base_path: str,
    overlay_path: str,
    output_dir: str,
    logo_path: Optional[str] = None
):
    psd = PSDImage.open(base_path)
    all_layers = list(psd.descendants())

    filtered: List[tuple] = []
    logo_layers: List[tuple] = []

    # Helper: safely extract bounding box
    def _safe_bbox(layer) -> Optional[tuple]:
        bbox = getattr(layer, "bbox", None)
        if not bbox:
            return None
        try:
            if isinstance(bbox, (tuple, list)):
                x1, y1, x2, y2 = map(int, bbox)
            else:
                x1, y1, x2, y2 = int(bbox.x1), int(bbox.y1), int(bbox.x2), int(bbox.y2)
            return (x1, y1, x2, y2) if (x2 > x1 and y2 > y1) else None
        except Exception:
            return None

    # Detect design layers + logo layers
    for layer in all_layers:
        try:
            if hasattr(layer, "visible") and not layer.visible:
                continue

            name = (getattr(layer, "name", "") or "").lower()

            # LOGO detection
            if any(k in name for k in LOGO_LAYER):
                bbox = _safe_bbox(layer)
                if bbox:
                    logo_layers.append((layer, bbox))
                continue

            # Design layer detection
            kind = getattr(layer, "kind", None)
            if kind == "smartobject" or any(k in name for k in KEYWORDS):
                bbox = _safe_bbox(layer)
                if bbox:
                    filtered.append((layer, bbox))

        except Exception:
            pass

    if not filtered:
        raise ValueError("No design layers detected in PSD")

    # Remove parent containers (keep only deepest layers)
    cleaned = []
    for i, (li, bbi) in enumerate(filtered):
        xi1, yi1, xi2, yi2 = bbi
        area_i = (xi2 - xi1) * (yi2 - yi1)

        skip = False
        for j, (_, bbj) in enumerate(filtered):
            if i == j:
                continue
            xj1, yj1, xj2, yj2 = bbj
            area_j = (xj2 - xj1) * (yj2 - yj1)

            # li is parent of lj => remove
            if (
                xi1 <= xj1 and yi1 <= yj1 and
                xi2 >= xj2 and yi2 >= yj2 and
                area_i > area_j
            ):
                skip = True
                break

        if not skip:
            cleaned.append((li, bbi))

    filtered = cleaned if cleaned else filtered

    # Wrap detection
    def same_wrap(b1, b2, y_tol_ratio=0.25, gap_tol_ratio=0.5):
        try:
            x11, y11, x12, y12 = b1
            x21, y21, x22, y22 = b2

            h1, h2 = y12 - y11, y22 - y21
            yc1 = (y11 + y12) / 2
            yc2 = (y21 + y22) / 2

            if abs(yc1 - yc2) > max(h1, h2) * y_tol_ratio:
                return False

            gap = max(0, x21 - x12, x11 - x22)
            avgw = ((x12 - x11) + (x22 - x21)) / 2
            return gap <= avgw * gap_tol_ratio
        except Exception:
            return False

    groups = []
    cur = []
    cur_bbox = None

    for layer, bbox in sorted(filtered, key=lambda x: x[1][0]):
        if not cur:
            cur = [(layer, bbox)]
            cur_bbox = bbox
            continue

        if same_wrap(cur[-1][1], bbox):
            cur.append((layer, bbox))
            cur_bbox = (
                min(cur_bbox[0], bbox[0]),
                min(cur_bbox[1], bbox[1]),
                max(cur_bbox[2], bbox[2]),
                max(cur_bbox[3], bbox[3]),
            )
        else:
            groups.append((cur, cur_bbox))
            cur = [(layer, bbox)]
            cur_bbox = bbox

    if cur:
        groups.append((cur, cur_bbox))

    wrap_groups = []
    tilt_group = None
    psd_h = getattr(psd, "height", 0)

    for g, bb in groups:
        w = bb[2] - bb[0]
        h = bb[3] - bb[1]
        if h < 0.40 * psd_h or (w / max(1, h)) > 3.5:
            tilt_group = (g, bb)
        else:
            wrap_groups.append((g, bb))

    if not wrap_groups:
        wrap_groups = [(filtered, filtered[0][1])]

    is_wrap = len(filtered) >= 2

    # Load images
    overlay_img = Image.open(overlay_path).convert("RGBA")
    overlay_W, overlay_H = overlay_img.size

    if logo_path in (None, "", "None"):
        logo_img = None
    else:
        logo_img = Image.open(logo_path).convert("RGBA")

    # Slice overlay for wrap mode
    if is_wrap:

        def get_internal_width(layer_obj) -> int:
            """Return internal smart-object width if available, else bbox width."""

            # 1. Try smart-object width
            try:
                so = getattr(layer_obj, "smart_object", None)
                if so:
                    # Try loading the embedded smart object image
                    for attr in ("image", "as_PIL", "as_pil", "as_PIL_image"):
                        fn = getattr(so, attr, None)
                        if callable(fn):
                            try:
                                img = fn()
                                if hasattr(img, "size"):
                                    return max(1, img.size[0])
                            except Exception:
                                pass

                    # Try width attributes
                    w = getattr(so, "width", None) or getattr(so, "w", None)
                    if isinstance(w, int):
                        return max(1, w)
                    if isinstance(w, (tuple, list)):
                        return max(1, w[0])
            except Exception:
                pass

            # 2. Fallback to bbox width
            bbox = _safe_bbox(layer_obj)
            if bbox:
                x1, y1, x2, y2 = bbox
                return max(1, x2 - x1)

            # 3. Last fallback to avoid zero width
            return 1

        segs = []
        for g, bb in wrap_groups:
            total = sum(max(get_internal_width(l), 1) for l, _ in g)
            segs.append((g, bb, total))

        total_all = sum(s[2] for s in segs)
        if total_all == 0:
            segs = [(g, bb, bb[2] - bb[0]) for (g, bb, _) in segs]
            total_all = sum(bb[2] - bb[0] for (_, bb, _) in segs)

        if total_all == 0:
            raise RuntimeError("Invalid smart-object width computation.")

        ratios = [s[2] / total_all for s in segs]

        cuts = [0]
        acc = 0
        for r in ratios:
            acc += int(r * overlay_W)
            cuts.append(acc)
        cuts[-1] = overlay_W

        overlay_slices = [
            overlay_img.crop((cuts[i], 0, cuts[i + 1], overlay_H))
            for i in range(len(cuts) - 1)
        ]
    else:
        overlay_slices = [overlay_img]

    # Hide original layers before compositing
    for l, _ in filtered:
        try:
            l.visible = False
        except Exception:
            pass

    for l, _ in logo_layers:
        try:
            l.visible = False
        except Exception:
            pass

    # Base canvas
    base = psd.composite().convert("RGBA")
    canvas_size = base.size

    # Build mask dictionary (design layers ONLY initially)
    layer_masks = {
        id(layer): render_layer_alpha_fullsize(psd, layer, canvas_size)
        for (layer, _) in filtered
    }

    # Add mask extraction for logo layers
    for (layer, _) in logo_layers:
        layer_masks[id(layer)] = render_layer_alpha_fullsize(psd, layer, canvas_size)

    # Overlay application
    def apply_group(group_layers, gbb, seg):
        gx1, gy1, gx2, gy2 = gbb
        gw, gh = gx2 - gx1, gy2 - gy1
        if gw <= 0 or gh <= 0:
            return

        resized = resize_and_crop_to_fill(seg, gw, gh)

        for layer, bb in group_layers:
            x1, y1, x2, y2 = bb
            zone_w, zone_h = x2 - x1, y2 - y1

            rx = x1 - gx1
            ry = y1 - gy1

            piece = resized.crop((rx, ry, rx + zone_w, ry + zone_h))

            mask = layer_masks.get(id(layer))
            if mask:
                mask_zone = mask.crop((x1, y1, x2, y2)).resize((zone_w, zone_h))
            else:
                mask_zone = None

            try:
                base.paste(piece, (x1, y1), mask_zone if mask_zone else piece)
            except Exception:
                base.paste(piece, (x1, y1))

    # Apply wrap / normal overlays
    if is_wrap:
        for (grp, bb), seg in zip(wrap_groups, overlay_slices):
            apply_group(grp, bb, seg)
        if tilt_group:
            apply_group(tilt_group[0], tilt_group[1], overlay_slices[-1])
    else:
        for (grp, bb) in wrap_groups:
            apply_group(grp, bb, overlay_img)
        if tilt_group:
            apply_group(tilt_group[0], tilt_group[1], overlay_img)

    # Apply LOGO layers
    if logo_img:
        for layer, bb in logo_layers:
            x1, y1, x2, y2 = bb
            w, h = x2 - x1, y2 - y1
            if w <= 0 or h <= 0:
                continue

            resized_logo = resize_and_crop_to_fill(logo_img, w, h)

            mask = layer_masks.get(id(layer))
            mask_zone = (
                mask.crop((x1, y1, x2, y2)).resize((w, h))
                if mask is not None else None
            )

            try:
                base.paste(resized_logo, (x1, y1), mask_zone if mask_zone else resized_logo)
            except Exception:
                base.paste(resized_logo, (x1, y1))

    # Save file
    final_name = f"{uuid.uuid4().hex}_final.png"
    final_path = Path(output_dir) / final_name
    base.save(final_path)

    return str(final_path)
