|
import json |
|
import os |
|
import os.path as osp |
|
import zipfile |
|
|
|
import numpy as np |
|
import pandas as pd |
|
import torch |
|
from huggingface_hub import hf_hub_download |
|
from langdetect import detect |
|
from ogb.nodeproppred import NodePropPredDataset |
|
from ogb.utils.url import download_url, extract_zip |
|
from tqdm import tqdm |
|
from typing import Union |
|
|
|
from stark_qa.skb.knowledge_base import SKB |
|
from stark_qa.tools.download_hf import download_hf_file, download_hf_folder |
|
from stark_qa.tools.io import load_files, save_files |
|
from stark_qa.tools.process_text import compact_text |
|
|
|
|
|
|
|
DATASET = { |
|
"repo": "snap-stanford/stark", |
|
'metadata': 'skb/mag/schema', |
|
'raw': 'skb/mag/idx_title_abs.zip', |
|
'processed': 'skb/mag/processed.zip' |
|
} |
|
|
|
RAW_DATA = { |
|
'ogbn_papers100M': 'https://snap.stanford.edu/ogb/data/misc/ogbn_papers100M/paperinfo.zip', |
|
'mag_mapping': 'https://zenodo.org/records/2628216/files' |
|
} |
|
|
|
class MagSKB(SKB): |
|
|
|
test_columns = ['title', 'abstract', 'text'] |
|
candidate_types = ['paper'] |
|
|
|
node_type_dict = {0: 'author', 1: 'institution', 2: 'field_of_study', 3: 'paper'} |
|
edge_type_dict = { |
|
0: 'author___affiliated_with___institution', |
|
1: 'paper___cites___paper', |
|
2: 'paper___has_topic___field_of_study', |
|
3: 'author___writes___paper' |
|
} |
|
node_attr_dict = { |
|
'paper': ['title', 'abstract', 'publication date', 'venue'], |
|
'author': ['name'], |
|
'institution': ['name'], |
|
'field_of_study': ['name'] |
|
} |
|
|
|
def __init__(self, |
|
root: Union[str, None] = None, |
|
download_processed: bool = True, |
|
**kwargs): |
|
""" |
|
Initialize the MagSKB class. |
|
|
|
Args: |
|
root (Union[str, None]): Root directory to store the dataset. If None, default HF cache paths will be used. |
|
download_processed (bool): Whether to download the processed data. |
|
""" |
|
self.root = root |
|
|
|
if download_processed: |
|
if (self.root is None) or (self.root is not None and not osp.exists(osp.join(self.root, 'processed', 'node_info.pkl'))): |
|
processed_path = hf_hub_download( |
|
DATASET["repo"], DATASET["processed"], repo_type="dataset" |
|
) |
|
if self.root is None: |
|
self.root = osp.dirname(processed_path) |
|
if not osp.exists(osp.join(self.root, 'processed', 'node_info.pkl')): |
|
with zipfile.ZipFile(processed_path, "r") as zip_ref: |
|
zip_ref.extractall(self.root) |
|
print(f"Extracting downloaded processed data to {self.root}") |
|
|
|
|
|
self.raw_data_dir = osp.join(self.root, 'raw') |
|
self.processed_data_dir = osp.join(self.root, 'processed') |
|
self.graph_data_root = osp.join(self.raw_data_dir, 'ogbn_mag') |
|
self.text_root = osp.join(self.raw_data_dir, 'ogbn_papers100M') |
|
|
|
|
|
self.schema_dir = osp.join(self.root, 'schema') |
|
if not osp.exists(self.schema_dir): |
|
download_hf_folder( |
|
DATASET["repo"], DATASET["metadata"], |
|
repo_type="dataset", save_as_folder=self.schema_dir |
|
) |
|
|
|
self.mag_mapping_dir = osp.join(self.graph_data_root, 'mag_mapping') |
|
self.ogbn_mag_mapping_dir = osp.join(self.graph_data_root, 'mapping') |
|
self.title_path = osp.join(self.text_root, 'paperinfo/idx_title.tsv') |
|
self.abstract_path = osp.join(self.text_root, 'paperinfo/idx_abs.tsv') |
|
|
|
|
|
self.mag_metadata_cache_dir = osp.join(self.processed_data_dir, 'mag_cache') |
|
self.paper100M_text_cache_dir = osp.join(self.processed_data_dir, 'paper100M_cache') |
|
self.merged_filtered_path = osp.join(self.paper100M_text_cache_dir, 'idx_title_abs.tsv') |
|
os.makedirs(self.mag_metadata_cache_dir, exist_ok=True) |
|
os.makedirs(self.paper100M_text_cache_dir, exist_ok=True) |
|
|
|
|
|
if osp.exists(osp.join(self.processed_data_dir, 'node_info.pkl')): |
|
print(f'Loading from {self.processed_data_dir}!') |
|
processed_data = load_files(self.processed_data_dir) |
|
else: |
|
print('Start processing raw data...') |
|
processed_data = self._process_raw() |
|
processed_data.update({ |
|
'node_type_dict': self.node_type_dict, |
|
'edge_type_dict': self.edge_type_dict |
|
}) |
|
super(MagSKB, self).__init__(**processed_data, **kwargs) |
|
|
|
def load_edge(self, edge_type: str) -> tuple: |
|
""" |
|
Load edge data for the specified edge type. |
|
|
|
Args: |
|
edge_type (str): Type of edge to load. |
|
|
|
Returns: |
|
tuple: A tuple containing edge tensor and edge numbers. |
|
""" |
|
edge_dir = osp.join(self.graph_data_root, f"raw/relations/{edge_type}/edge.csv.gz") |
|
edge_type_dir = osp.join(self.graph_data_root, f"raw/relations/{edge_type}/edge_reltype.csv.gz") |
|
num_dir = osp.join(self.graph_data_root, f"raw/relations/{edge_type}/num-edge-list.csv.gz") |
|
edge = pd.read_csv(edge_dir, names=['src', 'dst']) |
|
|
|
edge_t = pd.read_csv(edge_type_dir, names=['type']) |
|
edge_n = pd.read_csv(num_dir, names=['num']) |
|
edge_num = edge_n['num'].tolist() |
|
|
|
edge = [edge['src'].tolist(), edge['dst'].tolist(), edge_t['type'].tolist()] |
|
edge = torch.LongTensor(edge) |
|
|
|
return edge, edge_num |
|
|
|
def load_meta_data(self): |
|
""" |
|
Load metadata for the MAG dataset. |
|
|
|
Returns: |
|
tuple: DataFrames for authors, fields of study, institutions, and papers. |
|
""" |
|
mag_csv = {} |
|
if osp.exists(osp.join(self.mag_metadata_cache_dir, 'paper_data.csv')): |
|
print('Start loading MAG data from cache...') |
|
for t in ['author', 'institution', 'field_of_study', 'paper']: |
|
mag_csv[t] = pd.read_csv(osp.join(self.mag_metadata_cache_dir, f'{t}_data.csv')) |
|
author_data, paper_data = mag_csv['author'], mag_csv['paper'] |
|
field_of_study_data = mag_csv['field_of_study'] |
|
institution_data = mag_csv['institution'] |
|
print('Done!') |
|
else: |
|
print('Start loading MAG data, it might take a while...') |
|
full_attr_path = osp.join(self.schema_dir, 'mag.json') |
|
reduced_attr_path = osp.join(self.schema_dir, 'reduced_mag.json') |
|
|
|
full_attr = json.load(open(full_attr_path, 'r')) |
|
reduced_attr = json.load(open(reduced_attr_path, 'r')) |
|
|
|
loaded_csv = {} |
|
for key in reduced_attr.keys(): |
|
column_nums = [full_attr[key].index(i) for i in reduced_attr[key]] |
|
file = osp.join(self.mag_mapping_dir, key + '.txt.gz') |
|
if not osp.exists(file): |
|
try: |
|
download_url(f'{RAW_DATA["mag_mapping"]}/{key}.txt.gz', self.mag_mapping_dir) |
|
except Exception as error: |
|
print(f'Download failed or {key} data not found, please download from {RAW_DATA["mag_mapping"]} to {file}') |
|
raise error |
|
loaded_csv[key] = pd.read_csv(file, header=None, sep='\t', usecols=column_nums) |
|
loaded_csv[key].columns = reduced_attr[key] |
|
|
|
print('Processing and merging meta data...') |
|
author_data = pd.read_csv(osp.join(self.ogbn_mag_mapping_dir, "author_entidx2name.csv.gz"), names=['id', 'AuthorId'], skiprows=[0]) |
|
field_of_study_data = pd.read_csv(osp.join(self.ogbn_mag_mapping_dir, "field_of_study_entidx2name.csv.gz"), names=['id', 'FieldOfStudyId'], skiprows=[0]) |
|
institution_data = pd.read_csv(osp.join(self.ogbn_mag_mapping_dir, "institution_entidx2name.csv.gz"), names=['id', 'AffiliationId'], skiprows=[0]) |
|
paper_data = pd.read_csv(osp.join(self.ogbn_mag_mapping_dir, "paper_entidx2name.csv.gz"), names=['id', 'PaperId'], skiprows=[0]) |
|
|
|
loaded_csv['Papers'].rename(columns={'JournalId ': 'JournalId', 'Rank': 'PaperRank', 'CitationCount': 'PaperCitationCount'}, inplace=True) |
|
loaded_csv['Journals'].rename(columns={'DisplayName': 'JournalDisplayName', 'Rank': 'JournalRank', 'CitationCount': 'JournalCitationCount', 'PaperCount': 'JournalPaperCount'}, inplace=True) |
|
loaded_csv['ConferenceSeries'].rename(columns={'DisplayName': 'ConferenceSeriesDisplayName', 'Rank': 'ConferenceSeriesRank', 'CitationCount': 'ConferenceSeriesCitationCount', 'PaperCount': 'ConferenceSeriesPaperCount'}, inplace=True) |
|
loaded_csv['ConferenceInstances'].rename(columns={'DisplayName': 'ConferenceInstancesDisplayName', 'CitationCount': 'ConferenceInstanceCitationCount', 'PaperCount': 'ConferenceInstancesPaperCount'}, inplace=True) |
|
|
|
author_data = author_data.merge(loaded_csv['Authors'], on='AuthorId', how='left') |
|
field_of_study_data = field_of_study_data.merge(loaded_csv['FieldsOfStudy'], on='FieldOfStudyId', how='left') |
|
institution_data = institution_data.merge(loaded_csv['Affiliations'], on='AffiliationId', how='left') |
|
paper_data = paper_data.merge(loaded_csv['Papers'], on='PaperId', how='left') |
|
|
|
paper_data['JournalId'] = paper_data['JournalId'].apply(lambda x: float(x)).apply(lambda x: -1 if np.isnan(x) else int(x)) |
|
paper_data = paper_data.merge(loaded_csv['Journals'], on='JournalId', how='left') |
|
|
|
paper_data = paper_data.merge(loaded_csv['ConferenceSeries'], on='ConferenceSeriesId', how='left') |
|
|
|
paper_data['ConferenceInstanceId'] = paper_data['ConferenceInstanceId'].apply(lambda x: float(x)).apply(lambda x: -1 if np.isnan(x) else int(x)) |
|
paper_data = paper_data.merge(loaded_csv['ConferenceInstances'], on='ConferenceInstanceId', how='left') |
|
|
|
for csv_data in [author_data, field_of_study_data, institution_data, paper_data]: |
|
csv_data.columns = csv_data.columns.str.strip() |
|
for col in csv_data.columns: |
|
csv_data[col] = csv_data[col].apply(lambda x: -1 if isinstance(x, float) and np.isnan(x) else x) |
|
if 'rank' in col.lower() or 'count' in col.lower() or 'level' in col.lower() or 'year' in col.lower() or col.lower().endswith('id'): |
|
csv_data[col] = csv_data[col].apply(lambda x: int(x) if isinstance(x, float) else x) |
|
|
|
mag_csv = { |
|
'author': author_data, |
|
'institution': institution_data, |
|
'field_of_study': field_of_study_data, |
|
'paper': paper_data |
|
} |
|
|
|
for t in ['author', 'institution', 'field_of_study', 'paper']: |
|
mag_csv[t].to_csv(osp.join(self.mag_metadata_cache_dir, f'{t}_data.csv'), index=False) |
|
author_data, paper_data = mag_csv['author'], mag_csv['paper'] |
|
field_of_study_data = mag_csv['field_of_study'] |
|
institution_data = mag_csv['institution'] |
|
|
|
|
|
author_data['type'] = 'author' |
|
author_data.rename(columns={'id': 'id', 'AuthorId': 'mag_id'}, inplace=True) |
|
|
|
institution_data['type'] = 'institution' |
|
institution_data.rename(columns={'id': 'id', 'AffiliationId': 'mag_id'}, inplace=True) |
|
|
|
field_of_study_data['type'] = 'field_of_study' |
|
field_of_study_data.rename(columns={'id': 'id', 'FieldOfStudyId': 'mag_id'}, inplace=True) |
|
|
|
paper_data['type'] = 'paper' |
|
paper_data.rename(columns={'id': 'id', 'PaperId': 'mag_id'}, inplace=True) |
|
return author_data, field_of_study_data, institution_data, paper_data |
|
|
|
def load_english_paper_text(self, |
|
mag_ids: list, |
|
download_cache: bool = True) -> pd.DataFrame: |
|
""" |
|
Load English text data for the papers. |
|
|
|
Args: |
|
mag_ids (list): List of MAG IDs for the papers. |
|
download_cache (bool): Whether to download cached data. |
|
|
|
Returns: |
|
DataFrame: DataFrame containing English titles and abstracts. |
|
""" |
|
def is_english(text): |
|
try: |
|
return detect(text) == 'en' |
|
except: |
|
return False |
|
|
|
if not osp.exists(self.merged_filtered_path): |
|
if download_cache: |
|
merged_filtered_zip_path = self.merged_filtered_path.replace('tsv', 'zip') |
|
download_hf_file( |
|
DATASET["repo"], DATASET["raw"], |
|
repo_type="dataset", save_as_file=merged_filtered_zip_path |
|
) |
|
extract_zip(merged_filtered_zip_path, osp.dirname(self.merged_filtered_path)) |
|
else: |
|
if not osp.exists(self.title_path): |
|
raw_text_path = download_url(RAW_DATA['ogbn_papers100M'], self.text_root) |
|
extract_zip(raw_text_path, self.text_root) |
|
print('Start reading title...') |
|
title = pd.read_csv(self.title_path, sep='\t', header=None) |
|
title.columns = ["mag_id", "title"] |
|
print('Filtering titles in English...') |
|
|
|
|
|
title = title[title['mag_id'].apply(lambda x: x in mag_ids)] |
|
title_en = title[title['title'].apply(is_english)] |
|
|
|
print('Start reading abstract...') |
|
abstract = pd.read_csv(self.abstract_path, sep='\t', header=None) |
|
abstract.columns = ["mag_id", "abstract"] |
|
print('Filtering abstracts in English...') |
|
|
|
abstract = abstract[abstract['mag_id'].apply(lambda x: x in mag_ids)] |
|
abstract_en = abstract[abstract['abstract'].apply(is_english)] |
|
|
|
print('Start merging titles and abstracts...') |
|
title_abs_en = pd.merge(title, abstract, how="outer", on="mag_id", sort=True) |
|
title_abs_en.to_csv(self.merged_filtered_path, sep="\t", header=True, index=False) |
|
|
|
print('Loading merged and filtered titles and abstracts (English)...') |
|
title_abs_en = pd.read_csv(self.merged_filtered_path, sep='\t') |
|
title_abs_en.columns = ['mag_id', 'title', 'abstract'] |
|
print('Done!') |
|
|
|
return title_abs_en |
|
|
|
def get_map(self, df): |
|
""" |
|
Create mappings between MAG IDs and internal IDs. |
|
|
|
Args: |
|
df (DataFrame): DataFrame containing MAG IDs. |
|
|
|
Returns: |
|
tuple: Mappings from MAG IDs to internal IDs and vice versa. |
|
""" |
|
mag2id, id2mag = {}, {} |
|
for idx in range(len(df)): |
|
mag2id[df['mag_id'][idx]] = idx |
|
id2mag[idx] = df['mag_id'][idx] |
|
return mag2id, id2mag |
|
|
|
def get_doc_info(self, |
|
idx : int, |
|
compact: bool = False, |
|
add_rel: bool = False, |
|
n_rel: int = -1) -> str: |
|
""" |
|
Get document information for the specified node. |
|
|
|
Args: |
|
idx (int): Index of the node. |
|
compact (bool): Whether to compact the text. |
|
add_rel (bool): Whether to add relation information. |
|
n_rel (int): Number of relations to add. Default is -1 if all relations are included. |
|
|
|
Returns: |
|
str: Document information. |
|
""" |
|
node = self[idx] |
|
if node.type == 'author': |
|
doc = f'- author name: {node.DisplayName}\n' |
|
if node.PaperCount != -1: |
|
doc += f'- author paper count: {node.PaperCount}\n' |
|
if node.CitationCount != -1: |
|
doc += f'- author citation count: {node.CitationCount}\n' |
|
doc = doc.replace('-1', 'Unknown') |
|
|
|
elif node.type == 'paper': |
|
doc = f' - paper title: {node.title}\n' |
|
doc += ' - abstract: ' + node.abstract.replace('\r', '').rstrip('\n') + '\n' |
|
if str(node.Date) != '-1': |
|
doc += f' - publication date: {node.Date}\n' |
|
if str(node.OriginalVenue) != '-1': |
|
doc += f' - venue: {node.OriginalVenue}\n' |
|
elif str(node.JournalDisplayName) != '-1': |
|
doc += f' - journal: {node.JournalDisplayName}\n' |
|
elif str(node.ConferenceSeriesDisplayName) != '-1': |
|
doc += f' - conference: {node.ConferenceSeriesDisplayName}\n' |
|
elif str(node.ConferenceInstancesDisplayName) != '-1': |
|
doc += f' - conference: {node.ConferenceInstancesDisplayName}\n' |
|
|
|
elif node.type == 'field_of_study': |
|
doc = f' - field of study: {node.DisplayName}\n' |
|
if node.PaperCount != -1: |
|
doc += f'- field paper count: {node.PaperCount}\n' |
|
if node.CitationCount != -1: |
|
doc += f'- field citation count: {node.CitationCount}\n' |
|
doc = doc.replace('-1', 'Unknown') |
|
|
|
elif node.type == 'institution': |
|
doc = f' - institution: {node.DisplayName}\n' |
|
if node.PaperCount != -1: |
|
doc += f'- institution paper count: {node.PaperCount}\n' |
|
if node.CitationCount != -1: |
|
doc += f'- institution citation count: {node.CitationCount}\n' |
|
doc = doc.replace('-1', 'Unknown') |
|
|
|
if add_rel and node.type == 'paper': |
|
doc += self.get_rel_info(idx, n_rel=n_rel) |
|
|
|
if compact: |
|
doc = compact_text(doc) |
|
return doc |
|
|
|
def get_rel_info(self, |
|
idx: int, |
|
rel_types: Union[list, None] = None, |
|
n_rel: int = -1) -> str: |
|
""" |
|
Get relation information for the specified node. |
|
|
|
Args: |
|
idx (int): Index of the node. |
|
rel_types (Union[list, None]): List of relation types or None if all relation types are included. |
|
n_rel (int): Number of relations. Default is -1 if all relations are included. |
|
|
|
Returns: |
|
doc (str): Relation information. |
|
""" |
|
doc = '' |
|
rel_types = self.rel_type_lst() if rel_types is None else rel_types |
|
for edge_t in rel_types: |
|
node_ids = torch.LongTensor(self.get_neighbor_nodes(idx, edge_t)).tolist() |
|
if not node_ids: |
|
continue |
|
node_type = self.node_types[node_ids[0]] |
|
str_edge = edge_t.replace('___', ' ') |
|
doc += f"\n{str_edge}: " |
|
|
|
if n_rel > 0 and edge_t == 'paper___cites___paper': |
|
node_ids = node_ids[torch.randperm(len(node_ids))[:n_rel]].tolist() |
|
neighbors = [] |
|
for i in node_ids: |
|
if self[i].type == 'paper': |
|
neighbors.append(f'\"{self[i].title}\"') |
|
elif self[i].type == 'author': |
|
if str(self[i].DisplayName) != '-1': |
|
institutions = self.get_neighbor_nodes(i, "author___affiliated_with___institution") |
|
for inst in institutions: |
|
assert self[inst].type == 'institution' |
|
str_institutions = [self[j].DisplayName for j in institutions if str(self[j].DisplayName) != '-1'] |
|
if str_institutions: |
|
str_institutions = ', '.join(str_institutions) |
|
neighbors.append(f'{self[i].DisplayName} ({str_institutions})') |
|
else: |
|
neighbors.append(f'{self[i].DisplayName}') |
|
else: |
|
if str(self[i].DisplayName) != '-1': |
|
neighbors.append(f'{self[i].DisplayName}') |
|
neighbors = '(' + ', '.join(neighbors) + '),' |
|
doc += neighbors |
|
if doc: |
|
doc = '- relations:\n' + doc |
|
return doc |
|
|
|
def _process_raw(self): |
|
""" |
|
Process raw data for the MAG dataset. |
|
|
|
Returns: |
|
processed_data (dict): Processed data. |
|
""" |
|
NodePropPredDataset(name='ogbn-mag', root=self.raw_data_dir) |
|
author_data, field_of_study_data, institution_data, paper_data = self.load_meta_data() |
|
paper_text_data = self.load_english_paper_text(paper_data['mag_id'].tolist()) |
|
|
|
print('Processing graph data...') |
|
author_id_to_mag = {row['id']: row['mag_id'] for _, row in author_data.iterrows()} |
|
institution_id_to_mag = {row['id']: row['mag_id'] for _, row in institution_data.iterrows()} |
|
field_of_study_id_to_mag = {row['id']: row['mag_id'] for _, row in field_of_study_data.iterrows()} |
|
paper_mapping = pd.read_csv(osp.join(self.ogbn_mag_mapping_dir, "paper_entidx2name.csv.gz"), names=['id', 'mag_id'], skiprows=[0]) |
|
mag_to_paper_id, paper_id_to_mag = self.get_map(paper_mapping) |
|
|
|
unique_paper_id = paper_text_data['mag_id'].unique() |
|
unique_paper_id = torch.unique(torch.tensor(unique_paper_id)) |
|
node_type_edge = { |
|
0: 'author___writes___paper', |
|
2: 'paper___has_topic___field_of_study', |
|
3: 'paper___cites___paper' |
|
} |
|
node_type_overlapping_node = {} |
|
node_type_overlapping_edge = {} |
|
|
|
|
|
unique_paper_id_list = unique_paper_id.tolist() |
|
mapping_list = [mag_to_paper_id.get(k, k) for k in tqdm(unique_paper_id_list)] |
|
unique_paper_id = torch.tensor(mapping_list) |
|
|
|
|
|
print('Start loading edge data...') |
|
for node_type, paper_rel in node_type_edge.items(): |
|
print(node_type, paper_rel) |
|
edge, edge_num = self.load_edge(paper_rel) |
|
|
|
if node_type == 3: |
|
target_array = unique_paper_id.numpy() |
|
edge_array = edge.numpy() |
|
mask = np.isin(edge_array[0], target_array) & np.isin(edge_array[1], target_array) |
|
valid_edges_array = edge_array[:, mask] |
|
valid_edges_tensor = torch.from_numpy(valid_edges_array) |
|
node_type_overlapping_node[node_type] = unique_paper_id |
|
node_type_overlapping_edge[node_type] = valid_edges_tensor |
|
print(f'{node_type} has {unique_paper_id.shape[0]} nodes left, and {valid_edges_tensor.t().shape[0]} edges left.') |
|
continue |
|
else: |
|
edge = edge.t() |
|
connected_edges_list = [] |
|
for target_node in tqdm(unique_paper_id): |
|
|
|
if node_type == 0: |
|
mask = edge[:, 1] == target_node.item() |
|
current_connected_edges = edge[mask].clone() |
|
elif node_type == 2: |
|
mask = edge[:, 0] == target_node.item() |
|
current_connected_edges = edge[mask].clone() |
|
|
|
|
|
connected_edges_list.append(current_connected_edges) |
|
del mask |
|
del current_connected_edges |
|
|
|
connected_edges = torch.cat(connected_edges_list, dim=0) |
|
if node_type == 0: |
|
other_ends = torch.unique(connected_edges.t()[0]) |
|
elif node_type == 2: |
|
other_ends = torch.unique(connected_edges.t()[1]) |
|
|
|
node_type_overlapping_node[node_type] = other_ends |
|
node_type_overlapping_edge[node_type] = connected_edges.t() |
|
print(f'{node_type} has {other_ends.shape[0]} nodes left, and {connected_edges.shape[0]} edges left.') |
|
|
|
|
|
edge, edge_num = self.load_edge('author___affiliated_with___institution') |
|
edge = edge.t() |
|
connected_edges_list = [] |
|
for target_node in node_type_overlapping_node[0]: |
|
mask = edge[:, 0] == target_node |
|
current_connected_edges = edge[mask].clone() |
|
|
|
connected_edges_list.append(current_connected_edges) |
|
|
|
connected_edges = torch.cat(connected_edges_list, dim=0) |
|
other_ends = torch.unique(connected_edges.t()[1]) |
|
|
|
node_type_overlapping_node[1] = other_ends |
|
node_type_overlapping_edge[1] = connected_edges.t() |
|
print(f'1 has {other_ends.shape[0]} nodes left, and {connected_edges.shape[0]} edges left.') |
|
|
|
|
|
tot_n = sum([len(node_type_overlapping_node[i]) for i in range(4)]) |
|
|
|
|
|
domain_mappings = { |
|
0: author_id_to_mag, |
|
1: institution_id_to_mag, |
|
2: field_of_study_id_to_mag, |
|
3: paper_id_to_mag |
|
} |
|
new_domain_mappings = {} |
|
domain_old_to_new = {} |
|
id_to_mag = {} |
|
offset = 0 |
|
node_type_overlapping_node_sort = {k: node_type_overlapping_node[k] for k in sorted(node_type_overlapping_node.keys())} |
|
|
|
|
|
print('Start re-indexing...') |
|
for i, remain_node in node_type_overlapping_node_sort.items(): |
|
old_to_new_mappings = {key: id + offset for id, key in enumerate(remain_node.tolist())} |
|
updated_dict = {value: domain_mappings[i][key] for key, value in old_to_new_mappings.items()} |
|
print(f'{i} has {len(updated_dict)} nodes left') |
|
domain_old_to_new[i] = old_to_new_mappings |
|
id_to_mag.update(updated_dict) |
|
new_domain_mappings[i] = updated_dict |
|
offset += len(node_type_overlapping_node[i]) |
|
|
|
|
|
assert offset == tot_n |
|
edges_full = torch.cat([node_type_overlapping_edge[i] for i in range(4)], dim=1) |
|
|
|
|
|
|
|
d_of_mapping_dict = { |
|
0: [domain_old_to_new[0], domain_old_to_new[3]], |
|
1: [domain_old_to_new[0], domain_old_to_new[1]], |
|
2: [domain_old_to_new[3], domain_old_to_new[2]], |
|
3: [domain_old_to_new[3], domain_old_to_new[3]] |
|
} |
|
|
|
for i, remain_edge in tqdm(node_type_overlapping_edge.items()): |
|
edges = remain_edge[:2] |
|
edge_types = remain_edge[2] |
|
new_edges = edges.clone() |
|
dict1 = d_of_mapping_dict[i][0] |
|
dict2 = d_of_mapping_dict[i][1] |
|
|
|
|
|
for old, new in dict1.items(): |
|
new_edges[0, edges[0] == old] = new |
|
|
|
|
|
for old, new in dict2.items(): |
|
new_edges[1, edges[1] == old] = new |
|
|
|
final_edges = torch.cat([new_edges, edge_types.unsqueeze(0)], dim=0) |
|
node_type_overlapping_edge[i] = final_edges |
|
|
|
edges_final = torch.cat([node_type_overlapping_edge[i] for i in range(4)], dim=1) |
|
assert edges_final.shape == edges_full.shape |
|
edge_index = torch.LongTensor(edges_final[:2]) |
|
edge_types = torch.LongTensor(edges_final[2]) |
|
|
|
|
|
author_data['new_id'] = author_data['id'].map(domain_old_to_new[0]) |
|
author_data.dropna(subset=['new_id'], inplace=True) |
|
author_data['new_id'] = author_data['new_id'].astype(int) |
|
institution_data['new_id'] = institution_data['id'].map(domain_old_to_new[1]) |
|
institution_data.dropna(subset=['new_id'], inplace=True) |
|
institution_data['new_id'] = institution_data['new_id'].astype(int) |
|
field_of_study_data['new_id'] = field_of_study_data['id'].map(domain_old_to_new[2]) |
|
field_of_study_data.dropna(subset=['new_id'], inplace=True) |
|
field_of_study_data['new_id'] = field_of_study_data['new_id'].astype(int) |
|
paper_data['new_id'] = paper_data['id'].map(domain_old_to_new[3]) |
|
paper_data.dropna(subset=['new_id'], inplace=True) |
|
paper_data['new_id'] = paper_data['new_id'].astype(int) |
|
|
|
|
|
merged_df = pd.merge(paper_data, paper_text_data, on='mag_id', how='outer') |
|
merged_df.dropna(subset=['new_id'], inplace=True) |
|
merged_df['new_id'] = merged_df['new_id'].astype(int) |
|
merged_df['mag_id'] = merged_df['mag_id'].astype(int) |
|
merged_df = merged_df.drop_duplicates(subset=['new_id']) |
|
|
|
|
|
node_frame = {0: author_data, 1: institution_data, 2: field_of_study_data, 3: merged_df} |
|
node_info = {} |
|
node_types = [] |
|
for node_type, frame in tqdm(node_frame.items()): |
|
for idx, row in frame.iterrows(): |
|
|
|
node_info[row['new_id']] = row.to_dict() |
|
node_types.append(node_type) |
|
node_types = torch.tensor(node_types) |
|
if len(node_types) != tot_n: |
|
raise ValueError('node_types length does not match tot_n') |
|
|
|
processed_data = { |
|
'node_info': node_info, |
|
'edge_index': edge_index, |
|
'edge_types': edge_types, |
|
'node_types': node_types |
|
} |
|
|
|
print('Start saving processed data...') |
|
save_files(save_path=self.processed_data_dir, **processed_data) |
|
|
|
return processed_data |
|
|