cahya commited on
Commit
832a8c0
·
1 Parent(s): 9aae25d

use external inference

Browse files
Files changed (2) hide show
  1. README.md +1 -1
  2. app.py +30 -30
README.md CHANGED
@@ -7,7 +7,7 @@ sdk: gradio
7
  sdk_version: 3.18.0
8
  app_file: app.py
9
  pinned: false
10
- license: cc
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
7
  sdk_version: 3.18.0
8
  app_file: app.py
9
  pinned: false
10
+ license: creativeml-openrail-m
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py CHANGED
@@ -1,38 +1,38 @@
1
- import torch
2
  import gradio as gr
3
- from transformers import pipeline
4
  import os
5
  from mtranslate import translate
 
6
 
7
- device = torch.cuda.current_device() if torch.cuda.is_available() else "cpu"
8
  HF_AUTH_TOKEN = os.environ.get("HF_AUTH_TOKEN")
9
- text_generation_model = "cahya/indochat-tiny"
10
- text_generation = pipeline("text-generation", text_generation_model, use_auth_token=HF_AUTH_TOKEN, device=device)
11
 
12
-
13
- def get_answer(user_input, decoding_methods, num_beams, top_k, top_p, temperature, repetition_penalty, penalty_alpha):
14
- if decoding_methods == "Beam Search":
15
- do_sample = False
16
- penalty_alpha = 0
17
- elif decoding_methods == "Sampling":
18
- do_sample = True
19
- penalty_alpha = 0
20
- num_beams = 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  else:
22
- do_sample = False
23
- num_beams = 1
24
- print(user_input, decoding_methods, do_sample, top_k, top_p, temperature, repetition_penalty, penalty_alpha)
25
- prompt = f"User: {user_input}\nAssistant: "
26
- generated_text = text_generation(f"{prompt}", min_length=50, max_length=200, num_return_sequences=1,
27
- num_beams=num_beams, do_sample=do_sample, top_k=top_k, top_p=top_p,
28
- temperature=temperature, repetition_penalty=repetition_penalty,
29
- penalty_alpha=penalty_alpha)
30
- answer = generated_text[0]["generated_text"]
31
- answer_without_prompt = answer[len(prompt)+1:]
32
- user_input_en = translate(user_input, "en", "id")
33
- answer_without_prompt_en = translate(answer_without_prompt, "en", "id")
34
- return [(f"{user_input}\n", None), (answer_without_prompt, "")], \
35
- [(f"{user_input_en}\n", None), (answer_without_prompt_en, "")]
36
 
37
 
38
  css = """
@@ -55,7 +55,7 @@ with gr.Blocks(css=css) as demo:
55
  user_input = gr.inputs.Textbox(placeholder="",
56
  label="Ask me something in Indonesian or English",
57
  default="Bagaimana cara mendidik anak supaya tidak berbohong?")
58
- decoding_methods = gr.inputs.Dropdown(["Beam Search", "Sampling", "Contrastive Search"],
59
  default="Sampling", label="Decoding Method")
60
  num_beams = gr.inputs.Slider(label="Number of beams for beam search",
61
  default=1, minimum=1, maximum=10, step=1)
@@ -85,7 +85,7 @@ with gr.Blocks(css=css) as demo:
85
  gr.Markdown("![visitor badge](https://visitor-badge.glitch.me/badge?page_id=cahya_indochat)")
86
 
87
  button_generate_story.click(get_answer,
88
- inputs=[user_input, decoding_methods, num_beams, top_k, top_p, temperature,
89
  repetition_penalty, penalty_alpha],
90
  outputs=[generated_answer, generated_answer_en])
91
 
 
 
1
  import gradio as gr
 
2
  import os
3
  from mtranslate import translate
4
+ import requests
5
 
 
6
  HF_AUTH_TOKEN = os.environ.get("HF_AUTH_TOKEN")
7
+ indochat_api = 'https://cahya-indonesian-whisperer.hf.space/api/indochat/v1'
8
+ indochat_api_auth_token = os.getenv("INDOCHAT_API_AUTH_TOKEN", "")
9
 
10
+ def get_answer(user_input, decoding_method, num_beams, top_k, top_p, temperature, repetition_penalty, penalty_alpha):
11
+ print(user_input, decoding_method, top_k, top_p, temperature, repetition_penalty, penalty_alpha)
12
+ headers = {'Authorization': 'Bearer ' + indochat_api_auth_token}
13
+ data = {
14
+ "text": user_input,
15
+ "min_length": len(user_input) + 50,
16
+ "max_length": 300,
17
+ "decoding_method": decoding_method,
18
+ "num_beams": num_beams,
19
+ "top_k": top_k,
20
+ "top_p": top_p,
21
+ "temperature": temperature,
22
+ "seed": -1,
23
+ "repetition_penalty": repetition_penalty,
24
+ "penalty_alpha": penalty_alpha
25
+ }
26
+ r = requests.post(indochat_api, headers=headers, data=data)
27
+ if r.status_code == 200:
28
+ result = r.json()
29
+ answer = result["generated_text"]
30
+ user_input_en = translate(user_input, "en", "id")
31
+ answer_en = translate(answer, "en", "id")
32
+ return [(f"{user_input}\n", None), (answer, "")], \
33
+ [(f"{user_input_en}\n", None), (answer_en, "")]
34
  else:
35
+ return "Error: " + r.text
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
 
38
  css = """
 
55
  user_input = gr.inputs.Textbox(placeholder="",
56
  label="Ask me something in Indonesian or English",
57
  default="Bagaimana cara mendidik anak supaya tidak berbohong?")
58
+ decoding_method = gr.inputs.Dropdown(["Beam Search", "Sampling", "Contrastive Search"],
59
  default="Sampling", label="Decoding Method")
60
  num_beams = gr.inputs.Slider(label="Number of beams for beam search",
61
  default=1, minimum=1, maximum=10, step=1)
 
85
  gr.Markdown("![visitor badge](https://visitor-badge.glitch.me/badge?page_id=cahya_indochat)")
86
 
87
  button_generate_story.click(get_answer,
88
+ inputs=[user_input, decoding_method, num_beams, top_k, top_p, temperature,
89
  repetition_penalty, penalty_alpha],
90
  outputs=[generated_answer, generated_answer_en])
91