Spaces:
Sleeping
Sleeping
import os | |
import gradio as gr | |
import tensorflow as tf | |
from tensorflow.keras.preprocessing import image as image_processor | |
import numpy as np | |
from tensorflow.keras.applications.vgg16 import preprocess_input | |
from tensorflow.keras.models import load_model | |
from PIL import Image, ImageDraw, ImageFont | |
from ultralytics import YOLO | |
import cv2 | |
class Config: | |
ASSETS_DIR = './assets' | |
MODELS_DIR = './models' | |
FONT_DIR = './assets/arial.ttf' | |
MODELS = { | |
"Calculus and Caries Classification": "classification.h5", | |
"Caries Detection": "detection.pt", | |
"Dental X-Ray Segmentation": "dental_xray_seg.h5" | |
} | |
EXAMPLES = { | |
"Calculus and Caries Classification": os.path.join(ASSETS_DIR, 'classification'), | |
"Caries Detection": os.path.join(ASSETS_DIR, 'detection'), | |
"Dental X-Ray Segmentation": os.path.join(ASSETS_DIR, 'segmentation') | |
} | |
class ModelManager: | |
def load_model(model_name: str): | |
model_path = os.path.join(Config.MODELS_DIR, Config.MODELS[model_name]) | |
if model_name == "Dental X-Ray Segmentation": | |
return load_model(model_path) | |
elif model_name == "Caries Detection": | |
return YOLO(model_path) | |
else: | |
return load_model(model_path) | |
class ImageProcessor: | |
def process_image(self, image: Image.Image, model_name: str): | |
if model_name == "Calculus and Caries Classification": | |
return self.classify_image(image, model_name) | |
elif model_name == "Caries Detection": | |
return self.detect_caries(image) | |
elif model_name == "Dental X-Ray Segmentation": | |
return self.segment_dental_xray(image) | |
def classify_image(self, image: Image.Image, model_name: str): | |
model = ModelManager.load_model(model_name) | |
img = image.resize((224, 224)) | |
x = image_processor.img_to_array(img) | |
x = np.expand_dims(x, axis=0) | |
img_data = preprocess_input(x) | |
result = model.predict(img_data) | |
if result[0][0] > result[0][1]: | |
prediction = 'Calculus' | |
else: | |
prediction = 'Caries' | |
# Draw the classification result on the image | |
draw = ImageDraw.Draw(image) | |
font = ImageFont.truetype(Config.FONT_DIR, 20) | |
text = f"Classified as: {prediction}" | |
text_width, text_height = draw.textsize(text, font=font) | |
draw.rectangle([(0, 0), (text_width, text_height)], fill="black") | |
draw.text((0, 0), text, fill="white", font=font) | |
return image | |
def detect_caries(self, image: Image.Image): | |
model = ModelManager.load_model("Caries Detection") | |
results = model.predict(image) | |
result = results[0] | |
draw = ImageDraw.Draw(image) | |
font = ImageFont.truetype(Config.FONT_DIR, 20) | |
for box in result.boxes: | |
x1, y1, x2, y2 = [round(x) for x in box.xyxy[0].tolist()] | |
class_id = box.cls[0].item() | |
# check if the class is tooth | |
if result.names[class_id].lower() == "tooth": | |
continue | |
prob = round(box.conf[0].item(), 2) | |
label = f"{result.names[class_id]}: {prob}" | |
draw.rectangle([x1, y1, x2, y2], outline="red", width=2) | |
text_width, text_height = draw.textsize(label, font=font) | |
draw.rectangle([(x1, y1 - text_height), (x1 + text_width, y1)], fill="red") | |
draw.text((x1, y1 - text_height), label, fill="white", font=font) | |
return image | |
def segment_dental_xray(self, image: Image.Image): | |
model = ModelManager.load_model("Dental X-Ray Segmentation") | |
img = np.asarray(image) | |
img_cv = self.convert_one_channel(img) | |
img_cv = cv2.resize(img_cv, (512, 512), interpolation=cv2.INTER_LANCZOS4) | |
img_cv = np.float32(img_cv / 255) | |
img_cv = np.reshape(img_cv, (1, 512, 512, 1)) | |
prediction = model.predict(img_cv) | |
predicted = prediction[0] | |
predicted = cv2.resize(predicted, (img.shape[1], img.shape[0]), interpolation=cv2.INTER_LANCZOS4) | |
mask = np.uint8(predicted * 255) | |
_, mask = cv2.threshold(mask, thresh=0, maxval=255, type=cv2.THRESH_BINARY + cv2.THRESH_OTSU) | |
kernel = np.ones((5, 5), dtype=np.float32) | |
mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel, iterations=1) | |
mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel, iterations=1) | |
cnts, _ = cv2.findContours(mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) | |
# Make a writable copy of the image | |
img_writable = self.convert_rgb(img).copy() | |
output = cv2.drawContours(img_writable, cnts, -1, (255, 0, 0), 2) | |
return Image.fromarray(output) | |
def convert_one_channel(self, img): | |
if len(img.shape) > 2: | |
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) | |
return img | |
def convert_rgb(self, img): | |
if len(img.shape) == 2: | |
img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) | |
return img | |
class GradioInterface: | |
def __init__(self): | |
self.image_processor = ImageProcessor() | |
self.preloaded_examples = self.preload_examples() | |
def preload_examples(self): | |
preloaded = {} | |
for model_name, example_dir in Config.EXAMPLES.items(): | |
examples = [os.path.join(example_dir, img) for img in os.listdir(example_dir)] | |
preloaded[model_name] = examples | |
return preloaded | |
def create_interface(self): | |
app_styles = """ | |
<style> | |
/* Global Styles */ | |
body, #root { | |
font-family: Helvetica, Arial, sans-serif; | |
background-color: #1a1a1a; | |
color: #fafafa; | |
} | |
/* Header Styles */ | |
.app-header { | |
background: linear-gradient(45deg, #1a1a1a 0%, #333333 100%); | |
padding: 24px; | |
border-radius: 8px; | |
margin-bottom: 24px; | |
text-align: center; | |
} | |
.app-title { | |
font-size: 48px; | |
margin: 0; | |
color: #fafafa; | |
} | |
.app-subtitle { | |
font-size: 24px; | |
margin: 8px 0 16px; | |
color: #fafafa; | |
} | |
.app-description { | |
font-size: 16px; | |
line-height: 1.6; | |
opacity: 0.8; | |
margin-bottom: 24px; | |
} | |
/* Button Styles */ | |
.publication-links { | |
display: flex; | |
justify-content: center; | |
flex-wrap: wrap; | |
gap: 8px; | |
margin-bottom: 16px; | |
} | |
.publication-link { | |
display: inline-flex; | |
align-items: center; | |
padding: 8px 16px; | |
background-color: #333; | |
color: #fff !important; | |
text-decoration: none !important; | |
border-radius: 20px; | |
font-size: 14px; | |
transition: background-color 0.3s; | |
} | |
.publication-link:hover { | |
background-color: #555; | |
} | |
.publication-link i { | |
margin-right: 8px; | |
} | |
/* Content Styles */ | |
.content-container { | |
background-color: #2a2a2a; | |
border-radius: 8px; | |
padding: 24px; | |
margin-bottom: 24px; | |
} | |
/* Image Styles */ | |
.image-preview img { | |
max-width: 512px; | |
max-height: 512px; | |
margin: 0 auto; | |
border-radius: 4px; | |
display: block; | |
object-fit: contain; | |
} | |
/* Control Styles */ | |
.control-panel { | |
background-color: #333; | |
padding: 16px; | |
border-radius: 8px; | |
margin-top: 16px; | |
} | |
/* Gradio Component Overrides */ | |
.gr-button { | |
background-color: #4a4a4a; | |
color: #fff; | |
border: none; | |
border-radius: 4px; | |
padding: 8px 16px; | |
cursor: pointer; | |
transition: background-color 0.3s; | |
} | |
.gr-button:hover { | |
background-color: #5a5a5a; | |
} | |
.gr-input, .gr-dropdown { | |
background-color: #3a3a3a; | |
color: #fff; | |
border: 1px solid #4a4a4a; | |
border-radius: 4px; | |
padding: 8px; | |
} | |
.gr-form { | |
background-color: transparent; | |
} | |
.gr-panel { | |
border: none; | |
background-color: transparent; | |
} | |
/* Override any conflicting styles from Bulma */ | |
.button.is-normal.is-rounded.is-dark { | |
color: #fff !important; | |
text-decoration: none !important; | |
} | |
</style> | |
""" | |
header_html = f""" | |
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/[email protected]/css/bulma.min.css"> | |
<link rel="stylesheet" href="https://use.fontawesome.com/releases/v5.15.4/css/all.css"> | |
{app_styles} | |
<div class="app-header"> | |
<h1 class="app-title">AI Dentistry Application</h1> | |
<h2 class="app-subtitle">Steven Fernandes, Ph.D.</h2> | |
<p class="app-description"> | |
This application demonstrates the use of AI in dentistry for calculus and caries classification, caries detection, and dental x-ray segmentation. | |
</p> | |
</div> | |
""" | |
js_func = """ | |
function refresh() { | |
const url = new URL(window.location); | |
if (url.searchParams.get('__theme') !== 'dark') { | |
url.searchParams.set('__theme', 'dark'); | |
window.location.href = url.href; | |
} | |
} | |
""" | |
def process_image(image, model_name): | |
result = self.image_processor.process_image(image, model_name) | |
return result | |
def update_examples(model_name): | |
examples = self.preloaded_examples[model_name] | |
return gr.Dataset(samples=[[example] for example in examples]) | |
with gr.Blocks(js=js_func, theme=gr.themes.Default()) as demo: | |
gr.HTML(header_html) | |
with gr.Row(elem_classes="content-container"): | |
with gr.Column(): | |
input_image = gr.Image(label="Input Image", type="pil", format="png", elem_classes="image-preview") | |
with gr.Row(elem_classes="control-panel"): | |
model_name = gr.Dropdown( | |
label="Model", | |
choices=list(Config.MODELS.keys()), | |
value="Calculus and Caries Classification", | |
) | |
examples_classification = gr.Examples( | |
label="Classification Examples", | |
inputs=input_image, | |
examples=self.preloaded_examples["Calculus and Caries Classification"], | |
) | |
examples_detection = gr.Examples( | |
label="Caries Detection Examples", | |
inputs=input_image, | |
examples=self.preloaded_examples["Caries Detection"], | |
) | |
examples_segmentation = gr.Examples( | |
label="Segmentation Examples", | |
inputs=input_image, | |
examples=self.preloaded_examples["Dental X-Ray Segmentation"], | |
) | |
with gr.Column(): | |
result = gr.Image(label="Result", elem_classes="image-preview") | |
run_button = gr.Button("Run", elem_classes="gr-button") | |
model_name.change( | |
fn=update_examples, | |
inputs=model_name, | |
outputs=[examples_classification.dataset, examples_detection.dataset, examples_segmentation.dataset], | |
) | |
run_button.click( | |
fn=process_image, | |
inputs=[input_image, model_name], | |
outputs=result, | |
) | |
return demo | |
def main(): | |
interface = GradioInterface() | |
demo = interface.create_interface() | |
demo.launch(debug=True) | |
if __name__ == "__main__": | |
main() |