import gradio as gr from viz_ls_map import main from get_gt import generate_ground_truth def predict_species_distribution(taxa_id, taxa_name, text_type, num_context_points): """ Function to predict species distribution and visualize the map. """ isSnt = False taxa_id = int(taxa_id) #num_context_points = [0, 1, 2, 5, 10, 20] num_context_points = [1] # Generate ground truth for the species #generate_ground_truth(taxa_id, isSnt) image_path_gt = f'images/species_presence_hr_{taxa_id}.png' output_images = [] #print(num_context_points) for text_type_i in ['none','range','habitat']: # Set up evaluation parameters eval_params = { 'model_path': './experiments/zero_shot_ls_sin_cos_env_cap_1000_text_context_20_sinr_two_layer_nn/model.pt', 'taxa_id': taxa_id, 'threshold': -1, 'op_path': './images/', 'rand_taxa': False, 'high_res': True, 'disable_ocean_mask': False, 'set_max_cmap_to_1': False, 'device': 'cpu', 'show_map': 1, 'show_context_points': 1, 'prefix': '', 'num_context': num_context_points, 'choose_context_points': 1, 'additional_save_name': "", 'taxa_name': taxa_name, 'test_taxa': taxa_id, 'text_type': text_type_i, # 'none', 'habitat', or 'range' 'context_pt_trial': num_context_points, } # Run the FS-SINR model with the specified parameters main(eval_params) # The output image is saved in './images/' with the predicted range map #image_path = f'./images/{taxa_name}_predicted_range.png' for k in num_context_points: # Assume image filenames are stored like this image_path = f'./images/testenv_{taxa_name}(selected_points)_{text_type_i}_{k}.png' output_images.append(image_path) return [image_path_gt] + output_images #return True # Define the Gradio interface with gr.Blocks() as demo: gr.Markdown("# View Species Distribution Predictions using FS-SINR") # Input fields for the Gradio interface taxa_id = gr.Number(label="Taxa ID", value=43188) taxa_name = gr.Textbox(label="Taxa Name", value="test_pika") text_type = gr.Radio(label="Text Type", choices=['none', 'habitat', 'range'], value='none') #num_context_points = gr.Slider(label="Number of Context Points", minimum=1, maximum=20, value=5, step=1) num_context_points = gr.CheckboxGroup([0,1,2,3,4,5,10,15,20], label="Number of Context Points") # Button to trigger the prediction predict_button = gr.Button("Predict Species Distribution") # Output: predicted range map ground_truth = gr.Image(label="Ground Truth Map") none_maps = gr.Image(label=f"Map for No Text Input and Context Point {1}") range_maps = gr.Image(label=f"Map for Range Text input and Context Point {1}") hab_maps = gr.Image(label=f"Map for Habitat Text input and Context Point {1}") output_images = [ground_truth, none_maps, range_maps, hab_maps] # Link the button to the function and inputs predict_button.click(fn=predict_species_distribution, inputs=[taxa_id, taxa_name, text_type, num_context_points], outputs=output_images) # Launch the Gradio interface demo.launch()