thaboe01 commited on
Commit
a61e741
1 Parent(s): b9544c7

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -0
app.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from transformers import T5Tokenizer, T5ForConditionalGeneration
3
+
4
+ # Load your fine-tuned FLAN-T5 model and tokenizer
5
+ @st.cache_resource
6
+ def load_model():
7
+ tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-small")
8
+ model = T5ForConditionalGeneration.from_pretrained("thaboe01/t5-spelling-correctorv2")
9
+ return tokenizer, model
10
+
11
+ # Load model (only once)
12
+ tokenizer, model = load_model()
13
+
14
+ MAX_PHRASE_LENGTH = 5
15
+ PREFIX = "Please correct the following sentence: "
16
+
17
+ # Function to correct text
18
+ def correct_text(text):
19
+ words = text.split()
20
+ corrected_phrases = []
21
+ current_chunk = []
22
+
23
+ for word in words:
24
+ current_chunk.append(word)
25
+ # Check if adding the next word would exceed max length (including prefix)
26
+ if len(current_chunk) + 1 > MAX_PHRASE_LENGTH:
27
+ input_text = PREFIX + " ".join(current_chunk)
28
+ input_ids = tokenizer(input_text, return_tensors="pt").input_ids
29
+ outputs = model.generate(input_ids)
30
+ corrected_phrase = tokenizer.decode(outputs[0], skip_special_tokens=True)[len(PREFIX):] # Remove the prefix
31
+ corrected_phrases.append(corrected_phrase)
32
+ current_chunk = [] # Reset the chunk
33
+
34
+ # Handle the last chunk
35
+ if current_chunk:
36
+ input_text = PREFIX + " ".join(current_chunk)
37
+ input_ids = tokenizer(input_text, return_tensors="pt").input_ids
38
+ outputs = model.generate(input_ids)
39
+ corrected_phrase = tokenizer.decode(outputs[0], skip_special_tokens=True)[len(PREFIX):]
40
+ corrected_phrases.append(corrected_phrase)
41
+
42
+ return " ".join(corrected_phrases) # Join the corrected chunks
43
+
44
+
45
+ # Streamlit App
46
+ st.title("Shona Text Editor with Real-Time Spelling Correction")
47
+ text_input = st.text_area("Start typing here...", height=250)
48
+
49
+ if text_input:
50
+ corrected_text = correct_text(text_input)
51
+ st.text_area("Corrected Text", value=corrected_text, height=250, disabled=True)