AI-in-Dentistry / app.py
AI-RESEARCHER-2024's picture
Update app.py
7d05b33 verified
raw
history blame
12.5 kB
import os
import gradio as gr
import tensorflow as tf
from tensorflow.keras.preprocessing import image as image_processor
import numpy as np
from tensorflow.keras.applications.vgg16 import preprocess_input
from tensorflow.keras.models import load_model
from PIL import Image, ImageDraw, ImageFont
from ultralytics import YOLO
import cv2
class Config:
ASSETS_DIR = './assets'
MODELS_DIR = './models'
FONT_DIR = './assets/arial.ttf'
MODELS = {
"Calculus and Caries Classification": "classification.h5",
"Caries Detection": "detection.pt",
"Dental X-Ray Segmentation": "dental_xray_seg.h5"
}
EXAMPLES = {
"Calculus and Caries Classification": os.path.join(ASSETS_DIR, 'classification'),
"Caries Detection": os.path.join(ASSETS_DIR, 'detection'),
"Dental X-Ray Segmentation": os.path.join(ASSETS_DIR, 'segmentation')
}
class ModelManager:
@staticmethod
def load_model(model_name: str):
model_path = os.path.join(Config.MODELS_DIR, Config.MODELS[model_name])
if model_name == "Dental X-Ray Segmentation":
return load_model(model_path)
elif model_name == "Caries Detection":
return YOLO(model_path)
else:
return load_model(model_path)
class ImageProcessor:
def process_image(self, image: Image.Image, model_name: str):
if model_name == "Calculus and Caries Classification":
return self.classify_image(image, model_name)
elif model_name == "Caries Detection":
return self.detect_caries(image)
elif model_name == "Dental X-Ray Segmentation":
return self.segment_dental_xray(image)
def classify_image(self, image: Image.Image, model_name: str):
model = ModelManager.load_model(model_name)
img = image.resize((224, 224))
x = image_processor.img_to_array(img)
x = np.expand_dims(x, axis=0)
img_data = preprocess_input(x)
result = model.predict(img_data)
if result[0][0] > result[0][1]:
prediction = 'Calculus'
else:
prediction = 'Caries'
# Draw the classification result on the image
draw = ImageDraw.Draw(image)
font = ImageFont.truetype(Config.FONT_DIR, 20)
text = f"Classified as: {prediction}"
text_width, text_height = draw.textsize(text, font=font)
draw.rectangle([(0, 0), (text_width, text_height)], fill="black")
draw.text((0, 0), text, fill="white", font=font)
return image
def detect_caries(self, image: Image.Image):
model = ModelManager.load_model("Caries Detection")
results = model.predict(image)
result = results[0]
draw = ImageDraw.Draw(image)
font = ImageFont.truetype(Config.FONT_DIR, 20)
for box in result.boxes:
x1, y1, x2, y2 = [round(x) for x in box.xyxy[0].tolist()]
class_id = box.cls[0].item()
prob = round(box.conf[0].item(), 2)
label = f"{result.names[class_id]}: {prob}"
draw.rectangle([x1, y1, x2, y2], outline="red", width=2)
text_width, text_height = draw.textsize(label, font=font)
draw.rectangle([(x1, y1 - text_height), (x1 + text_width, y1)], fill="red")
draw.text((x1, y1 - text_height), label, fill="white", font=font)
return image
def segment_dental_xray(self, image: Image.Image):
model = ModelManager.load_model("Dental X-Ray Segmentation")
img = np.asarray(image)
img_cv = self.convert_one_channel(img)
img_cv = cv2.resize(img_cv, (512, 512), interpolation=cv2.INTER_LANCZOS4)
img_cv = np.float32(img_cv / 255)
img_cv = np.reshape(img_cv, (1, 512, 512, 1))
prediction = model.predict(img_cv)
predicted = prediction[0]
predicted = cv2.resize(predicted, (img.shape[1], img.shape[0]), interpolation=cv2.INTER_LANCZOS4)
mask = np.uint8(predicted * 255)
_, mask = cv2.threshold(mask, thresh=0, maxval=255, type=cv2.THRESH_BINARY + cv2.THRESH_OTSU)
kernel = np.ones((5, 5), dtype=np.float32)
mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel, iterations=1)
mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel, iterations=1)
cnts, _ = cv2.findContours(mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
# Make a writable copy of the image
img_writable = self.convert_rgb(img).copy()
output = cv2.drawContours(img_writable, cnts, -1, (255, 0, 0), 3)
return Image.fromarray(output)
def convert_one_channel(self, img):
if len(img.shape) > 2:
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
return img
def convert_rgb(self, img):
if len(img.shape) == 2:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
return img
class GradioInterface:
def __init__(self):
self.image_processor = ImageProcessor()
self.preloaded_examples = self.preload_examples()
def preload_examples(self):
preloaded = {}
for model_name, example_dir in Config.EXAMPLES.items():
examples = [os.path.join(example_dir, img) for img in os.listdir(example_dir)]
preloaded[model_name] = examples
return preloaded
def create_interface(self):
app_styles = """
<style>
/* Global Styles */
body, #root {
font-family: Helvetica, Arial, sans-serif;
background-color: #1a1a1a;
color: #fafafa;
}
/* Header Styles */
.app-header {
background: linear-gradient(45deg, #1a1a1a 0%, #333333 100%);
padding: 24px;
border-radius: 8px;
margin-bottom: 24px;
text-align: center;
}
.app-title {
font-size: 48px;
margin: 0;
color: #fafafa;
}
.app-subtitle {
font-size: 24px;
margin: 8px 0 16px;
color: #fafafa;
}
.app-description {
font-size: 16px;
line-height: 1.6;
opacity: 0.8;
margin-bottom: 24px;
}
/* Button Styles */
.publication-links {
display: flex;
justify-content: center;
flex-wrap: wrap;
gap: 8px;
margin-bottom: 16px;
}
.publication-link {
display: inline-flex;
align-items: center;
padding: 8px 16px;
background-color: #333;
color: #fff !important;
text-decoration: none !important;
border-radius: 20px;
font-size: 14px;
transition: background-color 0.3s;
}
.publication-link:hover {
background-color: #555;
}
.publication-link i {
margin-right: 8px;
}
/* Content Styles */
.content-container {
background-color: #2a2a2a;
border-radius: 8px;
padding: 24px;
margin-bottom: 24px;
}
/* Image Styles */
.image-preview img {
max-width: 512px;
max-height: 512px;
margin: 0 auto;
border-radius: 4px;
display: block;
object-fit: contain;
}
/* Control Styles */
.control-panel {
background-color: #333;
padding: 16px;
border-radius: 8px;
margin-top: 16px;
}
/* Gradio Component Overrides */
.gr-button {
background-color: #4a4a4a;
color: #fff;
border: none;
border-radius: 4px;
padding: 8px 16px;
cursor: pointer;
transition: background-color 0.3s;
}
.gr-button:hover {
background-color: #5a5a5a;
}
.gr-input, .gr-dropdown {
background-color: #3a3a3a;
color: #fff;
border: 1px solid #4a4a4a;
border-radius: 4px;
padding: 8px;
}
.gr-form {
background-color: transparent;
}
.gr-panel {
border: none;
background-color: transparent;
}
/* Override any conflicting styles from Bulma */
.button.is-normal.is-rounded.is-dark {
color: #fff !important;
text-decoration: none !important;
}
</style>
"""
header_html = f"""
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/[email protected]/css/bulma.min.css">
<link rel="stylesheet" href="https://use.fontawesome.com/releases/v5.15.4/css/all.css">
{app_styles}
<div class="app-header">
<h1 class="app-title">AI Dentistry Application</h1>
<h2 class="app-subtitle">Model Selection and Image Processing</h2>
<p class="app-description">
This application allows you to select different models for dental image processing tasks.
</p>
</div>
"""
js_func = """
function refresh() {
const url = new URL(window.location);
if (url.searchParams.get('__theme') !== 'dark') {
url.searchParams.set('__theme', 'dark');
window.location.href = url.href;
}
}
"""
def process_image(image, model_name):
result = self.image_processor.process_image(image, model_name)
return result
def update_examples(model_name):
examples = self.preloaded_examples[model_name]
return gr.Dataset(samples=[[example] for example in examples])
with gr.Blocks(js=js_func, theme=gr.themes.Default()) as demo:
gr.HTML(header_html)
with gr.Row(elem_classes="content-container"):
with gr.Column():
input_image = gr.Image(label="Input Image", type="pil", format="png", elem_classes="image-preview")
with gr.Row(elem_classes="control-panel"):
model_name = gr.Dropdown(
label="Model",
choices=list(Config.MODELS.keys()),
value="Calculus and Caries Classification",
)
examples_classification = gr.Examples(
label="Classification Examples",
inputs=input_image,
examples=self.preloaded_examples["Calculus and Caries Classification"],
)
examples_detection = gr.Examples(
label="Caries Detection Examples",
inputs=input_image,
examples=self.preloaded_examples["Caries Detection"],
)
examples_segmentation = gr.Examples(
label="Segmentation Examples",
inputs=input_image,
examples=self.preloaded_examples["Dental X-Ray Segmentation"],
)
with gr.Column():
result = gr.Image(label="Result", elem_classes="image-preview")
run_button = gr.Button("Run", elem_classes="gr-button")
model_name.change(
fn=update_examples,
inputs=model_name,
outputs=[examples_classification.dataset, examples_detection.dataset, examples_segmentation.dataset],
)
run_button.click(
fn=process_image,
inputs=[input_image, model_name],
outputs=result,
)
return demo
def main():
interface = GradioInterface()
demo = interface.create_interface()
demo.launch(debug=True)
if __name__ == "__main__":
main()