Commit c2ce9bbc by zlj

Add two new Models

parent 4bb7a33b
set(CMAKE_SOURCE_DIR "/mnt/data/")
cmake_minimum_required(VERSION 3.15) cmake_minimum_required(VERSION 3.15)
project(starrygl VERSION 0.1) project(starrygl VERSION 0.1)
......
cmake_minimum_required(VERSION 3.15)
project(starrygl VERSION 0.1)
option(WITH_PYTHON "Link to Python when building" ON)
option(WITH_CUDA "Link to CUDA when building" ON)
option(WITH_METIS "Link to METIS when building" ON)
option(WITH_MTMETIS "Link to multi-threaded METIS when building" ON)
option(WITH_LDG "Link to (multi-threaded optionally) LDG when building" ON)
message("third_party dir is ${CMAKE_SOURCE_DIR}")
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
set(CMAKE_CUDA_STANDARD 14)
set(CMAKE_CUDA_STANDARD_REQUIRED ON)
find_package(OpenMP REQUIRED)
link_libraries(OpenMP::OpenMP_CXX)
find_package(Torch REQUIRED)
include_directories(${TORCH_INCLUDE_DIRS})
add_compile_options(${TORCH_CXX_FLAGS})
if(WITH_PYTHON)
add_definitions(-DWITH_PYTHON)
find_package(Python3 COMPONENTS Interpreter Development REQUIRED)
include_directories(${Python3_INCLUDE_DIRS})
endif()
if(WITH_CUDA)
add_definitions(-DWITH_CUDA)
add_definitions(-DWITH_UVM)
find_package(CUDA REQUIRED)
include_directories(${CUDA_INCLUDE_DIRS})
set(CUDA_LIBRARIES "${CUDA_TOOLKIT_ROOT_DIR}/lib64/libcudart.so")
file(GLOB_RECURSE UVM_SRCS "csrc/uvm/*.cpp")
add_library(uvm_ops SHARED ${UVM_SRCS})
target_link_libraries(uvm_ops PRIVATE ${TORCH_LIBRARIES})
endif()
if(WITH_METIS)
add_definitions(-DWITH_METIS)
set(GKLIB_DIR "${CMAKE_SOURCE_DIR}/third_party/GKlib")
set(METIS_DIR "${CMAKE_SOURCE_DIR}/third_party/METIS")
set(GKLIB_INCLUDE_DIRS "${GKLIB_DIR}/include")
file(GLOB_RECURSE GKLIB_LIBRARIES "${GKLIB_DIR}/lib/lib*.a")
set(METIS_INCLUDE_DIRS "${METIS_DIR}/include")
file(GLOB_RECURSE METIS_LIBRARIES "${METIS_DIR}/lib/lib*.a")
include_directories(${METIS_INCLUDE_DIRS})
add_library(metis_partition SHARED "csrc/partition/metis.cpp")
target_link_libraries(metis_partition PRIVATE ${TORCH_LIBRARIES})
target_link_libraries(metis_partition PRIVATE ${GKLIB_LIBRARIES})
target_link_libraries(metis_partition PRIVATE ${METIS_LIBRARIES})
endif()
if(WITH_MTMETIS)
add_definitions(-DWITH_MTMETIS)
set(MTMETIS_DIR "${CMAKE_SOURCE_DIR}/third_party/mt-metis")
set(MTMETIS_INCLUDE_DIRS "${MTMETIS_DIR}/include")
file(GLOB_RECURSE MTMETIS_LIBRARIES "${MTMETIS_DIR}/lib/lib*.a")
include_directories(${MTMETIS_INCLUDE_DIRS})
add_library(mtmetis_partition SHARED "csrc/partition/mtmetis.cpp")
target_link_libraries(mtmetis_partition PRIVATE ${TORCH_LIBRARIES})
target_link_libraries(mtmetis_partition PRIVATE ${MTMETIS_LIBRARIES})
target_compile_definitions(mtmetis_partition PRIVATE -DMTMETIS_64BIT_VERTICES)
target_compile_definitions(mtmetis_partition PRIVATE -DMTMETIS_64BIT_EDGES)
target_compile_definitions(mtmetis_partition PRIVATE -DMTMETIS_64BIT_WEIGHTS)
target_compile_definitions(mtmetis_partition PRIVATE -DMTMETIS_64BIT_PARTITIONS)
endif()
if (WITH_LDG)
# Imports neighbor-clustering based (e.g. LDG algorithm) graph partitioning implementation
add_definitions(-DWITH_LDG)
# set(LDG_DIR "csrc/partition/neighbor_clustering")
set(LDG_DIR "third_party/ldg_partition")
add_library(ldg_partition SHARED "csrc/partition/ldg.cpp")
target_link_libraries(ldg_partition PRIVATE ${TORCH_LIBRARIES})
# add_subdirectory(${LDG_DIR})
target_include_directories(ldg_partition PRIVATE ${LDG_DIR})
target_link_libraries(ldg_partition PRIVATE ldg-vertex-partition)
endif ()
include_directories("csrc/include")
add_library(${PROJECT_NAME} SHARED csrc/export.cpp)
target_link_libraries(${PROJECT_NAME} PRIVATE ${TORCH_LIBRARIES})
target_compile_definitions(${PROJECT_NAME} PRIVATE -DTORCH_EXTENSION_NAME=lib${PROJECT_NAME})
if(WITH_PYTHON)
find_library(TORCH_PYTHON_LIBRARY torch_python PATHS "${TORCH_INSTALL_PREFIX}/lib")
target_link_libraries(${PROJECT_NAME} PRIVATE ${TORCH_PYTHON_LIBRARY})
endif()
if (WITH_CUDA)
target_link_libraries(${PROJECT_NAME} PRIVATE uvm_ops)
endif()
if (WITH_METIS)
message(STATUS "Current project '${PROJECT_NAME}' uses METIS graph partitioning algorithm.")
target_link_libraries(${PROJECT_NAME} PRIVATE metis_partition)
endif()
if (WITH_MTMETIS)
message(STATUS "Current project '${PROJECT_NAME}' uses multi-threaded METIS graph partitioning algorithm.")
target_link_libraries(${PROJECT_NAME} PRIVATE mtmetis_partition)
endif()
if (WITH_LDG)
message(STATUS "Current project '${PROJECT_NAME}' uses LDG graph partitioning algorithm.")
target_link_libraries(${PROJECT_NAME} PRIVATE ldg_partition)
endif()
# add libsampler.so
set(SAMLPER_NAME "${PROJECT_NAME}_sampler")
set(BOOST_INCLUDE_DIRS "${CMAKE_SOURCE_DIR}/third_party/boost_1_83_0")
include_directories(${BOOST_INCLUDE_DIRS})
file(GLOB_RECURSE SAMPLER_SRCS "csrc/sampler/*.cpp")
add_library(${SAMLPER_NAME} SHARED ${SAMPLER_SRCS})
target_include_directories(${SAMLPER_NAME} PRIVATE "csrc/sampler/include")
target_compile_options(${SAMLPER_NAME} PRIVATE -O3)
target_link_libraries(${SAMLPER_NAME} PRIVATE ${TORCH_LIBRARIES})
target_compile_definitions(${SAMLPER_NAME} PRIVATE -DTORCH_EXTENSION_NAME=lib${SAMLPER_NAME})
if(WITH_PYTHON)
find_library(TORCH_PYTHON_LIBRARY torch_python PATHS "${TORCH_INSTALL_PREFIX}/lib")
target_link_libraries(${SAMLPER_NAME} PRIVATE ${TORCH_PYTHON_LIBRARY})
endif()
sampling:
- layer: 1
neighbor:
- 10
strategy: 'recent'
prop_time: False
history: 1
duration: 0
num_thread: 32
memory:
- type: 'node'
dim_time: 100
deliver_to: 'self'
mail_combine: 'last'
memory_update: 'gru'
mailbox_size: 1
combine_node_feature: True
dim_out: 100
gnn:
- arch: 'transformer_attention'
use_src_emb: True
use_dst_emb: True
layer: 1
att_head: 2
dim_time: 100
dim_out: 100
train:
- epoch: 50
batch_size: 100
# reorder: 16
lr: 0.0001
dropout: 0.1
att_dropout: 0.2
all_on_gpu: True
\ No newline at end of file
...@@ -18,13 +18,15 @@ memory: ...@@ -18,13 +18,15 @@ memory:
dim_out: 100 dim_out: 100
gnn: gnn:
- arch: 'transformer_attention' - arch: 'transformer_attention'
use_src_emb: False
use_dst_emb: False
layer: 1 layer: 1
att_head: 2 att_head: 2
dim_time: 100 dim_time: 100
dim_out: 100 dim_out: 100
train: train:
- epoch: 5 - epoch: 20
#batch_size: 100 batch_size: 200
# reorder: 16 # reorder: 16
lr: 0.0001 lr: 0.0001
dropout: 0.2 dropout: 0.2
......
sampling:
- layer: 1
neighbor:
- 10
strategy: 'recent'
prop_time: False
history: 1
duration: 0
num_thread: 32
memory:
- type: 'node'
dim_time: 100
deliver_to: 'self'
mail_combine: 'last'
memory_update: 'gru'
mailbox_size: 1
combine_node_feature: True
dim_out: 100
gnn:
- arch: 'transformer_attention'
use_src_emb: True
use_dst_emb: True
layer: 1
att_head: 2
dim_time: 100
dim_out: 100
train:
- epoch: 20
batch_size: 200
# reorder: 16
lr: 0.0001
dropout: 0.2
att_dropout: 0.2
all_on_gpu: True
\ No newline at end of file
#include <torch/all.h> #include <torch/all.h>
#include "neighbor_clustering/vertex_partition/vertex_partition.h" #include "vertex_partition/vertex_partition.h"
#include "neighbor_clustering/vertex_partition/params.h" #include "vertex_partition/params.h"
at::Tensor ldg_partition(at::Tensor edges, at::Tensor ldg_partition(at::Tensor edges,
at::optional<at::Tensor> vertex_weights, at::optional<at::Tensor> vertex_weights,
......
...@@ -3,9 +3,10 @@ ...@@ -3,9 +3,10 @@
mkdir -p build && cd build mkdir -p build && cd build
cmake .. \ cmake .. \
-DCMAKE_EXPORT_COMPILE_COMMANDS=ON \ -DCMAKE_EXPORT_COMPILE_COMMANDS=ON \
-DCMAKE_PREFIX_PATH="/home/zlj/.miniconda3/envs/dgnn/lib/python3.10/site-packages" \ -DCMAKE_PREFIX_PATH="/home/zlj/.miniconda3/envs/dgnn/lib/python3.8/site-packages" \
-DPython3_ROOT_DIR="/home/zlj/.miniconda3/envs/dgnn" \ -DPython3_ROOT_DIR="/home/zlj/.miniconda3/envs/dgnn" \
-DCUDA_TOOLKIT_ROOT_DIR="/home/zlj/local/cuda-12.2" \ -DCUDA_TOOLKIT_ROOT_DIR="/home/zlj/local/cuda-12.2" \
-DWITH_LDG=OFF \
&& make -j32 \ && make -j32 \
&& rm -rf ../starrygl/lib \ && rm -rf ../starrygl/lib \
&& mkdir ../starrygl/lib \ && mkdir ../starrygl/lib \
......
...@@ -21,7 +21,7 @@ class GeneralModel(torch.nn.Module): ...@@ -21,7 +21,7 @@ class GeneralModel(torch.nn.Module):
self.train_param = train_param self.train_param = train_param
if memory_param['type'] == 'node': if memory_param['type'] == 'node':
if memory_param['memory_update'] == 'gru': if memory_param['memory_update'] == 'gru':
self.memory_updater = RNNMemeoryUpdater(memory_param, 2 * memory_param['dim_out'] + dim_edge, memory_param['dim_out'], memory_param['dim_time'], dim_node) self.memory_updater = GRUMemeoryUpdater(memory_param, 2 * memory_param['dim_out'] + dim_edge, memory_param['dim_out'], memory_param['dim_time'], dim_node)
elif memory_param['memory_update'] == 'rnn': elif memory_param['memory_update'] == 'rnn':
self.memory_updater = RNNMemeoryUpdater(memory_param, 2 * memory_param['dim_out'] + dim_edge, memory_param['dim_out'], memory_param['dim_time'], dim_node) self.memory_updater = RNNMemeoryUpdater(memory_param, 2 * memory_param['dim_out'] + dim_edge, memory_param['dim_out'], memory_param['dim_time'], dim_node)
elif memory_param['memory_update'] == 'transformer': elif memory_param['memory_update'] == 'transformer':
...@@ -47,7 +47,7 @@ class GeneralModel(torch.nn.Module): ...@@ -47,7 +47,7 @@ class GeneralModel(torch.nn.Module):
self.edge_predictor = EdgePredictor(gnn_param['dim_out']) self.edge_predictor = EdgePredictor(gnn_param['dim_out'])
if 'combine' in gnn_param and gnn_param['combine'] == 'rnn': if 'combine' in gnn_param and gnn_param['combine'] == 'rnn':
self.combiner = torch.nn.RNN(gnn_param['dim_out'], gnn_param['dim_out']) self.combiner = torch.nn.RNN(gnn_param['dim_out'], gnn_param['dim_out'])
def forward(self, mfgs, metadata = None,neg_samples=1): def forward(self, mfgs, metadata = None,neg_samples=1):
if self.memory_param['type'] == 'node': if self.memory_param['type'] == 'node':
...@@ -68,8 +68,14 @@ class GeneralModel(torch.nn.Module): ...@@ -68,8 +68,14 @@ class GeneralModel(torch.nn.Module):
out = torch.stack(out, dim=0) out = torch.stack(out, dim=0)
out = self.combiner(out)[0][-1, :, :] out = self.combiner(out)[0][-1, :, :]
#metadata需要在前面去重的时候记一下id #metadata需要在前面去重的时候记一下id
if self.gnn_param['use_src_emb'] or self.gnn_param['use_dst_emb']:
self.embedding = out.detach().clone()
else:
self.embedding = None
if metadata is not None: if metadata is not None:
#out = torch.cat((out[metadata['dst_pos_pos']],out[metadata['src_id_pos']],out[metadata['dst_neg_pos']]),0) #out = torch.cat((out[metadata['dst_pos_pos']],out[metadata['src_id_pos']],out[metadata['dst_neg_pos']]),0)
if self.gnn_param['dyrep']:
out = self.memory_updater.last_updated_memory
out = torch.cat((out[metadata['src_pos_index']],out[metadata['dst_pos_index']],out[metadata['src_neg_index']]),0) out = torch.cat((out[metadata['src_pos_index']],out[metadata['dst_pos_index']],out[metadata['src_neg_index']]),0)
return self.edge_predictor(out, neg_samples=neg_samples) return self.edge_predictor(out, neg_samples=neg_samples)
......
import yaml import yaml
import numpy as np
def parse_config(f): def parse_config(f):
conf = yaml.safe_load(open(f, 'r')) conf = yaml.safe_load(open(f, 'r'))
...@@ -7,4 +7,32 @@ def parse_config(f): ...@@ -7,4 +7,32 @@ def parse_config(f):
memory_param = conf['memory'][0] memory_param = conf['memory'][0]
gnn_param = conf['gnn'][0] gnn_param = conf['gnn'][0]
train_param = conf['train'][0] train_param = conf['train'][0]
return sample_param, memory_param, gnn_param, train_param return sample_param, memory_param, gnn_param, train_param
\ No newline at end of file
class EarlyStopMonitor(object):
def __init__(self, max_round=3, higher_better=True, tolerance=1e-10):
self.max_round = max_round
self.num_round = 0
self.epoch_count = 0
self.best_epoch = 0
self.last_best = None
self.higher_better = higher_better
self.tolerance = tolerance
def early_stop_check(self, curr_val):
if not self.higher_better:
curr_val *= -1
if self.last_best is None:
self.last_best = curr_val
elif (curr_val - self.last_best) / np.abs(self.last_best) > self.tolerance:
self.last_best = curr_val
self.num_round = 0
self.best_epoch = self.epoch_count
else:
self.num_round += 1
self.epoch_count += 1
return self.num_round >= self.max_round
\ No newline at end of file
...@@ -68,7 +68,7 @@ class DataSet: ...@@ -68,7 +68,7 @@ class DataSet:
if labels is not None: if labels is not None:
self.labels = labels self.labels = labels
self.len = self.nodes.shape[0] if nodes is not None else self.edges.shape[1] self.len = self.nodes.shape[0] if nodes is not None else self.edges.shape[1]
print(self.edges.shape,self.len)
for k, v in kwargs.items(): for k, v in kwargs.items():
assert isinstance(v,torch.Tensor) and v.shape[0]==self.len assert isinstance(v,torch.Tensor) and v.shape[0]==self.len
setattr(self, k, v.to(device)) setattr(self, k, v.to(device))
......
...@@ -241,7 +241,7 @@ class SharedMailBox(): ...@@ -241,7 +241,7 @@ class SharedMailBox():
def get_update_mail(self,dist_indx_mapper, def get_update_mail(self,dist_indx_mapper,
src,dst,ts,edge_feats, src,dst,ts,edge_feats,
memory): memory,embedding=None,use_src_emb=False,use_dst_emb=False):
if edge_feats is not None: if edge_feats is not None:
edge_feats = edge_feats.to(self.device).to(self.mailbox.dtype) edge_feats = edge_feats.to(self.device).to(self.mailbox.dtype)
src = src.to(self.device) src = src.to(self.device)
...@@ -251,12 +251,14 @@ class SharedMailBox(): ...@@ -251,12 +251,14 @@ class SharedMailBox():
mem_src = memory[src] mem_src = memory[src]
mem_dst = memory[dst] mem_dst = memory[dst]
if embedding is not None:
emb_src = embedding[src]
emb_dst = embedding[dst]
src_mail = torch.cat([emb_src if use_src_emb else mem_src, emb_dst if use_dst_emb else mem_dst], dim=1)
dst_mail = torch.cat([emb_dst if use_src_emb else mem_dst, emb_src if use_dst_emb else mem_src], dim=1)
if edge_feats is not None: if edge_feats is not None:
src_mail = torch.cat([mem_src, mem_dst, edge_feats], dim=1) src_mail = torch.cat([src_mail, edge_feats], dim=1)
dst_mail = torch.cat([mem_dst, mem_src, edge_feats], dim=1) dst_mail = torch.cat([dst_mail, edge_feats], dim=1)
else:
src_mail = torch.cat([mem_src, mem_dst], dim=1)
dst_mail = torch.cat([mem_dst, mem_src], dim=1)
mail = torch.cat([src_mail, dst_mail], dim=1).reshape(-1, src_mail.shape[1]) mail = torch.cat([src_mail, dst_mail], dim=1).reshape(-1, src_mail.shape[1])
mail_ts = torch.cat((ts,ts),-1).to(self.device).to(self.mailbox_ts.dtype) mail_ts = torch.cat((ts,ts),-1).to(self.device).to(self.mailbox_ts.dtype)
unq_index,inv = torch.unique(index,return_inverse = True) unq_index,inv = torch.unique(index,return_inverse = True)
...@@ -266,7 +268,6 @@ class SharedMailBox(): ...@@ -266,7 +268,6 @@ class SharedMailBox():
index = unq_index index = unq_index
return index,mail,mail_ts return index,mail,mail_ts
def get_update_memory(self,index,memory,memory_ts): def get_update_memory(self,index,memory,memory_ts):
unq_index,inv = torch.unique(index,return_inverse = True) unq_index,inv = torch.unique(index,return_inverse = True)
max_ts,idx = torch_scatter.scatter_max(memory_ts,inv,0) max_ts,idx = torch_scatter.scatter_max(memory_ts,inv,0)
......
...@@ -2,8 +2,7 @@ import parser ...@@ -2,8 +2,7 @@ import parser
from torch_sparse import SparseTensor from torch_sparse import SparseTensor
from torch_geometric.data import Data from torch_geometric.data import Data
from torch_geometric.utils import degree from torch_geometric.utils import degree
from starrygl.lib.libstarrygl_ops import metis_partition from starrygl.lib.libstarrygl_sampler import get_norm_temporal
from starrygl.lib.libstarrygl_ops_sampler import get_norm_temporal
import os.path as osp import os.path as osp
import os import os
import shutil import shutil
......
...@@ -9,7 +9,8 @@ from typing import Optional, Tuple ...@@ -9,7 +9,8 @@ from typing import Optional, Tuple
from .base import BaseSampler, NegativeSampling, SampleOutput, SampleType from .base import BaseSampler, NegativeSampling, SampleOutput, SampleType
# from sample_cores import ParallelSampler, get_neighbors, heads_unique # from sample_cores import ParallelSampler, get_neighbors, heads_unique
from starrygl.lib.libstarrygl_ops_sampler import ParallelSampler, get_neighbors import starrygl.lib
from starrygl.lib.libstarrygl_sampler import ParallelSampler, get_neighbors
from torch.distributed.rpc import rpc_async from torch.distributed.rpc import rpc_async
# def outer_sample(graph_name, nodes, ts, fanout_index, with_outer_sample = SampleType.Outer):# 默认此时继续向外采样 # def outer_sample(graph_name, nodes, ts, fanout_index, with_outer_sample = SampleType.Outer):# 默认此时继续向外采样
......
...@@ -5,8 +5,10 @@ from os.path import abspath, join, dirname ...@@ -5,8 +5,10 @@ from os.path import abspath, join, dirname
from starrygl.distributed.context import DistributedContext from starrygl.distributed.context import DistributedContext
from starrygl.distributed.utils import DistIndex from starrygl.distributed.utils import DistIndex
from starrygl.module.modules import GeneralModel from starrygl.module.modules import GeneralModel
from pathlib import Path
from starrygl.module.utils import parse_config
from starrygl.module.utils import parse_config, EarlyStopMonitor
from starrygl.sample.graph_core import DataSet, DistributedGraphStore, TemporalNeighborSampleGraph from starrygl.sample.graph_core import DataSet, DistributedGraphStore, TemporalNeighborSampleGraph
from starrygl.sample.memory.shared_mailbox import SharedMailBox from starrygl.sample.memory.shared_mailbox import SharedMailBox
from starrygl.sample.sample_core.base import NegativeSampling from starrygl.sample.sample_core.base import NegativeSampling
...@@ -32,10 +34,13 @@ parser = argparse.ArgumentParser( ...@@ -32,10 +34,13 @@ parser = argparse.ArgumentParser(
) )
parser.add_argument('--rank', default=0, type=str, metavar='W', parser.add_argument('--rank', default=0, type=str, metavar='W',
help='name of dataset') help='name of dataset')
parser.add_argument('--patience', type=int, default=5, help='Patience for early stopping')
parser.add_argument('--world_size', default=1, type=int, metavar='W', parser.add_argument('--world_size', default=1, type=int, metavar='W',
help='number of negative samples') help='number of negative samples')
parser.add_argument('--dataname', default=1, type=str, metavar='W', parser.add_argument('--dataname', default=1, type=str, metavar='W',
help='number of negative samples') help='name of dataset')
parser.add_argument('--model', default='TGN', type=str, metavar='W',
help='name of model')
args = parser.parse_args() args = parser.parse_args()
from sklearn.metrics import average_precision_score, roc_auc_score from sklearn.metrics import average_precision_score, roc_auc_score
import torch import torch
...@@ -62,7 +67,7 @@ def seed_everything(seed=42): ...@@ -62,7 +67,7 @@ def seed_everything(seed=42):
seed_everything(1234) seed_everything(1234)
def main(): def main():
use_cuda = True use_cuda = True
sample_param, memory_param, gnn_param, train_param = parse_config('./config/TGN.yml') sample_param, memory_param, gnn_param, train_param = parse_config('./config/{}.yml'.format(args.model))
torch.set_num_threads(12) torch.set_num_threads(12)
ctx = DistributedContext.init(backend="nccl", use_gpu=True) ctx = DistributedContext.init(backend="nccl", use_gpu=True)
device_id = torch.cuda.current_device() device_id = torch.cuda.current_device()
...@@ -70,6 +75,13 @@ def main(): ...@@ -70,6 +75,13 @@ def main():
pdata = partition_load("/mnt/data/part_data/dataset/here/{}".format(args.dataname), algo="metis_for_tgnn") pdata = partition_load("/mnt/data/part_data/dataset/here/{}".format(args.dataname), algo="metis_for_tgnn")
graph = DistributedGraphStore(pdata = pdata) graph = DistributedGraphStore(pdata = pdata)
Path("./saved_models/").mkdir(parents=True, exist_ok=True)
Path("./saved_checkpoints/").mkdir(parents=True, exist_ok=True)
get_checkpoint_path = lambda \
epoch: f'./saved_checkpoints/{args.model}-{args.dataname}-{epoch}.pth'
gnn_param['dyrep'] = True if args.model == 'DyRep' else False
use_src_emb = gnn_param['use_src_emb'] if 'use_src_emb' in gnn_param else False
use_dst_emb = gnn_param['use_dst_emb'] if 'use_dst_emb' in gnn_param else False
sample_graph = TemporalNeighborSampleGraph(sample_graph = pdata.sample_graph,mode = 'full') sample_graph = TemporalNeighborSampleGraph(sample_graph = pdata.sample_graph,mode = 'full')
mailbox = SharedMailBox(pdata.ids.shape[0], memory_param, dim_edge_feat = pdata.edge_attr.shape[1] if pdata.edge_attr is not None else 0) mailbox = SharedMailBox(pdata.ids.shape[0], memory_param, dim_edge_feat = pdata.edge_attr.shape[1] if pdata.edge_attr is not None else 0)
sampler = NeighborSampler(num_nodes=graph.num_nodes, num_layers=1, fanout=[10],graph_data=sample_graph, workers=10,policy = 'recent',graph_name = "wiki_train") sampler = NeighborSampler(num_nodes=graph.num_nodes, num_layers=1, fanout=[10],graph_data=sample_graph, workers=10,policy = 'recent',graph_name = "wiki_train")
...@@ -88,7 +100,7 @@ def main(): ...@@ -88,7 +100,7 @@ def main():
trainloader = DistributedDataLoader(graph,train_data,sampler = sampler, trainloader = DistributedDataLoader(graph,train_data,sampler = sampler,
sampler_fn = SAMPLE_TYPE.SAMPLE_FROM_TEMPORAL_EDGES, sampler_fn = SAMPLE_TYPE.SAMPLE_FROM_TEMPORAL_EDGES,
neg_sampler=neg_sampler, neg_sampler=neg_sampler,
batch_size = 1000, batch_size = train_param['batch_size'],
shuffle=False, shuffle=False,
drop_last=True, drop_last=True,
chunk_size = None, chunk_size = None,
...@@ -98,7 +110,7 @@ def main(): ...@@ -98,7 +110,7 @@ def main():
testloader = DistributedDataLoader(graph,test_data,sampler = sampler, testloader = DistributedDataLoader(graph,test_data,sampler = sampler,
sampler_fn = SAMPLE_TYPE.SAMPLE_FROM_TEMPORAL_EDGES, sampler_fn = SAMPLE_TYPE.SAMPLE_FROM_TEMPORAL_EDGES,
neg_sampler=neg_sampler, neg_sampler=neg_sampler,
batch_size = 1000, batch_size = train_param['batch_size'],
shuffle=False, shuffle=False,
drop_last=False, drop_last=False,
chunk_size = None, chunk_size = None,
...@@ -108,7 +120,7 @@ def main(): ...@@ -108,7 +120,7 @@ def main():
valloader = DistributedDataLoader(graph,val_data,sampler = sampler, valloader = DistributedDataLoader(graph,val_data,sampler = sampler,
sampler_fn = SAMPLE_TYPE.SAMPLE_FROM_TEMPORAL_EDGES, sampler_fn = SAMPLE_TYPE.SAMPLE_FROM_TEMPORAL_EDGES,
neg_sampler=neg_sampler, neg_sampler=neg_sampler,
batch_size = 1000, batch_size = train_param['batch_size'],
shuffle=False, shuffle=False,
drop_last=False, drop_last=False,
chunk_size = None, chunk_size = None,
...@@ -175,9 +187,10 @@ def main(): ...@@ -175,9 +187,10 @@ def main():
last_updated_ts) last_updated_ts)
index, mail, mail_ts = mailbox.get_update_mail(dist_index_mapper, index, mail, mail_ts = mailbox.get_update_mail(dist_index_mapper,
src,dst,ts,edge_feats, src,dst,ts,edge_feats,
model.module.memory_updater.last_updated_memory, model.module.memory_updater.last_updated_memory,
) model.module.embedding,use_src_emb,use_dst_emb,
)
mailbox.set_mailbox_all_to_all(index,memory,memory_ts,mail,mail_ts,reduce_Op = 'max') mailbox.set_mailbox_all_to_all(index,memory,memory_ts,mail,mail_ts,reduce_Op = 'max')
...@@ -197,6 +210,8 @@ def main(): ...@@ -197,6 +210,8 @@ def main():
creterion = torch.nn.BCEWithLogitsLoss() creterion = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=train_param['lr']) optimizer = torch.optim.Adam(model.parameters(), lr=train_param['lr'])
early_stopper = EarlyStopMonitor(max_round=args.patience)
MODEL_SAVE_PATH = f'./saved_models/{args.model}-{args.dataname}.pth'
for e in range(train_param['epoch']): for e in range(train_param['epoch']):
torch.cuda.synchronize() torch.cuda.synchronize()
write_back_time = 0 write_back_time = 0
...@@ -251,9 +266,10 @@ def main(): ...@@ -251,9 +266,10 @@ def main():
last_updated_memory, last_updated_memory,
last_updated_ts) last_updated_ts)
index, mail, mail_ts = mailbox.get_update_mail(dist_index_mapper, index, mail, mail_ts = mailbox.get_update_mail(dist_index_mapper,
src,dst,ts,edge_feats, src,dst,ts,edge_feats,
model.module.memory_updater.last_updated_memory, model.module.memory_updater.last_updated_memory,
) model.module.embedding,use_src_emb,use_dst_emb,
)
start_event = torch.cuda.Event(enable_timing=True) start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True)
start_event.record() start_event.record()
...@@ -269,9 +285,18 @@ def main(): ...@@ -269,9 +285,18 @@ def main():
ap = 0 ap = 0
auc = 0 auc = 0
ap, auc = eval('val') ap, auc = eval('val')
print('\ttrain loss:{:.4f} train ap:{:4f} val ap:{:4f} val auc:{:4f}'.format(total_loss,train_ap, ap, auc)) early_stop = early_stopper.early_stop_check(ap)
print('\ttotal time:{:.2f}s prep time:{:.2f}s'.format(time.time()-epoch_start_time, time_prep)) if early_stop:
print('\t fetch time:{:.2f}s write back time:{:.2f}s'.format(fetch_time,write_back_time)) print("Early stopping at epoch {:d}".format(e))
print(f"Loading the best model at epoch {early_stopper.best_epoch}")
best_model_path = get_checkpoint_path(early_stopper.best_epoch)
model.load_state_dict(torch.load(best_model_path))
break
else:
print('\ttrain loss:{:.4f} train ap:{:4f} val ap:{:4f} val auc:{:4f}'.format(total_loss,train_ap, ap, auc))
print('\ttotal time:{:.2f}s prep time:{:.2f}s'.format(time.time()-epoch_start_time, time_prep))
print('\t fetch time:{:.2f}s write back time:{:.2f}s'.format(fetch_time,write_back_time))
torch.save(model.state_dict(), get_checkpoint_path(e))
model.eval() model.eval()
if mailbox is not None: if mailbox is not None:
...@@ -286,6 +311,7 @@ def main(): ...@@ -286,6 +311,7 @@ def main():
else: else:
print('\ttest AP:{:4f} test AUC:{:4f}'.format(ap, auc)) print('\ttest AP:{:4f} test AUC:{:4f}'.format(ap, auc))
print('test_dataset',test_data.edges.shape[1],'avg_time',avg_time/train_param['epoch']) print('test_dataset',test_data.edges.shape[1],'avg_time',avg_time/train_param['epoch'])
torch.save(model.state_dict(), MODEL_SAVE_PATH)
ctx.shutdown() ctx.shutdown()
if __name__ == "__main__": if __name__ == "__main__":
main() main()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment