import torch import pandas as pd import ast import numpy as np def use_pregenerated_textemb_fromchris(taxon_id, text_type): #zero vector is for no text input text_embedding = torch.zeros(1,4096) if text_type is None or text_type == 'none': return text_embedding, 0 embs1 = torch.load('experiments/gpt_data.pt', map_location='cpu') emb_ids1 = embs1['taxon_id'].tolist() keys1 = embs1['keys'] embs1 = embs1['data'] taxa_of_interest = taxon_id taxa_index_of_interest = emb_ids1.index(taxa_of_interest) # gets 5 #keys_with_taxa_of_interest = [key for key in keys1 if key[0] == taxa_index_of_interest] #indices_with_taxa_of_interest = [(key, i) for i, key in enumerate(keys1) if key[0] == taxa_index_of_interest] possible_text_embedding_indexes = [i for i, key in enumerate(keys1) if key[0] == taxa_index_of_interest and key[1]==text_type] if len(possible_text_embedding_indexes) != 1: return text_embedding, 0 # take a look and choose what you want # for key in indices_with_taxa_of_interest: # print(key) # ((5, 'range'), 20) # ((5, 'habitat'), 21) # ((5, 'species_description'), 22) # ((5, 'overview_summary'), 23) #macaw: range: 20, habitat: 21 #baboon: range: 7928, habitat: 7929 #black&white warbler: range: 16, habitat: 17 #barn swallow: range: 1652, habitat: 1653 #pika: range: 7116, habitat: 7117 #loon: range: 11056, habitat:11057 #euro robin: range: 2020, habitat: 2021 #sfs: range: 7148, habitat: 7149 text_embedding_index = possible_text_embedding_indexes[0] text_embedding = embs1[text_embedding_index].unsqueeze(0) #print(text_embedding_index) return text_embedding, text_embedding_index def use_pregenerated_textemb_fromcsv(input_text): text_data = pd.read_csv('data/text_embs/text_embeddings_fig4.csv') result_row = text_data[text_data['Text'] == input_text] text_emb = ast.literal_eval(result_row['Embedding'].values[0]) embedding_tensor = torch.FloatTensor(text_emb) return embedding_tensor def get_eval_context_points(taxa_id, context_data, size): all_context_pts = context_data['locs'][context_data['labels'] == np.argwhere(context_data['class_to_taxa'] == taxa_id)[0]][1:] context_pts = all_context_pts[0:size] dummy_classtoken = np.array([[0,0]]) context_pts = np.vstack((dummy_classtoken, context_pts)) #print(f"context point shape: {np.shape(context_pts)}") normalized_pts = torch.from_numpy(context_pts) * torch.tensor([[1/180,1/90]], device='cpu') return normalized_pts if __name__ == '__main__': print('starting to generate text_embs') output_file = './data/text_embs/text_embeddings_fig4.csv' use_pregenerated_textemb_fromchris()