Commit 2edc1f1e by zlj

fix some bugs found in data maker and uvm

parent 2d0be982
sampling:
- layer: 2
neighbor:
- 10
- 10
strategy: 'uniform'
prop_time: False
history: 1
duration: 0
num_thread: 32
memory:
- type: 'none'
dim_out: 0
gnn:
- arch: 'transformer_attention'
use_src_emb: False
use_dst_emb: False
layer: 2
att_head: 2
dim_time: 100
dim_out: 100
train:
- epoch: 100
batch_size: 600
lr: 0.0001
dropout: 0.1
att_dropout: 0.1
all_on_gpu: True
\ No newline at end of file
#pragma once
#include <head.h>
void random_sample_with_collision_check(int size, th::Tensor batch_src_node_id, th::Tensor batch_dst_node_id){
auto src_ptr = batch_src_node_id.data_ptr<NodeIDType>();
auto dst_ptr = batch_dst_node_id.data_ptr<NodeIDType>();
#pragma omp for num_threads(10)
for(int i = 0; i<size; i++){
}
}
\ No newline at end of file
......@@ -95,8 +95,8 @@ if e_feat is not None:
data.edge_attr = e_feat
data.train_mask = (torch.from_numpy(np.array(df.ext_roll.values)) == 0)
data.test_mask = (torch.from_numpy(np.array(df.ext_roll.values)) == 1)
data.val_mask = (torch.from_numpy(np.array(df.ext_roll.values)) == 2)
data.val_mask = (torch.from_numpy(np.array(df.ext_roll.values)) == 1)
data.test_mask = (torch.from_numpy(np.array(df.ext_roll.values)) == 2)
sample_graph['train_mask'] = data.train_mask[sample_eid]
sample_graph['test_mask'] = data.test_mask[sample_eid]
sample_graph['val_mask'] = data.val_mask[sample_eid]
......@@ -106,24 +106,26 @@ data.y = torch.zeros(edge_index.shape[1])
edge_index_dict = {}
edge_index_dict['edata'] = data.edge_index
edge_index_dict['sample_data'] = data.sample_graph['edge_index']
edge_index_dict['neg_data'] = torch.cat([neg_src.view(1, -1),
dst.view(-1, 1).repeat(1, neg_nums).
reshape(1, -1)], dim=0)
#edge_index_dict['neg_data'] = torch.cat([neg_src.view(1, -1),
# dst.view(-1, 1).repeat(1, neg_nums).
# reshape(1, -1)], dim=0)
data.edge_index_dict = edge_index_dict
edge_weight_dict = {}
edge_weight_dict['edata'] = 2*neg_nums
edge_weight_dict['edata'] = 1*neg_nums
edge_weight_dict['sample_data'] = 1*neg_nums
edge_weight_dict['neg_data'] = 1
#partition_save('./dataset/here/'+data_name, data, 1, 'metis_for_tgnn',
# edge_weight_dict=edge_weight_dict)
#partition_save('./dataset/here/'+data_name, data, 2, 'metis_for_tgnn',
# edge_weight_dict=edge_weight_dict)
#partition_save('./dataset/here/'+data_name, data, 4, 'metis_for_tgnn',
# edge_weight_dict=edge_weight_dict)
#partition_save('./dataset/here/'+data_name, data, 8, 'metis_for_tgnn',
# edge_weight_dict=edge_weight_dict)
partition_save('./dataset/here/'+data_name, data, 16, 'metis_for_tgnn',
#edge_weight_dict['neg_data'] = 1
partition_save('/mnt/data/part_data/v2/here/'+data_name, data, 1, 'metis_for_tgnn',
edge_weight_dict=edge_weight_dict)
partition_save('/mnt/data/part_data/v2/here/'+data_name, data, 2, 'metis_for_tgnn',
edge_weight_dict=edge_weight_dict)
partition_save('/mnt/data/part_data/v2/here/'+data_name, data, 4, 'metis_for_tgnn',
edge_weight_dict=edge_weight_dict)
partition_save('/mnt/data/part_data/v2/here/'+data_name, data, 8, 'metis_for_tgnn',
edge_weight_dict=edge_weight_dict)
#partition_save('./dataset/here/'+data_name, data, 16, 'metis_for_tgnn',
# edge_weight_dict=edge_weight_dict)
#
# partition_save('./dataset/here/'+data_name, data, 4, 'metis_for_tgnn',
# edge_weight_dict=edge_weight_dict )
......
import argparse
from tgb.linkproppred.dataset_pyg import PyGLinkPropPredDataset
parser = argparse.ArgumentParser(
description="RPC Reinforcement Learning Example",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument('--data_name', default='tgbl-wiki', type=str, metavar='W',
help='name of dataset')
parser.add_argument('--num_neg_sample', default=1, type=int, metavar='W',
help='number of negative samples')
args = parser.parse_args()
name = args.data_name
dataset = PyGLinkPropPredDataset(name=name, root="datasets")
print(dataset)
\ No newline at end of file
import argparse
import os
import sys
from os.path import abspath, join, dirname
from starrygl.distributed.context import DistributedContext
from starrygl.distributed.utils import DistIndex
from starrygl.module.modules import GeneralModel
from pathlib import Path
from pathlib import Path
from starrygl.module.utils import parse_config
from starrygl.sample.cache.fetch_cache import FetchFeatureCache
from starrygl.sample.graph_core import DataSet, DistributedGraphStore, TemporalNeighborSampleGraph
from starrygl.module.utils import parse_config, EarlyStopMonitor
from starrygl.sample.graph_core import DataSet, DistributedGraphStore, TemporalNeighborSampleGraph
from starrygl.sample.memory.shared_mailbox import SharedMailBox
from starrygl.sample.sample_core.base import NegativeSampling
from starrygl.sample.sample_core.neighbor_sampler import NeighborSampler
from starrygl.sample.part_utils.partition_tgnn import partition_load
import torch
import time
import torch
import torch.nn.functional as F
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group
import os
from starrygl.sample.data_loader import DistributedDataLoader
from starrygl.sample.batch_data import SAMPLE_TYPE
from starrygl.sample.stream_manager import getPipelineManger
parser = argparse.ArgumentParser(
description="RPC Reinforcement Learning Example",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument('--rank', default=0, type=int, metavar='W',
help='name of dataset')
parser.add_argument('--patience', type=int, default=5, help='Patience for early stopping')
parser.add_argument('--world_size', default=1, type=int, metavar='W',
help='number of negative samples')
parser.add_argument('--dataname', default=1, type=str, metavar='W',
help='name of dataset')
parser.add_argument('--model', default='TGN', type=str, metavar='W',
help='name of model')
args = parser.parse_args()
from sklearn.metrics import average_precision_score, roc_auc_score
import torch
import time
import random
import dgl
import numpy as np
from sklearn.metrics import average_precision_score, roc_auc_score
from torch.nn.parallel import DistributedDataParallel as DDP
#os.environ['CUDA_VISIBLE_DEVICES'] = str(args.rank)
#os.environ["RANK"] = str(args.rank)
#os.environ["WORLD_SIZE"] = str(args.world_size)
#os.environ["LOCAL_RANK"] = str(0)
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
os.environ["MASTER_ADDR"] = '10.214.211.187'
os.environ["MASTER_PORT"] = '9337'
def seed_everything(seed=42):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
seed_everything(1234)
def main():
print('main')
use_cuda = True
sample_param, memory_param, gnn_param, train_param = parse_config('./config/{}.yml'.format(args.model))
torch.set_num_threads(12)
ctx = DistributedContext.init(backend="nccl", use_gpu=True)
device_id = torch.cuda.current_device()
print('use cuda on',device_id)
pdata = partition_load("/mnt/data/part_data/dataset/here/{}".format(args.dataname), algo="metis_for_tgnn")
graph = DistributedGraphStore(pdata = pdata)
Path("./saved_models/").mkdir(parents=True, exist_ok=True)
Path("./saved_checkpoints/").mkdir(parents=True, exist_ok=True)
get_checkpoint_path = lambda \
epoch: f'./saved_checkpoints/{args.model}-{args.dataname}-{epoch}.pth'
gnn_param['dyrep'] = True if args.model == 'DyRep' else False
use_src_emb = gnn_param['use_src_emb'] if 'use_src_emb' in gnn_param else False
use_dst_emb = gnn_param['use_dst_emb'] if 'use_dst_emb' in gnn_param else False
sample_graph = TemporalNeighborSampleGraph(sample_graph = pdata.sample_graph,mode = 'full')
mailbox = SharedMailBox(pdata.ids.shape[0], memory_param, dim_edge_feat = pdata.edge_attr.shape[1] if pdata.edge_attr is not None else 0)
sampler = NeighborSampler(num_nodes=graph.num_nodes, num_layers=1, fanout=[10],graph_data=sample_graph, workers=15,policy = 'recent',graph_name = "wiki_train")
neg_sampler = NegativeSampling('triplet')
gnn_dim_node = 0 if graph.x is None else pdata.x.shape[1]
gnn_dim_edge = 0 if graph.edge_attr is None else pdata.edge_attr.shape[1]
print(gnn_dim_node,gnn_dim_edge)
avg_time = 0
MODEL_SAVE_PATH = f'./saved_models/{args.model}-{args.dataname}.pth'
if use_cuda:
model = GeneralModel(gnn_dim_node, gnn_dim_edge, sample_param, memory_param, gnn_param, train_param).cuda()
device = torch.device('cuda')
else:
model = GeneralModel(gnn_dim_node, gnn_dim_edge, sample_param, memory_param, gnn_param, train_param)
device = torch.device('cpu')
model.load_state_dict(torch.load(MODEL_SAVE_PATH))
\ No newline at end of file
python data_maker.py --data_name WIKI
wait
python data_maker.py --data_name REDDIT
wait
python data_maker.py --data_name MOOC
wait
python data_maker.py --data_name LASTFM
wait
python data_maker.py --data_name DGraphFin
wait
python data_maker.py --data_name ML25M
wait
python data_maker.py --data_name TaoBao
wait
python data_maker.py --data_name GDELT
wait
import pandas as pd
import numpy as np
import os
import torch
def get_link_prediction_data(data_name: str, val_ratio: float, test_ratio: float):
"""
generate data for link prediction task (inductive & transductive settings)
:param dataset_name: str, dataset name
:param val_ratio: float, validation data ratio
:param test_ratio: float, test data ratio
:return: node_raw_features, edge_raw_features, (np.ndarray),
full_data, train_data, val_data, test_data, new_node_val_data, new_node_test_data, (Data object)
"""
# Load data and train val test split
graph_df = pd.read_csv('/mnt/nfs/fzz/TGL-DATA/'+data_name+'/edges.csv')
if os.path.exists('/mnt/nfs/fzz/TGL-DATA/'+data_name+'/node_features.pt'):
n_feat = torch.load('/mnt/nfs/fzz/TGL-DATA/'+data_name+'/node_features.pt')
else:
n_feat = None
if os.path.exists('/mnt/nfs/fzz/TGL-DATA/'+data_name+'/edge_features.pt'):
e_feat = torch.load('/mnt/nfs/fzz/TGL-DATA/'+data_name+'/edge_features.pt')
else:
e_feat = None
# get the timestamp of validate and test set
src_node_ids = torch.from_numpy(np.array(graph_df.src.values)).long()
dst_node_ids = torch.from_numpy(np.array(graph_df.dst.values)).long()
node_interact_times = torch.from_numpy(np.array(graph_df.time.values)).long()
train_mask = (torch.from_numpy(np.array(graph_df.ext_roll.values)) == 0)
test_mask = (torch.from_numpy(np.array(graph_df.ext_roll.values)) == 1)
val_mask = (torch.from_numpy(np.array(graph_df.ext_roll.values)) == 2)
val_time, test_time = list(np.quantile(graph_df.ts, [(1 - val_ratio - test_ratio), (1 - test_ratio)]))
# the setting of seed follows previous works
unique_node_ids = torch.cat((src_node_ids,dst_node_ids)).unique()
torch.manual_seed(2020)
test_node_set = torch.cat((src_node_ids[test_mask],dst_node_ids[test_mask])).unique()
train_node_set = torch.cat((src_node_ids[train_mask],dst_node_ids[train_mask])).unique()
new_test_node_set = torch.randint(0,test_node_set.shape[0],int(0.1*unique_node_ids.shape[0]))
new_test_node_set = test_node_set[new_test_node_set]
# mask for each source and destination to denote whether they are new test nodes
new_test_source_mask = graph_df.src.map(lambda x: x in new_test_node_set).values
new_test_destination_mask = graph_df.dst.map(lambda x: x in new_test_node_set).values
# mask, which is true for edges with both destination and source not being new test nodes (because we want to remove all edges involving any new test node)
observed_edges_mask = torch.from_numpy(np.logical_and(~new_test_source_mask, ~new_test_destination_mask)).long()
mask = torch.isin(unique_node_ids,train_node_set,invert = True)
new_node_set = unique_node_ids[mask]
edge_contains_new_node_mask = np.array([(src_node_id in new_node_set or dst_node_id in new_node_set)
for src_node_id, dst_node_id in zip(src_node_ids, dst_node_ids)])
new_node_val_mask = np.logical_and(val_mask, edge_contains_new_node_mask)
new_node_test_mask = np.logical_and(test_mask, edge_contains_new_node_mask)
# validation and test data
val_data = Data(src_node_ids=src_node_ids[val_mask], dst_node_ids=dst_node_ids[val_mask],
node_interact_times=node_interact_times[val_mask], edge_ids=edge_ids[val_mask], labels=labels[val_mask])
test_data = Data(src_node_ids=src_node_ids[test_mask], dst_node_ids=dst_node_ids[test_mask],
node_interact_times=node_interact_times[test_mask], edge_ids=edge_ids[test_mask], labels=labels[test_mask])
# validation and test with edges that at least has one new node (not in training set)
new_node_val_data = Data(src_node_ids=src_node_ids[new_node_val_mask], dst_node_ids=dst_node_ids[new_node_val_mask],
node_interact_times=node_interact_times[new_node_val_mask],
edge_ids=edge_ids[new_node_val_mask], labels=labels[new_node_val_mask])
new_node_test_data = Data(src_node_ids=src_node_ids[new_node_test_mask], dst_node_ids=dst_node_ids[new_node_test_mask],
node_interact_times=node_interact_times[new_node_test_mask],
edge_ids=edge_ids[new_node_test_mask], labels=labels[new_node_test_mask])
return node_raw_features, edge_raw_features, full_data, train_data, val_data, test_data, new_node_val_data, new_node_test_data
import torch
from sklearn.metrics import average_precision_score, roc_auc_score
def get_link_prediction_metrics(predicts: torch.Tensor, labels: torch.Tensor):
"""
get metrics for the link prediction task
:param predicts: Tensor, shape (num_samples, )
:param labels: Tensor, shape (num_samples, )
:return:
dictionary of metrics {'metric_name_1': metric_1, ...}
"""
predicts = predicts.cpu().detach().numpy()
labels = labels.cpu().numpy()
average_precision = average_precision_score(y_true=labels, y_score=predicts)
roc_auc = roc_auc_score(y_true=labels, y_score=predicts)
return {'average_precision': average_precision, 'roc_auc': roc_auc}
def get_node_classification_metrics(predicts: torch.Tensor, labels: torch.Tensor):
"""
get metrics for the node classification task
:param predicts: Tensor, shape (num_samples, )
:param labels: Tensor, shape (num_samples, )
:return:
dictionary of metrics {'metric_name_1': metric_1, ...}
"""
predicts = predicts.cpu().detach().numpy()
labels = labels.cpu().numpy()
roc_auc = roc_auc_score(y_true=labels, y_score=predicts)
return {'roc_auc': roc_auc}
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from tqdm import tqdm
import numpy as np
import logging
import time
import argparse
import os
import json
from models.EdgeBank import edge_bank_link_prediction
from starrygl.evaluation.metrics import get_link_prediction_metrics, get_node_classification_metrics
from utils.utils import set_random_seed
from starrygl.sample.sample_core import EvaluateNegativeSampling
from utils.DataLoader import Data
def evaluate_model_link_prediction(model_name: str, model: nn.Module, neighbor_sampler: NeighborSampler, evaluate_idx_data_loader: DataLoader,
evaluate_neg_edge_sampler: NegativeEdgeSampler, evaluate_data: Data, loss_func: nn.Module,
num_neighbors: int = 20, time_gap: int = 2000):
"""
evaluate models on the link prediction task
:param model_name: str, name of the model
:param model: nn.Module, the model to be evaluated
:param neighbor_sampler: NeighborSampler, neighbor sampler
:param evaluate_idx_data_loader: DataLoader, evaluate index data loader
:param evaluate_neg_edge_sampler: NegativeEdgeSampler, evaluate negative edge sampler
:param evaluate_data: Data, data to be evaluated
:param loss_func: nn.Module, loss function
:param num_neighbors: int, number of neighbors to sample for each node
:param time_gap: int, time gap for neighbors to compute node features
:return:
"""
# Ensures the random sampler uses a fixed seed for evaluation (i.e. we always sample the same negatives for validation / test set)
assert evaluate_neg_edge_sampler.seed is not None
evaluate_neg_edge_sampler.reset_random_state()
if model_name in ['DyRep', 'TGAT', 'TGN', 'CAWN', 'TCL', 'GraphMixer', 'DyGFormer']:
# evaluation phase use all the graph information
model[0].set_neighbor_sampler(neighbor_sampler)
model.eval()
with torch.no_grad():
# store evaluate losses and metrics
evaluate_losses, evaluate_metrics = [], []
evaluate_idx_data_loader_tqdm = tqdm(evaluate_idx_data_loader, ncols=120)
for batch_idx, evaluate_data_indices in enumerate(evaluate_idx_data_loader_tqdm):
evaluate_data_indices = evaluate_data_indices.numpy()
batch_src_node_ids, batch_dst_node_ids, batch_node_interact_times, batch_edge_ids = \
evaluate_data.src_node_ids[evaluate_data_indices], evaluate_data.dst_node_ids[evaluate_data_indices], \
evaluate_data.node_interact_times[evaluate_data_indices], evaluate_data.edge_ids[evaluate_data_indices]
if evaluate_neg_edge_sampler.negative_sample_strategy != 'random':
batch_neg_src_node_ids, batch_neg_dst_node_ids = evaluate_neg_edge_sampler.sample(size=len(batch_src_node_ids),
batch_src_node_ids=batch_src_node_ids,
batch_dst_node_ids=batch_dst_node_ids,
current_batch_start_time=batch_node_interact_times[0],
current_batch_end_time=batch_node_interact_times[-1])
else:
_, batch_neg_dst_node_ids = evaluate_neg_edge_sampler.sample(size=len(batch_src_node_ids))
batch_neg_src_node_ids = batch_src_node_ids
# we need to compute for positive and negative edges respectively, because the new sampling strategy (for evaluation) allows the negative source nodes to be
# different from the source nodes, this is different from previous works that just replace destination nodes with negative destination nodes
if model_name in ['TGAT', 'CAWN', 'TCL']:
# get temporal embedding of source and destination nodes
# two Tensors, with shape (batch_size, node_feat_dim)
batch_src_node_embeddings, batch_dst_node_embeddings = \
model[0].compute_src_dst_node_temporal_embeddings(src_node_ids=batch_src_node_ids,
dst_node_ids=batch_dst_node_ids,
node_interact_times=batch_node_interact_times,
num_neighbors=num_neighbors)
# get temporal embedding of negative source and negative destination nodes
# two Tensors, with shape (batch_size, node_feat_dim)
batch_neg_src_node_embeddings, batch_neg_dst_node_embeddings = \
model[0].compute_src_dst_node_temporal_embeddings(src_node_ids=batch_neg_src_node_ids,
dst_node_ids=batch_neg_dst_node_ids,
node_interact_times=batch_node_interact_times,
num_neighbors=num_neighbors)
elif model_name in ['JODIE', 'DyRep', 'TGN']:
# note that negative nodes do not change the memories while the positive nodes change the memories,
# we need to first compute the embeddings of negative nodes for memory-based models
# get temporal embedding of negative source and negative destination nodes
# two Tensors, with shape (batch_size, node_feat_dim)
batch_neg_src_node_embeddings, batch_neg_dst_node_embeddings = \
model[0].compute_src_dst_node_temporal_embeddings(src_node_ids=batch_neg_src_node_ids,
dst_node_ids=batch_neg_dst_node_ids,
node_interact_times=batch_node_interact_times,
edge_ids=None,
edges_are_positive=False,
num_neighbors=num_neighbors)
# get temporal embedding of source and destination nodes
# two Tensors, with shape (batch_size, node_feat_dim)
batch_src_node_embeddings, batch_dst_node_embeddings = \
model[0].compute_src_dst_node_temporal_embeddings(src_node_ids=batch_src_node_ids,
dst_node_ids=batch_dst_node_ids,
node_interact_times=batch_node_interact_times,
edge_ids=batch_edge_ids,
edges_are_positive=True,
num_neighbors=num_neighbors)
elif model_name in ['GraphMixer']:
# get temporal embedding of source and destination nodes
# two Tensors, with shape (batch_size, node_feat_dim)
batch_src_node_embeddings, batch_dst_node_embeddings = \
model[0].compute_src_dst_node_temporal_embeddings(src_node_ids=batch_src_node_ids,
dst_node_ids=batch_dst_node_ids,
node_interact_times=batch_node_interact_times,
num_neighbors=num_neighbors,
time_gap=time_gap)
# get temporal embedding of negative source and negative destination nodes
# two Tensors, with shape (batch_size, node_feat_dim)
batch_neg_src_node_embeddings, batch_neg_dst_node_embeddings = \
model[0].compute_src_dst_node_temporal_embeddings(src_node_ids=batch_neg_src_node_ids,
dst_node_ids=batch_neg_dst_node_ids,
node_interact_times=batch_node_interact_times,
num_neighbors=num_neighbors,
time_gap=time_gap)
elif model_name in ['DyGFormer']:
# get temporal embedding of source and destination nodes
# two Tensors, with shape (batch_size, node_feat_dim)
batch_src_node_embeddings, batch_dst_node_embeddings = \
model[0].compute_src_dst_node_temporal_embeddings(src_node_ids=batch_src_node_ids,
dst_node_ids=batch_dst_node_ids,
node_interact_times=batch_node_interact_times)
# get temporal embedding of negative source and negative destination nodes
# two Tensors, with shape (batch_size, node_feat_dim)
batch_neg_src_node_embeddings, batch_neg_dst_node_embeddings = \
model[0].compute_src_dst_node_temporal_embeddings(src_node_ids=batch_neg_src_node_ids,
dst_node_ids=batch_neg_dst_node_ids,
node_interact_times=batch_node_interact_times)
else:
raise ValueError(f"Wrong value for model_name {model_name}!")
# get positive and negative probabilities, shape (batch_size, )
positive_probabilities = model[1](input_1=batch_src_node_embeddings, input_2=batch_dst_node_embeddings).squeeze(dim=-1).sigmoid()
negative_probabilities = model[1](input_1=batch_neg_src_node_embeddings, input_2=batch_neg_dst_node_embeddings).squeeze(dim=-1).sigmoid()
predicts = torch.cat([positive_probabilities, negative_probabilities], dim=0)
labels = torch.cat([torch.ones_like(positive_probabilities), torch.zeros_like(negative_probabilities)], dim=0)
loss = loss_func(input=predicts, target=labels)
evaluate_losses.append(loss.item())
evaluate_metrics.append(get_link_prediction_metrics(predicts=predicts, labels=labels))
evaluate_idx_data_loader_tqdm.set_description(f'evaluate for the {batch_idx + 1}-th batch, evaluate loss: {loss.item()}')
return evaluate_losses, evaluate_metrics
def evaluate_edge_bank_link_prediction(args: argparse.Namespace, train_data: Data, val_data: Data, test_idx_data_loader: DataLoader,
test_neg_edge_sampler: NegativeEdgeSampler, test_data: Data):
"""
evaluate the EdgeBank model for link prediction
:param args: argparse.Namespace, configuration
:param train_data: Data, train data
:param val_data: Data, validation data
:param test_idx_data_loader: DataLoader, test index data loader
:param test_neg_edge_sampler: NegativeEdgeSampler, test negative edge sampler
:param test_data: Data, test data
:return:
"""
# generate the train_validation split of the data: needed for constructing the memory for EdgeBank
train_val_data = Data(src_node_ids=np.concatenate([train_data.src_node_ids, val_data.src_node_ids]),
dst_node_ids=np.concatenate([train_data.dst_node_ids, val_data.dst_node_ids]),
node_interact_times=np.concatenate([train_data.node_interact_times, val_data.node_interact_times]),
edge_ids=np.concatenate([train_data.edge_ids, val_data.edge_ids]),
labels=np.concatenate([train_data.labels, val_data.labels]))
test_metric_all_runs = []
for run in range(args.num_runs):
set_random_seed(seed=run)
args.seed = run
args.save_result_name = f'{args.negative_sample_strategy}_negative_sampling_{args.model_name}_seed{args.seed}'
# set up logger
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)
os.makedirs(f"./logs/{args.model_name}/{args.dataset_name}/{args.save_result_name}/", exist_ok=True)
# create file handler that logs debug and higher level messages
fh = logging.FileHandler(f"./logs/{args.model_name}/{args.dataset_name}/{args.save_result_name}/{str(time.time())}.log")
fh.setLevel(logging.DEBUG)
# create console handler with a higher log level
ch = logging.StreamHandler()
ch.setLevel(logging.WARNING)
# create formatter and add it to the handlers
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
fh.setFormatter(formatter)
ch.setFormatter(formatter)
# add the handlers to logger
logger.addHandler(fh)
logger.addHandler(ch)
run_start_time = time.time()
logger.info(f"********** Run {run + 1} starts. **********")
logger.info(f'configuration is {args}')
loss_func = nn.BCELoss()
# evaluate EdgeBank
logger.info(f'get final performance on dataset {args.dataset_name}...')
# Ensures the random sampler uses a fixed seed for evaluation (i.e. we always sample the same negatives for validation / test set)
assert test_neg_edge_sampler.seed is not None
test_neg_edge_sampler.reset_random_state()
test_losses, test_metrics = [], []
test_idx_data_loader_tqdm = tqdm(test_idx_data_loader, ncols=120)
for batch_idx, test_data_indices in enumerate(test_idx_data_loader_tqdm):
test_data_indices = test_data_indices.numpy()
batch_src_node_ids, batch_dst_node_ids, batch_node_interact_times = \
test_data.src_node_ids[test_data_indices], test_data.dst_node_ids[test_data_indices], \
test_data.node_interact_times[test_data_indices]
if test_neg_edge_sampler.negative_sample_strategy != 'random':
batch_neg_src_node_ids, batch_neg_dst_node_ids = test_neg_edge_sampler.sample(size=len(batch_src_node_ids),
batch_src_node_ids=batch_src_node_ids,
batch_dst_node_ids=batch_dst_node_ids,
current_batch_start_time=batch_node_interact_times[0],
current_batch_end_time=batch_node_interact_times[-1])
else:
_, batch_neg_dst_node_ids = test_neg_edge_sampler.sample(size=len(batch_src_node_ids))
batch_neg_src_node_ids = batch_src_node_ids
positive_edges = (batch_src_node_ids, batch_dst_node_ids)
negative_edges = (batch_neg_src_node_ids, batch_neg_dst_node_ids)
# incorporate the testing data before the current batch to history_data, which is similar to memory-based models
history_data = Data(src_node_ids=np.concatenate([train_val_data.src_node_ids, test_data.src_node_ids[: test_data_indices[0]]]),
dst_node_ids=np.concatenate([train_val_data.dst_node_ids, test_data.dst_node_ids[: test_data_indices[0]]]),
node_interact_times=np.concatenate([train_val_data.node_interact_times, test_data.node_interact_times[: test_data_indices[0]]]),
edge_ids=np.concatenate([train_val_data.edge_ids, test_data.edge_ids[: test_data_indices[0]]]),
labels=np.concatenate([train_val_data.labels, test_data.labels[: test_data_indices[0]]]))
# perform link prediction for EdgeBank
positive_probabilities, negative_probabilities = edge_bank_link_prediction(history_data=history_data,
positive_edges=positive_edges,
negative_edges=negative_edges,
edge_bank_memory_mode=args.edge_bank_memory_mode,
time_window_mode=args.time_window_mode,
time_window_proportion=args.test_ratio)
predicts = torch.from_numpy(np.concatenate([positive_probabilities, negative_probabilities])).float()
labels = torch.cat([torch.ones(len(positive_probabilities)), torch.zeros(len(negative_probabilities))], dim=0)
loss = loss_func(input=predicts, target=labels)
test_losses.append(loss.item())
test_metrics.append(get_link_prediction_metrics(predicts=predicts, labels=labels))
test_idx_data_loader_tqdm.set_description(f'test for the {batch_idx + 1}-th batch, test loss: {loss.item()}')
# store the evaluation metrics at the current run
test_metric_dict = {}
logger.info(f'test loss: {np.mean(test_losses):.4f}')
for metric_name in test_metrics[0].keys():
average_test_metric = np.mean([test_metric[metric_name] for test_metric in test_metrics])
logger.info(f'test {metric_name}, {average_test_metric:.4f}')
test_metric_dict[metric_name] = average_test_metric
single_run_time = time.time() - run_start_time
logger.info(f'Run {run + 1} cost {single_run_time:.2f} seconds.')
test_metric_all_runs.append(test_metric_dict)
# avoid the overlap of logs
if run < args.num_runs - 1:
logger.removeHandler(fh)
logger.removeHandler(ch)
# save model result
result_json = {
"test metrics": {metric_name: f'{test_metric_dict[metric_name]:.4f}'for metric_name in test_metric_dict}
}
result_json = json.dumps(result_json, indent=4)
save_result_folder = f"./saved_results/{args.model_name}/{args.dataset_name}"
os.makedirs(save_result_folder, exist_ok=True)
save_result_path = os.path.join(save_result_folder, f"{args.save_result_name}.json")
with open(save_result_path, 'w') as file:
file.write(result_json)
logger.info(f'save negative sampling results at {save_result_path}')
# store the average metrics at the log of the last run
logger.info(f'metrics over {args.num_runs} runs:')
for metric_name in test_metric_all_runs[0].keys():
logger.info(f'test {metric_name}, {[test_metric_single_run[metric_name] for test_metric_single_run in test_metric_all_runs]}')
logger.info(f'average test {metric_name}, {np.mean([test_metric_single_run[metric_name] for test_metric_single_run in test_metric_all_runs]):.4f} '
f'± {np.std([test_metric_single_run[metric_name] for test_metric_single_run in test_metric_all_runs], ddof=1):.4f}')
......@@ -25,8 +25,8 @@ delta_ts: list[tensor,tensor, tensor...]
metadata
"""
def prepare_input(node_feat, edge_feat, mem_embedding,mfgs,dist_nid,dist_eid):
for mfg in mfgs:
for i,b in enumerate(mfg):
for i,mfg in enumerate(mfgs):
for b in mfg:
e_idx = b.edata['ID']
idx = b.srcdata['ID']
b.edata['ID'] = dist_eid[e_idx]
......@@ -52,6 +52,7 @@ def to_block(graph: DistributedGraphStore, data, sample_out, mailbox:MailBox = N
metadata = None
eid = [ret.eid() for ret in sample_out]
eid_len = [e.shape[0] for e in eid ]
#print(len(sample_out),eid,eid_len)
eid_mapper: torch.Tensor = graph.eids_mapper
nid_mapper: torch.Tensor = graph.nids_mapper
eid_tensor = torch.cat(eid,dim = 0).to(eid_mapper.device)
......
......@@ -7,7 +7,7 @@ import torch
import torch.distributed as dist
from torch_geometric.data import Data
from starrygl.utils.uvm import *
class DistributedGraphStore:
'''
......@@ -46,17 +46,18 @@ class DistributedGraphStore:
self.uvm_edge = uvm_edge
if hasattr(pdata,'x') and pdata.x is not None:
ctx = DistributedContext.get_default_context()
pdata.x = pdata.x.to(torch.float)
if uvm_node == False :
x = pdata.x.to(self.device)
else:
if self.device.type == 'cuda':
x = starrygl.utils.uvm.uvm_empty(*pdata.x.size(),
x = uvm_empty(*pdata.x.size(),
dtype=pdata.x.dtype,
device=ctx.device)
starrygl.utils.uvm.uvm_share(x,device = ctx.device)
starrygl.utils.uvm.uvm_advise(x,starrygl.utils.uvm.cudaMemoryAdvise.cudaMemAdviseSetAccessedBy)
starrygl.utils.uvm.uvm_prefetch(x)
uvm_share(x,device = ctx.device)
uvm_advise(x,cudaMemoryAdvise.cudaMemAdviseSetAccessedBy)
uvm_prefetch(x)
if world_size > 1:
self.x = DistributedTensor(pdata.x.to(self.device).to(torch.float))
else:
......@@ -71,12 +72,15 @@ class DistributedGraphStore:
edge_attr = pdata.edge_attr.to(self.device)
else:
if self.device.type == 'cuda':
edge_attr = starrygl.utils.uvm.uvm_empty(*pdata.edge_attr.size(),
edge_attr = uvm_empty(*pdata.edge_attr.size(),
dtype=pdata.edge_attr.dtype,
device=ctx.device)
starrygl.utils.uvm.uvm_share(edge_attr,device = ctx.device)
starrygl.utils.uvm.uvm_advise(edge_attr,starrygl.utils.uvm.cudaMemoryAdvise.cudaMemAdviseSetAccessedBy)
starrygl.utils.uvm.uvm_prefetch(edge_attr)
edge_attr = uvm_share(edge_attr,device = torch.device('cpu'))
edge_attr.copy_(pdata.edge_attr)
edge_attr = uvm_share(edge_attr,device = ctx.device)
uvm_advise(edge_attr,cudaMemoryAdvise.cudaMemAdviseSetAccessedBy)
uvm_prefetch(edge_attr)
if world_size > 1:
self.edge_attr = DistributedTensor(edge_attr)
else:
......
import parser
from torch_sparse import SparseTensor
from torch_geometric.data import Data
from torch_geometric.utils import degree
import starrygl
import os.path as osp
import os
import shutil
import torch
import torch.utils.data
import metis
import networkx as nx
import torch.distributed as dist
from starrygl.lib.libstarrygl_sampler import get_norm_temporal
from starrygl.utils.partition import mt_metis_partition
def partition_load(root: str, algo: str = "metis") -> Data:
......@@ -21,7 +22,6 @@ def partition_load(root: str, algo: str = "metis") -> Data:
def partition_save(root: str, data: Data, num_parts: int,
algo: str = "metis",
node_weight = None,
edge_weight_dict=None):
root = osp.abspath(root)
if osp.exists(root) and not osp.isdir(root):
......@@ -46,7 +46,6 @@ def partition_save(root: str, data: Data, num_parts: int,
if algo == 'metis_for_tgnn':
for i, pdata in enumerate(partition_data_for_tgnn(
data, num_parts, algo, verbose=True,
node_weight = node_weight,
edge_weight_dict=edge_weight_dict)):
print(f"saving partition data: {i+1}/{num_parts}")
fn = osp.join(path, f"{i:03d}")
......@@ -154,41 +153,33 @@ def _nopart(edge_index: torch.LongTensor, num_nodes: int):
def metis_for_tgnn(edge_index_dict: dict,
num_nodes: int,
num_parts: int,
node_weight = None,
edge_weight_dict=None):
if num_parts <= 1:
return _nopart(edge_index_dict, num_nodes)
edge_list = []
weight_list = []
for i,key in enumerate(edge_index_dict):
G = nx.Graph()
G.add_nodes_from(torch.arange(0, num_nodes).tolist())
value, counts = torch.unique(edge_index_dict['edata'][1, :].view(-1),
return_counts=True)
nodes = torch.tensor(list(G.adj.keys()))
for i in range(value.shape[0]):
if (value[i].item() in G.nodes):
G.nodes[int(value[i].item())]['weight'] = counts[i]
G.nodes[int(value[i].item())]['ones'] = 1
G.graph['node_weight_attr'] = ['weight', 'ones']
edges = []
for i, key in enumerate(edge_index_dict):
v = edge_index_dict[key]
edge_list.append(v)
weight_list.append(torch.ones(v.shape[1])*edge_weight_dict[key])
edge_index = torch.cat(edge_list,dim = 1)
edge_weight = torch.cat(weight_list,dim = 0)
node_parts = mt_metis_partition(edge_index,num_nodes,num_parts,node_weight,edge_weight)
edge = torch.cat((v, (torch.ones(v.shape[1], dtype=torch.long) *
edge_weight_dict[key]).unsqueeze(0)), dim=0)
edges.append(edge)
# w = edges.T
edges = torch.cat(edges,dim = 1)
G.add_weighted_edges_from((edges.T).tolist())
G.graph['edge_weight_attr'] = 'weight'
cuts, part = metis.part_graph(G, num_parts)
node_parts = torch.zeros(num_nodes, dtype=torch.long)
node_parts[nodes] = torch.tensor(part)
return node_parts
#G = nx.Graph()
#G.add_nodes_from(torch.arange(0, num_nodes).tolist())
#value, counts = torch.unique(edge_index_dict['edata'][1, :].view(-1),
# return_counts=True)
#nodes = torch.tensor(list(G.adj.keys()))
#for i in range(value.shape[0]):
# if (value[i].item() in G.nodes):
# G.nodes[int(value[i].item())]['weight'] = counts[i]
# G.nodes[int(value[i].item())]['ones'] = 1
#G.graph['node_weight_attr'] = ['weight', 'ones']
#for i, key in enumerate(edge_index_dict):
# v = edge_index_dict[key]
# edges = torch.cat((v, (torch.ones(v.shape[1], dtype=torch.long) *
# edge_weight_dict[key]).unsqueeze(0)), dim=0)
# # w = edges.T
# G.add_weighted_edges_from((edges.T).tolist())
#G.graph['edge_weight_attr'] = 'weight'
#cuts, part = metis.part_graph(G, num_parts)
#node_parts = torch.zeros(num_nodes, dtype=torch.long)
#node_parts[nodes] = torch.tensor(part)
#return node_parts
"""
......@@ -199,7 +190,6 @@ weight: 各种工作负载边划分权重
def partition_data_for_tgnn(data: Data, num_parts: int, algo: str,
verbose: bool = False,
node_weight: torch.Tensor = None,
edge_weight_dict: dict = None):
if algo == "metis_for_tgnn":
part_fn = metis_for_tgnn
......@@ -213,7 +203,6 @@ def partition_data_for_tgnn(data: Data, num_parts: int, algo: str,
if verbose:
print(f"running partition algorithm: {algo}")
node_parts = part_fn(edge_index_dict, num_nodes, num_parts,
node_weight,
edge_weight_dict)
edge_parts = node_parts[data.edge_index[1, :]]
eids = torch.arange(num_edges, dtype=torch.long)
......@@ -304,7 +293,7 @@ def compute_gcn_norm(edge_index: torch.LongTensor, num_nodes: int):
def compute_temporal_norm(edge_index: torch.LongTensor,
timestamp: torch.FloatTensor,
num_nodes: int):
srcavg, srcvar, dstavg, dstvar = get_norm_temporal(edge_index[0, :],
srcavg, srcvar, dstavg, dstvar = starrygl.sampler_ops.get_norm_temporal(edge_index[0, :],
edge_index[1, :],
timestamp, num_nodes)
return srcavg, srcvar, dstavg, dstvar
......
'''
参考自DyLib
'''
import sys
from os.path import abspath, join, dirname
sys.path.insert(0, join(abspath(dirname(__file__))))
from torch import Tensor
import torch
from base import NegativeSampling
from base import NegativeSamplingMode
from typing import Any, List, Optional, Tuple, Union
class EvaluateNegSampling(NegativeSampling):
def __init__(
self,
mode: Union[NegativeSamplingMode, str],
src_node_ids: torch.Tensor,
dst_node_ids: torch.Tensor,
interact_times: torch.Tensor = None,
last_observed_time: float = None,
negative_sample_strategy: str = 'random',
seed: int = None
):
super(EvaluateNegSampling,self).__init__(mode)
self.seed = seed
self.negative_sample_strategy = negative_sample_strategy
self.src_node_ids = src_node_ids
self.dst_node_ids = dst_node_ids
self.interact_times = interact_times
self.unique_src_nodes_id = src_node_ids.unique()
self.unique_dst_nodes_id = dst_node_ids.unique()
self.src_id_mapper = torch.zeros(self.unique_src_nodes_id[-1])
self.dst_id_mapper = torch.zeros(self.unique_dst_nodes_id[-1])
self.src_id_mapper[self.unique_src_nodes_id] = torch.arange(self.unique_src_nodes_id.shape[0])
self.dst_id_mapper[self.unique_dst_nodes_id] = torch.arange(self.unique_dst_nodes_id.shape[0])
self.unique_interact_times = self.interact_times.unique()
self.earliest_time = self.unique_interact_times.min().item()
self.last_observed_time = last_observed_time
if self.negative_sample_strategy == 'inductive':
# set of observed edges
self.observed_edges = self.get_unique_edges_between_start_end_time(self.earliest_time, self.last_observed_time)
if self.seed is not None:
self.random_state = torch.Generator()
self.random_state.manual_seed(seed)
else:
self.random_state = torch.Generator()
def get_unique_edges_between_start_end_time(self, start_time: float, end_time: float):
selected_mask = ((self.interact_times >= start_time) and (self.interact_times <= end_time))
# return the unique select source and destination nodes in the selected time interval
return torch.cat((self.src_node_ids[selected_mask],self.dst_node_ids[selected_mask]),dim = 1)
def sample(self, num_samples: int, num_nodes: Optional[int] = None, batch_src_node_ids: Optional[torch.Tensor] = None,
batch_dst_node_ids: Optional[torch.Tensor] = None, current_batch_start_time: Optional[torch.Tensor] = None,
current_batch_end_time: Optional[torch.Tensor] = None) -> Tensor:
if self.negative_sample_strategy == 'random':
negative_src_node_ids, negative_dst_node_ids = self.random_sample(size=num_samples)
elif self.negative_sample_strategy == 'historical':
negative_src_node_ids, negative_dst_node_ids = self.historical_sample(size=num_samples, batch_src_node_ids=batch_src_node_ids,
batch_dst_node_ids=batch_dst_node_ids,
current_batch_start_time=current_batch_start_time,
current_batch_end_time=current_batch_end_time)
elif self.negative_sample_strategy == 'inductive':
negative_src_node_ids, negative_dst_node_ids = self.inductive_sample(size=num_samples, batch_src_node_ids=batch_src_node_ids,
batch_dst_node_ids=batch_dst_node_ids,
current_batch_start_time=current_batch_start_time,
current_batch_end_time=current_batch_end_time)
else:
raise ValueError(f'Not implemented error for negative_sample_strategy {self.negative_sample_strategy}!')
return negative_src_node_ids, negative_dst_node_ids
def random_sample(self, size: int):
if self.seed is None:
random_sample_edge_src_node_indices = torch.randint(0, len(self.unique_src_nodes_id), size)
random_sample_edge_dst_node_indices = torch.randint(0, len(self.unique_dst_nodes_id), size)
else:
random_sample_edge_src_node_indices = torch.randint(0, len(self.unique_src_nodes_id), size, generate = self.random_state)
random_sample_edge_dst_node_indices = torch.randint(0, len(self.unique_dst_nodes_id), size, generate = self.random_state)
return self.unique_src_nodes_id[random_sample_edge_src_node_indices], self.unique_dst_nodes_id[random_sample_edge_dst_node_indices]
def random_sample_with_collision_check(self, size: int, batch_src_nodes_id:torch.Tensor, batch_dst_nodes_id:torch.Tensor):
batch_edge = torch.stack((batch_src_nodes_id,batch_dst_nodes_id))
batch_src_index = self.src_id_mapper[batch_src_nodes_id]
batch_dst_index = self.dst_id_mapper[batch_dst_nodes_id]
return_edge = torch.tensor([[],[]])
while(True):
src_ = torch.randint(0, len(self.unique_src_nodes_id), size*2)
dst_ = torch.randint(0, len(self.unique_dst_nodes_id), size*2)
edge = torch.stack((src_,dst_))
sample_id = src_*self.unique_dst_nodes_id.shape[0] + dst_
batch_id = batch_src_index * self.unique_dst_nodes_id.shape[0] + batch_dst_index
mask = torch.isin(sample_id,batch_id,invert = True)
edge = edge[:,mask]
if(edge.shape[1] >= size):
return_edge = torch.cat((return_edge,edge[:,:size]),1)
break
else:
return_edge = torch.cat((return_edge,edge),1)
size = size - edge.shape[1]
return return_edge
def historical_sample(self, size: int, batch_src_nodes_id: torch.Tensor, batch_dst_nodes_id: torch.Tensor,
current_batch_start_time: float, current_batch_end_time: float):
assert self.seed is not None
historical_edges = self.get_unique_edges_between_start_end_time(start_time=self.earliest_time, end_time=current_batch_start_time)
current_batch_edges = self.get_unique_edges_between_start_end_time(start_time=current_batch_start_time, end_time=current_batch_end_time)
uni,ids = torch.cat((current_batch_edges, historical_edges), dim = 1).unique(dim = 1, return_inverse = False)
mask = torch.zeros(uni.shape[1],dtype = bool)
mask[ids[:current_batch_edges.shape[1]]] = True
mask = (~mask)
unique_historical_edges = uni[:,mask]
if size > unique_historical_edges.shape[1]:
num_random_sample_edges = size - len(unique_historical_edges)
random_sample_edge = self.random_sample_with_collision_check(size=num_random_sample_edges,batch_src_node_ids=batch_src_nodes_id,
batch_dst_node_ids=batch_dst_nodes_id)
sample_edges = torch.cat((unique_historical_edges,random_sample_edge),dim = 1)
else:
historical_sample_edge_node_indices = torch.randperm(unique_historical_edges.shape[1],generator=self.random_state)
sample_edges = unique_historical_edges[:,historical_sample_edge_node_indices[:size]]
return sample_edges
def inductive_sample(self, size: int, batch_src_node_ids: torch.Tensor, batch_dst_node_ids: torch.Tensor,
current_batch_start_time: float, current_batch_end_time: float):
assert self.seed is not None
historical_edges = self.get_unique_edges_between_start_end_time(start_time=self.earliest_time, end_time=current_batch_start_time)
current_batch_edges = self.get_unique_edges_between_start_end_time(start_time=current_batch_start_time, end_time=current_batch_end_time)
uni,ids = torch.cat((self.observed_edges,current_batch_edges, historical_edges), dim = 1).unique(dim = 1, return_inverse = False)
mask = torch.zeros(uni.shape[1],dtype = bool)
mask[ids[:current_batch_edges.shape[1]+historical_edges.shape[1]]] = True
mask = (~mask)
unique_inductive_edges = uni[:,mask]
if size > len(unique_inductive_edges):
num_random_sample_edges = size - len(unique_inductive_edges)
random_sample_edge = self.random_sample_with_collision_check(size=num_random_sample_edges,
batch_src_node_ids=batch_src_node_ids,
batch_dst_node_ids=batch_dst_node_ids)
sample_edges = torch.cat((unique_inductive_edges,random_sample_edge),dim = 1)
else:
inductive_sample_edge_node_indices = torch.randperm(unique_inductive_edges.shape[1],generator=self.random_state)
sample_edges = unique_inductive_edges[:, inductive_sample_edge_node_indices[:size]]
return sample_edges
......@@ -62,7 +62,7 @@ from torch.nn.parallel import DistributedDataParallel as DDP
#os.environ["LOCAL_RANK"] = str(0)
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
os.environ["MASTER_ADDR"] = '10.214.211.187'
os.environ["MASTER_PORT"] = '9337'
os.environ["MASTER_PORT"] = '9437'
def seed_everything(seed=42):
random.seed(seed)
np.random.seed(seed)
......@@ -81,7 +81,7 @@ def main():
device_id = torch.cuda.current_device()
print('use cuda on',device_id)
pdata = partition_load("/mnt/data/part_data/dataset/here/{}".format(args.dataname), algo="metis_for_tgnn")
graph = DistributedGraphStore(pdata = pdata)
graph = DistributedGraphStore(pdata = pdata,uvm_edge = True)
Path("./saved_models/").mkdir(parents=True, exist_ok=True)
Path("./saved_checkpoints/").mkdir(parents=True, exist_ok=True)
......@@ -91,8 +91,9 @@ def main():
use_src_emb = gnn_param['use_src_emb'] if 'use_src_emb' in gnn_param else False
use_dst_emb = gnn_param['use_dst_emb'] if 'use_dst_emb' in gnn_param else False
sample_graph = TemporalNeighborSampleGraph(sample_graph = pdata.sample_graph,mode = 'full')
#mailbox = None
mailbox = SharedMailBox(pdata.ids.shape[0], memory_param, dim_edge_feat = pdata.edge_attr.shape[1] if pdata.edge_attr is not None else 0)
sampler = NeighborSampler(num_nodes=graph.num_nodes, num_layers=1, fanout=[10],graph_data=sample_graph, workers=15,policy = 'recent',graph_name = "wiki_train")
sampler = NeighborSampler(num_nodes=graph.num_nodes, num_layers=2, fanout=[10,10],graph_data=sample_graph, workers=15,policy = 'uniform',graph_name = "wiki_train")
train_data = torch.masked_select(graph.edge_index,pdata.train_mask.to(graph.edge_index.device)).reshape(2,-1)
train_ts = torch.masked_select(graph.edge_ts,pdata.train_mask.to(graph.edge_index.device))
val_data = torch.masked_select(graph.edge_index,pdata.val_mask.to(graph.edge_index.device)).reshape(2,-1)
......@@ -157,7 +158,7 @@ def main():
else:
model = GeneralModel(gnn_dim_node, gnn_dim_edge, sample_param, memory_param, gnn_param, train_param)
device = torch.device('cpu')
model = DDP(model,find_unused_parameters=True)
model = DDP(model)
train_stream = torch.cuda.Stream()
send_stream = torch.cuda.Stream()
scatter_stream = torch.cuda.Stream()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment