Spaces:
Running
Running
File size: 4,789 Bytes
4489c0c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 |
import streamlit as st
from PIL import Image
import requests
from io import BytesIO
from transformers import (
ViTFeatureExtractor,
ViTForImageClassification,
pipeline,
AutoFeatureExtractor,
AutoModelForObjectDetection,
CLIPTokenizerFast,
CLIPTextModel
)
import torch
from torchvision.transforms import functional as F
import emoji
# Load models
@st.cache_resource
def load_models():
age_model = ViTForImageClassification.from_pretrained('nateraw/vit-age-classifier')
age_transforms = ViTFeatureExtractor.from_pretrained('nateraw/vit-age-classifier')
gender_model = ViTForImageClassification.from_pretrained('rizvandwiki/gender-classification-2')
gender_transforms = ViTFeatureExtractor.from_pretrained('rizvandwiki/gender-classification-2')
emotion_model = ViTForImageClassification.from_pretrained('dima806/facial_emotions_image_detection')
emotion_transforms = ViTFeatureExtractor.from_pretrained('dima806/facial_emotions_image_detection')
object_detector = pipeline("object-detection", model="facebook/detr-resnet-50")
action_model = ViTForImageClassification.from_pretrained('rvv-karma/Human-Action-Recognition-VIT-Base-patch16-224')
action_transforms = ViTFeatureExtractor.from_pretrained('rvv-karma/Human-Action-Recognition-VIT-Base-patch16-224')
prompt_generator = pipeline("text2text-generation", model="succinctly/text2image-prompt-generator")
clip_tokenizer = CLIPTokenizerFast.from_pretrained("openai/clip-vit-base-patch32")
clip_model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32")
return (age_model, age_transforms, gender_model, gender_transforms,
emotion_model, emotion_transforms, object_detector,
action_model, action_transforms, prompt_generator,
clip_tokenizer, clip_model)
models = load_models()
(age_model, age_transforms, gender_model, gender_transforms,
emotion_model, emotion_transforms, object_detector,
action_model, action_transforms, prompt_generator,
clip_tokenizer, clip_model) = models
def predict(image, model, transforms):
inputs = transforms(image, return_tensors='pt')
output = model(**inputs)
proba = output.logits.softmax(1)
return proba.argmax(1).item()
def detect_attributes(image):
age = predict(image, age_model, age_transforms)
gender = predict(image, gender_model, gender_transforms)
emotion = predict(image, emotion_model, emotion_transforms)
action = predict(image, action_model, action_transforms)
objects = object_detector(image)
return {
'age': age_model.config.id2label[age],
'gender': gender_model.config.id2label[gender],
'emotion': emotion_model.config.id2label[emotion],
'action': action_model.config.id2label[action],
'objects': [obj['label'] for obj in objects]
}
def generate_prompt(attributes):
prompt = f"A {attributes['age']} {attributes['gender']} person feeling {attributes['emotion']} "
prompt += f"while {attributes['action']}. "
if attributes['objects']:
prompt += f"Surrounded by {', '.join(attributes['objects'])}. "
return prompt
def generate_emoji(prompt):
inputs = clip_tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)
outputs = clip_model(**inputs)
embedding = outputs.last_hidden_state.mean(dim=1)
# Simple emoji mapping based on embedding features
if embedding[0, 0] > 0:
return emoji.emojize(":grinning_face:")
elif embedding[0, 1] > 0:
return emoji.emojize(":smiling_face_with_heart-eyes:")
elif embedding[0, 2] > 0:
return emoji.emojize(":face_with_tears_of_joy:")
elif embedding[0, 3] > 0:
return emoji.emojize(":thinking_face:")
else:
return emoji.emojize(":neutral_face:")
st.title("Image Attribute Detection and Emoji Generation")
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
if uploaded_file is not None:
image = Image.open(uploaded_file)
st.image(image, caption='Uploaded Image', use_column_width=True)
if st.button('Analyze and Generate Emoji'):
with st.spinner('Detecting attributes...'):
attributes = detect_attributes(image)
st.write("Detected Attributes:")
for key, value in attributes.items():
st.write(f"{key.capitalize()}: {value}")
with st.spinner('Generating prompt...'):
prompt = generate_prompt(attributes)
st.write("Generated Prompt:")
st.write(prompt)
with st.spinner('Generating emoji...'):
emoji_result = generate_emoji(prompt)
st.write("Generated Emoji:")
st.write(emoji_result) |