Spaces:
Running
Running
import streamlit as st | |
from PIL import Image | |
import requests | |
from io import BytesIO | |
from transformers import ( | |
ViTFeatureExtractor, | |
ViTForImageClassification, | |
pipeline, | |
AutoFeatureExtractor, | |
AutoModelForObjectDetection, | |
CLIPTokenizerFast, | |
CLIPTextModel | |
) | |
import torch | |
from torchvision.transforms import functional as F | |
import emoji | |
# Load models | |
def load_models(): | |
age_model = ViTForImageClassification.from_pretrained('nateraw/vit-age-classifier') | |
age_transforms = ViTFeatureExtractor.from_pretrained('nateraw/vit-age-classifier') | |
gender_model = ViTForImageClassification.from_pretrained('rizvandwiki/gender-classification-2') | |
gender_transforms = ViTFeatureExtractor.from_pretrained('rizvandwiki/gender-classification-2') | |
emotion_model = ViTForImageClassification.from_pretrained('dima806/facial_emotions_image_detection') | |
emotion_transforms = ViTFeatureExtractor.from_pretrained('dima806/facial_emotions_image_detection') | |
object_detector = pipeline("object-detection", model="facebook/detr-resnet-50") | |
action_model = ViTForImageClassification.from_pretrained('rvv-karma/Human-Action-Recognition-VIT-Base-patch16-224') | |
action_transforms = ViTFeatureExtractor.from_pretrained('rvv-karma/Human-Action-Recognition-VIT-Base-patch16-224') | |
prompt_generator = pipeline("text2text-generation", model="succinctly/text2image-prompt-generator") | |
clip_tokenizer = CLIPTokenizerFast.from_pretrained("openai/clip-vit-base-patch32") | |
clip_model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32") | |
return (age_model, age_transforms, gender_model, gender_transforms, | |
emotion_model, emotion_transforms, object_detector, | |
action_model, action_transforms, prompt_generator, | |
clip_tokenizer, clip_model) | |
models = load_models() | |
(age_model, age_transforms, gender_model, gender_transforms, | |
emotion_model, emotion_transforms, object_detector, | |
action_model, action_transforms, prompt_generator, | |
clip_tokenizer, clip_model) = models | |
def predict(image, model, transforms): | |
inputs = transforms(image, return_tensors='pt') | |
output = model(**inputs) | |
proba = output.logits.softmax(1) | |
return proba.argmax(1).item() | |
def detect_attributes(image): | |
age = predict(image, age_model, age_transforms) | |
gender = predict(image, gender_model, gender_transforms) | |
emotion = predict(image, emotion_model, emotion_transforms) | |
action = predict(image, action_model, action_transforms) | |
objects = object_detector(image) | |
return { | |
'age': age_model.config.id2label[age], | |
'gender': gender_model.config.id2label[gender], | |
'emotion': emotion_model.config.id2label[emotion], | |
'action': action_model.config.id2label[action], | |
'objects': [obj['label'] for obj in objects] | |
} | |
def generate_prompt(attributes): | |
prompt = f"A {attributes['age']} {attributes['gender']} person feeling {attributes['emotion']} " | |
prompt += f"while {attributes['action']}. " | |
if attributes['objects']: | |
prompt += f"Surrounded by {', '.join(attributes['objects'])}. " | |
return prompt | |
def generate_emoji(prompt): | |
inputs = clip_tokenizer(prompt, return_tensors="pt", padding=True, truncation=True) | |
outputs = clip_model(**inputs) | |
embedding = outputs.last_hidden_state.mean(dim=1) | |
# Simple emoji mapping based on embedding features | |
if embedding[0, 0] > 0: | |
return emoji.emojize(":grinning_face:") | |
elif embedding[0, 1] > 0: | |
return emoji.emojize(":smiling_face_with_heart-eyes:") | |
elif embedding[0, 2] > 0: | |
return emoji.emojize(":face_with_tears_of_joy:") | |
elif embedding[0, 3] > 0: | |
return emoji.emojize(":thinking_face:") | |
else: | |
return emoji.emojize(":neutral_face:") | |
st.title("Image Attribute Detection and Emoji Generation") | |
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) | |
if uploaded_file is not None: | |
image = Image.open(uploaded_file) | |
st.image(image, caption='Uploaded Image', use_column_width=True) | |
if st.button('Analyze and Generate Emoji'): | |
with st.spinner('Detecting attributes...'): | |
attributes = detect_attributes(image) | |
st.write("Detected Attributes:") | |
for key, value in attributes.items(): | |
st.write(f"{key.capitalize()}: {value}") | |
with st.spinner('Generating prompt...'): | |
prompt = generate_prompt(attributes) | |
st.write("Generated Prompt:") | |
st.write(prompt) | |
with st.spinner('Generating emoji...'): | |
emoji_result = generate_emoji(prompt) | |
st.write("Generated Emoji:") | |
st.write(emoji_result) |