import torch import csv from gritlm import GritLM import pandas as pd import ast import numpy as np input_text4 = ['The hyacinth macaw prefers semi-open, somewhat wooded habitats. It usually avoids dense, humid forest, and in regions dominated by such habitats, it is generally restricted to the edge or relatively open sections (e.g. along major rivers). In different areas of their range, these parrots are found in savannah grasslands, in dry thorn forests known as caatinga, and in palm stands or swamps, particularly the moriche palm (Mauritia flexuosa).', 'The hyacinth macaw occurs today in three main areas in South America: In the Pantanal region of Brazil, and adjacent eastern Bolivia and northeastern Paraguay, in the cerrado regions of the eastern interior of Brazil (Maranhão, Piauí, Bahia, Tocantins, Goiás, Mato Grosso, Mato Grosso do Sul, and Minas Gerais), and in the relatively open areas associated with the Tocantins River, Xingu River, Tapajós River, and the Marajó island in the eastern Amazon Basin of Brazil.', 'They are diurnal, terrestrial, and live in complex, mixed-gender social groups of 8 to 200 individuals per troop. They prefer savannas and light forests with a climate that is suitable for their omnivorous diet.', 'Yellow baboons inhabit savannas and light forests in eastern Africa, from Kenya and Tanzania to Zimbabwe and Botswana.'] input_text5 = ['chappell roan', 'europe', 'pawpaw', 'sierra nevada', 'great lakes', 'Treaty of Waitangi', 'hello kitty', 'disney', 'madagascar', 'Andes', 'africa', 'dessert', 'whale', 'moon snail', 'unicorn', 'rainfall', 'species occurs above 2000m of elevation', 'froyo', 'desert', 'dragon', 'bear', 'selkie', 'loch ness monster'] def extract_grit_token(model, text:str): def gritlm_instruction(instruction): return "<|user|>\n" + instruction + "\n<|embed|>\n" if instruction else "<|embed|>\n" d_rep = model.encode([text], instruction=gritlm_instruction("")) d_rep = torch.from_numpy(d_rep) return d_rep def generate_text_embs(text, output_file): grit = GritLM("GritLM/GritLM-7B", torch_dtype="auto", mode="embedding") with open(output_file, mode='w') as file: writer = csv.writer(file) writer.writerow(['Text', 'Embedding']) for i in range(0, len(text)): text_emb = extract_grit_token(grit, text[i]).to('cpu') print(f" {text[i]}: {text_emb} ") writer.writerow([text[i], text_emb.tolist()]) #TODO: max's generate text using grit def generate_text_emb(text): grit = GritLM("GritLM/GritLM-7B", torch_dtype="auto", mode="embedding") text_emb = extract_grit_token(grit, text) return text_emb def use_pregenerated_textemb_fromgpt(taxon_id): embs_loaded = torch.load('experiments/gpt_data.pt', map_location='cpu') emb_ids = embs_loaded['taxon_id'].tolist() #(2785,) keys1 = embs_loaded['keys'] #(11140, 2) embs = embs_loaded['data'] # torch.Size([11140, 4096]) print(embs_loaded['taxon_id'].size()) matching_indices = [i for i, (tid) in enumerate(emb_ids) if tid == taxon_id] print(matching_indices) taxon_embeddings = embs[matching_indices, :] # Get embeddings for the matching indices matching_keys = [keys1[i] for i in matching_indices] # Get the corresponding (taxon_id, text_type) keys print(f"Found {len(matching_keys)} embeddings for taxon ID {taxon_id}:") for i, key in enumerate(matching_keys): print(f"Text Type: {key[1]}, Embedding: {taxon_embeddings[i, :]}") return taxon_embeddings[i, :] def use_pregenerated_textemb_fromchris(taxon_id, text_type): #zero vector is for no text input text_embedding = torch.zeros(1,4096) if text_type is None or text_type == 'none': return text_embedding, 0 embs1 = torch.load('experiments/gpt_data.pt', map_location='cpu') emb_ids1 = embs1['taxon_id'].tolist() keys1 = embs1['keys'] embs1 = embs1['data'] taxa_of_interest = taxon_id taxa_index_of_interest = emb_ids1.index(taxa_of_interest) # gets 5 #keys_with_taxa_of_interest = [key for key in keys1 if key[0] == taxa_index_of_interest] #indices_with_taxa_of_interest = [(key, i) for i, key in enumerate(keys1) if key[0] == taxa_index_of_interest] possible_text_embedding_indexes = [i for i, key in enumerate(keys1) if key[0] == taxa_index_of_interest and key[1]==text_type] if len(possible_text_embedding_indexes) != 1: return text_embedding, 0 # take a look and choose what you want # for key in indices_with_taxa_of_interest: # print(key) # ((5, 'range'), 20) # ((5, 'habitat'), 21) # ((5, 'species_description'), 22) # ((5, 'overview_summary'), 23) #macaw: range: 20, habitat: 21 #baboon: range: 7928, habitat: 7929 #black&white warbler: range: 16, habitat: 17 #barn swallow: range: 1652, habitat: 1653 #pika: range: 7116, habitat: 7117 #loon: range: 11056, habitat:11057 #euro robin: range: 2020, habitat: 2021 #sfs: range: 7148, habitat: 7149 text_embedding_index = possible_text_embedding_indexes[0] text_embedding = embs1[text_embedding_index].unsqueeze(0) #print(text_embedding_index) return text_embedding, text_embedding_index def use_pregenerated_textemb_fromcsv(input_text): text_data = pd.read_csv('data/text_embs/text_embeddings_fig4.csv') result_row = text_data[text_data['Text'] == input_text] text_emb = ast.literal_eval(result_row['Embedding'].values[0]) embedding_tensor = torch.FloatTensor(text_emb) return embedding_tensor def get_eval_context_points(taxa_id, context_data, size): all_context_pts = context_data['locs'][context_data['labels'] == np.argwhere(context_data['class_to_taxa'] == taxa_id)[0]][1:] context_pts = all_context_pts[0:size] dummy_classtoken = np.array([[0,0]]) context_pts = np.vstack((dummy_classtoken, context_pts)) #print(f"context point shape: {np.shape(context_pts)}") normalized_pts = torch.from_numpy(context_pts) * torch.tensor([[1/180,1/90]], device='cpu') return normalized_pts if __name__ == '__main__': print('starting to generate text_embs') output_file = './data/text_embs/text_embeddings_fig4.csv' use_pregenerated_textemb_fromchris()