Commit 47ea208c by zhlj

demo with tgl

parents
# Configuration files
.idea/
# Data
startGNN_sample/cora/
startGNN_sample/wiki/
startGNN_sample/GDELT/
startGNN_sample/reddit/
startGNN_sample/*/metis_*/
tgl/DATA/
dataset/
bak_startGNN_sample/
*.zip
*.tar.gz
*.npy
*.csv
*.my
# Logs and output
logs/
log/
*.pt
# Byte-compiled / optimized / DLL files
__pycache__/
*/__pycache__/
*.py[cod]
*$py
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# Other
*.png
{
// Use IntelliSense to learn about possible attributes.
// Hover to view descriptions of existing attributes.
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"configurations": [
{
"name": "Python: Current File",
"type": "python",
"request": "launch",
"program": "${file}",
"console": "integratedTerminal",
"args":["--world","1","--rank","0"]
}
]
}
\ No newline at end of file
This diff is collapsed. Click to expand it.
import argparse
import os
import sys
from os.path import abspath, join, dirname
sys.path.insert(0, join(abspath(dirname(__file__)+'/startGNN_sample')))
from startGNN_sample.DistGraphLoader import DataSet, partition_load
from startGNN_sample.Sample.neighbor_sampler import NeighborSampler
from startGNN_sample.Sample.base import NegativeSampling
#path1=os.path.abspath('.')
import torch
from startGNN_sample.DistGraphLoader import DistGraphData
from startGNN_sample.DistGraphLoader import DistributedDataLoader
from startGNN_sample.DistGraphLoader import DistCustomPool
import startGNN_sample.distparser as distparser
import time
import torch
import torch.nn.functional as F
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group
from startGNN_sample.shared_mailbox import SharedMailBox, SharedRPCMemoryManager
import os
"""
test command
python test.py --world_size 2 --rank 0
--world_size', default=4, type=int, metavar='W',
help='number of workers')
parser.add_argument('--rank', default=0, type=int, metavar='W',
help='rank of the worker')
parser.add_argument('--log_interval', type=int, default=10, metavar='N',
help='interval between training status logs')
parser.add_argument('--gamma', type=float, default=0.99, metavar='G',
help='how much to value future rewards')
parser.add_argument('--seed', type=int, default=1, metavar='S',
help='random seed for reproducibility')
parser.add_argument('--num_sampler', type=int, default=10, metavar='S',
help='number of samplers')
parser.add_argument('--queue_size', type=int, default=10, metavar='S',
help='sampler queue size')
"""
from tgl.modules import *
from tgl.sampler import *
from tgl.utils import *
from sklearn.metrics import average_precision_score, roc_auc_score
import torch
import time
import random
import dgl
import numpy as np
from sklearn.metrics import average_precision_score, roc_auc_score
from torch.nn.parallel import DistributedDataParallel as DDP
import matplotlib.pyplot as plt
import math
def main():
sample_param, memory_param, gnn_param, train_param = parse_config('tgl/config/TGN.yml')
use_cuda = True
torch.set_num_threads(12)
args = distparser.args
rank = distparser._get_worker_rank()
if use_cuda:
torch.cuda.set_device(rank)
DistCustomPool.init_distribution('127.0.0.1',9675,'127.0.0.1',10023,backend = "nccl")
pdata = partition_load("./startGNN_sample/GDELT", algo="metis")
graph = DistGraphData(pdata = pdata,edge_index= pdata.edge_index, full_edge = False)
print(graph.edge_index[0,:].shape[0],graph.partptr)
value,counts = torch.unique(graph.edge_index[1,:].view(-1),return_counts=True)
counts,ind = counts.sort(descending=True)
#print(counts[:500])
plt.plot(counts)
plt.savefig('degree-{}.png'.format(rank))
#print(counts)
#node = counts
#plt.hist(node)
#plt.title('Histogram of Random Data')
#plt.xlabel('Value')
#plt.ylabel('Freuency')
print(counts[:2048])
##print(node.shape[1],node.unique(dim=1).shape[1])
# 添加标题和轴标签
#plt.title('Histogram of Random Data')
#plt.xlabel('Value')
#plt.ylabel('Frequency')
DistCustomPool.close_distribution()
if __name__ == "__main__":
main()
startGNN_sample @ a2522497
Subproject commit a2522497a20e77733d562db176b83475a5224610
{
"configurations": [
{
"name": "linux-gcc-x64",
"includePath": [
"${workspaceFolder}/**"
],
"compilerPath": "/usr/bin/gcc",
"cStandard": "${default}",
"cppStandard": "${default}",
"intelliSenseMode": "linux-gcc-x64",
"compilerArgs": [
""
]
}
],
"version": 4
}
\ No newline at end of file
{
"version": "0.2.0",
"configurations": [
/*{
"name": "C/C++ Runner: Debug Session",
"type": "cppdbg",
"request": "launch",
"args": [],
"stopAtEntry": false,
"externalConsole": false,
"cwd": "/home/sxx/zlj/tgl/tgl-main",
"program": "/home/sxx/zlj/tgl/tgl-main/build/Debug/outDebug",
"MIMode": "gdb",
"miDebuggerPath": "gdb",
"setupCommands": [
{
"description": "Enable pretty-printing for gdb",
"text": "-enable-pretty-printing",
"ignoreFailures": true
}
]
},*/
{
"name": "Python: Current File",
"type": "python",
"request": "launch",
"program": "${file}",
"console": "integratedTerminal",
"justMyCode": false,
"args": [
"--data","/WIKI","--config","config/TGN.yml"
]
}
]
}
\ No newline at end of file
{
"files.associations": {
"cctype": "cpp",
"cmath": "cpp",
"cstdarg": "cpp",
"cstddef": "cpp",
"cstdio": "cpp",
"cstdlib": "cpp",
"cstring": "cpp",
"ctime": "cpp",
"cwchar": "cpp",
"cwctype": "cpp",
"array": "cpp",
"atomic": "cpp",
"*.tcc": "cpp",
"bitset": "cpp",
"chrono": "cpp",
"cstdint": "cpp",
"unordered_map": "cpp",
"vector": "cpp",
"exception": "cpp",
"algorithm": "cpp",
"functional": "cpp",
"iterator": "cpp",
"memory": "cpp",
"memory_resource": "cpp",
"numeric": "cpp",
"random": "cpp",
"string": "cpp",
"string_view": "cpp",
"system_error": "cpp",
"tuple": "cpp",
"type_traits": "cpp",
"utility": "cpp",
"initializer_list": "cpp",
"iosfwd": "cpp",
"iostream": "cpp",
"istream": "cpp",
"limits": "cpp",
"mutex": "cpp",
"new": "cpp",
"ostream": "cpp",
"stdexcept": "cpp",
"streambuf": "cpp",
"typeinfo": "cpp",
"bit": "cpp",
"__bit_reference": "cpp",
"__bits": "cpp",
"__config": "cpp",
"__debug": "cpp",
"__errc": "cpp",
"__locale": "cpp",
"__split_buffer": "cpp",
"__threading_support": "cpp",
"__tuple": "cpp",
"__verbose_abort": "cpp",
"ios": "cpp",
"locale": "cpp"
},
"C_Cpp_Runner.cCompilerPath": "gcc",
"C_Cpp_Runner.cppCompilerPath": "g++",
"C_Cpp_Runner.debuggerPath": "gdb",
"C_Cpp_Runner.cStandard": "",
"C_Cpp_Runner.cppStandard": "",
"C_Cpp_Runner.msvcBatchPath": "",
"C_Cpp_Runner.useMsvc": false,
"C_Cpp_Runner.warnings": [
"-Wall",
"-Wextra",
"-Wpedantic",
"-Wshadow",
"-Wformat=2",
"-Wconversion",
"-Wnull-dereference",
"-Wsign-conversion"
],
"C_Cpp_Runner.enableWarnings": true,
"C_Cpp_Runner.warningsAsError": false,
"C_Cpp_Runner.compilerArgs": [],
"C_Cpp_Runner.linkerArgs": [],
"C_Cpp_Runner.includePaths": [],
"C_Cpp_Runner.includeSearch": [
"*",
"**/*"
],
"C_Cpp_Runner.excludeSearch": [
"**/build",
"**/build/**",
"**/.*",
"**/.*/**",
"**/.vscode",
"**/.vscode/**"
]
}
\ No newline at end of file
## Code of Conduct
This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct).
For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact
opensource-codeofconduct@amazon.com with any additional questions or comments.
# Contributing Guidelines
Thank you for your interest in contributing to our project. Whether it's a bug report, new feature, correction, or additional
documentation, we greatly value feedback and contributions from our community.
Please read through this document before submitting any issues or pull requests to ensure we have all the necessary
information to effectively respond to your bug report or contribution.
## Reporting Bugs/Feature Requests
We welcome you to use the GitHub issue tracker to report bugs or suggest features.
When filing an issue, please check existing open, or recently closed, issues to make sure somebody else hasn't already
reported the issue. Please try to include as much information as you can. Details like these are incredibly useful:
* A reproducible test case or series of steps
* The version of our code being used
* Any modifications you've made relevant to the bug
* Anything unusual about your environment or deployment
## Contributing via Pull Requests
Contributions via pull requests are much appreciated. Before sending us a pull request, please ensure that:
1. You are working against the latest source on the *main* branch.
2. You check existing open, and recently merged, pull requests to make sure someone else hasn't addressed the problem already.
3. You open an issue to discuss any significant work - we would hate for your time to be wasted.
To send us a pull request, please:
1. Fork the repository.
2. Modify the source; please focus on the specific change you are contributing. If you also reformat all the code, it will be hard for us to focus on your change.
3. Ensure local tests pass.
4. Commit to your fork using clear commit messages.
5. Send us a pull request, answering any default questions in the pull request interface.
6. Pay attention to any automated CI failures reported in the pull request, and stay involved in the conversation.
GitHub provides additional document on [forking a repository](https://help.github.com/articles/fork-a-repo/) and
[creating a pull request](https://help.github.com/articles/creating-a-pull-request/).
## Finding contributions to work on
Looking at the existing issues is a great way to find something to contribute on. As our projects, by default, use the default GitHub issue labels (enhancement/bug/duplicate/help wanted/invalid/question/wontfix), looking at any 'help wanted' issues is a great place to start.
## Code of Conduct
This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct).
For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact
opensource-codeofconduct@amazon.com with any additional questions or comments.
## Security issue notifications
If you discover a potential security issue in this project we ask that you notify AWS/Amazon Security via our [vulnerability reporting page](http://aws.amazon.com/security/vulnerability-reporting/). Please do **not** create a public github issue.
## Licensing
See the [LICENSE](LICENSE) file for our project's licensing. We will ask you to confirm the licensing of your contribution.
This diff is collapsed. Click to expand it.
Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# TGL: A General Framework for Temporal Graph Training on Billion-Scale Graphs
## Overview
This repo is the open-sourced code for our work *TGL: A General Framework for Temporal Graph Training on Billion-Scale Graphs*.
## Requirements
- python >= 3.6.13
- pytorch >= 1.8.1
- pandas >= 1.1.5
- numpy >= 1.19.5
- dgl >= 0.6.1
- pyyaml >= 5.4.1
- tqdm >= 4.61.0
- pybind11 >= 2.6.2
- g++ >= 7.5.0
- openmp >= 201511
Our temporal sampler is implemented using C++, please compile the sampler first with the following command
> python setup.py build_ext --inplace
## Dataset
The four datasets used in our paper are available to download from AWS S3 bucket using the `down.sh` script. The total download size is around 350GB.
To use your own dataset, you need to put the following files in the folder `\DATA\\<NameOfYourDataset>\`
1. `edges.csv`: The file that stores temporal edge informations. The csv should have the following columns with the header as `,src,dst,time,ext_roll` where each of the column refers to edge index (start with zero), source node index (start with zero), destination node index, time stamp, extrapolation roll (0 for training edges, 1 for validation edges, 2 for test edges). The CSV should be sorted by time ascendingly.
2. `ext_full.npz`: The T-CSR representation of the temporal graph. We provide a script to generate this file from `edges.csv`. You can use the following command to use the script
>python gen_graph.py --data \<NameOfYourDataset>
3. `edge_features.pt` (optional): The torch tensor that stores the edge featrues row-wise with shape (num edges, dim edge features). *Note: at least one of `edge_features.pt` or `node_features.pt` should present.*
4. `node_features.pt` (optional): The torch tensor that stores the node featrues row-wise with shape (num nodes, dim node features). *Note: at least one of `edge_features.pt` or `node_features.pt` should present.*
5. `labels.csv` (optional): The file contains node labels for dynamic node classification task. The csv should have the following columns with the header as `,node,time,label,ext_roll` where each of the column refers to node label index (start with zero), node index (start with zero), time stamp, node label, extrapolation roll (0 for training node labels, 1 for validation node labels, 2 for test node labels). The CSV should be sorted by time ascendingly.
## Configuration Files
We provide example configuration files for five temporal GNN methods: JODIE, DySAT, TGAT, TGN and TGAT. The configuration files for single GPU training are located at `/config/` while the multiple GPUs training configuration files are located at `/config/dist/`.
The provided configuration files are all tested to be working. If you want to use your own network architecture, please refer to `/config/readme.yml` for the meaining of each entry in the yaml configuration file. As our framework is still under development, it possible that some combination of the confiruations will lead to bug.
## Run
Currently, our framework only supports extrapolation setting (inference for the future).
### Single GPU Link Prediction
>python train.py --data \<NameOfYourDataset> --config \<PathToConfigFile>
### MultiGPU Link Prediction
>python -m torch.distributed.launch --nproc_per_node=\<NumberOfGPUs+1> train_dist.py --data \<NameOfYourDataset> --config \<PathToConfigFile> --num_gpus \<NumberOfGPUs>
### Dynamic Node Classification
Currenlty, TGL only supports performing dynamic node classification using the dynamic node embedding generated in link prediction.
For Single GPU models, directly run
>python train_node.py --data \<NameOfYourDATA> --config \<PathToConfigFile> --model \<PathToSavedModel>
For multi-GPU models, you need to first generate the dynamic node embedding
>python -m torch.distributed.launch --nproc_per_node=\<NumberOfGPUs+1> extract_node_dist.py --data \<NameOfYourDataset> --config \<PathToConfigFile> --num_gpus \<NumberOfGPUs> --model \<PathToSavedModel>
After generating the node embeding for multi-GPU models, run
>python train_node.py --data \<NameOfYourDATA> --model \<PathToSavedModel>
## Security
See [CONTRIBUTING](CONTRIBUTING.md#security-issue-notifications) for more information.
## Cite
If you use TGL in a scientific publication, we would appreciate citations to the following paper:
```
@article{zhou2022tgl,
title={{TGL}: A General Framework for Temporal GNN Training on Billion-Scale Graphs},
author={Zhou, Hongkuan and Zheng, Da and Nisa, Israt and Ioannidis, Vasileios and Song, Xiang and Karypis, George},
year = {2022},
journal = {Proc. VLDB Endow.},
volume = {15},
number = {8},
}
```
## License
This project is licensed under the Apache-2.0 License.
sampling:
- layer: 1
neighbor:
- 10
strategy: 'recent'
prop_time: False
history: 1
duration: 0
num_thread: 32
no_neg: True
memory:
- type: 'node'
dim_time: 100
deliver_to: 'neighbors'
mail_combine: 'last'
memory_update: 'transformer'
attention_head: 2
mailbox_size: 10
combine_node_feature: False
dim_out: 100
gnn:
- arch: 'identity'
train:
- epoch: 100
batch_size: 600
lr: 0.0001
dropout: 0.1
att_dropout: 0.1
# all_on_gpu: True
\ No newline at end of file
sampling:
- layer: 2
neighbor:
- 10
- 10
strategy: 'uniform'
prop_time: True
history: 3
duration: 10000
num_thread: 32
memory:
- type: 'none'
dim_out: 0
gnn:
- arch: 'transformer_attention'
layer: 2
att_head: 2
dim_time: 0
dim_out: 100
combine: 'rnn'
train:
- epoch: 50
batch_size: 600
lr: 0.0001
dropout: 0.1
att_dropout: 0.1
all_on_gpu: True
\ No newline at end of file
sampling:
- no_sample: True
history: 1
memory:
- type: 'node'
dim_time: 100
deliver_to: 'self'
mail_combine: 'last'
memory_update: 'rnn'
mailbox_size: 1
combine_node_feature: True
dim_out: 100
gnn:
- arch: 'identity'
time_transform: 'JODIE'
train:
- epoch: 100
batch_size: 600
lr: 0.0001
dropout: 0.1
all_on_gpu: True
\ No newline at end of file
sampling:
- layer: 2
neighbor:
- 10
- 10
strategy: 'uniform'
prop_time: False
history: 1
duration: 0
num_thread: 32
memory:
- type: 'none'
dim_out: 0
gnn:
- arch: 'transformer_attention'
layer: 2
att_head: 2
dim_time: 100
dim_out: 100
train:
- epoch: 10
batch_size: 600
lr: 0.0001
dropout: 0.1
att_dropout: 0.1
all_on_gpu: True
\ No newline at end of file
sampling:
- layer: 1
neighbor:
- 10
strategy: 'recent'
prop_time: False
history: 1
duration: 0
num_thread: 32
memory:
- type: 'node'
dim_time: 100
deliver_to: 'self'
mail_combine: 'last'
memory_update: 'gru'
mailbox_size: 1
combine_node_feature: True
dim_out: 100
gnn:
- arch: 'transformer_attention'
layer: 1
att_head: 2
dim_time: 100
dim_out: 100
train:
- epoch: 4
#batch_size: 100
# reorder: 16
lr: 0.0001
dropout: 0.2
att_dropout: 0.2
all_on_gpu: True
\ No newline at end of file
sampling:
- layer: <number of layers to sample>
neighbor: <a list of integers indicating how many neighbors are sampled in each layer>
strategy: <'recent' that samples most recent neighbors or 'uniform' that uniformly samples neighbors form the past>
prop_time: <False or True that specifies wherether to use the timestamp of the root nodes when sampling for their multi-hop neighbors>
history: <number of snapshots to sample on>
duration: <length in time of each snapshot, 0 for infinite length (used in non-snapshot-based methods)
num_thread: <number of threads of the sampler>
memory:
- type: <'node', we only support node memory now>
dim_time: <an integer, the dimension of the time embedding>
deliver_to: <'self' that delivers the mails only to involved nodes or 'neighbors' that deliver the mails to neighbors>
mail_combine: <'last' that use the latest latest mail as the input to the memory updater>
memory_update: <'gru' or 'rnn'>
mailbox_size: <an integer, the size of the mailbox for each node>
combine_node_feature: <False or True that specifies whether to combine node features (with the updated memory) as the input to the GNN.
dim_out: <an integer, the dimension of the output node memory>
gnn:
- arch: <'transformer_attention' or 'identity' (no GNN)>
layer: <an integer, number of layers>
att_head: <an integer, number of attention heads>
dim_time: <an integer, the dimension of the time embedding>
dim_out: <an integer, the dimension of the output dynamic node embedding>
train:
- epoch: <an integer, number of epochs to train>
batch_size: <an integer, the batch size (of edges); for multi-gpu training, this is the local batchsize>
reorder: <(optional) an integer that is divisible by batch size the specifies how many chunks per batch used in the random chunk scheduling>
lr: <floating point, learning rate>
dropout: <floating point, dropout>
att_dropout: <floating point, dropout for attention>
all_on_gpu: <False or True that decides if the node/edge features and node memory are completely stored on GPU>
\ No newline at end of file
#!/bin/bash
wget -P ./DATA/GDELT https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/GDELT/int_train.npz
wget -P ./DATA/GDELT https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/GDELT/int_full.npz
wget -P ./DATA/GDELT https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/GDELT/node_features.pt
wget -P ./DATA/GDELT https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/GDELT/labels.csv
wget -P ./DATA/GDELT https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/GDELT/ext_full.npz
wget -P ./DATA/GDELT https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/GDELT/edges.csv
wget -P ./DATA/GDELT https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/GDELT/edge_features.pt
wget -P ./DATA/LASTFM https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/LASTFM/edges.csv
wget -P ./DATA/LASTFM https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/LASTFM/ext_full.npz
wget -P ./DATA/LASTFM https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/LASTFM/int_full.npz
wget -P ./DATA/LASTFM https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/LASTFM/int_train.npz
wget -P ./DATA/MAG https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/MAG/int_train.npz
wget -P ./DATA/MAG https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/MAG/labels.csv
wget -P ./DATA/MAG https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/MAG/int_full.npz .)
wget -P ./DATA/MAG https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/MAG/ext_full.npz
wget -P ./DATA/MAG https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/MAG/edges.csv
wget -P ./DATA/MAG https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/MAG/node_features.pt
wget -P ./DATA/MOOC https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/MOOC/edges.csv
wget -P ./DATA/MOOC https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/MOOC/ext_full.npz
wget -P ./DATA/MOOC https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/MOOC/int_full.npz
wget -P ./DATA/MOOC https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/MOOC/int_train.npz
wget -P ./DATA/REDDIT https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/REDDIT/edge_features.pt
wget -P ./DATA/REDDIT https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/REDDIT/edges.csv
wget -P ./DATA/REDDIT https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/REDDIT/ext_full.npz
wget -P ./DATA/REDDIT https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/REDDIT/int_full.npz
wget -P ./DATA/REDDIT https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/REDDIT/int_train.npz
wget -P ./DATA/REDDIT https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/REDDIT/labels.csv
wget -P ./DATA/WIKI https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/WIKI/edge_features.pt
wget -P ./DATA/WIKI https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/WIKI/edges.csv
wget -P ./DATA/WIKI https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/WIKI/ext_full.npz
wget -P ./DATA/WIKI https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/WIKI/int_full.npz
wget -P ./DATA/WIKI https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/WIKI/int_train.npz
wget -P ./DATA/WIKI https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/WIKI/labels.csv
\ No newline at end of file
#!/bin/bash
wget -P ./DATA/GDELT https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/GDELT/int_train.npz
wget -P ./DATA/GDELT https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/GDELT/int_full.npz
wget -P ./DATA/GDELT https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/GDELT/node_features.pt
wget -P ./DATA/GDELT https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/GDELT/labels.csv
wget -P ./DATA/GDELT https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/GDELT/ext_full.npz
wget -P ./DATA/GDELT https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/GDELT/edges.csv
wget -P ./DATA/GDELT https://s3.us-west-2.amazonaws.com/dgl-data/dataset/tgl/GDELT/edge_features.pt
\ No newline at end of file
import argparse
import itertools
import pandas as pd
import numpy as np
from tqdm import tqdm
parser=argparse.ArgumentParser()
parser.add_argument('--data', type=str, help='dataset name')
parser.add_argument('--add_reverse', default=False, action='store_true')
args=parser.parse_args()
df = pd.read_csv('DATA/{}/edges.csv'.format(args.data))
num_nodes = max(int(df['src'].max()), int(df['dst'].max())) + 1
print('num_nodes: ', num_nodes)
int_train_indptr = np.zeros(num_nodes + 1, dtype=np.int)
int_train_indices = [[] for _ in range(num_nodes)]
int_train_ts = [[] for _ in range(num_nodes)]
int_train_eid = [[] for _ in range(num_nodes)]
int_full_indptr = np.zeros(num_nodes + 1, dtype=np.int)
int_full_indices = [[] for _ in range(num_nodes)]
int_full_ts = [[] for _ in range(num_nodes)]
int_full_eid = [[] for _ in range(num_nodes)]
ext_full_indptr = np.zeros(num_nodes + 1, dtype=np.int)
ext_full_indices = [[] for _ in range(num_nodes)]
ext_full_ts = [[] for _ in range(num_nodes)]
ext_full_eid = [[] for _ in range(num_nodes)]
for idx, row in tqdm(df.iterrows(), total=len(df)):
src = int(row['src'])
dst = int(row['dst'])
if row['int_roll'] == 0:
int_train_indices[src].append(dst)
int_train_ts[src].append(row['time'])
int_train_eid[src].append(idx)
if args.add_reverse:
int_train_indices[dst].append(src)
int_train_ts[dst].append(row['time'])
int_train_eid[dst].append(idx)
# int_train_indptr[src + 1:] += 1
if row['int_roll'] != 3:
int_full_indices[src].append(dst)
int_full_ts[src].append(row['time'])
int_full_eid[src].append(idx)
if args.add_reverse:
int_full_indices[dst].append(src)
int_full_ts[dst].append(row['time'])
int_full_eid[dst].append(idx)
# int_full_indptr[src + 1:] += 1
ext_full_indices[src].append(dst)
ext_full_ts[src].append(row['time'])
ext_full_eid[src].append(idx)
if args.add_reverse:
ext_full_indices[dst].append(src)
ext_full_ts[dst].append(row['time'])
ext_full_eid[dst].append(idx)
# ext_full_indptr[src + 1:] += 1
for i in tqdm(range(num_nodes)):
int_train_indptr[i + 1] = int_train_indptr[i] + len(int_train_indices[i])
int_full_indptr[i + 1] = int_full_indptr[i] + len(int_full_indices[i])
ext_full_indptr[i + 1] = ext_full_indptr[i] + len(ext_full_indices[i])
int_train_indices = np.array(list(itertools.chain(*int_train_indices)))
int_train_ts = np.array(list(itertools.chain(*int_train_ts)))
int_train_eid = np.array(list(itertools.chain(*int_train_eid)))
int_full_indices = np.array(list(itertools.chain(*int_full_indices)))
int_full_ts = np.array(list(itertools.chain(*int_full_ts)))
int_full_eid = np.array(list(itertools.chain(*int_full_eid)))
ext_full_indices = np.array(list(itertools.chain(*ext_full_indices)))
ext_full_ts = np.array(list(itertools.chain(*ext_full_ts)))
ext_full_eid = np.array(list(itertools.chain(*ext_full_eid)))
print('Sorting...')
def tsort(i, indptr, indices, t, eid):
beg = indptr[i]
end = indptr[i + 1]
sidx = np.argsort(t[beg:end])
indices[beg:end] = indices[beg:end][sidx]
t[beg:end] = t[beg:end][sidx]
eid[beg:end] = eid[beg:end][sidx]
for i in tqdm(range(int_train_indptr.shape[0] - 1)):
tsort(i, int_train_indptr, int_train_indices, int_train_ts, int_train_eid)
tsort(i, int_full_indptr, int_full_indices, int_full_ts, int_full_eid)
tsort(i, ext_full_indptr, ext_full_indices, ext_full_ts, ext_full_eid)
# import pdb; pdb.set_trace()
print('saving...')
np.savez('DATA/{}/int_train.npz'.format(args.data), indptr=int_train_indptr, indices=int_train_indices, ts=int_train_ts, eid=int_train_eid)
np.savez('DATA/{}/int_full.npz'.format(args.data), indptr=int_full_indptr, indices=int_full_indices, ts=int_full_ts, eid=int_full_eid)
np.savez('DATA/{}/ext_full.npz'.format(args.data), indptr=ext_full_indptr, indices=ext_full_indices, ts=ext_full_ts, eid=ext_full_eid)
\ No newline at end of file
name: gnn
channels:
- pyg
- pytorch
- conda-forge
- defaults
dependencies:
- _libgcc_mutex=0.1=main
- _openmp_mutex=5.1=1_gnu
- anyio=3.6.2=pyhd8ed1ab_0
- appdirs=1.4.4=pyhd3eb1b0_0
- argon2-cffi=21.3.0=pyhd8ed1ab_0
- argon2-cffi-bindings=21.2.0=py38h0a891b7_2
- asttokens=2.2.1=pyhd8ed1ab_0
- attrs=22.2.0=pyh71513ae_0
- backcall=0.2.0=pyh9f0ad1d_0
- backports=1.0=pyhd8ed1ab_3
- backports.functools_lru_cache=1.6.4=pyhd8ed1ab_0
- beautifulsoup4=4.11.2=pyha770c72_0
- blas=1.0=mkl
- bleach=6.0.0=pyhd8ed1ab_0
- brotlipy=0.7.0=py38h0a891b7_1004
- bzip2=1.0.8=h7f98852_4
- ca-certificates=2022.12.7=ha878542_0
- cffi=1.15.0=py38h3931269_0
- charset-normalizer=2.1.1=pyhd8ed1ab_0
- cryptography=37.0.2=py38h2b5fc30_0
- cudatoolkit=11.6.0=hecad31d_10
- decorator=5.1.1=pyhd8ed1ab_0
- defusedxml=0.7.1=pyhd8ed1ab_0
- entrypoints=0.4=pyhd8ed1ab_0
- executing=1.2.0=pyhd8ed1ab_0
- ffmpeg=4.3=hf484d3e_0
- fftw=3.3.9=h27cfd23_1
- flit-core=3.8.0=pyhd8ed1ab_0
- freetype=2.10.4=h0708190_1
- gmp=6.2.1=h58526e2_0
- gnutls=3.6.13=h85f3911_1
- idna=3.4=pyhd8ed1ab_0
- importlib-metadata=6.0.0=pyha770c72_0
- importlib_resources=5.10.2=pyhd8ed1ab_0
- intel-openmp=2021.4.0=h06a4308_3561
- ipykernel=5.5.5=py38hd0cf306_0
- ipython=8.10.0=pyh41d4057_0
- ipython_genutils=0.2.0=py_1
- ipywidgets=8.0.4=pyhd8ed1ab_0
- jedi=0.18.2=pyhd8ed1ab_0
- jinja2=3.1.2=py38h06a4308_0
- joblib=1.1.1=py38h06a4308_0
- jpeg=9e=h166bdaf_1
- jsonschema=4.17.3=pyhd8ed1ab_0
- jupyter_client=7.0.6=pyhd8ed1ab_0
- jupyter_core=5.2.0=py38h578d9bd_0
- jupyter_server=1.23.5=pyhd8ed1ab_0
- jupyterlab_pygments=0.2.2=pyhd8ed1ab_0
- jupyterlab_widgets=3.0.5=pyhd8ed1ab_0
- lame=3.100=h7f98852_1001
- lcms2=2.12=hddcbb42_0
- ld_impl_linux-64=2.38=h1181459_1
- libffi=3.4.2=h6a678d5_6
- libgcc-ng=11.2.0=h1234567_1
- libgfortran-ng=11.2.0=h00389a5_1
- libgfortran5=11.2.0=h1234567_1
- libgomp=11.2.0=h1234567_1
- libiconv=1.17=h166bdaf_0
- libpng=1.6.37=h21135ba_2
- libsodium=1.0.18=h36c2ea0_1
- libstdcxx-ng=11.2.0=h1234567_1
- libtiff=4.2.0=hf544144_3
- libwebp-base=1.2.2=h7f98852_1
- lz4-c=1.9.3=h9c3ff4c_1
- markupsafe=2.1.1=py38h7f8727e_0
- matplotlib-inline=0.1.6=pyhd8ed1ab_0
- mistune=2.0.5=pyhd8ed1ab_0
- mkl=2021.4.0=h06a4308_640
- mkl-service=2.4.0=py38h95df7f1_0
- mkl_fft=1.3.1=py38h8666266_1
- mkl_random=1.2.2=py38h1abd341_0
- nbclassic=0.5.1=pyhd8ed1ab_0
- nbclient=0.7.2=pyhd8ed1ab_0
- nbconvert=7.2.9=pyhd8ed1ab_0
- nbconvert-core=7.2.9=pyhd8ed1ab_0
- nbconvert-pandoc=7.2.9=pyhd8ed1ab_0
- nbformat=5.7.3=pyhd8ed1ab_0
- ncurses=6.4=h6a678d5_0
- nest-asyncio=1.5.6=pyhd8ed1ab_0
- nettle=3.6=he412f7d_0
- networkx=2.8.4=py38h06a4308_0
- notebook=6.5.2=pyha770c72_1
- notebook-shim=0.2.2=pyhd8ed1ab_0
- numpy=1.23.5=py38h14f4228_0
- numpy-base=1.23.5=py38h31eccc5_0
- olefile=0.46=pyh9f0ad1d_1
- openh264=2.1.1=h780b84a_0
- openjpeg=2.4.0=hb52868f_1
- openssl=1.1.1t=h7f8727e_0
- packaging=22.0=py38h06a4308_0
- pandoc=2.19.2=ha770c72_0
- pandocfilters=1.5.0=pyhd8ed1ab_0
- parso=0.8.3=pyhd8ed1ab_0
- pexpect=4.8.0=pyh1a96a4e_2
- pickleshare=0.7.5=py_1003
- pillow=8.2.0=py38ha0e1e83_1
- pip=22.3.1=py38h06a4308_0
- pkgutil-resolve-name=1.3.10=pyhd8ed1ab_0
- platformdirs=3.0.0=pyhd8ed1ab_0
- pooch=1.4.0=pyhd3eb1b0_0
- prometheus_client=0.16.0=pyhd8ed1ab_0
- prompt-toolkit=3.0.36=pyha770c72_0
- psutil=5.9.0=py38h5eee18b_0
- ptyprocess=0.7.0=pyhd3deb0d_0
- pure_eval=0.2.2=pyhd8ed1ab_0
- pycparser=2.21=pyhd8ed1ab_0
- pyg=2.2.0=py38_torch_1.12.0_cu116
- pygments=2.14.0=pyhd8ed1ab_0
- pyopenssl=22.0.0=pyhd8ed1ab_1
- pyparsing=3.0.9=py38h06a4308_0
- pyrsistent=0.18.0=py38heee7806_0
- pysocks=1.7.1=pyha2e5f31_6
- python=3.8.16=h7a1cb2a_2
- python-dateutil=2.8.2=pyhd8ed1ab_0
- python-fastjsonschema=2.16.2=pyhd8ed1ab_0
- python_abi=3.8=2_cp38
- pytorch=1.12.1=py3.8_cuda11.6_cudnn8.3.2_0
- pytorch-cluster=1.6.0=py38_torch_1.12.0_cu116
- pytorch-mutex=1.0=cuda
- pytorch-scatter=2.1.0=py38_torch_1.12.0_cu116
- pytorch-sparse=0.6.16=py38_torch_1.12.0_cu116
- pyzmq=19.0.2=py38ha71036d_2
- readline=8.2=h5eee18b_0
- requests=2.28.2=pyhd8ed1ab_0
- scikit-learn=1.2.0=py38h6a678d5_0
- scipy=1.10.0=py38h14f4228_0
- send2trash=1.8.0=pyhd8ed1ab_0
- setuptools=65.6.3=py38h06a4308_0
- six=1.16.0=pyh6c4a22f_0
- sniffio=1.3.0=pyhd8ed1ab_0
- soupsieve=2.3.2.post1=pyhd8ed1ab_0
- sqlite=3.40.1=h5082296_0
- stack_data=0.6.2=pyhd8ed1ab_0
- terminado=0.17.1=pyh41d4057_0
- threadpoolctl=2.2.0=pyh0d69192_0
- tinycss2=1.2.1=pyhd8ed1ab_0
- tk=8.6.12=h1ccaba5_0
- torchaudio=0.12.1=py38_cu116
- torchvision=0.13.1=py38_cu116
- tornado=6.1=py38h0a891b7_3
- tqdm=4.64.1=py38h06a4308_0
- traitlets=5.9.0=pyhd8ed1ab_0
- typing-extensions=4.4.0=hd8ed1ab_0
- typing_extensions=4.4.0=pyha770c72_0
- urllib3=1.26.14=pyhd8ed1ab_0
- wcwidth=0.2.6=pyhd8ed1ab_0
- webencodings=0.5.1=py_1
- websocket-client=1.5.1=pyhd8ed1ab_0
- wheel=0.37.1=pyhd3eb1b0_0
- widgetsnbextension=4.0.5=pyhd8ed1ab_0
- xz=5.2.10=h5eee18b_1
- zeromq=4.3.4=h9c3ff4c_1
- zipp=3.13.0=pyhd8ed1ab_0
- zlib=1.2.13=h5eee18b_0
- zstd=1.5.0=ha95c52a_0
- pip:
- alabaster==0.7.13
- autopep8==2.0.2
- babel==2.12.1
- certifi==2023.5.7
- click==8.1.3
- contourpy==1.0.7
- cycler==0.11.0
- dgl==1.0.2
- dglgo==0.0.2
- docutils==0.19
- exceptiongroup==1.1.1
- flameprof==0.4
- fonttools==4.39.3
- hydra-core==1.3.2
- imagesize==1.4.1
- iniconfig==2.0.0
- isort==5.12.0
- kiwisolver==1.4.4
- littleutils==0.2.2
- lltm-cpp==0.0.0
- markdown-it-py==2.2.0
- matmul-cpu-pytorch==0.0.0
- matmul-gpu==0.0.0
- matplotlib==3.7.1
- mdurl==0.1.2
- memray==1.7.0
- numpydoc==1.5.0
- ogb==1.3.5
- omegaconf==2.3.0
- outdated==0.2.2
- pandas==1.5.3
- pluggy==1.0.0
- presample-cores==0.0.0
- pybind11==2.10.4
- pycodestyle==2.10.0
- pydantic==1.10.7
- pytest==7.2.2
- pytz==2022.7.1
- pyyaml==6.0
- rdkit-pypi==2022.9.5
- rich==13.4.1
- ruamel-yaml==0.17.21
- ruamel-yaml-clib==0.2.7
- sample-cores==0.0.0
- snowballstemmer==2.2.0
- sphinx==6.1.3
- sphinxcontrib-applehelp==1.0.4
- sphinxcontrib-devhelp==1.0.2
- sphinxcontrib-htmlhelp==2.0.1
- sphinxcontrib-jsmath==1.0.1
- sphinxcontrib-qthelp==1.0.3
- sphinxcontrib-serializinghtml==1.1.5
- tomli==2.0.1
- torch-geometric-autoscale==0.0.0
- torch-quiver==0.1.1
- typer==0.7.0
prefix: /home/sxx/miniconda3/envs/gnn
from os.path import abspath, join, dirname
import os
import sys
from os.path import abspath, join, dirname
sys.path.insert(0, join(abspath(dirname(__file__))))
import torch
import dgl
import math
import numpy as np
class TimeEncode(torch.nn.Module):
def __init__(self, dim):
super(TimeEncode, self).__init__()
self.dim = dim
self.w = torch.nn.Linear(1, dim)
self.w.weight = torch.nn.Parameter((torch.from_numpy(1 / 10 ** np.linspace(0, 9, dim, dtype=np.float32))).reshape(dim, -1))
self.w.bias = torch.nn.Parameter(torch.zeros(dim))
def forward(self, t):
output = torch.cos(self.w(t.float().reshape((-1, 1))))
return output
class EdgePredictor(torch.nn.Module):
def __init__(self, dim_in):
super(EdgePredictor, self).__init__()
self.dim_in = dim_in
self.src_fc = torch.nn.Linear(dim_in, dim_in)
self.dst_fc = torch.nn.Linear(dim_in, dim_in)
self.out_fc = torch.nn.Linear(dim_in, 1)
def forward(self, h, neg_samples=1):
num_edge = h.shape[0] // (neg_samples + 2)
h_src = self.src_fc(h[:num_edge])
h_pos_dst = self.dst_fc(h[num_edge:2 * num_edge])
h_neg_dst = self.dst_fc(h[2 * num_edge:])
h_pos_edge = torch.nn.functional.relu(h_src + h_pos_dst)
h_neg_edge = torch.nn.functional.relu(h_src.tile(neg_samples, 1) + h_neg_dst)
return self.out_fc(h_pos_edge), self.out_fc(h_neg_edge)
class TransfomerAttentionLayer(torch.nn.Module):
def __init__(self, dim_node_feat, dim_edge_feat, dim_time, num_head, dropout, att_dropout, dim_out, combined=False):
super(TransfomerAttentionLayer, self).__init__()
self.num_head = num_head
self.dim_node_feat = dim_node_feat
self.dim_edge_feat = dim_edge_feat
self.dim_time = dim_time
self.dim_out = dim_out
self.dropout = torch.nn.Dropout(dropout)
self.att_dropout = torch.nn.Dropout(att_dropout)
self.att_act = torch.nn.LeakyReLU(0.2)
self.combined = combined
if dim_time > 0:
self.time_enc = TimeEncode(dim_time)
if combined:
if dim_node_feat > 0:
self.w_q_n = torch.nn.Linear(dim_node_feat, dim_out)
self.w_k_n = torch.nn.Linear(dim_node_feat, dim_out)
self.w_v_n = torch.nn.Linear(dim_node_feat, dim_out)
if dim_edge_feat > 0:
self.w_k_e = torch.nn.Linear(dim_edge_feat, dim_out)
self.w_v_e = torch.nn.Linear(dim_edge_feat, dim_out)
if dim_time > 0:
self.w_q_t = torch.nn.Linear(dim_time, dim_out)
self.w_k_t = torch.nn.Linear(dim_time, dim_out)
self.w_v_t = torch.nn.Linear(dim_time, dim_out)
else:
if dim_node_feat + dim_time > 0:
self.w_q = torch.nn.Linear(dim_node_feat + dim_time, dim_out)
self.w_k = torch.nn.Linear(dim_node_feat + dim_edge_feat + dim_time, dim_out)
self.w_v = torch.nn.Linear(dim_node_feat + dim_edge_feat + dim_time, dim_out)
self.w_out = torch.nn.Linear(dim_node_feat + dim_out, dim_out)
self.layer_norm = torch.nn.LayerNorm(dim_out)
def forward(self, b):
assert(self.dim_time + self.dim_node_feat + self.dim_edge_feat > 0)
self.device = b.device
if b.num_edges() == 0:
return torch.zeros((b.num_dst_nodes(), self.dim_out), device=self.device)
if self.dim_time > 0:
time_feat = self.time_enc(b.edata['dt'])
zero_time_feat = self.time_enc(torch.zeros(b.num_dst_nodes(), dtype=torch.float32, device=self.device))
if self.combined:
Q = torch.zeros((b.num_edges(), self.dim_out), device=self.device)
K = torch.zeros((b.num_edges(), self.dim_out), device=self.device)
V = torch.zeros((b.num_edges(), self.dim_out), device=self.device)
if self.dim_node_feat > 0:
Q += self.w_q_n(b.srcdata['h'][:b.num_dst_nodes()])[b.edges()[1]]
K += self.w_k_n(b.srcdata['h'][b.num_dst_nodes():])[b.edges()[0] - b.num_dst_nodes()]
V += self.w_v_n(b.srcdata['h'][b.num_dst_nodes():])[b.edges()[0] - b.num_dst_nodes()]
if self.dim_edge_feat > 0:
K += self.w_k_e(b.edata['f'])
V += self.w_v_e(b.edata['f'])
if self.dim_time > 0:
Q += self.w_q_t(zero_time_feat)[b.edges()[1]]
K += self.w_k_t(time_feat)
V += self.w_v_t(time_feat)
Q = torch.reshape(Q, (Q.shape[0], self.num_head, -1))
K = torch.reshape(K, (K.shape[0], self.num_head, -1))
V = torch.reshape(V, (V.shape[0], self.num_head, -1))
att = dgl.ops.edge_softmax(b, self.att_act(torch.sum(Q*K, dim=2)))
att = self.att_dropout(att)
V = torch.reshape(V*att[:, :, None], (V.shape[0], -1))
b.edata['v'] = V
b.update_all(dgl.function.copy_edge('v', 'm'), dgl.function.sum('m', 'h'))
else:
if self.dim_time == 0 and self.dim_node_feat == 0:
Q = torch.ones((b.num_edges(), self.dim_out), device=self.device)
K = self.w_k(b.edata['f'])
V = self.w_v(b.edata['f'])
elif self.dim_time == 0 and self.dim_edge_feat == 0:
Q = self.w_q(b.srcdata['h'][:b.num_dst_nodes()])[b.edges()[1]]
K = self.w_k(b.srcdata['h'][b.edges()[0]])
V = self.w_v(b.srcdata['h'][b.edges()[0]])
elif self.dim_time == 0:
Q = self.w_q(b.srcdata['h'][:b.num_dst_nodes()])[b.edges()[1]]
K = self.w_k(torch.cat([b.srcdata['h'][b.edges()[0]], b.edata['f']], dim=1))
V = self.w_v(torch.cat([b.srcdata['h'][b.edges()[0]], b.edata['f']], dim=1))
#K = self.w_k(torch.cat([b.srcdata['h'][b.num_dst_nodes():], b.edata['f']], dim=1))
#V = self.w_v(torch.cat([b.srcdata['h'][b.num_dst_nodes():], b.edata['f']], dim=1))
elif self.dim_node_feat == 0:
Q = self.w_q(zero_time_feat)[b.edges()[1]]
K = self.w_k(torch.cat([b.edata['f'], time_feat], dim=1))
V = self.w_v(torch.cat([b.edata['f'], time_feat], dim=1))
elif self.dim_edge_feat == 0:
Q = self.w_q(torch.cat([b.srcdata['h'][:b.num_dst_nodes()], zero_time_feat], dim=1))[b.edges()[1]]
K = self.w_k(torch.cat([b.srcdata['h'][b.edges()[0]], time_feat], dim=1))
V = self.w_v(torch.cat([b.srcdata['h'][b.edges()[0]], time_feat], dim=1))
#K = self.w_k(torch.cat([b.srcdata['h'][b.num_dst_nodes():], time_feat], dim=1))
#V = self.w_v(torch.cat([b.srcdata['h'][b.num_dst_nodes():], time_feat], dim=1))
else:
Q = self.w_q(torch.cat([b.srcdata['h'][:b.num_dst_nodes()], zero_time_feat], dim=1))[b.edges()[1]]
K = self.w_k(torch.cat([b.srcdata['h'][b.edges()[0]], b.edata['f'], time_feat], dim=1))
V = self.w_v(torch.cat([b.srcdata['h'][b.edges()[0]], b.edata['f'], time_feat], dim=1))
#Q = self.w_q(torch.cat([b.srcdata['h'][:b.num_dst_nodes()], zero_time_feat], dim=1))[b.edges()[1]]
#K = self.w_k(torch.cat([b.srcdata['h'][b.num_dst_nodes():], b.edata['f'], time_feat], dim=1))
#V = self.w_v(torch.cat([b.srcdata['h'][b.num_dst_nodes():], b.edata['f'], time_feat], dim=1))
Q = torch.reshape(Q, (Q.shape[0], self.num_head, -1))
K = torch.reshape(K, (K.shape[0], self.num_head, -1))
V = torch.reshape(V, (V.shape[0], self.num_head, -1))
att = dgl.ops.edge_softmax(b, self.att_act(torch.sum(Q*K, dim=2)))
att = self.att_dropout(att)
V = torch.reshape(V*att[:, :, None], (V.shape[0], -1))
b.edata['v'] = V
b.update_all(dgl.function.copy_e('v', 'm'), dgl.function.sum('m', 'h'))
#b.srcdata['v'] = torch.cat([torch.zeros((b.num_dst_nodes(), V.shape[1]), device=torch.device('cuda:0')), V], dim=0)
#b.update_all(dgl.function.copy_u('v', 'm'), dgl.function.sum('m', 'h'))
if self.dim_node_feat != 0:
rst = torch.cat([b.dstdata['h'], b.srcdata['h'][:b.num_dst_nodes()]], dim=1)
else:
rst = b.dstdata['h']
rst = self.w_out(rst)
rst = torch.nn.functional.relu(self.dropout(rst))
return self.layer_norm(rst)
class IdentityNormLayer(torch.nn.Module):
def __init__(self, dim_out):
super(IdentityNormLayer, self).__init__()
self.norm = torch.nn.LayerNorm(dim_out)
def forward(self, b):
return self.norm(b.srcdata['h'])
class JODIETimeEmbedding(torch.nn.Module):
def __init__(self, dim_out):
super(JODIETimeEmbedding, self).__init__()
self.dim_out = dim_out
class NormalLinear(torch.nn.Linear):
# From Jodie code
def reset_parameters(self):
stdv = 1. / math.sqrt(self.weight.size(1))
self.weight.data.normal_(0, stdv)
if self.bias is not None:
self.bias.data.normal_(0, stdv)
self.time_emb = NormalLinear(1, dim_out)
def forward(self, h, mem_ts, ts):
time_diff = (ts - mem_ts) / (ts + 1)
rst = h * (1 + self.time_emb(time_diff.unsqueeze(1)))
return rst
\ No newline at end of file
This diff is collapsed. Click to expand it.
import torch
import dgl
from os.path import abspath, join, dirname
import os
import sys
from os.path import abspath, join, dirname
sys.path.insert(0, join(abspath(dirname(__file__))))
from memorys import *
from layers import *
class GeneralModel(torch.nn.Module):
def __init__(self, dim_node, dim_edge, sample_param, memory_param, gnn_param, train_param, combined=False):
super(GeneralModel, self).__init__()
self.dim_node = dim_node
self.dim_node_input = dim_node
self.dim_edge = dim_edge
self.sample_param = sample_param
self.memory_param = memory_param
if not 'dim_out' in gnn_param:
gnn_param['dim_out'] = memory_param['dim_out']
self.gnn_param = gnn_param
self.train_param = train_param
if memory_param['type'] == 'node':
if memory_param['memory_update'] == 'gru':
self.memory_updater = GRUMemeoryUpdater(memory_param, 2 * memory_param['dim_out'] + dim_edge, memory_param['dim_out'], memory_param['dim_time'], dim_node)
elif memory_param['memory_update'] == 'rnn':
self.memory_updater = RNNMemeoryUpdater(memory_param, 2 * memory_param['dim_out'] + dim_edge, memory_param['dim_out'], memory_param['dim_time'], dim_node)
elif memory_param['memory_update'] == 'transformer':
self.memory_updater = TransformerMemoryUpdater(memory_param, 2 * memory_param['dim_out'] + dim_edge, memory_param['dim_out'], memory_param['dim_time'], train_param)
else:
raise NotImplementedError
self.dim_node_input = memory_param['dim_out']
self.layers = torch.nn.ModuleDict()
if gnn_param['arch'] == 'transformer_attention':
for h in range(sample_param['history']):
self.layers['l0h' + str(h)] = TransfomerAttentionLayer(self.dim_node_input, dim_edge, gnn_param['dim_time'], gnn_param['att_head'], train_param['dropout'], train_param['att_dropout'], gnn_param['dim_out'], combined=combined)
for l in range(1, gnn_param['layer']):
for h in range(sample_param['history']):
self.layers['l' + str(l) + 'h' + str(h)] = TransfomerAttentionLayer(gnn_param['dim_out'], dim_edge, gnn_param['dim_time'], gnn_param['att_head'], train_param['dropout'], train_param['att_dropout'], gnn_param['dim_out'], combined=False)
elif gnn_param['arch'] == 'identity':
self.gnn_param['layer'] = 1
for h in range(sample_param['history']):
self.layers['l0h' + str(h)] = IdentityNormLayer(self.dim_node_input)
if 'time_transform' in gnn_param and gnn_param['time_transform'] == 'JODIE':
self.layers['l0h' + str(h) + 't'] = JODIETimeEmbedding(gnn_param['dim_out'])
else:
raise NotImplementedError
self.edge_predictor = EdgePredictor(gnn_param['dim_out'])
if 'combine' in gnn_param and gnn_param['combine'] == 'rnn':
self.combiner = torch.nn.RNN(gnn_param['dim_out'], gnn_param['dim_out'])
def forward(self, mfgs, metadata = None,neg_samples=1):
if self.memory_param['type'] == 'node':
self.memory_updater(mfgs[0])
out = list()
for l in range(self.gnn_param['layer']):
for h in range(self.sample_param['history']):
rst = self.layers['l' + str(l) + 'h' + str(h)](mfgs[l][h])
if 'time_transform' in self.gnn_param and self.gnn_param['time_transform'] == 'JODIE':
rst = self.layers['l0h' + str(h) + 't'](rst, mfgs[l][h].srcdata['mem_ts'], mfgs[l][h].srcdata['ts'])
if l != self.gnn_param['layer'] - 1:
mfgs[l + 1][h].srcdata['h'] = rst
else:
out.append(rst)
if self.sample_param['history'] == 1:
out = out[0]
else:
out = torch.stack(out, dim=0)
out = self.combiner(out)[0][-1, :, :]
#metadata需要在前面去重的时候记一下id
if metadata is not None:
out = torch.cat((out[metadata['src_id_pos']],out[metadata['dst_pos_pos']],out[metadata['dst_neg_pos']]),0)
return self.edge_predictor(out, neg_samples=neg_samples)
def get_emb(self, mfgs):
if self.memory_param['type'] == 'node':
self.memory_updater(mfgs[0])
out = list()
for l in range(self.gnn_param['layer']):
for h in range(self.sample_param['history']):
rst = self.layers['l' + str(l) + 'h' + str(h)](mfgs[l][h])
if 'time_transform' in self.gnn_param and self.gnn_param['time_transform'] == 'JODIE':
rst = self.layers['l0h' + str(h) + 't'](rst, mfgs[l][h].srcdata['mem_ts'], mfgs[l][h].srcdata['ts'])
if l != self.gnn_param['layer'] - 1:
mfgs[l + 1][h].srcdata['h'] = rst
else:
out.append(rst)
if self.sample_param['history'] == 1:
out = out[0]
else:
out = torch.stack(out, dim=0)
out = self.combiner(out)[0][-1, :, :]
return out
class NodeClassificationModel(torch.nn.Module):
def __init__(self, dim_in, dim_hid, num_class):
super(NodeClassificationModel, self).__init__()
self.fc1 = torch.nn.Linear(dim_in, dim_hid)
self.fc2 = torch.nn.Linear(dim_hid, num_class)
def forward(self, x):
x = self.fc1(x)
x = torch.nn.functional.relu(x)
x = self.fc2(x)
return x
\ No newline at end of file
sudo mount -o remount,size=600G /dev/shm
\ No newline at end of file
alabaster==0.7.13
antlr4-python3-runtime==4.9.3
anyio @ file:///home/conda/feedstock_root/build_artifacts/anyio_1666191106763/work/dist
appdirs==1.4.4
argon2-cffi @ file:///home/conda/feedstock_root/build_artifacts/argon2-cffi_1640817743617/work
argon2-cffi-bindings @ file:///home/conda/feedstock_root/build_artifacts/argon2-cffi-bindings_1649500309442/work
asttokens==2.2.1
attrs @ file:///home/conda/feedstock_root/build_artifacts/attrs_1671632566681/work
autopep8==2.0.2
Babel==2.12.1
backcall==0.2.0
backports.functools-lru-cache @ file:///home/conda/feedstock_root/build_artifacts/backports.functools_lru_cache_1618230623929/work
beautifulsoup4 @ file:///home/conda/feedstock_root/build_artifacts/beautifulsoup4_1675252249248/work
bleach @ file:///home/conda/feedstock_root/build_artifacts/bleach_1674535352125/work
brotlipy @ file:///home/conda/feedstock_root/build_artifacts/brotlipy_1648854175163/work
certifi==2022.12.7
cffi @ file:///home/conda/feedstock_root/build_artifacts/cffi_1636046063618/work
charset-normalizer @ file:///home/conda/feedstock_root/build_artifacts/charset-normalizer_1661170624537/work
click==8.1.3
comm==0.1.1
contourpy==1.0.7
cryptography @ file:///home/conda/feedstock_root/build_artifacts/cryptography_1652967113783/work
cycler==0.11.0
debugpy==1.6.4
decorator @ file:///home/conda/feedstock_root/build_artifacts/decorator_1641555617451/work
defusedxml @ file:///home/conda/feedstock_root/build_artifacts/defusedxml_1615232257335/work
dgl==1.0.2
dglgo==0.0.2
docutils==0.19
entrypoints @ file:///home/conda/feedstock_root/build_artifacts/entrypoints_1643888246732/work
exceptiongroup==1.1.1
executing==1.2.0
fastgraph==0.2.0
fastjsonschema @ file:///home/conda/feedstock_root/build_artifacts/python-fastjsonschema_1663619548554/work/dist
flameprof==0.4
flit_core @ file:///home/conda/feedstock_root/build_artifacts/flit-core_1667734568827/work/source/flit_core
fonttools==4.39.3
hydra-core==1.3.2
idna @ file:///home/conda/feedstock_root/build_artifacts/idna_1663625384323/work
imagesize==1.4.1
importlib-metadata @ file:///home/conda/feedstock_root/build_artifacts/importlib-metadata_1672612343532/work
importlib-resources @ file:///home/conda/feedstock_root/build_artifacts/importlib_resources_1672681417544/work
iniconfig==2.0.0
ipykernel==6.19.0
ipython==8.7.0
ipython-genutils==0.2.0
ipywidgets @ file:///home/conda/feedstock_root/build_artifacts/ipywidgets_1671720089366/work
isort==5.12.0
jedi==0.18.2
Jinja2 @ file:///croot/jinja2_1666908132255/work
joblib==1.2.0
jsonschema @ file:///home/conda/feedstock_root/build_artifacts/jsonschema-meta_1669810440410/work
jupyter-server @ file:///home/conda/feedstock_root/build_artifacts/jupyter_server_1673574555503/work
jupyter_client==7.4.8
jupyter_core==5.1.0
jupyterlab-pygments @ file:///home/conda/feedstock_root/build_artifacts/jupyterlab_pygments_1649936611996/work
jupyterlab-widgets @ file:///home/conda/feedstock_root/build_artifacts/jupyterlab_widgets_1671722028097/work
kiwisolver==1.4.4
littleutils==0.2.2
lltm-cpp==0.0.0
markdown-it-py==2.2.0
MarkupSafe @ file:///opt/conda/conda-bld/markupsafe_1654597864307/work
matmul-cpu-pytorch==0.0.0
matmul-gpu==0.0.0
matplotlib==3.7.1
matplotlib-inline==0.1.6
mdurl==0.1.2
memray==1.7.0
mistune @ file:///home/conda/feedstock_root/build_artifacts/mistune_1675771498296/work
mkl-fft==1.3.1
mkl-random==1.2.2
mkl-service==2.4.0
nbclassic @ file:///home/conda/feedstock_root/build_artifacts/nbclassic_1675369808718/work
nbclient @ file:///home/conda/feedstock_root/build_artifacts/nbclient_1669795076334/work
nbconvert @ file:///home/conda/feedstock_root/build_artifacts/nbconvert-meta_1674590374792/work
nbformat @ file:///home/conda/feedstock_root/build_artifacts/nbformat_1673560067442/work
nest-asyncio==1.5.6
networkx @ file:///opt/conda/conda-bld/networkx_1657784097507/work
notebook @ file:///home/conda/feedstock_root/build_artifacts/notebook_1667565639349/work
notebook_shim @ file:///home/conda/feedstock_root/build_artifacts/notebook-shim_1667478401171/work
numpy @ file:///croot/numpy_and_numpy_base_1672336185480/work
numpydoc==1.5.0
ogb==1.3.5
olefile @ file:///home/conda/feedstock_root/build_artifacts/olefile_1602866521163/work
omegaconf==2.3.0
outdated==0.2.2
packaging @ file:///croot/packaging_1671697413597/work
pandas==1.5.3
pandocfilters @ file:///home/conda/feedstock_root/build_artifacts/pandocfilters_1631603243851/work
parso==0.8.3
pexpect @ file:///home/conda/feedstock_root/build_artifacts/pexpect_1667297516076/work
pickleshare==0.7.5
Pillow @ file:///home/conda/feedstock_root/build_artifacts/pillow_1621287973719/work
pkgutil_resolve_name @ file:///home/conda/feedstock_root/build_artifacts/pkgutil-resolve-name_1633981968097/work
platformdirs==2.6.0
pluggy==1.0.0
pooch @ file:///tmp/build/80754af9/pooch_1623324770023/work
presample-cores==0.0.0
prometheus-client @ file:///home/conda/feedstock_root/build_artifacts/prometheus_client_1674535637125/work
prompt-toolkit==3.0.36
psutil==5.9.4
ptyprocess @ file:///home/conda/feedstock_root/build_artifacts/ptyprocess_1609419310487/work/dist/ptyprocess-0.7.0-py2.py3-none-any.whl
pure-eval==0.2.2
pybind11==2.10.4
pycodestyle==2.10.0
pycparser @ file:///home/conda/feedstock_root/build_artifacts/pycparser_1636257122734/work
pydantic==1.10.7
Pygments==2.13.0
pyOpenSSL @ file:///home/conda/feedstock_root/build_artifacts/pyopenssl_1663846997386/work
pyparsing @ file:///opt/conda/conda-bld/pyparsing_1661452539315/work
pyrsistent @ file:///tmp/build/80754af9/pyrsistent_1636110947380/work
PySocks @ file:///home/conda/feedstock_root/build_artifacts/pysocks_1661604839144/work
pytest==7.2.2
python-dateutil==2.8.2
pytz==2022.7.1
PyYAML==6.0
pyzmq==24.0.1
rdkit-pypi==2022.9.5
requests @ file:///home/conda/feedstock_root/build_artifacts/requests_1673863902341/work
rich==13.4.1
ruamel.yaml==0.17.21
ruamel.yaml.clib==0.2.7
sample-cores==0.0.0
scikit-learn @ file:///croot/scikit-learn_1673957209982/work
scipy==1.10.0
Send2Trash @ file:///home/conda/feedstock_root/build_artifacts/send2trash_1628511208346/work
six @ file:///home/conda/feedstock_root/build_artifacts/six_1620240208055/work
sniffio @ file:///home/conda/feedstock_root/build_artifacts/sniffio_1662051266223/work
snowballstemmer==2.2.0
soupsieve @ file:///home/conda/feedstock_root/build_artifacts/soupsieve_1658207591808/work
Sphinx==6.1.3
sphinxcontrib-applehelp==1.0.4
sphinxcontrib-devhelp==1.0.2
sphinxcontrib-htmlhelp==2.0.1
sphinxcontrib-jsmath==1.0.1
sphinxcontrib-qthelp==1.0.3
sphinxcontrib-serializinghtml==1.1.5
stack-data==0.6.2
terminado @ file:///home/conda/feedstock_root/build_artifacts/terminado_1670253674810/work
threadpoolctl==3.1.0
tinycss2 @ file:///home/conda/feedstock_root/build_artifacts/tinycss2_1666100256010/work
tomli==2.0.1
torch==1.12.1
torch-cluster @ file:///usr/share/miniconda/envs/test/conda-bld/pytorch-cluster_1656604322802/work
torch-geometric @ file:///usr/share/miniconda/envs/test/conda-bld/pyg_1669879221143/work
torch-geometric-autoscale==0.0.0
torch-quiver==0.1.1
torch-scatter @ file:///usr/share/miniconda/envs/test/conda-bld/pytorch-scatter_1669298837499/work
torch-sparse @ file:///usr/share/miniconda/envs/test/conda-bld/pytorch-sparse_1671714918898/work
torchaudio==0.12.1
torchvision==0.13.1
tornado==6.2
tqdm==4.64.1
traitlets==5.6.0
typer==0.7.0
typing_extensions==4.4.0
urllib3 @ file:///home/conda/feedstock_root/build_artifacts/urllib3_1673452138552/work
wcwidth @ file:///home/conda/feedstock_root/build_artifacts/wcwidth_1673864653149/work
webencodings==0.5.1
websocket-client @ file:///home/conda/feedstock_root/build_artifacts/websocket-client_1675567828044/work
widgetsnbextension @ file:///home/conda/feedstock_root/build_artifacts/widgetsnbextension_1672066693230/work
zipp @ file:///home/conda/feedstock_root/build_artifacts/zipp_1675982654259/work
import argparse
import yaml
import torch
import time
import numpy as np
import pandas as pd
from tqdm import tqdm
from sampler_core import ParallelSampler, TemporalGraphBlock
class NegLinkSampler:
def __init__(self, num_nodes):
self.num_nodes = num_nodes
def sample(self, n):
return np.random.randint(self.num_nodes, size=n)
if __name__ == '__main__':
parser=argparse.ArgumentParser()
parser.add_argument('--data', type=str, help='dataset name')
parser.add_argument('--config', type=str, help='path to config file')
parser.add_argument('--batch_size', type=int, default=600, help='path to config file')
parser.add_argument('--num_thread', type=int, default=64, help='number of thread')
args=parser.parse_args()
df = pd.read_csv('DATA/{}/edges.csv'.format(args.data))
g = np.load('DATA/{}/ext_full.npz'.format(args.data))
sample_config = yaml.safe_load(open(args.config, 'r'))['sampling'][0]
# sampling:
# - layer: 2
# neighbor:
# - 10
# - 10
# strategy: 'uniform'
# prop_time: True
# history: 3
# duration: 5
# num_thread: 64
#
# strategy有两种一种是过去邻居均匀采样:'uniform'
# 一种是最近时间采样
sampler = ParallelSampler(g['indptr'], g['indices'], g['eid'], g['ts'].astype(np.float32),
args.num_thread, 1, sample_config['layer'], sample_config['neighbor'],
sample_config['strategy']=='recent', sample_config['prop_time'],
sample_config['history'], float(sample_config['duration']))
num_nodes = max(int(df['src'].max()), int(df['dst'].max()))
neg_link_sampler = NegLinkSampler(num_nodes)
tot_time = 0
ptr_time = 0
coo_time = 0
sea_time = 0
sam_time = 0
uni_time = 0
total_nodes = 0
unique_nodes = 0
for _, rows in tqdm(df.groupby(df.index // args.batch_size), total=len(df) // args.batch_size):
root_nodes = np.concatenate([rows.src.values, rows.dst.values, neg_link_sampler.sample(len(rows))]).astype(np.int32)
ts = np.concatenate([rows.time.values, rows.time.values, rows.time.values]).astype(np.float32)
sampler.sample(root_nodes, ts)
ret = sampler.get_ret()
tot_time += ret[0].tot_time()
ptr_time += ret[0].ptr_time()
coo_time += ret[0].coo_time()
sea_time += ret[0].search_time()
sam_time += ret[0].sample_time()
# for i in range(sample_config['history']):
# total_nodes += ret[i].dim_in() - ret[i].dim_out()
# unique_nodes += ret[i].dim_in() - ret[i].dim_out()
# if ret[i].dim_in() > ret[i].dim_out():
# ts = torch.from_numpy(ret[i].ts()[ret[i].dim_out():])
# nid = torch.from_numpy(ret[i].nodes()[ret[i].dim_out():]).float()
# nts = torch.stack([ts,nid],dim=1).cuda()
# uni_t_s = time.time()
# unts, idx = torch.unique(nts, dim=0, return_inverse=True)
# uni_time += time.time() - uni_t_s
# total_nodes += idx.shape[0]
# unique_nodes += unts.shape[0]
print('total time : {:.4f}'.format(tot_time))
print('pointer time: {:.4f}'.format(ptr_time))
print('coo time : {:.4f}'.format(coo_time))
print('search time : {:.4f}'.format(sea_time))
print('sample time : {:.4f}'.format(sam_time))
# print('unique time : {:.4f}'.format(uni_time))
# print('unique per : {:.4f}'.format(unique_nodes / total_nodes))
\ No newline at end of file
from glob import glob
from setuptools import setup
from pybind11.setup_helpers import Pybind11Extension
ext_modules = [
Pybind11Extension("sampler_core",
['sampler_core.cpp'],
extra_compile_args = ['-fopenmp'],
extra_link_args = ['-fopenmp'],),
]
setup(
name = "sampler_core",
version = "0.0.1",
author = "Hongkuan Zhou",
author_email = "hongkuaz@usc.edu",
url = "https://tedzhouhk.github.io/about/",
description = "Parallel Sampling for Temporal Graphs",
ext_modules = ext_modules,
)
\ No newline at end of file
This diff is collapsed. Click to expand it.
This diff is collapsed. Click to expand it.
This source diff could not be displayed because it is too large. You can view the blob instead.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment