223 lines
6.8 KiB
Python
223 lines
6.8 KiB
Python
|
|
from __future__ import annotations
|
||
|
|
|
||
|
|
from dataclasses import dataclass
|
||
|
|
from pathlib import Path
|
||
|
|
from typing import Optional, Tuple
|
||
|
|
|
||
|
|
import numpy as np
|
||
|
|
|
||
|
|
|
||
|
|
@dataclass
|
||
|
|
class CubeLUT:
|
||
|
|
size: int
|
||
|
|
table: np.ndarray # shape: (size, size, size, 3), float32 in [0,1]
|
||
|
|
domain_min: np.ndarray # shape (3,)
|
||
|
|
domain_max: np.ndarray # shape (3,)
|
||
|
|
|
||
|
|
|
||
|
|
class CubeParseError(Exception):
|
||
|
|
pass
|
||
|
|
|
||
|
|
|
||
|
|
def load_cube(path: str | Path) -> CubeLUT:
|
||
|
|
"""
|
||
|
|
.cube parser for 3D LUTs (common sizes 17/33/65).
|
||
|
|
Handles: TITLE, LUT_3D_SIZE, LUT_1D_SIZE (skipped),
|
||
|
|
DOMAIN_MIN, DOMAIN_MAX, comments (#), blank lines.
|
||
|
|
"""
|
||
|
|
p = Path(path)
|
||
|
|
if not p.exists():
|
||
|
|
raise CubeParseError(f"LUT file not found: {p}")
|
||
|
|
|
||
|
|
size_3d: Optional[int] = None
|
||
|
|
size_1d: Optional[int] = None
|
||
|
|
domain_min = np.array([0.0, 0.0, 0.0], dtype=np.float32)
|
||
|
|
domain_max = np.array([1.0, 1.0, 1.0], dtype=np.float32)
|
||
|
|
|
||
|
|
data_3d: list[Tuple[float, float, float]] = []
|
||
|
|
collecting_3d = False
|
||
|
|
rows_1d_remaining = 0
|
||
|
|
|
||
|
|
with p.open("r", encoding="utf-8", errors="ignore") as f:
|
||
|
|
for raw in f:
|
||
|
|
line = raw.strip()
|
||
|
|
if not line or line.startswith("#"):
|
||
|
|
continue
|
||
|
|
|
||
|
|
if "#" in line:
|
||
|
|
line = line.split("#", 1)[0].strip()
|
||
|
|
if not line:
|
||
|
|
continue
|
||
|
|
|
||
|
|
parts = line.split()
|
||
|
|
key = parts[0].upper()
|
||
|
|
|
||
|
|
if key == "TITLE":
|
||
|
|
continue
|
||
|
|
if key == "LUT_3D_SIZE":
|
||
|
|
if len(parts) != 2:
|
||
|
|
raise CubeParseError("Invalid LUT_3D_SIZE line")
|
||
|
|
size_3d = int(parts[1])
|
||
|
|
collecting_3d = True
|
||
|
|
continue
|
||
|
|
if key == "LUT_1D_SIZE":
|
||
|
|
if len(parts) == 2:
|
||
|
|
size_1d = int(parts[1])
|
||
|
|
rows_1d_remaining = size_1d
|
||
|
|
continue
|
||
|
|
if key == "DOMAIN_MIN":
|
||
|
|
if len(parts) != 4:
|
||
|
|
raise CubeParseError("Invalid DOMAIN_MIN line")
|
||
|
|
domain_min = np.array(list(map(float, parts[1:4])), dtype=np.float32)
|
||
|
|
continue
|
||
|
|
if key == "DOMAIN_MAX":
|
||
|
|
if len(parts) != 4:
|
||
|
|
raise CubeParseError("Invalid DOMAIN_MAX line")
|
||
|
|
domain_max = np.array(list(map(float, parts[1:4])), dtype=np.float32)
|
||
|
|
continue
|
||
|
|
|
||
|
|
if len(parts) >= 3:
|
||
|
|
try:
|
||
|
|
r, g, b = float(parts[0]), float(parts[1]), float(parts[2])
|
||
|
|
except ValueError:
|
||
|
|
continue
|
||
|
|
|
||
|
|
# Skip 1D LUT rows that appear before the 3D data
|
||
|
|
if rows_1d_remaining > 0 and not collecting_3d:
|
||
|
|
rows_1d_remaining -= 1
|
||
|
|
continue
|
||
|
|
|
||
|
|
data_3d.append((r, g, b))
|
||
|
|
|
||
|
|
if size_3d is None:
|
||
|
|
raise CubeParseError("Missing LUT_3D_SIZE in .cube file")
|
||
|
|
|
||
|
|
expected = size_3d ** 3
|
||
|
|
if len(data_3d) != expected:
|
||
|
|
raise CubeParseError(f"Expected {expected} LUT rows, got {len(data_3d)}")
|
||
|
|
|
||
|
|
arr = np.array(data_3d, dtype=np.float32)
|
||
|
|
# .cube standard: R varies fastest, then G, then B (slowest).
|
||
|
|
# C-order reshape gives arr[B, G, R, channels].
|
||
|
|
# Transpose axes 0↔2 so table[R, G, B] = output RGB for intuitive lookup.
|
||
|
|
arr = arr.reshape((size_3d, size_3d, size_3d, 3), order="C")
|
||
|
|
arr = arr.transpose(2, 1, 0, 3).copy()
|
||
|
|
return CubeLUT(size=size_3d, table=arr, domain_min=domain_min, domain_max=domain_max)
|
||
|
|
|
||
|
|
|
||
|
|
# Set True to swap R↔B channels before/after LUT lookup (debug only)
|
||
|
|
_DEBUG_SWAP_RB = False
|
||
|
|
|
||
|
|
|
||
|
|
def apply_lut_rgb_float01(
|
||
|
|
rgb: np.ndarray,
|
||
|
|
lut: CubeLUT,
|
||
|
|
swap_rb: Optional[bool] = None,
|
||
|
|
) -> np.ndarray:
|
||
|
|
"""
|
||
|
|
Apply a 3D LUT with trilinear interpolation.
|
||
|
|
Input must be sRGB-encoded float32 in [0,1], shape (H, W, 3).
|
||
|
|
LUTs designed for Log/ACES input will NOT produce correct results
|
||
|
|
without an additional input transform.
|
||
|
|
|
||
|
|
swap_rb: if True, swap R↔B before lookup (debug for channel order issues).
|
||
|
|
Defaults to module-level _DEBUG_SWAP_RB flag.
|
||
|
|
"""
|
||
|
|
do_swap = swap_rb if swap_rb is not None else _DEBUG_SWAP_RB
|
||
|
|
if rgb.dtype != np.float32:
|
||
|
|
rgb = rgb.astype(np.float32)
|
||
|
|
|
||
|
|
if do_swap:
|
||
|
|
rgb = rgb[..., ::-1].copy()
|
||
|
|
|
||
|
|
# Map from domain to [0,1]
|
||
|
|
dmin = lut.domain_min.reshape((1, 1, 3))
|
||
|
|
dmax = lut.domain_max.reshape((1, 1, 3))
|
||
|
|
denom = np.maximum(dmax - dmin, 1e-12)
|
||
|
|
x = np.clip((rgb - dmin) / denom, 0.0, 1.0)
|
||
|
|
|
||
|
|
n = lut.size
|
||
|
|
xg = x * (n - 1)
|
||
|
|
|
||
|
|
# Floor indices clamped to [0, n-2] so i0+1 is always valid
|
||
|
|
i0 = np.clip(np.floor(xg).astype(np.int32), 0, n - 2)
|
||
|
|
i1 = i0 + 1
|
||
|
|
t = np.clip(xg - i0.astype(np.float32), 0.0, 1.0)
|
||
|
|
|
||
|
|
r0, g0, b0 = i0[..., 0], i0[..., 1], i0[..., 2]
|
||
|
|
r1, g1, b1 = i1[..., 0], i1[..., 1], i1[..., 2]
|
||
|
|
tr = t[..., 0:1]
|
||
|
|
tg = t[..., 1:2]
|
||
|
|
tb = t[..., 2:3]
|
||
|
|
|
||
|
|
# 8 corners for trilinear interpolation
|
||
|
|
c000 = lut.table[r0, g0, b0]
|
||
|
|
c001 = lut.table[r0, g0, b1]
|
||
|
|
c010 = lut.table[r0, g1, b0]
|
||
|
|
c011 = lut.table[r0, g1, b1]
|
||
|
|
c100 = lut.table[r1, g0, b0]
|
||
|
|
c101 = lut.table[r1, g0, b1]
|
||
|
|
c110 = lut.table[r1, g1, b0]
|
||
|
|
c111 = lut.table[r1, g1, b1]
|
||
|
|
|
||
|
|
# Interpolate along B axis
|
||
|
|
c00 = c000 + (c001 - c000) * tb
|
||
|
|
c01 = c010 + (c011 - c010) * tb
|
||
|
|
c10 = c100 + (c101 - c100) * tb
|
||
|
|
c11 = c110 + (c111 - c110) * tb
|
||
|
|
|
||
|
|
# Interpolate along G axis
|
||
|
|
c0 = c00 + (c01 - c00) * tg
|
||
|
|
c1 = c10 + (c11 - c10) * tg
|
||
|
|
|
||
|
|
# Interpolate along R axis
|
||
|
|
out = c0 + (c1 - c0) * tr
|
||
|
|
out = np.clip(out, 0.0, 1.0).astype(np.float32)
|
||
|
|
|
||
|
|
if do_swap:
|
||
|
|
out = out[..., ::-1].copy()
|
||
|
|
|
||
|
|
return out
|
||
|
|
|
||
|
|
|
||
|
|
def blend_rgb(a: np.ndarray, b: np.ndarray, strength_0_1: float) -> np.ndarray:
|
||
|
|
s = float(np.clip(strength_0_1, 0.0, 1.0))
|
||
|
|
return (a * (1.0 - s) + b * s).astype(np.float32)
|
||
|
|
|
||
|
|
|
||
|
|
# ─── LUT cache & chain ──────────────────────────────────────────────────────
|
||
|
|
|
||
|
|
_lut_cache: dict[str, CubeLUT] = {}
|
||
|
|
|
||
|
|
|
||
|
|
def get_cached_lut(path: str | Path) -> CubeLUT:
|
||
|
|
"""Load a .cube LUT, returning a cached copy on repeat calls."""
|
||
|
|
key = str(Path(path).resolve())
|
||
|
|
if key not in _lut_cache:
|
||
|
|
_lut_cache[key] = load_cube(path)
|
||
|
|
return _lut_cache[key]
|
||
|
|
|
||
|
|
|
||
|
|
def clear_lut_cache():
|
||
|
|
_lut_cache.clear()
|
||
|
|
|
||
|
|
|
||
|
|
def apply_lut_chain(
|
||
|
|
rgb: np.ndarray,
|
||
|
|
chain: list[tuple[CubeLUT, float]],
|
||
|
|
) -> np.ndarray:
|
||
|
|
"""
|
||
|
|
Sequential LUT application. Each ``(lut, strength_0_100)`` is blended::
|
||
|
|
|
||
|
|
out = lerp(input, apply_lut(input), strength / 100)
|
||
|
|
|
||
|
|
The result of each step feeds the next.
|
||
|
|
"""
|
||
|
|
for lut, strength in chain:
|
||
|
|
if strength <= 0:
|
||
|
|
continue
|
||
|
|
lut_rgb = apply_lut_rgb_float01(rgb, lut)
|
||
|
|
rgb = blend_rgb(rgb, lut_rgb, strength / 100.0)
|
||
|
|
rgb = np.clip(rgb, 0.0, 1.0)
|
||
|
|
return rgb.astype(np.float32)
|