Commit 47ea208c by zhlj

demo with tgl

parents
# Configuration files
.idea/
# Data
startGNN_sample/cora/
startGNN_sample/wiki/
startGNN_sample/GDELT/
startGNN_sample/reddit/
startGNN_sample/*/metis_*/
tgl/DATA/
dataset/
bak_startGNN_sample/
*.zip
*.tar.gz
*.npy
*.csv
*.my
# Logs and output
logs/
log/
*.pt
# Byte-compiled / optimized / DLL files
__pycache__/
*/__pycache__/
*.py[cod]
*$py
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# Other
*.png
{
// Use IntelliSense to learn about possible attributes.
// Hover to view descriptions of existing attributes.
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"configurations": [
{
"name": "Python: Current File",
"type": "python",
"request": "launch",
"program": "${file}",
"console": "integratedTerminal",
"args":["--world","1","--rank","0"]
}
]
}
\ No newline at end of file
import argparse
import os
import sys
from os.path import abspath, join, dirname
sys.path.insert(0, join(abspath(dirname(__file__)+'/startGNN_sample')))
from startGNN_sample.DistGraphLoader import DataSet, partition_load
from startGNN_sample.Sample.neighbor_sampler import NeighborSampler
from startGNN_sample.Sample.base import NegativeSampling
#path1=os.path.abspath('.')
import torch
from startGNN_sample.DistGraphLoader import DistGraphData
from startGNN_sample.DistGraphLoader import DistributedDataLoader
from startGNN_sample.DistGraphLoader import DistCustomPool
import startGNN_sample.distparser as distparser
import time
import torch
import torch.nn.functional as F
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group
from startGNN_sample.shared_mailbox import SharedMailBox, SharedRPCMemoryManager
import os
"""
test command
python test.py --world_size 2 --rank 0
--world_size', default=4, type=int, metavar='W',
help='number of workers')
parser.add_argument('--rank', default=0, type=int, metavar='W',
help='rank of the worker')
parser.add_argument('--log_interval', type=int, default=10, metavar='N',
help='interval between training status logs')
parser.add_argument('--gamma', type=float, default=0.99, metavar='G',
help='how much to value future rewards')
parser.add_argument('--seed', type=int, default=1, metavar='S',
help='random seed for reproducibility')
parser.add_argument('--num_sampler', type=int, default=10, metavar='S',
help='number of samplers')
parser.add_argument('--queue_size', type=int, default=10, metavar='S',
help='sampler queue size')
"""
from tgl.modules import *
from tgl.sampler import *
from tgl.utils import *
from sklearn.metrics import average_precision_score, roc_auc_score
import torch
import time
import random
import dgl
import numpy as np
from sklearn.metrics import average_precision_score, roc_auc_score
from torch.nn.parallel import DistributedDataParallel as DDP
def set_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def main():
sample_param, memory_param, gnn_param, train_param = parse_config('tgl/config/TGN.yml')
use_cuda = True
torch.set_num_threads(12)
args = distparser.args
rank = distparser._get_worker_rank()
if use_cuda:
torch.cuda.set_device(rank)
DistCustomPool.init_distribution('127.0.0.1',9675,'127.0.0.1',10023,backend = "nccl")
pdata = partition_load("./startGNN_sample/GDELT", algo="metis")
graph = DistGraphData(pdata = pdata,edge_index= pdata.edge_index, full_edge = False)
sampler = NeighborSampler(num_nodes=graph.num_nodes, num_layers=1, fanout=[10],graph_data=graph, workers=8,is_root_ts = 0,policy = 'recent',graph_name = "wiki_train")
#print(graph.num_nodes)
train_data = torch.masked_select(graph.edge_index,graph.data.train_mask).reshape(2,-1)
#print(train_data.shape[1])
train_ts = torch.masked_select(graph.edge_ts,graph.data.train_mask)
val_data = torch.masked_select(graph.edge_index,graph.data.val_mask).reshape(2,-1)
val_ts = torch.masked_select(graph.edge_ts,graph.data.val_mask)
test_data = torch.masked_select(graph.edge_index,graph.data.test_mask).reshape(2,-1)
test_ts = torch.masked_select(graph.edge_ts,graph.data.test_mask)
print(train_data.shape[1],val_data.shape[1],test_data.shape[1])
train_data = DataSet(edges = train_data,ts =train_ts,labels = torch.ones(train_data.shape[-1]),eids = torch.nonzero(graph.data.train_mask).view(-1))
test_data = DataSet(edges = test_data,ts =test_ts,labels = torch.ones(test_data.shape[-1]),eids = torch.nonzero(graph.data.test_mask).view(-1))
val_data = DataSet(edges = val_data,ts = val_ts,labels = torch.ones(val_data.shape[-1]),eids = torch.nonzero(graph.data.val_mask).view(-1))
neg_sampler = NegativeSampling('triplet')
trainloader = DistributedDataLoader('train',graph,train_data,sampler = sampler,neg_sampler=neg_sampler,batch_size = 4000,shuffle=False,cache_memory_size = 0,drop_last=True,cs = 1)
testloader = DistributedDataLoader('test',graph,test_data,sampler = sampler,neg_sampler=neg_sampler,batch_size = 4000,shuffle=False,cache_memory_size = 0,drop_last=True,cs = None)
valloader = DistributedDataLoader('val',graph,val_data,sampler = sampler,neg_sampler=neg_sampler,batch_size = 4000,shuffle=False,cache_memory_size = 0,drop_last=True,cs = None)
total_time = time.time()
total_tps = 0
#for e in range(4):
# for batch_data in trainloader:
# #for e in batch_data.eids:
# total_tps = batch_data.nids.shape[0]
# print(total_tps/(time.time()-total_time))
gnn_dim_node = 0 if graph.data.x is None else graph.data.x.shape[1]
gnn_dim_edge = 0 if graph.data.edge_attr is None else graph.data.edge_attr.shape[1]
avg_time = 0
combine_first = False
if 'combine_neighs' in train_param and train_param['combine_neighs']:
combine_first = True#
if use_cuda:
model = GeneralModel(gnn_dim_node, gnn_dim_edge, sample_param, memory_param, gnn_param, train_param, combined=combine_first).cuda()
else:
model = GeneralModel(gnn_dim_node, gnn_dim_edge, sample_param, memory_param, gnn_param, train_param, combined=combine_first)
model = DDP(model,find_unused_parameters=True)
#model.module.last
if memory_param['type'] != 'none':
#注册Memory
mailbox = SharedMailBox(device = torch.device('cuda'))
#mailbox = SharedRPCMemoryManager('cuda')#暂时不启用
#构建mailbox的映射,partptr是分区划分的数组,默认存在graph里面,local_num_node是本地数量
mailbox.build_map(local_num_nodes=graph.partptr[rank+1]-graph.partptr[rank],partptr=graph.partptr)
mailbox.create_empty_memory(memory_param,gnn_dim_edge)
dist.barrier()
else:
mailbox = None
val_losses = list()
def eval(mode='val'):
neg_samples = 1
model.eval()
aps = list()
y_true_list = list()
y_prep_list = list()
aucs_mrrs = list()
if mode == 'val':
loader = valloader
elif mode == 'test':
loader = testloader
elif mode == 'train':
loader = trainloader
with torch.no_grad():
total_loss = 0
t0 = time.time()
for batchData in loader:
t1 = time.time()
mfgs,metadata,unq = batch_data_prepare_input(batchData,sample_param['history'],cuda = use_cuda)
t2 = time.time()
if mailbox is not None:
if use_cuda:
#这里是对接原来结构的处理,换到你的模型上需要重新处理一下,
mailbox.prep_input_mails(batchData.nids.to('cuda'),mfgs[0])
else:
mailbox.prep_input_mails(batchData.nids,mfgs[0])
t3 = time.time()
optimizer.zero_grad()
pred_pos, pred_neg = model(mfgs,metadata)
total_loss += creterion(pred_pos, torch.ones_like(pred_pos))
total_loss += creterion(pred_neg, torch.zeros_like(pred_neg))
y_pred = torch.cat([pred_pos, pred_neg], dim=0).sigmoid().cpu()
y_true = torch.cat([torch.ones(pred_pos.size(0)), torch.zeros(pred_neg.size(0))], dim=0)
y_true_list.append(y_true)
y_prep_list.append(y_pred)
if neg_samples > 1:
aucs_mrrs.append(torch.reciprocal(torch.sum(pred_pos.squeeze() < pred_neg.squeeze().reshape(neg_samples, -1), dim=0) + 1).type(torch.float))
else:
aucs_mrrs.append(roc_auc_score(y_true, y_pred))
t4 = time.time()
if mailbox is not None:
src = metadata['src_id_pos']
dst = metadata['dst_pos_pos']
if use_cuda:
ts = batchData.roots.ts.to('cuda')
edge_feats = graph.data.edge_attr[batchData.roots.eids].to('cuda')
block_ptr = batchData.nids[unq].to('cuda')
else:
ts = batchData.roots.ts
edge_feats = graph.data.edge_attr[batchData.roots.eids]
block_ptr = batchData.nids[unq]
#这里也是按照block的数据格式更新的,需要按你的模型更新重写,第一个block_ptr是一个映射指针,因为我们batchData里的数据是重映射的结果,需要从atchData.nids再找到原来的id值,#内层用#的scatter更新函数用的是全局的id,这里还有一个unq是处理block的时候又映射了一次,你那里应该不用再处理这个了
mailbox.get_memory_to_update(block_ptr,src,dst,ts,edge_feats,model.module.memory_updater.last_updated_memory,model.module.memory_updater.last_updated_ts)
t5 = time.time()
#print(t1-t0,t2-t1,t3-t2,t4-t3,t5-t4)
t0 = time.time()
if mode == 'val':
val_losses.append(float(total_loss))
aps = torch.tensor(aps)
y_true_list = torch.cat(y_true_list)
y_prep_list = torch.cat(y_prep_list)
ap = float(average_precision_score(y_true_list,y_prep_list))
if neg_samples > 1:
auc_mrr = float(torch.cat(aucs_mrrs).mean())
else:
auc_mrr = float(torch.tensor(aucs_mrrs).mean())
return ap, auc_mrr
#
# se#t_seed(0)
#
#
#MailBox(memory_param, graph.partptr[-1], gnn_dim_edge) if memory_param['type'] != 'none' else None
creterion = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=train_param['lr'])
#if 'all_on_gpu' in train_param and train_param['all_on_gpu']:
best_e = 0
best_ap = 0
#
sample_time = time.time()
for e in range(train_param['epoch']):
epoch_start_time = time.time()
train_aps = list()
print('Epoch {:d}:'.format(e))
time_prep = 0
time_tot = 0
total_loss = 0
# training
model.train()
if mailbox is not None:
mailbox.reset()
model.module.memory_updater.last_updated_nid = None
model.module.memory_updater.last_updated_memory = None
last_time = time.time()
t0 = time.time()
total_node = 0
total_edge = 0
total_batch = 0
total_root = 0
for batchData in trainloader:
total_batch = total_batch+1
total_root = total_root + batchData.meta_data['src_id'].shape[0]
total_node = total_node + batchData.nids.unique().shape[0]
total_edge = total_edge + batchData.eids[0].shape[0]
t1 = time.time()
sample_time += time.time() - last_time
t_tot_s = time.time()
t_prep_s = time.time()
mfgs,metadata,unq = batch_data_prepare_input(batchData,sample_param['history'],cuda = use_cuda)
t2 = time.time()
if mailbox is not None:
if use_cuda:
mailbox.prep_input_mails(batchData.nids.to('cuda'),mfgs[0])
else:
mailbox.prep_input_mails(batchData.nids,mfgs[0])
t3 = time.time()
time_prep += time.time() - t_prep_s
optimizer.zero_grad()
pred_pos, pred_neg = model(mfgs,metadata)
loss = creterion(pred_pos, torch.ones_like(pred_pos))
loss += creterion(pred_neg, torch.zeros_like(pred_neg))
total_loss += float(loss) #* train_param['batch_size']
loss.backward()
t4 = time.time()
optimizer.step()
t_prep_s = time.time()
#if mailbox is not None:
y_pred = torch.cat([pred_pos, pred_neg], dim=0).sigmoid().cpu()
y_true = torch.cat([torch.ones(pred_pos.size(0)), torch.zeros(pred_neg.size(0))], dim=0)
train_aps.append(average_precision_score(y_true, y_pred.detach().numpy()))
if mailbox is not None:
src = metadata['src_id']
dst = metadata['dst_pos_id']
if use_cuda:
ts = batchData.roots.ts.to('cuda')
edge_feats = graph.data.edge_attr[batchData.roots.eids].to('cuda')
block_ptr = batchData.nids[unq].to('cuda')
else:
ts = batchData.roots.ts
edge_feats = graph.data.edge_attr[batchData.roots.eids]
block_ptr = batchData.nids[unq]
mailbox.get_memory_to_update(block_ptr,src,dst,ts,edge_feats,model.module.memory_updater.last_updated_memory,model.module.memory_updater.last_updated_ts)
t5 = time.time()
time_prep += time.time() - t_prep_s
time_tot += time.time() - t_tot_s
last_time = time.time()
#print(t1-t0,t2-t1,t3-t2,t4-t3,t5-t4)
t0 = time.time()
print(total_batch,total_root,total_node,total_edge)
avg_time += time.time()-epoch_start_time
train_ap = float(torch.tensor(train_aps).mean())
ap, auc = eval('val')
if e > 2 and ap > best_ap:
best_e = e
best_ap = ap
# torch.save(model.state_dict(), path_saver)
print('\ttrain loss:{:.4f} train ap:{:4f} val ap:{:4f} val auc:{:4f}'.format(total_loss,train_ap, ap, auc))
print('\ttotal time:{:.2f}s sample time:{:.2f}s prep time:{:.2f}s'.format(time.time()-epoch_start_time, sample_time, time_prep))
dist.barrier()
print('Loading model at epoch {}...'.format(best_e))
#model.load_state_dict(torch.load(path_saver))
model.eval()
#if sampler is not None:
# sampler.reset()
if mailbox is not None:
mailbox.reset()
model.module.memory_updater.last_updated_nid = None
eval('train')
eval('val')
ap, auc = eval('test')
if args.eval_neg_samples > 1:
print('\ttest AP:{:4f} test MRR:{:4f}'.format(ap, auc))
else:
print('\ttest AP:{:4f} test AUC:{:4f}'.format(ap, auc))
print('avg_time',avg_time/train_param['epoch'])
DistCustomPool.close_distribution()
if __name__ == "__main__":
main()
import argparse
import os
import sys
from os.path import abspath, join, dirname
sys.path.insert(0, join(abspath(dirname(__file__)+'/startGNN_sample')))
from startGNN_sample.DistGraphLoader import DataSet, partition_load
from startGNN_sample.Sample.neighbor_sampler import NeighborSampler
from startGNN_sample.Sample.base import NegativeSampling
#path1=os.path.abspath('.')
import torch
from startGNN_sample.DistGraphLoader import DistGraphData
from startGNN_sample.DistGraphLoader import DistributedDataLoader
from startGNN_sample.DistGraphLoader import DistCustomPool
import startGNN_sample.distparser as distparser
import time
import torch
import torch.nn.functional as F
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group
from startGNN_sample.shared_mailbox import SharedMailBox, SharedRPCMemoryManager
import os
"""
test command
python test.py --world_size 2 --rank 0
--world_size', default=4, type=int, metavar='W',
help='number of workers')
parser.add_argument('--rank', default=0, type=int, metavar='W',
help='rank of the worker')
parser.add_argument('--log_interval', type=int, default=10, metavar='N',
help='interval between training status logs')
parser.add_argument('--gamma', type=float, default=0.99, metavar='G',
help='how much to value future rewards')
parser.add_argument('--seed', type=int, default=1, metavar='S',
help='random seed for reproducibility')
parser.add_argument('--num_sampler', type=int, default=10, metavar='S',
help='number of samplers')
parser.add_argument('--queue_size', type=int, default=10, metavar='S',
help='sampler queue size')
"""
from tgl.modules import *
from tgl.sampler import *
from tgl.utils import *
from sklearn.metrics import average_precision_score, roc_auc_score
import torch
import time
import random
import dgl
import numpy as np
from sklearn.metrics import average_precision_score, roc_auc_score
from torch.nn.parallel import DistributedDataParallel as DDP
import matplotlib.pyplot as plt
import math
def main():
sample_param, memory_param, gnn_param, train_param = parse_config('tgl/config/TGN.yml')
use_cuda = True
torch.set_num_threads(12)
args = distparser.args
rank = distparser._get_worker_rank()
if use_cuda:
torch.cuda.set_device(rank)
DistCustomPool.init_distribution('127.0.0.1',9675,'127.0.0.1',10023,backend = "nccl")
pdata = partition_load("./startGNN_sample/GDELT", algo="metis")
graph = DistGraphData(pdata = pdata,edge_index= pdata.edge_index, full_edge = False)
print(graph.edge_index[0,:].shape[0],graph.partptr)
value,counts = torch.unique(graph.edge_index[1,:].view(-1),return_counts=True)
counts,ind = counts.sort(descending=True)
#print(counts[:500])
plt.plot(counts)
plt.savefig('degree-{}.png'.format(rank))
#print(counts)
#node = counts
#plt.hist(node)
#plt.title('Histogram of Random Data')
#plt.xlabel('Value')
#plt.ylabel('Freuency')
print(counts[:2048])
##print(node.shape[1],node.unique(dim=1).shape[1])
# 添加标题和轴标签
#plt.title('Histogram of Random Data')
#plt.xlabel('Value')
#plt.ylabel('Frequency')
DistCustomPool.close_distribution()
if __name__ == "__main__":
main()
startGNN_sample @ a2522497
Subproject commit a2522497a20e77733d562db176b83475a5224610
{
"configurations": [
{
"name": "linux-gcc-x64",
"includePath": [
"${workspaceFolder}/**"
],
"compilerPath": "/usr/bin/gcc",
"cStandard": "${default}",
"cppStandard": "${default}",
"intelliSenseMode": "linux-gcc-x64",
"compilerArgs": [
""
]
}
],
"version": 4
}
\ No newline at end of file
{
"version": "0.2.0",
"configurations": [
/*{
"name": "C/C++ Runner: Debug Session",
"type": "cppdbg",
"request": "launch",
"args": [],
"stopAtEntry": false,
"externalConsole": false,
"cwd": "/home/sxx/zlj/tgl/tgl-main",
"program": "/home/sxx/zlj/tgl/tgl-main/build/Debug/outDebug",
"MIMode": "gdb",
"miDebuggerPath": "gdb",
"setupCommands": [
{
"description": "Enable pretty-printing for gdb",
"text": "-enable-pretty-printing",
"ignoreFailures": true
}
]
},*/
{
"name": "Python: Current File",
"type": "python",
"request": "launch",
"program": "${file}",
"console": "integratedTerminal",
"justMyCode": false,
"args": [
"--data","/WIKI","--config","config/TGN.yml"
]
}
]
}
\ No newline at end of file
{
"files.associations": {
"cctype": "cpp",
"cmath": "cpp",
"cstdarg": "cpp",
"cstddef": "cpp",
"cstdio": "cpp",
"cstdlib": "cpp",
"cstring": "cpp",
"ctime": "cpp",
"cwchar": "cpp",
"cwctype": "cpp",
"array": "cpp",
"atomic": "cpp",
"*.tcc": "cpp",
"bitset": "cpp",
"chrono": "cpp",
"cstdint": "cpp",
"unordered_map": "cpp",
"vector": "cpp",
"exception": "cpp",
"algorithm": "cpp",
"functional": "cpp",
"iterator": "cpp",
"memory": "cpp",
"memory_resource": "cpp",
"numeric": "cpp",
"random": "cpp",
"string": "cpp",
"string_view": "cpp",
"system_error": "cpp",
"tuple": "cpp",
"type_traits": "cpp",
"utility": "cpp",
"initializer_list": "cpp",
"iosfwd": "cpp",
"iostream": "cpp",
"istream": "cpp",
"limits": "cpp",
"mutex": "cpp",
"new": "cpp",
"ostream": "cpp",
"stdexcept": "cpp",
"streambuf": "cpp",
"typeinfo": "cpp",
"bit": "cpp",
"__bit_reference": "cpp",
"__bits": "cpp",
"__config": "cpp",
"__debug": "cpp",
"__errc": "cpp",
"__locale": "cpp",
"__split_buffer": "cpp",
"__threading_support": "cpp",
"__tuple": "cpp",
"__verbose_abort": "cpp",
"ios": "cpp",
"locale": "cpp"
},
"C_Cpp_Runner.cCompilerPath": "gcc",
"C_Cpp_Runner.cppCompilerPath": "g++",
"C_Cpp_Runner.debuggerPath": "gdb",
"C_Cpp_Runner.cStandard": "",
"C_Cpp_Runner.cppStandard": "",
"C_Cpp_Runner.msvcBatchPath": "",
"C_Cpp_Runner.useMsvc": false,
"C_Cpp_Runner.warnings": [
"-Wall",
"-Wextra",
"-Wpedantic",
"-Wshadow",
"-Wformat=2",
"-Wconversion",
"-Wnull-dereference",
"-Wsign-conversion"
],
"C_Cpp_Runner.enableWarnings": true,
"C_Cpp_Runner.warningsAsError": false,
"C_Cpp_Runner.compilerArgs": [],
"C_Cpp_Runner.linkerArgs": [],
"C_Cpp_Runner.includePaths": [],
"C_Cpp_Runner.includeSearch": [
"*",
"**/*"
],
"C_Cpp_Runner.excludeSearch": [
"**/build",
"**/build/**",
"**/.*",
"**/.*/**",
"**/.vscode",
"**/.vscode/**"
]
}
\ No newline at end of file
## Code of Conduct
This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct).
For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact
opensource-codeofconduct@amazon.com with any additional questions or comments.
# Contributing Guidelines
Thank you for your interest in contributing to our project. Whether it's a bug report, new feature, correction, or additional
documentation, we greatly value feedback and contributions from our community.
Please read through this document before submitting any issues or pull requests to ensure we have all the necessary
information to effectively respond to your bug report or contribution.
## Reporting Bugs/Feature Requests
We welcome you to use the GitHub issue tracker to report bugs or suggest features.
When filing an issue, please check existing open, or recently closed, issues to make sure somebody else hasn't already
reported the issue. Please try to include as much information as you can. Details like these are incredibly useful:
* A reproducible test case or series of steps
* The version of our code being used
* Any modifications you've made relevant to the bug
* Anything unusual about your environment or deployment
## Contributing via Pull Requests
Contributions via pull requests are much appreciated. Before sending us a pull request, please ensure that:
1. You are working against the latest source on the *main* branch.
2. You check existing open, and recently merged, pull requests to make sure someone else hasn't addressed the problem already.
3. You open an issue to discuss any significant work - we would hate for your time to be wasted.
To send us a pull request, please:
1. Fork the repository.
2. Modify the source; please focus on the specific change you are contributing. If you also reformat all the code, it will be hard for us to focus on your change.
3. Ensure local tests pass.
4. Commit to your fork using clear commit messages.
5. Send us a pull request, answering any default questions in the pull request interface.
6. Pay attention to any automated CI failures reported in the pull request, and stay involved in the conversation.
GitHub provides additional document on [forking a repository](https://help.github.com/articles/fork-a-repo/) and
[creating a pull request](https://help.github.com/articles/creating-a-pull-request/).
## Finding contributions to work on
Looking at the existing issues is a great way to find something to contribute on. As our projects, by default, use the default GitHub issue labels (enhancement/bug/duplicate/help wanted/invalid/question/wontfix), looking at any 'help wanted' issues is a great place to start.
## Code of Conduct
This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct).
For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact
opensource-codeofconduct@amazon.com with any additional questions or comments.
## Security issue notifications
If you discover a potential security issue in this project we ask that you notify AWS/Amazon Security via our [vulnerability reporting page](http://aws.amazon.com/security/vulnerability-reporting/). Please do **not** create a public github issue.
## Licensing
See the [LICENSE](LICENSE) file for our project's licensing. We will ask you to confirm the licensing of your contribution.
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# TGL: A General Framework for Temporal Graph Training on Billion-Scale Graphs
## Overview
This repo is the open-sourced code for our work *TGL: A General Framework for Temporal Graph Training on Billion-Scale Graphs*.
## Requirements
- python >= 3.6.13
- pytorch >= 1.8.1
- pandas >= 1.1.5
- numpy >= 1.19.5
- dgl >= 0.6.1
- pyyaml >= 5.4.1
- tqdm >= 4.61.0
- pybind11 >= 2.6.2
- g++ >= 7.5.0
- openmp >= 201511
Our temporal sampler is implemented using C++, please compile the sampler first with the following command
> python setup.py build_ext --inplace
## Dataset
The four datasets used in our paper are available to download from AWS S3 bucket using the `down.sh` script. The total download size is around 350GB.
To use your own dataset, you need to put the following files in the folder `\DATA\\<NameOfYourDataset>\`
1. `edges.csv`: The file that stores temporal edge informations. The csv should have the following columns with the header as `,src,dst,time,ext_roll` where each of the column refers to edge index (start with zero), source node index (start with zero), destination node index, time stamp, extrapolation roll (0 for training edges, 1 for validation edges, 2 for test edges). The CSV should be sorted by time ascendingly.
2. `ext_full.npz`: The T-CSR representation of the temporal graph. We provide a script to generate this file from `edges.csv`. You can use the following command to use the script
>python gen_graph.py --data \<NameOfYourDataset>
3. `edge_features.pt` (optional): The torch tensor that stores the edge featrues row-wise with shape (num edges, dim edge features). *Note: at least one of `edge_features.pt` or `node_features.pt` should present.*
4. `node_features.pt` (optional): The torch tensor that stores the node featrues row-wise with shape (num nodes, dim node features). *Note: at least one of `edge_features.pt` or `node_features.pt` should present.*
5. `labels.csv` (optional): The file contains node labels for dynamic node classification task. The csv should have the following columns with the header as `,node,time,label,ext_roll` where each of the column refers to node label index (start with zero), node index (start with zero), time stamp, node label, extrapolation roll (0 for training node labels, 1 for validation node labels, 2 for test node labels). The CSV should be sorted by time ascendingly.
## Configuration Files
We provide example configuration files for five temporal GNN methods: JODIE, DySAT, TGAT, TGN and TGAT. The configuration files for single GPU training are located at `/config/` while the multiple GPUs training configuration files are located at `/config/dist/`.
The provided configuration files are all tested to be working. If you want to use your own network architecture, please refer to `/config/readme.yml` for the meaining of each entry in the yaml configuration file. As our framework is still under development, it possible that some combination of the confiruations will lead to bug.
## Run
Currently, our framework only supports extrapolation setting (inference for the future).
### Single GPU Link Prediction
>python train.py --data \<NameOfYourDataset> --config \<PathToConfigFile>
### MultiGPU Link Prediction
>python -m torch.distributed.launch --nproc_per_node=\<NumberOfGPUs+1> train_dist.py --data \<NameOfYourDataset> --config \<PathToConfigFile> --num_gpus \<NumberOfGPUs>
### Dynamic Node Classification
Currenlty, TGL only supports performing dynamic node classification using the dynamic node embedding generated in link prediction.
For Single GPU models, directly run
>python train_node.py --data \<NameOfYourDATA> --config \<PathToConfigFile> --model \<PathToSavedModel>
For multi-GPU models, you need to first generate the dynamic node embedding
>python -m torch.distributed.launch --nproc_per_node=\<NumberOfGPUs+1> extract_node_dist.py --data \<NameOfYourDataset> --config \<PathToConfigFile> --num_gpus \<NumberOfGPUs> --model \<PathToSavedModel>
After generating the node embeding for multi-GPU models, run
>python train_node.py --data \<NameOfYourDATA> --model \<PathToSavedModel>
## Security
See [CONTRIBUTING](CONTRIBUTING.md#security-issue-notifications) for more information.
## Cite
If you use TGL in a scientific publication, we would appreciate citations to the following paper:
```
@article{zhou2022tgl,
title={{TGL}: A General Framework for Temporal GNN Training on Billion-Scale Graphs},
author={Zhou, Hongkuan and Zheng, Da and Nisa, Israt and Ioannidis, Vasileios and Song, Xiang and Karypis, George},
year = {2022},
journal = {Proc. VLDB Endow.},
volume = {15},
number = {8},
}
```
## License
This project is licensed under the Apache-2.0 License.
sampling:
- layer: 1
neighbor:
- 10
strategy: 'recent'
prop_time: False
history: 1
duration: 0
num_thread: 32
no_neg: True
memory:
- type: 'node'
dim_time: 100
deliver_to: 'neighbors'
mail_combine: 'last'
memory_update: 'transformer'
attention_head: 2
mailbox_size: 10
combine_node_feature: False
dim_out: 100
gnn:
- arch: 'identity'
train:
- epoch: 100
batch_size: 600
lr: 0.0001
dropout: 0.1
att_dropout: 0.1
# all_on_gpu: True
\ No newline at end of file
sampling:
- layer: 2
neighbor:
- 10
- 10
strategy: 'uniform'
prop_time: True
history: 3
duration: 10000
num_thread: 32
memory:
- type: 'none'
dim_out: 0
gnn:
- arch: 'transformer_attention'
layer: 2
att_head: 2
dim_time: 0
dim_out: 100
combine: 'rnn'
train:
- epoch: 50
batch_size: 600
lr: 0.0001
dropout: 0.1
att_dropout: 0.1
all_on_gpu: True
\ No newline at end of file
sampling:
- no_sample: True
history: 1
memory:
- type: 'node'
dim_time: 100
deliver_to: 'self'
mail_combine: 'last'
memory_update: 'rnn'
mailbox_size: 1
combine_node_feature: True
dim_out: 100
gnn:
- arch: 'identity'
time_transform: 'JODIE'
train:
- epoch: 100
batch_size: 600
lr: 0.0001
dropout: 0.1
all_on_gpu: True
\ No newline at end of file
sampling:
- layer: 2
neighbor:
- 10
- 10
strategy: 'uniform'
prop_time: False
history: 1
duration: 0
num_thread: 32
memory:
- type: 'none'
dim_out: 0
gnn:
- arch: 'transformer_attention'
layer: 2
att_head: 2
dim_time: 100
dim_out: 100
train:
- epoch: 10
batch_size: 600
lr: 0.0001
dropout: 0.1
att_dropout: 0.1
all_on_gpu: True
\ No newline at end of file
sampling:
- layer: 1
neighbor:
- 10
strategy: 'recent'
prop_time: False
history: 1
duration: 0
num_thread: 32
memory:
- type: 'node'
dim_time: 100
deliver_to: 'self'
mail_combine: 'last'
memory_update: 'gru'
mailbox_size: 1
combine_node_feature: True
dim_out: 100
gnn:
- arch: 'transformer_attention'
layer: 1
att_head: 2
dim_time: 100
dim_out: 100
train:
- epoch: 4
#batch_size: 100
# reorder: 16
lr: 0.0001
dropout: 0.2
att_dropout: 0.2
all_on_gpu: True
\ No newline at end of file
sampling:
- layer: <number of layers to sample>
neighbor: <a list of integers indicating how many neighbors are sampled in each layer>
strategy: <'recent' that samples most recent neighbors or 'uniform' that uniformly samples neighbors form the past>
prop_time: <False or True that specifies wherether to use the timestamp of the root nodes when sampling for their multi-hop neighbors>
history: <number of snapshots to sample on>
duration: <length in time of each snapshot, 0 for infinite length (used in non-snapshot-based methods)
num_thread: <number of threads of the sampler>
memory:
- type: <'node', we only support node memory now>
dim_time: <an integer, the dimension of the time embedding>
deliver_to: <'self' that delivers the mails only to involved nodes or 'neighbors' that deliver the mails to neighbors>
mail_combine: <'last' that use the latest latest mail as the input to the memory updater>
memory_update: <'gru' or 'rnn'>
mailbox_size: <an integer, the size of the mailbox for each node>
combine_node_feature: <False or True that specifies whether to combine node features (with the updated memory) as the input to the GNN.
dim_out: <an integer, the dimension of the output node memory>
gnn:
- arch: <'transformer_attention' or 'identity' (no GNN)>
layer: <an integer, number of layers>
att_head: <an integer, number of attention heads>
dim_time: <an integer, the dimension of the time embedding>
dim_out: <an integer, the dimension of the output dynamic node embedding>
train:
- epoch: <an integer, number of epochs to train>
batch_size: <an integer, the batch size (of edges); for multi-gpu training, this is the local batchsize>
reorder: <(optional) an integer that is divisible by batch size the specifies how many chunks per batch used in the random chunk scheduling>
lr: <floating point, learning rate>
dropout: <floating point, dropout>
att_dropout: <floating point, dropout for attention>
all_on_gpu: <False or True that decides if the node/edge features and node memory are completely stored on GPU>
\ No newline at end of file
#!/bin/bash
wget -P ./DATA/GDELT https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/GDELT/int_train.npz
wget -P ./DATA/GDELT https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/GDELT/int_full.npz
wget -P ./DATA/GDELT https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/GDELT/node_features.pt
wget -P ./DATA/GDELT https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/GDELT/labels.csv
wget -P ./DATA/GDELT https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/GDELT/ext_full.npz
wget -P ./DATA/GDELT https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/GDELT/edges.csv
wget -P ./DATA/GDELT https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/GDELT/edge_features.pt
wget -P ./DATA/LASTFM https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/LASTFM/edges.csv
wget -P ./DATA/LASTFM https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/LASTFM/ext_full.npz
wget -P ./DATA/LASTFM https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/LASTFM/int_full.npz
wget -P ./DATA/LASTFM https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/LASTFM/int_train.npz
wget -P ./DATA/MAG https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/MAG/int_train.npz
wget -P ./DATA/MAG https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/MAG/labels.csv
wget -P ./DATA/MAG https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/MAG/int_full.npz .)
wget -P ./DATA/MAG https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/MAG/ext_full.npz
wget -P ./DATA/MAG https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/MAG/edges.csv
wget -P ./DATA/MAG https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/MAG/node_features.pt
wget -P ./DATA/MOOC https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/MOOC/edges.csv
wget -P ./DATA/MOOC https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/MOOC/ext_full.npz
wget -P ./DATA/MOOC https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/MOOC/int_full.npz
wget -P ./DATA/MOOC https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/MOOC/int_train.npz
wget -P ./DATA/REDDIT https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/REDDIT/edge_features.pt
wget -P ./DATA/REDDIT https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/REDDIT/edges.csv
wget -P ./DATA/REDDIT https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/REDDIT/ext_full.npz
wget -P ./DATA/REDDIT https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/REDDIT/int_full.npz
wget -P ./DATA/REDDIT https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/REDDIT/int_train.npz
wget -P ./DATA/REDDIT https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/REDDIT/labels.csv
wget -P ./DATA/WIKI https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/WIKI/edge_features.pt
wget -P ./DATA/WIKI https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/WIKI/edges.csv
wget -P ./DATA/WIKI https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/WIKI/ext_full.npz
wget -P ./DATA/WIKI https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/WIKI/int_full.npz
wget -P ./DATA/WIKI https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/WIKI/int_train.npz
wget -P ./DATA/WIKI https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/WIKI/labels.csv
\ No newline at end of file
#!/bin/bash
wget -P ./DATA/GDELT https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/GDELT/int_train.npz
wget -P ./DATA/GDELT https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/GDELT/int_full.npz
wget -P ./DATA/GDELT https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/GDELT/node_features.pt
wget -P ./DATA/GDELT https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/GDELT/labels.csv
wget -P ./DATA/GDELT https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/GDELT/ext_full.npz
wget -P ./DATA/GDELT https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/GDELT/edges.csv
wget -P ./DATA/GDELT https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/GDELT/edge_features.pt
\ No newline at end of file
import argparse
import os
parser=argparse.ArgumentParser()
parser.add_argument('--data', type=str, help='dataset name')
parser.add_argument('--config', type=str, help='path to config file')
parser.add_argument('--seed', type=int, default=0, help='random seed to use')
parser.add_argument('--num_gpus', type=int, default=4, help='number of gpus to use')
parser.add_argument('--model', type=str, help='path to model file')
parser.add_argument('--batch_size', type=int, default=4000, help='batch size to generate node embeddings')
parser.add_argument('--omp_num_threads', type=int, default=16)
parser.add_argument("--local_rank", type=int, default=-1)
args=parser.parse_args()
# set which GPU to use
if args.local_rank < args.num_gpus:
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.local_rank)
else:
os.environ['CUDA_VISIBLE_DEVICES'] = ''
os.environ['OMP_NUM_THREADS'] = str(args.omp_num_threads)
os.environ['MKL_NUM_THREADS'] = str(args.omp_num_threads)
import torch
import dgl
import random
import math
import hashlib
import numpy as np
from tqdm import tqdm
from dgl.utils.shared_mem import create_shared_mem_array, get_shared_mem_array
from sklearn.metrics import average_precision_score, roc_auc_score
from modules import *
from sampler import *
from utils import *
def set_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
set_seed(args.seed)
torch.distributed.init_process_group(backend='gloo')
nccl_group = torch.distributed.new_group(ranks=list(range(args.num_gpus)), backend='nccl')
if args.local_rank == 0:
_node_feats, _edge_feats = load_feat(args.data)
dim_feats = [0, 0, 0, 0, 0, 0]
if args.local_rank == 0:
if _node_feats is not None:
dim_feats[0] = _node_feats.shape[0]
dim_feats[1] = _node_feats.shape[1]
dim_feats[2] = _node_feats.dtype
node_feats = create_shared_mem_array('node_feats', _node_feats.shape, dtype=_node_feats.dtype)
node_feats.copy_(_node_feats)
del _node_feats
else:
node_feats = None
if _edge_feats is not None:
dim_feats[3] = _edge_feats.shape[0]
dim_feats[4] = _edge_feats.shape[1]
dim_feats[5] = _edge_feats.dtype
edge_feats = create_shared_mem_array('edge_feats', _edge_feats.shape, dtype=_edge_feats.dtype)
edge_feats.copy_(_edge_feats)
del _edge_feats
else:
edge_feats = None
torch.distributed.barrier()
torch.distributed.broadcast_object_list(dim_feats, src=0)
if args.local_rank > 0 and args.local_rank < args.num_gpus:
node_feats = None
edge_feats = None
if os.path.exists('DATA/{}/node_features.pt'.format(args.data)):
node_feats = get_shared_mem_array('node_feats', (dim_feats[0], dim_feats[1]), dtype=dim_feats[2])
if os.path.exists('DATA/{}/edge_features.pt'.format(args.data)):
edge_feats = get_shared_mem_array('edge_feats', (dim_feats[3], dim_feats[4]), dtype=dim_feats[5])
sample_param, memory_param, gnn_param, train_param = parse_config(args.config)
path_saver = args.model
if args.local_rank == args.num_gpus:
g, df = load_graph(args.data)
num_nodes = [g['indptr'].shape[0] - 1]
else:
num_nodes = [None]
torch.distributed.barrier()
torch.distributed.broadcast_object_list(num_nodes, src=args.num_gpus)
num_nodes = num_nodes[0]
mailbox = None
if memory_param['type'] != 'none':
if args.local_rank == 0:
node_memory = create_shared_mem_array('node_memory', torch.Size([num_nodes, memory_param['dim_out']]), dtype=torch.float32)
node_memory_ts = create_shared_mem_array('node_memory_ts', torch.Size([num_nodes]), dtype=torch.float32)
mails = create_shared_mem_array('mails', torch.Size([num_nodes, memory_param['mailbox_size'], 2 * memory_param['dim_out'] + dim_feats[4]]), dtype=torch.float32)
mail_ts = create_shared_mem_array('mail_ts', torch.Size([num_nodes, memory_param['mailbox_size']]), dtype=torch.float32)
next_mail_pos = create_shared_mem_array('next_mail_pos', torch.Size([num_nodes]), dtype=torch.long)
update_mail_pos = create_shared_mem_array('update_mail_pos', torch.Size([num_nodes]), dtype=torch.int32)
torch.distributed.barrier()
node_memory.zero_()
node_memory_ts.zero_()
mails.zero_()
mail_ts.zero_()
next_mail_pos.zero_()
update_mail_pos.zero_()
else:
torch.distributed.barrier()
node_memory = get_shared_mem_array('node_memory', torch.Size([num_nodes, memory_param['dim_out']]), dtype=torch.float32)
node_memory_ts = get_shared_mem_array('node_memory_ts', torch.Size([num_nodes]), dtype=torch.float32)
mails = get_shared_mem_array('mails', torch.Size([num_nodes, memory_param['mailbox_size'], 2 * memory_param['dim_out'] + dim_feats[4]]), dtype=torch.float32)
mail_ts = get_shared_mem_array('mail_ts', torch.Size([num_nodes, memory_param['mailbox_size']]), dtype=torch.float32)
next_mail_pos = get_shared_mem_array('next_mail_pos', torch.Size([num_nodes]), dtype=torch.long)
update_mail_pos = get_shared_mem_array('update_mail_pos', torch.Size([num_nodes]), dtype=torch.int32)
mailbox = MailBox(memory_param, num_nodes, dim_feats[4], node_memory, node_memory_ts, mails, mail_ts, next_mail_pos, update_mail_pos)
if args.local_rank < args.num_gpus:
# GPU worker process
model = GeneralModel(dim_feats[1], dim_feats[4], sample_param, memory_param, gnn_param, train_param).cuda()
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], process_group=nccl_group, output_device=args.local_rank)
model.load_state_dict(torch.load(path_saver, map_location=torch.device('cuda:0')))
creterion = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=train_param['lr'])
while True:
my_model_state = [None]
model_state = [None] * (args.num_gpus + 1)
torch.distributed.scatter_object_list(my_model_state, model_state, src=args.num_gpus)
if my_model_state[0] == -1:
break
elif my_model_state[0] == 4:
continue
elif my_model_state[0] == 2:
torch.save(model.state_dict(), path_saver)
continue
elif my_model_state[0] == 3:
model.load_state_dict(torch.load(path_saver, map_location=torch.device('cuda:0')))
continue
my_mfgs = [None]
multi_mfgs = [None] * (args.num_gpus + 1)
torch.distributed.scatter_object_list(my_mfgs, multi_mfgs, src=args.num_gpus)
mfgs = mfgs_to_cuda(my_mfgs[0])
prepare_input(mfgs, node_feats, edge_feats)
if my_model_state[0] == 0:
model.train()
optimizer.zero_grad()
if mailbox is not None:
mailbox.prep_input_mails(mfgs[0])
pred_pos, pred_neg = model(mfgs)
loss = creterion(pred_pos, torch.ones_like(pred_pos))
loss += creterion(pred_neg, torch.zeros_like(pred_neg))
loss.backward()
optimizer.step()
if mailbox is not None:
with torch.no_grad():
my_root = [None]
multi_root = [None] * (args.num_gpus + 1)
my_ts = [None]
multi_ts = [None] * (args.num_gpus + 1)
my_eid = [None]
multi_eid = [None] * (args.num_gpus + 1)
torch.distributed.scatter_object_list(my_root, multi_root, src=args.num_gpus)
torch.distributed.scatter_object_list(my_ts, multi_ts, src=args.num_gpus)
torch.distributed.scatter_object_list(my_eid, multi_eid, src=args.num_gpus)
eid = my_eid[0]
mem_edge_feats = edge_feats[eid] if edge_feats is not None else None
root_nodes = my_root[0]
ts = my_ts[0]
block = None
if memory_param['deliver_to'] == 'neighbors':
my_block = [None]
multi_block = [None] * (args.num_gpus + 1)
torch.distributed.scatter_object_list(my_block, multi_block, src=args.num_gpus)
block = my_block[0]
mailbox.update_mailbox(model.module.memory_updater.last_updated_nid, model.module.memory_updater.last_updated_memory, root_nodes, ts, mem_edge_feats, block)
mailbox.update_memory(model.module.memory_updater.last_updated_nid, model.module.memory_updater.last_updated_memory, model.module.memory_updater.last_updated_ts)
if memory_param['deliver_to'] == 'neighbors':
torch.distributed.barrier(group=nccl_group)
if args.local_rank == 0:
mailbox.update_next_mail_pos()
torch.distributed.gather_object(float(loss), None, dst=args.num_gpus)
elif my_model_state[0] == 1:
model.eval()
with torch.no_grad():
if mailbox is not None:
mailbox.prep_input_mails(mfgs[0])
pred_pos, pred_neg = model(mfgs)
if mailbox is not None:
my_root = [None]
multi_root = [None] * (args.num_gpus + 1)
my_ts = [None]
multi_ts = [None] * (args.num_gpus + 1)
my_eid = [None]
multi_eid = [None] * (args.num_gpus + 1)
torch.distributed.scatter_object_list(my_root, multi_root, src=args.num_gpus)
torch.distributed.scatter_object_list(my_ts, multi_ts, src=args.num_gpus)
torch.distributed.scatter_object_list(my_eid, multi_eid, src=args.num_gpus)
eid = my_eid[0]
mem_edge_feats = edge_feats[eid] if edge_feats is not None else None
root_nodes = my_root[0]
ts = my_ts[0]
block = None
if memory_param['deliver_to'] == 'neighbors':
my_block = [None]
multi_block = [None] * (args.num_gpus + 1)
torch.distributed.scatter_object_list(my_block, multi_block, src=args.num_gpus)
block = my_block[0]
mailbox.update_mailbox(model.module.memory_updater.last_updated_nid, model.module.memory_updater.last_updated_memory, root_nodes, ts, mem_edge_feats, block)
mailbox.update_memory(model.module.memory_updater.last_updated_nid, model.module.memory_updater.last_updated_memory, model.module.memory_updater.last_updated_ts)
if memory_param['deliver_to'] == 'neighbors':
torch.distributed.barrier(group=nccl_group)
if args.local_rank == 0:
mailbox.update_next_mail_pos()
y_pred = torch.cat([pred_pos, pred_neg], dim=0).sigmoid().cpu()
y_true = torch.cat([torch.ones(pred_pos.size(0)), torch.zeros(pred_neg.size(0))], dim=0)
ap = average_precision_score(y_true, y_pred)
auc = roc_auc_score(y_true, y_pred)
torch.distributed.gather_object(float(ap), None, dst=args.num_gpus)
torch.distributed.gather_object(float(auc), None, dst=args.num_gpus)
elif my_model_state[0] == 5:
model.eval()
with torch.no_grad():
if mailbox is not None:
mailbox.prep_input_mails(mfgs[0])
emb = model.module.get_emb(mfgs).detach().cpu()
torch.distributed.gather_object(emb, None, dst=args.num_gpus)
else:
# hosting process
train_edge_end = df[df['ext_roll'].gt(0)].index[0]
val_edge_end = df[df['ext_roll'].gt(1)].index[0]
sampler = None
if not ('no_sample' in sample_param and sample_param['no_sample']):
sampler = ParallelSampler(g['indptr'], g['indices'], g['eid'], g['ts'].astype(np.float32),
sample_param['num_thread'], 1, sample_param['layer'], sample_param['neighbor'],
sample_param['strategy']=='recent', sample_param['prop_time'],
sample_param['history'], float(sample_param['duration']))
neg_link_sampler = NegLinkSampler(g['indptr'].shape[0] - 1)
ldf = pd.read_csv('DATA/{}/labels.csv'.format(args.data))
args.batch_size = math.ceil(len(ldf) / (len(ldf) // args.batch_size // args.num_gpus * args.num_gpus))
train_param['batch_size'] = math.ceil(len(df) / (len(df) // train_param['batch_size'] // args.num_gpus * args.num_gpus))
processed_edge_id = 0
def forward_model_to(time):
global processed_edge_id
if processed_edge_id >= len(df):
return
while df.time[processed_edge_id] < time:
# print('curr:',processed_edge_id,df.time[processed_edge_id],'target:',time)
multi_mfgs = list()
multi_root = list()
multi_ts = list()
multi_eid = list()
multi_block = list()
for _ in range(args.num_gpus):
if processed_edge_id >= len(df):
break
rows = df[processed_edge_id:min(len(df), processed_edge_id + train_param['batch_size'])]
root_nodes = np.concatenate([rows.src.values, rows.dst.values, neg_link_sampler.sample(len(rows))]).astype(np.int32)
ts = np.concatenate([rows.time.values, rows.time.values, rows.time.values]).astype(np.float32)
if sampler is not None:
if 'no_neg' in sample_param and sample_param['no_neg']:
pos_root_end = root_nodes.shape[0] * 2 // 3
sampler.sample(root_nodes[:pos_root_end], ts[:pos_root_end])
else:
sampler.sample(root_nodes, ts)
ret = sampler.get_ret()
if gnn_param['arch'] != 'identity':
mfgs = to_dgl_blocks(ret, sample_param['history'], cuda=False)
else:
mfgs = node_to_dgl_blocks(root_nodes, ts, cuda=False)
multi_mfgs.append(mfgs)
multi_root.append(root_nodes)
multi_ts.append(ts)
multi_eid.append(rows['Unnamed: 0'].values)
if mailbox is not None and memory_param['deliver_to'] == 'neighbors':
multi_block.append(to_dgl_blocks(ret, sample_param['history'], reverse=True, cuda=False)[0][0])
processed_edge_id += train_param['batch_size']
if processed_edge_id >= len(df):
return
model_state = [1] * (args.num_gpus + 1)
my_model_state = [None]
torch.distributed.scatter_object_list(my_model_state, model_state, src=args.num_gpus)
multi_mfgs.append(None)
my_mfgs = [None]
torch.distributed.scatter_object_list(my_mfgs, multi_mfgs, src=args.num_gpus)
if mailbox is not None:
multi_root.append(None)
multi_ts.append(None)
multi_eid.append(None)
my_root = [None]
my_ts = [None]
my_eid = [None]
torch.distributed.scatter_object_list(my_root, multi_root, src=args.num_gpus)
torch.distributed.scatter_object_list(my_ts, multi_ts, src=args.num_gpus)
torch.distributed.scatter_object_list(my_eid, multi_eid, src=args.num_gpus)
if memory_param['deliver_to'] == 'neighbors':
multi_block.append(None)
my_block = [None]
torch.distributed.scatter_object_list(my_block, multi_block, src=args.num_gpus)
gathered_ap = [None] * (args.num_gpus + 1)
gathered_auc = [None] * (args.num_gpus + 1)
torch.distributed.gather_object(float(0), gathered_ap, dst=args.num_gpus)
torch.distributed.gather_object(float(0), gathered_auc, dst=args.num_gpus)
if processed_edge_id >= len(df):
break
embs = list()
multi_mfgs = list()
for _, rows in tqdm(ldf.groupby(ldf.index // args.batch_size)):
root_nodes = rows.node.values.astype(np.int32)
ts = rows.time.values.astype(np.float32)
if args.data == 'MAG':
# allow paper to sample neighbors
ts += 1
if sampler is not None:
sampler.sample(root_nodes, ts)
ret = sampler.get_ret()
if gnn_param['arch'] != 'identity':
mfgs = to_dgl_blocks(ret, sample_param['history'], cuda=False)
else:
mfgs = node_to_dgl_blocks(root_nodes, ts, cuda=False)
multi_mfgs.append(mfgs)
if len(multi_mfgs) == args.num_gpus:
forward_model_to(ts[-1])
model_state = [5] * (args.num_gpus + 1)
my_model_state = [None]
torch.distributed.scatter_object_list(my_model_state, model_state, src=args.num_gpus)
multi_mfgs.append(None)
my_mfgs = [None]
torch.distributed.scatter_object_list(my_mfgs, multi_mfgs, src=args.num_gpus)
multi_embs = [None] * (args.num_gpus + 1)
torch.distributed.gather_object(None, multi_embs, dst=args.num_gpus)
embs += multi_embs[:-1]
multi_mfgs = list()
emb_file_name = hashlib.md5(str(torch.load(args.model, map_location=torch.device('cpu'))).encode('utf-8')).hexdigest() + '.pt'
if not os.path.isdir('embs'):
os.mkdir('embs')
embs = torch.cat(embs, dim=0)
print('Embedding shape:', embs.shape)
torch.save(embs, 'embs/' + emb_file_name)
# let all process exit
model_state = [-1] * (args.num_gpus + 1)
my_model_state = [None]
torch.distributed.scatter_object_list(my_model_state, model_state, src=args.num_gpus)
\ No newline at end of file
import argparse
import itertools
import pandas as pd
import numpy as np
from tqdm import tqdm
parser=argparse.ArgumentParser()
parser.add_argument('--data', type=str, help='dataset name')
parser.add_argument('--add_reverse', default=False, action='store_true')
args=parser.parse_args()
df = pd.read_csv('DATA/{}/edges.csv'.format(args.data))
num_nodes = max(int(df['src'].max()), int(df['dst'].max())) + 1
print('num_nodes: ', num_nodes)
int_train_indptr = np.zeros(num_nodes + 1, dtype=np.int)
int_train_indices = [[] for _ in range(num_nodes)]
int_train_ts = [[] for _ in range(num_nodes)]
int_train_eid = [[] for _ in range(num_nodes)]
int_full_indptr = np.zeros(num_nodes + 1, dtype=np.int)
int_full_indices = [[] for _ in range(num_nodes)]
int_full_ts = [[] for _ in range(num_nodes)]
int_full_eid = [[] for _ in range(num_nodes)]
ext_full_indptr = np.zeros(num_nodes + 1, dtype=np.int)
ext_full_indices = [[] for _ in range(num_nodes)]
ext_full_ts = [[] for _ in range(num_nodes)]
ext_full_eid = [[] for _ in range(num_nodes)]
for idx, row in tqdm(df.iterrows(), total=len(df)):
src = int(row['src'])
dst = int(row['dst'])
if row['int_roll'] == 0:
int_train_indices[src].append(dst)
int_train_ts[src].append(row['time'])
int_train_eid[src].append(idx)
if args.add_reverse:
int_train_indices[dst].append(src)
int_train_ts[dst].append(row['time'])
int_train_eid[dst].append(idx)
# int_train_indptr[src + 1:] += 1
if row['int_roll'] != 3:
int_full_indices[src].append(dst)
int_full_ts[src].append(row['time'])
int_full_eid[src].append(idx)
if args.add_reverse:
int_full_indices[dst].append(src)
int_full_ts[dst].append(row['time'])
int_full_eid[dst].append(idx)
# int_full_indptr[src + 1:] += 1
ext_full_indices[src].append(dst)
ext_full_ts[src].append(row['time'])
ext_full_eid[src].append(idx)
if args.add_reverse:
ext_full_indices[dst].append(src)
ext_full_ts[dst].append(row['time'])
ext_full_eid[dst].append(idx)
# ext_full_indptr[src + 1:] += 1
for i in tqdm(range(num_nodes)):
int_train_indptr[i + 1] = int_train_indptr[i] + len(int_train_indices[i])
int_full_indptr[i + 1] = int_full_indptr[i] + len(int_full_indices[i])
ext_full_indptr[i + 1] = ext_full_indptr[i] + len(ext_full_indices[i])
int_train_indices = np.array(list(itertools.chain(*int_train_indices)))
int_train_ts = np.array(list(itertools.chain(*int_train_ts)))
int_train_eid = np.array(list(itertools.chain(*int_train_eid)))
int_full_indices = np.array(list(itertools.chain(*int_full_indices)))
int_full_ts = np.array(list(itertools.chain(*int_full_ts)))
int_full_eid = np.array(list(itertools.chain(*int_full_eid)))
ext_full_indices = np.array(list(itertools.chain(*ext_full_indices)))
ext_full_ts = np.array(list(itertools.chain(*ext_full_ts)))
ext_full_eid = np.array(list(itertools.chain(*ext_full_eid)))
print('Sorting...')
def tsort(i, indptr, indices, t, eid):
beg = indptr[i]
end = indptr[i + 1]
sidx = np.argsort(t[beg:end])
indices[beg:end] = indices[beg:end][sidx]
t[beg:end] = t[beg:end][sidx]
eid[beg:end] = eid[beg:end][sidx]
for i in tqdm(range(int_train_indptr.shape[0] - 1)):
tsort(i, int_train_indptr, int_train_indices, int_train_ts, int_train_eid)
tsort(i, int_full_indptr, int_full_indices, int_full_ts, int_full_eid)
tsort(i, ext_full_indptr, ext_full_indices, ext_full_ts, ext_full_eid)
# import pdb; pdb.set_trace()
print('saving...')
np.savez('DATA/{}/int_train.npz'.format(args.data), indptr=int_train_indptr, indices=int_train_indices, ts=int_train_ts, eid=int_train_eid)
np.savez('DATA/{}/int_full.npz'.format(args.data), indptr=int_full_indptr, indices=int_full_indices, ts=int_full_ts, eid=int_full_eid)
np.savez('DATA/{}/ext_full.npz'.format(args.data), indptr=ext_full_indptr, indices=ext_full_indices, ts=ext_full_ts, eid=ext_full_eid)
\ No newline at end of file
name: gnn
channels:
- pyg
- pytorch
- conda-forge
- defaults
dependencies:
- _libgcc_mutex=0.1=main
- _openmp_mutex=5.1=1_gnu
- anyio=3.6.2=pyhd8ed1ab_0
- appdirs=1.4.4=pyhd3eb1b0_0
- argon2-cffi=21.3.0=pyhd8ed1ab_0
- argon2-cffi-bindings=21.2.0=py38h0a891b7_2
- asttokens=2.2.1=pyhd8ed1ab_0
- attrs=22.2.0=pyh71513ae_0
- backcall=0.2.0=pyh9f0ad1d_0
- backports=1.0=pyhd8ed1ab_3
- backports.functools_lru_cache=1.6.4=pyhd8ed1ab_0
- beautifulsoup4=4.11.2=pyha770c72_0
- blas=1.0=mkl
- bleach=6.0.0=pyhd8ed1ab_0
- brotlipy=0.7.0=py38h0a891b7_1004
- bzip2=1.0.8=h7f98852_4
- ca-certificates=2022.12.7=ha878542_0
- cffi=1.15.0=py38h3931269_0
- charset-normalizer=2.1.1=pyhd8ed1ab_0
- cryptography=37.0.2=py38h2b5fc30_0
- cudatoolkit=11.6.0=hecad31d_10
- decorator=5.1.1=pyhd8ed1ab_0
- defusedxml=0.7.1=pyhd8ed1ab_0
- entrypoints=0.4=pyhd8ed1ab_0
- executing=1.2.0=pyhd8ed1ab_0
- ffmpeg=4.3=hf484d3e_0
- fftw=3.3.9=h27cfd23_1
- flit-core=3.8.0=pyhd8ed1ab_0
- freetype=2.10.4=h0708190_1
- gmp=6.2.1=h58526e2_0
- gnutls=3.6.13=h85f3911_1
- idna=3.4=pyhd8ed1ab_0
- importlib-metadata=6.0.0=pyha770c72_0
- importlib_resources=5.10.2=pyhd8ed1ab_0
- intel-openmp=2021.4.0=h06a4308_3561
- ipykernel=5.5.5=py38hd0cf306_0
- ipython=8.10.0=pyh41d4057_0
- ipython_genutils=0.2.0=py_1
- ipywidgets=8.0.4=pyhd8ed1ab_0
- jedi=0.18.2=pyhd8ed1ab_0
- jinja2=3.1.2=py38h06a4308_0
- joblib=1.1.1=py38h06a4308_0
- jpeg=9e=h166bdaf_1
- jsonschema=4.17.3=pyhd8ed1ab_0
- jupyter_client=7.0.6=pyhd8ed1ab_0
- jupyter_core=5.2.0=py38h578d9bd_0
- jupyter_server=1.23.5=pyhd8ed1ab_0
- jupyterlab_pygments=0.2.2=pyhd8ed1ab_0
- jupyterlab_widgets=3.0.5=pyhd8ed1ab_0
- lame=3.100=h7f98852_1001
- lcms2=2.12=hddcbb42_0
- ld_impl_linux-64=2.38=h1181459_1
- libffi=3.4.2=h6a678d5_6
- libgcc-ng=11.2.0=h1234567_1
- libgfortran-ng=11.2.0=h00389a5_1
- libgfortran5=11.2.0=h1234567_1
- libgomp=11.2.0=h1234567_1
- libiconv=1.17=h166bdaf_0
- libpng=1.6.37=h21135ba_2
- libsodium=1.0.18=h36c2ea0_1
- libstdcxx-ng=11.2.0=h1234567_1
- libtiff=4.2.0=hf544144_3
- libwebp-base=1.2.2=h7f98852_1
- lz4-c=1.9.3=h9c3ff4c_1
- markupsafe=2.1.1=py38h7f8727e_0
- matplotlib-inline=0.1.6=pyhd8ed1ab_0
- mistune=2.0.5=pyhd8ed1ab_0
- mkl=2021.4.0=h06a4308_640
- mkl-service=2.4.0=py38h95df7f1_0
- mkl_fft=1.3.1=py38h8666266_1
- mkl_random=1.2.2=py38h1abd341_0
- nbclassic=0.5.1=pyhd8ed1ab_0
- nbclient=0.7.2=pyhd8ed1ab_0
- nbconvert=7.2.9=pyhd8ed1ab_0
- nbconvert-core=7.2.9=pyhd8ed1ab_0
- nbconvert-pandoc=7.2.9=pyhd8ed1ab_0
- nbformat=5.7.3=pyhd8ed1ab_0
- ncurses=6.4=h6a678d5_0
- nest-asyncio=1.5.6=pyhd8ed1ab_0
- nettle=3.6=he412f7d_0
- networkx=2.8.4=py38h06a4308_0
- notebook=6.5.2=pyha770c72_1
- notebook-shim=0.2.2=pyhd8ed1ab_0
- numpy=1.23.5=py38h14f4228_0
- numpy-base=1.23.5=py38h31eccc5_0
- olefile=0.46=pyh9f0ad1d_1
- openh264=2.1.1=h780b84a_0
- openjpeg=2.4.0=hb52868f_1
- openssl=1.1.1t=h7f8727e_0
- packaging=22.0=py38h06a4308_0
- pandoc=2.19.2=ha770c72_0
- pandocfilters=1.5.0=pyhd8ed1ab_0
- parso=0.8.3=pyhd8ed1ab_0
- pexpect=4.8.0=pyh1a96a4e_2
- pickleshare=0.7.5=py_1003
- pillow=8.2.0=py38ha0e1e83_1
- pip=22.3.1=py38h06a4308_0
- pkgutil-resolve-name=1.3.10=pyhd8ed1ab_0
- platformdirs=3.0.0=pyhd8ed1ab_0
- pooch=1.4.0=pyhd3eb1b0_0
- prometheus_client=0.16.0=pyhd8ed1ab_0
- prompt-toolkit=3.0.36=pyha770c72_0
- psutil=5.9.0=py38h5eee18b_0
- ptyprocess=0.7.0=pyhd3deb0d_0
- pure_eval=0.2.2=pyhd8ed1ab_0
- pycparser=2.21=pyhd8ed1ab_0
- pyg=2.2.0=py38_torch_1.12.0_cu116
- pygments=2.14.0=pyhd8ed1ab_0
- pyopenssl=22.0.0=pyhd8ed1ab_1
- pyparsing=3.0.9=py38h06a4308_0
- pyrsistent=0.18.0=py38heee7806_0
- pysocks=1.7.1=pyha2e5f31_6
- python=3.8.16=h7a1cb2a_2
- python-dateutil=2.8.2=pyhd8ed1ab_0
- python-fastjsonschema=2.16.2=pyhd8ed1ab_0
- python_abi=3.8=2_cp38
- pytorch=1.12.1=py3.8_cuda11.6_cudnn8.3.2_0
- pytorch-cluster=1.6.0=py38_torch_1.12.0_cu116
- pytorch-mutex=1.0=cuda
- pytorch-scatter=2.1.0=py38_torch_1.12.0_cu116
- pytorch-sparse=0.6.16=py38_torch_1.12.0_cu116
- pyzmq=19.0.2=py38ha71036d_2
- readline=8.2=h5eee18b_0
- requests=2.28.2=pyhd8ed1ab_0
- scikit-learn=1.2.0=py38h6a678d5_0
- scipy=1.10.0=py38h14f4228_0
- send2trash=1.8.0=pyhd8ed1ab_0
- setuptools=65.6.3=py38h06a4308_0
- six=1.16.0=pyh6c4a22f_0
- sniffio=1.3.0=pyhd8ed1ab_0
- soupsieve=2.3.2.post1=pyhd8ed1ab_0
- sqlite=3.40.1=h5082296_0
- stack_data=0.6.2=pyhd8ed1ab_0
- terminado=0.17.1=pyh41d4057_0
- threadpoolctl=2.2.0=pyh0d69192_0
- tinycss2=1.2.1=pyhd8ed1ab_0
- tk=8.6.12=h1ccaba5_0
- torchaudio=0.12.1=py38_cu116
- torchvision=0.13.1=py38_cu116
- tornado=6.1=py38h0a891b7_3
- tqdm=4.64.1=py38h06a4308_0
- traitlets=5.9.0=pyhd8ed1ab_0
- typing-extensions=4.4.0=hd8ed1ab_0
- typing_extensions=4.4.0=pyha770c72_0
- urllib3=1.26.14=pyhd8ed1ab_0
- wcwidth=0.2.6=pyhd8ed1ab_0
- webencodings=0.5.1=py_1
- websocket-client=1.5.1=pyhd8ed1ab_0
- wheel=0.37.1=pyhd3eb1b0_0
- widgetsnbextension=4.0.5=pyhd8ed1ab_0
- xz=5.2.10=h5eee18b_1
- zeromq=4.3.4=h9c3ff4c_1
- zipp=3.13.0=pyhd8ed1ab_0
- zlib=1.2.13=h5eee18b_0
- zstd=1.5.0=ha95c52a_0
- pip:
- alabaster==0.7.13
- autopep8==2.0.2
- babel==2.12.1
- certifi==2023.5.7
- click==8.1.3
- contourpy==1.0.7
- cycler==0.11.0
- dgl==1.0.2
- dglgo==0.0.2
- docutils==0.19
- exceptiongroup==1.1.1
- flameprof==0.4
- fonttools==4.39.3
- hydra-core==1.3.2
- imagesize==1.4.1
- iniconfig==2.0.0
- isort==5.12.0
- kiwisolver==1.4.4
- littleutils==0.2.2
- lltm-cpp==0.0.0
- markdown-it-py==2.2.0
- matmul-cpu-pytorch==0.0.0
- matmul-gpu==0.0.0
- matplotlib==3.7.1
- mdurl==0.1.2
- memray==1.7.0
- numpydoc==1.5.0
- ogb==1.3.5
- omegaconf==2.3.0
- outdated==0.2.2
- pandas==1.5.3
- pluggy==1.0.0
- presample-cores==0.0.0
- pybind11==2.10.4
- pycodestyle==2.10.0
- pydantic==1.10.7
- pytest==7.2.2
- pytz==2022.7.1
- pyyaml==6.0
- rdkit-pypi==2022.9.5
- rich==13.4.1
- ruamel-yaml==0.17.21
- ruamel-yaml-clib==0.2.7
- sample-cores==0.0.0
- snowballstemmer==2.2.0
- sphinx==6.1.3
- sphinxcontrib-applehelp==1.0.4
- sphinxcontrib-devhelp==1.0.2
- sphinxcontrib-htmlhelp==2.0.1
- sphinxcontrib-jsmath==1.0.1
- sphinxcontrib-qthelp==1.0.3
- sphinxcontrib-serializinghtml==1.1.5
- tomli==2.0.1
- torch-geometric-autoscale==0.0.0
- torch-quiver==0.1.1
- typer==0.7.0
prefix: /home/sxx/miniconda3/envs/gnn
from os.path import abspath, join, dirname
import os
import sys
from os.path import abspath, join, dirname
sys.path.insert(0, join(abspath(dirname(__file__))))
import torch
import dgl
import math
import numpy as np
class TimeEncode(torch.nn.Module):
def __init__(self, dim):
super(TimeEncode, self).__init__()
self.dim = dim
self.w = torch.nn.Linear(1, dim)
self.w.weight = torch.nn.Parameter((torch.from_numpy(1 / 10 ** np.linspace(0, 9, dim, dtype=np.float32))).reshape(dim, -1))
self.w.bias = torch.nn.Parameter(torch.zeros(dim))
def forward(self, t):
output = torch.cos(self.w(t.float().reshape((-1, 1))))
return output
class EdgePredictor(torch.nn.Module):
def __init__(self, dim_in):
super(EdgePredictor, self).__init__()
self.dim_in = dim_in
self.src_fc = torch.nn.Linear(dim_in, dim_in)
self.dst_fc = torch.nn.Linear(dim_in, dim_in)
self.out_fc = torch.nn.Linear(dim_in, 1)
def forward(self, h, neg_samples=1):
num_edge = h.shape[0] // (neg_samples + 2)
h_src = self.src_fc(h[:num_edge])
h_pos_dst = self.dst_fc(h[num_edge:2 * num_edge])
h_neg_dst = self.dst_fc(h[2 * num_edge:])
h_pos_edge = torch.nn.functional.relu(h_src + h_pos_dst)
h_neg_edge = torch.nn.functional.relu(h_src.tile(neg_samples, 1) + h_neg_dst)
return self.out_fc(h_pos_edge), self.out_fc(h_neg_edge)
class TransfomerAttentionLayer(torch.nn.Module):
def __init__(self, dim_node_feat, dim_edge_feat, dim_time, num_head, dropout, att_dropout, dim_out, combined=False):
super(TransfomerAttentionLayer, self).__init__()
self.num_head = num_head
self.dim_node_feat = dim_node_feat
self.dim_edge_feat = dim_edge_feat
self.dim_time = dim_time
self.dim_out = dim_out
self.dropout = torch.nn.Dropout(dropout)
self.att_dropout = torch.nn.Dropout(att_dropout)
self.att_act = torch.nn.LeakyReLU(0.2)
self.combined = combined
if dim_time > 0:
self.time_enc = TimeEncode(dim_time)
if combined:
if dim_node_feat > 0:
self.w_q_n = torch.nn.Linear(dim_node_feat, dim_out)
self.w_k_n = torch.nn.Linear(dim_node_feat, dim_out)
self.w_v_n = torch.nn.Linear(dim_node_feat, dim_out)
if dim_edge_feat > 0:
self.w_k_e = torch.nn.Linear(dim_edge_feat, dim_out)
self.w_v_e = torch.nn.Linear(dim_edge_feat, dim_out)
if dim_time > 0:
self.w_q_t = torch.nn.Linear(dim_time, dim_out)
self.w_k_t = torch.nn.Linear(dim_time, dim_out)
self.w_v_t = torch.nn.Linear(dim_time, dim_out)
else:
if dim_node_feat + dim_time > 0:
self.w_q = torch.nn.Linear(dim_node_feat + dim_time, dim_out)
self.w_k = torch.nn.Linear(dim_node_feat + dim_edge_feat + dim_time, dim_out)
self.w_v = torch.nn.Linear(dim_node_feat + dim_edge_feat + dim_time, dim_out)
self.w_out = torch.nn.Linear(dim_node_feat + dim_out, dim_out)
self.layer_norm = torch.nn.LayerNorm(dim_out)
def forward(self, b):
assert(self.dim_time + self.dim_node_feat + self.dim_edge_feat > 0)
self.device = b.device
if b.num_edges() == 0:
return torch.zeros((b.num_dst_nodes(), self.dim_out), device=self.device)
if self.dim_time > 0:
time_feat = self.time_enc(b.edata['dt'])
zero_time_feat = self.time_enc(torch.zeros(b.num_dst_nodes(), dtype=torch.float32, device=self.device))
if self.combined:
Q = torch.zeros((b.num_edges(), self.dim_out), device=self.device)
K = torch.zeros((b.num_edges(), self.dim_out), device=self.device)
V = torch.zeros((b.num_edges(), self.dim_out), device=self.device)
if self.dim_node_feat > 0:
Q += self.w_q_n(b.srcdata['h'][:b.num_dst_nodes()])[b.edges()[1]]
K += self.w_k_n(b.srcdata['h'][b.num_dst_nodes():])[b.edges()[0] - b.num_dst_nodes()]
V += self.w_v_n(b.srcdata['h'][b.num_dst_nodes():])[b.edges()[0] - b.num_dst_nodes()]
if self.dim_edge_feat > 0:
K += self.w_k_e(b.edata['f'])
V += self.w_v_e(b.edata['f'])
if self.dim_time > 0:
Q += self.w_q_t(zero_time_feat)[b.edges()[1]]
K += self.w_k_t(time_feat)
V += self.w_v_t(time_feat)
Q = torch.reshape(Q, (Q.shape[0], self.num_head, -1))
K = torch.reshape(K, (K.shape[0], self.num_head, -1))
V = torch.reshape(V, (V.shape[0], self.num_head, -1))
att = dgl.ops.edge_softmax(b, self.att_act(torch.sum(Q*K, dim=2)))
att = self.att_dropout(att)
V = torch.reshape(V*att[:, :, None], (V.shape[0], -1))
b.edata['v'] = V
b.update_all(dgl.function.copy_edge('v', 'm'), dgl.function.sum('m', 'h'))
else:
if self.dim_time == 0 and self.dim_node_feat == 0:
Q = torch.ones((b.num_edges(), self.dim_out), device=self.device)
K = self.w_k(b.edata['f'])
V = self.w_v(b.edata['f'])
elif self.dim_time == 0 and self.dim_edge_feat == 0:
Q = self.w_q(b.srcdata['h'][:b.num_dst_nodes()])[b.edges()[1]]
K = self.w_k(b.srcdata['h'][b.edges()[0]])
V = self.w_v(b.srcdata['h'][b.edges()[0]])
elif self.dim_time == 0:
Q = self.w_q(b.srcdata['h'][:b.num_dst_nodes()])[b.edges()[1]]
K = self.w_k(torch.cat([b.srcdata['h'][b.edges()[0]], b.edata['f']], dim=1))
V = self.w_v(torch.cat([b.srcdata['h'][b.edges()[0]], b.edata['f']], dim=1))
#K = self.w_k(torch.cat([b.srcdata['h'][b.num_dst_nodes():], b.edata['f']], dim=1))
#V = self.w_v(torch.cat([b.srcdata['h'][b.num_dst_nodes():], b.edata['f']], dim=1))
elif self.dim_node_feat == 0:
Q = self.w_q(zero_time_feat)[b.edges()[1]]
K = self.w_k(torch.cat([b.edata['f'], time_feat], dim=1))
V = self.w_v(torch.cat([b.edata['f'], time_feat], dim=1))
elif self.dim_edge_feat == 0:
Q = self.w_q(torch.cat([b.srcdata['h'][:b.num_dst_nodes()], zero_time_feat], dim=1))[b.edges()[1]]
K = self.w_k(torch.cat([b.srcdata['h'][b.edges()[0]], time_feat], dim=1))
V = self.w_v(torch.cat([b.srcdata['h'][b.edges()[0]], time_feat], dim=1))
#K = self.w_k(torch.cat([b.srcdata['h'][b.num_dst_nodes():], time_feat], dim=1))
#V = self.w_v(torch.cat([b.srcdata['h'][b.num_dst_nodes():], time_feat], dim=1))
else:
Q = self.w_q(torch.cat([b.srcdata['h'][:b.num_dst_nodes()], zero_time_feat], dim=1))[b.edges()[1]]
K = self.w_k(torch.cat([b.srcdata['h'][b.edges()[0]], b.edata['f'], time_feat], dim=1))
V = self.w_v(torch.cat([b.srcdata['h'][b.edges()[0]], b.edata['f'], time_feat], dim=1))
#Q = self.w_q(torch.cat([b.srcdata['h'][:b.num_dst_nodes()], zero_time_feat], dim=1))[b.edges()[1]]
#K = self.w_k(torch.cat([b.srcdata['h'][b.num_dst_nodes():], b.edata['f'], time_feat], dim=1))
#V = self.w_v(torch.cat([b.srcdata['h'][b.num_dst_nodes():], b.edata['f'], time_feat], dim=1))
Q = torch.reshape(Q, (Q.shape[0], self.num_head, -1))
K = torch.reshape(K, (K.shape[0], self.num_head, -1))
V = torch.reshape(V, (V.shape[0], self.num_head, -1))
att = dgl.ops.edge_softmax(b, self.att_act(torch.sum(Q*K, dim=2)))
att = self.att_dropout(att)
V = torch.reshape(V*att[:, :, None], (V.shape[0], -1))
b.edata['v'] = V
b.update_all(dgl.function.copy_e('v', 'm'), dgl.function.sum('m', 'h'))
#b.srcdata['v'] = torch.cat([torch.zeros((b.num_dst_nodes(), V.shape[1]), device=torch.device('cuda:0')), V], dim=0)
#b.update_all(dgl.function.copy_u('v', 'm'), dgl.function.sum('m', 'h'))
if self.dim_node_feat != 0:
rst = torch.cat([b.dstdata['h'], b.srcdata['h'][:b.num_dst_nodes()]], dim=1)
else:
rst = b.dstdata['h']
rst = self.w_out(rst)
rst = torch.nn.functional.relu(self.dropout(rst))
return self.layer_norm(rst)
class IdentityNormLayer(torch.nn.Module):
def __init__(self, dim_out):
super(IdentityNormLayer, self).__init__()
self.norm = torch.nn.LayerNorm(dim_out)
def forward(self, b):
return self.norm(b.srcdata['h'])
class JODIETimeEmbedding(torch.nn.Module):
def __init__(self, dim_out):
super(JODIETimeEmbedding, self).__init__()
self.dim_out = dim_out
class NormalLinear(torch.nn.Linear):
# From Jodie code
def reset_parameters(self):
stdv = 1. / math.sqrt(self.weight.size(1))
self.weight.data.normal_(0, stdv)
if self.bias is not None:
self.bias.data.normal_(0, stdv)
self.time_emb = NormalLinear(1, dim_out)
def forward(self, h, mem_ts, ts):
time_diff = (ts - mem_ts) / (ts + 1)
rst = h * (1 + self.time_emb(time_diff.unsqueeze(1)))
return rst
\ No newline at end of file
import argparse
import os
import sys
from os.path import abspath, join, dirname
sys.path.insert(0, join(abspath(dirname(__file__)+'/startGNN_sample')))
from startGNN_sample.DistGraphLoader import DataSet, partition_load
from startGNN_sample.Sample.temporal_neighbor_sampler import TemporalNeighborSampler
from startGNN_sample.Sample.base import NegativeSampling
#path1=os.path.abspath('.')
import torch
from startGNN_sample.DistGraphLoader import DistGraphData
from startGNN_sample.DistGraphLoader import DistributedDataLoader
from startGNN_sample.DistGraphLoader import DistCustomPool
import startGNN_sample.distparser as distparser
import time
import torch
import torch.nn.functional as F
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group
import os
"""
test command
python test.py --world_size 2 --rank 0
--world_size', default=4, type=int, metavar='W',
help='number of workers')
parser.add_argument('--rank', default=0, type=int, metavar='W',
help='rank of the worker')
parser.add_argument('--log_interval', type=int, default=10, metavar='N',
help='interval between training status logs')
parser.add_argument('--gamma', type=float, default=0.99, metavar='G',
help='how much to value future rewards')
parser.add_argument('--seed', type=int, default=1, metavar='S',
help='random seed for reproducibility')
parser.add_argument('--num_sampler', type=int, default=10, metavar='S',
help='number of samplers')
parser.add_argument('--queue_size', type=int, default=10, metavar='S',
help='sampler queue size')
"""
from tgl.modules import *
from tgl.sampler import *
from tgl.utils import *
from sklearn.metrics import average_precision_score, roc_auc_score
import torch
import time
import random
import dgl
import numpy as np
from sklearn.metrics import average_precision_score, roc_auc_score
from torch.nn.parallel import DistributedDataParallel as DDP
def set_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def main():
args = distparser.args
DistCustomPool.init_distribution('127.0.0.1',9675,'127.0.0.1',10023,backend = "gloo")
pdata = partition_load("./startGNN_sample/wiki", algo="metis")
graph = DistGraphData(pdata = pdata,edge_index= pdata.edge_index, full_edge = False)
sampler = TemporalNeighborSampler(num_nodes=graph.num_nodes, num_layers=2, fanout=[10,10], graph_data=graph, workers=10,is_root_ts = True,graph_name = "wiki_train")
train_data = torch.masked_select(graph.edge_index,graph.data.train_mask).reshape(2,-1)
train_ts = torch.masked_select(graph.edge_ts,graph.data.train_mask)
val_data = torch.masked_select(graph.edge_index,graph.data.val_mask).reshape(2,-1)
val_ts = torch.masked_select(graph.edge_ts,graph.data.val_mask)
test_data = torch.masked_select(graph.edge_index,graph.data.test_mask).reshape(2,-1)
test_ts = torch.masked_select(graph.edge_ts,graph.data.test_mask)
train_data = DataSet(edges = train_data,ts =train_ts,labels = torch.ones(train_data.shape[1]))
test_data = DataSet(edges = test_data,ts =test_ts,labels = torch.ones(test_data.shape[1]))
val_data = DataSet(edges = val_data,ts = val_ts,labels = torch.ones(val_data.shape[1]))
neg_sampler = NegativeSampling('triplet')
trainloader = DistributedDataLoader('train',graph,train_data,sampler = sampler,neg_sampler=neg_sampler,batch_size = 600,shuffle=False,cache_memory_size = 0,drop_last=True,cs = 1)
testloader = DistributedDataLoader('test',graph,test_data,sampler = sampler,neg_sampler=neg_sampler,batch_size = 600,shuffle=False,cache_memory_size = 0,drop_last=True,cs = None)
valloader = DistributedDataLoader('val',graph,val_data,sampler = sampler,neg_sampler=neg_sampler,batch_size = 600,shuffle=False,cache_memory_size = 0,drop_last=True,cs = None)
val_losses = list()
def eval(mode='val'):
neg_samples = 1
model.eval()
aps = list()
aucs_mrrs = list()
if mode == 'val':
loader = valloader
elif mode == 'test':
loader = testloader
elif mode == 'train':
loader = trainloader
with torch.no_grad():
total_loss = 0
for batchData in loader:
mfgs,metadata = batch_data_prepare_input(batchData,sample_param['history'],cuda = use_cuda)
optimizer.zero_grad()
pred_pos, pred_neg = model(mfgs,metadata)
total_loss += creterion(pred_pos, torch.ones_like(pred_pos))
total_loss += creterion(pred_neg, torch.zeros_like(pred_neg))
y_pred = torch.cat([pred_pos, pred_neg], dim=0).sigmoid().cpu()
y_true = torch.cat([torch.ones(pred_pos.size(0)), torch.zeros(pred_neg.size(0))], dim=0)
aps.append(average_precision_score(y_true, y_pred))
if neg_samples > 1:
aucs_mrrs.append(torch.reciprocal(torch.sum(pred_pos.squeeze() < pred_neg.squeeze().reshape(neg_samples, -1), dim=0) + 1).type(torch.float))
else:
aucs_mrrs.append(roc_auc_score(y_true, y_pred))
if mode == 'val':
val_losses.append(float(total_loss))
ap = float(torch.tensor(aps).mean())
if neg_samples > 1:
auc_mrr = float(torch.cat(aucs_mrrs).mean())
else:
auc_mrr = float(torch.tensor(aucs_mrrs).mean())
return ap, auc_mrr
# set_seed(0)
sample_param, memory_param, gnn_param, train_param = parse_config('tgl/config/TGN.yml')
gnn_dim_node = 0 if graph.data.x is None else graph.data.x.shape[1]
gnn_dim_edge = 0 if graph.data.edge_attr is None else graph.data.edge_attr.shape[1]
combine_first = False
if 'combine_neighs' in train_param and train_param['combine_neighs']:
combine_first = True
use_cuda = False
if use_cuda:
model = GeneralModel(gnn_dim_node, gnn_dim_edge, sample_param, memory_param, gnn_param, train_param, combined=combine_first).cuda()
else:
model = GeneralModel(gnn_dim_node, gnn_dim_edge, sample_param, memory_param, gnn_param, train_param, combined=combine_first)
model = DDP(model)
mailbox = MailBox(memory_param, graph.partptr[-1], gnn_dim_edge) if memory_param['type'] != 'none' else None
creterion = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=train_param['lr'])
if 'all_on_gpu' in train_param and train_param['all_on_gpu']:
#if node_feats is not None:
# node_feats = node_feats.cuda()
#if edge_feats is not None:
# edge_feats = edge_feats.cuda()
if mailbox is not None:
mailbox.move_to_gpu()
sampler = None
#if not ('no_sample' in sample_param and sample_param['no_sample']):
# sampler = ParallelSampler(g['indptr'], g['indices'], g['eid'], g['ts'].astype(np.float32),
# sample_param['num_thread'], 1, sample_param['layer'], sample_param['neighbor'],
# sample_param['strategy']=='recent', sample_param['prop_time'],
# sample_param['history'], float(sample_param['duration']))
# neg_link_sampler = NegLinkSampler(g['indptr'].shape[0] - 1)
neg_link_sampler = None#NegativeSampling()
if not os.path.isdir('models'):
os.mkdir('models')
#if args.model_name == '':
# path_saver = 'models/{}_{}.pkl'.format(args.data, time.time())
#else:
# path_saver = 'models/{}.pkl'.format(args.model_name)
best_ap = 0
best_e = 0
sample_time = time.time()
for e in range(train_param['epoch']):
train_aps = list()
print('Epoch {:d}:'.format(e))
time_prep = 0
time_tot = 0
total_loss = 0
# training
model.train()
if sampler is not None:
sampler.reset()
if mailbox is not None:
mailbox.reset()
model.memory_updater.last_updated_nid = None
last_time = time.time()
for batchData in trainloader:
sample_time += time.time() - last_time
##for _, rows in df[:train_edge_end].groupby(group_indexes[random.randint(0, len(group_indexes) - 1)]):
t_tot_s = time.time()
t_prep_s = time.time()
mfgs,metadata = batch_data_prepare_input(batchData,sample_param['history'],cuda = use_cuda)
#if mailbox is not None:
# mailbox.prep_input_mails(mfgs[0])
time_prep += time.time() - t_prep_s
optimizer.zero_grad()
pred_pos, pred_neg = model(mfgs,metadata)
loss = creterion(pred_pos, torch.ones_like(pred_pos))
loss += creterion(pred_neg, torch.zeros_like(pred_neg))
total_loss += float(loss) * train_param['batch_size']
loss.backward()
optimizer.step()
t_prep_s = time.time()
#if mailbox is not None:
y_pred = torch.cat([pred_pos, pred_neg], dim=0).sigmoid().cpu()
y_true = torch.cat([torch.ones(pred_pos.size(0)), torch.zeros(pred_neg.size(0))], dim=0)
train_aps.append(average_precision_score(y_true, y_pred.detach().numpy()))
# eid = rows['Unnamed: 0'].values
# mem_edge_feats = edge_feats[eid] if edge_feats is not None else None
# block = None
# mailbox.update_mailbox(model.memory_updater.last_updated_nid, model.memory_updater.last_updated_memory, root_nodes, ts, mem_edge_feats,# block)
# mailbox.update_memory(model.memory_updater.last_updated_nid, model.memory_updater.last_updated_memory, root_nodes, model.memory_updater#.last_updated_ts)
time_prep += time.time() - t_prep_s
time_tot += time.time() - t_tot_s
last_time = time.time()
train_ap = float(torch.tensor(train_aps).mean())
ap, auc = eval('val')
if e > 2 and ap > best_ap:
best_e = e
best_ap = ap
# torch.save(model.state_dict(), path_saver)
print('\ttrain loss:{:.4f} train ap:{:4f} val ap:{:4f} val auc:{:4f}'.format(total_loss,train_ap, ap, auc))
print('\ttotal time:{:.2f}s sample time:{:.2f}s prep time:{:.2f}s'.format(time_tot, sample_time, time_prep))
#
print('Loading model at epoch {}...'.format(best_e))
#model.load_state_dict(torch.load(path_saver))
model.eval()
if sampler is not None:
sampler.reset()
if mailbox is not None:
mailbox.reset()
model.memory_updater.last_updated_nid = None
eval('train')
eval('val')
ap, auc = eval('test')
if args.eval_neg_samples > 1:
print('\ttest AP:{:4f} test MRR:{:4f}'.format(ap, auc))
else:
print('\ttest AP:{:4f} test AUC:{:4f}'.format(ap, auc))
DistCustomPool.close_distribution()
if __name__ == "__main__":
main()
from os.path import abspath, join, dirname
import os
import sys
from os.path import abspath, join, dirname
sys.path.insert(0, join(abspath(dirname(__file__))))
import torch
import dgl
from layers import TimeEncode
from torch_scatter import scatter
class MailBox():
def __init__(self, memory_param, num_nodes, dim_edge_feat, _node_memory=None, _node_memory_ts=None,_mailbox=None, _mailbox_ts=None, _next_mail_pos=None, _update_mail_pos=None):
self.memory_param = memory_param
self.dim_edge_feat = dim_edge_feat
if memory_param['type'] != 'node':
raise NotImplementedError
self.node_memory = torch.zeros((num_nodes, memory_param['dim_out']), dtype=torch.float32) if _node_memory is None else _node_memory
self.node_memory_ts = torch.zeros(num_nodes, dtype=torch.float32) if _node_memory_ts is None else _node_memory_ts
self.mailbox = torch.zeros((num_nodes, memory_param['mailbox_size'], 2 * memory_param['dim_out'] + dim_edge_feat), dtype=torch.float32) if _mailbox is None else _mailbox
self.mailbox_ts = torch.zeros((num_nodes, memory_param['mailbox_size']), dtype=torch.float32) if _mailbox_ts is None else _mailbox_ts
self.next_mail_pos = torch.zeros((num_nodes), dtype=torch.long) if _next_mail_pos is None else _next_mail_pos
self.update_mail_pos = _update_mail_pos
self.device = torch.device('cpu')
def reset(self):
self.node_memory.fill_(0)
self.node_memory_ts.fill_(0)
self.mailbox.fill_(0)
self.mailbox_ts.fill_(0)
self.next_mail_pos.fill_(0)
def move_to_gpu(self):
self.node_memory = self.node_memory.cuda()
self.node_memory_ts = self.node_memory_ts.cuda()
self.mailbox = self.mailbox.cuda()
self.mailbox_ts = self.mailbox_ts.cuda()
self.next_mail_pos = self.next_mail_pos.cuda()
self.device = torch.device('cuda:0')
def allocate_pinned_memory_buffers(self, sample_param, batch_size):
limit = int(batch_size * 3.3)
if 'neighbor' in sample_param:
for i in sample_param['neighbor']:
limit *= i + 1
self.pinned_node_memory_buffs = list()
self.pinned_node_memory_ts_buffs = list()
self.pinned_mailbox_buffs = list()
self.pinned_mailbox_ts_buffs = list()
for _ in range(sample_param['history']):
self.pinned_node_memory_buffs.append(torch.zeros((limit, self.node_memory.shape[1]), pin_memory=True))
self.pinned_node_memory_ts_buffs.append(torch.zeros((limit,), pin_memory=True))
self.pinned_mailbox_buffs.append(torch.zeros((limit, self.mailbox.shape[1], self.mailbox.shape[2]), pin_memory=True))
self.pinned_mailbox_ts_buffs.append(torch.zeros((limit, self.mailbox_ts.shape[1]), pin_memory=True))
def prep_input_mails(self, mfg, global_id_list = None,use_pinned_buffers=False):
if(global_id_list is not None):
idx = global_id_list[b.srcdata['ID']]
else:
idx = b.srcdata['ID']
if use_pinned_buffers:
idx = idx.cpu().long()
for i, b in enumerate(mfg):
if use_pinned_buffers:
torch.index_select(self.node_memory, 0, idx, out=self.pinned_node_memory_buffs[i][:idx.shape[0]])
b.srcdata['mem'] = self.pinned_node_memory_buffs[i][:idx.shape[0]].cuda(non_blocking=True)
torch.index_select(self.node_memory_ts,0, idx, out=self.pinned_node_memory_ts_buffs[i][:idx.shape[0]])
b.srcdata['mem_ts'] = self.pinned_node_memory_ts_buffs[i][:idx.shape[0]].cuda(non_blocking=True)
torch.index_select(self.mailbox, 0, idx, out=self.pinned_mailbox_buffs[i][:idx.shape[0]])
b.srcdata['mem_input'] = self.pinned_mailbox_buffs[i][:idx.shape[0]].reshape(b.srcdata['ID'].shape[0], -1).cuda(non_blocking=True)
torch.index_select(self.mailbox_ts, 0, idx, out=self.pinned_mailbox_ts_buffs[i][:idx.shape[0]])
b.srcdata['mail_ts'] = self.pinned_mailbox_ts_buffs[i][:idx.shape[0]].cuda(non_blocking=True)
else:
b.srcdata['mem'] = self.node_memory[idx].cuda()
b.srcdata['mem_ts'] = self.node_memory_ts[idx].cuda()
b.srcdata['mem_input'] = self.mailbox[idx].cuda().reshape(b.srcdata['ID'].shape[0], -1)
b.srcdata['mail_ts'] = self.mailbox_ts[idx].cuda()
def update_memory(self, nid, memory, root_nodes, ts, global_node_list = None,neg_samples=1):
if nid is None:
return
num_true_src_dst = root_nodes.shape[0] // (neg_samples + 2) * 2
with torch.no_grad():
nid = nid[:num_true_src_dst].to(self.device)
memory = memory[:num_true_src_dst].to(self.device)
ts = ts[:num_true_src_dst].to(self.device)
self.node_memory[nid.long()] = memory
self.node_memory_ts[nid.long()] = ts
def update_mailbox(self, nid, memory, root_nodes, ts, edge_feats, block, global_node_list = None, neg_samples=1):
with torch.no_grad():
num_true_edges = root_nodes.shape[0] // (neg_samples + 2)
memory = memory.to(self.device)
if edge_feats is not None:
edge_feats = edge_feats.to(self.device)
if block is not None:
block = block.to(self.device)
# TGN/JODIE
if self.memory_param['deliver_to'] == 'self':
src = torch.from_numpy(root_nodes[:num_true_edges]).to(self.device)
dst = torch.from_numpy(root_nodes[num_true_edges:num_true_edges * 2]).to(self.device)
mem_src = memory[:num_true_edges]
mem_dst = memory[num_true_edges:num_true_edges * 2]
if self.dim_edge_feat > 0:
src_mail = torch.cat([mem_src, mem_dst, edge_feats], dim=1)
dst_mail = torch.cat([mem_dst, mem_src, edge_feats], dim=1)
else:
src_mail = torch.cat([mem_src, mem_dst], dim=1)
dst_mail = torch.cat([mem_dst, mem_src], dim=1)
mail = torch.cat([src_mail, dst_mail], dim=1).reshape(-1, src_mail.shape[1])
nid = torch.cat([src.unsqueeze(1), dst.unsqueeze(1)], dim=1).reshape(-1)
mail_ts = torch.from_numpy(ts[:num_true_edges * 2]).to(self.device)
if mail_ts.dtype == torch.float64:
import pdb; pdb.set_trace()
# find unique nid to update mailbox
uni, inv = torch.unique(nid, return_inverse=True)
perm = torch.arange(inv.size(0), dtype=inv.dtype, device=inv.device)
perm = inv.new_empty(uni.size(0)).scatter_(0, inv, perm)
nid = nid[perm]
mail = mail[perm]
mail_ts = mail_ts[perm]
if self.memory_param['mail_combine'] == 'last':
self.mailbox[nid.long(), self.next_mail_pos[nid.long()]] = mail
self.mailbox_ts[nid.long(), self.next_mail_pos[nid.long()]] = mail_ts
if self.memory_param['mailbox_size'] > 1:
self.next_mail_pos[nid.long()] = torch.remainder(self.next_mail_pos[nid.long()] + 1, self.memory_param['mailbox_size'])
# APAN
elif self.memory_param['deliver_to'] == 'neighbors':
mem_src = memory[:num_true_edges]
mem_dst = memory[num_true_edges:num_true_edges * 2]
if self.dim_edge_feat > 0:
src_mail = torch.cat([mem_src, mem_dst, edge_feats], dim=1)
dst_mail = torch.cat([mem_dst, mem_src, edge_feats], dim=1)
else:
src_mail = torch.cat([mem_src, mem_dst], dim=1)
dst_mail = torch.cat([mem_dst, mem_src], dim=1)
mail = torch.cat([src_mail, dst_mail], dim=0)
mail = torch.cat([mail, mail[block.edges()[0].long()]], dim=0)
mail_ts = torch.from_numpy(ts[:num_true_edges * 2]).to(self.device)
mail_ts = torch.cat([mail_ts, mail_ts[block.edges()[0].long()]], dim=0)
if self.memory_param['mail_combine'] == 'mean':
(nid, idx) = torch.unique(block.dstdata['ID'], return_inverse=True)
mail = scatter(mail, idx, reduce='mean', dim=0)
mail_ts = scatter(mail_ts, idx, reduce='mean')
self.mailbox[nid.long(), self.next_mail_pos[nid.long()]] = mail
self.mailbox_ts[nid.long(), self.next_mail_pos[nid.long()]] = mail_ts
elif self.memory_param['mail_combine'] == 'last':
nid = block.dstdata['ID']
# find unique nid to update mailbox
uni, inv = torch.unique(nid, return_inverse=True)
perm = torch.arange(inv.size(0), dtype=inv.dtype, device=inv.device)
perm = inv.new_empty(uni.size(0)).scatter_(0, inv, perm)
nid = nid[perm]
mail = mail[perm]
mail_ts = mail_ts[perm]
self.mailbox[nid.long(), self.next_mail_pos[nid.long()]] = mail
self.mailbox_ts[nid.long(), self.next_mail_pos[nid.long()]] = mail_ts
else:
raise NotImplementedError
if self.memory_param['mailbox_size'] > 1:
if self.update_mail_pos is None:
self.next_mail_pos[nid.long()] = torch.remainder(self.next_mail_pos[nid.long()] + 1, self.memory_param['mailbox_size'])
else:
self.update_mail_pos[nid.long()] = 1
else:
raise NotImplementedError
def update_next_mail_pos(self):
if self.update_mail_pos is not None:
nid = torch.where(self.update_mail_pos == 1)[0]
self.next_mail_pos[nid] = torch.remainder(self.next_mail_pos[nid] + 1, self.memory_param['mailbox_size'])
self.update_mail_pos.fill_(0)
class GRUMemeoryUpdater(torch.nn.Module):
def __init__(self, memory_param, dim_in, dim_hid, dim_time, dim_node_feat):
super(GRUMemeoryUpdater, self).__init__()
self.dim_hid = dim_hid
self.dim_node_feat = dim_node_feat
self.memory_param = memory_param
self.dim_time = dim_time
self.updater = torch.nn.GRUCell(dim_in + dim_time, dim_hid)
self.last_updated_memory = None
self.last_updated_ts = None
self.last_updated_nid = None
if dim_time > 0:
self.time_enc = TimeEncode(dim_time)
if memory_param['combine_node_feature']:
if dim_node_feat > 0 and dim_node_feat != dim_hid:
self.node_feat_map = torch.nn.Linear(dim_node_feat, dim_hid)
def forward(self, mfg):
for b in mfg:
if self.dim_time > 0:
time_feat = self.time_enc(b.srcdata['ts'] - b.srcdata['mem_ts'])
b.srcdata['mem_input'] = torch.cat([b.srcdata['mem_input'], time_feat], dim=1)
updated_memory = self.updater(b.srcdata['mem_input'], b.srcdata['mem'])
self.last_updated_ts = b.srcdata['ts'].detach().clone()
self.last_updated_memory = updated_memory.detach().clone()
self.last_updated_nid = b.srcdata['ID'].detach().clone()
if self.memory_param['combine_node_feature']:
if self.dim_node_feat > 0:
if self.dim_node_feat == self.dim_hid:
b.srcdata['h'] += updated_memory
else:
b.srcdata['h'] = updated_memory + self.node_feat_map(b.srcdata['h'])
else:
b.srcdata['h'] = updated_memory
class RNNMemeoryUpdater(torch.nn.Module):
def __init__(self, memory_param, dim_in, dim_hid, dim_time, dim_node_feat):
super(RNNMemeoryUpdater, self).__init__()
self.dim_hid = dim_hid
self.dim_node_feat = dim_node_feat
self.memory_param = memory_param
self.dim_time = dim_time
self.updater = torch.nn.RNNCell(dim_in + dim_time, dim_hid)
self.last_updated_memory = None
self.last_updated_ts = None
self.last_updated_nid = None
if dim_time > 0:
self.time_enc = TimeEncode(dim_time)
if memory_param['combine_node_feature']:
if dim_node_feat > 0 and dim_node_feat != dim_hid:
self.node_feat_map = torch.nn.Linear(dim_node_feat, dim_hid)
def forward(self, mfg):
for b in mfg:
if self.dim_time > 0:
time_feat = self.time_enc(b.srcdata['ts'] - b.srcdata['mem_ts'])
b.srcdata['mem_input'] = torch.cat([b.srcdata['mem_input'], time_feat], dim=1)
updated_memory = self.updater(b.srcdata['mem_input'], b.srcdata['mem'])
self.last_updated_ts = b.srcdata['ts'].detach().clone()
self.last_updated_memory = updated_memory.detach().clone()
self.last_updated_nid = b.srcdata['ID'].detach().clone()
if self.memory_param['combine_node_feature']:
if self.dim_node_feat > 0:
if self.dim_node_feat == self.dim_hid:
b.srcdata['h'] += updated_memory
else:
b.srcdata['h'] = updated_memory + self.node_feat_map(b.srcdata['h'])
else:
b.srcdata['h'] = updated_memory
class TransformerMemoryUpdater(torch.nn.Module):
def __init__(self, memory_param, dim_in, dim_out, dim_time, train_param):
super(TransformerMemoryUpdater, self).__init__()
self.memory_param = memory_param
self.dim_time = dim_time
self.att_h = memory_param['attention_head']
if dim_time > 0:
self.time_enc = TimeEncode(dim_time)
self.w_q = torch.nn.Linear(dim_out, dim_out)
self.w_k = torch.nn.Linear(dim_in + dim_time, dim_out)
self.w_v = torch.nn.Linear(dim_in + dim_time, dim_out)
self.att_act = torch.nn.LeakyReLU(0.2)
self.layer_norm = torch.nn.LayerNorm(dim_out)
self.mlp = torch.nn.Linear(dim_out, dim_out)
self.dropout = torch.nn.Dropout(train_param['dropout'])
self.att_dropout = torch.nn.Dropout(train_param['att_dropout'])
self.last_updated_memory = None
self.last_updated_ts = None
self.last_updated_nid = None
def forward(self, mfg):
for b in mfg:
Q = self.w_q(b.srcdata['mem']).reshape((b.num_src_nodes(), self.att_h, -1))
mails = b.srcdata['mem_input'].reshape((b.num_src_nodes(), self.memory_param['mailbox_size'], -1))
if self.dim_time > 0:
time_feat = self.time_enc(b.srcdata['ts'][:, None] - b.srcdata['mail_ts']).reshape((b.num_src_nodes(), self.memory_param['mailbox_size'], -1))
mails = torch.cat([mails, time_feat], dim=2)
K = self.w_k(mails).reshape((b.num_src_nodes(), self.memory_param['mailbox_size'], self.att_h, -1))
V = self.w_v(mails).reshape((b.num_src_nodes(), self.memory_param['mailbox_size'], self.att_h, -1))
att = self.att_act((Q[:,None,:,:]*K).sum(dim=3))
att = torch.nn.functional.softmax(att, dim=1)
att = self.att_dropout(att)
rst = (att[:,:,:,None]*V).sum(dim=1)
rst = rst.reshape((rst.shape[0], -1))
rst += b.srcdata['mem']
rst = self.layer_norm(rst)
rst = self.mlp(rst)
rst = self.dropout(rst)
rst = torch.nn.functional.relu(rst)
b.srcdata['h'] = rst
self.last_updated_memory = rst.detach().clone()
self.last_updated_nid = b.srcdata['ID'].detach().clone()
self.last_updated_ts = b.srcdata['ts'].detach().clone()
import torch
import dgl
from os.path import abspath, join, dirname
import os
import sys
from os.path import abspath, join, dirname
sys.path.insert(0, join(abspath(dirname(__file__))))
from memorys import *
from layers import *
class GeneralModel(torch.nn.Module):
def __init__(self, dim_node, dim_edge, sample_param, memory_param, gnn_param, train_param, combined=False):
super(GeneralModel, self).__init__()
self.dim_node = dim_node
self.dim_node_input = dim_node
self.dim_edge = dim_edge
self.sample_param = sample_param
self.memory_param = memory_param
if not 'dim_out' in gnn_param:
gnn_param['dim_out'] = memory_param['dim_out']
self.gnn_param = gnn_param
self.train_param = train_param
if memory_param['type'] == 'node':
if memory_param['memory_update'] == 'gru':
self.memory_updater = GRUMemeoryUpdater(memory_param, 2 * memory_param['dim_out'] + dim_edge, memory_param['dim_out'], memory_param['dim_time'], dim_node)
elif memory_param['memory_update'] == 'rnn':
self.memory_updater = RNNMemeoryUpdater(memory_param, 2 * memory_param['dim_out'] + dim_edge, memory_param['dim_out'], memory_param['dim_time'], dim_node)
elif memory_param['memory_update'] == 'transformer':
self.memory_updater = TransformerMemoryUpdater(memory_param, 2 * memory_param['dim_out'] + dim_edge, memory_param['dim_out'], memory_param['dim_time'], train_param)
else:
raise NotImplementedError
self.dim_node_input = memory_param['dim_out']
self.layers = torch.nn.ModuleDict()
if gnn_param['arch'] == 'transformer_attention':
for h in range(sample_param['history']):
self.layers['l0h' + str(h)] = TransfomerAttentionLayer(self.dim_node_input, dim_edge, gnn_param['dim_time'], gnn_param['att_head'], train_param['dropout'], train_param['att_dropout'], gnn_param['dim_out'], combined=combined)
for l in range(1, gnn_param['layer']):
for h in range(sample_param['history']):
self.layers['l' + str(l) + 'h' + str(h)] = TransfomerAttentionLayer(gnn_param['dim_out'], dim_edge, gnn_param['dim_time'], gnn_param['att_head'], train_param['dropout'], train_param['att_dropout'], gnn_param['dim_out'], combined=False)
elif gnn_param['arch'] == 'identity':
self.gnn_param['layer'] = 1
for h in range(sample_param['history']):
self.layers['l0h' + str(h)] = IdentityNormLayer(self.dim_node_input)
if 'time_transform' in gnn_param and gnn_param['time_transform'] == 'JODIE':
self.layers['l0h' + str(h) + 't'] = JODIETimeEmbedding(gnn_param['dim_out'])
else:
raise NotImplementedError
self.edge_predictor = EdgePredictor(gnn_param['dim_out'])
if 'combine' in gnn_param and gnn_param['combine'] == 'rnn':
self.combiner = torch.nn.RNN(gnn_param['dim_out'], gnn_param['dim_out'])
def forward(self, mfgs, metadata = None,neg_samples=1):
if self.memory_param['type'] == 'node':
self.memory_updater(mfgs[0])
out = list()
for l in range(self.gnn_param['layer']):
for h in range(self.sample_param['history']):
rst = self.layers['l' + str(l) + 'h' + str(h)](mfgs[l][h])
if 'time_transform' in self.gnn_param and self.gnn_param['time_transform'] == 'JODIE':
rst = self.layers['l0h' + str(h) + 't'](rst, mfgs[l][h].srcdata['mem_ts'], mfgs[l][h].srcdata['ts'])
if l != self.gnn_param['layer'] - 1:
mfgs[l + 1][h].srcdata['h'] = rst
else:
out.append(rst)
if self.sample_param['history'] == 1:
out = out[0]
else:
out = torch.stack(out, dim=0)
out = self.combiner(out)[0][-1, :, :]
#metadata需要在前面去重的时候记一下id
if metadata is not None:
out = torch.cat((out[metadata['src_id_pos']],out[metadata['dst_pos_pos']],out[metadata['dst_neg_pos']]),0)
return self.edge_predictor(out, neg_samples=neg_samples)
def get_emb(self, mfgs):
if self.memory_param['type'] == 'node':
self.memory_updater(mfgs[0])
out = list()
for l in range(self.gnn_param['layer']):
for h in range(self.sample_param['history']):
rst = self.layers['l' + str(l) + 'h' + str(h)](mfgs[l][h])
if 'time_transform' in self.gnn_param and self.gnn_param['time_transform'] == 'JODIE':
rst = self.layers['l0h' + str(h) + 't'](rst, mfgs[l][h].srcdata['mem_ts'], mfgs[l][h].srcdata['ts'])
if l != self.gnn_param['layer'] - 1:
mfgs[l + 1][h].srcdata['h'] = rst
else:
out.append(rst)
if self.sample_param['history'] == 1:
out = out[0]
else:
out = torch.stack(out, dim=0)
out = self.combiner(out)[0][-1, :, :]
return out
class NodeClassificationModel(torch.nn.Module):
def __init__(self, dim_in, dim_hid, num_class):
super(NodeClassificationModel, self).__init__()
self.fc1 = torch.nn.Linear(dim_in, dim_hid)
self.fc2 = torch.nn.Linear(dim_hid, num_class)
def forward(self, x):
x = self.fc1(x)
x = torch.nn.functional.relu(x)
x = self.fc2(x)
return x
\ No newline at end of file
sudo mount -o remount,size=600G /dev/shm
\ No newline at end of file
alabaster==0.7.13
antlr4-python3-runtime==4.9.3
anyio @ file:///home/conda/feedstock_root/build_artifacts/anyio_1666191106763/work/dist
appdirs==1.4.4
argon2-cffi @ file:///home/conda/feedstock_root/build_artifacts/argon2-cffi_1640817743617/work
argon2-cffi-bindings @ file:///home/conda/feedstock_root/build_artifacts/argon2-cffi-bindings_1649500309442/work
asttokens==2.2.1
attrs @ file:///home/conda/feedstock_root/build_artifacts/attrs_1671632566681/work
autopep8==2.0.2
Babel==2.12.1
backcall==0.2.0
backports.functools-lru-cache @ file:///home/conda/feedstock_root/build_artifacts/backports.functools_lru_cache_1618230623929/work
beautifulsoup4 @ file:///home/conda/feedstock_root/build_artifacts/beautifulsoup4_1675252249248/work
bleach @ file:///home/conda/feedstock_root/build_artifacts/bleach_1674535352125/work
brotlipy @ file:///home/conda/feedstock_root/build_artifacts/brotlipy_1648854175163/work
certifi==2022.12.7
cffi @ file:///home/conda/feedstock_root/build_artifacts/cffi_1636046063618/work
charset-normalizer @ file:///home/conda/feedstock_root/build_artifacts/charset-normalizer_1661170624537/work
click==8.1.3
comm==0.1.1
contourpy==1.0.7
cryptography @ file:///home/conda/feedstock_root/build_artifacts/cryptography_1652967113783/work
cycler==0.11.0
debugpy==1.6.4
decorator @ file:///home/conda/feedstock_root/build_artifacts/decorator_1641555617451/work
defusedxml @ file:///home/conda/feedstock_root/build_artifacts/defusedxml_1615232257335/work
dgl==1.0.2
dglgo==0.0.2
docutils==0.19
entrypoints @ file:///home/conda/feedstock_root/build_artifacts/entrypoints_1643888246732/work
exceptiongroup==1.1.1
executing==1.2.0
fastgraph==0.2.0
fastjsonschema @ file:///home/conda/feedstock_root/build_artifacts/python-fastjsonschema_1663619548554/work/dist
flameprof==0.4
flit_core @ file:///home/conda/feedstock_root/build_artifacts/flit-core_1667734568827/work/source/flit_core
fonttools==4.39.3
hydra-core==1.3.2
idna @ file:///home/conda/feedstock_root/build_artifacts/idna_1663625384323/work
imagesize==1.4.1
importlib-metadata @ file:///home/conda/feedstock_root/build_artifacts/importlib-metadata_1672612343532/work
importlib-resources @ file:///home/conda/feedstock_root/build_artifacts/importlib_resources_1672681417544/work
iniconfig==2.0.0
ipykernel==6.19.0
ipython==8.7.0
ipython-genutils==0.2.0
ipywidgets @ file:///home/conda/feedstock_root/build_artifacts/ipywidgets_1671720089366/work
isort==5.12.0
jedi==0.18.2
Jinja2 @ file:///croot/jinja2_1666908132255/work
joblib==1.2.0
jsonschema @ file:///home/conda/feedstock_root/build_artifacts/jsonschema-meta_1669810440410/work
jupyter-server @ file:///home/conda/feedstock_root/build_artifacts/jupyter_server_1673574555503/work
jupyter_client==7.4.8
jupyter_core==5.1.0
jupyterlab-pygments @ file:///home/conda/feedstock_root/build_artifacts/jupyterlab_pygments_1649936611996/work
jupyterlab-widgets @ file:///home/conda/feedstock_root/build_artifacts/jupyterlab_widgets_1671722028097/work
kiwisolver==1.4.4
littleutils==0.2.2
lltm-cpp==0.0.0
markdown-it-py==2.2.0
MarkupSafe @ file:///opt/conda/conda-bld/markupsafe_1654597864307/work
matmul-cpu-pytorch==0.0.0
matmul-gpu==0.0.0
matplotlib==3.7.1
matplotlib-inline==0.1.6
mdurl==0.1.2
memray==1.7.0
mistune @ file:///home/conda/feedstock_root/build_artifacts/mistune_1675771498296/work
mkl-fft==1.3.1
mkl-random==1.2.2
mkl-service==2.4.0
nbclassic @ file:///home/conda/feedstock_root/build_artifacts/nbclassic_1675369808718/work
nbclient @ file:///home/conda/feedstock_root/build_artifacts/nbclient_1669795076334/work
nbconvert @ file:///home/conda/feedstock_root/build_artifacts/nbconvert-meta_1674590374792/work
nbformat @ file:///home/conda/feedstock_root/build_artifacts/nbformat_1673560067442/work
nest-asyncio==1.5.6
networkx @ file:///opt/conda/conda-bld/networkx_1657784097507/work
notebook @ file:///home/conda/feedstock_root/build_artifacts/notebook_1667565639349/work
notebook_shim @ file:///home/conda/feedstock_root/build_artifacts/notebook-shim_1667478401171/work
numpy @ file:///croot/numpy_and_numpy_base_1672336185480/work
numpydoc==1.5.0
ogb==1.3.5
olefile @ file:///home/conda/feedstock_root/build_artifacts/olefile_1602866521163/work
omegaconf==2.3.0
outdated==0.2.2
packaging @ file:///croot/packaging_1671697413597/work
pandas==1.5.3
pandocfilters @ file:///home/conda/feedstock_root/build_artifacts/pandocfilters_1631603243851/work
parso==0.8.3
pexpect @ file:///home/conda/feedstock_root/build_artifacts/pexpect_1667297516076/work
pickleshare==0.7.5
Pillow @ file:///home/conda/feedstock_root/build_artifacts/pillow_1621287973719/work
pkgutil_resolve_name @ file:///home/conda/feedstock_root/build_artifacts/pkgutil-resolve-name_1633981968097/work
platformdirs==2.6.0
pluggy==1.0.0
pooch @ file:///tmp/build/80754af9/pooch_1623324770023/work
presample-cores==0.0.0
prometheus-client @ file:///home/conda/feedstock_root/build_artifacts/prometheus_client_1674535637125/work
prompt-toolkit==3.0.36
psutil==5.9.4
ptyprocess @ file:///home/conda/feedstock_root/build_artifacts/ptyprocess_1609419310487/work/dist/ptyprocess-0.7.0-py2.py3-none-any.whl
pure-eval==0.2.2
pybind11==2.10.4
pycodestyle==2.10.0
pycparser @ file:///home/conda/feedstock_root/build_artifacts/pycparser_1636257122734/work
pydantic==1.10.7
Pygments==2.13.0
pyOpenSSL @ file:///home/conda/feedstock_root/build_artifacts/pyopenssl_1663846997386/work
pyparsing @ file:///opt/conda/conda-bld/pyparsing_1661452539315/work
pyrsistent @ file:///tmp/build/80754af9/pyrsistent_1636110947380/work
PySocks @ file:///home/conda/feedstock_root/build_artifacts/pysocks_1661604839144/work
pytest==7.2.2
python-dateutil==2.8.2
pytz==2022.7.1
PyYAML==6.0
pyzmq==24.0.1
rdkit-pypi==2022.9.5
requests @ file:///home/conda/feedstock_root/build_artifacts/requests_1673863902341/work
rich==13.4.1
ruamel.yaml==0.17.21
ruamel.yaml.clib==0.2.7
sample-cores==0.0.0
scikit-learn @ file:///croot/scikit-learn_1673957209982/work
scipy==1.10.0
Send2Trash @ file:///home/conda/feedstock_root/build_artifacts/send2trash_1628511208346/work
six @ file:///home/conda/feedstock_root/build_artifacts/six_1620240208055/work
sniffio @ file:///home/conda/feedstock_root/build_artifacts/sniffio_1662051266223/work
snowballstemmer==2.2.0
soupsieve @ file:///home/conda/feedstock_root/build_artifacts/soupsieve_1658207591808/work
Sphinx==6.1.3
sphinxcontrib-applehelp==1.0.4
sphinxcontrib-devhelp==1.0.2
sphinxcontrib-htmlhelp==2.0.1
sphinxcontrib-jsmath==1.0.1
sphinxcontrib-qthelp==1.0.3
sphinxcontrib-serializinghtml==1.1.5
stack-data==0.6.2
terminado @ file:///home/conda/feedstock_root/build_artifacts/terminado_1670253674810/work
threadpoolctl==3.1.0
tinycss2 @ file:///home/conda/feedstock_root/build_artifacts/tinycss2_1666100256010/work
tomli==2.0.1
torch==1.12.1
torch-cluster @ file:///usr/share/miniconda/envs/test/conda-bld/pytorch-cluster_1656604322802/work
torch-geometric @ file:///usr/share/miniconda/envs/test/conda-bld/pyg_1669879221143/work
torch-geometric-autoscale==0.0.0
torch-quiver==0.1.1
torch-scatter @ file:///usr/share/miniconda/envs/test/conda-bld/pytorch-scatter_1669298837499/work
torch-sparse @ file:///usr/share/miniconda/envs/test/conda-bld/pytorch-sparse_1671714918898/work
torchaudio==0.12.1
torchvision==0.13.1
tornado==6.2
tqdm==4.64.1
traitlets==5.6.0
typer==0.7.0
typing_extensions==4.4.0
urllib3 @ file:///home/conda/feedstock_root/build_artifacts/urllib3_1673452138552/work
wcwidth @ file:///home/conda/feedstock_root/build_artifacts/wcwidth_1673864653149/work
webencodings==0.5.1
websocket-client @ file:///home/conda/feedstock_root/build_artifacts/websocket-client_1675567828044/work
widgetsnbextension @ file:///home/conda/feedstock_root/build_artifacts/widgetsnbextension_1672066693230/work
zipp @ file:///home/conda/feedstock_root/build_artifacts/zipp_1675982654259/work
import argparse
import yaml
import torch
import time
import numpy as np
import pandas as pd
from tqdm import tqdm
from sampler_core import ParallelSampler, TemporalGraphBlock
class NegLinkSampler:
def __init__(self, num_nodes):
self.num_nodes = num_nodes
def sample(self, n):
return np.random.randint(self.num_nodes, size=n)
if __name__ == '__main__':
parser=argparse.ArgumentParser()
parser.add_argument('--data', type=str, help='dataset name')
parser.add_argument('--config', type=str, help='path to config file')
parser.add_argument('--batch_size', type=int, default=600, help='path to config file')
parser.add_argument('--num_thread', type=int, default=64, help='number of thread')
args=parser.parse_args()
df = pd.read_csv('DATA/{}/edges.csv'.format(args.data))
g = np.load('DATA/{}/ext_full.npz'.format(args.data))
sample_config = yaml.safe_load(open(args.config, 'r'))['sampling'][0]
# sampling:
# - layer: 2
# neighbor:
# - 10
# - 10
# strategy: 'uniform'
# prop_time: True
# history: 3
# duration: 5
# num_thread: 64
#
# strategy有两种一种是过去邻居均匀采样:'uniform'
# 一种是最近时间采样
sampler = ParallelSampler(g['indptr'], g['indices'], g['eid'], g['ts'].astype(np.float32),
args.num_thread, 1, sample_config['layer'], sample_config['neighbor'],
sample_config['strategy']=='recent', sample_config['prop_time'],
sample_config['history'], float(sample_config['duration']))
num_nodes = max(int(df['src'].max()), int(df['dst'].max()))
neg_link_sampler = NegLinkSampler(num_nodes)
tot_time = 0
ptr_time = 0
coo_time = 0
sea_time = 0
sam_time = 0
uni_time = 0
total_nodes = 0
unique_nodes = 0
for _, rows in tqdm(df.groupby(df.index // args.batch_size), total=len(df) // args.batch_size):
root_nodes = np.concatenate([rows.src.values, rows.dst.values, neg_link_sampler.sample(len(rows))]).astype(np.int32)
ts = np.concatenate([rows.time.values, rows.time.values, rows.time.values]).astype(np.float32)
sampler.sample(root_nodes, ts)
ret = sampler.get_ret()
tot_time += ret[0].tot_time()
ptr_time += ret[0].ptr_time()
coo_time += ret[0].coo_time()
sea_time += ret[0].search_time()
sam_time += ret[0].sample_time()
# for i in range(sample_config['history']):
# total_nodes += ret[i].dim_in() - ret[i].dim_out()
# unique_nodes += ret[i].dim_in() - ret[i].dim_out()
# if ret[i].dim_in() > ret[i].dim_out():
# ts = torch.from_numpy(ret[i].ts()[ret[i].dim_out():])
# nid = torch.from_numpy(ret[i].nodes()[ret[i].dim_out():]).float()
# nts = torch.stack([ts,nid],dim=1).cuda()
# uni_t_s = time.time()
# unts, idx = torch.unique(nts, dim=0, return_inverse=True)
# uni_time += time.time() - uni_t_s
# total_nodes += idx.shape[0]
# unique_nodes += unts.shape[0]
print('total time : {:.4f}'.format(tot_time))
print('pointer time: {:.4f}'.format(ptr_time))
print('coo time : {:.4f}'.format(coo_time))
print('search time : {:.4f}'.format(sea_time))
print('sample time : {:.4f}'.format(sam_time))
# print('unique time : {:.4f}'.format(uni_time))
# print('unique per : {:.4f}'.format(unique_nodes / total_nodes))
\ No newline at end of file
#include <iostream>
#include <string>
#include <cstdlib>
#include <random>
#include <omp.h>
#include <math.h>
#include <pybind11/pybind11.h>
#include <pybind11/numpy.h>
#include <pybind11/stl.h>
namespace py = pybind11;
typedef int NodeIDType;
typedef int EdgeIDType;
typedef float TimeStampType;
class TemporalGraphBlock
{
public:
std::vector<NodeIDType> row;
std::vector<NodeIDType> col;
std::vector<EdgeIDType> eid;
std::vector<TimeStampType> ts;
std::vector<TimeStampType> dts;
std::vector<NodeIDType> nodes;
NodeIDType dim_in, dim_out;
double ptr_time = 0;
double search_time = 0;
double sample_time = 0;
double tot_time = 0;
double coo_time = 0;
TemporalGraphBlock(){}
TemporalGraphBlock(std::vector<NodeIDType> &_row, std::vector<NodeIDType> &_col,
std::vector<EdgeIDType> &_eid, std::vector<TimeStampType> &_ts,
std::vector<TimeStampType> &_dts, std::vector<NodeIDType> &_nodes,
NodeIDType _dim_in, NodeIDType _dim_out) :
row(_row), col(_col), eid(_eid), ts(_ts), dts(_dts),
nodes(_nodes), dim_in(_dim_in), dim_out(_dim_out) {}
};
class ParallelSampler
{
public:
std::vector<EdgeIDType> indptr;
std::vector<EdgeIDType> indices;
std::vector<EdgeIDType> eid;
std::vector<TimeStampType> ts;
NodeIDType num_nodes;
EdgeIDType num_edges;
int num_thread_per_worker;
int num_workers;
int num_threads;
int num_layers;
std::vector<int> num_neighbors;
bool recent;
bool prop_time;
int num_history;
TimeStampType window_duration;
std::vector<std::vector<std::vector<EdgeIDType>::size_type>> ts_ptr;
omp_lock_t *ts_ptr_lock;
std::vector<TemporalGraphBlock> ret;
ParallelSampler(std::vector<EdgeIDType> &_indptr, std::vector<EdgeIDType> &_indices,
std::vector<EdgeIDType> &_eid, std::vector<TimeStampType> &_ts,
int _num_thread_per_worker, int _num_workers, int _num_layers,
std::vector<int> &_num_neighbors, bool _recent, bool _prop_time,
int _num_history, TimeStampType _window_duration) :
indptr(_indptr), indices(_indices), eid(_eid), ts(_ts), prop_time(_prop_time),
num_thread_per_worker(_num_thread_per_worker), num_workers(_num_workers),
num_layers(_num_layers), num_neighbors(_num_neighbors), recent(_recent),
num_history(_num_history), window_duration(_window_duration)
/*
sampling:
- layer: <number of layers to sample>
neighbor: <a list of integers indicating how many neighbors are sampled in each layer>
strategy: <'recent' that samples most recent neighbors or 'uniform' that uniformly samples neighbors form the past>
prop_time: <False or True that specifies wherether to use the timestamp of the root nodes when sampling for their multi-hop neighbors>
history: <number of snapshots to sample on>
duration: <length in time of each snapshot, 0 for infinite length (used in non-snapshot-based methods)
num_thread: <number of threads of the sampler>
*/
{
omp_set_num_threads(num_thread_per_worker * num_workers);
num_threads = num_thread_per_worker * num_workers;
num_nodes = indptr.size() - 1;
num_edges = indices.size();
ts_ptr_lock = (omp_lock_t *)malloc(num_nodes * sizeof(omp_lock_t));
for (int i = 0; i < num_nodes; i++)
omp_init_lock(&ts_ptr_lock[i]);
//初值为indptr
// ts_ptr 指向一个二维vector[[]]
// 外层的每个vector对应一个快照
ts_ptr.resize(num_history + 1);
// 对于每个快照都对应一个t-csr的结构中的indpt结构
// 对应的是每个节点按时间顺序的出边的dst id
// 初始化每个快照的大小都是未切分全图的indptr
for (auto it = ts_ptr.begin(); it != ts_ptr.end(); it++)
{
// 每个节点最多总结点个数-1条边
it->resize(indptr.size() - 1);
#pragma omp parallel for
for (auto itt = indptr.begin(); itt < indptr.end() - 1; itt++)
(*it)[itt - indptr.begin()] = *itt;
}
}
// 复原每个快照为全图
void reset()
{
for (auto it = ts_ptr.begin(); it != ts_ptr.end(); it++)
{
it->resize(indptr.size() - 1);
#pragma omp parallel for
for (auto itt = indptr.begin(); itt < indptr.end() - 1; itt++)
(*it)[itt - indptr.begin()] = *itt;
}
}
//num_history root_nodes root_ts offset?
// 对于第slc个快照以及确定的root_nodes进行采样
void update_ts_ptr(int slc, std::vector<NodeIDType> &root_nodes,
std::vector<TimeStampType> &root_ts, float offset)
{
#pragma omp parallel for schedule(static, int(ceil(static_cast<float>(root_nodes.size()) / num_threads)))
for (std::vector<NodeIDType>::size_type i = 0; i < root_nodes.size(); i++)
{
//获取root节点
NodeIDType n = root_nodes[i];
omp_set_lock(&(ts_ptr_lock[n]));
//从slc开始出发,依次遍历边表里的每条边
// 根据ts_ptr和indptr找到time的索引
for (std::vector<EdgeIDType>::size_type j = ts_ptr[slc][n]; j < indptr[n + 1]; j++)
{
// std::cout << "comparing " << ts[j] << " with " << root_ts[i] << std::endl;
if (ts[j] > (root_ts[i] + offset - 1e-7f))
{
if (j != ts_ptr[slc][n])
ts_ptr[slc][n] = j - 1;
break;
}
if (j == indptr[n + 1] - 1)
{
ts_ptr[slc][n] = j;
}
}
omp_unset_lock(&(ts_ptr_lock[n]));
}
}
inline void add_neighbor(std::vector<NodeIDType> *_row, std::vector<NodeIDType> *_col,
std::vector<EdgeIDType> *_eid, std::vector<TimeStampType> *_ts,
std::vector<TimeStampType> *_dts, std::vector<NodeIDType> *_nodes,
EdgeIDType &k, TimeStampType &src_ts, int &row_id)
{
_row->push_back(row_id);
_col->push_back(_nodes->size());
_eid->push_back(eid[k]);
if (prop_time)
_ts->push_back(src_ts);
else
_ts->push_back(ts[k]);
_dts->push_back(src_ts - ts[k]);
_nodes->push_back(indices[k]);
// _row.push_back(0);
// _col.push_back(0);
// _eid.push_back(0);
// if (prop_time)
// _ts.push_back(src_ts);
// else
// _ts.push_back(10000);
// _nodes.push_back(100);
}
inline void combine_coo(TemporalGraphBlock &_ret, std::vector<NodeIDType> **_row,
std::vector<NodeIDType> **_col,
std::vector<EdgeIDType> **_eid,
std::vector<TimeStampType> **_ts,
std::vector<TimeStampType> **_dts,
std::vector<NodeIDType> **_nodes,
std::vector<int> &_out_nodes)
{
std::vector<EdgeIDType> cum_row, cum_col;
cum_row.push_back(0);
cum_col.push_back(0);
for (int tid = 0; tid < num_threads; tid++)
{
// std::cout<<tid<<" here "<<_out_nodes[tid]<<std::endl;
cum_row.push_back(cum_row.back() + _out_nodes[tid]);
cum_col.push_back(cum_col.back() + _col[tid]->size());
}
int num_root_nodes = _ret.nodes.size();
_ret.row.resize(cum_col.back());
_ret.col.resize(cum_col.back());
_ret.eid.resize(cum_col.back());
_ret.ts.resize(cum_col.back() + num_root_nodes);
_ret.dts.resize(cum_col.back() + num_root_nodes);
_ret.nodes.resize(cum_col.back() + num_root_nodes);
#pragma omp parallel for schedule(static, 1)
for (int tid = 0; tid < num_threads; tid++)
{
std::transform(_row[tid]->begin(), _row[tid]->end(), _row[tid]->begin(),
[&](auto &v){ return v + cum_row[tid]; });
std::transform(_col[tid]->begin(), _col[tid]->end(), _col[tid]->begin(),
[&](auto &v){ return v + cum_col[tid] + num_root_nodes; });
std::copy(_row[tid]->begin(), _row[tid]->end(), _ret.row.begin() + cum_col[tid]);
std::copy(_col[tid]->begin(), _col[tid]->end(), _ret.col.begin() + cum_col[tid]);
std::copy(_eid[tid]->begin(), _eid[tid]->end(), _ret.eid.begin() + cum_col[tid]);
std::copy(_ts[tid]->begin(), _ts[tid]->end(), _ret.ts.begin() + cum_col[tid] + num_root_nodes);
std::copy(_dts[tid]->begin(), _dts[tid]->end(), _ret.dts.begin() + cum_col[tid] + num_root_nodes);
std::copy(_nodes[tid]->begin(), _nodes[tid]->end(), _ret.nodes.begin() + cum_col[tid] + num_root_nodes);
delete _row[tid];
delete _col[tid];
delete _eid[tid];
delete _ts[tid];
delete _dts[tid];
delete _nodes[tid];
}
_ret.dim_in = _ret.nodes.size();
_ret.dim_out = cum_row.back();
}
void sample_layer(std::vector<NodeIDType> &_root_nodes, std::vector<TimeStampType> &_root_ts,
int neighs, bool use_ptr, bool from_root)
{
double t_s = omp_get_wtime();
std::vector<NodeIDType> *root_nodes;
std::vector<TimeStampType> *root_ts;
if (from_root)
{
root_nodes = &_root_nodes;
root_ts = &_root_ts;
}
double t_ptr_s = omp_get_wtime();
if (use_ptr)
update_ts_ptr(num_history, *root_nodes, *root_ts, 0);
ret[0].ptr_time += omp_get_wtime() - t_ptr_s;
for (int i = 0; i < num_history; i++)
{
if (!from_root)
{
root_nodes = &(ret[ret.size() - 1 - i - num_history].nodes);
root_ts = &(ret[ret.size() - 1 - i - num_history].ts);
}
TimeStampType offset = -i * window_duration;
t_ptr_s = omp_get_wtime();
if ((use_ptr) && (std::abs(window_duration) > 1e-7f))
update_ts_ptr(num_history - 1 - i, *root_nodes, *root_ts, offset - window_duration);
ret[0].ptr_time += omp_get_wtime() - t_ptr_s;
std::vector<NodeIDType> *_row[num_threads];
std::vector<NodeIDType> *_col[num_threads];
std::vector<EdgeIDType> *_eid[num_threads];
std::vector<TimeStampType> *_ts[num_threads];
std::vector<TimeStampType> *_dts[num_threads];
std::vector<NodeIDType> *_nodes[num_threads];
std::vector<int> _out_node(num_threads, 0);
int reserve_capacity = int(ceil((*root_nodes).size() / num_threads)) * neighs;
#pragma omp parallel
{
int tid = omp_get_thread_num();
unsigned int loc_seed = tid;
_row[tid] = new std::vector<NodeIDType>;
_col[tid] = new std::vector<NodeIDType>;
_eid[tid] = new std::vector<EdgeIDType>;
_ts[tid] = new std::vector<TimeStampType>;
_dts[tid] = new std::vector<TimeStampType>;
_nodes[tid] = new std::vector<NodeIDType>;
_row[tid]->reserve(reserve_capacity);
_col[tid]->reserve(reserve_capacity);
_eid[tid]->reserve(reserve_capacity);
_ts[tid]->reserve(reserve_capacity);
_dts[tid]->reserve(reserve_capacity);
_nodes[tid]->reserve(reserve_capacity);
// #pragma omp critical
// std::cout<<tid<<" sampling: "<<root_nodes->size()<<" "<<int(ceil((*root_nodes).size() / num_threads))<<std::endl;
#pragma omp for schedule(static, int(ceil(static_cast<float>((*root_nodes).size()) / num_threads)))
for (std::vector<NodeIDType>::size_type j = 0; j < (*root_nodes).size(); j++)
{
NodeIDType n = (*root_nodes)[j];
// if (tid == 16)
// std::cout << _out_node[tid] << " " <<j << " " << n << std::endl;
TimeStampType nts = (*root_ts)[j];
EdgeIDType s_search, e_search;
if (use_ptr)
{
s_search = ts_ptr[num_history - 1 - i][n];
e_search = ts_ptr[num_history - i][n];
}
else
{
// search for start and end pointer
double t_search_s = omp_get_wtime();
if (num_history == 1)
{
// TGAT style
s_search = indptr[n];
auto e_it = std::upper_bound(ts.begin() + indptr[n],
ts.begin() + indptr[n + 1], nts);
e_search = std::max(int(e_it - ts.begin()) - 1, s_search);
}
else
{
// DySAT style
auto s_it = std::upper_bound(ts.begin() + indptr[n],
ts.begin() + indptr[n + 1],
nts + offset - window_duration);
s_search = std::max(int(s_it - ts.begin()) - 1, indptr[n]);
auto e_it = std::upper_bound(ts.begin() + indptr[n],
ts.begin() + indptr[n + 1], nts + offset);
e_search = std::max(int(e_it - ts.begin()) - 1, s_search);
}
if (tid == 0)
ret[0].search_time += omp_get_wtime() - t_search_s;
}
// std::cout << n << " " << s_search << " " << e_search << std::endl;
double t_sample_s = omp_get_wtime();
if ((recent) || (e_search - s_search < neighs))
{
// no sampling, pick recent neighbors
for (EdgeIDType k = e_search; k > std::max(s_search, e_search - neighs); k--)
{
//采样目标是下标k
if (ts[k] < nts + offset - 1e-7f)
{
add_neighbor(_row[tid], _col[tid], _eid[tid], _ts[tid],
_dts[tid], _nodes[tid], k, nts, _out_node[tid]);
}
}
}
else
{
// random sampling within ptr
for (int _i = 0; _i < neighs; _i++)
{
//采样目标picked下标,把k对应的数据加进去,nts是root的时间戳,_out_node
EdgeIDType picked = s_search + rand_r(&loc_seed) % (e_search - s_search + 1);
if (ts[picked] < nts + offset - 1e-7f)
{
add_neighbor(_row[tid], _col[tid], _eid[tid], _ts[tid],
_dts[tid], _nodes[tid], picked, nts, _out_node[tid]);
}
}
}
_out_node[tid] += 1;
if (tid == 0)
ret[0].sample_time += omp_get_wtime() - t_sample_s;
}
}
double t_coo_s = omp_get_wtime();
ret[ret.size() - 1 - i].ts.insert(ret[ret.size() - 1 - i].ts.end(),
root_ts->begin(), root_ts->end());
ret[ret.size() - 1 - i].nodes.insert(ret[ret.size() - 1 - i].nodes.end(),
root_nodes->begin(), root_nodes->end());
ret[ret.size() - 1 - i].dts.resize(root_nodes->size());
combine_coo(ret[ret.size() - 1 - i], _row, _col, _eid, _ts, _dts, _nodes, _out_node);
ret[0].coo_time += omp_get_wtime() - t_coo_s;
}
ret[0].tot_time += omp_get_wtime() - t_s;
}
void sample(std::vector<NodeIDType> &root_nodes, std::vector<TimeStampType> &root_ts)
{
// a weird bug, dgl library seems to modify the total number of threads
omp_set_num_threads(num_threads);
ret.resize(0);
bool first_layer = true;
bool use_ptr = false;
for (int i = 0; i < num_layers; i++)
{
ret.resize(ret.size() + num_history);
if ((first_layer) || ((prop_time) && num_history == 1) || (recent))
{
first_layer = false;
use_ptr = true;
}
else
use_ptr = false;
if (i==0)
sample_layer(root_nodes, root_ts, num_neighbors[i], use_ptr, true);
else
sample_layer(root_nodes, root_ts, num_neighbors[i], use_ptr, false);
}
}
};
template<typename T>
inline py::array vec2npy(const std::vector<T> &vec)
{
// need to let python garbage collector handle C++ vector memory
// see https://github.com/pybind/pybind11/issues/1042
auto v = new std::vector<T>(vec);
auto capsule = py::capsule(v, [](void *v)
{ delete reinterpret_cast<std::vector<T> *>(v); });
return py::array(v->size(), v->data(), capsule);
// return py::array(vec.size(), vec.data());
}
PYBIND11_MODULE(sampler_core, m)
{
py::class_<TemporalGraphBlock>(m, "TemporalGraphBlock")
.def(py::init<std::vector<NodeIDType> &, std::vector<NodeIDType> &,
std::vector<EdgeIDType> &, std::vector<TimeStampType> &,
std::vector<TimeStampType> &, std::vector<NodeIDType> &,
NodeIDType, NodeIDType>())
.def("row", [](const TemporalGraphBlock &tgb) { return vec2npy(tgb.row); })
.def("col", [](const TemporalGraphBlock &tgb) { return vec2npy(tgb.col); })
.def("eid", [](const TemporalGraphBlock &tgb) { return vec2npy(tgb.eid); })
.def("ts", [](const TemporalGraphBlock &tgb) { return vec2npy(tgb.ts); })
.def("dts", [](const TemporalGraphBlock &tgb) { return vec2npy(tgb.dts); })
.def("nodes", [](const TemporalGraphBlock &tgb) { return vec2npy(tgb.nodes); })
.def("dim_in", [](const TemporalGraphBlock &tgb) { return tgb.dim_in; })
.def("dim_out", [](const TemporalGraphBlock &tgb) { return tgb.dim_out; })
.def("tot_time", [](const TemporalGraphBlock &tgb) { return tgb.tot_time; })
.def("ptr_time", [](const TemporalGraphBlock &tgb) { return tgb.ptr_time; })
.def("search_time", [](const TemporalGraphBlock &tgb) { return tgb.search_time; })
.def("sample_time", [](const TemporalGraphBlock &tgb) { return tgb.sample_time; })
.def("coo_time", [](const TemporalGraphBlock &tgb) { return tgb.coo_time; });
py::class_<ParallelSampler>(m, "ParallelSampler")
.def(py::init<std::vector<EdgeIDType> &, std::vector<EdgeIDType> &,
std::vector<EdgeIDType> &, std::vector<TimeStampType> &,
int, int, int, std::vector<int> &, bool, bool,
int, TimeStampType>())
.def("sample", &ParallelSampler::sample)
.def("reset", &ParallelSampler::reset)
.def("get_ret", [](const ParallelSampler &ps) { return ps.ret; });
}
\ No newline at end of file
from glob import glob
from setuptools import setup
from pybind11.setup_helpers import Pybind11Extension
ext_modules = [
Pybind11Extension("sampler_core",
['sampler_core.cpp'],
extra_compile_args = ['-fopenmp'],
extra_link_args = ['-fopenmp'],),
]
setup(
name = "sampler_core",
version = "0.0.1",
author = "Hongkuan Zhou",
author_email = "hongkuaz@usc.edu",
url = "https://tedzhouhk.github.io/about/",
description = "Parallel Sampling for Temporal Graphs",
ext_modules = ext_modules,
)
\ No newline at end of file
import argparse
import os
parser=argparse.ArgumentParser()
parser.add_argument('--data', type=str, help='dataset name')
parser.add_argument('--config', type=str, help='path to config file')
parser.add_argument('--gpu', type=str, default='0', help='which GPU to use')
parser.add_argument('--model_name', type=str, default='', help='name of stored model')
parser.add_argument('--rand_edge_features', type=int, default=0, help='use random edge featrues')
parser.add_argument('--rand_node_features', type=int, default=0, help='use random node featrues')
parser.add_argument('--eval_neg_samples', type=int, default=1, help='how many negative samples to use at inference. Note: this will change the metric of test set to AP+AUC to AP+MRR!')
args=parser.parse_args()
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
import torch
import time
import random
import dgl
import numpy as np
from modules import *
from sampler import *
from utils import *
from sklearn.metrics import average_precision_score, roc_auc_score
def set_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
# set_seed(0)
node_feats, edge_feats = load_feat(args.data, args.rand_edge_features, args.rand_node_features)
g, df = load_graph(args.data)
sample_param, memory_param, gnn_param, train_param = parse_config(args.config)
train_edge_end = df[df['ext_roll'].gt(0)].index[0]
val_edge_end = df[df['ext_roll'].gt(1)].index[0]
gnn_dim_node = 0 if node_feats is None else node_feats.shape[1]
gnn_dim_edge = 0 if edge_feats is None else edge_feats.shape[1]
combine_first = False
if 'combine_neighs' in train_param and train_param['combine_neighs']:
combine_first = True
model = GeneralModel(gnn_dim_node, gnn_dim_edge, sample_param, memory_param, gnn_param, train_param, combined=combine_first).cuda()
mailbox = MailBox(memory_param, g['indptr'].shape[0] - 1, gnn_dim_edge) if memory_param['type'] != 'none' else None
creterion = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=train_param['lr'])
if 'all_on_gpu' in train_param and train_param['all_on_gpu']:
if node_feats is not None:
node_feats = node_feats.cuda()
if edge_feats is not None:
edge_feats = edge_feats.cuda()
if mailbox is not None:
mailbox.move_to_gpu()
sampler = None
if not ('no_sample' in sample_param and sample_param['no_sample']):
sampler = ParallelSampler(g['indptr'], g['indices'], g['eid'], g['ts'].astype(np.float32),
sample_param['num_thread'], 1, sample_param['layer'], sample_param['neighbor'],
sample_param['strategy']=='recent', sample_param['prop_time'],
sample_param['history'], float(sample_param['duration']))
neg_link_sampler = NegLinkSampler(g['indptr'].shape[0] - 1)
def eval(mode='val'):
neg_samples = 1
model.eval()
aps = list()
aucs_mrrs = list()
if mode == 'val':
eval_df = df[train_edge_end:val_edge_end]
elif mode == 'test':
eval_df = df[val_edge_end:]
neg_samples = args.eval_neg_samples
elif mode == 'train':
eval_df = df[:train_edge_end]
with torch.no_grad():
total_loss = 0
for _, rows in eval_df.groupby(eval_df.index // train_param['batch_size']):
root_nodes = np.concatenate([rows.src.values, rows.dst.values, neg_link_sampler.sample(len(rows) * neg_samples)]).astype(np.int32)
ts = np.tile(rows.time.values, neg_samples + 2).astype(np.float32)
if sampler is not None:
if 'no_neg' in sample_param and sample_param['no_neg']:
pos_root_end = len(rows) * 2
sampler.sample(root_nodes[:pos_root_end], ts[:pos_root_end])
else:
sampler.sample(root_nodes, ts)
ret = sampler.get_ret()
if gnn_param['arch'] != 'identity':
mfgs = to_dgl_blocks(ret, sample_param['history'])
else:
mfgs = node_to_dgl_blocks(root_nodes, ts)
mfgs = prepare_input(mfgs, node_feats, edge_feats, combine_first=combine_first)
if mailbox is not None:
mailbox.prep_input_mails(mfgs[0])
pred_pos, pred_neg = model(mfgs, neg_samples=neg_samples)
total_loss += creterion(pred_pos, torch.ones_like(pred_pos))
total_loss += creterion(pred_neg, torch.zeros_like(pred_neg))
y_pred = torch.cat([pred_pos, pred_neg], dim=0).sigmoid().cpu()
y_true = torch.cat([torch.ones(pred_pos.size(0)), torch.zeros(pred_neg.size(0))], dim=0)
aps.append(average_precision_score(y_true, y_pred))
if neg_samples > 1:
aucs_mrrs.append(torch.reciprocal(torch.sum(pred_pos.squeeze() < pred_neg.squeeze().reshape(neg_samples, -1), dim=0) + 1).type(torch.float))
else:
aucs_mrrs.append(roc_auc_score(y_true, y_pred))
if mailbox is not None:
eid = rows['Unnamed: 0'].values
mem_edge_feats = edge_feats[eid] if edge_feats is not None else None
block = None
if memory_param['deliver_to'] == 'neighbors':
block = to_dgl_blocks(ret, sample_param['history'], reverse=True)[0][0]
mailbox.update_mailbox(model.memory_updater.last_updated_nid, model.memory_updater.last_updated_memory, root_nodes, ts, mem_edge_feats, block, neg_samples=neg_samples)
mailbox.update_memory(model.memory_updater.last_updated_nid, model.memory_updater.last_updated_memory, root_nodes, model.memory_updater.last_updated_ts, neg_samples=neg_samples)
if mode == 'val':
val_losses.append(float(total_loss))
ap = float(torch.tensor(aps).mean())
if neg_samples > 1:
auc_mrr = float(torch.cat(aucs_mrrs).mean())
else:
auc_mrr = float(torch.tensor(aucs_mrrs).mean())
return ap, auc_mrr
if not os.path.isdir('models'):
os.mkdir('models')
if args.model_name == '':
path_saver = 'models/{}_{}.pkl'.format(args.data, time.time())
else:
path_saver = 'models/{}.pkl'.format(args.model_name)
best_ap = 0
best_e = 0
val_losses = list()
group_indexes = list()
group_indexes.append(np.array(df[:train_edge_end].index // train_param['batch_size']))
if 'reorder' in train_param:
# random chunk shceduling
reorder = train_param['reorder']
group_idx = list()
for i in range(reorder):
group_idx += list(range(0 - i, reorder - i))
group_idx = np.repeat(np.array(group_idx), train_param['batch_size'] // reorder)
group_idx = np.tile(group_idx, train_edge_end // train_param['batch_size'] + 1)[:train_edge_end]
group_indexes.append(group_indexes[0] + group_idx)
base_idx = group_indexes[0]
for i in range(1, train_param['reorder']):
additional_idx = np.zeros(train_param['batch_size'] // train_param['reorder'] * i) - 1
group_indexes.append(np.concatenate([additional_idx, base_idx])[:base_idx.shape[0]])
for e in range(train_param['epoch']):
print('Epoch {:d}:'.format(e))
time_sample = 0
time_prep = 0
time_tot = 0
total_loss = 0
# training
model.train()
if sampler is not None:
sampler.reset()
if mailbox is not None:
mailbox.reset()
model.memory_updater.last_updated_nid = None
for _, rows in df[:train_edge_end].groupby(group_indexes[random.randint(0, len(group_indexes) - 1)]):
t_tot_s = time.time()
root_nodes = np.concatenate([rows.src.values, rows.dst.values, neg_link_sampler.sample(len(rows))]).astype(np.int32)
ts = np.concatenate([rows.time.values, rows.time.values, rows.time.values]).astype(np.float32)
if sampler is not None:
if 'no_neg' in sample_param and sample_param['no_neg']:
pos_root_end = root_nodes.shape[0] * 2 // 3
sampler.sample(root_nodes[:pos_root_end], ts[:pos_root_end])
else:
sampler.sample(root_nodes, ts)
ret = sampler.get_ret()
time_sample += ret[0].sample_time()
t_prep_s = time.time()
if gnn_param['arch'] != 'identity':
mfgs = to_dgl_blocks(ret, sample_param['history'])
else:
mfgs = node_to_dgl_blocks(root_nodes, ts)
mfgs = prepare_input(mfgs, node_feats, edge_feats, combine_first=combine_first)
if mailbox is not None:
mailbox.prep_input_mails(mfgs[0])
time_prep += time.time() - t_prep_s
optimizer.zero_grad()
pred_pos, pred_neg = model(mfgs)
loss = creterion(pred_pos, torch.ones_like(pred_pos))
loss += creterion(pred_neg, torch.zeros_like(pred_neg))
total_loss += float(loss) * train_param['batch_size']
loss.backward()
optimizer.step()
t_prep_s = time.time()
if mailbox is not None:
eid = rows['Unnamed: 0'].values
mem_edge_feats = edge_feats[eid] if edge_feats is not None else None
block = None
if memory_param['deliver_to'] == 'neighbors':
block = to_dgl_blocks(ret, sample_param['history'], reverse=True)[0][0]
mailbox.update_mailbox(model.memory_updater.last_updated_nid, model.memory_updater.last_updated_memory, root_nodes, ts, mem_edge_feats, block)
mailbox.update_memory(model.memory_updater.last_updated_nid, model.memory_updater.last_updated_memory, root_nodes, model.memory_updater.last_updated_ts)
time_prep += time.time() - t_prep_s
time_tot += time.time() - t_tot_s
ap, auc = eval('val')
if e > 2 and ap > best_ap:
best_e = e
best_ap = ap
torch.save(model.state_dict(), path_saver)
print('\ttrain loss:{:.4f} val ap:{:4f} val auc:{:4f}'.format(total_loss, ap, auc))
print('\ttotal time:{:.2f}s sample time:{:.2f}s prep time:{:.2f}s'.format(time_tot, time_sample, time_prep))
print('Loading model at epoch {}...'.format(best_e))
model.load_state_dict(torch.load(path_saver))
model.eval()
if sampler is not None:
sampler.reset()
if mailbox is not None:
mailbox.reset()
model.memory_updater.last_updated_nid = None
eval('train')
eval('val')
ap, auc = eval('test')
if args.eval_neg_samples > 1:
print('\ttest AP:{:4f} test MRR:{:4f}'.format(ap, auc))
else:
print('\ttest AP:{:4f} test AUC:{:4f}'.format(ap, auc))
import argparse
import os
parser=argparse.ArgumentParser()
parser.add_argument('--data', type=str, help='dataset name')
parser.add_argument('--config', type=str, help='path to config file')
parser.add_argument('--seed', type=int, default=0, help='random seed to use')
parser.add_argument('--num_gpus', type=int, default=4, help='number of gpus to use')
parser.add_argument('--omp_num_threads', type=int, default=8)
parser.add_argument("--local_rank", type=int, default=-1)
args=parser.parse_args()
# set which GPU to use
if args.local_rank < args.num_gpus:
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.local_rank)
else:
os.environ['CUDA_VISIBLE_DEVICES'] = ''
os.environ['OMP_NUM_THREADS'] = str(args.omp_num_threads)
os.environ['MKL_NUM_THREADS'] = str(args.omp_num_threads)
import torch
import dgl
import datetime
import random
import math
import threading
import numpy as np
from tqdm import tqdm
from dgl.utils.shared_mem import create_shared_mem_array, get_shared_mem_array
from sklearn.metrics import average_precision_score, roc_auc_score
from modules import *
from sampler import *
from utils import *
def set_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
set_seed(args.seed)
torch.distributed.init_process_group(backend='gloo', timeout=datetime.timedelta(0, 3600))
nccl_group = torch.distributed.new_group(ranks=list(range(args.num_gpus)), backend='nccl')
if args.local_rank == 0:
_node_feats, _edge_feats = load_feat(args.data)
dim_feats = [0, 0, 0, 0, 0, 0]
if args.local_rank == 0:
if _node_feats is not None:
dim_feats[0] = _node_feats.shape[0]
dim_feats[1] = _node_feats.shape[1]
dim_feats[2] = _node_feats.dtype
node_feats = create_shared_mem_array('node_feats', _node_feats.shape, dtype=_node_feats.dtype)
node_feats.copy_(_node_feats)
del _node_feats
else:
node_feats = None
if _edge_feats is not None:
dim_feats[3] = _edge_feats.shape[0]
dim_feats[4] = _edge_feats.shape[1]
dim_feats[5] = _edge_feats.dtype
edge_feats = create_shared_mem_array('edge_feats', _edge_feats.shape, dtype=_edge_feats.dtype)
edge_feats.copy_(_edge_feats)
del _edge_feats
else:
edge_feats = None
torch.distributed.barrier()
torch.distributed.broadcast_object_list(dim_feats, src=0)
if args.local_rank > 0 and args.local_rank < args.num_gpus:
node_feats = None
edge_feats = None
if os.path.exists('DATA/{}/node_features.pt'.format(args.data)):
node_feats = get_shared_mem_array('node_feats', (dim_feats[0], dim_feats[1]), dtype=dim_feats[2])
if os.path.exists('DATA/{}/edge_features.pt'.format(args.data)):
edge_feats = get_shared_mem_array('edge_feats', (dim_feats[3], dim_feats[4]), dtype=dim_feats[5])
sample_param, memory_param, gnn_param, train_param = parse_config(args.config)
orig_batch_size = train_param['batch_size']
if args.local_rank == 0:
if not os.path.isdir('models'):
os.mkdir('models')
path_saver = ['models/{}_{}.pkl'.format(args.data, time.time())]
else:
path_saver = [None]
torch.distributed.broadcast_object_list(path_saver, src=0)
path_saver = path_saver[0]
if args.local_rank == args.num_gpus:
g, df = load_graph(args.data)
num_nodes = [g['indptr'].shape[0] - 1]
else:
num_nodes = [None]
torch.distributed.barrier()
torch.distributed.broadcast_object_list(num_nodes, src=args.num_gpus)
num_nodes = num_nodes[0]
mailbox = None
if memory_param['type'] != 'none':
if args.local_rank == 0:
node_memory = create_shared_mem_array('node_memory', torch.Size([num_nodes, memory_param['dim_out']]), dtype=torch.float32)
node_memory_ts = create_shared_mem_array('node_memory_ts', torch.Size([num_nodes]), dtype=torch.float32)
mails = create_shared_mem_array('mails', torch.Size([num_nodes, memory_param['mailbox_size'], 2 * memory_param['dim_out'] + dim_feats[4]]), dtype=torch.float32)
mail_ts = create_shared_mem_array('mail_ts', torch.Size([num_nodes, memory_param['mailbox_size']]), dtype=torch.float32)
next_mail_pos = create_shared_mem_array('next_mail_pos', torch.Size([num_nodes]), dtype=torch.long)
update_mail_pos = create_shared_mem_array('update_mail_pos', torch.Size([num_nodes]), dtype=torch.int32)
torch.distributed.barrier()
node_memory.zero_()
node_memory_ts.zero_()
mails.zero_()
mail_ts.zero_()
next_mail_pos.zero_()
update_mail_pos.zero_()
else:
torch.distributed.barrier()
node_memory = get_shared_mem_array('node_memory', torch.Size([num_nodes, memory_param['dim_out']]), dtype=torch.float32)
node_memory_ts = get_shared_mem_array('node_memory_ts', torch.Size([num_nodes]), dtype=torch.float32)
mails = get_shared_mem_array('mails', torch.Size([num_nodes, memory_param['mailbox_size'], 2 * memory_param['dim_out'] + dim_feats[4]]), dtype=torch.float32)
mail_ts = get_shared_mem_array('mail_ts', torch.Size([num_nodes, memory_param['mailbox_size']]), dtype=torch.float32)
next_mail_pos = get_shared_mem_array('next_mail_pos', torch.Size([num_nodes]), dtype=torch.long)
update_mail_pos = get_shared_mem_array('update_mail_pos', torch.Size([num_nodes]), dtype=torch.int32)
mailbox = MailBox(memory_param, num_nodes, dim_feats[4], node_memory, node_memory_ts, mails, mail_ts, next_mail_pos, update_mail_pos)
class DataPipelineThread(threading.Thread):
def __init__(self, my_mfgs, my_root, my_ts, my_eid, my_block, stream):
super(DataPipelineThread, self).__init__()
self.my_mfgs = my_mfgs
self.my_root = my_root
self.my_ts = my_ts
self.my_eid = my_eid
self.my_block = my_block
self.stream = stream
self.mfgs = None
self.root = None
self.ts = None
self.eid = None
self.block = None
def run(self):
with torch.cuda.stream(self.stream):
# print(args.local_rank, 'start thread')
nids, eids = get_ids(self.my_mfgs[0], node_feats, edge_feats)
mfgs = mfgs_to_cuda(self.my_mfgs[0])
prepare_input(mfgs, node_feats, edge_feats, pinned=True, nfeat_buffs=pinned_nfeat_buffs, efeat_buffs=pinned_efeat_buffs, nids=nids, eids=eids)
if mailbox is not None:
mailbox.prep_input_mails(mfgs[0], use_pinned_buffers=True)
self.mfgs = mfgs
self.root = self.my_root[0]
self.ts = self.my_ts[0]
self.eid = self.my_eid[0]
if memory_param['deliver_to'] == 'neighbors':
self.block = self.my_block[0]
# print(args.local_rank, 'finished')
def get_stream(self):
return self.stream
def get_mfgs(self):
return self.mfgs
def get_root(self):
return self.root
def get_ts(self):
return self.ts
def get_eid(self):
return self.eid
def get_block(self):
return self.block
if args.local_rank < args.num_gpus:
# GPU worker process
model = GeneralModel(dim_feats[1], dim_feats[4], sample_param, memory_param, gnn_param, train_param).cuda()
find_unused_parameters = True if sample_param['history'] > 1 else False
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], process_group=nccl_group, output_device=args.local_rank, find_unused_parameters=find_unused_parameters)
creterion = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=train_param['lr'])
pinned_nfeat_buffs, pinned_efeat_buffs = get_pinned_buffers(sample_param, train_param['batch_size'], node_feats, edge_feats)
if mailbox is not None:
mailbox.allocate_pinned_memory_buffers(sample_param, train_param['batch_size'])
tot_loss = 0
prev_thread = None
while True:
my_model_state = [None]
model_state = [None] * (args.num_gpus + 1)
torch.distributed.scatter_object_list(my_model_state, model_state, src=args.num_gpus)
if my_model_state[0] == -1:
break
elif my_model_state[0] == 4:
continue
elif my_model_state[0] == 2:
torch.save(model.state_dict(), path_saver)
continue
elif my_model_state[0] == 3:
model.load_state_dict(torch.load(path_saver, map_location=torch.device('cuda:0')))
continue
elif my_model_state[0] == 5:
torch.distributed.gather_object(float(tot_loss), None, dst=args.num_gpus)
tot_loss = 0
continue
elif my_model_state[0] == 0:
if prev_thread is not None:
my_mfgs = [None]
multi_mfgs = [None] * (args.num_gpus + 1)
my_root = [None]
multi_root = [None] * (args.num_gpus + 1)
my_ts = [None]
multi_ts = [None] * (args.num_gpus + 1)
my_eid = [None]
multi_eid = [None] * (args.num_gpus + 1)
my_block = [None]
multi_block = [None] * (args.num_gpus + 1)
torch.distributed.scatter_object_list(my_mfgs, multi_mfgs, src=args.num_gpus)
if mailbox is not None:
torch.distributed.scatter_object_list(my_root, multi_root, src=args.num_gpus)
torch.distributed.scatter_object_list(my_ts, multi_ts, src=args.num_gpus)
torch.distributed.scatter_object_list(my_eid, multi_eid, src=args.num_gpus)
if memory_param['deliver_to'] == 'neighbors':
torch.distributed.scatter_object_list(my_block, multi_block, src=args.num_gpus)
stream = torch.cuda.Stream()
curr_thread = DataPipelineThread(my_mfgs, my_root, my_ts, my_eid, my_block, stream)
curr_thread.start()
prev_thread.join()
# with torch.cuda.stream(prev_thread.get_stream()):
mfgs = prev_thread.get_mfgs()
model.train()
optimizer.zero_grad()
pred_pos, pred_neg = model(mfgs)
loss = creterion(pred_pos, torch.ones_like(pred_pos))
loss += creterion(pred_neg, torch.zeros_like(pred_neg))
loss.backward()
optimizer.step()
with torch.no_grad():
tot_loss += float(loss)
if mailbox is not None:
with torch.no_grad():
eid = prev_thread.get_eid()
mem_edge_feats = edge_feats[eid] if edge_feats is not None else None
root_nodes = prev_thread.get_root()
ts = prev_thread.get_ts()
block = prev_thread.get_block()
mailbox.update_mailbox(model.module.memory_updater.last_updated_nid, model.module.memory_updater.last_updated_memory, root_nodes, ts, mem_edge_feats, block)
mailbox.update_memory(model.module.memory_updater.last_updated_nid, model.module.memory_updater.last_updated_memory, model.module.memory_updater.last_updated_ts)
if memory_param['deliver_to'] == 'neighbors':
torch.distributed.barrier(group=nccl_group)
if args.local_rank == 0:
mailbox.update_next_mail_pos()
prev_thread = curr_thread
else:
my_mfgs = [None]
multi_mfgs = [None] * (args.num_gpus + 1)
my_root = [None]
multi_root = [None] * (args.num_gpus + 1)
my_ts = [None]
multi_ts = [None] * (args.num_gpus + 1)
my_eid = [None]
multi_eid = [None] * (args.num_gpus + 1)
my_block = [None]
multi_block = [None] * (args.num_gpus + 1)
torch.distributed.scatter_object_list(my_mfgs, multi_mfgs, src=args.num_gpus)
if mailbox is not None:
torch.distributed.scatter_object_list(my_root, multi_root, src=args.num_gpus)
torch.distributed.scatter_object_list(my_ts, multi_ts, src=args.num_gpus)
torch.distributed.scatter_object_list(my_eid, multi_eid, src=args.num_gpus)
if memory_param['deliver_to'] == 'neighbors':
torch.distributed.scatter_object_list(my_block, multi_block, src=args.num_gpus)
stream = torch.cuda.Stream()
prev_thread = DataPipelineThread(my_mfgs, my_root, my_ts, my_eid, my_block, stream)
prev_thread.start()
elif my_model_state[0] == 1:
if prev_thread is not None:
# finish last training mini-batch
prev_thread.join()
mfgs = prev_thread.get_mfgs()
model.train()
optimizer.zero_grad()
pred_pos, pred_neg = model(mfgs)
loss = creterion(pred_pos, torch.ones_like(pred_pos))
loss += creterion(pred_neg, torch.zeros_like(pred_neg))
loss.backward()
optimizer.step()
with torch.no_grad():
tot_loss += float(loss)
if mailbox is not None:
with torch.no_grad():
eid = prev_thread.get_eid()
mem_edge_feats = edge_feats[eid] if edge_feats is not None else None
root_nodes = prev_thread.get_root()
ts = prev_thread.get_ts()
block = prev_thread.get_block()
mailbox.update_mailbox(model.module.memory_updater.last_updated_nid, model.module.memory_updater.last_updated_memory, root_nodes, ts, mem_edge_feats, block)
mailbox.update_memory(model.module.memory_updater.last_updated_nid, model.module.memory_updater.last_updated_memory, model.module.memory_updater.last_updated_ts)
if memory_param['deliver_to'] == 'neighbors':
torch.distributed.barrier(group=nccl_group)
if args.local_rank == 0:
mailbox.update_next_mail_pos()
prev_thread = None
my_mfgs = [None]
multi_mfgs = [None] * (args.num_gpus + 1)
torch.distributed.scatter_object_list(my_mfgs, multi_mfgs, src=args.num_gpus)
mfgs = mfgs_to_cuda(my_mfgs[0])
prepare_input(mfgs, node_feats, edge_feats, pinned=True, nfeat_buffs=pinned_nfeat_buffs, efeat_buffs=pinned_efeat_buffs)
model.eval()
with torch.no_grad():
if mailbox is not None:
mailbox.prep_input_mails(mfgs[0])
pred_pos, pred_neg = model(mfgs)
if mailbox is not None:
my_root = [None]
multi_root = [None] * (args.num_gpus + 1)
my_ts = [None]
multi_ts = [None] * (args.num_gpus + 1)
my_eid = [None]
multi_eid = [None] * (args.num_gpus + 1)
torch.distributed.scatter_object_list(my_root, multi_root, src=args.num_gpus)
torch.distributed.scatter_object_list(my_ts, multi_ts, src=args.num_gpus)
torch.distributed.scatter_object_list(my_eid, multi_eid, src=args.num_gpus)
eid = my_eid[0]
mem_edge_feats = edge_feats[eid] if edge_feats is not None else None
root_nodes = my_root[0]
ts = my_ts[0]
block = None
if memory_param['deliver_to'] == 'neighbors':
my_block = [None]
multi_block = [None] * (args.num_gpus + 1)
torch.distributed.scatter_object_list(my_block, multi_block, src=args.num_gpus)
block = my_block[0]
mailbox.update_mailbox(model.module.memory_updater.last_updated_nid, model.module.memory_updater.last_updated_memory, root_nodes, ts, mem_edge_feats, block)
mailbox.update_memory(model.module.memory_updater.last_updated_nid, model.module.memory_updater.last_updated_memory, model.module.memory_updater.last_updated_ts)
if memory_param['deliver_to'] == 'neighbors':
torch.distributed.barrier(group=nccl_group)
if args.local_rank == 0:
mailbox.update_next_mail_pos()
y_pred = torch.cat([pred_pos, pred_neg], dim=0).sigmoid().cpu()
y_true = torch.cat([torch.ones(pred_pos.size(0)), torch.zeros(pred_neg.size(0))], dim=0)
ap = average_precision_score(y_true, y_pred)
auc = roc_auc_score(y_true, y_pred)
torch.distributed.gather_object(float(ap), None, dst=args.num_gpus)
torch.distributed.gather_object(float(auc), None, dst=args.num_gpus)
else:
# hosting process
train_edge_end = df[df['ext_roll'].gt(0)].index[0]
val_edge_end = df[df['ext_roll'].gt(1)].index[0]
sampler = None
if not ('no_sample' in sample_param and sample_param['no_sample']):
sampler = ParallelSampler(g['indptr'], g['indices'], g['eid'], g['ts'].astype(np.float32),
sample_param['num_thread'], 1, sample_param['layer'], sample_param['neighbor'],
sample_param['strategy']=='recent', sample_param['prop_time'],
sample_param['history'], float(sample_param['duration']))
neg_link_sampler = NegLinkSampler(g['indptr'].shape[0] - 1)
def eval(mode='val'):
if mode == 'val':
eval_df = df[train_edge_end:val_edge_end]
elif mode == 'test':
eval_df = df[val_edge_end:]
elif mode == 'train':
eval_df = df[:train_edge_end]
ap_tot = list()
auc_tot = list()
train_param['bathc_size'] = orig_batch_size
itr_tot = max(len(eval_df) // train_param['batch_size'] // args.num_gpus, 1) * args.num_gpus
train_param['batch_size'] = math.ceil(len(eval_df) / itr_tot)
multi_mfgs = list()
multi_root = list()
multi_ts = list()
multi_eid = list()
multi_block = list()
for _, rows in eval_df.groupby(eval_df.index // train_param['batch_size']):
root_nodes = np.concatenate([rows.src.values, rows.dst.values, neg_link_sampler.sample(len(rows))]).astype(np.int32)
ts = np.concatenate([rows.time.values, rows.time.values, rows.time.values]).astype(np.float32)
if sampler is not None:
if 'no_neg' in sample_param and sample_param['no_neg']:
pos_root_end = root_nodes.shape[0] * 2 // 3
sampler.sample(root_nodes[:pos_root_end], ts[:pos_root_end])
else:
sampler.sample(root_nodes, ts)
ret = sampler.get_ret()
if gnn_param['arch'] != 'identity':
mfgs = to_dgl_blocks(ret, sample_param['history'], cuda=False)
else:
mfgs = node_to_dgl_blocks(root_nodes, ts, cuda=False)
multi_mfgs.append(mfgs)
multi_root.append(root_nodes)
multi_ts.append(ts)
multi_eid.append(rows['Unnamed: 0'].values)
if mailbox is not None and memory_param['deliver_to'] == 'neighbors':
multi_block.append(to_dgl_blocks(ret, sample_param['history'], reverse=True, cuda=False)[0][0])
if len(multi_mfgs) == args.num_gpus:
model_state = [1] * (args.num_gpus + 1)
my_model_state = [None]
torch.distributed.scatter_object_list(my_model_state, model_state, src=args.num_gpus)
multi_mfgs.append(None)
my_mfgs = [None]
torch.distributed.scatter_object_list(my_mfgs, multi_mfgs, src=args.num_gpus)
if mailbox is not None:
multi_root.append(None)
multi_ts.append(None)
multi_eid.append(None)
my_root = [None]
my_ts = [None]
my_eid = [None]
torch.distributed.scatter_object_list(my_root, multi_root, src=args.num_gpus)
torch.distributed.scatter_object_list(my_ts, multi_ts, src=args.num_gpus)
torch.distributed.scatter_object_list(my_eid, multi_eid, src=args.num_gpus)
if memory_param['deliver_to'] == 'neighbors':
multi_block.append(None)
my_block = [None]
torch.distributed.scatter_object_list(my_block, multi_block, src=args.num_gpus)
gathered_ap = [None] * (args.num_gpus + 1)
gathered_auc = [None] * (args.num_gpus + 1)
torch.distributed.gather_object(float(0), gathered_ap, dst=args.num_gpus)
torch.distributed.gather_object(float(0), gathered_auc, dst=args.num_gpus)
ap_tot += gathered_ap[:-1]
auc_tot += gathered_auc[:-1]
multi_mfgs = list()
multi_root = list()
multi_ts = list()
multi_eid = list()
multi_block = list()
pbar.update(1)
ap = float(torch.tensor(ap_tot).mean())
auc = float(torch.tensor(auc_tot).mean())
return ap, auc
best_ap = 0
best_e = 0
tap = 0
tauc = 0
for e in range(train_param['epoch']):
print('Epoch {:d}:'.format(e))
time_sample = 0
time_tot = 0
if sampler is not None:
sampler.reset()
if mailbox is not None:
mailbox.reset()
# training
train_param['bathc_size'] = orig_batch_size
itr_tot = train_edge_end // train_param['batch_size'] // args.num_gpus * args.num_gpus
train_param['batch_size'] = math.ceil(train_edge_end / itr_tot)
multi_mfgs = list()
multi_root = list()
multi_ts = list()
multi_eid = list()
multi_block = list()
group_indexes = list()
group_indexes.append(np.array(df[:train_edge_end].index // train_param['batch_size']))
if 'reorder' in train_param:
# random chunk shceduling
reorder = train_param['reorder']
group_idx = list()
for i in range(reorder):
group_idx += list(range(0 - i, reorder - i))
group_idx = np.repeat(np.array(group_idx), train_param['batch_size'] // reorder)
group_idx = np.tile(group_idx, train_edge_end // train_param['batch_size'] + 1)[:train_edge_end]
group_indexes.append(group_indexes[0] + group_idx)
base_idx = group_indexes[0]
for i in range(1, train_param['reorder']):
additional_idx = np.zeros(train_param['batch_size'] // train_param['reorder'] * i) - 1
group_indexes.append(np.concatenate([additional_idx, base_idx])[:base_idx.shape[0]])
with tqdm(total=itr_tot + max((val_edge_end - train_edge_end) // train_param['batch_size'] // args.num_gpus, 1) * args.num_gpus) as pbar:
for _, rows in df[:train_edge_end].groupby(group_indexes[random.randint(0, len(group_indexes) - 1)]):
t_tot_s = time.time()
root_nodes = np.concatenate([rows.src.values, rows.dst.values, neg_link_sampler.sample(len(rows))]).astype(np.int32)
ts = np.concatenate([rows.time.values, rows.time.values, rows.time.values]).astype(np.float32)
if sampler is not None:
if 'no_neg' in sample_param and sample_param['no_neg']:
pos_root_end = root_nodes.shape[0] * 2 // 3
sampler.sample(root_nodes[:pos_root_end], ts[:pos_root_end])
else:
sampler.sample(root_nodes, ts)
ret = sampler.get_ret()
time_sample += ret[0].sample_time()
if gnn_param['arch'] != 'identity':
mfgs = to_dgl_blocks(ret, sample_param['history'], cuda=False)
else:
mfgs = node_to_dgl_blocks(root_nodes, ts, cuda=False)
multi_mfgs.append(mfgs)
multi_root.append(root_nodes)
multi_ts.append(ts)
multi_eid.append(rows['Unnamed: 0'].values)
if mailbox is not None and memory_param['deliver_to'] == 'neighbors':
multi_block.append(to_dgl_blocks(ret, sample_param['history'], reverse=True, cuda=False)[0][0])
if len(multi_mfgs) == args.num_gpus:
model_state = [0] * (args.num_gpus + 1)
my_model_state = [None]
torch.distributed.scatter_object_list(my_model_state, model_state, src=args.num_gpus)
multi_mfgs.append(None)
my_mfgs = [None]
torch.distributed.scatter_object_list(my_mfgs, multi_mfgs, src=args.num_gpus)
if mailbox is not None:
multi_root.append(None)
multi_ts.append(None)
multi_eid.append(None)
my_root = [None]
my_ts = [None]
my_eid = [None]
torch.distributed.scatter_object_list(my_root, multi_root, src=args.num_gpus)
torch.distributed.scatter_object_list(my_ts, multi_ts, src=args.num_gpus)
torch.distributed.scatter_object_list(my_eid, multi_eid, src=args.num_gpus)
if memory_param['deliver_to'] == 'neighbors':
multi_block.append(None)
my_block = [None]
torch.distributed.scatter_object_list(my_block, multi_block, src=args.num_gpus)
multi_mfgs = list()
multi_root = list()
multi_ts = list()
multi_eid = list()
multi_block = list()
pbar.update(1)
time_tot += time.time() - t_tot_s
print('Training time:',time_tot)
model_state = [5] * (args.num_gpus + 1)
my_model_state = [None]
torch.distributed.scatter_object_list(my_model_state, model_state, src=args.num_gpus)
gathered_loss = [None] * (args.num_gpus + 1)
torch.distributed.gather_object(float(0), gathered_loss, dst=args.num_gpus)
total_loss = np.sum(np.array(gathered_loss) * train_param['batch_size'])
ap, auc = eval('val')
if ap > best_ap:
best_e = e
best_ap = ap
model_state = [4] * (args.num_gpus + 1)
model_state[0] = 2
my_model_state = [None]
torch.distributed.scatter_object_list(my_model_state, model_state, src=args.num_gpus)
# for memory based models, testing after validation is faster
tap, tauc = eval('test')
print('\ttrain loss:{:.4f} val ap:{:4f} val auc:{:4f}'.format(total_loss, ap, auc))
print('\ttotal time:{:.2f}s sample time:{:.2f}s'.format(time_tot, time_sample))
print('Best model at epoch {}.'.format(best_e))
print('\ttest ap:{:4f} test auc:{:4f}'.format(tap, tauc))
# let all process exit
model_state = [-1] * (args.num_gpus + 1)
my_model_state = [None]
torch.distributed.scatter_object_list(my_model_state, model_state, src=args.num_gpus)
\ No newline at end of file
import argparse
import os
import hashlib
parser=argparse.ArgumentParser()
parser.add_argument('--data', type=str, help='dataset name')
parser.add_argument('--config', type=str, default='', help='path to config file')
parser.add_argument('--batch_size', type=int, default=4000)
parser.add_argument('--epoch', type=int, default=100)
parser.add_argument('--dim', type=int, default=100)
parser.add_argument('--lr', type=float, default=0.001)
parser.add_argument('--gpu', type=str, default='0', help='which GPU to use')
parser.add_argument('--model', type=str, default='', help='name of stored model to load')
parser.add_argument('--posneg', default=False, action='store_true', help='for positive negative detection, whether to sample negative nodes')
args=parser.parse_args()
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
if args.data == 'WIKI' or args.data == 'REDDIT':
args.posneg = True
import torch
import time
import random
import dgl
import numpy as np
import pandas as pd
from modules import *
from sampler import *
from utils import *
from tqdm import tqdm
from sklearn.metrics import average_precision_score, f1_score
ldf = pd.read_csv('DATA/{}/labels.csv'.format(args.data))
role = ldf['ext_roll'].values
# train_node_end = ldf[ldf['ext_roll'].gt(0)].index[0]
# val_node_end = ldf[ldf['ext_roll'].gt(1)].index[0]
labels = ldf['label'].values.astype(np.int64)
emb_file_name = hashlib.md5(str(torch.load(args.model, map_location=torch.device('cpu'))).encode('utf-8')).hexdigest() + '.pt'
if not os.path.isdir('embs'):
os.mkdir('embs')
if not os.path.isfile('embs/' + emb_file_name):
print('Generating temporal embeddings..')
node_feats, edge_feats = load_feat(args.data)
g, df = load_graph(args.data)
sample_param, memory_param, gnn_param, train_param = parse_config(args.config)
train_edge_end = df[df['ext_roll'].gt(0)].index[0]
val_edge_end = df[df['ext_roll'].gt(1)].index[0]
gnn_dim_node = 0 if node_feats is None else node_feats.shape[1]
gnn_dim_edge = 0 if edge_feats is None else edge_feats.shape[1]
combine_first = False
if 'combine_neighs' in train_param and train_param['combine_neighs']:
combine_first = True
model = GeneralModel(gnn_dim_node, gnn_dim_edge, sample_param, memory_param, gnn_param, train_param, combined=combine_first).cuda()
mailbox = MailBox(memory_param, g['indptr'].shape[0] - 1, gnn_dim_edge) if memory_param['type'] != 'none' else None
creterion = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=train_param['lr'])
if 'all_on_gpu' in train_param and train_param['all_on_gpu']:
if node_feats is not None:
node_feats = node_feats.cuda()
if edge_feats is not None:
edge_feats = edge_feats.cuda()
if mailbox is not None:
mailbox.move_to_gpu()
sampler = None
if not ('no_sample' in sample_param and sample_param['no_sample']):
sampler = ParallelSampler(g['indptr'], g['indices'], g['eid'], g['ts'].astype(np.float32),
sample_param['num_thread'], 1, sample_param['layer'], sample_param['neighbor'],
sample_param['strategy']=='recent', sample_param['prop_time'],
sample_param['history'], float(sample_param['duration']))
neg_link_sampler = NegLinkSampler(g['indptr'].shape[0] - 1)
model.load_state_dict(torch.load(args.model))
processed_edge_id = 0
def forward_model_to(time):
global processed_edge_id
if processed_edge_id >= len(df):
return
while df.time[processed_edge_id] < time:
rows = df[processed_edge_id:min(processed_edge_id + train_param['batch_size'], len(df))]
if processed_edge_id < train_edge_end:
model.train()
else:
model.eval()
root_nodes = np.concatenate([rows.src.values, rows.dst.values, neg_link_sampler.sample(len(rows))]).astype(np.int32)
ts = np.concatenate([rows.time.values, rows.time.values, rows.time.values]).astype(np.float32)
if sampler is not None:
if 'no_neg' in sample_param and sample_param['no_neg']:
pos_root_end = root_nodes.shape[0] * 2 // 3
sampler.sample(root_nodes[:pos_root_end], ts[:pos_root_end])
else:
sampler.sample(root_nodes, ts)
ret = sampler.get_ret()
if gnn_param['arch'] != 'identity':
mfgs = to_dgl_blocks(ret, sample_param['history'])
else:
mfgs = node_to_dgl_blocks(root_nodes, ts)
mfgs = prepare_input(mfgs, node_feats, edge_feats, combine_first=combine_first)
if mailbox is not None:
mailbox.prep_input_mails(mfgs[0])
with torch.no_grad():
pred_pos, pred_neg = model(mfgs)
if mailbox is not None:
eid = rows['Unnamed: 0'].values
mem_edge_feats = edge_feats[eid] if edge_feats is not None else None
block = None
if memory_param['deliver_to'] == 'neighbors':
block = to_dgl_blocks(ret, sample_param['history'], reverse=True)[0][0]
mailbox.update_mailbox(model.memory_updater.last_updated_nid, model.memory_updater.last_updated_memory, root_nodes, ts, mem_edge_feats, block)
mailbox.update_memory(model.memory_updater.last_updated_nid, model.memory_updater.last_updated_memory, model.memory_updater.last_updated_ts)
processed_edge_id += train_param['batch_size']
if processed_edge_id >= len(df):
return
def get_node_emb(root_nodes, ts):
forward_model_to(ts[-1])
if sampler is not None:
sampler.sample(root_nodes, ts)
ret = sampler.get_ret()
if gnn_param['arch'] != 'identity':
mfgs = to_dgl_blocks(ret, sample_param['history'])
else:
mfgs = node_to_dgl_blocks(root_nodes, ts)
mfgs = prepare_input(mfgs, node_feats, edge_feats, combine_first=combine_first)
if mailbox is not None:
mailbox.prep_input_mails(mfgs[0])
with torch.no_grad():
ret = model.get_emb(mfgs)
return ret.detach().cpu()
emb = list()
for _, rows in tqdm(ldf.groupby(ldf.index // args.batch_size)):
emb.append(get_node_emb(rows.node.values.astype(np.int32), rows.time.values.astype(np.float32)))
emb = torch.cat(emb, dim=0)
torch.save(emb, 'embs/' + emb_file_name)
print('Saved to embs/' + emb_file_name)
else:
print('Loading temporal embeddings from embs/' + emb_file_name)
emb = torch.load('embs/' + emb_file_name)
model = NodeClassificationModel(emb.shape[1], args.dim, labels.max() + 1).cuda()
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
labels = torch.from_numpy(labels).type(torch.int32)
role = torch.from_numpy(role).type(torch.int32)
emb = emb
class NodeEmbMinibatch():
def __init__(self, emb, role, label, batch_size):
self.role = role
self.label = label
self.batch_size = batch_size
self.train_emb = emb[role == 0]
self.val_emb = emb[role == 1]
self.test_emb = emb[role == 2]
self.train_label = label[role == 0]
self.val_label = label[role == 1]
self.test_label = label[role == 2]
self.mode = 0
self.s_idx = 0
def shuffle(self):
perm = torch.randperm(self.train_emb.shape[0])
self.train_emb = self.train_emb[perm]
self.train_label = self.train_label[perm]
def set_mode(self, mode):
if mode == 'train':
self.mode = 0
elif mode == 'val':
self.mode = 1
elif mode == 'test':
self.mode = 2
self.s_idx = 0
def __iter__(self):
return self
def __next__(self):
if self.mode == 0:
emb = self.train_emb
label = self.train_label
elif self.mode == 1:
emb = self.val_emb
label = self.val_label
else:
emb = self.test_emb
label = self.test_label
if self.s_idx >= emb.shape[0]:
raise StopIteration
else:
end = min(self.s_idx + self.batch_size, emb.shape[0])
curr_emb = emb[self.s_idx:end]
curr_label = label[self.s_idx:end]
self.s_idx += self.batch_size
return curr_emb.cuda(), curr_label.cuda()
if args.posneg:
role = role[labels == 1]
emb_neg = emb[labels == 0].cuda()
emb = emb[labels == 1]
labels = torch.ones(emb.shape[0], dtype=torch.int64).cuda()
labels_neg = torch.zeros(emb_neg.shape[0], dtype=torch.int64).cuda()
neg_node_sampler = NegLinkSampler(emb_neg.shape[0])
minibatch = NodeEmbMinibatch(emb, role, labels, args.batch_size)
if not os.path.isdir('models'):
os.mkdir('models')
save_path = 'models/node_' + args.model.split('/')[-1]
best_e = 0
best_acc = 0
for e in range(args.epoch):
minibatch.set_mode('train')
minibatch.shuffle()
model.train()
for emb, label in minibatch:
optimizer.zero_grad()
if args.posneg:
neg_idx = neg_node_sampler.sample(emb.shape[0])
emb = torch.cat([emb, emb_neg[neg_idx]], dim=0)
label = torch.cat([label, labels_neg[neg_idx]], dim=0)
pred = model(emb)
loss = loss_fn(pred, label.long())
loss.backward()
optimizer.step()
minibatch.set_mode('val')
model.eval()
accs = list()
with torch.no_grad():
for emb, label in minibatch:
if args.posneg:
neg_idx = neg_node_sampler.sample(emb.shape[0])
emb = torch.cat([emb, emb_neg[neg_idx]], dim=0)
label = torch.cat([label, labels_neg[neg_idx]], dim=0)
pred = model(emb)
if args.posneg:
acc = average_precision_score(label.cpu(), pred.softmax(dim=1)[:, 1].cpu())
else:
acc = f1_score(label.cpu(), torch.argmax(pred, dim=1).cpu(), average="micro")
accs.append(acc)
acc = float(torch.tensor(accs).mean())
print('Epoch: {}\tVal acc: {:.4f}'.format(e, acc))
if acc > best_acc:
best_e = e
best_acc = acc
torch.save(model.state_dict(), save_path)
print('Loading model at epoch {}...'.format(best_e))
model.load_state_dict(torch.load(save_path))
minibatch.set_mode('test')
model.eval()
accs = list()
with torch.no_grad():
for emb, label in minibatch:
if args.posneg:
neg_idx = neg_node_sampler.sample(emb.shape[0])
emb = torch.cat([emb, emb_neg[neg_idx]], dim=0)
label = torch.cat([label, labels_neg[neg_idx]], dim=0)
pred = model(emb)
if args.posneg:
acc = average_precision_score(label.cpu(), pred.softmax(dim=1)[:, 1].cpu())
else:
acc = f1_score(label.cpu(), torch.argmax(pred, dim=1).cpu(), average="micro")
accs.append(acc)
acc = float(torch.tensor(accs).mean())
print('Testing acc: {:.4f}'.format(acc))
\ No newline at end of file
import os
import sys
from os.path import abspath, join, dirname
import torch
import yaml
import dgl
import time
import pandas as pd
import numpy as np
def load_feat(d, rand_de=0, rand_dn=0):
node_feats = None
if os.path.exists('DATA/{}/node_features.pt'.format(d)):
node_feats = torch.load('DATA/{}/node_features.pt'.format(d))
if node_feats.dtype == torch.bool:
node_feats = node_feats.type(torch.float32)
edge_feats = None
if os.path.exists('DATA/{}/edge_features.pt'.format(d)):
edge_feats = torch.load('DATA/{}/edge_features.pt'.format(d))
if edge_feats.dtype == torch.bool:
edge_feats = edge_feats.type(torch.float32)
if rand_de > 0:
if d == 'LASTFM':
edge_feats = torch.randn(1293103, rand_de)
elif d == 'MOOC':
edge_feats = torch.randn(411749, rand_de)
if rand_dn > 0:
if d == 'LASTFM':
node_feats = torch.randn(1980, rand_dn)
elif d == 'MOOC':
edge_feats = torch.randn(7144, rand_dn)
return node_feats, edge_feats
def load_graph(d):
df = pd.read_csv('DATA/{}/edges.csv'.format(d))
g = np.load('DATA/{}/ext_full.npz'.format(d))
return g, df
def parse_config(f):
conf = yaml.safe_load(open(f, 'r'))
sample_param = conf['sampling'][0]
memory_param = conf['memory'][0]
gnn_param = conf['gnn'][0]
train_param = conf['train'][0]
return sample_param, memory_param, gnn_param, train_param
def to_dgl_blocks(ret, hist, reverse=False, cuda=True):
mfgs = list()
for r in ret:
if not reverse:
b = dgl.create_block((r.col(), r.row()), num_src_nodes=r.dim_in(), num_dst_nodes=r.dim_out())
b.srcdata['ID'] = torch.from_numpy(r.nodes())
b.edata['dt'] = torch.from_numpy(r.dts())[b.num_dst_nodes():]
b.srcdata['ts'] = torch.from_numpy(r.ts())
else:
b = dgl.create_block((r.row(), r.col()), num_src_nodes=r.dim_out(), num_dst_nodes=r.dim_in())
b.dstdata['ID'] = torch.from_numpy(r.nodes())
b.edata['dt'] = torch.from_numpy(r.dts())[b.num_src_nodes():]
b.dstdata['ts'] = torch.from_numpy(r.ts())
b.edata['ID'] = torch.from_numpy(r.eid())
if cuda:
mfgs.append(b.to('cuda:0'))
else:
mfgs.append(b)
mfgs = list(map(list, zip(*[iter(mfgs)] * hist)))
mfgs.reverse()
return mfgs
def node_to_dgl_blocks(root_nodes, ts, cuda=True):
mfgs = list()
b = dgl.create_block(([],[]), num_src_nodes=root_nodes.shape[0], num_dst_nodes=root_nodes.shape[0])
b.srcdata['ID'] = torch.from_numpy(root_nodes)
b.srcdata['ts'] = torch.from_numpy(ts)
if cuda:
mfgs.insert(0, [b.to('cuda:0')])
else:
mfgs.insert(0, [b])
return mfgs
def mfgs_to_cuda(mfgs):
for mfg in mfgs:
for i in range(len(mfg)):
mfg[i] = mfg[i].to('cuda:0')
return mfgs
def batch_data_prepare_input(batch_data,hist,cuda = True,combine_first=False, pinned=False, nfeat_buffs=None, efeat_buffs=None, type = 'identify'):
mfgs = list()
#e_index = torch.cat(batch_data.edge_index,1).view(-1)
#e_id = torch.cat(batch_data.edge_id,1).view(-1)
#e_ts = torch.cat(batch_data.edge_ts,1).view(-1)
#src_ts =
#pre_torch_id = torch.cat((batch_data.meta_data['src_id'],batch_data.meta_data['dst_pos_id'],batch_data.meta_data['dst_neg_id'],e_index),dim = 0)
#pre_torch_ts = torch.cat((batch_data.roots.ts.repeat(3),e_ts),dim = 0)
#pre_eid = torch.cat((torch.zeros(len(batch_data.roots.ts)*3),e_id),dim = 0)
#pre_src_ts = torch.cat((torch.zeros(len(batch_data.root.ts))))
pre_torch = torch.stack((torch.cat(batch_data.edge_index,1).view(-1),torch.cat(batch_data.edge_ts,1).reshape(-1)),dim = 0)
pre_torch = torch.cat((torch.stack((torch.cat((batch_data.meta_data['src_id'],
batch_data.meta_data['dst_pos_id'],batch_data.meta_data['dst_neg_id']),dim = 0),
batch_data.roots.ts.repeat(3)),dim= 0 ),pre_torch),dim = -1)
uniq,ind = pre_torch.unique(dim = 1, return_inverse = True)
rt = len(batch_data.meta_data['src_id'])
batch_data.meta_data['src_id_pos'] = ind[:rt]
batch_data.meta_data['dst_pos_pos'] = ind[rt:2*rt]
batch_data.meta_data['dst_neg_pos'] = ind[2*rt:3*rt]
maxid = uniq.shape[-1]
ind = ind[rt*3:].reshape(2,-1)
lastindex = torch.zeros(uniq.shape[1],device = torch.device('cuda')).long()
id_l = 0
l = 0
last_src_data = None
last_src_ts = None
for i in range(len(batch_data.eids)):
r = l + len(batch_data.eids[i])
if i == 0:
dst_nodes,pos = torch.cat((batch_data.meta_data['src_id_pos'],
batch_data.meta_data['dst_pos_pos'],batch_data.meta_data['dst_neg_pos']),dim = 0).unique(return_inverse=True)
if cuda:
batch_data.meta_data['src_id_pos'] = pos[:rt].cuda()
batch_data.meta_data['dst_pos_pos'] = pos[rt:rt*2].cuda()
batch_data.meta_data['dst_neg_pos'] = pos[rt*2:rt*3].cuda()
else:
batch_data.meta_data['src_id_pos'] = pos[:rt]
batch_data.meta_data['dst_pos_pos'] = pos[rt:rt*2]
batch_data.meta_data['dst_neg_pos'] = pos[rt*2:rt*3]
dst_nodes_index = torch.arange(len(dst_nodes),device=torch.device('cuda')).long()
lastindex[dst_nodes.long()] = dst_nodes_index
id_l += dst_nodes.shape[0]
last_src_data = uniq[0,dst_nodes]
last_src_ts = uniq[1,dst_nodes]
src_nodes,src_nodes_index = ind[0,l:r].unique(return_inverse = True)
#src_nodes_index += id_l
b = dgl.create_block((src_nodes_index + id_l ,lastindex[ind[1,l:r]]),num_src_nodes = id_l + src_nodes.shape[0],num_dst_nodes = len(dst_nodes_index))
lastindex[src_nodes.long()] = torch.arange(id_l,id_l + src_nodes.shape[0],device=torch.device('cuda')).long()
id_l += src_nodes.shape[0]
else:
src_nodes,src_nodes_index = ind[0,l:r].unique(return_inverse = True)
#src_nodes_index += id_l
b = dgl.create_block((src_nodes_index + id_l,lastindex[ind[1,l:r]]),num_src_nodes = id_l + src_nodes.shape[0],num_dst_nodes = id_l)
lastindex[src_nodes.long()] = torch.arange(id_l,id_l + src_nodes.shape[0],device=torch.device('cuda')).long()
id_l += src_nodes.shape[0]
l = r
#b = dgl.create_block((ind[0,l:r],ind[1,l:r]),num_src_nodes= dim_in,num_dst_nodes= dim_out)
last_src_data = torch.cat((last_src_data, uniq[0,src_nodes]),-1)
last_src_ts = torch.cat((last_src_ts, uniq[1,src_nodes]),-1)
b.srcdata['ID'] = last_src_data
b.srcdata['ts'] = last_src_ts
batch_data.edge_ts[i]=batch_data.edge_ts[i].view(2,-1)
b.edata['dt'] = batch_data.edge_ts[i][1,:] - batch_data.edge_ts[i][0,:]
b.edata['ID'] = batch_data.eids[i]
if cuda:
b = b.to('cuda')
batch_data.meta_data['src_id_pos'].to('cuda')
batch_data.meta_data['dst_pos_pos'].to('cuda')
batch_data.meta_data['dst_neg_pos'].to('cuda')
j = 0
if(batch_data.x is not None and i == len(batch_data.eids)-1 ):
if pinned:
idx = b.srcdata['ID'].cpu().long()
torch.index_select(batch_data.x, 0, idx, out=nfeat_buffs[i][:idx.shape[0]])
b.srcdata['h'] = nfeat_buffs[j][:idx.shape[0]].cuda(non_blocking=True)
j += 1
else:
srch = batch_data.x[b.srcdata['ID'].long()].float()
b.srcdata['h'] = srch.cuda() if cuda else srch
j = 0
if(batch_data.edge_attr is not None):
if pinned:
idx = b.edata['ID'].cpu().long()
efeat_buffs[j][idx.shape[0]] = batch_data.edge_attr[idx]
b.edata['f'] = efeat_buffs[i][:idx.shape[0]].cuda(non_blocking=True)
j += 1
else:
srch = batch_data.edge_attr[b.edata['ID'].long()].float()
b.edata['f'] = srch.cuda() if cuda else srch
mfgs.append(b)
mfgs = list(map(list, zip(*[iter(mfgs)] * hist)))
mfgs.reverse()
return mfgs,batch_data.meta_data,uniq[0,:].long()
def prepare_input(mfgs, node_feats, edge_feats, combine_first=False, pinned=False, nfeat_buffs=None, efeat_buffs=None, nids=None, eids=None):
if combine_first:
for i in range(len(mfgs[0])):
if mfgs[0][i].num_src_nodes() > mfgs[0][i].num_dst_nodes():
num_dst = mfgs[0][i].num_dst_nodes()
ts = mfgs[0][i].srcdata['ts'][num_dst:]
nid = mfgs[0][i].srcdata['ID'][num_dst:].float()
nts = torch.stack([ts, nid], dim=1)
unts, idx = torch.unique(nts, dim=0, return_inverse=True)
uts = unts[:, 0]
unid = unts[:, 1]
# import pdb; pdb.set_trace()
b = dgl.create_block((idx + num_dst, mfgs[0][i].edges()[1]), num_src_nodes=unts.shape[0] + num_dst, num_dst_nodes=num_dst, device=torch.device('cuda:0'))
b.srcdata['ts'] = torch.cat([mfgs[0][i].srcdata['ts'][:num_dst], uts], dim=0)
b.srcdata['ID'] = torch.cat([mfgs[0][i].srcdata['ID'][:num_dst], unid], dim=0)
b.edata['dt'] = mfgs[0][i].edata['dt']
b.edata['ID'] = mfgs[0][i].edata['ID']
mfgs[0][i] = b
t_idx = 0
t_cuda = 0
i = 0
if node_feats is not None:
for b in mfgs[0]:
if pinned:
if nids is not None:
idx = nids[i]
else:
idx = b.srcdata['ID'].cpu().long()
torch.index_select(node_feats, 0, idx, out=nfeat_buffs[i][:idx.shape[0]])
b.srcdata['h'] = nfeat_buffs[i][:idx.shape[0]].cuda(non_blocking=True)
i += 1
else:
srch = node_feats[b.srcdata['ID'].long()].float()
b.srcdata['h'] = srch.cuda()
i = 0
if edge_feats is not None:
for mfg in mfgs:
for b in mfg:
if b.num_src_nodes() > b.num_dst_nodes():
if pinned:
if eids is not None:
idx = eids[i]
else:
idx = b.edata['ID'].cpu().long()
torch.index_select(edge_feats, 0, idx, out=efeat_buffs[i][:idx.shape[0]])
b.edata['f'] = [i][:idx.shape[0]].cuda(non_blocking=True)
i += 1
else:
srch = edge_feats[b.edata['ID'].long()].float()
b.edata['f'] = srch.cuda()
return mfgs
def get_ids(mfgs, node_feats, edge_feats):
nids = list()
eids = list()
if node_feats is not None:
for b in mfgs[0]:
nids.append(b.srcdata['ID'].long())
if edge_feats is not None:
for mfg in mfgs:
for b in mfg:
eids.append(b.edata['ID'].long())
return nids, eids
def get_pinned_buffers(sample_param, batch_size, node_feats, edge_feats):
pinned_nfeat_buffs = list()
pinned_efeat_buffs = list()
limit = int(batch_size * 3.3)
if 'neighbor' in sample_param:
for i in sample_param['neighbor']:
limit *= i + 1
if edge_feats is not None:
for _ in range(sample_param['history']):
pinned_efeat_buffs.insert(0, torch.zeros((limit, edge_feats.shape[1]), pin_memory=True))
if node_feats is not None:
for _ in range(sample_param['history']):
pinned_nfeat_buffs.insert(0, torch.zeros((limit, node_feats.shape[1]), pin_memory=True))
return pinned_nfeat_buffs, pinned_efeat_buffs
This source diff could not be displayed because it is too large. You can view the blob instead.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment