AmitIsraeli commited on
Commit
2cb6621
·
1 Parent(s): 68fcb41

Add application file

Browse files
Files changed (2) hide show
  1. app.py +14 -0
  2. help_function.py +40 -0
app.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from help_function import help_function
3
+ from PIL import Image
4
+ import numpy as np
5
+
6
+ model_helper = help_function()
7
+
8
+ def greet(numpy_image,text,float_value):
9
+ PIL_image = Image.fromarray(np.uint8(numpy_image)).convert('RGB')
10
+ image_edit = model_helper.image_from_text(text,PIL_image,float_value)
11
+ return image_edit
12
+
13
+ iface = gr.Interface(fn=greet, inputs=["image", "text", gr.inputs.Slider(0.0, 1.0)], outputs="image")
14
+ iface.launch()
help_function.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import open_clip
3
+ from torchvision import transforms
4
+ from torchvision.transforms import ToPILImage
5
+
6
+ class help_function:
7
+ def __init__(self):
8
+ self.clip_text_model = torch.jit.load('jit_models/clip_text_jit.pt', map_location=torch.device('cpu'))
9
+ self.decoder = torch.jit.load('jit_models/decoder_16w.pt', map_location=torch.device('cpu'))
10
+ self.mapper_clip = torch.jit.load('jit_models/mapper_clip_jit.pt', map_location=torch.device('cpu'))
11
+ self.mean_clip = torch.load('jit_models/mean_clip.pt')
12
+ self.mean_person = torch.load('jit_models/mean_person.pt')
13
+ self.encoder = torch.jit.load('jit_models/combined_encoder.pt', map_location=torch.device('cpu'))
14
+ self.tokenizer = open_clip.get_tokenizer('ViT-B-32')
15
+ self.transform = transforms.Compose([
16
+ transforms.Resize(224),
17
+ transforms.ToTensor(),
18
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
19
+ ])
20
+
21
+ def get_text_embedding(self, text):
22
+ text = self.clip_text_model(self.tokenizer(text))
23
+ return text
24
+
25
+ def get_image_inversion(self, image):
26
+ image = self.transform(image)
27
+ w_inversion = self.encoder(image.reshape(1,3,224,224)).reshape(1,16,512)
28
+ return w_inversion + self.mean_person
29
+
30
+ def get_text_delta(self,text_feachers):
31
+ w_delta = self.mapper_clip(text_feachers - self.mean_clip)
32
+ return w_delta
33
+ def image_from_text(self,text,image,power = 1.0):
34
+ w_inversion = self.get_image_inversion(image)
35
+ text_embedding = self.get_text_embedding(text)
36
+ w_delta = self.get_text_delta(text_embedding)
37
+
38
+ w_edit = w_inversion + w_delta * power
39
+ image_edit = self.decoder(w_edit)
40
+ return ToPILImage()((image_edit[0]+0.5)*0.5)