yxmauw's picture
Update app.py
07e66aa
raw
history blame
2.33 kB
import streamlit as st
from PIL import Image
import pandas as pd
import numpy as np
from model_methods import predict
import base64 # for title image
from load_css import local_css # for highlighting text
@st.cache # cached so that latency for subsequent runs are shorter
def import_nltk():
import nltk
nltk.download('wordnet')
nltk.download('omw-1.4')
# configuration of the page
st.set_page_config(
layout='centered',
page_icon=Image.open('subreddit_icon.png'),
page_title='Marvel vs. DC comics',
initial_sidebar_state='auto'
)
# embed source link in title image using base64 module
# reference: https://discuss.streamlit.io/t/how-to-show-local-gif-image/3408/4
# reference: https://discuss.streamlit.io/t/local-image-button/5409/4
im = open("subreddit_icon.png", "rb")
contents = im.read()
im_base64 = base64.b64encode(contents).decode("utf-8")
im.close()
html = f'''<a href='https://www.reddit.com/'>
<img src='data:image/png;base64,{im_base64}' width='100'>
</a><figcaption>Credit: reddit.com</figcaption>'''
st.markdown(html, unsafe_allow_html=True)
st.title('Subreddit Post classifier')
local_css("highlight_text.css")
text = '''The algorithm driving this app is built using subreddit posts published
between April and July 2022. It is only able to classify between
<span class='highlight blue'> **Marvel** </span>
and
<span class='highlight blue'> **DC Comics** </span>
subreddits.'''
st.markdown(text, unsafe_allow_html=True)
# Area for text input
import_nltk() # import nltk module if not yet cached in local computer
new_post = st.text_input('Please copy and paste the subreddit post here', '')
# process new input
def predict_post():
data = pd.Series(new_post) # pd.Series format new input coz that is the format that predict() recognises
result = predict(data)
if result == 1:
post = 'Marvel'
if result == 0:
post = 'DC comics'
st.write(f'### This post belongs to')
st.success(f'# {post}')
st.write(f'### subreddit')
# instantiate submit button
if st.button('Submit'):
with st.sidebar:
try:
predict_post()
except:
st.warning('''
Unable to detect text.
Please enter text for prediction.
\n\n Thank you πŸ™.
''')