Commit 2f15e921 by zlj

add sample module and tgnn trainer

parent b84f3d4c
def foo():
a = 1
def fa():
print(a)
a+=1
print(a)
def fb(a):
def apply():
print(a)
return apply
fc = lambda: print(a)
return fa, fb(a), fc
fa, fb, fc = foo()
fa()
fb()
fc()
sampling:
- layer: 1
neighbor:
- 10
strategy: 'recent'
prop_time: False
history: 1
duration: 0
num_thread: 32
no_neg: True
memory:
- type: 'node'
dim_time: 100
deliver_to: 'neighbors'
mail_combine: 'last'
memory_update: 'transformer'
attention_head: 2
mailbox_size: 10
combine_node_feature: False
dim_out: 100
gnn:
- arch: 'identity'
train:
- epoch: 100
batch_size: 600
lr: 0.0001
dropout: 0.1
att_dropout: 0.1
# all_on_gpu: True
\ No newline at end of file
sampling:
- layer: 2
neighbor:
- 10
- 10
strategy: 'uniform'
prop_time: True
history: 3
duration: 10000
num_thread: 32
memory:
- type: 'none'
dim_out: 0
gnn:
- arch: 'transformer_attention'
layer: 2
att_head: 2
dim_time: 0
dim_out: 100
combine: 'rnn'
train:
- epoch: 50
batch_size: 600
lr: 0.0001
dropout: 0.1
att_dropout: 0.1
all_on_gpu: True
\ No newline at end of file
sampling:
- no_sample: True
history: 1
memory:
- type: 'node'
dim_time: 100
deliver_to: 'self'
mail_combine: 'last'
memory_update: 'rnn'
mailbox_size: 1
combine_node_feature: True
dim_out: 100
gnn:
- arch: 'identity'
time_transform: 'JODIE'
train:
- epoch: 100
batch_size: 600
lr: 0.0001
dropout: 0.1
all_on_gpu: True
\ No newline at end of file
sampling:
- layer: 2
neighbor:
- 10
- 10
strategy: 'uniform'
prop_time: False
history: 1
duration: 0
num_thread: 32
memory:
- type: 'none'
dim_out: 0
gnn:
- arch: 'transformer_attention'
layer: 2
att_head: 2
dim_time: 100
dim_out: 100
train:
- epoch: 10
batch_size: 600
lr: 0.0001
dropout: 0.1
att_dropout: 0.1
all_on_gpu: True
\ No newline at end of file
sampling:
- layer: 1
neighbor:
- 10
strategy: 'recent'
prop_time: False
history: 1
duration: 0
num_thread: 32
memory:
- type: 'node'
dim_time: 100
deliver_to: 'self'
mail_combine: 'last'
memory_update: 'gru'
mailbox_size: 1
combine_node_feature: True
dim_out: 100
gnn:
- arch: 'transformer_attention'
layer: 1
att_head: 2
dim_time: 100
dim_out: 100
train:
- epoch: 10
#batch_size: 100
# reorder: 16
lr: 0.0001
dropout: 0.2
att_dropout: 0.2
all_on_gpu: True
\ No newline at end of file
sampling:
- layer: <number of layers to sample>
neighbor: <a list of integers indicating how many neighbors are sampled in each layer>
strategy: <'recent' that samples most recent neighbors or 'uniform' that uniformly samples neighbors form the past>
prop_time: <False or True that specifies wherether to use the timestamp of the root nodes when sampling for their multi-hop neighbors>
history: <number of snapshots to sample on>
duration: <length in time of each snapshot, 0 for infinite length (used in non-snapshot-based methods)
num_thread: <number of threads of the sampler>
memory:
- type: <'node', we only support node memory now>
dim_time: <an integer, the dimension of the time embedding>
deliver_to: <'self' that delivers the mails only to involved nodes or 'neighbors' that deliver the mails to neighbors>
mail_combine: <'last' that use the latest latest mail as the input to the memory updater>
memory_update: <'gru' or 'rnn'>
mailbox_size: <an integer, the size of the mailbox for each node>
combine_node_feature: <False or True that specifies whether to combine node features (with the updated memory) as the input to the GNN.
dim_out: <an integer, the dimension of the output node memory>
gnn:
- arch: <'transformer_attention' or 'identity' (no GNN)>
layer: <an integer, number of layers>
att_head: <an integer, number of attention heads>
dim_time: <an integer, the dimension of the time embedding>
dim_out: <an integer, the dimension of the output dynamic node embedding>
train:
- epoch: <an integer, number of epochs to train>
batch_size: <an integer, the batch size (of edges); for multi-gpu training, this is the local batchsize>
reorder: <(optional) an integer that is divisible by batch size the specifies how many chunks per batch used in the random chunk scheduling>
lr: <floating point, learning rate>
dropout: <floating point, dropout>
att_dropout: <floating point, dropout for attention>
all_on_gpu: <False or True that decides if the node/edge features and node memory are completely stored on GPU>
\ No newline at end of file
import os
import argparse
import numpy as np
import pandas as pd
import torch
from torch_geometric.data import Data
from starrygl.sample.sample_core.base import NegativeSampling
from starrygl.sample.part_utils.partition_tgnn import partition_save
parser = argparse.ArgumentParser(
description="RPC Reinforcement Learning Example",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument('--data_name', default='WIKI', type=str, metavar='W',
help='name of dataset')
parser.add_argument('--num_neg_sample', default=1, type=int, metavar='W',
help='number of negative samples')
args = parser.parse_args()
def load_feat(d, rand_de=0, rand_dn=0):
node_feats = None
if os.path.exists('DATA/{}/node_features.pt'.format(d)):
node_feats = torch.load('DATA/{}/node_features.pt'.format(d))
if node_feats.dtype == torch.bool:
node_feats = node_feats.type(torch.float32)
edge_feats = None
if os.path.exists('DATA/{}/edge_features.pt'.format(d)):
edge_feats = torch.load('DATA/{}/edge_features.pt'.format(d))
if edge_feats.dtype == torch.bool:
edge_feats = edge_feats.type(torch.float32)
if rand_de > 0:
if d == 'LASTFM':
edge_feats = torch.randn(1293103, rand_de)
elif d == 'MOOC':
edge_feats = torch.randn(411749, rand_de)
if rand_dn > 0:
if d == 'LASTFM':
node_feats = torch.randn(1980, rand_dn)
elif d == 'MOOC':
edge_feats = torch.randn(7144, rand_dn)
return node_feats, edge_feats
data_name = args.data_name
g = np.load('../tgl_main/DATA/'+data_name+'/ext_full.npz')
df = pd.read_csv('../tgl_main/DATA/'+data_name+'/edges.csv')
if os.path.exists('../tgl_main/DATA/'+data_name+'/node_features.pt'):
n_feat = torch.load('../tgl_main/DATA/'+data_name+'/node_features.pt')
else:
n_feat = None
if os.path.exists('../tgl_main/DATA/'+data_name+'/edge_features.pt'):
e_feat = torch.load('../tgl_main/DATA/'+data_name+'/edge_features.pt')
else:
e_feat = None
src = torch.from_numpy(np.array(df.src.values)).long()
dst = torch.from_numpy(np.array(df.dst.values)).long()
ts = torch.from_numpy(np.array(df.time.values)).long()
neg_nums = args.num_neg_sample
edge_index = torch.cat((src[np.newaxis, :], dst[np.newaxis, :]), 0)
num_nodes = edge_index.view(-1).max().item()+1
num_edges = edge_index.shape[1]
print('the number of nodes in graph is {}, \
the number of edges in graph is {}'.format(num_nodes, num_edges))
sample_graph = {}
sample_src = torch.cat([src.view(-1, 1), dst.view(-1, 1)], dim=1)\
.reshape(1, -1)
sample_dst = torch.cat([dst.view(-1, 1), src.view(-1, 1)], dim=1)\
.reshape(1, -1)
sample_ts = torch.cat([ts.view(-1, 1), ts.view(-1, 1)], dim=1).reshape(-1)
sample_eid = torch.arange(num_edges).view(-1, 1).repeat(1, 2).reshape(-1)
sample_graph['edge_index'] = torch.cat([sample_src, sample_dst], dim=0)
sample_graph['ts'] = sample_ts
sample_graph['eids'] = sample_eid
neg_sampler = NegativeSampling('triplet')
neg_src = neg_sampler.sample(edge_index.shape[1]*neg_nums, num_nodes)
neg_sample = neg_src.reshape(-1, neg_nums)
edge_ts = torch.torch.from_numpy(np.array(ts)).float()
data = Data()
data.num_nodes = num_nodes
data.num_edges = num_edges
data.edge_index = edge_index
data.edge_ts = edge_ts
data.neg_sample = neg_sample
if n_feat is not None:
data.x = n_feat
if e_feat is not None:
data.edge_attr = e_feat
data.train_mask = (torch.from_numpy(np.array(df.ext_roll.values)) == 0)
data.test_mask = (torch.from_numpy(np.array(df.ext_roll.values)) == 1)
data.val_mask = (torch.from_numpy(np.array(df.ext_roll.values)) == 2)
sample_graph['train_mask'] = data.train_mask[sample_eid]
sample_graph['test_mask'] = data.test_mask[sample_eid]
sample_graph['val_mask'] = data.val_mask[sample_eid]
data.sample_graph = sample_graph
data.y = torch.zeros(edge_index.shape[1])
edge_index_dict = {}
edge_index_dict['edata'] = data.edge_index
edge_index_dict['sample_data'] = data.sample_graph['edge_index']
edge_index_dict['neg_data'] = torch.cat([neg_src.view(1, -1),
dst.view(-1, 1).repeat(1, neg_nums).
reshape(1, -1)], dim=0)
data.edge_index_dict = edge_index_dict
edge_weight_dict = {}
edge_weight_dict['edata'] = 2*neg_nums
edge_weight_dict['sample_data'] = 1*neg_nums
edge_weight_dict['neg_data'] = 1
partition_save('./dataset/here/'+data_name, data, 1, 'metis_for_tgnn',
edge_weight_dict=edge_weight_dict)
partition_save('./dataset/here/'+data_name, data, 2, 'metis_for_tgnn',
edge_weight_dict=edge_weight_dict)
#
# partition_save('./dataset/here/'+data_name, data, 4, 'metis_for_tgnn',
# edge_weight_dict=edge_weight_dict )
# partition_save('./dataset/here'+data_name, data, 8, 'metis')
# partition_save('./dataset/'+data_name, data, 12, 'metis')
# partition_save('./dataset/'+data_name, data, 16, 'metis')
#!/bin/sh
#conda activate gnn
cd ./starrygl/sample/sample_core
if [ -f "setup.py" ]; then
rm -r build
rm sample_cores.cpython-*.so
python setup.py build_ext --inplace
fi
cd ../part_utils
if [ -f "setup.py" ]; then
rm -r build
rm torch_utils.cpython-*.so
python setup.py build_ext --inplace
fi
cd ../../
......@@ -93,6 +93,8 @@ class DistIndex:
@property
def dist(self) -> Tensor:
return self.data
def to(self,device) -> Tensor:
return DistIndex(self.data.to(device))
class DistributedTensor:
......@@ -109,6 +111,21 @@ class DistributedTensor:
self.num_nodes = DistInt(local_sizes)
self.num_parts = DistInt([1] * len(self.rrefs))
@property
def dtype(self):
return self.accessor.data.dtype
@property
def device(self):
return self.accessor.data.device
def to(self,device):
return self.accessor.data.to(device)
def __getitem__(self,index):
return self.accessor.data[index]
@property
def ctx(self):
return self.accessor.ctx
......
from os.path import abspath, join, dirname
import os
import sys
from os.path import abspath, join, dirname
sys.path.insert(0, join(abspath(dirname(__file__))))
import torch
import dgl
import math
import numpy as np
class TimeEncode(torch.nn.Module):
def __init__(self, dim):
super(TimeEncode, self).__init__()
self.dim = dim
self.w = torch.nn.Linear(1, dim)
self.w.weight = torch.nn.Parameter((torch.from_numpy(1 / 10 ** np.linspace(0, 9, dim, dtype=np.float32))).reshape(dim, -1))
self.w.bias = torch.nn.Parameter(torch.zeros(dim))
def forward(self, t):
output = torch.cos(self.w(t.float().reshape((-1, 1))))
return output
class EdgePredictor(torch.nn.Module):
def __init__(self, dim_in):
super(EdgePredictor, self).__init__()
self.dim_in = dim_in
self.src_fc = torch.nn.Linear(dim_in, dim_in)
self.dst_fc = torch.nn.Linear(dim_in, dim_in)
self.out_fc = torch.nn.Linear(dim_in, 1)
def forward(self, h, neg_samples=1):
num_edge = h.shape[0] // (neg_samples + 2)
h_src = self.src_fc(h[num_edge:2 * num_edge])#self.src_fc(h[:num_edge])
h_pos_dst = self.dst_fc(h[:num_edge]) #
h_neg_src = self.src_fc(h[2 * num_edge:])
h_pos_edge = torch.nn.functional.relu(h_src + h_pos_dst)
h_neg_edge = torch.nn.functional.relu(h_neg_src + h_pos_dst.tile(neg_samples, 1))
#h_neg_edge = torch.nn.functional.relu(h_neg_dst.tile(neg_samples, 1) + h_pos_dst)
#print(h_src,h_pos_dst,h_neg_dst)
return self.out_fc(h_pos_edge), self.out_fc(h_neg_edge)
class TransfomerAttentionLayer(torch.nn.Module):
def __init__(self, dim_node_feat, dim_edge_feat, dim_time, num_head, dropout, att_dropout, dim_out, combined=False):
super(TransfomerAttentionLayer, self).__init__()
self.num_head = num_head
self.dim_node_feat = dim_node_feat
self.dim_edge_feat = dim_edge_feat
self.dim_time = dim_time
self.dim_out = dim_out
self.dropout = torch.nn.Dropout(dropout)
self.att_dropout = torch.nn.Dropout(att_dropout)
self.att_act = torch.nn.LeakyReLU(0.2)
self.combined = combined
if dim_time > 0:
self.time_enc = TimeEncode(dim_time)
if combined:
if dim_node_feat > 0:
self.w_q_n = torch.nn.Linear(dim_node_feat, dim_out)
self.w_k_n = torch.nn.Linear(dim_node_feat, dim_out)
self.w_v_n = torch.nn.Linear(dim_node_feat, dim_out)
if dim_edge_feat > 0:
self.w_k_e = torch.nn.Linear(dim_edge_feat, dim_out)
self.w_v_e = torch.nn.Linear(dim_edge_feat, dim_out)
if dim_time > 0:
self.w_q_t = torch.nn.Linear(dim_time, dim_out)
self.w_k_t = torch.nn.Linear(dim_time, dim_out)
self.w_v_t = torch.nn.Linear(dim_time, dim_out)
else:
if dim_node_feat + dim_time > 0:
self.w_q = torch.nn.Linear(dim_node_feat + dim_time, dim_out)
self.w_k = torch.nn.Linear(dim_node_feat + dim_edge_feat + dim_time, dim_out)
self.w_v = torch.nn.Linear(dim_node_feat + dim_edge_feat + dim_time, dim_out)
self.w_out = torch.nn.Linear(dim_node_feat + dim_out, dim_out)
self.layer_norm = torch.nn.LayerNorm(dim_out)
def forward(self, b):
assert(self.dim_time + self.dim_node_feat + self.dim_edge_feat > 0)
self.device = b.device
if b.num_edges() == 0:
return torch.zeros((b.num_dst_nodes(), self.dim_out), device=self.device)
if self.dim_time > 0:
time_feat = self.time_enc(b.edata['dt'])
zero_time_feat = self.time_enc(torch.zeros(b.num_dst_nodes(), dtype=torch.float32, device=self.device))
if self.combined:
Q = torch.zeros((b.num_edges(), self.dim_out), device=self.device)
K = torch.zeros((b.num_edges(), self.dim_out), device=self.device)
V = torch.zeros((b.num_edges(), self.dim_out), device=self.device)
if self.dim_node_feat > 0:
Q += self.w_q_n(b.srcdata['h'][:b.num_dst_nodes()])[b.edges()[1]]
K += self.w_k_n(b.srcdata['h'][b.num_dst_nodes():])[b.edges()[0] - b.num_dst_nodes()]
V += self.w_v_n(b.srcdata['h'][b.num_dst_nodes():])[b.edges()[0] - b.num_dst_nodes()]
if self.dim_edge_feat > 0:
K += self.w_k_e(b.edata['f'])
V += self.w_v_e(b.edata['f'])
if self.dim_time > 0:
Q += self.w_q_t(zero_time_feat)[b.edges()[1]]
K += self.w_k_t(time_feat)
V += self.w_v_t(time_feat)
Q = torch.reshape(Q, (Q.shape[0], self.num_head, -1))
K = torch.reshape(K, (K.shape[0], self.num_head, -1))
V = torch.reshape(V, (V.shape[0], self.num_head, -1))
att = dgl.ops.edge_softmax(b, self.att_act(torch.sum(Q*K, dim=2)))
att = self.att_dropout(att)
V = torch.reshape(V*att[:, :, None], (V.shape[0], -1))
b.edata['v'] = V
b.update_all(dgl.function.copy_edge('v', 'm'), dgl.function.sum('m', 'h'))
else:
if self.dim_time == 0 and self.dim_node_feat == 0:
Q = torch.ones((b.num_edges(), self.dim_out), device=self.device)
K = self.w_k(b.edata['f'])
V = self.w_v(b.edata['f'])
elif self.dim_time == 0 and self.dim_edge_feat == 0:
Q = self.w_q(b.srcdata['h'][:b.num_dst_nodes()])[b.edges()[1]]
K = self.w_k(b.srcdata['h'][b.edges()[0]])
V = self.w_v(b.srcdata['h'][b.edges()[0]])
elif self.dim_time == 0:
Q = self.w_q(b.srcdata['h'][:b.num_dst_nodes()])[b.edges()[1]]
K = self.w_k(torch.cat([b.srcdata['h'][b.edges()[0]], b.edata['f']], dim=1))
V = self.w_v(torch.cat([b.srcdata['h'][b.edges()[0]], b.edata['f']], dim=1))
#K = self.w_k(torch.cat([b.srcdata['h'][b.num_dst_nodes():], b.edata['f']], dim=1))
#V = self.w_v(torch.cat([b.srcdata['h'][b.num_dst_nodes():], b.edata['f']], dim=1))
elif self.dim_node_feat == 0:
Q = self.w_q(zero_time_feat)[b.edges()[1]]
K = self.w_k(torch.cat([b.edata['f'], time_feat], dim=1))
V = self.w_v(torch.cat([b.edata['f'], time_feat], dim=1))
elif self.dim_edge_feat == 0:
Q = self.w_q(torch.cat([b.srcdata['h'][:b.num_dst_nodes()], zero_time_feat], dim=1))[b.edges()[1]]
K = self.w_k(torch.cat([b.srcdata['h'][b.edges()[0]], time_feat], dim=1))
V = self.w_v(torch.cat([b.srcdata['h'][b.edges()[0]], time_feat], dim=1))
#K = self.w_k(torch.cat([b.srcdata['h'][b.num_dst_nodes():], time_feat], dim=1))
#V = self.w_v(torch.cat([b.srcdata['h'][b.num_dst_nodes():], time_feat], dim=1))
else:
Q = self.w_q(torch.cat([b.srcdata['h'][:b.num_dst_nodes()], zero_time_feat], dim=1))[b.edges()[1]]
K = self.w_k(torch.cat([b.srcdata['h'][b.edges()[0]], b.edata['f'], time_feat], dim=1))
V = self.w_v(torch.cat([b.srcdata['h'][b.edges()[0]], b.edata['f'], time_feat], dim=1))
#Q = self.w_q(torch.cat([b.srcdata['h'][:b.num_dst_nodes()], zero_time_feat], dim=1))[b.edges()[1]]
#K = self.w_k(torch.cat([b.srcdata['h'][b.num_dst_nodes():], b.edata['f'], time_feat], dim=1))
#V = self.w_v(torch.cat([b.srcdata['h'][b.num_dst_nodes():], b.edata['f'], time_feat], dim=1))
Q = torch.reshape(Q, (Q.shape[0], self.num_head, -1))
K = torch.reshape(K, (K.shape[0], self.num_head, -1))
V = torch.reshape(V, (V.shape[0], self.num_head, -1))
att = dgl.ops.edge_softmax(b, self.att_act(torch.sum(Q*K, dim=2)))
att = self.att_dropout(att)
V = torch.reshape(V*att[:, :, None], (V.shape[0], -1))
b.edata['v'] = V
b.update_all(dgl.function.copy_e('v', 'm'), dgl.function.sum('m', 'h'))
#b.srcdata['v'] = torch.cat([torch.zeros((b.num_dst_nodes(), V.shape[1]), device=torch.device('cuda:0')), V], dim=0)
#b.update_all(dgl.function.copy_u('v', 'm'), dgl.function.sum('m', 'h'))
if self.dim_node_feat != 0:
rst = torch.cat([b.dstdata['h'], b.srcdata['h'][:b.num_dst_nodes()]], dim=1)
else:
rst = b.dstdata['h']
rst = self.w_out(rst)
rst = torch.nn.functional.relu(self.dropout(rst))
return self.layer_norm(rst)
class IdentityNormLayer(torch.nn.Module):
def __init__(self, dim_out):
super(IdentityNormLayer, self).__init__()
self.norm = torch.nn.LayerNorm(dim_out)
def forward(self, b):
return self.norm(b.srcdata['h'])
class JODIETimeEmbedding(torch.nn.Module):
def __init__(self, dim_out):
super(JODIETimeEmbedding, self).__init__()
self.dim_out = dim_out
class NormalLinear(torch.nn.Linear):
# From Jodie code
def reset_parameters(self):
stdv = 1. / math.sqrt(self.weight.size(1))
self.weight.data.normal_(0, stdv)
if self.bias is not None:
self.bias.data.normal_(0, stdv)
self.time_emb = NormalLinear(1, dim_out)
def forward(self, h, mem_ts, ts):
time_diff = (ts - mem_ts) / (ts + 1)
rst = h * (1 + self.time_emb(time_diff.unsqueeze(1)))
return rst
\ No newline at end of file
from os.path import abspath, join, dirname
import os
import sys
from os.path import abspath, join, dirname
sys.path.insert(0, join(abspath(dirname(__file__))))
import torch
import dgl
from layers import TimeEncode
from torch_scatter import scatter
class MailBox():
def __init__(self, memory_param, num_nodes, dim_edge_feat, _node_memory=None, _node_memory_ts=None,_mailbox=None, _mailbox_ts=None, _next_mail_pos=None, _update_mail_pos=None):
self.memory_param = memory_param
self.dim_edge_feat = dim_edge_feat
if memory_param['type'] != 'node':
raise NotImplementedError
self.node_memory = torch.zeros((num_nodes, memory_param['dim_out']), dtype=torch.float32) if _node_memory is None else _node_memory
self.node_memory_ts = torch.zeros(num_nodes, dtype=torch.float32) if _node_memory_ts is None else _node_memory_ts
self.mailbox = torch.zeros((num_nodes, memory_param['mailbox_size'], 2 * memory_param['dim_out'] + dim_edge_feat), dtype=torch.float32) if _mailbox is None else _mailbox
self.mailbox_ts = torch.zeros((num_nodes, memory_param['mailbox_size']), dtype=torch.float32) if _mailbox_ts is None else _mailbox_ts
self.next_mail_pos = torch.zeros((num_nodes), dtype=torch.long) if _next_mail_pos is None else _next_mail_pos
self.update_mail_pos = _update_mail_pos
self.device = torch.device('cpu')
def reset(self):
self.node_memory.fill_(0)
self.node_memory_ts.fill_(0)
self.mailbox.fill_(0)
self.mailbox_ts.fill_(0)
self.next_mail_pos.fill_(0)
def move_to_gpu(self):
self.node_memory = self.node_memory.cuda()
self.node_memory_ts = self.node_memory_ts.cuda()
self.mailbox = self.mailbox.cuda()
self.mailbox_ts = self.mailbox_ts.cuda()
self.next_mail_pos = self.next_mail_pos.cuda()
self.device = torch.device('cuda:0')
def allocate_pinned_memory_buffers(self, sample_param, batch_size):
limit = int(batch_size * 3.3)
if 'neighbor' in sample_param:
for i in sample_param['neighbor']:
limit *= i + 1
self.pinned_node_memory_buffs = list()
self.pinned_node_memory_ts_buffs = list()
self.pinned_mailbox_buffs = list()
self.pinned_mailbox_ts_buffs = list()
for _ in range(sample_param['history']):
self.pinned_node_memory_buffs.append(torch.zeros((limit, self.node_memory.shape[1]), pin_memory=True))
self.pinned_node_memory_ts_buffs.append(torch.zeros((limit,), pin_memory=True))
self.pinned_mailbox_buffs.append(torch.zeros((limit, self.mailbox.shape[1], self.mailbox.shape[2]), pin_memory=True))
self.pinned_mailbox_ts_buffs.append(torch.zeros((limit, self.mailbox_ts.shape[1]), pin_memory=True))
def prep_input_mails(self, mfg, global_id_list = None,use_pinned_buffers=False):
if(global_id_list is not None):
idx = global_id_list[b.srcdata['ID']]
else:
idx = b.srcdata['ID']
if use_pinned_buffers:
idx = idx.cpu().long()
for i, b in enumerate(mfg):
if use_pinned_buffers:
torch.index_select(self.node_memory, 0, idx, out=self.pinned_node_memory_buffs[i][:idx.shape[0]])
b.srcdata['mem'] = self.pinned_node_memory_buffs[i][:idx.shape[0]].cuda(non_blocking=True)
torch.index_select(self.node_memory_ts,0, idx, out=self.pinned_node_memory_ts_buffs[i][:idx.shape[0]])
b.srcdata['mem_ts'] = self.pinned_node_memory_ts_buffs[i][:idx.shape[0]].cuda(non_blocking=True)
torch.index_select(self.mailbox, 0, idx, out=self.pinned_mailbox_buffs[i][:idx.shape[0]])
b.srcdata['mem_input'] = self.pinned_mailbox_buffs[i][:idx.shape[0]].reshape(b.srcdata['ID'].shape[0], -1).cuda(non_blocking=True)
torch.index_select(self.mailbox_ts, 0, idx, out=self.pinned_mailbox_ts_buffs[i][:idx.shape[0]])
b.srcdata['mail_ts'] = self.pinned_mailbox_ts_buffs[i][:idx.shape[0]].cuda(non_blocking=True)
else:
b.srcdata['mem'] = self.node_memory[idx].cuda()
b.srcdata['mem_ts'] = self.node_memory_ts[idx].cuda()
b.srcdata['mem_input'] = self.mailbox[idx].cuda().reshape(b.srcdata['ID'].shape[0], -1)
b.srcdata['mail_ts'] = self.mailbox_ts[idx].cuda()
def update_memory(self, nid, memory, root_nodes, ts, global_node_list = None,neg_samples=1):
if nid is None:
return
num_true_src_dst = root_nodes.shape[0] // (neg_samples + 2) * 2
with torch.no_grad():
nid = nid[:num_true_src_dst].to(self.device)
memory = memory[:num_true_src_dst].to(self.device)
ts = ts[:num_true_src_dst].to(self.device)
self.node_memory[nid.long()] = memory
self.node_memory_ts[nid.long()] = ts
def update_mailbox(self, nid, memory, root_nodes, ts, edge_feats, block, global_node_list = None, neg_samples=1):
with torch.no_grad():
num_true_edges = root_nodes.shape[0] // (neg_samples + 2)
memory = memory.to(self.device)
if edge_feats is not None:
edge_feats = edge_feats.to(self.device)
if block is not None:
block = block.to(self.device)
# TGN/JODIE
if self.memory_param['deliver_to'] == 'self':
src = torch.from_numpy(root_nodes[:num_true_edges]).to(self.device)
dst = torch.from_numpy(root_nodes[num_true_edges:num_true_edges * 2]).to(self.device)
mem_src = memory[:num_true_edges]
mem_dst = memory[num_true_edges:num_true_edges * 2]
if self.dim_edge_feat > 0:
src_mail = torch.cat([mem_src, mem_dst, edge_feats], dim=1)
dst_mail = torch.cat([mem_dst, mem_src, edge_feats], dim=1)
else:
src_mail = torch.cat([mem_src, mem_dst], dim=1)
dst_mail = torch.cat([mem_dst, mem_src], dim=1)
mail = torch.cat([src_mail, dst_mail], dim=1).reshape(-1, src_mail.shape[1])
nid = torch.cat([src.unsqueeze(1), dst.unsqueeze(1)], dim=1).reshape(-1)
mail_ts = torch.from_numpy(ts[:num_true_edges * 2]).to(self.device)
if mail_ts.dtype == torch.float64:
import pdb; pdb.set_trace()
# find unique nid to update mailbox
uni, inv = torch.unique(nid, return_inverse=True)
perm = torch.arange(inv.size(0), dtype=inv.dtype, device=inv.device)
perm = inv.new_empty(uni.size(0)).scatter_(0, inv, perm)
nid = nid[perm]
mail = mail[perm]
mail_ts = mail_ts[perm]
if self.memory_param['mail_combine'] == 'last':
self.mailbox[nid.long(), self.next_mail_pos[nid.long()]] = mail
self.mailbox_ts[nid.long(), self.next_mail_pos[nid.long()]] = mail_ts
if self.memory_param['mailbox_size'] > 1:
self.next_mail_pos[nid.long()] = torch.remainder(self.next_mail_pos[nid.long()] + 1, self.memory_param['mailbox_size'])
# APAN
elif self.memory_param['deliver_to'] == 'neighbors':
mem_src = memory[:num_true_edges]
mem_dst = memory[num_true_edges:num_true_edges * 2]
if self.dim_edge_feat > 0:
src_mail = torch.cat([mem_src, mem_dst, edge_feats], dim=1)
dst_mail = torch.cat([mem_dst, mem_src, edge_feats], dim=1)
else:
src_mail = torch.cat([mem_src, mem_dst], dim=1)
dst_mail = torch.cat([mem_dst, mem_src], dim=1)
mail = torch.cat([src_mail, dst_mail], dim=0)
mail = torch.cat([mail, mail[block.edges()[0].long()]], dim=0)
mail_ts = torch.from_numpy(ts[:num_true_edges * 2]).to(self.device)
mail_ts = torch.cat([mail_ts, mail_ts[block.edges()[0].long()]], dim=0)
if self.memory_param['mail_combine'] == 'mean':
(nid, idx) = torch.unique(block.dstdata['ID'], return_inverse=True)
mail = scatter(mail, idx, reduce='mean', dim=0)
mail_ts = scatter(mail_ts, idx, reduce='mean')
self.mailbox[nid.long(), self.next_mail_pos[nid.long()]] = mail
self.mailbox_ts[nid.long(), self.next_mail_pos[nid.long()]] = mail_ts
elif self.memory_param['mail_combine'] == 'last':
nid = block.dstdata['ID']
# find unique nid to update mailbox
uni, inv = torch.unique(nid, return_inverse=True)
perm = torch.arange(inv.size(0), dtype=inv.dtype, device=inv.device)
perm = inv.new_empty(uni.size(0)).scatter_(0, inv, perm)
nid = nid[perm]
mail = mail[perm]
mail_ts = mail_ts[perm]
self.mailbox[nid.long(), self.next_mail_pos[nid.long()]] = mail
self.mailbox_ts[nid.long(), self.next_mail_pos[nid.long()]] = mail_ts
else:
raise NotImplementedError
if self.memory_param['mailbox_size'] > 1:
if self.update_mail_pos is None:
self.next_mail_pos[nid.long()] = torch.remainder(self.next_mail_pos[nid.long()] + 1, self.memory_param['mailbox_size'])
else:
self.update_mail_pos[nid.long()] = 1
else:
raise NotImplementedError
def update_next_mail_pos(self):
if self.update_mail_pos is not None:
nid = torch.where(self.update_mail_pos == 1)[0]
self.next_mail_pos[nid] = torch.remainder(self.next_mail_pos[nid] + 1, self.memory_param['mailbox_size'])
self.update_mail_pos.fill_(0)
class GRUMemeoryUpdater(torch.nn.Module):
def __init__(self, memory_param, dim_in, dim_hid, dim_time, dim_node_feat):
super(GRUMemeoryUpdater, self).__init__()
self.dim_hid = dim_hid
self.dim_node_feat = dim_node_feat
self.memory_param = memory_param
self.dim_time = dim_time
self.updater = torch.nn.GRUCell(dim_in + dim_time, dim_hid)
self.last_updated_memory = None
self.last_updated_ts = None
self.last_updated_nid = None
self.delta_memory = 0
if dim_time > 0:
self.time_enc = TimeEncode(dim_time)
if memory_param['combine_node_feature']:
if dim_node_feat > 0 and dim_node_feat != dim_hid:
self.node_feat_map = torch.nn.Linear(dim_node_feat, dim_hid)
def forward(self, mfg):
for b in mfg:
if self.dim_time > 0:
time_feat = self.time_enc(b.srcdata['ts'] - b.srcdata['mem_ts'])
b.srcdata['mem_input'] = torch.cat([b.srcdata['mem_input'], time_feat], dim=1)
updated_memory = self.updater(b.srcdata['mem_input'], b.srcdata['mem'])
self.last_updated_ts = b.srcdata['ts'].detach().clone()
self.last_updated_memory = updated_memory.detach().clone()
self.last_updated_nid = b.srcdata['ID'].detach().clone()
x1 = torch.sum(b.srcdata['mem']**2,dim = 1)
self.delta_memory = torch.sum((updated_memory - b.srcdata['mem'])**2,dim = 1)/torch.sum(b.srcdata['mem']**2,dim = 1)
#print(torch.dist(b.srcdata['mem'],updated_memory))
if self.memory_param['combine_node_feature']:
if self.dim_node_feat > 0:
if self.dim_node_feat == self.dim_hid:
b.srcdata['h'] += updated_memory
else:
b.srcdata['h'] = updated_memory + self.node_feat_map(b.srcdata['h'])
else:
b.srcdata['h'] = updated_memory
class RNNMemeoryUpdater(torch.nn.Module):
def __init__(self, memory_param, dim_in, dim_hid, dim_time, dim_node_feat):
super(RNNMemeoryUpdater, self).__init__()
self.dim_hid = dim_hid
self.dim_node_feat = dim_node_feat
self.memory_param = memory_param
self.dim_time = dim_time
self.updater = torch.nn.RNNCell(dim_in + dim_time, dim_hid)
self.last_updated_memory = None
self.last_updated_ts = None
self.last_updated_nid = None
self.delta_memory = 0
if dim_time > 0:
self.time_enc = TimeEncode(dim_time)
if memory_param['combine_node_feature']:
if dim_node_feat > 0 and dim_node_feat != dim_hid:
self.node_feat_map = torch.nn.Linear(dim_node_feat, dim_hid)
def forward(self, mfg):
for b in mfg:
if self.dim_time > 0:
#print(b.srcdata['ts'].shape,b.srcdata['mem_ts'].shape)
time_feat = self.time_enc(b.srcdata['ts'] - b.srcdata['mem_ts'])
b.srcdata['mem_input'] = torch.cat([b.srcdata['mem_input'], time_feat], dim=1)
updated_memory = self.updater(b.srcdata['mem_input'], b.srcdata['mem'])
self.last_updated_ts = b.srcdata['ts'].detach().clone()
self.last_updated_memory = updated_memory.detach().clone()
self.last_updated_nid = b.srcdata['ID'].detach().clone()
if self.memory_param['combine_node_feature']:
if self.dim_node_feat > 0:
if self.dim_node_feat == self.dim_hid:
b.srcdata['h'] += updated_memory
else:
b.srcdata['h'] = updated_memory + self.node_feat_map(b.srcdata['h'])
else:
b.srcdata['h'] = updated_memory
class TransformerMemoryUpdater(torch.nn.Module):
def __init__(self, memory_param, dim_in, dim_out, dim_time, train_param):
super(TransformerMemoryUpdater, self).__init__()
self.memory_param = memory_param
self.dim_time = dim_time
self.att_h = memory_param['attention_head']
if dim_time > 0:
self.time_enc = TimeEncode(dim_time)
self.w_q = torch.nn.Linear(dim_out, dim_out)
self.w_k = torch.nn.Linear(dim_in + dim_time, dim_out)
self.w_v = torch.nn.Linear(dim_in + dim_time, dim_out)
self.att_act = torch.nn.LeakyReLU(0.2)
self.layer_norm = torch.nn.LayerNorm(dim_out)
self.mlp = torch.nn.Linear(dim_out, dim_out)
self.dropout = torch.nn.Dropout(train_param['dropout'])
self.att_dropout = torch.nn.Dropout(train_param['att_dropout'])
self.last_updated_memory = None
self.last_updated_ts = None
self.last_updated_nid = None
def forward(self, mfg):
for b in mfg:
Q = self.w_q(b.srcdata['mem']).reshape((b.num_src_nodes(), self.att_h, -1))
mails = b.srcdata['mem_input'].reshape((b.num_src_nodes(), self.memory_param['mailbox_size'], -1))
if self.dim_time > 0:
time_feat = self.time_enc(b.srcdata['ts'][:, None] - b.srcdata['mail_ts']).reshape((b.num_src_nodes(), self.memory_param['mailbox_size'], -1))
mails = torch.cat([mails, time_feat], dim=2)
K = self.w_k(mails).reshape((b.num_src_nodes(), self.memory_param['mailbox_size'], self.att_h, -1))
V = self.w_v(mails).reshape((b.num_src_nodes(), self.memory_param['mailbox_size'], self.att_h, -1))
att = self.att_act((Q[:,None,:,:]*K).sum(dim=3))
att = torch.nn.functional.softmax(att, dim=1)
att = self.att_dropout(att)
rst = (att[:,:,:,None]*V).sum(dim=1)
rst = rst.reshape((rst.shape[0], -1))
rst += b.srcdata['mem']
rst = self.layer_norm(rst)
rst = self.mlp(rst)
rst = self.dropout(rst)
rst = torch.nn.functional.relu(rst)
b.srcdata['h'] = rst
self.last_updated_memory = rst.detach().clone()
self.last_updated_nid = b.srcdata['ID'].detach().clone()
self.last_updated_ts = b.srcdata['ts'].detach().clone()
import torch
import dgl
from os.path import abspath, join, dirname
import sys
sys.path.insert(0, join(abspath(dirname(__file__))))
from layers import *
from memorys import *
class GeneralModel(torch.nn.Module):
def __init__(self, dim_node, dim_edge, sample_param, memory_param, gnn_param, train_param, combined=False):
super(GeneralModel, self).__init__()
self.dim_node = dim_node
self.dim_node_input = dim_node
self.dim_edge = dim_edge
self.sample_param = sample_param
self.memory_param = memory_param
if not 'dim_out' in gnn_param:
gnn_param['dim_out'] = memory_param['dim_out']
self.gnn_param = gnn_param
self.train_param = train_param
if memory_param['type'] == 'node':
if memory_param['memory_update'] == 'gru':
self.memory_updater = RNNMemeoryUpdater(memory_param, 2 * memory_param['dim_out'] + dim_edge, memory_param['dim_out'], memory_param['dim_time'], dim_node)
elif memory_param['memory_update'] == 'rnn':
self.memory_updater = RNNMemeoryUpdater(memory_param, 2 * memory_param['dim_out'] + dim_edge, memory_param['dim_out'], memory_param['dim_time'], dim_node)
elif memory_param['memory_update'] == 'transformer':
self.memory_updater = TransformerMemoryUpdater(memory_param, 2 * memory_param['dim_out'] + dim_edge, memory_param['dim_out'], memory_param['dim_time'], train_param)
else:
raise NotImplementedError
self.dim_node_input = memory_param['dim_out']
self.layers = torch.nn.ModuleDict()
if gnn_param['arch'] == 'transformer_attention':
for h in range(sample_param['history']):
self.layers['l0h' + str(h)] = TransfomerAttentionLayer(self.dim_node_input, dim_edge, gnn_param['dim_time'], gnn_param['att_head'], train_param['dropout'], train_param['att_dropout'], gnn_param['dim_out'], combined=combined)
for l in range(1, gnn_param['layer']):
for h in range(sample_param['history']):
self.layers['l' + str(l) + 'h' + str(h)] = TransfomerAttentionLayer(gnn_param['dim_out'], dim_edge, gnn_param['dim_time'], gnn_param['att_head'], train_param['dropout'], train_param['att_dropout'], gnn_param['dim_out'], combined=False)
elif gnn_param['arch'] == 'identity':
self.gnn_param['layer'] = 1
for h in range(sample_param['history']):
self.layers['l0h' + str(h)] = IdentityNormLayer(self.dim_node_input)
if 'time_transform' in gnn_param and gnn_param['time_transform'] == 'JODIE':
self.layers['l0h' + str(h) + 't'] = JODIETimeEmbedding(gnn_param['dim_out'])
else:
raise NotImplementedError
self.edge_predictor = EdgePredictor(gnn_param['dim_out'])
if 'combine' in gnn_param and gnn_param['combine'] == 'rnn':
self.combiner = torch.nn.RNN(gnn_param['dim_out'], gnn_param['dim_out'])
def forward(self, mfgs, metadata = None,neg_samples=1):
if self.memory_param['type'] == 'node':
self.memory_updater(mfgs[0])
out = list()
for l in range(self.gnn_param['layer']):
for h in range(self.sample_param['history']):
rst = self.layers['l' + str(l) + 'h' + str(h)](mfgs[l][h])
if 'time_transform' in self.gnn_param and self.gnn_param['time_transform'] == 'JODIE':
rst = self.layers['l0h' + str(h) + 't'](rst, mfgs[l][h].srcdata['mem_ts'], mfgs[l][h].srcdata['ts'])
if l != self.gnn_param['layer'] - 1:
mfgs[l + 1][h].srcdata['h'] = rst
else:
out.append(rst)
if self.sample_param['history'] == 1:
out = out[0]
else:
out = torch.stack(out, dim=0)
out = self.combiner(out)[0][-1, :, :]
#metadata需要在前面去重的时候记一下id
if metadata is not None:
#out = torch.cat((out[metadata['dst_pos_pos']],out[metadata['src_id_pos']],out[metadata['dst_neg_pos']]),0)
out = torch.cat((out[metadata['src_pos_index']],out[metadata['dst_pos_index']],out[metadata['src_neg_index']]),0)
return self.edge_predictor(out, neg_samples=neg_samples)
def get_emb(self, mfgs):
if self.memory_param['type'] == 'node':
self.memory_updater(mfgs[0])
out = list()
for l in range(self.gnn_param['layer']):
for h in range(self.sample_param['history']):
rst = self.layers['l' + str(l) + 'h' + str(h)](mfgs[l][h])
if 'time_transform' in self.gnn_param and self.gnn_param['time_transform'] == 'JODIE':
rst = self.layers['l0h' + str(h) + 't'](rst, mfgs[l][h].srcdata['mem_ts'], mfgs[l][h].srcdata['ts'])
if l != self.gnn_param['layer'] - 1:
mfgs[l + 1][h].srcdata['h'] = rst
else:
out.append(rst)
if self.sample_param['history'] == 1:
out = out[0]
else:
out = torch.stack(out, dim=0)
out = self.combiner(out)[0][-1, :, :]
return out
class NodeClassificationModel(torch.nn.Module):
def __init__(self, dim_in, dim_hid, num_class):
super(NodeClassificationModel, self).__init__()
self.fc1 = torch.nn.Linear(dim_in, dim_hid)
self.fc2 = torch.nn.Linear(dim_hid, num_class)
def forward(self, x):
x = self.fc1(x)
x = torch.nn.functional.relu(x)
x = self.fc2(x)
return x
\ No newline at end of file
import yaml
def parse_config(f):
conf = yaml.safe_load(open(f, 'r'))
sample_param = conf['sampling'][0]
memory_param = conf['memory'][0]
gnn_param = conf['gnn'][0]
train_param = conf['train'][0]
return sample_param, memory_param, gnn_param, train_param
\ No newline at end of file
from typing import List, Tuple
import torch
import torch.distributed as dist
from starrygl.sample.graph_core import DataSet
from starrygl.sample.graph_core import GraphData
from starrygl.sample.sample_core.base import BaseSampler, NegativeSampling
import dgl
"""
入参不变,出参变为:
sample_from_nodes
node: list[tensor,tensor, tensor...]
eid: list[tensor,tensor, tensor...]
src_index: list[tensor,tensor, tensor...]
sample_from_edges:
node
eid: list[tensor,tensor, tensor...]
src_index: list[tensor,tensor, tensor...]
delta_ts: list[tensor,tensor, tensor...]
metadata
"""
def prepare_input(node_feat, edge_feat, mem_embedding,mfgs,dist_nid,dist_eid):
for mfg in mfgs:
for i,b in enumerate(mfg):
e_idx = b.edata['ID']
idx = b.srcdata['ID']
b.edata['ID'] = dist_eid[e_idx]
b.srcdata['ID'] = dist_nid[idx]
if edge_feat is not None:
b.edata['f'] = edge_feat[e_idx]
if i == 0:
if node_feat is not None:
b.srcdata['h'] = node_feat[idx]
if mem_embedding is not None:
node_memory,node_memory_ts,mailbox,mailbox_ts = mem_embedding
b.srcdata['mem'] = node_memory[idx]
b.srcdata['mem_ts'] = node_memory_ts[idx]
b.srcdata['mem_input'] = mailbox[idx].reshape(b.srcdata['ID'].shape[0], -1)
b.srcdata['mail_ts'] = mailbox_ts[idx]
#print(idx.shape[0],b.srcdata['mem_ts'].shape)
return mfgs
def to_block(graph: GraphData, data, sample_out, mailbox = None,device = torch.device('cuda')):
if len(sample_out) > 1:
sample_out,metadata = sample_out
else:
metadata = None
eid = [ret.eid() for ret in sample_out]
eid_len = [e.shape[0] for e in eid ]
eid_mapper: torch.Tensor = graph.eids_mapper
nid_mapper: torch.Tensor = graph.nids_mapper
eid_tensor = torch.cat(eid,dim = 0).to(eid_mapper.device)
dist_eid = eid_mapper[eid_tensor].to(device)
dist_eid,eid_inv = dist_eid.unique(return_inverse=True)
src_node = graph.sample_graph['edge_index'][0,eid_tensor*2].to(graph.nids_mapper.device)
src_ts = None
edge_feat = graph._get_edge_attr(dist_eid)
if metadata is None:
root_node = data.nodes.to(graph.nids_mapper.device)
root_len = [root_node.shape[0]]
if hasattr(data,'ts'):
src_ts = torch.cat([data.ts,
graph.sample_graph['ts'][eid_tensor*2].to(device)])
elif 'seed' in metadata:
root_node = metadata.pop('seed').to(graph.nids_mapper.device)
root_len = root_node.shape[0]
if 'seed_ts' in metadata:
src_ts = torch.cat([metadata.pop('seed_ts').to(device),\
graph.sample_graph['ts'][eid_tensor*2].to(device)])
for k in metadata:
metadata[k] = metadata[k].to(device)
nid_tensor = torch.cat([root_node,src_node],dim = 0)
dist_nid = nid_mapper[nid_tensor].to(device)
dist_nid,nid_inv = dist_nid.unique(return_inverse = True)
node_feat = graph._get_node_attr(dist_nid)
if mailbox is not None:
mem = mailbox._get_memory(dist_nid)
else:
mem = None
def build_block():
mfgs = list()
col = torch.arange(0,root_len,device = device)
col_len = 0
row_len = root_len
for r in range(len(eid_len)):
elen = eid_len[r]
row = torch.arange(row_len,row_len+elen,device = device)
b = dgl.create_block((row,col[sample_out[r].src_index().to(device)]),
num_src_nodes = row_len + elen,
num_dst_nodes = row_len,
device = device)
idx = nid_inv[0:row_len + elen]
e_idx = eid_inv[col_len:col_len+elen]
b.srcdata['ID'] = idx#dist_nid[idx]
if sample_out[r].delta_ts().shape[0] > 0:
b.edata['dt'] = sample_out[r].delta_ts().to(device)
if src_ts is not None:
b.srcdata['ts'] = src_ts[0:row_len + eid_len[r]]
b.edata['ID'] = e_idx#dist_eid[e_idx]
#if edge_feat is not None:
# b.edata['f'] = edge_feat[e_idx]
#if r == len(eid_len)-1:
# if node_feat is not None:
# b.srcdata['h'] = node_feat[idx]
# if mem_embedding is not None:
# node_memory,node_memory_ts,mailbox,mailbox_ts = mem_embedding
# b.srcdata['mem'] = node_memory[idx]
# b.srcdata['mem_ts'] = node_memory_ts[idx]
# b.srcdata['mem_input'] = mailbox[idx].cuda().reshape(b.srcdata['ID'].shape[0], -1)
# b.srcdata['mail_ts'] = mailbox_ts[idx]
col = row
col_len += eid_len[r]
row_len += eid_len[r]
mfgs.append(b)
mfgs = list(map(list, zip(*[iter(mfgs)])))
mfgs.reverse()
return data,mfgs,metadata
data,mfgs,metadata = build_block()
if dist.get_world_size() > 1:
if(node_feat is None):
node_feat = torch.futures.Future()
node_feat.set_result(None)
if(edge_feat is None):
edge_feat = torch.futures.Future()
edge_feat.set_result(None)
if(mem is None):
mem = torch.futures.Future()
mem.set_result(None)
def callback(fs,mfgs,dist_nid,dist_eid):
node_feat,edge_feat,mem_embedding = fs.value()
node_feat = node_feat.value()
edge_feat = edge_feat.value()
mem_embedding = mem_embedding.value()
return prepare_input(node_feat,edge_feat,mem_embedding,mfgs,dist_nid,dist_eid)
cal = lambda fut: callback(fs=fut,mfgs = mfgs,dist_nid = dist_nid,dist_eid =dist_eid)
return data,torch.futures.collect_all([node_feat,edge_feat,mem]).then(cal),metadata
else:
mfgs = prepare_input(node_feat,edge_feat,mem,mfgs,dist_nid,dist_eid)
#return build_block(node_feat,edge_feat,mem)#data,mfgs,metadata
return data,mfgs,metadata
def graph_sample(graph, sampler:BaseSampler,
sample_fn, data,
neg_sampling = None,
mailbox = None,
device = torch.device('cuda')):
out = sample_fn(sampler,data,neg_sampling)
return to_block(graph,data,out,mailbox,device)
def sample_from_nodes(sampler:BaseSampler, data:DataSet, **kwargs):
out = sampler.sample_from_nodes(nodes=data.nodes.reshape(-1))
#out.metadata = None
return out
def sample_from_edges(sampler:BaseSampler,
data:DataSet,
neg_sampling:NegativeSampling = None):
edge_label = data.labels if hasattr(data,'labels') else None
out = sampler.sample_from_edges(edges = data.edges,
neg_sampling=neg_sampling)
return out
def sample_from_temporal_nodes(sampler:BaseSampler,data:DataSet,
**kwargs):
out = sampler.sample_from_nodes(nodes=data.nodes.reshape(-1),
ts = data.ts.reshape(-1))
#out.metadata = None
return out
def sample_from_temporal_edges(sampler:BaseSampler, data:DataSet,
neg_sampling: NegativeSampling = None):
edge_label = data.labels if hasattr(data,'labels') else None
out = sampler.sample_from_edges(edges=data.edges.to('cpu'),
ets=data.ts.to('cpu'),
neg_sampling = neg_sampling
)
return out
class SAMPLE_TYPE:
SAMPLE_FROM_NODES = sample_from_nodes,
SAMPLE_FROM_EDGES = sample_from_edges,
SAMPLE_FROM_TEMPORAL_NODES = sample_from_temporal_nodes,
SAMPLE_FROM_TEMPORAL_EDGES = sample_from_temporal_edges
\ No newline at end of file
from collections import deque
from enum import Enum
import queue
import torch
import sys
from os.path import abspath, join, dirname
import numpy as np
from starrygl.sample.batch_data import graph_sample
from starrygl.sample.sample_core.PreNegSampling import PreNegativeSampling
sys.path.insert(0, join(abspath(dirname(__file__))))
from typing import Deque, Optional
import torch.distributed as dist
from torch_geometric.data import Data
import os.path as osp
import math
class DistributedDataLoader:
'''
Args:
data_path: the path of loaded graph ,each part 0 of graph is saved on $path$/rank_0
num_replicas: the num of worker
'''
def __init__(
self,
graph,
dataset = None,
sampler = None,
sampler_fn = None,
neg_sampler = None,
batch_size: Optional[int]=None,
drop_last = False,
device: torch.device = torch.device('cuda'),
shuffle:bool = True,
chunk_size = None,
train = False,
queue_size = 10,
mailbox = None,
**kwargs
):
assert sampler is not None
self.chunk_size = chunk_size
self.batch_size = batch_size
self.queue_size = queue_size
self.num_pending = 0
self.current_pos = 0
self.recv_idxs = 0
self.drop_last = drop_last
self.result_queue = deque(maxlen = self.queue_size)
self.shuffle = shuffle
self.is_closed = False
self.sampler = sampler
self.sampler_fn = sampler_fn
self.neg_sampler = neg_sampler
self.graph = graph
self.shuffle=shuffle
self.dataset = dataset
self.mailbox = mailbox
self.device = device
if train is True:
self._get_expected_idx(self.dataset.len)
else:
self.expected_idx = int(math.ceil(self.dataset.len/self.batch_size))
def __iter__(self):
if self.chunk_size is None:
if self.shuffle:
self.input_dataset = self.dataset.shuffle()
else:
self.input_dataset = self.dataset
self.recv_idxs = 0
self.current_pos = 0
self.num_pending = 0
self.submitted = 0
else:
self.input_dataset = self.dataset
self.recv_idxs = 0
self.num_pending = 0
self.submitted = 0
if dist.get_rank == 0:
self.current_pos = int(
math.floor(
np.random.uniform(0,self.batch_size/self.chunk_size)
)*self.chunk_size
)
else:
self.current_pos = 0
current_pos = torch.tensor([self.current_pos],dtype = torch.long,device=self.device)
dist.broadcast(current_pos, src = 0)
self.current_pos = int(current_pos.item())
self._get_expected_idx(self.dataset.len-self.current_pos)
if self.neg_sampler is not None \
and isinstance(self.neg_sampler,PreNegativeSampling):
self.neg_sampler.set_next_pos(self.current_pos)
return self
def _get_expected_idx(self,data_size):
world_size = dist.get_world_size()
self.expected_idx = data_size // self.batch_size if self.drop_last is True else int(math.ceil(data_size/self.batch_size))
if dist.get_world_size() > 1:
num_epochs = torch.tensor([self.expected_idx],dtype = torch.long,device=self.device)
dist.all_reduce(num_epochs, op=dist.ReduceOp.MIN)
self.expected_idx = int(num_epochs.item())
def _next_data(self):
if self.current_pos >= self.dataset.len:
return None
if self.current_pos + self.batch_size > self.input_dataset.len:
if self.drop_last:
return None
else:
next_data = self.input_dataset.get_next(
slice(self.current_pos,None,None)
)
self.current_pos = 0
else:
next_data = self.input_dataset.get_next(
slice(self.current_pos,self.current_pos + self.batch_size,None)
)
self.current_pos += self.batch_size
return next_data
def __next__(self):
if(dist.get_world_size() == 1):
if self.recv_idxs < self.expected_idx:
data = self._next_data()
batch_data = graph_sample(self.graph,
self.sampler,
self.sampler_fn,
data,self.neg_sampler,
self.mailbox,
self.device)
self.recv_idxs += 1
assert batch_data is not None
return batch_data
else :
raise StopIteration
else :
num_reqs = min(self.queue_size - self.num_pending,self.expected_idx - self.submitted)
for _ in range(num_reqs):
if len(self.result_queue)>0:
result = self.result_queue[0]
if result[1].done() == True:
self.recv_idxs += 1
self.num_pending -= 1
self.result_queue.popleft()
return result[0],result[1].value(),result[2]
else:
next_data = self._next_data()
assert next_data is not None
batch_data = graph_sample(self.graph,
self.sampler,
self.sampler_fn,
next_data,self.neg_sampler,
self.mailbox,
self.device)
self.result_queue.append(batch_data)
self.submitted = self.submitted + 1
self.num_pending = self.num_pending + 1
while(self.recv_idxs < self.expected_idx):
assert len(self.result_queue) > 0
result= self.result_queue[0]
if result[1].done() == True:
self.recv_idxs += 1
self.num_pending -= 1
self.result_queue.popleft()
return result[0],result[1].value(),result[2]
assert self.num_pending == 0
raise StopIteration
from starrygl.distributed.utils import DistIndex, DistributedTensor
from starrygl.sample.graph_core.utils import build_mapper
import os.path as osp
import torch
import torch.distributed as dist
from torch_geometric.data import Data
class GraphData():
def __init__(self, pdata, device = torch.device('cuda'), all_on_gpu = False):
self.device = device
self.ids = pdata.ids.to(device)
self.edge_index = pdata.edge_index.to(device)
if hasattr(pdata,'edge_ts'):
self.edge_ts = pdata.edge_ts.to(device).to(torch.float)
else:
self.edge_ts = None
self.sample_graph = pdata.sample_graph
self.nids_mapper = build_mapper(nids=pdata.ids.to(device)).data.to('cpu')
self.eids_mapper = build_mapper(nids=pdata.eids.to(device)).data.to('cpu')
if all_on_gpu:
self.nids_mapper = self.nids_mapper.to(device)
self.eids_mapper = self.eids_mapper.to(device)
self.num_nodes = self.nids_mapper.data.shape[0]
world_size = dist.get_world_size()
if hasattr(pdata,'x') and pdata.x is not None:
if world_size > 1:
self.x = DistributedTensor(pdata.x.to(self.device))
else:
self.x = pdata.x.to(device).to(torch.float)
else:
self.x = None
if hasattr(pdata,'edge_attr') and pdata.edge_attr is not None:
if world_size > 1:
self.edge_attr = DistributedTensor(pdata.edge_attr.to(self.device))
else:
self.edge_attr = pdata.edge_attr.to('cuda').to(torch.float)
else:
self.edge_attr = None
def _get_node_attr(self,ids):
if self.x is None:
return None
elif dist.get_world_size() == 1:
return self.x[ids]
else:
return self.x.index_select(ids)
def _get_edge_attr(self,ids):
if self.edge_attr is None:
return None
elif dist.get_world_size() == 1:
return self.edge_attr[ids.to('cpu')].to('cuda')
else:
return self.edge_attr.index_select(ids)
class DataSet:
def __init__(self,nodes = None,
edges = None,
labels = None,
ts = None,
device = torch.device('cuda'),**kwargs):
if nodes is not None:
self.nodes = nodes.to(device)
if edges is not None:
self.edges = edges.to(device)
if ts is not None:
self.ts = ts.to(device)
if labels is not None:
self.labels = labels
self.len = self.nodes.shape[0] if nodes is not None else self.edges.shape[1]
for k, v in kwargs.items():
assert isinstance(v,torch.Tensor) and v.shape[0]==self.len
setattr(self, k, v.to(device))
#@staticmethod
def get_next(self,indx):
nodes = self.nodes[indx] if hasattr(self,'nodes') else None
edges = self.edges[:,indx] if hasattr(self,'edges') else None
d = DataSet(nodes,edges)
for k,v in self.__dict__.items():
if k == 'edges' or k=='nodes' or k == 'len':
continue
else:
setattr(d,k,v[indx])
return d
#@staticmethod
def shuffle(self):
indx = torch.randperm(self.len)
nodes = self.nodes[indx] if hasattr(self,'nodes') else None
edges = self.edges[:,indx] if hasattr(self,'edges') else None
d = DataSet(nodes,edges)
for k,v in self.__dict__.items():
if k == 'edges' or k=='nodes' or k == 'len':
continue
else:
setattr(d,k,v[indx])
return d
class TemporalGraphData():
def __init__(self,pdata,device):
super(TemporalGraphData,self).__init__(pdata,device)
def _set_temporal_batch_cache(self,size,pin_size):
pass
def _load_feature_to_cuda(self,ids):
pass
class TemporalNeighborSampleGraph(GraphData):
def __init__(self, sample_graph=None, mode='full', eids_mapper=None):
self.edge_index = sample_graph['edge_index']
self.num_edges = self.edge_index.shape[1]
if 'ts' in sample_graph:
self.edge_ts = sample_graph['ts']
else:
self.edge_ts = None
self.eid = sample_graph['eids']
if mode == 'train':
mask = sample_graph['train_mask']
if mode == 'val':
mask = sample_graph['val_mask']
if mode == 'test':
mask = sample_graph['test_mask']
if mode != 'full':
self.edge_index = self.edge_index[:, mask]
self.edge_ts = self.edge_ts[mask]
self.eid = self.eid[mask]
from starrygl.distributed.context import DistributedContext
from typing import *
import torch.distributed as dist
import torch
from starrygl.distributed.utils import DistIndex
def build_mapper(nids):
rank = dist.get_rank()
world_size = dist.get_world_size()
dst_len = nids.size(0)
ikw = dict(dtype=torch.long, device=nids.device)
num_nodes = torch.zeros(1, **ikw)
num_nodes[0] = dst_len
dist.all_reduce(num_nodes, op=dist.ReduceOp.SUM)
all_ids: List[torch.Tensor] = [None] * world_size
dist.all_gather_object(all_ids,nids)
part_mp = torch.empty(num_nodes,**ikw)
ind_mp = torch.empty(num_nodes,**ikw)
for i in range(world_size):
iid = all_ids[i]
part_mp[iid] = i
ind_mp[iid] = torch.arange(all_ids[i].shape[0],**ikw)
return DistIndex(ind_mp,part_mp)
\ No newline at end of file
from typing import List
import torch
from torch.distributed import rpc
import torch_scatter
from starrygl.distributed.context import DistributedContext
from starrygl.distributed.utils import DistIndex, DistributedTensor
import torch.distributed as dist
class SharedMailBox():
def __init__(self,
num_nodes,
memory_param,
dim_edge_feat,
device = torch.device('cuda')):
self.device = device
self.num_nodes = num_nodes
self.num_parts = dist.get_world_size()
if memory_param['type'] != 'node':
raise NotImplementedError
self.memory_param = memory_param
self.memory_size = memory_param['dim_out']
node_memory = torch.zeros((
self.num_nodes, memory_param['dim_out']),
dtype=torch.float32,device =self.device)
node_memory_ts = torch.zeros(self.num_nodes,
dtype=torch.float32,
device = self.device)
mailbox = torch.zeros(self.num_nodes,
memory_param['mailbox_size'],
2 * memory_param['dim_out'] + dim_edge_feat,
device = self.device, dtype=torch.float32)
mailbox_ts = torch.zeros((self.num_nodes,
memory_param['mailbox_size']),
dtype=torch.float32,device = self.device)
self.node_memory = DistributedTensor(node_memory)
self.node_memory_ts = DistributedTensor(node_memory_ts)
self.mailbox = DistributedTensor(mailbox)
self.mailbox_ts = DistributedTensor(mailbox_ts)
self.next_mail_pos = torch.zeros((self.num_nodes),
dtype=torch.long,
device = self.device)
self._ctx = DistributedContext.get_default_context()
self.rref = rpc.RRef(self)
self.rrefs = self._ctx.all_gather_remote_objects(self.rref)
self.partptr = torch.tensor([ ((i & 0xFFFF)<<48) for i in range(self.num_parts+1) ],device = device)
def reset(self):
self.node_memory.accessor.data.zero_()
self.node_memory_ts.accessor.data.zero_()
self.mailbox.accessor.data.zero_()
self.mailbox_ts.accessor.data.zero_()
self.next_mail_pos.zero_()
def set_memory_local(self,index,source,source_ts,Reduce_Op = None):
if Reduce_Op == 'max':
unq_id,inv = index.unique(return_inverse = True)
max_ts,id = torch_scatter.scatter_max(source_ts,inv,dim=0)
source_ts = max_ts
source = source[id]
index = unq_id
self.node_memory.accessor.data[index] = source
self.node_memory_ts.accessor.data[index] = source_ts
def set_mailbox_local(self,index,source,source_ts,Reduce_Op = None):
if Reduce_Op == 'max' and self.num_parts > 1:
unq_id,inv = index.unique(return_inverse = True)
max_ts,id = torch_scatter.scatter_max(source_ts,inv,dim=0)
source_ts = max_ts
source = source[id]
index = unq_id
self.mailbox_ts.accessor.data[index, self.next_mail_pos[index]] = source_ts
self.mailbox.accessor.data[index, self.next_mail_pos[index]] = source
if self.memory_param['mailbox_size'] > 1:
self.next_mail_pos[index] = torch.remainder(
self.next_mail_pos[index] + 1,
self.memory_param['mailbox_size'])
def set_memory_async(self,index,source,source_ts):
dist_index = DistIndex(index)
part_idx = dist_index.part
index = dist_index.loc
futs: List[torch.futures.Future] = []
if self.num_parts == 1:
self.set_memory_local(index,source,source_ts)
for i in range(self.num_parts):
fut = self.ctx.remote_call(
SharedMailBox.set_memory_local,
self.rrefs[i],
index[part_idx == i],
source[part_idx == i],
source_ts[part_idx == i])
futs.append(fut)
return torch.futures.collect_all(futs)
def add_to_mailbox_async(self,index,source,source_ts):
dist_index = DistIndex(index)
part_idx = dist_index.part
index = dist_index.loc
futs: List[torch.futures.Future] = []
if self.num_parts == 1:
self.set_mailbox_local(index,source,source_ts)
else:
for i in range(self.num_parts):
fut = self.ctx.remote_call(
SharedMailBox.set_mailbox_local,
self.rrefs[i],
index[part_idx == i],
source[part_idx == i],
source_ts[part_idx == i])
futs.append(fut)
return torch.futures.collect_all(futs)
def set_mailbox_all_to_all(self,index,memory,
memory_ts,mail,mail_ts,
reduce_Op = None):
#futs: List[torch.futures.Future] = []
if self.num_parts == 1:
dist_index = DistIndex(index)
part_idx = dist_index.part
index = dist_index.loc
self.set_mailbox_local(index,mail,mail_ts)
self.set_memory_local(index,memory,memory_ts)
else:
gather_len_list = torch.empty([self.num_parts],
dtype = int,
device = self.device)
indic = torch.searchsorted(index,self.partptr,right=False)
scatter_len_list = indic[1:] - indic[0:-1]
torch.distributed.all_to_all_single(gather_len_list,scatter_len_list)
input_split = scatter_len_list.tolist()
output_split = gather_len_list.tolist()
gather_id_list = torch.empty(
[gather_len_list.sum()],
dtype = torch.long,
device = self.device)
input_split = scatter_len_list.tolist()
output_split = gather_len_list.tolist()
torch.distributed.all_to_all_single(
gather_id_list,index,output_split_sizes=output_split,
input_split_sizes=input_split)
index = gather_id_list
gather_memory = torch.empty(
[gather_len_list.sum(),memory.shape[1]],
dtype = memory.dtype,device = self.device)
gather_memory_ts = torch.empty(
[gather_len_list.sum()],
dtype = memory_ts.dtype,device = self.device)
gather_mail = torch.empty(
[gather_len_list.sum(),mail.shape[1]],
dtype = mail.dtype,device = self.device)
gather_mail_ts = torch.empty(
[gather_len_list.sum()],
dtype = mail_ts.dtype,device = self.device)
torch.distributed.all_to_all_single(
gather_memory,memory,
output_split_sizes=output_split,
input_split_sizes=input_split)
torch.distributed.all_to_all_single(
gather_memory_ts,memory_ts,
output_split_sizes=output_split,
input_split_sizes=input_split)
torch.distributed.all_to_all_single(
gather_mail,mail,
output_split_sizes=output_split,
input_split_sizes=input_split)
torch.distributed.all_to_all_single(
gather_mail_ts,mail_ts,
output_split_sizes=output_split,
input_split_sizes=input_split)
self.set_mailbox_local(DistIndex(index).loc,gather_mail,gather_mail_ts,Reduce_Op = reduce_Op)
self.set_memory_local(DistIndex(index).loc,gather_memory,gather_memory_ts, Reduce_Op = reduce_Op)
def get_update_mail(self,dist_indx_mapper,
src,dst,ts,edge_feats,
memory):
if edge_feats is not None:
edge_feats = edge_feats.to(self.device).to(self.mailbox.dtype)
src = src.to(self.device)
dst = dst.to(self.device)
index = torch.cat([src, dst]).reshape(-1)
index = dist_indx_mapper[index]
mem_src = memory[src]
mem_dst = memory[dst]
if edge_feats is not None:
src_mail = torch.cat([mem_src, mem_dst, edge_feats], dim=1)
dst_mail = torch.cat([mem_dst, mem_src, edge_feats], dim=1)
else:
src_mail = torch.cat([mem_src, mem_dst], dim=1)
dst_mail = torch.cat([mem_dst, mem_src], dim=1)
mail = torch.cat([src_mail, dst_mail], dim=1).reshape(-1, src_mail.shape[1])
mail_ts = torch.cat((ts,ts),-1).to(self.device).to(self.mailbox_ts.dtype)
unq_index,inv = torch.unique(index,return_inverse = True)
max_ts,idx = torch_scatter.scatter_max(mail_ts,inv,0)
mail_ts = max_ts
mail = mail[idx]
index = unq_index
return index,mail,mail_ts
def get_update_memory(self,index,memory,memory_ts):
unq_index,inv = torch.unique(index,return_inverse = True)
max_ts,idx = torch_scatter.scatter_max(memory_ts,inv,0)
ts = max_ts
memory = memory[idx]
index = unq_index
return index,memory,ts
def _get_memory(self,index):
if self.num_parts == 1:
return self.node_memory.accessor.data[index],\
self.node_memory_ts.accessor.data[index],\
self.mailbox.accessor.data[index],\
self.mailbox_ts.accessor.data[index]
else:
memory = self.node_memory.index_select(index)
memory_ts = self.node_memory_ts.index_select(index)
mail = self.mailbox.index_select(index)
mail_ts = self.mailbox_ts.index_select(index)
def callback(fs):
memory,memory_ts,mail,mail_ts = fs.value()
memory = memory.value()
memory_ts = memory_ts.value()
mail = mail.value()
mail_ts = mail_ts.value()
#print(memory.shape[0])
return memory,memory_ts,mail,mail_ts
return torch.futures.collect_all([memory,memory_ts,mail,mail_ts]).then(callback)
import parser
from torch_sparse import SparseTensor
from torch_geometric.data import Data
from torch_geometric.utils import degree
from .torch_utils import get_norm_temporal
import os.path as osp
import os
import shutil
import torch
import torch.utils.data
import metis
import networkx as nx
import torch.distributed as dist
def partition_load(root: str, algo: str = "metis") -> Data:
rank = dist.get_rank()
world_size = dist.get_world_size()
fn = osp.join(root, f"{algo}_{world_size}", f"{rank:03d}")
return torch.load(fn)
def partition_save(root: str, data: Data, num_parts: int,
algo: str = "metis",
edge_weight_dict=None):
root = osp.abspath(root)
if osp.exists(root) and not osp.isdir(root):
raise ValueError(f"path '{root}' should be a directory")
path = osp.join(root, f"{algo}_{num_parts}")
if osp.exists(path) and not osp.isdir(path):
raise ValueError(f"path '{path}' should be a directory")
if osp.exists(path) and os.listdir(path):
print(f"directory '{path}' not empty and cleared")
for p in os.listdir(path):
p = osp.join(path, p)
if osp.isdir(p):
shutil.rmtree(osp.join(path, p))
else:
os.remove(p)
if not osp.exists(path):
print(f"creating directory '{path}'")
os.makedirs(path)
if algo == 'metis_for_tgnn':
for i, pdata in enumerate(partition_data_for_tgnn(
data, num_parts, algo, verbose=True,
edge_weight_dict=edge_weight_dict)):
print(f"saving partition data: {i+1}/{num_parts}")
fn = osp.join(path, f"{i:03d}")
torch.save(pdata, fn)
else:
for i, pdata in enumerate(partition_data_for_gnn(data, num_parts,
algo, verbose=True)):
print(f"saving partition data: {i+1}/{num_parts}")
fn = osp.join(path, f"{i:03d}")
torch.save(pdata, fn)
def partition_data_for_gnn(data: Data, num_parts: int,
algo: str, verbose: bool = False):
if algo == "metis":
part_fn = metis_partition
else:
raise ValueError(f"invalid algorithm: {algo}")
num_nodes = data.num_nodes
num_edges = data.num_edges
edge_index = data.edge_index
if verbose:
print(f"running partition algorithm: {algo}")
node_parts, edge_parts = part_fn(edge_index, num_nodes, num_parts)
if verbose:
print("computing GCN normalized factor")
gcn_norm = compute_gcn_norm(edge_index, num_nodes)
if data.y.dtype == torch.long:
if verbose:
print("compute num_classes")
num_classes = data.y.max().item() + 1
else:
num_classes = None
eids = torch.zeros(num_edges, dtype=torch.long)
len = 0
edgeptr = torch.zeros(num_parts+1, dtype=eids.dtype)
for i in range(num_parts):
epart_i = torch.where(edge_parts == i)[0]
eids[epart_i] = torch.arange(epart_i.shape[0]) + len
len += epart_i.shape[0]
edgeptr[i+1] = len
data.eids = eids
data.sample_graph.sample_eids = eids[data.sample_graph.sample_eid]
nids = torch.zeros(num_nodes, dtype=torch.long)
len = 0
partptr = torch.zeros(num_parts+1, dtype=nids.dtype)
for i in range(num_parts):
npart_i = torch.where(node_parts == i)[0]
nids[npart_i] = torch.arange(npart_i.shape[0]) + len
len += npart_i.shape[0]
partptr[i+1] = len
data.edge_index = nids[data.edge_index]
data.sample_graph.edge_index = nids[data.sample_graph.edge_index]
for i in range(num_parts):
npart_i = torch.where(node_parts == i)[0]
epart_i = torch.where(edge_parts == i)[0]
npart = npart_i
epart = edge_index[:, epart_i]
pdata = {
"ids": npart,
"edge_index": epart,
"gcn_norm": gcn_norm[epart_i],
"sample_graph": data.sample_graph,
"partptr": partptr,
"edgeptr": edgeptr
}
if num_classes is not None:
pdata["num_classes"] = num_classes
for key, val in data:
if key == "edge_index" or key == "sample_graph":
continue
if isinstance(val, torch.Tensor):
if val.size(0) == num_nodes:
pdata[key] = val[npart_i]
elif val.size(0) == num_edges:
pdata[key] = val[epart_i]
# else:
# pdata[key] = val
elif isinstance(val, SparseTensor):
pass
else:
pdata[key] = val
pdata = Data(**pdata)
yield pdata
def _nopart(edge_index: torch.LongTensor, num_nodes: int):
node_parts = torch.zeros(num_nodes, dtype=torch.long)
if isinstance(edge_index, torch.Tensor):
edge_parts = torch.zeros(edge_index.size(1), dtype=torch.long)
return node_parts, edge_parts
return node_parts
def metis_for_tgnn(edge_index_dict: dict,
num_nodes: int,
num_parts: int,
edge_weight_dict=None):
if num_parts <= 1:
return _nopart(edge_index_dict, num_nodes)
G = nx.Graph()
G.add_nodes_from(torch.arange(0, num_nodes).tolist())
value, counts = torch.unique(edge_index_dict['edata'][1, :].view(-1),
return_counts=True)
nodes = torch.tensor(list(G.adj.keys()))
for i in range(value.shape[0]):
if (value[i].item() in G.nodes):
G.nodes[int(value[i].item())]['weight'] = counts[i]
G.nodes[int(value[i].item())]['ones'] = 1
G.graph['node_weight_attr'] = ['weight', 'ones']
for i, key in enumerate(edge_index_dict):
v = edge_index_dict[key]
edges = torch.cat((v, (torch.ones(v.shape[1], dtype=torch.long) *
edge_weight_dict[key]).unsqueeze(0)), dim=0)
# w = edges.T
G.add_weighted_edges_from((edges.T).tolist())
G.graph['edge_weight_attr'] = 'weight'
cuts, part = metis.part_graph(G, num_parts)
node_parts = torch.zeros(num_nodes, dtype=torch.long)
node_parts[nodes] = torch.tensor(part)
return node_parts
"""
weight: 各种工作负载边划分权重
按照点均衡划分
"""
def partition_data_for_tgnn(data: Data, num_parts: int, algo: str,
verbose: bool = False,
edge_weight_dict: dict = None):
if algo == "metis_for_tgnn":
part_fn = metis_for_tgnn
else:
raise ValueError(f"invalid algorithm: {algo}")
num_nodes = data.num_nodes
num_edges = data.num_edges
edge_index_dict = data.edge_index_dict
tgnn_norm = compute_temporal_norm(data.edge_index, data.edge_ts, num_nodes)
if verbose:
print(f"running partition algorithm: {algo}")
node_parts = part_fn(edge_index_dict, num_nodes, num_parts,
edge_weight_dict)
edge_parts = node_parts[data.edge_index[1, :]]
eids = torch.arange(num_edges, dtype=torch.long)
data.eids = eids
data.sample_graph['eids'] = eids[data.sample_graph['eids']]
if data.y.dtype == torch.long:
if verbose:
print("compute num_classes")
num_classes = data.y.max().item() + 1
else:
num_classes = None
for i in range(num_parts):
npart_i = torch.where(node_parts == i)[0]
epart_i = torch.where(edge_parts == i)[0]
pdata = {
"ids": npart_i,
"tgnn_norm": tgnn_norm,
"edge_index": data.edge_index[:, epart_i],
"sample_graph": data.sample_graph
}
if num_classes is not None:
pdata["num_classes"] = num_classes
for key, val in data:
if key == "edge_index" or key == "edge_index_dict" \
or key == "sample_graph":
continue
if isinstance(val, torch.Tensor):
if val.size(0) == num_nodes:
pdata[key] = val[npart_i]
elif val.size(0) == num_edges:
pdata[key] = val[epart_i]
# else:
# pdata[key] = val
elif isinstance(val, SparseTensor):
pass
else:
pdata[key] = val
pdata = Data(**pdata)
yield pdata
def metis_partition(edge_index, num_nodes: int, num_parts: int):
if num_parts <= 1:
return _nopart(edge_index, num_nodes)
G = nx.Graph()
G.add_nodes_from(torch.arange(0, num_nodes).tolist())
G.add_edges_from(edge_index.T.tolist())
nodes = torch.tensor(list(G.adj.keys()))
nodes = torch.tensor(list(G.adj.keys()))
cuts, part = metis.part_graph(G, num_parts)
node_parts = torch.zeros(num_nodes, dtype=torch.long)
node_parts[nodes] = torch.tensor(part)
edge_parts = node_parts[edge_index[1]]
return node_parts, edge_parts
def metis_partition_bydegree(edge_index, num_nodes: int, num_parts: int):
if num_parts <= 1:
return _nopart(edge_index, num_nodes)
G = nx.Graph()
G.add_nodes_from(torch.arange(0, num_nodes).tolist())
G.add_edges_from(edge_index.T.tolist())
value, counts = torch.unique(edge_index[1, :].view(-1), return_counts=True)
nodes = torch.tensor(list(G.adj.keys()))
for i in range(value.shape[0]):
if (value[i].item() in G.nodes):
G.nodes[int(value[i].item())]['weight'] = counts[i]
G.graph['node_weight_attr'] = 'weight'
nodes = torch.tensor(list(G.adj.keys()))
cuts, part = metis.part_graph(G, num_parts)
node_parts = torch.zeros(num_nodes, dtype=torch.long)
node_parts[nodes] = torch.tensor(part)
edge_parts = node_parts[edge_index[1]]
return node_parts, edge_parts
def compute_gcn_norm(edge_index: torch.LongTensor, num_nodes: int):
deg_j = degree(edge_index[0], num_nodes).pow(-0.5)
deg_i = degree(edge_index[1], num_nodes).pow(-0.5)
deg_i[deg_i.isinf() | deg_i.isnan()] = 0.0
deg_j[deg_j.isinf() | deg_j.isnan()] = 0.0
return deg_j[edge_index[0]] * deg_i[edge_index[1]]
def compute_temporal_norm(edge_index: torch.LongTensor,
timestamp: torch.FloatTensor,
num_nodes: int):
srcavg, srcvar, dstavg, dstvar = get_norm_temporal(edge_index[0, :],
edge_index[1, :],
timestamp, num_nodes)
return srcavg, srcvar, dstavg, dstvar
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CppExtension
setup(
name='torch_utils',
ext_modules=[
CppExtension(
name='torch_utils',
sources=['./torch_utils.cpp'],
extra_compile_args=['-fopenmp','-Xlinker',' -export-dynamic'],
include_dirs=["../sample_core"],
),
],
cmdclass={
'build_ext': BuildExtension
})#
#setup(
# name='cpu_cache_manager',
# ext_modules=[
# CppExtension(
# name='cpu_cache_manager',
# sources=['cpu_cache_manager.cpp'],
# extra_compile_args=['-fopenmp','-Xlinker',' -export-dynamic'],
# include_dirs=["./"],
# ),
# ],
# cmdclass={
# 'build_ext': BuildExtension
# })#
#
\ No newline at end of file
#include <pybind11/pybind11.h>
#include <pybind11/numpy.h>
#include <pybind11/stl.h>
#include <torch/extension.h>
#include <parallel_hashmap/phmap.h>
#include <cstring>
#include <vector>
#include <iostream>
#include <map>
#include <boost/thread/mutex.hpp>
using namespace std;
#define MTX boost::mutex
namespace py = pybind11;
namespace th = torch;
typedef int64_t NodeIDType;
typedef float TimeStampType;
#define EXTRAARGS , phmap::priv::hash_default_hash<K>, \
phmap::priv::hash_default_eq<K>, \
std::allocator<K>, 4, MTX
template <class K>
using HashT = phmap::parallel_flat_hash_set<K EXTRAARGS>;
template <class K, class V>
using HashM = phmap::parallel_flat_hash_map<K, V EXTRAARGS>;
#define work_thread 10
th::Tensor sparse_get_index(th::Tensor in,th::Tensor map_key){
auto key_ptr = map_key.data_ptr<NodeIDType>();
auto in_ptr = in.data_ptr<NodeIDType>();
int sz = map_key.size(0);
vector<pair<NodeIDType,NodeIDType>> mp(sz);
vector<NodeIDType> out(in.size(0));
#pragma omp parallel for
for(int i=0;i<sz;i++){
mp[i] = make_pair(key_ptr[i],i);
}
phmap::parallel_flat_hash_map<NodeIDType,NodeIDType> dict(mp.begin(),mp.end());
#pragma omp parallel for
for(int i=0;i<in.size(0);i++){
out[i] = dict.find(in_ptr[i])->second;
}
return th::tensor(out);
}
vector<double> get_norm_temporal(th::Tensor row,th::Tensor col,th::Tensor timestamp,int num_nodes){
vector<double> ret(4);
HashM<NodeIDType,TimeStampType> dict0;
HashM<NodeIDType,TimeStampType> dict1;
auto rowptr = row.data_ptr<NodeIDType>();
auto colptr = col.data_ptr<NodeIDType>();
auto time_ptr = timestamp.data_ptr<TimeStampType>();
vector<TimeStampType> out_timestamp[work_thread];
vector<TimeStampType> in_timestamp[work_thread];
#pragma omp parallel for num_threads(work_thread)
for(int i = 0;i<row.size(0);i++){
int tid = omp_get_thread_num();
if(dict0.find(rowptr[i])!=dict0.end()){
out_timestamp[tid].push_back(time_ptr[i]-dict0.find(rowptr[i])->second);
dict0.find(rowptr[i])->second = time_ptr[i];
}
else dict0.insert(make_pair(rowptr[i],time_ptr[i]));
if(dict1.find(colptr[i])!=dict1.end()){
in_timestamp[tid].push_back(time_ptr[i]-dict1.find(colptr[i])->second);
dict1.find(colptr[i])->second = time_ptr[i];
}
else dict1.insert(make_pair(colptr[i],time_ptr[i]));
}
double srcavg = 0;
double dstavg = 0;
double srcvar = 0;
double dstvar = 0;
double srccnt = 0;
double dstcnt = 0;
for(int i = 0;i<work_thread;i++){
#pragma omp parallel for num_threads(work_thread)
for(auto &v: in_timestamp[i]){
dstavg += v;
dstcnt++;
}
#pragma omp parallel for num_threads(work_thread)
for(auto &v: out_timestamp[i]){
srcavg += v;
srccnt++;
}
}
dstavg /= dstcnt;
srcavg /= srccnt;
for(int i = 0;i<work_thread;i++){
#pragma omp parallel for num_threads(work_thread)
for(int j = 0;j<in_timestamp[i].size();j++){
TimeStampType v=in_timestamp[i][j];
dstvar += (v-dstavg)*(v-dstavg)/dstavg;
}
#pragma omp parallel for num_threads(work_thread)
for(int j = 0;j<out_timestamp[i].size();j++){
TimeStampType v=out_timestamp[i][j];
srcavg += (v-srcavg)*(v-srcavg)/srcavg;
}
}
ret[0]=srcavg;
ret[1]=srcvar;
ret[2]=dstavg;
ret[3]=dstvar;
return ret;
}
PYBIND11_MODULE(torch_utils, m)
{
m
.def("sparse_get_index",
&sparse_get_index,
py::return_value_policy::reference)
.def("get_norm_temporal",
&get_norm_temporal,
py::return_value_policy::reference
);
}
\ No newline at end of file
import sys
from os.path import abspath, join, dirname
sys.path.insert(0, join(abspath(dirname(__file__))))
from torch import Tensor
import torch
from base import NegativeSampling
from base import NegativeSamplingMode
from typing import Any, List, Optional, Tuple, Union
class PreNegativeSampling(NegativeSampling):
r"""The negative sampling configuration of a
:class:`~torch_geometric.sampler.BaseSampler` when calling
:meth:`~torch_geometric.sampler.BaseSampler.sample_from_edges`.
Args:
mode (str): The negative sampling mode
(:obj:`"binary"` or :obj:`"triplet"`).
If set to :obj:`"binary"`, will randomly sample negative links
from the graph.
If set to :obj:`"triplet"`, will randomly sample negative
destination nodes for each positive source node.
amount (int or float, optional): The ratio of sampled negative edges to
the number of positive edges. (default: :obj:`1`)
weight (torch.Tensor, optional): A node-level vector determining the
sampling of nodes. Does not necessariyl need to sum up to one.
If not given, negative nodes will be sampled uniformly.
(default: :obj:`None`)
"""
def __init__(
self,
mode: Union[NegativeSamplingMode, str],
neg_sample_list: torch.Tensor
):
super(PreNegativeSampling,self).__init__(mode)
self.neg_sample_list = neg_sample_list
self.next_pos = 0
def set_next_pos(self,pos):
self.next_pos = pos
def sample(self, num_samples: int,
num_nodes: Optional[int] = None) -> Tensor:
r"""Generates :obj:`num_samples` negative samples."""
if num_nodes is None:
raise ValueError(
f"Cannot sample negatives in '{self.__class__.__name__}' "
f"without passing the 'num_nodes' argument")
neg_sample_out = self.neg_sample_list[
self.next_pos:self.next_pos+num_samples,:].reshape(-1)
self.next_pos = self.next_pos + num_samples
return neg_sample_out
#return torch.from_numpy(np.random.randint(num_nodes, size=num_samples))
import os.path as osp
import torch
class GraphData():
def __init__(self, path):
assert path is not None and osp.exists(path),'path 不存在'
id,edge_index,data,partptr =torch.load(path)
# 当前分区序号
self.partition_id = id
# 总分区数
self.partitions = partptr.numel() - 1
# 全图结构数据
self.num_nodes = partptr[self.partitions]
self.num_edges = edge_index[0].numel()
self.edge_index = edge_index
# 该分区下的数据(包含特征向量和子图结构)pyg Data数据结构
self.data = data
# 分区映射关系
self.partptr = partptr
self.eid = [i for i in range(self.num_edges)]
def __init__(self, id, edge_index, data, partptr, timestamp=None):
# 当前分区序号
self.partition_id = id
# 总分区数
self.partitions = partptr.numel() - 1
# 全图结构数据
self.num_nodes = partptr[self.partitions]
if edge_index is not None:
self.num_edges = edge_index[0].numel()
self.edge_index = edge_index
self.edge_ts = timestamp
# 该分区下的数据(包含特征向量和子图结构)pyg Data数据结构
self.data = data
# 分区映射关系
self.partptr = partptr
# edge id
self.eid = torch.tensor([i for i in range(0, self.num_edges)])
def select_attr(self,index):
return torch.index_select(self.data.x,0,index)
#返回全局的节点id 所对应的分区
def get_part_num(self):
return self.data.x.size()[0]
def select_attr(self,index):
return torch.index_select(self.data.x,0,index)
def select_y(self,index):
return torch.index_select(self.data.y,0,index)
#返回全局的节点id 所对应的分区
def get_localId_by_partitionId(self,id,index):
#print(index)
if(id == -1 or id == 0):
return index
else:
return torch.add(index,-self.partptr[id])
def get_globalId_by_partitionId(self,id,index):
if(id == -1 or id == 0):
return index
else:
return torch.add(index,self.partptr[id])
def get_node_num(self):
return self.num_nodes
def localId_to_globalId(self,id,partitionId:int = -1):
'''
将分区partitionId内的点id映射为全局的id
'''
if partitionId == -1:
partitionId = self.partition_id
assert id >=self.partptr[self.partition_id] and id < self.partptr[self.partition_id+1]
ids_before = 0
if self.partition_id>0:
ids_before = self.partptr[self.partition_id-1]
return id+ids_before
def get_partitionId_by_globalId(self,id):
'''
通过全局id得到对应的分区序号
'''
partitionId = -1
assert id>=0 and id<self.num_nodes,'id 超过范围'
for i in range(self.partitions):
if id>=self.partptr[i] and id<self.partptr[i+1]:
partitionId = i
break
assert partitionId>=0, 'id 不存在对应的分区'
return partitionId
def get_nodes_by_partitionId(self,id):
'''
根据partitioId 返回该分区的节点数量
'''
assert id>=0 and id<self.partitions,'partitionId 非法'
return (int)(self.partptr[id+1]-self.partptr[id])
def __repr__(self):
return (f'{self.__class__.__name__}(\n'
f' partition_id={self.partition_id}\n'
f' data={self.data},\n'
f' global_info('
f'num_nodes={self.num_nodes},'
f' num_edges={self.num_edges},'
f' num_parts={self.partitions},'
f' edge_index=[2,{self.edge_index[0].numel()}])\n'
f')')
import torch
from torch import Tensor
from enum import Enum
import math
from abc import ABC
from typing import Any, List, Optional, Tuple, Union
class SampleType(Enum):
Whole = 0
Inner = 1
Outer =2
class NegativeSamplingMode(Enum):
# 'binary': Randomly sample negative edges in the graph.
binary = 'binary'
# 'triplet': Randomly sample negative destination nodes for each positive
# source node.
triplet = 'triplet'
class NegativeSampling:
r"""The negative sampling configuration of a
:class:`~torch_geometric.sampler.BaseSampler` when calling
:meth:`~torch_geometric.sampler.BaseSampler.sample_from_edges`.
Args:
mode (str): The negative sampling mode
(:obj:`"binary"` or :obj:`"triplet"`).
If set to :obj:`"binary"`, will randomly sample negative links
from the graph.
If set to :obj:`"triplet"`, will randomly sample negative
destination nodes for each positive source node.
amount (int or float, optional): The ratio of sampled negative edges to
the number of positive edges. (default: :obj:`1`)
weight (torch.Tensor, optional): A node-level vector determining the
sampling of nodes. Does not necessariyl need to sum up to one.
If not given, negative nodes will be sampled uniformly.
(default: :obj:`None`)
"""
mode: NegativeSamplingMode
amount: Union[int, float] = 1
weight: Optional[Tensor] = None
unique: bool
def __init__(
self,
mode: Union[NegativeSamplingMode, str],
amount: Union[int, float] = 1,
weight: Optional[Tensor] = None,
unique: bool = False
):
self.mode = NegativeSamplingMode(mode)
self.amount = amount
self.weight = weight
self.unique = unique
if self.amount <= 0:
raise ValueError(f"The attribute 'amount' needs to be positive "
f"for '{self.__class__.__name__}' "
f"(got {self.amount})")
if self.is_triplet():
if self.amount != math.ceil(self.amount):
raise ValueError(f"The attribute 'amount' needs to be an "
f"integer for '{self.__class__.__name__}' "
f"with 'triplet' negative sampling "
f"(got {self.amount}).")
self.amount = math.ceil(self.amount)
def is_binary(self) -> bool:
return self.mode == NegativeSamplingMode.binary
def is_triplet(self) -> bool:
return self.mode == NegativeSamplingMode.triplet
def sample(self, num_samples: int,
num_nodes: Optional[int] = None) -> Tensor:
r"""Generates :obj:`num_samples` negative samples."""
if self.weight is None:
if num_nodes is None:
raise ValueError(
f"Cannot sample negatives in '{self.__class__.__name__}' "
f"without passing the 'num_nodes' argument")
return torch.randint(num_nodes, (num_samples, ))
if num_nodes is not None and self.weight.numel() != num_nodes:
raise ValueError(
f"The 'weight' attribute in '{self.__class__.__name__}' "
f"needs to match the number of nodes {num_nodes} "
f"(got {self.weight.numel()})")
return torch.multinomial(self.weight, num_samples, replacement=True)
class SampleOutput:
node: Optional[torch.Tensor] = None
edge_index_list: Optional[List[torch.Tensor]] = None
eid_list: Optional[List[torch.Tensor]] = None
delta_ts_list: Optional[List[torch.Tensor]] = None
metadata: Optional[Any] = None
class BaseSampler(ABC):
r"""An abstract base class that initializes a graph sampler and provides
:meth:`_sample_one_layer_from_nodes`
:meth:`_sample_one_layer_from_nodes_parallel`
:meth:`sample_from_nodes` routines.
"""
def sample_from_nodes(
self,
nodes: torch.Tensor,
with_outer_sample: SampleType,
**kwargs
) -> Tuple[torch.Tensor, list]:
r"""Performs mutilayer sampling from the nodes specified in: nodes
The specific number of layers is determined by parameter: num_layers
returning a sampled subgraph in the specified output format: Tuple[torch.Tensor, list].
Args:
nodes: the list of seed nodes index
with_outer_sample: 0-sample in whole graph structure; 1-sample onehop outer nodel; 2-cross partition sampling
**kwargs: other kwargs
Returns:
sampled_nodes: the nodes sampled
sampled_edge_index_list: the edges sampled
"""
raise NotImplementedError
def sample_from_edges(
self,
edges: torch.Tensor,
with_outer_sample: SampleType,
edge_label: Optional[torch.Tensor] = None,
neg_sampling: Optional[NegativeSampling] = None
) -> Tuple[torch.Tensor, list]:
r"""Performs sampling from the edges specified in :obj:`index`,
returning a sampled subgraph in the specified output format.
Args:
edges: the list of seed edges index
with_outer_sample: 0-sample in whole graph structure; 1-sample onehop outer nodel; 2-cross partition sampling
edge_label: the label for the seed edges.
neg_sampling: The negative sampling configuration
Returns:
sampled_nodes: the nodes sampled
sampled_edge_index_list: the edges sampled
metadata: other infomation
"""
raise NotImplementedError
# def _sample_one_layer_from_nodes(
# self,
# nodes:torch.Tensor,
# **kwargs
# ) -> Tuple[torch.Tensor, torch.Tensor]:
# r"""Performs sampling from the nodes specified in: nodes,
# returning a sampled subgraph in the specified output format: Tuple[torch.Tensor, torch.Tensor].
# Args:
# nodes: the list of seed nodes index
# **kwargs: other kwargs
# Returns:
# sampled_nodes: the nodes sampled
# sampled_edge_index: the edges sampled
# """
# raise NotImplementedError
# def _sample_one_layer_from_nodes_parallel(
# self,
# nodes: torch.Tensor,
# **kwargs
# ) -> Tuple[torch.Tensor, torch.Tensor]:
# r"""Performs sampling paralleled from the nodes specified in: nodes,
# returning a sampled subgraph in the specified output format: Tuple[torch.Tensor, torch.Tensor].
# Args:
# nodes: the list of seed nodes index
# **kwargs: other kwargs
# Returns:
# sampled_nodes: the nodes sampled
# sampled_edge_index: the edges sampled
# """
# raise NotImplementedError
from enum import Enum
import sys
import argparse
from os.path import abspath, join, dirname
import time
sys.path.insert(0, join(abspath(dirname(__file__))))
class SampleType(Enum):
Whole = 0
Inner = 1
Outer =2
parser = argparse.ArgumentParser(
description="RPC Reinforcement Learning Example",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument('--world_size', default=1, type=int, metavar='W',
help='number of workers')
parser.add_argument('--rank', default=0, type=int, metavar='W',
help='rank of the worker')
parser.add_argument('--log_interval', type=int, default=10, metavar='N',
help='interval between training status logs')
parser.add_argument('--gamma', type=float, default=0.99, metavar='G',
help='how much to value future rewards')
parser.add_argument('--seed', type=int, default=1, metavar='S',
help='random seed for reproducibility')
parser.add_argument('--num_sampler', type=int, default=1, metavar='S',
help='number of samplers')
parser.add_argument('--queue_size', type=int, default=10000, metavar='S',
help='sampler queue size')
#parser = distparser.parser.add_subparsers().add_parser("train")#argparse.ArgumentParser(description='minibatch_gnn_models')
parser.add_argument('--data', type=str, help='dataset name')
parser.add_argument('--config', type=str, help='path to config file')
parser.add_argument('--gpu', type=str, default='0', help='which GPU to use')
parser.add_argument('--model_name', type=str, default='', help='name of stored model')
parser.add_argument('--rand_edge_features', type=int, default=0, help='use random edge featrues')
parser.add_argument('--rand_node_features', type=int, default=0, help='use random node featrues')
parser.add_argument('--eval_neg_samples', type=int, default=1, help='how many negative samples to use at inference. Note: this will change the metric of test set to AP+AUC to AP+MRR!')
args = parser.parse_args()
rpc_proxy=None
WORKER_RANK = args.rank
NUM_SAMPLER = args.num_sampler
WORLD_SIZE = args.world_size
QUEUE_SIZE = args.queue_size
MAX_QUEUE_SIZE = 5*args.queue_size
RPC_NAME = "rpcserver{}"
SAMPLE_TYPE = SampleType.Outer
def _get_worker_rank():
return WORKER_RANK
def _get_num_sampler():
return NUM_SAMPLER
def _get_world_size():
return WORLD_SIZE
def _get_RPC_NAME():
return RPC_NAME
def _get_queue_size():
return QUEUE_SIZE
def _get_max_queue_size():
return MAX_QUEUE_SIZE
def _get_rpc_name():
return RPC_NAME
\ No newline at end of file
import sys
from os.path import abspath, join, dirname
import time
sys.path.insert(0, join(abspath(dirname(__file__))))
graph_set={}
def _clear_all(barrier = None):
global graph_set
for key in graph_set:
graph = graph_set[key]
graph._close_graph_in_shame()
print('clear ',key)
if(barrier is not None and barrier.wait()==0):
graph._unlink_graph_in_shame()
graph_set = {}
def _set_graph(graph_name,graph_info):
graph_info._get_graph_from_shm()
graph_set[graph_name]=graph_info
def _get_graph(graph_name):
return graph_set[graph_name]
def _del_graph(graph_name):
graph_set.pop(graph_name)
def _get_size():
return len(graph_set)
# local_sampler=None
local_sampler = {}
def set_local_sampler(graph_name,sampler):
local_sampler[graph_name] = sampler
def get_local_sampler(sampler_name):
assert sampler_name in local_sampler, 'Local_sampler doesn\'t has sampler_name'
return local_sampler[sampler_name]
\ No newline at end of file
import torch
import time
from Utils import GraphData
seed = 10 # 你可以选择任何整数作为种子
torch.manual_seed(seed)
num_nodes1 = 10
fanout1 = [2]
edge_index1 = torch.tensor([[1, 5, 7, 9, 2, 4, 6, 7, 8, 0, 1, 6, 2, 0, 1, 3, 5, 8, 9, 7, 4, 8, 2, 3, 5, 8],
[0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 3, 3, 4, 4, 4, 4, 4, 5, 6, 6, 7, 7, 8, 9]])
edge_ts = torch.tensor([1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 4, 4, 4, 4, 4, 5, 5, 5, 6, 6, 6, 6]).double()
edge_weight1 = torch.tensor([2, 1, 2, 1, 8, 6, 3, 1, 1, 1, 1, 5, 1, 1, 2, 1, 1, 1, 1, 5, 1, 2, 2, 2, 1, 1]).double()
edge_weight1 = None
g_data = GraphData(id=0, edge_index=edge_index1, timestamp=edge_ts, data=None, partptr=torch.tensor([0, num_nodes1]))
from neighbor_sampler import NeighborSampler, SampleType
pre = time.time()
# from neighbor_sampler import get_neighbors, update_edge_weight
# row, col = edge_index1
# tnb = get_neighbors(row.contiguous(), col.contiguous(), num_nodes1, edge_weight1)
# print("tnb.neighbors:", tnb.neighbors)
# print("tnb.deg:", tnb.deg)
# print("tnb.weight:", tnb.edge_weight)
sampler = NeighborSampler(num_nodes1,
num_layers=1,
fanout=fanout1,
edge_weight=edge_weight1,
graph_data=g_data,
workers=2,
graph_name='a',
is_distinct = 0,
policy="recent")
end = time.time()
print("init time:", end-pre)
print("tnb.neighbors:", sampler.tnb.neighbors)
print("tnb.deg:", sampler.tnb.deg)
print("tnb.ts:", sampler.tnb.timestamp)
print("tnb.weight:", sampler.tnb.edge_weight)
# update_edge_row = row
# update_edge_col = col
# update_edge_w = torch.DoubleTensor([i for i in range(edge_weight1.size(0))])
# print('tnb.edge_weight:', tnb.edge_weight)
# print('begin update')
# pre = time.time()
# update_edge_weight(tnb, update_edge_row.contiguous(), update_edge_col.contiguous(), update_edge_w.contiguous())
# end = time.time()
# print("update time:", end-pre)
# print('update_edge_row:', update_edge_row)
# print('update_edge_col:', update_edge_col)
# print('tnb.edge_weight:', tnb.edge_weight)
pre = time.time()
out = sampler.sample_from_nodes(torch.tensor([6,7]),
with_outer_sample=SampleType.Whole,
ts=torch.tensor([9, 9]))
end = time.time()
# print('node:', out.node)
# print('edge_index_list:', out.edge_index_list)
# print('eid_list:', out.eid_list)
# print('eid_ts_list:', out.eid_ts_list)
print("sample time:", end-pre)
print("tot_time", out[0].tot_time)
print("sam_time", out[0].sample_time)
print("sam_edge", out[0].sample_edge_num)
print('eid_list:', out[0].eid)
print('delta_ts_list:', out[0].delta_ts)
print((out[0].sample_nodes<10000).sum())
print('node:', out[0].sample_nodes)
print('node_ts:', out[0].sample_nodes_ts)
\ No newline at end of file
import torch
import time
from Utils import GraphData
seed = 10 # 你可以选择任何整数作为种子
torch.manual_seed(seed)
num_nodes1 = 10
fanout1 = [2,2] # index 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25
edge_index1 = torch.tensor([[1, 5, 7, 9, 2, 4, 6, 7, 8, 0, 1, 6, 2, 0, 1, 3, 5, 8, 9, 7, 4, 8, 2, 3, 5, 8],
[0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 3, 3, 4, 4, 4, 4, 4, 5, 6, 6, 7, 7, 8, 9]])
edge_weight1 = torch.tensor([2, 1, 2, 1, 8, 6, 3, 1, 1, 1, 1, 5, 1, 1, 2, 1, 1, 1, 1, 5, 1, 2, 2, 2, 1, 1]).double()
g_data = GraphData(id=0, edge_index=edge_index1, data=None, partptr=torch.tensor([0, num_nodes1]))
edge_weight1 = None
# g_data.eid=None
from neighbor_sampler import NeighborSampler, SampleType
pre = time.time()
sampler = NeighborSampler(num_nodes1,
num_layers=2,
fanout=fanout1,
edge_weight=edge_weight1,
graph_data=g_data,
workers=2,
graph_name='a',
policy="uniform")
end = time.time()
print("init time:", end-pre)
print("tnb.neighbors:", sampler.tnb.neighbors)
print("tnb.eid:", sampler.tnb.eid)
print("tnb.deg:", sampler.tnb.deg)
print("tnb.weight:", sampler.tnb.edge_weight)
# row,col = edge_index1
# update_edge_row = row
# update_edge_col = col
# update_edge_w = torch.FloatTensor([i for i in range(edge_weight1.size(0))])
# print('tnb.edge_weight:', sampler.tnb.edge_weight)
# print('begin update')
# pre = time.time()
# sampler.tnb.update_edge_weight(sampler.tnb, update_edge_row.contiguous(), update_edge_col.contiguous(), update_edge_w.contiguous())
# end = time.time()
# print("update time:", end-pre)
# print('update_edge_row:', update_edge_row)
# print('update_edge_col:', update_edge_col)
# print('tnb.edge_weight:', sampler.tnb.edge_weight)
pre = time.time()
out = sampler.sample_from_nodes(torch.tensor([1,2]),
with_outer_sample=SampleType.Whole)# sampler.sample_from_nodes(torch.masked_select(torch.arange(g.num_nodes),node_data['train_mask']))
end = time.time()
print('node1:\t', out[0].sample_nodes().tolist())
print('eid1:\t', out[0].eid().tolist())
print('edge1:\t', edge_index1[:, out[0].eid()].tolist())
print('node2:\t', out[1].sample_nodes().tolist())
print('eid2:\t', out[1].eid().tolist())
print('edge2:\t', edge_index1[:, out[1].eid()].tolist())
print("sample time:", end-pre)
\ No newline at end of file
import sys
from os.path import abspath, join, dirname
sys.path.insert(0, join(abspath(dirname(__file__))))
import math
import torch
import torch.multiprocessing as mp
from typing import Optional, Tuple
import graph_store
from distparser import SampleType, NUM_SAMPLER
from base import BaseSampler, NegativeSampling, SampleOutput
from sample_cores import ParallelSampler, get_neighbors, heads_unique
from torch.distributed.rpc import rpc_async
def outer_sample(graph_name, nodes, ts, fanout_index, with_outer_sample = SampleType.Outer):# 默认此时继续向外采样
local_sampler = graph_store.get_local_sampler(graph_name)
assert local_sampler is not None, 'Local_sampler is None!!!'
out = local_sampler.sample_from_nodes(nodes, with_outer_sample, ts, fanout_index)
return out
class NeighborSampler(BaseSampler):
def __init__(
self,
num_nodes: int,
num_layers: int,
fanout: list,
graph_data,
workers = 1,
tnb = None,
is_distinct = 0,
policy = "uniform",
edge_weight: Optional[torch.Tensor] = None,
graph_name = None
) -> None:
r"""__init__
Args:
num_nodes: the num of all nodes in the graph
num_layers: the num of layers to be sampled
fanout: the list of max neighbors' number chosen for each layer
workers: the number of threads, default value is 1
tnb: neighbor infomation table
is_distinct: 1-need distinct muti-edge, 0-don't need distinct muti-edge
policy: "uniform" or "recent" or "weighted"
edge_weight: the initial weights of edges
graph_name: the name of graph
should provide edge_index or (neighbors, deg)
"""
super().__init__()
self.num_layers = num_layers
# 线程数不超过torch默认的omp线程数
self.workers = workers # min(workers, torch.get_num_threads())
self.fanout = fanout
self.num_nodes = num_nodes
self.graph_data=graph_data
self.policy = policy
self.is_distinct = is_distinct
assert graph_name is not None
self.graph_name = graph_name
if(tnb is None):
if(graph_data.edge_ts is not None):
timestamp,ind = graph_data.edge_ts.sort()
timestamp = timestamp.float().contiguous()
eid = graph_data.eid[ind].contiguous()
row, col = graph_data.edge_index[:,ind]
else:
eid = graph_data.eid
timestamp = None
row, col = graph_data.edge_index
if(edge_weight is not None):
edge_weight = edge_weight.float().contiguous()
self.tnb = get_neighbors(graph_name, row.contiguous(), col.contiguous(), num_nodes, is_distinct, eid, edge_weight, timestamp)
else:
assert tnb is not None
self.tnb = tnb
self.p_sampler = ParallelSampler(self.tnb, num_nodes, graph_data.num_edges, workers,
fanout, num_layers, policy)
def _get_sample_info(self):
return self.num_nodes,self.num_layers,self.fanout,self.workers
def _get_sample_options(self):
return {"is_distinct" : self.is_distinct,
"policy" : self.policy,
"with_eid" : self.tnb.with_eid,
"weighted" : self.tnb.weighted,
"with_timestamp" : self.tnb.with_timestamp}
def insert_edges_with_timestamp(
self,
edge_index : torch.Tensor,
eid : torch.Tensor,
timestamp : torch.Tensor,
edge_weight : Optional[torch.Tensor] = None):
row, col = edge_index
# 更新节点数和tnb
self.num_nodes = self.tnb.update_neighbors_with_time(
row.contiguous(),
col.contiguous(),
timestamp.contiguous(),
eid.contiguous(),
self.is_distinct,
edge_weight.contiguous())
def update_edges_weight(
self,
edge_index : torch.Tensor,
eid : torch.Tensor,
edge_weight : Optional[torch.Tensor] = None):
row, col = edge_index
# 更新tnb的权重信息
if self.tnb.with_eid:
self.tnb.update_edge_weight(
eid.contiguous(),
col.contiguous(),
edge_weight.contiguous()
)
else:
self.tnb.update_edge_weight(
row.contiguous(),
col.contiguous(),
edge_weight.contiguous()
)
def update_nodes_weight(
self,
nid : torch.Tensor,
node_weight : Optional[torch.Tensor] = None):
# 更新tnb的权重信息
self.tnb.update_node_weight(
nid.contiguous(),
node_weight.contiguous()
)
def update_all_node_weight(
self,
node_weight : torch.Tensor):
# 更新tnb的权重信息
self.tnb.update_all_node_weight(node_weight.contiguous())
def sample_from_nodes(
self,
nodes: torch.Tensor,
ts: Optional[torch.Tensor] = None,
with_outer_sample: SampleType = SampleType.Whole
) -> SampleOutput:
r"""Performs mutilayer sampling from the nodes specified in: nodes
The specific number of layers is determined by parameter: num_layers
returning a sampled subgraph in the specified output format: Tuple[torch.Tensor, list].
Args:
nodes: the list of seed nodes index,
ts: the timestamp of nodes, optional,
with_outer_sample: 0-sample in whole graph structure; 1-sample onehop outer nodel; 2-cross partition sampling
fanout_index: optional. Specify the index to fanout
Returns:
sampled_nodes: the node sampled
sampled_edge_index_list: the edge sampled
"""
if(ts is None):
self.p_sampler.neighbor_sample_from_nodes(nodes.contiguous(), None)
ret = self.p_sampler.get_ret()
return ret
else:
self.p_sampler.neighbor_sample_from_nodes(nodes.contiguous(), ts.float().contiguous())
ret = self.p_sampler.get_ret()
return ret
def sample_from_edges(
self,
edges: torch.Tensor,
ets: Optional[torch.Tensor] = None,
neg_sampling: Optional[NegativeSampling] = None,
with_outer_sample: SampleType = SampleType.Whole
) -> SampleOutput:
r"""Performs sampling from the edges specified in :obj:`index`,
returning a sampled subgraph in the specified output format.
Args:
edges: the list of seed edges index
with_outer_sample: 0-sample in whole graph structure; 1-sample onehop outer nodel; 2-cross partition sampling
ets: the timestamp of edges, optional
neg_sampling: The negative sampling configuration
Returns:
sampled_edge_index_list: the edges sampled
sampled_eid_list: the edges' id sampled
sampled_delta_ts_list:the edges' delta time sampled
metadata: other infomation
"""
src, dst = edges
num_pos = src.numel()
num_neg = 0
with_timestap = ets is not None
seed_ts = None
if neg_sampling is not None:
num_neg = math.ceil(num_pos * neg_sampling.amount)
if neg_sampling.is_binary():
src_neg = neg_sampling.sample(num_neg, self.num_nodes)
dst_neg = neg_sampling.sample(num_neg, self.num_nodes)
seed = torch.cat([src, dst, src_neg, dst_neg], dim=0)
if with_timestap: # ts操作
seed_ts = torch.cat([ets, ets, ets, ets], dim=0)
if neg_sampling.is_triplet():
src_neg = neg_sampling.sample(num_neg, self.num_nodes)
seed = torch.cat([src, dst, src_neg], dim=0)
if with_timestap: # ts操作
seed_ts = torch.cat([ets, ets, ets], dim=0)
else:
seed = torch.cat([src, dst], dim=0)
if with_timestap: # ts操作
seed_ts = torch.cat([ets, ets], dim=0)
# 去重负采样
if neg_sampling is not None and neg_sampling.unique:
if with_timestap: # ts操作
pair, inverse_seed= torch.unique(torch.stack([seed, seed_ts],0), return_inverse=True, dim=1)
seed, seed_ts = pair
seed = seed.long()
else:
seed, inverse_seed = seed.unique(return_inverse=True)
out = self.sample_from_nodes(seed, seed_ts, with_outer_sample)
if neg_sampling is None or (not neg_sampling.unique):
if with_timestap:
return out, {'seed':seed,'seed_ts':seed_ts,
'src_pos_index':torch.arange(0,num_pos),
'dst_pos_index':torch.arange(num_pos,2*num_pos),
'src_neg_index':torch.arange(2*num_pos,3*num_pos)}
else:
return out, {'seed':seed,
'src_pos_index':slice(0,num_pos),
'dst_pos_index':slice(num_pos,2*num_pos),
'src_neg_index':slice(2*num_pos,3*num_pos)}
metadata = {}
if neg_sampling.is_binary():
src_pos_index = inverse_seed[:num_pos]
dst_pos_index = inverse_seed[num_pos:2 * num_pos]
src_neg_index = inverse_seed[2 * num_pos:3 * num_pos]
src_neg_index = src_neg_index.view(num_pos, -1).squeeze(-1)
dst_neg_index = inverse_seed[3 * num_pos:]
dst_neg_index = dst_neg_index.view(num_pos, -1).squeeze(-1)
metadata = {'seed':seed, 'src_neg_index':src_neg_index, 'dst_pos_index':dst_pos_index, 'dst_neg_index':dst_neg_index}
if with_timestap:
metadata['seed_ts'] = seed_ts
elif neg_sampling.is_triplet():
src_pos_index = inverse_seed[:num_pos]
dst_pos_index = inverse_seed[num_pos:2 * num_pos]
src_neg_index = inverse_seed[2 * num_pos:]
src_neg_index = src_neg_index.view(num_pos, -1).squeeze(-1)
# src_index是seed里src点的索引
# dst_pos_index是seed里dst_pos点的索引
# dst_neg_index是seed里dst_neg点的索引
metadata = {'seed':seed, 'src_pos_index':src_pos_index, 'src_neg_index':src_neg_index, 'dst_pos_index':dst_pos_index}
if with_timestap:
metadata['seed_ts'] = seed_ts
# sampled_nodes最前方是原始序列的采样起点也就是去重后的seed
return out, metadata
if __name__=="__main__":
# edge_index1 = torch.tensor([[0, 1, 1, 1, 2, 2, 2, 4, 4, 4, 5], # , 3, 3
# [1, 0, 2, 4, 1, 3, 0, 3, 5, 0, 2]])# , 2, 5
edge_index1 = torch.tensor([[0, 1, 1, 1, 1, 2, 2, 2, 2, 4, 4, 4, 5], # , 3, 3
[1, 0, 2, 0, 4, 1, 3, 0, 3, 3, 5, 0, 2]])# , 2, 5
edge_weight1 = None
timeStamp=torch.FloatTensor([1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4])
num_nodes1 = 6
num_neighbors = 2
# Run the neighbor sampling
from Utils import GraphData
g_data = GraphData(id=0, edge_index=edge_index1, timestamp=timeStamp, data=None, partptr=torch.tensor([0, num_nodes1]))
sampler = NeighborSampler(num_nodes=num_nodes1,
num_layers=3,
fanout=[2, 1, 1],
edge_weight=edge_weight1,
graph_data=g_data,
graph_name='a',
workers=4,
is_distinct = 0)
out = sampler.sample_from_nodes(torch.tensor([1,2]),
ts=torch.tensor([1, 2]),
with_outer_sample=SampleType.Whole)
# out = sampler.sample_from_edges(torch.tensor([[1,2],[4,0]]),
# with_outer_sample=SampleType.Whole,
# ets = torch.tensor([1, 2]))
# Print the result
print('node:', out.node)
print('edge_index_list:', out.edge_index_list)
print('eid_list:', out.eid_list)
print('delta_ts_list:', out.delta_ts_list)
print('metadata: ', out.metadata)
import torch
from ogb.nodeproppred import PygNodePropPredDataset
from torch_geometric import datasets
import time
from Utils import GraphData
def load_ogb_dataset(name, data_path):
dataset = PygNodePropPredDataset(name=name, root=data_path)
split_idx = dataset.get_idx_split()
g = dataset[0]
n_node = g.num_nodes
node_data={}
node_data['train_mask'] = torch.zeros(n_node, dtype=torch.bool)
node_data['val_mask'] = torch.zeros(n_node, dtype=torch.bool)
node_data['test_mask'] = torch.zeros(n_node, dtype=torch.bool)
node_data['train_mask'][split_idx["train"]] = True
node_data['val_mask'][split_idx["valid"]] = True
node_data['test_mask'][split_idx["test"]] = True
return g, node_data
g, node_data = load_ogb_dataset('ogbn-products', "/home/zlj/hzq/code/gnn/dataset/")
print(g)
# for worker in [1,2,3,4,5,6,7,8,9,10,20,30]:
# import random
# timestamp = [random.randint(1, 5) for i in range(0, g.num_edges)]
# timestamp = torch.FloatTensor(timestamp)
print('begin load')
pre = time.time()
timestamp = torch.load('/home/zlj/hzq/code/gnn/my_sampler/TemporalSample/timestamp.my')
tnb = torch.load("tnb_before.my")
end = time.time()
print("load time:", end-pre)
row, col = g.edge_index
edge_weight=None
g_data = GraphData(id=1, edge_index=g.edge_index, timestamp=timestamp, data=g, partptr=torch.tensor([0, g.num_nodes//4, g.num_nodes//4*2, g.num_nodes//4*3, g.num_nodes]))
from neighbor_sampler import NeighborSampler, SampleType, get_neighbors
# print('begin tnb')
# pre = time.time()
# tnb = get_neighbors(row.contiguous(), col.contiguous(), g.num_nodes, 0, g_data.eid, edge_weight, timestamp)
# end = time.time()
# print("init tnb time:", end-pre)
# torch.save(tnb, "tnb_before.my")
pre = time.time()
sampler = NeighborSampler(g.num_nodes,
tnb=tnb,
num_layers=2,
fanout=[100,100],
graph_data=g_data,
workers=10,
policy="uniform",
is_root_ts=0,
graph_name='a')
end = time.time()
print("init time:", end-pre)
# from torch_geometric.sampler import NeighborSampler, NumNeighbors, NodeSamplerInput, SamplerOutput
# pre = time.time()
# num_nei = NumNeighbors([100, 100])
# node_idx = NodeSamplerInput(input_id=None, node=torch.tensor(range(g.num_nodes//4, g.num_nodes//4+600000)))# (input_id=None, node=torch.masked_select(torch.arange(g.num_nodes),node_data['train_mask']))
# sampler = NeighborSampler(g, num_nei)
# end = time.time()
# print("init time:", end-pre)
ts = torch.tensor([i%5+1 for i in range(0, 600000)])
pre = time.time()
out = sampler.sample_from_nodes(torch.tensor(range(g.num_nodes//4, g.num_nodes//4+600000)),
ts=ts,
with_outer_sample=SampleType.Inner)# sampler.sample_from_nodes(torch.masked_select(torch.arange(g.num_nodes),node_data['train_mask']))
# out = sampler.sample_from_nodes(node_idx)
# node = out.node
# edge = [out.row, out.col]
end = time.time()
print('node:', out.node)
print('edge_index_list:', out.edge_index_list)
print('eid_list:', out.eid_list)
print('eid_ts_list:', out.eid_ts_list)
print("sample time", end-pre)
\ No newline at end of file
import torch
from ogb.nodeproppred import PygNodePropPredDataset
from torch_geometric import datasets
import time
from Utils import GraphData
def load_ogb_dataset(name, data_path):
dataset = PygNodePropPredDataset(name=name, root=data_path)
split_idx = dataset.get_idx_split()
g = dataset[0]
n_node = g.num_nodes
node_data={}
node_data['train_mask'] = torch.zeros(n_node, dtype=torch.bool)
node_data['val_mask'] = torch.zeros(n_node, dtype=torch.bool)
node_data['test_mask'] = torch.zeros(n_node, dtype=torch.bool)
node_data['train_mask'][split_idx["train"]] = True
node_data['val_mask'][split_idx["valid"]] = True
node_data['test_mask'][split_idx["test"]] = True
return g, node_data
g, node_data = load_ogb_dataset('ogbn-products', "/home/zlj/hzq/code/gnn/dataset/")
print(g)
# for worker in [1,2,3,4,5,6,7,8,9,10,20,30]:
g_data = GraphData(id=1, edge_index=g.edge_index, data=g, partptr=torch.tensor([0, g.num_nodes//4, g.num_nodes//4*2, g.num_nodes//4*3, g.num_nodes]))
row, col = g.edge_index
# edge_weight = torch.ones(g.num_edges).float()
# indices = [x for x in range(0, g.num_edges, 5)]
# edge_weight[indices] = 2.0
# g_data.eid = None
edge_weight = None
timestamp = None
from neighbor_sampler import NeighborSampler, SampleType
from neighbor_sampler import get_neighbors
update_edge_row = row
update_edge_col = col
update_edge_w = torch.DoubleTensor([i for i in range(g.num_edges)])
print('begin update')
pre = time.time()
# update_edge_weight(tnb, update_edge_row.contiguous(), update_edge_col.contiguous(), update_edge_w.contiguous())
end = time.time()
print("update time:", end-pre)
print('begin tnb')
pre = time.time()
tnb = get_neighbors("a",
row.contiguous(),
col.contiguous(),
g.num_nodes, 0,
g_data.eid,
edge_weight,
timestamp)
end = time.time()
print("init tnb time:", end-pre)
torch.save(tnb, "/home/zlj/hzq/code/gnn/my_sampler/MergeSample/tnb_static.my")
# print('begin load')
# pre = time.time()
# tnb = torch.load("/home/zlj/hzq/code/gnn/my_sampler/MergeSample/tnb_static.my")
# end = time.time()
# print("load time:", end-pre)
print('begin init')
pre = time.time()
sampler = NeighborSampler(g.num_nodes,
tnb = tnb,
num_layers=2,
fanout=[100,100],
graph_data=g_data,
workers=10,
graph_name='a',
policy="uniform")
end = time.time()
print("init time:", end-pre)
# from torch_geometric.sampler import NeighborSampler, NumNeighbors, NodeSamplerInput, SamplerOutput
# pre = time.time()
# num_nei = NumNeighbors([100, 100])
# node_idx = NodeSamplerInput(input_id=None, node=torch.tensor(range(g.num_nodes//4, g.num_nodes//4+600000)))# (input_id=None, node=torch.masked_select(torch.arange(g.num_nodes),node_data['train_mask']))
# sampler = NeighborSampler(g, num_nei)
# end = time.time()
# print("init time:", end-pre)
pre = time.time()
out = sampler.sample_from_nodes(torch.tensor(range(g.num_nodes//4, g.num_nodes//4+600000)), with_outer_sample=SampleType.Inner)# sampler.sample_from_nodes(torch.masked_select(torch.arange(g.num_nodes),node_data['train_mask']))
# out = sampler.sample_from_nodes(node_idx)
# node = out.node
# edge = [out.row, out.col]
end = time.time()
print('node1:\t', out[0].sample_nodes())
print('eid1:\t', out[0].eid())
print('edge1:\t', g.edge_index[:, out[0].eid()])
print('node2:\t', out[1].sample_nodes())
print('eid2:\t', out[1].eid())
print('edge2:\t', g.edge_index[:, out[1].eid()])
print("sample time", end-pre)
\ No newline at end of file
This source diff could not be displayed because it is too large. You can view the blob instead.
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from conans import ConanFile, tools
import os
class SparseppConan(ConanFile):
name = "parallel_hashmap"
version = "1.27"
description = "A header-only, very fast and memory-friendly hash map"
# Indicates License type of the packaged library
license = "https://github.com/greg7mdp/parallel-hashmap/blob/master/LICENSE"
# Packages the license for the conanfile.py
exports = ["LICENSE"]
# Custom attributes for Bincrafters recipe conventions
source_subfolder = "source_subfolder"
def source(self):
source_url = "https://github.com/greg7mdp/parallel-hashmap"
tools.get("{0}/archive/{1}.tar.gz".format(source_url, self.version))
extracted_dir = self.name + "-" + self.version
#Rename to "source_folder" is a convention to simplify later steps
os.rename(extracted_dir, self.source_subfolder)
def package(self):
include_folder = os.path.join(self.source_subfolder, "parallel_hashmap")
self.copy(pattern="LICENSE")
self.copy(pattern="*", dst="include/parallel_hashmap", src=include_folder)
def package_id(self):
self.info.header_only()
#if !defined(spp_memory_h_guard)
#define spp_memory_h_guard
#include <cstdint>
#include <cstring>
#include <cstdlib>
#if defined(_WIN32) || defined( __CYGWIN__)
#define SPP_WIN
#endif
#ifdef SPP_WIN
#include <windows.h>
#include <Psapi.h>
#undef min
#undef max
#elif defined(__linux__)
#include <sys/types.h>
#include <sys/sysinfo.h>
#elif defined(__FreeBSD__)
#include <paths.h>
#include <fcntl.h>
#include <kvm.h>
#include <unistd.h>
#include <sys/sysctl.h>
#include <sys/user.h>
#endif
namespace spp
{
uint64_t GetSystemMemory();
uint64_t GetTotalMemoryUsed();
uint64_t GetProcessMemoryUsed();
uint64_t GetPhysicalMemory();
uint64_t GetSystemMemory()
{
#ifdef SPP_WIN
MEMORYSTATUSEX memInfo;
memInfo.dwLength = sizeof(MEMORYSTATUSEX);
GlobalMemoryStatusEx(&memInfo);
return static_cast<uint64_t>(memInfo.ullTotalPageFile);
#elif defined(__linux__)
struct sysinfo memInfo;
sysinfo (&memInfo);
auto totalVirtualMem = memInfo.totalram;
totalVirtualMem += memInfo.totalswap;
totalVirtualMem *= memInfo.mem_unit;
return static_cast<uint64_t>(totalVirtualMem);
#elif defined(__FreeBSD__)
kvm_t *kd;
u_int pageCnt;
size_t pageCntLen = sizeof(pageCnt);
u_int pageSize;
struct kvm_swap kswap;
uint64_t totalVirtualMem;
pageSize = static_cast<u_int>(getpagesize());
sysctlbyname("vm.stats.vm.v_page_count", &pageCnt, &pageCntLen, NULL, 0);
totalVirtualMem = pageCnt * pageSize;
kd = kvm_open(NULL, _PATH_DEVNULL, NULL, O_RDONLY, "kvm_open");
kvm_getswapinfo(kd, &kswap, 1, 0);
kvm_close(kd);
totalVirtualMem += kswap.ksw_total * pageSize;
return totalVirtualMem;
#else
return 0;
#endif
}
uint64_t GetTotalMemoryUsed()
{
#ifdef SPP_WIN
MEMORYSTATUSEX memInfo;
memInfo.dwLength = sizeof(MEMORYSTATUSEX);
GlobalMemoryStatusEx(&memInfo);
return static_cast<uint64_t>(memInfo.ullTotalPageFile - memInfo.ullAvailPageFile);
#elif defined(__linux__)
struct sysinfo memInfo;
sysinfo(&memInfo);
auto virtualMemUsed = memInfo.totalram - memInfo.freeram;
virtualMemUsed += memInfo.totalswap - memInfo.freeswap;
virtualMemUsed *= memInfo.mem_unit;
return static_cast<uint64_t>(virtualMemUsed);
#elif defined(__FreeBSD__)
kvm_t *kd;
u_int pageSize;
u_int pageCnt, freeCnt;
size_t pageCntLen = sizeof(pageCnt);
size_t freeCntLen = sizeof(freeCnt);
struct kvm_swap kswap;
uint64_t virtualMemUsed;
pageSize = static_cast<u_int>(getpagesize());
sysctlbyname("vm.stats.vm.v_page_count", &pageCnt, &pageCntLen, NULL, 0);
sysctlbyname("vm.stats.vm.v_free_count", &freeCnt, &freeCntLen, NULL, 0);
virtualMemUsed = (pageCnt - freeCnt) * pageSize;
kd = kvm_open(NULL, _PATH_DEVNULL, NULL, O_RDONLY, "kvm_open");
kvm_getswapinfo(kd, &kswap, 1, 0);
kvm_close(kd);
virtualMemUsed += kswap.ksw_used * pageSize;
return virtualMemUsed;
#else
return 0;
#endif
}
uint64_t GetProcessMemoryUsed()
{
#ifdef SPP_WIN
PROCESS_MEMORY_COUNTERS_EX pmc;
GetProcessMemoryInfo(GetCurrentProcess(), reinterpret_cast<PPROCESS_MEMORY_COUNTERS>(&pmc), sizeof(pmc));
return static_cast<uint64_t>(pmc.PrivateUsage);
#elif defined(__linux__)
auto parseLine =
[](char* line)->int
{
auto i = strlen(line);
while(*line < '0' || *line > '9')
{
line++;
}
line[i-3] = '\0';
i = atoi(line);
return i;
};
auto file = fopen("/proc/self/status", "r");
auto result = -1;
char line[128];
while(fgets(line, 128, file) != nullptr)
{
if(strncmp(line, "VmSize:", 7) == 0)
{
result = parseLine(line);
break;
}
}
fclose(file);
return static_cast<uint64_t>(result) * 1024;
#elif defined(__FreeBSD__)
struct kinfo_proc info;
size_t infoLen = sizeof(info);
int mib[] = { CTL_KERN, KERN_PROC, KERN_PROC_PID, getpid() };
sysctl(mib, sizeof(mib) / sizeof(*mib), &info, &infoLen, NULL, 0);
return static_cast<uint64_t>(info.ki_rssize * getpagesize());
#else
return 0;
#endif
}
uint64_t GetPhysicalMemory()
{
#ifdef SPP_WIN
MEMORYSTATUSEX memInfo;
memInfo.dwLength = sizeof(MEMORYSTATUSEX);
GlobalMemoryStatusEx(&memInfo);
return static_cast<uint64_t>(memInfo.ullTotalPhys);
#elif defined(__linux__)
struct sysinfo memInfo;
sysinfo(&memInfo);
auto totalPhysMem = memInfo.totalram;
totalPhysMem *= memInfo.mem_unit;
return static_cast<uint64_t>(totalPhysMem);
#elif defined(__FreeBSD__)
u_long physMem;
size_t physMemLen = sizeof(physMem);
int mib[] = { CTL_HW, HW_PHYSMEM };
sysctl(mib, sizeof(mib) / sizeof(*mib), &physMem, &physMemLen, NULL, 0);
return physMem;
#else
return 0;
#endif
}
}
#endif // spp_memory_h_guard
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
#if !defined(phmap_bits_h_guard_)
#define phmap_bits_h_guard_
// ---------------------------------------------------------------------------
// Copyright (c) 2019, Gregory Popovitch - greg7mdp@gmail.com
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Includes work from abseil-cpp (https://github.com/abseil/abseil-cpp)
// with modifications.
//
// Copyright 2018 The Abseil Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// ---------------------------------------------------------------------------
// The following guarantees declaration of the byte swap functions
#ifdef _MSC_VER
#include <stdlib.h> // NOLINT(build/include)
#elif defined(__APPLE__)
// Mac OS X / Darwin features
#include <libkern/OSByteOrder.h>
#elif defined(__FreeBSD__)
#include <sys/endian.h>
#elif defined(__GLIBC__)
#include <byteswap.h> // IWYU pragma: export
#endif
#include <string.h>
#include <cstdint>
#include "phmap_config.h"
#ifdef _MSC_VER
#pragma warning(push)
#pragma warning(disable : 4514) // unreferenced inline function has been removed
#endif
// -----------------------------------------------------------------------------
// unaligned APIs
// -----------------------------------------------------------------------------
// Portable handling of unaligned loads, stores, and copies.
// On some platforms, like ARM, the copy functions can be more efficient
// then a load and a store.
// -----------------------------------------------------------------------------
#if defined(ADDRESS_SANITIZER) || defined(THREAD_SANITIZER) ||\
defined(MEMORY_SANITIZER)
#include <stdint.h>
extern "C" {
uint16_t __sanitizer_unaligned_load16(const void *p);
uint32_t __sanitizer_unaligned_load32(const void *p);
uint64_t __sanitizer_unaligned_load64(const void *p);
void __sanitizer_unaligned_store16(void *p, uint16_t v);
void __sanitizer_unaligned_store32(void *p, uint32_t v);
void __sanitizer_unaligned_store64(void *p, uint64_t v);
} // extern "C"
namespace phmap {
namespace bits {
inline uint16_t UnalignedLoad16(const void *p) {
return __sanitizer_unaligned_load16(p);
}
inline uint32_t UnalignedLoad32(const void *p) {
return __sanitizer_unaligned_load32(p);
}
inline uint64_t UnalignedLoad64(const void *p) {
return __sanitizer_unaligned_load64(p);
}
inline void UnalignedStore16(void *p, uint16_t v) {
__sanitizer_unaligned_store16(p, v);
}
inline void UnalignedStore32(void *p, uint32_t v) {
__sanitizer_unaligned_store32(p, v);
}
inline void UnalignedStore64(void *p, uint64_t v) {
__sanitizer_unaligned_store64(p, v);
}
} // namespace bits
} // namespace phmap
#define PHMAP_INTERNAL_UNALIGNED_LOAD16(_p) (phmap::bits::UnalignedLoad16(_p))
#define PHMAP_INTERNAL_UNALIGNED_LOAD32(_p) (phmap::bits::UnalignedLoad32(_p))
#define PHMAP_INTERNAL_UNALIGNED_LOAD64(_p) (phmap::bits::UnalignedLoad64(_p))
#define PHMAP_INTERNAL_UNALIGNED_STORE16(_p, _val) (phmap::bits::UnalignedStore16(_p, _val))
#define PHMAP_INTERNAL_UNALIGNED_STORE32(_p, _val) (phmap::bits::UnalignedStore32(_p, _val))
#define PHMAP_INTERNAL_UNALIGNED_STORE64(_p, _val) (phmap::bits::UnalignedStore64(_p, _val))
#else
namespace phmap {
namespace bits {
inline uint16_t UnalignedLoad16(const void *p) {
uint16_t t;
memcpy(&t, p, sizeof t);
return t;
}
inline uint32_t UnalignedLoad32(const void *p) {
uint32_t t;
memcpy(&t, p, sizeof t);
return t;
}
inline uint64_t UnalignedLoad64(const void *p) {
uint64_t t;
memcpy(&t, p, sizeof t);
return t;
}
inline void UnalignedStore16(void *p, uint16_t v) { memcpy(p, &v, sizeof v); }
inline void UnalignedStore32(void *p, uint32_t v) { memcpy(p, &v, sizeof v); }
inline void UnalignedStore64(void *p, uint64_t v) { memcpy(p, &v, sizeof v); }
} // namespace bits
} // namespace phmap
#define PHMAP_INTERNAL_UNALIGNED_LOAD16(_p) (phmap::bits::UnalignedLoad16(_p))
#define PHMAP_INTERNAL_UNALIGNED_LOAD32(_p) (phmap::bits::UnalignedLoad32(_p))
#define PHMAP_INTERNAL_UNALIGNED_LOAD64(_p) (phmap::bits::UnalignedLoad64(_p))
#define PHMAP_INTERNAL_UNALIGNED_STORE16(_p, _val) (phmap::bits::UnalignedStore16(_p, _val))
#define PHMAP_INTERNAL_UNALIGNED_STORE32(_p, _val) (phmap::bits::UnalignedStore32(_p, _val))
#define PHMAP_INTERNAL_UNALIGNED_STORE64(_p, _val) (phmap::bits::UnalignedStore64(_p, _val))
#endif
// -----------------------------------------------------------------------------
// File: optimization.h
// -----------------------------------------------------------------------------
#if defined(__pnacl__)
#define PHMAP_BLOCK_TAIL_CALL_OPTIMIZATION() if (volatile int x = 0) { (void)x; }
#elif defined(__clang__)
// Clang will not tail call given inline volatile assembly.
#define PHMAP_BLOCK_TAIL_CALL_OPTIMIZATION() __asm__ __volatile__("")
#elif defined(__GNUC__)
// GCC will not tail call given inline volatile assembly.
#define PHMAP_BLOCK_TAIL_CALL_OPTIMIZATION() __asm__ __volatile__("")
#elif defined(_MSC_VER)
#include <intrin.h>
// The __nop() intrinsic blocks the optimisation.
#define PHMAP_BLOCK_TAIL_CALL_OPTIMIZATION() __nop()
#else
#define PHMAP_BLOCK_TAIL_CALL_OPTIMIZATION() if (volatile int x = 0) { (void)x; }
#endif
#if defined(__GNUC__)
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wpedantic"
#endif
#ifdef PHMAP_HAVE_INTRINSIC_INT128
__extension__ typedef unsigned __int128 phmap_uint128;
inline uint64_t umul128(uint64_t a, uint64_t b, uint64_t* high)
{
auto result = static_cast<phmap_uint128>(a) * static_cast<phmap_uint128>(b);
*high = static_cast<uint64_t>(result >> 64);
return static_cast<uint64_t>(result);
}
#define PHMAP_HAS_UMUL128 1
#elif (defined(_MSC_VER))
#if defined(_M_X64)
#pragma intrinsic(_umul128)
inline uint64_t umul128(uint64_t a, uint64_t b, uint64_t* high)
{
return _umul128(a, b, high);
}
#define PHMAP_HAS_UMUL128 1
#endif
#endif
#if defined(__GNUC__)
#pragma GCC diagnostic pop
#endif
#if defined(__GNUC__)
// Cache line alignment
#if defined(__i386__) || defined(__x86_64__)
#define PHMAP_CACHELINE_SIZE 64
#elif defined(__powerpc64__)
#define PHMAP_CACHELINE_SIZE 128
#elif defined(__aarch64__)
// We would need to read special register ctr_el0 to find out L1 dcache size.
// This value is a good estimate based on a real aarch64 machine.
#define PHMAP_CACHELINE_SIZE 64
#elif defined(__arm__)
// Cache line sizes for ARM: These values are not strictly correct since
// cache line sizes depend on implementations, not architectures. There
// are even implementations with cache line sizes configurable at boot
// time.
#if defined(__ARM_ARCH_5T__)
#define PHMAP_CACHELINE_SIZE 32
#elif defined(__ARM_ARCH_7A__)
#define PHMAP_CACHELINE_SIZE 64
#endif
#endif
#ifndef PHMAP_CACHELINE_SIZE
// A reasonable default guess. Note that overestimates tend to waste more
// space, while underestimates tend to waste more time.
#define PHMAP_CACHELINE_SIZE 64
#endif
#define PHMAP_CACHELINE_ALIGNED __attribute__((aligned(PHMAP_CACHELINE_SIZE)))
#elif defined(_MSC_VER)
#define PHMAP_CACHELINE_SIZE 64
#define PHMAP_CACHELINE_ALIGNED __declspec(align(PHMAP_CACHELINE_SIZE))
#else
#define PHMAP_CACHELINE_SIZE 64
#define PHMAP_CACHELINE_ALIGNED
#endif
#if PHMAP_HAVE_BUILTIN(__builtin_expect) || \
(defined(__GNUC__) && !defined(__clang__))
#define PHMAP_PREDICT_FALSE(x) (__builtin_expect(x, 0))
#define PHMAP_PREDICT_TRUE(x) (__builtin_expect(!!(x), 1))
#else
#define PHMAP_PREDICT_FALSE(x) (x)
#define PHMAP_PREDICT_TRUE(x) (x)
#endif
// -----------------------------------------------------------------------------
// File: bits.h
// -----------------------------------------------------------------------------
#if defined(_MSC_VER)
// We can achieve something similar to attribute((always_inline)) with MSVC by
// using the __forceinline keyword, however this is not perfect. MSVC is
// much less aggressive about inlining, and even with the __forceinline keyword.
#define PHMAP_BASE_INTERNAL_FORCEINLINE __forceinline
#else
// Use default attribute inline.
#define PHMAP_BASE_INTERNAL_FORCEINLINE inline PHMAP_ATTRIBUTE_ALWAYS_INLINE
#endif
namespace phmap {
namespace base_internal {
PHMAP_BASE_INTERNAL_FORCEINLINE int CountLeadingZeros64Slow(uint64_t n) {
int zeroes = 60;
if (n >> 32) zeroes -= 32, n >>= 32;
if (n >> 16) zeroes -= 16, n >>= 16;
if (n >> 8) zeroes -= 8, n >>= 8;
if (n >> 4) zeroes -= 4, n >>= 4;
return "\4\3\2\2\1\1\1\1\0\0\0\0\0\0\0"[n] + zeroes;
}
PHMAP_BASE_INTERNAL_FORCEINLINE int CountLeadingZeros64(uint64_t n) {
#if defined(_MSC_VER) && defined(_M_X64)
// MSVC does not have __buitin_clzll. Use _BitScanReverse64.
unsigned long result = 0; // NOLINT(runtime/int)
if (_BitScanReverse64(&result, n)) {
return (int)(63 - result);
}
return 64;
#elif defined(_MSC_VER) && !defined(__clang__)
// MSVC does not have __buitin_clzll. Compose two calls to _BitScanReverse
unsigned long result = 0; // NOLINT(runtime/int)
if ((n >> 32) && _BitScanReverse(&result, (unsigned long)(n >> 32))) {
return 31 - result;
}
if (_BitScanReverse(&result, (unsigned long)n)) {
return 63 - result;
}
return 64;
#elif defined(__GNUC__) || defined(__clang__)
// Use __builtin_clzll, which uses the following instructions:
// x86: bsr
// ARM64: clz
// PPC: cntlzd
static_assert(sizeof(unsigned long long) == sizeof(n), // NOLINT(runtime/int)
"__builtin_clzll does not take 64-bit arg");
// Handle 0 as a special case because __builtin_clzll(0) is undefined.
if (n == 0) {
return 64;
}
return __builtin_clzll(n);
#else
return CountLeadingZeros64Slow(n);
#endif
}
PHMAP_BASE_INTERNAL_FORCEINLINE int CountLeadingZeros32Slow(uint64_t n) {
int zeroes = 28;
if (n >> 16) zeroes -= 16, n >>= 16;
if (n >> 8) zeroes -= 8, n >>= 8;
if (n >> 4) zeroes -= 4, n >>= 4;
return "\4\3\2\2\1\1\1\1\0\0\0\0\0\0\0"[n] + zeroes;
}
PHMAP_BASE_INTERNAL_FORCEINLINE int CountLeadingZeros32(uint32_t n) {
#if defined(_MSC_VER) && !defined(__clang__)
unsigned long result = 0; // NOLINT(runtime/int)
if (_BitScanReverse(&result, n)) {
return (int)(31 - result);
}
return 32;
#elif defined(__GNUC__) || defined(__clang__)
// Use __builtin_clz, which uses the following instructions:
// x86: bsr
// ARM64: clz
// PPC: cntlzd
static_assert(sizeof(int) == sizeof(n),
"__builtin_clz does not take 32-bit arg");
// Handle 0 as a special case because __builtin_clz(0) is undefined.
if (n == 0) {
return 32;
}
return __builtin_clz(n);
#else
return CountLeadingZeros32Slow(n);
#endif
}
PHMAP_BASE_INTERNAL_FORCEINLINE int CountTrailingZerosNonZero64Slow(uint64_t n) {
int c = 63;
n &= ~n + 1;
if (n & 0x00000000FFFFFFFF) c -= 32;
if (n & 0x0000FFFF0000FFFF) c -= 16;
if (n & 0x00FF00FF00FF00FF) c -= 8;
if (n & 0x0F0F0F0F0F0F0F0F) c -= 4;
if (n & 0x3333333333333333) c -= 2;
if (n & 0x5555555555555555) c -= 1;
return c;
}
PHMAP_BASE_INTERNAL_FORCEINLINE int CountTrailingZerosNonZero64(uint64_t n) {
#if defined(_MSC_VER) && !defined(__clang__) && defined(_M_X64)
unsigned long result = 0; // NOLINT(runtime/int)
_BitScanForward64(&result, n);
return (int)result;
#elif defined(_MSC_VER) && !defined(__clang__)
unsigned long result = 0; // NOLINT(runtime/int)
if (static_cast<uint32_t>(n) == 0) {
_BitScanForward(&result, (unsigned long)(n >> 32));
return result + 32;
}
_BitScanForward(&result, (unsigned long)n);
return result;
#elif defined(__GNUC__) || defined(__clang__)
static_assert(sizeof(unsigned long long) == sizeof(n), // NOLINT(runtime/int)
"__builtin_ctzll does not take 64-bit arg");
return __builtin_ctzll(n);
#else
return CountTrailingZerosNonZero64Slow(n);
#endif
}
PHMAP_BASE_INTERNAL_FORCEINLINE int CountTrailingZerosNonZero32Slow(uint32_t n) {
int c = 31;
n &= ~n + 1;
if (n & 0x0000FFFF) c -= 16;
if (n & 0x00FF00FF) c -= 8;
if (n & 0x0F0F0F0F) c -= 4;
if (n & 0x33333333) c -= 2;
if (n & 0x55555555) c -= 1;
return c;
}
PHMAP_BASE_INTERNAL_FORCEINLINE int CountTrailingZerosNonZero32(uint32_t n) {
#if defined(_MSC_VER) && !defined(__clang__)
unsigned long result = 0; // NOLINT(runtime/int)
_BitScanForward(&result, n);
return (int)result;
#elif defined(__GNUC__) || defined(__clang__)
static_assert(sizeof(int) == sizeof(n),
"__builtin_ctz does not take 32-bit arg");
return __builtin_ctz(n);
#else
return CountTrailingZerosNonZero32Slow(n);
#endif
}
#undef PHMAP_BASE_INTERNAL_FORCEINLINE
} // namespace base_internal
} // namespace phmap
// -----------------------------------------------------------------------------
// File: endian.h
// -----------------------------------------------------------------------------
namespace phmap {
// Use compiler byte-swapping intrinsics if they are available. 32-bit
// and 64-bit versions are available in Clang and GCC as of GCC 4.3.0.
// The 16-bit version is available in Clang and GCC only as of GCC 4.8.0.
// For simplicity, we enable them all only for GCC 4.8.0 or later.
#if defined(__clang__) || \
(defined(__GNUC__) && \
((__GNUC__ == 4 && __GNUC_MINOR__ >= 8) || __GNUC__ >= 5))
inline uint64_t gbswap_64(uint64_t host_int) {
return __builtin_bswap64(host_int);
}
inline uint32_t gbswap_32(uint32_t host_int) {
return __builtin_bswap32(host_int);
}
inline uint16_t gbswap_16(uint16_t host_int) {
return __builtin_bswap16(host_int);
}
#elif defined(_MSC_VER)
inline uint64_t gbswap_64(uint64_t host_int) {
return _byteswap_uint64(host_int);
}
inline uint32_t gbswap_32(uint32_t host_int) {
return _byteswap_ulong(host_int);
}
inline uint16_t gbswap_16(uint16_t host_int) {
return _byteswap_ushort(host_int);
}
#elif defined(__APPLE__)
inline uint64_t gbswap_64(uint64_t host_int) { return OSSwapInt16(host_int); }
inline uint32_t gbswap_32(uint32_t host_int) { return OSSwapInt32(host_int); }
inline uint16_t gbswap_16(uint16_t host_int) { return OSSwapInt64(host_int); }
#else
inline uint64_t gbswap_64(uint64_t host_int) {
#if defined(__GNUC__) && defined(__x86_64__) && !defined(__APPLE__)
// Adapted from /usr/include/byteswap.h. Not available on Mac.
if (__builtin_constant_p(host_int)) {
return __bswap_constant_64(host_int);
} else {
uint64_t result;
__asm__("bswap %0" : "=r"(result) : "0"(host_int));
return result;
}
#elif defined(__GLIBC__)
return bswap_64(host_int);
#else
return (((host_int & uint64_t{0xFF}) << 56) |
((host_int & uint64_t{0xFF00}) << 40) |
((host_int & uint64_t{0xFF0000}) << 24) |
((host_int & uint64_t{0xFF000000}) << 8) |
((host_int & uint64_t{0xFF00000000}) >> 8) |
((host_int & uint64_t{0xFF0000000000}) >> 24) |
((host_int & uint64_t{0xFF000000000000}) >> 40) |
((host_int & uint64_t{0xFF00000000000000}) >> 56));
#endif // bswap_64
}
inline uint32_t gbswap_32(uint32_t host_int) {
#if defined(__GLIBC__)
return bswap_32(host_int);
#else
return (((host_int & uint32_t{0xFF}) << 24) |
((host_int & uint32_t{0xFF00}) << 8) |
((host_int & uint32_t{0xFF0000}) >> 8) |
((host_int & uint32_t{0xFF000000}) >> 24));
#endif
}
inline uint16_t gbswap_16(uint16_t host_int) {
#if defined(__GLIBC__)
return bswap_16(host_int);
#else
return (((host_int & uint16_t{0xFF}) << 8) |
((host_int & uint16_t{0xFF00}) >> 8));
#endif
}
#endif // intrinics available
#ifdef PHMAP_IS_LITTLE_ENDIAN
// Definitions for ntohl etc. that don't require us to include
// netinet/in.h. We wrap gbswap_32 and gbswap_16 in functions rather
// than just #defining them because in debug mode, gcc doesn't
// correctly handle the (rather involved) definitions of bswap_32.
// gcc guarantees that inline functions are as fast as macros, so
// this isn't a performance hit.
inline uint16_t ghtons(uint16_t x) { return gbswap_16(x); }
inline uint32_t ghtonl(uint32_t x) { return gbswap_32(x); }
inline uint64_t ghtonll(uint64_t x) { return gbswap_64(x); }
#elif defined PHMAP_IS_BIG_ENDIAN
// These definitions are simpler on big-endian machines
// These are functions instead of macros to avoid self-assignment warnings
// on calls such as "i = ghtnol(i);". This also provides type checking.
inline uint16_t ghtons(uint16_t x) { return x; }
inline uint32_t ghtonl(uint32_t x) { return x; }
inline uint64_t ghtonll(uint64_t x) { return x; }
#else
#error \
"Unsupported byte order: Either PHMAP_IS_BIG_ENDIAN or " \
"PHMAP_IS_LITTLE_ENDIAN must be defined"
#endif // byte order
inline uint16_t gntohs(uint16_t x) { return ghtons(x); }
inline uint32_t gntohl(uint32_t x) { return ghtonl(x); }
inline uint64_t gntohll(uint64_t x) { return ghtonll(x); }
// Utilities to convert numbers between the current hosts's native byte
// order and little-endian byte order
//
// Load/Store methods are alignment safe
namespace little_endian {
// Conversion functions.
#ifdef PHMAP_IS_LITTLE_ENDIAN
inline uint16_t FromHost16(uint16_t x) { return x; }
inline uint16_t ToHost16(uint16_t x) { return x; }
inline uint32_t FromHost32(uint32_t x) { return x; }
inline uint32_t ToHost32(uint32_t x) { return x; }
inline uint64_t FromHost64(uint64_t x) { return x; }
inline uint64_t ToHost64(uint64_t x) { return x; }
inline constexpr bool IsLittleEndian() { return true; }
#elif defined PHMAP_IS_BIG_ENDIAN
inline uint16_t FromHost16(uint16_t x) { return gbswap_16(x); }
inline uint16_t ToHost16(uint16_t x) { return gbswap_16(x); }
inline uint32_t FromHost32(uint32_t x) { return gbswap_32(x); }
inline uint32_t ToHost32(uint32_t x) { return gbswap_32(x); }
inline uint64_t FromHost64(uint64_t x) { return gbswap_64(x); }
inline uint64_t ToHost64(uint64_t x) { return gbswap_64(x); }
inline constexpr bool IsLittleEndian() { return false; }
#endif /* ENDIAN */
// Functions to do unaligned loads and stores in little-endian order.
inline uint16_t Load16(const void *p) {
return ToHost16(PHMAP_INTERNAL_UNALIGNED_LOAD16(p));
}
inline void Store16(void *p, uint16_t v) {
PHMAP_INTERNAL_UNALIGNED_STORE16(p, FromHost16(v));
}
inline uint32_t Load32(const void *p) {
return ToHost32(PHMAP_INTERNAL_UNALIGNED_LOAD32(p));
}
inline void Store32(void *p, uint32_t v) {
PHMAP_INTERNAL_UNALIGNED_STORE32(p, FromHost32(v));
}
inline uint64_t Load64(const void *p) {
return ToHost64(PHMAP_INTERNAL_UNALIGNED_LOAD64(p));
}
inline void Store64(void *p, uint64_t v) {
PHMAP_INTERNAL_UNALIGNED_STORE64(p, FromHost64(v));
}
} // namespace little_endian
// Utilities to convert numbers between the current hosts's native byte
// order and big-endian byte order (same as network byte order)
//
// Load/Store methods are alignment safe
namespace big_endian {
#ifdef PHMAP_IS_LITTLE_ENDIAN
inline uint16_t FromHost16(uint16_t x) { return gbswap_16(x); }
inline uint16_t ToHost16(uint16_t x) { return gbswap_16(x); }
inline uint32_t FromHost32(uint32_t x) { return gbswap_32(x); }
inline uint32_t ToHost32(uint32_t x) { return gbswap_32(x); }
inline uint64_t FromHost64(uint64_t x) { return gbswap_64(x); }
inline uint64_t ToHost64(uint64_t x) { return gbswap_64(x); }
inline constexpr bool IsLittleEndian() { return true; }
#elif defined PHMAP_IS_BIG_ENDIAN
inline uint16_t FromHost16(uint16_t x) { return x; }
inline uint16_t ToHost16(uint16_t x) { return x; }
inline uint32_t FromHost32(uint32_t x) { return x; }
inline uint32_t ToHost32(uint32_t x) { return x; }
inline uint64_t FromHost64(uint64_t x) { return x; }
inline uint64_t ToHost64(uint64_t x) { return x; }
inline constexpr bool IsLittleEndian() { return false; }
#endif /* ENDIAN */
// Functions to do unaligned loads and stores in big-endian order.
inline uint16_t Load16(const void *p) {
return ToHost16(PHMAP_INTERNAL_UNALIGNED_LOAD16(p));
}
inline void Store16(void *p, uint16_t v) {
PHMAP_INTERNAL_UNALIGNED_STORE16(p, FromHost16(v));
}
inline uint32_t Load32(const void *p) {
return ToHost32(PHMAP_INTERNAL_UNALIGNED_LOAD32(p));
}
inline void Store32(void *p, uint32_t v) {
PHMAP_INTERNAL_UNALIGNED_STORE32(p, FromHost32(v));
}
inline uint64_t Load64(const void *p) {
return ToHost64(PHMAP_INTERNAL_UNALIGNED_LOAD64(p));
}
inline void Store64(void *p, uint64_t v) {
PHMAP_INTERNAL_UNALIGNED_STORE64(p, FromHost64(v));
}
} // namespace big_endian
} // namespace phmap
#ifdef _MSC_VER
#pragma warning(pop)
#endif
#endif // phmap_bits_h_guard_
#if !defined(phmap_config_h_guard_)
#define phmap_config_h_guard_
// ---------------------------------------------------------------------------
// Copyright (c) 2019, Gregory Popovitch - greg7mdp@gmail.com
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Includes work from abseil-cpp (https://github.com/abseil/abseil-cpp)
// with modifications.
//
// Copyright 2018 The Abseil Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// ---------------------------------------------------------------------------
#define PHMAP_VERSION_MAJOR 1
#define PHMAP_VERSION_MINOR 0
#define PHMAP_VERSION_PATCH 0
// Included for the __GLIBC__ macro (or similar macros on other systems).
#include <limits.h>
#ifdef __cplusplus
// Included for __GLIBCXX__, _LIBCPP_VERSION
#include <cstddef>
#endif // __cplusplus
#if defined(__APPLE__)
// Included for TARGET_OS_IPHONE, __IPHONE_OS_VERSION_MIN_REQUIRED,
// __IPHONE_8_0.
#include <Availability.h>
#include <TargetConditionals.h>
#endif
#define PHMAP_XSTR(x) PHMAP_STR(x)
#define PHMAP_STR(x) #x
#define PHMAP_VAR_NAME_VALUE(var) #var "=" PHMAP_STR(var)
// -----------------------------------------------------------------------------
// Some sanity checks
// -----------------------------------------------------------------------------
//#if defined(__CYGWIN__)
// #error "Cygwin is not supported."
//#endif
#if defined(_MSC_FULL_VER) && _MSC_FULL_VER < 190023918 && !defined(__clang__)
#error "phmap requires Visual Studio 2015 Update 2 or higher."
#endif
// We support gcc 4.7 and later.
#if defined(__GNUC__) && !defined(__clang__)
#if __GNUC__ < 4 || (__GNUC__ == 4 && __GNUC_MINOR__ < 7)
#error "phmap requires gcc 4.7 or higher."
#endif
#endif
// We support Apple Xcode clang 4.2.1 (version 421.11.65) and later.
// This corresponds to Apple Xcode version 4.5.
#if defined(__apple_build_version__) && __apple_build_version__ < 4211165
#error "phmap requires __apple_build_version__ of 4211165 or higher."
#endif
// Enforce C++11 as the minimum.
#if defined(__cplusplus) && !defined(_MSC_VER)
#if __cplusplus < 201103L
#error "C++ versions less than C++11 are not supported."
#endif
#endif
// We have chosen glibc 2.12 as the minimum
#if defined(__GLIBC__) && defined(__GLIBC_PREREQ)
#if !__GLIBC_PREREQ(2, 12)
#error "Minimum required version of glibc is 2.12."
#endif
#endif
#if defined(_STLPORT_VERSION)
#error "STLPort is not supported."
#endif
#if CHAR_BIT != 8
#error "phmap assumes CHAR_BIT == 8."
#endif
// phmap currently assumes that an int is 4 bytes.
#if INT_MAX < 2147483647
#error "phmap assumes that int is at least 4 bytes. "
#endif
// -----------------------------------------------------------------------------
// Compiler Feature Checks
// -----------------------------------------------------------------------------
#ifdef __has_builtin
#define PHMAP_HAVE_BUILTIN(x) __has_builtin(x)
#else
#define PHMAP_HAVE_BUILTIN(x) 0
#endif
#if (defined(_MSVC_LANG) && _MSVC_LANG >= 201703) || __cplusplus >= 201703
#define PHMAP_HAVE_CC17 1
#else
#define PHMAP_HAVE_CC17 0
#endif
#define PHMAP_BRANCHLESS 1
// ----------------------------------------------------------------
// Checks whether `std::is_trivially_destructible<T>` is supported.
// ----------------------------------------------------------------
#ifdef PHMAP_HAVE_STD_IS_TRIVIALLY_DESTRUCTIBLE
#error PHMAP_HAVE_STD_IS_TRIVIALLY_DESTRUCTIBLE cannot be directly set
#elif defined(_LIBCPP_VERSION) || \
(!defined(__clang__) && defined(__GNUC__) && defined(__GLIBCXX__) && \
(__GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ >= 8))) || \
defined(_MSC_VER)
#define PHMAP_HAVE_STD_IS_TRIVIALLY_DESTRUCTIBLE 1
#endif
// --------------------------------------------------------------
// Checks whether `std::is_trivially_default_constructible<T>` is
// supported.
// --------------------------------------------------------------
#if defined(PHMAP_HAVE_STD_IS_TRIVIALLY_CONSTRUCTIBLE)
#error PHMAP_HAVE_STD_IS_TRIVIALLY_CONSTRUCTIBLE cannot be directly set
#elif defined(PHMAP_HAVE_STD_IS_TRIVIALLY_ASSIGNABLE)
#error PHMAP_HAVE_STD_IS_TRIVIALLY_ASSIGNABLE cannot directly set
#elif (defined(__clang__) && defined(_LIBCPP_VERSION)) || \
(!defined(__clang__) && defined(__GNUC__) && \
(__GNUC__ > 5 || (__GNUC__ == 5 && __GNUC_MINOR__ >= 1)) && \
(defined(_LIBCPP_VERSION) || defined(__GLIBCXX__))) || \
(defined(_MSC_VER) && !defined(__NVCC__))
#define PHMAP_HAVE_STD_IS_TRIVIALLY_CONSTRUCTIBLE 1
#define PHMAP_HAVE_STD_IS_TRIVIALLY_ASSIGNABLE 1
#endif
// -------------------------------------------------------------------
// Checks whether C++11's `thread_local` storage duration specifier is
// supported.
// -------------------------------------------------------------------
#ifdef PHMAP_HAVE_THREAD_LOCAL
#error PHMAP_HAVE_THREAD_LOCAL cannot be directly set
#elif defined(__APPLE__)
#if __has_feature(cxx_thread_local) && \
!(TARGET_OS_IPHONE && __IPHONE_OS_VERSION_MIN_REQUIRED < __IPHONE_9_0)
#define PHMAP_HAVE_THREAD_LOCAL 1
#endif
#else // !defined(__APPLE__)
#define PHMAP_HAVE_THREAD_LOCAL 1
#endif
#if defined(__ANDROID__) && defined(__clang__)
#if __has_include(<android/ndk-version.h>)
#include <android/ndk-version.h>
#endif // __has_include(<android/ndk-version.h>)
#if defined(__ANDROID__) && defined(__clang__) && defined(__NDK_MAJOR__) && \
defined(__NDK_MINOR__) && \
((__NDK_MAJOR__ < 12) || ((__NDK_MAJOR__ == 12) && (__NDK_MINOR__ < 1)))
#undef PHMAP_HAVE_TLS
#undef PHMAP_HAVE_THREAD_LOCAL
#endif
#endif
// ------------------------------------------------------------
// Checks whether the __int128 compiler extension for a 128-bit
// integral type is supported.
// ------------------------------------------------------------
#ifdef PHMAP_HAVE_INTRINSIC_INT128
#error PHMAP_HAVE_INTRINSIC_INT128 cannot be directly set
#elif defined(__SIZEOF_INT128__)
#if (defined(__clang__) && !defined(_WIN32) && !defined(__aarch64__)) || \
(defined(__CUDACC__) && __CUDACC_VER_MAJOR__ >= 9) || \
(defined(__GNUC__) && !defined(__clang__) && !defined(__CUDACC__))
#define PHMAP_HAVE_INTRINSIC_INT128 1
#elif defined(__CUDACC__)
#if __CUDACC_VER__ >= 70000
#define PHMAP_HAVE_INTRINSIC_INT128 1
#endif // __CUDACC_VER__ >= 70000
#endif // defined(__CUDACC__)
#endif
// ------------------------------------------------------------------
// Checks whether the compiler both supports and enables exceptions.
// ------------------------------------------------------------------
#ifdef PHMAP_HAVE_EXCEPTIONS
#error PHMAP_HAVE_EXCEPTIONS cannot be directly set.
#elif defined(__clang__)
#if defined(__EXCEPTIONS) && __has_feature(cxx_exceptions)
#define PHMAP_HAVE_EXCEPTIONS 1
#endif // defined(__EXCEPTIONS) && __has_feature(cxx_exceptions)
#elif !(defined(__GNUC__) && (__GNUC__ < 5) && !defined(__EXCEPTIONS)) && \
!(defined(__GNUC__) && (__GNUC__ >= 5) && !defined(__cpp_exceptions)) && \
!(defined(_MSC_VER) && !defined(_CPPUNWIND))
#define PHMAP_HAVE_EXCEPTIONS 1
#endif
// -----------------------------------------------------------------------
// Checks whether the platform has an mmap(2) implementation as defined in
// POSIX.1-2001.
// -----------------------------------------------------------------------
#ifdef PHMAP_HAVE_MMAP
#error PHMAP_HAVE_MMAP cannot be directly set
#elif defined(__linux__) || defined(__APPLE__) || defined(__FreeBSD__) || \
defined(__ros__) || defined(__native_client__) || defined(__asmjs__) || \
defined(__wasm__) || defined(__Fuchsia__) || defined(__sun) || \
defined(__ASYLO__)
#define PHMAP_HAVE_MMAP 1
#endif
// -----------------------------------------------------------------------
// Checks the endianness of the platform.
// -----------------------------------------------------------------------
#if defined(PHMAP_IS_BIG_ENDIAN)
#error "PHMAP_IS_BIG_ENDIAN cannot be directly set."
#endif
#if defined(PHMAP_IS_LITTLE_ENDIAN)
#error "PHMAP_IS_LITTLE_ENDIAN cannot be directly set."
#endif
#if (defined(__BYTE_ORDER__) && defined(__ORDER_LITTLE_ENDIAN__) && \
__BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__)
#define PHMAP_IS_LITTLE_ENDIAN 1
#elif defined(__BYTE_ORDER__) && defined(__ORDER_BIG_ENDIAN__) && \
__BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
#define PHMAP_IS_BIG_ENDIAN 1
#elif defined(_WIN32)
#define PHMAP_IS_LITTLE_ENDIAN 1
#else
#error "phmap endian detection needs to be set up for your compiler"
#endif
#if defined(__APPLE__) && defined(_LIBCPP_VERSION) && \
defined(__MAC_OS_X_VERSION_MIN_REQUIRED__) && \
__ENVIRONMENT_MAC_OS_X_VERSION_MIN_REQUIRED__ < 101400
#define PHMAP_INTERNAL_MACOS_CXX17_TYPES_UNAVAILABLE 1
#else
#define PHMAP_INTERNAL_MACOS_CXX17_TYPES_UNAVAILABLE 0
#endif
// ---------------------------------------------------------------------------
// Checks whether C++17 std::any is available by checking whether <any> exists.
// ---------------------------------------------------------------------------
#ifdef PHMAP_HAVE_STD_ANY
#error "PHMAP_HAVE_STD_ANY cannot be directly set."
#endif
#ifdef __has_include
#if __has_include(<any>) && __cplusplus >= 201703L && \
!PHMAP_INTERNAL_MACOS_CXX17_TYPES_UNAVAILABLE
#define PHMAP_HAVE_STD_ANY 1
#endif
#endif
#ifdef PHMAP_HAVE_STD_OPTIONAL
#error "PHMAP_HAVE_STD_OPTIONAL cannot be directly set."
#endif
#ifdef __has_include
#if __has_include(<optional>) && __cplusplus >= 201703L && \
!PHMAP_INTERNAL_MACOS_CXX17_TYPES_UNAVAILABLE
#define PHMAP_HAVE_STD_OPTIONAL 1
#endif
#endif
#ifdef PHMAP_HAVE_STD_VARIANT
#error "PHMAP_HAVE_STD_VARIANT cannot be directly set."
#endif
#ifdef __has_include
#if __has_include(<variant>) && __cplusplus >= 201703L && \
!PHMAP_INTERNAL_MACOS_CXX17_TYPES_UNAVAILABLE
#define PHMAP_HAVE_STD_VARIANT 1
#endif
#endif
#ifdef PHMAP_HAVE_STD_STRING_VIEW
#error "PHMAP_HAVE_STD_STRING_VIEW cannot be directly set."
#endif
#ifdef __has_include
#if __has_include(<string_view>) && __cplusplus >= 201703L && \
(!defined(_MSC_VER) || _MSC_VER >= 1920) // vs2019
#define PHMAP_HAVE_STD_STRING_VIEW 1
#endif
#endif
// #pragma message(PHMAP_VAR_NAME_VALUE(_MSVC_LANG))
#if defined(_MSC_VER) && _MSC_VER >= 1910 && PHMAP_HAVE_CC17
// #define PHMAP_HAVE_STD_ANY 1
#define PHMAP_HAVE_STD_OPTIONAL 1
#define PHMAP_HAVE_STD_VARIANT 1
#if !defined(PHMAP_HAVE_STD_STRING_VIEW) && _MSC_VER >= 1920
#define PHMAP_HAVE_STD_STRING_VIEW 1
#endif
#endif
#if PHMAP_HAVE_CC17
#define PHMAP_HAVE_SHARED_MUTEX 1
#endif
#ifndef PHMAP_HAVE_STD_STRING_VIEW
#define PHMAP_HAVE_STD_STRING_VIEW 0
#endif
// In debug mode, MSVC 2017's std::variant throws a EXCEPTION_ACCESS_VIOLATION
// SEH exception from emplace for variant<SomeStruct> when constructing the
// struct can throw. This defeats some of variant_test and
// variant_exception_safety_test.
#if defined(_MSC_VER) && _MSC_VER >= 1700 && defined(_DEBUG)
#define PHMAP_INTERNAL_MSVC_2017_DBG_MODE
#endif
// -----------------------------------------------------------------------------
// Sanitizer Attributes
// -----------------------------------------------------------------------------
//
// Sanitizer-related attributes are not "defined" in this file (and indeed
// are not defined as such in any file). To utilize the following
// sanitizer-related attributes within your builds, define the following macros
// within your build using a `-D` flag, along with the given value for
// `-fsanitize`:
//
// * `ADDRESS_SANITIZER` + `-fsanitize=address` (Clang, GCC 4.8)
// * `MEMORY_SANITIZER` + `-fsanitize=memory` (Clang-only)
// * `THREAD_SANITIZER + `-fsanitize=thread` (Clang, GCC 4.8+)
// * `UNDEFINED_BEHAVIOR_SANITIZER` + `-fsanitize=undefined` (Clang, GCC 4.9+)
// * `CONTROL_FLOW_INTEGRITY` + -fsanitize=cfi (Clang-only)
// -----------------------------------------------------------------------------
// -----------------------------------------------------------------------------
// A function-like feature checking macro that is a wrapper around
// `__has_attribute`, which is defined by GCC 5+ and Clang and evaluates to a
// nonzero constant integer if the attribute is supported or 0 if not.
//
// It evaluates to zero if `__has_attribute` is not defined by the compiler.
// -----------------------------------------------------------------------------
#ifdef __has_attribute
#define PHMAP_HAVE_ATTRIBUTE(x) __has_attribute(x)
#else
#define PHMAP_HAVE_ATTRIBUTE(x) 0
#endif
// -----------------------------------------------------------------------------
// A function-like feature checking macro that accepts C++11 style attributes.
// It's a wrapper around `__has_cpp_attribute`, defined by ISO C++ SD-6
// (https://en.cppreference.com/w/cpp/experimental/feature_test). If we don't
// find `__has_cpp_attribute`, will evaluate to 0.
// -----------------------------------------------------------------------------
#if defined(__cplusplus) && defined(__has_cpp_attribute)
#define PHMAP_HAVE_CPP_ATTRIBUTE(x) __has_cpp_attribute(x)
#else
#define PHMAP_HAVE_CPP_ATTRIBUTE(x) 0
#endif
// -----------------------------------------------------------------------------
// Function Attributes
// -----------------------------------------------------------------------------
#if PHMAP_HAVE_ATTRIBUTE(format) || (defined(__GNUC__) && !defined(__clang__))
#define PHMAP_PRINTF_ATTRIBUTE(string_index, first_to_check) \
__attribute__((__format__(__printf__, string_index, first_to_check)))
#define PHMAP_SCANF_ATTRIBUTE(string_index, first_to_check) \
__attribute__((__format__(__scanf__, string_index, first_to_check)))
#else
#define PHMAP_PRINTF_ATTRIBUTE(string_index, first_to_check)
#define PHMAP_SCANF_ATTRIBUTE(string_index, first_to_check)
#endif
#if PHMAP_HAVE_ATTRIBUTE(always_inline) || \
(defined(__GNUC__) && !defined(__clang__))
#define PHMAP_ATTRIBUTE_ALWAYS_INLINE __attribute__((always_inline))
#define PHMAP_HAVE_ATTRIBUTE_ALWAYS_INLINE 1
#else
#define PHMAP_ATTRIBUTE_ALWAYS_INLINE
#endif
#if !defined(__INTEL_COMPILER) && (PHMAP_HAVE_ATTRIBUTE(noinline) || (defined(__GNUC__) && !defined(__clang__)))
#define PHMAP_ATTRIBUTE_NOINLINE __attribute__((noinline))
#define PHMAP_HAVE_ATTRIBUTE_NOINLINE 1
#else
#define PHMAP_ATTRIBUTE_NOINLINE
#endif
#if PHMAP_HAVE_ATTRIBUTE(disable_tail_calls)
#define PHMAP_HAVE_ATTRIBUTE_NO_TAIL_CALL 1
#define PHMAP_ATTRIBUTE_NO_TAIL_CALL __attribute__((disable_tail_calls))
#elif defined(__GNUC__) && !defined(__clang__)
#define PHMAP_HAVE_ATTRIBUTE_NO_TAIL_CALL 1
#define PHMAP_ATTRIBUTE_NO_TAIL_CALL \
__attribute__((optimize("no-optimize-sibling-calls")))
#else
#define PHMAP_ATTRIBUTE_NO_TAIL_CALL
#define PHMAP_HAVE_ATTRIBUTE_NO_TAIL_CALL 0
#endif
#if (PHMAP_HAVE_ATTRIBUTE(weak) || \
(defined(__GNUC__) && !defined(__clang__))) && \
!(defined(__llvm__) && defined(_WIN32))
#undef PHMAP_ATTRIBUTE_WEAK
#define PHMAP_ATTRIBUTE_WEAK __attribute__((weak))
#define PHMAP_HAVE_ATTRIBUTE_WEAK 1
#else
#define PHMAP_ATTRIBUTE_WEAK
#define PHMAP_HAVE_ATTRIBUTE_WEAK 0
#endif
#if PHMAP_HAVE_ATTRIBUTE(nonnull) || (defined(__GNUC__) && !defined(__clang__))
#define PHMAP_ATTRIBUTE_NONNULL(arg_index) __attribute__((nonnull(arg_index)))
#else
#define PHMAP_ATTRIBUTE_NONNULL(...)
#endif
#if PHMAP_HAVE_ATTRIBUTE(noreturn) || (defined(__GNUC__) && !defined(__clang__))
#define PHMAP_ATTRIBUTE_NORETURN __attribute__((noreturn))
#elif defined(_MSC_VER)
#define PHMAP_ATTRIBUTE_NORETURN __declspec(noreturn)
#else
#define PHMAP_ATTRIBUTE_NORETURN
#endif
#if defined(__GNUC__) && defined(ADDRESS_SANITIZER)
#define PHMAP_ATTRIBUTE_NO_SANITIZE_ADDRESS __attribute__((no_sanitize_address))
#else
#define PHMAP_ATTRIBUTE_NO_SANITIZE_ADDRESS
#endif
#if defined(__GNUC__) && defined(MEMORY_SANITIZER)
#define PHMAP_ATTRIBUTE_NO_SANITIZE_MEMORY __attribute__((no_sanitize_memory))
#else
#define PHMAP_ATTRIBUTE_NO_SANITIZE_MEMORY
#endif
#if defined(__GNUC__) && defined(THREAD_SANITIZER)
#define PHMAP_ATTRIBUTE_NO_SANITIZE_THREAD __attribute__((no_sanitize_thread))
#else
#define PHMAP_ATTRIBUTE_NO_SANITIZE_THREAD
#endif
#if defined(__GNUC__) && \
(defined(UNDEFINED_BEHAVIOR_SANITIZER) || defined(ADDRESS_SANITIZER))
#define PHMAP_ATTRIBUTE_NO_SANITIZE_UNDEFINED \
__attribute__((no_sanitize("undefined")))
#else
#define PHMAP_ATTRIBUTE_NO_SANITIZE_UNDEFINED
#endif
#if defined(__GNUC__) && defined(CONTROL_FLOW_INTEGRITY)
#define PHMAP_ATTRIBUTE_NO_SANITIZE_CFI __attribute__((no_sanitize("cfi")))
#else
#define PHMAP_ATTRIBUTE_NO_SANITIZE_CFI
#endif
#if defined(__GNUC__) && defined(SAFESTACK_SANITIZER)
#define PHMAP_ATTRIBUTE_NO_SANITIZE_SAFESTACK \
__attribute__((no_sanitize("safe-stack")))
#else
#define PHMAP_ATTRIBUTE_NO_SANITIZE_SAFESTACK
#endif
#if PHMAP_HAVE_ATTRIBUTE(returns_nonnull) || \
(defined(__GNUC__) && \
(__GNUC__ > 5 || (__GNUC__ == 4 && __GNUC_MINOR__ >= 9)) && \
!defined(__clang__))
#define PHMAP_ATTRIBUTE_RETURNS_NONNULL __attribute__((returns_nonnull))
#else
#define PHMAP_ATTRIBUTE_RETURNS_NONNULL
#endif
#ifdef PHMAP_HAVE_ATTRIBUTE_SECTION
#error PHMAP_HAVE_ATTRIBUTE_SECTION cannot be directly set
#elif (PHMAP_HAVE_ATTRIBUTE(section) || \
(defined(__GNUC__) && !defined(__clang__))) && \
!defined(__APPLE__) && PHMAP_HAVE_ATTRIBUTE_WEAK
#define PHMAP_HAVE_ATTRIBUTE_SECTION 1
#ifndef PHMAP_ATTRIBUTE_SECTION
#define PHMAP_ATTRIBUTE_SECTION(name) \
__attribute__((section(#name))) __attribute__((noinline))
#endif
#ifndef PHMAP_ATTRIBUTE_SECTION_VARIABLE
#define PHMAP_ATTRIBUTE_SECTION_VARIABLE(name) __attribute__((section(#name)))
#endif
#ifndef PHMAP_DECLARE_ATTRIBUTE_SECTION_VARS
#define PHMAP_DECLARE_ATTRIBUTE_SECTION_VARS(name) \
extern char __start_##name[] PHMAP_ATTRIBUTE_WEAK; \
extern char __stop_##name[] PHMAP_ATTRIBUTE_WEAK
#endif
#ifndef PHMAP_DEFINE_ATTRIBUTE_SECTION_VARS
#define PHMAP_INIT_ATTRIBUTE_SECTION_VARS(name)
#define PHMAP_DEFINE_ATTRIBUTE_SECTION_VARS(name)
#endif
#define PHMAP_ATTRIBUTE_SECTION_START(name) \
(reinterpret_cast<void *>(__start_##name))
#define PHMAP_ATTRIBUTE_SECTION_STOP(name) \
(reinterpret_cast<void *>(__stop_##name))
#else // !PHMAP_HAVE_ATTRIBUTE_SECTION
#define PHMAP_HAVE_ATTRIBUTE_SECTION 0
#define PHMAP_ATTRIBUTE_SECTION(name)
#define PHMAP_ATTRIBUTE_SECTION_VARIABLE(name)
#define PHMAP_INIT_ATTRIBUTE_SECTION_VARS(name)
#define PHMAP_DEFINE_ATTRIBUTE_SECTION_VARS(name)
#define PHMAP_DECLARE_ATTRIBUTE_SECTION_VARS(name)
#define PHMAP_ATTRIBUTE_SECTION_START(name) (reinterpret_cast<void *>(0))
#define PHMAP_ATTRIBUTE_SECTION_STOP(name) (reinterpret_cast<void *>(0))
#endif // PHMAP_ATTRIBUTE_SECTION
#if PHMAP_HAVE_ATTRIBUTE(force_align_arg_pointer) || \
(defined(__GNUC__) && !defined(__clang__))
#if defined(__i386__)
#define PHMAP_ATTRIBUTE_STACK_ALIGN_FOR_OLD_LIBC \
__attribute__((force_align_arg_pointer))
#define PHMAP_REQUIRE_STACK_ALIGN_TRAMPOLINE (0)
#elif defined(__x86_64__)
#define PHMAP_REQUIRE_STACK_ALIGN_TRAMPOLINE (1)
#define PHMAP_ATTRIBUTE_STACK_ALIGN_FOR_OLD_LIBC
#else // !__i386__ && !__x86_64
#define PHMAP_REQUIRE_STACK_ALIGN_TRAMPOLINE (0)
#define PHMAP_ATTRIBUTE_STACK_ALIGN_FOR_OLD_LIBC
#endif // __i386__
#else
#define PHMAP_ATTRIBUTE_STACK_ALIGN_FOR_OLD_LIBC
#define PHMAP_REQUIRE_STACK_ALIGN_TRAMPOLINE (0)
#endif
#if PHMAP_HAVE_ATTRIBUTE(nodiscard)
#define PHMAP_MUST_USE_RESULT [[nodiscard]]
#elif defined(__clang__) && PHMAP_HAVE_ATTRIBUTE(warn_unused_result)
#define PHMAP_MUST_USE_RESULT __attribute__((warn_unused_result))
#else
#define PHMAP_MUST_USE_RESULT
#endif
#if PHMAP_HAVE_ATTRIBUTE(hot) || (defined(__GNUC__) && !defined(__clang__))
#define PHMAP_ATTRIBUTE_HOT __attribute__((hot))
#else
#define PHMAP_ATTRIBUTE_HOT
#endif
#if PHMAP_HAVE_ATTRIBUTE(cold) || (defined(__GNUC__) && !defined(__clang__))
#define PHMAP_ATTRIBUTE_COLD __attribute__((cold))
#else
#define PHMAP_ATTRIBUTE_COLD
#endif
#if defined(__clang__)
#if PHMAP_HAVE_CPP_ATTRIBUTE(clang::reinitializes)
#define PHMAP_ATTRIBUTE_REINITIALIZES [[clang::reinitializes]]
#else
#define PHMAP_ATTRIBUTE_REINITIALIZES
#endif
#else
#define PHMAP_ATTRIBUTE_REINITIALIZES
#endif
#if PHMAP_HAVE_ATTRIBUTE(unused) || (defined(__GNUC__) && !defined(__clang__))
#undef PHMAP_ATTRIBUTE_UNUSED
#define PHMAP_ATTRIBUTE_UNUSED __attribute__((__unused__))
#else
#define PHMAP_ATTRIBUTE_UNUSED
#endif
#if PHMAP_HAVE_ATTRIBUTE(tls_model) || (defined(__GNUC__) && !defined(__clang__))
#define PHMAP_ATTRIBUTE_INITIAL_EXEC __attribute__((tls_model("initial-exec")))
#else
#define PHMAP_ATTRIBUTE_INITIAL_EXEC
#endif
#if PHMAP_HAVE_ATTRIBUTE(packed) || (defined(__GNUC__) && !defined(__clang__))
#define PHMAP_ATTRIBUTE_PACKED __attribute__((__packed__))
#else
#define PHMAP_ATTRIBUTE_PACKED
#endif
#if PHMAP_HAVE_ATTRIBUTE(aligned) || (defined(__GNUC__) && !defined(__clang__))
#define PHMAP_ATTRIBUTE_FUNC_ALIGN(bytes) __attribute__((aligned(bytes)))
#else
#define PHMAP_ATTRIBUTE_FUNC_ALIGN(bytes)
#endif
// ----------------------------------------------------------------------
// Figure out SSE support
// ----------------------------------------------------------------------
#ifndef PHMAP_HAVE_SSE2
#if defined(__SSE2__) || \
(defined(_MSC_VER) && \
(defined(_M_X64) || (defined(_M_IX86) && _M_IX86_FP >= 2)))
#define PHMAP_HAVE_SSE2 1
#else
#define PHMAP_HAVE_SSE2 0
#endif
#endif
#ifndef PHMAP_HAVE_SSSE3
#if defined(__SSSE3__) || defined(__AVX2__)
#define PHMAP_HAVE_SSSE3 1
#else
#define PHMAP_HAVE_SSSE3 0
#endif
#endif
#if PHMAP_HAVE_SSSE3 && !PHMAP_HAVE_SSE2
#error "Bad configuration!"
#endif
#if PHMAP_HAVE_SSE2
#include <emmintrin.h>
#endif
#if PHMAP_HAVE_SSSE3
#include <tmmintrin.h>
#endif
// ----------------------------------------------------------------------
// constexpr if
// ----------------------------------------------------------------------
#if PHMAP_HAVE_CC17
#define PHMAP_IF_CONSTEXPR(expr) if constexpr ((expr))
#else
#define PHMAP_IF_CONSTEXPR(expr) if ((expr))
#endif
// ----------------------------------------------------------------------
// base/macros.h
// ----------------------------------------------------------------------
// PHMAP_ARRAYSIZE()
//
// Returns the number of elements in an array as a compile-time constant, which
// can be used in defining new arrays. If you use this macro on a pointer by
// mistake, you will get a compile-time error.
#define PHMAP_ARRAYSIZE(array) \
(sizeof(::phmap::macros_internal::ArraySizeHelper(array)))
namespace phmap {
namespace macros_internal {
// Note: this internal template function declaration is used by PHMAP_ARRAYSIZE.
// The function doesn't need a definition, as we only use its type.
template <typename T, size_t N>
auto ArraySizeHelper(const T (&array)[N]) -> char (&)[N];
} // namespace macros_internal
} // namespace phmap
// TODO(zhangxy): Use c++17 standard [[fallthrough]] macro, when supported.
#if defined(__clang__) && defined(__has_warning)
#if __has_feature(cxx_attributes) && __has_warning("-Wimplicit-fallthrough")
#define PHMAP_FALLTHROUGH_INTENDED [[clang::fallthrough]]
#endif
#elif defined(__GNUC__) && __GNUC__ >= 7
#define PHMAP_FALLTHROUGH_INTENDED [[gnu::fallthrough]]
#endif
#ifndef PHMAP_FALLTHROUGH_INTENDED
#define PHMAP_FALLTHROUGH_INTENDED \
do { } while (0)
#endif
// PHMAP_DEPRECATED()
//
// Marks a deprecated class, struct, enum, function, method and variable
// declarations. The macro argument is used as a custom diagnostic message (e.g.
// suggestion of a better alternative).
//
// Example:
//
// class PHMAP_DEPRECATED("Use Bar instead") Foo {...};
// PHMAP_DEPRECATED("Use Baz instead") void Bar() {...}
//
// Every usage of a deprecated entity will trigger a warning when compiled with
// clang's `-Wdeprecated-declarations` option. This option is turned off by
// default, but the warnings will be reported by clang-tidy.
#if defined(__clang__) && __cplusplus >= 201103L
#define PHMAP_DEPRECATED(message) __attribute__((deprecated(message)))
#endif
#ifndef PHMAP_DEPRECATED
#define PHMAP_DEPRECATED(message)
#endif
// PHMAP_BAD_CALL_IF()
//
// Used on a function overload to trap bad calls: any call that matches the
// overload will cause a compile-time error. This macro uses a clang-specific
// "enable_if" attribute, as described at
// http://clang.llvm.org/docs/AttributeReference.html#enable-if
//
// Overloads which use this macro should be bracketed by
// `#ifdef PHMAP_BAD_CALL_IF`.
//
// Example:
//
// int isdigit(int c);
// #ifdef PHMAP_BAD_CALL_IF
// int isdigit(int c)
// PHMAP_BAD_CALL_IF(c <= -1 || c > 255,
// "'c' must have the value of an unsigned char or EOF");
// #endif // PHMAP_BAD_CALL_IF
#if defined(__clang__)
#if __has_attribute(enable_if)
#define PHMAP_BAD_CALL_IF(expr, msg) \
__attribute__((enable_if(expr, "Bad call trap"), unavailable(msg)))
#endif
#endif
// PHMAP_ASSERT()
//
// In C++11, `assert` can't be used portably within constexpr functions.
// PHMAP_ASSERT functions as a runtime assert but works in C++11 constexpr
// functions. Example:
//
// constexpr double Divide(double a, double b) {
// return PHMAP_ASSERT(b != 0), a / b;
// }
//
// This macro is inspired by
// https://akrzemi1.wordpress.com/2017/05/18/asserts-in-constexpr-functions/
#if defined(NDEBUG)
#define PHMAP_ASSERT(expr) (false ? (void)(expr) : (void)0)
#else
#define PHMAP_ASSERT(expr) \
(PHMAP_PREDICT_TRUE((expr)) ? (void)0 \
: [] { assert(false && #expr); }()) // NOLINT
#endif
#ifdef PHMAP_HAVE_EXCEPTIONS
#define PHMAP_INTERNAL_TRY try
#define PHMAP_INTERNAL_CATCH_ANY catch (...)
#define PHMAP_INTERNAL_RETHROW do { throw; } while (false)
#else // PHMAP_HAVE_EXCEPTIONS
#define PHMAP_INTERNAL_TRY if (true)
#define PHMAP_INTERNAL_CATCH_ANY else if (false)
#define PHMAP_INTERNAL_RETHROW do {} while (false)
#endif // PHMAP_HAVE_EXCEPTIONS
#endif // phmap_config_h_guard_
#if !defined(phmap_dump_h_guard_)
#define phmap_dump_h_guard_
// ---------------------------------------------------------------------------
// Copyright (c) 2019, Gregory Popovitch - greg7mdp@gmail.com
//
// providing dump/load/mmap_load
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// ---------------------------------------------------------------------------
#include <iostream>
#include <fstream>
#include <sstream>
#include "phmap.h"
namespace phmap
{
namespace type_traits_internal {
#if defined(__GLIBCXX__) && __GLIBCXX__ < 20150801
template<typename T> struct IsTriviallyCopyable : public std::integral_constant<bool, __has_trivial_copy(T)> {};
#else
template<typename T> struct IsTriviallyCopyable : public std::is_trivially_copyable<T> {};
#endif
template <class T1, class T2>
struct IsTriviallyCopyable<std::pair<T1, T2>> {
static constexpr bool value = IsTriviallyCopyable<T1>::value && IsTriviallyCopyable<T2>::value;
};
}
namespace priv {
// ------------------------------------------------------------------------
// dump/load for raw_hash_set
// ------------------------------------------------------------------------
template <class Policy, class Hash, class Eq, class Alloc>
template<typename OutputArchive>
bool raw_hash_set<Policy, Hash, Eq, Alloc>::dump(OutputArchive& ar) const {
static_assert(type_traits_internal::IsTriviallyCopyable<value_type>::value,
"value_type should be trivially copyable");
if (!ar.dump(size_)) {
std::cerr << "Failed to dump size_" << std::endl;
return false;
}
if (size_ == 0) {
return true;
}
if (!ar.dump(capacity_)) {
std::cerr << "Failed to dump capacity_" << std::endl;
return false;
}
if (!ar.dump(reinterpret_cast<char*>(ctrl_),
sizeof(ctrl_t) * (capacity_ + Group::kWidth + 1))) {
std::cerr << "Failed to dump ctrl_" << std::endl;
return false;
}
if (!ar.dump(reinterpret_cast<char*>(slots_),
sizeof(slot_type) * capacity_)) {
std::cerr << "Failed to dump slot_" << std::endl;
return false;
}
return true;
}
template <class Policy, class Hash, class Eq, class Alloc>
template<typename InputArchive>
bool raw_hash_set<Policy, Hash, Eq, Alloc>::load(InputArchive& ar) {
static_assert(type_traits_internal::IsTriviallyCopyable<value_type>::value,
"value_type should be trivially copyable");
raw_hash_set<Policy, Hash, Eq, Alloc>().swap(*this); // clear any existing content
if (!ar.load(&size_)) {
std::cerr << "Failed to load size_" << std::endl;
return false;
}
if (size_ == 0) {
return true;
}
if (!ar.load(&capacity_)) {
std::cerr << "Failed to load capacity_" << std::endl;
return false;
}
// allocate memory for ctrl_ and slots_
initialize_slots();
if (!ar.load(reinterpret_cast<char*>(ctrl_),
sizeof(ctrl_t) * (capacity_ + Group::kWidth + 1))) {
std::cerr << "Failed to load ctrl" << std::endl;
return false;
}
if (!ar.load(reinterpret_cast<char*>(slots_),
sizeof(slot_type) * capacity_)) {
std::cerr << "Failed to load slot" << std::endl;
return false;
}
return true;
}
// ------------------------------------------------------------------------
// dump/load for parallel_hash_set
// ------------------------------------------------------------------------
template <size_t N,
template <class, class, class, class> class RefSet,
class Mtx_,
class Policy, class Hash, class Eq, class Alloc>
template<typename OutputArchive>
bool parallel_hash_set<N, RefSet, Mtx_, Policy, Hash, Eq, Alloc>::dump(OutputArchive& ar) const {
static_assert(type_traits_internal::IsTriviallyCopyable<value_type>::value,
"value_type should be trivially copyable");
if (! ar.dump(subcnt())) {
std::cerr << "Failed to dump meta!" << std::endl;
return false;
}
for (size_t i = 0; i < sets_.size(); ++i) {
auto& inner = sets_[i];
typename Lockable::UniqueLock m(const_cast<Inner&>(inner));
if (!inner.set_.dump(ar)) {
std::cerr << "Failed to dump submap " << i << std::endl;
return false;
}
}
return true;
}
template <size_t N,
template <class, class, class, class> class RefSet,
class Mtx_,
class Policy, class Hash, class Eq, class Alloc>
template<typename InputArchive>
bool parallel_hash_set<N, RefSet, Mtx_, Policy, Hash, Eq, Alloc>::load(InputArchive& ar) {
static_assert(type_traits_internal::IsTriviallyCopyable<value_type>::value,
"value_type should be trivially copyable");
size_t submap_count = 0;
if (!ar.load(&submap_count)) {
std::cerr << "Failed to load submap count!" << std::endl;
return false;
}
if (submap_count != subcnt()) {
std::cerr << "submap count(" << submap_count << ") != N(" << N << ")" << std::endl;
return false;
}
for (size_t i = 0; i < submap_count; ++i) {
auto& inner = sets_[i];
typename Lockable::UniqueLock m(const_cast<Inner&>(inner));
if (!inner.set_.load(ar)) {
std::cerr << "Failed to load submap " << i << std::endl;
return false;
}
}
return true;
}
} // namespace priv
// ------------------------------------------------------------------------
// BinaryArchive
// File is closed when archive object is destroyed
// ------------------------------------------------------------------------
// ------------------------------------------------------------------------
// ------------------------------------------------------------------------
class BinaryOutputArchive {
public:
BinaryOutputArchive(const char *file_path) {
ofs_.open(file_path, std::ios_base::binary);
}
bool dump(const char *p, size_t sz) {
ofs_.write(p, sz);
return true;
}
template<typename V>
typename std::enable_if<type_traits_internal::IsTriviallyCopyable<V>::value, bool>::type
dump(const V& v) {
ofs_.write(reinterpret_cast<const char *>(&v), sizeof(V));
return true;
}
private:
std::ofstream ofs_;
};
class BinaryInputArchive {
public:
BinaryInputArchive(const char * file_path) {
ifs_.open(file_path, std::ios_base::binary);
}
bool load(char* p, size_t sz) {
ifs_.read(p, sz);
return true;
}
template<typename V>
typename std::enable_if<type_traits_internal::IsTriviallyCopyable<V>::value, bool>::type
load(V* v) {
ifs_.read(reinterpret_cast<char *>(v), sizeof(V));
return true;
}
private:
std::ifstream ifs_;
};
} // namespace phmap
#endif // phmap_dump_h_guard_
#if !defined(phmap_fwd_decl_h_guard_)
#define phmap_fwd_decl_h_guard_
// ---------------------------------------------------------------------------
// Copyright (c) 2019, Gregory Popovitch - greg7mdp@gmail.com
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
// ---------------------------------------------------------------------------
#ifdef _MSC_VER
#pragma warning(push)
#pragma warning(disable : 4514) // unreferenced inline function has been removed
#pragma warning(disable : 4710) // function not inlined
#pragma warning(disable : 4711) // selected for automatic inline expansion
#endif
#include <memory>
#include <utility>
#if defined(PHMAP_USE_ABSL_HASH) && !defined(ABSL_HASH_HASH_H_)
namespace absl { template <class T> struct Hash; };
#endif
namespace phmap {
#if defined(PHMAP_USE_ABSL_HASH)
template <class T> using Hash = ::absl::Hash<T>;
#else
template <class T> struct Hash;
#endif
template <class T> struct EqualTo;
template <class T> struct Less;
template <class T> using Allocator = typename std::allocator<T>;
template<class T1, class T2> using Pair = typename std::pair<T1, T2>;
class NullMutex;
namespace priv {
// The hash of an object of type T is computed by using phmap::Hash.
template <class T, class E = void>
struct HashEq
{
using Hash = phmap::Hash<T>;
using Eq = phmap::EqualTo<T>;
};
template <class T>
using hash_default_hash = typename priv::HashEq<T>::Hash;
template <class T>
using hash_default_eq = typename priv::HashEq<T>::Eq;
// type alias for std::allocator so we can forward declare without including other headers
template <class T>
using Allocator = typename phmap::Allocator<T>;
// type alias for std::pair so we can forward declare without including other headers
template<class T1, class T2>
using Pair = typename phmap::Pair<T1, T2>;
} // namespace priv
// ------------- forward declarations for hash containers ----------------------------------
template <class T,
class Hash = phmap::priv::hash_default_hash<T>,
class Eq = phmap::priv::hash_default_eq<T>,
class Alloc = phmap::priv::Allocator<T>> // alias for std::allocator
class flat_hash_set;
template <class K, class V,
class Hash = phmap::priv::hash_default_hash<K>,
class Eq = phmap::priv::hash_default_eq<K>,
class Alloc = phmap::priv::Allocator<
phmap::priv::Pair<const K, V>>> // alias for std::allocator
class flat_hash_map;
template <class T,
class Hash = phmap::priv::hash_default_hash<T>,
class Eq = phmap::priv::hash_default_eq<T>,
class Alloc = phmap::priv::Allocator<T>> // alias for std::allocator
class node_hash_set;
template <class Key, class Value,
class Hash = phmap::priv::hash_default_hash<Key>,
class Eq = phmap::priv::hash_default_eq<Key>,
class Alloc = phmap::priv::Allocator<
phmap::priv::Pair<const Key, Value>>> // alias for std::allocator
class node_hash_map;
template <class T,
class Hash = phmap::priv::hash_default_hash<T>,
class Eq = phmap::priv::hash_default_eq<T>,
class Alloc = phmap::priv::Allocator<T>, // alias for std::allocator
size_t N = 4, // 2**N submaps
class Mutex = phmap::NullMutex> // use std::mutex to enable internal locks
class parallel_flat_hash_set;
template <class K, class V,
class Hash = phmap::priv::hash_default_hash<K>,
class Eq = phmap::priv::hash_default_eq<K>,
class Alloc = phmap::priv::Allocator<
phmap::priv::Pair<const K, V>>, // alias for std::allocator
size_t N = 4, // 2**N submaps
class Mutex = phmap::NullMutex> // use std::mutex to enable internal locks
class parallel_flat_hash_map;
template <class T,
class Hash = phmap::priv::hash_default_hash<T>,
class Eq = phmap::priv::hash_default_eq<T>,
class Alloc = phmap::priv::Allocator<T>, // alias for std::allocator
size_t N = 4, // 2**N submaps
class Mutex = phmap::NullMutex> // use std::mutex to enable internal locks
class parallel_node_hash_set;
template <class Key, class Value,
class Hash = phmap::priv::hash_default_hash<Key>,
class Eq = phmap::priv::hash_default_eq<Key>,
class Alloc = phmap::priv::Allocator<
phmap::priv::Pair<const Key, Value>>, // alias for std::allocator
size_t N = 4, // 2**N submaps
class Mutex = phmap::NullMutex> // use std::mutex to enable internal locks
class parallel_node_hash_map;
// ------------- forward declarations for btree containers ----------------------------------
template <typename Key, typename Compare = phmap::Less<Key>,
typename Alloc = phmap::Allocator<Key>>
class btree_set;
template <typename Key, typename Compare = phmap::Less<Key>,
typename Alloc = phmap::Allocator<Key>>
class btree_multiset;
template <typename Key, typename Value, typename Compare = phmap::Less<Key>,
typename Alloc = phmap::Allocator<phmap::priv::Pair<const Key, Value>>>
class btree_map;
template <typename Key, typename Value, typename Compare = phmap::Less<Key>,
typename Alloc = phmap::Allocator<phmap::priv::Pair<const Key, Value>>>
class btree_multimap;
} // namespace phmap
#ifdef _MSC_VER
#pragma warning(pop)
#endif
#endif // phmap_fwd_decl_h_guard_
#if !defined(phmap_utils_h_guard_)
#define phmap_utils_h_guard_
// ---------------------------------------------------------------------------
// Copyright (c) 2019, Gregory Popovitch - greg7mdp@gmail.com
//
// minimal header providing phmap::HashState
//
// use as: phmap::HashState().combine(0, _first_name, _last_name, _age);
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// ---------------------------------------------------------------------------
#ifdef _MSC_VER
#pragma warning(push)
#pragma warning(disable : 4514) // unreferenced inline function has been removed
#pragma warning(disable : 4710) // function not inlined
#pragma warning(disable : 4711) // selected for automatic inline expansion
#endif
#include <cstdint>
#include <functional>
#include <tuple>
#include "phmap_bits.h"
// ---------------------------------------------------------------
// Absl forward declaration requires global scope.
// ---------------------------------------------------------------
#if defined(PHMAP_USE_ABSL_HASH) && !defined(phmap_fwd_decl_h_guard_) && !defined(ABSL_HASH_HASH_H_)
namespace absl { template <class T> struct Hash; };
#endif
namespace phmap
{
// ---------------------------------------------------------------
// ---------------------------------------------------------------
template<int n>
struct phmap_mix
{
inline size_t operator()(size_t) const;
};
template<>
struct phmap_mix<4>
{
inline size_t operator()(size_t a) const
{
static constexpr uint64_t kmul = 0xcc9e2d51UL;
// static constexpr uint64_t kmul = 0x3B9ACB93UL; // [greg] my own random prime
uint64_t l = a * kmul;
return static_cast<size_t>(l ^ (l >> 32));
}
};
#if defined(PHMAP_HAS_UMUL128)
template<>
struct phmap_mix<8>
{
// Very fast mixing (similar to Abseil)
inline size_t operator()(size_t a) const
{
static constexpr uint64_t k = 0xde5fb9d2630458e9ULL;
// static constexpr uint64_t k = 0x7C9D0BF0567102A5ULL; // [greg] my own random prime
uint64_t h;
uint64_t l = umul128(a, k, &h);
return static_cast<size_t>(h + l);
}
};
#else
template<>
struct phmap_mix<8>
{
inline size_t operator()(size_t a) const
{
a = (~a) + (a << 21); // a = (a << 21) - a - 1;
a = a ^ (a >> 24);
a = (a + (a << 3)) + (a << 8); // a * 265
a = a ^ (a >> 14);
a = (a + (a << 2)) + (a << 4); // a * 21
a = a ^ (a >> 28);
a = a + (a << 31);
return static_cast<size_t>(a);
}
};
#endif
// --------------------------------------------
template<int n>
struct fold_if_needed
{
inline size_t operator()(uint64_t) const;
};
template<>
struct fold_if_needed<4>
{
inline size_t operator()(uint64_t a) const
{
return static_cast<size_t>(a ^ (a >> 32));
}
};
template<>
struct fold_if_needed<8>
{
inline size_t operator()(uint64_t a) const
{
return static_cast<size_t>(a);
}
};
// ---------------------------------------------------------------
// see if class T has a hash_value() friend method
// ---------------------------------------------------------------
template<typename T>
struct has_hash_value
{
private:
typedef std::true_type yes;
typedef std::false_type no;
template<typename U> static auto test(int) -> decltype(hash_value(std::declval<const U&>()) == 1, yes());
template<typename> static no test(...);
public:
static constexpr bool value = std::is_same<decltype(test<T>(0)), yes>::value;
};
#if defined(PHMAP_USE_ABSL_HASH) && !defined(phmap_fwd_decl_h_guard_)
template <class T> using Hash = ::absl::Hash<T>;
#elif !defined(PHMAP_USE_ABSL_HASH)
// ---------------------------------------------------------------
// phmap::Hash
// ---------------------------------------------------------------
template <class T>
struct Hash
{
template <class U, typename std::enable_if<has_hash_value<U>::value, int>::type = 0>
size_t _hash(const T& val) const
{
return hash_value(val);
}
template <class U, typename std::enable_if<!has_hash_value<U>::value, int>::type = 0>
size_t _hash(const T& val) const
{
return std::hash<T>()(val);
}
inline size_t operator()(const T& val) const
{
return _hash<T>(val);
}
};
template <class T>
struct Hash<T *>
{
inline size_t operator()(const T *val) const noexcept
{
return static_cast<size_t>(reinterpret_cast<const uintptr_t>(val));
}
};
template<class ArgumentType, class ResultType>
struct phmap_unary_function
{
typedef ArgumentType argument_type;
typedef ResultType result_type;
};
template <>
struct Hash<bool> : public phmap_unary_function<bool, size_t>
{
inline size_t operator()(bool val) const noexcept
{ return static_cast<size_t>(val); }
};
template <>
struct Hash<char> : public phmap_unary_function<char, size_t>
{
inline size_t operator()(char val) const noexcept
{ return static_cast<size_t>(val); }
};
template <>
struct Hash<signed char> : public phmap_unary_function<signed char, size_t>
{
inline size_t operator()(signed char val) const noexcept
{ return static_cast<size_t>(val); }
};
template <>
struct Hash<unsigned char> : public phmap_unary_function<unsigned char, size_t>
{
inline size_t operator()(unsigned char val) const noexcept
{ return static_cast<size_t>(val); }
};
template <>
struct Hash<wchar_t> : public phmap_unary_function<wchar_t, size_t>
{
inline size_t operator()(wchar_t val) const noexcept
{ return static_cast<size_t>(val); }
};
template <>
struct Hash<int16_t> : public phmap_unary_function<int16_t, size_t>
{
inline size_t operator()(int16_t val) const noexcept
{ return static_cast<size_t>(val); }
};
template <>
struct Hash<uint16_t> : public phmap_unary_function<uint16_t, size_t>
{
inline size_t operator()(uint16_t val) const noexcept
{ return static_cast<size_t>(val); }
};
template <>
struct Hash<int32_t> : public phmap_unary_function<int32_t, size_t>
{
inline size_t operator()(int32_t val) const noexcept
{ return static_cast<size_t>(val); }
};
template <>
struct Hash<uint32_t> : public phmap_unary_function<uint32_t, size_t>
{
inline size_t operator()(uint32_t val) const noexcept
{ return static_cast<size_t>(val); }
};
template <>
struct Hash<int64_t> : public phmap_unary_function<int64_t, size_t>
{
inline size_t operator()(int64_t val) const noexcept
{ return fold_if_needed<sizeof(size_t)>()(static_cast<uint64_t>(val)); }
};
template <>
struct Hash<uint64_t> : public phmap_unary_function<uint64_t, size_t>
{
inline size_t operator()(uint64_t val) const noexcept
{ return fold_if_needed<sizeof(size_t)>()(val); }
};
template <>
struct Hash<float> : public phmap_unary_function<float, size_t>
{
inline size_t operator()(float val) const noexcept
{
// -0.0 and 0.0 should return same hash
uint32_t *as_int = reinterpret_cast<uint32_t *>(&val);
return (val == 0) ? static_cast<size_t>(0) :
static_cast<size_t>(*as_int);
}
};
template <>
struct Hash<double> : public phmap_unary_function<double, size_t>
{
inline size_t operator()(double val) const noexcept
{
// -0.0 and 0.0 should return same hash
uint64_t *as_int = reinterpret_cast<uint64_t *>(&val);
return (val == 0) ? static_cast<size_t>(0) :
fold_if_needed<sizeof(size_t)>()(*as_int);
}
};
#endif
template <class H, int sz> struct Combiner
{
H operator()(H seed, size_t value);
};
template <class H> struct Combiner<H, 4>
{
H operator()(H seed, size_t value)
{
return seed ^ (value + 0x9e3779b9 + (seed << 6) + (seed >> 2));
}
};
template <class H> struct Combiner<H, 8>
{
H operator()(H seed, size_t value)
{
return seed ^ (value + size_t(0xc6a4a7935bd1e995) + (seed << 6) + (seed >> 2));
}
};
// define HashState to combine member hashes... see example below
// -----------------------------------------------------------------------------
template <typename H>
class HashStateBase {
public:
template <typename T, typename... Ts>
static H combine(H state, const T& value, const Ts&... values);
static H combine(H state) { return state; }
};
template <typename H>
template <typename T, typename... Ts>
H HashStateBase<H>::combine(H seed, const T& v, const Ts&... vs)
{
return HashStateBase<H>::combine(Combiner<H, sizeof(H)>()(
seed, phmap::Hash<T>()(v)),
vs...);
}
using HashState = HashStateBase<size_t>;
// -----------------------------------------------------------------------------
#if !defined(PHMAP_USE_ABSL_HASH)
// define Hash for std::pair
// -------------------------
template<class T1, class T2>
struct Hash<std::pair<T1, T2>> {
size_t operator()(std::pair<T1, T2> const& p) const noexcept {
return phmap::HashState().combine(phmap::Hash<T1>()(p.first), p.second);
}
};
// define Hash for std::tuple
// --------------------------
template<class... T>
struct Hash<std::tuple<T...>> {
size_t operator()(std::tuple<T...> const& t) const noexcept {
return _hash_helper(t);
}
private:
template<size_t I = 0, class ...P>
typename std::enable_if<I == sizeof...(P), size_t>::type
_hash_helper(const std::tuple<P...> &) const noexcept { return 0; }
template<size_t I = 0, class ...P>
typename std::enable_if<I < sizeof...(P), size_t>::type
_hash_helper(const std::tuple<P...> &t) const noexcept {
const auto &el = std::get<I>(t);
using el_type = typename std::remove_cv<typename std::remove_reference<decltype(el)>::type>::type;
return Combiner<size_t, sizeof(size_t)>()(
phmap::Hash<el_type>()(el), _hash_helper<I + 1>(t));
}
};
#endif
} // namespace phmap
#ifdef _MSC_VER
#pragma warning(pop)
#endif
#endif // phmap_utils_h_guard_
import torch
import torch.multiprocessing as mp
from typing import Optional, Tuple
from base import BaseSampler, NegativeSampling, SampleOutput
from neighbor_sampler import NeighborSampler, SampleType
class RandomWalkSampler(BaseSampler):
def __init__(
self,
num_nodes: int,
num_layers: int,
graph_data,
workers = 1,
tnb = None,
is_distinct = 0,
policy = "uniform",
edge_weight: Optional[torch.Tensor] = None,
graph_name = None
) -> None:
r"""__init__
Args:
num_nodes: the num of all nodes in the graph
num_layers: the num of layers to be sampled
workers: the number of threads, default value is 1
tnb: neighbor infomation table
is_distinct: 1-need distinct, 0-don't need distinct
policy: "uniform" or "recent" or "weighted"
is_root_ts: 1-base on root's ts, 0-base on parent node's ts
edge_weight: the initial weights of edges
graph_name: the name of graph
"""
super().__init__()
self.sampler = NeighborSampler(
num_nodes=num_nodes,
tnb=tnb,
num_layers=num_layers,
fanout=[1 for i in range(num_layers)],
graph_data=graph_data,
edge_weight = edge_weight,
workers=workers,
policy=policy,
graph_name=graph_name,
is_distinct = is_distinct
)
self.num_layers = num_layers
def _get_sample_info(self):
return self.num_nodes,self.num_layers,self.fanout,self.workers
def _get_sample_options(self):
return {"is_distinct" : self.is_distinct,
"policy" : self.policy,
"with_eid" : self.tnb.with_eid,
"weighted" : self.tnb.weighted,
"with_timestamp" : self.tnb.with_timestamp}
def insert_edges_with_timestamp(
self,
edge_index : torch.Tensor,
eid : torch.Tensor,
timestamp : torch.Tensor,
edge_weight : Optional[torch.Tensor] = None):
row, col = edge_index
# 更新节点数和tnb
self.num_nodes = self.tnb.update_neighbors_with_time(
row.contiguous(),
col.contiguous(),
timestamp.contiguous(),
eid.contiguous(),
self.is_distinct,
edge_weight.contiguous())
def update_edges_weight(
self,
edge_index : torch.Tensor,
eid : torch.Tensor,
edge_weight : Optional[torch.Tensor] = None):
row, col = edge_index
# 更新tnb的权重信息
if self.tnb.with_eid:
self.tnb.update_edge_weight(
eid.contiguous(),
col.contiguous(),
edge_weight.contiguous()
)
else:
self.tnb.update_edge_weight(
row.contiguous(),
col.contiguous(),
edge_weight.contiguous()
)
def update_nodes_weight(
self,
nid : torch.Tensor,
node_weight : Optional[torch.Tensor] = None):
# 更新tnb的权重信息
self.tnb.update_node_weight(
nid.contiguous(),
node_weight.contiguous()
)
def update_all_node_weight(
self,
node_weight : torch.Tensor):
# 更新tnb的权重信息
self.tnb.update_all_node_weight(node_weight.contiguous())
def sample_from_nodes(
self,
nodes: torch.Tensor,
with_outer_sample: SampleType,
ts: Optional[torch.Tensor] = None
) -> SampleOutput:
r"""Performs mutilayer sampling from the nodes specified in: nodes
The specific number of layers is determined by parameter: num_layers
returning a sampled subgraph in the specified output format: Tuple[torch.Tensor, list].
Args:
nodes: the list of seed nodes index
with_outer_sample: 0-sample in whole graph structure; 1-sample onehop outer nodel; 2-cross partition sampling
Returns:
sampled_nodes: the node sampled
sampled_edge_index: the edge sampled
"""
return self.sampler.sample_from_nodes(nodes, ts, with_outer_sample)
def sample_from_edges(
self,
edges: torch.Tensor,
ets: Optional[torch.Tensor] = None,
neg_sampling: Optional[NegativeSampling] = None,
with_outer_sample: SampleType = SampleType.Whole
) -> SampleOutput:
r"""Performs sampling from the edges specified in :obj:`index`,
returning a sampled subgraph in the specified output format.
Args:
edges: the list of seed edges index
with_outer_sample: 0-sample in whole graph structure; 1-sample onehop outer nodel; 2-cross partition sampling
edge_label: the label for the seed edges.
neg_sampling: The negative sampling configuration
Returns:
sampled_nodes: the nodes sampled
sampled_edge_index_list: the edges sampled
"""
return self.sampler.sample_from_edges(edges, ets, neg_sampling, with_outer_sample)
if __name__=="__main__":
edge_index1 = torch.tensor([[0, 1, 1, 1, 1, 2, 2, 2, 2, 4, 4, 4, 5], # , 3, 3
[1, 0, 2, 0, 4, 1, 3, 0, 3, 3, 5, 0, 2]])# , 2, 5
timeStamp=torch.FloatTensor([1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4])
edge_weight1 = None
num_nodes1 = 6
num_neighbors = 2
# Run the neighbor sampling
from Utils import GraphData
g_data = GraphData(id=0, edge_index=edge_index1, timestamp=timeStamp, data=None, partptr=torch.tensor([0, num_nodes1]))
# Run the random walk sampling
sampler=RandomWalkSampler(num_nodes=num_nodes1,
num_layers=3,
edge_weight=edge_weight1,
graph_data=g_data,
graph_name='a',
workers=4,
is_root_ts=0,
is_distinct = 0)
out = sampler.sample_from_nodes(torch.tensor([1,2]),
with_outer_sample=SampleType.Whole,
ts=torch.tensor([1, 2]))
# out = sampler.sample_from_edges(torch.tensor([[1,2],[4,0]]),
# with_outer_sample=SampleType.Whole,
# ets = torch.tensor([1, 2]))
# Print the result
print('node:', out.node)
print('edge_index_list:', out.edge_index_list)
print('eid_list:', out.eid_list)
print('eid_ts_list:', out.eid_ts_list)
print('metadata: ', out.metadata)
import argparse
import random
import pandas as pd
import numpy as np
import torch
from ogb.nodeproppred import PygNodePropPredDataset
from torch_geometric.datasets import Reddit
import time
from tqdm import tqdm
from Utils import GraphData
class NegLinkSampler:
def __init__(self, num_nodes):
self.num_nodes = num_nodes
def sample(self, n):
return np.random.randint(self.num_nodes, size=n)
class NegLinkInductiveSampler:
def __init__(self, nodes):
self.nodes = list(nodes)
def sample(self, n):
return np.random.choice(self.nodes, size=n)
def load_reddit_dataset(data_path):
df = pd.read_csv('/home/zlj/hzq/project/code/TGL/DATA/{}/edges.csv'.format("REDDIT"))
num_nodes = max(int(df['src'].max()), int(df['dst'].max())) + 1
src = torch.tensor(df['src'].to_numpy(dtype=int))
dst = torch.tensor(df['dst'].to_numpy(dtype=int))
edge_index = torch.stack([src, dst])
timestamp = torch.tensor(df['time']).float()
g = GraphData(0, edge_index, timestamp=timestamp, data=None, partptr=torch.tensor([0, num_nodes]))
return g, df
parser=argparse.ArgumentParser()
parser.add_argument('--data', type=str, help='dataset name',default="REDDIT")
parser.add_argument('--config', type=str, help='path to config file',default="/home/zlj/hzq/project/code/TGL/config/TGN.yml")
parser.add_argument('--batch_size', type=int, default=600, help='path to config file')
parser.add_argument('--num_thread', type=int, default=64, help='number of thread')
args=parser.parse_args()
seed=10
torch.manual_seed(seed) # 为CPU设置随机种子
torch.cuda.manual_seed(seed) # 为当前GPU设置随机种子
torch.cuda.manual_seed_all(seed) # if you are using multi-GPU,为所有GPU设置随机种子
np.random.seed(seed) # Numpy module.
random.seed(seed) # Python random module.
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
g_data, df = load_reddit_dataset("/home/zlj/hzq/code/dataset/reddit")
print(g_data)
# for worker in [1,2,3,4,5,6,7,8,9,10,20,30]:
# import random
# timestamp = [random.randint(1, 5) for i in range(0, g.num_edges)]
# timestamp = torch.FloatTensor(timestamp)
# print('begin load')
# pre = time.time()
# # timestamp = torch.load('/home/zlj/hzq/code/gnn/my_sampler/TemporalSample/timestamp.my')
# tnb = torch.load("tnb_reddit_before.my")
# end = time.time()
# print("load time:", end-pre)
# row, col = g.edge_index
edge_weight=None
# g_data = GraphData(id=1, edge_index=g.edge_index, timestamp=timestamp, data=g, partptr=torch.tensor([0, g.num_nodes//4, g.num_nodes//4*2, g.num_nodes//4*3, g.num_nodes]))
from neighbor_sampler import NeighborSampler, SampleType, get_neighbors
print('begin tnb')
row, col = g_data.edge_index
row = torch.cat([row, col])
col = torch.cat([col, row])
eid = torch.cat([g_data.eid, g_data.eid])
timestamp = torch.cat([g_data.edge_ts, g_data.edge_ts])
timestamp,ind = timestamp.sort()
timestamp = timestamp.float().contiguous()
eid = eid[ind].contiguous()
row = row[ind]
col = col[ind]
print(row, col)
pre = time.time()
tnb = get_neighbors("reddit", row.contiguous(), col.contiguous(), g_data.num_nodes, 0, eid, edge_weight, timestamp)
end = time.time()
print("init tnb time:", end-pre)
torch.save(tnb, "tnb_reddit_before.my")
pre = time.time()
sampler = NeighborSampler(g_data.num_nodes,
tnb=tnb,
num_layers=1,
fanout=[10],
graph_data=g_data,
workers=10,
policy="recent",
graph_name='a')
end = time.time()
print("init time:", end-pre)
# neg_link_sampler = NegLinkSampler(g_data.num_nodes)
from base import NegativeSampling, NegativeSamplingMode
neg_link_sampler = NegativeSampling(NegativeSamplingMode.triplet)
# from torch_geometric.sampler import NeighborSampler, NumNeighbors, NodeSamplerInput, SamplerOutput
# pre = time.time()
# num_nei = NumNeighbors([100, 100])
# node_idx = NodeSamplerInput(input_id=None, node=torch.tensor(range(g.num_nodes//4, g.num_nodes//4+600000)))# (input_id=None, node=torch.masked_select(torch.arange(g.num_nodes),node_data['train_mask']))
# sampler = NeighborSampler(g, num_nei)
# end = time.time()
# print("init time:", end-pre)
out = []
tot_time = 0
sam_time = 0
sam_edge = 0
pre = time.time()
for _, rows in tqdm(df.groupby(df.index // args.batch_size), total=len(df) // args.batch_size):
# root_nodes = torch.tensor(np.concatenate([rows.src.values, rows.dst.values, neg_link_sampler.sample(len(rows))])).long()
# ts = torch.tensor(np.concatenate([rows.time.values, rows.time.values, rows.time.values]).astype(np.float32))
# outi = sampler.sample_from_nodes(root_nodes, ts=ts)
edges = torch.tensor(np.stack([rows.src.values, rows.dst.values])).long()
outi, meta = sampler.sample_from_edges(edges=edges, ets=torch.tensor(rows.time.values).float(), neg_sampling=neg_link_sampler)
tot_time += outi[0].tot_time
sam_time += outi[0].sample_time
# print(outi[0].sample_edge_num)
sam_edge += outi[0].sample_edge_num
out.append(outi)
end = time.time()
print("sample time", end-pre)
print("tot_time", tot_time)
print("sam_time", sam_time)
print("sam_edge", sam_edge)
print('eid_list:', out[23][0].eid())
# print('delta_ts_list:', out[10][0].delta_ts)
print('node:', out[23][0].sample_nodes())
# print('node_ts:', out[23][0].sample_nodes_ts)
# print('eid_list:', out[23][1].eid)
# print('node:', out[23][1].sample_nodes)
# print('node_ts:', out[23][1].sample_nodes_ts)
# print('edge_index_list:', out[0][0].edge_index)
\ No newline at end of file
#include <iostream>
#include <omp.h>
#include <pybind11/pybind11.h>
#include <pybind11/numpy.h>
#include <pybind11/stl.h>
#include <torch/extension.h>
#include <time.h>
#include <random>
#include <phmap.h>
#include <boost/thread/mutex.hpp>
#define MTX boost::mutex
using namespace std;
namespace py = pybind11;
namespace th = torch;
typedef int64_t NodeIDType;
typedef int64_t EdgeIDType;
typedef float WeightType;
typedef float TimeStampType;
#define EXTRAARGS , phmap::priv::hash_default_hash<K>, \
phmap::priv::hash_default_eq<K>, \
std::allocator<K>, 4, MTX
template <class K>
using HashT = phmap::parallel_flat_hash_set<K EXTRAARGS>;
template <class K, class V>
using HashM = phmap::parallel_flat_hash_map<K, V EXTRAARGS>;
class TemporalNeighborBlock;
class TemporalGraphBlock;
TemporalNeighborBlock& get_neighbors(string graph_name, th::Tensor row, th::Tensor col, int64_t num_nodes, int is_distinct, optional<th::Tensor> eid, optional<th::Tensor> edge_weight, optional<th::Tensor> time);
th::Tensor heads_unique(th::Tensor array, th::Tensor heads, int threads);
int nodeIdToInOut(NodeIDType nid, int pid, const vector<NodeIDType>& part_ptr);
int nodeIdToPartId(NodeIDType nid, const vector<NodeIDType>& part_ptr);
vector<th::Tensor> divide_nodes_to_part(th::Tensor nodes, const vector<NodeIDType>& part_ptr, int threads);
// vector<int64_t> sample_multinomial(vector<WeightType> weights, int num_samples, bool replacement, default_random_engine e);
NodeIDType sample_multinomial(const vector<WeightType>& weights, default_random_engine& e);
template<typename T>
inline py::array vec2npy(const std::vector<T> &vec)
{
// need to let python garbage collector handle C++ vector memory
// see https://github.com/pybind/pybind11/issues/1042
// non-copy value transfer
auto v = new std::vector<T>(vec);
auto capsule = py::capsule(v, [](void *v)
{ delete reinterpret_cast<std::vector<T> *>(v); });
return py::array(v->size(), v->data(), capsule);
// return py::array(vec.size(), vec.data());
}
/*
* NeighborSampler Utils
*/
class TemporalNeighborBlock
{
public:
vector<vector<NodeIDType>> neighbors;
vector<vector<TimeStampType>> timestamp;
vector<vector<EdgeIDType>> eid;
vector<vector<WeightType>> edge_weight;
vector<phmap::parallel_flat_hash_map<NodeIDType, int64_t>> inverted_index;
vector<int64_t> deg;
vector<phmap::parallel_flat_hash_set<NodeIDType>> neighbors_set;
bool with_eid = false;
bool weighted = false;
bool with_timestamp = false;
TemporalNeighborBlock(){}
// TemporalNeighborBlock(const TemporalNeighborBlock &tnb);
TemporalNeighborBlock(vector<vector<NodeIDType>>& neighbors,
vector<int64_t> &deg):
neighbors(neighbors), deg(deg){}
TemporalNeighborBlock(vector<vector<NodeIDType>>& neighbors,
vector<vector<WeightType>>& edge_weight,
vector<vector<EdgeIDType>>& eid,
vector<int64_t> &deg):
neighbors(neighbors), edge_weight(edge_weight),eid(eid), deg(deg)
{ this->with_eid=true;
this->weighted=true; }
TemporalNeighborBlock(vector<vector<NodeIDType>>& neighbors,
vector<vector<WeightType>>& edge_weight,
vector<vector<TimeStampType>>& timestamp,
vector<vector<EdgeIDType>>& eid,
vector<int64_t> &deg):
neighbors(neighbors), edge_weight(edge_weight), timestamp(timestamp),eid(eid), deg(deg)
{ this->with_eid=true;
this->weighted=true;
this->with_timestamp=true;}
py::array get_node_neighbor(NodeIDType node_id){
return vec2npy(neighbors[node_id]);
}
py::array get_node_neighbor_timestamp(NodeIDType node_id){
return vec2npy(timestamp[node_id]);
}
int64_t get_node_deg(NodeIDType node_id){
return deg[node_id];
}
bool empty(){
return this->deg.empty();
}
void update_edge_weight(th::Tensor row_or_eid, th::Tensor col, th::Tensor edge_weight);
void update_node_weight(th::Tensor nid, th::Tensor node_weight);
void update_all_node_weight(th::Tensor node_weight);
int64_t update_neighbors_with_time(th::Tensor row, th::Tensor col, th::Tensor time, th::Tensor eid, int is_distinct, std::optional<th::Tensor> edge_weight);
std::string serialize() const {
std::ostringstream oss;
// 序列化基本类型成员
oss << with_eid << " " << weighted << " " << with_timestamp << " ";
// 序列化 vector<vector<T>> 类型成员
auto serializeVecVec = [&oss](const auto& vecVec) {
for (const auto& vec : vecVec) {
oss << vec.size() << " ";
for (const auto& elem : vec) {
oss << elem << " ";
}
}
oss << "|"; // 添加一个分隔符以区分不同的 vector
};
serializeVecVec(neighbors);
serializeVecVec(timestamp);
serializeVecVec(eid);
serializeVecVec(edge_weight);
// 序列化 vector<int64_t> 类型成员
oss << deg.size() << " ";
for (const auto& d : deg) {
oss << d << " ";
}
oss << "|";
// 序列化 inverted_index
for (const auto& map : inverted_index) {
oss << map.size() << " ";
for (const auto& [key, value] : map) {
oss << key << " " << value << " ";
}
}
oss << "|";
// 序列化 neighbors_set
for (const auto& set : neighbors_set) {
oss << set.size() << " ";
for (const auto& elem : set) {
oss << elem << " ";
}
}
oss << "|";
return oss.str();
}
static TemporalNeighborBlock deserialize(const std::string& s) {
std::istringstream iss(s);
TemporalNeighborBlock tnb;
// 反序列化基本类型成员
iss >> tnb.with_eid >> tnb.weighted >> tnb.with_timestamp;
// 反序列化 vector<vector<T>> 类型成员
auto deserializeVecLong = [&iss](vector<vector<int64_t>>& vecVec) {
std::string segment;
std::getline(iss, segment, '|');
std::istringstream vec_iss(segment);
while (!vec_iss.eof()) {
size_t vec_size;
vec_iss >> vec_size;
if (vec_iss.eof()) break; // 防止多余的空白
vector<int64_t> vec(vec_size);
for (size_t i = 0; i < vec_size; ++i) {
vec_iss >> vec[i];
}
vecVec.push_back(vec);
}
};
auto deserializeVecFloat = [&iss](vector<vector<float>>& vecVec) {
std::string segment;
std::getline(iss, segment, '|');
std::istringstream vec_iss(segment);
while (!vec_iss.eof()) {
size_t vec_size;
vec_iss >> vec_size;
if (vec_iss.eof()) break; // 防止多余的空白
vector<float> vec(vec_size);
for (size_t i = 0; i < vec_size; ++i) {
vec_iss >> vec[i];
}
vecVec.push_back(vec);
}
};
deserializeVecLong(tnb.neighbors);
deserializeVecFloat(tnb.timestamp);
deserializeVecLong(tnb.eid);
deserializeVecFloat(tnb.edge_weight);
std::string segment;
// 反序列化 vector<int64_t> 类型成员
segment="";
std::getline(iss, segment, '|');
std::istringstream vec_iss(segment);
size_t vec_size;
vec_iss >> vec_size;
tnb.deg.resize(vec_size);
for (size_t i = 0; i < vec_size; ++i) {
vec_iss >> tnb.deg[i];
}
// 反序列化 inverted_index
segment="";
std::getline(iss, segment, '|');
std::istringstream map_iss(segment);
while (!map_iss.eof()) {
size_t map_size;
map_iss >> map_size;
if (map_iss.eof()) break;
phmap::parallel_flat_hash_map<NodeIDType, int64_t> map;
for (size_t i = 0; i < map_size; ++i) {
NodeIDType key;
int64_t value;
map_iss >> key >> value;
map[key] = value;
}
tnb.inverted_index.push_back(map);
}
// 反序列化 neighbors_set
std::getline(iss, segment, '|');
std::istringstream set_iss(segment);
while (!set_iss.eof()) {
size_t set_size;
set_iss >> set_size;
if (set_iss.eof()) break;
phmap::parallel_flat_hash_set<NodeIDType> set;
for (size_t i = 0; i < set_size; ++i) {
NodeIDType elem;
set_iss >> elem;
set.insert(elem);
}
tnb.neighbors_set.push_back(set);
}
return tnb;
}
};
class TemporalGraphBlock
{
public:
vector<NodeIDType> row;
vector<NodeIDType> col;
vector<EdgeIDType> eid;
vector<TimeStampType> delta_ts;
vector<int64_t> src_index;
vector<NodeIDType> sample_nodes;
vector<TimeStampType> sample_nodes_ts;
double sample_time = 0;
double tot_time = 0;
int64_t sample_edge_num = 0;
TemporalGraphBlock(){}
// TemporalGraphBlock(const TemporalGraphBlock &tgb);
TemporalGraphBlock(vector<NodeIDType> &_row, vector<NodeIDType> &_col,
vector<NodeIDType> &_sample_nodes):
row(_row), col(_col), sample_nodes(_sample_nodes){}
TemporalGraphBlock(vector<NodeIDType> &_row, vector<NodeIDType> &_col,
vector<NodeIDType> &_sample_nodes,
vector<TimeStampType> &_sample_nodes_ts):
row(_row), col(_col), sample_nodes(_sample_nodes),
sample_nodes_ts(_sample_nodes_ts){}
};
class T_TemporalGraphBlock
{
public:
th::Tensor row;
th::Tensor col;
th::Tensor eid;
th::Tensor delta_ts;
th::Tensor src_index;
th::Tensor sample_nodes;
th::Tensor sample_nodes_ts;
double sample_time = 0;
double tot_time = 0;
int64_t sample_edge_num = 0;
T_TemporalGraphBlock(){}
T_TemporalGraphBlock(th::Tensor &_row, th::Tensor &_col,
th::Tensor &_sample_nodes):
row(_row), col(_col), sample_nodes(_sample_nodes){}
};
// 辅助函数
template <typename T>
T* get_data_ptr(const th::Tensor& tensor) {
AT_ASSERTM(tensor.is_contiguous(), "Offset tensor must be contiguous");
AT_ASSERTM(tensor.dim() == 1, "Offset tensor must be one-dimensional");
return tensor.data_ptr<T>();
}
template <typename T>
torch::Tensor vecToTensor(const std::vector<T>& vec) {
// 确定数据类型
torch::ScalarType dtype;
if (std::is_same<T, int64_t>::value) {
dtype = torch::kInt64;
} else if (std::is_same<T, float>::value) {
dtype = torch::kFloat32;
} else {
throw std::runtime_error("Unsupported data type");
}
// 创建Tensor
torch::Tensor tensor = torch::from_blob(
const_cast<T*>(vec.data()), /* 数据指针 */
{static_cast<long>(vec.size())}, /* 尺寸 */
dtype /* 数据类型 */
);
return tensor;//.clone(); // 克隆Tensor以拷贝数据
}
TemporalNeighborBlock& get_neighbors(
string graph_name, th::Tensor row, th::Tensor col, int64_t num_nodes, int is_distinct, optional<th::Tensor> eid, optional<th::Tensor> edge_weight, optional<th::Tensor> time)
{ //row、col、time按time升序排列,由时间早的到时间晚的
auto src = get_data_ptr<NodeIDType>(row);
auto dst = get_data_ptr<NodeIDType>(col);
EdgeIDType* eid_ptr = eid ? get_data_ptr<EdgeIDType>(eid.value()) : nullptr;
WeightType* ew = edge_weight ? get_data_ptr<WeightType>(edge_weight.value()) : nullptr;
TimeStampType* t = time ? get_data_ptr<TimeStampType>(time.value()) : nullptr;
int64_t edge_num = row.size(0);
static phmap::parallel_flat_hash_map<string, TemporalNeighborBlock> tnb_map;
if(tnb_map.count(graph_name)==1)
return tnb_map[graph_name];
tnb_map[graph_name] = TemporalNeighborBlock();
TemporalNeighborBlock& tnb = tnb_map[graph_name];
double start_time = omp_get_wtime();
//初始化
tnb.neighbors.resize(num_nodes);
tnb.deg.resize(num_nodes, 0);
//初始化optional相关
tnb.with_eid = eid.has_value();
tnb.weighted = edge_weight.has_value();
tnb.with_timestamp = time.has_value();
if (tnb.with_eid) tnb.eid.resize(num_nodes);
if (tnb.weighted) {
tnb.edge_weight.resize(num_nodes);
tnb.inverted_index.resize(num_nodes);
}
if (tnb.with_timestamp) tnb.timestamp.resize(num_nodes);
//计算, 条件判断移出循环优化执行效率
for(int64_t i=0; i<edge_num; i++){
//计算节点邻居
tnb.neighbors[dst[i]].emplace_back(src[i]);
}
//如果有eid,插入
if(tnb.with_eid)
for(int64_t i=0; i<edge_num; i++){
tnb.eid[dst[i]].emplace_back(eid_ptr[i]);
}
//如果有权重信息,插入节点与邻居边的权重和反向索引
if(tnb.weighted)
for(int64_t i=0; i<edge_num; i++){
tnb.edge_weight[dst[i]].emplace_back(ew[i]);
if(tnb.with_eid) tnb.inverted_index[dst[i]][eid_ptr[i]]=tnb.neighbors[dst[i]].size()-1;
else tnb.inverted_index[dst[i]][src[i]]=tnb.neighbors[dst[i]].size()-1;
}
//如果有时序信息,插入节点与邻居边的时间
if(tnb.with_timestamp)
for(int64_t i=0; i<edge_num; i++){
tnb.timestamp[dst[i]].emplace_back(t[i]);
}
if(is_distinct){
for(int64_t i=0; i<num_nodes; i++){
//收集单边去重节点度
phmap::parallel_flat_hash_set<NodeIDType> temp_s;
temp_s.insert(tnb.neighbors[i].begin(), tnb.neighbors[i].end());
tnb.neighbors_set.emplace_back(temp_s);
tnb.deg[i] = tnb.neighbors_set[i].size();
}
}
else{
for(int64_t i=0; i<num_nodes; i++){
//收集单边节点度
tnb.deg[i] = tnb.neighbors[i].size();
}
}
double end_time = omp_get_wtime();
cout<<"get_neighbors consume: "<<end_time-start_time<<"s"<<endl;
return tnb;
}
void TemporalNeighborBlock::update_edge_weight(
th::Tensor row_or_eid, th::Tensor col, th::Tensor edge_weight){
AT_ASSERTM(this->weighted, "This Graph has no edge weight infomation");
auto dst = get_data_ptr<NodeIDType>(col);
WeightType* ew = get_data_ptr<WeightType>(edge_weight);
NodeIDType* src;
EdgeIDType* eid_ptr;
if(this->with_eid) eid_ptr = get_data_ptr<EdgeIDType>(row_or_eid);
else src = get_data_ptr<NodeIDType>(row_or_eid);
int64_t edge_num = col.size(0);
for(int64_t i=0; i<edge_num; i++){
//修改节点与邻居边的权重
AT_ASSERTM(this->inverted_index[dst[i]].count(src[i])==1, "Unexist Edge Index: "+to_string(src[i])+", "+to_string(dst[i]));
int index;
if(this->with_eid) index = this->inverted_index[dst[i]][eid_ptr[i]];
else index = this->inverted_index[dst[i]][src[i]];
this->edge_weight[dst[i]][index] = ew[i];
}
}
void TemporalNeighborBlock:: update_node_weight(th::Tensor nid, th::Tensor node_weight){
AT_ASSERTM(this->weighted, "This Graph has no edge weight infomation");
auto dst = get_data_ptr<NodeIDType>(nid);
WeightType* nw = get_data_ptr<WeightType>(node_weight);
int64_t node_num = nid.size(0);
for(int64_t i=0; i<node_num; i++){
//修改节点与邻居边的权重
AT_ASSERTM(dst[i]<this->deg.size(), "Unexist Node Index: "+to_string(dst[i]));
if(this->inverted_index[dst[i]].empty())
return;
for(auto index : this->inverted_index[dst[i]]){
this->edge_weight[dst[i]][index.second] = nw[i];
}
}
}
void TemporalNeighborBlock:: update_all_node_weight(th::Tensor node_weight){
AT_ASSERTM(this->weighted, "This Graph has no edge weight infomation");
WeightType* nw = get_data_ptr<WeightType>(node_weight);
int64_t node_num = node_weight.size(0);
AT_ASSERTM(node_num==this->neighbors.size(), "The tensor node_weight size is not suitable node number.");
for(int64_t i=0; i<node_num; i++){
//修改节点与邻居边的权重
for(int j=0; j<this->neighbors[i].size();j++){
this->edge_weight[i][j] = nw[this->neighbors[i][j]];
}
}
}
int64_t TemporalNeighborBlock::update_neighbors_with_time(
th::Tensor row, th::Tensor col, th::Tensor time,th::Tensor eid, int is_distinct, std::optional<th::Tensor> edge_weight){
//row、col、time按time升序排列,由时间早的到时间晚的
AT_ASSERTM(this->empty(), "Empty TemporalNeighborBlock, please use get_neighbors_with_time");
AT_ASSERTM(this->with_timestamp == true, "This Graph has no time infomation!");
auto src = get_data_ptr<NodeIDType>(row);
auto dst = get_data_ptr<NodeIDType>(col);
auto eid_ptr = get_data_ptr<EdgeIDType>(eid);
auto t = get_data_ptr<TimeStampType>(time);
WeightType* ew = edge_weight ? get_data_ptr<WeightType>(edge_weight.value()) : nullptr;
int64_t edge_num = row.size(0);
int64_t num_nodes = this->neighbors.size();
//处理optional的值
if(edge_weight.has_value()){
AT_ASSERTM(this->weighted == true, "This Graph has no edge weight");
}
if(this->weighted){
AT_ASSERTM(edge_weight.has_value(), "This Graph need edge weight");
}
// double start_time = omp_get_wtime();
if(is_distinct){
for(int64_t i=0; i<edge_num; i++){
//如果有新节点
if(dst[i]>=num_nodes){
num_nodes = dst[i]+1;
this->neighbors.resize(num_nodes);
this->deg.resize(num_nodes, 0);
this->eid.resize(num_nodes);
this->timestamp.resize(num_nodes);
//初始化optional相关
if (this->weighted) {
this->edge_weight.resize(num_nodes);
this->inverted_index.resize(num_nodes);
}
}
//更新节点邻居
this->neighbors[dst[i]].emplace_back(src[i]);
//插入eid
this->eid[dst[i]].emplace_back(eid_ptr[i]);
//插入节点与邻居边的时间
this->timestamp[dst[i]].emplace_back(t[i]);
//如果有权重信息,插入节点与邻居边的权重和反向索引
if(this->weighted){
this->edge_weight[dst[i]].emplace_back(ew[i]);
if(this->with_eid) this->inverted_index[dst[i]][eid_ptr[i]]=this->neighbors[dst[i]].size()-1;
else this->inverted_index[dst[i]][src[i]]=this->neighbors[dst[i]].size()-1;
}
this->neighbors_set[dst[i]].insert(src[i]);
this->deg[dst[i]]=this->neighbors_set[dst[i]].size();
}
}
else{
for(int64_t i=0; i<edge_num; i++){
//更新节点邻居
this->neighbors[dst[i]].emplace_back(src[i]);
//插入eid
this->eid[dst[i]].emplace_back(eid_ptr[i]);
//插入节点与邻居边的时间
this->timestamp[dst[i]].emplace_back(t[i]);
//如果有权重信息,插入节点与邻居边的权重和反向索引
if(this->weighted){
this->edge_weight[dst[i]].emplace_back(ew[i]);
this->inverted_index[dst[i]][src[i]]=this->neighbors[dst[i]].size()-1;
}
this->deg[dst[i]]=this->neighbors[dst[i]].size();
}
}
// double end_time = omp_get_wtime();
// cout<<"update_neighbors consume: "<<end_time-start_time<<"s"<<endl;
return num_nodes;
}
class ParallelSampler
{
public:
TemporalNeighborBlock& tnb;
NodeIDType num_nodes;
EdgeIDType num_edges;
int threads;
vector<int> fanouts;
// vector<NodeIDType> part_ptr;
// int pid;
int num_layers;
string policy;
std::vector<TemporalGraphBlock> ret;
ParallelSampler(TemporalNeighborBlock& _tnb, NodeIDType _num_nodes, EdgeIDType _num_edges, int _threads,
vector<int>& _fanouts, int _num_layers, string _policy) :
tnb(_tnb), num_nodes(_num_nodes), num_edges(_num_edges), threads(_threads),
fanouts(_fanouts), num_layers(_num_layers), policy(_policy)
{
omp_set_num_threads(_threads);
ret.clear();
ret.resize(_num_layers);
}
void reset()
{
ret.clear();
ret.resize(num_layers);
}
void neighbor_sample_from_nodes(th::Tensor nodes, optional<th::Tensor> root_ts);
void neighbor_sample_from_nodes_static(th::Tensor nodes);
void neighbor_sample_from_nodes_static_layer(th::Tensor nodes, int cur_layer);
void neighbor_sample_from_nodes_with_before(th::Tensor nodes, th::Tensor root_ts);
void neighbor_sample_from_nodes_with_before_layer(th::Tensor nodes, th::Tensor root_ts, int cur_layer);
};
void ParallelSampler :: neighbor_sample_from_nodes(th::Tensor nodes, optional<th::Tensor> root_ts)
{
omp_set_num_threads(threads);
if(policy == "weighted")
AT_ASSERTM(tnb.weighted, "Tnb has no weight infomation!");
else if(policy == "recent")
AT_ASSERTM(tnb.with_timestamp, "Tnb has no timestamp infomation!");
else if(policy == "uniform")
;
else{
throw runtime_error("The policy \"" + policy + "\" is not exit!");
}
if(tnb.with_timestamp){
AT_ASSERTM(tnb.with_timestamp, "Tnb has no timestamp infomation!");
AT_ASSERTM(root_ts.has_value(), "Parameter mismatch!");
neighbor_sample_from_nodes_with_before(nodes, root_ts.value());
}
else{
neighbor_sample_from_nodes_static(nodes);
}
}
void ParallelSampler :: neighbor_sample_from_nodes_static_layer(th::Tensor nodes, int cur_layer){
py::gil_scoped_release release;
double tot_start_time = omp_get_wtime();
TemporalGraphBlock tgb = TemporalGraphBlock();
int fanout = fanouts[cur_layer];
ret[cur_layer] = TemporalGraphBlock();
auto nodes_data = get_data_ptr<NodeIDType>(nodes);
vector<phmap::parallel_flat_hash_set<NodeIDType>> node_s_threads(threads);
phmap::parallel_flat_hash_set<NodeIDType> node_s;
vector<vector<NodeIDType>> eid_threads(threads);//row_threads(threads),col_threads(threads);
AT_ASSERTM(tnb.with_eid, "Tnb has no eid infomation! We need eid!");
// double start_time = omp_get_wtime();
int reserve_capacity = int(ceil(nodes.size(0) / threads)) * fanout;
#pragma omp parallel
{
int tid = omp_get_thread_num();
unsigned int loc_seed = tid;
// row_threads[tid].reserve(reserve_capacity);
// col_threads[tid].reserve(reserve_capacity);
eid_threads[tid].reserve(reserve_capacity);
#pragma omp for schedule(static, int(ceil(static_cast<float>((nodes.size(0)) / threads))))
for(int64_t i=0; i<nodes.size(0); i++){
// int tid = omp_get_thread_num();
NodeIDType node = nodes_data[i];
vector<NodeIDType> nei(tnb.neighbors[node]);
vector<EdgeIDType> edge;
edge = tnb.eid[node];
double s_start_time = omp_get_wtime();
if(tnb.deg[node]>fanout){
phmap::flat_hash_set<NodeIDType> temp_s;
default_random_engine e(8);//(time(0));
uniform_int_distribution<> u(0, tnb.deg[node]-1);
while(temp_s.size()!=fanout){
//循环选择fanout个邻居
NodeIDType indice;
if(policy == "weighted"){//考虑边权重信息
const vector<WeightType>& ew = tnb.edge_weight[node];
indice = sample_multinomial(ew, e);
}
else if(policy == "uniform"){//均匀采样
indice = u(e);
}
auto chosen_n_iter = nei.begin() + indice;
auto rst = temp_s.insert(*chosen_n_iter);
if(rst.second){ //不重复
auto chosen_e_iter = edge.begin() + indice;
eid_threads[tid].emplace_back(*chosen_e_iter);
node_s_threads[tid].insert(*chosen_n_iter);
}
}
// row_threads[tid].insert(row_threads[tid].end(),temp_s.begin(),temp_s.end());
// col_threads[tid].insert(col_threads[tid].end(), fanout, node);
}
else{
node_s_threads[tid].insert(nei.begin(), nei.end());
// row_threads[tid].insert(row_threads[tid].end(),nei.begin(),nei.end());
// col_threads[tid].insert(col_threads[tid].end(), tnb.deg[node], node);
eid_threads[tid].insert(eid_threads[tid].end(),edge.begin(), edge.end());
}
if(tid==0)
ret[0].sample_time += omp_get_wtime() - s_start_time;
}
}
// double end_time = omp_get_wtime();
// cout<<"neighbor_sample_from_nodes parallel part consume: "<<end_time-start_time<<"s"<<endl;
int size = 0;
vector<int> each_begin(threads);
for(int i = 0; i<threads; i++){
int s = eid_threads[i].size();
each_begin[i]=size;
size += s;
}
// ret[cur_layer].row.resize(size);
// ret[cur_layer].col.resize(size);
ret[cur_layer].eid.resize(size);
#pragma omp parallel for schedule(static, 1)
for(int i = 0; i<threads; i++){
// copy(row_threads[i].begin(), row_threads[i].end(), ret[cur_layer].row.begin()+each_begin[i]);
// copy(col_threads[i].begin(), col_threads[i].end(), ret[cur_layer].col.begin()+each_begin[i]);
copy(eid_threads[i].begin(), eid_threads[i].end(), ret[cur_layer].eid.begin()+each_begin[i]);
}
for(int i = 0; i<threads; i++)
node_s.insert(node_s_threads[i].begin(), node_s_threads[i].end());
ret[cur_layer].sample_nodes.assign(node_s.begin(), node_s.end());
ret[0].tot_time += omp_get_wtime() - tot_start_time;
ret[0].sample_edge_num += ret[cur_layer].eid.size();
py::gil_scoped_acquire acquire;
}
void ParallelSampler :: neighbor_sample_from_nodes_static(th::Tensor nodes){
for(int i=0;i<num_layers;i++){
if(i==0) neighbor_sample_from_nodes_static_layer(nodes, i);
else neighbor_sample_from_nodes_static_layer(vecToTensor<NodeIDType>(ret[i-1].sample_nodes), i);
}
}
void ParallelSampler :: neighbor_sample_from_nodes_with_before_layer(
th::Tensor nodes, th::Tensor root_ts, int cur_layer){
py::gil_scoped_release release;
double tot_start_time = omp_get_wtime();
ret[cur_layer] = TemporalGraphBlock();
auto nodes_data = get_data_ptr<NodeIDType>(nodes);
auto ts_data = get_data_ptr<TimeStampType>(root_ts);
int fanout = fanouts[cur_layer];
// HashT<pair<NodeIDType,TimeStampType> > node_s;
vector<TemporalGraphBlock> tgb_i(threads);
default_random_engine e(8);//(time(0));
// double start_time = omp_get_wtime();
int reserve_capacity = int(ceil(nodes.size(0) / threads)) * fanout;
#pragma omp parallel
{
int tid = omp_get_thread_num();
unsigned int loc_seed = tid;
tgb_i[tid].sample_nodes.reserve(reserve_capacity);
tgb_i[tid].sample_nodes_ts.reserve(reserve_capacity);
tgb_i[tid].delta_ts.reserve(reserve_capacity);
tgb_i[tid].eid.reserve(reserve_capacity);
tgb_i[tid].src_index.reserve(reserve_capacity);
#pragma omp for schedule(static, int(ceil(static_cast<float>((nodes.size(0)) / threads))))
for(int64_t i=0; i<nodes.size(0); i++){
// int tid = omp_get_thread_num();
NodeIDType node = nodes_data[i];
TimeStampType rtts = ts_data[i];
int end_index = lower_bound(tnb.timestamp[node].begin(), tnb.timestamp[node].end(), rtts)-tnb.timestamp[node].begin();
// cout<<node<<" "<<end_index<<" "<<tnb.deg[node]<<endl;
double s_start_time = omp_get_wtime();
if ((policy == "recent") || (end_index <= fanout)){
int start_index = max(0, end_index-fanout);
tgb_i[tid].src_index.insert(tgb_i[tid].src_index.end(), end_index-start_index, i);
tgb_i[tid].sample_nodes.insert(tgb_i[tid].sample_nodes.end(), tnb.neighbors[node].begin()+start_index, tnb.neighbors[node].begin()+end_index);
tgb_i[tid].sample_nodes_ts.insert(tgb_i[tid].sample_nodes_ts.end(), tnb.timestamp[node].begin()+start_index, tnb.timestamp[node].begin()+end_index);
tgb_i[tid].eid.insert(tgb_i[tid].eid.end(), tnb.eid[node].begin()+start_index, tnb.eid[node].begin()+end_index);
for(int cid = start_index; cid < end_index;cid++){
tgb_i[tid].delta_ts.emplace_back(rtts-tnb.timestamp[node][cid]);
}
}
else{
//可选邻居边大于扇出的话需要随机选择fanout个邻居
tgb_i[tid].src_index.insert(tgb_i[tid].src_index.end(), fanout, i);
uniform_int_distribution<> u(0, end_index-1);
//cout<<end_index<<endl;
// cout<<"start:"<<start_index<<" end:"<<end_index<<endl;
for(int i=0; i<fanout;i++){
int cid;
if(policy == "uniform")
cid = u(e);
// cid = rand_r(&loc_seed) % (end_index);
else if(policy == "weighted"){
const vector<WeightType>& ew = tnb.edge_weight[node];
cid = sample_multinomial(ew, e);
}
tgb_i[tid].sample_nodes.emplace_back(tnb.neighbors[node][cid]);
tgb_i[tid].sample_nodes_ts.emplace_back(tnb.timestamp[node][cid]);
tgb_i[tid].delta_ts.emplace_back(rtts-tnb.timestamp[node][cid]);
tgb_i[tid].eid.emplace_back(tnb.eid[node][cid]);
}
}
if(tid==0)
ret[0].sample_time += omp_get_wtime() - s_start_time;
}
}
// double end_time = omp_get_wtime();
// cout<<"neighbor_sample_from_nodes parallel part consume: "<<end_time-start_time<<"s"<<endl;
// start_time = omp_get_wtime();
int size = 0;
vector<int> each_begin(threads);
for(int i = 0; i<threads; i++){
int s = tgb_i[i].eid.size();
each_begin[i]=size;
size += s;
}
ret[cur_layer].eid.resize(size);
ret[cur_layer].src_index.resize(size);
ret[cur_layer].delta_ts.resize(size);
ret[cur_layer].sample_nodes.resize(size);
ret[cur_layer].sample_nodes_ts.resize(size);
#pragma omp parallel for schedule(static, 1)
for(int i = 0; i<threads; i++){
copy(tgb_i[i].eid.begin(), tgb_i[i].eid.end(), ret[cur_layer].eid.begin()+each_begin[i]);
copy(tgb_i[i].src_index.begin(), tgb_i[i].src_index.end(), ret[cur_layer].src_index.begin()+each_begin[i]);
copy(tgb_i[i].delta_ts.begin(), tgb_i[i].delta_ts.end(), ret[cur_layer].delta_ts.begin()+each_begin[i]);
copy(tgb_i[i].sample_nodes.begin(), tgb_i[i].sample_nodes.end(), ret[cur_layer].sample_nodes.begin()+each_begin[i]);
copy(tgb_i[i].sample_nodes_ts.begin(), tgb_i[i].sample_nodes_ts.end(), ret[cur_layer].sample_nodes_ts.begin()+each_begin[i]);
}
// end_time = omp_get_wtime();
// cout<<"end union consume: "<<end_time-start_time<<"s"<<endl;
ret[0].tot_time += omp_get_wtime() - tot_start_time;
ret[0].sample_edge_num += ret[cur_layer].eid.size();
py::gil_scoped_acquire acquire;
}
void ParallelSampler :: neighbor_sample_from_nodes_with_before(th::Tensor nodes, th::Tensor root_ts){
for(int i=0;i<num_layers;i++){
if(i==0) neighbor_sample_from_nodes_with_before_layer(nodes, root_ts, i);
else neighbor_sample_from_nodes_with_before_layer(vecToTensor<NodeIDType>(ret[i-1].sample_nodes),
vecToTensor<TimeStampType>(ret[i-1].sample_nodes_ts), i);
}
}
/*-------------------------------------------------------------------------------------**
**------------Utils--------------------------------------------------------------------**
**-------------------------------------------------------------------------------------*/
th::Tensor heads_unique(th::Tensor array, th::Tensor heads, int threads){
auto array_ptr = array.data_ptr<NodeIDType>();
phmap::parallel_flat_hash_set<NodeIDType> s(array_ptr, array_ptr+array.numel());
if(heads.numel()==0) return th::tensor(vector<NodeIDType>(s.begin(), s.end()));
AT_ASSERTM(heads.is_contiguous(), "Offset tensor must be contiguous");
AT_ASSERTM(heads.dim() == 1, "0ffset tensor must be one-dimensional");
auto heads_ptr = heads.data_ptr<NodeIDType>();
#pragma omp parallel for num_threads(threads)
for(int64_t i=0; i<heads.size(0); i++){
if(s.count(heads_ptr[i])==1){
#pragma omp critical(erase)
s.erase(heads_ptr[i]);
}
}
vector<NodeIDType> ret;
ret.reserve(s.size()+heads.numel());
ret.assign(heads_ptr, heads_ptr+heads.numel());
ret.insert(ret.end(), s.begin(), s.end());
// cout<<"s: "<<s.size()<<" array: "<<array.size()<<endl;
return th::tensor(ret);
}
int nodeIdToPartId(NodeIDType nid, const vector<NodeIDType>& part_ptr){
int partitionId = -1;
for(int i=0;i<part_ptr.size()-1;i++){
if(nid>=part_ptr[i]&&nid<part_ptr[i+1]){
partitionId = i;
break;
}
}
if(partitionId<0) throw "nid 不存在对应的分区";
return partitionId;
}
//0:inner; 1:outer
int nodeIdToInOut(NodeIDType nid, int pid, const vector<NodeIDType>& part_ptr){
if(nid>=part_ptr[pid]&&nid<part_ptr[pid+1]){
return 0;
}
return 1;
}
vector<th::Tensor> divide_nodes_to_part(
th::Tensor nodes, const vector<NodeIDType>& part_ptr, int threads){
double start_time = omp_get_wtime();
AT_ASSERTM(nodes.is_contiguous(), "Offset tensor must be contiguous");
AT_ASSERTM(nodes.dim() == 1, "0ffset tensor must be one-dimensional");
auto nodes_id = nodes.data_ptr<NodeIDType>();
vector<vector<vector<NodeIDType>>> node_part_threads;
vector<th::Tensor> result(part_ptr.size()-1);
//初始化点的分区,每个分区按线程划分避免冲突
for(int i = 0; i<threads; i++){
vector<vector<NodeIDType>> node_parts;
for(int j=0;j<part_ptr.size()-1;j++){
node_parts.push_back(vector<NodeIDType>());
}
node_part_threads.push_back(node_parts);
}
#pragma omp parallel for num_threads(threads) default(shared)
for(int64_t i=0; i<nodes.size(0); i++){
int tid = omp_get_thread_num();
int pid = nodeIdToPartId(nodes_id[i], part_ptr);
node_part_threads[tid][pid].emplace_back(nodes_id[i]);
}
#pragma omp parallel for num_threads(part_ptr.size()-1) default(shared)
for(int i = 0; i<part_ptr.size()-1; i++){
vector<NodeIDType> temp;
for(int j=0;j<threads;j++){
temp.insert(temp.end(), node_part_threads[j][i].begin(), node_part_threads[j][i].end());
}
result[i]=th::tensor(temp);
}
double end_time = omp_get_wtime();
// cout<<"end divide consume: "<<end_time-start_time<<"s"<<endl;
return result;
}
NodeIDType sample_multinomial(const vector<WeightType>& weights, default_random_engine& e){
NodeIDType sample_indice;
vector<WeightType> cumulative_weights;
partial_sum(weights.begin(), weights.end(), back_inserter(cumulative_weights));
AT_ASSERTM(cumulative_weights.back() > 0, "Edge weight sum should be greater than 0.");
uniform_real_distribution<WeightType> distribution(0.0, cumulative_weights.back());
WeightType random_value = distribution(e);
auto it = lower_bound(cumulative_weights.begin(), cumulative_weights.end(), random_value);
sample_indice = distance(cumulative_weights.begin(), it);
return sample_indice;
}
/*------------Python Bind--------------------------------------------------------------*/
PYBIND11_MODULE(sample_cores, m)
{
m
.def("get_neighbors",
&get_neighbors,
py::return_value_policy::reference)
.def("heads_unique",
&heads_unique,
py::return_value_policy::reference)
.def("divide_nodes_to_part",
&divide_nodes_to_part,
py::return_value_policy::reference);
py::class_<TemporalGraphBlock>(m, "TemporalGraphBlock")
.def(py::init<vector<NodeIDType> &, vector<NodeIDType> &,
vector<NodeIDType> &>())
.def("row", [](const TemporalGraphBlock &tgb) { return vecToTensor<NodeIDType>(tgb.row); })
.def("col", [](const TemporalGraphBlock &tgb) { return vecToTensor<NodeIDType>(tgb.col); })
.def("eid", [](const TemporalGraphBlock &tgb) { return vecToTensor<EdgeIDType>(tgb.eid); })
.def("delta_ts", [](const TemporalGraphBlock &tgb) { return vecToTensor<TimeStampType>(tgb.delta_ts); })
.def("src_index", [](const TemporalGraphBlock &tgb) { return vecToTensor<EdgeIDType>(tgb.src_index); })
.def("sample_nodes", [](const TemporalGraphBlock &tgb) { return vecToTensor<NodeIDType>(tgb.sample_nodes); })
.def("sample_nodes_ts", [](const TemporalGraphBlock &tgb) { return vecToTensor<TimeStampType>(tgb.sample_nodes_ts); })
.def_readonly("sample_time", &TemporalGraphBlock::sample_time, py::return_value_policy::reference)
.def_readonly("tot_time", &TemporalGraphBlock::tot_time, py::return_value_policy::reference)
.def_readonly("sample_edge_num", &TemporalGraphBlock::sample_edge_num, py::return_value_policy::reference);
py::class_<T_TemporalGraphBlock>(m, "T_TemporalGraphBlock")
.def(py::init<th::Tensor &, th::Tensor &,
th::Tensor &>())
.def_readonly("row", &T_TemporalGraphBlock::row, py::return_value_policy::reference)
.def_readonly("col", &T_TemporalGraphBlock::col, py::return_value_policy::reference)
.def_readonly("eid", &T_TemporalGraphBlock::eid, py::return_value_policy::reference)
.def_readonly("delta_ts", &T_TemporalGraphBlock::delta_ts, py::return_value_policy::reference)
.def_readonly("src_index", &T_TemporalGraphBlock::src_index, py::return_value_policy::reference)
.def_readonly("sample_nodes", &T_TemporalGraphBlock::sample_nodes, py::return_value_policy::reference)
.def_readonly("sample_nodes_ts", &T_TemporalGraphBlock::sample_nodes_ts, py::return_value_policy::reference)
.def_readonly("sample_time", &T_TemporalGraphBlock::sample_time, py::return_value_policy::reference)
.def_readonly("tot_time", &T_TemporalGraphBlock::tot_time, py::return_value_policy::reference)
.def_readonly("sample_edge_num", &T_TemporalGraphBlock::sample_edge_num, py::return_value_policy::reference);
py::class_<TemporalNeighborBlock>(m, "TemporalNeighborBlock")
.def(py::init<vector<vector<NodeIDType>>&,
vector<int64_t> &>())
.def(py::pickle(
[](const TemporalNeighborBlock& tnb) { return tnb.serialize(); },
[](const std::string& s) { return TemporalNeighborBlock::deserialize(s); }
))
.def("update_neighbors_with_time",
&TemporalNeighborBlock::update_neighbors_with_time)
.def("update_edge_weight",
&TemporalNeighborBlock::update_edge_weight)
.def("update_node_weight",
&TemporalNeighborBlock::update_node_weight)
.def("update_all_node_weight",
&TemporalNeighborBlock::update_all_node_weight)
// .def("get_node_neighbor",&TemporalNeighborBlock::get_node_neighbor)
// .def("get_node_deg", &TemporalNeighborBlock::get_node_deg)
.def_readonly("neighbors", &TemporalNeighborBlock::neighbors, py::return_value_policy::reference)
.def_readonly("timestamp", &TemporalNeighborBlock::timestamp, py::return_value_policy::reference)
.def_readonly("edge_weight", &TemporalNeighborBlock::edge_weight, py::return_value_policy::reference)
.def_readonly("eid", &TemporalNeighborBlock::eid, py::return_value_policy::reference)
.def_readonly("deg", &TemporalNeighborBlock::deg, py::return_value_policy::reference)
.def_readonly("with_eid", &TemporalNeighborBlock::with_eid, py::return_value_policy::reference)
.def_readonly("with_timestamp", &TemporalNeighborBlock::with_timestamp, py::return_value_policy::reference)
.def_readonly("weighted", &TemporalNeighborBlock::weighted, py::return_value_policy::reference);
py::class_<ParallelSampler>(m, "ParallelSampler")
.def(py::init<TemporalNeighborBlock &, NodeIDType, EdgeIDType, int,
vector<int>&, int, string>())
.def_readonly("ret", &ParallelSampler::ret, py::return_value_policy::reference)
.def("neighbor_sample_from_nodes", &ParallelSampler::neighbor_sample_from_nodes)
.def("reset", &ParallelSampler::reset)
.def("get_ret", [](const ParallelSampler &ps) { return ps.ret; });
}
\ No newline at end of file
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CppExtension
setup(
name='sample_cores',
ext_modules=[
CppExtension(
name='sample_cores',
sources=['sample_cores_dist.cpp'],
extra_compile_args=['-fopenmp', '-Xlinker', ' -export-dynamic', '-O3', '-std=c++17'],
include_dirs=["./parallel_hashmap"],
),
],
cmdclass={
'build_ext': BuildExtension
})
import torch
class WorkStreamEvent:
def __init__(self):
self.train_stream = torch.cuda.Stream()
self.write_memory_stream = torch.cuda.Stream()
self.fetch_stream = torch.cuda.Stream()
self.write_mail_stream = torch.cuda.Stream()
\ No newline at end of file
import argparse
import os
import sys
from os.path import abspath, join, dirname
from starrygl.distributed.context import DistributedContext
from starrygl.distributed.utils import DistIndex
from starrygl.module.modules import GeneralModel
from starrygl.module.utils import parse_config
from starrygl.sample.graph_core import DataSet, GraphData, TemporalNeighborSampleGraph
from starrygl.sample.memory.shared_mailbox import SharedMailBox
from starrygl.sample.sample_core.base import NegativeSampling
from starrygl.sample.sample_core.neighbor_sampler import NeighborSampler
from starrygl.sample.part_utils.partition_tgnn import partition_load
import torch
import time
import torch
import torch.nn.functional as F
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group
import os
from starrygl.sample.data_loader import DistributedDataLoader
from starrygl.sample.batch_data import SAMPLE_TYPE
"""
test command
python test.py --world_size 2 --rank 0
--world_size', default=4, type=int, metavar='W',
help='number of workers')
parser.add_argument('--rank', default=0, type=int, metavar='W',
help='rank of the worker')
parser.add_argument('--log_interval', type=int, default=10, metavar='N',
help='interval between training status logs')
parser.add_argument('--gamma', type=float, default=0.99, metavar='G',
help='how much to value future rewards')
parser.add_argument('--seed', type=int, default=1, metavar='S',
help='random seed for reproducibility')
parser.add_argument('--num_sampler', type=int, default=10, metavar='S',
help='number of samplers')
parser.add_argument('--queue_size', type=int, default=10, metavar='S',
help='sampler queue size')
"""
parser = argparse.ArgumentParser(
description="RPC Reinforcement Learning Example",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument('--rank', default=0, type=str, metavar='W',
help='name of dataset')
parser.add_argument('--world_size', default=1, type=int, metavar='W',
help='number of negative samples')
args = parser.parse_args()
from sklearn.metrics import average_precision_score, roc_auc_score
import torch
import time
import random
import dgl
import numpy as np
from sklearn.metrics import average_precision_score, roc_auc_score
from torch.nn.parallel import DistributedDataParallel as DDP
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.rank)
os.environ["RANK"] = str(args.rank)
os.environ["WORLD_SIZE"] = str(args.world_size)
os.environ["LOCAL_RANK"] = str(0)
os.environ["MASTER_ADDR"] = '127.0.0.1'
os.environ["MASTER_PORT"] = '9337'
def seed_everything(seed=42):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
seed_everything(1234)
def main():
use_cuda = True
sample_param, memory_param, gnn_param, train_param = parse_config('./config/TGN.yml')
torch.set_num_threads(12)
ctx = DistributedContext.init(backend="nccl", use_gpu=True)
device_id = torch.cuda.current_device()
print('use cuda on',device_id)
pdata = partition_load("./dataset/here/WIKI", algo="metis_for_tgnn")
graph = GraphData(pdata = pdata)
sample_graph = TemporalNeighborSampleGraph(sample_graph = pdata.sample_graph,mode = 'full')
mailbox = SharedMailBox(pdata.ids.shape[0], memory_param, dim_edge_feat = pdata.edge_attr.shape[1] if pdata.edge_attr is not None else 0)
sampler = NeighborSampler(num_nodes=graph.num_nodes, num_layers=1, fanout=[10],graph_data=sample_graph, workers=10,policy = 'recent',graph_name = "wiki_train")
train_data = torch.masked_select(graph.edge_index,pdata.train_mask.to(graph.edge_index.device)).reshape(2,-1)
train_ts = torch.masked_select(graph.edge_ts,pdata.train_mask.to(graph.edge_index.device))
val_data = torch.masked_select(graph.edge_index,pdata.val_mask.to(graph.edge_index.device)).reshape(2,-1)
val_ts = torch.masked_select(graph.edge_ts,pdata.val_mask.to(graph.edge_index.device))
test_data = torch.masked_select(graph.edge_index,pdata.test_mask.to(graph.edge_index.device)).reshape(2,-1)
test_ts = torch.masked_select(graph.edge_ts,pdata.test_mask.to(graph.edge_index.device))
print(train_data.shape[1],val_data.shape[1],test_data.shape[1])
train_data = DataSet(edges = train_data,ts =train_ts,eids = torch.nonzero(pdata.train_mask).view(-1))
test_data = DataSet(edges = test_data,ts =test_ts,eids = torch.nonzero(pdata.test_mask).view(-1))
val_data = DataSet(edges = val_data,ts = val_ts,eids = torch.nonzero(pdata.val_mask).view(-1))
#train_neg_sampler = PreNegativeSampling('triplet',torch.masked_select(pdata.edge_index['pos_edge'],graph.data.train_mask).reshape(2,-1))
neg_sampler = NegativeSampling('triplet')
trainloader = DistributedDataLoader(graph,train_data,sampler = sampler,
sampler_fn = SAMPLE_TYPE.SAMPLE_FROM_TEMPORAL_EDGES,
neg_sampler=neg_sampler,
batch_size = 2000,
shuffle=False,
drop_last=True,
chunk_size = None,
train=True,
queue_size = 100,
mailbox = mailbox)
testloader = DistributedDataLoader(graph,test_data,sampler = sampler,
sampler_fn = SAMPLE_TYPE.SAMPLE_FROM_TEMPORAL_EDGES,
neg_sampler=neg_sampler,
batch_size = 2000,
shuffle=False,
drop_last=False,
chunk_size = None,
train=False,
queue_size = 100,
mailbox = mailbox)
valloader = DistributedDataLoader(graph,val_data,sampler = sampler,
sampler_fn = SAMPLE_TYPE.SAMPLE_FROM_TEMPORAL_EDGES,
neg_sampler=neg_sampler,
batch_size = 2000,
shuffle=False,
drop_last=False,
chunk_size = None,
train=False,
queue_size = 100,
mailbox = mailbox)
gnn_dim_node = 0 if graph.x is None else pdata.x.shape[1]
gnn_dim_edge = 0 if graph.edge_attr is None else pdata.edge_attr.shape[1]
avg_time = 0
if use_cuda:
model = GeneralModel(gnn_dim_node, gnn_dim_edge, sample_param, memory_param, gnn_param, train_param).cuda()
device = torch.device('cuda')
else:
model = GeneralModel(gnn_dim_node, gnn_dim_edge, sample_param, memory_param, gnn_param, train_param)
device = torch.device('cpu')
model = DDP(model,find_unused_parameters=True)
train_stream = torch.cuda.Stream()
send_stream = torch.cuda.Stream()
scatter_stream = torch.cuda.Stream()
val_losses = list()
def eval(mode='val'):
neg_samples = 1
model.eval()
aps = list()
aucs_mrrs = list()
if mode == 'val':
loader = valloader
elif mode == 'test':
loader = testloader
elif mode == 'train':
loader = trainloader
with torch.no_grad():
total_loss = 0
signal = torch.tensor([0],dtype = int,device = device)
for roots,mfgs,metadata in loader:
signal[0] = 0
dist.all_reduce(signal,async_op=False)
pred_pos, pred_neg = model(mfgs,metadata)
total_loss += creterion(pred_pos, torch.ones_like(pred_pos))
total_loss += creterion(pred_neg, torch.zeros_like(pred_neg))
y_pred = torch.cat([pred_pos, pred_neg], dim=0).sigmoid().cpu()
y_true = torch.cat([torch.ones(pred_pos.size(0)), torch.zeros(pred_neg.size(0))], dim=0)
aps.append(average_precision_score(y_true, y_pred.detach().numpy()))
aucs_mrrs.append(roc_auc_score(y_true, y_pred))
if mailbox is not None:
src = metadata['src_pos_index']
dst = metadata['dst_pos_index']
ts = roots.ts
if(graph.edge_attr.device == torch.device('cpu')):
edge_feats = graph.edge_attr[roots.eids.to('cpu')].to('cuda') if graph.edge_attr is not None else None
else:
edge_feats = graph.edge_attr[roots.eids] if graph.edge_attr is not None else None
dist_index_mapper = mfgs[0][0].srcdata['ID']
root_index = torch.cat((src,dst))
last_updated_nid = model.module.memory_updater.last_updated_nid[root_index]
last_updated_memory = model.module.memory_updater.last_updated_memory[root_index]
last_updated_ts=model.module.memory_updater.last_updated_ts[root_index]
index, memory, memory_ts = mailbox.get_update_memory(last_updated_nid,
last_updated_memory,
last_updated_ts)
index, mail, mail_ts = mailbox.get_update_mail(dist_index_mapper,
src,dst,ts,edge_feats,
model.module.memory_updater.last_updated_memory,
)
#mailbox.set_mailbox_all_to_all(index,memory,memory_ts,mail,mail_ts,reduce_Op = 'max')
if mode == 'val':
val_losses.append(float(total_loss))
while(signal[0].item() != dist.get_world_size()):
signal[0] = 1
dist.all_reduce(signal,async_op=False)
if(signal[0].item() == dist.get_world_size()):
break
if mailbox is not None:
mailbox.set_mailbox_all_to_all(torch.tensor([],device = device).reshape(-1),
torch.tensor([],device = device).reshape(-1,mailbox.memory_size),
torch.tensor([],device = device).reshape(-1),
torch.tensor([],device = device).reshape(-1,mailbox.mailbox.accessor.data.size(2)),
torch.tensor([],device = device).reshape(-1),
reduce_Op = 'max')
ap = float(torch.tensor(aps).mean())
if neg_samples > 1:
auc_mrr = float(torch.cat(aucs_mrrs).mean())
else:
auc_mrr = float(torch.tensor(aucs_mrrs).mean())
return ap, auc_mrr
creterion = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=train_param['lr'])
for e in range(train_param['epoch']):
epoch_start_time = time.time()
train_aps = list()
print('Epoch {:d}:'.format(e))
time_prep = 0
total_loss = 0
model.train()
if mailbox is not None:
mailbox.reset()
model.module.memory_updater.last_updated_nid = None
model.module.memory_updater.last_updated_memory = None
model.module.memory_updater.last_updated_ts = None
for roots,mfgs,metadata in trainloader:
t_prep_s = time.time()
optimizer.zero_grad()
with torch.cuda.stream(train_stream):
pred_pos, pred_neg = model(mfgs,metadata)
loss = creterion(pred_pos, torch.ones_like(pred_pos))
loss += creterion(pred_neg, torch.zeros_like(pred_neg))
total_loss += float(loss)
loss.backward()
optimizer.step()
t_prep_s = time.time()
y_pred = torch.cat([pred_pos, pred_neg], dim=0).sigmoid().cpu()
y_true = torch.cat([torch.ones(pred_pos.size(0)), torch.zeros(pred_neg.size(0))], dim=0)
train_aps.append(average_precision_score(y_true, y_pred.detach().numpy()))
if mailbox is not None:
src = metadata['src_pos_index']
dst = metadata['dst_pos_index']
ts = roots.ts
if(graph.edge_attr.device == torch.device('cpu')):
edge_feats = graph.edge_attr[roots.eids.to('cpu')].to('cuda') if graph.edge_attr is not None else None
else:
edge_feats = graph.edge_attr[roots.eids] if graph.edge_attr is not None else None
dist_index_mapper = mfgs[0][0].srcdata['ID']
root_index = torch.cat((src,dst))
last_updated_nid = model.module.memory_updater.last_updated_nid[root_index]
last_updated_memory = model.module.memory_updater.last_updated_memory[root_index]
last_updated_ts=model.module.memory_updater.last_updated_ts[root_index]
index, memory, memory_ts = mailbox.get_update_memory(last_updated_nid,
last_updated_memory,
last_updated_ts)
index, mail, mail_ts = mailbox.get_update_mail(dist_index_mapper,
src,dst,ts,edge_feats,
model.module.memory_updater.last_updated_memory,
)
#mailbox.set_mailbox_all_to_all(index,memory,memory_ts,mail,mail_ts,reduce_Op = 'max')
torch.cuda.synchronize()
time_prep = time.time() - epoch_start_time
avg_time += time.time() - epoch_start_time
train_ap = float(torch.tensor(train_aps).mean())
ap = 0
auc = 0
ap, auc = eval('val')
print('\ttrain loss:{:.4f} train ap:{:4f} val ap:{:4f} val auc:{:4f}'.format(total_loss,train_ap, ap, auc))
print('\ttotal time:{:.2f}s prep time:{:.2f}s'.format(time.time()-epoch_start_time, time_prep))
model.eval()
if mailbox is not None:
mailbox.reset()
model.module.memory_updater.last_updated_nid = None
eval('train')
eval('val')
ap, auc = eval('test')
eval_neg_samples = 1
if eval_neg_samples > 1:
print('\ttest AP:{:4f} test MRR:{:4f}'.format(ap, auc))
else:
print('\ttest AP:{:4f} test AUC:{:4f}'.format(ap, auc))
print('test_dataset',test_data.edges.shape[1],'avg_time',avg_time/train_param['epoch'])
ctx.shutdown()
if __name__ == "__main__":
main()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment