Christopher Capobianco commited on
Commit
385b1f2
·
1 Parent(s): 938985b

Add Document Classifier project

Browse files
Home.py CHANGED
@@ -9,12 +9,24 @@ st.markdown('Please have a look at the descriptions below, and select a project
9
 
10
  st.header('Projects', divider='red')
11
 
 
12
  mv = Image.open("assets/movie.jpg")
13
  # wp = Image.open("assets/weather.png")
14
  sm = Image.open("assets/stock-market.png")
15
  mu = Image.open("assets/music.jpg")
16
  llm = Image.open("assets/llm.png")
17
 
 
 
 
 
 
 
 
 
 
 
 
18
  with st.container():
19
  text_column, image_column = st.columns((3,1))
20
  with text_column:
 
9
 
10
  st.header('Projects', divider='red')
11
 
12
+ do = Image.open("assets/document.jpg")
13
  mv = Image.open("assets/movie.jpg")
14
  # wp = Image.open("assets/weather.png")
15
  sm = Image.open("assets/stock-market.png")
16
  mu = Image.open("assets/music.jpg")
17
  llm = Image.open("assets/llm.png")
18
 
19
+ with st.container():
20
+ text_column, image_column = st.columns((3,1))
21
+ with text_column:
22
+ st.subheader("Document Classifier", divider="green")
23
+ st.markdown("""
24
+ - Used OCR text and a Random Forest classification model to predict a document's classification
25
+ - Trained on Real World Documents Collection at Kaggle
26
+ """)
27
+ with image_column:
28
+ st.image(do)
29
+
30
  with st.container():
31
  text_column, image_column = st.columns((3,1))
32
  with text_column:
app.py CHANGED
@@ -5,6 +5,7 @@ st.set_page_config(page_title="Chris Capobianco's Profile", page_icon=':rocket:'
5
 
6
  home = st.Page('Home.py', title = 'Home')
7
 
 
8
  movie_recommendation = st.Page('projects/02_Movie_Recommendation.py', title='Movie Recommendation')
9
  # weather_classification = st.Page('projects/04_Weather_Classification.py', title='Weather Classification')
10
  stock_market = st.Page('projects/05_Stock_Market.py', title='Stock Market Forecast')
@@ -17,6 +18,7 @@ pg = st.navigation(
17
  home
18
  ],
19
  'Projects': [
 
20
  movie_recommendation,
21
  # weather_classification,
22
  stock_market,
 
5
 
6
  home = st.Page('Home.py', title = 'Home')
7
 
8
+ document_classification = st.Page('projects/01_Document_Classifier.py', title='Document Classifier')
9
  movie_recommendation = st.Page('projects/02_Movie_Recommendation.py', title='Movie Recommendation')
10
  # weather_classification = st.Page('projects/04_Weather_Classification.py', title='Weather Classification')
11
  stock_market = st.Page('projects/05_Stock_Market.py', title='Stock Market Forecast')
 
18
  home
19
  ],
20
  'Projects': [
21
+ document_classification,
22
  movie_recommendation,
23
  # weather_classification,
24
  stock_market,
assets/document.jpg ADDED
models/autoclassifier.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:85fbfe655117e18cba957ced3fec41d9c243013461682d0f5c296762cda54d9c
3
+ size 5116548
projects/01_Document_Classifier.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import easyocr
3
+ import pickle
4
+ import spacy
5
+ import en_core_web_sm
6
+ import re
7
+ import os
8
+
9
+ # Function to Load the Spacy tokenizer
10
+ @st.cache_data
11
+ def load_nlp():
12
+ return spacy.load('en_core_web_sm')
13
+
14
+ # Function to Initialze the OCR Engine
15
+ @st.cache_resource
16
+ def load_ocr_engine():
17
+ return easyocr.Reader(['en'])
18
+
19
+ # Function to Load the model
20
+ @st.cache_resource
21
+ def load_model():
22
+ with open('models/autoclassifier.pkl', 'rb') as model_file:
23
+ stopwords = pickle.load(model_file)
24
+ punctuations = pickle.load(model_file)
25
+ model_pipe = pickle.load(model_file)
26
+ return (stopwords, punctuations, model_pipe)
27
+
28
+ # Function to tokenize the text
29
+ def tokenizer(sentence):
30
+ # Process the text
31
+ doc = nlp(sentence)
32
+
33
+ # Convert tokens to lemma form for all except '-PRON-'
34
+ # Recall: Tokens like 'I', 'my', 'me' are represented as '-PRON-' by lemma attribute (See SpaCy Introduction)
35
+ tokens = [ token.lemma_.lower().strip() if token.lemma_ != "-PRON-" else token.lower_ for token in doc ]
36
+
37
+ # Remove stop words and punctuations
38
+ tokens = [ token for token in tokens if token not in stopwords and token not in punctuations ]
39
+
40
+ return tokens
41
+
42
+ # Function to process uploaded images
43
+ @st.cache_data
44
+ def autoclassifier(images):
45
+ # Iterate through all uploaded images
46
+ with st.spinner(f"Processing Images"):
47
+ for image in images:
48
+ # Write bytes to disk
49
+ with open(image.name, 'wb') as f:
50
+ f.write(image.read())
51
+
52
+ # Load image into OCR Engine and extract text
53
+ raw_ocr = ocr_engine.readtext(image.name)
54
+
55
+ # Extract relevant words from raw OCR
56
+ words = ''
57
+ for (bbox, text, prob) in raw_ocr:
58
+ # Only keep OCR text with 50% probability or higher
59
+ if prob > 0.5:
60
+ # Filter out any digits
61
+ text = re.sub('[0-9]+', '', text)
62
+ # If we have any characters left, append to string
63
+ if text != '':
64
+ words += ' ' + text
65
+ # Pass filtered OCR string to the model
66
+ doc_type = model_pipe.predict([words])
67
+
68
+ # Report filename and document class
69
+ st.info(f"filename: '{image.name}', doc_type: '{doc_type[0]}'")
70
+
71
+ # Delete image file
72
+ os.remove(image.name)
73
+
74
+ st.header('Document Classifier', divider='green')
75
+
76
+ st.markdown("#### What is OCR?")
77
+ st.markdown("OCR stands for Optical Character Recognition, and the technology for it has been around for over 30 years.")
78
+ st.markdown("In this project, we leverage the extraction of the text from an image to classify the document. I am using EasyOCR as the OCR Engine, and I do some pre-processing of the raw OCR text to improve the quality of the words used to classify the documents.")
79
+ st.markdown("After an investigation I settled on a Random Forest classifier for this project, since it had the best classification accuracy of the different models I investigated.")
80
+ st.markdown("This project makes use of the [Real World Documents Collections](https://www.kaggle.com/datasets/shaz13/real-world-documents-collections) found at `Kaggle`")
81
+ st.markdown("*This project is based off the tutorial by Animesh Giri [Intelligent Document Classification](https://www.kaggle.com/code/animeshgiri/intelligent-document-classification)*")
82
+ st.markdown("*N.B. I created a similar document classifier in my first ML project, but that relied on IBM's Datacap for the OCR Engine. I also used a Support Vector Machine (SVM) classifier library (libsvm) at the time, but it was slow to train. I tried to re-create that document classifier again, using open source tools and modern techniques outlined in the referenced tutorial.*")
83
+ st.divider()
84
+
85
+ # Load the Spacy tokenizer
86
+ nlp = load_nlp()
87
+
88
+ # Initialze the OCR Engine
89
+ ocr_engine = load_ocr_engine()
90
+
91
+ # Load the Model
92
+ stopwords, punctuations, model_pipe = load_model()
93
+
94
+ # Fetch uploaded images
95
+ images = st.file_uploader(
96
+ "Choose an image to classify",
97
+ type=['png','jpg','jpeg'],
98
+ accept_multiple_files=True
99
+ )
100
+
101
+ # Process and predict document classification
102
+ autoclassifier(images)