tyri zamasam commited on
Commit
b55ce60
β€’
0 Parent(s):

Duplicate from zamasam/allExtensionsHentai

Browse files

Co-authored-by: no <[email protected]>

Files changed (7) hide show
  1. .gitattributes +35 -0
  2. Dockerfile +21 -0
  3. README.md +10 -0
  4. constants.py +51 -0
  5. requirements.txt +18 -0
  6. server.py +856 -0
  7. tts_edge.py +34 -0
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
Dockerfile ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10
2
+
3
+ WORKDIR /app
4
+
5
+ COPY requirements.txt .
6
+ RUN pip install -r requirements.txt
7
+
8
+ RUN mkdir /.cache && chmod -R 777 /.cache
9
+ RUN mkdir .chroma && chmod -R 777 .chroma
10
+
11
+ COPY . .
12
+
13
+
14
+ RUN chmod -R 777 /app
15
+
16
+ RUN --mount=type=secret,id=password,mode=0444,required=true \
17
+ cat /run/secrets/password > /test
18
+
19
+ EXPOSE 7860
20
+
21
+ CMD ["python", "server.py", "--cpu", "--enable-modules=caption,summarize,classify,silero-tts,edge-tts,chromadb"]
README.md ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: data
3
+ emoji: 😭
4
+ colorFrom: blue
5
+ colorTo: green
6
+ sdk: docker
7
+ pinned: false
8
+ license: mit
9
+ duplicated_from: zamasam/allExtensionsHentai
10
+ ---
constants.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Constants
2
+ # Also try: 'slauw87/bart-large-cnn-samsum'
3
+ # Qiliang/bart-large-cnn-samsum-ElectrifAi_v14
4
+ DEFAULT_SUMMARIZATION_MODEL = "slauw87/bart_summarisation"
5
+ # Also try: 'nateraw/bert-base-uncased-emotion'
6
+ DEFAULT_CLASSIFICATION_MODEL = "joeddav/distilbert-base-uncased-go-emotions-student"
7
+ # Also try: 'Salesforce/blip-image-captioning-base'
8
+ DEFAULT_CAPTIONING_MODEL = "Salesforce/blip-image-captioning-large"
9
+ # Also try: 'ckpt/anything-v4.5-vae-swapped'
10
+ DEFAULT_SD_MODEL = "sinkinai/MeinaHentai-v3-baked-vae"
11
+ DEFAULT_EMBEDDING_MODEL = "sentence-transformers/all-mpnet-base-v2"
12
+ DEFAULT_REMOTE_SD_HOST = "127.0.0.1"
13
+ DEFAULT_REMOTE_SD_PORT = 7860
14
+ DEFAULT_CHROMA_PORT = 8000
15
+ SILERO_SAMPLES_PATH = "tts_samples"
16
+ SILERO_SAMPLE_TEXT = "Doctor is your lord and savior"
17
+ # ALL_MODULES = ['caption', 'summarize', 'classify', 'keywords', 'prompt', 'sd']
18
+ DEFAULT_SUMMARIZE_PARAMS = {
19
+ "temperature": 1.0,
20
+ "repetition_penalty": 1.0,
21
+ "max_length": 500,
22
+ "min_length": 200,
23
+ "length_penalty": 1.5,
24
+ "bad_words": [
25
+ "\n",
26
+ '"',
27
+ "*",
28
+ "[",
29
+ "]",
30
+ "{",
31
+ "}",
32
+ ":",
33
+ "(",
34
+ ")",
35
+ "<",
36
+ ">",
37
+ "Γ‚",
38
+ "The text ends",
39
+ "The story ends",
40
+ "The text is",
41
+ "The story is",
42
+ ],
43
+ }
44
+
45
+ PROMPT_PREFIX = "best quality, absurdres, "
46
+ NEGATIVE_PROMPT = """lowres, bad anatomy, error body, error hair, error arm,
47
+ error hands, bad hands, error fingers, bad fingers, missing fingers
48
+ error legs, bad legs, multiple legs, missing legs, error lighting,
49
+ error shadow, error reflection, text, error, extra digit, fewer digits,
50
+ cropped, worst quality, low quality, normal quality, jpeg artifacts,
51
+ signature, watermark, username, blurry"""
requirements.txt ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ flask
2
+ flask-cors
3
+ flask-compress
4
+ markdown
5
+ Pillow
6
+ colorama
7
+ webuiapi
8
+ --extra-index-url https://download.pytorch.org/whl/cu117
9
+ torch==2.0.0+cu117
10
+ torchvision==0.15.1
11
+ torchaudio==2.0.1+cu117
12
+ accelerate
13
+ transformers==4.28.1
14
+ diffusers==0.16.1
15
+ silero-api-server
16
+ chromadb==0.3.26
17
+ sentence_transformers
18
+ edge-tts
server.py ADDED
@@ -0,0 +1,856 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import wraps
2
+ from flask import (
3
+ Flask,
4
+ jsonify,
5
+ request,
6
+ Response,
7
+ render_template_string,
8
+ abort,
9
+ send_from_directory,
10
+ send_file,
11
+ )
12
+ from flask_cors import CORS
13
+ from flask_compress import Compress
14
+ import markdown
15
+ import argparse
16
+ from transformers import AutoTokenizer, AutoProcessor, pipeline
17
+ from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM
18
+ from transformers import BlipForConditionalGeneration
19
+ import unicodedata
20
+ import torch
21
+ import time
22
+ import os
23
+ import gc
24
+ import secrets
25
+ from PIL import Image
26
+ import base64
27
+ from io import BytesIO
28
+ from random import randint
29
+ import webuiapi
30
+ import hashlib
31
+ from constants import *
32
+ from colorama import Fore, Style, init as colorama_init
33
+
34
+ colorama_init()
35
+
36
+
37
+ class SplitArgs(argparse.Action):
38
+ def __call__(self, parser, namespace, values, option_string=None):
39
+ setattr(
40
+ namespace, self.dest, values.replace('"', "").replace("'", "").split(",")
41
+ )
42
+
43
+
44
+ # Script arguments
45
+ parser = argparse.ArgumentParser(
46
+ prog="SillyTavern Extras", description="Web API for transformers models"
47
+ )
48
+ parser.add_argument(
49
+ "--port", type=int, help="Specify the port on which the application is hosted"
50
+ )
51
+ parser.add_argument(
52
+ "--listen", action="store_true", help="Host the app on the local network"
53
+ )
54
+ parser.add_argument(
55
+ "--share", action="store_true", help="Share the app on CloudFlare tunnel"
56
+ )
57
+ parser.add_argument("--cpu", action="store_true", help="Run the models on the CPU")
58
+ parser.add_argument("--cuda", action="store_false", dest="cpu", help="Run the models on the GPU")
59
+ parser.set_defaults(cpu=True)
60
+ parser.add_argument("--summarization-model", help="Load a custom summarization model")
61
+ parser.add_argument(
62
+ "--classification-model", help="Load a custom text classification model"
63
+ )
64
+ parser.add_argument("--captioning-model", help="Load a custom captioning model")
65
+ parser.add_argument("--embedding-model", help="Load a custom text embedding model")
66
+ parser.add_argument("--chroma-host", help="Host IP for a remote ChromaDB instance")
67
+ parser.add_argument("--chroma-port", help="HTTP port for a remote ChromaDB instance (defaults to 8000)")
68
+ parser.add_argument("--chroma-folder", help="Path for chromadb persistence folder", default='.chroma_db')
69
+ parser.add_argument('--chroma-persist', help="Chromadb persistence", default=True, action=argparse.BooleanOptionalAction)
70
+ parser.add_argument(
71
+ "--secure", action="store_true", help="Enforces the use of an API key"
72
+ )
73
+
74
+ sd_group = parser.add_mutually_exclusive_group()
75
+
76
+ local_sd = sd_group.add_argument_group("sd-local")
77
+ local_sd.add_argument("--sd-model", help="Load a custom SD image generation model")
78
+ local_sd.add_argument("--sd-cpu", help="Force the SD pipeline to run on the CPU", action="store_true")
79
+
80
+ remote_sd = sd_group.add_argument_group("sd-remote")
81
+ remote_sd.add_argument(
82
+ "--sd-remote", action="store_true", help="Use a remote backend for SD"
83
+ )
84
+ remote_sd.add_argument(
85
+ "--sd-remote-host", type=str, help="Specify the host of the remote SD backend"
86
+ )
87
+ remote_sd.add_argument(
88
+ "--sd-remote-port", type=int, help="Specify the port of the remote SD backend"
89
+ )
90
+ remote_sd.add_argument(
91
+ "--sd-remote-ssl", action="store_true", help="Use SSL for the remote SD backend"
92
+ )
93
+ remote_sd.add_argument(
94
+ "--sd-remote-auth",
95
+ type=str,
96
+ help="Specify the username:password for the remote SD backend (if required)",
97
+ )
98
+
99
+ parser.add_argument(
100
+ "--enable-modules",
101
+ action=SplitArgs,
102
+ default=[],
103
+ help="Override a list of enabled modules",
104
+ )
105
+
106
+ args = parser.parse_args()
107
+
108
+ port = 7860
109
+ host = "0.0.0.0"
110
+ summarization_model = (
111
+ args.summarization_model
112
+ if args.summarization_model
113
+ else DEFAULT_SUMMARIZATION_MODEL
114
+ )
115
+ classification_model = (
116
+ args.classification_model
117
+ if args.classification_model
118
+ else DEFAULT_CLASSIFICATION_MODEL
119
+ )
120
+ captioning_model = (
121
+ args.captioning_model if args.captioning_model else DEFAULT_CAPTIONING_MODEL
122
+ )
123
+ embedding_model = (
124
+ args.embedding_model if args.embedding_model else DEFAULT_EMBEDDING_MODEL
125
+ )
126
+
127
+ sd_use_remote = False if args.sd_model else True
128
+ sd_model = args.sd_model if args.sd_model else DEFAULT_SD_MODEL
129
+ sd_remote_host = args.sd_remote_host if args.sd_remote_host else DEFAULT_REMOTE_SD_HOST
130
+ sd_remote_port = args.sd_remote_port if args.sd_remote_port else DEFAULT_REMOTE_SD_PORT
131
+ sd_remote_ssl = args.sd_remote_ssl
132
+ sd_remote_auth = args.sd_remote_auth
133
+
134
+ modules = (
135
+ args.enable_modules if args.enable_modules and len(args.enable_modules) > 0 else []
136
+ )
137
+
138
+ if len(modules) == 0:
139
+ print(
140
+ f"{Fore.RED}{Style.BRIGHT}You did not select any modules to run! Choose them by adding an --enable-modules option"
141
+ )
142
+ print(f"Example: --enable-modules=caption,summarize{Style.RESET_ALL}")
143
+
144
+ # Models init
145
+ device_string = "cuda:0" if torch.cuda.is_available() and not args.cpu else "cpu"
146
+ device = torch.device(device_string)
147
+ torch_dtype = torch.float32 if device_string == "cpu" else torch.float16
148
+
149
+ if not torch.cuda.is_available() and not args.cpu:
150
+ print(f"{Fore.YELLOW}{Style.BRIGHT}torch-cuda is not supported on this device. Defaulting to CPU mode.{Style.RESET_ALL}")
151
+
152
+ print(f"{Fore.GREEN}{Style.BRIGHT}Using torch device: {device_string}{Style.RESET_ALL}")
153
+
154
+ if "caption" in modules:
155
+ print("Initializing an image captioning model...")
156
+ captioning_processor = AutoProcessor.from_pretrained(captioning_model)
157
+ if "blip" in captioning_model:
158
+ captioning_transformer = BlipForConditionalGeneration.from_pretrained(
159
+ captioning_model, torch_dtype=torch_dtype
160
+ ).to(device)
161
+ else:
162
+ captioning_transformer = AutoModelForCausalLM.from_pretrained(
163
+ captioning_model, torch_dtype=torch_dtype
164
+ ).to(device)
165
+
166
+ if "summarize" in modules:
167
+ print("Initializing a text summarization model...")
168
+ summarization_tokenizer = AutoTokenizer.from_pretrained(summarization_model)
169
+ summarization_transformer = AutoModelForSeq2SeqLM.from_pretrained(
170
+ summarization_model, torch_dtype=torch_dtype
171
+ ).to(device)
172
+
173
+ if "classify" in modules:
174
+ print("Initializing a sentiment classification pipeline...")
175
+ classification_pipe = pipeline(
176
+ "text-classification",
177
+ model=classification_model,
178
+ top_k=None,
179
+ device=device,
180
+ torch_dtype=torch_dtype,
181
+ )
182
+
183
+ if "sd" in modules and not sd_use_remote:
184
+ from diffusers import StableDiffusionPipeline
185
+ from diffusers import EulerAncestralDiscreteScheduler
186
+
187
+ print("Initializing Stable Diffusion pipeline")
188
+ sd_device_string = (
189
+ "cuda" if torch.cuda.is_available() and not args.sd_cpu else "cpu"
190
+ )
191
+ sd_device = torch.device(sd_device_string)
192
+ sd_torch_dtype = torch.float32 if sd_device_string == "cpu" else torch.float16
193
+ sd_pipe = StableDiffusionPipeline.from_pretrained(
194
+ sd_model, custom_pipeline="lpw_stable_diffusion", torch_dtype=sd_torch_dtype
195
+ ).to(sd_device)
196
+ sd_pipe.safety_checker = lambda images, clip_input: (images, False)
197
+ sd_pipe.enable_attention_slicing()
198
+ # pipe.scheduler = KarrasVeScheduler.from_config(pipe.scheduler.config)
199
+ sd_pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(
200
+ sd_pipe.scheduler.config
201
+ )
202
+ elif "sd" in modules and sd_use_remote:
203
+ print("Initializing Stable Diffusion connection")
204
+ try:
205
+ sd_remote = webuiapi.WebUIApi(
206
+ host=sd_remote_host, port=sd_remote_port, use_https=sd_remote_ssl
207
+ )
208
+ if sd_remote_auth:
209
+ username, password = sd_remote_auth.split(":")
210
+ sd_remote.set_auth(username, password)
211
+ sd_remote.util_wait_for_ready()
212
+ except Exception as e:
213
+ # remote sd from modules
214
+ print(
215
+ f"{Fore.RED}{Style.BRIGHT}Could not connect to remote SD backend at http{'s' if sd_remote_ssl else ''}://{sd_remote_host}:{sd_remote_port}! Disabling SD module...{Style.RESET_ALL}"
216
+ )
217
+ modules.remove("sd")
218
+
219
+ if "tts" in modules:
220
+ print("tts module is deprecated. Please use silero-tts instead.")
221
+ modules.remove("tts")
222
+ modules.append("silero-tts")
223
+
224
+
225
+ if "silero-tts" in modules:
226
+ if not os.path.exists(SILERO_SAMPLES_PATH):
227
+ os.makedirs(SILERO_SAMPLES_PATH)
228
+ print("Initializing Silero TTS server")
229
+ from silero_api_server import tts
230
+
231
+ tts_service = tts.SileroTtsService(SILERO_SAMPLES_PATH)
232
+ if len(os.listdir(SILERO_SAMPLES_PATH)) == 0:
233
+ print("Generating Silero TTS samples...")
234
+ tts_service.update_sample_text(SILERO_SAMPLE_TEXT)
235
+ tts_service.generate_samples()
236
+
237
+
238
+ if "edge-tts" in modules:
239
+ print("Initializing Edge TTS client")
240
+ import tts_edge as edge
241
+
242
+
243
+ if "chromadb" in modules:
244
+ print("Initializing ChromaDB")
245
+ import chromadb
246
+ import posthog
247
+ from chromadb.config import Settings
248
+ from sentence_transformers import SentenceTransformer
249
+
250
+ # Assume that the user wants in-memory unless a host is specified
251
+ # Also disable chromadb telemetry
252
+ posthog.capture = lambda *args, **kwargs: None
253
+ if args.chroma_host is None:
254
+ if args.chroma_persist:
255
+ chromadb_client = chromadb.Client(Settings(anonymized_telemetry=False, persist_directory=args.chroma_folder, chroma_db_impl='duckdb+parquet'))
256
+ print(f"ChromaDB is running in-memory with persistence. Persistence is stored in {args.chroma_folder}. Can be cleared by deleting the folder or purging db.")
257
+ else:
258
+ chromadb_client = chromadb.Client(Settings(anonymized_telemetry=False))
259
+ print(f"ChromaDB is running in-memory without persistence.")
260
+ else:
261
+ chroma_port=(
262
+ args.chroma_port if args.chroma_port else DEFAULT_CHROMA_PORT
263
+ )
264
+ chromadb_client = chromadb.Client(
265
+ Settings(
266
+ anonymized_telemetry=False,
267
+ chroma_api_impl="rest",
268
+ chroma_server_host=args.chroma_host,
269
+ chroma_server_http_port=chroma_port
270
+ )
271
+ )
272
+ print(f"ChromaDB is remotely configured at {args.chroma_host}:{chroma_port}")
273
+
274
+ chromadb_embedder = SentenceTransformer(embedding_model)
275
+ chromadb_embed_fn = lambda *args, **kwargs: chromadb_embedder.encode(*args, **kwargs).tolist()
276
+
277
+ # Check if the db is connected and running, otherwise tell the user
278
+ try:
279
+ chromadb_client.heartbeat()
280
+ print("Successfully pinged ChromaDB! Your client is successfully connected.")
281
+ except:
282
+ print("Could not ping ChromaDB! If you are running remotely, please check your host and port!")
283
+
284
+ # Flask init
285
+ app = Flask(__name__)
286
+ CORS(app) # allow cross-domain requests
287
+ Compress(app) # compress responses
288
+ app.config["MAX_CONTENT_LENGTH"] = 100 * 1024 * 1024
289
+
290
+
291
+ def require_module(name):
292
+ def wrapper(fn):
293
+ @wraps(fn)
294
+ def decorated_view(*args, **kwargs):
295
+ if name not in modules:
296
+ abort(403, "Module is disabled by config")
297
+ return fn(*args, **kwargs)
298
+
299
+ return decorated_view
300
+
301
+ return wrapper
302
+
303
+
304
+ # AI stuff
305
+ def classify_text(text: str) -> list:
306
+ output = classification_pipe(
307
+ text,
308
+ truncation=True,
309
+ max_length=classification_pipe.model.config.max_position_embeddings,
310
+ )[0]
311
+ return sorted(output, key=lambda x: x["score"], reverse=True)
312
+
313
+
314
+ def caption_image(raw_image: Image, max_new_tokens: int = 20) -> str:
315
+ inputs = captioning_processor(raw_image.convert("RGB"), return_tensors="pt").to(
316
+ device, torch_dtype
317
+ )
318
+ outputs = captioning_transformer.generate(**inputs, max_new_tokens=max_new_tokens)
319
+ caption = captioning_processor.decode(outputs[0], skip_special_tokens=True)
320
+ return caption
321
+
322
+
323
+ def summarize_chunks(text: str, params: dict) -> str:
324
+ try:
325
+ return summarize(text, params)
326
+ except IndexError:
327
+ print(
328
+ "Sequence length too large for model, cutting text in half and calling again"
329
+ )
330
+ new_params = params.copy()
331
+ new_params["max_length"] = new_params["max_length"] // 2
332
+ new_params["min_length"] = new_params["min_length"] // 2
333
+ return summarize_chunks(
334
+ text[: (len(text) // 2)], new_params
335
+ ) + summarize_chunks(text[(len(text) // 2) :], new_params)
336
+
337
+
338
+ def summarize(text: str, params: dict) -> str:
339
+ # Tokenize input
340
+ inputs = summarization_tokenizer(text, return_tensors="pt").to(device)
341
+ token_count = len(inputs[0])
342
+
343
+ bad_words_ids = [
344
+ summarization_tokenizer(bad_word, add_special_tokens=False).input_ids
345
+ for bad_word in params["bad_words"]
346
+ ]
347
+ summary_ids = summarization_transformer.generate(
348
+ inputs["input_ids"],
349
+ num_beams=2,
350
+ max_new_tokens=max(token_count, int(params["max_length"])),
351
+ min_new_tokens=min(token_count, int(params["min_length"])),
352
+ repetition_penalty=float(params["repetition_penalty"]),
353
+ temperature=float(params["temperature"]),
354
+ length_penalty=float(params["length_penalty"]),
355
+ bad_words_ids=bad_words_ids,
356
+ )
357
+ summary = summarization_tokenizer.batch_decode(
358
+ summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
359
+ )[0]
360
+ summary = normalize_string(summary)
361
+ return summary
362
+
363
+
364
+ def normalize_string(input: str) -> str:
365
+ output = " ".join(unicodedata.normalize("NFKC", input).strip().split())
366
+ return output
367
+
368
+
369
+ def generate_image(data: dict) -> Image:
370
+ prompt = normalize_string(f'{data["prompt_prefix"]} {data["prompt"]}')
371
+
372
+ if sd_use_remote:
373
+ image = sd_remote.txt2img(
374
+ prompt=prompt,
375
+ negative_prompt=data["negative_prompt"],
376
+ sampler_name=data["sampler"],
377
+ steps=data["steps"],
378
+ cfg_scale=data["scale"],
379
+ width=data["width"],
380
+ height=data["height"],
381
+ restore_faces=data["restore_faces"],
382
+ enable_hr=data["enable_hr"],
383
+ save_images=True,
384
+ send_images=True,
385
+ do_not_save_grid=False,
386
+ do_not_save_samples=False,
387
+ ).image
388
+ else:
389
+ image = sd_pipe(
390
+ prompt=prompt,
391
+ negative_prompt=data["negative_prompt"],
392
+ num_inference_steps=data["steps"],
393
+ guidance_scale=data["scale"],
394
+ width=data["width"],
395
+ height=data["height"],
396
+ ).images[0]
397
+
398
+ image.save("./debug.png")
399
+ return image
400
+
401
+
402
+ def image_to_base64(image: Image, quality: int = 75) -> str:
403
+ buffer = BytesIO()
404
+ image.convert("RGB")
405
+ image.save(buffer, format="JPEG", quality=quality)
406
+ img_str = base64.b64encode(buffer.getvalue()).decode("utf-8")
407
+ return img_str
408
+
409
+ ignore_auth = []
410
+
411
+ api_key = os.environ.get("password")
412
+
413
+ def is_authorize_ignored(request):
414
+ view_func = app.view_functions.get(request.endpoint)
415
+
416
+ if view_func is not None:
417
+ if view_func in ignore_auth:
418
+ return True
419
+ return False
420
+
421
+ @app.before_request
422
+ def before_request():
423
+ # Request time measuring
424
+ request.start_time = time.time()
425
+
426
+ # Checks if an API key is present and valid, otherwise return unauthorized
427
+ # The options check is required so CORS doesn't get angry
428
+ try:
429
+ if request.method != 'OPTIONS' and is_authorize_ignored(request) == False and getattr(request.authorization, 'token', '') != api_key:
430
+ print(f"WARNING: Unauthorized API key access from {request.remote_addr}")
431
+ if request.method == 'POST':
432
+ print(f"Incoming POST request with {request.headers.get('Authorization')}")
433
+ response = jsonify({ 'error': '401: Invalid API key' })
434
+ response.status_code = 401
435
+ return "https://(hf_name)-(space_name).hf.space/"
436
+ except Exception as e:
437
+ print(f"API key check error: {e}")
438
+ return "https://(hf_name)-(space_name).hf.space/"
439
+
440
+
441
+ @app.after_request
442
+ def after_request(response):
443
+ duration = time.time() - request.start_time
444
+ response.headers["X-Request-Duration"] = str(duration)
445
+ return response
446
+
447
+
448
+ @app.route("/", methods=["GET"])
449
+ def index():
450
+ with open("./README.md", "r", encoding="utf8") as f:
451
+ content = f.read()
452
+ return render_template_string(markdown.markdown(content, extensions=["tables"]))
453
+
454
+
455
+ @app.route("/api/extensions", methods=["GET"])
456
+ def get_extensions():
457
+ extensions = dict(
458
+ {
459
+ "extensions": [
460
+ {
461
+ "name": "not-supported",
462
+ "metadata": {
463
+ "display_name": """<span style="white-space:break-spaces;">Extensions serving using Extensions API is no longer supported. Please update the mod from: <a href="https://github.com/Cohee1207/SillyTavern">https://github.com/Cohee1207/SillyTavern</a></span>""",
464
+ "requires": [],
465
+ "assets": [],
466
+ },
467
+ }
468
+ ]
469
+ }
470
+ )
471
+ return jsonify(extensions)
472
+
473
+
474
+ @app.route("/api/caption", methods=["POST"])
475
+ @require_module("caption")
476
+ def api_caption():
477
+ data = request.get_json()
478
+
479
+ if "image" not in data or not isinstance(data["image"], str):
480
+ abort(400, '"image" is required')
481
+
482
+ image = Image.open(BytesIO(base64.b64decode(data["image"])))
483
+ image = image.convert("RGB")
484
+ image.thumbnail((512, 512))
485
+ caption = caption_image(image)
486
+ thumbnail = image_to_base64(image)
487
+ print("Caption:", caption, sep="\n")
488
+ gc.collect()
489
+ return jsonify({"caption": caption, "thumbnail": thumbnail})
490
+
491
+
492
+ @app.route("/api/summarize", methods=["POST"])
493
+ @require_module("summarize")
494
+ def api_summarize():
495
+ data = request.get_json()
496
+
497
+ if "text" not in data or not isinstance(data["text"], str):
498
+ abort(400, '"text" is required')
499
+
500
+ params = DEFAULT_SUMMARIZE_PARAMS.copy()
501
+
502
+ if "params" in data and isinstance(data["params"], dict):
503
+ params.update(data["params"])
504
+
505
+ print("Summary input:", data["text"], sep="\n")
506
+ summary = summarize_chunks(data["text"], params)
507
+ print("Summary output:", summary, sep="\n")
508
+ gc.collect()
509
+ return jsonify({"summary": summary})
510
+
511
+
512
+ @app.route("/api/classify", methods=["POST"])
513
+ @require_module("classify")
514
+ def api_classify():
515
+ data = request.get_json()
516
+
517
+ if "text" not in data or not isinstance(data["text"], str):
518
+ abort(400, '"text" is required')
519
+
520
+ print("Classification input:", data["text"], sep="\n")
521
+ classification = classify_text(data["text"])
522
+ print("Classification output:", classification, sep="\n")
523
+ gc.collect()
524
+ return jsonify({"classification": classification})
525
+
526
+
527
+ @app.route("/api/classify/labels", methods=["GET"])
528
+ @require_module("classify")
529
+ def api_classify_labels():
530
+ classification = classify_text("")
531
+ labels = [x["label"] for x in classification]
532
+ return jsonify({"labels": labels})
533
+
534
+
535
+ @app.route("/api/image", methods=["POST"])
536
+ @require_module("sd")
537
+ def api_image():
538
+ required_fields = {
539
+ "prompt": str,
540
+ }
541
+
542
+ optional_fields = {
543
+ "steps": 30,
544
+ "scale": 6,
545
+ "sampler": "DDIM",
546
+ "width": 512,
547
+ "height": 512,
548
+ "restore_faces": False,
549
+ "enable_hr": False,
550
+ "prompt_prefix": PROMPT_PREFIX,
551
+ "negative_prompt": NEGATIVE_PROMPT,
552
+ }
553
+
554
+ data = request.get_json()
555
+
556
+ # Check required fields
557
+ for field, field_type in required_fields.items():
558
+ if field not in data or not isinstance(data[field], field_type):
559
+ abort(400, f'"{field}" is required')
560
+
561
+ # Set optional fields to default values if not provided
562
+ for field, default_value in optional_fields.items():
563
+ type_match = (
564
+ (int, float)
565
+ if isinstance(default_value, (int, float))
566
+ else type(default_value)
567
+ )
568
+ if field not in data or not isinstance(data[field], type_match):
569
+ data[field] = default_value
570
+
571
+ try:
572
+ print("SD inputs:", data, sep="\n")
573
+ image = generate_image(data)
574
+ base64image = image_to_base64(image, quality=90)
575
+ return jsonify({"image": base64image})
576
+ except RuntimeError as e:
577
+ abort(400, str(e))
578
+
579
+
580
+ @app.route("/api/image/model", methods=["POST"])
581
+ @require_module("sd")
582
+ def api_image_model_set():
583
+ data = request.get_json()
584
+
585
+ if not sd_use_remote:
586
+ abort(400, "Changing model for local sd is not supported.")
587
+ if "model" not in data or not isinstance(data["model"], str):
588
+ abort(400, '"model" is required')
589
+
590
+ old_model = sd_remote.util_get_current_model()
591
+ sd_remote.util_set_model(data["model"], find_closest=False)
592
+ # sd_remote.util_set_model(data['model'])
593
+ sd_remote.util_wait_for_ready()
594
+ new_model = sd_remote.util_get_current_model()
595
+
596
+ return jsonify({"previous_model": old_model, "current_model": new_model})
597
+
598
+
599
+ @app.route("/api/image/model", methods=["GET"])
600
+ @require_module("sd")
601
+ def api_image_model_get():
602
+ model = sd_model
603
+
604
+ if sd_use_remote:
605
+ model = sd_remote.util_get_current_model()
606
+
607
+ return jsonify({"model": model})
608
+
609
+
610
+ @app.route("/api/image/models", methods=["GET"])
611
+ @require_module("sd")
612
+ def api_image_models():
613
+ models = [sd_model]
614
+
615
+ if sd_use_remote:
616
+ models = sd_remote.util_get_model_names()
617
+
618
+ return jsonify({"models": models})
619
+
620
+
621
+ @app.route("/api/image/samplers", methods=["GET"])
622
+ @require_module("sd")
623
+ def api_image_samplers():
624
+ samplers = ["Euler a"]
625
+
626
+ if sd_use_remote:
627
+ samplers = [sampler["name"] for sampler in sd_remote.get_samplers()]
628
+
629
+ return jsonify({"samplers": samplers})
630
+
631
+
632
+ @app.route("/api/modules", methods=["GET"])
633
+ def get_modules():
634
+ return jsonify({"modules": modules})
635
+
636
+
637
+ @app.route("/api/tts/speakers", methods=["GET"])
638
+ @require_module("silero-tts")
639
+ def tts_speakers():
640
+ voices = [
641
+ {
642
+ "name": speaker,
643
+ "voice_id": speaker,
644
+ "preview_url": f"{str(request.url_root)}api/tts/sample/{speaker}",
645
+ }
646
+ for speaker in tts_service.get_speakers()
647
+ ]
648
+ return jsonify(voices)
649
+
650
+
651
+ @app.route("/api/tts/generate", methods=["POST"])
652
+ @require_module("silero-tts")
653
+ def tts_generate():
654
+ voice = request.get_json()
655
+ if "text" not in voice or not isinstance(voice["text"], str):
656
+ abort(400, '"text" is required')
657
+ if "speaker" not in voice or not isinstance(voice["speaker"], str):
658
+ abort(400, '"speaker" is required')
659
+ # Remove asterisks
660
+ voice["text"] = voice["text"].replace("*", "")
661
+ try:
662
+ audio = tts_service.generate(voice["speaker"], voice["text"])
663
+ return send_file(audio, mimetype="audio/x-wav")
664
+ except Exception as e:
665
+ print(e)
666
+ abort(500, voice["speaker"])
667
+
668
+
669
+ @app.route("/api/tts/sample/<speaker>", methods=["GET"])
670
+ @require_module("silero-tts")
671
+ def tts_play_sample(speaker: str):
672
+ return send_from_directory(SILERO_SAMPLES_PATH, f"{speaker}.wav")
673
+
674
+
675
+ @app.route("/api/edge-tts/list", methods=["GET"])
676
+ @require_module("edge-tts")
677
+ def edge_tts_list():
678
+ voices = edge.get_voices()
679
+ return jsonify(voices)
680
+
681
+
682
+ @app.route("/api/edge-tts/generate", methods=["POST"])
683
+ @require_module("edge-tts")
684
+ def edge_tts_generate():
685
+ data = request.get_json()
686
+ if "text" not in data or not isinstance(data["text"], str):
687
+ abort(400, '"text" is required')
688
+ if "voice" not in data or not isinstance(data["voice"], str):
689
+ abort(400, '"voice" is required')
690
+ if "rate" in data and isinstance(data['rate'], int):
691
+ rate = data['rate']
692
+ else:
693
+ rate = 0
694
+ # Remove asterisks
695
+ data["text"] = data["text"].replace("*", "")
696
+ try:
697
+ audio = edge.generate_audio(text=data["text"], voice=data["voice"], rate=rate)
698
+ return Response(audio, mimetype="audio/mpeg")
699
+ except Exception as e:
700
+ print(e)
701
+ abort(500, data["voice"])
702
+
703
+
704
+ @app.route("/api/chromadb", methods=["POST"])
705
+ @require_module("chromadb")
706
+ def chromadb_add_messages():
707
+ data = request.get_json()
708
+ if "chat_id" not in data or not isinstance(data["chat_id"], str):
709
+ abort(400, '"chat_id" is required')
710
+ if "messages" not in data or not isinstance(data["messages"], list):
711
+ abort(400, '"messages" is required')
712
+
713
+ chat_id_md5 = hashlib.md5(data["chat_id"].encode()).hexdigest()
714
+ collection = chromadb_client.get_or_create_collection(
715
+ name=f"chat-{chat_id_md5}", embedding_function=chromadb_embed_fn
716
+ )
717
+
718
+ documents = [m["content"] for m in data["messages"]]
719
+ ids = [m["id"] for m in data["messages"]]
720
+ metadatas = [
721
+ {"role": m["role"], "date": m["date"], "meta": m.get("meta", "")}
722
+ for m in data["messages"]
723
+ ]
724
+
725
+ collection.upsert(
726
+ ids=ids,
727
+ documents=documents,
728
+ metadatas=metadatas,
729
+ )
730
+
731
+ return jsonify({"count": len(ids)})
732
+
733
+
734
+ @app.route("/api/chromadb/purge", methods=["POST"])
735
+ @require_module("chromadb")
736
+ def chromadb_purge():
737
+ data = request.get_json()
738
+ if "chat_id" not in data or not isinstance(data["chat_id"], str):
739
+ abort(400, '"chat_id" is required')
740
+
741
+ chat_id_md5 = hashlib.md5(data["chat_id"].encode()).hexdigest()
742
+ collection = chromadb_client.get_or_create_collection(
743
+ name=f"chat-{chat_id_md5}", embedding_function=chromadb_embed_fn
744
+ )
745
+
746
+ count = collection.count()
747
+ collection.delete()
748
+ #Write deletion to persistent folder
749
+ chromadb_client.persist()
750
+ print("ChromaDB embeddings deleted", count)
751
+ return 'Ok', 200
752
+
753
+
754
+ @app.route("/api/chromadb/query", methods=["POST"])
755
+ @require_module("chromadb")
756
+ def chromadb_query():
757
+ data = request.get_json()
758
+ if "chat_id" not in data or not isinstance(data["chat_id"], str):
759
+ abort(400, '"chat_id" is required')
760
+ if "query" not in data or not isinstance(data["query"], str):
761
+ abort(400, '"query" is required')
762
+
763
+ if "n_results" not in data or not isinstance(data["n_results"], int):
764
+ n_results = 1
765
+ else:
766
+ n_results = data["n_results"]
767
+
768
+ chat_id_md5 = hashlib.md5(data["chat_id"].encode()).hexdigest()
769
+ collection = chromadb_client.get_or_create_collection(
770
+ name=f"chat-{chat_id_md5}", embedding_function=chromadb_embed_fn
771
+ )
772
+
773
+ n_results = min(collection.count(), n_results)
774
+ query_result = collection.query(
775
+ query_texts=[data["query"]],
776
+ n_results=n_results,
777
+ )
778
+
779
+ documents = query_result["documents"][0]
780
+ ids = query_result["ids"][0]
781
+ metadatas = query_result["metadatas"][0]
782
+ distances = query_result["distances"][0]
783
+
784
+ messages = [
785
+ {
786
+ "id": ids[i],
787
+ "date": metadatas[i]["date"],
788
+ "role": metadatas[i]["role"],
789
+ "meta": metadatas[i]["meta"],
790
+ "content": documents[i],
791
+ "distance": distances[i],
792
+ }
793
+ for i in range(len(ids))
794
+ ]
795
+
796
+ return jsonify(messages)
797
+
798
+
799
+ @app.route("/api/chromadb/export", methods=["POST"])
800
+ @require_module("chromadb")
801
+ def chromadb_export():
802
+ data = request.get_json()
803
+ if "chat_id" not in data or not isinstance(data["chat_id"], str):
804
+ abort(400, '"chat_id" is required')
805
+
806
+ chat_id_md5 = hashlib.md5(data["chat_id"].encode()).hexdigest()
807
+ collection = chromadb_client.get_or_create_collection(
808
+ name=f"chat-{chat_id_md5}", embedding_function=chromadb_embed_fn
809
+ )
810
+ collection_content = collection.get()
811
+ documents = collection_content.get('documents', [])
812
+ ids = collection_content.get('ids', [])
813
+ metadatas = collection_content.get('metadatas', [])
814
+
815
+ unsorted_content = [
816
+ {
817
+ "id": ids[i],
818
+ "metadata": metadatas[i],
819
+ "document": documents[i],
820
+ }
821
+ for i in range(len(ids))
822
+ ]
823
+
824
+ sorted_content = sorted(unsorted_content, key=lambda x: x['metadata']['date'])
825
+
826
+ export = {
827
+ "chat_id": data["chat_id"],
828
+ "content": sorted_content
829
+ }
830
+
831
+ return jsonify(export)
832
+
833
+ @app.route("/api/chromadb/import", methods=["POST"])
834
+ @require_module("chromadb")
835
+ def chromadb_import():
836
+ data = request.get_json()
837
+ content = data['content']
838
+ if "chat_id" not in data or not isinstance(data["chat_id"], str):
839
+ abort(400, '"chat_id" is required')
840
+
841
+ chat_id_md5 = hashlib.md5(data["chat_id"].encode()).hexdigest()
842
+ collection = chromadb_client.get_or_create_collection(
843
+ name=f"chat-{chat_id_md5}", embedding_function=chromadb_embed_fn
844
+ )
845
+
846
+ documents = [item['document'] for item in content]
847
+ metadatas = [item['metadata'] for item in content]
848
+ ids = [item['id'] for item in content]
849
+
850
+
851
+ collection.upsert(documents=documents, metadatas=metadatas, ids=ids)
852
+
853
+ return jsonify({"count": len(ids)})
854
+
855
+ ignore_auth.append(tts_play_sample)
856
+ app.run(host=host, port=port)
tts_edge.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import edge_tts
3
+ import asyncio
4
+
5
+
6
+ def get_voices():
7
+ voices = asyncio.run(edge_tts.list_voices())
8
+ return voices
9
+
10
+
11
+ async def _iterate_chunks(audio):
12
+ async for chunk in audio.stream():
13
+ if chunk["type"] == "audio":
14
+ yield chunk["data"]
15
+
16
+
17
+ async def _async_generator_to_list(async_gen):
18
+ result = []
19
+ async for item in async_gen:
20
+ result.append(item)
21
+ return result
22
+
23
+
24
+ def generate_audio(text: str, voice: str, rate: int) -> bytes:
25
+ sign = '+' if rate > 0 else '-'
26
+ rate = f'{sign}{abs(rate)}%'
27
+ audio = edge_tts.Communicate(text=text, voice=voice, rate=rate)
28
+ chunks = asyncio.run(_async_generator_to_list(_iterate_chunks(audio)))
29
+ buffer = io.BytesIO()
30
+
31
+ for chunk in chunks:
32
+ buffer.write(chunk)
33
+
34
+ return buffer.getvalue()