fs_sinr / app.py
angelazhu96
rewrite history
0b54529
import gradio as gr
from viz_ls_map import main
from get_gt import generate_ground_truth
def predict_species_distribution(taxa_id, taxa_name, text_type, num_context_points):
"""
Function to predict species distribution and visualize the map.
"""
isSnt = False
taxa_id = int(taxa_id)
#num_context_points = [0, 1, 2, 5, 10, 20]
num_context_points = [1]
# Generate ground truth for the species
#generate_ground_truth(taxa_id, isSnt)
image_path_gt = f'images/species_presence_hr_{taxa_id}.png'
output_images = []
#print(num_context_points)
for text_type_i in ['none','range','habitat']:
# Set up evaluation parameters
eval_params = {
'model_path': './experiments/zero_shot_ls_sin_cos_env_cap_1000_text_context_20_sinr_two_layer_nn/model.pt',
'taxa_id': taxa_id,
'threshold': -1,
'op_path': './images/',
'rand_taxa': False,
'high_res': True,
'disable_ocean_mask': False,
'set_max_cmap_to_1': False,
'device': 'cpu',
'show_map': 1,
'show_context_points': 1,
'prefix': '',
'num_context': num_context_points,
'choose_context_points': 1,
'additional_save_name': "",
'taxa_name': taxa_name,
'test_taxa': taxa_id,
'text_type': text_type_i, # 'none', 'habitat', or 'range'
'context_pt_trial': num_context_points,
}
# Run the FS-SINR model with the specified parameters
main(eval_params)
# The output image is saved in './images/' with the predicted range map
#image_path = f'./images/{taxa_name}_predicted_range.png'
for k in num_context_points:
# Assume image filenames are stored like this
image_path = f'./images/testenv_{taxa_name}(selected_points)_{text_type_i}_{k}.png'
output_images.append(image_path)
return [image_path_gt] + output_images
#return True
# Define the Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# View Species Distribution Predictions using FS-SINR")
# Input fields for the Gradio interface
taxa_id = gr.Number(label="Taxa ID", value=43188)
taxa_name = gr.Textbox(label="Taxa Name", value="test_pika")
text_type = gr.Radio(label="Text Type", choices=['none', 'habitat', 'range'], value='none')
#num_context_points = gr.Slider(label="Number of Context Points", minimum=1, maximum=20, value=5, step=1)
num_context_points = gr.CheckboxGroup([0,1,2,3,4,5,10,15,20], label="Number of Context Points")
# Button to trigger the prediction
predict_button = gr.Button("Predict Species Distribution")
# Output: predicted range map
ground_truth = gr.Image(label="Ground Truth Map")
none_maps = gr.Image(label=f"Map for No Text Input and Context Point {1}")
range_maps = gr.Image(label=f"Map for Range Text input and Context Point {1}")
hab_maps = gr.Image(label=f"Map for Habitat Text input and Context Point {1}")
output_images = [ground_truth, none_maps, range_maps, hab_maps]
# Link the button to the function and inputs
predict_button.click(fn=predict_species_distribution,
inputs=[taxa_id, taxa_name, text_type, num_context_points],
outputs=output_images)
# Launch the Gradio interface
demo.launch()