from __future__ import annotations from dataclasses import dataclass from pathlib import Path from typing import Optional, Tuple import numpy as np @dataclass class CubeLUT: size: int table: np.ndarray # shape: (size, size, size, 3), float32 in [0,1] domain_min: np.ndarray # shape (3,) domain_max: np.ndarray # shape (3,) class CubeParseError(Exception): pass def load_cube(path: str | Path) -> CubeLUT: """ .cube parser for 3D LUTs (common sizes 17/33/65). Handles: TITLE, LUT_3D_SIZE, LUT_1D_SIZE (skipped), DOMAIN_MIN, DOMAIN_MAX, comments (#), blank lines. """ p = Path(path) if not p.exists(): raise CubeParseError(f"LUT file not found: {p}") size_3d: Optional[int] = None size_1d: Optional[int] = None domain_min = np.array([0.0, 0.0, 0.0], dtype=np.float32) domain_max = np.array([1.0, 1.0, 1.0], dtype=np.float32) data_3d: list[Tuple[float, float, float]] = [] collecting_3d = False rows_1d_remaining = 0 with p.open("r", encoding="utf-8", errors="ignore") as f: for raw in f: line = raw.strip() if not line or line.startswith("#"): continue if "#" in line: line = line.split("#", 1)[0].strip() if not line: continue parts = line.split() key = parts[0].upper() if key == "TITLE": continue if key == "LUT_3D_SIZE": if len(parts) != 2: raise CubeParseError("Invalid LUT_3D_SIZE line") size_3d = int(parts[1]) collecting_3d = True continue if key == "LUT_1D_SIZE": if len(parts) == 2: size_1d = int(parts[1]) rows_1d_remaining = size_1d continue if key == "DOMAIN_MIN": if len(parts) != 4: raise CubeParseError("Invalid DOMAIN_MIN line") domain_min = np.array(list(map(float, parts[1:4])), dtype=np.float32) continue if key == "DOMAIN_MAX": if len(parts) != 4: raise CubeParseError("Invalid DOMAIN_MAX line") domain_max = np.array(list(map(float, parts[1:4])), dtype=np.float32) continue if len(parts) >= 3: try: r, g, b = float(parts[0]), float(parts[1]), float(parts[2]) except ValueError: continue # Skip 1D LUT rows that appear before the 3D data if rows_1d_remaining > 0 and not collecting_3d: rows_1d_remaining -= 1 continue data_3d.append((r, g, b)) if size_3d is None: raise CubeParseError("Missing LUT_3D_SIZE in .cube file") expected = size_3d ** 3 if len(data_3d) != expected: raise CubeParseError(f"Expected {expected} LUT rows, got {len(data_3d)}") arr = np.array(data_3d, dtype=np.float32) # .cube standard: R varies fastest, then G, then B (slowest). # C-order reshape gives arr[B, G, R, channels]. # Transpose axes 0↔2 so table[R, G, B] = output RGB for intuitive lookup. arr = arr.reshape((size_3d, size_3d, size_3d, 3), order="C") arr = arr.transpose(2, 1, 0, 3).copy() return CubeLUT(size=size_3d, table=arr, domain_min=domain_min, domain_max=domain_max) # Set True to swap R↔B channels before/after LUT lookup (debug only) _DEBUG_SWAP_RB = False def apply_lut_rgb_float01( rgb: np.ndarray, lut: CubeLUT, swap_rb: Optional[bool] = None, ) -> np.ndarray: """ Apply a 3D LUT with trilinear interpolation. Input must be sRGB-encoded float32 in [0,1], shape (H, W, 3). LUTs designed for Log/ACES input will NOT produce correct results without an additional input transform. swap_rb: if True, swap R↔B before lookup (debug for channel order issues). Defaults to module-level _DEBUG_SWAP_RB flag. """ do_swap = swap_rb if swap_rb is not None else _DEBUG_SWAP_RB if rgb.dtype != np.float32: rgb = rgb.astype(np.float32) if do_swap: rgb = rgb[..., ::-1].copy() # Map from domain to [0,1] dmin = lut.domain_min.reshape((1, 1, 3)) dmax = lut.domain_max.reshape((1, 1, 3)) denom = np.maximum(dmax - dmin, 1e-12) x = np.clip((rgb - dmin) / denom, 0.0, 1.0) n = lut.size xg = x * (n - 1) # Floor indices clamped to [0, n-2] so i0+1 is always valid i0 = np.clip(np.floor(xg).astype(np.int32), 0, n - 2) i1 = i0 + 1 t = np.clip(xg - i0.astype(np.float32), 0.0, 1.0) r0, g0, b0 = i0[..., 0], i0[..., 1], i0[..., 2] r1, g1, b1 = i1[..., 0], i1[..., 1], i1[..., 2] tr = t[..., 0:1] tg = t[..., 1:2] tb = t[..., 2:3] # 8 corners for trilinear interpolation c000 = lut.table[r0, g0, b0] c001 = lut.table[r0, g0, b1] c010 = lut.table[r0, g1, b0] c011 = lut.table[r0, g1, b1] c100 = lut.table[r1, g0, b0] c101 = lut.table[r1, g0, b1] c110 = lut.table[r1, g1, b0] c111 = lut.table[r1, g1, b1] # Interpolate along B axis c00 = c000 + (c001 - c000) * tb c01 = c010 + (c011 - c010) * tb c10 = c100 + (c101 - c100) * tb c11 = c110 + (c111 - c110) * tb # Interpolate along G axis c0 = c00 + (c01 - c00) * tg c1 = c10 + (c11 - c10) * tg # Interpolate along R axis out = c0 + (c1 - c0) * tr out = np.clip(out, 0.0, 1.0).astype(np.float32) if do_swap: out = out[..., ::-1].copy() return out def blend_rgb(a: np.ndarray, b: np.ndarray, strength_0_1: float) -> np.ndarray: s = float(np.clip(strength_0_1, 0.0, 1.0)) return (a * (1.0 - s) + b * s).astype(np.float32) # ─── LUT cache & chain ────────────────────────────────────────────────────── _lut_cache: dict[str, CubeLUT] = {} def get_cached_lut(path: str | Path) -> CubeLUT: """Load a .cube LUT, returning a cached copy on repeat calls.""" key = str(Path(path).resolve()) if key not in _lut_cache: _lut_cache[key] = load_cube(path) return _lut_cache[key] def clear_lut_cache(): _lut_cache.clear() def apply_lut_chain( rgb: np.ndarray, chain: list[tuple[CubeLUT, float]], ) -> np.ndarray: """ Sequential LUT application. Each ``(lut, strength_0_100)`` is blended:: out = lerp(input, apply_lut(input), strength / 100) The result of each step feeds the next. """ for lut, strength in chain: if strength <= 0: continue lut_rgb = apply_lut_rgb_float01(rgb, lut) rgb = blend_rgb(rgb, lut_rgb, strength / 100.0) rgb = np.clip(rgb, 0.0, 1.0) return rgb.astype(np.float32)