File size: 4,789 Bytes
4489c0c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import streamlit as st
from PIL import Image
import requests
from io import BytesIO
from transformers import (
    ViTFeatureExtractor, 
    ViTForImageClassification, 
    pipeline,
    AutoFeatureExtractor, 
    AutoModelForObjectDetection,
    CLIPTokenizerFast,
    CLIPTextModel
)
import torch
from torchvision.transforms import functional as F
import emoji

# Load models
@st.cache_resource
def load_models():
    age_model = ViTForImageClassification.from_pretrained('nateraw/vit-age-classifier')
    age_transforms = ViTFeatureExtractor.from_pretrained('nateraw/vit-age-classifier')
    
    gender_model = ViTForImageClassification.from_pretrained('rizvandwiki/gender-classification-2')
    gender_transforms = ViTFeatureExtractor.from_pretrained('rizvandwiki/gender-classification-2')
    
    emotion_model = ViTForImageClassification.from_pretrained('dima806/facial_emotions_image_detection')
    emotion_transforms = ViTFeatureExtractor.from_pretrained('dima806/facial_emotions_image_detection')
    
    object_detector = pipeline("object-detection", model="facebook/detr-resnet-50")
    
    action_model = ViTForImageClassification.from_pretrained('rvv-karma/Human-Action-Recognition-VIT-Base-patch16-224')
    action_transforms = ViTFeatureExtractor.from_pretrained('rvv-karma/Human-Action-Recognition-VIT-Base-patch16-224')
    
    prompt_generator = pipeline("text2text-generation", model="succinctly/text2image-prompt-generator")
    
    clip_tokenizer = CLIPTokenizerFast.from_pretrained("openai/clip-vit-base-patch32")
    clip_model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32")
    
    return (age_model, age_transforms, gender_model, gender_transforms, 
            emotion_model, emotion_transforms, object_detector, 
            action_model, action_transforms, prompt_generator, 
            clip_tokenizer, clip_model)

models = load_models()
(age_model, age_transforms, gender_model, gender_transforms, 
 emotion_model, emotion_transforms, object_detector, 
 action_model, action_transforms, prompt_generator, 
 clip_tokenizer, clip_model) = models

def predict(image, model, transforms):
    inputs = transforms(image, return_tensors='pt')
    output = model(**inputs)
    proba = output.logits.softmax(1)
    return proba.argmax(1).item()

def detect_attributes(image):
    age = predict(image, age_model, age_transforms)
    gender = predict(image, gender_model, gender_transforms)
    emotion = predict(image, emotion_model, emotion_transforms)
    action = predict(image, action_model, action_transforms)
    
    objects = object_detector(image)
    
    return {
        'age': age_model.config.id2label[age],
        'gender': gender_model.config.id2label[gender],
        'emotion': emotion_model.config.id2label[emotion],
        'action': action_model.config.id2label[action],
        'objects': [obj['label'] for obj in objects]
    }

def generate_prompt(attributes):
    prompt = f"A {attributes['age']} {attributes['gender']} person feeling {attributes['emotion']} "
    prompt += f"while {attributes['action']}. "
    if attributes['objects']:
        prompt += f"Surrounded by {', '.join(attributes['objects'])}. "
    return prompt

def generate_emoji(prompt):
    inputs = clip_tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)
    outputs = clip_model(**inputs)
    embedding = outputs.last_hidden_state.mean(dim=1)
    
    # Simple emoji mapping based on embedding features
    if embedding[0, 0] > 0:
        return emoji.emojize(":grinning_face:")
    elif embedding[0, 1] > 0:
        return emoji.emojize(":smiling_face_with_heart-eyes:")
    elif embedding[0, 2] > 0:
        return emoji.emojize(":face_with_tears_of_joy:")
    elif embedding[0, 3] > 0:
        return emoji.emojize(":thinking_face:")
    else:
        return emoji.emojize(":neutral_face:")

st.title("Image Attribute Detection and Emoji Generation")

uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])

if uploaded_file is not None:
    image = Image.open(uploaded_file)
    st.image(image, caption='Uploaded Image', use_column_width=True)
    
    if st.button('Analyze and Generate Emoji'):
        with st.spinner('Detecting attributes...'):
            attributes = detect_attributes(image)
        
        st.write("Detected Attributes:")
        for key, value in attributes.items():
            st.write(f"{key.capitalize()}: {value}")
        
        with st.spinner('Generating prompt...'):
            prompt = generate_prompt(attributes)
        st.write("Generated Prompt:")
        st.write(prompt)
        
        with st.spinner('Generating emoji...'):
            emoji_result = generate_emoji(prompt)
        st.write("Generated Emoji:")
        st.write(emoji_result)