Commit a1d8044f by zhlj

fix alpha policy

parent eacb2444
...@@ -21,7 +21,7 @@ class Filter(nn.Module): ...@@ -21,7 +21,7 @@ class Filter(nn.Module):
""" """
# Treat filter as parameter so that it is saved and loaded together with the model # Treat filter as parameter so that it is saved and loaded together with the model
self.count = torch.zeros((self.n_nodes),1).to(self.device) self.count = torch.zeros((self.n_nodes),1).to(self.device)
self.incretment = torch.zeros((self.n_nodes, self.memory_dimension)).to(self.device) self.incretment = torch.zeros((self.n_nodes, self.memory_dimension),dtype=torch.float32).to(self.device)
def get_count(self, node_idxs): def get_count(self, node_idxs):
...@@ -37,7 +37,7 @@ class Filter(nn.Module): ...@@ -37,7 +37,7 @@ class Filter(nn.Module):
def update(self, node_idxs, incret): def update(self, node_idxs, incret):
self.count[node_idxs, :] = self.count[node_idxs, :] + 1 self.count[node_idxs, :] = self.count[node_idxs, :] + 1
self.incretment[node_idxs, :] = self.incretment[node_idxs, :] + incret self.incretment[node_idxs, :] = (self.incretment[node_idxs, :] + incret).to(self.incretment.dtype)
def clear(self): def clear(self):
self.count.zero_() self.count.zero_()
......
...@@ -393,6 +393,7 @@ class AsyncMemeoryUpdater(torch.nn.Module): ...@@ -393,6 +393,7 @@ class AsyncMemeoryUpdater(torch.nn.Module):
# if nxt_fetch_func is not None: # if nxt_fetch_func is not None:
# nxt_fetch_func() # nxt_fetch_func()
def historical_func(self,index,memory,memory_ts,mail_index,mail,mail_ts,nxt_fetch_func,spread_mail=False): def historical_func(self,index,memory,memory_ts,mail_index,mail,mail_ts,nxt_fetch_func,spread_mail=False):
self.ada_param.update_memory_update_time(self.ada_param.last_start_event_memory_update)
self.mailbox.sychronize_shared() self.mailbox.sychronize_shared()
self.mailbox.handle_last_async() self.mailbox.handle_last_async()
#if self.ada_param.training is True: #if self.ada_param.training is True:
...@@ -409,7 +410,7 @@ class AsyncMemeoryUpdater(torch.nn.Module): ...@@ -409,7 +410,7 @@ class AsyncMemeoryUpdater(torch.nn.Module):
wait_submit=submit_to_queue,spread_mail=spread_mail, wait_submit=submit_to_queue,spread_mail=spread_mail,
update_cross_mm=False, update_cross_mm=False,
) )
self.ada_param.update_memory_update_time(self.ada_param.last_start_event_memory_update)
if nxt_fetch_func is not None: if nxt_fetch_func is not None:
nxt_fetch_func() nxt_fetch_func()
...@@ -422,7 +423,9 @@ class AsyncMemeoryUpdater(torch.nn.Module): ...@@ -422,7 +423,9 @@ class AsyncMemeoryUpdater(torch.nn.Module):
if self.dim_time > 0: if self.dim_time > 0:
#print(b.srcdata['ts'].shape,b.srcdata['mem_ts'].shape) #print(b.srcdata['ts'].shape,b.srcdata['mem_ts'].shape)
time_feat = self.time_enc(b.srcdata['ts'] - b.srcdata['mem_ts']) time_feat = self.time_enc(b.srcdata['ts'] - b.srcdata['mem_ts'])
#print(b.srcdata['mem_input'].dtype,b.srcdata['mem_ts'].dtype,b.srcdata['ts'].dtype,time_feat.dtype)
b.srcdata['mem_input'] = torch.cat([b.srcdata['mem_input'], time_feat], dim=1) b.srcdata['mem_input'] = torch.cat([b.srcdata['mem_input'], time_feat], dim=1)
return self.ceil_updater(b.srcdata['mem_input'], b.srcdata['mem']) return self.ceil_updater(b.srcdata['mem_input'], b.srcdata['mem'])
def __init__(self, memory_param, dim_in, dim_hid, dim_time, dim_node_feat,updater,mode = None,mailbox = None,train_param=None, ada_param = None): def __init__(self, memory_param, dim_in, dim_hid, dim_time, dim_node_feat,updater,mode = None,mailbox = None,train_param=None, ada_param = None):
super(AsyncMemeoryUpdater, self).__init__() super(AsyncMemeoryUpdater, self).__init__()
...@@ -468,32 +471,23 @@ class AsyncMemeoryUpdater(torch.nn.Module): ...@@ -468,32 +471,23 @@ class AsyncMemeoryUpdater(torch.nn.Module):
self.gamma = 1 self.gamma = 1
def forward(self, mfg, param = None): def forward(self, mfg, param = None):
for b in mfg: for b in mfg:
#print(b.srcdata['ID'].shape[0])
updated_memory0 = self.updater(b) updated_memory0 = self.updater(b)
mask = DistIndex(b.srcdata['ID']).is_shared mask = DistIndex(b.srcdata['ID']).is_shared
#incr = updated_memory[mask] - b.srcdata['mem'][mask]
#time_feat = self.time_enc(b.srcdata['ts'][mask].reshape(-1,1) - b.srcdata['his_ts'][mask].reshape(-1,1))
#his_mem = torch.cat((mail_input[mask],time_feat),dim = 1)
with torch.no_grad(): with torch.no_grad():
upd0 = torch.zeros_like(updated_memory0) upd0 = torch.zeros_like(updated_memory0)
#print(upd0.dtype)
if self.mode == 'historical': if self.mode == 'historical':
shared_ind = self.mailbox.is_shared_mask[DistIndex(b.srcdata['ID'][mask]).loc] shared_ind = self.mailbox.is_shared_mask[DistIndex(b.srcdata['ID'][mask]).loc]
transition_dense = b.srcdata['his_mem'][mask] + self.filter.get_incretment(shared_ind) transition_dense = b.srcdata['his_mem'][mask] + self.filter.get_incretment(shared_ind)
#print(transition_dense.shape)
if not (transition_dense.max().item() == 0): if not (transition_dense.max().item() == 0):
transition_dense -= transition_dense.min() transition_dense -= transition_dense.min()
transition_dense /= transition_dense.max() transition_dense /= transition_dense.max()
transition_dense = 2*transition_dense - 1 transition_dense = 2*transition_dense - 1
upd0[mask] = transition_dense#b.srcdata['his_mem'][mask] + transition_dense upd0[mask] = transition_dense.to(upd0.dtype)#b.srcdata['his_mem'][mask] + transition_dense
#print(self.gamma) updated_memory = torch.where(mask.unsqueeze(1),torch.sigmoid(self.gamma)*updated_memory0 + (1-torch.sigmoid(self.gamma))*(upd0),updated_memory0)
#print('tran {} {} {}\n'.format(transition_dense.max().item(),upd0[mask].max().item(),b.srcdata['his_mem'][mask].max().item()))
else: else:
upd0[mask] = updated_memory0[mask] #upd0[mask] = updated_memory0[mask]
#upd0[mask] = self.ceil_updater(his_mem, b.srcdata['his_mem'][mask]) updated_memory = updated_memory0
#updated_memory = torch.where(mask.unsqueeze(1),self.gamma*updated_memory0 + (1-self.gamma)*(b.srcdata['his_mem'])
# ,updated_memory0)
updated_memory = torch.where(mask.unsqueeze(1),torch.sigmoid(self.gamma)*updated_memory0 + (1-torch.sigmoid(self.gamma))*(upd0),updated_memory0)
with torch.no_grad(): with torch.no_grad():
if self.mode == 'historical': if self.mode == 'historical':
change = updated_memory[mask] - b.srcdata['his_mem'][mask] change = updated_memory[mask] - b.srcdata['his_mem'][mask]
...@@ -518,12 +512,11 @@ class AsyncMemeoryUpdater(torch.nn.Module): ...@@ -518,12 +512,11 @@ class AsyncMemeoryUpdater(torch.nn.Module):
None,False,False,block=b None,False,False,block=b
) )
#print(index.shape[0]) #print(index.shape[0])
if torch.distributed.get_world_size() == 0: #if torch.distributed.get_world_size() == 0:
self.mailbox.mon.add(index,self.mailbox.node_memory.accessor.data[index],memory) # self.mailbox.mon.add(index,self.mailbox.node_memory.accessor.data[index],memory)
##print(index.shape,memory.shape,memory_ts.shape,mail.shape,mail_ts.shape) ##print(index.shape,memory.shape,memory_ts.shape,mail.shape,mail_ts.shape)
local_mask = (DistIndex(index).part==torch.distributed.get_rank()) local_mask = (DistIndex(index).part==torch.distributed.get_rank())
local_mask_mail = (DistIndex(index0).part==torch.distributed.get_rank()) local_mask_mail = (DistIndex(index0).part==torch.distributed.get_rank())
self.mailbox.set_mailbox_local(DistIndex(index0[local_mask_mail]).loc,mail[local_mask_mail],mail_ts[local_mask_mail],Reduce_Op = 'max') self.mailbox.set_mailbox_local(DistIndex(index0[local_mask_mail]).loc,mail[local_mask_mail],mail_ts[local_mask_mail],Reduce_Op = 'max')
self.mailbox.set_memory_local(DistIndex(index[local_mask]).loc,memory[local_mask],memory_ts[local_mask], Reduce_Op = 'max') self.mailbox.set_memory_local(DistIndex(index[local_mask]).loc,memory[local_mask],memory_ts[local_mask], Reduce_Op = 'max')
is_deliver=(self.mailbox.deliver_to == 'neighbors') is_deliver=(self.mailbox.deliver_to == 'neighbors')
......
...@@ -159,9 +159,11 @@ class AdaParameter: ...@@ -159,9 +159,11 @@ class AdaParameter:
self.beta = self.beta * average_gnn_aggregate/average_fetch * (1 + self.wait_threshold) self.beta = self.beta * average_gnn_aggregate/average_fetch * (1 + self.wait_threshold)
average_memory_sync_time = self.average_memory_sync/self.count_memory_sync average_memory_sync_time = self.average_memory_sync/self.count_memory_sync
average_memory_update_time = self.average_memory_update/self.count_memory_update average_memory_update_time = self.average_memory_update/self.count_memory_update
self.alpha = self.alpha-math.log(average_memory_update_time/average_memory_sync_time * (1 + self.wait_threshold)) self.alpha = self.alpha - math.log(average_memory_update_time*(1+self.wait_threshold)) + math.log(average_memory_sync_time)
print(self.alpha)
self.beta = max(min(self.beta, self.max_beta),self.min_beta) self.beta = max(min(self.beta, self.max_beta),self.min_beta)
self.alpha = max(min(self.alpha, self.max_alpha),self.min_alpha) self.alpha = max(min(self.alpha, self.max_alpha),self.min_alpha)
ctx = DistributedContext.get_default_context() ctx = DistributedContext.get_default_context()
beta_comm=torch.tensor([self.beta]) beta_comm=torch.tensor([self.beta])
torch.distributed.all_reduce(beta_comm,group=ctx.gloo_group) torch.distributed.all_reduce(beta_comm,group=ctx.gloo_group)
...@@ -169,8 +171,8 @@ class AdaParameter: ...@@ -169,8 +171,8 @@ class AdaParameter:
alpha_comm=torch.tensor([self.alpha]) alpha_comm=torch.tensor([self.alpha])
torch.distributed.all_reduce(alpha_comm,group=ctx.gloo_group) torch.distributed.all_reduce(alpha_comm,group=ctx.gloo_group)
self.alpha = alpha_comm[0].item()/ctx.world_size self.alpha = alpha_comm[0].item()/ctx.world_size
#print('gnn aggregate {} fetch {} memory sync {} memory update {}'.format(average_gnn_aggregate,average_fetch,average_memory_sync_time,average_memory_update_time)) print('gnn aggregate {} fetch {} memory sync {} memory update {}'.format(average_gnn_aggregate,average_fetch,average_memory_sync_time,average_memory_update_time))
#print('beta is {} alpha is {}\n'.format(self.beta,self.alpha)) print('beta is {} alpha is {}\n'.format(self.beta,self.alpha))
#self.reset_time() #self.reset_time()
#log(2-a1 ) = log(2-a2) * t1/t2 * (1 + wait_threshold) #log(2-a1 ) = log(2-a2) * t1/t2 * (1 + wait_threshold)
#2-a1 = 2-a2 ^(t1/t2 * (1 + wait_threshold)) #2-a1 = 2-a2 ^(t1/t2 * (1 + wait_threshold))
......
...@@ -186,26 +186,7 @@ def to_block(graph,data, sample_out,device = torch.device('cuda'),unique = True) ...@@ -186,26 +186,7 @@ def to_block(graph,data, sample_out,device = torch.device('cuda'),unique = True)
sample_out,metadata = sample_out sample_out,metadata = sample_out
else: else:
metadata = None metadata = None
#to_block(metadata['src_pos_index'],metadata['dst_pos_index'],metadata['dst_neg_index'],
# metadata['seed'],metadata['seed_ts'],graph.nids_mapper,graph.eids_mapper,#device.type if "cpu" else str(device.index))
#root_len = len(metadata.pop('seed'))
#eid_inv = metadata.pop('eid_inv').clone()
#print('data {} {}\n'.format(data.edges,data.ts))
#first_block_id = metadata.pop('first_block_id').clone()
#print('first_block_id {}\n'.format(first_block_id))
#block_node_list = metadata.pop('block_node_list').clone()
#print('block_node_list {}\n'.format(block_node_list))
#unq_id = metadata.pop('unq_id').clone()
#print('unq id {}'.format(unq_id))
#dist_nid = metadata.pop('dist_nid').clone().to(device)
#dist_eid = metadata.pop('dist_eid').clone().to(device)
#print('dist nid {} dist eid {}\n'.format(dist_nid,dist_eid))
#print('block node list edge {} {}'.format(
# graph.ids[DistIndex(dist_nid[block_node_list[0,#unq_id]]).loc.to('cpu')],block_node_list[1,unq_id]
eid_len = [ret.eid().shape[0] for ret in sample_out ] eid_len = [ret.eid().shape[0] for ret in sample_out ]
# print(sample_out)
t0 = time.time()
eid = [ret.eid() for ret in sample_out] eid = [ret.eid() for ret in sample_out]
dst = [ret.sample_nodes() for ret in sample_out] dst = [ret.sample_nodes() for ret in sample_out]
dst_ts = [ret.sample_nodes_ts() for ret in sample_out] dst_ts = [ret.sample_nodes_ts() for ret in sample_out]
...@@ -237,13 +218,7 @@ def to_block(graph,data, sample_out,device = torch.device('cuda'),unique = True) ...@@ -237,13 +218,7 @@ def to_block(graph,data, sample_out,device = torch.device('cuda'),unique = True)
metadata[k] = metadata[k].to(device) metadata[k] = metadata[k].to(device)
nid_tensor = torch.cat([root_node,src_node],dim = 0) nid_tensor = torch.cat([root_node,src_node],dim = 0)
dist_nid = nid_mapper[nid_tensor].to(device) dist_nid = nid_mapper[nid_tensor].to(device)
#print(CountComm.origin_local,CountComm.origin_remote)
#for p in range(dist.get_world_size()):
# print((DistIndex(dist_nid).part == p).sum().item())
#CountComm.origin_local = (DistIndex(dist_nid).part == dist.get_rank()).sum().item()
#CountComm.origin_remote =(DistIndex(dist_nid).part != dist.get_rank()).sum().item()
dist_nid,nid_inv = dist_nid.unique(return_inverse=True) dist_nid,nid_inv = dist_nid.unique(return_inverse=True)
#print('nid_tensor {} \n nid {}\n'.format(nid_tensor,dist_nid))
""" """
对于同id和同时间的节点去重取得index 对于同id和同时间的节点去重取得index
...@@ -258,13 +233,11 @@ def to_block(graph,data, sample_out,device = torch.device('cuda'),unique = True) ...@@ -258,13 +233,11 @@ def to_block(graph,data, sample_out,device = torch.device('cuda'),unique = True)
first_block_id = torch.empty(first_index.shape[0],device=unq_id.device,dtype=unq_id.dtype) first_block_id = torch.empty(first_index.shape[0],device=unq_id.device,dtype=unq_id.dtype)
first_block_id[first_index] = torch.arange(first_index.shape[0],device=first_index.device,dtype=first_index.dtype) first_block_id[first_index] = torch.arange(first_index.shape[0],device=first_index.device,dtype=first_index.dtype)
first_block_id = first_block_id[unq_id].contiguous() first_block_id = first_block_id[unq_id].contiguous()
block_node_list = block_node_list[:,first_index] block_node_list = block_node_list[:,first_index].contiguous()
#print('first block id {}\n unq id {} \n block_node_list {}\n'.format(first_block_id,unq_id,block_node_list)) #print('first block id {}\n unq id {} \n block_node_list {}\n'.format(first_block_id,unq_id,block_node_list))
for k in metadata: for k in metadata:
if isinstance(metadata[k],torch.Tensor): if isinstance(metadata[k],torch.Tensor):
#print('{}:{}\n'.format(k,metadata[k]))
metadata[k] = first_block_id[metadata[k]] metadata[k] = first_block_id[metadata[k]]
#print('{}:{}\n'.format(k,metadata[k]))
t2 = time.time() t2 = time.time()
def build_block(): def build_block():
...@@ -290,9 +263,6 @@ def to_block(graph,data, sample_out,device = torch.device('cuda'),unique = True) ...@@ -290,9 +263,6 @@ def to_block(graph,data, sample_out,device = torch.device('cuda'),unique = True)
if sample_out[r].delta_ts().shape[0] > 0: if sample_out[r].delta_ts().shape[0] > 0:
b.edata['dt'] = sample_out[r].delta_ts().to(device) b.edata['dt'] = sample_out[r].delta_ts().to(device)
b.srcdata['ts'] = block_node_list[1,b.srcnodes()].to(torch.float) b.srcdata['ts'] = block_node_list[1,b.srcnodes()].to(torch.float)
#weight = sample_out[r].sample_weight()
#if(weight.shape[0] > 0):
# b.edata['weight'] = 1/torch.clamp(sample_out[r].sample_weight(),0.0001).to(b.device)
b.edata['__ID'] = e_idx b.edata['__ID'] = e_idx
col = row col = row
col_len += eid_len[r] col_len += eid_len[r]
...@@ -348,9 +318,6 @@ def to_reversed_block(graph,data, sample_out,device = torch.device('cuda'),uniqu ...@@ -348,9 +318,6 @@ def to_reversed_block(graph,data, sample_out,device = torch.device('cuda'),uniqu
if identity is False: if identity is False:
assert len(sample_out) == 1 assert len(sample_out) == 1
ret = sample_out[0] ret = sample_out[0]
eid_len = ret.eid().shape[0]
t0 = time.time()
dst_ts = ret.sample_nodes_ts().to(device)
dst = nid_mapper[ret.sample_nodes()].to(device) dst = nid_mapper[ret.sample_nodes()].to(device)
dist_eid = torch.tensor([],dtype=torch.long,device=device) dist_eid = torch.tensor([],dtype=torch.long,device=device)
src_index = ret.src_index().to(device) src_index = ret.src_index().to(device)
......
...@@ -33,11 +33,11 @@ class time_count: ...@@ -33,11 +33,11 @@ class time_count:
def start_gpu(): def start_gpu():
# Uncomment for better breakdown timings # Uncomment for better breakdown timings
#torch.cuda.synchronize() #torch.cuda.synchronize()
#start_event = torch.cuda.Event(enable_timing=True) start_event = torch.cuda.Event(enable_timing=True)
#end_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True)
#start_event.record() start_event.record()
#return start_event,end_event return start_event,end_event
return 0,0 #return 0,0
@staticmethod @staticmethod
def start(): def start():
# return time.perf_counter(),0 # return time.perf_counter(),0
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment