hema1 commited on
Commit
e58607b
·
1 Parent(s): 30a20fb

Create text_to_image.py

Browse files
Files changed (1) hide show
  1. text_to_image.py +52 -0
text_to_image.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers.tools.base import Tool, get_default_device
2
+ from transformers.utils import is_accelerate_available
3
+ import torch
4
+
5
+ from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
6
+
7
+
8
+ TEXT_TO_IMAGE_DESCRIPTION = (
9
+ "This is a tool that creates an image according to a prompt, which is a text description. It takes an input named `prompt` which "
10
+ "contains the image description and outputs an image."
11
+ )
12
+
13
+
14
+ class TextToImageTool(Tool):
15
+ default_checkpoint = "runwayml/stable-diffusion-v1-5"
16
+ description = TEXT_TO_IMAGE_DESCRIPTION
17
+ inputs = ['text']
18
+ outputs = ['image']
19
+
20
+ def __init__(self, device=None, **hub_kwargs) -> None:
21
+ if not is_accelerate_available():
22
+ raise ImportError("Accelerate should be installed in order to use tools.")
23
+
24
+ super().__init__()
25
+
26
+ self.device = device
27
+ self.pipeline = None
28
+ self.hub_kwargs = hub_kwargs
29
+
30
+ def setup(self):
31
+ if self.device is None:
32
+ self.device = get_default_device()
33
+
34
+ self.pipeline = DiffusionPipeline.from_pretrained(self.default_checkpoint)
35
+ self.pipeline.scheduler = DPMSolverMultistepScheduler.from_config(self.pipeline.scheduler.config)
36
+ self.pipeline.to(self.device)
37
+
38
+ if self.device.type == "cuda":
39
+ self.pipeline.to(torch_dtype=torch.float16)
40
+
41
+ self.is_initialized = True
42
+
43
+ def __call__(self, prompt):
44
+ if not self.is_initialized:
45
+ self.setup()
46
+
47
+ negative_prompt = "low quality, bad quality, deformed, low resolution"
48
+ added_prompt = " , highest quality, highly realistic, very high resolution"
49
+
50
+ return self.pipeline(prompt + added_prompt, negative_prompt=negative_prompt, num_inference_steps=25).images[0]
51
+
52
+