Flux-TRELLIS / lut_processor.py
gokaygokay's picture
add tab
c63008c
raw
history blame
3.92 kB
import gradio as gr
import numpy as np
from pathlib import Path
from colour.io.luts.iridas_cube import read_LUT_IridasCube
import torch
import os
def get_available_luts():
cube_luts_dir = Path('cube_luts')
if not cube_luts_dir.exists():
return []
return sorted([f.name for f in cube_luts_dir.glob('*.cube')])
def apply_lut(image, lut_name, gamma_correction=True, clip_values=True, strength=1.0):
if image is None or lut_name is None:
return None
# Convert gradio image to torch tensor
image = torch.from_numpy(image).float() / 255.0
# Get full path to LUT file
lut_file = Path('cube_luts') / lut_name
# Read LUT file with error handling for different encodings
try:
lut = read_LUT_IridasCube(str(lut_file))
except UnicodeDecodeError:
# Try different encodings if utf-8 fails
try:
with open(str(lut_file), 'r', encoding='latin-1') as f:
lut = read_LUT_IridasCube(f)
except Exception as e:
print(f"Error reading LUT file with latin-1 encoding: {e}")
return image.numpy() * 255.0
except Exception as e:
print(f"Error reading LUT file: {e}")
return image.numpy() * 255.0
lut.name = lut_name
# Handle clipping
if clip_values:
if lut.domain[0].max() == lut.domain[0].min() and lut.domain[1].max() == lut.domain[1].min():
lut.table = np.clip(lut.table, lut.domain[0, 0], lut.domain[1, 0])
else:
if len(lut.table.shape) == 2: # 3x1D
for dim in range(3):
lut.table[:, dim] = np.clip(lut.table[:, dim], lut.domain[0, dim], lut.domain[1, dim])
else: # 3D
for dim in range(3):
lut.table[:, :, :, dim] = np.clip(lut.table[:, :, :, dim], lut.domain[0, dim], lut.domain[1, dim])
# Process image
lut_img = image.numpy().copy()
is_non_default_domain = not np.array_equal(lut.domain, np.array([[0., 0., 0.], [1., 1., 1.]]))
dom_scale = None
if is_non_default_domain:
dom_scale = lut.domain[1] - lut.domain[0]
lut_img = lut_img * dom_scale + lut.domain[0]
if gamma_correction:
lut_img = lut_img ** (1/2.2)
lut_img = lut.apply(lut_img)
if gamma_correction:
lut_img = lut_img ** (2.2)
if is_non_default_domain:
lut_img = (lut_img - lut.domain[0]) / dom_scale
# Ensure values are in valid range
lut_img = np.clip(lut_img, 0, 1)
lut_img = torch.from_numpy(lut_img).float()
if strength < 1.0:
lut_img = strength * lut_img + (1 - strength) * image
# Convert back to uint8 range and ensure proper bounds
result = (lut_img.numpy() * 255.0)
result = np.clip(result, 0, 255).astype(np.uint8)
return result
def create_lut_tab():
available_luts = get_available_luts()
with gr.Tab("LUT"):
with gr.Row():
with gr.Column():
input_image = gr.Image(label="Input Image", height=256)
lut_dropdown = gr.Dropdown(
choices=available_luts,
label="Select LUT",
value=available_luts[0] if available_luts else None
)
gamma_correction = gr.Checkbox(label="Gamma Correction", value=True)
clip_values = gr.Checkbox(label="Clip Values", value=True)
strength = gr.Slider(minimum=0.0, maximum=1.0, value=1.0, step=0.1, label="Effect Strength")
process_btn = gr.Button("Apply LUT")
with gr.Column():
output_image = gr.Image(label="Output Image")
process_btn.click(
fn=apply_lut,
inputs=[input_image, lut_dropdown, gamma_correction, clip_values, strength],
outputs=output_image
)