from fastapi import FastAPI, File, UploadFile, HTTPException from transformers import SegformerForSemanticSegmentation, SegformerImageProcessor import torch from pydantic import BaseModel from PIL import Image import numpy as np import io import base64 import logging import requests import torch.nn as nn # Inizializza l'app FastAPI app = FastAPI() # Add this class for the request body class ImageURL(BaseModel): url: str # Configura il logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Carica il modello e il processore SegFormer try: logger.info("Caricamento del modello SegFormer...") model = SegformerForSemanticSegmentation.from_pretrained("sayeed99/segformer-b3-fashion") processor = SegformerImageProcessor.from_pretrained("sayeed99/segformer-b3-fashion") model.to("cpu") # Usa CPU per il free tier logger.info("Modello caricato con successo.") except Exception as e: logger.error(f"Errore nel caricamento del modello: {str(e)}") raise RuntimeError(f"Errore nel caricamento del modello: {str(e)}") # Funzione per segmentare l'immagine def segment_image(image: Image.Image): # Prepara l'input per SegFormer logger.info("Preparazione dell'immagine per l'inferenza...") inputs = processor(images=image, return_tensors="pt").to("cpu") # Inferenza logger.info("Esecuzione dell'inferenza...") with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits # Post-processa la maschera logger.info("Post-processing della maschera...") mask = torch.argmax(logits, dim=1)[0] mask = mask.cpu().numpy() # Converti la maschera in immagine mask_img = Image.fromarray((mask * 255 / mask.max()).astype(np.uint8)) # Converti la maschera in base64 per la risposta buffered = io.BytesIO() mask_img.save(buffered, format="PNG") mask_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8") # Annotazioni annotations = {"mask": mask.tolist(), "label": logits } return mask_base64, annotations # Endpoint API @app.post("/segment") async def segment_endpoint(file: UploadFile = File(...)): try: logger.info("Ricezione del file...") image_data = await file.read() image = Image.open(io.BytesIO(image_data)).convert("RGB") logger.info("Segmentazione dell'immagine...") mask_base64, annotations = segment_image(image) return { "mask": f"data:image/png;base64,{mask_base64}", "annotations": annotations } except Exception as e: logger.error(f"Errore nell'endpoint: {str(e)}") raise HTTPException(status_code=500, detail=f"Errore nell'elaborazione: {str(e)}") # Add new endpoint @app.post("/segment-url") async def segment_url_endpoint(image_data: ImageURL): try: logger.info("Downloading image from URL...") response = requests.get(image_data.url, stream=True) if response.status_code != 200: raise HTTPException(status_code=400, detail="Could not download image from URL") # Open image from URL image = Image.open(response.raw).convert("RGB") # Process image with SegFormer logger.info("Processing image...") inputs = processor(images=image, return_tensors="pt") outputs = model(**inputs) logits = outputs.logits.cpu() # Upsample logits to match original image size upsampled_logits = nn.functional.interpolate( logits, size=image.size[::-1], mode="bilinear", align_corners=False, ) # Get prediction pred_seg = upsampled_logits.argmax(dim=1)[0] # Convert to image mask_img = Image.fromarray((pred_seg.numpy() * 255).astype(np.uint8)) # Convert to base64 buffered = io.BytesIO() mask_img.save(buffered, format="PNG") mask_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8") return { "mask": f"data:image/png;base64,{mask_base64}", "size": image.size, "labels" : upsampled_logits } except Exception as e: logger.error(f"Error processing URL: {str(e)}") raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}") # Per compatibilità con Hugging Face Spaces if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)