Files
aza/APP/fotoapp - Kopie/lut.py

223 lines
6.8 KiB
Python
Raw Normal View History

2026-03-25 14:14:07 +01:00
from __future__ import annotations
from dataclasses import dataclass
from pathlib import Path
from typing import Optional, Tuple
import numpy as np
@dataclass
class CubeLUT:
size: int
table: np.ndarray # shape: (size, size, size, 3), float32 in [0,1]
domain_min: np.ndarray # shape (3,)
domain_max: np.ndarray # shape (3,)
class CubeParseError(Exception):
pass
def load_cube(path: str | Path) -> CubeLUT:
"""
.cube parser for 3D LUTs (common sizes 17/33/65).
Handles: TITLE, LUT_3D_SIZE, LUT_1D_SIZE (skipped),
DOMAIN_MIN, DOMAIN_MAX, comments (#), blank lines.
"""
p = Path(path)
if not p.exists():
raise CubeParseError(f"LUT file not found: {p}")
size_3d: Optional[int] = None
size_1d: Optional[int] = None
domain_min = np.array([0.0, 0.0, 0.0], dtype=np.float32)
domain_max = np.array([1.0, 1.0, 1.0], dtype=np.float32)
data_3d: list[Tuple[float, float, float]] = []
collecting_3d = False
rows_1d_remaining = 0
with p.open("r", encoding="utf-8", errors="ignore") as f:
for raw in f:
line = raw.strip()
if not line or line.startswith("#"):
continue
if "#" in line:
line = line.split("#", 1)[0].strip()
if not line:
continue
parts = line.split()
key = parts[0].upper()
if key == "TITLE":
continue
if key == "LUT_3D_SIZE":
if len(parts) != 2:
raise CubeParseError("Invalid LUT_3D_SIZE line")
size_3d = int(parts[1])
collecting_3d = True
continue
if key == "LUT_1D_SIZE":
if len(parts) == 2:
size_1d = int(parts[1])
rows_1d_remaining = size_1d
continue
if key == "DOMAIN_MIN":
if len(parts) != 4:
raise CubeParseError("Invalid DOMAIN_MIN line")
domain_min = np.array(list(map(float, parts[1:4])), dtype=np.float32)
continue
if key == "DOMAIN_MAX":
if len(parts) != 4:
raise CubeParseError("Invalid DOMAIN_MAX line")
domain_max = np.array(list(map(float, parts[1:4])), dtype=np.float32)
continue
if len(parts) >= 3:
try:
r, g, b = float(parts[0]), float(parts[1]), float(parts[2])
except ValueError:
continue
# Skip 1D LUT rows that appear before the 3D data
if rows_1d_remaining > 0 and not collecting_3d:
rows_1d_remaining -= 1
continue
data_3d.append((r, g, b))
if size_3d is None:
raise CubeParseError("Missing LUT_3D_SIZE in .cube file")
expected = size_3d ** 3
if len(data_3d) != expected:
raise CubeParseError(f"Expected {expected} LUT rows, got {len(data_3d)}")
arr = np.array(data_3d, dtype=np.float32)
# .cube standard: R varies fastest, then G, then B (slowest).
# C-order reshape gives arr[B, G, R, channels].
# Transpose axes 0↔2 so table[R, G, B] = output RGB for intuitive lookup.
arr = arr.reshape((size_3d, size_3d, size_3d, 3), order="C")
arr = arr.transpose(2, 1, 0, 3).copy()
return CubeLUT(size=size_3d, table=arr, domain_min=domain_min, domain_max=domain_max)
# Set True to swap R↔B channels before/after LUT lookup (debug only)
_DEBUG_SWAP_RB = False
def apply_lut_rgb_float01(
rgb: np.ndarray,
lut: CubeLUT,
swap_rb: Optional[bool] = None,
) -> np.ndarray:
"""
Apply a 3D LUT with trilinear interpolation.
Input must be sRGB-encoded float32 in [0,1], shape (H, W, 3).
LUTs designed for Log/ACES input will NOT produce correct results
without an additional input transform.
swap_rb: if True, swap RB before lookup (debug for channel order issues).
Defaults to module-level _DEBUG_SWAP_RB flag.
"""
do_swap = swap_rb if swap_rb is not None else _DEBUG_SWAP_RB
if rgb.dtype != np.float32:
rgb = rgb.astype(np.float32)
if do_swap:
rgb = rgb[..., ::-1].copy()
# Map from domain to [0,1]
dmin = lut.domain_min.reshape((1, 1, 3))
dmax = lut.domain_max.reshape((1, 1, 3))
denom = np.maximum(dmax - dmin, 1e-12)
x = np.clip((rgb - dmin) / denom, 0.0, 1.0)
n = lut.size
xg = x * (n - 1)
# Floor indices clamped to [0, n-2] so i0+1 is always valid
i0 = np.clip(np.floor(xg).astype(np.int32), 0, n - 2)
i1 = i0 + 1
t = np.clip(xg - i0.astype(np.float32), 0.0, 1.0)
r0, g0, b0 = i0[..., 0], i0[..., 1], i0[..., 2]
r1, g1, b1 = i1[..., 0], i1[..., 1], i1[..., 2]
tr = t[..., 0:1]
tg = t[..., 1:2]
tb = t[..., 2:3]
# 8 corners for trilinear interpolation
c000 = lut.table[r0, g0, b0]
c001 = lut.table[r0, g0, b1]
c010 = lut.table[r0, g1, b0]
c011 = lut.table[r0, g1, b1]
c100 = lut.table[r1, g0, b0]
c101 = lut.table[r1, g0, b1]
c110 = lut.table[r1, g1, b0]
c111 = lut.table[r1, g1, b1]
# Interpolate along B axis
c00 = c000 + (c001 - c000) * tb
c01 = c010 + (c011 - c010) * tb
c10 = c100 + (c101 - c100) * tb
c11 = c110 + (c111 - c110) * tb
# Interpolate along G axis
c0 = c00 + (c01 - c00) * tg
c1 = c10 + (c11 - c10) * tg
# Interpolate along R axis
out = c0 + (c1 - c0) * tr
out = np.clip(out, 0.0, 1.0).astype(np.float32)
if do_swap:
out = out[..., ::-1].copy()
return out
def blend_rgb(a: np.ndarray, b: np.ndarray, strength_0_1: float) -> np.ndarray:
s = float(np.clip(strength_0_1, 0.0, 1.0))
return (a * (1.0 - s) + b * s).astype(np.float32)
# ─── LUT cache & chain ──────────────────────────────────────────────────────
_lut_cache: dict[str, CubeLUT] = {}
def get_cached_lut(path: str | Path) -> CubeLUT:
"""Load a .cube LUT, returning a cached copy on repeat calls."""
key = str(Path(path).resolve())
if key not in _lut_cache:
_lut_cache[key] = load_cube(path)
return _lut_cache[key]
def clear_lut_cache():
_lut_cache.clear()
def apply_lut_chain(
rgb: np.ndarray,
chain: list[tuple[CubeLUT, float]],
) -> np.ndarray:
"""
Sequential LUT application. Each ``(lut, strength_0_100)`` is blended::
out = lerp(input, apply_lut(input), strength / 100)
The result of each step feeds the next.
"""
for lut, strength in chain:
if strength <= 0:
continue
lut_rgb = apply_lut_rgb_float01(rgb, lut)
rgb = blend_rgb(rgb, lut_rgb, strength / 100.0)
rgb = np.clip(rgb, 0.0, 1.0)
return rgb.astype(np.float32)