Commit 4ad96131 by zlj

delte some code useless

parent 4fa85d33
import matplotlib.pyplot as plt
import numpy as np
import torch
# 读取文件内容
import os
probability_values = [1,0.1,0.05,0.01,0]#[0.1,0.05,0.01,0]
data_values = ['WIKI','LASTFM','WikiTalk','StackOverflow','GDELT'] # 存储从文件中读取的数据
seed = ['13357','12347','53473','54763','12347','63377','53473','54763']
partition = 'ours_shared'
# 从文件中读取数据,假设数据存储在文件 data.txt 中
#all/"$data"/"$partitions"-ours_shared-0.01-"$mem"-"$ssim"-"$sample".out
partitions=4
topk=0.01
mem='all_update'#'historical'
model0='TGN'
def average(l):
return sum(l)/len(l)
for p in probability_values:
for data in data_values:
ap_list = []
comm_list = []
test_time = []
train_time = []
total_communication = []
shared_synchronize = []
for sd in seed :
if data == 'WIKI' or data =='LASTFM':
model = model0
else:
model = model0+'_large'
if p == 1:
file = '../examples-probability-sample/all_{}/{}/{}/{}-{}-{}-{}-recent.out'.format(sd,data,model,partitions,partition,topk,mem)
else:
file = '../examples-probability-sample/all_{}/{}/{}/{}-{}-{}-{}-boundery_recent_decay-{}.out'.format(sd,data,model,partitions,partition,topk,mem,p)
#print(file)
prefix = "val ap:"
max_val_ap = 0
test_ap = 0
#if
#print(file)
if os.path.exists(file):
with open(file, 'r') as file:
_total_communication = []
_shared_synchronize = []
for line in file:
#if line.find('Epoch 50:')!=-1:
# break
if line.find(prefix)!=-1:
pos = line.find(prefix)+len(prefix)
posr = line.find(' ',pos)
#print(line[pos:posr])
val_ap = float(line[pos:posr])
pos = line.find("test ap ")+len("test ap ")
posr = line.find(' ',pos)
#print(line[pos:posr])
_test_ap = float(line[pos:posr])
if(val_ap>max_val_ap):
max_val_ap = val_ap
test_ap = _test_ap
elif line.find('avg_time ')!=-1:
pl = line.find('avg_time ') + len('avg_time ')
pr = line.find(' test time')
train_time.append(float(line[pl:pr]))
test_time.append(float(line[pr+len(' test time '):]))
#print(line)
ap_list.append(test_ap)
total_communication.append(average(_total_communication))
shared_synchronize.append(average(_shared_synchronize))
"""
local node number tensor([114453298]) remote node number tensor([20457479]) local edge tensor([1592867807]) remote edgetensor([326448632])
local node number tensor([114453298]) remote node number tensor([20457479]) local edge tensor([1592867807]) remote edgetensor([326448632])
comm local node number 0 remote node number 0 local edge 0 remote edge0
comm local node number 0 remote node number 0 local edge 0 remote edge0
memory comm tensor([0]) shared comm tensor([24860])
memory comm tensor([0]) shared comm tensor([24860])
"""
elif line.find('remote node number tensor([')!=-1:
pl = line.find('remote node number tensor([')+len('remote node number tensor([')
pr = line.find('])',pl)
_total_communication.append(int(line[pl:pr]))
#if(p==0):
#print(file)
#print(line)
elif line.find('shared comm tensor([')!=-1:
pl = line.find('shared comm tensor([')+len('shared comm tensor([')
pr = line.find('])',pl)
_shared_synchronize.append(int(line[pl:pr]))
#else:
# print(file)
if len(ap_list) > 0:
#print('prob {} data {} model {} remote volume : {} synchronize volume : {}'.format(p,data,model,average(total_communication),average(shared_synchronize)))
print('prob {} data {} model {} ap: {} train_time: {} eval time: {} remote volume : {} synchronize volume : {}'.format(p,data,model,average(ap_list),average(train_time),average(test_time),average(total_communication),average(shared_synchronize)))
......@@ -11,15 +11,16 @@ node_per="4"
nnodes="3"
node_rank="1"
probability_params=("0.1")
probability_params=("0.1")
sample_type_params=("boundery_recent_decay")
#sample_type_params=("recent" "boundery_recent_decay") #"boundery_recent_uniform")
#memory_type=("all_update" "p2p" "all_reduce" "historical" "local")
memory_type=("historical")
memory_type=("local")
#"historical")
#memory_type=("local" "all_update" "historical" "all_reduce")
shared_memory_ssim=("0.3")
#data_param=("WIKI" "REDDIT" "LASTFM" "WikiTalk")
data_param=("WIKI" "LASTFM" "WikiTalk" "StackOverflow" "GDELT")
data_param=("WIKI" )
#"GDELT")
#data_param=("WIKI" "REDDIT" "LASTFM" "DGraphFin" "WikiTalk" "StackOverflow")
#data_param=("WIKI" "REDDIT" "LASTFM" "WikiTalk" "StackOverflow")
......@@ -34,7 +35,7 @@ mkdir -p all_"$seed"
for data in "${data_param[@]}"; do
model="APAN_large"
if [ "$data" = "WIKI" ] || [ "$data" = "REDDIT" ] || [ "$data" = "LASTFM" ]; then
model="APAN"
model="TGN"
fi
#model="APAN"
mkdir all_"$seed"/"$data"
......
......@@ -236,7 +236,9 @@ def main():
policy_train = policy
if memory_param['type'] != 'none':
mailbox = SharedMailBox(graph.ids.shape[0], memory_param, dim_edge_feat = graph.efeat.shape[1] if graph.efeat is not None else 0,
shared_nodes_index=graph.shared_nids_list[ctx.memory_group_rank],device = torch.device('cuda:{}'.format(local_rank)),shared_ssim=args.shared_memory_ssim)
shared_nodes_index=graph.shared_nids_list[ctx.memory_group_rank],device = torch.device('cuda:{}'.format(local_rank)),
start_historical=(args.memory_type=='historical')
shared_ssim=args.shared_memory_ssim)
else:
mailbox = None
......
......@@ -76,12 +76,12 @@ class GeneralModel(torch.nn.Module):
self.dim_edge = dim_edge
self.sample_param = sample_param
self.memory_param = memory_param
self.train_pos_ratio,self.train_neg_ratio = train_ratio
#self.train_pos_ratio,self.train_neg_ratio = train_ratio
if not 'dim_out' in gnn_param:
gnn_param['dim_out'] = memory_param['dim_out']
self.gnn_param = gnn_param
self.train_param = train_param
self.neg_fix_layer = NegFixLayer()
#self.neg_fix_layer = NegFixLayer()
if memory_param['type'] == 'node':
if memory_param['memory_update'] == 'gru':
#if memory_param['async'] == False:
......
......@@ -290,9 +290,9 @@ def to_block(graph,data, sample_out,device = torch.device('cuda'),unique = True)
if sample_out[r].delta_ts().shape[0] > 0:
b.edata['dt'] = sample_out[r].delta_ts().to(device)
b.srcdata['ts'] = block_node_list[1,b.srcnodes()].to(torch.float)
weight = sample_out[r].sample_weight()
if(weight.shape[0] > 0):
b.edata['weight'] = 1/torch.clamp(sample_out[r].sample_weight(),0.0001).to(b.device)
#weight = sample_out[r].sample_weight()
#if(weight.shape[0] > 0):
# b.edata['weight'] = 1/torch.clamp(sample_out[r].sample_weight(),0.0001).to(b.device)
b.edata['__ID'] = e_idx
col = row
col_len += eid_len[r]
......
......@@ -105,7 +105,7 @@ class SharedMailBox():
self.is_shared_mask[shared_nodes_index] = torch.arange(self.shared_nodes_index.shape[0],dtype=torch.int,
device=torch.device('cuda:{}'.format(ctx.local_rank)))
if start_historical is not None:
if start_historical:
self.historical_cache = historical_cache.HistoricalCache(self.shared_nodes_index,0,self.node_memory.shape[1],self.node_memory.dtype,self.node_memory.device,threshold=shared_ssim)
self._mem_pin = {}
self._mail_pin = {}
......@@ -180,26 +180,12 @@ class SharedMailBox():
if self.deliver_to == 'neighbors':
assert block is not None and Reduce_score is None
# print(block.edges().shape)
root = torch.cat([src,dst]).reshape(-1)
#pos = torch.empty(root.max()+1,dtype=torch.long,device=block.device)
#print('edge {} {}\n'.format(block.num_src_nodes(),block.edges()[0].max()))
#print('root is {} {} {} {}\n'.format(root,root.shape,root.max(),block.edges()[0].shape))
#pos_index = torch.arange(root.shape[0],device=root.device,dtype=root.dtype)
pos,idx = torch_scatter.scatter_max(mail_ts,root,0)
#print(block.number_of_edges())
_,idx = torch_scatter.scatter_max(mail_ts,root,0)
mail = torch.cat([mail, mail[idx[block.edges()[0].long()]]],dim=0)
mail_ts = torch.cat([mail_ts, mail_ts[idx[block.edges()[0].long()]]], dim=0)
#print('pos is {} {}\n'.format(pos,block.edges()[0].long()))
#mail = torch.cat([mail, mail[pos[block.edges()[0].long()]]],dim=0)
#mail_ts = torch.cat([mail_ts, mail_ts[pos[block.edges()[0].long()]]], dim=0)
#print(root,block.edges()[1].long())
index = torch.cat([index,block.dstdata['ID'][block.edges()[1].long()]],dim=0)
#print(index)
#mail = torch.cat([mail, mail[block.edges()[0].long()]], dim=0)
#mail_ts = torch.cat([mail_ts, mail_ts[block.edges()[0].long()]], dim=0)
#index = torch.cat([index,block.dstdata['ID'][block.edges()[1].long()]],dim=0)
if Reduce_score is not None:
Reduce_score = torch.cat((Reduce_score,Reduce_score),-1).to(self.device)
if Reduce_score is None:
......@@ -209,18 +195,12 @@ class SharedMailBox():
mail = mail[idx]
index = unq_index
else:
uni, inv = torch.unique(index, return_inverse=True)
perm = torch.arange(inv.size(0), dtype=inv.dtype, device=inv.device)
perm = inv.new_empty(uni.size(0)).scatter_(0, inv, perm)
index = index[perm]
mail = mail[perm]
mail_ts = mail_ts[perm]
#unq_index,inv = torch.unique(index,return_inverse = True)
#print(inv.shape,Reduce_score.shape)
#max_score,idx = torch_scatter.scatter_max(Reduce_score,inv,0)
#mail_ts = mail_ts[idx]
#mail = mail[idx]
#index = unq_index
unq_index,inv = torch.unique(index,return_inverse = True)
print(inv.shape,Reduce_score.shape)
max_score,idx = torch_scatter.scatter_max(Reduce_score,inv,0)
mail_ts = mail_ts[idx]
mail = mail[idx]
index = unq_index
#print('mail {} {}\n'.format(index.shape,mail.shape,mail_ts.shape))
return index,mail,mail_ts
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment