File size: 2,326 Bytes
4f78275
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
07e66aa
4f78275
 
 
 
 
 
 
07e66aa
4f78275
 
 
 
 
 
 
 
 
07e66aa
4f78275
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import streamlit as st
from PIL import Image
import pandas as pd
import numpy as np
from model_methods import predict
import base64 # for title image
from load_css import local_css # for highlighting text

@st.cache # cached so that latency for subsequent runs are shorter
def import_nltk():
    import nltk
    nltk.download('wordnet')
    nltk.download('omw-1.4')

# configuration of the page
st.set_page_config(
    layout='centered',
    page_icon=Image.open('subreddit_icon.png'),
    page_title='Marvel vs. DC comics',
    initial_sidebar_state='auto'
)

# embed source link in title image using base64 module
# reference: https://discuss.streamlit.io/t/how-to-show-local-gif-image/3408/4
# reference: https://discuss.streamlit.io/t/local-image-button/5409/4
im = open("subreddit_icon.png", "rb")
contents = im.read()
im_base64 = base64.b64encode(contents).decode("utf-8")
im.close()
html = f'''<a href='https://www.reddit.com/'> 
            <img src='data:image/png;base64,{im_base64}' width='100'>
            </a><figcaption>Credit: reddit.com</figcaption>'''
st.markdown(html, unsafe_allow_html=True)

st.title('Subreddit Post classifier')
local_css("highlight_text.css")
text = '''The algorithm driving this app is built using subreddit posts published 
between April and July 2022. It is only able to classify between 
<span class='highlight blue'> **Marvel** </span>
and 
<span class='highlight blue'> **DC Comics** </span>
subreddits.'''
st.markdown(text, unsafe_allow_html=True)

# Area for text input
import_nltk() # import nltk module if not yet cached in local computer
new_post = st.text_input('Please copy and paste the subreddit post here', '')

# process new input
def predict_post():
    data = pd.Series(new_post) # pd.Series format new input coz that is the format that predict() recognises
    result = predict(data)
    if result == 1:
        post = 'Marvel'
    if result == 0:
        post = 'DC comics'
    st.write(f'### This post belongs to') 
    st.success(f'# {post}')
    st.write(f'### subreddit')

# instantiate submit button
if st.button('Submit'):
   with st.sidebar:
       try: 
            predict_post()
       except:
            st.warning('''
            Unable to detect text. 
            Please enter text for prediction. 
            \n\n Thank you 🙏. 
            ''')