Commit 2f15e921 by zlj

add sample module and tgnn trainer

parent b84f3d4c
def foo():
a = 1
def fa():
print(a)
a+=1
print(a)
def fb(a):
def apply():
print(a)
return apply
fc = lambda: print(a)
return fa, fb(a), fc
fa, fb, fc = foo()
fa()
fb()
fc()
sampling:
- layer: 1
neighbor:
- 10
strategy: 'recent'
prop_time: False
history: 1
duration: 0
num_thread: 32
no_neg: True
memory:
- type: 'node'
dim_time: 100
deliver_to: 'neighbors'
mail_combine: 'last'
memory_update: 'transformer'
attention_head: 2
mailbox_size: 10
combine_node_feature: False
dim_out: 100
gnn:
- arch: 'identity'
train:
- epoch: 100
batch_size: 600
lr: 0.0001
dropout: 0.1
att_dropout: 0.1
# all_on_gpu: True
\ No newline at end of file
sampling:
- layer: 2
neighbor:
- 10
- 10
strategy: 'uniform'
prop_time: True
history: 3
duration: 10000
num_thread: 32
memory:
- type: 'none'
dim_out: 0
gnn:
- arch: 'transformer_attention'
layer: 2
att_head: 2
dim_time: 0
dim_out: 100
combine: 'rnn'
train:
- epoch: 50
batch_size: 600
lr: 0.0001
dropout: 0.1
att_dropout: 0.1
all_on_gpu: True
\ No newline at end of file
sampling:
- no_sample: True
history: 1
memory:
- type: 'node'
dim_time: 100
deliver_to: 'self'
mail_combine: 'last'
memory_update: 'rnn'
mailbox_size: 1
combine_node_feature: True
dim_out: 100
gnn:
- arch: 'identity'
time_transform: 'JODIE'
train:
- epoch: 100
batch_size: 600
lr: 0.0001
dropout: 0.1
all_on_gpu: True
\ No newline at end of file
sampling:
- layer: 2
neighbor:
- 10
- 10
strategy: 'uniform'
prop_time: False
history: 1
duration: 0
num_thread: 32
memory:
- type: 'none'
dim_out: 0
gnn:
- arch: 'transformer_attention'
layer: 2
att_head: 2
dim_time: 100
dim_out: 100
train:
- epoch: 10
batch_size: 600
lr: 0.0001
dropout: 0.1
att_dropout: 0.1
all_on_gpu: True
\ No newline at end of file
sampling:
- layer: 1
neighbor:
- 10
strategy: 'recent'
prop_time: False
history: 1
duration: 0
num_thread: 32
memory:
- type: 'node'
dim_time: 100
deliver_to: 'self'
mail_combine: 'last'
memory_update: 'gru'
mailbox_size: 1
combine_node_feature: True
dim_out: 100
gnn:
- arch: 'transformer_attention'
layer: 1
att_head: 2
dim_time: 100
dim_out: 100
train:
- epoch: 10
#batch_size: 100
# reorder: 16
lr: 0.0001
dropout: 0.2
att_dropout: 0.2
all_on_gpu: True
\ No newline at end of file
sampling:
- layer: <number of layers to sample>
neighbor: <a list of integers indicating how many neighbors are sampled in each layer>
strategy: <'recent' that samples most recent neighbors or 'uniform' that uniformly samples neighbors form the past>
prop_time: <False or True that specifies wherether to use the timestamp of the root nodes when sampling for their multi-hop neighbors>
history: <number of snapshots to sample on>
duration: <length in time of each snapshot, 0 for infinite length (used in non-snapshot-based methods)
num_thread: <number of threads of the sampler>
memory:
- type: <'node', we only support node memory now>
dim_time: <an integer, the dimension of the time embedding>
deliver_to: <'self' that delivers the mails only to involved nodes or 'neighbors' that deliver the mails to neighbors>
mail_combine: <'last' that use the latest latest mail as the input to the memory updater>
memory_update: <'gru' or 'rnn'>
mailbox_size: <an integer, the size of the mailbox for each node>
combine_node_feature: <False or True that specifies whether to combine node features (with the updated memory) as the input to the GNN.
dim_out: <an integer, the dimension of the output node memory>
gnn:
- arch: <'transformer_attention' or 'identity' (no GNN)>
layer: <an integer, number of layers>
att_head: <an integer, number of attention heads>
dim_time: <an integer, the dimension of the time embedding>
dim_out: <an integer, the dimension of the output dynamic node embedding>
train:
- epoch: <an integer, number of epochs to train>
batch_size: <an integer, the batch size (of edges); for multi-gpu training, this is the local batchsize>
reorder: <(optional) an integer that is divisible by batch size the specifies how many chunks per batch used in the random chunk scheduling>
lr: <floating point, learning rate>
dropout: <floating point, dropout>
att_dropout: <floating point, dropout for attention>
all_on_gpu: <False or True that decides if the node/edge features and node memory are completely stored on GPU>
\ No newline at end of file
import os
import argparse
import numpy as np
import pandas as pd
import torch
from torch_geometric.data import Data
from starrygl.sample.sample_core.base import NegativeSampling
from starrygl.sample.part_utils.partition_tgnn import partition_save
parser = argparse.ArgumentParser(
description="RPC Reinforcement Learning Example",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument('--data_name', default='WIKI', type=str, metavar='W',
help='name of dataset')
parser.add_argument('--num_neg_sample', default=1, type=int, metavar='W',
help='number of negative samples')
args = parser.parse_args()
def load_feat(d, rand_de=0, rand_dn=0):
node_feats = None
if os.path.exists('DATA/{}/node_features.pt'.format(d)):
node_feats = torch.load('DATA/{}/node_features.pt'.format(d))
if node_feats.dtype == torch.bool:
node_feats = node_feats.type(torch.float32)
edge_feats = None
if os.path.exists('DATA/{}/edge_features.pt'.format(d)):
edge_feats = torch.load('DATA/{}/edge_features.pt'.format(d))
if edge_feats.dtype == torch.bool:
edge_feats = edge_feats.type(torch.float32)
if rand_de > 0:
if d == 'LASTFM':
edge_feats = torch.randn(1293103, rand_de)
elif d == 'MOOC':
edge_feats = torch.randn(411749, rand_de)
if rand_dn > 0:
if d == 'LASTFM':
node_feats = torch.randn(1980, rand_dn)
elif d == 'MOOC':
edge_feats = torch.randn(7144, rand_dn)
return node_feats, edge_feats
data_name = args.data_name
g = np.load('../tgl_main/DATA/'+data_name+'/ext_full.npz')
df = pd.read_csv('../tgl_main/DATA/'+data_name+'/edges.csv')
if os.path.exists('../tgl_main/DATA/'+data_name+'/node_features.pt'):
n_feat = torch.load('../tgl_main/DATA/'+data_name+'/node_features.pt')
else:
n_feat = None
if os.path.exists('../tgl_main/DATA/'+data_name+'/edge_features.pt'):
e_feat = torch.load('../tgl_main/DATA/'+data_name+'/edge_features.pt')
else:
e_feat = None
src = torch.from_numpy(np.array(df.src.values)).long()
dst = torch.from_numpy(np.array(df.dst.values)).long()
ts = torch.from_numpy(np.array(df.time.values)).long()
neg_nums = args.num_neg_sample
edge_index = torch.cat((src[np.newaxis, :], dst[np.newaxis, :]), 0)
num_nodes = edge_index.view(-1).max().item()+1
num_edges = edge_index.shape[1]
print('the number of nodes in graph is {}, \
the number of edges in graph is {}'.format(num_nodes, num_edges))
sample_graph = {}
sample_src = torch.cat([src.view(-1, 1), dst.view(-1, 1)], dim=1)\
.reshape(1, -1)
sample_dst = torch.cat([dst.view(-1, 1), src.view(-1, 1)], dim=1)\
.reshape(1, -1)
sample_ts = torch.cat([ts.view(-1, 1), ts.view(-1, 1)], dim=1).reshape(-1)
sample_eid = torch.arange(num_edges).view(-1, 1).repeat(1, 2).reshape(-1)
sample_graph['edge_index'] = torch.cat([sample_src, sample_dst], dim=0)
sample_graph['ts'] = sample_ts
sample_graph['eids'] = sample_eid
neg_sampler = NegativeSampling('triplet')
neg_src = neg_sampler.sample(edge_index.shape[1]*neg_nums, num_nodes)
neg_sample = neg_src.reshape(-1, neg_nums)
edge_ts = torch.torch.from_numpy(np.array(ts)).float()
data = Data()
data.num_nodes = num_nodes
data.num_edges = num_edges
data.edge_index = edge_index
data.edge_ts = edge_ts
data.neg_sample = neg_sample
if n_feat is not None:
data.x = n_feat
if e_feat is not None:
data.edge_attr = e_feat
data.train_mask = (torch.from_numpy(np.array(df.ext_roll.values)) == 0)
data.test_mask = (torch.from_numpy(np.array(df.ext_roll.values)) == 1)
data.val_mask = (torch.from_numpy(np.array(df.ext_roll.values)) == 2)
sample_graph['train_mask'] = data.train_mask[sample_eid]
sample_graph['test_mask'] = data.test_mask[sample_eid]
sample_graph['val_mask'] = data.val_mask[sample_eid]
data.sample_graph = sample_graph
data.y = torch.zeros(edge_index.shape[1])
edge_index_dict = {}
edge_index_dict['edata'] = data.edge_index
edge_index_dict['sample_data'] = data.sample_graph['edge_index']
edge_index_dict['neg_data'] = torch.cat([neg_src.view(1, -1),
dst.view(-1, 1).repeat(1, neg_nums).
reshape(1, -1)], dim=0)
data.edge_index_dict = edge_index_dict
edge_weight_dict = {}
edge_weight_dict['edata'] = 2*neg_nums
edge_weight_dict['sample_data'] = 1*neg_nums
edge_weight_dict['neg_data'] = 1
partition_save('./dataset/here/'+data_name, data, 1, 'metis_for_tgnn',
edge_weight_dict=edge_weight_dict)
partition_save('./dataset/here/'+data_name, data, 2, 'metis_for_tgnn',
edge_weight_dict=edge_weight_dict)
#
# partition_save('./dataset/here/'+data_name, data, 4, 'metis_for_tgnn',
# edge_weight_dict=edge_weight_dict )
# partition_save('./dataset/here'+data_name, data, 8, 'metis')
# partition_save('./dataset/'+data_name, data, 12, 'metis')
# partition_save('./dataset/'+data_name, data, 16, 'metis')
#!/bin/sh
#conda activate gnn
cd ./starrygl/sample/sample_core
if [ -f "setup.py" ]; then
rm -r build
rm sample_cores.cpython-*.so
python setup.py build_ext --inplace
fi
cd ../part_utils
if [ -f "setup.py" ]; then
rm -r build
rm torch_utils.cpython-*.so
python setup.py build_ext --inplace
fi
cd ../../
......@@ -93,6 +93,8 @@ class DistIndex:
@property
def dist(self) -> Tensor:
return self.data
def to(self,device) -> Tensor:
return DistIndex(self.data.to(device))
class DistributedTensor:
......@@ -109,6 +111,21 @@ class DistributedTensor:
self.num_nodes = DistInt(local_sizes)
self.num_parts = DistInt([1] * len(self.rrefs))
@property
def dtype(self):
return self.accessor.data.dtype
@property
def device(self):
return self.accessor.data.device
def to(self,device):
return self.accessor.data.to(device)
def __getitem__(self,index):
return self.accessor.data[index]
@property
def ctx(self):
return self.accessor.ctx
......
from os.path import abspath, join, dirname
import os
import sys
from os.path import abspath, join, dirname
sys.path.insert(0, join(abspath(dirname(__file__))))
import torch
import dgl
import math
import numpy as np
class TimeEncode(torch.nn.Module):
def __init__(self, dim):
super(TimeEncode, self).__init__()
self.dim = dim
self.w = torch.nn.Linear(1, dim)
self.w.weight = torch.nn.Parameter((torch.from_numpy(1 / 10 ** np.linspace(0, 9, dim, dtype=np.float32))).reshape(dim, -1))
self.w.bias = torch.nn.Parameter(torch.zeros(dim))
def forward(self, t):
output = torch.cos(self.w(t.float().reshape((-1, 1))))
return output
class EdgePredictor(torch.nn.Module):
def __init__(self, dim_in):
super(EdgePredictor, self).__init__()
self.dim_in = dim_in
self.src_fc = torch.nn.Linear(dim_in, dim_in)
self.dst_fc = torch.nn.Linear(dim_in, dim_in)
self.out_fc = torch.nn.Linear(dim_in, 1)
def forward(self, h, neg_samples=1):
num_edge = h.shape[0] // (neg_samples + 2)
h_src = self.src_fc(h[num_edge:2 * num_edge])#self.src_fc(h[:num_edge])
h_pos_dst = self.dst_fc(h[:num_edge]) #
h_neg_src = self.src_fc(h[2 * num_edge:])
h_pos_edge = torch.nn.functional.relu(h_src + h_pos_dst)
h_neg_edge = torch.nn.functional.relu(h_neg_src + h_pos_dst.tile(neg_samples, 1))
#h_neg_edge = torch.nn.functional.relu(h_neg_dst.tile(neg_samples, 1) + h_pos_dst)
#print(h_src,h_pos_dst,h_neg_dst)
return self.out_fc(h_pos_edge), self.out_fc(h_neg_edge)
class TransfomerAttentionLayer(torch.nn.Module):
def __init__(self, dim_node_feat, dim_edge_feat, dim_time, num_head, dropout, att_dropout, dim_out, combined=False):
super(TransfomerAttentionLayer, self).__init__()
self.num_head = num_head
self.dim_node_feat = dim_node_feat
self.dim_edge_feat = dim_edge_feat
self.dim_time = dim_time
self.dim_out = dim_out
self.dropout = torch.nn.Dropout(dropout)
self.att_dropout = torch.nn.Dropout(att_dropout)
self.att_act = torch.nn.LeakyReLU(0.2)
self.combined = combined
if dim_time > 0:
self.time_enc = TimeEncode(dim_time)
if combined:
if dim_node_feat > 0:
self.w_q_n = torch.nn.Linear(dim_node_feat, dim_out)
self.w_k_n = torch.nn.Linear(dim_node_feat, dim_out)
self.w_v_n = torch.nn.Linear(dim_node_feat, dim_out)
if dim_edge_feat > 0:
self.w_k_e = torch.nn.Linear(dim_edge_feat, dim_out)
self.w_v_e = torch.nn.Linear(dim_edge_feat, dim_out)
if dim_time > 0:
self.w_q_t = torch.nn.Linear(dim_time, dim_out)
self.w_k_t = torch.nn.Linear(dim_time, dim_out)
self.w_v_t = torch.nn.Linear(dim_time, dim_out)
else:
if dim_node_feat + dim_time > 0:
self.w_q = torch.nn.Linear(dim_node_feat + dim_time, dim_out)
self.w_k = torch.nn.Linear(dim_node_feat + dim_edge_feat + dim_time, dim_out)
self.w_v = torch.nn.Linear(dim_node_feat + dim_edge_feat + dim_time, dim_out)
self.w_out = torch.nn.Linear(dim_node_feat + dim_out, dim_out)
self.layer_norm = torch.nn.LayerNorm(dim_out)
def forward(self, b):
assert(self.dim_time + self.dim_node_feat + self.dim_edge_feat > 0)
self.device = b.device
if b.num_edges() == 0:
return torch.zeros((b.num_dst_nodes(), self.dim_out), device=self.device)
if self.dim_time > 0:
time_feat = self.time_enc(b.edata['dt'])
zero_time_feat = self.time_enc(torch.zeros(b.num_dst_nodes(), dtype=torch.float32, device=self.device))
if self.combined:
Q = torch.zeros((b.num_edges(), self.dim_out), device=self.device)
K = torch.zeros((b.num_edges(), self.dim_out), device=self.device)
V = torch.zeros((b.num_edges(), self.dim_out), device=self.device)
if self.dim_node_feat > 0:
Q += self.w_q_n(b.srcdata['h'][:b.num_dst_nodes()])[b.edges()[1]]
K += self.w_k_n(b.srcdata['h'][b.num_dst_nodes():])[b.edges()[0] - b.num_dst_nodes()]
V += self.w_v_n(b.srcdata['h'][b.num_dst_nodes():])[b.edges()[0] - b.num_dst_nodes()]
if self.dim_edge_feat > 0:
K += self.w_k_e(b.edata['f'])
V += self.w_v_e(b.edata['f'])
if self.dim_time > 0:
Q += self.w_q_t(zero_time_feat)[b.edges()[1]]
K += self.w_k_t(time_feat)
V += self.w_v_t(time_feat)
Q = torch.reshape(Q, (Q.shape[0], self.num_head, -1))
K = torch.reshape(K, (K.shape[0], self.num_head, -1))
V = torch.reshape(V, (V.shape[0], self.num_head, -1))
att = dgl.ops.edge_softmax(b, self.att_act(torch.sum(Q*K, dim=2)))
att = self.att_dropout(att)
V = torch.reshape(V*att[:, :, None], (V.shape[0], -1))
b.edata['v'] = V
b.update_all(dgl.function.copy_edge('v', 'm'), dgl.function.sum('m', 'h'))
else:
if self.dim_time == 0 and self.dim_node_feat == 0:
Q = torch.ones((b.num_edges(), self.dim_out), device=self.device)
K = self.w_k(b.edata['f'])
V = self.w_v(b.edata['f'])
elif self.dim_time == 0 and self.dim_edge_feat == 0:
Q = self.w_q(b.srcdata['h'][:b.num_dst_nodes()])[b.edges()[1]]
K = self.w_k(b.srcdata['h'][b.edges()[0]])
V = self.w_v(b.srcdata['h'][b.edges()[0]])
elif self.dim_time == 0:
Q = self.w_q(b.srcdata['h'][:b.num_dst_nodes()])[b.edges()[1]]
K = self.w_k(torch.cat([b.srcdata['h'][b.edges()[0]], b.edata['f']], dim=1))
V = self.w_v(torch.cat([b.srcdata['h'][b.edges()[0]], b.edata['f']], dim=1))
#K = self.w_k(torch.cat([b.srcdata['h'][b.num_dst_nodes():], b.edata['f']], dim=1))
#V = self.w_v(torch.cat([b.srcdata['h'][b.num_dst_nodes():], b.edata['f']], dim=1))
elif self.dim_node_feat == 0:
Q = self.w_q(zero_time_feat)[b.edges()[1]]
K = self.w_k(torch.cat([b.edata['f'], time_feat], dim=1))
V = self.w_v(torch.cat([b.edata['f'], time_feat], dim=1))
elif self.dim_edge_feat == 0:
Q = self.w_q(torch.cat([b.srcdata['h'][:b.num_dst_nodes()], zero_time_feat], dim=1))[b.edges()[1]]
K = self.w_k(torch.cat([b.srcdata['h'][b.edges()[0]], time_feat], dim=1))
V = self.w_v(torch.cat([b.srcdata['h'][b.edges()[0]], time_feat], dim=1))
#K = self.w_k(torch.cat([b.srcdata['h'][b.num_dst_nodes():], time_feat], dim=1))
#V = self.w_v(torch.cat([b.srcdata['h'][b.num_dst_nodes():], time_feat], dim=1))
else:
Q = self.w_q(torch.cat([b.srcdata['h'][:b.num_dst_nodes()], zero_time_feat], dim=1))[b.edges()[1]]
K = self.w_k(torch.cat([b.srcdata['h'][b.edges()[0]], b.edata['f'], time_feat], dim=1))
V = self.w_v(torch.cat([b.srcdata['h'][b.edges()[0]], b.edata['f'], time_feat], dim=1))
#Q = self.w_q(torch.cat([b.srcdata['h'][:b.num_dst_nodes()], zero_time_feat], dim=1))[b.edges()[1]]
#K = self.w_k(torch.cat([b.srcdata['h'][b.num_dst_nodes():], b.edata['f'], time_feat], dim=1))
#V = self.w_v(torch.cat([b.srcdata['h'][b.num_dst_nodes():], b.edata['f'], time_feat], dim=1))
Q = torch.reshape(Q, (Q.shape[0], self.num_head, -1))
K = torch.reshape(K, (K.shape[0], self.num_head, -1))
V = torch.reshape(V, (V.shape[0], self.num_head, -1))
att = dgl.ops.edge_softmax(b, self.att_act(torch.sum(Q*K, dim=2)))
att = self.att_dropout(att)
V = torch.reshape(V*att[:, :, None], (V.shape[0], -1))
b.edata['v'] = V
b.update_all(dgl.function.copy_e('v', 'm'), dgl.function.sum('m', 'h'))
#b.srcdata['v'] = torch.cat([torch.zeros((b.num_dst_nodes(), V.shape[1]), device=torch.device('cuda:0')), V], dim=0)
#b.update_all(dgl.function.copy_u('v', 'm'), dgl.function.sum('m', 'h'))
if self.dim_node_feat != 0:
rst = torch.cat([b.dstdata['h'], b.srcdata['h'][:b.num_dst_nodes()]], dim=1)
else:
rst = b.dstdata['h']
rst = self.w_out(rst)
rst = torch.nn.functional.relu(self.dropout(rst))
return self.layer_norm(rst)
class IdentityNormLayer(torch.nn.Module):
def __init__(self, dim_out):
super(IdentityNormLayer, self).__init__()
self.norm = torch.nn.LayerNorm(dim_out)
def forward(self, b):
return self.norm(b.srcdata['h'])
class JODIETimeEmbedding(torch.nn.Module):
def __init__(self, dim_out):
super(JODIETimeEmbedding, self).__init__()
self.dim_out = dim_out
class NormalLinear(torch.nn.Linear):
# From Jodie code
def reset_parameters(self):
stdv = 1. / math.sqrt(self.weight.size(1))
self.weight.data.normal_(0, stdv)
if self.bias is not None:
self.bias.data.normal_(0, stdv)
self.time_emb = NormalLinear(1, dim_out)
def forward(self, h, mem_ts, ts):
time_diff = (ts - mem_ts) / (ts + 1)
rst = h * (1 + self.time_emb(time_diff.unsqueeze(1)))
return rst
\ No newline at end of file
import torch
import dgl
from os.path import abspath, join, dirname
import sys
sys.path.insert(0, join(abspath(dirname(__file__))))
from layers import *
from memorys import *
class GeneralModel(torch.nn.Module):
def __init__(self, dim_node, dim_edge, sample_param, memory_param, gnn_param, train_param, combined=False):
super(GeneralModel, self).__init__()
self.dim_node = dim_node
self.dim_node_input = dim_node
self.dim_edge = dim_edge
self.sample_param = sample_param
self.memory_param = memory_param
if not 'dim_out' in gnn_param:
gnn_param['dim_out'] = memory_param['dim_out']
self.gnn_param = gnn_param
self.train_param = train_param
if memory_param['type'] == 'node':
if memory_param['memory_update'] == 'gru':
self.memory_updater = RNNMemeoryUpdater(memory_param, 2 * memory_param['dim_out'] + dim_edge, memory_param['dim_out'], memory_param['dim_time'], dim_node)
elif memory_param['memory_update'] == 'rnn':
self.memory_updater = RNNMemeoryUpdater(memory_param, 2 * memory_param['dim_out'] + dim_edge, memory_param['dim_out'], memory_param['dim_time'], dim_node)
elif memory_param['memory_update'] == 'transformer':
self.memory_updater = TransformerMemoryUpdater(memory_param, 2 * memory_param['dim_out'] + dim_edge, memory_param['dim_out'], memory_param['dim_time'], train_param)
else:
raise NotImplementedError
self.dim_node_input = memory_param['dim_out']
self.layers = torch.nn.ModuleDict()
if gnn_param['arch'] == 'transformer_attention':
for h in range(sample_param['history']):
self.layers['l0h' + str(h)] = TransfomerAttentionLayer(self.dim_node_input, dim_edge, gnn_param['dim_time'], gnn_param['att_head'], train_param['dropout'], train_param['att_dropout'], gnn_param['dim_out'], combined=combined)
for l in range(1, gnn_param['layer']):
for h in range(sample_param['history']):
self.layers['l' + str(l) + 'h' + str(h)] = TransfomerAttentionLayer(gnn_param['dim_out'], dim_edge, gnn_param['dim_time'], gnn_param['att_head'], train_param['dropout'], train_param['att_dropout'], gnn_param['dim_out'], combined=False)
elif gnn_param['arch'] == 'identity':
self.gnn_param['layer'] = 1
for h in range(sample_param['history']):
self.layers['l0h' + str(h)] = IdentityNormLayer(self.dim_node_input)
if 'time_transform' in gnn_param and gnn_param['time_transform'] == 'JODIE':
self.layers['l0h' + str(h) + 't'] = JODIETimeEmbedding(gnn_param['dim_out'])
else:
raise NotImplementedError
self.edge_predictor = EdgePredictor(gnn_param['dim_out'])
if 'combine' in gnn_param and gnn_param['combine'] == 'rnn':
self.combiner = torch.nn.RNN(gnn_param['dim_out'], gnn_param['dim_out'])
def forward(self, mfgs, metadata = None,neg_samples=1):
if self.memory_param['type'] == 'node':
self.memory_updater(mfgs[0])
out = list()
for l in range(self.gnn_param['layer']):
for h in range(self.sample_param['history']):
rst = self.layers['l' + str(l) + 'h' + str(h)](mfgs[l][h])
if 'time_transform' in self.gnn_param and self.gnn_param['time_transform'] == 'JODIE':
rst = self.layers['l0h' + str(h) + 't'](rst, mfgs[l][h].srcdata['mem_ts'], mfgs[l][h].srcdata['ts'])
if l != self.gnn_param['layer'] - 1:
mfgs[l + 1][h].srcdata['h'] = rst
else:
out.append(rst)
if self.sample_param['history'] == 1:
out = out[0]
else:
out = torch.stack(out, dim=0)
out = self.combiner(out)[0][-1, :, :]
#metadata需要在前面去重的时候记一下id
if metadata is not None:
#out = torch.cat((out[metadata['dst_pos_pos']],out[metadata['src_id_pos']],out[metadata['dst_neg_pos']]),0)
out = torch.cat((out[metadata['src_pos_index']],out[metadata['dst_pos_index']],out[metadata['src_neg_index']]),0)
return self.edge_predictor(out, neg_samples=neg_samples)
def get_emb(self, mfgs):
if self.memory_param['type'] == 'node':
self.memory_updater(mfgs[0])
out = list()
for l in range(self.gnn_param['layer']):
for h in range(self.sample_param['history']):
rst = self.layers['l' + str(l) + 'h' + str(h)](mfgs[l][h])
if 'time_transform' in self.gnn_param and self.gnn_param['time_transform'] == 'JODIE':
rst = self.layers['l0h' + str(h) + 't'](rst, mfgs[l][h].srcdata['mem_ts'], mfgs[l][h].srcdata['ts'])
if l != self.gnn_param['layer'] - 1:
mfgs[l + 1][h].srcdata['h'] = rst
else:
out.append(rst)
if self.sample_param['history'] == 1:
out = out[0]
else:
out = torch.stack(out, dim=0)
out = self.combiner(out)[0][-1, :, :]
return out
class NodeClassificationModel(torch.nn.Module):
def __init__(self, dim_in, dim_hid, num_class):
super(NodeClassificationModel, self).__init__()
self.fc1 = torch.nn.Linear(dim_in, dim_hid)
self.fc2 = torch.nn.Linear(dim_hid, num_class)
def forward(self, x):
x = self.fc1(x)
x = torch.nn.functional.relu(x)
x = self.fc2(x)
return x
\ No newline at end of file
import yaml
def parse_config(f):
conf = yaml.safe_load(open(f, 'r'))
sample_param = conf['sampling'][0]
memory_param = conf['memory'][0]
gnn_param = conf['gnn'][0]
train_param = conf['train'][0]
return sample_param, memory_param, gnn_param, train_param
\ No newline at end of file
from typing import List, Tuple
import torch
import torch.distributed as dist
from starrygl.sample.graph_core import DataSet
from starrygl.sample.graph_core import GraphData
from starrygl.sample.sample_core.base import BaseSampler, NegativeSampling
import dgl
"""
入参不变,出参变为:
sample_from_nodes
node: list[tensor,tensor, tensor...]
eid: list[tensor,tensor, tensor...]
src_index: list[tensor,tensor, tensor...]
sample_from_edges:
node
eid: list[tensor,tensor, tensor...]
src_index: list[tensor,tensor, tensor...]
delta_ts: list[tensor,tensor, tensor...]
metadata
"""
def prepare_input(node_feat, edge_feat, mem_embedding,mfgs,dist_nid,dist_eid):
for mfg in mfgs:
for i,b in enumerate(mfg):
e_idx = b.edata['ID']
idx = b.srcdata['ID']
b.edata['ID'] = dist_eid[e_idx]
b.srcdata['ID'] = dist_nid[idx]
if edge_feat is not None:
b.edata['f'] = edge_feat[e_idx]
if i == 0:
if node_feat is not None:
b.srcdata['h'] = node_feat[idx]
if mem_embedding is not None:
node_memory,node_memory_ts,mailbox,mailbox_ts = mem_embedding
b.srcdata['mem'] = node_memory[idx]
b.srcdata['mem_ts'] = node_memory_ts[idx]
b.srcdata['mem_input'] = mailbox[idx].reshape(b.srcdata['ID'].shape[0], -1)
b.srcdata['mail_ts'] = mailbox_ts[idx]
#print(idx.shape[0],b.srcdata['mem_ts'].shape)
return mfgs
def to_block(graph: GraphData, data, sample_out, mailbox = None,device = torch.device('cuda')):
if len(sample_out) > 1:
sample_out,metadata = sample_out
else:
metadata = None
eid = [ret.eid() for ret in sample_out]
eid_len = [e.shape[0] for e in eid ]
eid_mapper: torch.Tensor = graph.eids_mapper
nid_mapper: torch.Tensor = graph.nids_mapper
eid_tensor = torch.cat(eid,dim = 0).to(eid_mapper.device)
dist_eid = eid_mapper[eid_tensor].to(device)
dist_eid,eid_inv = dist_eid.unique(return_inverse=True)
src_node = graph.sample_graph['edge_index'][0,eid_tensor*2].to(graph.nids_mapper.device)
src_ts = None
edge_feat = graph._get_edge_attr(dist_eid)
if metadata is None:
root_node = data.nodes.to(graph.nids_mapper.device)
root_len = [root_node.shape[0]]
if hasattr(data,'ts'):
src_ts = torch.cat([data.ts,
graph.sample_graph['ts'][eid_tensor*2].to(device)])
elif 'seed' in metadata:
root_node = metadata.pop('seed').to(graph.nids_mapper.device)
root_len = root_node.shape[0]
if 'seed_ts' in metadata:
src_ts = torch.cat([metadata.pop('seed_ts').to(device),\
graph.sample_graph['ts'][eid_tensor*2].to(device)])
for k in metadata:
metadata[k] = metadata[k].to(device)
nid_tensor = torch.cat([root_node,src_node],dim = 0)
dist_nid = nid_mapper[nid_tensor].to(device)
dist_nid,nid_inv = dist_nid.unique(return_inverse = True)
node_feat = graph._get_node_attr(dist_nid)
if mailbox is not None:
mem = mailbox._get_memory(dist_nid)
else:
mem = None
def build_block():
mfgs = list()
col = torch.arange(0,root_len,device = device)
col_len = 0
row_len = root_len
for r in range(len(eid_len)):
elen = eid_len[r]
row = torch.arange(row_len,row_len+elen,device = device)
b = dgl.create_block((row,col[sample_out[r].src_index().to(device)]),
num_src_nodes = row_len + elen,
num_dst_nodes = row_len,
device = device)
idx = nid_inv[0:row_len + elen]
e_idx = eid_inv[col_len:col_len+elen]
b.srcdata['ID'] = idx#dist_nid[idx]
if sample_out[r].delta_ts().shape[0] > 0:
b.edata['dt'] = sample_out[r].delta_ts().to(device)
if src_ts is not None:
b.srcdata['ts'] = src_ts[0:row_len + eid_len[r]]
b.edata['ID'] = e_idx#dist_eid[e_idx]
#if edge_feat is not None:
# b.edata['f'] = edge_feat[e_idx]
#if r == len(eid_len)-1:
# if node_feat is not None:
# b.srcdata['h'] = node_feat[idx]
# if mem_embedding is not None:
# node_memory,node_memory_ts,mailbox,mailbox_ts = mem_embedding
# b.srcdata['mem'] = node_memory[idx]
# b.srcdata['mem_ts'] = node_memory_ts[idx]
# b.srcdata['mem_input'] = mailbox[idx].cuda().reshape(b.srcdata['ID'].shape[0], -1)
# b.srcdata['mail_ts'] = mailbox_ts[idx]
col = row
col_len += eid_len[r]
row_len += eid_len[r]
mfgs.append(b)
mfgs = list(map(list, zip(*[iter(mfgs)])))
mfgs.reverse()
return data,mfgs,metadata
data,mfgs,metadata = build_block()
if dist.get_world_size() > 1:
if(node_feat is None):
node_feat = torch.futures.Future()
node_feat.set_result(None)
if(edge_feat is None):
edge_feat = torch.futures.Future()
edge_feat.set_result(None)
if(mem is None):
mem = torch.futures.Future()
mem.set_result(None)
def callback(fs,mfgs,dist_nid,dist_eid):
node_feat,edge_feat,mem_embedding = fs.value()
node_feat = node_feat.value()
edge_feat = edge_feat.value()
mem_embedding = mem_embedding.value()
return prepare_input(node_feat,edge_feat,mem_embedding,mfgs,dist_nid,dist_eid)
cal = lambda fut: callback(fs=fut,mfgs = mfgs,dist_nid = dist_nid,dist_eid =dist_eid)
return data,torch.futures.collect_all([node_feat,edge_feat,mem]).then(cal),metadata
else:
mfgs = prepare_input(node_feat,edge_feat,mem,mfgs,dist_nid,dist_eid)
#return build_block(node_feat,edge_feat,mem)#data,mfgs,metadata
return data,mfgs,metadata
def graph_sample(graph, sampler:BaseSampler,
sample_fn, data,
neg_sampling = None,
mailbox = None,
device = torch.device('cuda')):
out = sample_fn(sampler,data,neg_sampling)
return to_block(graph,data,out,mailbox,device)
def sample_from_nodes(sampler:BaseSampler, data:DataSet, **kwargs):
out = sampler.sample_from_nodes(nodes=data.nodes.reshape(-1))
#out.metadata = None
return out
def sample_from_edges(sampler:BaseSampler,
data:DataSet,
neg_sampling:NegativeSampling = None):
edge_label = data.labels if hasattr(data,'labels') else None
out = sampler.sample_from_edges(edges = data.edges,
neg_sampling=neg_sampling)
return out
def sample_from_temporal_nodes(sampler:BaseSampler,data:DataSet,
**kwargs):
out = sampler.sample_from_nodes(nodes=data.nodes.reshape(-1),
ts = data.ts.reshape(-1))
#out.metadata = None
return out
def sample_from_temporal_edges(sampler:BaseSampler, data:DataSet,
neg_sampling: NegativeSampling = None):
edge_label = data.labels if hasattr(data,'labels') else None
out = sampler.sample_from_edges(edges=data.edges.to('cpu'),
ets=data.ts.to('cpu'),
neg_sampling = neg_sampling
)
return out
class SAMPLE_TYPE:
SAMPLE_FROM_NODES = sample_from_nodes,
SAMPLE_FROM_EDGES = sample_from_edges,
SAMPLE_FROM_TEMPORAL_NODES = sample_from_temporal_nodes,
SAMPLE_FROM_TEMPORAL_EDGES = sample_from_temporal_edges
\ No newline at end of file
from collections import deque
from enum import Enum
import queue
import torch
import sys
from os.path import abspath, join, dirname
import numpy as np
from starrygl.sample.batch_data import graph_sample
from starrygl.sample.sample_core.PreNegSampling import PreNegativeSampling
sys.path.insert(0, join(abspath(dirname(__file__))))
from typing import Deque, Optional
import torch.distributed as dist
from torch_geometric.data import Data
import os.path as osp
import math
class DistributedDataLoader:
'''
Args:
data_path: the path of loaded graph ,each part 0 of graph is saved on $path$/rank_0
num_replicas: the num of worker
'''
def __init__(
self,
graph,
dataset = None,
sampler = None,
sampler_fn = None,
neg_sampler = None,
batch_size: Optional[int]=None,
drop_last = False,
device: torch.device = torch.device('cuda'),
shuffle:bool = True,
chunk_size = None,
train = False,
queue_size = 10,
mailbox = None,
**kwargs
):
assert sampler is not None
self.chunk_size = chunk_size
self.batch_size = batch_size
self.queue_size = queue_size
self.num_pending = 0
self.current_pos = 0
self.recv_idxs = 0
self.drop_last = drop_last
self.result_queue = deque(maxlen = self.queue_size)
self.shuffle = shuffle
self.is_closed = False
self.sampler = sampler
self.sampler_fn = sampler_fn
self.neg_sampler = neg_sampler
self.graph = graph
self.shuffle=shuffle
self.dataset = dataset
self.mailbox = mailbox
self.device = device
if train is True:
self._get_expected_idx(self.dataset.len)
else:
self.expected_idx = int(math.ceil(self.dataset.len/self.batch_size))
def __iter__(self):
if self.chunk_size is None:
if self.shuffle:
self.input_dataset = self.dataset.shuffle()
else:
self.input_dataset = self.dataset
self.recv_idxs = 0
self.current_pos = 0
self.num_pending = 0
self.submitted = 0
else:
self.input_dataset = self.dataset
self.recv_idxs = 0
self.num_pending = 0
self.submitted = 0
if dist.get_rank == 0:
self.current_pos = int(
math.floor(
np.random.uniform(0,self.batch_size/self.chunk_size)
)*self.chunk_size
)
else:
self.current_pos = 0
current_pos = torch.tensor([self.current_pos],dtype = torch.long,device=self.device)
dist.broadcast(current_pos, src = 0)
self.current_pos = int(current_pos.item())
self._get_expected_idx(self.dataset.len-self.current_pos)
if self.neg_sampler is not None \
and isinstance(self.neg_sampler,PreNegativeSampling):
self.neg_sampler.set_next_pos(self.current_pos)
return self
def _get_expected_idx(self,data_size):
world_size = dist.get_world_size()
self.expected_idx = data_size // self.batch_size if self.drop_last is True else int(math.ceil(data_size/self.batch_size))
if dist.get_world_size() > 1:
num_epochs = torch.tensor([self.expected_idx],dtype = torch.long,device=self.device)
dist.all_reduce(num_epochs, op=dist.ReduceOp.MIN)
self.expected_idx = int(num_epochs.item())
def _next_data(self):
if self.current_pos >= self.dataset.len:
return None
if self.current_pos + self.batch_size > self.input_dataset.len:
if self.drop_last:
return None
else:
next_data = self.input_dataset.get_next(
slice(self.current_pos,None,None)
)
self.current_pos = 0
else:
next_data = self.input_dataset.get_next(
slice(self.current_pos,self.current_pos + self.batch_size,None)
)
self.current_pos += self.batch_size
return next_data
def __next__(self):
if(dist.get_world_size() == 1):
if self.recv_idxs < self.expected_idx:
data = self._next_data()
batch_data = graph_sample(self.graph,
self.sampler,
self.sampler_fn,
data,self.neg_sampler,
self.mailbox,
self.device)
self.recv_idxs += 1
assert batch_data is not None
return batch_data
else :
raise StopIteration
else :
num_reqs = min(self.queue_size - self.num_pending,self.expected_idx - self.submitted)
for _ in range(num_reqs):
if len(self.result_queue)>0:
result = self.result_queue[0]
if result[1].done() == True:
self.recv_idxs += 1
self.num_pending -= 1
self.result_queue.popleft()
return result[0],result[1].value(),result[2]
else:
next_data = self._next_data()
assert next_data is not None
batch_data = graph_sample(self.graph,
self.sampler,
self.sampler_fn,
next_data,self.neg_sampler,
self.mailbox,
self.device)
self.result_queue.append(batch_data)
self.submitted = self.submitted + 1
self.num_pending = self.num_pending + 1
while(self.recv_idxs < self.expected_idx):
assert len(self.result_queue) > 0
result= self.result_queue[0]
if result[1].done() == True:
self.recv_idxs += 1
self.num_pending -= 1
self.result_queue.popleft()
return result[0],result[1].value(),result[2]
assert self.num_pending == 0
raise StopIteration
from starrygl.distributed.utils import DistIndex, DistributedTensor
from starrygl.sample.graph_core.utils import build_mapper
import os.path as osp
import torch
import torch.distributed as dist
from torch_geometric.data import Data
class GraphData():
def __init__(self, pdata, device = torch.device('cuda'), all_on_gpu = False):
self.device = device
self.ids = pdata.ids.to(device)
self.edge_index = pdata.edge_index.to(device)
if hasattr(pdata,'edge_ts'):
self.edge_ts = pdata.edge_ts.to(device).to(torch.float)
else:
self.edge_ts = None
self.sample_graph = pdata.sample_graph
self.nids_mapper = build_mapper(nids=pdata.ids.to(device)).data.to('cpu')
self.eids_mapper = build_mapper(nids=pdata.eids.to(device)).data.to('cpu')
if all_on_gpu:
self.nids_mapper = self.nids_mapper.to(device)
self.eids_mapper = self.eids_mapper.to(device)
self.num_nodes = self.nids_mapper.data.shape[0]
world_size = dist.get_world_size()
if hasattr(pdata,'x') and pdata.x is not None:
if world_size > 1:
self.x = DistributedTensor(pdata.x.to(self.device))
else:
self.x = pdata.x.to(device).to(torch.float)
else:
self.x = None
if hasattr(pdata,'edge_attr') and pdata.edge_attr is not None:
if world_size > 1:
self.edge_attr = DistributedTensor(pdata.edge_attr.to(self.device))
else:
self.edge_attr = pdata.edge_attr.to('cuda').to(torch.float)
else:
self.edge_attr = None
def _get_node_attr(self,ids):
if self.x is None:
return None
elif dist.get_world_size() == 1:
return self.x[ids]
else:
return self.x.index_select(ids)
def _get_edge_attr(self,ids):
if self.edge_attr is None:
return None
elif dist.get_world_size() == 1:
return self.edge_attr[ids.to('cpu')].to('cuda')
else:
return self.edge_attr.index_select(ids)
class DataSet:
def __init__(self,nodes = None,
edges = None,
labels = None,
ts = None,
device = torch.device('cuda'),**kwargs):
if nodes is not None:
self.nodes = nodes.to(device)
if edges is not None:
self.edges = edges.to(device)
if ts is not None:
self.ts = ts.to(device)
if labels is not None:
self.labels = labels
self.len = self.nodes.shape[0] if nodes is not None else self.edges.shape[1]
for k, v in kwargs.items():
assert isinstance(v,torch.Tensor) and v.shape[0]==self.len
setattr(self, k, v.to(device))
#@staticmethod
def get_next(self,indx):
nodes = self.nodes[indx] if hasattr(self,'nodes') else None
edges = self.edges[:,indx] if hasattr(self,'edges') else None
d = DataSet(nodes,edges)
for k,v in self.__dict__.items():
if k == 'edges' or k=='nodes' or k == 'len':
continue
else:
setattr(d,k,v[indx])
return d
#@staticmethod
def shuffle(self):
indx = torch.randperm(self.len)
nodes = self.nodes[indx] if hasattr(self,'nodes') else None
edges = self.edges[:,indx] if hasattr(self,'edges') else None
d = DataSet(nodes,edges)
for k,v in self.__dict__.items():
if k == 'edges' or k=='nodes' or k == 'len':
continue
else:
setattr(d,k,v[indx])
return d
class TemporalGraphData():
def __init__(self,pdata,device):
super(TemporalGraphData,self).__init__(pdata,device)
def _set_temporal_batch_cache(self,size,pin_size):
pass
def _load_feature_to_cuda(self,ids):
pass
class TemporalNeighborSampleGraph(GraphData):
def __init__(self, sample_graph=None, mode='full', eids_mapper=None):
self.edge_index = sample_graph['edge_index']
self.num_edges = self.edge_index.shape[1]
if 'ts' in sample_graph:
self.edge_ts = sample_graph['ts']
else:
self.edge_ts = None
self.eid = sample_graph['eids']
if mode == 'train':
mask = sample_graph['train_mask']
if mode == 'val':
mask = sample_graph['val_mask']
if mode == 'test':
mask = sample_graph['test_mask']
if mode != 'full':
self.edge_index = self.edge_index[:, mask]
self.edge_ts = self.edge_ts[mask]
self.eid = self.eid[mask]
from starrygl.distributed.context import DistributedContext
from typing import *
import torch.distributed as dist
import torch
from starrygl.distributed.utils import DistIndex
def build_mapper(nids):
rank = dist.get_rank()
world_size = dist.get_world_size()
dst_len = nids.size(0)
ikw = dict(dtype=torch.long, device=nids.device)
num_nodes = torch.zeros(1, **ikw)
num_nodes[0] = dst_len
dist.all_reduce(num_nodes, op=dist.ReduceOp.SUM)
all_ids: List[torch.Tensor] = [None] * world_size
dist.all_gather_object(all_ids,nids)
part_mp = torch.empty(num_nodes,**ikw)
ind_mp = torch.empty(num_nodes,**ikw)
for i in range(world_size):
iid = all_ids[i]
part_mp[iid] = i
ind_mp[iid] = torch.arange(all_ids[i].shape[0],**ikw)
return DistIndex(ind_mp,part_mp)
\ No newline at end of file
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CppExtension
setup(
name='torch_utils',
ext_modules=[
CppExtension(
name='torch_utils',
sources=['./torch_utils.cpp'],
extra_compile_args=['-fopenmp','-Xlinker',' -export-dynamic'],
include_dirs=["../sample_core"],
),
],
cmdclass={
'build_ext': BuildExtension
})#
#setup(
# name='cpu_cache_manager',
# ext_modules=[
# CppExtension(
# name='cpu_cache_manager',
# sources=['cpu_cache_manager.cpp'],
# extra_compile_args=['-fopenmp','-Xlinker',' -export-dynamic'],
# include_dirs=["./"],
# ),
# ],
# cmdclass={
# 'build_ext': BuildExtension
# })#
#
\ No newline at end of file
#include <pybind11/pybind11.h>
#include <pybind11/numpy.h>
#include <pybind11/stl.h>
#include <torch/extension.h>
#include <parallel_hashmap/phmap.h>
#include <cstring>
#include <vector>
#include <iostream>
#include <map>
#include <boost/thread/mutex.hpp>
using namespace std;
#define MTX boost::mutex
namespace py = pybind11;
namespace th = torch;
typedef int64_t NodeIDType;
typedef float TimeStampType;
#define EXTRAARGS , phmap::priv::hash_default_hash<K>, \
phmap::priv::hash_default_eq<K>, \
std::allocator<K>, 4, MTX
template <class K>
using HashT = phmap::parallel_flat_hash_set<K EXTRAARGS>;
template <class K, class V>
using HashM = phmap::parallel_flat_hash_map<K, V EXTRAARGS>;
#define work_thread 10
th::Tensor sparse_get_index(th::Tensor in,th::Tensor map_key){
auto key_ptr = map_key.data_ptr<NodeIDType>();
auto in_ptr = in.data_ptr<NodeIDType>();
int sz = map_key.size(0);
vector<pair<NodeIDType,NodeIDType>> mp(sz);
vector<NodeIDType> out(in.size(0));
#pragma omp parallel for
for(int i=0;i<sz;i++){
mp[i] = make_pair(key_ptr[i],i);
}
phmap::parallel_flat_hash_map<NodeIDType,NodeIDType> dict(mp.begin(),mp.end());
#pragma omp parallel for
for(int i=0;i<in.size(0);i++){
out[i] = dict.find(in_ptr[i])->second;
}
return th::tensor(out);
}
vector<double> get_norm_temporal(th::Tensor row,th::Tensor col,th::Tensor timestamp,int num_nodes){
vector<double> ret(4);
HashM<NodeIDType,TimeStampType> dict0;
HashM<NodeIDType,TimeStampType> dict1;
auto rowptr = row.data_ptr<NodeIDType>();
auto colptr = col.data_ptr<NodeIDType>();
auto time_ptr = timestamp.data_ptr<TimeStampType>();
vector<TimeStampType> out_timestamp[work_thread];
vector<TimeStampType> in_timestamp[work_thread];
#pragma omp parallel for num_threads(work_thread)
for(int i = 0;i<row.size(0);i++){
int tid = omp_get_thread_num();
if(dict0.find(rowptr[i])!=dict0.end()){
out_timestamp[tid].push_back(time_ptr[i]-dict0.find(rowptr[i])->second);
dict0.find(rowptr[i])->second = time_ptr[i];
}
else dict0.insert(make_pair(rowptr[i],time_ptr[i]));
if(dict1.find(colptr[i])!=dict1.end()){
in_timestamp[tid].push_back(time_ptr[i]-dict1.find(colptr[i])->second);
dict1.find(colptr[i])->second = time_ptr[i];
}
else dict1.insert(make_pair(colptr[i],time_ptr[i]));
}
double srcavg = 0;
double dstavg = 0;
double srcvar = 0;
double dstvar = 0;
double srccnt = 0;
double dstcnt = 0;
for(int i = 0;i<work_thread;i++){
#pragma omp parallel for num_threads(work_thread)
for(auto &v: in_timestamp[i]){
dstavg += v;
dstcnt++;
}
#pragma omp parallel for num_threads(work_thread)
for(auto &v: out_timestamp[i]){
srcavg += v;
srccnt++;
}
}
dstavg /= dstcnt;
srcavg /= srccnt;
for(int i = 0;i<work_thread;i++){
#pragma omp parallel for num_threads(work_thread)
for(int j = 0;j<in_timestamp[i].size();j++){
TimeStampType v=in_timestamp[i][j];
dstvar += (v-dstavg)*(v-dstavg)/dstavg;
}
#pragma omp parallel for num_threads(work_thread)
for(int j = 0;j<out_timestamp[i].size();j++){
TimeStampType v=out_timestamp[i][j];
srcavg += (v-srcavg)*(v-srcavg)/srcavg;
}
}
ret[0]=srcavg;
ret[1]=srcvar;
ret[2]=dstavg;
ret[3]=dstvar;
return ret;
}
PYBIND11_MODULE(torch_utils, m)
{
m
.def("sparse_get_index",
&sparse_get_index,
py::return_value_policy::reference)
.def("get_norm_temporal",
&get_norm_temporal,
py::return_value_policy::reference
);
}
\ No newline at end of file
import sys
from os.path import abspath, join, dirname
sys.path.insert(0, join(abspath(dirname(__file__))))
from torch import Tensor
import torch
from base import NegativeSampling
from base import NegativeSamplingMode
from typing import Any, List, Optional, Tuple, Union
class PreNegativeSampling(NegativeSampling):
r"""The negative sampling configuration of a
:class:`~torch_geometric.sampler.BaseSampler` when calling
:meth:`~torch_geometric.sampler.BaseSampler.sample_from_edges`.
Args:
mode (str): The negative sampling mode
(:obj:`"binary"` or :obj:`"triplet"`).
If set to :obj:`"binary"`, will randomly sample negative links
from the graph.
If set to :obj:`"triplet"`, will randomly sample negative
destination nodes for each positive source node.
amount (int or float, optional): The ratio of sampled negative edges to
the number of positive edges. (default: :obj:`1`)
weight (torch.Tensor, optional): A node-level vector determining the
sampling of nodes. Does not necessariyl need to sum up to one.
If not given, negative nodes will be sampled uniformly.
(default: :obj:`None`)
"""
def __init__(
self,
mode: Union[NegativeSamplingMode, str],
neg_sample_list: torch.Tensor
):
super(PreNegativeSampling,self).__init__(mode)
self.neg_sample_list = neg_sample_list
self.next_pos = 0
def set_next_pos(self,pos):
self.next_pos = pos
def sample(self, num_samples: int,
num_nodes: Optional[int] = None) -> Tensor:
r"""Generates :obj:`num_samples` negative samples."""
if num_nodes is None:
raise ValueError(
f"Cannot sample negatives in '{self.__class__.__name__}' "
f"without passing the 'num_nodes' argument")
neg_sample_out = self.neg_sample_list[
self.next_pos:self.next_pos+num_samples,:].reshape(-1)
self.next_pos = self.next_pos + num_samples
return neg_sample_out
#return torch.from_numpy(np.random.randint(num_nodes, size=num_samples))
import os.path as osp
import torch
class GraphData():
def __init__(self, path):
assert path is not None and osp.exists(path),'path 不存在'
id,edge_index,data,partptr =torch.load(path)
# 当前分区序号
self.partition_id = id
# 总分区数
self.partitions = partptr.numel() - 1
# 全图结构数据
self.num_nodes = partptr[self.partitions]
self.num_edges = edge_index[0].numel()
self.edge_index = edge_index
# 该分区下的数据(包含特征向量和子图结构)pyg Data数据结构
self.data = data
# 分区映射关系
self.partptr = partptr
self.eid = [i for i in range(self.num_edges)]
def __init__(self, id, edge_index, data, partptr, timestamp=None):
# 当前分区序号
self.partition_id = id
# 总分区数
self.partitions = partptr.numel() - 1
# 全图结构数据
self.num_nodes = partptr[self.partitions]
if edge_index is not None:
self.num_edges = edge_index[0].numel()
self.edge_index = edge_index
self.edge_ts = timestamp
# 该分区下的数据(包含特征向量和子图结构)pyg Data数据结构
self.data = data
# 分区映射关系
self.partptr = partptr
# edge id
self.eid = torch.tensor([i for i in range(0, self.num_edges)])
def select_attr(self,index):
return torch.index_select(self.data.x,0,index)
#返回全局的节点id 所对应的分区
def get_part_num(self):
return self.data.x.size()[0]
def select_attr(self,index):
return torch.index_select(self.data.x,0,index)
def select_y(self,index):
return torch.index_select(self.data.y,0,index)
#返回全局的节点id 所对应的分区
def get_localId_by_partitionId(self,id,index):
#print(index)
if(id == -1 or id == 0):
return index
else:
return torch.add(index,-self.partptr[id])
def get_globalId_by_partitionId(self,id,index):
if(id == -1 or id == 0):
return index
else:
return torch.add(index,self.partptr[id])
def get_node_num(self):
return self.num_nodes
def localId_to_globalId(self,id,partitionId:int = -1):
'''
将分区partitionId内的点id映射为全局的id
'''
if partitionId == -1:
partitionId = self.partition_id
assert id >=self.partptr[self.partition_id] and id < self.partptr[self.partition_id+1]
ids_before = 0
if self.partition_id>0:
ids_before = self.partptr[self.partition_id-1]
return id+ids_before
def get_partitionId_by_globalId(self,id):
'''
通过全局id得到对应的分区序号
'''
partitionId = -1
assert id>=0 and id<self.num_nodes,'id 超过范围'
for i in range(self.partitions):
if id>=self.partptr[i] and id<self.partptr[i+1]:
partitionId = i
break
assert partitionId>=0, 'id 不存在对应的分区'
return partitionId
def get_nodes_by_partitionId(self,id):
'''
根据partitioId 返回该分区的节点数量
'''
assert id>=0 and id<self.partitions,'partitionId 非法'
return (int)(self.partptr[id+1]-self.partptr[id])
def __repr__(self):
return (f'{self.__class__.__name__}(\n'
f' partition_id={self.partition_id}\n'
f' data={self.data},\n'
f' global_info('
f'num_nodes={self.num_nodes},'
f' num_edges={self.num_edges},'
f' num_parts={self.partitions},'
f' edge_index=[2,{self.edge_index[0].numel()}])\n'
f')')
import torch
from torch import Tensor
from enum import Enum
import math
from abc import ABC
from typing import Any, List, Optional, Tuple, Union
class SampleType(Enum):
Whole = 0
Inner = 1
Outer =2
class NegativeSamplingMode(Enum):
# 'binary': Randomly sample negative edges in the graph.
binary = 'binary'
# 'triplet': Randomly sample negative destination nodes for each positive
# source node.
triplet = 'triplet'
class NegativeSampling:
r"""The negative sampling configuration of a
:class:`~torch_geometric.sampler.BaseSampler` when calling
:meth:`~torch_geometric.sampler.BaseSampler.sample_from_edges`.
Args:
mode (str): The negative sampling mode
(:obj:`"binary"` or :obj:`"triplet"`).
If set to :obj:`"binary"`, will randomly sample negative links
from the graph.
If set to :obj:`"triplet"`, will randomly sample negative
destination nodes for each positive source node.
amount (int or float, optional): The ratio of sampled negative edges to
the number of positive edges. (default: :obj:`1`)
weight (torch.Tensor, optional): A node-level vector determining the
sampling of nodes. Does not necessariyl need to sum up to one.
If not given, negative nodes will be sampled uniformly.
(default: :obj:`None`)
"""
mode: NegativeSamplingMode
amount: Union[int, float] = 1
weight: Optional[Tensor] = None
unique: bool
def __init__(
self,
mode: Union[NegativeSamplingMode, str],
amount: Union[int, float] = 1,
weight: Optional[Tensor] = None,
unique: bool = False
):
self.mode = NegativeSamplingMode(mode)
self.amount = amount
self.weight = weight
self.unique = unique
if self.amount <= 0:
raise ValueError(f"The attribute 'amount' needs to be positive "
f"for '{self.__class__.__name__}' "
f"(got {self.amount})")
if self.is_triplet():
if self.amount != math.ceil(self.amount):
raise ValueError(f"The attribute 'amount' needs to be an "
f"integer for '{self.__class__.__name__}' "
f"with 'triplet' negative sampling "
f"(got {self.amount}).")
self.amount = math.ceil(self.amount)
def is_binary(self) -> bool:
return self.mode == NegativeSamplingMode.binary
def is_triplet(self) -> bool:
return self.mode == NegativeSamplingMode.triplet
def sample(self, num_samples: int,
num_nodes: Optional[int] = None) -> Tensor:
r"""Generates :obj:`num_samples` negative samples."""
if self.weight is None:
if num_nodes is None:
raise ValueError(
f"Cannot sample negatives in '{self.__class__.__name__}' "
f"without passing the 'num_nodes' argument")
return torch.randint(num_nodes, (num_samples, ))
if num_nodes is not None and self.weight.numel() != num_nodes:
raise ValueError(
f"The 'weight' attribute in '{self.__class__.__name__}' "
f"needs to match the number of nodes {num_nodes} "
f"(got {self.weight.numel()})")
return torch.multinomial(self.weight, num_samples, replacement=True)
class SampleOutput:
node: Optional[torch.Tensor] = None
edge_index_list: Optional[List[torch.Tensor]] = None
eid_list: Optional[List[torch.Tensor]] = None
delta_ts_list: Optional[List[torch.Tensor]] = None
metadata: Optional[Any] = None
class BaseSampler(ABC):
r"""An abstract base class that initializes a graph sampler and provides
:meth:`_sample_one_layer_from_nodes`
:meth:`_sample_one_layer_from_nodes_parallel`
:meth:`sample_from_nodes` routines.
"""
def sample_from_nodes(
self,
nodes: torch.Tensor,
with_outer_sample: SampleType,
**kwargs
) -> Tuple[torch.Tensor, list]:
r"""Performs mutilayer sampling from the nodes specified in: nodes
The specific number of layers is determined by parameter: num_layers
returning a sampled subgraph in the specified output format: Tuple[torch.Tensor, list].
Args:
nodes: the list of seed nodes index
with_outer_sample: 0-sample in whole graph structure; 1-sample onehop outer nodel; 2-cross partition sampling
**kwargs: other kwargs
Returns:
sampled_nodes: the nodes sampled
sampled_edge_index_list: the edges sampled
"""
raise NotImplementedError
def sample_from_edges(
self,
edges: torch.Tensor,
with_outer_sample: SampleType,
edge_label: Optional[torch.Tensor] = None,
neg_sampling: Optional[NegativeSampling] = None
) -> Tuple[torch.Tensor, list]:
r"""Performs sampling from the edges specified in :obj:`index`,
returning a sampled subgraph in the specified output format.
Args:
edges: the list of seed edges index
with_outer_sample: 0-sample in whole graph structure; 1-sample onehop outer nodel; 2-cross partition sampling
edge_label: the label for the seed edges.
neg_sampling: The negative sampling configuration
Returns:
sampled_nodes: the nodes sampled
sampled_edge_index_list: the edges sampled
metadata: other infomation
"""
raise NotImplementedError
# def _sample_one_layer_from_nodes(
# self,
# nodes:torch.Tensor,
# **kwargs
# ) -> Tuple[torch.Tensor, torch.Tensor]:
# r"""Performs sampling from the nodes specified in: nodes,
# returning a sampled subgraph in the specified output format: Tuple[torch.Tensor, torch.Tensor].
# Args:
# nodes: the list of seed nodes index
# **kwargs: other kwargs
# Returns:
# sampled_nodes: the nodes sampled
# sampled_edge_index: the edges sampled
# """
# raise NotImplementedError
# def _sample_one_layer_from_nodes_parallel(
# self,
# nodes: torch.Tensor,
# **kwargs
# ) -> Tuple[torch.Tensor, torch.Tensor]:
# r"""Performs sampling paralleled from the nodes specified in: nodes,
# returning a sampled subgraph in the specified output format: Tuple[torch.Tensor, torch.Tensor].
# Args:
# nodes: the list of seed nodes index
# **kwargs: other kwargs
# Returns:
# sampled_nodes: the nodes sampled
# sampled_edge_index: the edges sampled
# """
# raise NotImplementedError
from enum import Enum
import sys
import argparse
from os.path import abspath, join, dirname
import time
sys.path.insert(0, join(abspath(dirname(__file__))))
class SampleType(Enum):
Whole = 0
Inner = 1
Outer =2
parser = argparse.ArgumentParser(
description="RPC Reinforcement Learning Example",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument('--world_size', default=1, type=int, metavar='W',
help='number of workers')
parser.add_argument('--rank', default=0, type=int, metavar='W',
help='rank of the worker')
parser.add_argument('--log_interval', type=int, default=10, metavar='N',
help='interval between training status logs')
parser.add_argument('--gamma', type=float, default=0.99, metavar='G',
help='how much to value future rewards')
parser.add_argument('--seed', type=int, default=1, metavar='S',
help='random seed for reproducibility')
parser.add_argument('--num_sampler', type=int, default=1, metavar='S',
help='number of samplers')
parser.add_argument('--queue_size', type=int, default=10000, metavar='S',
help='sampler queue size')
#parser = distparser.parser.add_subparsers().add_parser("train")#argparse.ArgumentParser(description='minibatch_gnn_models')
parser.add_argument('--data', type=str, help='dataset name')
parser.add_argument('--config', type=str, help='path to config file')
parser.add_argument('--gpu', type=str, default='0', help='which GPU to use')
parser.add_argument('--model_name', type=str, default='', help='name of stored model')
parser.add_argument('--rand_edge_features', type=int, default=0, help='use random edge featrues')
parser.add_argument('--rand_node_features', type=int, default=0, help='use random node featrues')
parser.add_argument('--eval_neg_samples', type=int, default=1, help='how many negative samples to use at inference. Note: this will change the metric of test set to AP+AUC to AP+MRR!')
args = parser.parse_args()
rpc_proxy=None
WORKER_RANK = args.rank
NUM_SAMPLER = args.num_sampler
WORLD_SIZE = args.world_size
QUEUE_SIZE = args.queue_size
MAX_QUEUE_SIZE = 5*args.queue_size
RPC_NAME = "rpcserver{}"
SAMPLE_TYPE = SampleType.Outer
def _get_worker_rank():
return WORKER_RANK
def _get_num_sampler():
return NUM_SAMPLER
def _get_world_size():
return WORLD_SIZE
def _get_RPC_NAME():
return RPC_NAME
def _get_queue_size():
return QUEUE_SIZE
def _get_max_queue_size():
return MAX_QUEUE_SIZE
def _get_rpc_name():
return RPC_NAME
\ No newline at end of file
import sys
from os.path import abspath, join, dirname
import time
sys.path.insert(0, join(abspath(dirname(__file__))))
graph_set={}
def _clear_all(barrier = None):
global graph_set
for key in graph_set:
graph = graph_set[key]
graph._close_graph_in_shame()
print('clear ',key)
if(barrier is not None and barrier.wait()==0):
graph._unlink_graph_in_shame()
graph_set = {}
def _set_graph(graph_name,graph_info):
graph_info._get_graph_from_shm()
graph_set[graph_name]=graph_info
def _get_graph(graph_name):
return graph_set[graph_name]
def _del_graph(graph_name):
graph_set.pop(graph_name)
def _get_size():
return len(graph_set)
# local_sampler=None
local_sampler = {}
def set_local_sampler(graph_name,sampler):
local_sampler[graph_name] = sampler
def get_local_sampler(sampler_name):
assert sampler_name in local_sampler, 'Local_sampler doesn\'t has sampler_name'
return local_sampler[sampler_name]
\ No newline at end of file
import torch
import time
from Utils import GraphData
seed = 10 # 你可以选择任何整数作为种子
torch.manual_seed(seed)
num_nodes1 = 10
fanout1 = [2]
edge_index1 = torch.tensor([[1, 5, 7, 9, 2, 4, 6, 7, 8, 0, 1, 6, 2, 0, 1, 3, 5, 8, 9, 7, 4, 8, 2, 3, 5, 8],
[0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 3, 3, 4, 4, 4, 4, 4, 5, 6, 6, 7, 7, 8, 9]])
edge_ts = torch.tensor([1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 4, 4, 4, 4, 4, 5, 5, 5, 6, 6, 6, 6]).double()
edge_weight1 = torch.tensor([2, 1, 2, 1, 8, 6, 3, 1, 1, 1, 1, 5, 1, 1, 2, 1, 1, 1, 1, 5, 1, 2, 2, 2, 1, 1]).double()
edge_weight1 = None
g_data = GraphData(id=0, edge_index=edge_index1, timestamp=edge_ts, data=None, partptr=torch.tensor([0, num_nodes1]))
from neighbor_sampler import NeighborSampler, SampleType
pre = time.time()
# from neighbor_sampler import get_neighbors, update_edge_weight
# row, col = edge_index1
# tnb = get_neighbors(row.contiguous(), col.contiguous(), num_nodes1, edge_weight1)
# print("tnb.neighbors:", tnb.neighbors)
# print("tnb.deg:", tnb.deg)
# print("tnb.weight:", tnb.edge_weight)
sampler = NeighborSampler(num_nodes1,
num_layers=1,
fanout=fanout1,
edge_weight=edge_weight1,
graph_data=g_data,
workers=2,
graph_name='a',
is_distinct = 0,
policy="recent")
end = time.time()
print("init time:", end-pre)
print("tnb.neighbors:", sampler.tnb.neighbors)
print("tnb.deg:", sampler.tnb.deg)
print("tnb.ts:", sampler.tnb.timestamp)
print("tnb.weight:", sampler.tnb.edge_weight)
# update_edge_row = row
# update_edge_col = col
# update_edge_w = torch.DoubleTensor([i for i in range(edge_weight1.size(0))])
# print('tnb.edge_weight:', tnb.edge_weight)
# print('begin update')
# pre = time.time()
# update_edge_weight(tnb, update_edge_row.contiguous(), update_edge_col.contiguous(), update_edge_w.contiguous())
# end = time.time()
# print("update time:", end-pre)
# print('update_edge_row:', update_edge_row)
# print('update_edge_col:', update_edge_col)
# print('tnb.edge_weight:', tnb.edge_weight)
pre = time.time()
out = sampler.sample_from_nodes(torch.tensor([6,7]),
with_outer_sample=SampleType.Whole,
ts=torch.tensor([9, 9]))
end = time.time()
# print('node:', out.node)
# print('edge_index_list:', out.edge_index_list)
# print('eid_list:', out.eid_list)
# print('eid_ts_list:', out.eid_ts_list)
print("sample time:", end-pre)
print("tot_time", out[0].tot_time)
print("sam_time", out[0].sample_time)
print("sam_edge", out[0].sample_edge_num)
print('eid_list:', out[0].eid)
print('delta_ts_list:', out[0].delta_ts)
print((out[0].sample_nodes<10000).sum())
print('node:', out[0].sample_nodes)
print('node_ts:', out[0].sample_nodes_ts)
\ No newline at end of file
import torch
import time
from Utils import GraphData
seed = 10 # 你可以选择任何整数作为种子
torch.manual_seed(seed)
num_nodes1 = 10
fanout1 = [2,2] # index 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25
edge_index1 = torch.tensor([[1, 5, 7, 9, 2, 4, 6, 7, 8, 0, 1, 6, 2, 0, 1, 3, 5, 8, 9, 7, 4, 8, 2, 3, 5, 8],
[0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 3, 3, 4, 4, 4, 4, 4, 5, 6, 6, 7, 7, 8, 9]])
edge_weight1 = torch.tensor([2, 1, 2, 1, 8, 6, 3, 1, 1, 1, 1, 5, 1, 1, 2, 1, 1, 1, 1, 5, 1, 2, 2, 2, 1, 1]).double()
g_data = GraphData(id=0, edge_index=edge_index1, data=None, partptr=torch.tensor([0, num_nodes1]))
edge_weight1 = None
# g_data.eid=None
from neighbor_sampler import NeighborSampler, SampleType
pre = time.time()
sampler = NeighborSampler(num_nodes1,
num_layers=2,
fanout=fanout1,
edge_weight=edge_weight1,
graph_data=g_data,
workers=2,
graph_name='a',
policy="uniform")
end = time.time()
print("init time:", end-pre)
print("tnb.neighbors:", sampler.tnb.neighbors)
print("tnb.eid:", sampler.tnb.eid)
print("tnb.deg:", sampler.tnb.deg)
print("tnb.weight:", sampler.tnb.edge_weight)
# row,col = edge_index1
# update_edge_row = row
# update_edge_col = col
# update_edge_w = torch.FloatTensor([i for i in range(edge_weight1.size(0))])
# print('tnb.edge_weight:', sampler.tnb.edge_weight)
# print('begin update')
# pre = time.time()
# sampler.tnb.update_edge_weight(sampler.tnb, update_edge_row.contiguous(), update_edge_col.contiguous(), update_edge_w.contiguous())
# end = time.time()
# print("update time:", end-pre)
# print('update_edge_row:', update_edge_row)
# print('update_edge_col:', update_edge_col)
# print('tnb.edge_weight:', sampler.tnb.edge_weight)
pre = time.time()
out = sampler.sample_from_nodes(torch.tensor([1,2]),
with_outer_sample=SampleType.Whole)# sampler.sample_from_nodes(torch.masked_select(torch.arange(g.num_nodes),node_data['train_mask']))
end = time.time()
print('node1:\t', out[0].sample_nodes().tolist())
print('eid1:\t', out[0].eid().tolist())
print('edge1:\t', edge_index1[:, out[0].eid()].tolist())
print('node2:\t', out[1].sample_nodes().tolist())
print('eid2:\t', out[1].eid().tolist())
print('edge2:\t', edge_index1[:, out[1].eid()].tolist())
print("sample time:", end-pre)
\ No newline at end of file
import torch
from ogb.nodeproppred import PygNodePropPredDataset
from torch_geometric import datasets
import time
from Utils import GraphData
def load_ogb_dataset(name, data_path):
dataset = PygNodePropPredDataset(name=name, root=data_path)
split_idx = dataset.get_idx_split()
g = dataset[0]
n_node = g.num_nodes
node_data={}
node_data['train_mask'] = torch.zeros(n_node, dtype=torch.bool)
node_data['val_mask'] = torch.zeros(n_node, dtype=torch.bool)
node_data['test_mask'] = torch.zeros(n_node, dtype=torch.bool)
node_data['train_mask'][split_idx["train"]] = True
node_data['val_mask'][split_idx["valid"]] = True
node_data['test_mask'][split_idx["test"]] = True
return g, node_data
g, node_data = load_ogb_dataset('ogbn-products', "/home/zlj/hzq/code/gnn/dataset/")
print(g)
# for worker in [1,2,3,4,5,6,7,8,9,10,20,30]:
# import random
# timestamp = [random.randint(1, 5) for i in range(0, g.num_edges)]
# timestamp = torch.FloatTensor(timestamp)
print('begin load')
pre = time.time()
timestamp = torch.load('/home/zlj/hzq/code/gnn/my_sampler/TemporalSample/timestamp.my')
tnb = torch.load("tnb_before.my")
end = time.time()
print("load time:", end-pre)
row, col = g.edge_index
edge_weight=None
g_data = GraphData(id=1, edge_index=g.edge_index, timestamp=timestamp, data=g, partptr=torch.tensor([0, g.num_nodes//4, g.num_nodes//4*2, g.num_nodes//4*3, g.num_nodes]))
from neighbor_sampler import NeighborSampler, SampleType, get_neighbors
# print('begin tnb')
# pre = time.time()
# tnb = get_neighbors(row.contiguous(), col.contiguous(), g.num_nodes, 0, g_data.eid, edge_weight, timestamp)
# end = time.time()
# print("init tnb time:", end-pre)
# torch.save(tnb, "tnb_before.my")
pre = time.time()
sampler = NeighborSampler(g.num_nodes,
tnb=tnb,
num_layers=2,
fanout=[100,100],
graph_data=g_data,
workers=10,
policy="uniform",
is_root_ts=0,
graph_name='a')
end = time.time()
print("init time:", end-pre)
# from torch_geometric.sampler import NeighborSampler, NumNeighbors, NodeSamplerInput, SamplerOutput
# pre = time.time()
# num_nei = NumNeighbors([100, 100])
# node_idx = NodeSamplerInput(input_id=None, node=torch.tensor(range(g.num_nodes//4, g.num_nodes//4+600000)))# (input_id=None, node=torch.masked_select(torch.arange(g.num_nodes),node_data['train_mask']))
# sampler = NeighborSampler(g, num_nei)
# end = time.time()
# print("init time:", end-pre)
ts = torch.tensor([i%5+1 for i in range(0, 600000)])
pre = time.time()
out = sampler.sample_from_nodes(torch.tensor(range(g.num_nodes//4, g.num_nodes//4+600000)),
ts=ts,
with_outer_sample=SampleType.Inner)# sampler.sample_from_nodes(torch.masked_select(torch.arange(g.num_nodes),node_data['train_mask']))
# out = sampler.sample_from_nodes(node_idx)
# node = out.node
# edge = [out.row, out.col]
end = time.time()
print('node:', out.node)
print('edge_index_list:', out.edge_index_list)
print('eid_list:', out.eid_list)
print('eid_ts_list:', out.eid_ts_list)
print("sample time", end-pre)
\ No newline at end of file
import torch
from ogb.nodeproppred import PygNodePropPredDataset
from torch_geometric import datasets
import time
from Utils import GraphData
def load_ogb_dataset(name, data_path):
dataset = PygNodePropPredDataset(name=name, root=data_path)
split_idx = dataset.get_idx_split()
g = dataset[0]
n_node = g.num_nodes
node_data={}
node_data['train_mask'] = torch.zeros(n_node, dtype=torch.bool)
node_data['val_mask'] = torch.zeros(n_node, dtype=torch.bool)
node_data['test_mask'] = torch.zeros(n_node, dtype=torch.bool)
node_data['train_mask'][split_idx["train"]] = True
node_data['val_mask'][split_idx["valid"]] = True
node_data['test_mask'][split_idx["test"]] = True
return g, node_data
g, node_data = load_ogb_dataset('ogbn-products', "/home/zlj/hzq/code/gnn/dataset/")
print(g)
# for worker in [1,2,3,4,5,6,7,8,9,10,20,30]:
g_data = GraphData(id=1, edge_index=g.edge_index, data=g, partptr=torch.tensor([0, g.num_nodes//4, g.num_nodes//4*2, g.num_nodes//4*3, g.num_nodes]))
row, col = g.edge_index
# edge_weight = torch.ones(g.num_edges).float()
# indices = [x for x in range(0, g.num_edges, 5)]
# edge_weight[indices] = 2.0
# g_data.eid = None
edge_weight = None
timestamp = None
from neighbor_sampler import NeighborSampler, SampleType
from neighbor_sampler import get_neighbors
update_edge_row = row
update_edge_col = col
update_edge_w = torch.DoubleTensor([i for i in range(g.num_edges)])
print('begin update')
pre = time.time()
# update_edge_weight(tnb, update_edge_row.contiguous(), update_edge_col.contiguous(), update_edge_w.contiguous())
end = time.time()
print("update time:", end-pre)
print('begin tnb')
pre = time.time()
tnb = get_neighbors("a",
row.contiguous(),
col.contiguous(),
g.num_nodes, 0,
g_data.eid,
edge_weight,
timestamp)
end = time.time()
print("init tnb time:", end-pre)
torch.save(tnb, "/home/zlj/hzq/code/gnn/my_sampler/MergeSample/tnb_static.my")
# print('begin load')
# pre = time.time()
# tnb = torch.load("/home/zlj/hzq/code/gnn/my_sampler/MergeSample/tnb_static.my")
# end = time.time()
# print("load time:", end-pre)
print('begin init')
pre = time.time()
sampler = NeighborSampler(g.num_nodes,
tnb = tnb,
num_layers=2,
fanout=[100,100],
graph_data=g_data,
workers=10,
graph_name='a',
policy="uniform")
end = time.time()
print("init time:", end-pre)
# from torch_geometric.sampler import NeighborSampler, NumNeighbors, NodeSamplerInput, SamplerOutput
# pre = time.time()
# num_nei = NumNeighbors([100, 100])
# node_idx = NodeSamplerInput(input_id=None, node=torch.tensor(range(g.num_nodes//4, g.num_nodes//4+600000)))# (input_id=None, node=torch.masked_select(torch.arange(g.num_nodes),node_data['train_mask']))
# sampler = NeighborSampler(g, num_nei)
# end = time.time()
# print("init time:", end-pre)
pre = time.time()
out = sampler.sample_from_nodes(torch.tensor(range(g.num_nodes//4, g.num_nodes//4+600000)), with_outer_sample=SampleType.Inner)# sampler.sample_from_nodes(torch.masked_select(torch.arange(g.num_nodes),node_data['train_mask']))
# out = sampler.sample_from_nodes(node_idx)
# node = out.node
# edge = [out.row, out.col]
end = time.time()
print('node1:\t', out[0].sample_nodes())
print('eid1:\t', out[0].eid())
print('edge1:\t', g.edge_index[:, out[0].eid()])
print('node2:\t', out[1].sample_nodes())
print('eid2:\t', out[1].eid())
print('edge2:\t', g.edge_index[:, out[1].eid()])
print("sample time", end-pre)
\ No newline at end of file
This source diff could not be displayed because it is too large. You can view the blob instead.
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from conans import ConanFile, tools
import os
class SparseppConan(ConanFile):
name = "parallel_hashmap"
version = "1.27"
description = "A header-only, very fast and memory-friendly hash map"
# Indicates License type of the packaged library
license = "https://github.com/greg7mdp/parallel-hashmap/blob/master/LICENSE"
# Packages the license for the conanfile.py
exports = ["LICENSE"]
# Custom attributes for Bincrafters recipe conventions
source_subfolder = "source_subfolder"
def source(self):
source_url = "https://github.com/greg7mdp/parallel-hashmap"
tools.get("{0}/archive/{1}.tar.gz".format(source_url, self.version))
extracted_dir = self.name + "-" + self.version
#Rename to "source_folder" is a convention to simplify later steps
os.rename(extracted_dir, self.source_subfolder)
def package(self):
include_folder = os.path.join(self.source_subfolder, "parallel_hashmap")
self.copy(pattern="LICENSE")
self.copy(pattern="*", dst="include/parallel_hashmap", src=include_folder)
def package_id(self):
self.info.header_only()
#if !defined(spp_memory_h_guard)
#define spp_memory_h_guard
#include <cstdint>
#include <cstring>
#include <cstdlib>
#if defined(_WIN32) || defined( __CYGWIN__)
#define SPP_WIN
#endif
#ifdef SPP_WIN
#include <windows.h>
#include <Psapi.h>
#undef min
#undef max
#elif defined(__linux__)
#include <sys/types.h>
#include <sys/sysinfo.h>
#elif defined(__FreeBSD__)
#include <paths.h>
#include <fcntl.h>
#include <kvm.h>
#include <unistd.h>
#include <sys/sysctl.h>
#include <sys/user.h>
#endif
namespace spp
{
uint64_t GetSystemMemory();
uint64_t GetTotalMemoryUsed();
uint64_t GetProcessMemoryUsed();
uint64_t GetPhysicalMemory();
uint64_t GetSystemMemory()
{
#ifdef SPP_WIN
MEMORYSTATUSEX memInfo;
memInfo.dwLength = sizeof(MEMORYSTATUSEX);
GlobalMemoryStatusEx(&memInfo);
return static_cast<uint64_t>(memInfo.ullTotalPageFile);
#elif defined(__linux__)
struct sysinfo memInfo;
sysinfo (&memInfo);
auto totalVirtualMem = memInfo.totalram;
totalVirtualMem += memInfo.totalswap;
totalVirtualMem *= memInfo.mem_unit;
return static_cast<uint64_t>(totalVirtualMem);
#elif defined(__FreeBSD__)
kvm_t *kd;
u_int pageCnt;
size_t pageCntLen = sizeof(pageCnt);
u_int pageSize;
struct kvm_swap kswap;
uint64_t totalVirtualMem;
pageSize = static_cast<u_int>(getpagesize());
sysctlbyname("vm.stats.vm.v_page_count", &pageCnt, &pageCntLen, NULL, 0);
totalVirtualMem = pageCnt * pageSize;
kd = kvm_open(NULL, _PATH_DEVNULL, NULL, O_RDONLY, "kvm_open");
kvm_getswapinfo(kd, &kswap, 1, 0);
kvm_close(kd);
totalVirtualMem += kswap.ksw_total * pageSize;
return totalVirtualMem;
#else
return 0;
#endif
}
uint64_t GetTotalMemoryUsed()
{
#ifdef SPP_WIN
MEMORYSTATUSEX memInfo;
memInfo.dwLength = sizeof(MEMORYSTATUSEX);
GlobalMemoryStatusEx(&memInfo);
return static_cast<uint64_t>(memInfo.ullTotalPageFile - memInfo.ullAvailPageFile);
#elif defined(__linux__)
struct sysinfo memInfo;
sysinfo(&memInfo);
auto virtualMemUsed = memInfo.totalram - memInfo.freeram;
virtualMemUsed += memInfo.totalswap - memInfo.freeswap;
virtualMemUsed *= memInfo.mem_unit;
return static_cast<uint64_t>(virtualMemUsed);
#elif defined(__FreeBSD__)
kvm_t *kd;
u_int pageSize;
u_int pageCnt, freeCnt;
size_t pageCntLen = sizeof(pageCnt);
size_t freeCntLen = sizeof(freeCnt);
struct kvm_swap kswap;
uint64_t virtualMemUsed;
pageSize = static_cast<u_int>(getpagesize());
sysctlbyname("vm.stats.vm.v_page_count", &pageCnt, &pageCntLen, NULL, 0);
sysctlbyname("vm.stats.vm.v_free_count", &freeCnt, &freeCntLen, NULL, 0);
virtualMemUsed = (pageCnt - freeCnt) * pageSize;
kd = kvm_open(NULL, _PATH_DEVNULL, NULL, O_RDONLY, "kvm_open");
kvm_getswapinfo(kd, &kswap, 1, 0);
kvm_close(kd);
virtualMemUsed += kswap.ksw_used * pageSize;
return virtualMemUsed;
#else
return 0;
#endif
}
uint64_t GetProcessMemoryUsed()
{
#ifdef SPP_WIN
PROCESS_MEMORY_COUNTERS_EX pmc;
GetProcessMemoryInfo(GetCurrentProcess(), reinterpret_cast<PPROCESS_MEMORY_COUNTERS>(&pmc), sizeof(pmc));
return static_cast<uint64_t>(pmc.PrivateUsage);
#elif defined(__linux__)
auto parseLine =
[](char* line)->int
{
auto i = strlen(line);
while(*line < '0' || *line > '9')
{
line++;
}
line[i-3] = '\0';
i = atoi(line);
return i;
};
auto file = fopen("/proc/self/status", "r");
auto result = -1;
char line[128];
while(fgets(line, 128, file) != nullptr)
{
if(strncmp(line, "VmSize:", 7) == 0)
{
result = parseLine(line);
break;
}
}
fclose(file);
return static_cast<uint64_t>(result) * 1024;
#elif defined(__FreeBSD__)
struct kinfo_proc info;
size_t infoLen = sizeof(info);
int mib[] = { CTL_KERN, KERN_PROC, KERN_PROC_PID, getpid() };
sysctl(mib, sizeof(mib) / sizeof(*mib), &info, &infoLen, NULL, 0);
return static_cast<uint64_t>(info.ki_rssize * getpagesize());
#else
return 0;
#endif
}
uint64_t GetPhysicalMemory()
{
#ifdef SPP_WIN
MEMORYSTATUSEX memInfo;
memInfo.dwLength = sizeof(MEMORYSTATUSEX);
GlobalMemoryStatusEx(&memInfo);
return static_cast<uint64_t>(memInfo.ullTotalPhys);
#elif defined(__linux__)
struct sysinfo memInfo;
sysinfo(&memInfo);
auto totalPhysMem = memInfo.totalram;
totalPhysMem *= memInfo.mem_unit;
return static_cast<uint64_t>(totalPhysMem);
#elif defined(__FreeBSD__)
u_long physMem;
size_t physMemLen = sizeof(physMem);
int mib[] = { CTL_HW, HW_PHYSMEM };
sysctl(mib, sizeof(mib) / sizeof(*mib), &physMem, &physMemLen, NULL, 0);
return physMem;
#else
return 0;
#endif
}
}
#endif // spp_memory_h_guard
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
#if !defined(phmap_dump_h_guard_)
#define phmap_dump_h_guard_
// ---------------------------------------------------------------------------
// Copyright (c) 2019, Gregory Popovitch - greg7mdp@gmail.com
//
// providing dump/load/mmap_load
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// ---------------------------------------------------------------------------
#include <iostream>
#include <fstream>
#include <sstream>
#include "phmap.h"
namespace phmap
{
namespace type_traits_internal {
#if defined(__GLIBCXX__) && __GLIBCXX__ < 20150801
template<typename T> struct IsTriviallyCopyable : public std::integral_constant<bool, __has_trivial_copy(T)> {};
#else
template<typename T> struct IsTriviallyCopyable : public std::is_trivially_copyable<T> {};
#endif
template <class T1, class T2>
struct IsTriviallyCopyable<std::pair<T1, T2>> {
static constexpr bool value = IsTriviallyCopyable<T1>::value && IsTriviallyCopyable<T2>::value;
};
}
namespace priv {
// ------------------------------------------------------------------------
// dump/load for raw_hash_set
// ------------------------------------------------------------------------
template <class Policy, class Hash, class Eq, class Alloc>
template<typename OutputArchive>
bool raw_hash_set<Policy, Hash, Eq, Alloc>::dump(OutputArchive& ar) const {
static_assert(type_traits_internal::IsTriviallyCopyable<value_type>::value,
"value_type should be trivially copyable");
if (!ar.dump(size_)) {
std::cerr << "Failed to dump size_" << std::endl;
return false;
}
if (size_ == 0) {
return true;
}
if (!ar.dump(capacity_)) {
std::cerr << "Failed to dump capacity_" << std::endl;
return false;
}
if (!ar.dump(reinterpret_cast<char*>(ctrl_),
sizeof(ctrl_t) * (capacity_ + Group::kWidth + 1))) {
std::cerr << "Failed to dump ctrl_" << std::endl;
return false;
}
if (!ar.dump(reinterpret_cast<char*>(slots_),
sizeof(slot_type) * capacity_)) {
std::cerr << "Failed to dump slot_" << std::endl;
return false;
}
return true;
}
template <class Policy, class Hash, class Eq, class Alloc>
template<typename InputArchive>
bool raw_hash_set<Policy, Hash, Eq, Alloc>::load(InputArchive& ar) {
static_assert(type_traits_internal::IsTriviallyCopyable<value_type>::value,
"value_type should be trivially copyable");
raw_hash_set<Policy, Hash, Eq, Alloc>().swap(*this); // clear any existing content
if (!ar.load(&size_)) {
std::cerr << "Failed to load size_" << std::endl;
return false;
}
if (size_ == 0) {
return true;
}
if (!ar.load(&capacity_)) {
std::cerr << "Failed to load capacity_" << std::endl;
return false;
}
// allocate memory for ctrl_ and slots_
initialize_slots();
if (!ar.load(reinterpret_cast<char*>(ctrl_),
sizeof(ctrl_t) * (capacity_ + Group::kWidth + 1))) {
std::cerr << "Failed to load ctrl" << std::endl;
return false;
}
if (!ar.load(reinterpret_cast<char*>(slots_),
sizeof(slot_type) * capacity_)) {
std::cerr << "Failed to load slot" << std::endl;
return false;
}
return true;
}
// ------------------------------------------------------------------------
// dump/load for parallel_hash_set
// ------------------------------------------------------------------------
template <size_t N,
template <class, class, class, class> class RefSet,
class Mtx_,
class Policy, class Hash, class Eq, class Alloc>
template<typename OutputArchive>
bool parallel_hash_set<N, RefSet, Mtx_, Policy, Hash, Eq, Alloc>::dump(OutputArchive& ar) const {
static_assert(type_traits_internal::IsTriviallyCopyable<value_type>::value,
"value_type should be trivially copyable");
if (! ar.dump(subcnt())) {
std::cerr << "Failed to dump meta!" << std::endl;
return false;
}
for (size_t i = 0; i < sets_.size(); ++i) {
auto& inner = sets_[i];
typename Lockable::UniqueLock m(const_cast<Inner&>(inner));
if (!inner.set_.dump(ar)) {
std::cerr << "Failed to dump submap " << i << std::endl;
return false;
}
}
return true;
}
template <size_t N,
template <class, class, class, class> class RefSet,
class Mtx_,
class Policy, class Hash, class Eq, class Alloc>
template<typename InputArchive>
bool parallel_hash_set<N, RefSet, Mtx_, Policy, Hash, Eq, Alloc>::load(InputArchive& ar) {
static_assert(type_traits_internal::IsTriviallyCopyable<value_type>::value,
"value_type should be trivially copyable");
size_t submap_count = 0;
if (!ar.load(&submap_count)) {
std::cerr << "Failed to load submap count!" << std::endl;
return false;
}
if (submap_count != subcnt()) {
std::cerr << "submap count(" << submap_count << ") != N(" << N << ")" << std::endl;
return false;
}
for (size_t i = 0; i < submap_count; ++i) {
auto& inner = sets_[i];
typename Lockable::UniqueLock m(const_cast<Inner&>(inner));
if (!inner.set_.load(ar)) {
std::cerr << "Failed to load submap " << i << std::endl;
return false;
}
}
return true;
}
} // namespace priv
// ------------------------------------------------------------------------
// BinaryArchive
// File is closed when archive object is destroyed
// ------------------------------------------------------------------------
// ------------------------------------------------------------------------
// ------------------------------------------------------------------------
class BinaryOutputArchive {
public:
BinaryOutputArchive(const char *file_path) {
ofs_.open(file_path, std::ios_base::binary);
}
bool dump(const char *p, size_t sz) {
ofs_.write(p, sz);
return true;
}
template<typename V>
typename std::enable_if<type_traits_internal::IsTriviallyCopyable<V>::value, bool>::type
dump(const V& v) {
ofs_.write(reinterpret_cast<const char *>(&v), sizeof(V));
return true;
}
private:
std::ofstream ofs_;
};
class BinaryInputArchive {
public:
BinaryInputArchive(const char * file_path) {
ifs_.open(file_path, std::ios_base::binary);
}
bool load(char* p, size_t sz) {
ifs_.read(p, sz);
return true;
}
template<typename V>
typename std::enable_if<type_traits_internal::IsTriviallyCopyable<V>::value, bool>::type
load(V* v) {
ifs_.read(reinterpret_cast<char *>(v), sizeof(V));
return true;
}
private:
std::ifstream ifs_;
};
} // namespace phmap
#endif // phmap_dump_h_guard_
#if !defined(phmap_fwd_decl_h_guard_)
#define phmap_fwd_decl_h_guard_
// ---------------------------------------------------------------------------
// Copyright (c) 2019, Gregory Popovitch - greg7mdp@gmail.com
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
// ---------------------------------------------------------------------------
#ifdef _MSC_VER
#pragma warning(push)
#pragma warning(disable : 4514) // unreferenced inline function has been removed
#pragma warning(disable : 4710) // function not inlined
#pragma warning(disable : 4711) // selected for automatic inline expansion
#endif
#include <memory>
#include <utility>
#if defined(PHMAP_USE_ABSL_HASH) && !defined(ABSL_HASH_HASH_H_)
namespace absl { template <class T> struct Hash; };
#endif
namespace phmap {
#if defined(PHMAP_USE_ABSL_HASH)
template <class T> using Hash = ::absl::Hash<T>;
#else
template <class T> struct Hash;
#endif
template <class T> struct EqualTo;
template <class T> struct Less;
template <class T> using Allocator = typename std::allocator<T>;
template<class T1, class T2> using Pair = typename std::pair<T1, T2>;
class NullMutex;
namespace priv {
// The hash of an object of type T is computed by using phmap::Hash.
template <class T, class E = void>
struct HashEq
{
using Hash = phmap::Hash<T>;
using Eq = phmap::EqualTo<T>;
};
template <class T>
using hash_default_hash = typename priv::HashEq<T>::Hash;
template <class T>
using hash_default_eq = typename priv::HashEq<T>::Eq;
// type alias for std::allocator so we can forward declare without including other headers
template <class T>
using Allocator = typename phmap::Allocator<T>;
// type alias for std::pair so we can forward declare without including other headers
template<class T1, class T2>
using Pair = typename phmap::Pair<T1, T2>;
} // namespace priv
// ------------- forward declarations for hash containers ----------------------------------
template <class T,
class Hash = phmap::priv::hash_default_hash<T>,
class Eq = phmap::priv::hash_default_eq<T>,
class Alloc = phmap::priv::Allocator<T>> // alias for std::allocator
class flat_hash_set;
template <class K, class V,
class Hash = phmap::priv::hash_default_hash<K>,
class Eq = phmap::priv::hash_default_eq<K>,
class Alloc = phmap::priv::Allocator<
phmap::priv::Pair<const K, V>>> // alias for std::allocator
class flat_hash_map;
template <class T,
class Hash = phmap::priv::hash_default_hash<T>,
class Eq = phmap::priv::hash_default_eq<T>,
class Alloc = phmap::priv::Allocator<T>> // alias for std::allocator
class node_hash_set;
template <class Key, class Value,
class Hash = phmap::priv::hash_default_hash<Key>,
class Eq = phmap::priv::hash_default_eq<Key>,
class Alloc = phmap::priv::Allocator<
phmap::priv::Pair<const Key, Value>>> // alias for std::allocator
class node_hash_map;
template <class T,
class Hash = phmap::priv::hash_default_hash<T>,
class Eq = phmap::priv::hash_default_eq<T>,
class Alloc = phmap::priv::Allocator<T>, // alias for std::allocator
size_t N = 4, // 2**N submaps
class Mutex = phmap::NullMutex> // use std::mutex to enable internal locks
class parallel_flat_hash_set;
template <class K, class V,
class Hash = phmap::priv::hash_default_hash<K>,
class Eq = phmap::priv::hash_default_eq<K>,
class Alloc = phmap::priv::Allocator<
phmap::priv::Pair<const K, V>>, // alias for std::allocator
size_t N = 4, // 2**N submaps
class Mutex = phmap::NullMutex> // use std::mutex to enable internal locks
class parallel_flat_hash_map;
template <class T,
class Hash = phmap::priv::hash_default_hash<T>,
class Eq = phmap::priv::hash_default_eq<T>,
class Alloc = phmap::priv::Allocator<T>, // alias for std::allocator
size_t N = 4, // 2**N submaps
class Mutex = phmap::NullMutex> // use std::mutex to enable internal locks
class parallel_node_hash_set;
template <class Key, class Value,
class Hash = phmap::priv::hash_default_hash<Key>,
class Eq = phmap::priv::hash_default_eq<Key>,
class Alloc = phmap::priv::Allocator<
phmap::priv::Pair<const Key, Value>>, // alias for std::allocator
size_t N = 4, // 2**N submaps
class Mutex = phmap::NullMutex> // use std::mutex to enable internal locks
class parallel_node_hash_map;
// ------------- forward declarations for btree containers ----------------------------------
template <typename Key, typename Compare = phmap::Less<Key>,
typename Alloc = phmap::Allocator<Key>>
class btree_set;
template <typename Key, typename Compare = phmap::Less<Key>,
typename Alloc = phmap::Allocator<Key>>
class btree_multiset;
template <typename Key, typename Value, typename Compare = phmap::Less<Key>,
typename Alloc = phmap::Allocator<phmap::priv::Pair<const Key, Value>>>
class btree_map;
template <typename Key, typename Value, typename Compare = phmap::Less<Key>,
typename Alloc = phmap::Allocator<phmap::priv::Pair<const Key, Value>>>
class btree_multimap;
} // namespace phmap
#ifdef _MSC_VER
#pragma warning(pop)
#endif
#endif // phmap_fwd_decl_h_guard_
import torch
import torch.multiprocessing as mp
from typing import Optional, Tuple
from base import BaseSampler, NegativeSampling, SampleOutput
from neighbor_sampler import NeighborSampler, SampleType
class RandomWalkSampler(BaseSampler):
def __init__(
self,
num_nodes: int,
num_layers: int,
graph_data,
workers = 1,
tnb = None,
is_distinct = 0,
policy = "uniform",
edge_weight: Optional[torch.Tensor] = None,
graph_name = None
) -> None:
r"""__init__
Args:
num_nodes: the num of all nodes in the graph
num_layers: the num of layers to be sampled
workers: the number of threads, default value is 1
tnb: neighbor infomation table
is_distinct: 1-need distinct, 0-don't need distinct
policy: "uniform" or "recent" or "weighted"
is_root_ts: 1-base on root's ts, 0-base on parent node's ts
edge_weight: the initial weights of edges
graph_name: the name of graph
"""
super().__init__()
self.sampler = NeighborSampler(
num_nodes=num_nodes,
tnb=tnb,
num_layers=num_layers,
fanout=[1 for i in range(num_layers)],
graph_data=graph_data,
edge_weight = edge_weight,
workers=workers,
policy=policy,
graph_name=graph_name,
is_distinct = is_distinct
)
self.num_layers = num_layers
def _get_sample_info(self):
return self.num_nodes,self.num_layers,self.fanout,self.workers
def _get_sample_options(self):
return {"is_distinct" : self.is_distinct,
"policy" : self.policy,
"with_eid" : self.tnb.with_eid,
"weighted" : self.tnb.weighted,
"with_timestamp" : self.tnb.with_timestamp}
def insert_edges_with_timestamp(
self,
edge_index : torch.Tensor,
eid : torch.Tensor,
timestamp : torch.Tensor,
edge_weight : Optional[torch.Tensor] = None):
row, col = edge_index
# 更新节点数和tnb
self.num_nodes = self.tnb.update_neighbors_with_time(
row.contiguous(),
col.contiguous(),
timestamp.contiguous(),
eid.contiguous(),
self.is_distinct,
edge_weight.contiguous())
def update_edges_weight(
self,
edge_index : torch.Tensor,
eid : torch.Tensor,
edge_weight : Optional[torch.Tensor] = None):
row, col = edge_index
# 更新tnb的权重信息
if self.tnb.with_eid:
self.tnb.update_edge_weight(
eid.contiguous(),
col.contiguous(),
edge_weight.contiguous()
)
else:
self.tnb.update_edge_weight(
row.contiguous(),
col.contiguous(),
edge_weight.contiguous()
)
def update_nodes_weight(
self,
nid : torch.Tensor,
node_weight : Optional[torch.Tensor] = None):
# 更新tnb的权重信息
self.tnb.update_node_weight(
nid.contiguous(),
node_weight.contiguous()
)
def update_all_node_weight(
self,
node_weight : torch.Tensor):
# 更新tnb的权重信息
self.tnb.update_all_node_weight(node_weight.contiguous())
def sample_from_nodes(
self,
nodes: torch.Tensor,
with_outer_sample: SampleType,
ts: Optional[torch.Tensor] = None
) -> SampleOutput:
r"""Performs mutilayer sampling from the nodes specified in: nodes
The specific number of layers is determined by parameter: num_layers
returning a sampled subgraph in the specified output format: Tuple[torch.Tensor, list].
Args:
nodes: the list of seed nodes index
with_outer_sample: 0-sample in whole graph structure; 1-sample onehop outer nodel; 2-cross partition sampling
Returns:
sampled_nodes: the node sampled
sampled_edge_index: the edge sampled
"""
return self.sampler.sample_from_nodes(nodes, ts, with_outer_sample)
def sample_from_edges(
self,
edges: torch.Tensor,
ets: Optional[torch.Tensor] = None,
neg_sampling: Optional[NegativeSampling] = None,
with_outer_sample: SampleType = SampleType.Whole
) -> SampleOutput:
r"""Performs sampling from the edges specified in :obj:`index`,
returning a sampled subgraph in the specified output format.
Args:
edges: the list of seed edges index
with_outer_sample: 0-sample in whole graph structure; 1-sample onehop outer nodel; 2-cross partition sampling
edge_label: the label for the seed edges.
neg_sampling: The negative sampling configuration
Returns:
sampled_nodes: the nodes sampled
sampled_edge_index_list: the edges sampled
"""
return self.sampler.sample_from_edges(edges, ets, neg_sampling, with_outer_sample)
if __name__=="__main__":
edge_index1 = torch.tensor([[0, 1, 1, 1, 1, 2, 2, 2, 2, 4, 4, 4, 5], # , 3, 3
[1, 0, 2, 0, 4, 1, 3, 0, 3, 3, 5, 0, 2]])# , 2, 5
timeStamp=torch.FloatTensor([1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4])
edge_weight1 = None
num_nodes1 = 6
num_neighbors = 2
# Run the neighbor sampling
from Utils import GraphData
g_data = GraphData(id=0, edge_index=edge_index1, timestamp=timeStamp, data=None, partptr=torch.tensor([0, num_nodes1]))
# Run the random walk sampling
sampler=RandomWalkSampler(num_nodes=num_nodes1,
num_layers=3,
edge_weight=edge_weight1,
graph_data=g_data,
graph_name='a',
workers=4,
is_root_ts=0,
is_distinct = 0)
out = sampler.sample_from_nodes(torch.tensor([1,2]),
with_outer_sample=SampleType.Whole,
ts=torch.tensor([1, 2]))
# out = sampler.sample_from_edges(torch.tensor([[1,2],[4,0]]),
# with_outer_sample=SampleType.Whole,
# ets = torch.tensor([1, 2]))
# Print the result
print('node:', out.node)
print('edge_index_list:', out.edge_index_list)
print('eid_list:', out.eid_list)
print('eid_ts_list:', out.eid_ts_list)
print('metadata: ', out.metadata)
import argparse
import random
import pandas as pd
import numpy as np
import torch
from ogb.nodeproppred import PygNodePropPredDataset
from torch_geometric.datasets import Reddit
import time
from tqdm import tqdm
from Utils import GraphData
class NegLinkSampler:
def __init__(self, num_nodes):
self.num_nodes = num_nodes
def sample(self, n):
return np.random.randint(self.num_nodes, size=n)
class NegLinkInductiveSampler:
def __init__(self, nodes):
self.nodes = list(nodes)
def sample(self, n):
return np.random.choice(self.nodes, size=n)
def load_reddit_dataset(data_path):
df = pd.read_csv('/home/zlj/hzq/project/code/TGL/DATA/{}/edges.csv'.format("REDDIT"))
num_nodes = max(int(df['src'].max()), int(df['dst'].max())) + 1
src = torch.tensor(df['src'].to_numpy(dtype=int))
dst = torch.tensor(df['dst'].to_numpy(dtype=int))
edge_index = torch.stack([src, dst])
timestamp = torch.tensor(df['time']).float()
g = GraphData(0, edge_index, timestamp=timestamp, data=None, partptr=torch.tensor([0, num_nodes]))
return g, df
parser=argparse.ArgumentParser()
parser.add_argument('--data', type=str, help='dataset name',default="REDDIT")
parser.add_argument('--config', type=str, help='path to config file',default="/home/zlj/hzq/project/code/TGL/config/TGN.yml")
parser.add_argument('--batch_size', type=int, default=600, help='path to config file')
parser.add_argument('--num_thread', type=int, default=64, help='number of thread')
args=parser.parse_args()
seed=10
torch.manual_seed(seed) # 为CPU设置随机种子
torch.cuda.manual_seed(seed) # 为当前GPU设置随机种子
torch.cuda.manual_seed_all(seed) # if you are using multi-GPU,为所有GPU设置随机种子
np.random.seed(seed) # Numpy module.
random.seed(seed) # Python random module.
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
g_data, df = load_reddit_dataset("/home/zlj/hzq/code/dataset/reddit")
print(g_data)
# for worker in [1,2,3,4,5,6,7,8,9,10,20,30]:
# import random
# timestamp = [random.randint(1, 5) for i in range(0, g.num_edges)]
# timestamp = torch.FloatTensor(timestamp)
# print('begin load')
# pre = time.time()
# # timestamp = torch.load('/home/zlj/hzq/code/gnn/my_sampler/TemporalSample/timestamp.my')
# tnb = torch.load("tnb_reddit_before.my")
# end = time.time()
# print("load time:", end-pre)
# row, col = g.edge_index
edge_weight=None
# g_data = GraphData(id=1, edge_index=g.edge_index, timestamp=timestamp, data=g, partptr=torch.tensor([0, g.num_nodes//4, g.num_nodes//4*2, g.num_nodes//4*3, g.num_nodes]))
from neighbor_sampler import NeighborSampler, SampleType, get_neighbors
print('begin tnb')
row, col = g_data.edge_index
row = torch.cat([row, col])
col = torch.cat([col, row])
eid = torch.cat([g_data.eid, g_data.eid])
timestamp = torch.cat([g_data.edge_ts, g_data.edge_ts])
timestamp,ind = timestamp.sort()
timestamp = timestamp.float().contiguous()
eid = eid[ind].contiguous()
row = row[ind]
col = col[ind]
print(row, col)
pre = time.time()
tnb = get_neighbors("reddit", row.contiguous(), col.contiguous(), g_data.num_nodes, 0, eid, edge_weight, timestamp)
end = time.time()
print("init tnb time:", end-pre)
torch.save(tnb, "tnb_reddit_before.my")
pre = time.time()
sampler = NeighborSampler(g_data.num_nodes,
tnb=tnb,
num_layers=1,
fanout=[10],
graph_data=g_data,
workers=10,
policy="recent",
graph_name='a')
end = time.time()
print("init time:", end-pre)
# neg_link_sampler = NegLinkSampler(g_data.num_nodes)
from base import NegativeSampling, NegativeSamplingMode
neg_link_sampler = NegativeSampling(NegativeSamplingMode.triplet)
# from torch_geometric.sampler import NeighborSampler, NumNeighbors, NodeSamplerInput, SamplerOutput
# pre = time.time()
# num_nei = NumNeighbors([100, 100])
# node_idx = NodeSamplerInput(input_id=None, node=torch.tensor(range(g.num_nodes//4, g.num_nodes//4+600000)))# (input_id=None, node=torch.masked_select(torch.arange(g.num_nodes),node_data['train_mask']))
# sampler = NeighborSampler(g, num_nei)
# end = time.time()
# print("init time:", end-pre)
out = []
tot_time = 0
sam_time = 0
sam_edge = 0
pre = time.time()
for _, rows in tqdm(df.groupby(df.index // args.batch_size), total=len(df) // args.batch_size):
# root_nodes = torch.tensor(np.concatenate([rows.src.values, rows.dst.values, neg_link_sampler.sample(len(rows))])).long()
# ts = torch.tensor(np.concatenate([rows.time.values, rows.time.values, rows.time.values]).astype(np.float32))
# outi = sampler.sample_from_nodes(root_nodes, ts=ts)
edges = torch.tensor(np.stack([rows.src.values, rows.dst.values])).long()
outi, meta = sampler.sample_from_edges(edges=edges, ets=torch.tensor(rows.time.values).float(), neg_sampling=neg_link_sampler)
tot_time += outi[0].tot_time
sam_time += outi[0].sample_time
# print(outi[0].sample_edge_num)
sam_edge += outi[0].sample_edge_num
out.append(outi)
end = time.time()
print("sample time", end-pre)
print("tot_time", tot_time)
print("sam_time", sam_time)
print("sam_edge", sam_edge)
print('eid_list:', out[23][0].eid())
# print('delta_ts_list:', out[10][0].delta_ts)
print('node:', out[23][0].sample_nodes())
# print('node_ts:', out[23][0].sample_nodes_ts)
# print('eid_list:', out[23][1].eid)
# print('node:', out[23][1].sample_nodes)
# print('node_ts:', out[23][1].sample_nodes_ts)
# print('edge_index_list:', out[0][0].edge_index)
\ No newline at end of file
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CppExtension
setup(
name='sample_cores',
ext_modules=[
CppExtension(
name='sample_cores',
sources=['sample_cores_dist.cpp'],
extra_compile_args=['-fopenmp', '-Xlinker', ' -export-dynamic', '-O3', '-std=c++17'],
include_dirs=["./parallel_hashmap"],
),
],
cmdclass={
'build_ext': BuildExtension
})
import torch
class WorkStreamEvent:
def __init__(self):
self.train_stream = torch.cuda.Stream()
self.write_memory_stream = torch.cuda.Stream()
self.fetch_stream = torch.cuda.Stream()
self.write_mail_stream = torch.cuda.Stream()
\ No newline at end of file
This diff is collapsed. Click to expand it.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment