Commit faec7932 by zljJoan

20230204

parent 3e80e491
......@@ -63,6 +63,7 @@ class message_worker:
print(t2-t1,t3-t2,t4-t3,time.time()-t4)
def get_attr(self,node_feature,ob_rrefs):
local_feature=self.get_localattr(node_feature[self.rank])
futs=[]
part=0
for i in range(self.worker_size):
......@@ -78,39 +79,7 @@ class message_worker:
for fut in futs:
feature_part=fut.wait()
local_feature=self.get_localattr(node_feature[self.rank])
''''
#get remote attr
def get_list_attr(self,node_list,ob_rrefs):
feature=torch.zeros(node_list,self.graph.get_attr_size())
sample_info={}
for ob_rank in range(1,self.world_size):
sample_info[ob_rank]=[]
futs=[]
node_feature={}
for i in range(node_list.size()):
for nid in range(node_list[i]):
part_id=self.graph.get_nodepart(nid)
if(self.id!=part_id):
sample_info[part_id].append(nid)
else:
node_feature[nid]=self.graph.get_nodeattr(nid)
for ob_rank in range(1,self.world_size):
if(ob_rank != self.id and sample_info[ob_rank]):
sample_info=set(sample_info[ob_rank]).
futs.append(
rpc_async(
ob_rrefs[ob_rank-1].owner(),
ob_rrefs[ob_rank-1].rpc_sync().get_localattr,
args=(sample_info[ob_rank])
)
)
for fut in futs:
feature_part=fut.wait()
for f in feature_part:
node_feature[f[0]]=f[1:]
return
'''
def get_localattr(self,node_list):
local_id=self.graph.get_localId_by_partitionId(self.rank,node_list)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment