from datasets import load_dataset

# dataset_name = "dim/nfs_pix2pix_1920_1080_v5"
dataset_name = "dim/nfs_pix2pix_1920_1080_v6"
dataset = load_dataset(dataset_name, num_proc=4)
dataset = dataset["train"]
import os

os.chdir("/code/img2img-turbo/src")
import argparse
import numpy as np
from PIL import Image
import torch
from torchvision import transforms
import torchvision.transforms.functional as F
from pix2pix_turbo import Pix2Pix_Turbo
from image_prep import canny_from_pil

model_name = ""
model_path = "/code/img2img-turbo/output/pix2pix_turbo/nfs_pix2pix_1736564855/checkpoints/model_16001.pkl"
use_fp16 = False

# initialize the model
model = Pix2Pix_Turbo(pretrained_name=model_name, pretrained_path=model_path)
model.set_eval()
if use_fp16:
    model.half()

T = transforms.Compose(
    [
        transforms.Resize(512, interpolation=transforms.InterpolationMode.LANCZOS),
        transforms.CenterCrop(512),
    ]
)
input_image = dataset[290]["input_image"].convert("RGB")
prompt = dataset[0]["edit_prompt"]
with torch.no_grad():
    i_t = T(input_image)
    c_t = F.to_tensor(i_t).unsqueeze(0).cuda()
    if use_fp16:
        c_t = c_t.half()
    output_image = model(c_t, prompt)

    output_pil = transforms.ToPILImage()(output_image[0].cpu() * 0.5 + 0.5)

output_pil
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model is not currently available via any of the supported Inference Providers.
The model cannot be deployed to the HF Inference API: The model has no library tag.