import json import torch from huggingnft.lightweight_gan.train import timestamped_filename from streamlit_option_menu import option_menu from huggingface_hub import hf_hub_download, file_download from PIL import Image from huggingface_hub.hf_api import HfApi import streamlit as st from huggingnft.lightweight_gan.lightweight_gan import Generator, LightweightGAN, evaluate_in_chunks, Trainer from accelerate import Accelerator from huggan.pytorch.cyclegan.modeling_cyclegan import GeneratorResNet from torchvision import transforms as T from torchvision.transforms import Compose, Resize, ToTensor, Normalize, RandomCrop, RandomHorizontalFlip from torchvision.utils import make_grid import requests hfapi = HfApi() model_names = [model.modelId[model.modelId.index("/") + 1:] for model in hfapi.list_models(author="huggingnft")] # streamlit-option-menu # st.set_page_config(page_title="Streamlit App Gallery", page_icon="", layout="wide") # sysmenu = ''' #
""", unsafe_allow_html=True, ) if choose == "About": README = requests.get("https://raw.githubusercontent.com/AlekseyKorshuk/huggingnft/main/README.md").text README = str(README).replace('width="1200"','width="700"') # st.title(choose) st.markdown(README, unsafe_allow_html=True) if choose == "Contact": st.title(choose) st.markdown(CONTACT_TEXT) if choose == "Generate image": st.title(choose) st.markdown(GENERATE_IMAGE_TEXT) model_name = st.selectbox( 'Choose model:', clean_models(model_names, COLLECTION2COLLECTION_KEYS) ) generation_type = st.selectbox( 'Select generation type:', ["default", "ema"] ) nrows = st.number_input("Number of rows:", min_value=1, max_value=10, step=1, value=8, ) generate_image_button = st.button("Generate") if generate_image_button: with st.spinner(text=f"Downloading selected model..."): model = load_lightweight_model(f"huggingnft/{model_name}") with st.spinner(text=f"Generating..."): image = model.generate_app( num=timestamped_filename(), nrow=nrows, checkpoint=-1, types=generation_type )[0] st.markdown(TRAIN_TEXT) st.image( image ) if choose == "Interpolation": st.title(choose) st.markdown(INTERPOLATION_TEXT) model_name = st.selectbox( 'Choose model:', clean_models(model_names, COLLECTION2COLLECTION_KEYS) ) nrows = st.number_input("Number of rows:", min_value=1, max_value=10, step=1, value=1, ) num_steps = st.number_input("Number of steps:", min_value=1, max_value=1000, step=1, value=100, ) generate_image_button = st.button("Generate") if generate_image_button: with st.spinner(text=f"Downloading selected model..."): model = load_lightweight_model(f"huggingnft/{model_name}") my_bar = st.progress(0) result = model.generate_interpolation( num=timestamped_filename(), num_image_tiles=nrows, num_steps=num_steps, save_frames=False, progress_bar=my_bar ) my_bar.empty() st.markdown(TRAIN_TEXT) st.image( result ) if choose == "Collection2Collection": st.title(choose) st.markdown(COLLECTION2COLLECTION_TEXT) model_name = st.selectbox( 'Choose model:', set(model_names) - set(clean_models(model_names, COLLECTION2COLLECTION_KEYS)) ) nrows = st.number_input("Number of images to generate:", min_value=1, max_value=10, step=1, value=1, ) generate_image_button = st.button("Generate") if generate_image_button: n_channels = 3 image_size = 256 input_shape = (image_size, image_size) transform = Compose([ T.ToPILImage(), T.Resize(input_shape), ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ]) with st.spinner(text=f"Downloading selected model..."): translator = GeneratorResNet.from_pretrained(f'huggingnft/{model_name}', input_shape=(n_channels, image_size, image_size), num_residual_blocks=9) z = torch.randn(nrows, 100, 1, 1) with st.spinner(text=f"Downloading selected model..."): model = load_lightweight_model(f"huggingnft/{model_name.split('__2__')[0]}") with st.spinner(text=f"Generating input images..."): punks = model.generate_app( num=timestamped_filename(), nrow=nrows, checkpoint=-1, types="default" )[1] pipe_transform = T.Resize((256, 256)) input = pipe_transform(punks) with st.spinner(text=f"Generating output images..."): output = translator(input) out_img = make_grid(output, nrow=4, normalize=True) # out_img = make_grid(punks, # nrow=8, normalize=True) out_transform = Compose([ T.ToPILImage() ]) results = [] for out_punk, out_ape in zip(input, output): results.append( get_concat_h(out_transform(make_grid(out_punk, nrow=1, normalize=True)), out_transform(make_grid(out_ape, nrow=1, normalize=True))) ) st.markdown(TRAIN_TEXT) for result in results: st.image(result)