Commit ef367556 by zhlj

fix bugs in jodie and APAN

parents 4927a9e0 ce4b726f
# Introduction
·
\ No newline at end of file
...@@ -22,7 +22,7 @@ gnn: ...@@ -22,7 +22,7 @@ gnn:
train: train:
- epoch: 100 - epoch: 100
batch_size: 1000 batch_size: 1000
lr: 0.0001 lr: 0.0002
dropout: 0.1 dropout: 0.1
att_dropout: 0.1 att_dropout: 0.1
# all_on_gpu: True # all_on_gpu: True
\ No newline at end of file
bash test_all.sh 12347
wait
bash test_all.sh 12357
wait
bash test_all.sh 63457
wait
\ No newline at end of file
...@@ -2,41 +2,24 @@ ...@@ -2,41 +2,24 @@
#跑了4卡的TaoBao #跑了4卡的TaoBao
# 定义数组变量 # 定义数组变量
seed=$1 seed=$1
addr="192.168.1.106" addr="192.168.1.105"
partition_params=("ours" ) partition_params=("ours" )
#"metis" "ldg" "random") #"metis" "ldg" "random")
#("ours" "metis" "ldg" "random") #("ours" "metis" "ldg" "random")
partitions="8" partitions="4"
node_per="4" node_per="4"
nnodes="2" nnodes="1"
<<<<<<< Updated upstream
node_rank="0" node_rank="0"
probability_params=("0.1") probability_params=("0.1")
sample_type_params=("boundery_recent_decay") sample_type_params=("boundery_recent_decay")
#sample_type_params=("recent" "boundery_recent_decay") #"boundery_recent_uniform") #sample_type_params=("recent" "boundery_recent_decay") #"boundery_recent_uniform")
#memory_type=("all_update" "p2p" "all_reduce" "historical" "local") #memory_type=("all_update" "p2p" "all_reduce" "historical" "local")
memory_type=( "historical") memory_type=("historical")
#"historical")
#memory_type=("local" "all_update" "historical" "all_reduce") #memory_type=("local" "all_update" "historical" "all_reduce")
shared_memory_ssim=("0.3") shared_memory_ssim=("0.3")
#data_param=("WIKI" "REDDIT" "LASTFM" "WikiTalk") #data_param=("WIKI" "REDDIT" "LASTFM" "WikiTalk")
<<<<<<< HEAD data_param=("WIKI" "LASTFM" "WikiTalk" "StackOverflow" "GDELT")
data_param=("GDELT")
=======
data_param=("LASTFM")
>>>>>>> 8233776274204f6cf2f8a2eb37022d426d6197d8
=======
node_rank="1"
probability_params=("0.1" "0.05" "0.01" "0")
sample_type_params=("recent" "boundery_recent_decay")
#sample_type_params=("recent" "boundery_recent_decay") #"boundery_recent_uniform")
#memory_type=("all_update" "p2p" "all_reduce" "historical" "local")
#memory_type=("all_update" "historical" "local")
memory_type=("historical" "all_update" "local")
#memory_type=("local" "all_update" "historical" "all_reduce")
shared_memory_ssim=("0.3" "0.7")
#data_param=("WIKI" "REDDIT" "LASTFM" "WikiTalk")
data_param=("WIKI" "LASTFM" "WikiTalk" "StackOverflow" "GDELT" "TaoBao")
>>>>>>> Stashed changes
#data_param=("WIKI" "REDDIT" "LASTFM" "DGraphFin" "WikiTalk" "StackOverflow") #data_param=("WIKI" "REDDIT" "LASTFM" "DGraphFin" "WikiTalk" "StackOverflow")
#data_param=("WIKI" "REDDIT" "LASTFM" "WikiTalk" "StackOverflow") #data_param=("WIKI" "REDDIT" "LASTFM" "WikiTalk" "StackOverflow")
#data_param=("REDDIT" "WikiTalk") #data_param=("REDDIT" "WikiTalk")
...@@ -48,9 +31,9 @@ data_param=("WIKI" "LASTFM" "WikiTalk" "StackOverflow" "GDELT" "TaoBao") ...@@ -48,9 +31,9 @@ data_param=("WIKI" "LASTFM" "WikiTalk" "StackOverflow" "GDELT" "TaoBao")
#seed=(( RANDOM % 1000000 + 1 )) #seed=(( RANDOM % 1000000 + 1 ))
mkdir -p all_"$seed" mkdir -p all_"$seed"
for data in "${data_param[@]}"; do for data in "${data_param[@]}"; do
model="APAN" model="JODIE"
if [ "$data" = "WIKI" ] || [ "$data" = "REDDIT" ] || [ "$data" = "LASTFM" ]; then if [ "$data" = "WIKI" ] || [ "$data" = "REDDIT" ] || [ "$data" = "LASTFM" ]; then
model="APAN" model="JODIE"
fi fi
#model="APAN" #model="APAN"
mkdir all_"$seed"/"$data" mkdir all_"$seed"/"$data"
...@@ -89,11 +72,7 @@ for data in "${data_param[@]}"; do ...@@ -89,11 +72,7 @@ for data in "${data_param[@]}"; do
if [ "$mem" = "historical" ]; then if [ "$mem" = "historical" ]; then
for ssim in "${shared_memory_ssim[@]}"; do for ssim in "${shared_memory_ssim[@]}"; do
if [ "$partition" = "ours" ]; then if [ "$partition" = "ours" ]; then
<<<<<<< Updated upstream
torchrun --nnodes "$nnodes" --node_rank "$node_rank" --nproc-per-node "$node_per" --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode "$model" --partition "$partition" --topk 0.1 --sample_type "$sample" --probability "$pro" --memory_type "$mem" --shared_memory_ssim "$ssim" --seed "$seed" > all_"$seed"/"$data"/"$model"/"$partitions"-ours_shared-0.01-"$mem"-"$ssim"-"$sample"-"$pro".out & torchrun --nnodes "$nnodes" --node_rank "$node_rank" --nproc-per-node "$node_per" --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode "$model" --partition "$partition" --topk 0.1 --sample_type "$sample" --probability "$pro" --memory_type "$mem" --shared_memory_ssim "$ssim" --seed "$seed" > all_"$seed"/"$data"/"$model"/"$partitions"-ours_shared-0.01-"$mem"-"$ssim"-"$sample"-"$pro".out &
=======
torchrun --nnodes "$nnodes" --node_rank "$node_rank" --nproc-per-node "$node_per" --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode "$model" --partition "$partition" --topk 0.1 --sample_type "$sample" --probability "$pro" --memory_type "$mem" --shared_memory_ssim "$ssim" > all/"$data"/"$model"/"$partitions"-ours_shared-0.01-"$mem"-"$ssim"-"$sample"-"$pro".out &
>>>>>>> Stashed changes
wait wait
fi fi
done done
......
...@@ -229,7 +229,8 @@ def main(): ...@@ -229,7 +229,8 @@ def main():
fanout = sample_param['neighbor'] if 'neighbor' in sample_param else [10] fanout = sample_param['neighbor'] if 'neighbor' in sample_param else [10]
policy = sample_param['strategy'] if 'strategy' in sample_param else 'recent' policy = sample_param['strategy'] if 'strategy' in sample_param else 'recent'
no_neg = sample_param['no_neg'] if 'no_neg' in sample_param else False no_neg = sample_param['no_neg'] if 'no_neg' in sample_param else False
if policy != 'recent': print(policy)
if policy == 'recent':
policy_train = args.sample_type#'boundery_recent_decay' policy_train = args.sample_type#'boundery_recent_decay'
else: else:
policy_train = policy policy_train = policy
...@@ -480,7 +481,7 @@ def main(): ...@@ -480,7 +481,7 @@ def main():
val_list = [] val_list = []
loss_list = [] loss_list = []
for e in range(train_param['epoch']): for e in range(train_param['epoch']):
model.module.memory_updater.empty_cache() # model.module.memory_updater.empty_cache()
tt._zero() tt._zero()
torch.cuda.synchronize() torch.cuda.synchronize()
epoch_start_time = time.time() epoch_start_time = time.time()
......
...@@ -336,9 +336,12 @@ class TransformerMemoryUpdater(torch.nn.Module): ...@@ -336,9 +336,12 @@ class TransformerMemoryUpdater(torch.nn.Module):
def forward(self, b, param = None): def forward(self, b, param = None):
Q = self.w_q(b.srcdata['mem']).reshape((b.num_src_nodes(), self.att_h, -1)) Q = self.w_q(b.srcdata['mem']).reshape((b.num_src_nodes(), self.att_h, -1))
mails = b.srcdata['mem_input'].reshape((b.num_src_nodes(), self.memory_param['mailbox_size'], -1)) mails = b.srcdata['mem_input'].reshape((b.num_src_nodes(), self.memory_param['mailbox_size'], -1))
#print(mails.shape,b.srcdata['mem_input'].shape,b.srcdata['mail_ts'].shape)
if self.dim_time > 0: if self.dim_time > 0:
time_feat = self.time_enc(b.srcdata['ts'][:, None] - b.srcdata['mail_ts']).reshape((b.num_src_nodes(), self.memory_param['mailbox_size'], -1)) time_feat = self.time_enc(b.srcdata['ts'][:, None] - b.srcdata['mail_ts']).reshape((b.num_src_nodes(), self.memory_param['mailbox_size'], -1))
#print(time_feat.shape)
mails = torch.cat([mails, time_feat], dim=2) mails = torch.cat([mails, time_feat], dim=2)
#print(mails.shape)
K = self.w_k(mails).reshape((b.num_src_nodes(), self.memory_param['mailbox_size'], self.att_h, -1)) K = self.w_k(mails).reshape((b.num_src_nodes(), self.memory_param['mailbox_size'], self.att_h, -1))
V = self.w_v(mails).reshape((b.num_src_nodes(), self.memory_param['mailbox_size'], self.att_h, -1)) V = self.w_v(mails).reshape((b.num_src_nodes(), self.memory_param['mailbox_size'], self.att_h, -1))
att = self.att_act((Q[:,None,:,:]*K).sum(dim=3)) att = self.att_act((Q[:,None,:,:]*K).sum(dim=3))
...@@ -394,7 +397,6 @@ class AsyncMemeoryUpdater(torch.nn.Module): ...@@ -394,7 +397,6 @@ class AsyncMemeoryUpdater(torch.nn.Module):
self.mailbox.handle_last_async() self.mailbox.handle_last_async()
submit_to_queue = False submit_to_queue = False
if nxt_fetch_func is not None: if nxt_fetch_func is not None:
nxt_fetch_func()
submit_to_queue = True submit_to_queue = True
self.mailbox.set_memory_all_reduce( self.mailbox.set_memory_all_reduce(
index,memory,memory_ts, index,memory,memory_ts,
...@@ -404,6 +406,8 @@ class AsyncMemeoryUpdater(torch.nn.Module): ...@@ -404,6 +406,8 @@ class AsyncMemeoryUpdater(torch.nn.Module):
wait_submit=submit_to_queue,spread_mail=spread_mail, wait_submit=submit_to_queue,spread_mail=spread_mail,
update_cross_mm=False, update_cross_mm=False,
) )
if nxt_fetch_func is not None:
nxt_fetch_func()
def local_func(self,index,memory,memory_ts,mail_index,mail,mail_ts,nxt_fetch_func,spread_mail=False): def local_func(self,index,memory,memory_ts,mail_index,mail,mail_ts,nxt_fetch_func,spread_mail=False):
if nxt_fetch_func is not None: if nxt_fetch_func is not None:
...@@ -471,6 +475,7 @@ class AsyncMemeoryUpdater(torch.nn.Module): ...@@ -471,6 +475,7 @@ class AsyncMemeoryUpdater(torch.nn.Module):
shared_ind = self.mailbox.is_shared_mask[DistIndex(b.srcdata['ID'][mask]).loc] shared_ind = self.mailbox.is_shared_mask[DistIndex(b.srcdata['ID'][mask]).loc]
transition_dense = b.srcdata['his_mem'][mask] + self.filter.get_incretment(shared_ind) transition_dense = b.srcdata['his_mem'][mask] + self.filter.get_incretment(shared_ind)
#print(transition_dense.shape)
if not (transition_dense.max().item() == 0): if not (transition_dense.max().item() == 0):
transition_dense -= transition_dense.min() transition_dense -= transition_dense.min()
transition_dense /= transition_dense.max() transition_dense /= transition_dense.max()
...@@ -514,8 +519,8 @@ class AsyncMemeoryUpdater(torch.nn.Module): ...@@ -514,8 +519,8 @@ class AsyncMemeoryUpdater(torch.nn.Module):
local_mask = (DistIndex(index).part==torch.distributed.get_rank()) local_mask = (DistIndex(index).part==torch.distributed.get_rank())
local_mask_mail = (DistIndex(index0).part==torch.distributed.get_rank()) local_mask_mail = (DistIndex(index0).part==torch.distributed.get_rank())
#self.mailbox.set_mailbox_local(DistIndex(index0[local_mask_mail]).loc,mail[local_mask_mail],mail_ts[local_mask_mail],Reduce_Op = 'max') self.mailbox.set_mailbox_local(DistIndex(index0[local_mask_mail]).loc,mail[local_mask_mail],mail_ts[local_mask_mail],Reduce_Op = 'max')
#self.mailbox.set_memory_local(DistIndex(index[local_mask]).loc,memory[local_mask],memory_ts[local_mask], Reduce_Op = 'max') self.mailbox.set_memory_local(DistIndex(index[local_mask]).loc,memory[local_mask],memory_ts[local_mask], Reduce_Op = 'max')
is_deliver=(self.mailbox.deliver_to == 'neighbors') is_deliver=(self.mailbox.deliver_to == 'neighbors')
self.update_hunk(index,memory,memory_ts,index0,mail,mail_ts,nxt_fetch_func,spread_mail= is_deliver) self.update_hunk(index,memory,memory_ts,index0,mail,mail_ts,nxt_fetch_func,spread_mail= is_deliver)
......
...@@ -344,6 +344,7 @@ def to_reversed_block(graph,data, sample_out,device = torch.device('cuda'),uniqu ...@@ -344,6 +344,7 @@ def to_reversed_block(graph,data, sample_out,device = torch.device('cuda'),uniqu
else: else:
metadata = None metadata = None
nid_mapper: torch.Tensor = graph.nids_mapper nid_mapper: torch.Tensor = graph.nids_mapper
#print('reverse block {}\n'.format(identity))
if identity is False: if identity is False:
assert len(sample_out) == 1 assert len(sample_out) == 1
ret = sample_out[0] ret = sample_out[0]
...@@ -354,6 +355,8 @@ def to_reversed_block(graph,data, sample_out,device = torch.device('cuda'),uniqu ...@@ -354,6 +355,8 @@ def to_reversed_block(graph,data, sample_out,device = torch.device('cuda'),uniqu
dist_eid = torch.tensor([],dtype=torch.long,device=device) dist_eid = torch.tensor([],dtype=torch.long,device=device)
src_index = ret.src_index().to(device) src_index = ret.src_index().to(device)
else: else:
#print('is jodie')
#print(sample_out)
src_index = torch.tensor([],dtype=torch.long,device=device) src_index = torch.tensor([],dtype=torch.long,device=device)
dst = torch.tensor([],dtype=torch.long,device=device) dst = torch.tensor([],dtype=torch.long,device=device)
dist_eid = torch.tensor([],dtype=torch.long,device=device) dist_eid = torch.tensor([],dtype=torch.long,device=device)
...@@ -401,6 +404,7 @@ def to_reversed_block(graph,data, sample_out,device = torch.device('cuda'),uniqu ...@@ -401,6 +404,7 @@ def to_reversed_block(graph,data, sample_out,device = torch.device('cuda'),uniqu
row_len = root_len row_len = root_len
col = first_block_id[:row_len] col = first_block_id[:row_len]
max_row = col.max().item()+1 max_row = col.max().item()+1
#print(src_index,dst)
b = dgl.create_block((col[src_index].to(device), b = dgl.create_block((col[src_index].to(device),
torch.arange(dst.shape[0],device=device,dtype=torch.long)),num_src_nodes=first_block_id.max().item()+1, torch.arange(dst.shape[0],device=device,dtype=torch.long)),num_src_nodes=first_block_id.max().item()+1,
num_dst_nodes=dst.shape[0]) num_dst_nodes=dst.shape[0])
...@@ -424,6 +428,7 @@ def graph_sample(graph,sampler,sample_fn,data,neg_sampling = None,out_device = t ...@@ -424,6 +428,7 @@ def graph_sample(graph,sampler,sample_fn,data,neg_sampling = None,out_device = t
t_s = time.time() t_s = time.time()
param = {'is_unique':False,'nid_mapper':nid_mapper,'eid_mapper':eid_mapper,'out_device':out_device} param = {'is_unique':False,'nid_mapper':nid_mapper,'eid_mapper':eid_mapper,'out_device':out_device}
out = sample_fn(sampler,data,neg_sampling,**param) out = sample_fn(sampler,data,neg_sampling,**param)
#print(sampler.policy)
if reversed is False: if reversed is False:
out,dist_nid,dist_eid = to_block(graph,data,out,out_device) out,dist_nid,dist_eid = to_block(graph,data,out,out_device)
else: else:
......
...@@ -18,9 +18,10 @@ class MemoryMoniter: ...@@ -18,9 +18,10 @@ class MemoryMoniter:
#self.memory_ssim.append(self.ssim(pre_memory,now_memory,method = 'cos')) #self.memory_ssim.append(self.ssim(pre_memory,now_memory,method = 'cos'))
#self.nid_list.append(nid) #self.nid_list.append(nid)
def draw(self,degree,data,model,e): def draw(self,degree,data,model,e):
torch.save(self.nid_list,'all_args.seed/{}/{}/memorynid_{}.pt'.format(data,model,e)) pass
torch.save(self.memorychange,'all_args.seed/{}/{}/memoryF_{}.pt'.format(data,model,e)) #torch.save(self.nid_list,'all_args.seed/{}/{}/memorynid_{}.pt'.format(data,model,e))
torch.save(self.memory_ssim,'all_args.seed/{}/{}/memcos_{}.pt'.format(data,model,e)) #torch.save(self.memorychange,'all_args.seed/{}/{}/memoryF_{}.pt'.format(data,model,e))
#torch.save(self.memory_ssim,'all_args.seed/{}/{}/memcos_{}.pt'.format(data,model,e))
# path = './memory/{}/'.format(data) # path = './memory/{}/'.format(data)
# if not os.path.exists(path): # if not os.path.exists(path):
...@@ -87,6 +88,7 @@ class MemoryMoniter: ...@@ -87,6 +88,7 @@ class MemoryMoniter:
def set_zero(self): def set_zero(self):
self.memorychange = [] pass
self.nid_list =[] #self.memorychange = []
self.memory_ssim = [] #self.nid_list =[]
#self.memory_ssim = []
...@@ -146,7 +146,7 @@ class SharedMailBox(): ...@@ -146,7 +146,7 @@ class SharedMailBox():
source_ts = max_ts source_ts = max_ts
source = source[id] source = source[id]
index = unq_id index = unq_id
#print(self.next_mail_pos[index])
self.mailbox_ts.accessor.data[index, self.next_mail_pos[index]] = source_ts self.mailbox_ts.accessor.data[index, self.next_mail_pos[index]] = source_ts
self.mailbox.accessor.data[index, self.next_mail_pos[index]] = source self.mailbox.accessor.data[index, self.next_mail_pos[index]] = source
if self.memory_param['mailbox_size'] > 1: if self.memory_param['mailbox_size'] > 1:
...@@ -180,9 +180,23 @@ class SharedMailBox(): ...@@ -180,9 +180,23 @@ class SharedMailBox():
if self.deliver_to == 'neighbors': if self.deliver_to == 'neighbors':
assert block is not None and Reduce_score is None assert block is not None and Reduce_score is None
mail = torch.cat([mail, mail[block.edges()[0].long()]], dim=0) # print(block.edges().shape)
mail_ts = torch.cat([mail_ts, mail_ts[block.edges()[0].long()]], dim=0) root = torch.cat([src,dst]).reshape(-1)
#pos = torch.empty(root.max()+1,dtype=torch.long,device=block.device)
#print('edge {} {}\n'.format(block.num_src_nodes(),block.edges()[0].max()))
#print('root is {} {} {} {}\n'.format(root,root.shape,root.max(),block.edges()[0].shape))
#pos_index = torch.arange(root.shape[0],device=root.device,dtype=root.dtype)
pos,idx = torch_scatter.scatter_max(mail_ts,root,0)
mail = torch.cat([mail, mail[idx]],dim=0)
mail_ts = torch.cat([mail_ts, mail_ts[idx]], dim=0)
#print('pos is {} {}\n'.format(pos,block.edges()[0].long()))
#mail = torch.cat([mail, mail[pos[block.edges()[0].long()]]],dim=0)
#mail_ts = torch.cat([mail_ts, mail_ts[pos[block.edges()[0].long()]]], dim=0)
#print(root,block.edges()[1].long())
index = torch.cat([index,block.dstdata['ID'][block.edges()[1].long()]],dim=0) index = torch.cat([index,block.dstdata['ID'][block.edges()[1].long()]],dim=0)
#mail = torch.cat([mail, mail[block.edges()[0].long()]], dim=0)
#mail_ts = torch.cat([mail_ts, mail_ts[block.edges()[0].long()]], dim=0)
#index = torch.cat([index,block.dstdata['ID'][block.edges()[1].long()]],dim=0)
if Reduce_score is not None: if Reduce_score is not None:
Reduce_score = torch.cat((Reduce_score,Reduce_score),-1).to(self.device) Reduce_score = torch.cat((Reduce_score,Reduce_score),-1).to(self.device)
if Reduce_score is None: if Reduce_score is None:
...@@ -192,12 +206,19 @@ class SharedMailBox(): ...@@ -192,12 +206,19 @@ class SharedMailBox():
mail = mail[idx] mail = mail[idx]
index = unq_index index = unq_index
else: else:
unq_index,inv = torch.unique(index,return_inverse = True) uni, inv = torch.unique(index, return_inverse=True)
perm = torch.arange(inv.size(0), dtype=inv.dtype, device=inv.device)
perm = inv.new_empty(uni.size(0)).scatter_(0, inv, perm)
index = index[perm]
mail = mail[perm]
mail_ts = mail_ts[perm]
#unq_index,inv = torch.unique(index,return_inverse = True)
#print(inv.shape,Reduce_score.shape) #print(inv.shape,Reduce_score.shape)
max_score,idx = torch_scatter.scatter_max(Reduce_score,inv,0) #max_score,idx = torch_scatter.scatter_max(Reduce_score,inv,0)
mail_ts = mail_ts[idx] #mail_ts = mail_ts[idx]
mail = mail[idx] #mail = mail[idx]
index = unq_index #index = unq_index
#print('mail {} {}\n'.format(index.shape,mail.shape,mail_ts.shape))
return index,mail,mail_ts return index,mail,mail_ts
def get_update_memory(self,index,memory,memory_ts,embedding): def get_update_memory(self,index,memory,memory_ts,embedding):
...@@ -205,7 +226,8 @@ class SharedMailBox(): ...@@ -205,7 +226,8 @@ class SharedMailBox():
max_ts,idx = torch_scatter.scatter_max(memory_ts,inv,0) max_ts,idx = torch_scatter.scatter_max(memory_ts,inv,0)
ts = max_ts ts = max_ts
index = unq_index index = unq_index
memory = memory[idx] memory = memory[idx]
#print('memory {} {}\n'.format(index.shape,memory.shape,ts.shape))
return index,memory,ts return index,memory,ts
def pack(self,memory=None,memory_ts=None,mail=None,mail_ts=None,index = None,mode=None): def pack(self,memory=None,memory_ts=None,mail=None,mail_ts=None,index = None,mode=None):
...@@ -250,6 +272,7 @@ class SharedMailBox(): ...@@ -250,6 +272,7 @@ class SharedMailBox():
self.set_mailbox_local(DistIndex(gather_id_list).loc,gather_mail,gather_mail_ts,Reduce_Op = reduce_Op) self.set_mailbox_local(DistIndex(gather_id_list).loc,gather_mail,gather_mail_ts,Reduce_Op = reduce_Op)
else: else:
gather_memory,gather_memory_ts = self.unpack(gather_memory) gather_memory,gather_memory_ts = self.unpack(gather_memory)
#print(gather_id_list.shape,gather_memory.shape,gather_memory_ts.shape)
self.set_memory_local(DistIndex(gather_id_list).loc,gather_memory,gather_memory_ts, Reduce_Op = reduce_Op) self.set_memory_local(DistIndex(gather_id_list).loc,gather_memory,gather_memory_ts, Reduce_Op = reduce_Op)
def handle_last_mail(self,reduce_Op=None,): def handle_last_mail(self,reduce_Op=None,):
if self.last_mail_sync is not None: if self.last_mail_sync is not None:
...@@ -260,6 +283,7 @@ class SharedMailBox(): ...@@ -260,6 +283,7 @@ class SharedMailBox():
if isinstance(gather_memory,list): if isinstance(gather_memory,list):
gather_memory = torch.cat(gather_memory,dim = 0) gather_memory = torch.cat(gather_memory,dim = 0)
gather_memory,gather_memory_ts = self.unpack(gather_memory) gather_memory,gather_memory_ts = self.unpack(gather_memory)
#print(gather_id_list.shape,gather_memory.shape,gather_memory_ts.shape)
self.set_mailbox_local(DistIndex(gather_id_list).loc,gather_memory,gather_memory_ts, Reduce_Op = reduce_Op) self.set_mailbox_local(DistIndex(gather_id_list).loc,gather_memory,gather_memory_ts, Reduce_Op = reduce_Op)
def handle_last_async(self,reduce_Op = None): def handle_last_async(self,reduce_Op = None):
self.handle_last_memory(reduce_Op) self.handle_last_memory(reduce_Op)
...@@ -303,6 +327,7 @@ class SharedMailBox(): ...@@ -303,6 +327,7 @@ class SharedMailBox():
return return
index,gather_id_list,mem,gather_memory,input_split,output_split,group,async_op = self.next_wait_mail_job index,gather_id_list,mem,gather_memory,input_split,output_split,group,async_op = self.next_wait_mail_job
self.next_wait_mail_job = None self.next_wait_mail_job = None
#print(index,gather_id_list)
handle0 = torch.distributed.all_to_all_single( handle0 = torch.distributed.all_to_all_single(
gather_id_list,index,output_split_sizes=output_split, gather_id_list,index,output_split_sizes=output_split,
input_split_sizes=input_split,group = group,async_op=async_op) input_split_sizes=input_split,group = group,async_op=async_op)
...@@ -409,6 +434,7 @@ class SharedMailBox(): ...@@ -409,6 +434,7 @@ class SharedMailBox():
self.update_p2p_mail() self.update_p2p_mail()
self.update_p2p_mem() self.update_p2p_mem()
self.handle_last_async() self.handle_last_async()
ctx = DistributedContext.get_default_context() ctx = DistributedContext.get_default_context()
...@@ -483,7 +509,7 @@ class SharedMailBox(): ...@@ -483,7 +509,7 @@ class SharedMailBox():
unq_index,inv = torch.unique(shared_index,return_inverse = True) unq_index,inv = torch.unique(shared_index,return_inverse = True)
max_ts,idx = torch_scatter.scatter_max(shared_memory_ts,inv,0) max_ts,idx = torch_scatter.scatter_max(shared_memory_ts,inv,0)
shared_memory = shared_memory[idx] shared_memory = shared_memory[idx]
shared_memory = shared_memory[idx] #shared_memory = shared_memory[idx]
shared_memory_ts = shared_memory_ts[idx] shared_memory_ts = shared_memory_ts[idx]
shared_mail_ts = shared_mail_ts[idx] shared_mail_ts = shared_mail_ts[idx]
shared_mail = shared_mail[idx] shared_mail = shared_mail[idx]
...@@ -495,11 +521,13 @@ class SharedMailBox(): ...@@ -495,11 +521,13 @@ class SharedMailBox():
self.set_mailbox_local(self.shared_nodes_index[shared_mail_indx],shared_mail,shared_mail_ts) self.set_mailbox_local(self.shared_nodes_index[shared_mail_indx],shared_mail,shared_mail_ts)
else: else:
self.next_wait_gather_memory_job = (shared_list,mem,shared_id_list,shared_memory_ind) self.next_wait_gather_memory_job = (shared_list,mem,shared_id_list,shared_memory_ind)
if not wait_submit: if not wait_submit:
self.update_shared() self.update_shared()
self.update_p2p_mail() self.update_p2p_mail()
self.update_p2p_mem() self.update_p2p_mem()
self.handle_last_async()
self.sychronize_shared()
#self.historical_cache.add_shared_to_queue(handle0,handle1,shared_id_list,shared_list) #self.historical_cache.add_shared_to_queue(handle0,handle1,shared_id_list,shared_list)
""" """
shared_memory = self.node_memory.accessor.data[self.shared_nodes_index] shared_memory = self.node_memory.accessor.data[self.shared_nodes_index]
......
...@@ -317,8 +317,9 @@ class NeighborSampler(BaseSampler): ...@@ -317,8 +317,9 @@ class NeighborSampler(BaseSampler):
""" """
if self.no_neg: if self.no_neg:
out,metadata = self.sample_from_nodes(seed[:seed.shape[0]//3*2], seed_ts[:seed.shape[0]//3*2], is_unique=False) out,metadata = self.sample_from_nodes(seed[:seed.shape[0]//3*2], seed_ts[:seed.shape[0]//3*2], is_unique=False)
else: else:
out,metadata = self.sample_from_nodes(seed[:seed.shape[0]], seed_ts[:seed.shape[0]//3*2], is_unique=False) out,metadata = self.sample_from_nodes(seed, seed_ts, is_unique=False)
src_pos_index = torch.arange(0,num_pos,dtype= torch.long,device=out_device) src_pos_index = torch.arange(0,num_pos,dtype= torch.long,device=out_device)
dst_pos_index = torch.arange(num_pos,2*num_pos,dtype= torch.long,device=out_device) dst_pos_index = torch.arange(num_pos,2*num_pos,dtype= torch.long,device=out_device)
if neg_sampling.is_triplet() or neg_sampling.is_tgbtriplet(): if neg_sampling.is_triplet() or neg_sampling.is_tgbtriplet():
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment