Commit a70d1b3b by zlj

add sample module

parent abce8f53
def foo():
a = 1
def fa():
print(a)
a+=1
print(a)
def fb(a):
def apply():
print(a)
return apply
fc = lambda: print(a)
return fa, fb(a), fc
fa, fb, fc = foo()
fa()
fb()
fc()
import argparse
import os
import sys
from os.path import abspath, join, dirname
from starrygl.distributed.context import DistributedContext
from starrygl.distributed.utils import DistIndex
from starrygl.module.modules import GeneralModel
from starrygl.module.utils import parse_config
from starrygl.sample.graph_core import DataSet, GraphData, TemporalNeighborSampleGraph
from starrygl.sample.memory.shared_mailbox import SharedMailBox
from starrygl.sample.sample_core.base import NegativeSampling
from starrygl.sample.sample_core.neighbor_sampler import NeighborSampler
from starrygl.sample.part_utils.partition_tgnn import partition_load
import torch
import time
import torch
import torch.nn.functional as F
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group
import os
from starrygl.sample.data_loader import DistributedDataLoader
from starrygl.sample.batch_data import SAMPLE_TYPE
"""
test command
python test.py --world_size 2 --rank 0
--world_size', default=4, type=int, metavar='W',
help='number of workers')
parser.add_argument('--rank', default=0, type=int, metavar='W',
help='rank of the worker')
parser.add_argument('--log_interval', type=int, default=10, metavar='N',
help='interval between training status logs')
parser.add_argument('--gamma', type=float, default=0.99, metavar='G',
help='how much to value future rewards')
parser.add_argument('--seed', type=int, default=1, metavar='S',
help='random seed for reproducibility')
parser.add_argument('--num_sampler', type=int, default=10, metavar='S',
help='number of samplers')
parser.add_argument('--queue_size', type=int, default=10, metavar='S',
help='sampler queue size')
"""
parser = argparse.ArgumentParser(
description="RPC Reinforcement Learning Example",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument('--rank', default=0, type=str, metavar='W',
help='name of dataset')
parser.add_argument('--world_size', default=1, type=int, metavar='W',
help='number of negative samples')
args = parser.parse_args()
from sklearn.metrics import average_precision_score, roc_auc_score
import torch
import time
import random
import dgl
import numpy as np
from sklearn.metrics import average_precision_score, roc_auc_score
from torch.nn.parallel import DistributedDataParallel as DDP
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.rank)
os.environ["RANK"] = str(args.rank)
os.environ["WORLD_SIZE"] = str(args.world_size)
os.environ["LOCAL_RANK"] = str(0)
os.environ["MASTER_ADDR"] = '127.0.0.1'
os.environ["MASTER_PORT"] = '9337'
def seed_everything(seed=42):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
seed_everything(1234)
def main():
use_cuda = True
sample_param, memory_param, gnn_param, train_param = parse_config('./config/TGN.yml')
torch.set_num_threads(12)
ctx = DistributedContext.init(backend="nccl", use_gpu=True)
device_id = torch.cuda.current_device()
print('use cuda on',device_id)
pdata = partition_load("./dataset/here/WIKI", algo="metis_for_tgnn")
graph = GraphData(pdata = pdata)
#dist.barrier()
#for i in range(100):
# print(i)
dist.barrier()
idx = ((graph.eids_mapper >> 48).int() & 0xFFFF)
print((idx==0).nonzero().shape,(idx==1).nonzero().shape)
t1 = time.time()
"""
fut = []
for i in range(1000):
#print(i)
out = graph.edge_attr.index_select(graph.eids_mapper[(idx== 0)|(idx ==1)].to('cuda'))
fut.append(out)
#out.wait()
#out.value()
if i>0 and i%100==0:
f = torch.futures.collect_all(fut)
f.wait()
f.value()
fut = []
"""
partptr = torch.tensor([ ((i & 0xFFFF)<<48) for i in range(3) ],device = 'cuda')
for i in range(1000):
if i%100==0:
idx = graph.eids_mapper.to('cuda')
idx,inv = idx.unique(return_inverse=True)
ind = torch.searchsorted(idx,partptr,right=False)
len = ind[1:]-ind[:-1]
gatherlen = torch.empty([2],dtype = torch.long,device = 'cuda')
dist.all_to_all_single(gatherlen,len)
query_idx = torch.empty([gatherlen.sum()],dtype = torch.long,device = 'cuda')
input_s = list(len)
output_s = list(gatherlen)
dist.all_to_all_single(query_idx,idx,output_s,input_s)
input_f = graph.edge_attr.accessor.data[DistIndex(query_idx).loc]
f = torch.empty([idx.shape[0],graph.edge_attr.accessor.data.shape[1]],dtype=torch.float,device='cuda')
dist.all_to_all_single(f,input_f,input_s,output_s)
torch.cuda.synchronize()
t2 = time.time()-t1
print(t2)
#dist.barrier()
ctx.shutdown()
if __name__ == "__main__":
main()
......@@ -118,6 +118,8 @@ partition_save('./dataset/here/'+data_name, data, 1, 'metis_for_tgnn',
edge_weight_dict=edge_weight_dict)
partition_save('./dataset/here/'+data_name, data, 2, 'metis_for_tgnn',
edge_weight_dict=edge_weight_dict)
partition_save('./dataset/here/'+data_name, data, 4, 'metis_for_tgnn',
edge_weight_dict=edge_weight_dict)
#
# partition_save('./dataset/here/'+data_name, data, 4, 'metis_for_tgnn',
# edge_weight_dict=edge_weight_dict )
......
ERROR:root:unable to import libstarrygl.so, some features may not be available.
the number of nodes in graph is 1980, the number of edges in graph is 1293103
directory '/home/zlj/starrygl/dataset/here/LASTFM/metis_for_tgnn_1' not empty and cleared
running partition algorithm: metis_for_tgnn
saving partition data: 1/1
running partition algorithm: metis_for_tgnn
saving partition data: 1/2
saving partition data: 2/2
creating directory '/home/zlj/starrygl/dataset/here/LASTFM/metis_for_tgnn_4'
running partition algorithm: metis_for_tgnn
saving partition data: 1/4
saving partition data: 2/4
saving partition data: 3/4
saving partition data: 4/4
......@@ -15,6 +15,7 @@ class TensorAccessor:
self._data = data
self._ctx = DistributedContext.get_default_context()
self._rref = rpc.RRef(data)
self._rref.confirmed_by_owner
@property
def data(self):
......@@ -28,6 +29,20 @@ class TensorAccessor:
def ctx(self):
return self._ctx
def all_gather_index(self,index,input_split) -> Tensor:
out_split = torch.empty_like(input_split)
torch.distributed.all_to_all_single(out_split,input_split)
input_split = list(input_split)
output = torch.empty([out_split.sum()],dtype = index.dtype,device = index.device)
out_split = list(out_split)
torch.distributed.all_to_all_single(output,index,out_split,input_split)
return output,out_split,input_split
def all_gather_data(self,index,input_split,out_split):
output = torch.empty([int(Tensor(out_split).sum().item()),*self._data.shape[1:]],dtype = self._data.dtype,device = 'cuda')#self._data.device)
torch.distributed.all_to_all_single(output,self.data[index.to(self.data.device)].to('cuda'),output_split_sizes = out_split,input_split_sizes = input_split)
return output
def all_gather_rrefs(self) -> List[rpc.RRef]:
return self.ctx.all_gather_remote_objects(self.rref)
......@@ -101,7 +116,6 @@ class DistributedTensor:
def __init__(self, data: Tensor) -> None:
self.accessor = TensorAccessor(data)
self.rrefs = self.accessor.all_gather_rrefs()
# self.num_parts = len(self.rrefs)
local_sizes = []
......@@ -110,6 +124,7 @@ class DistributedTensor:
local_sizes.append(n)
self.num_nodes = DistInt(local_sizes)
self.num_parts = DistInt([1] * len(self.rrefs))
self.distptr = torch.tensor([((part_ids & 0xFFFF) << 48) for part_ids in range(self.num_parts()+1)],device = 'cuda')#data.device)
@property
......@@ -130,6 +145,17 @@ class DistributedTensor:
def ctx(self):
return self.accessor.ctx
def gather_select_index(self,dist_index: Union[Tensor,DistIndex]):
if isinstance(dist_index, Tensor):
dist_index = DistIndex(dist_index)
data = dist_index.data
posptr = torch.searchsorted(data,self.distptr,right = False)
input_split = posptr[1:] - posptr[:-1]
return self.accessor.all_gather_index(DistIndex(data).loc,input_split)
def scatter_data(self,local_index,input_split,out_split):
return self.accessor.all_gather_data(local_index,input_split=input_split,out_split=out_split)
def index_select(self, dist_index: Union[Tensor, DistIndex]):
if isinstance(dist_index, Tensor):
dist_index = DistIndex(dist_index)
......@@ -139,6 +165,10 @@ class DistributedTensor:
futs: List[torch.futures.Future] = []
for i in range(self.num_parts()):
#if i != torch.distributed.get_rank():
# continue
#f = torch.futures.Future()
#f.set_result(self.accessor.data[index[part_idx == i]])
f = self.accessor.async_index_select(0, index[part_idx == i], self.rrefs[i])
futs.append(f)
......@@ -150,6 +180,7 @@ class DistributedTensor:
result = torch.empty(
part_idx.size(0), *t.shape[1:], dtype=t.dtype, device=t.device,
)
result[part_idx == i] = t
return result
return torch.futures.collect_all(futs).then(callback)
......
from typing import List, Tuple
import torch
import torch.distributed as dist
from starrygl.distributed.utils import DistributedTensor
from starrygl.module.memorys import MailBox
from starrygl.sample.graph_core import DataSet
from starrygl.sample.graph_core import GraphData
from starrygl.sample.sample_core.base import BaseSampler, NegativeSampling
......@@ -40,7 +42,7 @@ def prepare_input(node_feat, edge_feat, mem_embedding,mfgs,dist_nid,dist_eid):
#print(idx.shape[0],b.srcdata['mem_ts'].shape)
return mfgs
def to_block(graph: GraphData, data, sample_out, mailbox = None,device = torch.device('cuda')):
def to_block(graph: GraphData, data, sample_out, mailbox:MailBox = None,device = torch.device('cuda')):
if len(sample_out) > 1:
sample_out,metadata = sample_out
......@@ -55,7 +57,6 @@ def to_block(graph: GraphData, data, sample_out, mailbox = None,device = torch.d
dist_eid,eid_inv = dist_eid.unique(return_inverse=True)
src_node = graph.sample_graph['edge_index'][0,eid_tensor*2].to(graph.nids_mapper.device)
src_ts = None
edge_feat = graph._get_edge_attr(dist_eid)
if metadata is None:
root_node = data.nodes.to(graph.nids_mapper.device)
root_len = [root_node.shape[0]]
......@@ -74,9 +75,24 @@ def to_block(graph: GraphData, data, sample_out, mailbox = None,device = torch.d
nid_tensor = torch.cat([root_node,src_node],dim = 0)
dist_nid = nid_mapper[nid_tensor].to(device)
dist_nid,nid_inv = dist_nid.unique(return_inverse = True)
if isinstance(graph.edge_attr,DistributedTensor):
local_index, input_split,output_split = graph.edge_attr.gather_select_index(dist_eid)
edge_feat = graph.edge_attr.scatter_data(local_index,input_split=input_split,out_split=output_split)
else:
edge_feat = graph._get_edge_attr(dist_eid)
local_index = None
if isinstance(graph.x,DistributedTensor):
local_index, input_split,output_split = graph.x.gather_select_index(dist_nid)
node_feat = graph.x.scatter_data(local_index,input_split=input_split,out_split=output_split)
else:
node_feat = graph._get_node_attr(dist_nid)
if mailbox is not None:
mem = mailbox._get_memory(dist_nid)
if torch.distributed.get_world_size() > 1:
if node_feat is None:
local_index, input_split,output_split = mailbox.node_memory.gather_select_index(dist_nid)
mem = mailbox.gather_memory(local_index,input_split,output_split)
else:
mem = mailbox.get_memory(dist_nid)
else:
mem = None
......@@ -120,25 +136,25 @@ def to_block(graph: GraphData, data, sample_out, mailbox = None,device = torch.d
return data,mfgs,metadata
data,mfgs,metadata = build_block()
if dist.get_world_size() > 1:
if(node_feat is None):
node_feat = torch.futures.Future()
node_feat.set_result(None)
if(edge_feat is None):
edge_feat = torch.futures.Future()
edge_feat.set_result(None)
if(mem is None):
mem = torch.futures.Future()
mem.set_result(None)
def callback(fs,mfgs,dist_nid,dist_eid):
node_feat,edge_feat,mem_embedding = fs.value()
node_feat = node_feat.value()
edge_feat = edge_feat.value()
mem_embedding = mem_embedding.value()
return prepare_input(node_feat,edge_feat,mem_embedding,mfgs,dist_nid,dist_eid)
cal = lambda fut: callback(fs=fut,mfgs = mfgs,dist_nid = dist_nid,dist_eid =dist_eid)
return data,torch.futures.collect_all([node_feat,edge_feat,mem]).then(cal),metadata
else:
#if dist.get_world_size() > 1:
# if(node_feat is None):
# node_feat = torch.futures.Future()
# node_feat.set_result(None)
# if(edge_feat is None):
# edge_feat = torch.futures.Future()
# edge_feat.set_result(None)
# if(mem is None):
# mem = torch.futures.Future()
# mem.set_result(None)
# def callback(fs,mfgs,dist_nid,dist_eid):
# node_feat,edge_feat,mem_embedding = fs.value()
# node_feat = node_feat.value()
# edge_feat = edge_feat.value()
# mem_embedding = mem_embedding.value()
# return prepare_input(node_feat,edge_feat,mem_embedding,mfgs,dist_nid,dist_eid)
# cal = lambda fut: callback(fs=fut,mfgs = mfgs,dist_nid = dist_nid,dist_eid =dist_eid)
# return data,torch.futures.collect_all([node_feat,edge_feat,mem]).then(cal),metadata
#else:
mfgs = prepare_input(node_feat,edge_feat,mem,mfgs,dist_nid,dist_eid)
#return build_block(node_feat,edge_feat,mem)#data,mfgs,metadata
return data,mfgs,metadata
......
......@@ -66,7 +66,8 @@ class DistributedDataLoader:
if train is True:
self._get_expected_idx(self.dataset.len)
else:
self.expected_idx = int(math.ceil(self.dataset.len/self.batch_size))
self._get_expected_idx(self.dataset.len,op = dist.ReduceOp.MAX)
#self.expected_idx = int(math.ceil(self.dataset.len/self.batch_size))
def __iter__(self):
if self.chunk_size is None:
......@@ -102,18 +103,19 @@ class DistributedDataLoader:
self.neg_sampler.set_next_pos(self.current_pos)
return self
def _get_expected_idx(self,data_size):
def _get_expected_idx(self,data_size,op = dist.ReduceOp.MIN):
world_size = dist.get_world_size()
self.expected_idx = data_size // self.batch_size if self.drop_last is True else int(math.ceil(data_size/self.batch_size))
if dist.get_world_size() > 1:
num_epochs = torch.tensor([self.expected_idx],dtype = torch.long,device=self.device)
dist.all_reduce(num_epochs, op=dist.ReduceOp.MIN)
print(num_epochs)
dist.all_reduce(num_epochs, op=op)
self.expected_idx = int(num_epochs.item())
def _next_data(self):
if self.current_pos >= self.dataset.len:
return None
return self.input_dataset._get_empty()
if self.current_pos + self.batch_size > self.input_dataset.len:
if self.drop_last:
......@@ -132,7 +134,7 @@ class DistributedDataLoader:
return next_data
def __next__(self):
if(dist.get_world_size() == 1):
if(dist.get_world_size() > 0):
if self.recv_idxs < self.expected_idx:
data = self._next_data()
batch_data = graph_sample(self.graph,
......@@ -165,6 +167,12 @@ class DistributedDataLoader:
next_data,self.neg_sampler,
self.mailbox,
self.device)
batch_data[1].wait()
self.submitted = self.submitted + 1
self.num_pending = self.num_pending + 1
self.recv_idxs += 1
self.num_pending -= 1
return batch_data[0],batch_data[1].value(),batch_data[2]
self.result_queue.append(batch_data)
self.submitted = self.submitted + 1
self.num_pending = self.num_pending + 1
......
......@@ -23,16 +23,16 @@ class GraphData():
world_size = dist.get_world_size()
if hasattr(pdata,'x') and pdata.x is not None:
if world_size > 1:
self.x = DistributedTensor(pdata.x.to(self.device))
self.x = DistributedTensor(pdata.x.to(self.device).to(torch.float))
else:
self.x = pdata.x.to(device).to(torch.float)
else:
self.x = None
if hasattr(pdata,'edge_attr') and pdata.edge_attr is not None:
if world_size > 1:
self.edge_attr = DistributedTensor(pdata.edge_attr.to(self.device))
self.edge_attr = DistributedTensor(pdata.edge_attr.to('cpu').to(torch.float))
else:
self.edge_attr = pdata.edge_attr.to('cuda').to(torch.float)
self.edge_attr = pdata.edge_attr.to('cpu').to(torch.float)
else:
self.edge_attr = None
......@@ -43,14 +43,15 @@ class GraphData():
return self.x[ids]
else:
return self.x.index_select(ids)
def _get_edge_attr(self,ids):
def _get_edge_attr(self,ids,):
if self.edge_attr is None:
return None
elif dist.get_world_size() == 1:
return self.edge_attr[ids.to('cpu')].to('cuda')
return self.edge_attr[ids]
else:
return self.edge_attr.index_select(ids)
class DataSet:
def __init__(self,nodes = None,
edges = None,
......@@ -69,6 +70,16 @@ class DataSet:
for k, v in kwargs.items():
assert isinstance(v,torch.Tensor) and v.shape[0]==self.len
setattr(self, k, v.to(device))
def _get_empty(self):
nodes = torch.empty([],dtype = self.nodes.dtype,device= self.nodes.device)if hasattr(self,'nodes') else None
edges = torch.empty([[],[]],dtype = self.edges.dtype,device= self.edge.device)if hasattr(self,'edges') else None
d = DataSet(nodes,edges)
for k,v in self.__dict__.items():
if k == 'edges' or k=='nodes' or k == 'len':
continue
else:
setattr(d,k,torch.empty([]))
return d
#@staticmethod
def get_next(self,indx):
......
......@@ -214,7 +214,7 @@ class SharedMailBox():
return index,memory,ts
def _get_memory(self,index):
def get_memory(self,index):
if self.num_parts == 1:
return self.node_memory.accessor.data[index],\
self.node_memory_ts.accessor.data[index],\
......@@ -234,6 +234,9 @@ class SharedMailBox():
#print(memory.shape[0])
return memory,memory_ts,mail,mail_ts
return torch.futures.collect_all([memory,memory_ts,mail,mail_ts]).then(callback)
def gather_memory(self,index,input_split,out_split):
return self.node_memory.scatter_data(index,input_split,out_split),\
self.node_memory_ts.scatter_data(index,input_split,out_split),\
self.mailbox.scatter_data(index,input_split,out_split),\
self.mailbox_ts.scatter_data(index,input_split,out_split)
......@@ -5,3 +5,7 @@ class WorkStreamEvent:
self.write_memory_stream = torch.cuda.Stream()
self.fetch_stream = torch.cuda.Stream()
self.write_mail_stream = torch.cuda.Stream()
self.event = None
event = WorkStreamEvent()
def get_event():
return event.event
......@@ -82,8 +82,9 @@ def main():
ctx = DistributedContext.init(backend="nccl", use_gpu=True)
device_id = torch.cuda.current_device()
print('use cuda on',device_id)
pdata = partition_load("./dataset/here/WIKI", algo="metis_for_tgnn")
pdata = partition_load("./dataset/here/GDELT", algo="metis_for_tgnn")
graph = GraphData(pdata = pdata)
sample_graph = TemporalNeighborSampleGraph(sample_graph = pdata.sample_graph,mode = 'full')
mailbox = SharedMailBox(pdata.ids.shape[0], memory_param, dim_edge_feat = pdata.edge_attr.shape[1] if pdata.edge_attr is not None else 0)
sampler = NeighborSampler(num_nodes=graph.num_nodes, num_layers=1, fanout=[10],graph_data=sample_graph, workers=10,policy = 'recent',graph_name = "wiki_train")
......@@ -102,17 +103,17 @@ def main():
trainloader = DistributedDataLoader(graph,train_data,sampler = sampler,
sampler_fn = SAMPLE_TYPE.SAMPLE_FROM_TEMPORAL_EDGES,
neg_sampler=neg_sampler,
batch_size = 2000,
batch_size = 1000,
shuffle=False,
drop_last=True,
chunk_size = None,
train=True,
queue_size = 100,
queue_size = 1000,
mailbox = mailbox)
testloader = DistributedDataLoader(graph,test_data,sampler = sampler,
sampler_fn = SAMPLE_TYPE.SAMPLE_FROM_TEMPORAL_EDGES,
neg_sampler=neg_sampler,
batch_size = 2000,
batch_size = 1000,
shuffle=False,
drop_last=False,
chunk_size = None,
......@@ -122,7 +123,7 @@ def main():
valloader = DistributedDataLoader(graph,val_data,sampler = sampler,
sampler_fn = SAMPLE_TYPE.SAMPLE_FROM_TEMPORAL_EDGES,
neg_sampler=neg_sampler,
batch_size = 2000,
batch_size = 1000,
shuffle=False,
drop_last=False,
chunk_size = None,
......@@ -158,9 +159,9 @@ def main():
with torch.no_grad():
total_loss = 0
signal = torch.tensor([0],dtype = int,device = device)
for roots,mfgs,metadata in loader:
signal[0] = 0
dist.all_reduce(signal,async_op=False)
pred_pos, pred_neg = model(mfgs,metadata)
total_loss += creterion(pred_pos, torch.ones_like(pred_pos))
total_loss += creterion(pred_neg, torch.zeros_like(pred_neg))
......@@ -190,32 +191,28 @@ def main():
src,dst,ts,edge_feats,
model.module.memory_updater.last_updated_memory,
)
#mailbox.set_mailbox_all_to_all(index,memory,memory_ts,mail,mail_ts,reduce_Op = 'max')
if mode == 'val':
val_losses.append(float(total_loss))
while(signal[0].item() != dist.get_world_size()):
signal[0] = 1
dist.all_reduce(signal,async_op=False)
if(signal[0].item() == dist.get_world_size()):
break
if mailbox is not None:
mailbox.set_mailbox_all_to_all(torch.tensor([],device = device).reshape(-1),
torch.tensor([],device = device).reshape(-1,mailbox.memory_size),
torch.tensor([],device = device).reshape(-1),
torch.tensor([],device = device).reshape(-1,mailbox.mailbox.accessor.data.size(2)),
torch.tensor([],device = device).reshape(-1),
reduce_Op = 'max')
mailbox.set_mailbox_all_to_all(index,memory,memory_ts,mail,mail_ts,reduce_Op = 'max')
ap = float(torch.tensor(aps).mean())
if neg_samples > 1:
auc_mrr = float(torch.cat(aucs_mrrs).mean())
else:
auc_mrr = float(torch.tensor(aucs_mrrs).mean())
#ap = float(torch.tensor(aps).mean())
#if neg_samples > 1:
# auc_mrr = float(torch.cat(aucs_mrrs).mean())
#else:
# auc_mrr = float(torch.tensor(aucs_mrrs).mean())
world_size = dist.get_world_size()
apc = torch.empty([loader.expected_idx*world_size],dtype = torch.float,device='cuda')
auc_mrr = torch.empty([loader.expected_idx*world_size],dtype = torch.float,device = 'cuda')
dist.all_gather_into_tensor(apc,torch.tensor(aps,device ='cuda',dtype=torch.float))
dist.all_gather_into_tensor(auc_mrr,torch.tensor(aucs_mrrs,device ='cuda',dtype=torch.float))
ap = float(torch.tensor(apc).mean())
auc_mrr = float(torch.tensor(auc_mrr).mean())
return ap, auc_mrr
creterion = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=train_param['lr'])
for e in range(train_param['epoch']):
torch.cuda.synchronize()
epoch_start_time = time.time()
train_aps = list()
print('Epoch {:d}:'.format(e))
......@@ -229,14 +226,18 @@ def main():
model.module.memory_updater.last_updated_ts = None
for roots,mfgs,metadata in trainloader:
t_prep_s = time.time()
optimizer.zero_grad()
with torch.cuda.stream(train_stream):
optimizer.zero_grad()
pred_pos, pred_neg = model(mfgs,metadata)
loss = creterion(pred_pos, torch.ones_like(pred_pos))
loss += creterion(pred_neg, torch.zeros_like(pred_neg))
total_loss += float(loss)
loss.backward()
optimizer.step()
#torch.cuda.synchronize()
t_prep_s = time.time()
y_pred = torch.cat([pred_pos, pred_neg], dim=0).sigmoid().cpu()
y_true = torch.cat([torch.ones(pred_pos.size(0)), torch.zeros(pred_neg.size(0))], dim=0)
......@@ -262,7 +263,9 @@ def main():
src,dst,ts,edge_feats,
model.module.memory_updater.last_updated_memory,
)
#mailbox.set_mailbox_all_to_all(index,memory,memory_ts,mail,mail_ts,reduce_Op = 'max')
mailbox.set_mailbox_all_to_all(index,memory,memory_ts,mail,mail_ts,reduce_Op = 'max')
torch.cuda.synchronize()
time_prep = time.time() - epoch_start_time
avg_time += time.time() - epoch_start_time
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment