Commit 36c9a61a by Wenjie Huang

Merge branch 'doc-v2' of http://10.75.75.11:7001/wjie98/starrygl

parents 4140455c 8ee95360
......@@ -179,4 +179,9 @@ cython_debug/
/test_*
/*.ipynb
saved_models/
saved_checkpoints/
\ No newline at end of file
saved_checkpoints/
.history/
.preprocess_data/
.processed_data/
.DG_data/
.examples/
......@@ -25,8 +25,8 @@ gnn:
dim_time: 100
dim_out: 100
train:
- epoch: 50
batch_size: 100
- epoch: 100
batch_size: 1000
# reorder: 16
lr: 0.0001
dropout: 0.1
......
......@@ -16,8 +16,8 @@ gnn:
use_dst_emb: False
time_transform: 'JODIE'
train:
- epoch: 20
batch_size: 200
- epoch: 100
batch_size: 1000
lr: 0.0001
dropout: 0.1
all_on_gpu: True
\ No newline at end of file
......@@ -13,13 +13,15 @@ memory:
dim_out: 0
gnn:
- arch: 'transformer_attention'
use_src_emb: False
use_dst_emb: False
layer: 2
att_head: 2
dim_time: 100
dim_out: 100
train:
- epoch: 100
batch_size: 600
- epoch: 50
batch_size: 1000
lr: 0.0001
dropout: 0.1
att_dropout: 0.1
......
......@@ -25,8 +25,8 @@ gnn:
dim_time: 100
dim_out: 100
train:
- epoch: 20
batch_size: 200
- epoch: 5
batch_size: 1000
# reorder: 16
lr: 0.0001
dropout: 0.2
......
......@@ -25,8 +25,8 @@ gnn:
dim_time: 100
dim_out: 100
train:
- epoch: 20
batch_size: 200
- epoch: 50
batch_size: 1000
# reorder: 16
lr: 0.0001
dropout: 0.2
......
......@@ -19,29 +19,6 @@ parser.add_argument('--num_neg_sample', default=1, type=int, metavar='W',
args = parser.parse_args()
def load_feat(d, rand_de=0, rand_dn=0):
node_feats = None
if os.path.exists('DATA/{}/node_features.pt'.format(d)):
node_feats = torch.load('DATA/{}/node_features.pt'.format(d))
if node_feats.dtype == torch.bool:
node_feats = node_feats.type(torch.float32)
edge_feats = None
if os.path.exists('DATA/{}/edge_features.pt'.format(d)):
edge_feats = torch.load('DATA/{}/edge_features.pt'.format(d))
if edge_feats.dtype == torch.bool:
edge_feats = edge_feats.type(torch.float32)
if rand_de > 0:
if d == 'LASTFM':
edge_feats = torch.randn(1293103, rand_de)
elif d == 'MOOC':
edge_feats = torch.randn(411749, rand_de)
if rand_dn > 0:
if d == 'LASTFM':
node_feats = torch.randn(1980, rand_dn)
elif d == 'MOOC':
edge_feats = torch.randn(7144, rand_dn)
return node_feats, edge_feats
data_name = args.data_name
g = np.load('/mnt/nfs/fzz/TGL-DATA/'+data_name+'/ext_full.npz')
......@@ -95,8 +72,9 @@ if e_feat is not None:
data.edge_attr = e_feat
data.train_mask = (torch.from_numpy(np.array(df.ext_roll.values)) == 0)
data.test_mask = (torch.from_numpy(np.array(df.ext_roll.values)) == 1)
data.val_mask = (torch.from_numpy(np.array(df.ext_roll.values)) == 2)
data.val_mask = (torch.from_numpy(np.array(df.ext_roll.values)) == 1)
data.test_mask = (torch.from_numpy(np.array(df.ext_roll.values)) == 2)
print(ts[data.train_mask].min(),ts[data.train_mask].max(),ts[data.val_mask].min(),ts[data.val_mask].max(),ts[data.test_mask].min(),ts[data.test_mask].max())
sample_graph['train_mask'] = data.train_mask[sample_eid]
sample_graph['test_mask'] = data.test_mask[sample_eid]
sample_graph['val_mask'] = data.val_mask[sample_eid]
......@@ -106,28 +84,19 @@ data.y = torch.zeros(edge_index.shape[1])
edge_index_dict = {}
edge_index_dict['edata'] = data.edge_index
edge_index_dict['sample_data'] = data.sample_graph['edge_index']
edge_index_dict['neg_data'] = torch.cat([neg_src.view(1, -1),
dst.view(-1, 1).repeat(1, neg_nums).
reshape(1, -1)], dim=0)
data.edge_index_dict = edge_index_dict
edge_weight_dict = {}
edge_weight_dict['edata'] = 2*neg_nums
edge_weight_dict['edata'] = 1*neg_nums
edge_weight_dict['sample_data'] = 1*neg_nums
edge_weight_dict['neg_data'] = 1
#partition_save('./dataset/here/'+data_name, data, 1, 'metis_for_tgnn',
# edge_weight_dict=edge_weight_dict)
#partition_save('./dataset/here/'+data_name, data, 2, 'metis_for_tgnn',
# edge_weight_dict=edge_weight_dict)
#partition_save('./dataset/here/'+data_name, data, 4, 'metis_for_tgnn',
# edge_weight_dict=edge_weight_dict)
#partition_save('./dataset/here/'+data_name, data, 8, 'metis_for_tgnn',
# edge_weight_dict=edge_weight_dict)
partition_save('/mnt/data/part_data/v2/here/'+data_name, data, 1, 'metis_for_tgnn',
edge_weight_dict=edge_weight_dict)
partition_save('/mnt/data/part_data/v2/here/'+data_name, data, 2, 'metis_for_tgnn',
edge_weight_dict=edge_weight_dict)
partition_save('/mnt/data/part_data/v2/here/'+data_name, data, 4, 'metis_for_tgnn',
edge_weight_dict=edge_weight_dict)
partition_save('/mnt/data/part_data/v2/here/'+data_name, data, 8, 'metis_for_tgnn',
edge_weight_dict=edge_weight_dict)
partition_save('./dataset/here/'+data_name, data, 16, 'metis_for_tgnn',
edge_weight_dict=edge_weight_dict)
#
# partition_save('./dataset/here/'+data_name, data, 4, 'metis_for_tgnn',
# edge_weight_dict=edge_weight_dict )
# partition_save('./dataset/here'+data_name, data, 8, 'metis')
# partition_save('./dataset/'+data_name, data, 12, 'metis')
# partition_save('./dataset/'+data_name, data, 16, 'metis')
......@@ -3,16 +3,11 @@
mkdir -p build && cd build
cmake .. \
-DCMAKE_EXPORT_COMPILE_COMMANDS=ON \
-DCMAKE_PREFIX_PATH=$(python -c "from distutils.sysconfig import get_python_lib; print(get_python_lib())") \
-DPython3_ROOT_DIR=$(python -c "import sys; print(sys.prefix)") \
-DCUDA_TOOLKIT_ROOT_DIR=${CUDA_HOME:-"$(realpath $(dirname $(which nvcc))/../)"} \
-DCMAKE_PREFIX_PATH="/home/zlj/.miniconda3/envs/sgl/lib/python3.10/site-packages" \
-DPython3_ROOT_DIR="/home/zlj/.miniconda3/envs/sgl" \
-DCUDA_TOOLKIT_ROOT_DIR="/home/zlj/.local/cuda-11.8" \
&& make -j32 \
&& rm -rf ../starrygl/lib \
&& mkdir ../starrygl/lib \
&& cp lib*.so ../starrygl/lib/ \
&& patchelf --set-rpath '$ORIGIN:$ORIGIN/lib' --force-rpath ../starrygl/lib/*.so
# -DCMAKE_PREFIX_PATH="/home/zlj/.miniconda3/envs/dgnn/lib/python3.10/site-packages" \
# -DPython3_ROOT_DIR="/home/zlj/.miniconda3/envs/dgnn" \
# -DCUDA_TOOLKIT_ROOT_DIR="/home/zlj/local/cuda-12.2" \
\ No newline at end of file
import random
import pandas as pd
import numpy as np
import os
import torch
from torch_geometric.data import Data
from starrygl.sample.graph_core import DataSet, DistributedGraphStore
def get_link_prediction_data(data_name: str, val_ratio, test_ratio):
"""
generate data for link prediction task (inductive & transductive settings)
:param dataset_name: str, dataset name
:param val_ratio: float, validation data ratio
:param test_ratio: float, test data ratio
:return: node_raw_features, edge_raw_features, (np.ndarray),
full_data, train_data, val_data, test_data, new_node_val_data, new_node_test_data, (Data object)
"""
# Load data and train val test split
#graph_df = pd.read_csv('/mnt/nfs/fzz/TGL-DATA/'+data_name+'/edges.csv')
#if os.path.exists('/mnt/nfs/fzz/TGL-DATA/'+data_name+'/node_features.pt'):
# n_feat = torch.load('/mnt/nfs/fzz/TGL-DATA/'+data_name+'/node_features.pt')
#else:
# n_feat = None
#if os.path.exists('/mnt/nfs/fzz/TGL-DATA/'+data_name+'/edge_features.pt'):
# e_feat = torch.load('/mnt/nfs/fzz/TGL-DATA/'+data_name+'/edge_features.pt')
#else:
# e_feat = None
#
## get the timestamp of validate and test set
#src_node_ids = torch.from_numpy(np.array(graph_df.src.values)).long()
#dst_node_ids = torch.from_numpy(np.array(graph_df.dst.values)).long()
#node_interact_times = torch.from_numpy(np.array(graph_df.time.values)).long()
#
#train_mask = (torch.from_numpy(np.array(graph_df.ext_roll.values)) == 0)
#test_mask = (torch.from_numpy(np.array(graph_df.ext_roll.values)) == 1)
#val_mask = (torch.from_numpy(np.array(graph_df.ext_roll.values)) == 2)
# the setting of seed follows previous works
graph_df = pd.read_csv('./processed_data/{}/ml_{}.csv'.format(data_name, data_name))
edge_raw_features = np.load('./processed_data/{}/ml_{}.npy'.format(data_name, data_name))
node_raw_features = np.load('./processed_data/{}/ml_{}_node.npy'.format(data_name, data_name))
NODE_FEAT_DIM = EDGE_FEAT_DIM = 172
assert NODE_FEAT_DIM >= node_raw_features.shape[1], f'Node feature dimension in dataset {data_name} is bigger than {NODE_FEAT_DIM}!'
assert EDGE_FEAT_DIM >= edge_raw_features.shape[1], f'Edge feature dimension in dataset {data_name} is bigger than {EDGE_FEAT_DIM}!'
# padding the features of edges and nodes to the same dimension (172 for all the datasets)
if node_raw_features.shape[1] < NODE_FEAT_DIM:
node_zero_padding = np.zeros((node_raw_features.shape[0], NODE_FEAT_DIM - node_raw_features.shape[1]))
node_raw_features = np.concatenate([node_raw_features, node_zero_padding], axis=1)
if edge_raw_features.shape[1] < EDGE_FEAT_DIM:
edge_zero_padding = np.zeros((edge_raw_features.shape[0], EDGE_FEAT_DIM - edge_raw_features.shape[1]))
edge_raw_features = np.concatenate([edge_raw_features, edge_zero_padding], axis=1)
e_feat = edge_raw_features
n_feat = torch.from_numpy(node_raw_features.astype(np.float32))
e_feat = torch.from_numpy(edge_raw_features.astype(np.float32))
assert NODE_FEAT_DIM == node_raw_features.shape[1] and EDGE_FEAT_DIM == edge_raw_features.shape[1], 'Unaligned feature dimensions after feature padding!'
# get the timestamp of validate and test set
val_time, test_time = list(np.quantile(graph_df.ts, [(1 - val_ratio - test_ratio), (1 - test_ratio)]))
src_node_ids = torch.from_numpy(graph_df.u.values.astype(np.longlong))
dst_node_ids = torch.from_numpy(graph_df.i.values.astype(np.longlong))
node_interact_times = torch.from_numpy(graph_df.ts.values.astype(np.float32))
#edge_ids = torch.from_numpy(graph_df.idx.values.astype(np.longlong))
labels = torch.from_numpy(graph_df.label.values)
unique_node_ids = torch.cat((src_node_ids,dst_node_ids)).unique()
train_mask = node_interact_times <= val_time
val_mask = ((node_interact_times > val_time)&(node_interact_times <= test_time))
test_mask = (node_interact_times > test_time)
torch.manual_seed(2020)
train_node_set = torch.cat((src_node_ids[train_mask],dst_node_ids[train_mask])).unique()
test_node_set = set(src_node_ids[node_interact_times > val_time]).union(set(dst_node_ids[node_interact_times > val_time]))
new_test_node_set = set(random.sample(test_node_set, int(0.1 * unique_node_ids.shape[0])))
new_test_source_mask = graph_df.u.map(lambda x: x in new_test_node_set).values
new_test_destination_mask = graph_df.i.map(lambda x: x in new_test_node_set).values
# mask, which is true for edges with both destination and source not being new test nodes (because we want to remove all edges involving any new test node)
observed_edges_mask = torch.from_numpy(np.logical_and(~new_test_source_mask, ~new_test_destination_mask)).long()
train_mask = (train_mask & observed_edges_mask)
mask = torch.isin(unique_node_ids,train_node_set,invert = True)
new_node_set = unique_node_ids[mask]
edge_contains_new_node_mask = (torch.isin(src_node_ids,new_node_set) | torch.isin(dst_node_ids,new_node_set))
new_node_val_mask = (val_mask & edge_contains_new_node_mask)
new_node_test_mask = (test_mask & edge_contains_new_node_mask)
full_data = Data()
full_data.edge_index = torch.stack((src_node_ids,dst_node_ids))
sample_graph = {}
sample_src = torch.cat([src_node_ids.view(-1, 1), dst_node_ids.view(-1, 1)], dim=1)\
.reshape(1, -1)
sample_dst = torch.cat([dst_node_ids.view(-1, 1), src_node_ids.view(-1, 1)], dim=1)\
.reshape(1, -1)
sample_ts = torch.cat([node_interact_times.view(-1, 1), node_interact_times.view(-1, 1)], dim=1).reshape(-1)
sample_eid = torch.arange(full_data.edge_index.shape[1]).view(-1, 1).repeat(1, 2).reshape(-1)
sample_graph['edge_index'] = torch.cat([sample_src, sample_dst], dim=0)
sample_graph['ts'] = sample_ts
sample_graph['eids'] = sample_eid
sample_graph['train_mask'] = train_mask
sample_graph['val_mask'] = val_mask
sample_graph['test_mask'] = val_mask
sample_graph['new_node_val_mask'] = new_node_val_mask
sample_graph['new_node_test_mask'] = new_node_test_mask
print(unique_node_ids.max().item(),unique_node_ids.shape[0])
full_data.num_nodes = int(unique_node_ids.max().item())+1
full_data.num_edges = node_interact_times.shape[0]
full_data.sample_graph = sample_graph
full_data.x = n_feat
full_data.edge_attr = e_feat
full_data.y = labels
full_data.edge_ts = node_interact_times
full_data.train_mask = train_mask
full_data.val_mask = val_mask
full_data.test_mask = test_mask
full_data.new_node_val_mask = new_node_val_mask
full_data.new_node_test_mask = new_node_test_mask
return full_data
#full_graph = DistributedGraphStore(full_data, device, uvm_node, uvm_edge)
#train_data = torch.masked_select(full_data.edge_index,train_mask.to(device)).reshape(2,-1)
#train_ts = torch.masked_select(full_data.edge_ts,train_mask.to(device))
#val_data = torch.masked_select(full_data.edge_index,val_mask.to(device)).reshape(2,-1)
#val_ts = torch.masked_select(full_data.edge_ts,val_mask.to(device))
#test_data = torch.masked_select(full_data.edge_index,test_mask.to(device)).reshape(2,-1)
#test_ts = torch.masked_select(full_data.edge_ts,test_mask.to(device))
##print(train_data.shape[1],val_data.shape[1],test_data.shape[1])
#train_data = DataSet(edges = train_data,ts =train_ts,eids = torch.nonzero(train_mask).view(-1))
#test_data = DataSet(edges = test_data,ts =test_ts,eids = torch.nonzero(test_mask).view(-1))
#val_data = DataSet(edges = val_data,ts = val_ts,eids = torch.nonzero(val_mask).view(-1))
#new_node_val_data = torch.masked_select(full_data.edge_index,new_node_val_mask.to(device)).reshape(2,-1)
#new_node_val_ts = torch.masked_select(full_data.edge_ts,new_node_val_mask.to(device))
#new_node_test_data = torch.masked_select(full_data.edge_index,new_node_test_mask.to(device)).reshape(2,-1)
#new_node_test_ts = torch.masked_select(full_data.edge_ts,new_node_test_mask.to(device))
#return full_data, train_data, val_data, test_data, new_node_val_data, new_node_test_data
import torch
from sklearn.metrics import average_precision_score, roc_auc_score
def get_link_prediction_metrics(predicts: torch.Tensor, labels: torch.Tensor):
"""
get metrics for the link prediction task
:param predicts: Tensor, shape (num_samples, )
:param labels: Tensor, shape (num_samples, )
:return:
dictionary of metrics {'metric_name_1': metric_1, ...}
"""
predicts = predicts.cpu().detach().numpy()
labels = labels.cpu().numpy()
average_precision = average_precision_score(y_true=labels, y_score=predicts)
roc_auc = roc_auc_score(y_true=labels, y_score=predicts)
return {'average_precision': average_precision, 'roc_auc': roc_auc}
def get_node_classification_metrics(predicts: torch.Tensor, labels: torch.Tensor):
"""
get metrics for the node classification task
:param predicts: Tensor, shape (num_samples, )
:param labels: Tensor, shape (num_samples, )
:return:
dictionary of metrics {'metric_name_1': metric_1, ...}
"""
predicts = predicts.cpu().detach().numpy()
labels = labels.cpu().numpy()
roc_auc = roc_auc_score(y_true=labels, y_score=predicts)
return {'roc_auc': roc_auc}
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from tqdm import tqdm
import numpy as np
import logging
import time
import argparse
import os
import json
from models.EdgeBank import edge_bank_link_prediction
from starrygl.evaluation.metrics import get_link_prediction_metrics, get_node_classification_metrics
from utils.utils import set_random_seed
from starrygl.sample.sample_core import EvaluateNegativeSampling
from utils.DataLoader import Data
def evaluate_model_link_prediction(model_name: str, model: nn.Module, neighbor_sampler: NeighborSampler, evaluate_idx_data_loader: DataLoader,
evaluate_neg_edge_sampler: NegativeEdgeSampler, evaluate_data: Data, loss_func: nn.Module,
num_neighbors: int = 20, time_gap: int = 2000):
"""
evaluate models on the link prediction task
:param model_name: str, name of the model
:param model: nn.Module, the model to be evaluated
:param neighbor_sampler: NeighborSampler, neighbor sampler
:param evaluate_idx_data_loader: DataLoader, evaluate index data loader
:param evaluate_neg_edge_sampler: NegativeEdgeSampler, evaluate negative edge sampler
:param evaluate_data: Data, data to be evaluated
:param loss_func: nn.Module, loss function
:param num_neighbors: int, number of neighbors to sample for each node
:param time_gap: int, time gap for neighbors to compute node features
:return:
"""
# Ensures the random sampler uses a fixed seed for evaluation (i.e. we always sample the same negatives for validation / test set)
assert evaluate_neg_edge_sampler.seed is not None
evaluate_neg_edge_sampler.reset_random_state()
if model_name in ['DyRep', 'TGAT', 'TGN', 'CAWN', 'TCL', 'GraphMixer', 'DyGFormer']:
# evaluation phase use all the graph information
model[0].set_neighbor_sampler(neighbor_sampler)
model.eval()
with torch.no_grad():
# store evaluate losses and metrics
evaluate_losses, evaluate_metrics = [], []
evaluate_idx_data_loader_tqdm = tqdm(evaluate_idx_data_loader, ncols=120)
for batch_idx, evaluate_data_indices in enumerate(evaluate_idx_data_loader_tqdm):
evaluate_data_indices = evaluate_data_indices.numpy()
batch_src_node_ids, batch_dst_node_ids, batch_node_interact_times, batch_edge_ids = \
evaluate_data.src_node_ids[evaluate_data_indices], evaluate_data.dst_node_ids[evaluate_data_indices], \
evaluate_data.node_interact_times[evaluate_data_indices], evaluate_data.edge_ids[evaluate_data_indices]
if evaluate_neg_edge_sampler.negative_sample_strategy != 'random':
batch_neg_src_node_ids, batch_neg_dst_node_ids = evaluate_neg_edge_sampler.sample(size=len(batch_src_node_ids),
batch_src_node_ids=batch_src_node_ids,
batch_dst_node_ids=batch_dst_node_ids,
current_batch_start_time=batch_node_interact_times[0],
current_batch_end_time=batch_node_interact_times[-1])
else:
_, batch_neg_dst_node_ids = evaluate_neg_edge_sampler.sample(size=len(batch_src_node_ids))
batch_neg_src_node_ids = batch_src_node_ids
# we need to compute for positive and negative edges respectively, because the new sampling strategy (for evaluation) allows the negative source nodes to be
# different from the source nodes, this is different from previous works that just replace destination nodes with negative destination nodes
if model_name in ['TGAT', 'CAWN', 'TCL']:
# get temporal embedding of source and destination nodes
# two Tensors, with shape (batch_size, node_feat_dim)
batch_src_node_embeddings, batch_dst_node_embeddings = \
model[0].compute_src_dst_node_temporal_embeddings(src_node_ids=batch_src_node_ids,
dst_node_ids=batch_dst_node_ids,
node_interact_times=batch_node_interact_times,
num_neighbors=num_neighbors)
# get temporal embedding of negative source and negative destination nodes
# two Tensors, with shape (batch_size, node_feat_dim)
batch_neg_src_node_embeddings, batch_neg_dst_node_embeddings = \
model[0].compute_src_dst_node_temporal_embeddings(src_node_ids=batch_neg_src_node_ids,
dst_node_ids=batch_neg_dst_node_ids,
node_interact_times=batch_node_interact_times,
num_neighbors=num_neighbors)
elif model_name in ['JODIE', 'DyRep', 'TGN']:
# note that negative nodes do not change the memories while the positive nodes change the memories,
# we need to first compute the embeddings of negative nodes for memory-based models
# get temporal embedding of negative source and negative destination nodes
# two Tensors, with shape (batch_size, node_feat_dim)
batch_neg_src_node_embeddings, batch_neg_dst_node_embeddings = \
model[0].compute_src_dst_node_temporal_embeddings(src_node_ids=batch_neg_src_node_ids,
dst_node_ids=batch_neg_dst_node_ids,
node_interact_times=batch_node_interact_times,
edge_ids=None,
edges_are_positive=False,
num_neighbors=num_neighbors)
# get temporal embedding of source and destination nodes
# two Tensors, with shape (batch_size, node_feat_dim)
batch_src_node_embeddings, batch_dst_node_embeddings = \
model[0].compute_src_dst_node_temporal_embeddings(src_node_ids=batch_src_node_ids,
dst_node_ids=batch_dst_node_ids,
node_interact_times=batch_node_interact_times,
edge_ids=batch_edge_ids,
edges_are_positive=True,
num_neighbors=num_neighbors)
elif model_name in ['GraphMixer']:
# get temporal embedding of source and destination nodes
# two Tensors, with shape (batch_size, node_feat_dim)
batch_src_node_embeddings, batch_dst_node_embeddings = \
model[0].compute_src_dst_node_temporal_embeddings(src_node_ids=batch_src_node_ids,
dst_node_ids=batch_dst_node_ids,
node_interact_times=batch_node_interact_times,
num_neighbors=num_neighbors,
time_gap=time_gap)
# get temporal embedding of negative source and negative destination nodes
# two Tensors, with shape (batch_size, node_feat_dim)
batch_neg_src_node_embeddings, batch_neg_dst_node_embeddings = \
model[0].compute_src_dst_node_temporal_embeddings(src_node_ids=batch_neg_src_node_ids,
dst_node_ids=batch_neg_dst_node_ids,
node_interact_times=batch_node_interact_times,
num_neighbors=num_neighbors,
time_gap=time_gap)
elif model_name in ['DyGFormer']:
# get temporal embedding of source and destination nodes
# two Tensors, with shape (batch_size, node_feat_dim)
batch_src_node_embeddings, batch_dst_node_embeddings = \
model[0].compute_src_dst_node_temporal_embeddings(src_node_ids=batch_src_node_ids,
dst_node_ids=batch_dst_node_ids,
node_interact_times=batch_node_interact_times)
# get temporal embedding of negative source and negative destination nodes
# two Tensors, with shape (batch_size, node_feat_dim)
batch_neg_src_node_embeddings, batch_neg_dst_node_embeddings = \
model[0].compute_src_dst_node_temporal_embeddings(src_node_ids=batch_neg_src_node_ids,
dst_node_ids=batch_neg_dst_node_ids,
node_interact_times=batch_node_interact_times)
else:
raise ValueError(f"Wrong value for model_name {model_name}!")
# get positive and negative probabilities, shape (batch_size, )
positive_probabilities = model[1](input_1=batch_src_node_embeddings, input_2=batch_dst_node_embeddings).squeeze(dim=-1).sigmoid()
negative_probabilities = model[1](input_1=batch_neg_src_node_embeddings, input_2=batch_neg_dst_node_embeddings).squeeze(dim=-1).sigmoid()
predicts = torch.cat([positive_probabilities, negative_probabilities], dim=0)
labels = torch.cat([torch.ones_like(positive_probabilities), torch.zeros_like(negative_probabilities)], dim=0)
loss = loss_func(input=predicts, target=labels)
evaluate_losses.append(loss.item())
evaluate_metrics.append(get_link_prediction_metrics(predicts=predicts, labels=labels))
evaluate_idx_data_loader_tqdm.set_description(f'evaluate for the {batch_idx + 1}-th batch, evaluate loss: {loss.item()}')
return evaluate_losses, evaluate_metrics
def evaluate_edge_bank_link_prediction(args: argparse.Namespace, train_data: Data, val_data: Data, test_idx_data_loader: DataLoader,
test_neg_edge_sampler: NegativeEdgeSampler, test_data: Data):
"""
evaluate the EdgeBank model for link prediction
:param args: argparse.Namespace, configuration
:param train_data: Data, train data
:param val_data: Data, validation data
:param test_idx_data_loader: DataLoader, test index data loader
:param test_neg_edge_sampler: NegativeEdgeSampler, test negative edge sampler
:param test_data: Data, test data
:return:
"""
# generate the train_validation split of the data: needed for constructing the memory for EdgeBank
train_val_data = Data(src_node_ids=np.concatenate([train_data.src_node_ids, val_data.src_node_ids]),
dst_node_ids=np.concatenate([train_data.dst_node_ids, val_data.dst_node_ids]),
node_interact_times=np.concatenate([train_data.node_interact_times, val_data.node_interact_times]),
edge_ids=np.concatenate([train_data.edge_ids, val_data.edge_ids]),
labels=np.concatenate([train_data.labels, val_data.labels]))
test_metric_all_runs = []
for run in range(args.num_runs):
set_random_seed(seed=run)
args.seed = run
args.save_result_name = f'{args.negative_sample_strategy}_negative_sampling_{args.model_name}_seed{args.seed}'
# set up logger
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)
os.makedirs(f"./logs/{args.model_name}/{args.dataset_name}/{args.save_result_name}/", exist_ok=True)
# create file handler that logs debug and higher level messages
fh = logging.FileHandler(f"./logs/{args.model_name}/{args.dataset_name}/{args.save_result_name}/{str(time.time())}.log")
fh.setLevel(logging.DEBUG)
# create console handler with a higher log level
ch = logging.StreamHandler()
ch.setLevel(logging.WARNING)
# create formatter and add it to the handlers
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
fh.setFormatter(formatter)
ch.setFormatter(formatter)
# add the handlers to logger
logger.addHandler(fh)
logger.addHandler(ch)
run_start_time = time.time()
logger.info(f"********** Run {run + 1} starts. **********")
logger.info(f'configuration is {args}')
loss_func = nn.BCELoss()
# evaluate EdgeBank
logger.info(f'get final performance on dataset {args.dataset_name}...')
# Ensures the random sampler uses a fixed seed for evaluation (i.e. we always sample the same negatives for validation / test set)
assert test_neg_edge_sampler.seed is not None
test_neg_edge_sampler.reset_random_state()
test_losses, test_metrics = [], []
test_idx_data_loader_tqdm = tqdm(test_idx_data_loader, ncols=120)
for batch_idx, test_data_indices in enumerate(test_idx_data_loader_tqdm):
test_data_indices = test_data_indices.numpy()
batch_src_node_ids, batch_dst_node_ids, batch_node_interact_times = \
test_data.src_node_ids[test_data_indices], test_data.dst_node_ids[test_data_indices], \
test_data.node_interact_times[test_data_indices]
if test_neg_edge_sampler.negative_sample_strategy != 'random':
batch_neg_src_node_ids, batch_neg_dst_node_ids = test_neg_edge_sampler.sample(size=len(batch_src_node_ids),
batch_src_node_ids=batch_src_node_ids,
batch_dst_node_ids=batch_dst_node_ids,
current_batch_start_time=batch_node_interact_times[0],
current_batch_end_time=batch_node_interact_times[-1])
else:
_, batch_neg_dst_node_ids = test_neg_edge_sampler.sample(size=len(batch_src_node_ids))
batch_neg_src_node_ids = batch_src_node_ids
positive_edges = (batch_src_node_ids, batch_dst_node_ids)
negative_edges = (batch_neg_src_node_ids, batch_neg_dst_node_ids)
# incorporate the testing data before the current batch to history_data, which is similar to memory-based models
history_data = Data(src_node_ids=np.concatenate([train_val_data.src_node_ids, test_data.src_node_ids[: test_data_indices[0]]]),
dst_node_ids=np.concatenate([train_val_data.dst_node_ids, test_data.dst_node_ids[: test_data_indices[0]]]),
node_interact_times=np.concatenate([train_val_data.node_interact_times, test_data.node_interact_times[: test_data_indices[0]]]),
edge_ids=np.concatenate([train_val_data.edge_ids, test_data.edge_ids[: test_data_indices[0]]]),
labels=np.concatenate([train_val_data.labels, test_data.labels[: test_data_indices[0]]]))
# perform link prediction for EdgeBank
positive_probabilities, negative_probabilities = edge_bank_link_prediction(history_data=history_data,
positive_edges=positive_edges,
negative_edges=negative_edges,
edge_bank_memory_mode=args.edge_bank_memory_mode,
time_window_mode=args.time_window_mode,
time_window_proportion=args.test_ratio)
predicts = torch.from_numpy(np.concatenate([positive_probabilities, negative_probabilities])).float()
labels = torch.cat([torch.ones(len(positive_probabilities)), torch.zeros(len(negative_probabilities))], dim=0)
loss = loss_func(input=predicts, target=labels)
test_losses.append(loss.item())
test_metrics.append(get_link_prediction_metrics(predicts=predicts, labels=labels))
test_idx_data_loader_tqdm.set_description(f'test for the {batch_idx + 1}-th batch, test loss: {loss.item()}')
# store the evaluation metrics at the current run
test_metric_dict = {}
logger.info(f'test loss: {np.mean(test_losses):.4f}')
for metric_name in test_metrics[0].keys():
average_test_metric = np.mean([test_metric[metric_name] for test_metric in test_metrics])
logger.info(f'test {metric_name}, {average_test_metric:.4f}')
test_metric_dict[metric_name] = average_test_metric
single_run_time = time.time() - run_start_time
logger.info(f'Run {run + 1} cost {single_run_time:.2f} seconds.')
test_metric_all_runs.append(test_metric_dict)
# avoid the overlap of logs
if run < args.num_runs - 1:
logger.removeHandler(fh)
logger.removeHandler(ch)
# save model result
result_json = {
"test metrics": {metric_name: f'{test_metric_dict[metric_name]:.4f}'for metric_name in test_metric_dict}
}
result_json = json.dumps(result_json, indent=4)
save_result_folder = f"./saved_results/{args.model_name}/{args.dataset_name}"
os.makedirs(save_result_folder, exist_ok=True)
save_result_path = os.path.join(save_result_folder, f"{args.save_result_name}.json")
with open(save_result_path, 'w') as file:
file.write(result_json)
logger.info(f'save negative sampling results at {save_result_path}')
# store the average metrics at the log of the last run
logger.info(f'metrics over {args.num_runs} runs:')
for metric_name in test_metric_all_runs[0].keys():
logger.info(f'test {metric_name}, {[test_metric_single_run[metric_name] for test_metric_single_run in test_metric_all_runs]}')
logger.info(f'average test {metric_name}, {np.mean([test_metric_single_run[metric_name] for test_metric_single_run in test_metric_all_runs]):.4f} '
f'± {np.std([test_metric_single_run[metric_name] for test_metric_single_run in test_metric_all_runs], ddof=1):.4f}')
......@@ -32,11 +32,16 @@ class EdgePredictor(torch.nn.Module):
def forward(self, h, neg_samples=1):
num_edge = h.shape[0] // (neg_samples + 2)
h_src = self.src_fc(h[:num_edge])
h_pos_dst = self.dst_fc(h[num_edge:num_edge*2]) #
h_neg_src = self.src_fc(h[2 * num_edge:])
h_src = self.src_fc(h[:num_edge])
h_pos_dst = self.dst_fc(h[num_edge:num_edge*2])
h_neg_dst = self.dst_fc(h[2 * num_edge:])
h_pos_edge = torch.nn.functional.relu(h_src + h_pos_dst)
h_neg_edge = torch.nn.functional.relu(h_neg_src + h_pos_dst.tile(neg_samples, 1))
h_neg_edge = torch.nn.functional.relu(h_src + h_neg_dst.tile(neg_samples, 1))
#h_src = self.src_fc(h[num_edge:2 * num_edge])#self.src_fc(h[:num_edge])
#h_pos_dst = self.dst_fc(h[:num_edge]) #
#h_neg_src = self.src_fc(h[2 * num_edge:])
#h_pos_edge = torch.nn.functional.relu(h_src + #h_pos_dst)
#h_neg_edge = torch.nn.functional.relu(h_neg_src #+ h_pos_dst.tile(neg_samples, 1))
#h_neg_edge = torch.nn.functional.relu(h_neg_dst.tile(neg_samples, 1) + h_pos_dst)
#print(h_src,h_pos_dst,h_neg_dst)
return self.out_fc(h_pos_edge), self.out_fc(h_neg_edge)
......@@ -123,6 +128,10 @@ class TransfomerAttentionLayer(torch.nn.Module):
V = self.w_v(torch.cat([b.srcdata['h'][b.edges()[0]], b.edata['f']], dim=1))
#K = self.w_k(torch.cat([b.srcdata['h'][b.num_dst_nodes():], b.edata['f']], dim=1))
#V = self.w_v(torch.cat([b.srcdata['h'][b.num_dst_nodes():], b.edata['f']], dim=1))
elif self.dim_node_feat == 0 and self.dim_edge_feat == 0:
Q = self.w_q(zero_time_feat)[b.edges()[1]]
K = self.w_k(time_feat)
V = self.w_v(time_feat)
elif self.dim_node_feat == 0:
Q = self.w_q(zero_time_feat)[b.edges()[1]]
K = self.w_k(torch.cat([b.edata['f'], time_feat], dim=1))
......@@ -140,6 +149,7 @@ class TransfomerAttentionLayer(torch.nn.Module):
#Q = self.w_q(torch.cat([b.srcdata['h'][:b.num_dst_nodes()], zero_time_feat], dim=1))[b.edges()[1]]
#K = self.w_k(torch.cat([b.srcdata['h'][b.num_dst_nodes():], b.edata['f'], time_feat], dim=1))
#V = self.w_v(torch.cat([b.srcdata['h'][b.num_dst_nodes():], b.edata['f'], time_feat], dim=1))
Q = torch.reshape(Q, (Q.shape[0], self.num_head, -1))
K = torch.reshape(K, (K.shape[0], self.num_head, -1))
V = torch.reshape(V, (V.shape[0], self.num_head, -1))
......
......@@ -203,9 +203,6 @@ class GRUMemeoryUpdater(torch.nn.Module):
self.last_updated_ts = b.srcdata['ts'].detach().clone()
self.last_updated_memory = updated_memory.detach().clone()
self.last_updated_nid = b.srcdata['ID'].detach().clone()
x1 = torch.sum(b.srcdata['mem']**2,dim = 1)
self.delta_memory = torch.sum((updated_memory - b.srcdata['mem'])**2,dim = 1)/torch.sum(b.srcdata['mem']**2,dim = 1)
#print(torch.dist(b.srcdata['mem'],updated_memory))
if self.memory_param['combine_node_feature']:
if self.dim_node_feat > 0:
if self.dim_node_feat == self.dim_hid:
......
......@@ -25,8 +25,8 @@ delta_ts: list[tensor,tensor, tensor...]
metadata
"""
def prepare_input(node_feat, edge_feat, mem_embedding,mfgs,dist_nid,dist_eid):
for mfg in mfgs:
for i,b in enumerate(mfg):
for i,mfg in enumerate(mfgs):
for b in mfg:
e_idx = b.edata['ID']
idx = b.srcdata['ID']
b.edata['ID'] = dist_eid[e_idx]
......@@ -43,73 +43,61 @@ def prepare_input(node_feat, edge_feat, mem_embedding,mfgs,dist_nid,dist_eid):
b.srcdata['mem_input'] = mailbox[idx].reshape(b.srcdata['ID'].shape[0], -1)
b.srcdata['mail_ts'] = mailbox_ts[idx]
#print(idx.shape[0],b.srcdata['mem_ts'].shape)
return mfgs
return mfgs
def to_block(graph: DistributedGraphStore, data, sample_out, mailbox:MailBox = None,device = torch.device('cuda'),group = None):
if len(sample_out) > 1:
sample_out,metadata = sample_out
else:
metadata = None
# print(sample_out)
eid = [ret.eid() for ret in sample_out]
eid_len = [e.shape[0] for e in eid ]
eid_mapper: torch.Tensor = graph.eids_mapper
nid_mapper: torch.Tensor = graph.nids_mapper
eid_tensor = torch.cat(eid,dim = 0).to(eid_mapper.device)
dist_eid = eid_mapper[eid_tensor].to(device)
dist_eid = graph.sample_graph['dist_eid'][eid_tensor].to(device)#eid_mapper[eid_tensor].to(device)
dist_eid,eid_inv = dist_eid.unique(return_inverse=True)
src_node = graph.sample_graph['edge_index'][0,eid_tensor*2].to(graph.nids_mapper.device)
src_node = graph.sample_graph['edge_index'][0,eid_tensor].to(graph.nids_mapper.device)
src_ts = None
if metadata is None:
root_node = data.nodes.to(graph.nids_mapper.device)
root_node = data.nodes.to(graph.nidst_eid_mapper.device)
root_len = [root_node.shape[0]]
if hasattr(data,'ts'):
src_ts = torch.cat([data.ts,
graph.sample_graph['ts'][eid_tensor*2].to(device)])
graph.sample_graph['ts'][eid_tensor].to(device)])
elif 'seed' in metadata:
root_node = metadata.pop('seed').to(graph.nids_mapper.device)
root_len = root_node.shape[0]
if 'seed_ts' in metadata:
src_ts = torch.cat([metadata.pop('seed_ts').to(device),\
graph.sample_graph['ts'][eid_tensor*2].to(device)])
graph.sample_graph['ts'][eid_tensor].to(device)])
for k in metadata:
metadata[k] = metadata[k].to(device)
nid_tensor = torch.cat([root_node,src_node],dim = 0)
dist_nid = nid_mapper[nid_tensor].to(device)
dist_nid,nid_inv = dist_nid.unique(return_inverse = True)
fetchCache = FetchFeatureCache.getFetchCache()
if fetchCache is None:
if isinstance(graph.edge_attr,DistributedTensor):
ind_dict = graph.edge_attr.all_to_all_ind2ptr(dist_eid,group = group)
edge_feat = graph.edge_attr.all_to_all_get(group = group,**ind_dict)
else:
edge_feat = graph._get_edge_attr(dist_eid)
ind_dict = None
if isinstance(graph.x,DistributedTensor):
ind_dict = graph.x.all_to_all_ind2ptr(dist_nid,group = group)
node_feat = graph.x.all_to_all_get(group = group,**ind_dict)
else:
node_feat = graph._get_node_attr(dist_nid)
if mailbox is not None:
if torch.distributed.get_world_size() > 1:
if node_feat is None:
ind_dict = mailbox.node_memory.all_to_all_ind2ptr(dist_nid,group = group)
mem = mailbox.gather_memory(**ind_dict)
else:
mem = mailbox.get_memory(dist_nid)
if isinstance(graph.edge_attr,DistributedTensor):
ind_dict = graph.edge_attr.all_to_all_ind2ptr(dist_eid,group = group)
edge_feat = graph.edge_attr.all_to_all_get(group = group,**ind_dict)
else:
edge_feat = graph._get_edge_attr(dist_eid)
ind_dict = None
if isinstance(graph.x,DistributedTensor):
ind_dict = graph.x.all_to_all_ind2ptr(dist_nid,group = group)
node_feat = graph.x.all_to_all_get(group = group,**ind_dict)
else:
node_feat = graph._get_node_attr(dist_nid)
if mailbox is not None:
if torch.distributed.get_world_size() > 1:
if node_feat is None:
ind_dict = mailbox.node_memory.all_to_all_ind2ptr(dist_nid,group = group)
mem = mailbox.gather_memory(**ind_dict)
else:
mem = None
mem = mailbox.get_memory(dist_nid)
else:
raw_nid = torch.empty_like(dist_nid)
raw_eid = torch.empty_like(dist_eid)
nid_tensor = nid_tensor.to(device)
eid_tensor = eid_tensor.to(device)
raw_nid[nid_inv] = nid_tensor
raw_eid[eid_inv] = eid_tensor
node_feat,edge_feat,mem = fetchCache.fetch_feature(raw_nid,
dist_nid,raw_eid,
dist_eid)
mem = None
def build_block():
mfgs = list()
col = torch.arange(0,root_len,device = device)
......@@ -142,7 +130,6 @@ def to_block(graph: DistributedGraphStore, data, sample_out, mailbox:MailBox = N
#return build_block(node_feat,edge_feat,mem)#data,mfgs,metadata
return (data,mfgs,metadata)
def graph_sample(graph, sampler:BaseSampler,
sample_fn, data,
neg_sampling = None,
......
......@@ -7,7 +7,7 @@ import torch
import torch.distributed as dist
from torch_geometric.data import Data
from starrygl.utils.uvm import *
class DistributedGraphStore:
'''
......@@ -37,6 +37,7 @@ class DistributedGraphStore:
self.sample_graph = pdata.sample_graph
self.nids_mapper = build_mapper(nids=pdata.ids.to(device)).dist.to('cpu')
self.eids_mapper = build_mapper(nids=pdata.eids.to(device)).dist.to('cpu')
self.sample_graph['dist_eid'] = self.eids_mapper[pdata.sample_graph['eids']]
torch.cuda.empty_cache()
self.num_nodes = self.nids_mapper.data.shape[0]
......@@ -46,17 +47,18 @@ class DistributedGraphStore:
self.uvm_edge = uvm_edge
if hasattr(pdata,'x') and pdata.x is not None:
ctx = DistributedContext.get_default_context()
pdata.x = pdata.x.to(torch.float)
if uvm_node == False :
x = pdata.x.to(self.device)
else:
if self.device.type == 'cuda':
x = starrygl.utils.uvm.uvm_empty(*pdata.x.size(),
x = uvm_empty(*pdata.x.size(),
dtype=pdata.x.dtype,
device=ctx.device)
starrygl.utils.uvm.uvm_share(x,device = ctx.device)
starrygl.utils.uvm.uvm_advise(x,starrygl.utils.uvm.cudaMemoryAdvise.cudaMemAdviseSetAccessedBy)
starrygl.utils.uvm.uvm_prefetch(x)
uvm_share(x,device = ctx.device)
uvm_advise(x,cudaMemoryAdvise.cudaMemAdviseSetAccessedBy)
uvm_prefetch(x)
if world_size > 1:
self.x = DistributedTensor(pdata.x.to(self.device).to(torch.float))
else:
......@@ -71,12 +73,15 @@ class DistributedGraphStore:
edge_attr = pdata.edge_attr.to(self.device)
else:
if self.device.type == 'cuda':
edge_attr = starrygl.utils.uvm.uvm_empty(*pdata.edge_attr.size(),
edge_attr = uvm_empty(*pdata.edge_attr.size(),
dtype=pdata.edge_attr.dtype,
device=ctx.device)
starrygl.utils.uvm.uvm_share(edge_attr,device = ctx.device)
starrygl.utils.uvm.uvm_advise(edge_attr,starrygl.utils.uvm.cudaMemoryAdvise.cudaMemAdviseSetAccessedBy)
starrygl.utils.uvm.uvm_prefetch(edge_attr)
edge_attr = uvm_share(edge_attr,device = torch.device('cpu'))
edge_attr.copy_(pdata.edge_attr)
edge_attr = uvm_share(edge_attr,device = ctx.device)
uvm_advise(edge_attr,cudaMemoryAdvise.cudaMemAdviseSetAccessedBy)
uvm_prefetch(edge_attr)
if world_size > 1:
self.edge_attr = DistributedTensor(edge_attr)
else:
......@@ -251,7 +256,8 @@ class TemporalNeighborSampleGraph(DistributedGraphStore):
self.edge_ts = sample_graph['ts']
else:
self.edge_ts = None
self.eid = sample_graph['eids']
self.eid = sample_graph['eids']#torch.arange(self.num_edges,dtype = torch.long, device = sample_graph['eids'].device)
#sample_graph['eids']
if mode == 'train':
mask = sample_graph['train_mask']
if mode == 'val':
......
......@@ -115,7 +115,7 @@ class SharedMailBox():
def set_memory_local(self,index,source,source_ts,Reduce_Op = None):
if Reduce_Op == 'max':
if Reduce_Op == 'max' and self.num_parts > 1:
unq_id,inv = index.unique(return_inverse = True)
max_ts,id = torch_scatter.scatter_max(source_ts,inv,dim=0)
source_ts = max_ts
......@@ -243,7 +243,6 @@ class SharedMailBox():
#futs: List[torch.futures.Future] = []
if self.num_parts == 1:
dist_index = DistIndex(index)
part_idx = dist_index.part
index = dist_index.loc
self.set_mailbox_local(index,mail,mail_ts)
self.set_memory_local(index,memory,memory_ts)
......@@ -317,7 +316,7 @@ class SharedMailBox():
if edge_feats is not None:
src_mail = torch.cat([src_mail, edge_feats], dim=1)
dst_mail = torch.cat([dst_mail, edge_feats], dim=1)
mail = torch.cat([src_mail, dst_mail], dim=1).reshape(-1, src_mail.shape[1])
mail = torch.cat([src_mail, dst_mail], dim=0)#.reshape(-1, src_mail.shape[1])
mail_ts = torch.cat((ts,ts),-1).to(self.device).to(self.mailbox_ts.dtype)
unq_index,inv = torch.unique(index,return_inverse = True)
max_ts,idx = torch_scatter.scatter_max(mail_ts,inv,0)
......
from torch_sparse import SparseTensor
from torch_geometric.data import Data
from torch_geometric.utils import degree
import starrygl
import os.path as osp
import os
import shutil
import torch
import torch.utils.data
import metis
import networkx as nx
import torch.distributed as dist
from starrygl.lib.libstarrygl_sampler import get_norm_temporal
from starrygl.utils.partition import mt_metis_partition
def partition_load(root: str, algo: str = "metis") -> Data:
......@@ -21,7 +21,6 @@ def partition_load(root: str, algo: str = "metis") -> Data:
def partition_save(root: str, data: Data, num_parts: int,
algo: str = "metis",
node_weight = None,
edge_weight_dict=None):
root = osp.abspath(root)
if osp.exists(root) and not osp.isdir(root):
......@@ -46,7 +45,6 @@ def partition_save(root: str, data: Data, num_parts: int,
if algo == 'metis_for_tgnn':
for i, pdata in enumerate(partition_data_for_tgnn(
data, num_parts, algo, verbose=True,
node_weight = node_weight,
edge_weight_dict=edge_weight_dict)):
print(f"saving partition data: {i+1}/{num_parts}")
fn = osp.join(path, f"{i:03d}")
......@@ -154,41 +152,33 @@ def _nopart(edge_index: torch.LongTensor, num_nodes: int):
def metis_for_tgnn(edge_index_dict: dict,
num_nodes: int,
num_parts: int,
node_weight = None,
edge_weight_dict=None):
if num_parts <= 1:
return _nopart(edge_index_dict, num_nodes)
edge_list = []
weight_list = []
for i,key in enumerate(edge_index_dict):
G = nx.Graph()
G.add_nodes_from(torch.arange(0, num_nodes).tolist())
value, counts = torch.unique(edge_index_dict['edata'][1, :].view(-1),
return_counts=True)
nodes = torch.tensor(list(G.adj.keys()))
for i in range(value.shape[0]):
if (value[i].item() in G.nodes):
G.nodes[int(value[i].item())]['weight'] = counts[i]
G.nodes[int(value[i].item())]['ones'] = 1
G.graph['node_weight_attr'] = ['weight', 'ones']
edges = []
for i, key in enumerate(edge_index_dict):
v = edge_index_dict[key]
edge_list.append(v)
weight_list.append(torch.ones(v.shape[1])*edge_weight_dict[key])
edge_index = torch.cat(edge_list,dim = 1)
edge_weight = torch.cat(weight_list,dim = 0)
node_parts = mt_metis_partition(edge_index,num_nodes,num_parts,node_weight,edge_weight)
edge = torch.cat((v, (torch.ones(v.shape[1], dtype=torch.long) *
edge_weight_dict[key]).unsqueeze(0)), dim=0)
edges.append(edge)
# w = edges.T
edges = torch.cat(edges,dim = 1)
G.add_weighted_edges_from((edges.T).tolist())
G.graph['edge_weight_attr'] = 'weight'
cuts, part = metis.part_graph(G, num_parts)
node_parts = torch.zeros(num_nodes, dtype=torch.long)
node_parts[nodes] = torch.tensor(part)
return node_parts
#G = nx.Graph()
#G.add_nodes_from(torch.arange(0, num_nodes).tolist())
#value, counts = torch.unique(edge_index_dict['edata'][1, :].view(-1),
# return_counts=True)
#nodes = torch.tensor(list(G.adj.keys()))
#for i in range(value.shape[0]):
# if (value[i].item() in G.nodes):
# G.nodes[int(value[i].item())]['weight'] = counts[i]
# G.nodes[int(value[i].item())]['ones'] = 1
#G.graph['node_weight_attr'] = ['weight', 'ones']
#for i, key in enumerate(edge_index_dict):
# v = edge_index_dict[key]
# edges = torch.cat((v, (torch.ones(v.shape[1], dtype=torch.long) *
# edge_weight_dict[key]).unsqueeze(0)), dim=0)
# # w = edges.T
# G.add_weighted_edges_from((edges.T).tolist())
#G.graph['edge_weight_attr'] = 'weight'
#cuts, part = metis.part_graph(G, num_parts)
#node_parts = torch.zeros(num_nodes, dtype=torch.long)
#node_parts[nodes] = torch.tensor(part)
#return node_parts
"""
......@@ -199,7 +189,6 @@ weight: 各种工作负载边划分权重
def partition_data_for_tgnn(data: Data, num_parts: int, algo: str,
verbose: bool = False,
node_weight: torch.Tensor = None,
edge_weight_dict: dict = None):
if algo == "metis_for_tgnn":
part_fn = metis_for_tgnn
......@@ -213,7 +202,6 @@ def partition_data_for_tgnn(data: Data, num_parts: int, algo: str,
if verbose:
print(f"running partition algorithm: {algo}")
node_parts = part_fn(edge_index_dict, num_nodes, num_parts,
node_weight,
edge_weight_dict)
edge_parts = node_parts[data.edge_index[1, :]]
eids = torch.arange(num_edges, dtype=torch.long)
......@@ -304,7 +292,7 @@ def compute_gcn_norm(edge_index: torch.LongTensor, num_nodes: int):
def compute_temporal_norm(edge_index: torch.LongTensor,
timestamp: torch.FloatTensor,
num_nodes: int):
srcavg, srcvar, dstavg, dstvar = get_norm_temporal(edge_index[0, :],
srcavg, srcvar, dstavg, dstvar = starrygl.sampler_ops.get_norm_temporal(edge_index[0, :],
edge_index[1, :],
timestamp, num_nodes)
return srcavg, srcvar, dstavg, dstvar
......
import sys
from os.path import abspath, join, dirname
sys.path.insert(0, join(abspath(dirname(__file__))))
from torch import Tensor
import torch
from base import NegativeSampling
from base import NegativeSamplingMode
from typing import Any, List, Optional, Tuple, Union
class EvaluateNegativeSampling(NegativeSampling):
def __init__(
self,
mode: Union[NegativeSamplingMode, str],
src_node_ids: torch.Tensor,
dst_node_ids: torch.Tensor,
interact_times: torch.Tensor = None,
last_observed_time: float = None,
negative_sample_strategy: str = 'random',
seed: int = None
):
super(EvaluateNegativeSampling,self).__init__(mode)
self.seed = seed
self.negative_sample_strategy = negative_sample_strategy
self.src_node_ids = src_node_ids
self.dst_node_ids = dst_node_ids
self.interact_times = interact_times
self.unique_src_nodes_id = src_node_ids.unique()
self.unique_dst_nodes_id = dst_node_ids.unique()
self.src_id_mapper = torch.zeros(self.unique_src_nodes_id[-1])
self.dst_id_mapper = torch.zeros(self.unique_dst_nodes_id[-1])
self.src_id_mapper[self.unique_src_nodes_id] = torch.arange(self.unique_src_nodes_id.shape[0])
self.dst_id_mapper[self.unique_dst_nodes_id] = torch.arange(self.unique_dst_nodes_id.shape[0])
self.unique_interact_times = self.interact_times.unique()
self.earliest_time = self.unique_interact_times.min().item()
self.last_observed_time = last_observed_time
if self.negative_sample_strategy == 'inductive':
# set of observed edges
self.observed_edges = self.get_unique_edges_between_start_end_time(self.earliest_time, self.last_observed_time)
if self.seed is not None:
self.random_state = torch.Generator()
self.random_state.manual_seed(seed)
else:
self.random_state = torch.Generator()
def get_unique_edges_between_start_end_time(self, start_time: float, end_time: float):
selected_mask = ((self.interact_times >= start_time) and (self.interact_times <= end_time))
# return the unique select source and destination nodes in the selected time interval
return torch.cat((self.src_node_ids[selected_mask],self.dst_node_ids[selected_mask]),dim = 1)
def sample(self, num_samples: int, num_nodes: Optional[int] = None, batch_src_node_ids: Optional[torch.Tensor] = None,
batch_dst_node_ids: Optional[torch.Tensor] = None, current_batch_start_time: Optional[torch.Tensor] = None,
current_batch_end_time: Optional[torch.Tensor] = None) -> Tensor:
if self.negative_sample_strategy == 'random':
negative_src_node_ids, negative_dst_node_ids = self.random_sample(size=num_samples)
elif self.negative_sample_strategy == 'historical':
negative_src_node_ids, negative_dst_node_ids = self.historical_sample(size=num_samples, batch_src_node_ids=batch_src_node_ids,
batch_dst_node_ids=batch_dst_node_ids,
current_batch_start_time=current_batch_start_time,
current_batch_end_time=current_batch_end_time)
elif self.negative_sample_strategy == 'inductive':
negative_src_node_ids, negative_dst_node_ids = self.inductive_sample(size=num_samples, batch_src_node_ids=batch_src_node_ids,
batch_dst_node_ids=batch_dst_node_ids,
current_batch_start_time=current_batch_start_time,
current_batch_end_time=current_batch_end_time)
else:
raise ValueError(f'Not implemented error for negative_sample_strategy {self.negative_sample_strategy}!')
return negative_src_node_ids, negative_dst_node_ids
def random_sample(self, size: int):
if self.seed is None:
random_sample_edge_src_node_indices = torch.randint(0, len(self.unique_src_nodes_id), size)
random_sample_edge_dst_node_indices = torch.randint(0, len(self.unique_dst_nodes_id), size)
else:
random_sample_edge_src_node_indices = torch.randint(0, len(self.unique_src_nodes_id), size, generate = self.random_state)
random_sample_edge_dst_node_indices = torch.randint(0, len(self.unique_dst_nodes_id), size, generate = self.random_state)
return self.unique_src_nodes_id[random_sample_edge_src_node_indices], self.unique_dst_nodes_id[random_sample_edge_dst_node_indices]
def random_sample_with_collision_check(self, size: int, batch_src_nodes_id:torch.Tensor, batch_dst_nodes_id:torch.Tensor):
batch_edge = torch.stack((batch_src_nodes_id,batch_dst_nodes_id))
batch_src_index = self.src_id_mapper[batch_src_nodes_id]
batch_dst_index = self.dst_id_mapper[batch_dst_nodes_id]
return_edge = torch.tensor([[],[]])
while(True):
src_ = torch.randint(0, len(self.unique_src_nodes_id), size*2)
dst_ = torch.randint(0, len(self.unique_dst_nodes_id), size*2)
edge = torch.stack((src_,dst_))
sample_id = src_*self.unique_dst_nodes_id.shape[0] + dst_
batch_id = batch_src_index * self.unique_dst_nodes_id.shape[0] + batch_dst_index
mask = torch.isin(sample_id,batch_id,invert = True)
edge = edge[:,mask]
if(edge.shape[1] >= size):
return_edge = torch.cat((return_edge,edge[:,:size]),1)
break
else:
return_edge = torch.cat((return_edge,edge),1)
size = size - edge.shape[1]
return return_edge
def historical_sample(self, size: int, batch_src_nodes_id: torch.Tensor, batch_dst_nodes_id: torch.Tensor,
current_batch_start_time: float, current_batch_end_time: float):
assert self.seed is not None
historical_edges = self.get_unique_edges_between_start_end_time(start_time=self.earliest_time, end_time=current_batch_start_time)
current_batch_edges = self.get_unique_edges_between_start_end_time(start_time=current_batch_start_time, end_time=current_batch_end_time)
uni,ids = torch.cat((current_batch_edges, historical_edges), dim = 1).unique(dim = 1, return_inverse = False)
mask = torch.zeros(uni.shape[1],dtype = bool)
mask[ids[:current_batch_edges.shape[1]]] = True
mask = (~mask)
unique_historical_edges = uni[:,mask]
if size > unique_historical_edges.shape[1]:
num_random_sample_edges = size - len(unique_historical_edges)
random_sample_edge = self.random_sample_with_collision_check(size=num_random_sample_edges,batch_src_node_ids=batch_src_nodes_id,
batch_dst_node_ids=batch_dst_nodes_id)
sample_edges = torch.cat((unique_historical_edges,random_sample_edge),dim = 1)
else:
historical_sample_edge_node_indices = torch.randperm(unique_historical_edges.shape[1],generator=self.random_state)
sample_edges = unique_historical_edges[:,historical_sample_edge_node_indices[:size]]
return sample_edges
def inductive_sample(self, size: int, batch_src_node_ids: torch.Tensor, batch_dst_node_ids: torch.Tensor,
current_batch_start_time: float, current_batch_end_time: float):
assert self.seed is not None
historical_edges = self.get_unique_edges_between_start_end_time(start_time=self.earliest_time, end_time=current_batch_start_time)
current_batch_edges = self.get_unique_edges_between_start_end_time(start_time=current_batch_start_time, end_time=current_batch_end_time)
uni,ids = torch.cat((self.observed_edges,current_batch_edges, historical_edges), dim = 1).unique(dim = 1, return_inverse = False)
mask = torch.zeros(uni.shape[1],dtype = bool)
mask[ids[:current_batch_edges.shape[1]+historical_edges.shape[1]]] = True
mask = (~mask)
unique_inductive_edges = uni[:,mask]
if size > len(unique_inductive_edges):
num_random_sample_edges = size - len(unique_inductive_edges)
random_sample_edge = self.random_sample_with_collision_check(size=num_random_sample_edges,
batch_src_node_ids=batch_src_node_ids,
batch_dst_node_ids=batch_dst_node_ids)
sample_edges = torch.cat((unique_inductive_edges,random_sample_edge),dim = 1)
else:
inductive_sample_edge_node_indices = torch.randperm(unique_inductive_edges.shape[1],generator=self.random_state)
sample_edges = unique_inductive_edges[:, inductive_sample_edge_node_indices[:size]]
return sample_edges
......@@ -4,7 +4,7 @@ from enum import Enum
import math
from abc import ABC
from typing import Any, List, Optional, Tuple, Union
import numpy as np
class SampleType(Enum):
Whole = 0
Inner = 1
......@@ -82,7 +82,8 @@ class NegativeSampling:
f"Cannot sample negatives in '{self.__class__.__name__}' "
f"without passing the 'num_nodes' argument")
return torch.randint(num_nodes, (num_samples, ))
#return torch.from_numpy(np.random.randint(num_nodes, size=num_samples))
if num_nodes is not None and self.weight.numel() != num_nodes:
raise ValueError(
f"The 'weight' attribute in '{self.__class__.__name__}' "
......
......@@ -262,6 +262,8 @@ class NeighborSampler(BaseSampler):
seed = torch.cat([src, dst, src_neg], dim=0)
if with_timestap: # ts操作
seed_ts = torch.cat([ets, ets, ets], dim=0)
#if neg_sampling.is_evaluate():
#src,dst = neg_sampling.sample(num_samples=)
else:
seed = torch.cat([src, dst], dim=0)
if with_timestap: # ts操作
......
......@@ -40,6 +40,8 @@ parser = argparse.ArgumentParser(
)
parser.add_argument('--rank', default=0, type=int, metavar='W',
help='name of dataset')
parser.add_argument('--local_rank', default=0, type=int, metavar='W',
help='name of dataset')
parser.add_argument('--patience', type=int, default=5, help='Patience for early stopping')
parser.add_argument('--world_size', default=1, type=int, metavar='W',
help='number of negative samples')
......@@ -56,14 +58,16 @@ import numpy as np
from sklearn.metrics import average_precision_score, roc_auc_score
from torch.nn.parallel import DistributedDataParallel as DDP
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.rank)
os.environ["RANK"] = str(args.rank)
os.environ["WORLD_SIZE"] = str(args.world_size)
os.environ["LOCAL_RANK"] = str(0)
#os.environ['CUDA_VISIBLE_DEVICES'] = str(args.rank)
if not 'WORLD_SIZE' in os.environ:
os.environ["RANK"] = str(args.rank)
os.environ["WORLD_SIZE"] = str(args.world_size)
os.environ["LOCAL_RANK"] = str(args.local_rank)
if not 'MASTER_ADDR' in os.environ:
os.environ["MASTER_ADDR"] = '192.168.2.107'
if not 'MASTER_PORT' in os.environ:
os.environ["MASTER_PORT"] = '9337'
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
os.environ["MASTER_ADDR"] = '10.214.211.186'
os.environ["MASTER_PORT"] = '9667'
def seed_everything(seed=42):
random.seed(seed)
np.random.seed(seed)
......@@ -72,18 +76,18 @@ def seed_everything(seed=42):
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
seed_everything(1234)
seed_everything(34)
def main():
print('main')
print('LOCAL RANK {}, RANK{}'.format(os.environ["LOCAL_RANK"],os.environ["RANK"]))
use_cuda = True
sample_param, memory_param, gnn_param, train_param = parse_config('./config/{}.yml'.format(args.model))
torch.set_num_threads(12)
ctx = DistributedContext.init(backend="nccl", use_gpu=True)
torch.set_num_threads(int(80/torch.distributed.get_world_size()))
device_id = torch.cuda.current_device()
print('use cuda on',device_id)
pdata = partition_load("/mnt/data/part_data/here/{}".format(args.dataname), algo="metis_for_tgnn")
pdata = partition_load("/mnt/data/part_data/v2/here/{}".format(args.dataname), algo="metis_for_tgnn")
graph = DistributedGraphStore(pdata = pdata)
print(graph.num_nodes)
Path("./saved_models/").mkdir(parents=True, exist_ok=True)
Path("./saved_checkpoints/").mkdir(parents=True, exist_ok=True)
get_checkpoint_path = lambda \
......@@ -92,27 +96,24 @@ def main():
use_src_emb = gnn_param['use_src_emb'] if 'use_src_emb' in gnn_param else False
use_dst_emb = gnn_param['use_dst_emb'] if 'use_dst_emb' in gnn_param else False
sample_graph = TemporalNeighborSampleGraph(sample_graph = pdata.sample_graph,mode = 'full')
mailbox = SharedMailBox(pdata.ids.shape[0], memory_param, dim_edge_feat = pdata.edge_attr.shape[1] if pdata.edge_attr is not None else 0)
sampler = NeighborSampler(num_nodes=graph.num_nodes, num_layers=1, fanout=[10],graph_data=sample_graph, workers=15,policy = 'recent',graph_name = "wiki_train")
if memory_param['type'] != 'none':
mailbox = SharedMailBox(pdata.ids.shape[0], memory_param, dim_edge_feat = pdata.edge_attr.shape[1] if pdata.edge_attr is not None else 0)
else:
mailbox = None
fanout = []
num_layers = sample_param['layer'] if 'layer' in sample_param else 1
fanout = sample_param['neighbor'] if 'neighbor' in sample_param else [10]
policy = sample_param['strategy'] if 'strategy' in sample_param else 'recent'
sampler = NeighborSampler(num_nodes=graph.num_nodes, num_layers=num_layers, fanout=fanout,graph_data=sample_graph, workers=int(80/torch.distributed.get_world_size()),policy = policy, graph_name = "wiki_train")
train_data = torch.masked_select(graph.edge_index,pdata.train_mask.to(graph.edge_index.device)).reshape(2,-1)
train_ts = torch.masked_select(graph.edge_ts,pdata.train_mask.to(graph.edge_index.device))
val_data = torch.masked_select(graph.edge_index,pdata.val_mask.to(graph.edge_index.device)).reshape(2,-1)
val_ts = torch.masked_select(graph.edge_ts,pdata.val_mask.to(graph.edge_index.device))
test_data = torch.masked_select(graph.edge_index,pdata.test_mask.to(graph.edge_index.device)).reshape(2,-1)
test_ts = torch.masked_select(graph.edge_ts,pdata.test_mask.to(graph.edge_index.device))
##print(train_data.shape[1],val_data.shape[1],test_data.shape[1])
test_ts = torch.masked_select(graph.edge_ts,pdata.test_mask.to(graph.edge_index.device))
val_data = torch.masked_select(graph.edge_index,pdata.val_mask.to(graph.edge_index.device)).reshape(2,-1)
val_ts = torch.masked_select(graph.edge_ts,pdata.val_mask.to(graph.edge_index.device))
train_data = DataSet(edges = train_data,ts =train_ts,eids = torch.nonzero(pdata.train_mask).view(-1))
#if dist.get_rank() == 0:
test_data = DataSet(edges = test_data,ts =test_ts,eids = torch.nonzero(pdata.test_mask).view(-1))
val_data = DataSet(edges = val_data,ts = val_ts,eids = torch.nonzero(pdata.val_mask).view(-1))
#else:
#test_data = torch.tensor([[],[]],device = graph.edge_index.device,detype = graph.edge_index.#dtype)
#val_data = torch.tensor([[],[]],device = graph.edge_index.device,detype = graph.edge_index.dtype)
#test_ts = torch.tensor([[],[]],device = graph.ts.device,detype = graph.ts.dtype)
#val_ts = torch.tensor([[],[]],device = graph.ts.device,detype = graph.ts.dtype)
#test_data = DataSet(edges = test_data,ts =test_ts,eids = torch.tensor([],dtype = torch.long,#device = torch.cuda))
#val_data = DataSet(edges = val_data,ts = val_ts,eids = torch.tensor([],dtype = torch.long,device #= torch.cuda))
#train_neg_sampler = PreNegativeSampling('triplet',torch.masked_select(pdata.edge_index['pos_edge'],graph.data.train_mask).reshape(2,-1))
neg_sampler = NegativeSampling('triplet')
trainloader = DistributedDataLoader(graph,train_data,sampler = sampler,
sampler_fn = SAMPLE_TYPE.SAMPLE_FROM_TEMPORAL_EDGES,
......@@ -122,7 +123,7 @@ def main():
drop_last=True,
chunk_size = None,
train=True,
queue_size = 1000,
queue_size = 200,
mailbox = mailbox,
)
testloader = DistributedDataLoader(graph,test_data,sampler = sampler,
......@@ -145,12 +146,10 @@ def main():
train=False,
queue_size = 100,
mailbox = mailbox)
#FetchFeatureCache.create_fetch_cache(graph.num_nodes,graph.eids_mapper.shape[0],0.1,0.1,graph,mailbox,policy = 'static')
#cache = FetchFeatureCache.getFetchCache()
#cache.init_cache_with_presample(trainloader,3)
gnn_dim_node = 0 if graph.x is None else pdata.x.shape[1]
gnn_dim_edge = 0 if graph.edge_attr is None else pdata.edge_attr.shape[1]
print("gnn_dim_node:", gnn_dim_node, "gnn_dim_edge:", gnn_dim_edge)
avg_time = 0
if use_cuda:
model = GeneralModel(gnn_dim_node, gnn_dim_edge, sample_param, memory_param, gnn_param, train_param).cuda()
......@@ -158,7 +157,7 @@ def main():
else:
model = GeneralModel(gnn_dim_node, gnn_dim_edge, sample_param, memory_param, gnn_param, train_param)
device = torch.device('cpu')
model = DDP(model,find_unused_parameters=False)
model = DDP(model,find_unused_parameters=True)
train_stream = torch.cuda.Stream()
send_stream = torch.cuda.Stream()
scatter_stream = torch.cuda.Stream()
......@@ -187,11 +186,11 @@ def main():
y_pred = torch.cat([pred_pos, pred_neg], dim=0).sigmoid().cpu()
y_true = torch.cat([torch.ones(pred_pos.size(0)), torch.zeros(pred_neg.size(0))], dim=0)
aps.append(average_precision_score(y_true, y_pred.detach().numpy()))
aucs_mrrs.append(roc_auc_score(y_true, y_pred))
if mailbox is not None:
src = metadata['src_pos_index']
dst = metadata['dst_pos_index']
ts = roots.ts
if graph.edge_attr is None:
edge_feats = None
......@@ -199,6 +198,7 @@ def main():
edge_feats = graph.edge_attr[roots.eids.to('cpu')].to('cuda')
else:
edge_feats = graph.edge_attr[roots.eids]
dist_index_mapper = mfgs[0][0].srcdata['ID']
root_index = torch.cat((src,dst))
last_updated_nid = model.module.memory_updater.last_updated_nid[root_index]
......@@ -214,6 +214,7 @@ def main():
model.module.embedding,use_src_emb,
use_dst_emb,
)
mailbox.set_mailbox_all_to_all(index,memory,memory_ts,mail,mail_ts,reduce_Op = 'max')
......@@ -253,27 +254,22 @@ def main():
model.module.memory_updater.last_updated_memory = None
model.module.memory_updater.last_updated_ts = None
for roots,mfgs,metadata in trainloader:
# fetch_time +=sample_time/1000
#fetch_time +=sample_time/1000
t_prep_s = time.time()
with torch.cuda.stream(train_stream):
optimizer.zero_grad()
pred_pos, pred_neg = model(mfgs,metadata)
loss = creterion(pred_pos, torch.ones_like(pred_pos))
loss += creterion(pred_neg, torch.zeros_like(pred_neg))
total_loss += float(loss)
loss.backward()
optimizer.step()
#torch.cuda.synchronize()
t_prep_s = time.time()
# y_pred = torch.cat([pred_pos, pred_neg], dim=0).sigmoid().cpu()
# y_true = torch.cat([torch.ones(pred_pos.size(0)), torch.zeros(pred_neg.size(0))], dim=0)
# train_aps.append(average_precision_score(y_true, y_pred.detach().numpy()))
#start_event = torch.cuda.Event(enable_timing=True)
#end_event = torch.cuda.Event(enable_timing=True)
#start_event.record()
y_pred = torch.cat([pred_pos, pred_neg], dim=0).sigmoid().cpu()
y_true = torch.cat([torch.ones(pred_pos.size(0)), torch.zeros(pred_neg.size(0))], dim=0)
train_aps.append(average_precision_score(y_true, y_pred.detach().numpy()))
if mailbox is not None:
src = metadata['src_pos_index']
dst = metadata['dst_pos_index']
......@@ -298,13 +294,8 @@ def main():
model.module.memory_updater.last_updated_memory,
model.module.embedding,use_src_emb,use_dst_emb,
)
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
mailbox.set_mailbox_all_to_all(index,memory,memory_ts,mail,mail_ts,reduce_Op = 'max')
end_event.record()
torch.cuda.synchronize()
write_back_time += start_event.elapsed_time(end_event)/1000
torch.cuda.synchronize()
time_prep = time.time() - epoch_start_time
......@@ -312,24 +303,24 @@ def main():
train_ap = float(torch.tensor(train_aps).mean())
ap = 0
auc = 0
#if cache.edge_cache is not None:
# print('hit {}'.format(cache.edge_cache.hit_/ cache.edge_cache.hit_sum))
#if cache.node_cache is not None:
# print('hit {}'.format(cache.node_cache.hit_/ cache.node_cache.hit_sum))
ap, auc = eval('val')
early_stop = early_stopper.early_stop_check(ap)
if early_stop:
print("Early stopping at epoch {:d}".format(e))
print(f"Loading the best model at epoch {early_stopper.best_epoch}")
print("Early stopping at epoch {:d}\n".format(e))
print(f"Loading the best model at epoch {early_stopper.best_epoch}\n")
best_model_path = get_checkpoint_path(early_stopper.best_epoch)
model.load_state_dict(torch.load(best_model_path))
break
else:
print('\ttrain loss:{:.4f} train ap:{:4f} val ap:{:4f} val auc:{:4f}'.format(total_loss,train_ap, ap, auc))
print('\ttotal time:{:.2f}s prep time:{:.2f}s'.format(time.time()-epoch_start_time, time_prep))
# print('\t fetch time:{:.2f}s write back time:{:.2f}s'.format(fetch_time,write_back_time))
print('\ttrain loss:{:.4f} train ap:{:4f} val ap:{:4f} val auc:{:4f}\n'.format(total_loss,train_ap, ap, auc))
print('\ttotal time:{:.2f}s prep time:{:.2f}s\n'.format(time.time()-epoch_start_time, time_prep))
print('\t fetch time:{:.2f}s write back time:{:.2f}s\n'.format(fetch_time,write_back_time))
torch.save(model.state_dict(), get_checkpoint_path(e))
if not early_stop:
print(f"Loading the best model at epoch {early_stopper.best_epoch}")
best_model_path = get_checkpoint_path(early_stopper.best_epoch)
model.load_state_dict(torch.load(best_model_path))
model.eval()
if mailbox is not None:
mailbox.reset()
......@@ -339,10 +330,10 @@ def main():
ap, auc = eval('test')
eval_neg_samples = 1
if eval_neg_samples > 1:
print('\ttest AP:{:4f} test MRR:{:4f}'.format(ap, auc))
print('\ttest AP:{:4f} test MRR:{:4f}\n'.format(ap, auc))
else:
print('\ttest AP:{:4f} test AUC:{:4f}'.format(ap, auc))
print('test_dataset',test_data.edges.shape[1],'avg_time',avg_time/train_param['epoch'])
print('\ttest AP:{:4f} test AUC:{:4f}\n'.format(ap, auc))
print('test_dataset {} avg_time {} \n'.format(test_data.edges.shape[1],avg_time/train_param['epoch']))
torch.save(model.state_dict(), MODEL_SAVE_PATH)
ctx.shutdown()
if __name__ == "__main__":
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment