import torch from huggan.pytorch.lightweight_gan.lightweight_gan import LightweightGAN # from datasets import load_dataset # from PIL import Image import numpy as np # import paddlehub as hub # import random # from PIL import ImageDraw,ImageFont # import streamlit as st # @st.experimental_singleton # def load_bg_model(): # bg_model = hub.Module(name='U2NetP', directory='assets/models/') # return bg_model # bg_model = load_bg_model() # def remove_bg(img): # result = bg_model.Segmentation( # images=[np.array(img)[:,:,::-1]], # paths=None, # batch_size=1, # input_size=320, # output_dir=None, # visualization=False) # output = result[0] # mask=Image.fromarray(output['mask']) # front=Image.fromarray(output['front'][:,:,::-1]).convert("RGBA") # front.putalpha(mask) # return front # meme_template=Image.open("./assets/pigeon_meme.jpg").convert("RGBA") # def make_meme(pigeon,text="Is this a pigeon?",show_text=True,remove_background=True): # meme=meme_template.copy() # approx_butterfly_center=(850,30) # if remove_background: # pigeon=remove_bg(pigeon) # else: # pigeon=Image.fromarray(pigeon).convert("RGBA") # random_rotate=random.randint(-30,30) # random_size=random.randint(150,200) # pigeon=pigeon.resize((random_size,random_size)).rotate(random_rotate,expand=True) # meme.alpha_composite(pigeon, approx_butterfly_center) # #ref: https://blog.lipsumarium.com/caption-memes-in-python/ # def drawTextWithOutline(text, x, y): # draw.text((x-2, y-2), text,(0,0,0),font=font) # draw.text((x+2, y-2), text,(0,0,0),font=font) # draw.text((x+2, y+2), text,(0,0,0),font=font) # draw.text((x-2, y+2), text,(0,0,0),font=font) # draw.text((x, y), text, (255,255,255), font=font) # if show_text: # draw = ImageDraw.Draw(meme) # font_size=52 # font = ImageFont.truetype("assets/impact.ttf", font_size) # w, h = draw.textsize(text, font) # measure the size the text will take # drawTextWithOutline(text, meme.width/2 - w/2, meme.height - font_size*2) # meme = meme.convert("RGB") # return meme # def get_train_data(dataset_name="huggan/smithsonian_butterflies_subset"): # dataset=load_dataset(dataset_name) # dataset=dataset.sort("sim_score") # return dataset["train"] # from transformers import BeitFeatureExtractor, BeitForImageClassification # emb_feature_extractor = BeitFeatureExtractor.from_pretrained('microsoft/beit-base-patch16-224') # emb_model = BeitForImageClassification.from_pretrained('microsoft/beit-base-patch16-224') # def embed(images): # inputs = emb_feature_extractor(images=images, return_tensors="pt") # outputs = emb_model(**inputs,output_hidden_states= True) # last_hidden=outputs.hidden_states[-1] # pooler=emb_model.base_model.pooler # final_emb=pooler(last_hidden).detach().numpy() # return final_emb # def build_index(): # dataset=get_train_data() # ds_with_embeddings = dataset.map(lambda x: {"beit_embeddings":embed(x["image"])},batched=True,batch_size=20) # ds_with_embeddings.add_faiss_index(column='beit_embeddings') # ds_with_embeddings.save_faiss_index('beit_embeddings', 'beit_index.faiss') # def get_dataset(): # dataset=get_train_data() # dataset.load_faiss_index('beit_embeddings', 'beit_index.faiss') # return dataset def load_model(model_name='ceyda/butterfly_cropped_uniq1K_512',model_version=None): gan = LightweightGAN.from_pretrained(model_name,version=model_version) gan.eval() return gan def generate(gan,batch_size=1): with torch.no_grad(): ims = gan.G(torch.randn(batch_size, gan.latent_dim)).clamp_(0., 1.)*255 ims = ims.permute(0,2,3,1).detach().cpu().numpy().astype(np.uint8) return ims # def interpolate(): # pass