|
from huggingface_hub import cached_download, hf_hub_url |
|
from PIL import Image |
|
import os |
|
import gradio as gr |
|
import spaces |
|
import torch |
|
from torch import nn |
|
from torch.utils.data import Dataset, DataLoader |
|
from torchvision import transforms |
|
from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModel, CLIPModel |
|
|
|
@spaces.GPU() |
|
def train_image_generation_model(image_folder, text_folder, model_name="image_generation_model"): |
|
"""Trains an image generation model on the provided dataset. |
|
|
|
Args: |
|
image_folder (str): Path to the folder containing training images. |
|
text_folder (str): Path to the folder containing text prompts for each image. |
|
model_name (str, optional): Name for the saved model file. Defaults to "image_generation_model". |
|
|
|
Returns: |
|
str: Path to the saved model file. |
|
""" |
|
|
|
class ImageTextDataset(Dataset): |
|
def __init__(self, image_folder, text_folder, transform=None): |
|
self.image_paths = [os.path.join(image_folder, f) for f in os.listdir(image_folder) if f.lower().endswith(('.png', '.jpg', '.jpeg'))] |
|
self.text_paths = [os.path.join(text_folder, f) for f in os.listdir(text_folder) if f.lower().endswith('.txt')] |
|
self.transform = transform |
|
|
|
def __len__(self): |
|
return len(self.image_paths) |
|
|
|
def __getitem__(self, idx): |
|
image = Image.open(self.image_paths[idx]).convert("RGB") |
|
if self.transform: |
|
image = self.transform(image) |
|
with open(self.text_paths[idx], 'r') as f: |
|
text = f.read().strip() |
|
return image, text |
|
|
|
|
|
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") |
|
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32") |
|
|
|
|
|
transform = transforms.Compose([ |
|
transforms.Resize((224, 224)), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711]) |
|
]) |
|
|
|
|
|
dataset = ImageTextDataset(image_folder, text_folder, transform=transform) |
|
dataloader = DataLoader(dataset, batch_size=8, shuffle=True) |
|
|
|
|
|
optimizer = torch.optim.Adam(clip_model.parameters(), lr=1e-5) |
|
loss_fn = nn.CrossEntropyLoss() |
|
|
|
|
|
for epoch in range(10): |
|
for i, (images, texts) in enumerate(dataloader): |
|
optimizer.zero_grad() |
|
image_features = clip_model.get_image_features(images) |
|
text_features = clip_model.get_text_features(tokenizer(texts, return_tensors="pt")["input_ids"]) |
|
similarity = image_features @ text_features.T |
|
loss = loss_fn(similarity, torch.arange(images.size(0), device=images.device)) |
|
loss.backward() |
|
optimizer.step() |
|
print(f"Epoch: {epoch} | Iteration: {i} | Loss: {loss.item()}") |
|
|
|
|
|
model_path = os.path.join(os.getcwd(), model_name + ".pt") |
|
torch.save(clip_model.state_dict(), model_path) |
|
|
|
return model_path |
|
|
|
|
|
iface = gr.Interface( |
|
fn=train_image_generation_model, |
|
inputs=[ |
|
gr.File(label="Image Folder", file_count="directory"), |
|
gr.File(label="Text Prompts Folder", file_count="directory"), |
|
gr.Textbox(label="Model Name"), |
|
], |
|
outputs=gr.File(label="Model File"), |
|
title="Image Generation Model Trainer", |
|
description="Upload a folder of images and their corresponding text prompts to train a model.\n Images foler should contain image files. Prompts folder should contain .txt files. Each text file is prompt for each image in images folder.", |
|
) |
|
|
|
iface.launch(share=True) |