multimodalart HF staff commited on
Commit
49240f9
1 Parent(s): 4830021

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -0
app.py CHANGED
@@ -2,6 +2,7 @@ import os
2
  from knowledge_storm import STORMWikiRunnerArguments, STORMWikiRunner, STORMWikiLMConfigs
3
  from knowledge_storm.lm import OpenAIModel
4
  from knowledge_storm.rm import YouRM
 
5
 
6
  lm_configs = STORMWikiLMConfigs()
7
  openai_kwargs = {
@@ -25,6 +26,7 @@ engine_args = STORMWikiRunnerArguments("outputs")
25
  rm = YouRM(ydc_api_key=os.getenv('YDC_API_KEY'), k=engine_args.search_top_k)
26
  runner = STORMWikiRunner(engine_args, lm_configs, rm)
27
 
 
28
  def generate_article(prompt, progress=gr.Progress(track_tqdm=True)):
29
  runner.run(
30
  topic=prompt,
@@ -36,4 +38,11 @@ def generate_article(prompt, progress=gr.Progress(track_tqdm=True)):
36
  runner.post_run()
37
  runner.summary()
38
 
 
 
 
 
 
39
 
 
 
 
2
  from knowledge_storm import STORMWikiRunnerArguments, STORMWikiRunner, STORMWikiLMConfigs
3
  from knowledge_storm.lm import OpenAIModel
4
  from knowledge_storm.rm import YouRM
5
+ import spaces
6
 
7
  lm_configs = STORMWikiLMConfigs()
8
  openai_kwargs = {
 
26
  rm = YouRM(ydc_api_key=os.getenv('YDC_API_KEY'), k=engine_args.search_top_k)
27
  runner = STORMWikiRunner(engine_args, lm_configs, rm)
28
 
29
+ @spaces.GPU
30
  def generate_article(prompt, progress=gr.Progress(track_tqdm=True)):
31
  runner.run(
32
  topic=prompt,
 
38
  runner.post_run()
39
  runner.summary()
40
 
41
+ with gr.Blocks() as demo:
42
+ prompt = gr.Textbox(label="Prompt")
43
+ output = gr.Markdown(label="Output")
44
+ btn = gr.Button("Generate")
45
+ btn.click(fn=generate_article, inputs=prompt, outputs=output)
46
 
47
+ if __name__ == "__main__":
48
+ demo.launch()