Spaces:
Sleeping
Sleeping
import gradio as gr | |
from viz_ls_map import main | |
from get_gt import generate_ground_truth | |
def predict_species_distribution(taxa_id, taxa_name, text_type, num_context_points): | |
""" | |
Function to predict species distribution and visualize the map. | |
""" | |
isSnt = False | |
taxa_id = int(taxa_id) | |
#num_context_points = [0, 1, 2, 5, 10, 20] | |
num_context_points = [1] | |
# Generate ground truth for the species | |
#generate_ground_truth(taxa_id, isSnt) | |
image_path_gt = f'images/species_presence_hr_{taxa_id}.png' | |
output_images = [] | |
#print(num_context_points) | |
for text_type_i in ['none','range','habitat']: | |
# Set up evaluation parameters | |
eval_params = { | |
'model_path': './experiments/zero_shot_ls_sin_cos_env_cap_1000_text_context_20_sinr_two_layer_nn/model.pt', | |
'taxa_id': taxa_id, | |
'threshold': -1, | |
'op_path': './images/', | |
'rand_taxa': False, | |
'high_res': True, | |
'disable_ocean_mask': False, | |
'set_max_cmap_to_1': False, | |
'device': 'cpu', | |
'show_map': 1, | |
'show_context_points': 1, | |
'prefix': '', | |
'num_context': num_context_points, | |
'choose_context_points': 1, | |
'additional_save_name': "", | |
'taxa_name': taxa_name, | |
'test_taxa': taxa_id, | |
'text_type': text_type_i, # 'none', 'habitat', or 'range' | |
'context_pt_trial': num_context_points, | |
} | |
# Run the FS-SINR model with the specified parameters | |
main(eval_params) | |
# The output image is saved in './images/' with the predicted range map | |
#image_path = f'./images/{taxa_name}_predicted_range.png' | |
for k in num_context_points: | |
# Assume image filenames are stored like this | |
image_path = f'./images/testenv_{taxa_name}(selected_points)_{text_type_i}_{k}.png' | |
output_images.append(image_path) | |
return [image_path_gt] + output_images | |
#return True | |
# Define the Gradio interface | |
with gr.Blocks() as demo: | |
gr.Markdown("# View Species Distribution Predictions using FS-SINR") | |
# Input fields for the Gradio interface | |
taxa_id = gr.Number(label="Taxa ID", value=43188) | |
taxa_name = gr.Textbox(label="Taxa Name", value="test_pika") | |
text_type = gr.Radio(label="Text Type", choices=['none', 'habitat', 'range'], value='none') | |
#num_context_points = gr.Slider(label="Number of Context Points", minimum=1, maximum=20, value=5, step=1) | |
num_context_points = gr.CheckboxGroup([0,1,2,3,4,5,10,15,20], label="Number of Context Points") | |
# Button to trigger the prediction | |
predict_button = gr.Button("Predict Species Distribution") | |
# Output: predicted range map | |
ground_truth = gr.Image(label="Ground Truth Map") | |
none_maps = gr.Image(label=f"Map for No Text Input and Context Point {1}") | |
range_maps = gr.Image(label=f"Map for Range Text input and Context Point {1}") | |
hab_maps = gr.Image(label=f"Map for Habitat Text input and Context Point {1}") | |
output_images = [ground_truth, none_maps, range_maps, hab_maps] | |
# Link the button to the function and inputs | |
predict_button.click(fn=predict_species_distribution, | |
inputs=[taxa_id, taxa_name, text_type, num_context_points], | |
outputs=output_images) | |
# Launch the Gradio interface | |
demo.launch() |