Commit f69647c0 by zhlj

Initial commit

parents
from distparser import WORLD_SIZE
import torch
class BatchData:
'''
Args:
batch_size: the number of sampled node
nids: the real id of sampled nodes and their neighbours
edge_index: the edge between nids, the nodes' id is the index of BatchData
x: nodes' feature
y:
eids: the edge id in subgraph of edge_index
train_mask
val_mask
test_mask
'''
def __init__(self,nids,edge_index,roots=None,x=None,y=None,eids=None,train_mask=None,val_mask=None,test_mask=None):
self.batch_size=roots.size(0)
self.roots=roots
self.nids=nids
self.edge_index=edge_index
self.x=x
self.y=y
self.eids=eids
self.train_mask=train_mask
self.val_mask=val_mask
self.test_mask=test_mask
def _check_with_graph(self,graph1,graph2):
for i,id in enumerate(self.nids):
if(id >= graph1.partptr[0] and id < graph1.partptr[1]):
real_x = graph1.select_attr(graph1.get_localId_by_partitionId(0,torch.tensor(id)))
else:
real_x = graph2.select_attr(graph2.get_localId_by_partitionId(1,torch.tensor(id)))
if(torch.equal(real_x[0],self.x[i])==False):
print(real_x[0],self.x[i])
print('error')
def __repr__(self):
return "BatchData(batch_size = {},roots = {} , \
nides = {} , edge_index = {} , x= {}, \
y ={})".format(self.batch_size,self.roots.__repr__,
self.nids.__repr__,
self.edge_index.__repr__,
self.train_mask.__repr__,
self.x.__repr__,self.y.__repr__)
from collections import deque
from enum import Enum
from multiprocessing.connection import Client
import queue
from threading import Thread
import traceback
import torch
from torch.distributed.rpc import RRef, rpc_async, remote
from Sample.neighbor_sampler import get_neighbors
from Sample.neighbor_sampler import NeighborSampler
import time
import torch.distributed as dist
import torch.multiprocessing as mp
import graph_store
import os
import time
from message_worker import _check_future_finish, _get_batch_data, _sample_node_neighbors_server
from rpc_server import RpcCommand, close_rpc, start_rpc_listener,start_rpc_caller
import distparser as parser
import sys
from logger import logger
WORKER_RANK = parser._get_worker_rank()
NUM_SAMPLER = parser._get_num_sampler()
WORLD_SIZE = parser._get_world_size()
QUEUE_SIZE =parser._get_queue_size()
MAX_QUEUE_SIZE = parser._get_max_queue_size()
RPC_NAME = parser._get_RPC_NAME()
class MpCommand(Enum):
"""Enum class for multiprocessing command"""
INIT_RPC = 0 # Not used in the task queue
SET_COLLATE_FN = 1
CALL_BARRIER = 2
DELETE_COLLATE_FN = 3
CALL_COLLATE_FN = 4
CALL_FN_ALL_WORKERS = 5
FINALIZE_POOL = 6
start = time.time()
total_time = 0
def wait_thread(waitfut,data_queue):
global total_time
global start
while(True):
if len(waitfut) > 0:
result = waitfut[0]
if result == -1:
break
future_to_check = result.get("union_args")[-1]
#print('check',time.time()-start)
if(_check_future_finish(future_to_check) is True):
start = time.time()
batch_data = _get_batch_data(result)
total_time = total_time+time.time()-result.get("append_time")
#logger.debug('time wait for answer {}, time for get batch {},total_time {}'.format(start - result.get("append_time"),time.time()-start,total_time))
waitfut.popleft()
if not isinstance(data_queue,deque):
data_queue.put(
(
result.get("data_name"),
batch_data,
)
)
else:
data_queue.append(
batch_data
)
custom_pool = None
rpc_server_prco = None
rpc_server_queue = None
rpc_server_barrier = None
keep_polling = True
class LocalSampler:
def __init__(self):
self.waitfut = deque()
self.data_queue = deque()
self.thread = Thread(target=wait_thread,args=(self.waitfut,self.data_queue))
self.thread.start()
def set_collate_fn(self,dataloader_name,graph_inshm):
self.dataloader_name = dataloader_name
graph_store._set_graph(dataloader_name,graph_inshm)
global rpc_server_queue
rpc_server_queue.put((RpcCommand.SET_GRAPH,(dataloader_name,graph_inshm)))
rpc_server_queue.put((RpcCommand.CALL_BARRIER,tuple()))
global rpc_server_barrier
rpc_server_barrier.wait()
self.results=[]
def sample_next(self, dataloader_name,sampler,input_nodes):
#out = sampler.sample_from_nodes((dataloader_name+"_"+str(WORKER_RANK),input_nodes,None))
#neighbor_nodes = out.node
#sampled_edge_index = torch.cat((out.row.reshape(1,-1),out.col.reshape(1,-1)),0)
t1 = time.time()
neighbor_nodes, sampled_edge_index= sampler.sample_from_nodes(input_nodes)
t2 = time.time()
union_args=_sample_node_neighbors_server(dataloader_name,neighbor_nodes)
t3 = time.time()
#logger.debug('sample {},get union_args {}'.format(t2-t1,t3-t2))
start = time.time()
self.waitfut.append({
"data_name":dataloader_name,
"input_size":input_nodes.size(0),
"nids":neighbor_nodes,
"union_args":union_args,
"edge_index":sampled_edge_index,
"append_time":time.time()}
)
def get_result(self):
if(len(self.data_queue)!=0):
result = self.data_queue[0]
self.data_queue.popleft()
return result
else:
return None
def __del__(self):
dist.barrier()
global rpc_server_queue
rpc_server_queue.put(
(RpcCommand.UNLOAD_GRAPH, (self.dataloader_name,))
)
global rpc_server_barrier
if(keep_polling is True):
rpc_server_barrier.wait()
graph_inshm = graph_store._get_graph(self.dataloader_name)
graph_store._del_graph(self.dataloader_name)
graph_inshm._close_graph_in_shame()
graph_inshm._unlink_graph_in_shame()
self.waitfut.append(-1)
def init_process(sampler_id,rpc_config,comm_config):
"""start the proxy of rpc on the process """
rpc_master_addr,rpc_port,num_worker_threads = rpc_config
start_rpc_caller(rpc_master_addr,
rpc_port,
{
"num_worker_threads": num_worker_threads,
"rpc_name": RPC_NAME,
"rpc_world_size": WORLD_SIZE * (NUM_SAMPLER + 1),
"worker_rank": WORKER_RANK,
"rpc_worker_rank": (NUM_SAMPLER + 1) * WORKER_RANK + sampler_id + 1}
)
print('start work')
global start
try:
data_queue,task_queue,barrier = comm_config
collate_fn_dict = {}
sampler_dict = {}
keep_poll = True
waitfut = deque()
thread = Thread(target=wait_thread,args=(waitfut,data_queue))
thread.start()
while keep_poll or graph_store._get_size() > 0:
if(len(waitfut)>QUEUE_SIZE):
continue
try:
command,args = task_queue.get(timeout=5)
except queue.Empty:
continue
if command == MpCommand.SET_COLLATE_FN:
dataloader_name,graph_inshm,func,sampler_info = args
graph_store._set_graph(dataloader_name,graph_inshm)
collate_fn_dict[dataloader_name] = func
#neighbors,deg = graph_inshm._get_sampler_from_shame()
#sampler = sampler_info #NeighborSampler(*sampler_info,neighbors= neighbors,deg = deg)
graph = graph_store._get_graph(dataloader_name)
row, col = graph.edge_index
tnb = get_neighbors(row.contiguous(), col.contiguous(), graph.num_nodes)
sampler = NeighborSampler(*sampler_info,tnb=tnb)
sampler_dict[dataloader_name] = sampler
elif command == MpCommand.CALL_BARRIER:
barrier.wait()
elif command == MpCommand.DELETE_COLLATE_FN:
(dataloader_name,) = args
del collate_fn_dict[dataloader_name]
del sampler_dict[dataloader_name]
graph_inshm = graph_store._get_graph(dataloader_name)
graph_store._del_graph(dataloader_name)
graph_inshm._close_graph_in_shame()
if(barrier.wait()==0):
graph_inshm._unlink_graph_in_shame()
elif command == MpCommand.CALL_COLLATE_FN:
dataloader_name, collate_args, start_time = args
t1 = time.time()
#out = sampler_dict[dataloader_name].sample_from_nodes((dataloader_name,collate_args,time.time()))
#neighbor_nodes = out.node
#sampled_edge_index = torch.cat((out.row.reshape(1,-1),out.col.reshape(1,-1)),0)
neighbor_nodes, sampled_edge_index= sampler_dict[dataloader_name].sample_from_nodes(collate_args)
t2 = time.time()
union_args=_sample_node_neighbors_server(dataloader_name,neighbor_nodes)
t3 = time.time()
start = time.time()
#logger.debug('wait for input {},sample {},get union_args {}'.format(t1-start_time,t2-t1,t3-t2))
waitfut.append({
"data_name":dataloader_name,
"input_size":collate_args.size(0),
"nids":neighbor_nodes,
"union_args":union_args,
"edge_index":sampled_edge_index,
"append_time":time.time()}
)
#data_queue.put(
# (
# dataloader_name,
# collate_fn_dict[dataloader_name](collate_args),
# )
#)
elif command == MpCommand.CALL_FN_ALL_WORKERS:
func, func_args = args
func(func_args)
elif command == MpCommand.FINALIZE_POOL:
keep_poll = False
graph_store._clear_all(barrier)
close_rpc()
else:
raise Exception("Unknown command")
except Exception as e:
traceback.print_exc()
raise e
class CustomPool:
def __init__(self,rpc_config):
self.max_queue_size = MAX_QUEUE_SIZE
self.num_samplers = NUM_SAMPLER
ctx = mp.get_context("spawn")
self.result_queue = ctx.Queue(self.max_queue_size)
self.results = {}
self.task_queues = []
self.process_list = []
self.current_proc_id = 0
self.cache_result_dict = {}
self.barrier = ctx.Barrier(self.num_samplers)
for rank in range(self.num_samplers):
task_queue = ctx.Queue(self.max_queue_size)
self.task_queues.append(task_queue)
proc = ctx.Process(
target=init_process,
args=(
rank,rpc_config,
(self.result_queue,task_queue,self.barrier),
)
)
proc.daemon=True
proc.start()
self.process_list.append(proc)
self.call_barrier()
def set_collate_fn(self,func,sampler,dataloader_name,graph_inshm):
for i in range(self.num_samplers):
self.task_queues[i].put(
(MpCommand.SET_COLLATE_FN,(dataloader_name,graph_inshm,func,sampler))
)
global rpc_server_queue
rpc_server_queue.put((RpcCommand.SET_GRAPH,(dataloader_name,graph_inshm)))
rpc_server_queue.put((RpcCommand.CALL_BARRIER,tuple()))
global rpc_server_barrier
rpc_server_barrier.wait()
self.call_barrier()
self.results[dataloader_name]=[]
def submit_task(self, dataloader_name, args):
"""Submit task to workers"""
# Round robin
self.task_queues[self.current_proc_id].put(
(MpCommand.CALL_COLLATE_FN, (dataloader_name, args,time.time()))
)
self.current_proc_id = (self.current_proc_id + 1) % self.num_samplers
def submit_task_to_all_workers(self, func, args):
"""Submit task to all workers"""
for i in range(self.num_samplers):
self.task_queues[i].put(
(MpCommand.CALL_FN_ALL_WORKERS, (func, args))
)
def get_result(self, dataloader_name, timeout=1800):
"""Get result from result queue"""
if dataloader_name not in self.results:
raise Exception(
f"Got result from an unknown dataloader {dataloader_name}."
)
while len(self.results[dataloader_name]) == 0:
data_name, data = self.result_queue.get(timeout=timeout)
self.results[data_name].append(data)
return self.results[dataloader_name].pop(0)
def delete_collate_fn(self, dataloader_name):
"""Delete collate function"""
global rpc_server_queue
rpc_server_queue.put(
(RpcCommand.UNLOAD_GRAPH, (dataloader_name,))
)
global rpc_server_barrier
if(keep_polling is True):
rpc_server_barrier.wait()
for i in range(self.num_samplers):
self.task_queues[i].put(
(MpCommand.DELETE_COLLATE_FN, (dataloader_name,))
)
if dataloader_name in self.results:
del self.results[dataloader_name]
def call_barrier(self):
"""Call barrier at all workers"""
for i in range(self.num_samplers):
self.task_queues[i].put((MpCommand.CALL_BARRIER, tuple()))
def close(self):
"""Close worker pool"""
for i in range(self.num_samplers):
self.task_queues[i].put(
(MpCommand.FINALIZE_POOL, tuple()), block=False
)
time.sleep(0.5) # Fix for early python version
def join(self):
"""Join the close process of worker pool"""
for i in range(self.num_samplers):
self.process_list[i].join()
sample_group = None
custom_pool = None
def get_sampler_pool():
global custom_pool
return custom_pool
def get_sample_group():
global sample_group
return sample_group
def init_distribution(master_addr=None,master_port=None,rpc_master_addr=None,rpc_port=None, num_worker_threads = 16,backend = "gloo"):
print('init distribution')
global sample_group
global custom_pool
world_size = parser._get_world_size()
worker_rank = parser._get_worker_rank()
num_sampler = parser._get_num_sampler()
if master_addr is None or master_port is None:
raise Exception(
f"The master address is unknown."
)
if rpc_port is None:
raise Exception(
f"The rpc listener address is unknown."
)
init_method="tcp://{}:{}".format(master_addr,master_port)
sample_group = dist.init_process_group(backend=backend, world_size=world_size, rank=worker_rank,init_method=init_method,group_name='sample-default-group')
if world_size > 1:
ctx = mp.get_context("spawn")
global rpc_server_queue
rpc_server_queue = ctx.Queue(MAX_QUEUE_SIZE)
global rpc_server_barrier
rpc_server_barrier = ctx.Barrier(2)
rpc_server_proc= ctx.Process(target = start_rpc_listener,
args=(rpc_master_addr,
rpc_port,
{"num_worker_threads": num_worker_threads,
"rpc_name": RPC_NAME,
"rpc_world_size": world_size * (num_sampler + 1),
"worker_rank": worker_rank,
"rpc_worker_rank": (num_sampler + 1) * worker_rank
},
(rpc_server_queue,rpc_server_barrier)),
)
rpc_server_proc.start()
if parser._get_num_sampler()>1:
custom_pool = CustomPool((rpc_master_addr,rpc_port,num_worker_threads))
else:
start_rpc_caller(rpc_master_addr,
rpc_port,
{
"num_worker_threads": num_worker_threads,
"rpc_name": RPC_NAME,
"rpc_world_size": WORLD_SIZE * (NUM_SAMPLER + 1),
"worker_rank": WORKER_RANK,
"rpc_worker_rank": (NUM_SAMPLER + 1) * WORKER_RANK + 1}
)
dist.barrier()
else:
custom_pool = None
def close_distribution():
global custom_pool
global rpc_server_queue
global rpc_server_barrier
global keep_polling
if parser._get_world_size()>1 and custom_pool is not None:
dist.barrier()
print('del')
rpc_server_queue.put((RpcCommand.STOP_RPC,tuple()))
rpc_server_barrier.wait()
custom_pool.close()
keep_polling = False
elif parser._get_world_size()>1 and custom_pool is None:
dist.barrier()
rpc_server_queue.put((RpcCommand.STOP_RPC,tuple()))
rpc_server_barrier.wait()
keep_polling = False
close_rpc()
#rpc.init_rpc(OBSERVER_NAME.format(rank),rank=rank,world_size=world_size)
#if(rank==0):
# message_worker.init_sampler_worker()
#dist.barrier()
#print(OBSERVER_NAME.format(rank))
#worker = sampler_worker(world_size,OBSERVER_NAME)
#if(rank==0):
#rrefs =
#master_worker = sampler_worker(world_size,OBSERVER_NAME)
from enum import Enum
import torch
from message_worker import _get_batch_data, _sample_node_neighbors_server, _sample_node_neighbors_single
from part.Utils import GraphData
from typing import Optional
import torch.distributed as dist
import DistCustomPool
from shared_graph_memory import GraphInfoInShm
import distparser as parser
from torch_geometric.data import Data
import os.path as osp
def partition_load(root: str, algo: str = "metis") -> Data:
rank = parser._get_worker_rank()
world_size = parser._get_world_size()
fn = osp.join(root, f"{algo}_{world_size}", f"{rank:03d}")
return torch.load(fn)
class DistGraphData(GraphData):
def __init__(self,pdata = None,edge_index = None,path = None):
if path is not None:
self.rank = parser._get_worker_rank()
path = path + '/rank_' + str(self.rank)
print('load graph ',path)
super(DistGraphData,self).__init__(path)
else:
#dst和edge在一个分区,src不一定
#本地节点id
self.ids = pdata.ids
self.origin_edge_index = edge_index
#特征信息
self.data = Data()
self.data.x = pdata.x
self.data.y = pdata.y
self.data.y = self.data.y.reshape(-1)
self.data.train_mask = pdata.train_mask
self.data.test_mask = pdata.test_mask
self.data.val_mask = pdata.val_mask
self.rank = parser._get_worker_rank()
#通信后获得索引
world_size = parser._get_world_size()
sample_group = DistCustomPool.get_sample_group()
self.partition_id = self.rank
self.partptr = torch.zeros(world_size + 1).int()
self.num_edge_part = torch.zeros(world_size+1).int()
self.num_nodes = 0
self.partitions = world_size
self.num_parts = self.partitions
self.num_feasures = pdata.x.size()[1]
global_edge = []
if world_size != 1:
for rank in range(world_size):
#dst在同一个分区,src不一定,映射dst和src的
rev_msg = [len(self.ids),len(edge_index[0,:])]
print(rev_msg)
dist.broadcast_object_list(rev_msg,rank,group=sample_group)
self.num_edge_part[rank + 1] = self.num_edge_part[rank] + rev_msg[1]
self.num_nodes = self.num_nodes + rev_msg[0]
self.partptr[rank + 1] = self.num_nodes
ptr_index = torch.zeros_like(edge_index)
self.edgeindex2ptr(self.ids,edge_index[1,:],ptr_index[1,:],self.rank)
dist.barrier()
for rank in range(world_size):
#edge_index映射成新的id
if rank != self.rank:
rev_idx = torch.zeros(self.partptr[rank+1] - self.partptr[rank]).type_as(self.ids)
else:
rev_idx = self.ids.clone()
print('idx',rev_idx)
dist.broadcast(rev_idx,rank,group=sample_group)
self.edgeindex2ptr(rev_idx,edge_index[0,:],ptr_index[0,:],rank)
print(ptr_index)
self.data.edge_index = ptr_index
dist.barrier()
for rank in range(world_size):
#edge_index映射成新的id
if rank != self.rank:
rev_idx = torch.zeros(2,self.num_edge_part[rank + 1] - self.num_edge_part[rank]).type_as(ptr_index)
else:
rev_idx = ptr_index.clone()
print('edge',rev_idx)
dist.broadcast(rev_idx,rank,group=sample_group)
global_edge.append(rev_idx)
self.edge_index = torch.cat(global_edge,1)
dist.barrier()
else:
self.num_nodes = len(self.ids)
self.partptr[1] = self.num_nodes
self.edge_index = edge_index
self.data.edge_index = edge_index
self.num_edges = edge_index[0].numel()
def edgeindex2ptr(self,ids,edge_index,ptr_index,rank):
#构建映射
dic = dict(zip(ids.tolist(),torch.arange(len(ids)).tolist()))
print(dic)
#values = torch.arange(len(ids))
#indices = ids
#sparse_coo = torch.sparse_coo_tensor(indices.unsqueeze(0),values)
#sparse_csr = sparse_coo.to_sparse_csr
for i in range(len(edge_index)):
id = int(edge_index[i].item())
#print(id)
if id in dic:
print(id)
ptr_index[i] = dic[id] + self.partptr[rank]
print(ptr_index)
class DistributedDataLoader:
'''
Args:
data_path: the path of loaded graph ,each part 0 of graph is saved on $path$/rank_0
num_replicas: the num of worker
'''
def __init__(
self,
graph_name,
graph,
data_index_mask = None,
sampler = None,
collate_fn = None,
batch_size: Optional[int]=None,
shuffle:bool = True,
seed:int = 0,
drop_last = False,
**kwargs
):
assert sampler is not None
self.pool = DistCustomPool.get_sampler_pool()
self.graph_name = graph_name
self.batch_size = batch_size
self.num_workers = parser._get_num_sampler()
self.queue_size = parser._get_queue_size()
#if(queue_size is None):
# queue_size = self.num_workers * 4 if self.num_workers > 0 else 4
self.num_pending = 0
# self.num_node_features = graph.num_node_features
self.collate_fn = collate_fn
self.current_pos = 0
self.drop_last = drop_last
self.recv_idxs = 0
self.queue = []
self.shuffle = shuffle
self.is_closed = False
self.sampler = sampler
self.epoch = 0
self.graph = graph
if parser._get_world_size() > 1:
self.graph_inshm = GraphInfoInShm(graph)
if self.pool is None:
self.sampler_info = sampler
else :
self.sampler_info = (sampler.num_nodes,sampler.num_layers,sampler.fanout,sampler.workers)#sampler#._get_sample_info()
#self.graph_inshm._copy_sampler_to_shame(*sampler._get_neighbors_and_deg())
self.shuffle=shuffle
self.seed=seed
self.kwargs=kwargs
self.rank = parser._get_worker_rank()
self.data = torch.arange(start = graph.partptr[self.rank],end = graph.partptr[self.rank+1])
if(data_index_mask is not None):
self.data = torch.masked_select(self.data,data_index_mask)
print('the number of input_number is ',self.data.size(0))
self._get_expected_idx(drop_last,self.data.size(0))
if(self.pool is not None and parser._get_world_size() > 1 ):
self.pool.set_collate_fn(self.collate_fn,self.sampler_info,self.graph_name,self.graph_inshm)
dist.barrier()
elif parser._get_world_size() > 1 :
self.local_pool = DistCustomPool.LocalSampler()
self.local_pool.set_collate_fn(self.graph_name,self.graph_inshm)
def _get_expected_idx(self,drop_last,data_length):
if parser._get_world_size() > 1:
world_size = parser._get_world_size()
sample_group = DistCustomPool.get_sample_group()
expected_data = data_length
for rank in range(world_size):
len = torch.tensor(data_length)
dist.broadcast(len,rank,group=sample_group)
if(drop_last is True):
expected_data = min(expected_data,int(len.item()))
else:
expected_data = max(expected_data,int(len.item()))
self.expected_idx = expected_data // self.batch_size
if(not self.drop_last and expected_data % self.batch_size != 0):
self.expected_idx += 1
print('expected_index ',self.expected_idx)
else:
self.expected_idx = data_length// self.batch_size
if(not self.drop_last and data_length % self.batch_size != 0):
self.expected_idx += 1
def __next__(self):
#print(self.sampler,self.num_samples,self.kwargs)
#self.sampleGraph(self.sampler,self.num_samplers,self.kwargs)
#return self.sampleGraph(self.sampler,self.num_samples,**self.kwargs)
if(self.pool is None and parser._get_world_size() == 1):
if self.recv_idxs < self.expected_idx:
batch_data = self._sample_next_batch()
self.recv_idxs += 1
assert batch_data is not None
return batch_data
else :
raise StopIteration
elif self.pool is not None:
num_reqs = min(self.queue_size - self.num_pending,self.expected_idx - self.submitted)
for _ in range(num_reqs):
self._submit_request_next()
self.submitted = self.submitted + 1
if self.recv_idxs < self.expected_idx:
result = self._get_data_from_queue()
self.recv_idxs += 1
self.num_pending -=1
return result
else :
assert self.num_pending == 0
raise StopIteration
else :
num_reqs = min(self.queue_size - self.num_pending,self.expected_idx - self.submitted)
for _ in range(num_reqs):
result = self.local_pool.get_result()
if result is not None:
self.recv_idxs += 1
self.num_pending -= 1
return result
else:
next_data = self._next_data()
if next_data is None:
continue
assert next_data is not None
self.local_pool.sample_next(self.graph_name,self.sampler_info,next_data)
self.submitted = self.submitted + 1
self.num_pending = self.num_pending + 1
while(self.recv_idxs < self.expected_idx):
result = self.local_pool.get_result()
if result is not None:
self.recv_idxs += 1
self.num_pending -= 1
return result
assert self.num_pending == 0
raise StopIteration
def _get_data_from_queue(self,timeout=1800):
ret = self.pool.get_result(self.graph_name,timeout = timeout)
return ret
def _sample_next_batch(self):
next_data = self._next_data()
if next_data is None:
return None
batch_data = _sample_node_neighbors_single(self.graph,self.graph_name,self.sampler,next_data)
return batch_data
def _submit_request_next(self):
next_data = self._next_data()
if next_data is None:
return
self.pool.submit_task(self.graph_name,next_data)
self.num_pending += 1
def _next_data(self):
if self.current_pos >= len(self.data):
return None
next_idx = None
if self.current_pos + self.batch_size > len(self.data):
if self.drop_last:
return None
else:
next_idx = self.input_data[self.current_pos:]
self.current_pos = 0
else:
next_idx = self.input_data[self.current_pos:self.current_pos + self.batch_size]
self.current_pos += len(next_idx)
return next_idx
def __iter__(self):
if self.shuffle:
indx = torch.randperm(self.data.size(0))
self.input_data = self.data[indx]
else:
self.input_data = self.data
#expected_len = self.expected_idx * self.batch_size
#if(not self.drop_last):
# k = expected_len // len(self.input_data)
# if(expected_len % len(self.input_data) !=0): k = k+1
# if(self.data.shape.numel() == 2):
# self.input_data = self.input_data.repeat(k,1)
# else:
# self.input_data = self.input_data.repeat(k)
# self.input_data = self.input_data[:expected_len]
self.recv_idxs = 0
self.current_pos = 0
self.num_pending = 0
self.submitted = 0
return self
def __del__(self):
if self.pool is not None:
self.pool.delete_collate_fn(self.graph_name)
def set_epoch(self,epoch):
self.epoch = epoch
\ No newline at end of file
init distribution
! 2 0
get_neighbors consume: 0.00205912s
neighbor_sample_from_nodes consume: 1.9188e-05s
unique consume: 4.72449e-05s
neighbor_sample_from_nodes consume: 1.7761e-05s
unique consume: 3.89579e-05s
neighbor_sample_from_nodes consume: 1.6575e-05s
unique consume: 3.77171e-05s
neighbor_sample_from_nodes consume: 2.5596e-05s
unique consume: 2.6808e-05s
neighbor_sample_from_nodes consume: 2.5437e-05s
unique consume: 2.45729e-05s
neighbor_sample_from_nodes consume: 2.556e-05s
unique consume: 2.5559e-05s
neighbor_sample_from_nodes consume: 3.456e-05s
unique consume: 6.08771e-05s
neighbor_sample_from_nodes consume: 1.70741e-05s
unique consume: 2.65629e-05s
neighbor_sample_from_nodes consume: 1.72061e-05s
unique consume: 2.5715e-05s
neighbor_sample_from_nodes consume: 1.72009e-05s
unique consume: 2.7349e-05s
neighbor_sample_from_nodes consume: 2.53409e-05s
unique consume: 5.2902e-05s
neighbor_sample_from_nodes consume: 5.4088e-05s
unique consume: 4.0817e-05s
neighbor_sample_from_nodes consume: 2.7567e-05s
unique consume: 2.47849e-05s
neighbor_sample_from_nodes consume: 2.7924e-05s
unique consume: 4.4742e-05s
neighbor_sample_from_nodes consume: 1.6885e-05s
unique consume: 2.8905e-05s
neighbor_sample_from_nodes consume: 1.729e-05s
unique consume: 2.45421e-05s
neighbor_sample_from_nodes consume: 1.73061e-05s
unique consume: 2.5179e-05s
neighbor_sample_from_nodes consume: 2.29189e-05s
unique consume: 2.37221e-05s
neighbor_sample_from_nodes consume: 2.54969e-05s
unique consume: 2.2979e-05s
neighbor_sample_from_nodes consume: 2.7542e-05s
unique consume: 2.4079e-05s
neighbor_sample_from_nodes consume: 2.6393e-05s
unique consume: 3.1711e-05s
neighbor_sample_from_nodes consume: 1.73319e-05s
unique consume: 2.51541e-05s
neighbor_sample_from_nodes consume: 1.6988e-05s
unique consume: 2.5345e-05s
neighbor_sample_from_nodes consume: 1.7336e-05s
unique consume: 2.58629e-05s
neighbor_sample_from_nodes consume: 3.4753e-05s
unique consume: 2.4382e-05s
neighbor_sample_from_nodes consume: 2.5076e-05s
unique consume: 2.3602e-05s
neighbor_sample_from_nodes consume: 2.5392e-05s
unique consume: 2.64619e-05s
neighbor_sample_from_nodes consume: 2.7744e-05s
unique consume: 2.5377e-05s
neighbor_sample_from_nodes consume: 1.7647e-05s
unique consume: 2.6718e-05s
neighbor_sample_from_nodes consume: 2.8096e-05s
unique consume: 2.5603e-05s
neighbor_sample_from_nodes consume: 8.59704e-06s
unique consume: 2.3596e-05s
neighbor_sample_from_nodes consume: 1.04799e-05s
unique consume: 2.356e-05s
load graph /home/sxx/zlj/rpc_ps/part/metis_2/rank_0
<bound method BatchData.__repr__ of BatchData(batch_size = 5,roots = <bound method Tensor.__repr__ of tensor([1274, 1357, 1063, 1340, 1317])> , nides = <bound method Tensor.__repr__ of tensor([1274, 1357, 1063, 1340, 1317, 1202, 454, 705, 48, 2113, 559, 69,
1152, 912, 1241, 131, 1278, 44, 249, 2430, 1200])> , edge_index = <method-wrapper '__repr__' of list object at 0x7fe8006d6640> , x= <bound method Tensor.__repr__ of tensor([[0., 0., 0., ..., 0., 1., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 1., 0.]])>, y =<bound method Tensor.__repr__ of tensor([0, 0, 1, 1, 5])>)>
<bound method BatchData.__repr__ of BatchData(batch_size = 5,roots = <bound method Tensor.__repr__ of tensor([1282, 999, 1062, 1300, 1256])> , nides = <bound method Tensor.__repr__ of tensor([1282, 999, 1062, 1300, 1256, 195, 1095, 700, 456, 528, 935, 589,
1060, 706, 561, 933, 813, 352])> , edge_index = <method-wrapper '__repr__' of list object at 0x7fe8006d6c00> , x= <bound method Tensor.__repr__ of tensor([[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 1., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]])>, y =<bound method Tensor.__repr__ of tensor([0, 6, 1, 6, 4])>)>
<bound method BatchData.__repr__ of BatchData(batch_size = 5,roots = <bound method Tensor.__repr__ of tensor([1041, 998, 1004, 1368, 1267])> , nides = <bound method Tensor.__repr__ of tensor([1041, 998, 1004, 1368, 1267, 886, 193, 593, 1022, 843, 658, 1102,
1094, 842, 1013, 347, 274, 862, 9, 897, 36])> , edge_index = <method-wrapper '__repr__' of list object at 0x7fe8006d6e40> , x= <bound method Tensor.__repr__ of tensor([[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 1.]])>, y =<bound method Tensor.__repr__ of tensor([5, 0, 0, 0, 3])>)>
<bound method BatchData.__repr__ of BatchData(batch_size = 5,roots = <bound method Tensor.__repr__ of tensor([1374, 1316, 1381, 1000, 1061])> , nides = <bound method Tensor.__repr__ of tensor([1374, 1316, 1381, 1000, 1061, 207, 933, 562, 398, 1020, 1256, 253,
563, 387, 564, 458, 248, 1054, 1352, 723])> , edge_index = <method-wrapper '__repr__' of list object at 0x7fe8006e3440> , x= <bound method Tensor.__repr__ of tensor([[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]])>, y =<bound method Tensor.__repr__ of tensor([4, 3, 3, 5, 1])>)>
<bound method BatchData.__repr__ of BatchData(batch_size = 5,roots = <bound method Tensor.__repr__ of tensor([1308, 1325, 1349, 1356, 1377])> , nides = <bound method Tensor.__repr__ of tensor([1308, 1325, 1349, 1356, 1377, 670, 1287, 372, 1032, 829, 377, 146,
391, 1099, 1173, 18, 32, 994, 404, 671, 875, 1029])> , edge_index = <method-wrapper '__repr__' of list object at 0x7fe8006e3100> , x= <bound method Tensor.__repr__ of tensor([[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]])>, y =<bound method Tensor.__repr__ of tensor([6, 6, 0, 5, 3])>)>
<bound method BatchData.__repr__ of BatchData(batch_size = 5,roots = <bound method Tensor.__repr__ of tensor([ 386, 1045, 1336, 1341, 1277])> , nides = <bound method Tensor.__repr__ of tensor([ 386, 1045, 1336, 1341, 1277, 12, 1217, 301, 1192, 691, 1105, 620,
754, 919, 649, 1284, 1343, 824, 794, 535, 636, 1092])> , edge_index = <method-wrapper '__repr__' of list object at 0x7fe8006e38c0> , x= <bound method Tensor.__repr__ of tensor([[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]])>, y =<bound method Tensor.__repr__ of tensor([3, 5, 4, 0, 0])>)>
<bound method BatchData.__repr__ of BatchData(batch_size = 5,roots = <bound method Tensor.__repr__ of tensor([1076, 1009, 1303, 1073, 1334])> , nides = <bound method Tensor.__repr__ of tensor([1076, 1009, 1303, 1073, 1334, 1225, 867, 547, 1217, 608, 827, 1179,
1147, 739, 333, 111, 607, 636])> , edge_index = <method-wrapper '__repr__' of list object at 0x7fe8006d6e80> , x= <bound method Tensor.__repr__ of tensor([[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]])>, y =<bound method Tensor.__repr__ of tensor([6, 1, 5, 5, 3])>)>
<bound method BatchData.__repr__ of BatchData(batch_size = 5,roots = <bound method Tensor.__repr__ of tensor([1014, 1029, 1059, 1284, 1294])> , nides = <bound method Tensor.__repr__ of tensor([1014, 1029, 1059, 1284, 1294, 479, 472, 207, 910, 984, 1336, 1304,
600, 197, 12, 723, 1133, 977, 829])> , edge_index = <method-wrapper '__repr__' of list object at 0x7fe8006e5180> , x= <bound method Tensor.__repr__ of tensor([[0., 0., 0., ..., 0., 0., 1.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 1., ..., 0., 0., 0.]])>, y =<bound method Tensor.__repr__ of tensor([1, 6, 1, 0, 2])>)>
<bound method BatchData.__repr__ of BatchData(batch_size = 5,roots = <bound method Tensor.__repr__ of tensor([1032, 1276, 1071, 1309, 1085])> , nides = <bound method Tensor.__repr__ of tensor([1032, 1276, 1071, 1309, 1085, 195, 813, 868, 852, 1303, 1240, 352,
880, 1339, 837, 1068, 1179, 576, 1114, 656, 789])> , edge_index = <method-wrapper '__repr__' of list object at 0x7fe8006e5b80> , x= <bound method Tensor.__repr__ of tensor([[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]])>, y =<bound method Tensor.__repr__ of tensor([0, 1, 5, 0, 1])>)>
<bound method BatchData.__repr__ of BatchData(batch_size = 5,roots = <bound method Tensor.__repr__ of tensor([1375, 1023, 1353, 1306, 984])> , nides = <bound method Tensor.__repr__ of tensor([1375, 1023, 1353, 1306, 984, 1189, 2486, 1273, 840, 1161, 621, 841,
1187, 956, 241, 1159, 197, 647, 1056, 973, 2490])> , edge_index = <method-wrapper '__repr__' of list object at 0x7fe8006e3ac0> , x= <bound method Tensor.__repr__ of tensor([[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 1.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 1., 0., ..., 0., 0., 0.]])>, y =<bound method Tensor.__repr__ of tensor([3, 0, 3, 3, 1])>)>
<bound method BatchData.__repr__ of BatchData(batch_size = 5,roots = <bound method Tensor.__repr__ of tensor([1078, 1046, 1362, 1069, 1038])> , nides = <bound method Tensor.__repr__ of tensor([1078, 1046, 1362, 1069, 1038, 942, 968, 166, 196, 48, 1243, 126,
1199, 883, 1213, 852, 23, 801, 882, 1240, 289, 380])> , edge_index = <method-wrapper '__repr__' of list object at 0x7fe8006e7140> , x= <bound method Tensor.__repr__ of tensor([[0., 0., 0., ..., 0., 0., 0.],
[0., 1., 0., ..., 0., 0., 0.],
[1., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]])>, y =<bound method Tensor.__repr__ of tensor([5, 1, 1, 1, 6])>)>
<bound method BatchData.__repr__ of BatchData(batch_size = 5,roots = <bound method Tensor.__repr__ of tensor([1291, 1055, 1314, 1001, 1065])> , nides = <bound method Tensor.__repr__ of tensor([1291, 1055, 1314, 1001, 1065, 1143, 1107, 540, 727, 276, 1178, 173,
1568, 17, 786, 539, 869, 640])> , edge_index = <method-wrapper '__repr__' of list object at 0x7fe8006e7080> , x= <bound method Tensor.__repr__ of tensor([[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]])>, y =<bound method Tensor.__repr__ of tensor([5, 5, 3, 0, 5])>)>
<bound method BatchData.__repr__ of BatchData(batch_size = 5,roots = <bound method Tensor.__repr__ of tensor([1058, 1299, 1050, 1351, 1068])> , nides = <bound method Tensor.__repr__ of tensor([1058, 1299, 1050, 1351, 1068, 1205, 435, 968, 770, 1103, 9, 123,
1096, 533, 1204, 852, 1143, 1102, 769, 358, 165, 540])> , edge_index = <method-wrapper '__repr__' of list object at 0x7fe8006e7cc0> , x= <bound method Tensor.__repr__ of tensor([[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]])>, y =<bound method Tensor.__repr__ of tensor([1, 5, 5, 0, 1])>)>
<bound method BatchData.__repr__ of BatchData(batch_size = 5,roots = <bound method Tensor.__repr__ of tensor([1346, 1369, 1072, 1307, 1359])> , nides = <bound method Tensor.__repr__ of tensor([1346, 1369, 1072, 1307, 1359, 686, 997, 810, 1032, 216, 1252, 1287,
452, 822, 811, 679, 868, 980, 587, 466, 59, 1260])> , edge_index = <method-wrapper '__repr__' of list object at 0x7fe8006e7240> , x= <bound method Tensor.__repr__ of tensor([[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]])>, y =<bound method Tensor.__repr__ of tensor([5, 0, 5, 5, 0])>)>
<bound method BatchData.__repr__ of BatchData(batch_size = 5,roots = <bound method Tensor.__repr__ of tensor([1079, 1342, 1060, 1373, 1064])> , nides = <bound method Tensor.__repr__ of tensor([1079, 1342, 1060, 1373, 1064, 557, 943, 837, 847, 772, 549, 813,
922, 1016, 830, 421, 236, 838, 24, 986])> , edge_index = <method-wrapper '__repr__' of list object at 0x7fe8006e7b00> , x= <bound method Tensor.__repr__ of tensor([[0., 0., 1., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]])>, y =<bound method Tensor.__repr__ of tensor([1, 6, 1, 3, 6])>)>
<bound method BatchData.__repr__ of BatchData(batch_size = 2,roots = <bound method Tensor.__repr__ of tensor([1057, 1077])> , nides = <bound method Tensor.__repr__ of tensor([1057, 1077, 690, 1111, 781, 685])> , edge_index = <method-wrapper '__repr__' of list object at 0x7fe8006e80c0> , x= <bound method Tensor.__repr__ of tensor([[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]])>, y =<bound method Tensor.__repr__ of tensor([5, 5])>)>
{'num_worker_threads': 16, 'rpc_name': 'rpcserver{}', 'rpc_world_size': 10, 'worker_rank': 0, 'rpc_worker_rank': 0}
127.0.0.1 10001
import torch
from torch import Tensor
from enum import Enum
import math
from abc import ABC
from typing import Optional, Tuple, Union
class NegativeSamplingMode(Enum):
# 'binary': Randomly sample negative edges in the graph.
binary = 'binary'
# 'triplet': Randomly sample negative destination nodes for each positive
# source node.
triplet = 'triplet'
class NegativeSampling:
r"""The negative sampling configuration of a
:class:`~torch_geometric.sampler.BaseSampler` when calling
:meth:`~torch_geometric.sampler.BaseSampler.sample_from_edges`.
Args:
mode (str): The negative sampling mode
(:obj:`"binary"` or :obj:`"triplet"`).
If set to :obj:`"binary"`, will randomly sample negative links
from the graph.
If set to :obj:`"triplet"`, will randomly sample negative
destination nodes for each positive source node.
amount (int or float, optional): The ratio of sampled negative edges to
the number of positive edges. (default: :obj:`1`)
weight (torch.Tensor, optional): A node-level vector determining the
sampling of nodes. Does not necessariyl need to sum up to one.
If not given, negative nodes will be sampled uniformly.
(default: :obj:`None`)
"""
mode: NegativeSamplingMode
amount: Union[int, float] = 1
weight: Optional[Tensor] = None
def __init__(
self,
mode: Union[NegativeSamplingMode, str],
amount: Union[int, float] = 1,
weight: Optional[Tensor] = None,
):
self.mode = NegativeSamplingMode(mode)
self.amount = amount
self.weight = weight
if self.amount <= 0:
raise ValueError(f"The attribute 'amount' needs to be positive "
f"for '{self.__class__.__name__}' "
f"(got {self.amount})")
if self.is_triplet():
if self.amount != math.ceil(self.amount):
raise ValueError(f"The attribute 'amount' needs to be an "
f"integer for '{self.__class__.__name__}' "
f"with 'triplet' negative sampling "
f"(got {self.amount}).")
self.amount = math.ceil(self.amount)
def is_binary(self) -> bool:
return self.mode == NegativeSamplingMode.binary
def is_triplet(self) -> bool:
return self.mode == NegativeSamplingMode.triplet
def sample(self, num_samples: int,
num_nodes: Optional[int] = None) -> Tensor:
r"""Generates :obj:`num_samples` negative samples."""
if self.weight is None:
if num_nodes is None:
raise ValueError(
f"Cannot sample negatives in '{self.__class__.__name__}' "
f"without passing the 'num_nodes' argument")
return torch.randint(num_nodes, (num_samples, ))
if num_nodes is not None and self.weight.numel() != num_nodes:
raise ValueError(
f"The 'weight' attribute in '{self.__class__.__name__}' "
f"needs to match the number of nodes {num_nodes} "
f"(got {self.weight.numel()})")
return torch.multinomial(self.weight, num_samples, replacement=True)
class BaseSampler(ABC):
r"""An abstract base class that initializes a graph sampler and provides
:meth:`_sample_one_layer_from_nodes`
:meth:`_sample_one_layer_from_nodes_parallel`
:meth:`sample_from_nodes` routines.
"""
def sample_from_nodes(
self,
nodes: torch.Tensor,
**kwargs
) -> Tuple[torch.Tensor, list]:
r"""Performs mutilayer sampling from the nodes specified in: nodes
The specific number of layers is determined by parameter: num_layers
returning a sampled subgraph in the specified output format: Tuple[torch.Tensor, list].
Args:
nodes: the list of seed nodes index
**kwargs: other kwargs
Returns:
sampled_nodes: the nodes sampled
sampled_edge_index_list: the edges sampled
"""
raise NotImplementedError
def sample_from_edges(
self,
edges: torch.Tensor,
edge_label: Optional[torch.Tensor] = None,
neg_sampling: Optional[NegativeSampling] = None
) -> Tuple[torch.Tensor, list]:
r"""Performs sampling from the edges specified in :obj:`index`,
returning a sampled subgraph in the specified output format.
Args:
edges: the list of seed edges index
edge_label: the label for the seed edges.
neg_sampling: The negative sampling configuration
Returns:
sampled_nodes: the nodes sampled
sampled_edge_index_list: the edges sampled
metadata: other infomation
"""
raise NotImplementedError
# def _sample_one_layer_from_nodes(
# self,
# nodes:torch.Tensor,
# **kwargs
# ) -> Tuple[torch.Tensor, torch.Tensor]:
# r"""Performs sampling from the nodes specified in: nodes,
# returning a sampled subgraph in the specified output format: Tuple[torch.Tensor, torch.Tensor].
# Args:
# nodes: the list of seed nodes index
# **kwargs: other kwargs
# Returns:
# sampled_nodes: the nodes sampled
# sampled_edge_index: the edges sampled
# """
# raise NotImplementedError
# def _sample_one_layer_from_nodes_parallel(
# self,
# nodes: torch.Tensor,
# **kwargs
# ) -> Tuple[torch.Tensor, torch.Tensor]:
# r"""Performs sampling paralleled from the nodes specified in: nodes,
# returning a sampled subgraph in the specified output format: Tuple[torch.Tensor, torch.Tensor].
# Args:
# nodes: the list of seed nodes index
# **kwargs: other kwargs
# Returns:
# sampled_nodes: the nodes sampled
# sampled_edge_index: the edges sampled
# """
# raise NotImplementedError
import torch
# import sys
# from os.path import abspath, dirname
# sys.path.insert(0, abspath(dirname(__file__)))
# print(sys.path)
from neighbor_sampler import NeighborSampler
# edge_index = torch.tensor([[0, 1, 1, 2, 2, 2, 3], [1, 0, 2, 1, 3, 0, 2]])
# num_nodes = 4
edge_index = torch.tensor([[0, 1, 1, 1, 2, 2, 2, 3, 3, 4, 4, 4, 5], [1, 0, 2, 4, 1, 3, 0, 2, 5, 3, 5, 0, 2]])
num_nodes = 6
num_neighbors = 2
# Run the neighbor sampling
sampler=NeighborSampler(edge_index=edge_index, num_nodes=num_nodes, num_layers=2, workers=2, fanout=[2, 1])
# neighbor_nodes, sampled_edge_index = sampler._sample_one_layer_from_nodes(nodes=torch.tensor([1,3]), fanout=num_neighbors)
neighbor_nodes, sampled_edge_index = sampler.sample_from_nodes(torch.tensor([1,2,3]))
# Print the result
print('neighbor_nodes_id: \n',neighbor_nodes, '\nedge_index: \n',sampled_edge_index)
# import torch_scatter
# nodes=torch.Tensor([1,2])
# row, col = edge_index
# deg = torch_scatter.scatter_add(torch.ones_like(row), row, dim=0, dim_size=num_nodes)
# neighbors1=torch.concat([row[row==nodes[i]] for i in range(0, nodes.shape[0])])
# print(neighbors1)
# neighbors2=torch.concat([col[row==nodes[i]] for i in range(0, nodes.shape[0])])
# print(neighbors2)
# neighbors=torch.stack([neighbors1, neighbors2], dim=0)
# print('neighbors: \n', neighbors[0]==1)
from base import NegativeSampling
edge_index = torch.tensor([[0, 1, 1, 1, 2, 2, 2, 3, 3, 4, 4, 4, 5], [1, 0, 2, 4, 1, 3, 0, 2, 5, 3, 5, 0, 2]])
num_nodes = 6
# sampler
sampler=NeighborSampler(edge_index=edge_index, num_nodes=num_nodes, num_layers=2, workers=2, fanout=[2, 1])
# negative
weight = torch.tensor([0.3,0.1,0.1,0.1,0.3,0.1])
negative = NegativeSampling('binary', 2, weight)
# negative = NegativeSampling('triplet', 2, weight)
label=torch.tensor([1,2])
seed_edges = torch.tensor([[0,1],
[1,4]])
# result = sampler.sample_from_edges(edges=seed_edges)
result = sampler.sample_from_edges(edges=seed_edges, edge_label=label, neg_sampling=negative)
# Print the result
print('neighbor_nodes_id: \n',result[0], '\nedge_index: \n',result[1], '\nmetadata: \n',result[2])
\ No newline at end of file
import sys
from os.path import abspath, join, dirname
sys.path.insert(0, join(abspath(dirname(__file__))))
import math
import torch
import torch.multiprocessing as mp
from typing import Optional, Tuple
from base import BaseSampler, NegativeSampling
from sample_cores import get_neighbors, neighbor_sample_from_nodes, heads_unique, TemporalNeighborBlock
class NeighborSampler(BaseSampler):
def __init__(
self,
num_nodes: int,
num_layers: int,
fanout: list,
workers = 1,
edge_index : Optional[torch.Tensor] = None,
tnb = None,
) -> None:
r"""__init__
Args:
num_nodes: the num of all nodes in the graph
num_layers: the num of layers to be sampled
fanout: the list of max neighbors' number chosen for each layer
workers: the number of threads, default value is 1
edge_index: all edges in the graph
neighbors: all nodes' neighbors
deg: the degree of all nodes
should provide edge_index or (neighbors, deg)
"""
super().__init__()
self.num_layers = num_layers
# 线程数不超过torch默认的omp线程数
self.workers = min(workers, torch.get_num_threads())
self.fanout = fanout
self.num_nodes = num_nodes
if(edge_index is not None):
row, col = edge_index
self.tnb = get_neighbors(row.contiguous(), col.contiguous(), num_nodes)
else:
assert tnb is not None
self.tnb = tnb
def _get_sample_info(self):
return self.num_nodes,self.num_layers,self.fanout,self.workers
def _get_neighbors_and_deg(self):
return self.neighbors,self.deg
def _set_neighbors_and_deg(self,neighbors,deg):
self.neighbors = neighbors
self.deg = deg
def sample_from_nodes(
self,
nodes: torch.Tensor
) -> Tuple[torch.Tensor, list]:
r"""Performs mutilayer sampling from the nodes specified in: nodes
The specific number of layers is determined by parameter: num_layers
returning a sampled subgraph in the specified output format: Tuple[torch.Tensor, list].
Args:
nodes: the list of seed nodes index
Returns:
sampled_nodes: the node sampled
sampled_edge_index_list: the edge sampled
"""
sampled_edge_index_list = []
sampled_nodes = torch.LongTensor([])
src_nodes = nodes
assert self.workers > 0, 'Workers should be positive integer!!!'
for i in range(0, self.num_layers):
sampled_nodes_i, sampled_edge_index_i = self._sample_one_layer_from_nodes(nodes, self.fanout[i])
sampled_nodes = torch.cat([sampled_nodes, sampled_nodes_i])
nodes = sampled_nodes_i
sampled_edge_index_list.append(sampled_edge_index_i)
sampled_nodes = heads_unique(sampled_nodes, src_nodes, self.workers)
return sampled_nodes, sampled_edge_index_list
def sample_from_edges(
self,
edges: torch.Tensor,
edge_label: Optional[torch.Tensor] = None,
neg_sampling: Optional[NegativeSampling] = None
) -> Tuple[torch.Tensor, list]:
r"""Performs sampling from the edges specified in :obj:`index`,
returning a sampled subgraph in the specified output format.
Args:
edges: the list of seed edges index
edge_label: the label for the seed edges.
neg_sampling: The negative sampling configuration
Returns:
sampled_nodes: the nodes sampled
sampled_edge_index_list: the edges sampled
metadata: other infomation
"""
src, dst = edges
num_pos = src.numel()
num_neg = 0
if edge_label is None:
edge_label = torch.ones(num_pos)
if neg_sampling is not None:
num_neg = math.ceil(num_pos * neg_sampling.amount)
if neg_sampling.is_binary():
src_neg = neg_sampling.sample(num_neg, self.num_nodes)
src = torch.cat([src, src_neg], dim=0)
dst_neg = neg_sampling.sample(num_neg, self.num_nodes)
dst = torch.cat([dst, dst_neg], dim=0)
if neg_sampling.is_triplet():
dst_neg = neg_sampling.sample(num_neg, self.num_nodes)
dst = torch.cat([dst, dst_neg], dim=0)
seed = torch.cat([src, dst], dim=0)
seed, inverse_seed = seed.unique(return_inverse=True)
sampled_nodes, sampled_edge_index_list = self.sample_from_nodes(seed)
if neg_sampling is None or neg_sampling.is_binary():
edge_label_index = inverse_seed.view(2, -1)
# edge_label_index不知道是啥 edge_label是seed links的标签
metadata = {'edge_label_index':edge_label_index, 'edge_label':edge_label}
elif neg_sampling.is_triplet():
src_index = inverse_seed[:num_pos]
dst_pos_index = inverse_seed[num_pos:2 * num_pos]
dst_neg_index = inverse_seed[2 * num_pos:]
dst_neg_index = dst_neg_index.view(num_pos, -1).squeeze(-1)
# src_index是seed里src点的索引
# dst_pos_index是seed里dst_pos点的索引
# dst_neg_index是seed里dst_neg点的索引
metadata = {'src_index':src_index, 'dst_pos_index':dst_pos_index, 'dst_neg_index':dst_neg_index}
# sampled_nodes最前方是原始序列的采样起点也就是去重后的seed
return sampled_nodes, sampled_edge_index_list, metadata
def _sample_one_layer_from_nodes(
self,
nodes: torch.Tensor,
fanout: int
) -> Tuple[torch.Tensor, torch.Tensor]:
r"""Performs sampling from the nodes specified in: nodes,
returning a sampled subgraph in the specified output format: Tuple[torch.Tensor, torch.Tensor].
Args:
nodes: the list of seed nodes index
fanout: the number of max neighbors chosen
Returns:
sampled_nodes: the nodes sampled
sampled_edge_index: the edges sampled
"""
import time
pre = time.time()
row, col, sampled_nodes = neighbor_sample_from_nodes(nodes.contiguous(), self.tnb, fanout, self.workers)
end = time.time()
# print("py sample one layer:", (end-pre))
return sampled_nodes, torch.stack([row,col], dim=0)
# def _sample_one_layer_from_nodes_parallel(
# self,
# nodes: torch.Tensor,
# fanout: int
# ) -> Tuple[torch.Tensor, torch.Tensor]:
# r"""Performs sampling from the nodes specified in: nodes,
# returning a sampled subgraph in the specified output format: Tuple[torch.Tensor, torch.Tensor].
# Args:
# nodes: the list of seed nodes index
# fanout: the number of max neighbor chosen
# Returns:
# sampled_nodes: the node sampled
# sampled_edge_index: the edge sampled
# """
# sampled_nodes=torch.LongTensor([])
# row=torch.LongTensor([])
# col=torch.LongTensor([])
# assert self.workers > 0, 'Workers should be positive integer!!!'
# with mp.Pool(processes=self.workers) as p:
# n=len(nodes)
# if(self.workers>=n):
# results = [p.apply_async(self._sample_one_layer_from_nodes,
# (torch.tensor([node.item()]), fanout))
# for node in nodes]
# else:
# quotient = n//self.workers
# remainder = n%self.workers
# # 每个batch先分配quotient个nodes,然后将余数remainder平均分配给其中一些batch
# nodes1 = nodes[0:(quotient+1)*(remainder)].resize_(remainder,quotient+1)# 分配了余数的batch
# nodes2 = nodes[(quotient+1)*(remainder):n].resize_(self.workers - remainder,quotient)# 未分配余数的batch
# results = [p.apply_async(self._sample_one_layer_from_nodes,
# (nodes1[i], fanout))
# for i in range(0, remainder)]
# results.extend([p.apply_async(self._sample_one_layer_from_nodes,
# (nodes2[i], fanout))
# for i in range(0, self.workers - remainder)])
# for result in results:
# sampled_nodes_i, sampled_edge_index_i = result.get()
# sampled_nodes = torch.unique(torch.cat([sampled_nodes, sampled_nodes_i]))
# row = torch.cat([row, sampled_edge_index_i[0]])
# col = torch.cat([col, sampled_edge_index_i[1]])
# return sampled_nodes, torch.stack([row, col], dim=0)
# 不使用_sample_one_layer_from_node直接取所有点邻居方法:
# row, col = edge_index
# neighbors1=torch.concat([row[row==nodes[i]] for i in range(0, nodes.shape[0])])
# neighbors2=torch.concat([col[row==nodes[i]] for i in range(0, nodes.shape[0])])
# neighbors=torch.stack([neighbors1, neighbors2], dim=0)
# print('neighbors: \n', neighbors)
if __name__=="__main__":
edge_index1 = torch.tensor([[0, 1, 1, 1, 2, 2, 2, 4, 4, 4, 5], # , 3, 3
[1, 0, 2, 4, 1, 3, 0, 3, 5, 0, 2]])# , 2, 5
num_nodes1 = 6
num_neighbors = 2
# Run the neighbor sampling
sampler=NeighborSampler(edge_index=edge_index1, num_nodes=num_nodes1, num_layers=3, workers=4, fanout=[2, 1, 1])
# neighbor_nodes, sampled_edge_index = sampler._sample_one_layer_from_nodes(nodes=torch.tensor([2,1]), fanout=num_neighbors)
neighbor_nodes, sampled_edge_index = sampler.sample_from_nodes(torch.tensor([1,2]))
# Print the result
print('neighbor_nodes_id: \n',neighbor_nodes, '\nedge_index: \n',sampled_edge_index)
import torch
from ogb.nodeproppred import PygNodePropPredDataset
from torch_geometric import datasets
import time
def load_ogb_dataset(name, data_path):
dataset = PygNodePropPredDataset(name=name, root=data_path)
split_idx = dataset.get_idx_split()
g = dataset[0]
n_node = g.num_nodes
node_data={}
node_data['train_mask'] = torch.zeros(n_node, dtype=torch.bool)
node_data['val_mask'] = torch.zeros(n_node, dtype=torch.bool)
node_data['test_mask'] = torch.zeros(n_node, dtype=torch.bool)
node_data['train_mask'][split_idx["train"]] = True
node_data['val_mask'][split_idx["valid"]] = True
node_data['test_mask'][split_idx["test"]] = True
return g, node_data
g, node_data = load_ogb_dataset('ogbn-products', "/home/hzq/code/gnn/test/NewSample/dataset")
print(g)
from neighbor_sampler import NeighborSampler
pre = time.time()
from neighbor_sampler import get_neighbors
row, col = g.edge_index
tnb = get_neighbors(row.contiguous(), col.contiguous(), g.num_nodes)
sampler = NeighborSampler(g.num_nodes, num_layers=2, fanout=[100,100], workers=4, tnb=tnb)
end = time.time()
print("init time:", end-pre)
# from torch_geometric.sampler import NeighborSampler, NumNeighbors, NodeSamplerInput, SamplerOutput
# pre = time.time()
# num_nei = NumNeighbors([100, 100])
# node_idx = NodeSamplerInput(input_id=None, node=torch.masked_select(torch.arange(g.num_nodes),node_data['train_mask']))
# sampler = NeighborSampler(g, num_nei)
# end = time.time()
# print("init time:", end-pre)
pre = time.time()
node, edge = sampler.sample_from_nodes(torch.masked_select(torch.arange(g.num_nodes),node_data['train_mask']))
# out = sampler.sample_from_nodes(node_idx)
# node = out.node
# edge = [out.row, out.col]
end = time.time()
print(node)
print(edge)
print("sample time", end-pre)
\ No newline at end of file
import torch
import torch.multiprocessing as mp
from typing import Optional, Tuple
from .base import BaseSampler, NegativeSampling
from .neighbor_sampler import NeighborSampler
class RandomWalkSampler(BaseSampler):
def __init__(
self,
num_nodes: int,
num_layers: int,
workers = 1,
edge_index : Optional[torch.Tensor] = None,
deg = None,
neighbors = None
) -> None:
r"""__init__
Args:
num_nodes: the num of all nodes in the graph
num_layers: the num of layers to be sampled
fanout: the list of max neighbors' number chosen for each layer
workers: the number of threads, default value is 1
edge_index: all edges in the graph
neighbors: all nodes' neighbors
deg: the degree of all nodes
"""
super().__init__()
if(edge_index is not None):
self.sampler = NeighborSampler(num_nodes, num_layers, [1 for _ in range(num_layers)],
workers, edge_index)
elif(neighbors is not None and deg is not None):
self.sampler = NeighborSampler(num_nodes, num_layers, [1 for _ in range(num_layers)],
workers, neighbors, deg)
else:
raise Exception("Not enough parameters")
self.num_layers = num_layers
# 线程数不超过torch默认的omp线程数
self.workers = min(workers, torch.get_num_threads())
def sample_from_nodes(
self,
nodes: torch.Tensor
) -> Tuple[torch.Tensor, list]:
r"""Performs mutilayer sampling from the nodes specified in: nodes
The specific number of layers is determined by parameter: num_layers
returning a sampled subgraph in the specified output format: Tuple[torch.Tensor, list].
Args:
nodes: the list of seed nodes index
Returns:
sampled_nodes: the node sampled
sampled_edge_index: the edge sampled
"""
return self.sampler.sample_from_nodes(nodes)
def sample_from_edges(
self,
edges: torch.Tensor,
edge_label: Optional[torch.Tensor] = None,
neg_sampling: Optional[NegativeSampling] = None
) -> Tuple[torch.Tensor, list]:
r"""Performs sampling from the edges specified in :obj:`index`,
returning a sampled subgraph in the specified output format.
Args:
edges: the list of seed edges index
edge_label: the label for the seed edges.
neg_sampling: The negative sampling configuration
Returns:
sampled_nodes: the nodes sampled
sampled_edge_index_list: the edges sampled
"""
return self.sampler.sample_from_edges(edges, edge_label, neg_sampling)
if __name__=="__main__":
edge_index1 = torch.tensor([[0, 1, 1, 1, 2, 2, 2, 4, 4, 4, 5], # , 3, 3
[1, 0, 2, 4, 1, 3, 0, 3, 5, 0, 2]])# , 2, 5
num_nodes1 = 6
# Run the random walk sampling
sampler=RandomWalkSampler(edge_index=edge_index1, num_nodes=num_nodes1, num_layers=3, workers=4)
sampled_nodes, sampled_edge_index = sampler.sample_from_nodes(torch.tensor([1,2]))
# Print the result
print('sampled_nodes_id: \n',sampled_nodes, '\nedge_index: \n',sampled_edge_index)
#include <iostream>
#include <omp.h>
#include <pybind11/pybind11.h>
#include <pybind11/numpy.h>
#include <pybind11/stl.h>
#include <torch/extension.h>
#include <time.h>
#include<random>
#include <parallel_hashmap/phmap.h>
using namespace std;
namespace py = pybind11;
namespace th = torch;
typedef int64_t NodeIDType;
// typedef int64_t EdgeIDType;
// typedef float TimeStampType;
class TemporalNeighborBlock;
class TemporalGraphBlock;
TemporalNeighborBlock get_neighbors(th::Tensor row, th::Tensor col, int64_t num_nodes);
TemporalGraphBlock* neighbor_sample_from_node(NodeIDType node, vector<NodeIDType> neighbors, int64_t deg, int64_t fanout, int threads);
vector<th::Tensor> neighbor_sample_from_nodes(th::Tensor nodes, TemporalNeighborBlock& tnb, int64_t fanout, int threads);
th::Tensor heads_unique(th::Tensor array, th::Tensor heads, int threads);
template<typename T>
inline py::array vec2npy(const std::vector<T> &vec)
{
// need to let python garbage collector handle C++ vector memory
// see https://github.com/pybind/pybind11/issues/1042
// non-copy value transfer
auto v = new std::vector<T>(vec);
auto capsule = py::capsule(v, [](void *v)
{ delete reinterpret_cast<std::vector<T> *>(v); });
return py::array(v->size(), v->data(), capsule);
// return py::array(vec.size(), vec.data());
}
/*
* NeighborSampler Utils
*/
class TemporalNeighborBlock
{
public:
std::vector<vector<NodeIDType>> neighbors;
std::vector<int64_t> deg;
// int64_t max_deg;
TemporalNeighborBlock(){}
TemporalNeighborBlock(const TemporalNeighborBlock &tnb);
TemporalNeighborBlock(std::vector<vector<NodeIDType>>& neighbors,
std::vector<int64_t> &deg):
neighbors(neighbors), deg(deg){}
py::array get_node_neighbor(NodeIDType node_id){
return vec2npy(neighbors[node_id]);
}
int64_t get_node_deg(NodeIDType node_id){
return deg[node_id];
}
// int64_t get_max_deg(){
// return max_deg;
// }
};
//�������캯��
TemporalNeighborBlock::TemporalNeighborBlock(const TemporalNeighborBlock &tnb){
// clock_t pre = clock();
this->neighbors = tnb.neighbors;
this->deg = tnb.deg;
// this->max_deg = tnb.max_deg;
// clock_t end=clock();
// cout<<"copy tnb time: "<<(end-pre)*1.0/CLOCKS_PER_SEC<<endl;
}
class TemporalGraphBlock
{
public:
std::vector<NodeIDType> row;
std::vector<NodeIDType> col;
std::vector<NodeIDType> nodes;
TemporalGraphBlock(){}
TemporalGraphBlock(const TemporalGraphBlock &tgb);
TemporalGraphBlock(std::vector<NodeIDType> &_row, std::vector<NodeIDType> &_col,
std::vector<NodeIDType> &_nodes):
row(_row), col(_col), nodes(_nodes){}
};
//�������캯��
TemporalGraphBlock::TemporalGraphBlock(const TemporalGraphBlock &tgb){
// clock_t pre = clock();
this->row = tgb.row;
this->col = tgb.col;
this->nodes = tgb.nodes;
// clock_t end=clock();
// cout<<"copy tgb time: "<<(end-pre)*1.0/CLOCKS_PER_SEC<<endl;
}
TemporalNeighborBlock get_neighbors(
th::Tensor row, th::Tensor col, int64_t num_nodes){
AT_ASSERTM(row.is_contiguous(), "Offset tensor must be contiguous");
AT_ASSERTM(col.is_contiguous(), "Offset tensor must be contiguous");
AT_ASSERTM(row.dim() == 1, "0ffset tensor must be one-dimensional");
AT_ASSERTM(col.dim() == 1, "0ffset tensor must be one-dimensional");
auto u = row.data_ptr<NodeIDType>();
auto v = col.data_ptr<NodeIDType>();
int64_t edge_num = row.size(0);
TemporalNeighborBlock tnb = TemporalNeighborBlock();
tnb.deg.resize(num_nodes, 0);
// tnb.max_deg=0;
double start_time = omp_get_wtime();
for(int64_t i=0; i<num_nodes; i++)
tnb.neighbors.push_back(vector<NodeIDType>());
for(int64_t i=0; i<edge_num; i++){
//����ڵ��ھ�
tnb.neighbors[u[i]].push_back(v[i]);
}
for(int64_t i=0; i<num_nodes; i++){
//�ռ��ڵ��
tnb.deg[i] = int64_t(tnb.neighbors[i].size());
// tnb.max_deg = max(tnb.max_deg, tnb.deg[i]);
}
double end_time = omp_get_wtime();
cout<<"get_neighbors consume: "<<end_time-start_time<<"s"<<endl;
return tnb;
}
vector<th::Tensor> neighbor_sample_from_nodes(
th::Tensor nodes, TemporalNeighborBlock& tnb, int64_t fanout, int threads){
py::gil_scoped_release release;
TemporalGraphBlock tgb = TemporalGraphBlock();
vector<th::Tensor> ret;
AT_ASSERTM(nodes.is_contiguous(), "Offset tensor must be contiguous");
AT_ASSERTM(nodes.dim() == 1, "0ffset tensor must be one-dimensional");
auto nodes_data = nodes.data_ptr<NodeIDType>();
vector<vector<NodeIDType>> row_threads, col_threads;
for(int i = 0; i<threads; i++){
row_threads.push_back(vector<NodeIDType>());
col_threads.push_back(vector<NodeIDType>());
}
// double start_time = omp_get_wtime();
#pragma omp parallel for num_threads(threads) default(shared)
for(int64_t i=0; i<nodes.size(0); i++){
NodeIDType node = nodes_data[i];
vector<NodeIDType> nei(tnb.neighbors[node]);
TemporalGraphBlock tgb_i = TemporalGraphBlock();
default_random_engine e(time(0));
uniform_int_distribution<> u(0, tnb.deg[node]-1);
if(tnb.deg[node]>fanout){
//�ȴ����ȳ��Ļ���Ҫ���ѡ��fanout���ھ�
phmap::flat_hash_set<NodeIDType> s;
while(s.size()!=fanout){
//ѭ��ѡ��fanout���ھ�
auto chosen_iter = nei.begin() + u(e);
s.insert(*chosen_iter);
}
tgb_i.col.assign(s.begin(), s.end());
}
else{
tgb_i.col.swap(nei);
}
tgb_i.row.resize(tgb_i.col.size(), node);
// TemporalGraphBlock* tgb_i = neighbor_sample_from_node(node, tnb.neighbors[node], tnb.deg[node], fanout, threads);
int tid = omp_get_thread_num();
row_threads[tid].insert(row_threads[tid].end(),tgb_i.row.begin(),tgb_i.row.end());
col_threads[tid].insert(col_threads[tid].end(),tgb_i.col.begin(),tgb_i.col.end());
}
// double end_time = omp_get_wtime();
// cout<<"neighbor_sample_from_nodes parallel part consume: "<<end_time-start_time<<"s"<<endl;
#pragma omp sections
{
#pragma omp section
for(int i = 0; i<threads; i++)
tgb.row.insert(tgb.row.end(), row_threads[i].begin(), row_threads[i].end());
#pragma omp section
for(int i = 0; i<threads; i++)
tgb.col.insert(tgb.col.end(), col_threads[i].begin(), col_threads[i].end());
}
// for(int i = 0; i<threads; i++){
// tgb.row.insert(tgb.row.end(), row_threads[i].begin(), row_threads[i].end());
// tgb.col.insert(tgb.col.end(), col_threads[i].begin(), col_threads[i].end());
// }
//sampled nodes ȥ��, �ݲ�����root nodes
// start_time = end_time;
// heads_unique(tgb.nodes, nodes, threads);
phmap::parallel_flat_hash_set<NodeIDType> s(tgb.col.begin(), tgb.col.end());
tgb.nodes.assign(s.begin(), s.end());
// end_time = omp_get_wtime();
// cout<<"end unique consume: "<<end_time-start_time<<"s"<<endl;
// start_time = end_time;
ret.push_back(th::tensor(tgb.row));
ret.push_back(th::tensor(tgb.col));
ret.push_back(th::tensor(tgb.nodes));
// end_time = omp_get_wtime();
// cout<<"vec to tensor consume: "<<end_time-start_time<<"s"<<endl;
py::gil_scoped_acquire acquire;
return ret;
}
vector<th::Tensor> neighbor_sample_from_nodes(
th::Tensor nodes, TemporalNeighborBlock& tnb, int64_t fanout, double D,int threads){
py::gil_scoped_release release;
TemporalGraphBlock tgb = TemporalGraphBlock();
vector<th::Tensor> ret;
AT_ASSERTM(nodes.is_contiguous(), "Offset tensor must be contiguous");
AT_ASSERTM(nodes.dim() == 1, "0ffset tensor must be one-dimensional");
auto nodes_data = nodes.data_ptr<NodeIDType>();
vector<vector<NodeIDType>> row_threads, col_threads;
for(int i = 0; i<threads; i++){
row_threads.push_back(vector<NodeIDType>());
col_threads.push_back(vector<NodeIDType>());
}
// double start_time = omp_get_wtime();
#pragma omp parallel for num_threads(threads) default(shared)
for(int64_t i=0; i<nodes.size(0); i++){
NodeIDType node = nodes_data[i];
vector<NodeIDType> nei(tnb.neighbors[node]);
TemporalGraphBlock tgb_i = TemporalGraphBlock();
default_random_engine e(time(0));
uniform_int_distribution<> u(0, tnb.deg[node]-1);
if(tnb.deg[node]>fanout){
//�ȴ����ȳ��Ļ���Ҫ���ѡ��fanout���ھ�
phmap::flat_hash_set<NodeIDType> s;
while(s.size()!=fanout){
//ѭ��ѡ��fanout���ھ�
auto chosen_iter = nei.begin() + u(e);
s.insert(*chosen_iter);
}
tgb_i.col.assign(s.begin(), s.end());
}
else{
tgb_i.col.swap(nei);
}
tgb_i.row.resize(tgb_i.col.size(), node);
// TemporalGraphBlock* tgb_i = neighbor_sample_from_node(node, tnb.neighbors[node], tnb.deg[node], fanout, threads);
int tid = omp_get_thread_num();
row_threads[tid].insert(row_threads[tid].end(),tgb_i.row.begin(),tgb_i.row.end());
col_threads[tid].insert(col_threads[tid].end(),tgb_i.col.begin(),tgb_i.col.end());
}
// double end_time = omp_get_wtime();
// cout<<"neighbor_sample_from_nodes parallel part consume: "<<end_time-start_time<<"s"<<endl;
#pragma omp sections
{
#pragma omp section
for(int i = 0; i<threads; i++)
tgb.row.insert(tgb.row.end(), row_threads[i].begin(), row_threads[i].end());
#pragma omp section
for(int i = 0; i<threads; i++)
tgb.col.insert(tgb.col.end(), col_threads[i].begin(), col_threads[i].end());
}
// for(int i = 0; i<threads; i++){
// tgb.row.insert(tgb.row.end(), row_threads[i].begin(), row_threads[i].end());
// tgb.col.insert(tgb.col.end(), col_threads[i].begin(), col_threads[i].end());
// }
//sampled nodes ȥ��, �ݲ�����root nodes
// start_time = end_time;
// heads_unique(tgb.nodes, nodes, threads);
phmap::parallel_flat_hash_set<NodeIDType> s(tgb.col.begin(), tgb.col.end());
tgb.nodes.assign(s.begin(), s.end());
// end_time = omp_get_wtime();
// cout<<"end unique consume: "<<end_time-start_time<<"s"<<endl;
// start_time = end_time;
ret.push_back(th::tensor(tgb.row));
ret.push_back(th::tensor(tgb.col));
ret.push_back(th::tensor(tgb.nodes));
// end_time = omp_get_wtime();
// cout<<"vec to tensor consume: "<<end_time-start_time<<"s"<<endl;
py::gil_scoped_acquire acquire;
return ret;
}
/*-------------------------------------------------------------------------------------**
**------------Utils--------------------------------------------------------------------**
**-------------------------------------------------------------------------------------*/
TemporalGraphBlock* neighbor_sample_from_node(
NodeIDType node, vector<NodeIDType> neighbors,
int64_t deg, int64_t fanout, int threads){
TemporalGraphBlock* tgb = new TemporalGraphBlock();
srand((int)time(0));
if(deg>fanout){
//�ȴ����ȳ��Ļ���Ҫ���ѡ��fanout���ھ�
for(int64_t i=0; i<fanout; i++){
//ѭ��ѡ��fanout���ھ�
auto chosen_iter = neighbors.begin() + rand()%(deg-i);
tgb->col.push_back(*chosen_iter);
neighbors.erase(chosen_iter);
}
}
else
tgb->col.swap(neighbors);
tgb->row.resize(tgb->col.size(), node);
//sampled nodes �ݲ����룬���ϲ���һ����벢ȥ��
return tgb;
}
th::Tensor heads_unique(th::Tensor array, th::Tensor heads, int threads){
auto array_ptr = array.data_ptr<NodeIDType>();
phmap::parallel_flat_hash_set<NodeIDType> s(array_ptr, array_ptr+array.numel());
if(heads.numel()==0) return th::tensor(vector<NodeIDType>(s.begin(), s.end()));
AT_ASSERTM(heads.is_contiguous(), "Offset tensor must be contiguous");
AT_ASSERTM(heads.dim() == 1, "0ffset tensor must be one-dimensional");
auto heads_ptr = heads.data_ptr<NodeIDType>();
#pragma omp parallel for num_threads(threads)
for(int64_t i=0; i<heads.size(0); i++){
if(s.count(heads_ptr[i])==1){
#pragma omp critical(erase)
s.erase(heads_ptr[i]);
}
}
vector<NodeIDType> ret;
ret.reserve(s.size()+heads.numel());
ret.assign(heads_ptr, heads_ptr+heads.numel());
ret.insert(ret.end(), s.begin(), s.end());
// cout<<"s: "<<s.size()<<" array: "<<array.size()<<endl;
return th::tensor(ret);
}
/*------------Python Bind--------------------------------------------------------------*/
PYBIND11_MODULE(sample_cores, m)
{
m
.def("neighbor_sample_from_nodes",
&neighbor_sample_from_nodes,
py::return_value_policy::reference)
.def("get_neighbors",
&get_neighbors,
py::return_value_policy::reference)
.def("heads_unique",
&heads_unique,
py::return_value_policy::reference);
py::class_<TemporalGraphBlock>(m, "TemporalGraphBlock")
.def(py::init<std::vector<NodeIDType> &, std::vector<NodeIDType> &,
std::vector<NodeIDType> &>())
.def("row", [](const TemporalGraphBlock &tgb) { return vec2npy(tgb.row); })
.def("col", [](const TemporalGraphBlock &tgb) { return vec2npy(tgb.col); })
.def("nodes", [](const TemporalGraphBlock &tgb) { return vec2npy(tgb.nodes); });
py::class_<TemporalNeighborBlock>(m, "TemporalNeighborBlock")
.def(py::init<std::vector<vector<NodeIDType>>&,
std::vector<int64_t> &>())
// .def("get_node_neighbor",&TemporalNeighborBlock::get_node_neighbor)
// .def("get_node_deg", &TemporalNeighborBlock::get_node_deg)
// .def("get_max_deg", &TemporalNeighborBlock::get_max_deg)
// .def_readonly("max_deg", &TemporalNeighborBlock::max_deg, py::return_value_policy::reference)
.def_readonly("neighbors", &TemporalNeighborBlock::neighbors, py::return_value_policy::reference)
.def_readonly("deg", &TemporalNeighborBlock::deg, py::return_value_policy::reference);
}
\ No newline at end of file
Metadata-Version: 2.1
Name: sample-cores
Version: 0.0.0
sample_cores.cpp
setup.py
sample_cores.egg-info/PKG-INFO
sample_cores.egg-info/SOURCES.txt
sample_cores.egg-info/dependency_links.txt
sample_cores.egg-info/top_level.txt
\ No newline at end of file
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CppExtension
setup(
name='sample_cores',
ext_modules=[
CppExtension(
name='sample_cores',
sources=['sample_cores.cpp'],
extra_compile_args=['-fopenmp','-Xlinker',' -export-dynamic'],
include_dirs=["/home/hzq/code/gnn/test/NewSample"],
),
],
cmdclass={
'build_ext': BuildExtension
})
File added
import sys
import argparse
parser = argparse.ArgumentParser(
description="RPC Reinforcement Learning Example",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument('--world_size', default=1, type=int, metavar='W',
help='number of workers')
parser.add_argument('--rank', default=0, type=int, metavar='W',
help='rank of the worker')
parser.add_argument('--log_interval', type=int, default=10, metavar='N',
help='interval between training status logs')
parser.add_argument('--gamma', type=float, default=0.99, metavar='G',
help='how much to value future rewards')
parser.add_argument('--seed', type=int, default=1, metavar='S',
help='random seed for reproducibility')
parser.add_argument('--num_sampler', type=int, default=1, metavar='S',
help='number of samplers')
parser.add_argument('--queue_size', type=int, default=10000, metavar='S',
help='sampler queue size')
args = parser.parse_args()
rpc_proxy=None
WORKER_RANK = args.rank
NUM_SAMPLER = args.num_sampler
WORLD_SIZE = args.world_size
QUEUE_SIZE = args.queue_size
MAX_QUEUE_SIZE = 5*args.queue_size
RPC_NAME = "rpcserver{}"
def _get_worker_rank():
return WORKER_RANK
def _get_num_sampler():
return NUM_SAMPLER
def _get_world_size():
return WORLD_SIZE
def _get_RPC_NAME():
return RPC_NAME
def _get_queue_size():
return QUEUE_SIZE
def _get_max_queue_size():
return MAX_QUEUE_SIZE
def _get_rpc_name():
return RPC_NAME
\ No newline at end of file
graph_set={}
def _clear_all(barrier = None):
global graph_set
for key in graph_set:
graph = graph_set[key]
graph._close_graph_in_shame()
print('clear ',key)
if(barrier is not None and barrier.wait()==0):
graph._unlink_graph_in_shame()
graph_set = {}
def _set_graph(graph_name,graph_info):
graph_info._get_graph_from_shm()
graph_set[graph_name]=graph_info
def _get_graph(graph_name):
return graph_set[graph_name]
def _del_graph(graph_name):
graph_set.pop(graph_name)
def _get_size():
return len(graph_set)
\ No newline at end of file
import logging
import os
import time
class Config():
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)
# 添加控制台管理器(即控制台展示log内容)
ls = logging.StreamHandler()
ls.setLevel(logging.DEBUG)
# 设置log的记录格式
formatter = logging.Formatter('%(asctime)s-%(name)s-%(filename)s-[line:%(lineno)d]''-%(levelname)s: %(message)s')
# 把格式添加到控制台管理器,即控制台打印日志
ls.setFormatter(formatter)
# 把控制台添加到logger
logger.addHandler(ls)
# 先在项目目录下建一个logs目录,来存放log文件(可自定义路径)
print(os.path.abspath('.') )
logdir = os.path.join(os.path.abspath('.'), 'logs')
if not os.path.exists(logdir):
os.mkdir(logdir)
# 再在logs目录下创建以日期开头的.log文件
logfile = os.path.join(logdir, time.strftime('%Y-%m-%d %H:%M:%S') + '.log')
# 添加log的文件处理器,并设置log的配置文件模式编码
lf = logging.FileHandler(filename=logfile, encoding='utf8')
# 设置log文件处理器记录的日志级别
lf.setLevel(logging.DEBUG)
# 设置日志记录的格式
lf.setFormatter(formatter)
# 把文件处理器添加到log
logger.addHandler(lf)
def get_config(self):
return self.logger
logger = Config().get_config()
2023-05-11 02:24:19,565-torch.distributed.distributed_c10d-distributed_c10d.py-[line:228]-INFO: Added key: store_based_barrier_key:0 to store for rank: 0
2023-05-11 02:24:19,566-torch.distributed.distributed_c10d-distributed_c10d.py-[line:262]-INFO: Rank 0: Completed store-based barrier for key:store_based_barrier_key:0 with 1 nodes.
2023-05-11 02:24:20,504-torch.nn.parallel.distributed-distributed.py-[line:995]-INFO: Reducer buckets have been rebuilt in this iteration.
from concurrent.futures import ThreadPoolExecutor
from enum import Enum
from multiprocessing import Barrier
from multiprocessing.connection import Client
import queue
import sys
import argparse
import traceback
from torch.distributed.rpc import RRef, rpc_async, remote
import torch.distributed.rpc as rpc
import torch
from part.Utils import GraphData
from Sample.neighbor_sampler import NeighborSampler
import time
from typing import Optional
import torch.distributed as dist
import torch.multiprocessing as map
from BatchData import BatchData
import asyncio
import os
from Sample.neighbor_sampler import NeighborSampler
from graph_store import _get_graph
from shared_graph_memory import GraphInfoInShm
import distparser as parser
from logger import logger
WORKER_RANK = parser._get_worker_rank()
NUM_SAMPLER = parser._get_num_sampler()
WORLD_SIZE = parser._get_world_size()
QUEUE_SIZE =parser._get_queue_size()
MAX_QUEUE_SIZE = parser._get_max_queue_size()
RPC_NAME = parser._get_RPC_NAME()
#@rpc.functions.async_execution
def _get_local_attr(data_name,nodes):
graph = _get_graph(data_name)
local_id = graph.get_localId_by_partitionId(parser._get_worker_rank(),nodes)
return graph.select_attr(local_id)
def _request_remote_attr(rank,data_name,nodes):
t1 = time.time()
fut = rpc_async(
parser._get_RPC_NAME().format(rank),
_get_local_attr,
args=(data_name,nodes,)
)
#logger.debug('request {}'.format(time.time()-t1))
return fut
#ThreadPoolExecutor pool
def _request_all_remote_attr(data_name,nodes_list):
worker_size = parser._get_world_size()
worker_rank = parser._get_worker_rank()
futs = []
for rank in range(worker_size):
if(rank == worker_rank):
futs.append(None)
continue
else:
if(nodes_list[rank].size(0) == 0):
futs.append(None)
else:
futs.append(_request_remote_attr(rank,data_name,nodes_list[rank]))
return futs
def _check_future_finish(futs):
check = True
for _,fut in futs:
if fut is not None and fut.done() is False:
check = False
return check
def _split_node_part_and_submit(data_name,node_id_list):
t0 = time.time()
graph = _get_graph(data_name)
worker_size = parser._get_world_size()
local_rank = parser._get_worker_rank()
futs = []
t1 = time.time()
for rank in range(worker_size):
if(rank != local_rank):
part_mask = (graph.partptr[rank]<=node_id_list) & (node_id_list<graph.partptr[rank+1])
part_node = torch.masked_select(node_id_list,part_mask)
if(part_node.size(0) != 0):
futs.append((part_node,_request_remote_attr(rank,data_name,part_node)))
t2 = time.time()
local_mask = (graph.partptr[local_rank]<=node_id_list) & (node_id_list<graph.partptr[local_rank+1])
t3 = time.time()
local_node = torch.masked_select(node_id_list,local_mask)
t4 = time.time()
#logger.debug('size {},split {} {} {}'.format(node_id_list.size(0),t2-t1,t3-t2,t4-t3))
return local_node,futs
#def _split_node_part(data_name,node_id_list):
# graph = _get_graph(data_name)
# part_list=[]
# scatter_list=[]
# worker_size = parser._get_world_size()
# #print(node_id_list)
# for rank in range(worker_size):
# group_id=(graph.partptr[rank]<=node_id_list) & (node_id_list<graph.partptr[rank+1])
# part_list.append(torch.masked_select(node_id_list,group_id))
# scatter_list.append(torch.nonzero(group_id))
# #print(graph.partptr[rank],graph.partptr[rank+1])
# #print(torch.nonzero(group_id))
# #print(node_id_list)
# scatter_index=torch.cat(scatter_list)
# # print('part_list',part_list)
# return part_list,scatter_index
#
#def _union_remote_and_local(data_name,scatter_list,local_nodes,futs):
# worker_rank = parser._get_worker_rank()
# worker_size = parser._get_world_size()
# feature_part = []
# local_feature=_get_local_attr(data_name,local_nodes)
# node_feature_list = []
# for rank in range(worker_size):
# if(rank == worker_rank):
# node_feature_list.append(local_feature)
# elif(futs[rank] is None):
# continue
# else:
# node_feature_list.append(futs[rank].value())
#
# node_feature = torch.cat(node_feature_list)
# attr_size=node_feature.size()
# x=torch.zeros(attr_size)
# #print(scatter_list.size())
# #print(node_feature.size())
# #print(scatter_list.repeat(1,attr_size[1]).size())
# #print(x.size())
# #print(scatter_list)
# x=x.scatter(0,scatter_list.repeat(1,attr_size[1]),node_feature)
# return x
#total_local = 0
#total_remote = 0
#def _get_batch_data(kwargs):
# global total_local
# global total_remote
# data_name = kwargs.get("data_name")
# input_data = kwargs.get("input_data")
# nids = kwargs.get("nids")
# root = nids[:input_data.size(0)]
# scatter_list,local_nodes,futs = kwargs.get("union_args")
# edge_index = kwargs.get("edge_index")
# x=_union_remote_and_local(data_name,scatter_list,local_nodes,futs)
# graph = _get_graph(data_name)
# worker_rank = parser._get_worker_rank()
# local_y_id = graph.get_localId_by_partitionId(worker_rank,root)
# y = graph.select_y(local_y_id)
# total_local = total_local + len(local_nodes)
# total_remote = total_remote + len(x)-len(local_nodes)
# print('total local',total_local,'total remote',total_remote)
# #print(t2-t1,t3-t2,t4-t3,t5-t4,time.time()-t4)
# return BatchData(nids,edge_index,roots=input_data,x=x,y=y)
def _get_batch_data(kwargs):
data_name = kwargs.get("data_name")
input_size = kwargs.get("input_size")
local_nodes,futs = kwargs.get("union_args")
edge_index = kwargs.get("edge_index")
graph = _get_graph(data_name)
nids = [local_nodes]
root = local_nodes[:input_size]
worker_rank = parser._get_worker_rank()
local_y_id = graph.get_localId_by_partitionId(worker_rank,root)
y = graph.select_y(local_y_id)
x = [_get_local_attr(data_name,local_nodes)]
for (part_node,part_feature) in futs:
nids.append(part_node)
x.append(part_feature.value())
nids = torch.cat(nids,0)
x = torch.cat(x,0)
return BatchData(nids,edge_index,roots=root,x=x,y=y)
def _sample_node_neighbors_server(data_name,neighbor_nodes):
'''
sample the struct of the subgraph
'''
print('sample node neighbors server')
t1 = time.time()
local_node,futs = _split_node_part_and_submit(data_name,neighbor_nodes)
t2 = time.time()
#logger.debug('sample server {}'.format(t2-t1))
return local_node,futs
def _sample_node_neighbors_single(graph,graph_name,sampler,input_data):
#out = sampler.sample_from_nodes((graph_name,input_data,time.time()))
nids,edge_index = sampler.sample_from_nodes(input_data.reshape(-1))
#nids = out.node
#edge_index = torch.cat((out.row.reshape(1,-1),out.col.reshape(1,-1)),0)
x= graph.select_attr(nids)
y = graph.select_y(input_data)
return BatchData(nids,edge_index,roots=input_data,x=x,y=y)
\ No newline at end of file
from mimetypes import init
import torch
import torch.nn as nn
import torch.nn.functional as F
from part.Utils import GraphData
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import GCNConv
from torch_geometric.nn import MessagePassing
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric.nn as pyg_nn
from BatchData import BatchData
class MyModel(torch.nn.Module):
def __init__(self,graph:GraphData):
super(MyModel, self).__init__()
self.conv1 = GCNConv(graph.num_feasures, 128) #输入=节点特征维度,16是中间隐藏神经元个数
self.conv2 = GCNConv(128, 7)
self.num_layers = 2
def forward(self, x, edge_index):
x = self.conv1(x, edge_index)
x = F.relu(x)
x = self.conv2(x, edge_index)
return F.log_softmax(x, dim=1)
class SAGEConv(MessagePassing):
def __init__(self, in_channels, out_channels, bias=True, aggr='mean'):
super(SAGEConv, self).__init__(aggr=aggr) # 使用mean聚合方式
# 线性层
self.w1 = pyg_nn.dense.linear.Linear(in_channels, out_channels, weight_initializer='glorot', bias=bias)
self.w2 = pyg_nn.dense.linear.Linear(in_channels, out_channels, weight_initializer='glorot', bias=bias)
def message(self, x_j):
# x_j [E, in_channels]
# 将邻居特征进行特征映射
wh = self.w2(x_j) # [E, out_channels]
return wh
def update(self, aggr_out, x):
# aggr_out [num_nodes, out_channels]
# 对自身节点进行特征映射
wh = self.w1(x)
return aggr_out + wh
def forward(self, x, edge_index):
return self.propagate(edge_index, x=x)
# 3.定义GraphSAGE网络
class GraphSAGE(nn.Module):
def __init__(self, num_node_features, num_classes):
super(GraphSAGE, self).__init__()
# self.gcns = []
# self.gcns.append(SAGEConv(in_channels=num_node_features,out_channels=16))
# self.gcns.append(SAGEConv(in_channels=16, out_channels=num_classes))
self.conv1 = SAGEConv(in_channels=num_node_features,
out_channels=256)
self.conv2 = SAGEConv(in_channels=256,
out_channels=num_classes)
def forward(self, data:BatchData,type):
#print(data.data.train_mask)
##nids = torch.nonzero(data.data.train_mask).reshape(-1)
#x=data.data.x
#edge_indexs = data.edge_index
nids,x, edge_indexs = data.nids,data.x, data.edge_index
nids_dict = dict(zip(nids.tolist(), torch.arange(nids.numel()).tolist()))
edge_index_local = []
edge_index_local.append(torch.tensor([
[nids_dict[int(s.item())] for s in edge_indexs[0][0]],
[nids_dict[int(s.item())] for s in edge_indexs[0][1]]
]) )
edge_index_local.append(torch.tensor([
[nids_dict[int(s.item())] for s in edge_indexs[1][0]],
[nids_dict[int(s.item())] for s in edge_indexs[1][1]]
]) )
x = self.conv1(x, edge_index_local[1])
x = F.relu(x)
if(type == 0):
x = F.dropout(x, training=self.training)
x = self.conv2(x, edge_index_local[1])
print(f'x sizes {x.size()}')
#root_nids = nids
return F.log_softmax(x[:data.batch_size], dim=1)
def inference(self, data):
nids,x, edge_indexs = data.nids,data.x, data.edge_index
# x_global = [x[nids[i]] for i in range(x.shape[0])]
nids_dict = dict(zip(nids.tolist(), torch.arange(nids.numel()).tolist()))
edge_index_local = torch.tensor([
[nids_dict[int(s.item())] for s in edge_indexs[0][0]],
[nids_dict[int(s.item())] for s in edge_indexs[0][1]]])
x = self.conv1(x, edge_index_local)
x = F.relu(x)
x = F.dropout(x, training=self.training)
x = self.conv2(x, edge_index_local)
print(f'x sizes {x.size()}')
root_nids = [nids_dict[root.item()] for root in data.roots]
return F.log_softmax(x[root_nids], dim=1)
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from model import MyModel
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group
import os
import BatchData
from part.Utils import GraphData
def ddp_setup(rank, world_size):
"""
Args:
rank: Unique identifier of each process
world_size: Total number of processes
"""
# 指定主gpu和相应的端口号用于协调不同gpu之间的通信
#os.environ["MASTER_ADDR"] = "localhost"
#os.environ["MASTER_PORT"] = "12355"
# 初始化所有的gpu 每个gpu都拥有一个进程,彼此可以互相发现
# rank是每个process的唯一标识符,world_size是总共的gpu个数
# nccl 是nvidia的通信库的后端,用于分布式通信
#init_process_group(backend="gloo", rank=rank, world_size=world_size)
init_method="tcp://{}:{}".format("localhost",12355)
init_process_group(backend="gloo", rank=rank, world_size=world_size,init_method=init_method)
class Trainer:
def __init__(
self,
model: torch.nn.Module,
train_data: DataLoader,
optimizer: torch.optim.Optimizer,
gpu_id: int,
save_every: int,
) -> None:
self.gpu_id = gpu_id
self.model = model.to('cpu')
self.train_data = train_data
self.optimizer = optimizer
self.save_every = save_every
# 需要通过DDP包装model 告诉model该复制到哪些gpu中
self.model = DDP(model)
self.total_correct = 0
self.count = 0
def _run_batch(self, batchData:BatchData):
graph = GraphData('/home/sxx/zlj/rpc_ps/part/metis_1/rank_0')
self.count = self.count +1
print(f'run epoch in batch data f {self.count}')
l = len(batchData.edge_index)
# print(f'batchData:len:{l},edge_index:{batchData.edge_index}')
self.optimizer.zero_grad()
out = self.model(batchData)
print(f'out size: {out.size()}')
print(f'batchDate.y: {batchData.y}')
# batchData.y = F.one_hot(batchData.y, num_classes=7)
print(f'roots {batchData.roots}')
print(f'y size:{batchData.y.size()}')
# print(f'mask :{batchData.train_mask}')
##loss = F.nll_loss(out, batchData.y)
y = batchData.y#torch.masked_select(graph.data.y,graph.data.train_mask)
loss = F.nll_loss(out, y)
loss.backward()
self.optimizer.step()
print(out.argmax(dim=-1))
self.total_correct += int(out.argmax(dim=-1).eq(y).sum())
##self.total_correct += int(out.argmax(dim=-1).eq(batchData.y).sum())
self.total_count += y.size(0)
print('finish')
def _run_epoch(self, epoch):
self.total_correct = 0
self.total_count = 0
for batchData in self.train_data:
self._run_batch(batchData)
approx_acc = self.total_correct / self.total_count
print(f"=======[GPU{self.gpu_id}] Epoch {epoch} | approx_acc: {approx_acc}=======")
def _save_checkpoint(self, epoch):
# 由于model现在有了一层ddp的封装,访问模型的参数需要model.module
ckp = self.model.module.state_dict()
PATH = "checkpoint.pt"
torch.save(ckp, PATH)
print(f"=======Epoch {epoch} | Training checkpoint saved at {PATH}========")
def train(self, max_epochs: int):
for epoch in range(max_epochs):
self._run_epoch(epoch)
# 对于checkpoint只保存一份即可
if self.gpu_id == 0 and epoch % self.save_every == 0:
self._save_checkpoint(epoch)
# def load_train_objs():
# train_set =
# model =
# optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
# return train_set, model, optimizer
def prepare_dataloader(dataset: Dataset, batch_size: int):
return DataLoader(
dataset,
batch_size=batch_size,
pin_memory=True,
# shuffle在sampler已经实现了
shuffle=False,
# 分布式采样能够不重合地在不同的gpu上对样本进行采样
# sampler=DistributedSampler(dataset)
)
def main(rank: int, world_size: int, save_every: int, total_epochs: int, batch_size: int):
ddp_setup(rank, world_size)
optimizer = optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
dataset, model, optimizer = load_train_objs()
train_data = prepare_dataloader(dataset, batch_size)
trainer = Trainer(model, train_data, optimizer, rank, save_every)
trainer.train(total_epochs)
# 销毁ddp 所有进程
destroy_process_group()
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description='simple distributed training job')
parser.add_argument('--total_epochs', dest='total_epochs',type=int, help='Total epochs to train the model')
parser.add_argument('--save_every', dest='save_every',type=int, help='How often to save a snapshot')
parser.add_argument('--batch_size', default=32, type=int, help='Input batch size on each device (default: 32)')
parser.add_argument('--rank', '-r', dest='rank',type=int, help='Rank of this process')
parser.add_argument('--world-size', dest='world_size',type=int, default=1, help='World size')
args = parser.parse_args()
# world_size = torch.cuda.device_count()
world_size = 1
# 一个分布式api,同时启动多个进程
# mp.spawn(main, args=(world_size, args.save_every, args.total_epochs, args.batch_size), nprocs=2)
main(args.rank,args.world_size,args.save_every,args.total_epochs,args.batch_size)
import copy
import os.path as osp
import sys
from typing import Optional
import torch
import torch.utils.data
from torch_sparse import SparseTensor, cat
import numpy as np
# def metis(edge_index: LongTensor, num_nodes: int, num_parts: int) -> Tuple[LongTensor, LongTensor]:
# if num_parts <= 1:
# return _nopart(edge_index, num_nodes)
# adj_t = SparseTensor.from_edge_index(edge_index, sparse_sizes=(num_nodes, num_nodes)).to_symmetric()
# rowptr, col, _= adj_t.csr()
# node_parts = torch.ops.torch_sparse.partition(rowptr, col, None, num_parts, num_parts < 8)
# edge_parts = node_parts[edge_index[1]]
# return node_parts, edge_parts
class GPDataset(torch.utils.data.Dataset):
def __init__(self, data, num_parts: int, recursive: bool = False,
save_dir: Optional[str] = None, log: bool = True):
assert data.edge_index is not None
self.num_parts = num_parts
recursive_str = '_recursive' if recursive else ''
filename = f'partition_{num_parts}{recursive_str}.pt'
path = osp.join(save_dir or '', filename)
if save_dir is not None and osp.exists(path):
adj, partptr, perm = torch.load(path)
else:
if log: # pragma: no cover
print('Computing METIS partitioning...', file=sys.stderr)
N, E = data.num_nodes, data.num_edges
adj = SparseTensor(
row=data.edge_index[0], col=data.edge_index[1],
value=torch.arange(E, device=data.edge_index.device),
sparse_sizes=(N, N))
adj, partptr, perm = adj.partition(num_parts, recursive)
# self.global_adj = adj
if save_dir is not None:
torch.save((adj, partptr, perm), path)
if log: # pragma: no cover
print('Done!', file=sys.stderr)
# 对于所有的点属性重排
self.data = self.__permute_data__(data, perm, adj)
self.global_adj = adj
self.partptr = partptr
perm_ = torch.zeros_like(perm)
for i in range(len(perm)):
perm_[perm[i]] = i
self.perm = torch.stack([perm,perm_])
def __permute_data__(self, data, node_idx, adj):
out = copy.copy(data)
for key, value in data.items():
if data.is_node_attr(key):
out[key] = value[node_idx]
row, col, _ = adj.coo()
out.edge_index = torch.stack([row, col], dim=0)
out.adj = adj
return out
def __len__(self):
return self.partptr.numel() - 1
def __getitem__(self, idx):
# 第idx个分区的起始id和分区的id数量
start = int(self.partptr[idx])
length = int(self.partptr[idx + 1]) - start
N, E = self.data.num_nodes, self.data.num_edges
data = copy.copy(self.data)
del data.num_nodes
adj, data.adj = data.adj, None
# 将邻接矩阵进行切分
adj = adj.narrow(0, start, length).narrow(1, start, length)
#
edge_idx = adj.storage.value()
for key, item in data:
if isinstance(item, torch.Tensor) and item.size(0) == N:
data[key] = item.narrow(0, start, length)
elif isinstance(item, torch.Tensor) and item.size(0) == E:
data[key] = item[edge_idx]
else:
data[key] = item
row, col, _ = adj.coo()
data.edge_index = torch.stack([row, col], dim=0)
return data
def __repr__(self):
return (f'{self.__class__.__name__}(\n'
f' data={self.data},\n'
f' num_parts={self.num_parts}\n'
f')')
import os.path as osp
import torch
class GraphData():
def __init__(self, path):
assert path is not None and osp.exists(path),'path 不存在'
id,edge_index,data,partptr =torch.load(path)
# 当前分区序号
self.partition_id = id
# 总分区数
self.partitions = partptr.numel() - 1
#num_feature
self.num_feasures = data.x.size()[1]
# 全图结构数据
self.num_nodes = partptr[self.partitions]
self.num_edges = edge_index[0].numel()
self.edge_index = edge_index
# 该分区下的数据(包含特征向量和子图结构)pyg Data数据结构
self.data = data
# 分区映射关系
self.partptr = partptr
self.data.y = self.data.y.reshape(-1)
def select_attr(self,index):
return torch.index_select(self.data.x,0,index)
#返回全局的节点id 所对应的分区
def get_part_num(self):
return self.data.x.size()[0]
def select_attr(self,index):
return torch.index_select(self.data.x,0,index)
def select_y(self,index):
return torch.index_select(self.data.y,0,index)
#返回全局的节点id 所对应的分区
def get_localId_by_partitionId(self,id,index):
if(id == -1 or id == 0):
return index
else:
return torch.add(index,-self.partptr[id])
def get_globalId_by_partitionId(self,id,index):
if(id == -1 or id == 0):
return index
else:
return torch.add(index,self.partptr[id])
def get_node_num(self):
return self.num_nodes
def localId_to_globalId(self,id,partitionId:int = -1):
'''
将分区partitionId内的点id映射为全局的id
'''
if partitionId == -1:
partitionId = self.partition_id
assert id >=self.partptr[self.partition_id] and id < self.partptr[self.partition_id+1]
ids_before = 0
if self.partition_id>0:
ids_before = self.partptr[self.partition_id-1]
return id+ids_before
def get_partitionId_by_globalId(self,id):
'''
通过全局id得到对应的分区序号
'''
partitionId = -1
assert id>=0 and id<self.num_nodes,'id 超过范围'
for i in range(self.partitions):
if id>=self.partptr[i] and id<self.partptr[i+1]:
partitionId = i
break
assert partitionId>=0, 'id 不存在对应的分区'
return partitionId
def get_nodes_by_partitionId(self,id):
'''
根据partitioId 返回该分区的节点数量
'''
assert id>=0 and id<self.partitions,'partitionId 非法'
return (int)(self.partptr[id+1]-self.partptr[id])
def __repr__(self):
return (f'{self.__class__.__name__}(\n'
f' partition_id={self.partition_id}\n'
f' data={self.data},\n'
f' global_info('
f'num_nodes={self.num_nodes},'
f' num_edges={self.num_edges},'
f' num_parts={self.partitions},'
f' num_features={self.num_feasures}, '
f' edge_index=[2,{self.edge_index[0].numel()}])\n'
f')')
from Utils import GraphData
from torch_geometric.data import Data
g = GraphData('./rpc_ps/part/metis_4/rank_0')
x= Data()
print(g)
'''
GraphData(
partition_id=0
data=Data(x=[679, 1433], edge_index=[2, 2908], y=[679], train_mask=[679], val_mask=[679], test_mask=[679]),
global_info(num_nodes=2029, num_edges=10556, num_parts=4, edge_index=[2,10556])
)
'''
from enum import Enum
from multiprocessing.connection import Listener
import pickle
import queue
from threading import Thread
from multiprocessing import Pool, Lock, Value
from torch.distributed.rpc import TensorPipeRpcBackendOptions
from torch.distributed import rpc
import graph_store
class RpcCommand(Enum):
SET_GRAPH = 0
CALL_BARRIER = 1
UNLOAD_GRAPH =2
STOP_RPC = 3
"""
class RPCHandler:
def __init__(self):
self._functions= { }
def register_function(self,func):
self._functions[func.__name__] = func
def handle_connection(self,connection):
try:
while True:
func_name,args,kwargs = pickle.loads(connection.recv())
try:
r=self._functions[func_name] (*args,**kwargs)
connection.send(pickle.dumps(r))
except Exception as e:
connection.send(pickle.dumps(e))
except EOFError:
pass
def start_rpc_server(handler,address,queue):
sock = Listener(address)
queue.put(RpcCommand.FINISH_LISTEN)
while True:
client = sock.accept()
now_thread=Thread(target = handler.handle_connection, args=(client,))
now_thread.daemon = True
now_thread.start()
"""
NUM_WORKER_THREADS = 128
def start_rpc_listener(rpc_master,rpc_port,rpc_config,mp_config):
print(rpc_config)
num_worker_threads = rpc_config.get("num_worker_threads", NUM_WORKER_THREADS)
rpc_name = rpc_config.get("rpc_name", "rpcserver{}")
world_size = rpc_config.get("rpc_world_size", 2)
worker_rank = rpc_config.get("worker_rank", 0)
rpc_rank = rpc_config.get("rpc_worker_rank", 0)
rpc_backend_options = TensorPipeRpcBackendOptions()
rpc_backend_options.init_method = "tcp://{}:{}".format(rpc_master,rpc_port)
rpc_backend_options.num_worker_threads = num_worker_threads
task_queue,barrier = mp_config
print(rpc_master,rpc_port)
rpc.init_rpc(
rpc_name.format(worker_rank),
rank=rpc_rank,
world_size=world_size,
rpc_backend_options=rpc_backend_options
)
keep_pooling = True
while(keep_pooling):
try:
command,args = task_queue.get(timeout=5)
except queue.Empty:
continue
if command == RpcCommand.SET_GRAPH:
dataloader_name,graph_inshm= args
graph_store._set_graph(dataloader_name,graph_inshm)
elif command == RpcCommand.UNLOAD_GRAPH:
dataloader_name= args
graph_inshm = graph_store._get_graph(dataloader_name)
graph_store._del_graph(dataloader_name)
graph_inshm._close_graph_in_shame()
print('unload dataloader')
barrier.wait()
elif command == RpcCommand.CALL_BARRIER:
barrier.wait()
elif command == RpcCommand.STOP_RPC:
keep_pooling = False
graph_store._clear_all()
barrier.wait()
close_rpc()
def start_rpc_caller(rpc_master,rpc_port,rpc_config):
print(rpc_config)
num_worker_threads = rpc_config.get("num_worker_threads",NUM_WORKER_THREADS)
rpc_name = rpc_config.get("rpc_name","rpcserver{}")
world_size = rpc_config.get("rpc_world_size", 2)
worker_rank = rpc_config.get("worker_rank", 0)
rpc_rank = rpc_config.get("rpc_worker_rank", 0)
rpc_backend_options = TensorPipeRpcBackendOptions()
rpc_backend_options.init_method = "tcp://{}:{}".format(rpc_master,rpc_port)
rpc_backend_options.num_worker_threads = num_worker_threads
rpc.init_rpc(
rpc_name.format(worker_rank) + "-{}".format(rpc_rank),
rank=rpc_rank,
world_size=world_size,
rpc_backend_options=rpc_backend_options
)
def close_rpc():
rpc.shutdown(True)
"""
class RPCProxy:
def __init__(self, connection):
self._connection = connection
def __getattr__(self, name):
def do_rpc(*args, **kwargs):
self._connection.send(pickle.dumps((name, args, kwargs)))
result = pickle.loads(self._connection.recv())
if isinstance(result, Exception):
raise result
return result
return do_rpc
"""
\ No newline at end of file
import numpy as np
import torch
from multiprocessing import shared_memory
def _copy_to_share_memory(data):
data_array=data.numpy()
shm = shared_memory.SharedMemory(create=True, size=data_array.nbytes)
data_share = np.ndarray(data_array.shape, dtype=data_array.dtype, buffer=shm.buf)
np.copyto(data_share,data_array)
name = shm.name
_close_existing_shm(shm)
#data_share.copy_(data_array) # Copy the original data into shared memory
return name,data_array.shape,data_array.dtype
def _copy_to_shareable_list(data):
shm = shared_memory.ShareableList(data)
name = shm.shm.name
return name
def _get_from_shareable_list(name):
return shared_memory.ShareableList(name=name)
def _get_existing_share_memory(name):
return shared_memory.SharedMemory(name=name)
def _get_from_share_memory(existing_shm,data_shape,data_dtype):
data = np.ndarray(data_shape, dtype=data_dtype, buffer=existing_shm.buf)
return torch.from_numpy(data)
def _close_existing_shm(existing_shm):
if(existing_shm != None):
existing_shm.close()
def _unlink_existing_shm(existing_shm):
if(existing_shm != None):
existing_shm.unlink()
\ No newline at end of file
import torch
from part.Utils import GraphData
from share_memory_util import _close_existing_shm, _copy_to_share_memory, _copy_to_shareable_list, _get_existing_share_memory, _get_from_share_memory, _get_from_shareable_list, _unlink_existing_shm
from Sample.neighbor_sampler import NeighborSampler
class GraphInfoInShm(GraphData):
def __init__(self,graph):
self.partition_id = graph.partition_id
self.partitions = graph.partitions
self.num_nodes = graph.num_nodes
self.num_edges = graph.num_edges
#self.edge_index_info = _copy_to_share_memory(graph.edge_index)
self.partptr = graph.partptr
self.data_x_info=_copy_to_share_memory(graph.data.x)
self.data_y_info=_copy_to_share_memory(graph.data.y)
self.edge_index_info=_copy_to_share_memory(graph.edge_index)
def _get_graph_from_shm(self):
#self.edge_index_shm = _get_existing_share_memory(self.edge_index_info[0])
#self.edge_index = _get_existing_share_memory(self.edge_index_shm,self.edge_index_info[1:])
self.data_x_shm = _get_existing_share_memory(self.data_x_info[0])
self.data_x = _get_from_share_memory(self.data_x_shm,*self.data_x_info[1:])
self.data_y_shm = _get_existing_share_memory(self.data_y_info[0])
self.data_y = _get_from_share_memory(self.data_y_shm,*self.data_y_info[1:])
self.edge_index_shm = _get_existing_share_memory(self.edge_index_info[0])
self.edge_index = _get_from_share_memory(self.edge_index_shm,*self.edge_index_info[1:])
def _get_sampler_from_shame(self):
self.deg_shm = _get_from_shareable_list(*self.deg_info)#_get_existing_share_memory(self.deg_info)
self.neighbors_shm = _get_from_shareable_list(*self.neighbors_info)
self.deg = self.deg_shm
self.neighbors = self.neighbors[0]
def _copy_sampler_to_shame(self,neighbors,deg):
self.deg_info = _copy_to_shareable_list(deg)#_copy_to_share_memory(deg)
self.neighbors_info = _copy_to_shareable_list(neighbors)#_copy_to_share_memory(neighbors)
def _close_graph_in_shame(self):
#_close_existing_shm(self.edge_index_shm)
#_close_existing_shm(self.deg_shm.shm)
#_close_existing_shm(self.neighbors_shm.shm)
_close_existing_shm(self.data_x_shm)
_close_existing_shm(self.data_y_shm)
_close_existing_shm(self.edge_index_shm)
def _unlink_graph_in_shame(self):
#_unlink_existing_shm(self.edge_index_shm)
_unlink_existing_shm(self.data_x_shm)
_unlink_existing_shm(self.data_y_shm)
_unlink_existing_shm(self.edge_index_shm)
#返回全局的节点id 所对应的分区数量
def get_part_num(self):
return self.data_x.size()[0]
def select_attr(self,index):
return torch.index_select(self.data_x,0,index)
def select_y(self,index):
return torch.index_select(self.data_y,0,index)
#返回全局的节点id 所对应的分区
\ No newline at end of file
import argparse
import os
from DistGraphLoader import partition_load
path1=os.path.abspath('.')
import torch
from Sample.neighbor_sampler import NeighborSampler
from Sample.neighbor_sampler import get_neighbors
from part.Utils import GraphData
from DistGraphLoader import DistGraphData
from DistGraphLoader import DistributedDataLoader
from torch_geometric.data import Data
import distparser
from DistCustomPool import CustomPool
import DistCustomPool
from torch.distributed import rpc
from torch.distributed.rpc import RRef, rpc_async, remote
from torch.distributed.rpc import TensorPipeRpcBackendOptions
import time
from model import GraphSAGE
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group
import os
import BatchData
from part.Utils import GraphData
from torch_geometric.sampler.neighbor_sampler import NeighborSampler as PYGSampler
"""
test command
python test.py --world_size 2 --rank 0
--world_size', default=4, type=int, metavar='W',
help='number of workers')
parser.add_argument('--rank', default=0, type=int, metavar='W',
help='rank of the worker')
parser.add_argument('--log_interval', type=int, default=10, metavar='N',
help='interval between training status logs')
parser.add_argument('--gamma', type=float, default=0.99, metavar='G',
help='how much to value future rewards')
parser.add_argument('--seed', type=int, default=1, metavar='S',
help='random seed for reproducibility')
parser.add_argument('--num_sampler', type=int, default=10, metavar='S',
help='number of samplers')
parser.add_argument('--queue_size', type=int, default=10, metavar='S',
help='sampler queue size')
"""
sage_neighsampler_parameters = {'lr':0.003
, 'num_layers':2
, 'hidden_channels':128
, 'dropout':0.0
, 'l2':5e-7
}
class Trainer:
def __init__(
self,
model: torch.nn.Module,
train_data: DataLoader,
test_data:DataLoader,
optimizer: torch.optim.Optimizer,
gpu_id: int,
save_every: int,
) -> None:
self.gpu_id = gpu_id
self.model = model.to('cpu')
self.train_data = train_data
self.test_data = test_data
self.optimizer = optimizer
self.save_every = save_every
# 需要通过DDP包装model 告诉model该复制到哪些gpu中
self.model = DDP(model)
self.total_correct = 0
self.count = 0
def _run_test(self,batchData:BatchData):
self.count = self.count +1
print(f'run epoch in batch data f {self.count}')
l = len(batchData.edge_index)
# print(f'batchData:len:{l},edge_index:{batchData.edge_index}')
out = self.model(batchData,1)
# print(f'out size: {out.size()}')
# print(f'batchDate.y: {batchData.y}')
# # batchData.y = F.one_hot(batchData.y, num_classes=7)
# print(f'roots {batchData.roots}')
# print(f'y size:{batchData.y.size()}')
# print(f'mask :{batchData.train_mask}')
##loss = F.nll_loss(out, batchData.y)
y = batchData.y#torch.masked_select(graph.data.y,graph.data.train_mask)
loss = F.nll_loss(out, y)
self.total_correct += int(out.argmax(dim=-1).eq(y).sum())
##self.total_correct += int(out.argmax(dim=-1).eq(batchData.y).sum())
self.total_count += y.size(0)
def _run_batch(self, batchData:BatchData):
# graph = GraphData('/home/sxx/zlj/rpc_ps/part/metis_1/rank_0')
self.count = self.count +1
print(f'run epoch in batch data f {self.count}')
l = len(batchData.edge_index)
# print(f'batchData:len:{l},edge_index:{batchData.edge_index}')
self.optimizer.zero_grad()
out = self.model(batchData,0)
# print(f'out size: {out.size()}')
# print(f'batchDate.y: {batchData.y}')
# # batchData.y = F.one_hot(batchData.y, num_classes=7)
# print(f'roots {batchData.roots}')
# print(f'y size:{batchData.y.size()}')
# print(f'mask :{batchData.train_mask}')
##loss = F.nll_loss(out, batchData.y)
y = batchData.y#torch.masked_select(graph.data.y,graph.data.train_mask)
loss = F.nll_loss(out, y)
loss.backward()
self.optimizer.step()
print(out.argmax(dim=-1))
self.total_correct += int(out.argmax(dim=-1).eq(y).sum())
##self.total_correct += int(out.argmax(dim=-1).eq(batchData.y).sum())
self.total_count += y.size(0)
print('finish')
def _run_epoch(self, epoch):
self.total_correct = 0
self.total_count = 0
for batchData in self.train_data:
self._run_batch(batchData)
approx_acc = self.total_correct / self.total_count
print(f"=======[GPU{self.gpu_id}] Epoch {epoch} | approx_acc: {approx_acc}=======")
self.total_correct = 0
self.total_count = 0
for batchData in self.test_data:
self._run_test(batchData)
approx_acc = self.total_correct / self.total_count
print(f"=======[GPU{self.gpu_id}] Epoch {epoch} | test_approx_acc: {approx_acc}=======")
def _save_checkpoint(self, epoch):
# 由于model现在有了一层ddp的封装,访问模型的参数需要model.module
ckp = self.model.module.state_dict()
PATH = "checkpoint.pt"
torch.save(ckp, PATH)
print(f"=======Epoch {epoch} | Training checkpoint saved at {PATH}========")
def train(self, max_epochs: int):
for epoch in range(max_epochs):
self._run_epoch(epoch)
# 对于checkpoint只保存一份即可
if self.gpu_id == 0 and epoch % self.save_every == 0:
self._save_checkpoint(epoch)
@torch.no_grad()
def test(layer_loader, model, data, split_idx, device, no_conv=False):
# data.y is labels of shape (N, )
model.eval()
out = model.inference(data.x, layer_loader, device)
# out = model.inference_all(data)
y_pred = out.exp() # (N,num_classes)
losses = dict()
for key in ['train', 'valid', 'test']:
node_id = split_idx[key]
node_id = node_id.to(device)
losses[key] = F.nll_loss(out[node_id], data.y[node_id]).item()
return losses, y_pred
def main():
parser = distparser.parser.add_subparsers().add_parser("train")#argparse.ArgumentParser(description='minibatch_gnn_models')
parser.add_argument('--device', type=str, default='cpu')
parser.add_argument('--dataset', type=str, default='Cora')
parser.add_argument('--log_steps', type=int, default=10)
parser.add_argument('--model', type=str, default='sage_neighsampler')
parser.add_argument('--epochs', type=int, default=100)
parser.add_argument('--rank', default=0, type=int, metavar='W',
help='rank of the worker')
parser.add_argument('--world_size', default=1, type=int, metavar='W',
help='number of the gpus')
args = parser.parse_args()
print(args)
#init_method="tcp://{}:{}".format('127.0.0.1',10031)
#dist.init_process_group(backend="gloo", world_size=args.world_size, rank=args.rank,init_method=init_method)
DistCustomPool.init_distribution('127.0.0.1',9673,'127.0.0.1',10011,backend = "gloo")
#graph = DistGraphData('./part/metis_1')
#graph = DistGraphData('/home/sxx/pycode/work/ogbn-products/metis_2')
pdata = partition_load("../../cora", algo="metis")
graph = DistGraphData(pdata = pdata,edge_index= pdata.edge_index)
row, col = graph.edge_index
tnb = get_neighbors(row.contiguous(), col.contiguous(), graph.num_nodes)
sampler = NeighborSampler(graph.num_nodes, num_layers=2, fanout=[10,5], workers=10, tnb=tnb)
loader = DistributedDataLoader('train',graph,graph.data.train_mask,sampler = sampler,batch_size = 100,shuffle=True)
testloader = DistributedDataLoader('test',graph,graph.data.test_mask,sampler = sampler,batch_size = 100,shuffle=True)
#count_node = 0
#count_edge = 0
#count_x_byte = 0
#start_time = time.time()
#cnt = 0
#for batchData in loader:
# cnt = cnt+1
# count_node += batchData.nids.size(0)
# count_x_byte += batchData.x.numel()*batchData.x.element_size()
# for edge_list in batchData.edge_index:
# count_edge += edge_list.size(1)
# #count_edge += batchData.edge_index.size(1)
# dt = time.time() - start_time
# print('{} count node {},count edge {}, node TPS {},edge TPS {}, x size {}, x TPS {} byte'
# .format(cnt,count_node,count_edge,count_node/dt,count_edge/dt,count_x_byte,count_x_byte/dt),batchData.x.size(0),batchData.x.element_size(),batchData.x.numel())
epochs = args.epochs
if args.model == 'sage_neighsampler':
para_dict = sage_neighsampler_parameters
num_classes = 7
num_node_features = graph.num_feasures
model = GraphSAGE(num_node_features,num_classes)
num_node_features = graph.num_feasures
print(f'Model {args.model} initialized')
epochs = args.epochs
optimizer = torch.optim.Adam(model.parameters(), lr=para_dict['lr'])
model.train()
trainer = Trainer(model, loader,testloader, optimizer, graph.rank, 5)
trainer.train(epochs)
#
DistCustomPool.close_distribution()
if __name__ == "__main__":
main()
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment