File size: 2,662 Bytes
636182d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
from typing import Dict, List, Any
from transformers import pipeline
from PIL import Image
import requests
import os
from io import BytesIO
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
from diffusers import DiffusionPipeline
import torch
from torch import autocast
import base64


auth_token = "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"


class EndpointHandler():
    def __init__(self, path=""):
        self.processor = CLIPSegProcessor.from_pretrained("./clipseg-rd64-refined")
        self.model = CLIPSegForImageSegmentation.from_pretrained("./clipseg-rd64-refined")

        self.pipe = DiffusionPipeline.from_pretrained(
            "./",
            custom_pipeline="text_inpainting",
            segmentation_model=self.model,
            segmentation_processor=self.processor,
            revision="fp16",
            torch_dtype=torch.float16,
            use_auth_token=auth_token,
        )

        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.pipe = self.pipe.to(self.device)

    def pad_image(self, image):
        w, h = image.size
        if w == h:
            return image
        elif w > h:
            new_image = Image.new(image.mode, (w, w), (0, 0, 0))
            new_image.paste(image, (0, (w - h) // 2))
            return new_image
        else:
            new_image = Image.new(image.mode, (h, h), (0, 0, 0))
            new_image.paste(image, ((h - w) // 2, 0))
            return new_image


    def process_image(self, image, text, prompt):
        image = self.pad_image(image)
        image = image.resize((512, 512))
        with autocast(self.device):
            inpainted_image = self.pipe(image=image, text=text, prompt=prompt).images[0]
        return inpainted_image

    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
        """
       data args:
            inputs (:obj: `str`)
            date (:obj: `str`)
      Return:
            A :obj:`list` | `dict`: will be serialized and returned
        """
        # get inputs
        inputs = data.pop("inputs", data)

        # decode base64 image to PIL
        image = Image.open(BytesIO(base64.b64decode(inputs['image'])))
        class_text = inputs['class_text']
        prompt = inputs['prompt']
        # run inference pipeline
        with autocast(self.device):
            image = self.process_image(image, class_text, prompt)
            
        # encode image as base 64
        buffered = BytesIO()
        image.save(buffered, format="JPEG")
        img_str = base64.b64encode(buffered.getvalue())

        # postprocess the prediction
        return {"image": img_str.decode()}