ccm commited on
Commit
6f386ce
·
1 Parent(s): d29161a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -9
app.py CHANGED
@@ -1,24 +1,26 @@
1
  import numpy
2
  import gradio
3
  from huggingface_hub import from_pretrained_keras
 
 
4
 
5
- S = 5
6
- N = 1000
7
- D = 3
8
- F = 64
9
- G = 32
10
 
11
  analysis_network = from_pretrained_keras("cmudrc/wave-energy-analysis")
12
  synthesis_network = from_pretrained_keras("cmudrc/wave-energy-synthesis")
13
 
14
  with gradio.Blocks() as demo:
15
- geometry = gradio.Textbox(label="geometry")
16
- spectrum = gradio.Textbox(label="spectrum")
17
 
18
  analyze_it = gradio.Button("Analyze")
19
  synthesize_it = gradio.Button("Synthesize")
20
 
21
- analyze_it.click(fn=lambda x: analysis_network.predict(numpy.fromstring(x, dtype=float)).tostring(), inputs=[geometry], outputs=[spectrum], api_name="analyze")
22
- synthesize_it.click(fn=lambda x: synthesis_network.predict(numpy.fromstring(x, dtype=int)).tostring(), inputs=[spectrum], outputs=[geometry], api_name="synthesize")
23
 
24
  demo.launch()
 
1
  import numpy
2
  import gradio
3
  from huggingface_hub import from_pretrained_keras
4
+ import json
5
+ from json import JSONEncoder
6
 
7
+ class NumpyArrayEncoder(JSONEncoder):
8
+ def default(self, obj):
9
+ if isinstance(obj, numpy.ndarray):
10
+ return obj.tolist()
11
+ return JSONEncoder.default(self, obj)
12
 
13
  analysis_network = from_pretrained_keras("cmudrc/wave-energy-analysis")
14
  synthesis_network = from_pretrained_keras("cmudrc/wave-energy-synthesis")
15
 
16
  with gradio.Blocks() as demo:
17
+ geometry = gradio.JSON(label="geometry")
18
+ spectrum = gradio.JSON(label="spectrum")
19
 
20
  analyze_it = gradio.Button("Analyze")
21
  synthesize_it = gradio.Button("Synthesize")
22
 
23
+ analyze_it.click(fn=lambda x: json.dumps(analysis_network.predict(numpy.asarray(json.loads(x))), cls=NumpyArrayEncoder), inputs=[geometry], outputs=[spectrum], api_name="analyze")
24
+ synthesize_it.click(fn=lambda x: json.dumps(synthesis_network.predict(numpy.asarray(json.loads(x))), inputs=[spectrum], outputs=[geometry], api_name="synthesize")
25
 
26
  demo.launch()