Spaces:
Runtime error
Runtime error
File size: 4,826 Bytes
9c9c79e 0a94f19 9c9c79e 275f227 0a94f19 275f227 0a94f19 275f227 0a94f19 275f227 0a94f19 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 |
from fastai.vision.all import *
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
import uvicorn
import logging
import tempfile
from pathlib import Path
import firebase_admin
from firebase_admin import credentials, firestore, storage
from pydantic import BaseModel
import torch
from transformers import AutoImageProcessor, AutoModelForObjectDetection
from PIL import Image, ImageDraw, ImageFont
import cv2
import random
# Load model and processor
processor = AutoImageProcessor.from_pretrained("valentinafeve/yolos-fashionpedia")
model = AutoModelForObjectDetection.from_pretrained("valentinafeve/yolos-fashionpedia")
# Fashionpedia categories
FASHION_CATEGORIES = [
'shirt, blouse', 'top, t-shirt, sweatshirt', 'sweater', 'cardigan', 'jacket', 'vest', 'pants', 'shorts', 'skirt',
'coat', 'dress', 'jumpsuit', 'cape', 'glasses', 'hat', 'headband, head covering, hair accessory', 'tie', 'glove',
'watch', 'belt', 'leg warmer', 'tights, stockings', 'sock', 'shoe', 'bag, wallet', 'scarf', 'umbrella', 'hood',
'collar', 'lapel', 'epaulette', 'sleeve', 'pocket', 'neckline', 'buckle', 'zipper', 'applique', 'bead', 'bow',
'flower', 'fringe', 'ribbon', 'rivet', 'ruffle', 'sequin', 'tassel'
]
def detect_fashion(image):
inputs = processor(images=image, return_tensors="pt")
outputs = model(**inputs)
# Convert outputs (bounding boxes and class logits) to COCO API
target_sizes = torch.tensor([image.size[::-1]])
results = processor.post_process_object_detection(outputs, threshold=0.1, target_sizes=target_sizes)[0]
detected_items = []
for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
if score > 0.5: # Adjust this threshold as needed
detected_items.append((FASHION_CATEGORIES[label], score.item(), box.tolist()))
return detected_items
def check_dress_code(detected_items):
formal_workplace_attire = {
"shirt, blouse", "jacket", "tie", "coat", "sweater", "cardigan", "coat"
}
return any(item[0] in formal_workplace_attire for item in detected_items)
@app.post("/process")
async def process_file(file_data: FileProcess):
logger.info(f"Processing file from Firebase Storage: {file_data.file_path}")
try:
# Get the file from Firebase Storage
blob = bucket.blob(file_data.file_path)
# Create a temporary file
with tempfile.NamedTemporaryFile(delete=False, suffix=f".{file_data.file_path.split('.')[-1]}") as tmp_file:
blob.download_to_filename(tmp_file.name)
tmp_file_path = Path(tmp_file.name)
logger.info(f"File downloaded temporarily at: {tmp_file_path}")
file_type = file_data.file_path.split('.')[-1].lower()
try:
if file_type in ['mp4', 'avi', 'mov', 'wmv']:
output,testing = process_video(str(tmp_file_path))
result = {"type": "video", "data": {"result": output}}
else:
raise HTTPException(status_code=400, detail="Unsupported file type")
logger.info(f"Processing complete. Result: {result}")
# Store result in Firebase
try:
doc_ref = db.collection('results').add(result)
return {"message": "File processed successfully", "result": result}
except Exception as e:
logger.error(f"Failed to store result in Firebase: {str(e)}")
return {"message": "File processed successfully, but failed to store in Firebase", "result": result,
"error": str(e)}
finally:
# Clean up the temporary file
tmp_file_path.unlink()
except Exception as e:
logger.error(f"Error processing file: {str(e)}")
raise HTTPException(status_code=500, detail=f"Error processing file: {str(e)}")
def process_video(video_path,num_frames=10):
cap = cv2.VideoCapture(video_path)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
frame_indices = sorted(random.sample(range(total_frames), min(num_frames, total_frames)))
compliance_results = []
for frame_index in frame_indices:
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_index)
ret, frame = cap.read()
if ret:
image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
detected_items = detect_fashion(image)
is_compliant = check_dress_code(detected_items)
compliance_results.append(is_compliant)
cap.release()
average_compliance = sum(compliance_results) / len(compliance_results)
return average_compliance, compliance_results
if __name__ == "__main__":
logger.info("Starting the Face Emotion Recognition API")
uvicorn.run(app, host="0.0.0.0", port=8000) |