Commit 802c3ca2 by xxx

first commit

parent 34f10818
...@@ -136,7 +136,7 @@ def load_from_shared_node_partition(data,node_i,shared_node,sample_add_rev = Tru ...@@ -136,7 +136,7 @@ def load_from_shared_node_partition(data,node_i,shared_node,sample_add_rev = Tru
test_mask = (torch.from_numpy(np.array(df.ext_roll.values)) == 2) test_mask = (torch.from_numpy(np.array(df.ext_roll.values)) == 2)
num_node = max(src.max().item(),dst.max().item())+1 num_node = max(src.max().item(),dst.max().item())+1
dim_feats = [0, 0, 0, 0, 0, 0, 0, 0] dim_feats = [0, 0, 0, 0, 0, 0, 0, 0]
if dist.get_rank() == 0: if ctx.local_rank == 0:
_node_feats, _edge_feats = load_feat(data, num_node, src.shape[0], 172, 172) _node_feats, _edge_feats = load_feat(data, num_node, src.shape[0], 172, 172)
if _node_feats is not None: if _node_feats is not None:
dim_feats[0] = _node_feats.shape[0] dim_feats[0] = _node_feats.shape[0]
...@@ -165,7 +165,7 @@ def load_from_shared_node_partition(data,node_i,shared_node,sample_add_rev = Tru ...@@ -165,7 +165,7 @@ def load_from_shared_node_partition(data,node_i,shared_node,sample_add_rev = Tru
torch.distributed.broadcast_object_list(dim_feats, src=0) torch.distributed.broadcast_object_list(dim_feats, src=0)
print('dist rank is {} after node feats defination:'.format(torch.distributed.get_rank())) print('dist rank is {} after node feats defination:'.format(torch.distributed.get_rank()))
print() print()
if dist.get_rank() > 0: if ctx.local_rank > 0:
node_feats = None node_feats = None
edge_feats = None edge_feats = None
if dim_feats[6] == 1: if dim_feats[6] == 1:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment