rajistics commited on
Commit
87050c9
1 Parent(s): 77427ef

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +63 -205
app.py CHANGED
@@ -1,207 +1,65 @@
1
- import logging
2
-
3
- from h2o_wave import Q, main, app, copy_expando, handle_on, on
4
- import whisper
5
-
6
- import cards
7
- from utils import get_inline_script
8
-
9
- # Set up logging
10
- logging.basicConfig(format='%(levelname)s:\t[%(asctime)s]\t%(message)s', level=logging.INFO)
11
-
12
-
13
- @app('/')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  async def serve(q: Q):
15
- """
16
- Main entry point. All queries pass through this function.
17
- """
18
-
19
- try:
20
- # Initialize the app if not already
21
- if not q.app.initialized:
22
- await initialize_app(q)
23
-
24
- # Initialize the client if not already
25
- if not q.client.initialized:
26
- await initialize_client(q)
27
-
28
- # Update theme if toggled
29
- elif q.args.theme_dark is not None and q.args.theme_dark != q.client.theme_dark:
30
- await update_theme(q)
31
-
32
- # Run inference if audio is recorded
33
- elif q.events.audio:
34
- await audio_inference(q)
35
-
36
- # Delegate query to query handlers
37
- elif await handle_on(q):
38
- pass
39
-
40
- # Adding this condition to help in identifying bugs
41
- else:
42
- await handle_fallback(q)
43
-
44
- except Exception as error:
45
- await show_error(q, error=str(error))
46
-
47
-
48
- async def initialize_app(q: Q):
49
- """
50
- Initialize the app.
51
- """
52
-
53
- logging.info('Initializing app')
54
-
55
- # Set initial argument values
56
- q.app.cards = ['main', 'error']
57
-
58
- q.app.model = whisper.load_model('base')
59
-
60
- q.app.initialized = True
61
-
62
-
63
- async def initialize_client(q: Q):
64
- """
65
- Initialize the client (browser tab).
66
- """
67
-
68
- logging.info('Initializing client')
69
-
70
- # Set initial argument values
71
- q.client.theme_dark = True
72
-
73
- # Add layouts, scripts, header and footer
74
- q.page['meta'] = cards.meta
75
- q.page['header'] = cards.header
76
- q.page['footer'] = cards.footer
77
-
78
- # Add cards for the main page
79
- q.page['asr'] = cards.asr()
80
-
81
- q.client.initialized = True
82
-
83
- await q.page.save()
84
-
85
-
86
- async def update_theme(q: Q):
87
- """
88
- Update theme of app.
89
- """
90
-
91
- # Copying argument values to client
92
- copy_expando(q.args, q.client)
93
-
94
- if q.client.theme_dark:
95
- logging.info('Updating theme to dark mode')
96
-
97
- # Update theme from light to dark mode
98
- q.page['meta'].theme = 'h2o-dark'
99
- q.page['header'].icon_color = 'black'
100
  else:
101
- logging.info('Updating theme to light mode')
102
-
103
- # Update theme from dark to light mode
104
- q.page['meta'].theme = 'light'
105
- q.page['header'].icon_color = '#FEC924'
106
-
107
- await q.page.save()
108
-
109
-
110
- @on('start')
111
- async def start_recording(q: Q):
112
- """
113
- Start recording audio.
114
- """
115
-
116
- logging.info('Starting recording')
117
-
118
- q.page['meta'].script = get_inline_script('startRecording()')
119
- q.page['asr'] = cards.asr(recording=True)
120
-
121
- await q.page.save()
122
-
123
-
124
- @on('stop')
125
- async def stop_recording(q: Q):
126
- """
127
- Stop recording audio.
128
- """
129
-
130
- logging.info('Stopping recording')
131
-
132
- q.page['meta'].script = get_inline_script('stopRecording()')
133
- q.page['asr'] = cards.asr()
134
-
135
- await q.page.save()
136
-
137
-
138
- @on('audio')
139
- async def audio_inference(q: Q):
140
- """
141
- Running ASR inference on audio.
142
- """
143
-
144
- logging.info('Inferencing recorded audio')
145
-
146
- audio_path = await q.site.download(q.events.audio.captured, '.')
147
-
148
- q.client.transcription = q.app.model.transcribe(audio_path)['text']
149
-
150
- q.page['asr'] = cards.asr(audio_path=q.events.audio.captured, transcription=q.client.transcription)
151
-
152
- await q.page.save()
153
-
154
-
155
- def clear_cards(q: Q, card_names: list):
156
- """
157
- Clear cards from the page.
158
- """
159
-
160
- logging.info('Clearing cards')
161
-
162
- # Delete cards from the page
163
- for card_name in card_names:
164
- del q.page[card_name]
165
-
166
-
167
- async def show_error(q: Q, error: str):
168
- """
169
- Displays errors.
170
- """
171
-
172
- logging.error(error)
173
-
174
- # Clear all cards
175
- clear_cards(q, q.app.cards)
176
-
177
- # Format and display the error
178
- q.page['error'] = cards.crash_report(q)
179
-
180
- await q.page.save()
181
-
182
-
183
- @on('reload')
184
- async def reload_client(q: Q):
185
- """
186
- Reset the client.
187
- """
188
-
189
- logging.info('Reloading client')
190
-
191
- # Clear all cards
192
- clear_cards(q, q.app.cards)
193
-
194
- # Reload the client
195
- await initialize_client(q)
196
-
197
-
198
- async def handle_fallback(q: Q):
199
- """
200
- Handle fallback cases.
201
- """
202
-
203
- logging.info('Adding fallback page')
204
-
205
- q.page['fallback'] = cards.fallback
206
-
207
- await q.page.save()
 
1
+ from h2o_wave import main, app, Q, ui, copy_expando
2
+ from transformers import pipeline
3
+
4
+ async def init(q: Q):
5
+ if not q.client.app_initialized:
6
+ q.app.model = pipeline("text-generation")
7
+ q.client.app_initialized = True
8
+
9
+ q.page.drop()
10
+
11
+ q.page["title"] = ui.header_card(
12
+ box="1 1 8 1",
13
+ title="Text Generation",
14
+ subtitle="Generate text using Huggingface pipelines",
15
+ icon="AddNotes",
16
+ icon_color="Blue",
17
+ )
18
+
19
+ async def get_inputs(q: Q):
20
+ q.page['main'] = ui.form_card(box="1 2 8 5", items=[
21
+ ui.text_xl('Enter your text input for generation:'),
22
+ ui.textbox(name="input_text",
23
+ label='',
24
+ value=q.app.input_text,
25
+ multiline=True),
26
+ ui.separator(),
27
+ ui.slider(name="num_words_to_generate",
28
+ label="Maximum number of words to generate (including input text)",
29
+ min=5,
30
+ max=50,
31
+ step=1,
32
+ value=q.app.num_words_to_generate if q.app.num_words_to_generate else 12,
33
+ ),
34
+ ui.separator(),
35
+ ui.buttons([ui.button(name="generate_text", label='Generate', primary=True),
36
+ ])
37
+ ])
38
+
39
+ async def show_results(q: Q):
40
+ q.page['main'] = ui.form_card(box="1 2 4 5", items=[
41
+ ui.text_xl("Input Text:"),
42
+ ui.separator(),
43
+ ui.text(q.app.input_text),
44
+ ui.separator(),
45
+ ui.buttons([ui.button(name="get_inputs", label='Try Again!', primary=True),
46
+ ])
47
+ ])
48
+
49
+ result = q.app.model(q.app.input_text, max_length=q.app.num_words_to_generate, do_sample=False)[0]
50
+ q.app.generated_text = result["generated_text"]
51
+ q.page['visualization'] = ui.form_card(box="5 2 4 5", items=[
52
+ ui.text_xl("Generated Text:"),
53
+ ui.separator(''),
54
+ ui.text(q.app.generated_text)
55
+ ])
56
+
57
+ @app("/")
58
  async def serve(q: Q):
59
+ await init(q)
60
+ if q.args.generate_text:
61
+ copy_expando(q.args, q.app)
62
+ await show_results(q)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  else:
64
+ await get_inputs(q)
65
+ await q.page.save()