Commit 36c9a61a by Wenjie Huang

Merge branch 'doc-v2' of http://10.75.75.11:7001/wjie98/starrygl

parents 4140455c 8ee95360
...@@ -179,4 +179,9 @@ cython_debug/ ...@@ -179,4 +179,9 @@ cython_debug/
/test_* /test_*
/*.ipynb /*.ipynb
saved_models/ saved_models/
saved_checkpoints/ saved_checkpoints/
\ No newline at end of file .history/
.preprocess_data/
.processed_data/
.DG_data/
.examples/
...@@ -25,8 +25,8 @@ gnn: ...@@ -25,8 +25,8 @@ gnn:
dim_time: 100 dim_time: 100
dim_out: 100 dim_out: 100
train: train:
- epoch: 50 - epoch: 100
batch_size: 100 batch_size: 1000
# reorder: 16 # reorder: 16
lr: 0.0001 lr: 0.0001
dropout: 0.1 dropout: 0.1
......
...@@ -16,8 +16,8 @@ gnn: ...@@ -16,8 +16,8 @@ gnn:
use_dst_emb: False use_dst_emb: False
time_transform: 'JODIE' time_transform: 'JODIE'
train: train:
- epoch: 20 - epoch: 100
batch_size: 200 batch_size: 1000
lr: 0.0001 lr: 0.0001
dropout: 0.1 dropout: 0.1
all_on_gpu: True all_on_gpu: True
\ No newline at end of file
...@@ -13,13 +13,15 @@ memory: ...@@ -13,13 +13,15 @@ memory:
dim_out: 0 dim_out: 0
gnn: gnn:
- arch: 'transformer_attention' - arch: 'transformer_attention'
use_src_emb: False
use_dst_emb: False
layer: 2 layer: 2
att_head: 2 att_head: 2
dim_time: 100 dim_time: 100
dim_out: 100 dim_out: 100
train: train:
- epoch: 100 - epoch: 50
batch_size: 600 batch_size: 1000
lr: 0.0001 lr: 0.0001
dropout: 0.1 dropout: 0.1
att_dropout: 0.1 att_dropout: 0.1
......
...@@ -25,8 +25,8 @@ gnn: ...@@ -25,8 +25,8 @@ gnn:
dim_time: 100 dim_time: 100
dim_out: 100 dim_out: 100
train: train:
- epoch: 20 - epoch: 5
batch_size: 200 batch_size: 1000
# reorder: 16 # reorder: 16
lr: 0.0001 lr: 0.0001
dropout: 0.2 dropout: 0.2
......
...@@ -25,8 +25,8 @@ gnn: ...@@ -25,8 +25,8 @@ gnn:
dim_time: 100 dim_time: 100
dim_out: 100 dim_out: 100
train: train:
- epoch: 20 - epoch: 50
batch_size: 200 batch_size: 1000
# reorder: 16 # reorder: 16
lr: 0.0001 lr: 0.0001
dropout: 0.2 dropout: 0.2
......
...@@ -19,29 +19,6 @@ parser.add_argument('--num_neg_sample', default=1, type=int, metavar='W', ...@@ -19,29 +19,6 @@ parser.add_argument('--num_neg_sample', default=1, type=int, metavar='W',
args = parser.parse_args() args = parser.parse_args()
def load_feat(d, rand_de=0, rand_dn=0):
node_feats = None
if os.path.exists('DATA/{}/node_features.pt'.format(d)):
node_feats = torch.load('DATA/{}/node_features.pt'.format(d))
if node_feats.dtype == torch.bool:
node_feats = node_feats.type(torch.float32)
edge_feats = None
if os.path.exists('DATA/{}/edge_features.pt'.format(d)):
edge_feats = torch.load('DATA/{}/edge_features.pt'.format(d))
if edge_feats.dtype == torch.bool:
edge_feats = edge_feats.type(torch.float32)
if rand_de > 0:
if d == 'LASTFM':
edge_feats = torch.randn(1293103, rand_de)
elif d == 'MOOC':
edge_feats = torch.randn(411749, rand_de)
if rand_dn > 0:
if d == 'LASTFM':
node_feats = torch.randn(1980, rand_dn)
elif d == 'MOOC':
edge_feats = torch.randn(7144, rand_dn)
return node_feats, edge_feats
data_name = args.data_name data_name = args.data_name
g = np.load('/mnt/nfs/fzz/TGL-DATA/'+data_name+'/ext_full.npz') g = np.load('/mnt/nfs/fzz/TGL-DATA/'+data_name+'/ext_full.npz')
...@@ -95,8 +72,9 @@ if e_feat is not None: ...@@ -95,8 +72,9 @@ if e_feat is not None:
data.edge_attr = e_feat data.edge_attr = e_feat
data.train_mask = (torch.from_numpy(np.array(df.ext_roll.values)) == 0) data.train_mask = (torch.from_numpy(np.array(df.ext_roll.values)) == 0)
data.test_mask = (torch.from_numpy(np.array(df.ext_roll.values)) == 1) data.val_mask = (torch.from_numpy(np.array(df.ext_roll.values)) == 1)
data.val_mask = (torch.from_numpy(np.array(df.ext_roll.values)) == 2) data.test_mask = (torch.from_numpy(np.array(df.ext_roll.values)) == 2)
print(ts[data.train_mask].min(),ts[data.train_mask].max(),ts[data.val_mask].min(),ts[data.val_mask].max(),ts[data.test_mask].min(),ts[data.test_mask].max())
sample_graph['train_mask'] = data.train_mask[sample_eid] sample_graph['train_mask'] = data.train_mask[sample_eid]
sample_graph['test_mask'] = data.test_mask[sample_eid] sample_graph['test_mask'] = data.test_mask[sample_eid]
sample_graph['val_mask'] = data.val_mask[sample_eid] sample_graph['val_mask'] = data.val_mask[sample_eid]
...@@ -106,28 +84,19 @@ data.y = torch.zeros(edge_index.shape[1]) ...@@ -106,28 +84,19 @@ data.y = torch.zeros(edge_index.shape[1])
edge_index_dict = {} edge_index_dict = {}
edge_index_dict['edata'] = data.edge_index edge_index_dict['edata'] = data.edge_index
edge_index_dict['sample_data'] = data.sample_graph['edge_index'] edge_index_dict['sample_data'] = data.sample_graph['edge_index']
edge_index_dict['neg_data'] = torch.cat([neg_src.view(1, -1),
dst.view(-1, 1).repeat(1, neg_nums).
reshape(1, -1)], dim=0)
data.edge_index_dict = edge_index_dict data.edge_index_dict = edge_index_dict
edge_weight_dict = {} edge_weight_dict = {}
edge_weight_dict['edata'] = 2*neg_nums edge_weight_dict['edata'] = 1*neg_nums
edge_weight_dict['sample_data'] = 1*neg_nums edge_weight_dict['sample_data'] = 1*neg_nums
edge_weight_dict['neg_data'] = 1 partition_save('/mnt/data/part_data/v2/here/'+data_name, data, 1, 'metis_for_tgnn',
#partition_save('./dataset/here/'+data_name, data, 1, 'metis_for_tgnn', edge_weight_dict=edge_weight_dict)
# edge_weight_dict=edge_weight_dict) partition_save('/mnt/data/part_data/v2/here/'+data_name, data, 2, 'metis_for_tgnn',
#partition_save('./dataset/here/'+data_name, data, 2, 'metis_for_tgnn', edge_weight_dict=edge_weight_dict)
# edge_weight_dict=edge_weight_dict) partition_save('/mnt/data/part_data/v2/here/'+data_name, data, 4, 'metis_for_tgnn',
#partition_save('./dataset/here/'+data_name, data, 4, 'metis_for_tgnn', edge_weight_dict=edge_weight_dict)
# edge_weight_dict=edge_weight_dict) partition_save('/mnt/data/part_data/v2/here/'+data_name, data, 8, 'metis_for_tgnn',
#partition_save('./dataset/here/'+data_name, data, 8, 'metis_for_tgnn', edge_weight_dict=edge_weight_dict)
# edge_weight_dict=edge_weight_dict)
partition_save('./dataset/here/'+data_name, data, 16, 'metis_for_tgnn', partition_save('./dataset/here/'+data_name, data, 16, 'metis_for_tgnn',
edge_weight_dict=edge_weight_dict) edge_weight_dict=edge_weight_dict)
#
# partition_save('./dataset/here/'+data_name, data, 4, 'metis_for_tgnn',
# edge_weight_dict=edge_weight_dict )
# partition_save('./dataset/here'+data_name, data, 8, 'metis')
# partition_save('./dataset/'+data_name, data, 12, 'metis')
# partition_save('./dataset/'+data_name, data, 16, 'metis')
...@@ -3,16 +3,11 @@ ...@@ -3,16 +3,11 @@
mkdir -p build && cd build mkdir -p build && cd build
cmake .. \ cmake .. \
-DCMAKE_EXPORT_COMPILE_COMMANDS=ON \ -DCMAKE_EXPORT_COMPILE_COMMANDS=ON \
-DCMAKE_PREFIX_PATH=$(python -c "from distutils.sysconfig import get_python_lib; print(get_python_lib())") \ -DCMAKE_PREFIX_PATH="/home/zlj/.miniconda3/envs/sgl/lib/python3.10/site-packages" \
-DPython3_ROOT_DIR=$(python -c "import sys; print(sys.prefix)") \ -DPython3_ROOT_DIR="/home/zlj/.miniconda3/envs/sgl" \
-DCUDA_TOOLKIT_ROOT_DIR=${CUDA_HOME:-"$(realpath $(dirname $(which nvcc))/../)"} \ -DCUDA_TOOLKIT_ROOT_DIR="/home/zlj/.local/cuda-11.8" \
&& make -j32 \ && make -j32 \
&& rm -rf ../starrygl/lib \ && rm -rf ../starrygl/lib \
&& mkdir ../starrygl/lib \ && mkdir ../starrygl/lib \
&& cp lib*.so ../starrygl/lib/ \ && cp lib*.so ../starrygl/lib/ \
&& patchelf --set-rpath '$ORIGIN:$ORIGIN/lib' --force-rpath ../starrygl/lib/*.so && patchelf --set-rpath '$ORIGIN:$ORIGIN/lib' --force-rpath ../starrygl/lib/*.so
# -DCMAKE_PREFIX_PATH="/home/zlj/.miniconda3/envs/dgnn/lib/python3.10/site-packages" \
# -DPython3_ROOT_DIR="/home/zlj/.miniconda3/envs/dgnn" \
# -DCUDA_TOOLKIT_ROOT_DIR="/home/zlj/local/cuda-12.2" \
\ No newline at end of file
import random
import pandas as pd
import numpy as np
import os
import torch
from torch_geometric.data import Data
from starrygl.sample.graph_core import DataSet, DistributedGraphStore
def get_link_prediction_data(data_name: str, val_ratio, test_ratio):
"""
generate data for link prediction task (inductive & transductive settings)
:param dataset_name: str, dataset name
:param val_ratio: float, validation data ratio
:param test_ratio: float, test data ratio
:return: node_raw_features, edge_raw_features, (np.ndarray),
full_data, train_data, val_data, test_data, new_node_val_data, new_node_test_data, (Data object)
"""
# Load data and train val test split
#graph_df = pd.read_csv('/mnt/nfs/fzz/TGL-DATA/'+data_name+'/edges.csv')
#if os.path.exists('/mnt/nfs/fzz/TGL-DATA/'+data_name+'/node_features.pt'):
# n_feat = torch.load('/mnt/nfs/fzz/TGL-DATA/'+data_name+'/node_features.pt')
#else:
# n_feat = None
#if os.path.exists('/mnt/nfs/fzz/TGL-DATA/'+data_name+'/edge_features.pt'):
# e_feat = torch.load('/mnt/nfs/fzz/TGL-DATA/'+data_name+'/edge_features.pt')
#else:
# e_feat = None
#
## get the timestamp of validate and test set
#src_node_ids = torch.from_numpy(np.array(graph_df.src.values)).long()
#dst_node_ids = torch.from_numpy(np.array(graph_df.dst.values)).long()
#node_interact_times = torch.from_numpy(np.array(graph_df.time.values)).long()
#
#train_mask = (torch.from_numpy(np.array(graph_df.ext_roll.values)) == 0)
#test_mask = (torch.from_numpy(np.array(graph_df.ext_roll.values)) == 1)
#val_mask = (torch.from_numpy(np.array(graph_df.ext_roll.values)) == 2)
# the setting of seed follows previous works
graph_df = pd.read_csv('./processed_data/{}/ml_{}.csv'.format(data_name, data_name))
edge_raw_features = np.load('./processed_data/{}/ml_{}.npy'.format(data_name, data_name))
node_raw_features = np.load('./processed_data/{}/ml_{}_node.npy'.format(data_name, data_name))
NODE_FEAT_DIM = EDGE_FEAT_DIM = 172
assert NODE_FEAT_DIM >= node_raw_features.shape[1], f'Node feature dimension in dataset {data_name} is bigger than {NODE_FEAT_DIM}!'
assert EDGE_FEAT_DIM >= edge_raw_features.shape[1], f'Edge feature dimension in dataset {data_name} is bigger than {EDGE_FEAT_DIM}!'
# padding the features of edges and nodes to the same dimension (172 for all the datasets)
if node_raw_features.shape[1] < NODE_FEAT_DIM:
node_zero_padding = np.zeros((node_raw_features.shape[0], NODE_FEAT_DIM - node_raw_features.shape[1]))
node_raw_features = np.concatenate([node_raw_features, node_zero_padding], axis=1)
if edge_raw_features.shape[1] < EDGE_FEAT_DIM:
edge_zero_padding = np.zeros((edge_raw_features.shape[0], EDGE_FEAT_DIM - edge_raw_features.shape[1]))
edge_raw_features = np.concatenate([edge_raw_features, edge_zero_padding], axis=1)
e_feat = edge_raw_features
n_feat = torch.from_numpy(node_raw_features.astype(np.float32))
e_feat = torch.from_numpy(edge_raw_features.astype(np.float32))
assert NODE_FEAT_DIM == node_raw_features.shape[1] and EDGE_FEAT_DIM == edge_raw_features.shape[1], 'Unaligned feature dimensions after feature padding!'
# get the timestamp of validate and test set
val_time, test_time = list(np.quantile(graph_df.ts, [(1 - val_ratio - test_ratio), (1 - test_ratio)]))
src_node_ids = torch.from_numpy(graph_df.u.values.astype(np.longlong))
dst_node_ids = torch.from_numpy(graph_df.i.values.astype(np.longlong))
node_interact_times = torch.from_numpy(graph_df.ts.values.astype(np.float32))
#edge_ids = torch.from_numpy(graph_df.idx.values.astype(np.longlong))
labels = torch.from_numpy(graph_df.label.values)
unique_node_ids = torch.cat((src_node_ids,dst_node_ids)).unique()
train_mask = node_interact_times <= val_time
val_mask = ((node_interact_times > val_time)&(node_interact_times <= test_time))
test_mask = (node_interact_times > test_time)
torch.manual_seed(2020)
train_node_set = torch.cat((src_node_ids[train_mask],dst_node_ids[train_mask])).unique()
test_node_set = set(src_node_ids[node_interact_times > val_time]).union(set(dst_node_ids[node_interact_times > val_time]))
new_test_node_set = set(random.sample(test_node_set, int(0.1 * unique_node_ids.shape[0])))
new_test_source_mask = graph_df.u.map(lambda x: x in new_test_node_set).values
new_test_destination_mask = graph_df.i.map(lambda x: x in new_test_node_set).values
# mask, which is true for edges with both destination and source not being new test nodes (because we want to remove all edges involving any new test node)
observed_edges_mask = torch.from_numpy(np.logical_and(~new_test_source_mask, ~new_test_destination_mask)).long()
train_mask = (train_mask & observed_edges_mask)
mask = torch.isin(unique_node_ids,train_node_set,invert = True)
new_node_set = unique_node_ids[mask]
edge_contains_new_node_mask = (torch.isin(src_node_ids,new_node_set) | torch.isin(dst_node_ids,new_node_set))
new_node_val_mask = (val_mask & edge_contains_new_node_mask)
new_node_test_mask = (test_mask & edge_contains_new_node_mask)
full_data = Data()
full_data.edge_index = torch.stack((src_node_ids,dst_node_ids))
sample_graph = {}
sample_src = torch.cat([src_node_ids.view(-1, 1), dst_node_ids.view(-1, 1)], dim=1)\
.reshape(1, -1)
sample_dst = torch.cat([dst_node_ids.view(-1, 1), src_node_ids.view(-1, 1)], dim=1)\
.reshape(1, -1)
sample_ts = torch.cat([node_interact_times.view(-1, 1), node_interact_times.view(-1, 1)], dim=1).reshape(-1)
sample_eid = torch.arange(full_data.edge_index.shape[1]).view(-1, 1).repeat(1, 2).reshape(-1)
sample_graph['edge_index'] = torch.cat([sample_src, sample_dst], dim=0)
sample_graph['ts'] = sample_ts
sample_graph['eids'] = sample_eid
sample_graph['train_mask'] = train_mask
sample_graph['val_mask'] = val_mask
sample_graph['test_mask'] = val_mask
sample_graph['new_node_val_mask'] = new_node_val_mask
sample_graph['new_node_test_mask'] = new_node_test_mask
print(unique_node_ids.max().item(),unique_node_ids.shape[0])
full_data.num_nodes = int(unique_node_ids.max().item())+1
full_data.num_edges = node_interact_times.shape[0]
full_data.sample_graph = sample_graph
full_data.x = n_feat
full_data.edge_attr = e_feat
full_data.y = labels
full_data.edge_ts = node_interact_times
full_data.train_mask = train_mask
full_data.val_mask = val_mask
full_data.test_mask = test_mask
full_data.new_node_val_mask = new_node_val_mask
full_data.new_node_test_mask = new_node_test_mask
return full_data
#full_graph = DistributedGraphStore(full_data, device, uvm_node, uvm_edge)
#train_data = torch.masked_select(full_data.edge_index,train_mask.to(device)).reshape(2,-1)
#train_ts = torch.masked_select(full_data.edge_ts,train_mask.to(device))
#val_data = torch.masked_select(full_data.edge_index,val_mask.to(device)).reshape(2,-1)
#val_ts = torch.masked_select(full_data.edge_ts,val_mask.to(device))
#test_data = torch.masked_select(full_data.edge_index,test_mask.to(device)).reshape(2,-1)
#test_ts = torch.masked_select(full_data.edge_ts,test_mask.to(device))
##print(train_data.shape[1],val_data.shape[1],test_data.shape[1])
#train_data = DataSet(edges = train_data,ts =train_ts,eids = torch.nonzero(train_mask).view(-1))
#test_data = DataSet(edges = test_data,ts =test_ts,eids = torch.nonzero(test_mask).view(-1))
#val_data = DataSet(edges = val_data,ts = val_ts,eids = torch.nonzero(val_mask).view(-1))
#new_node_val_data = torch.masked_select(full_data.edge_index,new_node_val_mask.to(device)).reshape(2,-1)
#new_node_val_ts = torch.masked_select(full_data.edge_ts,new_node_val_mask.to(device))
#new_node_test_data = torch.masked_select(full_data.edge_index,new_node_test_mask.to(device)).reshape(2,-1)
#new_node_test_ts = torch.masked_select(full_data.edge_ts,new_node_test_mask.to(device))
#return full_data, train_data, val_data, test_data, new_node_val_data, new_node_test_data
import torch
from sklearn.metrics import average_precision_score, roc_auc_score
def get_link_prediction_metrics(predicts: torch.Tensor, labels: torch.Tensor):
"""
get metrics for the link prediction task
:param predicts: Tensor, shape (num_samples, )
:param labels: Tensor, shape (num_samples, )
:return:
dictionary of metrics {'metric_name_1': metric_1, ...}
"""
predicts = predicts.cpu().detach().numpy()
labels = labels.cpu().numpy()
average_precision = average_precision_score(y_true=labels, y_score=predicts)
roc_auc = roc_auc_score(y_true=labels, y_score=predicts)
return {'average_precision': average_precision, 'roc_auc': roc_auc}
def get_node_classification_metrics(predicts: torch.Tensor, labels: torch.Tensor):
"""
get metrics for the node classification task
:param predicts: Tensor, shape (num_samples, )
:param labels: Tensor, shape (num_samples, )
:return:
dictionary of metrics {'metric_name_1': metric_1, ...}
"""
predicts = predicts.cpu().detach().numpy()
labels = labels.cpu().numpy()
roc_auc = roc_auc_score(y_true=labels, y_score=predicts)
return {'roc_auc': roc_auc}
...@@ -32,11 +32,16 @@ class EdgePredictor(torch.nn.Module): ...@@ -32,11 +32,16 @@ class EdgePredictor(torch.nn.Module):
def forward(self, h, neg_samples=1): def forward(self, h, neg_samples=1):
num_edge = h.shape[0] // (neg_samples + 2) num_edge = h.shape[0] // (neg_samples + 2)
h_src = self.src_fc(h[:num_edge]) h_src = self.src_fc(h[:num_edge])
h_pos_dst = self.dst_fc(h[num_edge:num_edge*2]) # h_pos_dst = self.dst_fc(h[num_edge:num_edge*2])
h_neg_src = self.src_fc(h[2 * num_edge:]) h_neg_dst = self.dst_fc(h[2 * num_edge:])
h_pos_edge = torch.nn.functional.relu(h_src + h_pos_dst) h_pos_edge = torch.nn.functional.relu(h_src + h_pos_dst)
h_neg_edge = torch.nn.functional.relu(h_neg_src + h_pos_dst.tile(neg_samples, 1)) h_neg_edge = torch.nn.functional.relu(h_src + h_neg_dst.tile(neg_samples, 1))
#h_src = self.src_fc(h[num_edge:2 * num_edge])#self.src_fc(h[:num_edge])
#h_pos_dst = self.dst_fc(h[:num_edge]) #
#h_neg_src = self.src_fc(h[2 * num_edge:])
#h_pos_edge = torch.nn.functional.relu(h_src + #h_pos_dst)
#h_neg_edge = torch.nn.functional.relu(h_neg_src #+ h_pos_dst.tile(neg_samples, 1))
#h_neg_edge = torch.nn.functional.relu(h_neg_dst.tile(neg_samples, 1) + h_pos_dst) #h_neg_edge = torch.nn.functional.relu(h_neg_dst.tile(neg_samples, 1) + h_pos_dst)
#print(h_src,h_pos_dst,h_neg_dst) #print(h_src,h_pos_dst,h_neg_dst)
return self.out_fc(h_pos_edge), self.out_fc(h_neg_edge) return self.out_fc(h_pos_edge), self.out_fc(h_neg_edge)
...@@ -123,6 +128,10 @@ class TransfomerAttentionLayer(torch.nn.Module): ...@@ -123,6 +128,10 @@ class TransfomerAttentionLayer(torch.nn.Module):
V = self.w_v(torch.cat([b.srcdata['h'][b.edges()[0]], b.edata['f']], dim=1)) V = self.w_v(torch.cat([b.srcdata['h'][b.edges()[0]], b.edata['f']], dim=1))
#K = self.w_k(torch.cat([b.srcdata['h'][b.num_dst_nodes():], b.edata['f']], dim=1)) #K = self.w_k(torch.cat([b.srcdata['h'][b.num_dst_nodes():], b.edata['f']], dim=1))
#V = self.w_v(torch.cat([b.srcdata['h'][b.num_dst_nodes():], b.edata['f']], dim=1)) #V = self.w_v(torch.cat([b.srcdata['h'][b.num_dst_nodes():], b.edata['f']], dim=1))
elif self.dim_node_feat == 0 and self.dim_edge_feat == 0:
Q = self.w_q(zero_time_feat)[b.edges()[1]]
K = self.w_k(time_feat)
V = self.w_v(time_feat)
elif self.dim_node_feat == 0: elif self.dim_node_feat == 0:
Q = self.w_q(zero_time_feat)[b.edges()[1]] Q = self.w_q(zero_time_feat)[b.edges()[1]]
K = self.w_k(torch.cat([b.edata['f'], time_feat], dim=1)) K = self.w_k(torch.cat([b.edata['f'], time_feat], dim=1))
...@@ -140,6 +149,7 @@ class TransfomerAttentionLayer(torch.nn.Module): ...@@ -140,6 +149,7 @@ class TransfomerAttentionLayer(torch.nn.Module):
#Q = self.w_q(torch.cat([b.srcdata['h'][:b.num_dst_nodes()], zero_time_feat], dim=1))[b.edges()[1]] #Q = self.w_q(torch.cat([b.srcdata['h'][:b.num_dst_nodes()], zero_time_feat], dim=1))[b.edges()[1]]
#K = self.w_k(torch.cat([b.srcdata['h'][b.num_dst_nodes():], b.edata['f'], time_feat], dim=1)) #K = self.w_k(torch.cat([b.srcdata['h'][b.num_dst_nodes():], b.edata['f'], time_feat], dim=1))
#V = self.w_v(torch.cat([b.srcdata['h'][b.num_dst_nodes():], b.edata['f'], time_feat], dim=1)) #V = self.w_v(torch.cat([b.srcdata['h'][b.num_dst_nodes():], b.edata['f'], time_feat], dim=1))
Q = torch.reshape(Q, (Q.shape[0], self.num_head, -1)) Q = torch.reshape(Q, (Q.shape[0], self.num_head, -1))
K = torch.reshape(K, (K.shape[0], self.num_head, -1)) K = torch.reshape(K, (K.shape[0], self.num_head, -1))
V = torch.reshape(V, (V.shape[0], self.num_head, -1)) V = torch.reshape(V, (V.shape[0], self.num_head, -1))
......
...@@ -203,9 +203,6 @@ class GRUMemeoryUpdater(torch.nn.Module): ...@@ -203,9 +203,6 @@ class GRUMemeoryUpdater(torch.nn.Module):
self.last_updated_ts = b.srcdata['ts'].detach().clone() self.last_updated_ts = b.srcdata['ts'].detach().clone()
self.last_updated_memory = updated_memory.detach().clone() self.last_updated_memory = updated_memory.detach().clone()
self.last_updated_nid = b.srcdata['ID'].detach().clone() self.last_updated_nid = b.srcdata['ID'].detach().clone()
x1 = torch.sum(b.srcdata['mem']**2,dim = 1)
self.delta_memory = torch.sum((updated_memory - b.srcdata['mem'])**2,dim = 1)/torch.sum(b.srcdata['mem']**2,dim = 1)
#print(torch.dist(b.srcdata['mem'],updated_memory))
if self.memory_param['combine_node_feature']: if self.memory_param['combine_node_feature']:
if self.dim_node_feat > 0: if self.dim_node_feat > 0:
if self.dim_node_feat == self.dim_hid: if self.dim_node_feat == self.dim_hid:
......
...@@ -25,8 +25,8 @@ delta_ts: list[tensor,tensor, tensor...] ...@@ -25,8 +25,8 @@ delta_ts: list[tensor,tensor, tensor...]
metadata metadata
""" """
def prepare_input(node_feat, edge_feat, mem_embedding,mfgs,dist_nid,dist_eid): def prepare_input(node_feat, edge_feat, mem_embedding,mfgs,dist_nid,dist_eid):
for mfg in mfgs: for i,mfg in enumerate(mfgs):
for i,b in enumerate(mfg): for b in mfg:
e_idx = b.edata['ID'] e_idx = b.edata['ID']
idx = b.srcdata['ID'] idx = b.srcdata['ID']
b.edata['ID'] = dist_eid[e_idx] b.edata['ID'] = dist_eid[e_idx]
...@@ -43,73 +43,61 @@ def prepare_input(node_feat, edge_feat, mem_embedding,mfgs,dist_nid,dist_eid): ...@@ -43,73 +43,61 @@ def prepare_input(node_feat, edge_feat, mem_embedding,mfgs,dist_nid,dist_eid):
b.srcdata['mem_input'] = mailbox[idx].reshape(b.srcdata['ID'].shape[0], -1) b.srcdata['mem_input'] = mailbox[idx].reshape(b.srcdata['ID'].shape[0], -1)
b.srcdata['mail_ts'] = mailbox_ts[idx] b.srcdata['mail_ts'] = mailbox_ts[idx]
#print(idx.shape[0],b.srcdata['mem_ts'].shape) #print(idx.shape[0],b.srcdata['mem_ts'].shape)
return mfgs return mfgs
def to_block(graph: DistributedGraphStore, data, sample_out, mailbox:MailBox = None,device = torch.device('cuda'),group = None): def to_block(graph: DistributedGraphStore, data, sample_out, mailbox:MailBox = None,device = torch.device('cuda'),group = None):
if len(sample_out) > 1: if len(sample_out) > 1:
sample_out,metadata = sample_out sample_out,metadata = sample_out
else: else:
metadata = None metadata = None
# print(sample_out)
eid = [ret.eid() for ret in sample_out] eid = [ret.eid() for ret in sample_out]
eid_len = [e.shape[0] for e in eid ] eid_len = [e.shape[0] for e in eid ]
eid_mapper: torch.Tensor = graph.eids_mapper eid_mapper: torch.Tensor = graph.eids_mapper
nid_mapper: torch.Tensor = graph.nids_mapper nid_mapper: torch.Tensor = graph.nids_mapper
eid_tensor = torch.cat(eid,dim = 0).to(eid_mapper.device) eid_tensor = torch.cat(eid,dim = 0).to(eid_mapper.device)
dist_eid = eid_mapper[eid_tensor].to(device) dist_eid = graph.sample_graph['dist_eid'][eid_tensor].to(device)#eid_mapper[eid_tensor].to(device)
dist_eid,eid_inv = dist_eid.unique(return_inverse=True) dist_eid,eid_inv = dist_eid.unique(return_inverse=True)
src_node = graph.sample_graph['edge_index'][0,eid_tensor*2].to(graph.nids_mapper.device) src_node = graph.sample_graph['edge_index'][0,eid_tensor].to(graph.nids_mapper.device)
src_ts = None src_ts = None
if metadata is None: if metadata is None:
root_node = data.nodes.to(graph.nids_mapper.device) root_node = data.nodes.to(graph.nidst_eid_mapper.device)
root_len = [root_node.shape[0]] root_len = [root_node.shape[0]]
if hasattr(data,'ts'): if hasattr(data,'ts'):
src_ts = torch.cat([data.ts, src_ts = torch.cat([data.ts,
graph.sample_graph['ts'][eid_tensor*2].to(device)]) graph.sample_graph['ts'][eid_tensor].to(device)])
elif 'seed' in metadata: elif 'seed' in metadata:
root_node = metadata.pop('seed').to(graph.nids_mapper.device) root_node = metadata.pop('seed').to(graph.nids_mapper.device)
root_len = root_node.shape[0] root_len = root_node.shape[0]
if 'seed_ts' in metadata: if 'seed_ts' in metadata:
src_ts = torch.cat([metadata.pop('seed_ts').to(device),\ src_ts = torch.cat([metadata.pop('seed_ts').to(device),\
graph.sample_graph['ts'][eid_tensor*2].to(device)]) graph.sample_graph['ts'][eid_tensor].to(device)])
for k in metadata: for k in metadata:
metadata[k] = metadata[k].to(device) metadata[k] = metadata[k].to(device)
nid_tensor = torch.cat([root_node,src_node],dim = 0) nid_tensor = torch.cat([root_node,src_node],dim = 0)
dist_nid = nid_mapper[nid_tensor].to(device) dist_nid = nid_mapper[nid_tensor].to(device)
dist_nid,nid_inv = dist_nid.unique(return_inverse = True) dist_nid,nid_inv = dist_nid.unique(return_inverse = True)
if isinstance(graph.edge_attr,DistributedTensor):
fetchCache = FetchFeatureCache.getFetchCache() ind_dict = graph.edge_attr.all_to_all_ind2ptr(dist_eid,group = group)
if fetchCache is None: edge_feat = graph.edge_attr.all_to_all_get(group = group,**ind_dict)
if isinstance(graph.edge_attr,DistributedTensor): else:
ind_dict = graph.edge_attr.all_to_all_ind2ptr(dist_eid,group = group) edge_feat = graph._get_edge_attr(dist_eid)
edge_feat = graph.edge_attr.all_to_all_get(group = group,**ind_dict) ind_dict = None
else: if isinstance(graph.x,DistributedTensor):
edge_feat = graph._get_edge_attr(dist_eid) ind_dict = graph.x.all_to_all_ind2ptr(dist_nid,group = group)
ind_dict = None node_feat = graph.x.all_to_all_get(group = group,**ind_dict)
if isinstance(graph.x,DistributedTensor): else:
ind_dict = graph.x.all_to_all_ind2ptr(dist_nid,group = group) node_feat = graph._get_node_attr(dist_nid)
node_feat = graph.x.all_to_all_get(group = group,**ind_dict) if mailbox is not None:
else: if torch.distributed.get_world_size() > 1:
node_feat = graph._get_node_attr(dist_nid) if node_feat is None:
if mailbox is not None: ind_dict = mailbox.node_memory.all_to_all_ind2ptr(dist_nid,group = group)
if torch.distributed.get_world_size() > 1: mem = mailbox.gather_memory(**ind_dict)
if node_feat is None:
ind_dict = mailbox.node_memory.all_to_all_ind2ptr(dist_nid,group = group)
mem = mailbox.gather_memory(**ind_dict)
else:
mem = mailbox.get_memory(dist_nid)
else: else:
mem = None mem = mailbox.get_memory(dist_nid)
else: else:
raw_nid = torch.empty_like(dist_nid) mem = None
raw_eid = torch.empty_like(dist_eid)
nid_tensor = nid_tensor.to(device)
eid_tensor = eid_tensor.to(device)
raw_nid[nid_inv] = nid_tensor
raw_eid[eid_inv] = eid_tensor
node_feat,edge_feat,mem = fetchCache.fetch_feature(raw_nid,
dist_nid,raw_eid,
dist_eid)
def build_block(): def build_block():
mfgs = list() mfgs = list()
col = torch.arange(0,root_len,device = device) col = torch.arange(0,root_len,device = device)
...@@ -142,7 +130,6 @@ def to_block(graph: DistributedGraphStore, data, sample_out, mailbox:MailBox = N ...@@ -142,7 +130,6 @@ def to_block(graph: DistributedGraphStore, data, sample_out, mailbox:MailBox = N
#return build_block(node_feat,edge_feat,mem)#data,mfgs,metadata #return build_block(node_feat,edge_feat,mem)#data,mfgs,metadata
return (data,mfgs,metadata) return (data,mfgs,metadata)
def graph_sample(graph, sampler:BaseSampler, def graph_sample(graph, sampler:BaseSampler,
sample_fn, data, sample_fn, data,
neg_sampling = None, neg_sampling = None,
......
...@@ -7,7 +7,7 @@ import torch ...@@ -7,7 +7,7 @@ import torch
import torch.distributed as dist import torch.distributed as dist
from torch_geometric.data import Data from torch_geometric.data import Data
from starrygl.utils.uvm import *
class DistributedGraphStore: class DistributedGraphStore:
''' '''
...@@ -37,6 +37,7 @@ class DistributedGraphStore: ...@@ -37,6 +37,7 @@ class DistributedGraphStore:
self.sample_graph = pdata.sample_graph self.sample_graph = pdata.sample_graph
self.nids_mapper = build_mapper(nids=pdata.ids.to(device)).dist.to('cpu') self.nids_mapper = build_mapper(nids=pdata.ids.to(device)).dist.to('cpu')
self.eids_mapper = build_mapper(nids=pdata.eids.to(device)).dist.to('cpu') self.eids_mapper = build_mapper(nids=pdata.eids.to(device)).dist.to('cpu')
self.sample_graph['dist_eid'] = self.eids_mapper[pdata.sample_graph['eids']]
torch.cuda.empty_cache() torch.cuda.empty_cache()
self.num_nodes = self.nids_mapper.data.shape[0] self.num_nodes = self.nids_mapper.data.shape[0]
...@@ -46,17 +47,18 @@ class DistributedGraphStore: ...@@ -46,17 +47,18 @@ class DistributedGraphStore:
self.uvm_edge = uvm_edge self.uvm_edge = uvm_edge
if hasattr(pdata,'x') and pdata.x is not None: if hasattr(pdata,'x') and pdata.x is not None:
ctx = DistributedContext.get_default_context()
pdata.x = pdata.x.to(torch.float) pdata.x = pdata.x.to(torch.float)
if uvm_node == False : if uvm_node == False :
x = pdata.x.to(self.device) x = pdata.x.to(self.device)
else: else:
if self.device.type == 'cuda': if self.device.type == 'cuda':
x = starrygl.utils.uvm.uvm_empty(*pdata.x.size(), x = uvm_empty(*pdata.x.size(),
dtype=pdata.x.dtype, dtype=pdata.x.dtype,
device=ctx.device) device=ctx.device)
starrygl.utils.uvm.uvm_share(x,device = ctx.device) uvm_share(x,device = ctx.device)
starrygl.utils.uvm.uvm_advise(x,starrygl.utils.uvm.cudaMemoryAdvise.cudaMemAdviseSetAccessedBy) uvm_advise(x,cudaMemoryAdvise.cudaMemAdviseSetAccessedBy)
starrygl.utils.uvm.uvm_prefetch(x) uvm_prefetch(x)
if world_size > 1: if world_size > 1:
self.x = DistributedTensor(pdata.x.to(self.device).to(torch.float)) self.x = DistributedTensor(pdata.x.to(self.device).to(torch.float))
else: else:
...@@ -71,12 +73,15 @@ class DistributedGraphStore: ...@@ -71,12 +73,15 @@ class DistributedGraphStore:
edge_attr = pdata.edge_attr.to(self.device) edge_attr = pdata.edge_attr.to(self.device)
else: else:
if self.device.type == 'cuda': if self.device.type == 'cuda':
edge_attr = starrygl.utils.uvm.uvm_empty(*pdata.edge_attr.size(), edge_attr = uvm_empty(*pdata.edge_attr.size(),
dtype=pdata.edge_attr.dtype, dtype=pdata.edge_attr.dtype,
device=ctx.device) device=ctx.device)
starrygl.utils.uvm.uvm_share(edge_attr,device = ctx.device) edge_attr = uvm_share(edge_attr,device = torch.device('cpu'))
starrygl.utils.uvm.uvm_advise(edge_attr,starrygl.utils.uvm.cudaMemoryAdvise.cudaMemAdviseSetAccessedBy) edge_attr.copy_(pdata.edge_attr)
starrygl.utils.uvm.uvm_prefetch(edge_attr)
edge_attr = uvm_share(edge_attr,device = ctx.device)
uvm_advise(edge_attr,cudaMemoryAdvise.cudaMemAdviseSetAccessedBy)
uvm_prefetch(edge_attr)
if world_size > 1: if world_size > 1:
self.edge_attr = DistributedTensor(edge_attr) self.edge_attr = DistributedTensor(edge_attr)
else: else:
...@@ -251,7 +256,8 @@ class TemporalNeighborSampleGraph(DistributedGraphStore): ...@@ -251,7 +256,8 @@ class TemporalNeighborSampleGraph(DistributedGraphStore):
self.edge_ts = sample_graph['ts'] self.edge_ts = sample_graph['ts']
else: else:
self.edge_ts = None self.edge_ts = None
self.eid = sample_graph['eids'] self.eid = sample_graph['eids']#torch.arange(self.num_edges,dtype = torch.long, device = sample_graph['eids'].device)
#sample_graph['eids']
if mode == 'train': if mode == 'train':
mask = sample_graph['train_mask'] mask = sample_graph['train_mask']
if mode == 'val': if mode == 'val':
......
...@@ -115,7 +115,7 @@ class SharedMailBox(): ...@@ -115,7 +115,7 @@ class SharedMailBox():
def set_memory_local(self,index,source,source_ts,Reduce_Op = None): def set_memory_local(self,index,source,source_ts,Reduce_Op = None):
if Reduce_Op == 'max': if Reduce_Op == 'max' and self.num_parts > 1:
unq_id,inv = index.unique(return_inverse = True) unq_id,inv = index.unique(return_inverse = True)
max_ts,id = torch_scatter.scatter_max(source_ts,inv,dim=0) max_ts,id = torch_scatter.scatter_max(source_ts,inv,dim=0)
source_ts = max_ts source_ts = max_ts
...@@ -243,7 +243,6 @@ class SharedMailBox(): ...@@ -243,7 +243,6 @@ class SharedMailBox():
#futs: List[torch.futures.Future] = [] #futs: List[torch.futures.Future] = []
if self.num_parts == 1: if self.num_parts == 1:
dist_index = DistIndex(index) dist_index = DistIndex(index)
part_idx = dist_index.part
index = dist_index.loc index = dist_index.loc
self.set_mailbox_local(index,mail,mail_ts) self.set_mailbox_local(index,mail,mail_ts)
self.set_memory_local(index,memory,memory_ts) self.set_memory_local(index,memory,memory_ts)
...@@ -317,7 +316,7 @@ class SharedMailBox(): ...@@ -317,7 +316,7 @@ class SharedMailBox():
if edge_feats is not None: if edge_feats is not None:
src_mail = torch.cat([src_mail, edge_feats], dim=1) src_mail = torch.cat([src_mail, edge_feats], dim=1)
dst_mail = torch.cat([dst_mail, edge_feats], dim=1) dst_mail = torch.cat([dst_mail, edge_feats], dim=1)
mail = torch.cat([src_mail, dst_mail], dim=1).reshape(-1, src_mail.shape[1]) mail = torch.cat([src_mail, dst_mail], dim=0)#.reshape(-1, src_mail.shape[1])
mail_ts = torch.cat((ts,ts),-1).to(self.device).to(self.mailbox_ts.dtype) mail_ts = torch.cat((ts,ts),-1).to(self.device).to(self.mailbox_ts.dtype)
unq_index,inv = torch.unique(index,return_inverse = True) unq_index,inv = torch.unique(index,return_inverse = True)
max_ts,idx = torch_scatter.scatter_max(mail_ts,inv,0) max_ts,idx = torch_scatter.scatter_max(mail_ts,inv,0)
......
from torch_sparse import SparseTensor from torch_sparse import SparseTensor
from torch_geometric.data import Data from torch_geometric.data import Data
from torch_geometric.utils import degree from torch_geometric.utils import degree
import starrygl
import os.path as osp import os.path as osp
import os import os
import shutil import shutil
import torch import torch
import torch.utils.data import torch.utils.data
import metis
import networkx as nx
import torch.distributed as dist import torch.distributed as dist
from starrygl.lib.libstarrygl_sampler import get_norm_temporal
from starrygl.utils.partition import mt_metis_partition
def partition_load(root: str, algo: str = "metis") -> Data: def partition_load(root: str, algo: str = "metis") -> Data:
...@@ -21,7 +21,6 @@ def partition_load(root: str, algo: str = "metis") -> Data: ...@@ -21,7 +21,6 @@ def partition_load(root: str, algo: str = "metis") -> Data:
def partition_save(root: str, data: Data, num_parts: int, def partition_save(root: str, data: Data, num_parts: int,
algo: str = "metis", algo: str = "metis",
node_weight = None,
edge_weight_dict=None): edge_weight_dict=None):
root = osp.abspath(root) root = osp.abspath(root)
if osp.exists(root) and not osp.isdir(root): if osp.exists(root) and not osp.isdir(root):
...@@ -46,7 +45,6 @@ def partition_save(root: str, data: Data, num_parts: int, ...@@ -46,7 +45,6 @@ def partition_save(root: str, data: Data, num_parts: int,
if algo == 'metis_for_tgnn': if algo == 'metis_for_tgnn':
for i, pdata in enumerate(partition_data_for_tgnn( for i, pdata in enumerate(partition_data_for_tgnn(
data, num_parts, algo, verbose=True, data, num_parts, algo, verbose=True,
node_weight = node_weight,
edge_weight_dict=edge_weight_dict)): edge_weight_dict=edge_weight_dict)):
print(f"saving partition data: {i+1}/{num_parts}") print(f"saving partition data: {i+1}/{num_parts}")
fn = osp.join(path, f"{i:03d}") fn = osp.join(path, f"{i:03d}")
...@@ -154,41 +152,33 @@ def _nopart(edge_index: torch.LongTensor, num_nodes: int): ...@@ -154,41 +152,33 @@ def _nopart(edge_index: torch.LongTensor, num_nodes: int):
def metis_for_tgnn(edge_index_dict: dict, def metis_for_tgnn(edge_index_dict: dict,
num_nodes: int, num_nodes: int,
num_parts: int, num_parts: int,
node_weight = None,
edge_weight_dict=None): edge_weight_dict=None):
if num_parts <= 1: if num_parts <= 1:
return _nopart(edge_index_dict, num_nodes) return _nopart(edge_index_dict, num_nodes)
edge_list = [] G = nx.Graph()
weight_list = [] G.add_nodes_from(torch.arange(0, num_nodes).tolist())
for i,key in enumerate(edge_index_dict): value, counts = torch.unique(edge_index_dict['edata'][1, :].view(-1),
return_counts=True)
nodes = torch.tensor(list(G.adj.keys()))
for i in range(value.shape[0]):
if (value[i].item() in G.nodes):
G.nodes[int(value[i].item())]['weight'] = counts[i]
G.nodes[int(value[i].item())]['ones'] = 1
G.graph['node_weight_attr'] = ['weight', 'ones']
edges = []
for i, key in enumerate(edge_index_dict):
v = edge_index_dict[key] v = edge_index_dict[key]
edge_list.append(v) edge = torch.cat((v, (torch.ones(v.shape[1], dtype=torch.long) *
weight_list.append(torch.ones(v.shape[1])*edge_weight_dict[key]) edge_weight_dict[key]).unsqueeze(0)), dim=0)
edge_index = torch.cat(edge_list,dim = 1) edges.append(edge)
edge_weight = torch.cat(weight_list,dim = 0) # w = edges.T
node_parts = mt_metis_partition(edge_index,num_nodes,num_parts,node_weight,edge_weight) edges = torch.cat(edges,dim = 1)
G.add_weighted_edges_from((edges.T).tolist())
G.graph['edge_weight_attr'] = 'weight'
cuts, part = metis.part_graph(G, num_parts)
node_parts = torch.zeros(num_nodes, dtype=torch.long)
node_parts[nodes] = torch.tensor(part)
return node_parts return node_parts
#G = nx.Graph()
#G.add_nodes_from(torch.arange(0, num_nodes).tolist())
#value, counts = torch.unique(edge_index_dict['edata'][1, :].view(-1),
# return_counts=True)
#nodes = torch.tensor(list(G.adj.keys()))
#for i in range(value.shape[0]):
# if (value[i].item() in G.nodes):
# G.nodes[int(value[i].item())]['weight'] = counts[i]
# G.nodes[int(value[i].item())]['ones'] = 1
#G.graph['node_weight_attr'] = ['weight', 'ones']
#for i, key in enumerate(edge_index_dict):
# v = edge_index_dict[key]
# edges = torch.cat((v, (torch.ones(v.shape[1], dtype=torch.long) *
# edge_weight_dict[key]).unsqueeze(0)), dim=0)
# # w = edges.T
# G.add_weighted_edges_from((edges.T).tolist())
#G.graph['edge_weight_attr'] = 'weight'
#cuts, part = metis.part_graph(G, num_parts)
#node_parts = torch.zeros(num_nodes, dtype=torch.long)
#node_parts[nodes] = torch.tensor(part)
#return node_parts
""" """
...@@ -199,7 +189,6 @@ weight: 各种工作负载边划分权重 ...@@ -199,7 +189,6 @@ weight: 各种工作负载边划分权重
def partition_data_for_tgnn(data: Data, num_parts: int, algo: str, def partition_data_for_tgnn(data: Data, num_parts: int, algo: str,
verbose: bool = False, verbose: bool = False,
node_weight: torch.Tensor = None,
edge_weight_dict: dict = None): edge_weight_dict: dict = None):
if algo == "metis_for_tgnn": if algo == "metis_for_tgnn":
part_fn = metis_for_tgnn part_fn = metis_for_tgnn
...@@ -213,7 +202,6 @@ def partition_data_for_tgnn(data: Data, num_parts: int, algo: str, ...@@ -213,7 +202,6 @@ def partition_data_for_tgnn(data: Data, num_parts: int, algo: str,
if verbose: if verbose:
print(f"running partition algorithm: {algo}") print(f"running partition algorithm: {algo}")
node_parts = part_fn(edge_index_dict, num_nodes, num_parts, node_parts = part_fn(edge_index_dict, num_nodes, num_parts,
node_weight,
edge_weight_dict) edge_weight_dict)
edge_parts = node_parts[data.edge_index[1, :]] edge_parts = node_parts[data.edge_index[1, :]]
eids = torch.arange(num_edges, dtype=torch.long) eids = torch.arange(num_edges, dtype=torch.long)
...@@ -304,7 +292,7 @@ def compute_gcn_norm(edge_index: torch.LongTensor, num_nodes: int): ...@@ -304,7 +292,7 @@ def compute_gcn_norm(edge_index: torch.LongTensor, num_nodes: int):
def compute_temporal_norm(edge_index: torch.LongTensor, def compute_temporal_norm(edge_index: torch.LongTensor,
timestamp: torch.FloatTensor, timestamp: torch.FloatTensor,
num_nodes: int): num_nodes: int):
srcavg, srcvar, dstavg, dstvar = get_norm_temporal(edge_index[0, :], srcavg, srcvar, dstavg, dstvar = starrygl.sampler_ops.get_norm_temporal(edge_index[0, :],
edge_index[1, :], edge_index[1, :],
timestamp, num_nodes) timestamp, num_nodes)
return srcavg, srcvar, dstavg, dstvar return srcavg, srcvar, dstavg, dstvar
......
import sys
from os.path import abspath, join, dirname
sys.path.insert(0, join(abspath(dirname(__file__))))
from torch import Tensor
import torch
from base import NegativeSampling
from base import NegativeSamplingMode
from typing import Any, List, Optional, Tuple, Union
class EvaluateNegativeSampling(NegativeSampling):
def __init__(
self,
mode: Union[NegativeSamplingMode, str],
src_node_ids: torch.Tensor,
dst_node_ids: torch.Tensor,
interact_times: torch.Tensor = None,
last_observed_time: float = None,
negative_sample_strategy: str = 'random',
seed: int = None
):
super(EvaluateNegativeSampling,self).__init__(mode)
self.seed = seed
self.negative_sample_strategy = negative_sample_strategy
self.src_node_ids = src_node_ids
self.dst_node_ids = dst_node_ids
self.interact_times = interact_times
self.unique_src_nodes_id = src_node_ids.unique()
self.unique_dst_nodes_id = dst_node_ids.unique()
self.src_id_mapper = torch.zeros(self.unique_src_nodes_id[-1])
self.dst_id_mapper = torch.zeros(self.unique_dst_nodes_id[-1])
self.src_id_mapper[self.unique_src_nodes_id] = torch.arange(self.unique_src_nodes_id.shape[0])
self.dst_id_mapper[self.unique_dst_nodes_id] = torch.arange(self.unique_dst_nodes_id.shape[0])
self.unique_interact_times = self.interact_times.unique()
self.earliest_time = self.unique_interact_times.min().item()
self.last_observed_time = last_observed_time
if self.negative_sample_strategy == 'inductive':
# set of observed edges
self.observed_edges = self.get_unique_edges_between_start_end_time(self.earliest_time, self.last_observed_time)
if self.seed is not None:
self.random_state = torch.Generator()
self.random_state.manual_seed(seed)
else:
self.random_state = torch.Generator()
def get_unique_edges_between_start_end_time(self, start_time: float, end_time: float):
selected_mask = ((self.interact_times >= start_time) and (self.interact_times <= end_time))
# return the unique select source and destination nodes in the selected time interval
return torch.cat((self.src_node_ids[selected_mask],self.dst_node_ids[selected_mask]),dim = 1)
def sample(self, num_samples: int, num_nodes: Optional[int] = None, batch_src_node_ids: Optional[torch.Tensor] = None,
batch_dst_node_ids: Optional[torch.Tensor] = None, current_batch_start_time: Optional[torch.Tensor] = None,
current_batch_end_time: Optional[torch.Tensor] = None) -> Tensor:
if self.negative_sample_strategy == 'random':
negative_src_node_ids, negative_dst_node_ids = self.random_sample(size=num_samples)
elif self.negative_sample_strategy == 'historical':
negative_src_node_ids, negative_dst_node_ids = self.historical_sample(size=num_samples, batch_src_node_ids=batch_src_node_ids,
batch_dst_node_ids=batch_dst_node_ids,
current_batch_start_time=current_batch_start_time,
current_batch_end_time=current_batch_end_time)
elif self.negative_sample_strategy == 'inductive':
negative_src_node_ids, negative_dst_node_ids = self.inductive_sample(size=num_samples, batch_src_node_ids=batch_src_node_ids,
batch_dst_node_ids=batch_dst_node_ids,
current_batch_start_time=current_batch_start_time,
current_batch_end_time=current_batch_end_time)
else:
raise ValueError(f'Not implemented error for negative_sample_strategy {self.negative_sample_strategy}!')
return negative_src_node_ids, negative_dst_node_ids
def random_sample(self, size: int):
if self.seed is None:
random_sample_edge_src_node_indices = torch.randint(0, len(self.unique_src_nodes_id), size)
random_sample_edge_dst_node_indices = torch.randint(0, len(self.unique_dst_nodes_id), size)
else:
random_sample_edge_src_node_indices = torch.randint(0, len(self.unique_src_nodes_id), size, generate = self.random_state)
random_sample_edge_dst_node_indices = torch.randint(0, len(self.unique_dst_nodes_id), size, generate = self.random_state)
return self.unique_src_nodes_id[random_sample_edge_src_node_indices], self.unique_dst_nodes_id[random_sample_edge_dst_node_indices]
def random_sample_with_collision_check(self, size: int, batch_src_nodes_id:torch.Tensor, batch_dst_nodes_id:torch.Tensor):
batch_edge = torch.stack((batch_src_nodes_id,batch_dst_nodes_id))
batch_src_index = self.src_id_mapper[batch_src_nodes_id]
batch_dst_index = self.dst_id_mapper[batch_dst_nodes_id]
return_edge = torch.tensor([[],[]])
while(True):
src_ = torch.randint(0, len(self.unique_src_nodes_id), size*2)
dst_ = torch.randint(0, len(self.unique_dst_nodes_id), size*2)
edge = torch.stack((src_,dst_))
sample_id = src_*self.unique_dst_nodes_id.shape[0] + dst_
batch_id = batch_src_index * self.unique_dst_nodes_id.shape[0] + batch_dst_index
mask = torch.isin(sample_id,batch_id,invert = True)
edge = edge[:,mask]
if(edge.shape[1] >= size):
return_edge = torch.cat((return_edge,edge[:,:size]),1)
break
else:
return_edge = torch.cat((return_edge,edge),1)
size = size - edge.shape[1]
return return_edge
def historical_sample(self, size: int, batch_src_nodes_id: torch.Tensor, batch_dst_nodes_id: torch.Tensor,
current_batch_start_time: float, current_batch_end_time: float):
assert self.seed is not None
historical_edges = self.get_unique_edges_between_start_end_time(start_time=self.earliest_time, end_time=current_batch_start_time)
current_batch_edges = self.get_unique_edges_between_start_end_time(start_time=current_batch_start_time, end_time=current_batch_end_time)
uni,ids = torch.cat((current_batch_edges, historical_edges), dim = 1).unique(dim = 1, return_inverse = False)
mask = torch.zeros(uni.shape[1],dtype = bool)
mask[ids[:current_batch_edges.shape[1]]] = True
mask = (~mask)
unique_historical_edges = uni[:,mask]
if size > unique_historical_edges.shape[1]:
num_random_sample_edges = size - len(unique_historical_edges)
random_sample_edge = self.random_sample_with_collision_check(size=num_random_sample_edges,batch_src_node_ids=batch_src_nodes_id,
batch_dst_node_ids=batch_dst_nodes_id)
sample_edges = torch.cat((unique_historical_edges,random_sample_edge),dim = 1)
else:
historical_sample_edge_node_indices = torch.randperm(unique_historical_edges.shape[1],generator=self.random_state)
sample_edges = unique_historical_edges[:,historical_sample_edge_node_indices[:size]]
return sample_edges
def inductive_sample(self, size: int, batch_src_node_ids: torch.Tensor, batch_dst_node_ids: torch.Tensor,
current_batch_start_time: float, current_batch_end_time: float):
assert self.seed is not None
historical_edges = self.get_unique_edges_between_start_end_time(start_time=self.earliest_time, end_time=current_batch_start_time)
current_batch_edges = self.get_unique_edges_between_start_end_time(start_time=current_batch_start_time, end_time=current_batch_end_time)
uni,ids = torch.cat((self.observed_edges,current_batch_edges, historical_edges), dim = 1).unique(dim = 1, return_inverse = False)
mask = torch.zeros(uni.shape[1],dtype = bool)
mask[ids[:current_batch_edges.shape[1]+historical_edges.shape[1]]] = True
mask = (~mask)
unique_inductive_edges = uni[:,mask]
if size > len(unique_inductive_edges):
num_random_sample_edges = size - len(unique_inductive_edges)
random_sample_edge = self.random_sample_with_collision_check(size=num_random_sample_edges,
batch_src_node_ids=batch_src_node_ids,
batch_dst_node_ids=batch_dst_node_ids)
sample_edges = torch.cat((unique_inductive_edges,random_sample_edge),dim = 1)
else:
inductive_sample_edge_node_indices = torch.randperm(unique_inductive_edges.shape[1],generator=self.random_state)
sample_edges = unique_inductive_edges[:, inductive_sample_edge_node_indices[:size]]
return sample_edges
...@@ -4,7 +4,7 @@ from enum import Enum ...@@ -4,7 +4,7 @@ from enum import Enum
import math import math
from abc import ABC from abc import ABC
from typing import Any, List, Optional, Tuple, Union from typing import Any, List, Optional, Tuple, Union
import numpy as np
class SampleType(Enum): class SampleType(Enum):
Whole = 0 Whole = 0
Inner = 1 Inner = 1
...@@ -82,7 +82,8 @@ class NegativeSampling: ...@@ -82,7 +82,8 @@ class NegativeSampling:
f"Cannot sample negatives in '{self.__class__.__name__}' " f"Cannot sample negatives in '{self.__class__.__name__}' "
f"without passing the 'num_nodes' argument") f"without passing the 'num_nodes' argument")
return torch.randint(num_nodes, (num_samples, )) return torch.randint(num_nodes, (num_samples, ))
#return torch.from_numpy(np.random.randint(num_nodes, size=num_samples))
if num_nodes is not None and self.weight.numel() != num_nodes: if num_nodes is not None and self.weight.numel() != num_nodes:
raise ValueError( raise ValueError(
f"The 'weight' attribute in '{self.__class__.__name__}' " f"The 'weight' attribute in '{self.__class__.__name__}' "
......
...@@ -262,6 +262,8 @@ class NeighborSampler(BaseSampler): ...@@ -262,6 +262,8 @@ class NeighborSampler(BaseSampler):
seed = torch.cat([src, dst, src_neg], dim=0) seed = torch.cat([src, dst, src_neg], dim=0)
if with_timestap: # ts操作 if with_timestap: # ts操作
seed_ts = torch.cat([ets, ets, ets], dim=0) seed_ts = torch.cat([ets, ets, ets], dim=0)
#if neg_sampling.is_evaluate():
#src,dst = neg_sampling.sample(num_samples=)
else: else:
seed = torch.cat([src, dst], dim=0) seed = torch.cat([src, dst], dim=0)
if with_timestap: # ts操作 if with_timestap: # ts操作
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment