Commit 05e596e1 by zlj

first commit

parents
import numpy as np
from message_worker import message_worker
import torch
import torch.distributed.rpc as rpc
import torch.optim as optim
from torch.distributed.rpc import RRef, rpc_async, remote
from torch.distributions import Categorical
from Sample.Sampler import NeighborSampler
class Agent:
def __init__(self, worker_size,OBSERVER_NAME):
self.ob_rrefs = []
self.agent_rref = RRef(self)
self.rewards = {}
self.saved_log_probs = {}
self.worker_size=worker_size
for ob_rank in range(worker_size):
ob_info = rpc.get_worker_info(OBSERVER_NAME.format(ob_rank))
print(OBSERVER_NAME.format(ob_rank),(worker_size,OBSERVER_NAME))
self.ob_rrefs.append(remote(ob_info, message_worker,args=(worker_size,OBSERVER_NAME)))
self.rewards[ob_info.id] = []
self.saved_log_probs[ob_info.id] = []
def run_sample(self,batch_size,num_samples):
futs=[]
sampler=NeighborSampler()
for ob_rref in self.ob_rrefs:
futs.append(
rpc_async(
ob_rref.owner(),
ob_rref.rpc_sync().sampler,
args=(self.ob_rrefs,sampler,5),
kwargs={"num_neighbors":3}
)
)
for fut in futs:
fut.wait()
def loadGraph(self,path):
print((path))
futs=[]
for ob_rref in self.ob_rrefs:
futs.append(
rpc_async(
ob_rref.owner(),
ob_rref.rpc_sync().loadGraph,
args=(path,)
)
)
for fut in futs:
fut.wait()
\ No newline at end of file
from typing import Tuple
import torch
import torch_scatter
import torch.multiprocessing as mp
from abc import ABC
class BaseSampler(ABC):
r"""An abstract base class that initializes a graph sampler and provides
:meth:`sample_from_nodes` and :meth:`sample_from_edges` routines.
"""
def sample_from_nodes(
self,
node:int,
edge_index:torch.Tensor,
num_nodes:int,
**kwargs
) -> Tuple[torch.tensor, torch.tensor]:
r"""Performs sampling from the nodes specified in: node,
returning a sampled subgraph in the specified output format: Tuple[int, torch.tensor].
Args:
node: the seed node index
edge_index: edges in the graph
num_nodes: the num of all
**kwargs: other kwargs
Returns:
samples_nodes: the node sampled
edge_index: the edge sampled
"""
raise NotImplementedError
def sample_from_nodes(
self,
nodes:torch.Tensor,
edge_index:torch.Tensor,
num_nodes:int,
**kwargs
) -> Tuple[torch.Tensor, torch.tensor]:
r"""Performs sampling from the nodes specified in: node,
returning a sampled subgraph in the specified output format: Tuple[int, torch.tensor].
Args:
nodes: the seed nodes index
edge_index: edges in the graph
num_nodes: the num of all
**kwargs: other kwargs
Returns:
samples_nodes: the node sampled
edge_index: the edge sampled
"""
raise NotImplementedError
class NeighborSampler(BaseSampler):
def __init__(self) -> None:
super().__init__()
def sample_from_node(
self,
node: int,
edge_index: torch.Tensor,
num_nodes: int,
fanout: int
) -> Tuple[torch.Tensor, torch.tensor]:
r"""Performs sampling from the nodes specified in: node,
returning a sampled subgraph in the specified output format: Tuple[int, torch.tensor].
Args:
node: the seed node index
edge_index: edges in the graph
num_nodes: the num of all node
fanout: the number of max neighbor chosen
Returns:
samples_nodes: the node sampled
edge_index: the edge sampled
"""
row, col = edge_index
deg = torch_scatter.scatter_add(torch.ones_like(row), row, dim=0, dim_size=num_nodes)
neighbors=torch.stack([row[row==node],col[row==node]],dim=0)
print('neighbors: \n', neighbors)
if deg[node]<=fanout:
return torch.unique(neighbors[1], dim=0), neighbors
else:
random_index = torch.multinomial(torch.ones(deg[node]), fanout, replacement=False)# torch.randperm(neighbors.shape[1])[0:fanout]
print("random_index:\n", random_index)
edge_index = neighbors.index_select(dim = 1, index=random_index)
samples_nodes = torch.unique(edge_index.view(-1), dim=0)
return samples_nodes, edge_index
def sample_from_nodes(
self,
nodes: torch.Tensor,
edge_index: torch.Tensor,
num_nodes: int,
fanout: int
) -> Tuple[torch.Tensor, torch.tensor]:
r"""Performs sampling from the nodes specified in: node,
returning a sampled subgraph in the specified output format: Tuple[int, torch.tensor].
Args:
node: the seed node index
edge_index: edges in the graph
num_nodes: the num of all node
fanout: the number of max neighbor chosen
Returns:
nodes: the node sampled
edge_index: the edge sampled
"""
samples_nodes=torch.IntTensor([])
row=torch.IntTensor([])
col=torch.IntTensor([])
with mp.Pool(processes=torch.get_num_threads()) as p:
results = [p.apply_async(self.sample_from_node,
(node, edge_index, num_nodes, fanout))
for node in nodes]
for result in results:
samples_nodes_i, edge_index_i = result.get()
samples_nodes = torch.unique(torch.cat([samples_nodes, samples_nodes_i]))
row = torch.cat([row, edge_index_i[0]])
col = torch.cat([col, edge_index_i[1]])
# 单线程循环法:
# for node in nodes:
# samples_nodes_i,edge_index_i = self.sample_from_node(node, edge_index, num_nodes, fanout)
# samples_nodes=torch.unique(torch.concat([samples_nodes,samples_nodes_i]))
# row=torch.concat([row,edge_index_i[0]])
# col=torch.concat([col,edge_index_i[1]])
return samples_nodes, torch.stack([row, col], dim=0)
# 不使用sample_from_node直接取所有点邻居方法:
# row, col = edge_index
# neighbors1=torch.concat([row[row==nodes[i]] for i in range(0, nodes.shape[0])])
# neighbors2=torch.concat([col[row==nodes[i]] for i in range(0, nodes.shape[0])])
# neighbors=torch.stack([neighbors1, neighbors2], dim=0)
# print('neighbors: \n', neighbors)
import torch
from Sampler import NeighborSampler
edge_index = torch.tensor([[0, 1, 1, 2, 2, 2, 3], [1, 0, 2, 1, 3, 0, 2]])
num_nodes = 4
num_neighbors = 1
# Run the neighbor sampling
sampler=NeighborSampler()
# neighbors, edge_index = sampler.sample_from_node(2, edge_index, num_nodes, num_neighbors)
neighbors, edge_index = sampler.sample_from_nodes(torch.tensor([1,2]), edge_index, num_nodes, num_neighbors)
# Print the result
print('neighbor_nodes_id: \n',neighbors, '\nedge_index: \n',edge_index)
# import torch_scatter
# nodes=torch.Tensor([1,2])
# row, col = edge_index
# deg = torch_scatter.scatter_add(torch.ones_like(row), row, dim=0, dim_size=num_nodes)
# neighbors1=torch.concat([row[row==nodes[i]] for i in range(0, nodes.shape[0])])
# print(neighbors1)
# neighbors2=torch.concat([col[row==nodes[i]] for i in range(0, nodes.shape[0])])
# print(neighbors2)
# neighbors=torch.stack([neighbors1, neighbors2], dim=0)
# print('neighbors: \n', neighbors[0]==1)
\ No newline at end of file
import torch
class Graph:
def __init__(self,**kwargs):
if('edge_index' in kwargs):
self.nmap = kwargs.get('edge_index')
if('edge_attr' in kwargs):
self.nmap = kwargs.get('edge_attr')
if('node_attr' in kwargs):
self.nmap = kwargs.get('node_attr')
if('y' in kwargs):
self.nmap = kwargs.get('y')
'''
edge_index srcId,dstId 全图部分
nmap real_id -> node_attr_id
emap real_id -> edge_attr_id
npart node_id->part_id
epart edge_id->part_id
'''
class DistGraph(Graph):
def __init__(self,**kwargs):
if('nmap' in kwargs):
self.nmap = kwargs.get('nmap')
if('emap' in kwargs):
self.nmap = kwargs.get('emap')
if('npart' in kwargs):
self.nmap = kwargs.get('npart')
if('epart' in kwargs):
self.nmap = kwargs.get('epart')
if(kwargs):
super(Graph,self).__init__(kwargs)
def get_node_num(self):
return self.adj.size(0)
def load_graph(self,path):
print(" load graph ",path)
adj,data,partptr,perm =torch.load(path)
self.adj=adj
self.data=data #maybe nmap
self.partptr=partptr
self.perm=perm
def get_part_num(self):
return self.data.x.size()[0]
def select_attr(self,index):
return torch.index_select(self.data.x,0,index)
#返回全局的节点id 所对应的分区
def get_part(self,index):
print(index)
return torch.index_select(self.partptr,0,index)
import torch
import torch.distributed.autograd as dist_autograd
import torch.distributed.rpc as rpc
import torch.multiprocessing as mp
import torch.nn as nn
import torch.nn.functional as F
from . import Net
from torch import optim
from torch.distributed.optim import DistributedOptimizer
from torchvision import datasets,transforms
def call_method(method,rref,*args,**kwargs):
return method(rref.local_value(),*args,**kwargs)
def remote_method(method,rref,*args,**kwargs):
args=[method,rref]+list(args)
return rpc.rpc_sync(rref.owner(),call_method,args=args,kwargs=kwargs)
class ParameterServer(nn.Module):
def __init__(self,num_gpus):
super().__init__()
self.model=Net(num_gpus)
self.input_device = torch.device("cuda:0" if torch.cuda.is_available() and num_gpus > 0 else "cpu")
def forward(self, inp):
inp = inp.to(self.input_device)
out = self.model(inp)
# This output is forwarded over RPC, which as of 1.5.0 only accepts CPU tensors.
# Tensors must be moved in and out of GPU memory due to this.
out = out.to("cpu")
return out
def get_dist_gradients(self,cid):
grads=dist_autograd.get_gradients(cid)
cpu_grads={}
for k,v in grads.items():
k_cpu,v_cpu=k.to("cpu"),v.to("cpu")
cpu_grads[k_cpu]=v_cpu
\ No newline at end of file
import sys
import argparse
from torch.distributed.rpc import RRef, rpc_async, remote
import torch.distributed.rpc as rpc
import torch
from part.Utils import GraphData as DistGraph
from Sample.Sampler import NeighborSampler
import time
class message_worker:
def __init__(self,worker_size,OBSERVER_NAME):
self.id = rpc.get_worker_info().id
self.rank=self.id-1
self.worker_size=worker_size
#rpc demo
def run_episode(self, agent_rref):
state, ep_reward = self.env.reset(), 0
for _ in range(10000):
# send the state to the agent to get an action
action = agent_rref.rpc_sync().select_action(self.id, state)
# apply the action to the environment, and get the reward
state, reward, done, _ = self.env.step(action)
# report the reward to the agent for training purpose
agent_rref.rpc_sync().report_reward(self.id, reward)
# finishes after the number of self.env._max_episode_steps
if done:
break
def loadGraph(self,path):
self.graph=DistGraph(path+'/rank_{}'.format(self.rank))
#self.graph.load_graph(path+'/rank_{}'.format(self.rank))
#将所需node_list的节点拆成各分区的节点
def split_node_part(self,node_id_list):
part_list=[]
for rank in range(self.worker_size):
group_id=(self.graph.partptr[rank]<=node_id_list) & (node_id_list<self.graph.partptr[rank+1])
part_list.append(torch.masked_select(node_id_list,group_id))
return part_list
def get_random_root(self,num_samples):
return self.graph.get_globalId_by_partitionId(self.rank,torch.randint(high=self.graph.get_part_num(),size=(num_samples,)))
def sampler(self,ob_rrefs,sampler,num_samples,**kwargs):
self.ob_rrefs=ob_rrefs
t1=time.time()
root_list=self.get_random_root(num_samples)
t2=time.time()
#邻居节点list和edge list
if('num_neighbors' in kwargs):
num_neighbors=kwargs.get('num_neighbors')
neighbors=root_list
#print( self.graph.get_node_num(),num_neighbors)
neighbors, edge_index = sampler.sample_from_nodes(root_list,self.graph.edge_index, self.graph.get_node_num(),num_neighbors)
#print(neighbors)
t3=time.time()
part_feature_list=self.split_node_part(neighbors)
t4=time.time()
node_feature=self.get_attr(part_feature_list,self.ob_rrefs)
print(t2-t1,t3-t2,t4-t3,time.time()-t4)
def get_attr(self,node_feature,ob_rrefs):
futs=[]
part=0
for i in range(self.worker_size):
if(i==self.rank):
continue
futs.append(
rpc_async(
ob_rrefs[i].owner(),
ob_rrefs[i].rpc_sync().get_localattr,
args=(node_feature[i],)
)
)
for fut in futs:
feature_part=fut.wait()
local_feature=self.get_localattr(node_feature[self.rank])
''''
#get remote attr
def get_list_attr(self,node_list,ob_rrefs):
feature=torch.zeros(node_list,self.graph.get_attr_size())
sample_info={}
for ob_rank in range(1,self.world_size):
sample_info[ob_rank]=[]
futs=[]
node_feature={}
for i in range(node_list.size()):
for nid in range(node_list[i]):
part_id=self.graph.get_nodepart(nid)
if(self.id!=part_id):
sample_info[part_id].append(nid)
else:
node_feature[nid]=self.graph.get_nodeattr(nid)
for ob_rank in range(1,self.world_size):
if(ob_rank != self.id and sample_info[ob_rank]):
sample_info=set(sample_info[ob_rank]).
futs.append(
rpc_async(
ob_rrefs[ob_rank-1].owner(),
ob_rrefs[ob_rank-1].rpc_sync().get_localattr,
args=(sample_info[ob_rank])
)
)
for fut in futs:
feature_part=fut.wait()
for f in feature_part:
node_feature[f[0]]=f[1:]
return
'''
def get_localattr(self,node_list):
local_id=self.graph.get_localId_by_partitionId(self.rank,node_list)
#print(self.rank,node_list,local_id)
return self.graph.select_attr(local_id)
if __name__=="__main__":
worker=message_worker(2,"observe")
worker.loadGraph('./part/metis_2')
sampler=NeighborSampler()
worker.sampler([],sampler,5,num_neighbors=3)
\ No newline at end of file
from mpi4py import mpi
import numpy
class mpi_worker:
def __init__(self,worker_num):
self.comm = mpi.COMM_WORLD
self.rank = mpi.Get_rank()
self.worker_num=worker_num()
def loadGraph(self,path):
self.graph=[]
def get_attr(self,node_list):
sample_info={}
broadcast_info={}
for rank in range(self.worker_num):
sample_info[rank]=[]
broadcast_info[rank]=[]
node_feature={}
for root in node_list:
part=self.graph.get_nodepart(root[0])
if(part==self.rank):
for nid in root[0,:]:
part0=self.graph.get_nodepart(nid)
sample_info[part0].append(nid)
else:
for nid in root[0,:]:
part0=self.graph.get_nodepart(nid)
if(part0==self.rank):
broadcast_info[part0].append(self.graph.get_nodeattr(nid))
for rank in range(self.worker_num):
data=self.comm.bcast(broadcast_info,rank)
if(sample_info[rank]):
for ndata in data:
nid=ndata[0]
if(nid in sample_info):
node_feature[nid]=ndata[1,:]
\ No newline at end of file
import os.path as osp
import torch
class GraphData():
def __init__(self, path):
assert path is not None and osp.exists(path),'path 不存在'
id,edge_index,data,partptr =torch.load(path)
# 当前分区序号
self.partition_id = id
# 总分区数
self.partitions = partptr.numel() - 1
# 全图结构数据
self.num_nodes = partptr[self.partitions]
self.num_edges = edge_index[0].numel()
self.edge_index = edge_index
# 该分区下的数据(包含特征向量和子图结构)pyg Data数据结构
self.data = data
# 分区映射关系
self.partptr = partptr
def select_attr(self,index):
return torch.index_select(self.data.x,0,index)
#返回全局的节点id 所对应的分区
def get_part_num(self):
return self.data.x.size()[0]
def select_attr(self,index):
return torch.index_select(self.data.x,0,index)
#返回全局的节点id 所对应的分区
def get_localId_by_partitionId(self,id,index):
#print(index)
if(id == -1 or id == 0):
return index
else:
return torch.add(index,-self.partptr[id])
def get_globalId_by_partitionId(self,id,index):
if(id == -1 or id == 0):
return index
else:
return torch.add(index,self.partptr[id])
def get_node_num(self):
return self.num_nodes
def localId_to_globalId(self,id,partitionId:int = -1):
'''
将分区partitionId内的点id映射为全局的id
'''
if partitionId == -1:
partitionId = self.partition_id
assert id >=self.partptr[self.partition_id] and id < self.partptr[self.partition_id+1]
ids_before = 0
if self.partition_id>0:
ids_before = self.partptr[self.partition_id-1]
return id+ids_before
def get_partitionId_by_globalId(self,id):
'''
通过全局id得到对应的分区序号
'''
partitionId = -1
assert id>=0 and id<self.num_nodes,'id 超过范围'
for i in range(self.partitions):
if id>=self.partptr[i] and id<self.partptr[i+1]:
partitionId = i
break
assert partitionId>=0, 'id 不存在对应的分区'
return partitionId
def get_nodes_by_partitionId(self,id):
'''
根据partitioId 返回该分区的节点数量
'''
assert id>=0 and id<self.partitions,'partitionId 非法'
return (int)(self.partptr[id+1]-self.partptr[id])
def __repr__(self):
return (f'{self.__class__.__name__}(\n'
f' partition_id={self.partition_id}\n'
f' data={self.data},\n'
f' global_info('
f'num_nodes={self.num_nodes},'
f' num_edges={self.num_edges},'
f' num_parts={self.partitions},'
f' edge_index=[2,{self.edge_index[0].numel()}])\n'
f')')
\ No newline at end of file
from Utils import GraphData
g = GraphData('./metis_4/rank_0')
'''
GraphData(
partition_id=0
data=Data(x=[679, 1433], edge_index=[2, 2908], y=[679], train_mask=[679], val_mask=[679], test_mask=[679]),
global_info(num_nodes=2029, num_edges=10556, num_parts=4, edge_index=[2,10556])
)
'''
import argparse
from multiprocessing import freeze_support
import os
from itertools import count
import torch.multiprocessing as mp
import torch
import torch.distributed.rpc as rpc
import torch.optim as optim
from torch.distributed.rpc import RRef, rpc_async, remote
from torch.distributions import Categorical
import torch.multiprocessing as mp
from RpcAgent import Agent
AGENT_NAME='server'
OBSERVER_NAME='obs{}'
def init_worker(rank,world_size):
print(rank,world_size)
os.environ['MASTER_ADDR']='localhost'
os.environ['MASTER_PORT']='29000'
if rank==0:
rpc.init_rpc(AGENT_NAME,rank=rank,world_size=world_size+1)
agent=Agent(world_size,OBSERVER_NAME)
agent.loadGraph('./part/metis_2')
for i in range(10):
agent.run_sample(5,3)
else:
print(OBSERVER_NAME.format(rank-1))
rpc.init_rpc(OBSERVER_NAME.format(rank-1),rank=rank,world_size=world_size+1)
rpc.shutdown()
parser = argparse.ArgumentParser(
description="RPC Reinforcement Learning Example",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument('--world_size', default=2, type=int, metavar='W',
help='number of workers')
parser.add_argument('--log_interval', type=int, default=10, metavar='N',
help='interval between training status logs')
parser.add_argument('--gamma', type=float, default=0.99, metavar='G',
help='how much to value future rewards')
parser.add_argument('--seed', type=int, default=1, metavar='S',
help='random seed for reproducibility')
args = parser.parse_args()
if __name__ == '__main__':
mp.spawn(
init_worker,
args=(args.world_size, ),
nprocs=args.world_size+1,
join=True
)
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment