Commit 64861da6 by xxx

manage import

parent c71ae1f5
import os.path as osp
import torch
class GraphData():
def __init__(self, path):
assert path is not None and osp.exists(path),'path 不存在'
id,edge_index,data,partptr =torch.load(path)
# 当前分区序号
self.partition_id = id
# 总分区数
self.partitions = partptr.numel() - 1
# 全图结构数据
self.num_nodes = partptr[self.partitions]
self.num_edges = edge_index[0].numel()
self.edge_index = edge_index
# 该分区下的数据(包含特征向量和子图结构)pyg Data数据结构
self.data = data
# 分区映射关系
self.partptr = partptr
self.eid = [i for i in range(self.num_edges)]
def __init__(self, id, edge_index, data, partptr, timestamp=None):
# 当前分区序号
self.partition_id = id
# 总分区数
self.partitions = partptr.numel() - 1
# 全图结构数据
self.num_nodes = partptr[self.partitions]
if edge_index is not None:
self.num_edges = edge_index[0].numel()
self.edge_index = edge_index
self.edge_ts = timestamp
# 该分区下的数据(包含特征向量和子图结构)pyg Data数据结构
self.data = data
# 分区映射关系
self.partptr = partptr
# edge id
self.eid = torch.tensor([i for i in range(0, self.num_edges)])
def select_attr(self,index):
return torch.index_select(self.data.x,0,index)
#返回全局的节点id 所对应的分区
def get_part_num(self):
return self.data.x.size()[0]
def select_attr(self,index):
return torch.index_select(self.data.x,0,index)
def select_y(self,index):
return torch.index_select(self.data.y,0,index)
#返回全局的节点id 所对应的分区
def get_localId_by_partitionId(self,id,index):
#print(index)
if(id == -1 or id == 0):
return index
else:
return torch.add(index,-self.partptr[id])
def get_globalId_by_partitionId(self,id,index):
if(id == -1 or id == 0):
return index
else:
return torch.add(index,self.partptr[id])
def get_node_num(self):
return self.num_nodes
def localId_to_globalId(self,id,partitionId:int = -1):
'''
将分区partitionId内的点id映射为全局的id
'''
if partitionId == -1:
partitionId = self.partition_id
assert id >=self.partptr[self.partition_id] and id < self.partptr[self.partition_id+1]
ids_before = 0
if self.partition_id>0:
ids_before = self.partptr[self.partition_id-1]
return id+ids_before
def get_partitionId_by_globalId(self,id):
'''
通过全局id得到对应的分区序号
'''
partitionId = -1
assert id>=0 and id<self.num_nodes,'id 超过范围'
for i in range(self.partitions):
if id>=self.partptr[i] and id<self.partptr[i+1]:
partitionId = i
break
assert partitionId>=0, 'id 不存在对应的分区'
return partitionId
def get_nodes_by_partitionId(self,id):
'''
根据partitioId 返回该分区的节点数量
'''
assert id>=0 and id<self.partitions,'partitionId 非法'
return (int)(self.partptr[id+1]-self.partptr[id])
def __repr__(self):
return (f'{self.__class__.__name__}(\n'
f' partition_id={self.partition_id}\n'
f' data={self.data},\n'
f' global_info('
f'num_nodes={self.num_nodes},'
f' num_edges={self.num_edges},'
f' num_parts={self.partitions},'
f' edge_index=[2,{self.edge_index[0].numel()}])\n'
f')')
from enum import Enum
import sys
import argparse
from os.path import abspath, join, dirname
import time
sys.path.insert(0, join(abspath(dirname(__file__))))
class SampleType(Enum):
Whole = 0
Inner = 1
Outer =2
parser = argparse.ArgumentParser(
description="RPC Reinforcement Learning Example",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument('--world_size', default=1, type=int, metavar='W',
help='number of workers')
parser.add_argument('--rank', default=0, type=int, metavar='W',
help='rank of the worker')
parser.add_argument('--log_interval', type=int, default=10, metavar='N',
help='interval between training status logs')
parser.add_argument('--gamma', type=float, default=0.99, metavar='G',
help='how much to value future rewards')
parser.add_argument('--seed', type=int, default=1, metavar='S',
help='random seed for reproducibility')
parser.add_argument('--num_sampler', type=int, default=1, metavar='S',
help='number of samplers')
parser.add_argument('--queue_size', type=int, default=10000, metavar='S',
help='sampler queue size')
#parser = distparser.parser.add_subparsers().add_parser("train")#argparse.ArgumentParser(description='minibatch_gnn_models')
parser.add_argument('--data', type=str, help='dataset name')
parser.add_argument('--config', type=str, help='path to config file')
parser.add_argument('--gpu', type=str, default='0', help='which GPU to use')
parser.add_argument('--model_name', type=str, default='', help='name of stored model')
parser.add_argument('--rand_edge_features', type=int, default=0, help='use random edge featrues')
parser.add_argument('--rand_node_features', type=int, default=0, help='use random node featrues')
parser.add_argument('--eval_neg_samples', type=int, default=1, help='how many negative samples to use at inference. Note: this will change the metric of test set to AP+AUC to AP+MRR!')
args = parser.parse_args()
rpc_proxy=None
WORKER_RANK = args.rank
NUM_SAMPLER = args.num_sampler
WORLD_SIZE = args.world_size
QUEUE_SIZE = args.queue_size
MAX_QUEUE_SIZE = 5*args.queue_size
RPC_NAME = "rpcserver{}"
SAMPLE_TYPE = SampleType.Outer
def _get_worker_rank():
return WORKER_RANK
def _get_num_sampler():
return NUM_SAMPLER
def _get_world_size():
return WORLD_SIZE
def _get_RPC_NAME():
return RPC_NAME
def _get_queue_size():
return QUEUE_SIZE
def _get_max_queue_size():
return MAX_QUEUE_SIZE
def _get_rpc_name():
return RPC_NAME
\ No newline at end of file
import sys
from os.path import abspath, join, dirname
import time
sys.path.insert(0, join(abspath(dirname(__file__))))
graph_set={}
def _clear_all(barrier = None):
global graph_set
for key in graph_set:
graph = graph_set[key]
graph._close_graph_in_shame()
print('clear ',key)
if(barrier is not None and barrier.wait()==0):
graph._unlink_graph_in_shame()
graph_set = {}
def _set_graph(graph_name,graph_info):
graph_info._get_graph_from_shm()
graph_set[graph_name]=graph_info
def _get_graph(graph_name):
return graph_set[graph_name]
def _del_graph(graph_name):
graph_set.pop(graph_name)
def _get_size():
return len(graph_set)
# local_sampler=None
local_sampler = {}
def set_local_sampler(graph_name,sampler):
local_sampler[graph_name] = sampler
def get_local_sampler(sampler_name):
assert sampler_name in local_sampler, 'Local_sampler doesn\'t has sampler_name'
return local_sampler[sampler_name]
\ No newline at end of file
...@@ -7,18 +7,16 @@ import torch ...@@ -7,18 +7,16 @@ import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
from typing import Optional, Tuple from typing import Optional, Tuple
import graph_store from .base import BaseSampler, NegativeSampling, SampleOutput, SampleType
from distparser import SampleType, NUM_SAMPLER
from base import BaseSampler, NegativeSampling, SampleOutput
# from sample_cores import ParallelSampler, get_neighbors, heads_unique # from sample_cores import ParallelSampler, get_neighbors, heads_unique
from starrygl.lib.libstarrygl_ops_sampler import ParallelSampler, get_neighbors from starrygl.lib.libstarrygl_ops_sampler import ParallelSampler, get_neighbors
from torch.distributed.rpc import rpc_async from torch.distributed.rpc import rpc_async
def outer_sample(graph_name, nodes, ts, fanout_index, with_outer_sample = SampleType.Outer):# 默认此时继续向外采样 # def outer_sample(graph_name, nodes, ts, fanout_index, with_outer_sample = SampleType.Outer):# 默认此时继续向外采样
local_sampler = graph_store.get_local_sampler(graph_name) # local_sampler = get_local_sampler(graph_name)
assert local_sampler is not None, 'Local_sampler is None!!!' # assert local_sampler is not None, 'Local_sampler is None!!!'
out = local_sampler.sample_from_nodes(nodes, with_outer_sample, ts, fanout_index) # out = local_sampler.sample_from_nodes(nodes, with_outer_sample, ts, fanout_index)
return out # return out
class NeighborSampler(BaseSampler): class NeighborSampler(BaseSampler):
def __init__( def __init__(
......
...@@ -122,13 +122,18 @@ def test(): ...@@ -122,13 +122,18 @@ def test():
sam_edge = 0 sam_edge = 0
pre = time.time() pre = time.time()
min_than_ten = 0
min_than_ten_sum = 0
seed_node_sum = 0
for _, rows in tqdm(df.groupby(df.index // args.batch_size), total=len(df) // args.batch_size): for _, rows in tqdm(df.groupby(df.index // args.batch_size), total=len(df) // args.batch_size):
# root_nodes = torch.tensor(np.concatenate([rows.src.values, rows.dst.values, neg_link_sampler.sample(len(rows))])).long() # root_nodes = torch.tensor(np.concatenate([rows.src.values, rows.dst.values, neg_link_sampler.sample(len(rows))])).long()
# ts = torch.tensor(np.concatenate([rows.time.values, rows.time.values, rows.time.values]).astype(np.float32)) # ts = torch.tensor(np.concatenate([rows.time.values, rows.time.values, rows.time.values]).astype(np.float32))
# outi = sampler.sample_from_nodes(root_nodes, ts=ts) # outi = sampler.sample_from_nodes(root_nodes, ts=ts)
edges = torch.tensor(np.stack([rows.src.values, rows.dst.values])).long() edges = torch.tensor(np.stack([rows.src.values, rows.dst.values])).long()
outi, meta = sampler.sample_from_edges(edges=edges, ets=torch.tensor(rows.time.values).float(), neg_sampling=neg_link_sampler) outi, meta = sampler.sample_from_edges(edges=edges, ets=torch.tensor(rows.time.values).float(), neg_sampling=neg_link_sampler)
# min_than_ten += (torch.tensor(tnb.deg)[meta['seed']]<10).sum()
# min_than_ten_sum += ((torch.tensor(tnb.deg)[meta['seed']])[torch.tensor(tnb.deg)[meta['seed']]<10]).sum()
# seed_node_sum += meta['seed'].size(0)
tot_time += outi[0].tot_time tot_time += outi[0].tot_time
sam_time += outi[0].sample_time sam_time += outi[0].sample_time
# print(outi[0].sample_edge_num) # print(outi[0].sample_edge_num)
...@@ -149,6 +154,10 @@ def test(): ...@@ -149,6 +154,10 @@ def test():
# print('node:', out[23][1].sample_nodes) # print('node:', out[23][1].sample_nodes)
# print('node_ts:', out[23][1].sample_nodes_ts) # print('node_ts:', out[23][1].sample_nodes_ts)
# print('edge_index_list:', out[0][0].edge_index) # print('edge_index_list:', out[0][0].edge_index)
# print("min_than_ten", min_than_ten)
# print("min_than_ten_sum", min_than_ten_sum)
# print("seed_node_sum", seed_node_sum)
# print("predict edge_num", (seed_node_sum-min_than_ten)*9+min_than_ten_sum)
if __name__ == "__main__": if __name__ == "__main__":
test() test()
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment