Files
aza/APP/fotoapp - Kopie/segmentation.py
2026-03-25 14:14:07 +01:00

207 lines
6.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""KI-based person segmentation, mask utilities and compositing."""
from __future__ import annotations
import os
from collections import deque
from pathlib import Path
from typing import Optional, Tuple
import numpy as np
from PIL import Image
_MODEL_URL = (
"https://storage.googleapis.com/mediapipe-models/"
"image_segmenter/selfie_multiclass_256x256/float32/latest/"
"selfie_multiclass_256x256.tflite"
)
_MODEL_DIR = Path(os.environ.get("APPDATA", "")) / "FotoApp"
_MODEL_PATH = _MODEL_DIR / "selfie_multiclass.tflite"
_segmenter = None
def _ensure_model(progress_cb=None) -> str:
"""Download the segmentation model on first use (~16 MB)."""
_MODEL_DIR.mkdir(parents=True, exist_ok=True)
path = str(_MODEL_PATH)
if os.path.isfile(path) and os.path.getsize(path) > 1_000_000:
return path
if progress_cb:
progress_cb("KI-Modell wird heruntergeladen (~16 MB) …")
import urllib.request
urllib.request.urlretrieve(_MODEL_URL, path)
return path
def _get_segmenter(progress_cb=None):
"""Lazy-init the MediaPipe ImageSegmenter (cached)."""
global _segmenter
if _segmenter is not None:
return _segmenter
import mediapipe as mp
from mediapipe.tasks.python import BaseOptions, vision
model_path = _ensure_model(progress_cb)
if progress_cb:
progress_cb("KI-Modell wird geladen …")
options = vision.ImageSegmenterOptions(
base_options=BaseOptions(model_asset_path=model_path),
output_category_mask=True,
)
_segmenter = vision.ImageSegmenter.create_from_options(options)
return _segmenter
# ─── KI segmentation ────────────────────────────────────────────────────────
def segment_person(img: Image.Image, progress_cb=None) -> np.ndarray:
"""Run MediaPipe Selfie Segmentation (Tasks API), returning float01 alpha mask.
Very fast (~1-2 s), lightweight (~16 MB model, low RAM).
The model is automatically downloaded on first use.
Category 0 = background, categories 1-5 = person parts.
"""
import mediapipe as mp
segmenter = _get_segmenter(progress_cb)
if progress_cb:
progress_cb("Segmentierung läuft …")
rgb_arr = np.asarray(img.convert("RGB"))
mp_img = mp.Image(image_format=mp.ImageFormat.SRGB, data=rgb_arr)
result = segmenter.segment(mp_img)
cat_mask = result.category_mask.numpy_view()
# category 0 = background; anything > 0 = person
mask = (cat_mask.squeeze() > 0).astype(np.float32)
return mask
# ─── Mask operations ────────────────────────────────────────────────────────
def feather_mask(mask: np.ndarray, radius_px: float) -> np.ndarray:
"""Gaussian blur on alpha mask for soft edges."""
if radius_px <= 0:
return mask
import cv2
ksize = int(radius_px) * 2 + 1
blurred = cv2.GaussianBlur(mask, (ksize, ksize), sigmaX=radius_px / 2.0)
return np.clip(blurred, 0.0, 1.0).astype(np.float32)
def apply_brush_stroke(
mask: np.ndarray,
points: list[Tuple[int, int]],
radius: int,
hardness: float,
add: bool,
) -> np.ndarray:
"""Paint on *mask* along *points* with a circular brush.
*hardness* 0..1 (0 = very soft, 1 = hard edge).
*add* True = paint foreground (white), False = erase (black).
"""
h, w = mask.shape[:2]
for cx, cy in points:
y_min = max(0, cy - radius)
y_max = min(h, cy + radius + 1)
x_min = max(0, cx - radius)
x_max = min(w, cx + radius + 1)
if y_min >= y_max or x_min >= x_max:
continue
yy, xx = np.mgrid[y_min:y_max, x_min:x_max]
dist = np.sqrt((xx - cx) ** 2 + (yy - cy) ** 2).astype(np.float32)
if hardness >= 0.99:
strength = (dist <= radius).astype(np.float32)
else:
inner = radius * hardness
outer = float(radius)
t = np.clip((dist - inner) / max(outer - inner, 1e-6), 0.0, 1.0)
strength = 1.0 - t
strength[dist > radius] = 0.0
patch = mask[y_min:y_max, x_min:x_max]
if add:
mask[y_min:y_max, x_min:x_max] = np.maximum(patch, strength)
else:
mask[y_min:y_max, x_min:x_max] = np.minimum(patch, 1.0 - strength)
return mask
# ─── Compositing ─────────────────────────────────────────────────────────────
def composite_fg_bg(
fg_rgb: np.ndarray,
alpha: np.ndarray,
bg_mode: str,
bg_color: Tuple[int, int, int] = (255, 255, 255),
bg_blur_radius: float = 0.0,
bg_image: Optional[np.ndarray] = None,
) -> np.ndarray:
"""Composite foreground over background.
*fg_rgb*: float01 (H, W, 3) the colour-graded image.
*alpha*: float01 (H, W) person mask (feathered).
*bg_mode*: ``"original"`` | ``"blur"`` | ``"color"`` | ``"transparent"`` | ``"image"``.
Returns float01 (H, W, 3) for non-transparent modes or (H, W, 4) for transparent.
"""
a = alpha[..., None]
if bg_mode == "original":
return fg_rgb
if bg_mode == "transparent":
rgba = np.concatenate([fg_rgb, alpha[..., None]], axis=-1)
return np.clip(rgba, 0.0, 1.0).astype(np.float32)
if bg_mode == "blur":
import cv2
ksize = max(1, int(bg_blur_radius)) * 2 + 1
bg = cv2.GaussianBlur(fg_rgb, (ksize, ksize), sigmaX=bg_blur_radius / 2.0)
elif bg_mode == "color":
bg = np.full_like(fg_rgb, [c / 255.0 for c in bg_color])
elif bg_mode == "image" and bg_image is not None:
h, w = fg_rgb.shape[:2]
bg = np.asarray(
Image.fromarray((bg_image * 255).astype(np.uint8)).resize(
(w, h), Image.Resampling.LANCZOS
),
dtype=np.float32,
) / 255.0
else:
return fg_rgb
result = fg_rgb * a + bg * (1.0 - a)
return np.clip(result, 0.0, 1.0).astype(np.float32)
# ─── Mask undo stack ─────────────────────────────────────────────────────────
class MaskHistory:
"""Simple undo buffer for mask edits (max *maxlen* snapshots)."""
def __init__(self, maxlen: int = 20):
self._stack: deque[np.ndarray] = deque(maxlen=maxlen)
def push(self, mask: np.ndarray):
self._stack.append(mask.copy())
def undo(self) -> Optional[np.ndarray]:
if self._stack:
return self._stack.pop()
return None
def clear(self):
self._stack.clear()
@property
def can_undo(self) -> bool:
return len(self._stack) > 0