Commit 66542b28 by zlj

fix useless code in memroy select

parent 05960e1d
...@@ -317,20 +317,14 @@ class DistributedDataLoader: ...@@ -317,20 +317,14 @@ class DistributedDataLoader:
edge_feat = get_edge_feature_by_dist(self.graph,dist_eid,is_local,out_device=self.device) edge_feat = get_edge_feature_by_dist(self.graph,dist_eid,is_local,out_device=self.device)
node_feat,mem = get_node_feature_by_dist(self.graph,self.mailbox,dist_nid, is_local,out_device=self.device) node_feat,mem = get_node_feature_by_dist(self.graph,self.mailbox,dist_nid, is_local,out_device=self.device)
prepare_input(node_feat,edge_feat,mem,mfgs,dist_nid,dist_eid) prepare_input(node_feat,edge_feat,mem,mfgs,dist_nid,dist_eid)
batch_data[1][0][0].srcdata['his_mem'] = batch_data[1][0][0].srcdata['mem'].clone() if(self.mailbox is not None and self.mailbox.historical_cache is not None):
batch_data[1][0][0].srcdata['his_ts'] = batch_data[1][0][0].srcdata['mail_ts'].clone() id = batch_data[1][0][0].srcdata['ID']
mask = DistIndex(batch_data[1][0][0].srcdata['ID']).is_shared mask = DistIndex(id).is_shared
indx = self.mailbox.is_shared_mask[DistIndex(batch_data[1][0][0].srcdata['ID']).loc[mask]] batch_data[1][0][0].srcdata['his_mem'] = batch_data[1][0][0].srcdata['mem'].clone()
batch_data[1][0][0].srcdata['his_mem'][mask] = self.mailbox.historical_cache.local_historical_data[indx] batch_data[1][0][0].srcdata['his_ts'] = batch_data[1][0][0].srcdata['mail_ts'].clone()
batch_data[1][0][0].srcdata['his_ts'][mask] = self.mailbox.historical_cache.local_ts[indx].reshape(-1,1) indx = self.mailbox.is_shared_mask[DistIndex(batch_data[1][0][0].srcdata['ID']).loc[mask]]
t_fetch = tt.elapsed_event(t1) batch_data[1][0][0].srcdata['his_mem'][mask] = self.mailbox.historical_cache.local_historical_data[indx]
tt.time_memory_fetch += t_fetch batch_data[1][0][0].srcdata['his_ts'][mask] = self.mailbox.historical_cache.local_ts[indx].reshape(-1,1)
#if(self.mailbox is not None and self.mailbox.historical_cache is not None):
# id = batch_data[1][0][0].srcdata['ID']
# mask = DistIndex(id).is_shared
# a = DistIndex(id[mask]).loc
# b = self.mailbox.historical_cache.local_historical_data[a]
#batch_data[1][0][0].srcdata['mem'][mask] = self.mailbox.historical_cache.local_historical_data[DistIndex(id[mask]).loc]
self.recv_idxs += 1 self.recv_idxs += 1
assert batch_data is not None assert batch_data is not None
return root,mfgs,metadata return root,mfgs,metadata
...@@ -352,8 +346,9 @@ class DistributedDataLoader: ...@@ -352,8 +346,9 @@ class DistributedDataLoader:
edge_feat = get_edge_feature_by_dist(self.graph,dist_eid,is_local,out_device=self.device) edge_feat = get_edge_feature_by_dist(self.graph,dist_eid,is_local,out_device=self.device)
node_feat,mem = get_node_feature_by_dist(self.graph,self.mailbox,dist_nid, is_local,out_device=self.device) node_feat,mem = get_node_feature_by_dist(self.graph,self.mailbox,dist_nid, is_local,out_device=self.device)
prepare_input(node_feat,edge_feat,mem,batch_data[1],dist_nid,dist_eid) prepare_input(node_feat,edge_feat,mem,batch_data[1],dist_nid,dist_eid)
batch_data[1][0][0].srcdata['his_mem'] = batch_data[1][0][0].srcdata['mem'].clone() if(self.mailbox is not None and self.mailbox.historical_cache is not None):
batch_data[1][0][0].srcdata['his_ts'] = batch_data[1][0][0].srcdata['mail_ts'].clone() batch_data[1][0][0].srcdata['his_mem'] = batch_data[1][0][0].srcdata['mem'].clone()
batch_data[1][0][0].srcdata['his_ts'] = batch_data[1][0][0].srcdata['mail_ts'].clone()
#if(self.mailbox is not None and self.mailbox.historical_cache is not None): #if(self.mailbox is not None and self.mailbox.historical_cache is not None):
# id = batch_data[1][0][0].srcdata['ID'] # id = batch_data[1][0][0].srcdata['ID']
...@@ -414,16 +409,7 @@ class DistributedDataLoader: ...@@ -414,16 +409,7 @@ class DistributedDataLoader:
batch_data[1][0][0].srcdata['his_ts'] = batch_data[1][0][0].srcdata['mail_ts'].clone() batch_data[1][0][0].srcdata['his_ts'] = batch_data[1][0][0].srcdata['mail_ts'].clone()
indx = self.mailbox.is_shared_mask[DistIndex(batch_data[1][0][0].srcdata['ID']).loc[mask]] indx = self.mailbox.is_shared_mask[DistIndex(batch_data[1][0][0].srcdata['ID']).loc[mask]]
batch_data[1][0][0].srcdata['his_mem'][mask] = self.mailbox.historical_cache.local_historical_data[indx] batch_data[1][0][0].srcdata['his_mem'][mask] = self.mailbox.historical_cache.local_historical_data[indx]
batch_data[1][0][0].srcdata['his_ts'][mask] = self.mailbox.historical_cache.local_ts[indx].reshape(-1,1)#self.mailbox.node_memory.accessor.data[DistIndex(id).loc[mask]] batch_data[1][0][0].srcdata['his_ts'][mask] = self.mailbox.historical_cache.local_ts[indx].reshape(-1,1)
#his_mem = torch.clone(batch_data[1][0][0].srcdata['mem'])
#his_ts = self.mailbox.historical_cache.local_ts[DistIndex(id[mask]).loc]
#maxer = his_ts>batch_data[1][0][0].srcdata['mem_ts'][mask]
#his_mem[mask&maxer] = self.mailbox.historical_cache.local_historical_data[DistIndex(id[mask]).loc]
#print((mask&maxer).sum(),mask.sum())
#batch_data[1][0][0].srcdata['mem'] = his_mem
#batch_data[1][0][0].srcdata['mem_ts'][mask & maxer] = his_ts
#batch_data[1][0][0].srcdata['his_mem'] = his_mem
#batch_data[1][0][0].srcdata['his_mem'][mask] = self.mailbox.historical_cache.local_historical_data[DistIndex(id).loc[mask]]#batch_data[1][0][0].srcdata['his_mem'] = self.mailbox.historical_cache.local_historical_data[DistIndex(id).loc]
self.recv_idxs += 1 self.recv_idxs += 1
else: else:
raise StopIteration raise StopIteration
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment