Commit e3346a38 by zljJoan

20230204

parent faec7932
class BatchData:
'''
Args:
batch_size: the number of sampled node
nids: the real id of sampled nodes and their neighbours
edge_index: the edge between nids, the nodes' id is the index of BatchData
x: nodes' feature
y:
eids: the edge id in subgraph of edge_index
train_mask
val_mask
test_mask
'''
def __init__(self,batch_size,nids,edge_index,roots=None,x=None,y=None,eids=None,train_mask=None,val_mask=None,test_mask=None):
self.batch_size=batch_size
self.roots=roots
self.nids=nids
self.edge_index=edge_index
self.x=x
self.y=y
self.eids=eids
self.train_mask=train_mask
self.val_mask=val_mask
self.test_mask=test_mask
import argparse
import os
import time
from threading import Lock
import torch
import torch.distributed.autograd as dist_autograd
import torch.distributed.rpc as rpc
import torch.multiprocessing as mp
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
from torch.distributed.optim import DistributedOptimizer
from torchvision import datasets, transforms
# --------- MNIST Network to train, from pytorch/examples -----
class Net(nn.Module):
def __init__(self, num_gpus=0):
super(Net, self).__init__()
print(f"Using {num_gpus} GPUs to train")
self.num_gpus = num_gpus
device = torch.device(
"cuda:0" if torch.cuda.is_available() and self.num_gpus > 0 else "cpu")
print(f"Putting first 2 convs on {str(device)}")
# Put conv layers on the first cuda device, or CPU if no cuda device
self.conv1 = nn.Conv2d(1, 32, 3, 1).to(device)
self.conv2 = nn.Conv2d(32, 64, 3, 1).to(device)
# Put rest of the network on the 2nd cuda device, if there is one
if "cuda" in str(device) and num_gpus > 1:
device = torch.device("cuda:1")
print(f"Putting rest of layers on {str(device)}")
self.dropout1 = nn.Dropout2d(0.25).to(device)
self.dropout2 = nn.Dropout2d(0.5).to(device)
self.fc1 = nn.Linear(9216, 128).to(device)
self.fc2 = nn.Linear(128, 10).to(device)
def forward(self, x):
x = self.conv1(x)
x = F.relu(x)
x = self.conv2(x)
x = F.max_pool2d(x, 2)
x = self.dropout1(x)
x = torch.flatten(x, 1)
# Move tensor to next device if necessary
next_device = next(self.fc1.parameters()).device
x = x.to(next_device)
x = self.fc1(x)
x = F.relu(x)
x = self.dropout2(x)
x = self.fc2(x)
output = F.log_softmax(x, dim=1)
return output
\ No newline at end of file
import numpy as np
from message_worker import message_worker
from message_worker import sampler_worker
import torch
import torch.distributed.rpc as rpc
import torch.optim as optim
from torch.distributed.rpc import RRef, rpc_async, remote
from torch.distributions import Categorical
from Sample.Sampler import NeighborSampler
#from Sample.Sampler import NeighborSampler
class Agent:
......@@ -21,7 +20,7 @@ class Agent:
for ob_rank in range(worker_size):
ob_info = rpc.get_worker_info(OBSERVER_NAME.format(ob_rank))
print(OBSERVER_NAME.format(ob_rank),(worker_size,OBSERVER_NAME))
self.ob_rrefs.append(remote(ob_info, message_worker,args=(worker_size,OBSERVER_NAME)))
self.ob_rrefs.append(remote(ob_info, sampler_worker,args=(worker_size,OBSERVER_NAME)))
self.rewards[ob_info.id] = []
self.saved_log_probs[ob_info.id] = []
......
......@@ -3,91 +3,209 @@ import argparse
from torch.distributed.rpc import RRef, rpc_async, remote
import torch.distributed.rpc as rpc
import torch
from part.Utils import GraphData as DistGraph
from part.Utils import GraphData
from Sample.Sampler import NeighborSampler
import time
from typing import Optional
import torch.distributed as dist
class message_worker:
from BatchData import BatchData
OBSERVER_NAME='server{}'
sampler_workers={}
distributed_loader_list={}
local_sample_worker={}
class sampler_worker:
def __init__(self,worker_size,OBSERVER_NAME):
self.id = rpc.get_worker_info().id
self.rank=self.id-1
self.rank=self.id
self.worker_size=worker_size
#rpc demo
def run_episode(self, agent_rref):
state, ep_reward = self.env.reset(), 0
for _ in range(10000):
# send the state to the agent to get an action
action = agent_rref.rpc_sync().select_action(self.id, state)
# apply the action to the environment, and get the reward
state, reward, done, _ = self.env.step(action)
# report the reward to the agent for training purpose
agent_rref.rpc_sync().report_reward(self.id, reward)
# finishes after the number of self.env._max_episode_steps
if done:
break
def loadGraph(self,path):
self.graph=DistGraph(path+'/rank_{}'.format(self.rank))
#self.graph.load_graph(path+'/rank_{}'.format(self.rank))
sampler_workers[self.rank]=self
#for ob_rank in range(worker_size):
# if(ob_rank == self.rank):
# self.ob_rrefs.append(RRef(self))
# continue
# ob_info = rpc.get_worker_info(OBSERVER_NAME.format(ob_rank))
# self.ob_rrefs.append(remote(ob_info, sampler_worker,args=(worker_size,OBSERVER_NAME)))
# futs=[]
if(self.rank==0):
futs=[]
self.ob_rrefs=[RRef(self)]
for rank in range(1,worker_size):
ob_info = rpc.get_worker_info(OBSERVER_NAME.format(rank))
self.ob_rrefs.append(remote(ob_info, sampler_worker,args=(worker_size,OBSERVER_NAME)))
for rank in range(1,worker_size):
ob_rref=self.ob_rrefs[rank]
futs.append(
rpc_async(
ob_rref.owner(),
ob_rref.rpc_sync().init_rref,
args=(self.ob_rrefs,)
)
)
for fut in futs:
fut.wait()
def init_rref(self,rrefs):
self.ob_rrefs=rrefs
local_sample_worker[self.rank]=self
def request_attr(self,data_name,node_feature):
print('request',self.rank,node_feature)
futs=[]
feature_part=[]
ob_rrefs=self.ob_rrefs
for i in range(self.worker_size):
if(i==self.rank):
continue
futs.append(
rpc_async(
ob_rrefs[i].owner(),
ob_rrefs[i].rpc_sync().get_localattr,
args=(data_name,node_feature[i],)
)
)
for fut in futs:
feature_part.append(fut.wait())
print('request end')
return feature_part
def get_localattr(self,data_name,node_list):
print('local',self.rank,node_list)
for key in distributed_loader_list:
print('key',key)
graph_loader=distributed_loader_list[(self.rank,data_name)]
print('local end')
return graph_loader.get_localattr(node_list)
parser = argparse.ArgumentParser(
description="RPC Reinforcement Learning Example",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument('--world_size', default=4, type=int, metavar='W',
help='number of workers')
parser.add_argument('--log_interval', type=int, default=10, metavar='N',
help='interval between training status logs')
parser.add_argument('--gamma', type=float, default=0.99, metavar='G',
help='how much to value future rewards')
parser.add_argument('--seed', type=int, default=1, metavar='S',
help='random seed for reproducibility')
args = parser.parse_args()
def init_sampler_worker():
local_sample_worker[0]=sampler_worker(args.world_size,OBSERVER_NAME)
class DistributedDataLoader:
'''
Args:
data_path: the path of loaded graph ,each part 0 of graph is saved on $path$/rank_0
num_replicas: the num of worker
'''
def __init__(self,data_name,data_path,sampler,num_samples: Optional[int]=None,shuffle:bool = True,seed:int=0,**kwargs):
self.data_name=data_name
self.num_samples=num_samples
self.sampler=sampler
self.epoch=0
self.id = rpc.get_worker_info().id
self.rank=self.id
self.worker_size=args.world_size#sampler_workers[self.rank].worker_size
print('load graph',data_path+'/rank_{}'.format(self.rank))
self.graph=GraphData(data_path+'/rank_{}'.format(self.rank))
self.shuffle=shuffle
self.seed=seed
self.kwargs=kwargs
distributed_loader_list[(self.rank,self.data_name)]=self
dist.barrier()
def __next__(self):
print(self.sampler,self.num_samples,self.kwargs)
#self.sampleGraph(self.sampler,self.num_samplers,self.kwargs)
return self.sampleGraph(self.sampler,self.num_samples,**self.kwargs)
def set_epoch(self,epoch):
self.epoch = epoch
#将所需node_list的节点拆成各分区的节点
def split_node_part(self,node_id_list):
part_list=[]
scatter_list=[]
for rank in range(self.worker_size):
group_id=(self.graph.partptr[rank]<=node_id_list) & (node_id_list<self.graph.partptr[rank+1])
part_list.append(torch.masked_select(node_id_list,group_id))
return part_list
scatter_list.append(torch.nonzero(group_id))
#index=torch.nonzero(group_id)
#scatter_list=torch.stack((scatter_list,index))
scatter_index=torch.cat(scatter_list)
return part_list,scatter_index
def get_random_root(self,num_samples):
print(num_samples)
return self.graph.get_globalId_by_partitionId(self.rank,torch.randint(high=self.graph.get_part_num(),size=(num_samples,)))
def sampler(self,ob_rrefs,sampler,num_samples,**kwargs):
self.ob_rrefs=ob_rrefs
def sampleGraph(self,sampler,num_samples,**kwargs):
print('sample',num_samples)
t1=time.time()
root_list=self.get_random_root(num_samples)
print('root_list',root_list)
t2=time.time()
#邻居节点list和edge list
if('num_neighbors' in kwargs):
num_neighbors=kwargs.get('num_neighbors')
neighbors=root_list
#neighbors=root_list
#edge_index=[]
#print( self.graph.get_node_num(),num_neighbors)
neighbors, edge_index = sampler.sample_from_nodes(root_list,self.graph.edge_index, self.graph.get_node_num(),num_neighbors)
#print(neighbors)
t3=time.time()
part_feature_list=self.split_node_part(neighbors)
part_feature_list,scatter_list=self.split_node_part(neighbors)
t4=time.time()
node_feature=self.get_attr(part_feature_list,self.ob_rrefs)
print(t2-t1,t3-t2,t4-t3,time.time()-t4)
node_feature=self.get_attr(part_feature_list)
t5=time.time()
nids=neighbors
attr_size=node_feature.size()
x=torch.zeros(attr_size)
x=x.scatter(0,scatter_list.repeat(1,attr_size[1]),node_feature)
print(nids)
#root_list=nids[:num_samples]
print(root_list)
local_y_id=self.graph.get_localId_by_partitionId(self.rank,root_list)
y=self.graph.select_y(local_y_id)
print(t2-t1,t3-t2,t4-t3,t5-t4,time.time()-t4)
return BatchData(num_samples,nids,edge_index,roots=root_list,x=x,y=y)
def get_attr(self,node_feature,ob_rrefs):
local_feature=self.get_localattr(node_feature[self.rank])
futs=[]
part=0
def get_attr(self,node_part_list):
#worker=self.worker#sampler_worker[self.rank]
#global local_sample_worker
worker=local_sample_worker[self.rank]
part_feature=worker.request_attr(self.data_name,node_part_list)
local_feature=self.get_localattr(node_part_list[self.rank])
node_feature_list=[]
for i in range(self.worker_size):
if(i==self.rank):
continue
futs.append(
rpc_async(
ob_rrefs[i].owner(),
ob_rrefs[i].rpc_sync().get_localattr,
args=(node_feature[i],)
)
)
for fut in futs:
feature_part=fut.wait()
if(i<self.rank):
node_feature_list.append(part_feature[i])
elif i>self.rank:
node_feature_list.append(part_feature[i-1])
else:
node_feature_list.append(local_feature)
return torch.cat(node_feature_list)
def get_localattr(self,node_list):
local_id=self.graph.get_localId_by_partitionId(self.rank,node_list)
#print(self.rank,node_list,local_id)
return self.graph.select_attr(local_id)
if __name__=="__main__":
worker=message_worker(2,"observe")
worker=sampler_worker(2,"observe")
worker.loadGraph('./part/metis_2')
sampler=NeighborSampler()
worker.sampler([],sampler,5,num_neighbors=3)
\ No newline at end of file
......@@ -27,7 +27,8 @@ class GraphData():
def select_attr(self,index):
return torch.index_select(self.data.x,0,index)
def select_y(self,index):
return torch.index_select(self.data.y,0,index)
#返回全局的节点id 所对应的分区
def get_localId_by_partitionId(self,id,index):
#print(index)
......
import argparse
from multiprocessing import freeze_support
import torch.distributed as dist
import os
from itertools import count
import torch.multiprocessing as mp
......@@ -9,33 +9,18 @@ import torch.optim as optim
from torch.distributed.rpc import RRef, rpc_async, remote
from torch.distributions import Categorical
import torch.multiprocessing as mp
from message_worker import DistributedDataLoader, sampler_worker
from RpcAgent import Agent
AGENT_NAME='server'
OBSERVER_NAME='obs{}'
def init_worker(rank,world_size):
print(rank,world_size)
os.environ['MASTER_ADDR']='localhost'
os.environ['MASTER_PORT']='29000'
if rank==0:
rpc.init_rpc(AGENT_NAME,rank=rank,world_size=world_size+1)
agent=Agent(world_size,OBSERVER_NAME)
agent.loadGraph('./part/metis_2')
for i in range(10):
agent.run_sample(5,3)
else:
print(OBSERVER_NAME.format(rank-1))
rpc.init_rpc(OBSERVER_NAME.format(rank-1),rank=rank,world_size=world_size+1)
rpc.shutdown()
from Sample.Sampler import NeighborSampler
import message_worker
OBSERVER_NAME='server{}'
parser = argparse.ArgumentParser(
description="RPC Reinforcement Learning Example",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument('--world_size', default=2, type=int, metavar='W',
parser.add_argument('--world_size', default=4, type=int, metavar='W',
help='number of workers')
parser.add_argument('--log_interval', type=int, default=10, metavar='N',
help='interval between training status logs')
......@@ -44,10 +29,36 @@ parser.add_argument('--gamma', type=float, default=0.99, metavar='G',
parser.add_argument('--seed', type=int, default=1, metavar='S',
help='random seed for reproducibility')
args = parser.parse_args()
def init_distribution(rank,world_size):
os.environ['MASTER_ADDR']='localhost'
os.environ['MASTER_PORT']='29000'
dist.init_process_group(backend="gloo", world_size=world_size, rank=rank)
rpc.init_rpc(OBSERVER_NAME.format(rank),rank=rank,world_size=world_size)
if(rank==0):
message_worker.init_sampler_worker()
dist.barrier()
print(OBSERVER_NAME.format(rank))
#worker = sampler_worker(world_size,OBSERVER_NAME)
#if(rank==0):
#rrefs =
#master_worker = sampler_worker(world_size,OBSERVER_NAME)
def close_distribution():
#message_worker.set_sampler_worker_none()
rpc.shutdown()
def main(rank,world_size=2):
init_distribution(rank,world_size)
sampler=NeighborSampler()
loader=DistributedDataLoader('test','./part/metis_4',sampler,num_samples=100,num_neighbors=30)
print(next(loader))
close_distribution()
if __name__ == '__main__':
mp.spawn(
init_worker,
args=(args.world_size, ),
nprocs=args.world_size+1,
main,
args=(4, ),
nprocs=4,
join=True
)
\ No newline at end of file
)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment