MusIre commited on
Commit
b7c2afa
·
verified ·
1 Parent(s): f23dba2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -24
app.py CHANGED
@@ -86,8 +86,7 @@ scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3, ver
86
 
87
  # Load GPT-Neo and CLIP
88
  model_clip = open_clip.create_model('ViT-B/32', pretrained='openai').to(device)
89
- image_size = (224, 224)
90
- preprocess_clip = open_clip.image_transform(image_size=image_size, is_train=False)
91
  tokenizer_clip = open_clip.get_tokenizer('ViT-B/32')
92
  model_clip.eval()
93
 
@@ -95,32 +94,43 @@ model_name = "EleutherAI/gpt-neo-1.3B"
95
  tokenizer = AutoTokenizer.from_pretrained(model_name)
96
  model_gptneo = AutoModelForCausalLM.from_pretrained(model_name).to(device)
97
 
98
- # Generate prediction using ResNet and CLIP
99
- def predict(image_path):
100
  image = Image.open(image_path).convert("RGB")
101
- image_tensor = data_transforms(image).unsqueeze(0).to(device)
102
-
103
- # Predict with ResNet
104
- style_logits, artist_logits = model_resnet(image_tensor)
105
- style_idx = torch.argmax(style_logits, dim=1).item()
106
- artist_idx = torch.argmax(artist_logits, dim=1).item()
107
-
108
- predicted_style = list(label_map_style.keys())[list(label_map_style.values()).index(style_idx)]
109
- predicted_artist = list(label_map_artist.keys())[list(label_map_artist.values()).index(artist_idx)]
110
-
111
- # Enrich prompt with additional information
112
- prompt = enrich_prompt(predicted_artist, predicted_style)
113
-
114
- # Generate text description using GPT-Neo
115
- input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
116
- output = model_gptneo.generate(input_ids, max_length=350, num_return_sequences=1)
117
- description = tokenizer.decode(output[0], skip_special_tokens=True)
118
-
119
- return predicted_style, predicted_artist, description
 
 
 
 
 
 
 
 
 
 
 
 
120
 
121
  # Gradio interface
122
  def gradio_interface(image):
123
- predicted_style, predicted_artist, description = predict(image)
124
  return f"Predicted Style: {predicted_style}\nPredicted Artist: {predicted_artist}\n\nDescription:\n{description}"
125
 
126
  iface = gr.Interface(
 
86
 
87
  # Load GPT-Neo and CLIP
88
  model_clip = open_clip.create_model('ViT-B/32', pretrained='openai').to(device)
89
+ preprocess_clip = open_clip.image_transform((224, 224), is_train=False)
 
90
  tokenizer_clip = open_clip.get_tokenizer('ViT-B/32')
91
  model_clip.eval()
92
 
 
94
  tokenizer = AutoTokenizer.from_pretrained(model_name)
95
  model_gptneo = AutoModelForCausalLM.from_pretrained(model_name).to(device)
96
 
97
+ def generate_description(image_path):
 
98
  image = Image.open(image_path).convert("RGB")
99
+ image_resnet = data_transforms(image).unsqueeze(0).to(device)
100
+
101
+ model_resnet.eval()
102
+ with torch.no_grad():
103
+ outputs_style, outputs_artist = model_resnet(image_resnet)
104
+ _, predicted_style_idx = torch.max(outputs_style, 1)
105
+ _, predicted_artist_idx = torch.max(outputs_artist, 1)
106
+
107
+ idx_to_style = {v: k for k, v in label_map_style.items()}
108
+ idx_to_artist = {v: k for k, v in label_map_artist.items()}
109
+ predicted_style = idx_to_style[predicted_style_idx.item()]
110
+ predicted_artist = idx_to_artist[predicted_artist_idx.item()]
111
+
112
+ enriched_prompt = enrich_prompt(predicted_artist, predicted_style)
113
+ full_prompt = (
114
+ f"This is an artwork created by {predicted_artist} in the style of {predicted_style}. {enriched_prompt} "
115
+ "Describe its distinctive features, considering both the artist's techniques and the artistic style."
116
+ )
117
+
118
+ input_ids = tokenizer.encode(full_prompt, return_tensors="pt").to(device)
119
+ output = model_gptneo.generate(
120
+ input_ids=input_ids,
121
+ max_length=300,
122
+ temperature=0.7,
123
+ top_p=0.9,
124
+ repetition_penalty=1.2
125
+ )
126
+
127
+ description_text = tokenizer.decode(output[0], skip_special_tokens=True)
128
+
129
+ return predicted_style, predicted_artist, description_text
130
 
131
  # Gradio interface
132
  def gradio_interface(image):
133
+ predicted_style, predicted_artist, description = generate_description(image)
134
  return f"Predicted Style: {predicted_style}\nPredicted Artist: {predicted_artist}\n\nDescription:\n{description}"
135
 
136
  iface = gr.Interface(