Spaces:
Runtime error
Runtime error
add first commit
Browse files- README.md +40 -7
- app/SessionState.py +95 -0
- app/abstract_dataset.py +62 -0
- app/app.py +244 -0
- app/prompts.py +57 -0
- requirements.txt +10 -0
README.md
CHANGED
@@ -1,12 +1,45 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: streamlit
|
7 |
-
|
8 |
-
app_file: app.py
|
9 |
pinned: false
|
10 |
---
|
11 |
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
+
title: News Generator
|
3 |
+
emoji: 🦀
|
4 |
+
colorFrom: pink
|
5 |
+
colorTo: red
|
6 |
sdk: streamlit
|
7 |
+
app_file: app/app.py
|
|
|
8 |
pinned: false
|
9 |
---
|
10 |
|
11 |
+
# Indonesian GPT-2 Applications
|
12 |
+
This is Application that generates sentences using Indonesian GPT-2 models finetuned on 6GB online news dataset!
|
13 |
+
|
14 |
+
|
15 |
+
## How did we create it
|
16 |
+
|
17 |
+
## Development
|
18 |
+
|
19 |
+
### Dependencies Installation
|
20 |
+
|
21 |
+
### Inference Pipeline
|
22 |
+
|
23 |
+
## Authors
|
24 |
+
|
25 |
+
Following are the authors of this work (listed alphabetically):
|
26 |
+
- [Cahya Wirawan](https://github.com/cahya-wirawan)
|
27 |
+
|
28 |
+
## Acknowledgements
|
29 |
+
|
30 |
+
- 🤗 Hugging Face for organizing [the FLAX/JAX community week](https://github.com/huggingface/transformers/tree/master/examples/research_projects/jax-projects)
|
31 |
+
- Google [TPU Research Cloud (TRC) program](https://sites.research.google/trc/) for providing computing resources
|
32 |
+
- [Weights & Biases](https://wandb.com/) for providing the infrastructure for experiment tracking and model management
|
33 |
+
|
34 |
+
## Citing Indonesian GPT-2 Applications
|
35 |
+
|
36 |
+
If you find this is useful in your research or wish to refer, please use the following BibTeX entry.
|
37 |
+
|
38 |
+
```
|
39 |
+
@misc{Indonesian_GPT2_App_2021,
|
40 |
+
author = {Cahya Wirawan},
|
41 |
+
title = {Abstract Generator using Indonesian GPT-2},
|
42 |
+
url = {https://github.com/cahya-wirawan/abstract-generator},
|
43 |
+
year = {2021}
|
44 |
+
}
|
45 |
+
```
|
app/SessionState.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Hack to add per-session state to Streamlit.
|
2 |
+
Usage
|
3 |
+
-----
|
4 |
+
>>> import SessionState
|
5 |
+
>>>
|
6 |
+
>>> session_state = SessionState.get(user_name='', favorite_color='black')
|
7 |
+
>>> session_state.user_name
|
8 |
+
''
|
9 |
+
>>> session_state.user_name = 'Mary'
|
10 |
+
>>> session_state.favorite_color
|
11 |
+
'black'
|
12 |
+
Since you set user_name above, next time your script runs this will be the
|
13 |
+
result:
|
14 |
+
>>> session_state = get(user_name='', favorite_color='black')
|
15 |
+
>>> session_state.user_name
|
16 |
+
'Mary'
|
17 |
+
"""
|
18 |
+
from streamlit.scriptrunner import get_script_run_ctx
|
19 |
+
from streamlit.server.server import Server
|
20 |
+
|
21 |
+
|
22 |
+
class SessionState(object):
|
23 |
+
def __init__(self, **kwargs):
|
24 |
+
"""A new SessionState object.
|
25 |
+
Parameters
|
26 |
+
----------
|
27 |
+
**kwargs : any
|
28 |
+
Default values for the session state.
|
29 |
+
Example
|
30 |
+
-------
|
31 |
+
>>> session_state = SessionState(user_name='', favorite_color='black')
|
32 |
+
>>> session_state.user_name = 'Mary'
|
33 |
+
''
|
34 |
+
>>> session_state.favorite_color
|
35 |
+
'black'
|
36 |
+
"""
|
37 |
+
for key, val in kwargs.items():
|
38 |
+
setattr(self, key, val)
|
39 |
+
|
40 |
+
|
41 |
+
def get(**kwargs):
|
42 |
+
"""Gets a SessionState object for the current session.
|
43 |
+
Creates a new object if necessary.
|
44 |
+
Parameters
|
45 |
+
----------
|
46 |
+
**kwargs : any
|
47 |
+
Default values you want to add to the session state, if we're creating a
|
48 |
+
new one.
|
49 |
+
Example
|
50 |
+
-------
|
51 |
+
>>> session_state = get(user_name='', favorite_color='black')
|
52 |
+
>>> session_state.user_name
|
53 |
+
''
|
54 |
+
>>> session_state.user_name = 'Mary'
|
55 |
+
>>> session_state.favorite_color
|
56 |
+
'black'
|
57 |
+
Since you set user_name above, next time your script runs this will be the
|
58 |
+
result:
|
59 |
+
>>> session_state = get(user_name='', favorite_color='black')
|
60 |
+
>>> session_state.user_name
|
61 |
+
'Mary'
|
62 |
+
"""
|
63 |
+
# Hack to get the session object from Streamlit.
|
64 |
+
|
65 |
+
ctx = get_script_run_ctx()
|
66 |
+
|
67 |
+
this_session = None
|
68 |
+
|
69 |
+
current_server = Server.get_current()
|
70 |
+
if hasattr(current_server, '_session_infos'):
|
71 |
+
# Streamlit < 0.56
|
72 |
+
session_infos = Server.get_current()._session_infos.values()
|
73 |
+
else:
|
74 |
+
session_infos = Server.get_current()._session_info_by_id.values()
|
75 |
+
|
76 |
+
for session_info in session_infos:
|
77 |
+
s = session_info.session
|
78 |
+
if (
|
79 |
+
(not hasattr(s, '_main_dg') and s._uploaded_file_mgr == ctx.uploaded_file_mgr)
|
80 |
+
):
|
81 |
+
this_session = s
|
82 |
+
|
83 |
+
if this_session is None:
|
84 |
+
raise RuntimeError(
|
85 |
+
"Oh noes. Couldn't get your Streamlit Session object. "
|
86 |
+
'Are you doing something fancy with threads?')
|
87 |
+
|
88 |
+
# Got the session object! Now let's attach some state into it.
|
89 |
+
|
90 |
+
if not hasattr(this_session, '_custom_session_state'):
|
91 |
+
this_session._custom_session_state = SessionState(**kwargs)
|
92 |
+
|
93 |
+
return this_session._custom_session_state
|
94 |
+
|
95 |
+
__all__ = ['get']
|
app/abstract_dataset.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import random
|
3 |
+
from torch.utils.data import Dataset
|
4 |
+
|
5 |
+
|
6 |
+
class AbstractDataset(Dataset):
|
7 |
+
special_tokens = {"bos_token": "<|BOS|>",
|
8 |
+
"eos_token": "<|EOS|>",
|
9 |
+
"unk_token": "<|UNK|>",
|
10 |
+
"pad_token": "<|PAD|>",
|
11 |
+
"sep_token": "<|SEP|>"}
|
12 |
+
max_length = 1024
|
13 |
+
|
14 |
+
def __init__(self, data, tokenizer, randomize=True):
|
15 |
+
title, text, keywords = [], [], []
|
16 |
+
for k, v in data.items():
|
17 |
+
title.append(v[0])
|
18 |
+
text.append(v[1])
|
19 |
+
keywords.append(v[2])
|
20 |
+
|
21 |
+
self.randomize = randomize
|
22 |
+
self.tokenizer = tokenizer
|
23 |
+
self.title = title
|
24 |
+
self.text = text
|
25 |
+
self.keywords = keywords
|
26 |
+
|
27 |
+
@staticmethod
|
28 |
+
def join_keywords(keywords, randomize=True):
|
29 |
+
N = len(keywords)
|
30 |
+
|
31 |
+
# random sampling and shuffle
|
32 |
+
if randomize:
|
33 |
+
# M = random.choice(range(N + 1))
|
34 |
+
# keywords = keywords[:M]
|
35 |
+
random.shuffle(keywords)
|
36 |
+
|
37 |
+
return ','.join(keywords)
|
38 |
+
|
39 |
+
|
40 |
+
def __len__(self):
|
41 |
+
return len(self.text)
|
42 |
+
|
43 |
+
|
44 |
+
def __getitem__(self, i):
|
45 |
+
keywords = self.keywords[i].copy()
|
46 |
+
kw = self.join_keywords(keywords, self.randomize)
|
47 |
+
|
48 |
+
input = self.special_tokens['bos_token'] + self.title[i] + \
|
49 |
+
self.special_tokens['sep_token'] + kw + self.special_tokens['sep_token'] + \
|
50 |
+
self.text[i] + self.special_tokens['eos_token']
|
51 |
+
|
52 |
+
encodings_dict = self.tokenizer(input,
|
53 |
+
truncation=True,
|
54 |
+
max_length=self.max_length,
|
55 |
+
padding="max_length")
|
56 |
+
|
57 |
+
input_ids = encodings_dict['input_ids']
|
58 |
+
attention_mask = encodings_dict['attention_mask']
|
59 |
+
|
60 |
+
return {'label': torch.tensor(input_ids),
|
61 |
+
'input_ids': torch.tensor(input_ids),
|
62 |
+
'attention_mask': torch.tensor(attention_mask)}
|
app/app.py
ADDED
@@ -0,0 +1,244 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import SessionState
|
3 |
+
from mtranslate import translate
|
4 |
+
from prompts import PROMPT_LIST
|
5 |
+
import random
|
6 |
+
import time
|
7 |
+
from transformers import pipeline, set_seed, AutoConfig, AutoTokenizer, GPT2LMHeadModel, GPT2Tokenizer
|
8 |
+
import psutil
|
9 |
+
import torch
|
10 |
+
import os
|
11 |
+
from abstract_dataset import AbstractDataset
|
12 |
+
|
13 |
+
|
14 |
+
# st.set_page_config(page_title="Indonesian GPT-2")
|
15 |
+
|
16 |
+
mirror_url = "https://abstract-generator.ai-research.id/"
|
17 |
+
if "MIRROR_URL" in os.environ:
|
18 |
+
mirror_url = os.environ["MIRROR_URL"]
|
19 |
+
|
20 |
+
MODELS = {
|
21 |
+
"Indonesian Academic Journal - Indonesian GPT-2 Medium": {
|
22 |
+
"group": "Indonesian Journal",
|
23 |
+
"name": "cahya/abstract-generator",
|
24 |
+
"description": "Abstract Generator using Indonesian GPT-2 Medium.",
|
25 |
+
"text_generator": None,
|
26 |
+
"tokenizer": None
|
27 |
+
},
|
28 |
+
}
|
29 |
+
|
30 |
+
st.sidebar.markdown("""
|
31 |
+
<style>
|
32 |
+
.centeralign {
|
33 |
+
text-align: center;
|
34 |
+
}
|
35 |
+
</style>
|
36 |
+
<p class="centeralign">
|
37 |
+
<img src="https://huggingface.co/spaces/flax-community/gpt2-indonesian/resolve/main/huggingwayang.png"/>
|
38 |
+
</p>
|
39 |
+
""", unsafe_allow_html=True)
|
40 |
+
st.sidebar.markdown(f"""
|
41 |
+
___
|
42 |
+
<p class="centeralign">
|
43 |
+
This is a collection of applications that generates sentences using Indonesian GPT-2 models!
|
44 |
+
</p>
|
45 |
+
<p class="centeralign">
|
46 |
+
Created by <a href="https://huggingface.co/indonesian-nlp">Indonesian NLP</a> team @2021
|
47 |
+
<br/>
|
48 |
+
<a href="https://github.com/indonesian-nlp/gpt2-app" target="_blank">GitHub</a> | <a href="https://github.com/indonesian-nlp/gpt2-app" target="_blank">Project Report</a>
|
49 |
+
<br/>
|
50 |
+
A mirror of the application is available <a href="{mirror_url}" target="_blank">here</a>
|
51 |
+
</p>
|
52 |
+
""", unsafe_allow_html=True)
|
53 |
+
|
54 |
+
st.sidebar.markdown("""
|
55 |
+
___
|
56 |
+
""", unsafe_allow_html=True)
|
57 |
+
|
58 |
+
model_type = st.sidebar.selectbox('Model', (MODELS.keys()))
|
59 |
+
|
60 |
+
|
61 |
+
@st.cache(suppress_st_warning=True, allow_output_mutation=True)
|
62 |
+
def get_generator(model_name: str):
|
63 |
+
st.write(f"Loading the GPT2 model {model_name}, please wait...")
|
64 |
+
special_tokens = AbstractDataset.special_tokens
|
65 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
66 |
+
tokenizer.add_special_tokens(special_tokens)
|
67 |
+
config = AutoConfig.from_pretrained(model_name,
|
68 |
+
bos_token_id=tokenizer.bos_token_id,
|
69 |
+
eos_token_id=tokenizer.eos_token_id,
|
70 |
+
sep_token_id=tokenizer.sep_token_id,
|
71 |
+
pad_token_id=tokenizer.pad_token_id,
|
72 |
+
output_hidden_states=False)
|
73 |
+
model = GPT2LMHeadModel.from_pretrained(model_name, config=config)
|
74 |
+
model.resize_token_embeddings(len(tokenizer))
|
75 |
+
return model, tokenizer
|
76 |
+
|
77 |
+
|
78 |
+
# Disable the st.cache for this function due to issue on newer version of streamlit
|
79 |
+
# @st.cache(suppress_st_warning=True, hash_funcs={tokenizers.Tokenizer: id})
|
80 |
+
def process(text_generator, tokenizer, title: str, keywords: str, text: str,
|
81 |
+
max_length: int = 200, do_sample: bool = True, top_k: int = 50, top_p: float = 0.95,
|
82 |
+
temperature: float = 1.0, max_time: float = 120.0, seed=42, repetition_penalty=1.0):
|
83 |
+
# st.write("Cache miss: process")
|
84 |
+
set_seed(seed)
|
85 |
+
if repetition_penalty == 0.0:
|
86 |
+
min_penalty = 1.05
|
87 |
+
max_penalty = 1.5
|
88 |
+
repetition_penalty = max(min_penalty + (1.0-temperature) * (max_penalty-min_penalty), 0.8)
|
89 |
+
|
90 |
+
keywords = [keyword.strip() for keyword in keywords.split(",")]
|
91 |
+
keywords = AbstractDataset.join_keywords(keywords, randomize=False)
|
92 |
+
|
93 |
+
special_tokens = AbstractDataset.special_tokens
|
94 |
+
prompt = special_tokens['bos_token'] + title + \
|
95 |
+
special_tokens['sep_token'] + keywords + special_tokens['sep_token'] + text
|
96 |
+
|
97 |
+
print(f"title: {title}, keywords: {keywords}, text: {text}")
|
98 |
+
|
99 |
+
generated = torch.tensor(tokenizer.encode(prompt)).unsqueeze(0)
|
100 |
+
# device = torch.device("cuda")
|
101 |
+
# generated = generated.to(device)
|
102 |
+
|
103 |
+
text_generator.eval()
|
104 |
+
sample_outputs = text_generator.generate(generated,
|
105 |
+
do_sample=do_sample,
|
106 |
+
min_length=200,
|
107 |
+
max_length=max_length,
|
108 |
+
top_k=top_k,
|
109 |
+
top_p=top_p,
|
110 |
+
temperature=temperature,
|
111 |
+
repetition_penalty=repetition_penalty,
|
112 |
+
num_return_sequences=1
|
113 |
+
)
|
114 |
+
result = tokenizer.decode(sample_outputs[0], skip_special_tokens=True)
|
115 |
+
print(f"result: {result}")
|
116 |
+
prefix_length = len(title) + len(keywords)
|
117 |
+
result = result[prefix_length:]
|
118 |
+
return result
|
119 |
+
|
120 |
+
|
121 |
+
st.title("Indonesian GPT-2 Applications")
|
122 |
+
prompt_group_name = MODELS[model_type]["group"]
|
123 |
+
st.header(prompt_group_name)
|
124 |
+
description = f"This is a bilingual (Indonesian and English) abstract generator using Indonesian GPT-2 Medium. We finetuned it with the Indonesian paper abstract dataset."
|
125 |
+
st.markdown(description)
|
126 |
+
model_name = f"Model name: [{MODELS[model_type]['name']}](https://huggingface.co/{MODELS[model_type]['name']})"
|
127 |
+
st.markdown(model_name)
|
128 |
+
if prompt_group_name in ["Indonesian GPT-2", "Indonesian Literature", "Indonesian Journal"]:
|
129 |
+
session_state = SessionState.get(prompt=None, prompt_box=None, text=None)
|
130 |
+
ALL_PROMPTS = list(PROMPT_LIST[prompt_group_name].keys())+["Custom"]
|
131 |
+
|
132 |
+
prompt = st.selectbox('Prompt', ALL_PROMPTS, index=len(ALL_PROMPTS)-1)
|
133 |
+
|
134 |
+
# Update prompt
|
135 |
+
if session_state.prompt is None:
|
136 |
+
session_state.prompt = prompt
|
137 |
+
elif session_state.prompt is not None and (prompt != session_state.prompt):
|
138 |
+
session_state.prompt = prompt
|
139 |
+
session_state.prompt_box = None
|
140 |
+
else:
|
141 |
+
session_state.prompt = prompt
|
142 |
+
|
143 |
+
# Update prompt box
|
144 |
+
if session_state.prompt == "Custom":
|
145 |
+
session_state.prompt_box = ""
|
146 |
+
session_state.title = ""
|
147 |
+
session_state.keywords = ""
|
148 |
+
else:
|
149 |
+
if session_state.prompt is not None and session_state.prompt_box is None:
|
150 |
+
session_state.prompt_box = random.choice(PROMPT_LIST[prompt_group_name][session_state.prompt])
|
151 |
+
|
152 |
+
session_state.title = st.text_input("Title", session_state.title)
|
153 |
+
session_state.keywords = st.text_input("Keywords", session_state.keywords)
|
154 |
+
session_state.text = st.text_area("Prompt", session_state.prompt_box)
|
155 |
+
|
156 |
+
max_length = st.sidebar.number_input(
|
157 |
+
"Maximum length",
|
158 |
+
value=200,
|
159 |
+
max_value=512,
|
160 |
+
help="The maximum length of the sequence to be generated."
|
161 |
+
)
|
162 |
+
|
163 |
+
temperature = st.sidebar.slider(
|
164 |
+
"Temperature",
|
165 |
+
value=0.4,
|
166 |
+
min_value=0.0,
|
167 |
+
max_value=2.0
|
168 |
+
)
|
169 |
+
|
170 |
+
do_sample = st.sidebar.checkbox(
|
171 |
+
"Use sampling",
|
172 |
+
value=True
|
173 |
+
)
|
174 |
+
|
175 |
+
top_k = 30
|
176 |
+
top_p = 0.95
|
177 |
+
|
178 |
+
if do_sample:
|
179 |
+
top_k = st.sidebar.number_input(
|
180 |
+
"Top k",
|
181 |
+
value=top_k,
|
182 |
+
help="The number of highest probability vocabulary tokens to keep for top-k-filtering."
|
183 |
+
)
|
184 |
+
top_p = st.sidebar.number_input(
|
185 |
+
"Top p",
|
186 |
+
value=top_p,
|
187 |
+
help="If set to float < 1, only the most probable tokens with probabilities that add up to top_p or higher "
|
188 |
+
"are kept for generation."
|
189 |
+
)
|
190 |
+
|
191 |
+
seed = st.sidebar.number_input(
|
192 |
+
"Random Seed",
|
193 |
+
value=25,
|
194 |
+
help="The number used to initialize a pseudorandom number generator"
|
195 |
+
)
|
196 |
+
|
197 |
+
repetition_penalty = 0.0
|
198 |
+
automatic_repetition_penalty = st.sidebar.checkbox(
|
199 |
+
"Automatic Repetition Penalty",
|
200 |
+
value=True
|
201 |
+
)
|
202 |
+
|
203 |
+
if not automatic_repetition_penalty:
|
204 |
+
repetition_penalty = st.sidebar.slider(
|
205 |
+
"Repetition Penalty",
|
206 |
+
value=1.0,
|
207 |
+
min_value=1.0,
|
208 |
+
max_value=2.0
|
209 |
+
)
|
210 |
+
|
211 |
+
for group_name in MODELS:
|
212 |
+
if MODELS[group_name]["group"] in ["Indonesian GPT-2", "Indonesian Literature", "Indonesian Journal"]:
|
213 |
+
MODELS[group_name]["text_generator"], MODELS[group_name]["tokenizer"] = \
|
214 |
+
get_generator(MODELS[group_name]["name"])
|
215 |
+
|
216 |
+
if st.button("Run"):
|
217 |
+
with st.spinner(text="Getting results..."):
|
218 |
+
memory = psutil.virtual_memory()
|
219 |
+
st.subheader("Result")
|
220 |
+
time_start = time.time()
|
221 |
+
# text_generator = MODELS[model_type]["text_generator"]
|
222 |
+
result = process(MODELS[model_type]["text_generator"], MODELS[model_type]["tokenizer"],
|
223 |
+
title=session_state.title,
|
224 |
+
keywords=session_state.keywords,
|
225 |
+
text=session_state.text, max_length=int(max_length),
|
226 |
+
temperature=temperature, do_sample=do_sample,
|
227 |
+
top_k=int(top_k), top_p=float(top_p), seed=seed, repetition_penalty=repetition_penalty)
|
228 |
+
time_end = time.time()
|
229 |
+
time_diff = time_end-time_start
|
230 |
+
#result = result[0]["generated_text"]
|
231 |
+
st.write(result.replace("\n", " \n"))
|
232 |
+
st.text("Translation")
|
233 |
+
translation = translate(result, "en", "id")
|
234 |
+
st.write(translation.replace("\n", " \n"))
|
235 |
+
# st.write(f"*do_sample: {do_sample}, top_k: {top_k}, top_p: {top_p}, seed: {seed}*")
|
236 |
+
info = f"""
|
237 |
+
*Memory: {memory.total/(1024*1024*1024):.2f}GB, used: {memory.percent}%, available: {memory.available/(1024*1024*1024):.2f}GB*
|
238 |
+
*Text generated in {time_diff:.5} seconds*
|
239 |
+
"""
|
240 |
+
st.write(info)
|
241 |
+
|
242 |
+
# Reset state
|
243 |
+
session_state.prompt = None
|
244 |
+
session_state.prompt_box = None
|
app/prompts.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
PROMPT_LIST = {
|
2 |
+
"Indonesian GPT-2": {
|
3 |
+
"Resep masakan (recipe)": [
|
4 |
+
"Berikut adalah cara memasak sate ayam:\n",
|
5 |
+
"Langkah-langkah membuat nasi goreng:\n",
|
6 |
+
"Berikut adalah bahan-bahan membuat nastar:\n"
|
7 |
+
],
|
8 |
+
"Puisi (poetry)": [
|
9 |
+
"Aku ingin jadi merpati\nTerbang di langit yang damai\nBernyanyi-nyanyi tentang masa depan\n",
|
10 |
+
"Terdiam aku satu persatu dengan tatapan binar\nSenyawa merasuk dalam sukma membuat lara\nKefanaan membentuk kelemahan"
|
11 |
+
],
|
12 |
+
"Cerpen (short story)": [
|
13 |
+
"Putri memakai sepatunya dengan malas. Kalau bisa, selama seminggu ini ia bolos sekolah saja. Namun, Mama pasti akan marah. Ulangan tengah semester telah selesai. Minggu ini, di sekolah sedang berlangsung pekan olahraga.",
|
14 |
+
"\"Wah, hari ini cerah sekali ya,\" ucap Budi ketika ia keluar rumah.",
|
15 |
+
"Sewindu sudah kita tak berjumpa, rinduku padamu sudah tak terkira."
|
16 |
+
],
|
17 |
+
"Sejarah (history)": [
|
18 |
+
"Mohammad Natsir adalah seorang ulama, politisi, dan pejuang kemerdekaan Indonesia.",
|
19 |
+
"Ir. H. Soekarno adalah Presiden pertama Republik Indonesia. Ia adalah seorang tokoh perjuangan yang memainkan peranan penting dalam memerdekakan bangsa Indonesia",
|
20 |
+
"Borobudur adalah sebuah candi Buddha yang terletak di sebelah barat laut Yogyakarta. Monumen ini merupakan model alam semesta dan dibangun sebagai tempat suci untuk memuliakan Buddha"
|
21 |
+
],
|
22 |
+
},
|
23 |
+
"Indonesian Literature": {
|
24 |
+
"Adult Romance": [
|
25 |
+
"Ini adalah kisah tentang seorang laki-laki yang berusaha memperjuangkan cintanya",
|
26 |
+
"Alunan musik terdengar memenuhi ruangan kantor, cowok itu duduk di balik meja kerjanya sambil memejamkan mata. Berusaha meresapi nada per nada",
|
27 |
+
"Aku mencari dan terus mencari\nDimana bahagia akan kutemui\nKumencari terus mencari\nHingga ku tak mengerti arti hari-hari",
|
28 |
+
"Gadis itu mengharuskan dirinya tegar, dan kuat dalam menghadapi masalah. Menahan air matanya jatuh setiap kali ingin menangis"
|
29 |
+
],
|
30 |
+
"Horror": [
|
31 |
+
"Ditengah-tengah perbincangan mereka berdua, datanglah sesosok mahluk tinggi hitam dan besar",
|
32 |
+
"Sesosok hantu perempuan seperti kuntilanak yang melayang keluar dan bergerak perlahan dari pintu kamar kecil tadi yang tertutup.",
|
33 |
+
"Sejak pertemuannya dengan leak, yang ternyata tinggal satu atap dengannya, hidupnya terus dihantui oleh berbagai sosok seram."
|
34 |
+
],
|
35 |
+
"Poetry": [
|
36 |
+
"Aku ingin menulis sajak\nyang melesat dalam kejap\nmenembus hati yang pejam\nmemaksa mimpimu terjaga\ndari semu",
|
37 |
+
"Malam ini langitku lengang\ntiada hujan yang membasuh rindu\npun awan yang biasanya temani seruput kopimu",
|
38 |
+
"Di sisimu waktu menjelma\nsetangkai kembang api\ngelora membakar tanpa jeda\nmemercik pijar binar kita."
|
39 |
+
]
|
40 |
+
},
|
41 |
+
"Indonesian Journal": {
|
42 |
+
"Biologi (biology)": [
|
43 |
+
"Tujuan penelitian ini untuk menentukan keanekaragaman Arthropoda pada lahan pertanian kacang",
|
44 |
+
"Identifikasi spesies secara molekuler sangat diperlukan dalam mempelajari taksonomi",
|
45 |
+
"Penelitian ini bertujuan untuk menentukan identitas invertebrata laut dari Perairan Papua dengan teknik DNA barcoding"],
|
46 |
+
"Psikologi (psychology)": [
|
47 |
+
"Penelitian ini bertujuan untuk mengetahui perilaku wirausaha remaja yang diprediksi dari motivasi intrinsik",
|
48 |
+
"Tujuan dari penelitian ini adalah untuk mendapatkan data empiris mengenai gambaran peta bakat mahasiswa Fakultas Psikologi Unjani"],
|
49 |
+
"Ekonomi (economics)": [
|
50 |
+
"Faktor kepuasan dan kepercayaan konsumen merupakan dua faktor kunci dalam meningkatkan penetrasi e-commerce. Penelitian yang dilakukan",
|
51 |
+
"Penelitian ini bertujuan untuk menganalisis pola konsumsi pangan di Indonesia",
|
52 |
+
"Model GTAP diimplementasikan untuk melihat dampak yang ditimbulkan pada PDB"],
|
53 |
+
"Teknologi Informasi (IT)": [
|
54 |
+
"pembuatan aplikasi ini menggunakan pengembangan metode Waterfall dan dirancang mengguynakan Unified Modeling Language (UML) dengan bahasa pemrograman",
|
55 |
+
"Berdasarkan masalah tersebut, maka penulis termotivasi untuk membangun Pengembangan Sistem Informasi Manajemen"]
|
56 |
+
},
|
57 |
+
}
|
requirements.txt
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
numpy
|
2 |
+
torch
|
3 |
+
tokenizers
|
4 |
+
transformers
|
5 |
+
datasets
|
6 |
+
mtranslate
|
7 |
+
# streamlit version 0.67.1 is needed due to issue with caching
|
8 |
+
# streamlit==0.67.1
|
9 |
+
streamlit
|
10 |
+
psutil
|