import streamlit as st from PIL import Image import requests from io import BytesIO from transformers import ( ViTFeatureExtractor, ViTForImageClassification, pipeline, AutoFeatureExtractor, AutoModelForObjectDetection, CLIPTokenizerFast, CLIPTextModel ) import torch from torchvision.transforms import functional as F import emoji # Load models @st.cache_resource def load_models(): age_model = ViTForImageClassification.from_pretrained('nateraw/vit-age-classifier') age_transforms = ViTFeatureExtractor.from_pretrained('nateraw/vit-age-classifier') gender_model = ViTForImageClassification.from_pretrained('rizvandwiki/gender-classification-2') gender_transforms = ViTFeatureExtractor.from_pretrained('rizvandwiki/gender-classification-2') emotion_model = ViTForImageClassification.from_pretrained('dima806/facial_emotions_image_detection') emotion_transforms = ViTFeatureExtractor.from_pretrained('dima806/facial_emotions_image_detection') object_detector = pipeline("object-detection", model="facebook/detr-resnet-50") action_model = ViTForImageClassification.from_pretrained('rvv-karma/Human-Action-Recognition-VIT-Base-patch16-224') action_transforms = ViTFeatureExtractor.from_pretrained('rvv-karma/Human-Action-Recognition-VIT-Base-patch16-224') prompt_generator = pipeline("text2text-generation", model="succinctly/text2image-prompt-generator") clip_tokenizer = CLIPTokenizerFast.from_pretrained("openai/clip-vit-base-patch32") clip_model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32") return (age_model, age_transforms, gender_model, gender_transforms, emotion_model, emotion_transforms, object_detector, action_model, action_transforms, prompt_generator, clip_tokenizer, clip_model) models = load_models() (age_model, age_transforms, gender_model, gender_transforms, emotion_model, emotion_transforms, object_detector, action_model, action_transforms, prompt_generator, clip_tokenizer, clip_model) = models def predict(image, model, transforms): inputs = transforms(image, return_tensors='pt') output = model(**inputs) proba = output.logits.softmax(1) return proba.argmax(1).item() def detect_attributes(image): age = predict(image, age_model, age_transforms) gender = predict(image, gender_model, gender_transforms) emotion = predict(image, emotion_model, emotion_transforms) action = predict(image, action_model, action_transforms) objects = object_detector(image) return { 'age': age_model.config.id2label[age], 'gender': gender_model.config.id2label[gender], 'emotion': emotion_model.config.id2label[emotion], 'action': action_model.config.id2label[action], 'objects': [obj['label'] for obj in objects] } def generate_prompt(attributes): prompt = f"A {attributes['age']} {attributes['gender']} person feeling {attributes['emotion']} " prompt += f"while {attributes['action']}. " if attributes['objects']: prompt += f"Surrounded by {', '.join(attributes['objects'])}. " return prompt def generate_emoji(prompt): inputs = clip_tokenizer(prompt, return_tensors="pt", padding=True, truncation=True) outputs = clip_model(**inputs) embedding = outputs.last_hidden_state.mean(dim=1) # Simple emoji mapping based on embedding features if embedding[0, 0] > 0: return emoji.emojize(":grinning_face:") elif embedding[0, 1] > 0: return emoji.emojize(":smiling_face_with_heart-eyes:") elif embedding[0, 2] > 0: return emoji.emojize(":face_with_tears_of_joy:") elif embedding[0, 3] > 0: return emoji.emojize(":thinking_face:") else: return emoji.emojize(":neutral_face:") st.title("Image Attribute Detection and Emoji Generation") uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) if uploaded_file is not None: image = Image.open(uploaded_file) st.image(image, caption='Uploaded Image', use_column_width=True) if st.button('Analyze and Generate Emoji'): with st.spinner('Detecting attributes...'): attributes = detect_attributes(image) st.write("Detected Attributes:") for key, value in attributes.items(): st.write(f"{key.capitalize()}: {value}") with st.spinner('Generating prompt...'): prompt = generate_prompt(attributes) st.write("Generated Prompt:") st.write(prompt) with st.spinner('Generating emoji...'): emoji_result = generate_emoji(prompt) st.write("Generated Emoji:") st.write(emoji_result)