Commit eb2cca25 by zljJoan

delete useless file

parent 1c113297
import numpy as np
from message_worker import sampler_worker
import torch
import torch.distributed.rpc as rpc
import torch.optim as optim
from torch.distributed.rpc import RRef, rpc_async, remote
from torch.distributions import Categorical
#from Sample.Sampler import NeighborSampler
class Agent:
def __init__(self, worker_size,OBSERVER_NAME):
self.ob_rrefs = []
self.agent_rref = RRef(self)
self.rewards = {}
self.saved_log_probs = {}
self.worker_size=worker_size
for ob_rank in range(worker_size):
ob_info = rpc.get_worker_info(OBSERVER_NAME.format(ob_rank))
print(OBSERVER_NAME.format(ob_rank),(worker_size,OBSERVER_NAME))
self.ob_rrefs.append(remote(ob_info, sampler_worker,args=(worker_size,OBSERVER_NAME)))
self.rewards[ob_info.id] = []
self.saved_log_probs[ob_info.id] = []
def run_sample(self,batch_size,num_samples):
futs=[]
sampler=NeighborSampler()
for ob_rref in self.ob_rrefs:
futs.append(
rpc_async(
ob_rref.owner(),
ob_rref.rpc_sync().sampler,
args=(self.ob_rrefs,sampler,5),
kwargs={"num_neighbors":3}
)
)
for fut in futs:
fut.wait()
def loadGraph(self,path):
print((path))
futs=[]
for ob_rref in self.ob_rrefs:
futs.append(
rpc_async(
ob_rref.owner(),
ob_rref.rpc_sync().loadGraph,
args=(path,)
)
)
for fut in futs:
fut.wait()
\ No newline at end of file
import torch
class Graph:
def __init__(self,**kwargs):
if('edge_index' in kwargs):
self.nmap = kwargs.get('edge_index')
if('edge_attr' in kwargs):
self.nmap = kwargs.get('edge_attr')
if('node_attr' in kwargs):
self.nmap = kwargs.get('node_attr')
if('y' in kwargs):
self.nmap = kwargs.get('y')
'''
edge_index srcId,dstId 全图部分
nmap real_id -> node_attr_id
emap real_id -> edge_attr_id
npart node_id->part_id
epart edge_id->part_id
'''
class DistGraph(Graph):
def __init__(self,**kwargs):
if('nmap' in kwargs):
self.nmap = kwargs.get('nmap')
if('emap' in kwargs):
self.nmap = kwargs.get('emap')
if('npart' in kwargs):
self.nmap = kwargs.get('npart')
if('epart' in kwargs):
self.nmap = kwargs.get('epart')
if(kwargs):
super(Graph,self).__init__(kwargs)
def get_node_num(self):
return self.adj.size(0)
def load_graph(self,path):
print(" load graph ",path)
adj,data,partptr,perm =torch.load(path)
self.adj=adj
self.data=data #maybe nmap
self.partptr=partptr
self.perm=perm
def get_part_num(self):
return self.data.x.size()[0]
def select_attr(self,index):
return torch.index_select(self.data.x,0,index)
#返回全局的节点id 所对应的分区
def get_part(self,index):
print(index)
return torch.index_select(self.partptr,0,index)
import torch
import torch.distributed.autograd as dist_autograd
import torch.distributed.rpc as rpc
import torch.multiprocessing as mp
import torch.nn as nn
import torch.nn.functional as F
from . import Net
from torch import optim
from torch.distributed.optim import DistributedOptimizer
from torchvision import datasets,transforms
def call_method(method,rref,*args,**kwargs):
return method(rref.local_value(),*args,**kwargs)
def remote_method(method,rref,*args,**kwargs):
args=[method,rref]+list(args)
return rpc.rpc_sync(rref.owner(),call_method,args=args,kwargs=kwargs)
class ParameterServer(nn.Module):
def __init__(self,num_gpus):
super().__init__()
self.model=Net(num_gpus)
self.input_device = torch.device("cuda:0" if torch.cuda.is_available() and num_gpus > 0 else "cpu")
def forward(self, inp):
inp = inp.to(self.input_device)
out = self.model(inp)
# This output is forwarded over RPC, which as of 1.5.0 only accepts CPU tensors.
# Tensors must be moved in and out of GPU memory due to this.
out = out.to("cpu")
return out
def get_dist_gradients(self,cid):
grads=dist_autograd.get_gradients(cid)
cpu_grads={}
for k,v in grads.items():
k_cpu,v_cpu=k.to("cpu"),v.to("cpu")
cpu_grads[k_cpu]=v_cpu
\ No newline at end of file
from mpi4py import mpi
import numpy
class mpi_worker:
def __init__(self,worker_num):
self.comm = mpi.COMM_WORLD
self.rank = mpi.Get_rank()
self.worker_num=worker_num()
def loadGraph(self,path):
self.graph=[]
def get_attr(self,node_list):
sample_info={}
broadcast_info={}
for rank in range(self.worker_num):
sample_info[rank]=[]
broadcast_info[rank]=[]
node_feature={}
for root in node_list:
part=self.graph.get_nodepart(root[0])
if(part==self.rank):
for nid in root[0,:]:
part0=self.graph.get_nodepart(nid)
sample_info[part0].append(nid)
else:
for nid in root[0,:]:
part0=self.graph.get_nodepart(nid)
if(part0==self.rank):
broadcast_info[part0].append(self.graph.get_nodeattr(nid))
for rank in range(self.worker_num):
data=self.comm.bcast(broadcast_info,rank)
if(sample_info[rank]):
for ndata in data:
nid=ndata[0]
if(nid in sample_info):
node_feature[nid]=ndata[1,:]
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment