Spaces:
Runtime error
Runtime error
add the max length of text
Browse files- app/app.py +11 -4
app/app.py
CHANGED
@@ -14,16 +14,16 @@ model_name = "cahya/gpt2-small-indonesian-story"
|
|
14 |
|
15 |
@st.cache(suppress_st_warning=True, allow_output_mutation=True)
|
16 |
def get_generator():
|
17 |
-
st.write("Loading the GPT2 model...")
|
18 |
text_generator = pipeline('text-generation', model=model_name)
|
19 |
return text_generator
|
20 |
|
21 |
|
22 |
#@st.cache(suppress_st_warning=True)
|
23 |
def process(text: str, max_length: int = 100, do_sample: bool = True, top_k: int = 50, top_p: float = 0.95,
|
24 |
-
temperature: float = 1.0, max_time: float = None):
|
25 |
st.write("Cache miss: process")
|
26 |
-
set_seed(
|
27 |
result = text_generator(text, max_length=max_length, do_sample=do_sample,
|
28 |
top_k=top_k, top_p=top_p, temperature=temperature, max_time=max_time)
|
29 |
return result
|
@@ -58,6 +58,13 @@ else:
|
|
58 |
|
59 |
session_state.text = st.text_area("Enter text", session_state.prompt_box)
|
60 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
temp = st.sidebar.slider(
|
62 |
"Temperature",
|
63 |
value=1.0,
|
@@ -80,7 +87,7 @@ if st.button("Run"):
|
|
80 |
with st.spinner(text="Getting results..."):
|
81 |
st.subheader("Result")
|
82 |
time_start = time.time()
|
83 |
-
result = process(text=session_state.text, top_k=int(top_k), top_p=float(top_p))
|
84 |
time_end = time.time()
|
85 |
time_diff = time_end-time_start
|
86 |
#print(f"Text generated in {time_diff} seconds")
|
|
|
14 |
|
15 |
@st.cache(suppress_st_warning=True, allow_output_mutation=True)
|
16 |
def get_generator():
|
17 |
+
st.write(f"Loading the GPT2 model {model_name}, please wait...")
|
18 |
text_generator = pipeline('text-generation', model=model_name)
|
19 |
return text_generator
|
20 |
|
21 |
|
22 |
#@st.cache(suppress_st_warning=True)
|
23 |
def process(text: str, max_length: int = 100, do_sample: bool = True, top_k: int = 50, top_p: float = 0.95,
|
24 |
+
temperature: float = 1.0, max_time: float = None, seed=42):
|
25 |
st.write("Cache miss: process")
|
26 |
+
set_seed(seed)
|
27 |
result = text_generator(text, max_length=max_length, do_sample=do_sample,
|
28 |
top_k=top_k, top_p=top_p, temperature=temperature, max_time=max_time)
|
29 |
return result
|
|
|
58 |
|
59 |
session_state.text = st.text_area("Enter text", session_state.prompt_box)
|
60 |
|
61 |
+
max_length = st.sidebar.number_input(
|
62 |
+
"Maximum length",
|
63 |
+
value=100,
|
64 |
+
max_value=512,
|
65 |
+
help="The maximum length of the sequence to be generated."
|
66 |
+
)
|
67 |
+
|
68 |
temp = st.sidebar.slider(
|
69 |
"Temperature",
|
70 |
value=1.0,
|
|
|
87 |
with st.spinner(text="Getting results..."):
|
88 |
st.subheader("Result")
|
89 |
time_start = time.time()
|
90 |
+
result = process(text=session_state.text, max_length=int(max_length), top_k=int(top_k), top_p=float(top_p))
|
91 |
time_end = time.time()
|
92 |
time_diff = time_end-time_start
|
93 |
#print(f"Text generated in {time_diff} seconds")
|