File size: 1,693 Bytes
caaa800
e228074
caaa800
 
4cef1a6
caaa800
 
9045a87
308dd86
e228074
308dd86
caaa800
e228074
 
 
48847f0
e228074
 
66476a4
caaa800
e228074
308dd86
e228074
 
 
 
 
 
 
 
 
 
 
 
 
 
66476a4
e228074
 
 
 
 
 
3857659
 
48af855
e228074
 
 
4cef1a6
e228074
caaa800
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
#############################################################################
# Title:  Gradio Interface to AI hosted by Huggingface
# Author: Andreas Fischer
# Date:   October 7th, 2023
# Last update: December 29th, 2023
#############################################################################

import gradio as gr
import requests
import time
import json

def response(message, history, model):
  if(model=="Default"): model = "mistralai/Mixtral-8x7B-Instruct-v0.1" 
  model_id = model
  params={"max_new_tokens":600, "return_full_text":False} #, "max_length":500, "stream":True
  url = f"https://api-inference.huggingface.co/models/{model_id}"
  correction=1
  prompt=f"[INST] {message} [/INST]" # skipped <s>
  print("URL: "+url)
  print(params)
  print("User: "+message+"\nAI: ")
  response=""
  for text in requests.post(url, json={"inputs":prompt, "parameters":params}, stream=True):
    text=text.decode('UTF-8')
    print(text)
    if(correction==3):
      text='"}]'+text
      correction=2
    if(correction==1):
      text=text.lstrip('[{"generated_text":"')
      correction=2
    if(text.endswith('"}]')):
      text=text.rstrip('"}]')
      correction=3 
    response=response+text
    print(text)
    time.sleep(0.2)
    yield response

x=requests.get(f"https://api-inference.huggingface.co/framework/text-generation-inference")
x=[i["model_id"] for i in x.json()]
print(x)
x=[s for s in x if s.startswith("mistral")]
print(x)
x.insert(0,"Default")

gr.ChatInterface(
  response,
  title="AI-Interface to HuggingFace-Models",
  additional_inputs=[gr.Dropdown(x,value="Default",label="Model")]).queue().launch(share=True) #False, server_name="0.0.0.0", server_port=7864)