Spaces:
Running
Running
import gradio as gr | |
import numpy as np | |
import pandas as pd | |
import matplotlib.pyplot as plt | |
import torch | |
from transformers import AutoImageProcessor, AutoModelForImageClassification | |
from PIL import Image | |
from difflib import get_close_matches | |
from typing import Optional, Dict, Any | |
import json | |
import io | |
from datasets import load_dataset # Import the datasets library | |
# ------------------------------------------------- | |
# Configuration | |
# ------------------------------------------------- | |
# Define insulin types and their durations and peak times | |
INSULIN_TYPES = { | |
"Rapid-Acting": {"onset": 0.25, "duration": 4, "peak_time": 1.0}, # Onset in hours, duration in hours, peak time in hours | |
"Long-Acting": {"onset": 2, "duration": 24, "peak_time": 8}, | |
} | |
#Define basal rates | |
DEFAULT_BASAL_RATES = { | |
"00:00-06:00": 0.8, | |
"06:00-12:00": 1.0, | |
"12:00-18:00": 0.9, | |
"18:00-24:00": 0.7 | |
} | |
# ------------------------------------------------- | |
# Load Food Data from Hugging Face Dataset | |
# ------------------------------------------------- | |
def load_food_data(dataset_name="Anupam251272/food_nutrition"): | |
try: | |
dataset = load_dataset(dataset_name) | |
food_data = dataset['train'].to_pandas() | |
# Normalize column names to lowercase and remove spaces | |
food_data.columns = [col.lower().replace(' ', '') for col in food_data.columns] | |
# Remove unnamed columns | |
food_data = food_data.loc[:, ~food_data.columns.str.contains('^unnamed')] # This line removes the columns | |
# Normalize food_name column to lowercase: Crucial for matching | |
if 'food_name' in food_data.columns: | |
food_data['food_name'] = food_data['food_name'].str.lower() | |
print("Unique Food Names in Dataset:") # ADDED | |
print(food_data['food_name'].unique()) # ADDED | |
else: | |
print("Warning: 'food_name' column not found in dataset.") | |
food_data = pd.DataFrame({ | |
'food_category': ['starch'], | |
'food_subcategory': ['bread'], | |
'food_name': ['white bread'], # lowercase default | |
'serving_description': ['servingsize'], | |
'serving_amount': [29], | |
'serving_unit': ['g'], | |
'carbohydrate_grams': [15], | |
'notes': ['default'] | |
}) | |
#Print first 5 rows to check columns and values | |
print("First 5 rows of loaded data from Hugging Face Dataset:") | |
print(food_data.head()) | |
return food_data | |
except Exception as e: | |
print(f"Error loading Hugging Face Dataset: {e}") | |
# Provide minimal default data in case of error | |
food_data = pd.DataFrame({ | |
'food_category': ['starch'], | |
'food_subcategory': ['bread'], | |
'food_name': ['white bread'], # lowercase default | |
'serving_description': ['servingsize'], | |
'serving_amount': [29], | |
'serving_unit': ['g'], | |
'carbohydrate_grams': [15], | |
'notes': ['default'] | |
}) | |
return food_data | |
# ------------------------------------------------- | |
# Load Food Classification Model | |
# ------------------------------------------------- | |
try: | |
processor = AutoImageProcessor.from_pretrained("therealcyberlord/vit-indian-food") | |
model = AutoModelForImageClassification.from_pretrained( | |
"therealcyberlord/vit-indian-food", | |
torch_dtype=torch.float16, | |
device_map="cpu", #This model will only use CPU! | |
low_cpu_mem_usage=True # Force low memory usage, no matter the device | |
) | |
model_loaded = True #Flag for error handling in other defs | |
except Exception as e: | |
print(f"Model Load Error", str(e)) | |
model_loaded = False | |
processor = None | |
model = None | |
def classify_food(image): | |
"""Classify food image using the pre-trained model""" | |
print("classify_food function called") # Check if this function is even called | |
try: | |
if not model_loaded: | |
print("Model not loaded, returning 'Unknown'") | |
return "Unknown" | |
print(f"Image type: {type(image)}") # Check the type of the image | |
if isinstance(image, np.ndarray): | |
print("Image is a numpy array, converting to PIL Image") | |
image = Image.fromarray(image) | |
print(f"Image mode: {image.mode}") # Check image mode (e.g., RGB, L) | |
image = processor(images=image, return_tensors="pt") | |
print(f"Processed image: {image}") # Print the output of the processor | |
with torch.no_grad(): | |
outputs = model(**image) | |
predicted_idx = torch.argmax(outputs.logits, dim=-1).item() | |
food_name = model.config.id2label.get(predicted_idx, "Unknown Food") | |
print(f"Predicted food name: {food_name}") # Print the predicted food name | |
return food_name.lower() # Convert classification to lowercase | |
except Exception as e: | |
print(f"Classify food error: {e}") # Print the full error message | |
return "Unknown" # If an exception arises make sure to create a default case | |
# ------------------------------------------------- | |
# USDA API Integration - REMOVED for local HF Spaces deployment | |
# ------------------------------------------------- | |
def get_food_nutrition(food_name: str, food_data, portion_size: float = 1.0) -> Optional[Dict[str, Any]]: | |
"""Get carbohydrate content for the given food""" #No USDA anymore | |
try: | |
# First try the local CSV database | |
food_name_lower = food_name.lower() # Ensure input is also lowercase | |
food_names = food_data['food_name'].str.lower().tolist() #Already lowercased during load | |
print(f"Searching for: {food_name_lower}") # Debugging: What are we searching for? | |
matches = get_close_matches(food_name_lower, food_names, n=1, cutoff=0.5) | |
if matches: | |
# Use local database match | |
matched_row = food_data[food_data['food_name'].str.lower() == matches[0]] | |
if not matched_row.empty: | |
row = matched_row.iloc[0] | |
# Debugging: Print the entire row | |
print(f"Matched row from CSV: {row}") | |
# Explicitly check for column existence and valid data | |
carb_col = 'carbohydrate_grams' | |
amount_col = 'serving_amount' | |
unit_col = 'serving_unit' | |
if carb_col not in row or pd.isna(row[carb_col]): | |
print(f"Warning: '{carb_col}' is missing or NaN in CSV") | |
base_carbs = 0.0 | |
else: | |
base_carbs = row[carb_col] | |
try: | |
base_carbs = float(base_carbs) # Ensure it's a float | |
except ValueError: | |
print(f"Warning: '{carb_col}' is not a valid number in CSV") | |
base_carbs = 0.0 | |
if amount_col not in row or unit_col not in row or pd.isna(row[amount_col]) or pd.isna(row[unit_col]): | |
serving_size = "Unknown" | |
print(f"Warning: '{amount_col}' or '{unit_col}' is missing in CSV") | |
else: | |
serving_size = f"{row[amount_col]} {row[unit_col]}" | |
adjusted_carbs = base_carbs * portion_size | |
return { | |
'matched_food': row['food_name'], | |
'category': row['food_category'] if 'food_category' in row and not pd.isna(row['food_category']) else 'Unknown', | |
'subcategory': row['food_subcategory'] if 'food_subcategory' in row and not pd.isna(row['food_subcategory']) else 'Unknown', | |
'base_carbs': base_carbs, | |
'adjusted_carbs': adjusted_carbs, | |
'serving_size': serving_size, | |
'portion_multiplier': portion_size, | |
'notes': row['notes'] if 'notes' in row and not pd.isna(row['notes']) else '' | |
} | |
# If no match found in local database | |
print(f"No match found in CSV for {food_name}") # Debugging line | |
print(f"No nutrition information found for {food_name} in the local database.") # Debugging line | |
return None | |
except Exception as e: | |
print(f"Error in get_food_nutrition: {e}") | |
return None | |
# ------------------------------------------------- | |
# Insulin and Glucose Calculations | |
# ------------------------------------------------- | |
def get_basal_rate(current_time_hour, basal_rates): | |
"""Gets the appropriate basal rate for a given time of day.""" | |
for interval, rate in basal_rates.items(): | |
try: # add a try and except to handle values in intervals that do not have the format "start-end" | |
parts = interval.split(":")[0].split("-") | |
if len(parts) == 2: # Check if there are two parts (start and end) | |
start_hour, end_hour = map(int, parts) | |
if start_hour <= current_time_hour < end_hour or (start_hour <= current_time_hour and end_hour == 24): | |
return rate | |
except: | |
print(f"Warning: Invalid interval format: {interval}. Skipping.") #Inform user of formatting issues | |
return 0 # Default if no matching interval | |
def insulin_activity(t, insulin_type, bolus_dose, bolus_duration=0): | |
"""Models insulin activity over time.""" | |
insulin_data = INSULIN_TYPES.get(insulin_type) | |
if not insulin_data: | |
return 0 # Or raise an error | |
# Simple exponential decay model (replace with a more sophisticated model) | |
peak_time = insulin_data['peak_time'] # Time in hours at which insulin activity is at max level | |
duration = insulin_data['duration'] # Total time for which insulin stays in effect | |
if t < peak_time: | |
activity = (bolus_dose * t / peak_time) * np.exp(1- t/peak_time) # rising activity | |
elif t < duration: | |
activity = bolus_dose * np.exp((peak_time - t) / (duration - peak_time)) # decaying activity | |
else: | |
activity = 0 | |
if bolus_duration > 0: # Extended Bolus | |
if 0 <= t <= bolus_duration: | |
# Linear release of insulin over bolus_duration | |
effective_dose = bolus_dose / bolus_duration | |
duration = INSULIN_TYPES.get(insulin_type)['duration'] | |
if t < duration: | |
activity = effective_dose | |
else: | |
activity = 0 | |
else: | |
activity = 0 | |
return activity | |
def calculate_active_insulin(insulin_history, current_time): | |
"""Calculates remaining active insulin from previous doses.""" | |
active_insulin = 0 | |
for dose_time, dose_amount, insulin_type, bolus_duration in insulin_history: | |
elapsed_time = current_time - dose_time | |
remaining_activity = insulin_activity(elapsed_time, insulin_type, dose_amount, bolus_duration) | |
active_insulin += remaining_activity | |
return active_insulin | |
def calculate_insulin_needs(carbs, glucose_current, glucose_target, tdd, weight, insulin_type="Rapid-Acting", override_correction_dose = None): | |
"""Calculate insulin needs for Type 1 diabetes""" | |
if tdd <= 0: | |
return { | |
'error': 'Total Daily Dose (TDD) must be greater than 0' | |
} | |
insulin_data = INSULIN_TYPES.get(insulin_type) | |
if not insulin_data: | |
return { | |
'error': "Invalid insulin type. Choose from" + ", ".join(INSULIN_TYPES.keys()) | |
} | |
# Refined calculations | |
icr = (450 if weight <= 45 else 500) / tdd | |
isf = 1700 / tdd | |
# Calculate correction dose | |
glucose_difference = glucose_current - glucose_target | |
correction_dose = glucose_difference / isf | |
if override_correction_dose is not None: # Check for None | |
correction_dose = override_correction_dose | |
# Calculate carb dose | |
carb_dose = carbs / icr | |
# Calculate total bolus | |
total_bolus = max(0, carb_dose + correction_dose) | |
# Calculate basal | |
basal_dose = weight * 0.5 | |
return { | |
'icr': round(icr, 2), | |
'isf': round(isf, 2), | |
'correction_dose': round(correction_dose, 2), | |
'carb_dose': round(carb_dose, 2), | |
'total_bolus': round(total_bolus, 2), | |
'basal_dose': round(basal_dose, 2), | |
'insulin_type': insulin_type, | |
'insulin_onset': insulin_data['onset'], | |
'insulin_duration': insulin_data['duration'], | |
'peak_time': insulin_data['peak_time'], | |
} | |
def create_detailed_report(nutrition_info, insulin_info, current_basal_rate): | |
"""Create a detailed report of carbs and insulin calculations""" | |
carb_details = f""" | |
FOOD DETAILS: | |
------------- | |
Detected Food: {nutrition_info['matched_food']} | |
Category: {nutrition_info['category']} | |
Subcategory: {nutrition_info['subcategory']} | |
CARBOHYDRATE INFORMATION: | |
------------------------ | |
Standard Serving Size: {nutrition_info['serving_size']} | |
Carbs per Serving: {nutrition_info['base_carbs']}g | |
Portion Multiplier: {nutrition_info['portion_multiplier']}x | |
Total Carbs: {nutrition_info['adjusted_carbs']}g | |
Notes: {nutrition_info['notes']} | |
""" | |
insulin_details = f""" | |
INSULIN CALCULATIONS: | |
-------------------- | |
ICR (Insulin to Carb Ratio): 1:{insulin_info['icr']} | |
ISF (Insulin Sensitivity Factor): 1:{insulin_info['isf']} | |
Insulin Type: {insulin_info['insulin_type']} | |
Onset: {insulin_info['insulin_onset']} hours | |
Duration: {insulin_info['insulin_duration']} hours | |
Peak Time: {insulin_info['peak_time']} hours | |
RECOMMENDED DOSES: | |
----------------- | |
Correction Dose: {insulin_info['correction_dose']} units | |
Carb Dose: {insulin_info['carb_dose']} units | |
Total Bolus: {insulin_info['total_bolus']} units | |
Daily Basal: {insulin_info['basal_dose']} units | |
Current Basal Rate: {current_basal_rate} units/hour | |
""" | |
return carb_details, insulin_details | |
# ------------------------------------------------- | |
# Main Dashboard Function | |
# ------------------------------------------------- | |
def diabetes_dashboard(initial_glucose, food_image, stress_level, sleep_hours, time_hours, | |
weight, tdd, target_glucose, exercise_duration, exercise_intensity, portion_size, insulin_type, | |
override_correction_dose, extended_bolus_duration, basal_rates_input): | |
"""Main dashboard function""" | |
try: | |
# 0. Load Files | |
food_data = load_food_data() #loads HF Datasets from the function | |
# 1. Food Classification and Carb Calculation | |
food_name = classify_food(food_image) # This line is now inside the function | |
print(f"Classified food name: {food_name}") # Debugging: What is classified as? # Corrected indentation | |
nutrition_info = get_food_nutrition(food_name, food_data, portion_size) # Changed to pass in data | |
if not nutrition_info: | |
# Try with generic categories if specific food not found | |
generic_terms = food_name.split() | |
for term in generic_terms: | |
nutrition_info = get_food_nutrition(term, food_data, portion_size) # Changed to pass in data | |
if nutrition_info: | |
break | |
if not nutrition_info: | |
return ( | |
f"Could not find nutrition information for: {food_name} in the local database", # Removed USDA ref | |
"No insulin calculations available", | |
None, | |
None, | |
None | |
) | |
# 2. Insulin Calculations | |
try: | |
basal_rates_dict = json.loads(basal_rates_input) | |
except: | |
print("Basal rates JSON invalid, using default") | |
basal_rates_dict = DEFAULT_BASAL_RATES | |
insulin_info = calculate_insulin_needs( | |
nutrition_info['adjusted_carbs'], | |
initial_glucose, | |
target_glucose, | |
tdd, | |
weight, | |
insulin_type, | |
override_correction_dose # Pass override | |
) | |
if 'error' in insulin_info: | |
return insulin_info['error'], None, None, None, None | |
# 3. Create detailed reports | |
current_basal_rate = get_basal_rate(12, basal_rates_dict) # Added basal rate to the function and report. | |
carb_details, insulin_details = create_detailed_report(nutrition_info, insulin_info, current_basal_rate) | |
# 4. Glucose Prediction | |
hours = list(range(time_hours)) | |
glucose_levels = [] | |
current_glucose = initial_glucose | |
insulin_history = [] # This will store all past doses for active insulin calculations | |
# simulate that a dose has just been given to the patient at t=0 | |
insulin_history.append((0, insulin_info['total_bolus'], insulin_info['insulin_type'], extended_bolus_duration)) # Pass bolus duration | |
for t in hours: | |
# Factor in carbs effect (peaks at 1-2 hours) | |
carb_effect = nutrition_info['adjusted_carbs'] * 0.1 * np.exp(-(t - 1.5) ** 2 / 2) | |
# Factor in insulin effect (peaks at 2-3 hours) | |
# Original model: insulin_effect = insulin_info['total_bolus'] * 2 * np.exp(-(t-2.5)**2/2) | |
# get effect based on amount of insulin still active from previous boluses | |
active_insulin = calculate_active_insulin(insulin_history, t) | |
insulin_effect = insulin_activity(t, insulin_type, active_insulin, extended_bolus_duration) # Pass bolus duration | |
# Get the basal effect | |
basal_rate = get_basal_rate(t, basal_rates_dict) | |
basal_insulin_effect = basal_rate # Units per hour | |
# Add stress effect | |
stress_effect = stress_level * 2 | |
# Add sleep effect | |
sleep_effect = abs(8 - sleep_hours) * 5 | |
# Add exercise effect | |
exercise_effect = (exercise_duration / 60) * exercise_intensity * 2 | |
# Calculate glucose with all factors | |
glucose = (current_glucose + carb_effect - insulin_effect + | |
stress_effect + sleep_effect + exercise_effect - basal_insulin_effect) | |
glucose_levels.append(max(70, min(400, glucose))) | |
current_glucose = glucose_levels[-1] | |
# 5. Create visualization | |
fig, ax = plt.subplots(figsize=(12, 6)) | |
ax.plot(hours, glucose_levels, 'b-', label='Predicted Glucose') | |
ax.axhline(y=target_glucose, color='g', linestyle='--', label='Target') | |
ax.fill_between(hours, [70] * len(hours), [180] * len(hours), | |
alpha=0.1, color='g', label='Target Range') | |
ax.set_ylabel('Glucose (mg/dL)') | |
ax.set_xlabel('Hours') | |
ax.set_title('Predicted Blood Glucose Over Time') | |
ax.legend() | |
ax.grid(True) | |
return ( | |
carb_details, | |
insulin_details, | |
insulin_info['basal_dose'], | |
insulin_info['total_bolus'], | |
fig | |
) | |
except Exception as e: | |
return f"Error: {str(e)}", None, None, None, None | |
# ------------------------------------------------- | |
# Gradio Interface Setup | |
# ------------------------------------------------- | |
with gr.Blocks() as app: # using Blocks API to manually design the layout | |
gr.Markdown("# Type 1 Diabetes Management Dashboard") | |
with gr.Tab("Glucose & Meal"): | |
with gr.Row(): | |
initial_glucose = gr.Number(label="Current Blood Glucose (mg/dL)", value=120) | |
food_image = gr.Image(label="Food Image", type="pil") # Now a file upload | |
with gr.Row(): | |
portion_size = gr.Slider(0.1, 3, step=0.1, label="Portion Size Multiplier", value=1.0) | |
with gr.Tab("Insulin"): | |
with gr.Column(): # Place inputs in a column layout | |
insulin_type = gr.Dropdown(choices=list(INSULIN_TYPES.keys()), label="Insulin Type", value="Rapid-Acting") | |
override_correction_dose = gr.Number(label="Override Correction Dose (Units)", value=None) | |
extended_bolus_duration = gr.Number(label="Extended Bolus Duration (Hours)", value=0) | |
with gr.Tab("Basal Settings"): | |
with gr.Column(): | |
basal_rates_input = gr.Textbox(label="Basal Rates (JSON)", lines=3, | |
value="""{"00:00-06:00": 0.8, "06:00-12:00": 1.0, "12:00-18:00": 0.9, "18:00-24:00": 0.7}""") | |
with gr.Tab("Other Factors"): | |
with gr.Accordion("Factors affecting Glucose levels", open=False): # keep advanced options collapsed by default | |
weight = gr.Number(label="Weight (kg)", value=70) | |
tdd = gr.Number(label="Total Daily Dose (TDD) of insulin (units)", value=40) | |
target_glucose = gr.Number(label="Target Blood Glucose (mg/dL)", value=100) | |
stress_level = gr.Slider(1, 10, step=1, label="Stress Level (1-10)", value=1) | |
sleep_hours = gr.Number(label="Sleep Hours", value=7) | |
exercise_duration = gr.Number(label="Exercise Duration (minutes)", value=0) | |
exercise_intensity = gr.Slider(1, 10, step=1, label="Exercise Intensity (1-10)", value=1) | |
with gr.Row(): | |
time_hours = gr.Slider(1, 24, step=1, label="Prediction Time (hours)", value=6) | |
with gr.Row(): | |
calculate_button = gr.Button("Calculate") | |
with gr.Column(): | |
carb_details_output = gr.Textbox(label="Carbohydrate Details", lines=5) | |
insulin_details_output = gr.Textbox(label="Insulin Calculation Details", lines=5) | |
basal_dose_output = gr.Number(label="Basal Insulin Dose (units/day)") | |
bolus_dose_output = gr.Number(label="Bolus Insulin Dose (units)") | |
glucose_plot_output = gr.Plot(label="Glucose Prediction") | |
calculate_button.click( | |
fn=diabetes_dashboard, | |
inputs=[ | |
initial_glucose, | |
food_image, | |
stress_level, | |
sleep_hours, | |
time_hours, | |
weight, | |
tdd, | |
target_glucose, | |
exercise_duration, | |
exercise_intensity, | |
portion_size, | |
insulin_type, | |
override_correction_dose, | |
extended_bolus_duration, | |
basal_rates_input, | |
], | |
outputs=[ | |
carb_details_output, | |
insulin_details_output, | |
basal_dose_output, | |
bolus_dose_output, | |
glucose_plot_output | |
] | |
) | |
if __name__ == "__main__": | |
app.launch(share=True) |