import numpy as np import torch import torch.nn as nn from transformers import CLIPConfig, CLIPVisionModelWithProjection, PreTrainedModel from ...utils import logging logger = logging.get_logger(__name__) class IFSafetyChecker(PreTrainedModel): config_class = CLIPConfig _no_split_modules = ["CLIPEncoderLayer"] def __init__(self, config: CLIPConfig): super().__init__(config) self.vision_model = CLIPVisionModelWithProjection(config.vision_config) self.p_head = nn.Linear(config.vision_config.projection_dim, 1) self.w_head = nn.Linear(config.vision_config.projection_dim, 1) @torch.no_grad() def forward(self, clip_input, images, p_threshold=0.5, w_threshold=0.5): image_embeds = self.vision_model(clip_input)[0] nsfw_detected = self.p_head(image_embeds) nsfw_detected = nsfw_detected.flatten() nsfw_detected = nsfw_detected > p_threshold nsfw_detected = nsfw_detected.tolist() if any(nsfw_detected): logger.warning( "Potential NSFW content was detected in one or more images. A black image will be returned instead." " Try again with a different prompt and/or seed." ) for idx, nsfw_detected_ in enumerate(nsfw_detected): if nsfw_detected_: images[idx] = np.zeros(images[idx].shape) watermark_detected = self.w_head(image_embeds) watermark_detected = watermark_detected.flatten() watermark_detected = watermark_detected > w_threshold watermark_detected = watermark_detected.tolist() if any(watermark_detected): logger.warning( "Potential watermarked content was detected in one or more images. A black image will be returned instead." " Try again with a different prompt and/or seed." ) for idx, watermark_detected_ in enumerate(watermark_detected): if watermark_detected_: images[idx] = np.zeros(images[idx].shape) return images, nsfw_detected, watermark_detected