anandhu-pk's picture
Update app.py
0eb2b65
import streamlit as st
import pickle
import tensorflow as tf
from tensorflow.keras.preprocessing import sequence
from tensorflow.keras.saving import load_model
import numpy as np
from PIL import Image
st.title('Classifier System')
task = st.selectbox('Select Task', ['Choose one','Sentiment Classification', 'Tumor Detection'])
if task=='Tumor Detection':
st.subheader('Tumor Detection with CNN')
# CNN
#with open(r"E:\DUK\DUKSEM3\DEEP_LEARNING\ASSIGN1\Multi-Modal_classifier_Image_Classification_Sentiment_Sentiment_Analysis\CNN\cnn_model.pkl",'rb') as file:
#cnn_model = pickle.load(file)
cnn_model = load_model("cnn_model1.h5")
img = st.file_uploader('Upload image', type=['jpeg', 'jpg', 'png'])
def cnn_make_prediction(img,model):
img=Image.open(img)
img=img.resize((128,128))
img=np.array(img)
input_img = np.expand_dims(img, axis=0)
res = model.predict(input_img)
if res:
return "Tumor Detected"
else:
return "No Tumor Detected"
if img is not None:
st.image(img, caption = "Image preview")
if st.button('Submit'):
pred = cnn_make_prediction(img, cnn_model)
st.write(pred)
if task=='Sentiment Classification':
arcs = ['Perceptron', 'Backpropagation', 'DNN', 'RNN', 'LSTM']
arc = st.radio('Pick one:', arcs, horizontal=True)
if arc == arcs[0]:
# Perceptron
with open("pnn_model.pkl",'rb') as file:
perceptron = pickle.load(file)
with open("pnn_tokeniser.pkl",'rb') as file:
ppn_tokeniser = pickle.load(file)
def ppn_make_predictions(inp, model):
encoded_inp = ppn_tokeniser.texts_to_sequences([inp])
padded_inp = sequence.pad_sequences(encoded_inp, maxlen=500)
res = model.predict(padded_inp)
if res:
return "Not spam"
else:
return "Spam"
st.subheader('SMS spam Classification using Perceptron')
inp = st.text_area('Enter message')
if st.button('Check'):
pred = ppn_make_predictions([inp], perceptron)
st.write(pred)
elif arc == arcs[1]:
# BackPropogation
with open("bpn_model.pkl",'rb') as file:
backprop = pickle.load(file)
with open("bpn_tokeniser.pkl",'rb') as file:
bp_tokeniser = pickle.load(file)
def bp_make_predictions(inp, model):
encoded_inp = bp_tokeniser.texts_to_sequences([inp])
padded_inp = sequence.pad_sequences(encoded_inp, maxlen=500)
res = model.predict(padded_inp)
if res:
return "Not spam"
else:
return "Spam"
st.subheader('SMS spam Classification using Backpropagation')
inp = st.text_area('Enter message')
if st.button('Check'):
pred = bp_make_predictions([inp], backprop)
st.write(pred)
elif arc == arcs[2]:
# DNN
dnn_model = load_model("dnn_model1.h5")
#with open("dnn_model.pkl",'rb') as file:
#dnn_model = pickle.load(file)
with open("dnn_tokeniser.pkl",'rb') as file:
dnn_tokeniser = pickle.load(file)
def dnn_make_predictions(inp, model):
inp = dnn_tokeniser.texts_to_sequences(inp)
inp = sequence.pad_sequences(inp, maxlen=500)
res = (model.predict(inp) > 0.5).astype("int32")
if res:
return "Not spam"
else:
return "Spam"
st.subheader('SMS spam Classification using DNN')
inp = st.text_area('Enter message')
if st.button('Check'):
pred = dnn_make_predictions([inp], dnn_model)
st.write(pred)
elif arc == arcs[3]:
# RNN
#with open("rnn_model.pkl",'rb') as file:
#rnn_model = pickle.load(file)
rnn_model = load_model("rnn_model1.h5")
with open("rnn_tokeniser.pkl",'rb') as file:
rnn_tokeniser = pickle.load(file)
def rnn_make_predictions(inp, model):
encoded_inp = rnn_tokeniser.texts_to_sequences(inp)
padded_inp = sequence.pad_sequences(encoded_inp, maxlen=10, padding='post')
res = (model.predict(padded_inp) > 0.5).astype("int32")
if res:
return "Spam"
else:
return "Not spam"
st.subheader('SMS Spam Classification using RNN')
inp = st.text_area('Enter message')
if st.button('Check'):
pred = rnn_make_predictions([inp], rnn_model)
st.write(pred)
elif arc == arcs[4]:
# LSTM
#with open("lstm_model.pkl",'rb') as file:
#lstm_model = pickle.load(file)
lstm_model = load_model("lstm_model1.h5")
with open("lstm_tokeniser.pkl",'rb') as file:
lstm_tokeniser = pickle.load(file)
def lstm_make_predictions(inp, model):
inp = lstm_tokeniser.texts_to_sequences(inp)
inp = sequence.pad_sequences(inp, maxlen=500)
res = (model.predict(inp) > 0.5).astype("int32")
if res:
return "Not spam"
else:
return "Spam"
st.subheader('SMS spam Classification using LSTM')
inp = st.text_area('Enter message')
if st.button('Check'):
pred = lstm_make_predictions([inp], lstm_model)
st.write(pred)