Commit f69647c0 by zhlj

Initial commit

parents
from distparser import WORLD_SIZE
import torch
class BatchData:
'''
Args:
batch_size: the number of sampled node
nids: the real id of sampled nodes and their neighbours
edge_index: the edge between nids, the nodes' id is the index of BatchData
x: nodes' feature
y:
eids: the edge id in subgraph of edge_index
train_mask
val_mask
test_mask
'''
def __init__(self,nids,edge_index,roots=None,x=None,y=None,eids=None,train_mask=None,val_mask=None,test_mask=None):
self.batch_size=roots.size(0)
self.roots=roots
self.nids=nids
self.edge_index=edge_index
self.x=x
self.y=y
self.eids=eids
self.train_mask=train_mask
self.val_mask=val_mask
self.test_mask=test_mask
def _check_with_graph(self,graph1,graph2):
for i,id in enumerate(self.nids):
if(id >= graph1.partptr[0] and id < graph1.partptr[1]):
real_x = graph1.select_attr(graph1.get_localId_by_partitionId(0,torch.tensor(id)))
else:
real_x = graph2.select_attr(graph2.get_localId_by_partitionId(1,torch.tensor(id)))
if(torch.equal(real_x[0],self.x[i])==False):
print(real_x[0],self.x[i])
print('error')
def __repr__(self):
return "BatchData(batch_size = {},roots = {} , \
nides = {} , edge_index = {} , x= {}, \
y ={})".format(self.batch_size,self.roots.__repr__,
self.nids.__repr__,
self.edge_index.__repr__,
self.train_mask.__repr__,
self.x.__repr__,self.y.__repr__)
This diff is collapsed. Click to expand it.
import torch
from torch import Tensor
from enum import Enum
import math
from abc import ABC
from typing import Optional, Tuple, Union
class NegativeSamplingMode(Enum):
# 'binary': Randomly sample negative edges in the graph.
binary = 'binary'
# 'triplet': Randomly sample negative destination nodes for each positive
# source node.
triplet = 'triplet'
class NegativeSampling:
r"""The negative sampling configuration of a
:class:`~torch_geometric.sampler.BaseSampler` when calling
:meth:`~torch_geometric.sampler.BaseSampler.sample_from_edges`.
Args:
mode (str): The negative sampling mode
(:obj:`"binary"` or :obj:`"triplet"`).
If set to :obj:`"binary"`, will randomly sample negative links
from the graph.
If set to :obj:`"triplet"`, will randomly sample negative
destination nodes for each positive source node.
amount (int or float, optional): The ratio of sampled negative edges to
the number of positive edges. (default: :obj:`1`)
weight (torch.Tensor, optional): A node-level vector determining the
sampling of nodes. Does not necessariyl need to sum up to one.
If not given, negative nodes will be sampled uniformly.
(default: :obj:`None`)
"""
mode: NegativeSamplingMode
amount: Union[int, float] = 1
weight: Optional[Tensor] = None
def __init__(
self,
mode: Union[NegativeSamplingMode, str],
amount: Union[int, float] = 1,
weight: Optional[Tensor] = None,
):
self.mode = NegativeSamplingMode(mode)
self.amount = amount
self.weight = weight
if self.amount <= 0:
raise ValueError(f"The attribute 'amount' needs to be positive "
f"for '{self.__class__.__name__}' "
f"(got {self.amount})")
if self.is_triplet():
if self.amount != math.ceil(self.amount):
raise ValueError(f"The attribute 'amount' needs to be an "
f"integer for '{self.__class__.__name__}' "
f"with 'triplet' negative sampling "
f"(got {self.amount}).")
self.amount = math.ceil(self.amount)
def is_binary(self) -> bool:
return self.mode == NegativeSamplingMode.binary
def is_triplet(self) -> bool:
return self.mode == NegativeSamplingMode.triplet
def sample(self, num_samples: int,
num_nodes: Optional[int] = None) -> Tensor:
r"""Generates :obj:`num_samples` negative samples."""
if self.weight is None:
if num_nodes is None:
raise ValueError(
f"Cannot sample negatives in '{self.__class__.__name__}' "
f"without passing the 'num_nodes' argument")
return torch.randint(num_nodes, (num_samples, ))
if num_nodes is not None and self.weight.numel() != num_nodes:
raise ValueError(
f"The 'weight' attribute in '{self.__class__.__name__}' "
f"needs to match the number of nodes {num_nodes} "
f"(got {self.weight.numel()})")
return torch.multinomial(self.weight, num_samples, replacement=True)
class BaseSampler(ABC):
r"""An abstract base class that initializes a graph sampler and provides
:meth:`_sample_one_layer_from_nodes`
:meth:`_sample_one_layer_from_nodes_parallel`
:meth:`sample_from_nodes` routines.
"""
def sample_from_nodes(
self,
nodes: torch.Tensor,
**kwargs
) -> Tuple[torch.Tensor, list]:
r"""Performs mutilayer sampling from the nodes specified in: nodes
The specific number of layers is determined by parameter: num_layers
returning a sampled subgraph in the specified output format: Tuple[torch.Tensor, list].
Args:
nodes: the list of seed nodes index
**kwargs: other kwargs
Returns:
sampled_nodes: the nodes sampled
sampled_edge_index_list: the edges sampled
"""
raise NotImplementedError
def sample_from_edges(
self,
edges: torch.Tensor,
edge_label: Optional[torch.Tensor] = None,
neg_sampling: Optional[NegativeSampling] = None
) -> Tuple[torch.Tensor, list]:
r"""Performs sampling from the edges specified in :obj:`index`,
returning a sampled subgraph in the specified output format.
Args:
edges: the list of seed edges index
edge_label: the label for the seed edges.
neg_sampling: The negative sampling configuration
Returns:
sampled_nodes: the nodes sampled
sampled_edge_index_list: the edges sampled
metadata: other infomation
"""
raise NotImplementedError
# def _sample_one_layer_from_nodes(
# self,
# nodes:torch.Tensor,
# **kwargs
# ) -> Tuple[torch.Tensor, torch.Tensor]:
# r"""Performs sampling from the nodes specified in: nodes,
# returning a sampled subgraph in the specified output format: Tuple[torch.Tensor, torch.Tensor].
# Args:
# nodes: the list of seed nodes index
# **kwargs: other kwargs
# Returns:
# sampled_nodes: the nodes sampled
# sampled_edge_index: the edges sampled
# """
# raise NotImplementedError
# def _sample_one_layer_from_nodes_parallel(
# self,
# nodes: torch.Tensor,
# **kwargs
# ) -> Tuple[torch.Tensor, torch.Tensor]:
# r"""Performs sampling paralleled from the nodes specified in: nodes,
# returning a sampled subgraph in the specified output format: Tuple[torch.Tensor, torch.Tensor].
# Args:
# nodes: the list of seed nodes index
# **kwargs: other kwargs
# Returns:
# sampled_nodes: the nodes sampled
# sampled_edge_index: the edges sampled
# """
# raise NotImplementedError
import torch
# import sys
# from os.path import abspath, dirname
# sys.path.insert(0, abspath(dirname(__file__)))
# print(sys.path)
from neighbor_sampler import NeighborSampler
# edge_index = torch.tensor([[0, 1, 1, 2, 2, 2, 3], [1, 0, 2, 1, 3, 0, 2]])
# num_nodes = 4
edge_index = torch.tensor([[0, 1, 1, 1, 2, 2, 2, 3, 3, 4, 4, 4, 5], [1, 0, 2, 4, 1, 3, 0, 2, 5, 3, 5, 0, 2]])
num_nodes = 6
num_neighbors = 2
# Run the neighbor sampling
sampler=NeighborSampler(edge_index=edge_index, num_nodes=num_nodes, num_layers=2, workers=2, fanout=[2, 1])
# neighbor_nodes, sampled_edge_index = sampler._sample_one_layer_from_nodes(nodes=torch.tensor([1,3]), fanout=num_neighbors)
neighbor_nodes, sampled_edge_index = sampler.sample_from_nodes(torch.tensor([1,2,3]))
# Print the result
print('neighbor_nodes_id: \n',neighbor_nodes, '\nedge_index: \n',sampled_edge_index)
# import torch_scatter
# nodes=torch.Tensor([1,2])
# row, col = edge_index
# deg = torch_scatter.scatter_add(torch.ones_like(row), row, dim=0, dim_size=num_nodes)
# neighbors1=torch.concat([row[row==nodes[i]] for i in range(0, nodes.shape[0])])
# print(neighbors1)
# neighbors2=torch.concat([col[row==nodes[i]] for i in range(0, nodes.shape[0])])
# print(neighbors2)
# neighbors=torch.stack([neighbors1, neighbors2], dim=0)
# print('neighbors: \n', neighbors[0]==1)
from base import NegativeSampling
edge_index = torch.tensor([[0, 1, 1, 1, 2, 2, 2, 3, 3, 4, 4, 4, 5], [1, 0, 2, 4, 1, 3, 0, 2, 5, 3, 5, 0, 2]])
num_nodes = 6
# sampler
sampler=NeighborSampler(edge_index=edge_index, num_nodes=num_nodes, num_layers=2, workers=2, fanout=[2, 1])
# negative
weight = torch.tensor([0.3,0.1,0.1,0.1,0.3,0.1])
negative = NegativeSampling('binary', 2, weight)
# negative = NegativeSampling('triplet', 2, weight)
label=torch.tensor([1,2])
seed_edges = torch.tensor([[0,1],
[1,4]])
# result = sampler.sample_from_edges(edges=seed_edges)
result = sampler.sample_from_edges(edges=seed_edges, edge_label=label, neg_sampling=negative)
# Print the result
print('neighbor_nodes_id: \n',result[0], '\nedge_index: \n',result[1], '\nmetadata: \n',result[2])
\ No newline at end of file
import torch
from ogb.nodeproppred import PygNodePropPredDataset
from torch_geometric import datasets
import time
def load_ogb_dataset(name, data_path):
dataset = PygNodePropPredDataset(name=name, root=data_path)
split_idx = dataset.get_idx_split()
g = dataset[0]
n_node = g.num_nodes
node_data={}
node_data['train_mask'] = torch.zeros(n_node, dtype=torch.bool)
node_data['val_mask'] = torch.zeros(n_node, dtype=torch.bool)
node_data['test_mask'] = torch.zeros(n_node, dtype=torch.bool)
node_data['train_mask'][split_idx["train"]] = True
node_data['val_mask'][split_idx["valid"]] = True
node_data['test_mask'][split_idx["test"]] = True
return g, node_data
g, node_data = load_ogb_dataset('ogbn-products', "/home/hzq/code/gnn/test/NewSample/dataset")
print(g)
from neighbor_sampler import NeighborSampler
pre = time.time()
from neighbor_sampler import get_neighbors
row, col = g.edge_index
tnb = get_neighbors(row.contiguous(), col.contiguous(), g.num_nodes)
sampler = NeighborSampler(g.num_nodes, num_layers=2, fanout=[100,100], workers=4, tnb=tnb)
end = time.time()
print("init time:", end-pre)
# from torch_geometric.sampler import NeighborSampler, NumNeighbors, NodeSamplerInput, SamplerOutput
# pre = time.time()
# num_nei = NumNeighbors([100, 100])
# node_idx = NodeSamplerInput(input_id=None, node=torch.masked_select(torch.arange(g.num_nodes),node_data['train_mask']))
# sampler = NeighborSampler(g, num_nei)
# end = time.time()
# print("init time:", end-pre)
pre = time.time()
node, edge = sampler.sample_from_nodes(torch.masked_select(torch.arange(g.num_nodes),node_data['train_mask']))
# out = sampler.sample_from_nodes(node_idx)
# node = out.node
# edge = [out.row, out.col]
end = time.time()
print(node)
print(edge)
print("sample time", end-pre)
\ No newline at end of file
import torch
import torch.multiprocessing as mp
from typing import Optional, Tuple
from .base import BaseSampler, NegativeSampling
from .neighbor_sampler import NeighborSampler
class RandomWalkSampler(BaseSampler):
def __init__(
self,
num_nodes: int,
num_layers: int,
workers = 1,
edge_index : Optional[torch.Tensor] = None,
deg = None,
neighbors = None
) -> None:
r"""__init__
Args:
num_nodes: the num of all nodes in the graph
num_layers: the num of layers to be sampled
fanout: the list of max neighbors' number chosen for each layer
workers: the number of threads, default value is 1
edge_index: all edges in the graph
neighbors: all nodes' neighbors
deg: the degree of all nodes
"""
super().__init__()
if(edge_index is not None):
self.sampler = NeighborSampler(num_nodes, num_layers, [1 for _ in range(num_layers)],
workers, edge_index)
elif(neighbors is not None and deg is not None):
self.sampler = NeighborSampler(num_nodes, num_layers, [1 for _ in range(num_layers)],
workers, neighbors, deg)
else:
raise Exception("Not enough parameters")
self.num_layers = num_layers
# 线程数不超过torch默认的omp线程数
self.workers = min(workers, torch.get_num_threads())
def sample_from_nodes(
self,
nodes: torch.Tensor
) -> Tuple[torch.Tensor, list]:
r"""Performs mutilayer sampling from the nodes specified in: nodes
The specific number of layers is determined by parameter: num_layers
returning a sampled subgraph in the specified output format: Tuple[torch.Tensor, list].
Args:
nodes: the list of seed nodes index
Returns:
sampled_nodes: the node sampled
sampled_edge_index: the edge sampled
"""
return self.sampler.sample_from_nodes(nodes)
def sample_from_edges(
self,
edges: torch.Tensor,
edge_label: Optional[torch.Tensor] = None,
neg_sampling: Optional[NegativeSampling] = None
) -> Tuple[torch.Tensor, list]:
r"""Performs sampling from the edges specified in :obj:`index`,
returning a sampled subgraph in the specified output format.
Args:
edges: the list of seed edges index
edge_label: the label for the seed edges.
neg_sampling: The negative sampling configuration
Returns:
sampled_nodes: the nodes sampled
sampled_edge_index_list: the edges sampled
"""
return self.sampler.sample_from_edges(edges, edge_label, neg_sampling)
if __name__=="__main__":
edge_index1 = torch.tensor([[0, 1, 1, 1, 2, 2, 2, 4, 4, 4, 5], # , 3, 3
[1, 0, 2, 4, 1, 3, 0, 3, 5, 0, 2]])# , 2, 5
num_nodes1 = 6
# Run the random walk sampling
sampler=RandomWalkSampler(edge_index=edge_index1, num_nodes=num_nodes1, num_layers=3, workers=4)
sampled_nodes, sampled_edge_index = sampler.sample_from_nodes(torch.tensor([1,2]))
# Print the result
print('sampled_nodes_id: \n',sampled_nodes, '\nedge_index: \n',sampled_edge_index)
Metadata-Version: 2.1
Name: sample-cores
Version: 0.0.0
sample_cores.cpp
setup.py
sample_cores.egg-info/PKG-INFO
sample_cores.egg-info/SOURCES.txt
sample_cores.egg-info/dependency_links.txt
sample_cores.egg-info/top_level.txt
\ No newline at end of file
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CppExtension
setup(
name='sample_cores',
ext_modules=[
CppExtension(
name='sample_cores',
sources=['sample_cores.cpp'],
extra_compile_args=['-fopenmp','-Xlinker',' -export-dynamic'],
include_dirs=["/home/hzq/code/gnn/test/NewSample"],
),
],
cmdclass={
'build_ext': BuildExtension
})
File added
import sys
import argparse
parser = argparse.ArgumentParser(
description="RPC Reinforcement Learning Example",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument('--world_size', default=1, type=int, metavar='W',
help='number of workers')
parser.add_argument('--rank', default=0, type=int, metavar='W',
help='rank of the worker')
parser.add_argument('--log_interval', type=int, default=10, metavar='N',
help='interval between training status logs')
parser.add_argument('--gamma', type=float, default=0.99, metavar='G',
help='how much to value future rewards')
parser.add_argument('--seed', type=int, default=1, metavar='S',
help='random seed for reproducibility')
parser.add_argument('--num_sampler', type=int, default=1, metavar='S',
help='number of samplers')
parser.add_argument('--queue_size', type=int, default=10000, metavar='S',
help='sampler queue size')
args = parser.parse_args()
rpc_proxy=None
WORKER_RANK = args.rank
NUM_SAMPLER = args.num_sampler
WORLD_SIZE = args.world_size
QUEUE_SIZE = args.queue_size
MAX_QUEUE_SIZE = 5*args.queue_size
RPC_NAME = "rpcserver{}"
def _get_worker_rank():
return WORKER_RANK
def _get_num_sampler():
return NUM_SAMPLER
def _get_world_size():
return WORLD_SIZE
def _get_RPC_NAME():
return RPC_NAME
def _get_queue_size():
return QUEUE_SIZE
def _get_max_queue_size():
return MAX_QUEUE_SIZE
def _get_rpc_name():
return RPC_NAME
\ No newline at end of file
graph_set={}
def _clear_all(barrier = None):
global graph_set
for key in graph_set:
graph = graph_set[key]
graph._close_graph_in_shame()
print('clear ',key)
if(barrier is not None and barrier.wait()==0):
graph._unlink_graph_in_shame()
graph_set = {}
def _set_graph(graph_name,graph_info):
graph_info._get_graph_from_shm()
graph_set[graph_name]=graph_info
def _get_graph(graph_name):
return graph_set[graph_name]
def _del_graph(graph_name):
graph_set.pop(graph_name)
def _get_size():
return len(graph_set)
\ No newline at end of file
import logging
import os
import time
class Config():
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)
# 添加控制台管理器(即控制台展示log内容)
ls = logging.StreamHandler()
ls.setLevel(logging.DEBUG)
# 设置log的记录格式
formatter = logging.Formatter('%(asctime)s-%(name)s-%(filename)s-[line:%(lineno)d]''-%(levelname)s: %(message)s')
# 把格式添加到控制台管理器,即控制台打印日志
ls.setFormatter(formatter)
# 把控制台添加到logger
logger.addHandler(ls)
# 先在项目目录下建一个logs目录,来存放log文件(可自定义路径)
print(os.path.abspath('.') )
logdir = os.path.join(os.path.abspath('.'), 'logs')
if not os.path.exists(logdir):
os.mkdir(logdir)
# 再在logs目录下创建以日期开头的.log文件
logfile = os.path.join(logdir, time.strftime('%Y-%m-%d %H:%M:%S') + '.log')
# 添加log的文件处理器,并设置log的配置文件模式编码
lf = logging.FileHandler(filename=logfile, encoding='utf8')
# 设置log文件处理器记录的日志级别
lf.setLevel(logging.DEBUG)
# 设置日志记录的格式
lf.setFormatter(formatter)
# 把文件处理器添加到log
logger.addHandler(lf)
def get_config(self):
return self.logger
logger = Config().get_config()
2023-05-11 02:24:19,565-torch.distributed.distributed_c10d-distributed_c10d.py-[line:228]-INFO: Added key: store_based_barrier_key:0 to store for rank: 0
2023-05-11 02:24:19,566-torch.distributed.distributed_c10d-distributed_c10d.py-[line:262]-INFO: Rank 0: Completed store-based barrier for key:store_based_barrier_key:0 with 1 nodes.
2023-05-11 02:24:20,504-torch.nn.parallel.distributed-distributed.py-[line:995]-INFO: Reducer buckets have been rebuilt in this iteration.
from concurrent.futures import ThreadPoolExecutor
from enum import Enum
from multiprocessing import Barrier
from multiprocessing.connection import Client
import queue
import sys
import argparse
import traceback
from torch.distributed.rpc import RRef, rpc_async, remote
import torch.distributed.rpc as rpc
import torch
from part.Utils import GraphData
from Sample.neighbor_sampler import NeighborSampler
import time
from typing import Optional
import torch.distributed as dist
import torch.multiprocessing as map
from BatchData import BatchData
import asyncio
import os
from Sample.neighbor_sampler import NeighborSampler
from graph_store import _get_graph
from shared_graph_memory import GraphInfoInShm
import distparser as parser
from logger import logger
WORKER_RANK = parser._get_worker_rank()
NUM_SAMPLER = parser._get_num_sampler()
WORLD_SIZE = parser._get_world_size()
QUEUE_SIZE =parser._get_queue_size()
MAX_QUEUE_SIZE = parser._get_max_queue_size()
RPC_NAME = parser._get_RPC_NAME()
#@rpc.functions.async_execution
def _get_local_attr(data_name,nodes):
graph = _get_graph(data_name)
local_id = graph.get_localId_by_partitionId(parser._get_worker_rank(),nodes)
return graph.select_attr(local_id)
def _request_remote_attr(rank,data_name,nodes):
t1 = time.time()
fut = rpc_async(
parser._get_RPC_NAME().format(rank),
_get_local_attr,
args=(data_name,nodes,)
)
#logger.debug('request {}'.format(time.time()-t1))
return fut
#ThreadPoolExecutor pool
def _request_all_remote_attr(data_name,nodes_list):
worker_size = parser._get_world_size()
worker_rank = parser._get_worker_rank()
futs = []
for rank in range(worker_size):
if(rank == worker_rank):
futs.append(None)
continue
else:
if(nodes_list[rank].size(0) == 0):
futs.append(None)
else:
futs.append(_request_remote_attr(rank,data_name,nodes_list[rank]))
return futs
def _check_future_finish(futs):
check = True
for _,fut in futs:
if fut is not None and fut.done() is False:
check = False
return check
def _split_node_part_and_submit(data_name,node_id_list):
t0 = time.time()
graph = _get_graph(data_name)
worker_size = parser._get_world_size()
local_rank = parser._get_worker_rank()
futs = []
t1 = time.time()
for rank in range(worker_size):
if(rank != local_rank):
part_mask = (graph.partptr[rank]<=node_id_list) & (node_id_list<graph.partptr[rank+1])
part_node = torch.masked_select(node_id_list,part_mask)
if(part_node.size(0) != 0):
futs.append((part_node,_request_remote_attr(rank,data_name,part_node)))
t2 = time.time()
local_mask = (graph.partptr[local_rank]<=node_id_list) & (node_id_list<graph.partptr[local_rank+1])
t3 = time.time()
local_node = torch.masked_select(node_id_list,local_mask)
t4 = time.time()
#logger.debug('size {},split {} {} {}'.format(node_id_list.size(0),t2-t1,t3-t2,t4-t3))
return local_node,futs
#def _split_node_part(data_name,node_id_list):
# graph = _get_graph(data_name)
# part_list=[]
# scatter_list=[]
# worker_size = parser._get_world_size()
# #print(node_id_list)
# for rank in range(worker_size):
# group_id=(graph.partptr[rank]<=node_id_list) & (node_id_list<graph.partptr[rank+1])
# part_list.append(torch.masked_select(node_id_list,group_id))
# scatter_list.append(torch.nonzero(group_id))
# #print(graph.partptr[rank],graph.partptr[rank+1])
# #print(torch.nonzero(group_id))
# #print(node_id_list)
# scatter_index=torch.cat(scatter_list)
# # print('part_list',part_list)
# return part_list,scatter_index
#
#def _union_remote_and_local(data_name,scatter_list,local_nodes,futs):
# worker_rank = parser._get_worker_rank()
# worker_size = parser._get_world_size()
# feature_part = []
# local_feature=_get_local_attr(data_name,local_nodes)
# node_feature_list = []
# for rank in range(worker_size):
# if(rank == worker_rank):
# node_feature_list.append(local_feature)
# elif(futs[rank] is None):
# continue
# else:
# node_feature_list.append(futs[rank].value())
#
# node_feature = torch.cat(node_feature_list)
# attr_size=node_feature.size()
# x=torch.zeros(attr_size)
# #print(scatter_list.size())
# #print(node_feature.size())
# #print(scatter_list.repeat(1,attr_size[1]).size())
# #print(x.size())
# #print(scatter_list)
# x=x.scatter(0,scatter_list.repeat(1,attr_size[1]),node_feature)
# return x
#total_local = 0
#total_remote = 0
#def _get_batch_data(kwargs):
# global total_local
# global total_remote
# data_name = kwargs.get("data_name")
# input_data = kwargs.get("input_data")
# nids = kwargs.get("nids")
# root = nids[:input_data.size(0)]
# scatter_list,local_nodes,futs = kwargs.get("union_args")
# edge_index = kwargs.get("edge_index")
# x=_union_remote_and_local(data_name,scatter_list,local_nodes,futs)
# graph = _get_graph(data_name)
# worker_rank = parser._get_worker_rank()
# local_y_id = graph.get_localId_by_partitionId(worker_rank,root)
# y = graph.select_y(local_y_id)
# total_local = total_local + len(local_nodes)
# total_remote = total_remote + len(x)-len(local_nodes)
# print('total local',total_local,'total remote',total_remote)
# #print(t2-t1,t3-t2,t4-t3,t5-t4,time.time()-t4)
# return BatchData(nids,edge_index,roots=input_data,x=x,y=y)
def _get_batch_data(kwargs):
data_name = kwargs.get("data_name")
input_size = kwargs.get("input_size")
local_nodes,futs = kwargs.get("union_args")
edge_index = kwargs.get("edge_index")
graph = _get_graph(data_name)
nids = [local_nodes]
root = local_nodes[:input_size]
worker_rank = parser._get_worker_rank()
local_y_id = graph.get_localId_by_partitionId(worker_rank,root)
y = graph.select_y(local_y_id)
x = [_get_local_attr(data_name,local_nodes)]
for (part_node,part_feature) in futs:
nids.append(part_node)
x.append(part_feature.value())
nids = torch.cat(nids,0)
x = torch.cat(x,0)
return BatchData(nids,edge_index,roots=root,x=x,y=y)
def _sample_node_neighbors_server(data_name,neighbor_nodes):
'''
sample the struct of the subgraph
'''
print('sample node neighbors server')
t1 = time.time()
local_node,futs = _split_node_part_and_submit(data_name,neighbor_nodes)
t2 = time.time()
#logger.debug('sample server {}'.format(t2-t1))
return local_node,futs
def _sample_node_neighbors_single(graph,graph_name,sampler,input_data):
#out = sampler.sample_from_nodes((graph_name,input_data,time.time()))
nids,edge_index = sampler.sample_from_nodes(input_data.reshape(-1))
#nids = out.node
#edge_index = torch.cat((out.row.reshape(1,-1),out.col.reshape(1,-1)),0)
x= graph.select_attr(nids)
y = graph.select_y(input_data)
return BatchData(nids,edge_index,roots=input_data,x=x,y=y)
\ No newline at end of file
from mimetypes import init
import torch
import torch.nn as nn
import torch.nn.functional as F
from part.Utils import GraphData
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import GCNConv
from torch_geometric.nn import MessagePassing
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric.nn as pyg_nn
from BatchData import BatchData
class MyModel(torch.nn.Module):
def __init__(self,graph:GraphData):
super(MyModel, self).__init__()
self.conv1 = GCNConv(graph.num_feasures, 128) #输入=节点特征维度,16是中间隐藏神经元个数
self.conv2 = GCNConv(128, 7)
self.num_layers = 2
def forward(self, x, edge_index):
x = self.conv1(x, edge_index)
x = F.relu(x)
x = self.conv2(x, edge_index)
return F.log_softmax(x, dim=1)
class SAGEConv(MessagePassing):
def __init__(self, in_channels, out_channels, bias=True, aggr='mean'):
super(SAGEConv, self).__init__(aggr=aggr) # 使用mean聚合方式
# 线性层
self.w1 = pyg_nn.dense.linear.Linear(in_channels, out_channels, weight_initializer='glorot', bias=bias)
self.w2 = pyg_nn.dense.linear.Linear(in_channels, out_channels, weight_initializer='glorot', bias=bias)
def message(self, x_j):
# x_j [E, in_channels]
# 将邻居特征进行特征映射
wh = self.w2(x_j) # [E, out_channels]
return wh
def update(self, aggr_out, x):
# aggr_out [num_nodes, out_channels]
# 对自身节点进行特征映射
wh = self.w1(x)
return aggr_out + wh
def forward(self, x, edge_index):
return self.propagate(edge_index, x=x)
# 3.定义GraphSAGE网络
class GraphSAGE(nn.Module):
def __init__(self, num_node_features, num_classes):
super(GraphSAGE, self).__init__()
# self.gcns = []
# self.gcns.append(SAGEConv(in_channels=num_node_features,out_channels=16))
# self.gcns.append(SAGEConv(in_channels=16, out_channels=num_classes))
self.conv1 = SAGEConv(in_channels=num_node_features,
out_channels=256)
self.conv2 = SAGEConv(in_channels=256,
out_channels=num_classes)
def forward(self, data:BatchData,type):
#print(data.data.train_mask)
##nids = torch.nonzero(data.data.train_mask).reshape(-1)
#x=data.data.x
#edge_indexs = data.edge_index
nids,x, edge_indexs = data.nids,data.x, data.edge_index
nids_dict = dict(zip(nids.tolist(), torch.arange(nids.numel()).tolist()))
edge_index_local = []
edge_index_local.append(torch.tensor([
[nids_dict[int(s.item())] for s in edge_indexs[0][0]],
[nids_dict[int(s.item())] for s in edge_indexs[0][1]]
]) )
edge_index_local.append(torch.tensor([
[nids_dict[int(s.item())] for s in edge_indexs[1][0]],
[nids_dict[int(s.item())] for s in edge_indexs[1][1]]
]) )
x = self.conv1(x, edge_index_local[1])
x = F.relu(x)
if(type == 0):
x = F.dropout(x, training=self.training)
x = self.conv2(x, edge_index_local[1])
print(f'x sizes {x.size()}')
#root_nids = nids
return F.log_softmax(x[:data.batch_size], dim=1)
def inference(self, data):
nids,x, edge_indexs = data.nids,data.x, data.edge_index
# x_global = [x[nids[i]] for i in range(x.shape[0])]
nids_dict = dict(zip(nids.tolist(), torch.arange(nids.numel()).tolist()))
edge_index_local = torch.tensor([
[nids_dict[int(s.item())] for s in edge_indexs[0][0]],
[nids_dict[int(s.item())] for s in edge_indexs[0][1]]])
x = self.conv1(x, edge_index_local)
x = F.relu(x)
x = F.dropout(x, training=self.training)
x = self.conv2(x, edge_index_local)
print(f'x sizes {x.size()}')
root_nids = [nids_dict[root.item()] for root in data.roots]
return F.log_softmax(x[root_nids], dim=1)
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from model import MyModel
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group
import os
import BatchData
from part.Utils import GraphData
def ddp_setup(rank, world_size):
"""
Args:
rank: Unique identifier of each process
world_size: Total number of processes
"""
# 指定主gpu和相应的端口号用于协调不同gpu之间的通信
#os.environ["MASTER_ADDR"] = "localhost"
#os.environ["MASTER_PORT"] = "12355"
# 初始化所有的gpu 每个gpu都拥有一个进程,彼此可以互相发现
# rank是每个process的唯一标识符,world_size是总共的gpu个数
# nccl 是nvidia的通信库的后端,用于分布式通信
#init_process_group(backend="gloo", rank=rank, world_size=world_size)
init_method="tcp://{}:{}".format("localhost",12355)
init_process_group(backend="gloo", rank=rank, world_size=world_size,init_method=init_method)
class Trainer:
def __init__(
self,
model: torch.nn.Module,
train_data: DataLoader,
optimizer: torch.optim.Optimizer,
gpu_id: int,
save_every: int,
) -> None:
self.gpu_id = gpu_id
self.model = model.to('cpu')
self.train_data = train_data
self.optimizer = optimizer
self.save_every = save_every
# 需要通过DDP包装model 告诉model该复制到哪些gpu中
self.model = DDP(model)
self.total_correct = 0
self.count = 0
def _run_batch(self, batchData:BatchData):
graph = GraphData('/home/sxx/zlj/rpc_ps/part/metis_1/rank_0')
self.count = self.count +1
print(f'run epoch in batch data f {self.count}')
l = len(batchData.edge_index)
# print(f'batchData:len:{l},edge_index:{batchData.edge_index}')
self.optimizer.zero_grad()
out = self.model(batchData)
print(f'out size: {out.size()}')
print(f'batchDate.y: {batchData.y}')
# batchData.y = F.one_hot(batchData.y, num_classes=7)
print(f'roots {batchData.roots}')
print(f'y size:{batchData.y.size()}')
# print(f'mask :{batchData.train_mask}')
##loss = F.nll_loss(out, batchData.y)
y = batchData.y#torch.masked_select(graph.data.y,graph.data.train_mask)
loss = F.nll_loss(out, y)
loss.backward()
self.optimizer.step()
print(out.argmax(dim=-1))
self.total_correct += int(out.argmax(dim=-1).eq(y).sum())
##self.total_correct += int(out.argmax(dim=-1).eq(batchData.y).sum())
self.total_count += y.size(0)
print('finish')
def _run_epoch(self, epoch):
self.total_correct = 0
self.total_count = 0
for batchData in self.train_data:
self._run_batch(batchData)
approx_acc = self.total_correct / self.total_count
print(f"=======[GPU{self.gpu_id}] Epoch {epoch} | approx_acc: {approx_acc}=======")
def _save_checkpoint(self, epoch):
# 由于model现在有了一层ddp的封装,访问模型的参数需要model.module
ckp = self.model.module.state_dict()
PATH = "checkpoint.pt"
torch.save(ckp, PATH)
print(f"=======Epoch {epoch} | Training checkpoint saved at {PATH}========")
def train(self, max_epochs: int):
for epoch in range(max_epochs):
self._run_epoch(epoch)
# 对于checkpoint只保存一份即可
if self.gpu_id == 0 and epoch % self.save_every == 0:
self._save_checkpoint(epoch)
# def load_train_objs():
# train_set =
# model =
# optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
# return train_set, model, optimizer
def prepare_dataloader(dataset: Dataset, batch_size: int):
return DataLoader(
dataset,
batch_size=batch_size,
pin_memory=True,
# shuffle在sampler已经实现了
shuffle=False,
# 分布式采样能够不重合地在不同的gpu上对样本进行采样
# sampler=DistributedSampler(dataset)
)
def main(rank: int, world_size: int, save_every: int, total_epochs: int, batch_size: int):
ddp_setup(rank, world_size)
optimizer = optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
dataset, model, optimizer = load_train_objs()
train_data = prepare_dataloader(dataset, batch_size)
trainer = Trainer(model, train_data, optimizer, rank, save_every)
trainer.train(total_epochs)
# 销毁ddp 所有进程
destroy_process_group()
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description='simple distributed training job')
parser.add_argument('--total_epochs', dest='total_epochs',type=int, help='Total epochs to train the model')
parser.add_argument('--save_every', dest='save_every',type=int, help='How often to save a snapshot')
parser.add_argument('--batch_size', default=32, type=int, help='Input batch size on each device (default: 32)')
parser.add_argument('--rank', '-r', dest='rank',type=int, help='Rank of this process')
parser.add_argument('--world-size', dest='world_size',type=int, default=1, help='World size')
args = parser.parse_args()
# world_size = torch.cuda.device_count()
world_size = 1
# 一个分布式api,同时启动多个进程
# mp.spawn(main, args=(world_size, args.save_every, args.total_epochs, args.batch_size), nprocs=2)
main(args.rank,args.world_size,args.save_every,args.total_epochs,args.batch_size)
import copy
import os.path as osp
import sys
from typing import Optional
import torch
import torch.utils.data
from torch_sparse import SparseTensor, cat
import numpy as np
# def metis(edge_index: LongTensor, num_nodes: int, num_parts: int) -> Tuple[LongTensor, LongTensor]:
# if num_parts <= 1:
# return _nopart(edge_index, num_nodes)
# adj_t = SparseTensor.from_edge_index(edge_index, sparse_sizes=(num_nodes, num_nodes)).to_symmetric()
# rowptr, col, _= adj_t.csr()
# node_parts = torch.ops.torch_sparse.partition(rowptr, col, None, num_parts, num_parts < 8)
# edge_parts = node_parts[edge_index[1]]
# return node_parts, edge_parts
class GPDataset(torch.utils.data.Dataset):
def __init__(self, data, num_parts: int, recursive: bool = False,
save_dir: Optional[str] = None, log: bool = True):
assert data.edge_index is not None
self.num_parts = num_parts
recursive_str = '_recursive' if recursive else ''
filename = f'partition_{num_parts}{recursive_str}.pt'
path = osp.join(save_dir or '', filename)
if save_dir is not None and osp.exists(path):
adj, partptr, perm = torch.load(path)
else:
if log: # pragma: no cover
print('Computing METIS partitioning...', file=sys.stderr)
N, E = data.num_nodes, data.num_edges
adj = SparseTensor(
row=data.edge_index[0], col=data.edge_index[1],
value=torch.arange(E, device=data.edge_index.device),
sparse_sizes=(N, N))
adj, partptr, perm = adj.partition(num_parts, recursive)
# self.global_adj = adj
if save_dir is not None:
torch.save((adj, partptr, perm), path)
if log: # pragma: no cover
print('Done!', file=sys.stderr)
# 对于所有的点属性重排
self.data = self.__permute_data__(data, perm, adj)
self.global_adj = adj
self.partptr = partptr
perm_ = torch.zeros_like(perm)
for i in range(len(perm)):
perm_[perm[i]] = i
self.perm = torch.stack([perm,perm_])
def __permute_data__(self, data, node_idx, adj):
out = copy.copy(data)
for key, value in data.items():
if data.is_node_attr(key):
out[key] = value[node_idx]
row, col, _ = adj.coo()
out.edge_index = torch.stack([row, col], dim=0)
out.adj = adj
return out
def __len__(self):
return self.partptr.numel() - 1
def __getitem__(self, idx):
# 第idx个分区的起始id和分区的id数量
start = int(self.partptr[idx])
length = int(self.partptr[idx + 1]) - start
N, E = self.data.num_nodes, self.data.num_edges
data = copy.copy(self.data)
del data.num_nodes
adj, data.adj = data.adj, None
# 将邻接矩阵进行切分
adj = adj.narrow(0, start, length).narrow(1, start, length)
#
edge_idx = adj.storage.value()
for key, item in data:
if isinstance(item, torch.Tensor) and item.size(0) == N:
data[key] = item.narrow(0, start, length)
elif isinstance(item, torch.Tensor) and item.size(0) == E:
data[key] = item[edge_idx]
else:
data[key] = item
row, col, _ = adj.coo()
data.edge_index = torch.stack([row, col], dim=0)
return data
def __repr__(self):
return (f'{self.__class__.__name__}(\n'
f' data={self.data},\n'
f' num_parts={self.num_parts}\n'
f')')
import os.path as osp
import torch
class GraphData():
def __init__(self, path):
assert path is not None and osp.exists(path),'path 不存在'
id,edge_index,data,partptr =torch.load(path)
# 当前分区序号
self.partition_id = id
# 总分区数
self.partitions = partptr.numel() - 1
#num_feature
self.num_feasures = data.x.size()[1]
# 全图结构数据
self.num_nodes = partptr[self.partitions]
self.num_edges = edge_index[0].numel()
self.edge_index = edge_index
# 该分区下的数据(包含特征向量和子图结构)pyg Data数据结构
self.data = data
# 分区映射关系
self.partptr = partptr
self.data.y = self.data.y.reshape(-1)
def select_attr(self,index):
return torch.index_select(self.data.x,0,index)
#返回全局的节点id 所对应的分区
def get_part_num(self):
return self.data.x.size()[0]
def select_attr(self,index):
return torch.index_select(self.data.x,0,index)
def select_y(self,index):
return torch.index_select(self.data.y,0,index)
#返回全局的节点id 所对应的分区
def get_localId_by_partitionId(self,id,index):
if(id == -1 or id == 0):
return index
else:
return torch.add(index,-self.partptr[id])
def get_globalId_by_partitionId(self,id,index):
if(id == -1 or id == 0):
return index
else:
return torch.add(index,self.partptr[id])
def get_node_num(self):
return self.num_nodes
def localId_to_globalId(self,id,partitionId:int = -1):
'''
将分区partitionId内的点id映射为全局的id
'''
if partitionId == -1:
partitionId = self.partition_id
assert id >=self.partptr[self.partition_id] and id < self.partptr[self.partition_id+1]
ids_before = 0
if self.partition_id>0:
ids_before = self.partptr[self.partition_id-1]
return id+ids_before
def get_partitionId_by_globalId(self,id):
'''
通过全局id得到对应的分区序号
'''
partitionId = -1
assert id>=0 and id<self.num_nodes,'id 超过范围'
for i in range(self.partitions):
if id>=self.partptr[i] and id<self.partptr[i+1]:
partitionId = i
break
assert partitionId>=0, 'id 不存在对应的分区'
return partitionId
def get_nodes_by_partitionId(self,id):
'''
根据partitioId 返回该分区的节点数量
'''
assert id>=0 and id<self.partitions,'partitionId 非法'
return (int)(self.partptr[id+1]-self.partptr[id])
def __repr__(self):
return (f'{self.__class__.__name__}(\n'
f' partition_id={self.partition_id}\n'
f' data={self.data},\n'
f' global_info('
f'num_nodes={self.num_nodes},'
f' num_edges={self.num_edges},'
f' num_parts={self.partitions},'
f' num_features={self.num_feasures}, '
f' edge_index=[2,{self.edge_index[0].numel()}])\n'
f')')
from Utils import GraphData
from torch_geometric.data import Data
g = GraphData('./rpc_ps/part/metis_4/rank_0')
x= Data()
print(g)
'''
GraphData(
partition_id=0
data=Data(x=[679, 1433], edge_index=[2, 2908], y=[679], train_mask=[679], val_mask=[679], test_mask=[679]),
global_info(num_nodes=2029, num_edges=10556, num_parts=4, edge_index=[2,10556])
)
'''
from enum import Enum
from multiprocessing.connection import Listener
import pickle
import queue
from threading import Thread
from multiprocessing import Pool, Lock, Value
from torch.distributed.rpc import TensorPipeRpcBackendOptions
from torch.distributed import rpc
import graph_store
class RpcCommand(Enum):
SET_GRAPH = 0
CALL_BARRIER = 1
UNLOAD_GRAPH =2
STOP_RPC = 3
"""
class RPCHandler:
def __init__(self):
self._functions= { }
def register_function(self,func):
self._functions[func.__name__] = func
def handle_connection(self,connection):
try:
while True:
func_name,args,kwargs = pickle.loads(connection.recv())
try:
r=self._functions[func_name] (*args,**kwargs)
connection.send(pickle.dumps(r))
except Exception as e:
connection.send(pickle.dumps(e))
except EOFError:
pass
def start_rpc_server(handler,address,queue):
sock = Listener(address)
queue.put(RpcCommand.FINISH_LISTEN)
while True:
client = sock.accept()
now_thread=Thread(target = handler.handle_connection, args=(client,))
now_thread.daemon = True
now_thread.start()
"""
NUM_WORKER_THREADS = 128
def start_rpc_listener(rpc_master,rpc_port,rpc_config,mp_config):
print(rpc_config)
num_worker_threads = rpc_config.get("num_worker_threads", NUM_WORKER_THREADS)
rpc_name = rpc_config.get("rpc_name", "rpcserver{}")
world_size = rpc_config.get("rpc_world_size", 2)
worker_rank = rpc_config.get("worker_rank", 0)
rpc_rank = rpc_config.get("rpc_worker_rank", 0)
rpc_backend_options = TensorPipeRpcBackendOptions()
rpc_backend_options.init_method = "tcp://{}:{}".format(rpc_master,rpc_port)
rpc_backend_options.num_worker_threads = num_worker_threads
task_queue,barrier = mp_config
print(rpc_master,rpc_port)
rpc.init_rpc(
rpc_name.format(worker_rank),
rank=rpc_rank,
world_size=world_size,
rpc_backend_options=rpc_backend_options
)
keep_pooling = True
while(keep_pooling):
try:
command,args = task_queue.get(timeout=5)
except queue.Empty:
continue
if command == RpcCommand.SET_GRAPH:
dataloader_name,graph_inshm= args
graph_store._set_graph(dataloader_name,graph_inshm)
elif command == RpcCommand.UNLOAD_GRAPH:
dataloader_name= args
graph_inshm = graph_store._get_graph(dataloader_name)
graph_store._del_graph(dataloader_name)
graph_inshm._close_graph_in_shame()
print('unload dataloader')
barrier.wait()
elif command == RpcCommand.CALL_BARRIER:
barrier.wait()
elif command == RpcCommand.STOP_RPC:
keep_pooling = False
graph_store._clear_all()
barrier.wait()
close_rpc()
def start_rpc_caller(rpc_master,rpc_port,rpc_config):
print(rpc_config)
num_worker_threads = rpc_config.get("num_worker_threads",NUM_WORKER_THREADS)
rpc_name = rpc_config.get("rpc_name","rpcserver{}")
world_size = rpc_config.get("rpc_world_size", 2)
worker_rank = rpc_config.get("worker_rank", 0)
rpc_rank = rpc_config.get("rpc_worker_rank", 0)
rpc_backend_options = TensorPipeRpcBackendOptions()
rpc_backend_options.init_method = "tcp://{}:{}".format(rpc_master,rpc_port)
rpc_backend_options.num_worker_threads = num_worker_threads
rpc.init_rpc(
rpc_name.format(worker_rank) + "-{}".format(rpc_rank),
rank=rpc_rank,
world_size=world_size,
rpc_backend_options=rpc_backend_options
)
def close_rpc():
rpc.shutdown(True)
"""
class RPCProxy:
def __init__(self, connection):
self._connection = connection
def __getattr__(self, name):
def do_rpc(*args, **kwargs):
self._connection.send(pickle.dumps((name, args, kwargs)))
result = pickle.loads(self._connection.recv())
if isinstance(result, Exception):
raise result
return result
return do_rpc
"""
\ No newline at end of file
import numpy as np
import torch
from multiprocessing import shared_memory
def _copy_to_share_memory(data):
data_array=data.numpy()
shm = shared_memory.SharedMemory(create=True, size=data_array.nbytes)
data_share = np.ndarray(data_array.shape, dtype=data_array.dtype, buffer=shm.buf)
np.copyto(data_share,data_array)
name = shm.name
_close_existing_shm(shm)
#data_share.copy_(data_array) # Copy the original data into shared memory
return name,data_array.shape,data_array.dtype
def _copy_to_shareable_list(data):
shm = shared_memory.ShareableList(data)
name = shm.shm.name
return name
def _get_from_shareable_list(name):
return shared_memory.ShareableList(name=name)
def _get_existing_share_memory(name):
return shared_memory.SharedMemory(name=name)
def _get_from_share_memory(existing_shm,data_shape,data_dtype):
data = np.ndarray(data_shape, dtype=data_dtype, buffer=existing_shm.buf)
return torch.from_numpy(data)
def _close_existing_shm(existing_shm):
if(existing_shm != None):
existing_shm.close()
def _unlink_existing_shm(existing_shm):
if(existing_shm != None):
existing_shm.unlink()
\ No newline at end of file
import torch
from part.Utils import GraphData
from share_memory_util import _close_existing_shm, _copy_to_share_memory, _copy_to_shareable_list, _get_existing_share_memory, _get_from_share_memory, _get_from_shareable_list, _unlink_existing_shm
from Sample.neighbor_sampler import NeighborSampler
class GraphInfoInShm(GraphData):
def __init__(self,graph):
self.partition_id = graph.partition_id
self.partitions = graph.partitions
self.num_nodes = graph.num_nodes
self.num_edges = graph.num_edges
#self.edge_index_info = _copy_to_share_memory(graph.edge_index)
self.partptr = graph.partptr
self.data_x_info=_copy_to_share_memory(graph.data.x)
self.data_y_info=_copy_to_share_memory(graph.data.y)
self.edge_index_info=_copy_to_share_memory(graph.edge_index)
def _get_graph_from_shm(self):
#self.edge_index_shm = _get_existing_share_memory(self.edge_index_info[0])
#self.edge_index = _get_existing_share_memory(self.edge_index_shm,self.edge_index_info[1:])
self.data_x_shm = _get_existing_share_memory(self.data_x_info[0])
self.data_x = _get_from_share_memory(self.data_x_shm,*self.data_x_info[1:])
self.data_y_shm = _get_existing_share_memory(self.data_y_info[0])
self.data_y = _get_from_share_memory(self.data_y_shm,*self.data_y_info[1:])
self.edge_index_shm = _get_existing_share_memory(self.edge_index_info[0])
self.edge_index = _get_from_share_memory(self.edge_index_shm,*self.edge_index_info[1:])
def _get_sampler_from_shame(self):
self.deg_shm = _get_from_shareable_list(*self.deg_info)#_get_existing_share_memory(self.deg_info)
self.neighbors_shm = _get_from_shareable_list(*self.neighbors_info)
self.deg = self.deg_shm
self.neighbors = self.neighbors[0]
def _copy_sampler_to_shame(self,neighbors,deg):
self.deg_info = _copy_to_shareable_list(deg)#_copy_to_share_memory(deg)
self.neighbors_info = _copy_to_shareable_list(neighbors)#_copy_to_share_memory(neighbors)
def _close_graph_in_shame(self):
#_close_existing_shm(self.edge_index_shm)
#_close_existing_shm(self.deg_shm.shm)
#_close_existing_shm(self.neighbors_shm.shm)
_close_existing_shm(self.data_x_shm)
_close_existing_shm(self.data_y_shm)
_close_existing_shm(self.edge_index_shm)
def _unlink_graph_in_shame(self):
#_unlink_existing_shm(self.edge_index_shm)
_unlink_existing_shm(self.data_x_shm)
_unlink_existing_shm(self.data_y_shm)
_unlink_existing_shm(self.edge_index_shm)
#返回全局的节点id 所对应的分区数量
def get_part_num(self):
return self.data_x.size()[0]
def select_attr(self,index):
return torch.index_select(self.data_x,0,index)
def select_y(self,index):
return torch.index_select(self.data_y,0,index)
#返回全局的节点id 所对应的分区
\ No newline at end of file
import argparse
import os
from DistGraphLoader import partition_load
path1=os.path.abspath('.')
import torch
from Sample.neighbor_sampler import NeighborSampler
from Sample.neighbor_sampler import get_neighbors
from part.Utils import GraphData
from DistGraphLoader import DistGraphData
from DistGraphLoader import DistributedDataLoader
from torch_geometric.data import Data
import distparser
from DistCustomPool import CustomPool
import DistCustomPool
from torch.distributed import rpc
from torch.distributed.rpc import RRef, rpc_async, remote
from torch.distributed.rpc import TensorPipeRpcBackendOptions
import time
from model import GraphSAGE
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group
import os
import BatchData
from part.Utils import GraphData
from torch_geometric.sampler.neighbor_sampler import NeighborSampler as PYGSampler
"""
test command
python test.py --world_size 2 --rank 0
--world_size', default=4, type=int, metavar='W',
help='number of workers')
parser.add_argument('--rank', default=0, type=int, metavar='W',
help='rank of the worker')
parser.add_argument('--log_interval', type=int, default=10, metavar='N',
help='interval between training status logs')
parser.add_argument('--gamma', type=float, default=0.99, metavar='G',
help='how much to value future rewards')
parser.add_argument('--seed', type=int, default=1, metavar='S',
help='random seed for reproducibility')
parser.add_argument('--num_sampler', type=int, default=10, metavar='S',
help='number of samplers')
parser.add_argument('--queue_size', type=int, default=10, metavar='S',
help='sampler queue size')
"""
sage_neighsampler_parameters = {'lr':0.003
, 'num_layers':2
, 'hidden_channels':128
, 'dropout':0.0
, 'l2':5e-7
}
class Trainer:
def __init__(
self,
model: torch.nn.Module,
train_data: DataLoader,
test_data:DataLoader,
optimizer: torch.optim.Optimizer,
gpu_id: int,
save_every: int,
) -> None:
self.gpu_id = gpu_id
self.model = model.to('cpu')
self.train_data = train_data
self.test_data = test_data
self.optimizer = optimizer
self.save_every = save_every
# 需要通过DDP包装model 告诉model该复制到哪些gpu中
self.model = DDP(model)
self.total_correct = 0
self.count = 0
def _run_test(self,batchData:BatchData):
self.count = self.count +1
print(f'run epoch in batch data f {self.count}')
l = len(batchData.edge_index)
# print(f'batchData:len:{l},edge_index:{batchData.edge_index}')
out = self.model(batchData,1)
# print(f'out size: {out.size()}')
# print(f'batchDate.y: {batchData.y}')
# # batchData.y = F.one_hot(batchData.y, num_classes=7)
# print(f'roots {batchData.roots}')
# print(f'y size:{batchData.y.size()}')
# print(f'mask :{batchData.train_mask}')
##loss = F.nll_loss(out, batchData.y)
y = batchData.y#torch.masked_select(graph.data.y,graph.data.train_mask)
loss = F.nll_loss(out, y)
self.total_correct += int(out.argmax(dim=-1).eq(y).sum())
##self.total_correct += int(out.argmax(dim=-1).eq(batchData.y).sum())
self.total_count += y.size(0)
def _run_batch(self, batchData:BatchData):
# graph = GraphData('/home/sxx/zlj/rpc_ps/part/metis_1/rank_0')
self.count = self.count +1
print(f'run epoch in batch data f {self.count}')
l = len(batchData.edge_index)
# print(f'batchData:len:{l},edge_index:{batchData.edge_index}')
self.optimizer.zero_grad()
out = self.model(batchData,0)
# print(f'out size: {out.size()}')
# print(f'batchDate.y: {batchData.y}')
# # batchData.y = F.one_hot(batchData.y, num_classes=7)
# print(f'roots {batchData.roots}')
# print(f'y size:{batchData.y.size()}')
# print(f'mask :{batchData.train_mask}')
##loss = F.nll_loss(out, batchData.y)
y = batchData.y#torch.masked_select(graph.data.y,graph.data.train_mask)
loss = F.nll_loss(out, y)
loss.backward()
self.optimizer.step()
print(out.argmax(dim=-1))
self.total_correct += int(out.argmax(dim=-1).eq(y).sum())
##self.total_correct += int(out.argmax(dim=-1).eq(batchData.y).sum())
self.total_count += y.size(0)
print('finish')
def _run_epoch(self, epoch):
self.total_correct = 0
self.total_count = 0
for batchData in self.train_data:
self._run_batch(batchData)
approx_acc = self.total_correct / self.total_count
print(f"=======[GPU{self.gpu_id}] Epoch {epoch} | approx_acc: {approx_acc}=======")
self.total_correct = 0
self.total_count = 0
for batchData in self.test_data:
self._run_test(batchData)
approx_acc = self.total_correct / self.total_count
print(f"=======[GPU{self.gpu_id}] Epoch {epoch} | test_approx_acc: {approx_acc}=======")
def _save_checkpoint(self, epoch):
# 由于model现在有了一层ddp的封装,访问模型的参数需要model.module
ckp = self.model.module.state_dict()
PATH = "checkpoint.pt"
torch.save(ckp, PATH)
print(f"=======Epoch {epoch} | Training checkpoint saved at {PATH}========")
def train(self, max_epochs: int):
for epoch in range(max_epochs):
self._run_epoch(epoch)
# 对于checkpoint只保存一份即可
if self.gpu_id == 0 and epoch % self.save_every == 0:
self._save_checkpoint(epoch)
@torch.no_grad()
def test(layer_loader, model, data, split_idx, device, no_conv=False):
# data.y is labels of shape (N, )
model.eval()
out = model.inference(data.x, layer_loader, device)
# out = model.inference_all(data)
y_pred = out.exp() # (N,num_classes)
losses = dict()
for key in ['train', 'valid', 'test']:
node_id = split_idx[key]
node_id = node_id.to(device)
losses[key] = F.nll_loss(out[node_id], data.y[node_id]).item()
return losses, y_pred
def main():
parser = distparser.parser.add_subparsers().add_parser("train")#argparse.ArgumentParser(description='minibatch_gnn_models')
parser.add_argument('--device', type=str, default='cpu')
parser.add_argument('--dataset', type=str, default='Cora')
parser.add_argument('--log_steps', type=int, default=10)
parser.add_argument('--model', type=str, default='sage_neighsampler')
parser.add_argument('--epochs', type=int, default=100)
parser.add_argument('--rank', default=0, type=int, metavar='W',
help='rank of the worker')
parser.add_argument('--world_size', default=1, type=int, metavar='W',
help='number of the gpus')
args = parser.parse_args()
print(args)
#init_method="tcp://{}:{}".format('127.0.0.1',10031)
#dist.init_process_group(backend="gloo", world_size=args.world_size, rank=args.rank,init_method=init_method)
DistCustomPool.init_distribution('127.0.0.1',9673,'127.0.0.1',10011,backend = "gloo")
#graph = DistGraphData('./part/metis_1')
#graph = DistGraphData('/home/sxx/pycode/work/ogbn-products/metis_2')
pdata = partition_load("../../cora", algo="metis")
graph = DistGraphData(pdata = pdata,edge_index= pdata.edge_index)
row, col = graph.edge_index
tnb = get_neighbors(row.contiguous(), col.contiguous(), graph.num_nodes)
sampler = NeighborSampler(graph.num_nodes, num_layers=2, fanout=[10,5], workers=10, tnb=tnb)
loader = DistributedDataLoader('train',graph,graph.data.train_mask,sampler = sampler,batch_size = 100,shuffle=True)
testloader = DistributedDataLoader('test',graph,graph.data.test_mask,sampler = sampler,batch_size = 100,shuffle=True)
#count_node = 0
#count_edge = 0
#count_x_byte = 0
#start_time = time.time()
#cnt = 0
#for batchData in loader:
# cnt = cnt+1
# count_node += batchData.nids.size(0)
# count_x_byte += batchData.x.numel()*batchData.x.element_size()
# for edge_list in batchData.edge_index:
# count_edge += edge_list.size(1)
# #count_edge += batchData.edge_index.size(1)
# dt = time.time() - start_time
# print('{} count node {},count edge {}, node TPS {},edge TPS {}, x size {}, x TPS {} byte'
# .format(cnt,count_node,count_edge,count_node/dt,count_edge/dt,count_x_byte,count_x_byte/dt),batchData.x.size(0),batchData.x.element_size(),batchData.x.numel())
epochs = args.epochs
if args.model == 'sage_neighsampler':
para_dict = sage_neighsampler_parameters
num_classes = 7
num_node_features = graph.num_feasures
model = GraphSAGE(num_node_features,num_classes)
num_node_features = graph.num_feasures
print(f'Model {args.model} initialized')
epochs = args.epochs
optimizer = torch.optim.Adam(model.parameters(), lr=para_dict['lr'])
model.train()
trainer = Trainer(model, loader,testloader, optimizer, graph.rank, 5)
trainer.train(epochs)
#
DistCustomPool.close_distribution()
if __name__ == "__main__":
main()
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment