Commit ef367556 by zhlj

fix bugs in jodie and APAN

parents 4927a9e0 ce4b726f
# Introduction
·
\ No newline at end of file
......@@ -22,7 +22,7 @@ gnn:
train:
- epoch: 100
batch_size: 1000
lr: 0.0001
lr: 0.0002
dropout: 0.1
att_dropout: 0.1
# all_on_gpu: True
\ No newline at end of file
bash test_all.sh 12347
wait
bash test_all.sh 12357
wait
bash test_all.sh 63457
wait
\ No newline at end of file
......@@ -2,41 +2,24 @@
#跑了4卡的TaoBao
# 定义数组变量
seed=$1
addr="192.168.1.106"
addr="192.168.1.105"
partition_params=("ours" )
#"metis" "ldg" "random")
#("ours" "metis" "ldg" "random")
partitions="8"
partitions="4"
node_per="4"
nnodes="2"
<<<<<<< Updated upstream
nnodes="1"
node_rank="0"
probability_params=("0.1")
sample_type_params=("boundery_recent_decay")
#sample_type_params=("recent" "boundery_recent_decay") #"boundery_recent_uniform")
#memory_type=("all_update" "p2p" "all_reduce" "historical" "local")
memory_type=( "historical")
memory_type=("historical")
#"historical")
#memory_type=("local" "all_update" "historical" "all_reduce")
shared_memory_ssim=("0.3")
#data_param=("WIKI" "REDDIT" "LASTFM" "WikiTalk")
<<<<<<< HEAD
data_param=("GDELT")
=======
data_param=("LASTFM")
>>>>>>> 8233776274204f6cf2f8a2eb37022d426d6197d8
=======
node_rank="1"
probability_params=("0.1" "0.05" "0.01" "0")
sample_type_params=("recent" "boundery_recent_decay")
#sample_type_params=("recent" "boundery_recent_decay") #"boundery_recent_uniform")
#memory_type=("all_update" "p2p" "all_reduce" "historical" "local")
#memory_type=("all_update" "historical" "local")
memory_type=("historical" "all_update" "local")
#memory_type=("local" "all_update" "historical" "all_reduce")
shared_memory_ssim=("0.3" "0.7")
#data_param=("WIKI" "REDDIT" "LASTFM" "WikiTalk")
data_param=("WIKI" "LASTFM" "WikiTalk" "StackOverflow" "GDELT" "TaoBao")
>>>>>>> Stashed changes
data_param=("WIKI" "LASTFM" "WikiTalk" "StackOverflow" "GDELT")
#data_param=("WIKI" "REDDIT" "LASTFM" "DGraphFin" "WikiTalk" "StackOverflow")
#data_param=("WIKI" "REDDIT" "LASTFM" "WikiTalk" "StackOverflow")
#data_param=("REDDIT" "WikiTalk")
......@@ -48,9 +31,9 @@ data_param=("WIKI" "LASTFM" "WikiTalk" "StackOverflow" "GDELT" "TaoBao")
#seed=(( RANDOM % 1000000 + 1 ))
mkdir -p all_"$seed"
for data in "${data_param[@]}"; do
model="APAN"
model="JODIE"
if [ "$data" = "WIKI" ] || [ "$data" = "REDDIT" ] || [ "$data" = "LASTFM" ]; then
model="APAN"
model="JODIE"
fi
#model="APAN"
mkdir all_"$seed"/"$data"
......@@ -89,11 +72,7 @@ for data in "${data_param[@]}"; do
if [ "$mem" = "historical" ]; then
for ssim in "${shared_memory_ssim[@]}"; do
if [ "$partition" = "ours" ]; then
<<<<<<< Updated upstream
torchrun --nnodes "$nnodes" --node_rank "$node_rank" --nproc-per-node "$node_per" --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode "$model" --partition "$partition" --topk 0.1 --sample_type "$sample" --probability "$pro" --memory_type "$mem" --shared_memory_ssim "$ssim" --seed "$seed" > all_"$seed"/"$data"/"$model"/"$partitions"-ours_shared-0.01-"$mem"-"$ssim"-"$sample"-"$pro".out &
=======
torchrun --nnodes "$nnodes" --node_rank "$node_rank" --nproc-per-node "$node_per" --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode "$model" --partition "$partition" --topk 0.1 --sample_type "$sample" --probability "$pro" --memory_type "$mem" --shared_memory_ssim "$ssim" > all/"$data"/"$model"/"$partitions"-ours_shared-0.01-"$mem"-"$ssim"-"$sample"-"$pro".out &
>>>>>>> Stashed changes
wait
fi
done
......
......@@ -229,7 +229,8 @@ def main():
fanout = sample_param['neighbor'] if 'neighbor' in sample_param else [10]
policy = sample_param['strategy'] if 'strategy' in sample_param else 'recent'
no_neg = sample_param['no_neg'] if 'no_neg' in sample_param else False
if policy != 'recent':
print(policy)
if policy == 'recent':
policy_train = args.sample_type#'boundery_recent_decay'
else:
policy_train = policy
......@@ -480,7 +481,7 @@ def main():
val_list = []
loss_list = []
for e in range(train_param['epoch']):
model.module.memory_updater.empty_cache()
# model.module.memory_updater.empty_cache()
tt._zero()
torch.cuda.synchronize()
epoch_start_time = time.time()
......
......@@ -336,9 +336,12 @@ class TransformerMemoryUpdater(torch.nn.Module):
def forward(self, b, param = None):
Q = self.w_q(b.srcdata['mem']).reshape((b.num_src_nodes(), self.att_h, -1))
mails = b.srcdata['mem_input'].reshape((b.num_src_nodes(), self.memory_param['mailbox_size'], -1))
#print(mails.shape,b.srcdata['mem_input'].shape,b.srcdata['mail_ts'].shape)
if self.dim_time > 0:
time_feat = self.time_enc(b.srcdata['ts'][:, None] - b.srcdata['mail_ts']).reshape((b.num_src_nodes(), self.memory_param['mailbox_size'], -1))
#print(time_feat.shape)
mails = torch.cat([mails, time_feat], dim=2)
#print(mails.shape)
K = self.w_k(mails).reshape((b.num_src_nodes(), self.memory_param['mailbox_size'], self.att_h, -1))
V = self.w_v(mails).reshape((b.num_src_nodes(), self.memory_param['mailbox_size'], self.att_h, -1))
att = self.att_act((Q[:,None,:,:]*K).sum(dim=3))
......@@ -394,7 +397,6 @@ class AsyncMemeoryUpdater(torch.nn.Module):
self.mailbox.handle_last_async()
submit_to_queue = False
if nxt_fetch_func is not None:
nxt_fetch_func()
submit_to_queue = True
self.mailbox.set_memory_all_reduce(
index,memory,memory_ts,
......@@ -404,6 +406,8 @@ class AsyncMemeoryUpdater(torch.nn.Module):
wait_submit=submit_to_queue,spread_mail=spread_mail,
update_cross_mm=False,
)
if nxt_fetch_func is not None:
nxt_fetch_func()
def local_func(self,index,memory,memory_ts,mail_index,mail,mail_ts,nxt_fetch_func,spread_mail=False):
if nxt_fetch_func is not None:
......@@ -471,6 +475,7 @@ class AsyncMemeoryUpdater(torch.nn.Module):
shared_ind = self.mailbox.is_shared_mask[DistIndex(b.srcdata['ID'][mask]).loc]
transition_dense = b.srcdata['his_mem'][mask] + self.filter.get_incretment(shared_ind)
#print(transition_dense.shape)
if not (transition_dense.max().item() == 0):
transition_dense -= transition_dense.min()
transition_dense /= transition_dense.max()
......@@ -514,8 +519,8 @@ class AsyncMemeoryUpdater(torch.nn.Module):
local_mask = (DistIndex(index).part==torch.distributed.get_rank())
local_mask_mail = (DistIndex(index0).part==torch.distributed.get_rank())
#self.mailbox.set_mailbox_local(DistIndex(index0[local_mask_mail]).loc,mail[local_mask_mail],mail_ts[local_mask_mail],Reduce_Op = 'max')
#self.mailbox.set_memory_local(DistIndex(index[local_mask]).loc,memory[local_mask],memory_ts[local_mask], Reduce_Op = 'max')
self.mailbox.set_mailbox_local(DistIndex(index0[local_mask_mail]).loc,mail[local_mask_mail],mail_ts[local_mask_mail],Reduce_Op = 'max')
self.mailbox.set_memory_local(DistIndex(index[local_mask]).loc,memory[local_mask],memory_ts[local_mask], Reduce_Op = 'max')
is_deliver=(self.mailbox.deliver_to == 'neighbors')
self.update_hunk(index,memory,memory_ts,index0,mail,mail_ts,nxt_fetch_func,spread_mail= is_deliver)
......
......@@ -344,6 +344,7 @@ def to_reversed_block(graph,data, sample_out,device = torch.device('cuda'),uniqu
else:
metadata = None
nid_mapper: torch.Tensor = graph.nids_mapper
#print('reverse block {}\n'.format(identity))
if identity is False:
assert len(sample_out) == 1
ret = sample_out[0]
......@@ -354,6 +355,8 @@ def to_reversed_block(graph,data, sample_out,device = torch.device('cuda'),uniqu
dist_eid = torch.tensor([],dtype=torch.long,device=device)
src_index = ret.src_index().to(device)
else:
#print('is jodie')
#print(sample_out)
src_index = torch.tensor([],dtype=torch.long,device=device)
dst = torch.tensor([],dtype=torch.long,device=device)
dist_eid = torch.tensor([],dtype=torch.long,device=device)
......@@ -401,6 +404,7 @@ def to_reversed_block(graph,data, sample_out,device = torch.device('cuda'),uniqu
row_len = root_len
col = first_block_id[:row_len]
max_row = col.max().item()+1
#print(src_index,dst)
b = dgl.create_block((col[src_index].to(device),
torch.arange(dst.shape[0],device=device,dtype=torch.long)),num_src_nodes=first_block_id.max().item()+1,
num_dst_nodes=dst.shape[0])
......@@ -424,6 +428,7 @@ def graph_sample(graph,sampler,sample_fn,data,neg_sampling = None,out_device = t
t_s = time.time()
param = {'is_unique':False,'nid_mapper':nid_mapper,'eid_mapper':eid_mapper,'out_device':out_device}
out = sample_fn(sampler,data,neg_sampling,**param)
#print(sampler.policy)
if reversed is False:
out,dist_nid,dist_eid = to_block(graph,data,out,out_device)
else:
......
......@@ -18,9 +18,10 @@ class MemoryMoniter:
#self.memory_ssim.append(self.ssim(pre_memory,now_memory,method = 'cos'))
#self.nid_list.append(nid)
def draw(self,degree,data,model,e):
torch.save(self.nid_list,'all_args.seed/{}/{}/memorynid_{}.pt'.format(data,model,e))
torch.save(self.memorychange,'all_args.seed/{}/{}/memoryF_{}.pt'.format(data,model,e))
torch.save(self.memory_ssim,'all_args.seed/{}/{}/memcos_{}.pt'.format(data,model,e))
pass
#torch.save(self.nid_list,'all_args.seed/{}/{}/memorynid_{}.pt'.format(data,model,e))
#torch.save(self.memorychange,'all_args.seed/{}/{}/memoryF_{}.pt'.format(data,model,e))
#torch.save(self.memory_ssim,'all_args.seed/{}/{}/memcos_{}.pt'.format(data,model,e))
# path = './memory/{}/'.format(data)
# if not os.path.exists(path):
......@@ -87,6 +88,7 @@ class MemoryMoniter:
def set_zero(self):
self.memorychange = []
self.nid_list =[]
self.memory_ssim = []
pass
#self.memorychange = []
#self.nid_list =[]
#self.memory_ssim = []
......@@ -146,7 +146,7 @@ class SharedMailBox():
source_ts = max_ts
source = source[id]
index = unq_id
#print(self.next_mail_pos[index])
self.mailbox_ts.accessor.data[index, self.next_mail_pos[index]] = source_ts
self.mailbox.accessor.data[index, self.next_mail_pos[index]] = source
if self.memory_param['mailbox_size'] > 1:
......@@ -180,9 +180,23 @@ class SharedMailBox():
if self.deliver_to == 'neighbors':
assert block is not None and Reduce_score is None
mail = torch.cat([mail, mail[block.edges()[0].long()]], dim=0)
mail_ts = torch.cat([mail_ts, mail_ts[block.edges()[0].long()]], dim=0)
# print(block.edges().shape)
root = torch.cat([src,dst]).reshape(-1)
#pos = torch.empty(root.max()+1,dtype=torch.long,device=block.device)
#print('edge {} {}\n'.format(block.num_src_nodes(),block.edges()[0].max()))
#print('root is {} {} {} {}\n'.format(root,root.shape,root.max(),block.edges()[0].shape))
#pos_index = torch.arange(root.shape[0],device=root.device,dtype=root.dtype)
pos,idx = torch_scatter.scatter_max(mail_ts,root,0)
mail = torch.cat([mail, mail[idx]],dim=0)
mail_ts = torch.cat([mail_ts, mail_ts[idx]], dim=0)
#print('pos is {} {}\n'.format(pos,block.edges()[0].long()))
#mail = torch.cat([mail, mail[pos[block.edges()[0].long()]]],dim=0)
#mail_ts = torch.cat([mail_ts, mail_ts[pos[block.edges()[0].long()]]], dim=0)
#print(root,block.edges()[1].long())
index = torch.cat([index,block.dstdata['ID'][block.edges()[1].long()]],dim=0)
#mail = torch.cat([mail, mail[block.edges()[0].long()]], dim=0)
#mail_ts = torch.cat([mail_ts, mail_ts[block.edges()[0].long()]], dim=0)
#index = torch.cat([index,block.dstdata['ID'][block.edges()[1].long()]],dim=0)
if Reduce_score is not None:
Reduce_score = torch.cat((Reduce_score,Reduce_score),-1).to(self.device)
if Reduce_score is None:
......@@ -192,12 +206,19 @@ class SharedMailBox():
mail = mail[idx]
index = unq_index
else:
unq_index,inv = torch.unique(index,return_inverse = True)
uni, inv = torch.unique(index, return_inverse=True)
perm = torch.arange(inv.size(0), dtype=inv.dtype, device=inv.device)
perm = inv.new_empty(uni.size(0)).scatter_(0, inv, perm)
index = index[perm]
mail = mail[perm]
mail_ts = mail_ts[perm]
#unq_index,inv = torch.unique(index,return_inverse = True)
#print(inv.shape,Reduce_score.shape)
max_score,idx = torch_scatter.scatter_max(Reduce_score,inv,0)
mail_ts = mail_ts[idx]
mail = mail[idx]
index = unq_index
#max_score,idx = torch_scatter.scatter_max(Reduce_score,inv,0)
#mail_ts = mail_ts[idx]
#mail = mail[idx]
#index = unq_index
#print('mail {} {}\n'.format(index.shape,mail.shape,mail_ts.shape))
return index,mail,mail_ts
def get_update_memory(self,index,memory,memory_ts,embedding):
......@@ -206,6 +227,7 @@ class SharedMailBox():
ts = max_ts
index = unq_index
memory = memory[idx]
#print('memory {} {}\n'.format(index.shape,memory.shape,ts.shape))
return index,memory,ts
def pack(self,memory=None,memory_ts=None,mail=None,mail_ts=None,index = None,mode=None):
......@@ -250,6 +272,7 @@ class SharedMailBox():
self.set_mailbox_local(DistIndex(gather_id_list).loc,gather_mail,gather_mail_ts,Reduce_Op = reduce_Op)
else:
gather_memory,gather_memory_ts = self.unpack(gather_memory)
#print(gather_id_list.shape,gather_memory.shape,gather_memory_ts.shape)
self.set_memory_local(DistIndex(gather_id_list).loc,gather_memory,gather_memory_ts, Reduce_Op = reduce_Op)
def handle_last_mail(self,reduce_Op=None,):
if self.last_mail_sync is not None:
......@@ -260,6 +283,7 @@ class SharedMailBox():
if isinstance(gather_memory,list):
gather_memory = torch.cat(gather_memory,dim = 0)
gather_memory,gather_memory_ts = self.unpack(gather_memory)
#print(gather_id_list.shape,gather_memory.shape,gather_memory_ts.shape)
self.set_mailbox_local(DistIndex(gather_id_list).loc,gather_memory,gather_memory_ts, Reduce_Op = reduce_Op)
def handle_last_async(self,reduce_Op = None):
self.handle_last_memory(reduce_Op)
......@@ -303,6 +327,7 @@ class SharedMailBox():
return
index,gather_id_list,mem,gather_memory,input_split,output_split,group,async_op = self.next_wait_mail_job
self.next_wait_mail_job = None
#print(index,gather_id_list)
handle0 = torch.distributed.all_to_all_single(
gather_id_list,index,output_split_sizes=output_split,
input_split_sizes=input_split,group = group,async_op=async_op)
......@@ -409,6 +434,7 @@ class SharedMailBox():
self.update_p2p_mail()
self.update_p2p_mem()
self.handle_last_async()
ctx = DistributedContext.get_default_context()
......@@ -483,7 +509,7 @@ class SharedMailBox():
unq_index,inv = torch.unique(shared_index,return_inverse = True)
max_ts,idx = torch_scatter.scatter_max(shared_memory_ts,inv,0)
shared_memory = shared_memory[idx]
shared_memory = shared_memory[idx]
#shared_memory = shared_memory[idx]
shared_memory_ts = shared_memory_ts[idx]
shared_mail_ts = shared_mail_ts[idx]
shared_mail = shared_mail[idx]
......@@ -495,11 +521,13 @@ class SharedMailBox():
self.set_mailbox_local(self.shared_nodes_index[shared_mail_indx],shared_mail,shared_mail_ts)
else:
self.next_wait_gather_memory_job = (shared_list,mem,shared_id_list,shared_memory_ind)
if not wait_submit:
self.update_shared()
self.update_p2p_mail()
self.update_p2p_mem()
self.handle_last_async()
self.sychronize_shared()
#self.historical_cache.add_shared_to_queue(handle0,handle1,shared_id_list,shared_list)
"""
shared_memory = self.node_memory.accessor.data[self.shared_nodes_index]
......
......@@ -317,8 +317,9 @@ class NeighborSampler(BaseSampler):
"""
if self.no_neg:
out,metadata = self.sample_from_nodes(seed[:seed.shape[0]//3*2], seed_ts[:seed.shape[0]//3*2], is_unique=False)
else:
out,metadata = self.sample_from_nodes(seed[:seed.shape[0]], seed_ts[:seed.shape[0]//3*2], is_unique=False)
out,metadata = self.sample_from_nodes(seed, seed_ts, is_unique=False)
src_pos_index = torch.arange(0,num_pos,dtype= torch.long,device=out_device)
dst_pos_index = torch.arange(num_pos,2*num_pos,dtype= torch.long,device=out_device)
if neg_sampling.is_triplet() or neg_sampling.is_tgbtriplet():
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment