import gradio as gr import torch import torch.nn as nn import torch.nn.functional as F import numpy as np from skimage.exposure import match_histograms class HistogramMatcher(nn.Module): def __init__(self, differentiable=False): super(HistogramMatcher, self).__init__() self.differentiable = differentiable def forward(self, dst, ref): B, C, H, W = dst.size() hist_dst = self.cal_hist(dst) hist_ref = self.cal_hist(ref) tables = self.cal_trans_batch(hist_dst, hist_ref) rst = dst.clone() for b in range(B): for c in range(C): rst[b,c] = tables[b*c, (dst[b,c] * 255).long()] return rst / 255. def cal_hist(self, img): B, C, H, W = img.size() if self.differentiable: hists = self.soft_histc_batch(img * 255, bins=256, min=0, max=256, sigma=75) else: hists = torch.stack([torch.histc(img[b,c] * 255, bins=256, min=0, max=255) for b in range(B) for c in range(C)]) hists = hists.float() hists = F.normalize(hists, p=1) bc, n = hists.size() triu = torch.ones(bc, n, n, device=hists.device).triu() hists = torch.bmm(hists[:,None,:], triu)[:,0,:] return hists def soft_histc_batch(self, x, bins=256, min=0, max=256, sigma=75): B, C, H, W = x.size() x = x.view(B*C, -1) delta = float(max - min) / float(bins) centers = float(min) + delta * (torch.arange(bins, device=x.device) + 0.5) x = torch.unsqueeze(x, 1) centers = centers[None,:,None] x = x - centers x = torch.sigmoid(sigma * (x + delta/2)) - torch.sigmoid(sigma * (x - delta/2)) x = x.sum(dim=2) return x def cal_trans_batch(self, hist_dst, hist_ref): hist_dst = hist_dst[:,None,:].repeat(1,256,1) hist_ref = hist_ref[:,:,None].repeat(1,1,256) table = hist_dst - hist_ref table = torch.where(table>=0, 1., 0.) table = torch.sum(table, dim=1) - 1 table = torch.clamp(table, min=0, max=255) return table def apply_histogram_matching(image, reference, factor): if image is None or reference is None: return None # Convert to torch tensors and normalize image = torch.from_numpy(image).float() / 255.0 reference = torch.from_numpy(reference).float() / 255.0 # Add batch dimension and rearrange to BCHW image = image.unsqueeze(0).permute(0, 3, 1, 2) reference = reference.unsqueeze(0).permute(0, 3, 1, 2) matched = match_histograms( image.permute(0, 2, 3, 1).numpy(), reference.permute(0, 2, 3, 1).numpy(), channel_axis=3 ) matched = torch.from_numpy(matched).permute(0, 3, 1, 2) # Apply factor blending result = factor * matched + (1 - factor) * image # Convert back to HWC format and to uint8 output = result.squeeze(0).permute(1, 2, 0) output = (output.clamp(0, 1).numpy() * 255).astype(np.uint8) return output def create_histogram_tab(): with gr.Tab("Histogram Matching"): with gr.Row(): with gr.Column(): input_image = gr.Image(label="Input Image", height=256) reference_image = gr.Image(label="Reference Image", height=256) factor = gr.Slider( minimum=0.0, maximum=1.0, value=0.5, step=0.05, label="Blend Factor" ) match_btn = gr.Button("Apply Histogram Matching") with gr.Column(): output_image = gr.Image(label="Matched Image") match_btn.click( fn=apply_histogram_matching, inputs=[input_image, reference_image, factor], outputs=output_image )