Commit f2ee424a by Wenjie Huang
parents 17306a86 38eb626e
install.sh merge=ours
\ No newline at end of file
*.tgz
*.my
# Byte-compiled / optimized / DLL files # Byte-compiled / optimized / DLL files
__pycache__/ __pycache__/
*.py[cod] *.py[cod]
......
#include<head.h> #include<head.h>
#include <sampler.h> #include <sampler.h>
#include <tppr.h>
#include <output.h> #include <output.h>
#include <neighbors.h> #include <neighbors.h>
#include <temporal_utils.h> #include <temporal_utils.h>
...@@ -88,4 +89,22 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) ...@@ -88,4 +89,22 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
.def("reset", &ParallelSampler::reset) .def("reset", &ParallelSampler::reset)
.def("get_ret", [](const ParallelSampler &ps) { return ps.ret; }); .def("get_ret", [](const ParallelSampler &ps) { return ps.ret; });
py::class_<ParallelTppRComputer>(m, "ParallelTppRComputer")
.def(py::init<TemporalNeighborBlock &, NodeIDType, EdgeIDType, int,
int, int, int, vector<float>&, vector<float>& >())
.def_readonly("ret", &ParallelTppRComputer::ret, py::return_value_policy::reference)
.def("reset_ret", &ParallelTppRComputer::reset_ret)
.def("reset_tppr", &ParallelTppRComputer::reset_tppr)
.def("reset_val_tppr", &ParallelTppRComputer::reset_val_tppr)
.def("backup_tppr", &ParallelTppRComputer::backup_tppr)
.def("restore_tppr", &ParallelTppRComputer::restore_tppr)
.def("restore_val_tppr", &ParallelTppRComputer::restore_val_tppr)
.def("get_pruned_topk", &ParallelTppRComputer::get_pruned_topk)
.def("extract_streaming_tppr", &ParallelTppRComputer::extract_streaming_tppr)
.def("streaming_topk", &ParallelTppRComputer::streaming_topk)
.def("single_streaming_topk", &ParallelTppRComputer::single_streaming_topk)
.def("streaming_topk_no_fake", &ParallelTppRComputer::streaming_topk_no_fake)
.def("compute_val_tppr", &ParallelTppRComputer::compute_val_tppr)
.def("get_ret", [](const ParallelTppRComputer &ps) { return ps.ret; });
} }
\ No newline at end of file
#pragma once #pragma once
#include <iostream> #include <iostream>
#include <algorithm>
#include <torch/extension.h> #include <torch/extension.h>
#include <omp.h> #include <omp.h>
#include <time.h> #include <time.h>
...@@ -17,6 +18,12 @@ typedef int64_t NodeIDType; ...@@ -17,6 +18,12 @@ typedef int64_t NodeIDType;
typedef int64_t EdgeIDType; typedef int64_t EdgeIDType;
typedef float WeightType; typedef float WeightType;
typedef float TimeStampType; typedef float TimeStampType;
typedef tuple<NodeIDType, EdgeIDType, TimeStampType> PPRKeyType;
typedef double PPRValueType;
typedef phmap::parallel_flat_hash_map<PPRKeyType, PPRValueType> PPRDictType;
typedef vector<PPRDictType> PPRListDictType;
typedef vector<vector<PPRDictType>> PPRListListDictType;
typedef vector<vector<double>> NormListType;
class TemporalNeighborBlock; class TemporalNeighborBlock;
class TemporalGraphBlock; class TemporalGraphBlock;
...@@ -28,6 +35,7 @@ int nodeIdToInOut(NodeIDType nid, int pid, const vector<NodeIDType>& part_ptr); ...@@ -28,6 +35,7 @@ int nodeIdToInOut(NodeIDType nid, int pid, const vector<NodeIDType>& part_ptr);
int nodeIdToPartId(NodeIDType nid, const vector<NodeIDType>& part_ptr); int nodeIdToPartId(NodeIDType nid, const vector<NodeIDType>& part_ptr);
vector<th::Tensor> divide_nodes_to_part(th::Tensor nodes, const vector<NodeIDType>& part_ptr, int threads); vector<th::Tensor> divide_nodes_to_part(th::Tensor nodes, const vector<NodeIDType>& part_ptr, int threads);
NodeIDType sample_multinomial(const vector<WeightType>& weights, default_random_engine& e); NodeIDType sample_multinomial(const vector<WeightType>& weights, default_random_engine& e);
vector<int64_t> sample_max(const vector<WeightType>& weights, int k);
...@@ -173,3 +181,17 @@ NodeIDType sample_multinomial(const vector<WeightType>& weights, default_random_ ...@@ -173,3 +181,17 @@ NodeIDType sample_multinomial(const vector<WeightType>& weights, default_random_
sample_indice = distance(cumulative_weights.begin(), it); sample_indice = distance(cumulative_weights.begin(), it);
return sample_indice; return sample_indice;
} }
vector<int64_t> sample_max(const vector<WeightType>& weights, int k) {
vector<int64_t> indices(weights.size());
for (int i = 0; i < weights.size(); ++i) {
indices[i] = i;
}
// 使用部分排序算法(选择算法)找到前k个最大值的索引
partial_sort(indices.begin(), indices.begin() + k, indices.end(),
[&weights](int64_t a, int64_t b) { return weights[a] > weights[b]; });
// 返回前k个最大值的索引
return vector<int64_t>(indices.begin(), indices.begin() + k);
}
\ No newline at end of file
...@@ -287,10 +287,15 @@ void TemporalNeighborBlock::update_edge_weight( ...@@ -287,10 +287,15 @@ void TemporalNeighborBlock::update_edge_weight(
for(int64_t i=0; i<edge_num; i++){ for(int64_t i=0; i<edge_num; i++){
//修改节点与邻居边的权重 //修改节点与邻居边的权重
AT_ASSERTM(this->inverted_index[dst[i]].count(src[i])==1, "Unexist Edge Index: "+to_string(src[i])+", "+to_string(dst[i]));
int index; int index;
if(this->with_eid) index = this->inverted_index[dst[i]][eid_ptr[i]]; if(this->with_eid){
else index = this->inverted_index[dst[i]][src[i]]; AT_ASSERTM(this->inverted_index[dst[i]].count(eid_ptr[i])==1, "Unexist Eid --> Col: "+to_string(eid_ptr[i])+"-->"+to_string(dst[i]));
index = this->inverted_index[dst[i]][eid_ptr[i]];
}
else{
AT_ASSERTM(this->inverted_index[dst[i]].count(src[i])==1, "Unexist Edge Index: "+to_string(src[i])+", "+to_string(dst[i]));
index = this->inverted_index[dst[i]][src[i]];
}
this->edge_weight[dst[i]][index] = ew[i]; this->edge_weight[dst[i]][index] = ew[i];
} }
} }
......
...@@ -11,6 +11,7 @@ class TemporalGraphBlock ...@@ -11,6 +11,7 @@ class TemporalGraphBlock
vector<int64_t> src_index; vector<int64_t> src_index;
vector<NodeIDType> sample_nodes; vector<NodeIDType> sample_nodes;
vector<TimeStampType> sample_nodes_ts; vector<TimeStampType> sample_nodes_ts;
vector<WeightType> e_weights;
double sample_time = 0; double sample_time = 0;
double tot_time = 0; double tot_time = 0;
int64_t sample_edge_num = 0; int64_t sample_edge_num = 0;
......
...@@ -105,13 +105,13 @@ void ParallelSampler :: neighbor_sample_from_nodes_static_layer(th::Tensor nodes ...@@ -105,13 +105,13 @@ void ParallelSampler :: neighbor_sample_from_nodes_static_layer(th::Tensor nodes
// uniform_int_distribution<> u(0, tnb.deg[node]-1); // uniform_int_distribution<> u(0, tnb.deg[node]-1);
// while(temp_s.size()!=fanout && temp_s.size()<tnb.neighbors_set[node].size()){ // while(temp_s.size()!=fanout && temp_s.size()<tnb.neighbors_set[node].size()){
for(int i=0;i<fanout;i++){ for(int i=0;i<fanout;i++){
//ѭ��ѡ��fanout���ھ� //循环选择fanout个邻居
NodeIDType indice; NodeIDType indice;
if(policy == "weighted"){//���DZ�Ȩ����Ϣ if(policy == "weighted"){//考虑边权重信
const vector<WeightType>& ew = tnb.edge_weight[node]; const vector<WeightType>& ew = tnb.edge_weight[node];
indice = sample_multinomial(ew, e); indice = sample_multinomial(ew, e);
} }
else if(policy == "uniform"){//���Ȳ��� else if(policy == "uniform"){//均匀采样
// indice = u(e); // indice = u(e);
indice = rand_r(&loc_seed) % (nei.size()); indice = rand_r(&loc_seed) % (nei.size());
} }
...@@ -119,7 +119,7 @@ void ParallelSampler :: neighbor_sample_from_nodes_static_layer(th::Tensor nodes ...@@ -119,7 +119,7 @@ void ParallelSampler :: neighbor_sample_from_nodes_static_layer(th::Tensor nodes
auto chosen_e_iter = edge.begin() + indice; auto chosen_e_iter = edge.begin() + indice;
if(part_unique){ if(part_unique){
auto rst = temp_s.insert(*chosen_n_iter); auto rst = temp_s.insert(*chosen_n_iter);
if(rst.second){ //���ظ� if(rst.second){ //不重复
eid_threads[tid].emplace_back(*chosen_e_iter); eid_threads[tid].emplace_back(*chosen_e_iter);
node_s_threads[tid].insert(*chosen_n_iter); node_s_threads[tid].insert(*chosen_n_iter);
if(!tnb.neighbors_set.empty() && temp_s.size()<fanout && temp_s.size()<tnb.neighbors_set[node].size()) fanout++; if(!tnb.neighbors_set.empty() && temp_s.size()<fanout && temp_s.size()<tnb.neighbors_set[node].size()) fanout++;
...@@ -229,7 +229,7 @@ void ParallelSampler :: neighbor_sample_from_nodes_with_before_layer( ...@@ -229,7 +229,7 @@ void ParallelSampler :: neighbor_sample_from_nodes_with_before_layer(
} }
} }
else{ else{
//��ѡ�ھӱߴ����ȳ��Ļ���Ҫ���ѡ��fanout���ھ� //可选邻居边大于扇出的话需要随机选择fanout个邻居
tgb_i[tid].src_index.insert(tgb_i[tid].src_index.end(), fanout, i); tgb_i[tid].src_index.insert(tgb_i[tid].src_index.end(), fanout, i);
uniform_int_distribution<> u(0, end_index-1); uniform_int_distribution<> u(0, end_index-1);
//cout<<end_index<<endl; //cout<<end_index<<endl;
......
...@@ -69,7 +69,7 @@ After defining the configuration file, we can firstly read the parameters from t ...@@ -69,7 +69,7 @@ After defining the configuration file, we can firstly read the parameters from t
Then a :code:`GeneralModel` object is created. If needed, we can adjust the model's parameters by modifying the contents of the configuration file. Here we provide 5 models for continuous-time temporal GNNs: Then a :code:`GeneralModel` object is created. If needed, we can adjust the model's parameters by modifying the contents of the configuration file. Here we provide 5 models for continuous-time temporal GNNs:
- :code:`TGN`: The TGN model proposed in `Temporal Graph Networks for Deep Learning on Dynamic Graphs <https://arxiv.org/abs/2006.10637>`__. - :code:`TGN`: The TGN model proposed in `Temporal Graph Networks for Deep Learning on Dynamic Graphs <https://arxiv.org/abs/2006.10637>`__.
- :code:`DyRep`: The DyRep model proposed in `Representation Learning and Reasoning on Temporal Knowledge Graphs <https://arxiv.org/abs/2010.02844>`__. - :code:`DyRep`: The DyRep model proposed in `Representation Learning and Reasoning on Temporal Knowledge Graphs <https://arxiv.org/abs/1803.04051>`__.
- :code:`TIGER`: The TIGER model proposed in `TIGER: A Transformer-Based Framework for Temporal Knowledge Graph Completion <https://arxiv.org/abs/2101.00529>`__. - :code:`TIGER`: The TIGER model proposed in `TIGER: A Transformer-Based Framework for Temporal Knowledge Graph Completion <https://arxiv.org/abs/2302.06057>`__.
- :code:`Jodie`: The Jodie model proposed in `JODIE: Joint Optimization of Dynamics and Importance for Online Embedding <https://arxiv.org/abs/1902.10197>`__. - :code:`Jodie`: The Jodie model proposed in `JODIE: Joint Optimization of Dynamics and Importance for Online Embedding <https://arxiv.org/abs/1908.01207>`__.
- :code:`TGAT`: The TGAT model proposed in `Temporal Graph Attention for Deep Temporal Modeling <https://arxiv.org/abs/2002.06902>`__. - :code:`TGAT`: The TGAT model proposed in `Temporal Graph Attention for Deep Temporal Modeling <https://arxiv.org/abs/2002.07962>`__.
\ No newline at end of file \ No newline at end of file
...@@ -11,3 +11,8 @@ cmake .. \ ...@@ -11,3 +11,8 @@ cmake .. \
&& mkdir ../starrygl/lib \ && mkdir ../starrygl/lib \
&& cp lib*.so ../starrygl/lib/ \ && cp lib*.so ../starrygl/lib/ \
&& patchelf --set-rpath '$ORIGIN:$ORIGIN/lib' --force-rpath ../starrygl/lib/*.so && patchelf --set-rpath '$ORIGIN:$ORIGIN/lib' --force-rpath ../starrygl/lib/*.so
# -DCMAKE_PREFIX_PATH="/home/zlj/.miniconda3/envs/dgnn/lib/python3.10/site-packages" \
# -DPython3_ROOT_DIR="/home/zlj/.miniconda3/envs/dgnn" \
# -DCUDA_TOOLKIT_ROOT_DIR="/home/zlj/local/cuda-12.2" \
\ No newline at end of file
import torch import torch
import dgl
from os.path import abspath, join, dirname from os.path import abspath, join, dirname
import sys import sys
sys.path.insert(0, join(abspath(dirname(__file__)))) sys.path.insert(0, join(abspath(dirname(__file__))))
......
...@@ -155,10 +155,10 @@ class DistributedDataLoader: ...@@ -155,10 +155,10 @@ class DistributedDataLoader:
self.expected_idx = data_size // self.batch_size if self.drop_last is True else int(math.ceil(data_size/self.batch_size)) self.expected_idx = data_size // self.batch_size if self.drop_last is True else int(math.ceil(data_size/self.batch_size))
if dist.get_world_size() > 1: if dist.get_world_size() > 1:
num_epochs = torch.tensor([self.expected_idx],dtype = torch.long,device=self.device) num_batchs = torch.tensor([self.expected_idx],dtype = torch.long,device=self.device)
print(num_epochs) print("num_batchs:", num_batchs)
dist.all_reduce(num_epochs, op=op) dist.all_reduce(num_batchs, op=op)
self.expected_idx = int(num_epochs.item()) self.expected_idx = int(num_batchs.item())
def _next_data(self): def _next_data(self):
if self.current_pos >= self.dataset.len: if self.current_pos >= self.dataset.len:
......
import os.path as osp
import torch
class GraphData():
def __init__(self, path):
assert path is not None and osp.exists(path),'path 不存在'
id,edge_index,data,partptr =torch.load(path)
# 当前分区序号
self.partition_id = id
# 总分区数
self.partitions = partptr.numel() - 1
# 全图结构数据
self.num_nodes = partptr[self.partitions]
self.num_edges = edge_index[0].numel()
self.edge_index = edge_index
# 该分区下的数据(包含特征向量和子图结构)pyg Data数据结构
self.data = data
# 分区映射关系
self.partptr = partptr
self.eid = [i for i in range(self.num_edges)]
def __init__(self, id, edge_index, data, partptr, timestamp=None):
# 当前分区序号
self.partition_id = id
# 总分区数
self.partitions = partptr.numel() - 1
# 全图结构数据
self.num_nodes = partptr[self.partitions]
if edge_index is not None:
self.num_edges = edge_index[0].numel()
self.edge_index = edge_index
self.edge_ts = timestamp
# 该分区下的数据(包含特征向量和子图结构)pyg Data数据结构
self.data = data
# 分区映射关系
self.partptr = partptr
# edge id
self.eid = torch.tensor([i for i in range(0, self.num_edges)])
def select_attr(self,index):
return torch.index_select(self.data.x,0,index)
#返回全局的节点id 所对应的分区
def get_part_num(self):
return self.data.x.size()[0]
def select_attr(self,index):
return torch.index_select(self.data.x,0,index)
def select_y(self,index):
return torch.index_select(self.data.y,0,index)
#返回全局的节点id 所对应的分区
def get_localId_by_partitionId(self,id,index):
#print(index)
if(id == -1 or id == 0):
return index
else:
return torch.add(index,-self.partptr[id])
def get_globalId_by_partitionId(self,id,index):
if(id == -1 or id == 0):
return index
else:
return torch.add(index,self.partptr[id])
def get_node_num(self):
return self.num_nodes
def localId_to_globalId(self,id,partitionId:int = -1):
'''
将分区partitionId内的点id映射为全局的id
'''
if partitionId == -1:
partitionId = self.partition_id
assert id >=self.partptr[self.partition_id] and id < self.partptr[self.partition_id+1]
ids_before = 0
if self.partition_id>0:
ids_before = self.partptr[self.partition_id-1]
return id+ids_before
def get_partitionId_by_globalId(self,id):
'''
通过全局id得到对应的分区序号
'''
partitionId = -1
assert id>=0 and id<self.num_nodes,'id 超过范围'
for i in range(self.partitions):
if id>=self.partptr[i] and id<self.partptr[i+1]:
partitionId = i
break
assert partitionId>=0, 'id 不存在对应的分区'
return partitionId
def get_nodes_by_partitionId(self,id):
'''
根据partitioId 返回该分区的节点数量
'''
assert id>=0 and id<self.partitions,'partitionId 非法'
return (int)(self.partptr[id+1]-self.partptr[id])
def __repr__(self):
return (f'{self.__class__.__name__}(\n'
f' partition_id={self.partition_id}\n'
f' data={self.data},\n'
f' global_info('
f'num_nodes={self.num_nodes},'
f' num_edges={self.num_edges},'
f' num_parts={self.partitions},'
f' edge_index=[2,{self.edge_index[0].numel()}])\n'
f')')
...@@ -8,7 +8,7 @@ import torch.multiprocessing as mp ...@@ -8,7 +8,7 @@ import torch.multiprocessing as mp
from typing import Optional, Tuple from typing import Optional, Tuple
from .base import BaseSampler, NegativeSampling, SampleOutput, SampleType from .base import BaseSampler, NegativeSampling, SampleOutput, SampleType
# from sample_cores import ParallelSampler, get_neighbors, heads_unique from starrygl.lib.libstarrygl_sampler import ParallelSampler, get_neighbors, heads_unique
from torch.distributed.rpc import rpc_async from torch.distributed.rpc import rpc_async
......
...@@ -62,7 +62,7 @@ def test(): ...@@ -62,7 +62,7 @@ def test():
torch.backends.cudnn.benchmark = False torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True torch.backends.cudnn.deterministic = True
g_data, df = load_gdelt_dataset() g_data, df = load_reddit_dataset()
print(g_data) print(g_data)
# for worker in [1,2,3,4,5,6,7,8,9,10,20,30]: # for worker in [1,2,3,4,5,6,7,8,9,10,20,30]:
# import random # import random
......
...@@ -70,7 +70,7 @@ def test(): ...@@ -70,7 +70,7 @@ def test():
torch.backends.cudnn.benchmark = False torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True torch.backends.cudnn.deterministic = True
g_data = load_ogb_dataset() g_data = load_reddit_dataset()
print(g_data) print(g_data)
from .neighbor_sampler import NeighborSampler, get_neighbors from .neighbor_sampler import NeighborSampler, get_neighbors
......
...@@ -43,7 +43,7 @@ parser.add_argument('--rank', default=0, type=int, metavar='W', ...@@ -43,7 +43,7 @@ parser.add_argument('--rank', default=0, type=int, metavar='W',
parser.add_argument('--patience', type=int, default=5, help='Patience for early stopping') parser.add_argument('--patience', type=int, default=5, help='Patience for early stopping')
parser.add_argument('--world_size', default=1, type=int, metavar='W', parser.add_argument('--world_size', default=1, type=int, metavar='W',
help='number of negative samples') help='number of negative samples')
parser.add_argument('--dataname', default=1, type=str, metavar='W', parser.add_argument('--dataname', default="WIKI", type=str, metavar='W',
help='name of dataset') help='name of dataset')
parser.add_argument('--model', default='TGN', type=str, metavar='W', parser.add_argument('--model', default='TGN', type=str, metavar='W',
help='name of model') help='name of model')
...@@ -52,17 +52,18 @@ from sklearn.metrics import average_precision_score, roc_auc_score ...@@ -52,17 +52,18 @@ from sklearn.metrics import average_precision_score, roc_auc_score
import torch import torch
import time import time
import random import random
import dgl
import numpy as np import numpy as np
from sklearn.metrics import average_precision_score, roc_auc_score from sklearn.metrics import average_precision_score, roc_auc_score
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
#os.environ['CUDA_VISIBLE_DEVICES'] = str(args.rank)
#os.environ["RANK"] = str(args.rank) os.environ['CUDA_VISIBLE_DEVICES'] = str(args.rank)
#os.environ["WORLD_SIZE"] = str(args.world_size) os.environ["RANK"] = str(args.rank)
#os.environ["LOCAL_RANK"] = str(0) os.environ["WORLD_SIZE"] = str(args.world_size)
os.environ["LOCAL_RANK"] = str(0)
torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
os.environ["MASTER_ADDR"] = '10.214.211.187' os.environ["MASTER_ADDR"] = '10.214.211.186'
os.environ["MASTER_PORT"] = '9337' os.environ["MASTER_PORT"] = '9667'
def seed_everything(seed=42): def seed_everything(seed=42):
random.seed(seed) random.seed(seed)
np.random.seed(seed) np.random.seed(seed)
...@@ -80,7 +81,7 @@ def main(): ...@@ -80,7 +81,7 @@ def main():
ctx = DistributedContext.init(backend="nccl", use_gpu=True) ctx = DistributedContext.init(backend="nccl", use_gpu=True)
device_id = torch.cuda.current_device() device_id = torch.cuda.current_device()
print('use cuda on',device_id) print('use cuda on',device_id)
pdata = partition_load("/mnt/data/part_data/dataset/here/{}".format(args.dataname), algo="metis_for_tgnn") pdata = partition_load("/mnt/data/part_data/here/{}".format(args.dataname), algo="metis_for_tgnn")
graph = DistributedGraphStore(pdata = pdata) graph = DistributedGraphStore(pdata = pdata)
Path("./saved_models/").mkdir(parents=True, exist_ok=True) Path("./saved_models/").mkdir(parents=True, exist_ok=True)
...@@ -149,7 +150,7 @@ def main(): ...@@ -149,7 +150,7 @@ def main():
#cache.init_cache_with_presample(trainloader,3) #cache.init_cache_with_presample(trainloader,3)
gnn_dim_node = 0 if graph.x is None else pdata.x.shape[1] gnn_dim_node = 0 if graph.x is None else pdata.x.shape[1]
gnn_dim_edge = 0 if graph.edge_attr is None else pdata.edge_attr.shape[1] gnn_dim_edge = 0 if graph.edge_attr is None else pdata.edge_attr.shape[1]
print(gnn_dim_node,gnn_dim_edge) print("gnn_dim_node:", gnn_dim_node, "gnn_dim_edge:", gnn_dim_edge)
avg_time = 0 avg_time = 0
if use_cuda: if use_cuda:
model = GeneralModel(gnn_dim_node, gnn_dim_edge, sample_param, memory_param, gnn_param, train_param).cuda() model = GeneralModel(gnn_dim_node, gnn_dim_edge, sample_param, memory_param, gnn_param, train_param).cuda()
...@@ -157,7 +158,7 @@ def main(): ...@@ -157,7 +158,7 @@ def main():
else: else:
model = GeneralModel(gnn_dim_node, gnn_dim_edge, sample_param, memory_param, gnn_param, train_param) model = GeneralModel(gnn_dim_node, gnn_dim_edge, sample_param, memory_param, gnn_param, train_param)
device = torch.device('cpu') device = torch.device('cpu')
model = DDP(model,find_unused_parameters=True) model = DDP(model,find_unused_parameters=False)
train_stream = torch.cuda.Stream() train_stream = torch.cuda.Stream()
send_stream = torch.cuda.Stream() send_stream = torch.cuda.Stream()
scatter_stream = torch.cuda.Stream() scatter_stream = torch.cuda.Stream()
...@@ -178,7 +179,7 @@ def main(): ...@@ -178,7 +179,7 @@ def main():
total_loss = 0 total_loss = 0
signal = torch.tensor([0],dtype = int,device = device) signal = torch.tensor([0],dtype = int,device = device)
with torch.cuda.stream(train_stream): with torch.cuda.stream(train_stream):
for roots,mfgs,metadata,sample_time in loader: for roots,mfgs,metadata in loader:
pred_pos, pred_neg = model(mfgs,metadata) pred_pos, pred_neg = model(mfgs,metadata)
total_loss += creterion(pred_pos, torch.ones_like(pred_pos)) total_loss += creterion(pred_pos, torch.ones_like(pred_pos))
...@@ -226,8 +227,11 @@ def main(): ...@@ -226,8 +227,11 @@ def main():
auc_mrr = torch.empty([loader.expected_idx*world_size],dtype = torch.float,device = 'cuda') auc_mrr = torch.empty([loader.expected_idx*world_size],dtype = torch.float,device = 'cuda')
dist.all_gather_into_tensor(apc,torch.tensor(aps,device ='cuda',dtype=torch.float)) dist.all_gather_into_tensor(apc,torch.tensor(aps,device ='cuda',dtype=torch.float))
dist.all_gather_into_tensor(auc_mrr,torch.tensor(aucs_mrrs,device ='cuda',dtype=torch.float)) dist.all_gather_into_tensor(auc_mrr,torch.tensor(aucs_mrrs,device ='cuda',dtype=torch.float))
ap = float(torch.tensor(apc).mean()) # ap = float(torch.tensor(apc).mean())
auc_mrr = float(torch.tensor(auc_mrr).mean()) # auc_mrr = float(torch.tensor(auc_mrr).mean())
ap = float(apc.clone().mean())
auc_mrr = float(auc_mrr.clone().mean())
return ap, auc_mrr return ap, auc_mrr
creterion = torch.nn.BCEWithLogitsLoss() creterion = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=train_param['lr']) optimizer = torch.optim.Adam(model.parameters(), lr=train_param['lr'])
...@@ -248,8 +252,8 @@ def main(): ...@@ -248,8 +252,8 @@ def main():
model.module.memory_updater.last_updated_nid = None model.module.memory_updater.last_updated_nid = None
model.module.memory_updater.last_updated_memory = None model.module.memory_updater.last_updated_memory = None
model.module.memory_updater.last_updated_ts = None model.module.memory_updater.last_updated_ts = None
for roots,mfgs,metadata,sample_time in trainloader: for roots,mfgs,metadata in trainloader:
fetch_time +=sample_time/1000 # fetch_time +=sample_time/1000
t_prep_s = time.time() t_prep_s = time.time()
with torch.cuda.stream(train_stream): with torch.cuda.stream(train_stream):
...@@ -264,9 +268,9 @@ def main(): ...@@ -264,9 +268,9 @@ def main():
optimizer.step() optimizer.step()
#torch.cuda.synchronize() #torch.cuda.synchronize()
t_prep_s = time.time() t_prep_s = time.time()
y_pred = torch.cat([pred_pos, pred_neg], dim=0).sigmoid().cpu() # y_pred = torch.cat([pred_pos, pred_neg], dim=0).sigmoid().cpu()
y_true = torch.cat([torch.ones(pred_pos.size(0)), torch.zeros(pred_neg.size(0))], dim=0) # y_true = torch.cat([torch.ones(pred_pos.size(0)), torch.zeros(pred_neg.size(0))], dim=0)
train_aps.append(average_precision_score(y_true, y_pred.detach().numpy())) # train_aps.append(average_precision_score(y_true, y_pred.detach().numpy()))
#start_event = torch.cuda.Event(enable_timing=True) #start_event = torch.cuda.Event(enable_timing=True)
#end_event = torch.cuda.Event(enable_timing=True) #end_event = torch.cuda.Event(enable_timing=True)
#start_event.record() #start_event.record()
...@@ -323,7 +327,7 @@ def main(): ...@@ -323,7 +327,7 @@ def main():
else: else:
print('\ttrain loss:{:.4f} train ap:{:4f} val ap:{:4f} val auc:{:4f}'.format(total_loss,train_ap, ap, auc)) print('\ttrain loss:{:.4f} train ap:{:4f} val ap:{:4f} val auc:{:4f}'.format(total_loss,train_ap, ap, auc))
print('\ttotal time:{:.2f}s prep time:{:.2f}s'.format(time.time()-epoch_start_time, time_prep)) print('\ttotal time:{:.2f}s prep time:{:.2f}s'.format(time.time()-epoch_start_time, time_prep))
print('\t fetch time:{:.2f}s write back time:{:.2f}s'.format(fetch_time,write_back_time)) # print('\t fetch time:{:.2f}s write back time:{:.2f}s'.format(fetch_time,write_back_time))
torch.save(model.state_dict(), get_checkpoint_path(e)) torch.save(model.state_dict(), get_checkpoint_path(e))
model.eval() model.eval()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment