Commit 7f481360 by xxx

Merge branch 'master' into hzq

parents cd3f3cd2 6acf7ed1
install.sh merge=ours
\ No newline at end of file
*.tgz
*.my
# Byte-compiled / optimized / DLL files # Byte-compiled / optimized / DLL files
__pycache__/ __pycache__/
*.py[cod] *.py[cod]
...@@ -169,8 +171,12 @@ cython_debug/ ...@@ -169,8 +171,12 @@ cython_debug/
/third_party /third_party
/.vscode /.vscode
/.history
/.cache
/run_route.py /run_route.py
/dataset /dataset
/test_* /test_*
/*.ipynb /*.ipynb
saved_models/
saved_checkpoints/
\ No newline at end of file
sampling:
- layer: 1
neighbor:
- 10
strategy: 'recent'
prop_time: False
history: 1
duration: 0
num_thread: 32
memory:
- type: 'node'
dim_time: 100
deliver_to: 'self'
mail_combine: 'last'
memory_update: 'gru'
mailbox_size: 1
combine_node_feature: True
dim_out: 100
gnn:
- arch: 'transformer_attention'
use_src_emb: True
use_dst_emb: True
layer: 1
att_head: 2
dim_time: 100
dim_out: 100
train:
- epoch: 50
batch_size: 100
# reorder: 16
lr: 0.0001
dropout: 0.1
att_dropout: 0.2
all_on_gpu: True
\ No newline at end of file
sampling:
- no_sample: True
history: 1
memory:
- type: 'node'
dim_time: 100
deliver_to: 'self'
mail_combine: 'last'
memory_update: 'rnn'
mailbox_size: 1
combine_node_feature: True
dim_out: 100
gnn:
- arch: 'identity'
use_src_emb: False
use_dst_emb: False
time_transform: 'JODIE'
train:
- epoch: 20
batch_size: 200
lr: 0.0001
dropout: 0.1
all_on_gpu: True
\ No newline at end of file
sampling:
- layer: 2
neighbor:
- 10
- 10
strategy: 'uniform'
prop_time: False
history: 1
duration: 0
num_thread: 32
memory:
- type: 'none'
dim_out: 0
gnn:
- arch: 'transformer_attention'
layer: 2
att_head: 2
dim_time: 100
dim_out: 100
train:
- epoch: 100
batch_size: 600
lr: 0.0001
dropout: 0.1
att_dropout: 0.1
all_on_gpu: True
\ No newline at end of file
...@@ -18,13 +18,15 @@ memory: ...@@ -18,13 +18,15 @@ memory:
dim_out: 100 dim_out: 100
gnn: gnn:
- arch: 'transformer_attention' - arch: 'transformer_attention'
use_src_emb: False
use_dst_emb: False
layer: 1 layer: 1
att_head: 2 att_head: 2
dim_time: 100 dim_time: 100
dim_out: 100 dim_out: 100
train: train:
- epoch: 5 - epoch: 20
#batch_size: 100 batch_size: 200
# reorder: 16 # reorder: 16
lr: 0.0001 lr: 0.0001
dropout: 0.2 dropout: 0.2
......
sampling:
- layer: 1
neighbor:
- 10
strategy: 'recent'
prop_time: False
history: 1
duration: 0
num_thread: 32
memory:
- type: 'node'
dim_time: 100
deliver_to: 'self'
mail_combine: 'last'
memory_update: 'gru'
mailbox_size: 1
combine_node_feature: True
dim_out: 100
gnn:
- arch: 'transformer_attention'
use_src_emb: True
use_dst_emb: True
layer: 1
att_head: 2
dim_time: 100
dim_out: 100
train:
- epoch: 20
batch_size: 200
# reorder: 16
lr: 0.0001
dropout: 0.2
att_dropout: 0.2
all_on_gpu: True
\ No newline at end of file
#include<head.h> #include<head.h>
#include <sampler.h> #include <sampler.h>
#include <tppr.h>
#include <output.h> #include <output.h>
#include <neighbors.h> #include <neighbors.h>
#include <temporal_utils.h> #include <temporal_utils.h>
...@@ -88,4 +89,22 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) ...@@ -88,4 +89,22 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
.def("reset", &ParallelSampler::reset) .def("reset", &ParallelSampler::reset)
.def("get_ret", [](const ParallelSampler &ps) { return ps.ret; }); .def("get_ret", [](const ParallelSampler &ps) { return ps.ret; });
py::class_<ParallelTppRComputer>(m, "ParallelTppRComputer")
.def(py::init<TemporalNeighborBlock &, NodeIDType, EdgeIDType, int,
int, int, int, vector<float>&, vector<float>& >())
.def_readonly("ret", &ParallelTppRComputer::ret, py::return_value_policy::reference)
.def("reset_ret", &ParallelTppRComputer::reset_ret)
.def("reset_tppr", &ParallelTppRComputer::reset_tppr)
.def("reset_val_tppr", &ParallelTppRComputer::reset_val_tppr)
.def("backup_tppr", &ParallelTppRComputer::backup_tppr)
.def("restore_tppr", &ParallelTppRComputer::restore_tppr)
.def("restore_val_tppr", &ParallelTppRComputer::restore_val_tppr)
.def("get_pruned_topk", &ParallelTppRComputer::get_pruned_topk)
.def("extract_streaming_tppr", &ParallelTppRComputer::extract_streaming_tppr)
.def("streaming_topk", &ParallelTppRComputer::streaming_topk)
.def("single_streaming_topk", &ParallelTppRComputer::single_streaming_topk)
.def("streaming_topk_no_fake", &ParallelTppRComputer::streaming_topk_no_fake)
.def("compute_val_tppr", &ParallelTppRComputer::compute_val_tppr)
.def("get_ret", [](const ParallelTppRComputer &ps) { return ps.ret; });
} }
\ No newline at end of file
#pragma once #pragma once
#include <iostream> #include <iostream>
#include <algorithm>
#include <torch/extension.h> #include <torch/extension.h>
#include <omp.h> #include <omp.h>
#include <time.h> #include <time.h>
...@@ -17,6 +18,12 @@ typedef int64_t NodeIDType; ...@@ -17,6 +18,12 @@ typedef int64_t NodeIDType;
typedef int64_t EdgeIDType; typedef int64_t EdgeIDType;
typedef float WeightType; typedef float WeightType;
typedef float TimeStampType; typedef float TimeStampType;
typedef tuple<NodeIDType, EdgeIDType, TimeStampType> PPRKeyType;
typedef double PPRValueType;
typedef phmap::parallel_flat_hash_map<PPRKeyType, PPRValueType> PPRDictType;
typedef vector<PPRDictType> PPRListDictType;
typedef vector<vector<PPRDictType>> PPRListListDictType;
typedef vector<vector<double>> NormListType;
class TemporalNeighborBlock; class TemporalNeighborBlock;
class TemporalGraphBlock; class TemporalGraphBlock;
...@@ -28,6 +35,7 @@ int nodeIdToInOut(NodeIDType nid, int pid, const vector<NodeIDType>& part_ptr); ...@@ -28,6 +35,7 @@ int nodeIdToInOut(NodeIDType nid, int pid, const vector<NodeIDType>& part_ptr);
int nodeIdToPartId(NodeIDType nid, const vector<NodeIDType>& part_ptr); int nodeIdToPartId(NodeIDType nid, const vector<NodeIDType>& part_ptr);
vector<th::Tensor> divide_nodes_to_part(th::Tensor nodes, const vector<NodeIDType>& part_ptr, int threads); vector<th::Tensor> divide_nodes_to_part(th::Tensor nodes, const vector<NodeIDType>& part_ptr, int threads);
NodeIDType sample_multinomial(const vector<WeightType>& weights, default_random_engine& e); NodeIDType sample_multinomial(const vector<WeightType>& weights, default_random_engine& e);
vector<int64_t> sample_max(const vector<WeightType>& weights, int k);
...@@ -173,3 +181,17 @@ NodeIDType sample_multinomial(const vector<WeightType>& weights, default_random_ ...@@ -173,3 +181,17 @@ NodeIDType sample_multinomial(const vector<WeightType>& weights, default_random_
sample_indice = distance(cumulative_weights.begin(), it); sample_indice = distance(cumulative_weights.begin(), it);
return sample_indice; return sample_indice;
} }
vector<int64_t> sample_max(const vector<WeightType>& weights, int k) {
vector<int64_t> indices(weights.size());
for (int i = 0; i < weights.size(); ++i) {
indices[i] = i;
}
// 使用部分排序算法(选择算法)找到前k个最大值的索引
partial_sort(indices.begin(), indices.begin() + k, indices.end(),
[&weights](int64_t a, int64_t b) { return weights[a] > weights[b]; });
// 返回前k个最大值的索引
return vector<int64_t>(indices.begin(), indices.begin() + k);
}
\ No newline at end of file
...@@ -287,10 +287,15 @@ void TemporalNeighborBlock::update_edge_weight( ...@@ -287,10 +287,15 @@ void TemporalNeighborBlock::update_edge_weight(
for(int64_t i=0; i<edge_num; i++){ for(int64_t i=0; i<edge_num; i++){
//修改节点与邻居边的权重 //修改节点与邻居边的权重
AT_ASSERTM(this->inverted_index[dst[i]].count(src[i])==1, "Unexist Edge Index: "+to_string(src[i])+", "+to_string(dst[i]));
int index; int index;
if(this->with_eid) index = this->inverted_index[dst[i]][eid_ptr[i]]; if(this->with_eid){
else index = this->inverted_index[dst[i]][src[i]]; AT_ASSERTM(this->inverted_index[dst[i]].count(eid_ptr[i])==1, "Unexist Eid --> Col: "+to_string(eid_ptr[i])+"-->"+to_string(dst[i]));
index = this->inverted_index[dst[i]][eid_ptr[i]];
}
else{
AT_ASSERTM(this->inverted_index[dst[i]].count(src[i])==1, "Unexist Edge Index: "+to_string(src[i])+", "+to_string(dst[i]));
index = this->inverted_index[dst[i]][src[i]];
}
this->edge_weight[dst[i]][index] = ew[i]; this->edge_weight[dst[i]][index] = ew[i];
} }
} }
......
...@@ -11,6 +11,7 @@ class TemporalGraphBlock ...@@ -11,6 +11,7 @@ class TemporalGraphBlock
vector<int64_t> src_index; vector<int64_t> src_index;
vector<NodeIDType> sample_nodes; vector<NodeIDType> sample_nodes;
vector<TimeStampType> sample_nodes_ts; vector<TimeStampType> sample_nodes_ts;
vector<WeightType> e_weights;
double sample_time = 0; double sample_time = 0;
double tot_time = 0; double tot_time = 0;
int64_t sample_edge_num = 0; int64_t sample_edge_num = 0;
......
...@@ -105,13 +105,13 @@ void ParallelSampler :: neighbor_sample_from_nodes_static_layer(th::Tensor nodes ...@@ -105,13 +105,13 @@ void ParallelSampler :: neighbor_sample_from_nodes_static_layer(th::Tensor nodes
// uniform_int_distribution<> u(0, tnb.deg[node]-1); // uniform_int_distribution<> u(0, tnb.deg[node]-1);
// while(temp_s.size()!=fanout && temp_s.size()<tnb.neighbors_set[node].size()){ // while(temp_s.size()!=fanout && temp_s.size()<tnb.neighbors_set[node].size()){
for(int i=0;i<fanout;i++){ for(int i=0;i<fanout;i++){
//ѭ��ѡ��fanout���ھ� //循环选择fanout个邻居
NodeIDType indice; NodeIDType indice;
if(policy == "weighted"){//���DZ�Ȩ����Ϣ if(policy == "weighted"){//考虑边权重信
const vector<WeightType>& ew = tnb.edge_weight[node]; const vector<WeightType>& ew = tnb.edge_weight[node];
indice = sample_multinomial(ew, e); indice = sample_multinomial(ew, e);
} }
else if(policy == "uniform"){//���Ȳ��� else if(policy == "uniform"){//均匀采样
// indice = u(e); // indice = u(e);
indice = rand_r(&loc_seed) % (nei.size()); indice = rand_r(&loc_seed) % (nei.size());
} }
...@@ -119,7 +119,7 @@ void ParallelSampler :: neighbor_sample_from_nodes_static_layer(th::Tensor nodes ...@@ -119,7 +119,7 @@ void ParallelSampler :: neighbor_sample_from_nodes_static_layer(th::Tensor nodes
auto chosen_e_iter = edge.begin() + indice; auto chosen_e_iter = edge.begin() + indice;
if(part_unique){ if(part_unique){
auto rst = temp_s.insert(*chosen_n_iter); auto rst = temp_s.insert(*chosen_n_iter);
if(rst.second){ //���ظ� if(rst.second){ //不重复
eid_threads[tid].emplace_back(*chosen_e_iter); eid_threads[tid].emplace_back(*chosen_e_iter);
node_s_threads[tid].insert(*chosen_n_iter); node_s_threads[tid].insert(*chosen_n_iter);
if(!tnb.neighbors_set.empty() && temp_s.size()<fanout && temp_s.size()<tnb.neighbors_set[node].size()) fanout++; if(!tnb.neighbors_set.empty() && temp_s.size()<fanout && temp_s.size()<tnb.neighbors_set[node].size()) fanout++;
...@@ -229,7 +229,7 @@ void ParallelSampler :: neighbor_sample_from_nodes_with_before_layer( ...@@ -229,7 +229,7 @@ void ParallelSampler :: neighbor_sample_from_nodes_with_before_layer(
} }
} }
else{ else{
//��ѡ�ھӱߴ����ȳ��Ļ���Ҫ���ѡ��fanout���ھ� //可选邻居边大于扇出的话需要随机选择fanout个邻居
tgb_i[tid].src_index.insert(tgb_i[tid].src_index.end(), fanout, i); tgb_i[tid].src_index.insert(tgb_i[tid].src_index.end(), fanout, i);
uniform_int_distribution<> u(0, end_index-1); uniform_int_distribution<> u(0, end_index-1);
//cout<<end_index<<endl; //cout<<end_index<<endl;
......
#pragma once
#include <head.h>
#include <neighbors.h>
# include <output.h>
class ParallelTppRComputer
{
public:
TemporalNeighborBlock& tnb;
NodeIDType num_nodes;
EdgeIDType num_edges;
int threads;
int fanout;//k, width
int num_layers;//depth
int num_tpprs;//n_tpprs
vector<float> alpha_list;
vector<float> beta_list;
// string policy;
PPRListListDictType PPR_list;
PPRListListDictType val_PPR_list;
NormListType norm_list;
NormListType val_norm_list;
vector<vector<TemporalGraphBlock>> ret;
ParallelTppRComputer(TemporalNeighborBlock& _tnb, NodeIDType _num_nodes, EdgeIDType _num_edges, int _threads,
int _fanout, int _num_layers, int _num_tpprs, vector<float>& _alpha_list, vector<float>& _beta_list) :
tnb(_tnb), num_nodes(_num_nodes), num_edges(_num_edges), threads(_threads),
fanout(_fanout), num_layers(_num_layers), num_tpprs(_num_tpprs), alpha_list(_alpha_list), beta_list(_beta_list)
{
omp_set_num_threads(_threads);
ret.clear();
ret = vector<vector<TemporalGraphBlock>>(_num_tpprs, vector<TemporalGraphBlock>());
}
void reset_ret() {
for (int i = 0; i < num_tpprs; ++i) {
ret[i].clear(); // 清空每个内部的 vector
}
}
void reset_ret_i(int tppr_id) {
ret[tppr_id].clear(); // 清空 tppr_id 处的 vector
}
void reset_tppr(){
PPR_list = PPRListListDictType(num_tpprs, PPRListDictType(num_nodes));
norm_list = NormListType(num_tpprs, vector<double>(num_nodes, 0.0));
}
void reset_val_tppr(){
val_PPR_list = PPRListListDictType(num_tpprs, PPRListDictType(num_nodes));
val_norm_list = NormListType(num_tpprs, vector<double>(num_nodes, 0.0));
}
py::tuple backup_tppr(){
return py::make_tuple(this->PPR_list, this->norm_list);
}
void restore_tppr(PPRListListDictType& input_PPR_list, NormListType& input_norm_list){
this->PPR_list = input_PPR_list;
this->norm_list = input_norm_list;
}
void restore_val_tppr(PPRListListDictType& input_PPR_list, NormListType& input_norm_list){
this->val_PPR_list = input_PPR_list;
this->val_norm_list = input_norm_list;
}
PPRDictType compute_s1_s2(NodeIDType s1, NodeIDType s2, int tppr_id, EdgeIDType eid, TimeStampType ts);
void get_pruned_topk(th::Tensor src_nodes, th::Tensor root_ts, int tppr_id);
void extract_streaming_tppr(PPRDictType tppr_dict, TimeStampType current_ts, int index0, int position);
void streaming_topk(th::Tensor src_nodes, th::Tensor root_ts, th::Tensor eids);
void single_streaming_topk(th::Tensor src_nodes, th::Tensor root_ts, th::Tensor eids, int tppr_id);
void streaming_topk_no_fake(th::Tensor src_nodes, th::Tensor root_ts, th::Tensor eids);
void compute_val_tppr(th::Tensor src_nodes, th::Tensor dst_nodes, th::Tensor root_ts, th::Tensor eids);
};
PPRDictType ParallelTppRComputer :: compute_s1_s2(NodeIDType s1, NodeIDType s2, int tppr_id, EdgeIDType eid, TimeStampType ts){
int alpha = alpha_list[tppr_id], beta = beta_list[tppr_id];
vector<double> norm_list = this->norm_list[tppr_id];
PPRListDictType PPR_list = this->PPR_list[tppr_id];
PPRDictType t_s1_PPR= PPRDictType();
PPRDictType updated_tppr= PPRDictType();
float scala_s1, scala_s2;
/***************s1 side*******************/
if(norm_list[s1]==0){
scala_s2 = 1-alpha;
}
else{
t_s1_PPR = PPR_list[s1];
double last_norm = norm_list[s1], new_norm;
new_norm = last_norm*beta+beta;
scala_s1 = last_norm/new_norm*beta;
scala_s2 = beta/new_norm*(1-alpha);
for (const auto& pair : t_s1_PPR)
t_s1_PPR[pair.first] = pair.second*scala_s1;
}
/**************s2 side*******************/
if(norm_list[s1]==0){
t_s1_PPR[make_tuple(eid, s2, ts)] = alpha!=0 ? scala_s2*alpha : scala_s2;
}
else{
PPRDictType s2_PPR= PPR_list[s2];
for (const auto& pair : s2_PPR){
if(t_s1_PPR.count(pair.first)==1)
t_s1_PPR[pair.first] += pair.second*scala_s2;
else
t_s1_PPR[pair.first] = pair.second*scala_s2;
}
t_s1_PPR[make_tuple(eid, s2, ts)] = alpha!=0 ? scala_s2*alpha : scala_s2;
}
/*********exract the top-k items ********/
int tppr_size = t_s1_PPR.size();
if(tppr_size<=this->fanout)
updated_tppr = t_s1_PPR;
else{
std::vector<std::pair<PPRKeyType, PPRValueType>> pairs;
pairs.reserve(t_s1_PPR.size());
// 提取键值对到 pairs 向量
for (const auto& pair : t_s1_PPR) {
pairs.emplace_back(pair.first, pair.second);
}
// 使用并行部分排序来获得前 this->fanout 个元素
std::partial_sort(pairs.begin(), pairs.begin() + this->fanout, pairs.end(),
[](const auto& a, const auto& b) { return a.second > b.second; });
// 将部分排序后的键值对添加到 updated_tppr
for (size_t i = 0; i < this->fanout; ++i) {
const auto& pair = pairs[i];
updated_tppr[pair.first] = pair.second;
}
}
return updated_tppr;
}
void ParallelTppRComputer :: get_pruned_topk(th::Tensor src_nodes, th::Tensor root_ts, int tppr_id){
auto src_nodes_data = get_data_ptr<NodeIDType>(src_nodes);
auto ts_data = get_data_ptr<TimeStampType>(root_ts);
int64_t n_edges = src_nodes.size(0);
float alpha = alpha_list[tppr_id], beta = beta_list[tppr_id];
this->reset_ret_i(tppr_id);
for(int i=0;i<n_edges;i++)
{
NodeIDType target_node = src_nodes_data[i];
TimeStampType target_timestamp = ts_data[i];
PPRDictType tppr_dict;
/*******get dictionary of neighbors*********************/
vector<tuple<NodeIDType, TimeStampType, PPRValueType>> query_list;
query_list.push_back(make_tuple(target_node, target_timestamp, 1.0));
for(int depth=0;depth<this->num_layers;depth++)
{
vector<tuple<NodeIDType, TimeStampType, PPRValueType>> new_query_list;
/*******traverse the query list*********************/
for(int j=0;j<query_list.size();j++)
{
NodeIDType query_node = get<0>(query_list[j]);
NodeIDType query_ts = get<1>(query_list[j]);
NodeIDType query_weight = get<2>(query_list[j]);
int end_index = lower_bound(tnb.timestamp[query_node].begin(), tnb.timestamp[query_node].end(), query_ts)-tnb.timestamp[query_node].begin();
int n_ngh = end_index;
if(n_ngh==0) continue;
else
{
double norm = beta/(1-beta)*(1-pow(beta, n_ngh));
double weight = alpha!=0 && depth==0 ? query_weight*(1-alpha)*beta/norm*alpha : query_weight*(1-alpha)*beta/norm;
for(int z=0;z<min(this->fanout, n_ngh);z++){
EdgeIDType eid = tnb.eid[query_node][end_index-z-1];
NodeIDType node = tnb.neighbors[query_node][end_index-z-1];
// the timestamp here is a neighbor timestamp,
// so that it is indeed a temporal random walk
TimeStampType timestamp = tnb.timestamp[query_node][end_index-z-1];
PPRKeyType state = make_tuple(eid, node, timestamp);
// update dict
if(tppr_dict.count(state)==1)
tppr_dict[state] = tppr_dict[state]+weight;
else
tppr_dict[state] = weight;
// update query list
tuple<NodeIDType, TimeStampType, PPRValueType> new_query = make_tuple(node, timestamp, weight);
new_query_list.push_back(new_query);
// update weight
weight = weight*beta;
}
}
}
if(new_query_list.empty()) break;
else query_list = new_query_list;
}
/*****sort and get the top-k neighbors********/
int tppr_size = tppr_dict.size();
if(tppr_size==0) continue;
TimeStampType current_timestamp = ts_data[i];
PPRDictType updated_tppr= PPRDictType();
if(tppr_size<=this->fanout)
updated_tppr = tppr_dict;
else
{
std::vector<std::pair<PPRKeyType, PPRValueType>> pairs;
pairs.reserve(tppr_dict.size());
// 提取键值对到 pairs 向量
for (const auto& pair : tppr_dict) {
pairs.emplace_back(pair.first, pair.second);
}
// 使用并行部分排序来获得前 this->fanout 个元素
std::partial_sort(pairs.begin(), pairs.begin() + this->fanout, pairs.end(),
[](const auto& a, const auto& b) { return a.second > b.second; });
// 将部分排序后的键值对添加到 updated_tppr
for (size_t i = 0; i < this->fanout; ++i) {
const auto& pair = pairs[i];
updated_tppr[pair.first] = pair.second;
}
}
// this->PPR_list[tppr_id][target_node] = updated_tppr;
extract_streaming_tppr(updated_tppr, current_timestamp, tppr_id, i);
}
}
// category=0-src category=1-dst category=2-fake
void ParallelTppRComputer :: extract_streaming_tppr(PPRDictType tppr_dict, TimeStampType current_ts, int index0, int position){
ret[index0][position] = TemporalGraphBlock();
if(!tppr_dict.empty()){
ret[index0][position].sample_nodes.resize(this->fanout);
ret[index0][position].eid.resize(this->fanout);
ret[index0][position].sample_nodes_ts.resize(this->fanout);
ret[index0][position].e_weights.resize(this->fanout);
ret[index0][position].delta_ts.resize(this->fanout);
int j=0;
for (const auto& pair : tppr_dict){
auto tuple = pair.first;
auto weight = pair.second;
EdgeIDType eid = get<0>(tuple);
NodeIDType dst = get<1>(tuple);
TimeStampType ets = get<2>(tuple);
ret[index0][position].sample_nodes[j]=dst;
ret[index0][position].eid[j]=eid;
ret[index0][position].sample_nodes_ts[j]=ets;
ret[index0][position].e_weights[j]=weight;
ret[index0][position].delta_ts[j]=current_ts-ets;
j++;
}
}
}
void ParallelTppRComputer :: streaming_topk(th::Tensor src_nodes, th::Tensor root_ts, th::Tensor eids){
auto src_nodes_data = get_data_ptr<NodeIDType>(src_nodes);
auto ts_data = get_data_ptr<TimeStampType>(root_ts);
auto eids_data = get_data_ptr<EdgeIDType>(eids);
int n_nodes = src_nodes.size(0);
int n_edges = num_nodes/3;
this->reset_ret();
for(int index0=0;index0<num_tpprs;index0++){
int alpha = alpha_list[index0], beta = beta_list[index0];
ret[index0].resize(n_nodes);
vector<double>& norm_list = this->norm_list[index0];
PPRListDictType& PPR_list = this->PPR_list[index0];
for(int i=0; i<n_edges; i++){
NodeIDType src = src_nodes_data[i];
NodeIDType dst = src_nodes_data[i+n_edges];
NodeIDType fake = src_nodes_data[i+(n_edges<<1)];
TimeStampType ts = ts_data[i];
EdgeIDType eid = eids_data[i];
/******first extract the top-k neighbors and fill the list******/
extract_streaming_tppr(PPR_list[src], ts, index0, i);
extract_streaming_tppr(PPR_list[dst], ts, index0, i+n_edges);
extract_streaming_tppr(PPR_list[fake], ts, index0, i+(n_edges<<1));
/******then update the PPR values here**************************/
PPR_list[src] = compute_s1_s2(src, dst, index0, eid, ts);
norm_list[src] = norm_list[src]*beta+beta;
if(src!=dst){
PPR_list[dst] = compute_s1_s2(dst, src, index0, eid, ts);
norm_list[dst] = norm_list[dst]*beta+beta;
}
}
}
}
void ParallelTppRComputer :: single_streaming_topk(th::Tensor src_nodes, th::Tensor root_ts, th::Tensor eids, int tppr_id){
auto src_nodes_data = get_data_ptr<NodeIDType>(src_nodes);
auto ts_data = get_data_ptr<TimeStampType>(root_ts);
auto eids_data = get_data_ptr<EdgeIDType>(eids);
int n_nodes = src_nodes.size(0);
int n_edges = num_nodes/3;
this->reset_ret_i(tppr_id);
int alpha = alpha_list[tppr_id], beta = beta_list[tppr_id];
ret[tppr_id].resize(n_nodes);
vector<double>& norm_list = this->norm_list[tppr_id];
PPRListDictType& PPR_list = this->PPR_list[tppr_id];
for(int i=0; i<n_edges; i++){
NodeIDType src = src_nodes_data[i];
NodeIDType dst = src_nodes_data[i+n_edges];
NodeIDType fake = src_nodes_data[i+(n_edges<<1)];
TimeStampType ts = ts_data[i];
EdgeIDType eid = eids_data[i];
/******first extract the top-k neighbors and fill the list******/
extract_streaming_tppr(PPR_list[src], ts, tppr_id, i);
extract_streaming_tppr(PPR_list[dst], ts, tppr_id, i+n_edges);
extract_streaming_tppr(PPR_list[fake], ts, tppr_id, i+(n_edges<<1));
/******then update the PPR values here**************************/
PPR_list[src] = compute_s1_s2(src, dst, tppr_id, eid, ts);
norm_list[src] = norm_list[src]*beta+beta;
if(src!=dst){
PPR_list[dst] = compute_s1_s2(dst, src, tppr_id, eid, ts);
norm_list[dst] = norm_list[dst]*beta+beta;
}
}
}
void ParallelTppRComputer :: streaming_topk_no_fake(th::Tensor src_nodes, th::Tensor root_ts, th::Tensor eids){
auto src_nodes_data = get_data_ptr<NodeIDType>(src_nodes);
auto ts_data = get_data_ptr<TimeStampType>(root_ts);
auto eids_data = get_data_ptr<EdgeIDType>(eids);
int n_nodes = src_nodes.size(0);
int n_edges = num_nodes/2;
this->reset_ret();
for(int index0=0;index0<num_tpprs;index0++){
int alpha = alpha_list[index0], beta = beta_list[index0];
ret[index0].resize(n_nodes);
vector<double>& norm_list = this->norm_list[index0];
PPRListDictType& PPR_list = this->PPR_list[index0];
for(int i=0; i<n_edges; i++){
NodeIDType src = src_nodes_data[i];
NodeIDType dst = src_nodes_data[i+n_edges];
TimeStampType ts = ts_data[i];
EdgeIDType eid = eids_data[i];
/******first extract the top-k neighbors and fill the list******/
extract_streaming_tppr(PPR_list[src], ts, index0, i);
extract_streaming_tppr(PPR_list[dst], ts, index0, i+n_edges);
/******then update the PPR values here**************************/
PPR_list[src] = compute_s1_s2(src, dst, index0, eid, ts);
norm_list[src] = norm_list[src]*beta+beta;
if(src!=dst){
PPR_list[dst] = compute_s1_s2(dst, src, index0, eid, ts);
norm_list[dst] = norm_list[dst]*beta+beta;
}
}
}
}
void ParallelTppRComputer :: compute_val_tppr(th::Tensor src_nodes, th::Tensor dst_nodes, th::Tensor root_ts, th::Tensor eids){
auto src_nodes_data = get_data_ptr<NodeIDType>(src_nodes);
auto dst_nodes_data = get_data_ptr<NodeIDType>(dst_nodes);
auto ts_data = get_data_ptr<TimeStampType>(root_ts);
auto eids_data = get_data_ptr<EdgeIDType>(eids);
int n_edges = src_nodes.size(0);
for(int index0=0;index0<num_tpprs;index0++){
int alpha = alpha_list[index0], beta = beta_list[index0];
vector<double>& norm_list = this->norm_list[index0];
PPRListDictType& PPR_list = this->PPR_list[index0];
for(int i=0; i<n_edges; i++){
NodeIDType src = src_nodes_data[i];
NodeIDType dst = dst_nodes_data[i];
TimeStampType ts = ts_data[i];
EdgeIDType eid = eids_data[i];
PPR_list[src] = compute_s1_s2(src, dst, index0, eid, ts);
norm_list[src] = norm_list[src]*beta+beta;
if(src!=dst){
PPR_list[dst] = compute_s1_s2(dst, src, index0, eid, ts);
norm_list[dst] = norm_list[dst]*beta+beta;
}
}
}
this->val_norm_list.assign(this->norm_list.begin(), this->norm_list.end());
this->val_PPR_list.assign(this->PPR_list.begin(), this->PPR_list.end());
}
\ No newline at end of file
...@@ -114,11 +114,15 @@ edge_weight_dict = {} ...@@ -114,11 +114,15 @@ edge_weight_dict = {}
edge_weight_dict['edata'] = 2*neg_nums edge_weight_dict['edata'] = 2*neg_nums
edge_weight_dict['sample_data'] = 1*neg_nums edge_weight_dict['sample_data'] = 1*neg_nums
edge_weight_dict['neg_data'] = 1 edge_weight_dict['neg_data'] = 1
partition_save('./dataset/here/'+data_name, data, 1, 'metis_for_tgnn', #partition_save('./dataset/here/'+data_name, data, 1, 'metis_for_tgnn',
edge_weight_dict=edge_weight_dict) # edge_weight_dict=edge_weight_dict)
partition_save('./dataset/here/'+data_name, data, 2, 'metis_for_tgnn', #partition_save('./dataset/here/'+data_name, data, 2, 'metis_for_tgnn',
edge_weight_dict=edge_weight_dict) # edge_weight_dict=edge_weight_dict)
partition_save('./dataset/here/'+data_name, data, 4, 'metis_for_tgnn', #partition_save('./dataset/here/'+data_name, data, 4, 'metis_for_tgnn',
# edge_weight_dict=edge_weight_dict)
#partition_save('./dataset/here/'+data_name, data, 8, 'metis_for_tgnn',
# edge_weight_dict=edge_weight_dict)
partition_save('./dataset/here/'+data_name, data, 16, 'metis_for_tgnn',
edge_weight_dict=edge_weight_dict) edge_weight_dict=edge_weight_dict)
# #
# partition_save('./dataset/here/'+data_name, data, 4, 'metis_for_tgnn', # partition_save('./dataset/here/'+data_name, data, 4, 'metis_for_tgnn',
......
"""单机四卡训练,启动指令:torchrun --nproc_per_node 4 --standalone demo_dtdg.py
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
from torch import Tensor
from typing import *
from starrygl.distributed import DistributedContext
from starrygl.data import GraphData
from starrygl.parallel import Route, SequencePipe, LayerPipe
from starrygl.parallel.utils import *
from torch_scatter import scatter
from torch_geometric_temporal.dataset import TwitterTennisDatasetLoader
import math
import logging
logging.getLogger().setLevel(logging.INFO)
def prepare_data(root: str, num_parts):
dataset = TwitterTennisDatasetLoader().get_dataset()
x = []
y = []
edge_index = []
edge_times = []
edge_attr = []
snapshot_count = 0
for i, data in enumerate(dataset):
x.append(data.x[:,None,:])
y.append(data.y[:,None])
edge_index.append(data.edge_index)
edge_times.append(torch.full_like(data.edge_index[0], i)) # 利用snapshot id作为时间,也可以用真实时间戳
edge_attr.append(data.edge_attr)
snapshot_count += 1
x = torch.cat(x, dim=1)
y = torch.cat(y, dim=1)
edge_index = torch.cat(edge_index, dim=1)
edge_times = torch.cat(edge_times, dim=0)
edge_attr = torch.cat(edge_attr, dim=0)
g = GraphData(edge_index, num_nodes=x.size(0))
g.node()["x"] = x
g.node()["y"] = y
g.edge()["time"] = edge_times
g.edge()["attr"] = edge_attr
g.meta()["num_nodes"] = x.size(0) # 全局的节点数量
g.meta()["num_snapshots"] = snapshot_count # 快照数量
logging.info(f"GraphData.meta().keys(): {g.meta().keys()}")
logging.info(f"GraphData.node().keys(): {g.node().keys()}")
logging.info(f"GraphData.edge().keys(): {g.edge().keys()}")
g.save_partition(root, num_parts, algorithm="random") # 采用随机划分算法
return g
class SageConv(nn.Module):
"""基础的GraphSAGE卷积层,采用平均聚合函数
"""
def __init__(self, in_feats: int, out_feats: int):
super().__init__()
self.linear = nn.Linear(in_feats, out_feats)
def forward(self, x: Tensor, edge_index: Tensor, edge_attr: Tensor, num_nodes: int):
assert edge_attr.dim() == 1
x = scatter(
src=x[edge_index[0]] * edge_attr[:,None],
index=edge_index[1],
reduce="sum",
dim=0,
dim_size=num_nodes,
)
e = scatter(
src=edge_attr,
index=edge_index[1],
reduce="sum",
dim=0,
dim_size=num_nodes,
).clamp_min(1) # 避免孤立点除0
x = x / e[:,None]
return self.linear(x)
class SyncSAGE(nn.Module):
def __init__(self,
graph: GraphData,
hidden_dim: int,
num_layers: int,
group: Any,
) -> None:
super().__init__()
self.graph = graph # 图数据
self.route = graph.to_route(group) # 在进程组上建立通信路由
self.group = group
self.num_features = self.graph.node("dst")["x"].size(-1) # 每个分区是一个二分图,只有dst节点有数据
self.num_snapshots = self.graph.meta()["num_snapshots"]
self.num_layers = num_layers
self.layers = nn.ModuleList()
self.norms = nn.ModuleList()
last_ch = self.num_features
for i in range(num_layers):
self.layers.append(SageConv(last_ch, hidden_dim))
if i == num_layers - 1:
break
self.norms.append(nn.LayerNorm(hidden_dim))
last_ch = hidden_dim
def get_snapshot(self, i: int):
num_nodes = self.graph.node("dst").num_nodes
x = self.graph.node("dst")["x"][:,i,:]
time_mask = self.graph.edge()["time"] == i
edge_attr = self.graph.edge()["attr"][time_mask]
edge_index = self.graph.edge_index()[:,time_mask]
return num_nodes, x, edge_index, edge_attr
def forward(self, snapshot_id: Optional[int] = None):
if snapshot_id is None: # 如果没有传入snapshot_id,则forward所有snapshot并拼接
xs = []
for i in range(self.num_snapshots):
xs.append(self.forward(i).unsqueeze(1))
return torch.cat(xs, dim=1)
num_nodes, x, edge_index, edge_attr = self.get_snapshot(snapshot_id)
for i in range(self.num_layers):
x = self.route.apply(x) # 将所有dst节点的表征同步到src节点上
x = self.layers[i](x, edge_index, edge_attr, num_nodes)
if i == self.num_layers - 1:
break
x = self.norms[i](x)
x = F.relu(x)
return x
class AsyncSAGE(LayerPipe):
"""采用快照并行的异步训练
"""
def __init__(self,
graph: GraphData,
hidden_dim: int,
num_layers: int,
group: Any,
) -> None:
super().__init__()
self.graph = graph
self.route = graph.to_route(group)
self.group = group
self.num_features = self.graph.node("dst")["x"].size(-1) # 每个分区是一个二分图,只有dst节点有数据
self.num_snapshots = self.graph.meta()["num_snapshots"]
self.num_layers = num_layers
self.layers = nn.ModuleList()
self.norms = nn.ModuleList()
last_ch = self.num_features
for i in range(num_layers):
self.layers.append(SageConv(last_ch, hidden_dim))
if i == num_layers - 1:
break
self.norms.append(nn.LayerNorm(hidden_dim))
last_ch = hidden_dim
def get_snapshot(self, i: int):
num_nodes = self.graph.node("dst").num_nodes
time_mask = self.graph.edge()["time"] == i
edge_attr = self.graph.edge()["attr"][time_mask]
edge_index = self.graph.edge_index()[:,time_mask]
return num_nodes, edge_index, edge_attr
def get_route(self) -> Route:
"""必须重写该函数,底层执行引擎需要识别用于通信的route
"""
return self.route
def layer_inputs(self, inputs: Sequence[Tensor] | None = None) -> Sequence[Tensor]:
"""必须重写该函数,用于准备数据
"""
if self.layer_id == 0: # 如果是第一层,则采用原始的图数据
x = self.graph.node("dst")["x"][:,self.snapshot_id,:] # 可以通过layer_id和snapshot_id来判断当前正在处理第几层和第几个快照
else:
x, = inputs # 如果不是第一层,则直接返回上一层的输出
self.register_route(x) # 通过该方法对输出打上标记,被标记的输出会通过route被同步成src节点,否则依然是dst节点
return (x,)
def layer_forward(self, inputs: Sequence[Tensor]) -> Sequence[Tensor]:
"""必须重写改函数,用于执行模型
"""
num_nodes, edge_index, edge_attr = self.get_snapshot(self.snapshot_id)
x = self.layers[self.layer_id](*inputs, edge_index, edge_attr, num_nodes)
return (x,)
def forward(self):
"""调用apply方法后,拼接所有的快照
"""
xs = self.apply(self.num_layers, self.num_snapshots)
return torch.cat([x.unsqueeze(1) for x, in xs], dim=1)
class SimpleRNN(SequencePipe):
def __init__(self,
graph: GraphData,
hidden_dims: int,
num_layers: int,
device: Any,
group: Any,
) -> None:
super().__init__()
self.graph = graph
self.device = device
self.group = group
self.num_layers = num_layers
self.hidden_dims = hidden_dims
self.gru = nn.GRU(
input_size = hidden_dims,
hidden_size = hidden_dims,
num_layers = num_layers,
batch_first = True,
)
self.out = nn.Linear(hidden_dims, 1)
def forward(self, inputs, states):
"""每个GPU上的时序单元
"""
x, = inputs # (N, L, H) (节点数量,快照数量,隐藏层维度)
h, = states # (N, L, H)
h = h.transpose(0, 1).contiguous() # (L, N, H) 需要节点顺序放在第0维,并且是连续张量
x, h = self.gru(x, h) # (N, L, H), (L, N, H)
h = h.transpose(0, 1).contiguous() # (N, L, H)
x = self.out(x)
return (x,), (h, )
def loss_fn(self, inputs, labels) -> Tensor:
"""如果调用fast_backward(),则需要重写该方法计算每个micro-batch的损失函数
"""
x, = inputs
y, = labels
loss_scale = x.size(0) / self.graph.node("dst").num_nodes
return F.mse_loss(x, y) * loss_scale # 最好对每一个micro batch的loss进行缩放,保证最终损失值一致
def get_group(self) -> Any:
"""必须重写该方法,用于识别子进程组
"""
return self.group
def get_init_states(self):
"""必须重写该方法,返回每个节点的初始状态模板
"""
s = torch.zeros(self.num_layers, self.hidden_dims).to(self.device)
return (s,)
def get_graph(
root: str,
sp_group: Any,
pp_group: Any,
):
g = GraphData.load_partition(
root, part_id=dist.get_rank(pp_group), num_parts=dist.get_world_size(pp_group), # 采用分区并行加载图快照
algorithm="random", # load_partition和save_partition所用的分区算法要一致
)
snap_part_id = dist.get_rank(sp_group)
num_snap_parts = dist.get_world_size(sp_group)
num_snapshots = g.meta()["num_snapshots"]
stride = math.ceil(num_snapshots * 1.0 / num_snap_parts)
start = snap_part_id * stride
end = min(num_snapshots, start + stride)
# 只保留在[start, end)内的快照
time = g.edge()["time"]
mask = (start <= time) & (time < end)
edge_attr = g.edge()["attr"][mask]
edge_index = g.edge_index()[:,mask]
sg = GraphData.from_bipartite(
edge_index=edge_index,
raw_src_ids=g.node("src")["raw_ids"],
raw_dst_ids=g.node("dst")["raw_ids"],
)
sg.node("dst")["x"] = g.node("dst")["x"][:,start:end].clone()
sg.node("dst")["y"] = g.node("dst")["y"][:,start:end].clone()
sg.edge()["time"] = time[mask] - start # 快照偏移从0开始
sg.edge()["attr"] = edge_attr
sg.meta()["num_snapshots"] = end - start
sg.meta()["num_nodes"] = g.meta()["num_nodes"]
return sg
def get_negative_route(g: GraphData, num_edges: int, group: Any):
num_nodes = g.meta()["num_nodes"] # 这个num_nodes是全局节点数量
raw_src_ids = g.node("src")["raw_ids"]
raw_dst_ids = g.node("dst")["raw_ids"]
# 随机选择src节点
src = torch.randint(num_nodes, size=(num_edges,)).type_as(raw_src_ids)
# 随机选择dst节点,并映射到全局id
dst = torch.randint(raw_dst_ids.numel(), size=(num_edges,)).type_as(raw_dst_ids)
dst = raw_dst_ids[dst]
edge_index = torch.vstack([src, dst]) # 生成负采样边
raw_src_ids = src.unique() # 对src节点去重,raw_dst_ids本身就是去重的
route = GraphData.from_bipartite(
edge_index=edge_index,
raw_src_ids=raw_src_ids,
raw_dst_ids=raw_dst_ids,
).to_route(group)
return route, edge_index
if __name__ == "__main__":
data_root = "./dataset"
ctx = DistributedContext.init(backend="nccl", use_gpu=True)
hybrid_matrix = ctx.get_hybrid_matrix().view(2, 2) # (2, 2)
"""返回进程矩阵: [[0, 1], [2, 3]]
其中[0, 1]和[2, 3]分别组成两个pp_group,用于训练t0-t1和t1-t2的两段图快照
[0, 2]有相同的dst节点,可以组成一个sp_group用于流水线训练时序模型,[1, 3]同理
"""
sp_group, pp_group = ctx.new_hybrid_subgroups(hybrid_matrix)
if ctx.rank == 0: # ctx.rank和ctx.world_size返回的是全局进程组(默认进程组)的属性
prepare_data(data_root, dist.get_world_size(pp_group)) # 这里只需要划分成pp_group进程组数量的分区即可
dist.barrier() # 这里加一个barrier,保证其它进程在读数据的时候,rank 0已经写完数据
g = get_graph(data_root, sp_group, pp_group).to(ctx.device)
ctx.sync_print(f'x: {g.node("dst")["x"].size()}')
ctx.sync_print(f'y: {g.node("dst")["y"].size()}')
ctx.sync_print(f'edge_index: {g.edge_index().size()}')
ctx.sync_print(f'edge_time: {g.edge()["time"].size()}')
ctx.sync_print(f'edge_attr: {g.edge()["attr"].size()}')
hidden_dim = 128
num_layers = 3
# 创建支持分区并行的GNN模型
sync_gnn = SyncSAGE(g, hidden_dim, num_layers, pp_group).to(ctx.device) # 原始的分区并行,只依赖于route.apply()
async_gnn = AsyncSAGE(g, hidden_dim, num_layers, pp_group).to(ctx.device) # 快照并行,重叠通信和计算,依赖于LayerPipe
# 创建时间并行模型
rnn = SimpleRNN(g, hidden_dim, num_layers, ctx.device, sp_group).to(ctx.device)
# 创建优化器
params = []
params.extend(sync_gnn.parameters())
params.extend(async_gnn.parameters())
params.extend(rnn.parameters())
opt = torch.optim.Adam(params)
# 训练标签
y = g.node("dst")["y"].unsqueeze(-1)
if True:
"""案例1:采用采用sync_gnn配合rnn训练
"""
ctx.main_print("\nCase 1:")
opt.zero_grad()
x = sync_gnn.forward() # 计算GNN
h, = rnn.apply(32, x) # 计算RNN, micro_batch_size = 32
loss = F.mse_loss(h, y) # 计算损失
# 反向传播
loss.backward()
# 合并梯度
rnn.all_reduce()
all_reduce_gradients(sync_gnn)
all_reduce_buffers(sync_gnn)
opt.step()
ctx.sync_print(f"loss: {loss.item():.6f}")
if True:
"""案例2:采用采用async_gnn配合rnn训练
"""
ctx.main_print("\nCase 2:")
opt.zero_grad()
x = async_gnn.forward() # 计算GNN
h, = rnn.apply(32, x) # 计算RNN, micro_batch_size = 32
loss = F.mse_loss(h, y) # 计算损失
# 反向传播
loss.backward()
async_gnn.backward() # !!! 目前快照并行的backward需要手动调用
# 合并梯度
rnn.all_reduce()
async_gnn.all_reduce()
opt.step()
ctx.sync_print(f"loss: {loss.item():.6f}")
if True:
"""案例3:采用采用async_gnn配合rnn.fast_backward训练。fast_backward只能在最后一层rnn调用,可以减少一次前向计算
"""
ctx.main_print("\nCase 3:")
opt.zero_grad()
x = async_gnn.forward() # 计算GNN
# 计算RNN的同时直接求反向梯度
loss = rnn.fast_backward(32, inputs=(x,), labels=(y,))
# 反向传播
# loss.backward() # 此时RNN的backward不需要调用
async_gnn.backward() # async_gnn的反向传播依然需要调用,如果是sync_gnn则不需要调用反向传播
# 合并梯度
rnn.all_reduce()
async_gnn.all_reduce()
opt.step()
ctx.sync_print(f"loss: {loss.item():.6f}")
if True:
"""案例4:随机负采样边及其route。保持dst节点不变,随机选择src节点
"""
route, edge_index = get_negative_route(g, num_edges=1000, group=pp_group)
ctx.sync_print(route, edge_index.size())
\ No newline at end of file
Advanced Data Preprocessing
===========================
.. note::
详细介绍一下StarryGL几种数据管理类,例如GraphData,的使用细节,内部索引结构的设计和底层操作。
\ No newline at end of file
...@@ -4,4 +4,4 @@ Advanced Concepts ...@@ -4,4 +4,4 @@ Advanced Concepts
.. toctree:: .. toctree::
sampling_parallel/index sampling_parallel/index
partition_parallel/index partition_parallel/index
timeline_parallel/index timeline_parallel/index
\ No newline at end of file
Distributed Partition Parallel
==============================
.. note::
分布式分区并行训练部分
\ No newline at end of file
Distributed Timeline Parallel
=============================
.. note::
分布式时序并行
\ No newline at end of file
Distributed Temporal Sampling
=============================
.. note::
基于分布式时序图采样的训练模式
\ No newline at end of file
starrygl.sample.cache.fetch_cache
=================================
.. note::
The cache used in feature fetching
.. currentmodule:: starrygl.sample.cache.fetch_cache
.. autoclass::
FetchFeatureCache
:members:
starrygl.sample.graph_core
==========================
.. note::
Distributed Data Structure used in sampling training
.. currentmodule:: starrygl.sample.graph_core
.. autoclass::
DistributedGraphStore
:members:
.. autoclass::
DataSet
.. autoclass::
TemporalNeighborSampleGraph
\ No newline at end of file
...@@ -5,4 +5,6 @@ Package References ...@@ -5,4 +5,6 @@ Package References
distributed distributed
neighbor_sampler neighbor_sampler
memory memory
data_loader data_loader
\ No newline at end of file graph_core
cache
Preparing the Temporal Graph Dataset Preparing the Temporal Graph Dataset
==================================== ====================================
.. note:: In this tutorial, we will show the preparation process of the temporal graph datase that can be used by StarryGL.
包含从原始数据开始的数据清洗和预处理步骤,最终形成可以被StarryGL使用的数据文件
\ No newline at end of file Read Raw Data
-------------
Take Wikipedia dataset as an example, the raw data files are as follows:
- `edges.csv`: the temporal edges of the graph
- `node_features.pt`: the node features of the graph
- `edge_features.pt`: the edge features of the graph
Here is an example to read the raw data files:
.. code-block:: python
data_name = args.data_name
df = pd.read_csv('raw_data/'+data_name+'/edges.csv')
if os.path.exists('raw_data/'+data_name+'/node_features.pt'):
n_feat = torch.load('raw_data/'+data_name+'/node_features.pt')
else:
n_feat = None
if os.path.exists('raw_data/'+data_name+'/edge_features.pt'):
e_feat = torch.load('raw_data/'+data_name+'/edge_features.pt')
else:
e_feat = None
src = torch.from_numpy(np.array(df.src.values)).long()
dst = torch.from_numpy(np.array(df.dst.values)).long()
ts = torch.from_numpy(np.array(df.time.values)).long()
neg_nums = args.num_neg_sample
edge_index = torch.cat((src[np.newaxis, :], dst[np.newaxis, :]), 0)
num_nodes = edge_index.view(-1).max().item()+1
num_edges = edge_index.shape[1]
print('the number of nodes in graph is {}, \
the number of edges in graph is {}'.format(num_nodes, num_edges))
Preprocess Data
---------------
After reading the raw data, we need to preprocess the data to get the data format that can be used by StarryGL. The following code shows the preprocessing process:
.. code-block:: python
sample_graph = {}
sample_src = torch.cat([src.view(-1, 1), dst.view(-1, 1)], dim=1)\
.reshape(1, -1)
sample_dst = torch.cat([dst.view(-1, 1), src.view(-1, 1)], dim=1)\
.reshape(1, -1)
sample_ts = torch.cat([ts.view(-1, 1), ts.view(-1, 1)], dim=1).reshape(-1)
sample_eid = torch.arange(num_edges).view(-1, 1).repeat(1, 2).reshape(-1)
sample_graph['edge_index'] = torch.cat([sample_src, sample_dst], dim=0)
sample_graph['ts'] = sample_ts
sample_graph['eids'] = sample_eid
neg_sampler = NegativeSampling('triplet')
neg_src = neg_sampler.sample(edge_index.shape[1]*neg_nums, num_nodes)
neg_sample = neg_src.reshape(-1, neg_nums)
edge_ts = torch.torch.from_numpy(np.array(ts)).float()
data = Data() #torch_geometric.data.Data()
data.num_nodes = num_nodes
data.num_edges = num_edges
data.edge_index = edge_index
data.edge_ts = edge_ts
data.neg_sample = neg_sample
if n_feat is not None:
data.x = n_feat
if e_feat is not None:
data.edge_attr = e_feat
data.train_mask = (torch.from_numpy(np.array(df.ext_roll.values)) == 0)
data.val_mask = (torch.from_numpy(np.array(df.ext_roll.values)) == 1)
data.test_mask = (torch.from_numpy(np.array(df.ext_roll.values)) == 2)
sample_graph['train_mask'] = data.train_mask[sample_eid]
sample_graph['test_mask'] = data.test_mask[sample_eid]
sample_graph['val_mask'] = data.val_mask[sample_eid]
data.sample_graph = sample_graph
data.y = torch.zeros(edge_index.shape[1])
edge_index_dict = {}
edge_index_dict['edata'] = data.edge_index
edge_index_dict['sample_data'] = data.sample_graph['edge_index']
edge_index_dict['neg_data'] = torch.cat([neg_src.view(1, -1),
dst.view(-1, 1).repeat(1, neg_nums).
reshape(1, -1)], dim=0)
data.edge_index_dict = edge_index_dict
edge_weight_dict = {}
edge_weight_dict['edata'] = 2*neg_nums
edge_weight_dict['sample_data'] = 1*neg_nums
edge_weight_dict['neg_data'] = 1
We construct a torch_geometric.data.Data object to store the data. The data object contains the following attributes:
- `num_nodes`: the number of nodes in the graph
- `num_edges`: the number of edges in the graph
- `edge_index`: the edge index of the graph
- `edge_ts`: the timestamp of the edges
- `neg_sample`: the negative samples of the edges
- `x`: the node features of the graph
- `edge_attr`: the edge features of the graph
- `train_mask`: the train mask of the edges
- `val_mask`: the validation mask of the edges
- `test_mask`: the test mask of the edges
- `sample_graph`: the sampled graph
- `edge_index_dict`: the edge index of the sampled graph
Finally, we can partition the graph and save the data:
.. code-block:: python
partition_save('./dataset/here/'+data_name, data, 16, 'metis_for_tgnn',
edge_weight_dict=edge_weight_dict)
Distributed Training Distributed Training
==================== ====================
.. note:: Preparation For Distributed Environment
开始介绍分布式训练相关的方法,并展示一个最简单的分布式DDP训练流程 ---------------------------------------
\ No newline at end of file
Before start training, we need to prepare the environment for distributed training, inluding the following steps:
1. Initialize the Distributed context
.. code-block:: python
ctx = DistributedContext.init(backend="nccl", use_gpu=True)
2. Load the partitioned dataset
.. code-block:: python
pdata = partition_load("/mnt/data/part_data/dataset/here/{}".format(args.dataname), algo="metis_for_tgnn")
graph = DistributedGraphStore(pdata = pdata)
sample_graph = TemporalNeighborSampleGraph(sample_graph = pdata.sample_graph,mode = 'full')
train_data = torch.masked_select(graph.edge_index,pdata.train_mask.to(graph.edge_index.device)).reshape(2,-1)
train_ts = torch.masked_select(graph.edge_ts,pdata.train_mask.to(graph.edge_index.device))
val_data = torch.masked_select(graph.edge_index,pdata.val_mask.to(graph.edge_index.device)).reshape(2,-1)
val_ts = torch.masked_select(graph.edge_ts,pdata.val_mask.to(graph.edge_index.device))
test_data = torch.masked_select(graph.edge_index,pdata.test_mask.to(graph.edge_index.device)).reshape(2,-1)
test_ts = torch.masked_select(graph.edge_ts,pdata.test_mask.to(graph.edge_index.device))
train_data = DataSet(edges = train_data,ts =train_ts,eids = torch.nonzero(pdata.train_mask).view(-1))
test_data = DataSet(edges = test_data,ts =test_ts,eids = torch.nonzero(pdata.test_mask).view(-1))
val_data = DataSet(edges = val_data,ts = val_ts,eids = torch.nonzero(pdata.val_mask).view(-1))
train_stream = torch.cuda.Stream()
send_stream = torch.cuda.Stream()
scatter_stream = torch.cuda.Stream()
3. Construct Mailbox and sampler
.. code-block:: python
mailbox = SharedMailBox(pdata.ids.shape[0], memory_param, dim_edge_feat = pdata.edge_attr.shape[1] if pdata.edge_attr is not None else 0)
sampler = NeighborSampler(num_nodes=graph.num_nodes, num_layers=num_layers, fanout=fanout,graph_data=sample_graph, workers=10,policy = policy, graph_name = "wiki_train")
neg_sampler = NegativeSampling('triplet')
4. Construct the DataLoader
.. code-block:: python
trainloader = DistributedDataLoader(graph,train_data,sampler = sampler,
sampler_fn = SAMPLE_TYPE.SAMPLE_FROM_TEMPORAL_EDGES,
neg_sampler=neg_sampler,
batch_size = train_param['batch_size'],
shuffle=False,
drop_last=True,
chunk_size = None,
train=True,
queue_size = 1000,
mailbox = mailbox)
testloader = DistributedDataLoader(graph,test_data,sampler = sampler,
sampler_fn = SAMPLE_TYPE.SAMPLE_FROM_TEMPORAL_EDGES,
neg_sampler=neg_sampler,
batch_size = train_param['batch_size'],
shuffle=False,
drop_last=False,
chunk_size = None,
train=False,
queue_size = 100,
mailbox = mailbox)
valloader = DistributedDataLoader(graph,val_data,sampler = sampler,
sampler_fn = SAMPLE_TYPE.SAMPLE_FROM_TEMPORAL_EDGES,
neg_sampler=neg_sampler,
batch_size = train_param['batch_size'],
shuffle=False,
drop_last=False,
chunk_size = None,
train=False,
queue_size = 100,
mailbox = mailbox)
5. `Create the Model <module.rst>`_
6. Construct the optimizer, early stopper and creterion
.. code-block:: python
creterion = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=train_param['lr'])
early_stopper = EarlyStopMonitor(max_round=args.patience)
7. Start Training
.. code-block:: python
for e in range(train_param['epoch']):
torch.cuda.synchronize()
model.train()
if mailbox is not None:
mailbox.reset()
model.module.memory_updater.last_updated_nid = None
model.module.memory_updater.last_updated_memory = None
model.module.memory_updater.last_updated_ts = None
for roots,mfgs,metadata,sample_time in trainloader:
with torch.cuda.stream(train_stream):
optimizer.zero_grad()
pred_pos, pred_neg = model(mfgs,metadata)
loss = creterion(pred_pos, torch.ones_like(pred_pos))
loss += creterion(pred_neg, torch.zeros_like(pred_neg))
total_loss += float(loss)
loss.backward()
optimizer.step()
y_pred = torch.cat([pred_pos, pred_neg], dim=0).sigmoid().cpu()
y_true = torch.cat([torch.ones(pred_pos.size(0)), torch.zeros(pred_neg.size(0))], dim=0)
train_aps.append(average_precision_score(y_true, y_pred.detach().numpy()))
if mailbox is not None:
src = metadata['src_pos_index']
dst = metadata['dst_pos_index']
ts = roots.ts
if graph.edge_attr is None:
edge_feats = None
elif(graph.edge_attr.device == torch.device('cpu')):
edge_feats = graph.edge_attr[roots.eids.to('cpu')].to('cuda')
else:
edge_feats = graph.edge_attr[roots.eids]
dist_index_mapper = mfgs[0][0].srcdata['ID']
root_index = torch.cat((src,dst))
last_updated_nid = model.module.memory_updater.last_updated_nid[root_index]
last_updated_memory = model.module.memory_updater.last_updated_memory[root_index]
last_updated_ts=model.module.memory_updater.last_updated_ts[root_index]
index, memory, memory_ts = mailbox.get_update_memory(last_updated_nid,
last_updated_memory,
last_updated_ts)
index, mail, mail_ts = mailbox.get_update_mail(dist_index_mapper,
src,dst,ts,edge_feats,
model.module.memory_updater.last_updated_memory,
model.module.embedding,use_src_emb,use_dst_emb,
)
mailbox.set_mailbox_all_to_all(index,memory,memory_ts,mail,mail_ts,reduce_Op = 'max')
train_ap = float(torch.tensor(train_aps).mean())
ap, auc = eval('val')
print('\ttrain loss:{:.4f} train ap:{:4f} val ap:{:4f} val auc:{:4f}'.format(total_loss,train_ap, ap, auc))
8. Deifine the Evaluation function
.. code-block:: python
def eval(mode='val'):
model.eval()
aps = list()
aucs_mrrs = list()
if mode == 'val':
loader = valloader
elif mode == 'test':
loader = testloader
elif mode == 'train':
loader = trainloader
with torch.no_grad():
total_loss = 0
for roots,mfgs,metadata,sample_time in loader:
pred_pos, pred_neg = model(mfgs,metadata)
total_loss += creterion(pred_pos, torch.ones_like(pred_pos))
total_loss += creterion(pred_neg, torch.zeros_like(pred_neg))
y_pred = torch.cat([pred_pos, pred_neg], dim=0).sigmoid().cpu()
y_true = torch.cat([torch.ones(pred_pos.size(0)), torch.zeros(pred_neg.size(0))], dim=0)
aps.append(average_precision_score(y_true, y_pred.detach().numpy()))
aucs_mrrs.append(roc_auc_score(y_true, y_pred))
if mailbox is not None:
src = metadata['src_pos_index']
dst = metadata['dst_pos_index']
ts = roots.ts
if graph.edge_attr is None:
edge_feats = None
elif(graph.edge_attr.device == torch.device('cpu')):
edge_feats = graph.edge_attr[roots.eids.to('cpu')].to('cuda')
else:
edge_feats = graph.edge_attr[roots.eids]
dist_index_mapper = mfgs[0][0].srcdata['ID']
root_index = torch.cat((src,dst))
last_updated_nid = model.module.memory_updater.last_updated_nid[root_index]
last_updated_memory = model.module.memory_updater.last_updated_memory[root_index]
last_updated_ts=model.module.memory_updater.last_updated_ts[root_index]
index, memory, memory_ts = mailbox.get_update_memory(last_updated_nid,
last_updated_memory,
last_updated_ts)
index, mail, mail_ts = mailbox.get_update_mail(dist_index_mapper,
src,dst,ts,edge_feats,
model.module.memory_updater.last_updated_memory,
model.module.embedding,use_src_emb,use_dst_emb,
)
mailbox.set_mailbox_all_to_all(index,memory,memory_ts,mail,mail_ts,reduce_Op = 'max')
9. Start Evaluation
.. code-block::python
if mailbox is not None:
mailbox.reset()
model.module.memory_updater.last_updated_nid = None
print("Train eval:", eval('train'))
print("Val eval:", eval('test'))
ap, auc = eval('val')
print('\ttest AP:{:4f} test MRR:{:4f}'.format(ap, auc))
...@@ -5,5 +5,4 @@ Tutorials ...@@ -5,5 +5,4 @@ Tutorials
intro intro
module module
dataset dataset
application
distributed distributed
\ No newline at end of file
Introduction to Temporal GNN Introduction to Temporal GNN
============================================== ==============================================
.. note:: There are so many real-word systems that can be formulated as temporal interaction graphs, such as social network and citation network. In these systems, the nodes represent the entities and the edges represent the interactions between entities. The interactions are usually time-stamped, which means the edges are associated with time. Temporal interaction graphs are dynamic, which means the graph structure changes over time. For example, in a social network, the friendship between two people may be established or broken at different time. In a citation network, a paper may cite another paper at different time.
简单介绍一下时序GNN,应用场景,需要解决的问题等,相当于一个总体的介绍
To encapsulate the temporal information present in these graphs and learn dynamic representations, researchers have introduced temporal graph neural networks (GNNs). These networks are capable of modeling both structural and temporal dependencies within the graph. Numerous innovative frameworks have been proposed to date, achieving outstanding performance in specific tasks such as link prediction. Based on two different methods to represent temporal graphs, we can divide temporal GNNs into two categories:
1. continuous-time temporal GNNs, which model the temporal graph as a sequence of interactions
2. discrete-time temporal GNNs, which model the temporal graph as a sequence of snapshots
However, as the temporal graph expands—potentially encompassing millions of nodes and billions of edges—it becomes increasingly challenging to scale temporal GNN training to accommodate these larger graphs. The reasons are twofold: first, sampling neighbors from a larger graph demands more time; second, chronological training also incurs a higher time cost. To address these challenges, we introduce StarryGL in this tutorial. StarryGL is a distributed temporal GNN framework designed to efficiently navigate the complexities of training larger temporal graphs.
\ No newline at end of file
Creating Temporal GNN Models Creating Temporal GNN Models
============================ ============================
.. note:: Continuous-time Temporal GNN Models
介绍如何创建GNN模型,找最经典最简洁的两个例子即可。包括 **离散时间动态图模型** 模型构建和 **连续时间动态图模型**。 -----------------------------------
\ No newline at end of file
To create a continuous-time temporal GNN model, we first need to define a configuration file with the suffix yml to specify the model structures and parameters. Here we use the configuration file :code:`TGN.yml` for TGN model as an example:
.. code-block:: yaml
sampling:
- layer: 1
neighbor:
- 10
strategy: 'recent'
prop_time: False
history: 1
duration: 0
num_thread: 32
memory:
- type: 'node'
dim_time: 100
deliver_to: 'self'
mail_combine: 'last'
memory_update: 'gru'
mailbox_size: 1
combine_node_feature: True
dim_out: 100
gnn:
- arch: 'transformer_attention'
use_src_emb: False
use_dst_emb: False
layer: 1
att_head: 2
dim_time: 100
dim_out: 100
train:
- epoch: 20
batch_size: 200
# reorder: 16
lr: 0.0001
dropout: 0.2
att_dropout: 0.2
all_on_gpu: True
The configuration file is composed of four parts: :code:`sampling`, :code:`memory`, :code:`gnn` and :code:`train`. Here are their meanings:
- :code:`sampling`: This part specifies the sampling strategy for the temporal graph. :code:`layer` field specifies the number of layers in the sampling strategy. The :code:`neighbor` field specifies the number of neighbors to sample for each layer. The :code:`strategy` field specifies the sampling strategy(recent or uniform). The :code:`prop_time` field specifies whether to propagate the time information. The :code:`history` field specifies the number of historical timestamps to use. The :code:`duration` field specifies the duration of the time window. The :code:`num_thread` field specifies the number of threads to use for sampling.
- :code:`memory`: This part specifies the memory module. :code:`type` field specifies the type of memory module(node or none). :code:`dim_time` field specifies the dimension of the time embedding. :code:`deliver_to` field specifies the destination of the message. :code:`mail_combine` field specifies the way to combine the messages. :code:`memory_update` field specifies the way to update the memory. :code:`mailbox_size` field specifies the size of the mailbox. :code:`combine_node_feature` field specifies whether to combine the node features. :code:`dim_out` field specifies the dimension of the output.
- :code:`gnn`: This part specifies the GNN module. :code:`arch` field specifies the architecture of the GNN module. :code:`use_src_emb` field specifies whether to use the source embedding. :code:`use_dst_emb` field specifies whether to use the destination embedding. :code:`layer` field specifies the number of layers in the GNN module. :code:`att_head` field specifies the number of attention heads. :code:`dim_time` field specifies the dimension of the time embedding. :code:`dim_out` field specifies the dimension of the output.
- :code:`train`: This part specifies the training parameters. :code:`epoch` field specifies the number of epochs. :code:`batch_size` field specifies the batch size. :code:`lr` field specifies the learning rate. :code:`dropout` field specifies the dropout rate. :code:`att_dropout` field specifies the attention dropout rate. :code:`all_on_gpu` field specifies whether to put all the data on GPU.
After defining the configuration file, we can firstly read the parameters from the configuration file and create the model by constructing a :code:`General Model` object:
.. code-block:: python
def parse_config(f):
conf = yaml.safe_load(open(f, 'r'))
sample_param = conf['sampling'][0]
memory_param = conf['memory'][0]
gnn_param = conf['gnn'][0]
train_param = conf['train'][0]
return sample_param, memory_param, gnn_param, train_param
sample_param, memory_param, gnn_param, train_param = parse_config('./config/{}.yml'.format(args.model))
model = GeneralModel(gnn_dim_node, gnn_dim_edge, sample_param, memory_param, gnn_param, train_param).cuda()
model = DDP(model)
Then a :code:`GeneralModel` object is created. If needed, we can adjust the model's parameters by modifying the contents of the configuration file. Here we provide 5 models for continuous-time temporal GNNs:
- :code:`TGN`: The TGN model proposed in `Temporal Graph Networks for Deep Learning on Dynamic Graphs <https://arxiv.org/abs/2006.10637>`__.
- :code:`DyRep`: The DyRep model proposed in `Representation Learning and Reasoning on Temporal Knowledge Graphs <https://arxiv.org/abs/1803.04051>`__.
- :code:`TIGER`: The TIGER model proposed in `TIGER: A Transformer-Based Framework for Temporal Knowledge Graph Completion <https://arxiv.org/abs/2302.06057>`__.
- :code:`Jodie`: The Jodie model proposed in `JODIE: Joint Optimization of Dynamics and Importance for Online Embedding <https://arxiv.org/abs/1908.01207>`__.
- :code:`TGAT`: The TGAT model proposed in `Temporal Graph Attention for Deep Temporal Modeling <https://arxiv.org/abs/2002.07962>`__.
\ No newline at end of file
...@@ -3,11 +3,17 @@ ...@@ -3,11 +3,17 @@
mkdir -p build && cd build mkdir -p build && cd build
cmake .. \ cmake .. \
-DCMAKE_EXPORT_COMPILE_COMMANDS=ON \ -DCMAKE_EXPORT_COMPILE_COMMANDS=ON \
-DCMAKE_PREFIX_PATH="/home/hwj/.miniconda3/envs/sgl/lib/python3.10/site-packages" \ <<<<<<< HEAD
-DPython3_ROOT_DIR="/home/hwj/.miniconda3/envs/sgl" \ -DCMAKE_PREFIX_PATH="/home/zlj/.miniconda3/envs/dgnn/lib/python3.10/site-packages" \
-DCUDA_TOOLKIT_ROOT_DIR="/home/hwj/.local/cuda-11.8" \ -DPython3_ROOT_DIR="/home/zlj/.miniconda3/envs/dgnn" \
-DCUDA_TOOLKIT_ROOT_DIR="/home/zlj/local/cuda-12.2" \
=======
-DCMAKE_PREFIX_PATH=$(python -c "from distutils.sysconfig import get_python_lib; print(get_python_lib())") \
-DPython3_ROOT_DIR=$(python -c "import sys; print(sys.prefix)") \
-DCUDA_TOOLKIT_ROOT_DIR=${CUDA_HOME:-"$(realpath $(dirname $(which nvcc))/../)"} \
>>>>>>> 98ad16d40e5e3e0a7dcdbc4f21dc9e164abc625f
&& make -j32 \ && make -j32 \
&& rm -rf ../starrygl/lib \ && rm -rf ../starrygl/lib \
&& mkdir ../starrygl/lib \ && mkdir ../starrygl/lib \
&& cp lib*.so ../starrygl/lib/ \ && cp lib*.so ../starrygl/lib/ \
&& patchelf --set-rpath '$ORIGIN:$ORIGIN/lib' --force-rpath ../starrygl/lib/*.so && patchelf --set-rpath '$ORIGIN:$ORIGIN/lib' --force-rpath ../starrygl/lib/*.so
\ No newline at end of file
...@@ -3,7 +3,7 @@ torch==2.1.1+cu118 ...@@ -3,7 +3,7 @@ torch==2.1.1+cu118
torchvision==0.16.1+cu118 torchvision==0.16.1+cu118
torchaudio==2.1.1+cu118 torchaudio==2.1.1+cu118
--extra-index-url https://data.pyg.org/whl/torch-2.1.0+cu118.html --find-links https://data.pyg.org/whl/torch-2.1.0+cu118.html
torch_geometric==2.4.0 torch_geometric==2.4.0
pyg_lib==0.3.1+pt21cu118 pyg_lib==0.3.1+pt21cu118
torch_scatter==2.1.2+pt21cu118 torch_scatter==2.1.2+pt21cu118
...@@ -11,6 +11,12 @@ torch_sparse==0.6.18+pt21cu118 ...@@ -11,6 +11,12 @@ torch_sparse==0.6.18+pt21cu118
torch_cluster==1.6.3+pt21cu118 torch_cluster==1.6.3+pt21cu118
torch_spline_conv==1.2.2+pt21cu118 torch_spline_conv==1.2.2+pt21cu118
--find-links https://data.dgl.ai/wheels/cu118/repo.html
dgl==1.1.3+cu118
--find-links https://data.dgl.ai/wheels-test/repo.html
dglgo==0.0.2
ogb ogb
tqdm tqdm
networkx networkx
\ No newline at end of file
...@@ -294,7 +294,7 @@ class DistributedTensor: ...@@ -294,7 +294,7 @@ class DistributedTensor:
index = dist_index.loc index = dist_index.loc
futs: List[torch.futures.Future] = [] futs: List[torch.futures.Future] = []
for i in range(self.num_parts()): for i in range(self.num_parts):
mask = part_idx == i mask = part_idx == i
f = self.accessor.async_index_copy_(0, index[mask], source[mask], self.rrefs[i]) f = self.accessor.async_index_copy_(0, index[mask], source[mask], self.rrefs[i])
futs.append(f) futs.append(f)
...@@ -308,7 +308,7 @@ class DistributedTensor: ...@@ -308,7 +308,7 @@ class DistributedTensor:
index = dist_index.loc index = dist_index.loc
futs: List[torch.futures.Future] = [] futs: List[torch.futures.Future] = []
for i in range(self.num_parts()): for i in range(self.num_parts):
mask = part_idx == i mask = part_idx == i
f = self.accessor.async_index_add_(0, index[mask], source[mask], self.rrefs[i]) f = self.accessor.async_index_add_(0, index[mask], source[mask], self.rrefs[i])
futs.append(f) futs.append(f)
......
import torch import torch
import dgl
from os.path import abspath, join, dirname from os.path import abspath, join, dirname
import sys import sys
sys.path.insert(0, join(abspath(dirname(__file__)))) sys.path.insert(0, join(abspath(dirname(__file__))))
...@@ -47,7 +46,7 @@ class GeneralModel(torch.nn.Module): ...@@ -47,7 +46,7 @@ class GeneralModel(torch.nn.Module):
self.edge_predictor = EdgePredictor(gnn_param['dim_out']) self.edge_predictor = EdgePredictor(gnn_param['dim_out'])
if 'combine' in gnn_param and gnn_param['combine'] == 'rnn': if 'combine' in gnn_param and gnn_param['combine'] == 'rnn':
self.combiner = torch.nn.RNN(gnn_param['dim_out'], gnn_param['dim_out']) self.combiner = torch.nn.RNN(gnn_param['dim_out'], gnn_param['dim_out'])
def forward(self, mfgs, metadata = None,neg_samples=1): def forward(self, mfgs, metadata = None,neg_samples=1):
if self.memory_param['type'] == 'node': if self.memory_param['type'] == 'node':
...@@ -68,8 +67,14 @@ class GeneralModel(torch.nn.Module): ...@@ -68,8 +67,14 @@ class GeneralModel(torch.nn.Module):
out = torch.stack(out, dim=0) out = torch.stack(out, dim=0)
out = self.combiner(out)[0][-1, :, :] out = self.combiner(out)[0][-1, :, :]
#metadata需要在前面去重的时候记一下id #metadata需要在前面去重的时候记一下id
if self.gnn_param['use_src_emb'] or self.gnn_param['use_dst_emb']:
self.embedding = out.detach().clone()
else:
self.embedding = None
if metadata is not None: if metadata is not None:
#out = torch.cat((out[metadata['dst_pos_pos']],out[metadata['src_id_pos']],out[metadata['dst_neg_pos']]),0) #out = torch.cat((out[metadata['dst_pos_pos']],out[metadata['src_id_pos']],out[metadata['dst_neg_pos']]),0)
if self.gnn_param['dyrep']:
out = self.memory_updater.last_updated_memory
out = torch.cat((out[metadata['src_pos_index']],out[metadata['dst_pos_index']],out[metadata['src_neg_index']]),0) out = torch.cat((out[metadata['src_pos_index']],out[metadata['dst_pos_index']],out[metadata['src_neg_index']]),0)
return self.edge_predictor(out, neg_samples=neg_samples) return self.edge_predictor(out, neg_samples=neg_samples)
......
import yaml import yaml
import numpy as np
def parse_config(f): def parse_config(f):
conf = yaml.safe_load(open(f, 'r')) conf = yaml.safe_load(open(f, 'r'))
...@@ -7,4 +7,32 @@ def parse_config(f): ...@@ -7,4 +7,32 @@ def parse_config(f):
memory_param = conf['memory'][0] memory_param = conf['memory'][0]
gnn_param = conf['gnn'][0] gnn_param = conf['gnn'][0]
train_param = conf['train'][0] train_param = conf['train'][0]
return sample_param, memory_param, gnn_param, train_param return sample_param, memory_param, gnn_param, train_param
\ No newline at end of file
class EarlyStopMonitor(object):
def __init__(self, max_round=3, higher_better=True, tolerance=1e-10):
self.max_round = max_round
self.num_round = 0
self.epoch_count = 0
self.best_epoch = 0
self.last_best = None
self.higher_better = higher_better
self.tolerance = tolerance
def early_stop_check(self, curr_val):
if not self.higher_better:
curr_val *= -1
if self.last_best is None:
self.last_best = curr_val
elif (curr_val - self.last_best) / np.abs(self.last_best) > self.tolerance:
self.last_best = curr_val
self.num_round = 0
self.best_epoch = self.epoch_count
else:
self.num_round += 1
self.epoch_count += 1
return self.num_round >= self.max_round
\ No newline at end of file
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.autograd as autograd
from torch import Tensor
from typing import *
from torch_scatter import segment_csr, gather_csr
from torch_sparse import SparseTensor
__all__ = [
"EmmaAttention",
"EmmaSum",
]
class EmmaAttention(nn.Module):
def __init__(self) -> None:
super().__init__()
self.register_buffer(
"his_x",
torch.empty(0),
persistent=False,
)
self.register_buffer(
"his_m",
torch.empty(0),
persistent=False,
)
self.register_buffer(
"inv_w",
torch.empty(0),
persistent=False,
)
self.reset_parameters()
def reset_parameters(self):
self.get_buffer("his_x").zero_()
self.get_buffer("his_m").fill_(-torch.inf)
self.get_buffer("inv_w").zero_()
def forward(self, x: Tensor, max_a: Tensor, agg_n: Tensor):
if self.training:
his_x = self.get_buffer("his_x")
his_m = self.get_buffer("his_m")
inv_w = self.get_buffer("inv_w")
x = EmmaAttentionFunction.apply(
x, max_a, his_x, his_m, agg_n, inv_w)
else:
inv_w = 1.0 / agg_n.data
inv_w[agg_n == 0] = 0.0
self._copy_or_clone("his_x", x)
self._copy_or_clone("his_m", max_a)
self._copy_or_clone("inv_w", inv_w)
return x
def _copy_or_clone(self, name: str, x: Tensor):
_x = self.get_buffer(name)
if _x.size() != x.size():
self.register_buffer(
name, x.data.clone(), persistent=False)
else:
_x.copy_(x.data)
@staticmethod
def softmax_gat(
src_a: Tensor,
dst_a: Tensor,
adj_t: SparseTensor,
negative_slope: float = 0.01,
) -> Tuple[SparseTensor, Tensor]:
assert src_a.dim() in {1, 2}
assert src_a.dim() == dst_a.dim()
ptr, ind, val = adj_t.csr()
a = src_a[ind] + gather_csr(dst_a, ptr)
a = F.leaky_relu(a, negative_slope=negative_slope)
with torch.no_grad():
max_a = torch.full_like(dst_a, -torch.inf)
max_a = segment_csr(a, ptr, reduce="max", out=max_a)
exp_a = torch.exp(a - gather_csr(max_a, ptr))
if val is not None:
assert val.dim() == 1
if exp_a.dim() == 1:
exp_a = exp_a * val
else:
exp_a = exp_a * val.unsqueeze(-1)
sum_exp_a = segment_csr(exp_a, ptr, reduce="sum")
exp_a = exp_a / gather_csr(sum_exp_a, ptr)
with torch.no_grad():
max_a.add_(sum_exp_a.log())
adj_t = SparseTensor(rowptr=ptr, col=ind, value=exp_a)
return adj_t, max_a
@staticmethod
def apply_gat(
x: Tensor,
src_a: Tensor,
dst_a: Tensor,
adj_t: SparseTensor,
negative_slope: float = 0.01,
) -> Tuple[Tensor, Tensor]:
adj_t, max_a = EmmaAttention.softmax_gat(
src_a=src_a, dst_a=dst_a,
adj_t=adj_t, negative_slope=negative_slope,
)
ptr, ind, val = adj_t.csr()
if val.dim() == 1:
assert x.dim() == 2
x = adj_t @ x
elif val.dim() == 2:
assert x.dim() == 3
assert x.size(1) == val.size(1)
xs = []
for i in range(x.size(1)):
xs.append(
SparseTensor(
rowptr=ptr, col=ind, value=val[:,i],
) @ x[:,i,:]
)
x = torch.cat(xs, dim=1).view(-1, *x.shape[1:])
return x, max_a
class EmmaAttentionFunction(autograd.Function):
@staticmethod
def forward(
ctx: autograd.function.FunctionCtx,
x: Tensor,
max_a: Tensor,
his_x: Tensor,
his_m: Tensor,
agg_n: Tensor,
inv_w: Tensor,
):
assert x.dim() in {2, 3}
assert x.dim() == his_x.dim()
assert max_a.dim() == his_m.dim()
beta = (1.0 - inv_w * agg_n).clamp_(0.0, 1.0)
if x.dim() == 2:
assert max_a.dim() == 1
elif x.dim() == 3:
assert max_a.dim() == 2
beta = beta.unsqueeze_(-1)
max_m = torch.max(max_a, his_m)
p = (his_m - max_m).nan_to_num_(0.0).exp_().mul_(beta)
q = (max_a - max_m).nan_to_num_(0.0).exp_()
t = p + q
p.div_(t).unsqueeze_(-1)
q.div_(t).unsqueeze_(-1)
his_x.mul_(p).add_(x * q)
his_m.copy_(max_m).add_(t.log_())
ctx.save_for_backward(q)
return his_x
@staticmethod
def backward(
ctx: autograd.function.FunctionCtx,
grad: Tensor,
):
q, = ctx.saved_tensors
return grad * q, None, None, None, None, None
class EmmaSum(nn.Module):
def __init__(self) -> None:
super().__init__()
self.register_buffer(
"his_x",
torch.empty(0),
persistent=False,
)
self.register_buffer(
"inv_w",
torch.empty(0),
persistent=False,
)
self.reset_parameters()
def reset_parameters(self):
self.get_buffer("his_x").zero_()
self.get_buffer("inv_w").zero_()
def forward(self, x: Tensor, agg_n: Tensor, aggr: str = "sum"):
assert aggr in {"sum", "mean"}
if self.training:
his_x = self.get_buffer("his_x")
inv_w = self.get_buffer("inv_w")
x = EmmaSumFunction.apply(x, his_x, agg_n, inv_w)
else:
inv_w = 1.0 / agg_n.data
inv_w[agg_n == 0] = 0.0
self._copy_or_clone("his_x", x)
self._copy_or_clone("inv_w", inv_w)
if aggr == "mean":
x = x * inv_w[:,None]
return x
def _copy_or_clone(self, name: str, x: Tensor):
_x = self.get_buffer(name)
if _x.size() != x.size():
self.register_buffer(
name, x.data.clone(), persistent=False)
else:
_x.copy_(x.data)
class EmmaSumFunction(autograd.Function):
@staticmethod
def forward(
ctx: autograd.function.FunctionCtx,
x: Tensor,
his_x: Tensor,
agg_n: Tensor,
inv_w: Tensor,
):
assert x.dim() == 2
assert his_x.dim() == x.dim()
beta = (1.0 - inv_w * agg_n) \
.clamp_(0.0, 1.0).unsqueeze_(-1)
his_x.mul_(beta).add_(x)
# ctx.save_for_backward(inv_w)
return his_x
@staticmethod
def backward(
ctx: autograd.function.FunctionCtx,
grad: Tensor,
):
# inv_w, = ctx.saved_tensors
# return grad * inv_w[:,None], None, None, None
return grad, None, None, None
\ No newline at end of file
...@@ -75,6 +75,12 @@ class LayerPipe(ABC): ...@@ -75,6 +75,12 @@ class LayerPipe(ABC):
models.append((key, val)) models.append((key, val))
return tuple(models) return tuple(models)
def parameters(self):
params: List[nn.Parameter] = []
for name, m in self.get_model():
params.extend(m.parameters())
return params
def register_route(self, *xs: Tensor): def register_route(self, *xs: Tensor):
for t in xs: for t in xs:
t.requires_route = True t.requires_route = True
......
...@@ -55,6 +55,12 @@ class SequencePipe(ABC): ...@@ -55,6 +55,12 @@ class SequencePipe(ABC):
models.append((key, val)) models.append((key, val))
return tuple(models) return tuple(models)
def parameters(self):
params: List[nn.Parameter] = []
for name, m in self.get_model():
params.extend(m.parameters())
return params
def to(self, device: Any): def to(self, device: Any):
for _, net in self.get_model(): for _, net in self.get_model():
net.to(device) net.to(device)
......
...@@ -17,12 +17,31 @@ class FetchFeatureCache: ...@@ -17,12 +17,31 @@ class FetchFeatureCache:
graph: DistributedGraphStore, graph: DistributedGraphStore,
mailbox:SharedMailBox = None, mailbox:SharedMailBox = None,
policy = 'lru'): policy = 'lru'):
"""
method to create a fetch cache instance.
Args:
num_nodes: Total number of nodes in the graph.
num_edges: Total number of edges in the graph.
edge_cache_ratio: The hit rate of cache edges.
node_cache_ratio: The hit rate of cache nodes.
graph: Distributed graph store.
mailbox: used for storing information.
policy: Caching policy, either 'lru' or 'static'.
"""
global _FetchCache global _FetchCache
_FetchCache = FetchFeatureCache(num_nodes, num_edges, _FetchCache = FetchFeatureCache(num_nodes, num_edges,
edge_cache_ratio, node_cache_ratio, edge_cache_ratio, node_cache_ratio,
graph,mailbox,policy) graph,mailbox,policy)
@staticmethod @staticmethod
def getFetchCache(): def getFetchCache():
"""
method to get the existing fetch cache instance.
Returns:
FetchFeatureCache: The existing fetch cache instance.
"""
global _FetchCache global _FetchCache
return _FetchCache return _FetchCache
def __init__(self, num_nodes: int, num_edges: int, def __init__(self, num_nodes: int, num_edges: int,
...@@ -31,6 +50,19 @@ class FetchFeatureCache: ...@@ -31,6 +50,19 @@ class FetchFeatureCache:
mailbox:SharedMailBox = None, mailbox:SharedMailBox = None,
policy = 'lru' policy = 'lru'
): ):
"""
Initializes the FetchFeatureCache instance.
Args:
num_nodes: Total number of nodes in the graph.
num_edges: Total number of edges in the graph.
edge_cache_ratio: The hit rate of cache edges.
node_cache_ratio: The hit rate of cache nodes.
graph: Distributed graph store.
mailbox: used for storing information.
policy: Caching policy, either 'lru' or 'static'.
"""
if policy == 'lru': if policy == 'lru':
init_fn = LRU_cache.LRUCache init_fn = LRU_cache.LRUCache
elif policy == 'static': elif policy == 'static':
...@@ -62,7 +94,17 @@ class FetchFeatureCache: ...@@ -62,7 +94,17 @@ class FetchFeatureCache:
def fetch_feature(self, nid: Optional[torch.Tensor] = None, dist_nid = None, def fetch_feature(self, nid: Optional[torch.Tensor] = None, dist_nid = None,
eid: Optional[torch.Tensor] = None, dist_eid = None eid: Optional[torch.Tensor] = None, dist_eid = None
): ):
"""
Fetches node and edge features along with mailbox memory.
Args:
nid: Node indices to fetch features for.
dist_nid: The remote communication corresponding to nid.
eid: Edge indices to fetch features for.
dist_eid: The remote communication corresponding to eid.
"""
nfeat = None nfeat = None
mem = None mem = None
efeat = None efeat = None
...@@ -147,6 +189,14 @@ class FetchFeatureCache: ...@@ -147,6 +189,14 @@ class FetchFeatureCache:
return nfeat,efeat,mem return nfeat,efeat,mem
def init_cache_with_presample(self,dataloader, num_epoch:int = 10): def init_cache_with_presample(self,dataloader, num_epoch:int = 10):
"""
Initializes the cache with pre-sampled data from the provided dataloader.
Args:
dataloader: The data loader we implement, containing the graph data.
num_epoch: Number of epochs to pre-sample the data.
"""
node_size = self.node_cache.capacity if self.node_cache is not None else 0 node_size = self.node_cache.capacity if self.node_cache is not None else 0
edge_size = self.edge_cache.capacity if self.edge_cache is not None else 0 edge_size = self.edge_cache.capacity if self.edge_cache is not None else 0
node_counts,edge_counts = pre_sample(dataloader=dataloader, node_counts,edge_counts = pre_sample(dataloader=dataloader,
......
...@@ -21,10 +21,54 @@ import math ...@@ -21,10 +21,54 @@ import math
class DistributedDataLoader: class DistributedDataLoader:
''' '''
Args: We will perform feature fetch in the data loader.
data_path: the path of loaded graph ,each part 0 of graph is saved on $path$/rank_0 you can simply define a data loader for use, while starrygl assisting in fetching node or edge features:
num_replicas: the num of worker
Args:
graph: distributed graph store
data: the graph data
sampler: a parallel sampler like `NeighborSampler` above
sampler_fn: sample type
neg_sampler: negative sampler
batch_size: batch size
mailbox: APAN's mailbox and TGN's memory implemented by starrygl
Examples:
.. code-block:: python
import torch
from starrygl.sample.data_loader import DistributedDataLoader
from starrygl.sample.part_utils.partition_tgnn import partition_load
from starrygl.sample.graph_core import DataSet, DistributedGraphStore, TemporalNeighborSampleGraph
from starrygl.sample.memory.shared_mailbox import SharedMailBox
from starrygl.sample.sample_core.neighbor_sampler import NeighborSampler
from starrygl.sample.sample_core.base import NegativeSampling
from starrygl.sample.batch_data import SAMPLE_TYPE
pdata = partition_load("PATH/{}".format(dataname), algo="metis_for_tgnn")
graph = DistributedGraphStore(pdata = pdata, uvm_edge = False, uvm_node = False)
sample_graph = TemporalNeighborSampleGraph(sample_graph = pdata.sample_graph,mode = 'full')
mailbox = SharedMailBox(pdata.ids.shape[0], memory_param, dim_edge_feat=pdata.edge_attr.shape[1] if pdata. edge_attr is not None else 0)
sampler = NeighborSampler(num_nodes=graph.num_nodes, num_layers=1, fanout=[10], graph_data=sample_graph, workers=15,policy = 'recent',graph_name = "wiki_train")
neg_sampler = NegativeSampling('triplet')
train_data = torch.masked_select(graph.edge_index, pdata.train_mask.to(graph.edge_index.device)).reshape (2, -1)
trainloader = DistributedDataLoader(graph, train_data, sampler=sampler, sampler_fn=SAMPLE_TYPE. SAMPLE_FROM_TEMPORAL_EDGES,neg_sampler=neg_sampler, batch_size=1000, shuffle=False, drop_last=True, chunk_size = None,train=True, mailbox=mailbox )
In the data loader, we will call the `graph_sample`, sourced from `starrygl.sample.batch_data`.
And the `to_block` function in the `graph_sample` will implement feature fetching.
If cache is not used, we will directly fetch node or edge features from the graph data,
otherwise we will call `fetch_data` for feature fetching.
''' '''
def __init__( def __init__(
...@@ -111,10 +155,10 @@ class DistributedDataLoader: ...@@ -111,10 +155,10 @@ class DistributedDataLoader:
self.expected_idx = data_size // self.batch_size if self.drop_last is True else int(math.ceil(data_size/self.batch_size)) self.expected_idx = data_size // self.batch_size if self.drop_last is True else int(math.ceil(data_size/self.batch_size))
if dist.get_world_size() > 1: if dist.get_world_size() > 1:
num_epochs = torch.tensor([self.expected_idx],dtype = torch.long,device=self.device) num_batchs = torch.tensor([self.expected_idx],dtype = torch.long,device=self.device)
print(num_epochs) print("num_batchs:", num_batchs)
dist.all_reduce(num_epochs, op=op) dist.all_reduce(num_batchs, op=op)
self.expected_idx = int(num_epochs.item()) self.expected_idx = int(num_batchs.item())
def _next_data(self): def _next_data(self):
if self.current_pos >= self.dataset.len: if self.current_pos >= self.dataset.len:
...@@ -148,6 +192,7 @@ class DistributedDataLoader: ...@@ -148,6 +192,7 @@ class DistributedDataLoader:
self.device) self.device)
self.recv_idxs += 1 self.recv_idxs += 1
assert batch_data is not None assert batch_data is not None
torch.cuda.synchronize()
return batch_data return batch_data
else : else :
raise StopIteration raise StopIteration
......
import starrygl
from starrygl.distributed.context import DistributedContext from starrygl.distributed.context import DistributedContext
from starrygl.distributed.utils import DistIndex, DistributedTensor from starrygl.distributed.utils import DistIndex, DistributedTensor
from starrygl.sample.graph_core.utils import build_mapper from starrygl.sample.graph_core.utils import build_mapper
...@@ -6,8 +7,22 @@ import torch ...@@ -6,8 +7,22 @@ import torch
import torch.distributed as dist import torch.distributed as dist
from torch_geometric.data import Data from torch_geometric.data import Data
from starrygl.utils.uvm import cudaMemoryAdvise, uvm_advise, uvm_empty, uvm_prefetch, uvm_share
class DistributedGraphStore: class DistributedGraphStore:
'''
Initializes the DistributedGraphStore with distributed graph data.
Args:
pdata: Graph data object containing ids, eids, edge_index, edge_ts, sample_graph, x, and edge_attr.
device: Device to which tensors are moved (default is 'cuda').
uvm_node: If True, enables Unified Virtual Memory (UVM) for node data.
uvm_edge: If True, enables Unified Virtual Memory (UVM) for edge data.
'''
def __init__(self, pdata, device = torch.device('cuda'), def __init__(self, pdata, device = torch.device('cuda'),
uvm_node = False, uvm_node = False,
uvm_edge = False): uvm_edge = False):
...@@ -36,12 +51,12 @@ class DistributedGraphStore: ...@@ -36,12 +51,12 @@ class DistributedGraphStore:
x = pdata.x.to(self.device) x = pdata.x.to(self.device)
else: else:
if self.device.type == 'cuda': if self.device.type == 'cuda':
x = uvm_empty(*pdata.x.size(), x = starrygl.utils.uvm.uvm_empty(*pdata.x.size(),
dtype=pdata.x.dtype, dtype=pdata.x.dtype,
device=ctx.device) device=ctx.device)
uvm_share(x,device = ctx.device) starrygl.utils.uvm.uvm_share(x,device = ctx.device)
uvm_advise(x,cudaMemoryAdvise.cudaMemAdviseSetAccessedBy) starrygl.utils.uvm.uvm_advise(x,starrygl.utils.uvm.cudaMemoryAdvise.cudaMemAdviseSetAccessedBy)
uvm_prefetch(x) starrygl.utils.uvm.uvm_prefetch(x)
if world_size > 1: if world_size > 1:
self.x = DistributedTensor(pdata.x.to(self.device).to(torch.float)) self.x = DistributedTensor(pdata.x.to(self.device).to(torch.float))
else: else:
...@@ -56,12 +71,12 @@ class DistributedGraphStore: ...@@ -56,12 +71,12 @@ class DistributedGraphStore:
edge_attr = pdata.edge_attr.to(self.device) edge_attr = pdata.edge_attr.to(self.device)
else: else:
if self.device.type == 'cuda': if self.device.type == 'cuda':
edge_attr = uvm_empty(*pdata.edge_attr.size(), edge_attr = starrygl.utils.uvm.uvm_empty(*pdata.edge_attr.size(),
dtype=pdata.edge_attr.dtype, dtype=pdata.edge_attr.dtype,
device=ctx.device) device=ctx.device)
uvm_share(edge_attr,device = ctx.device) starrygl.utils.uvm.uvm_share(edge_attr,device = ctx.device)
uvm_advise(edge_attr,cudaMemoryAdvise.cudaMemAdviseSetAccessedBy) starrygl.utils.uvm.uvm_advise(edge_attr,starrygl.utils.uvm.cudaMemoryAdvise.cudaMemAdviseSetAccessedBy)
uvm_prefetch(edge_attr) starrygl.utils.uvm.uvm_prefetch(edge_attr)
if world_size > 1: if world_size > 1:
self.edge_attr = DistributedTensor(edge_attr) self.edge_attr = DistributedTensor(edge_attr)
else: else:
...@@ -70,6 +85,15 @@ class DistributedGraphStore: ...@@ -70,6 +85,15 @@ class DistributedGraphStore:
self.edge_attr = None self.edge_attr = None
def _get_node_attr(self,ids,asyncOp = False): def _get_node_attr(self,ids,asyncOp = False):
'''
Retrieves node attributes for the specified node IDs.
Args:
ids: Node IDs for which to retrieve attributes.
asyncOp: If True, performs asynchronous operation for distributed data.
'''
if self.x is None: if self.x is None:
return None return None
elif dist.get_world_size() == 1: elif dist.get_world_size() == 1:
...@@ -81,6 +105,15 @@ class DistributedGraphStore: ...@@ -81,6 +105,15 @@ class DistributedGraphStore:
return self.x.index_select(ids) return self.x.index_select(ids)
def _get_edge_attr(self,ids,asyncOp = False): def _get_edge_attr(self,ids,asyncOp = False):
'''
Retrieves edge attributes for the specified edge IDs.
Args:
ids: Edge IDs for which to retrieve attributes.
asyncOp: If True, performs asynchronous operation for distributed data.
'''
if self.edge_attr is None: if self.edge_attr is None:
return None return None
elif dist.get_world_size() == 1: elif dist.get_world_size() == 1:
...@@ -93,9 +126,32 @@ class DistributedGraphStore: ...@@ -93,9 +126,32 @@ class DistributedGraphStore:
return self.edge_attr.index_select(ids) return self.edge_attr.index_select(ids)
def _get_dist_index(self,ind,mapper): def _get_dist_index(self,ind,mapper):
'''
Retrieves the distributed index for the specified local index using the provided mapper.
Args:
ind: Local index for which to retrieve the distributed index.
mapper: Mapper providing the distributed index.
'''
return mapper[ind.to(mapper.device)] return mapper[ind.to(mapper.device)]
class DataSet: class DataSet:
'''
Args:
nodes: Tensor representing nodes. If not None, it is moved to the specified device.
edges: Tensor representing edges. If not None, it is moved to the specified device.
labels: Optional parameter for labels.
ts: Tensor representing timestamps. If not None, it is moved to the specified device.
device: Device to which tensors are moved (default is 'cuda').
'''
def __init__(self,nodes = None, def __init__(self,nodes = None,
edges = None, edges = None,
labels = None, labels = None,
...@@ -110,10 +166,15 @@ class DataSet: ...@@ -110,10 +166,15 @@ class DataSet:
if labels is not None: if labels is not None:
self.labels = labels self.labels = labels
self.len = self.nodes.shape[0] if nodes is not None else self.edges.shape[1] self.len = self.nodes.shape[0] if nodes is not None else self.edges.shape[1]
for k, v in kwargs.items(): for k, v in kwargs.items():
assert isinstance(v,torch.Tensor) and v.shape[0]==self.len assert isinstance(v,torch.Tensor) and v.shape[0]==self.len
setattr(self, k, v.to(device)) setattr(self, k, v.to(device))
def _get_empty(self): def _get_empty(self):
'''
Creates an empty dataset with the same device and data types as the current instance.
'''
nodes = torch.empty([],dtype = self.nodes.dtype,device= self.nodes.device)if hasattr(self,'nodes') else None nodes = torch.empty([],dtype = self.nodes.dtype,device= self.nodes.device)if hasattr(self,'nodes') else None
edges = torch.empty([[],[]],dtype = self.edges.dtype,device= self.edge.device)if hasattr(self,'edges') else None edges = torch.empty([[],[]],dtype = self.edges.dtype,device= self.edge.device)if hasattr(self,'edges') else None
d = DataSet(nodes,edges) d = DataSet(nodes,edges)
...@@ -126,6 +187,13 @@ class DataSet: ...@@ -126,6 +187,13 @@ class DataSet:
#@staticmethod #@staticmethod
def get_next(self,indx): def get_next(self,indx):
'''
Retrieves the next dataset based on the provided index.
Args:
indx: Index specifying the dataset to retrieve.
'''
nodes = self.nodes[indx] if hasattr(self,'nodes') else None nodes = self.nodes[indx] if hasattr(self,'nodes') else None
edges = self.edges[:,indx] if hasattr(self,'edges') else None edges = self.edges[:,indx] if hasattr(self,'edges') else None
d = DataSet(nodes,edges) d = DataSet(nodes,edges)
...@@ -138,6 +206,10 @@ class DataSet: ...@@ -138,6 +206,10 @@ class DataSet:
#@staticmethod #@staticmethod
def shuffle(self): def shuffle(self):
'''
Shuffles the dataset and returns a new dataset with the same attributes.
'''
indx = torch.randperm(self.len) indx = torch.randperm(self.len)
nodes = self.nodes[indx] if hasattr(self,'nodes') else None nodes = self.nodes[indx] if hasattr(self,'nodes') else None
edges = self.edges[:,indx] if hasattr(self,'edges') else None edges = self.edges[:,indx] if hasattr(self,'edges') else None
...@@ -151,7 +223,7 @@ class DataSet: ...@@ -151,7 +223,7 @@ class DataSet:
class TemporalGraphData(DistributedGraphStore): class TemporalGraphData(DistributedGraphStore):
def __init__(self,pdata,device): def __init__(self,pdata,device):
super(TemporalGraphData,self).__init__(pdata,device) super(DistributedGraphStore,self).__init__(pdata,device)
def _set_temporal_batch_cache(self,size,pin_size): def _set_temporal_batch_cache(self,size,pin_size):
pass pass
def _load_feature_to_cuda(self,ids): def _load_feature_to_cuda(self,ids):
...@@ -161,6 +233,17 @@ class TemporalGraphData(DistributedGraphStore): ...@@ -161,6 +233,17 @@ class TemporalGraphData(DistributedGraphStore):
class TemporalNeighborSampleGraph(DistributedGraphStore): class TemporalNeighborSampleGraph(DistributedGraphStore):
'''
Args:
sample_graph: A dictionary containing graph structure information, including 'edge_index', 'ts' (edge timestamp), and 'eids' (edge identifiers).
mode: Specifies the dataset mode ('train', 'val', 'test', or 'full').
eids_mapper: Optional parameter for edge identifiers mapping.
'''
def __init__(self, sample_graph=None, mode='full', eids_mapper=None): def __init__(self, sample_graph=None, mode='full', eids_mapper=None):
self.edge_index = sample_graph['edge_index'] self.edge_index = sample_graph['edge_index']
self.num_edges = self.edge_index.shape[1] self.num_edges = self.edge_index.shape[1]
......
import starrygl
from typing import Union from typing import Union
from typing import List from typing import List
from typing import Optional from typing import Optional
...@@ -8,9 +9,41 @@ from starrygl.distributed.context import DistributedContext ...@@ -8,9 +9,41 @@ from starrygl.distributed.context import DistributedContext
from starrygl.distributed.utils import DistIndex, DistributedTensor from starrygl.distributed.utils import DistIndex, DistributedTensor
import torch.distributed as dist import torch.distributed as dist
from starrygl.utils.uvm import cudaMemoryAdvise, uvm_advise, uvm_empty, uvm_prefetch, uvm_share #from starrygl.utils.uvm import cudaMemoryAdvise
class SharedMailBox(): class SharedMailBox():
'''
We will first define our mailbox, including our definitions of mialbox and memory:
.. code-block:: python
from starrygl.sample.memory.shared_mailbox import SharedMailBox
mailbox = SharedMailBox(num_nodes=num_nodes, memory_param=memory_param, dim_edge_feat=dim_edge_feat)
Args:
num_nodes (int): number of nodes
memory_param (dict): the memory parameters in the yaml file,refer to TGL
dim_edge_feat (int): the dim of edge feature
device (torch.device): the device used to store MailBox
uvm (bool): 1-use uvm, 0-don't use uvm
Examples:
.. code-block:: python
from starrygl.sample.part_utils.partition_tgnn import partition_load
from starrygl.sample.memory.shared_mailbox import SharedMailBox
pdata = partition_load("PATH/{}".format(dataname), algo="metis_for_tgnn")
mailbox = SharedMailBox(pdata.ids.shape[0], memory_param, dim_edge_feat=pdata.edge_attr.shape[1] if pdata.edge_attr is not None else 0)
We then need to hand over the mailbox to the data loader as in the above example, so that the relevant memory/mailbox can be directly loaded during training.
During the training, we will call `get_update_memory`/`get_update_mail` function constantly updates
the relevant storage,which is the idea related to TGN.
'''
def __init__(self, def __init__(self,
num_nodes, num_nodes,
memory_param, memory_param,
...@@ -47,18 +80,18 @@ class SharedMailBox(): ...@@ -47,18 +80,18 @@ class SharedMailBox():
if uvm is True: if uvm is True:
ctx = DistributedContext.get_default_context() ctx = DistributedContext.get_default_context()
node_memory = uvm_empty(*node_memory.shape, node_memory = starrygl.utils.uvm.uvm_empty(*node_memory.shape,
dtype=node_memory.dtype, dtype=node_memory.dtype,
device=ctx.device) device=ctx.device)
uvm_share(node_memory,device = ctx.device) starrygl.utils.uvm.uvm_share(node_memory,device = ctx.device)
uvm_advise(node_memory,cudaMemoryAdvise.cudaMemAdviseSetAccessedBy) starrygl.utils.uvm.uvm_advise(node_memory,starrygl.utils.uvm.cudaMemoryAdvise.cudaMemAdviseSetAccessedBy)
uvm_prefetch(node_memory) starrygl.utils.uvm.uvm_prefetch(node_memory)
mailbox = uvm_empty(*mailbox.shape, mailbox = starrygl.utils.uvm.uvm_empty(*mailbox.shape,
dtype=mailbox.dtype, dtype=mailbox.dtype,
device=ctx.device) device=ctx.device)
uvm_share(mailbox,device = ctx.device) starrygl.utils.uvm.uvm_share(mailbox,device = ctx.device)
uvm_advise(mailbox,cudaMemoryAdvise.cudaMemAdviseSetAccessedBy) starrygl.utils.uvm.vm_advise(mailbox,starrygl.utils.uvm.cudaMemoryAdvise.cudaMemAdviseSetAccessedBy)
uvm_prefetch(mailbox) starrygl.utils.uvm.uvm_prefetch(mailbox)
self.node_memory = DistributedTensor(node_memory) self.node_memory = DistributedTensor(node_memory)
self.node_memory_ts = DistributedTensor(node_memory_ts) self.node_memory_ts = DistributedTensor(node_memory_ts)
self.mailbox = DistributedTensor(mailbox) self.mailbox = DistributedTensor(mailbox)
...@@ -266,7 +299,7 @@ class SharedMailBox(): ...@@ -266,7 +299,7 @@ class SharedMailBox():
def get_update_mail(self,dist_indx_mapper, def get_update_mail(self,dist_indx_mapper,
src,dst,ts,edge_feats, src,dst,ts,edge_feats,
memory): memory,embedding=None,use_src_emb=False,use_dst_emb=False):
if edge_feats is not None: if edge_feats is not None:
edge_feats = edge_feats.to(self.device).to(self.mailbox.dtype) edge_feats = edge_feats.to(self.device).to(self.mailbox.dtype)
src = src.to(self.device) src = src.to(self.device)
...@@ -276,12 +309,14 @@ class SharedMailBox(): ...@@ -276,12 +309,14 @@ class SharedMailBox():
mem_src = memory[src] mem_src = memory[src]
mem_dst = memory[dst] mem_dst = memory[dst]
if embedding is not None:
emb_src = embedding[src]
emb_dst = embedding[dst]
src_mail = torch.cat([emb_src if use_src_emb else mem_src, emb_dst if use_dst_emb else mem_dst], dim=1)
dst_mail = torch.cat([emb_dst if use_src_emb else mem_dst, emb_src if use_dst_emb else mem_src], dim=1)
if edge_feats is not None: if edge_feats is not None:
src_mail = torch.cat([mem_src, mem_dst, edge_feats], dim=1) src_mail = torch.cat([src_mail, edge_feats], dim=1)
dst_mail = torch.cat([mem_dst, mem_src, edge_feats], dim=1) dst_mail = torch.cat([dst_mail, edge_feats], dim=1)
else:
src_mail = torch.cat([mem_src, mem_dst], dim=1)
dst_mail = torch.cat([mem_dst, mem_src], dim=1)
mail = torch.cat([src_mail, dst_mail], dim=1).reshape(-1, src_mail.shape[1]) mail = torch.cat([src_mail, dst_mail], dim=1).reshape(-1, src_mail.shape[1])
mail_ts = torch.cat((ts,ts),-1).to(self.device).to(self.mailbox_ts.dtype) mail_ts = torch.cat((ts,ts),-1).to(self.device).to(self.mailbox_ts.dtype)
unq_index,inv = torch.unique(index,return_inverse = True) unq_index,inv = torch.unique(index,return_inverse = True)
...@@ -291,7 +326,6 @@ class SharedMailBox(): ...@@ -291,7 +326,6 @@ class SharedMailBox():
index = unq_index index = unq_index
return index,mail,mail_ts return index,mail,mail_ts
def get_update_memory(self,index,memory,memory_ts): def get_update_memory(self,index,memory,memory_ts):
unq_index,inv = torch.unique(index,return_inverse = True) unq_index,inv = torch.unique(index,return_inverse = True)
max_ts,idx = torch_scatter.scatter_max(memory_ts,inv,0) max_ts,idx = torch_scatter.scatter_max(memory_ts,inv,0)
......
from torch_sparse import SparseTensor from torch_sparse import SparseTensor
from torch_geometric.data import Data from torch_geometric.data import Data
from torch_geometric.utils import degree from torch_geometric.utils import degree
import os.path as osp import os.path as osp
import os import os
import shutil import shutil
......
import os.path as osp
import torch
class GraphData():
def __init__(self, path):
assert path is not None and osp.exists(path),'path 不存在'
id,edge_index,data,partptr =torch.load(path)
# 当前分区序号
self.partition_id = id
# 总分区数
self.partitions = partptr.numel() - 1
# 全图结构数据
self.num_nodes = partptr[self.partitions]
self.num_edges = edge_index[0].numel()
self.edge_index = edge_index
# 该分区下的数据(包含特征向量和子图结构)pyg Data数据结构
self.data = data
# 分区映射关系
self.partptr = partptr
self.eid = [i for i in range(self.num_edges)]
def __init__(self, id, edge_index, data, partptr, timestamp=None):
# 当前分区序号
self.partition_id = id
# 总分区数
self.partitions = partptr.numel() - 1
# 全图结构数据
self.num_nodes = partptr[self.partitions]
if edge_index is not None:
self.num_edges = edge_index[0].numel()
self.edge_index = edge_index
self.edge_ts = timestamp
# 该分区下的数据(包含特征向量和子图结构)pyg Data数据结构
self.data = data
# 分区映射关系
self.partptr = partptr
# edge id
self.eid = torch.tensor([i for i in range(0, self.num_edges)])
def select_attr(self,index):
return torch.index_select(self.data.x,0,index)
#返回全局的节点id 所对应的分区
def get_part_num(self):
return self.data.x.size()[0]
def select_attr(self,index):
return torch.index_select(self.data.x,0,index)
def select_y(self,index):
return torch.index_select(self.data.y,0,index)
#返回全局的节点id 所对应的分区
def get_localId_by_partitionId(self,id,index):
#print(index)
if(id == -1 or id == 0):
return index
else:
return torch.add(index,-self.partptr[id])
def get_globalId_by_partitionId(self,id,index):
if(id == -1 or id == 0):
return index
else:
return torch.add(index,self.partptr[id])
def get_node_num(self):
return self.num_nodes
def localId_to_globalId(self,id,partitionId:int = -1):
'''
将分区partitionId内的点id映射为全局的id
'''
if partitionId == -1:
partitionId = self.partition_id
assert id >=self.partptr[self.partition_id] and id < self.partptr[self.partition_id+1]
ids_before = 0
if self.partition_id>0:
ids_before = self.partptr[self.partition_id-1]
return id+ids_before
def get_partitionId_by_globalId(self,id):
'''
通过全局id得到对应的分区序号
'''
partitionId = -1
assert id>=0 and id<self.num_nodes,'id 超过范围'
for i in range(self.partitions):
if id>=self.partptr[i] and id<self.partptr[i+1]:
partitionId = i
break
assert partitionId>=0, 'id 不存在对应的分区'
return partitionId
def get_nodes_by_partitionId(self,id):
'''
根据partitioId 返回该分区的节点数量
'''
assert id>=0 and id<self.partitions,'partitionId 非法'
return (int)(self.partptr[id+1]-self.partptr[id])
def __repr__(self):
return (f'{self.__class__.__name__}(\n'
f' partition_id={self.partition_id}\n'
f' data={self.data},\n'
f' global_info('
f'num_nodes={self.num_nodes},'
f' num_edges={self.num_edges},'
f' num_parts={self.partitions},'
f' edge_index=[2,{self.edge_index[0].numel()}])\n'
f')')
import starrygl
import sys import sys
from os.path import abspath, join, dirname from os.path import abspath, join, dirname
sys.path.insert(0, join(abspath(dirname(__file__)))) sys.path.insert(0, join(abspath(dirname(__file__))))
import math import math
import torch import torch
...@@ -9,16 +9,72 @@ from typing import Optional, Tuple ...@@ -9,16 +9,72 @@ from typing import Optional, Tuple
from .base import BaseSampler, NegativeSampling, SampleOutput, SampleType from .base import BaseSampler, NegativeSampling, SampleOutput, SampleType
# from sample_cores import ParallelSampler, get_neighbors, heads_unique # from sample_cores import ParallelSampler, get_neighbors, heads_unique
from starrygl.lib.libstarrygl_sampler import ParallelSampler, get_neighbors
from torch.distributed.rpc import rpc_async
# def outer_sample(graph_name, nodes, ts, fanout_index, with_outer_sample = SampleType.Outer):# 默认此时继续向外采样 from torch.distributed.rpc import rpc_async
# local_sampler = get_local_sampler(graph_name)
# assert local_sampler is not None, 'Local_sampler is None!!!'
# out = local_sampler.sample_from_nodes(nodes, with_outer_sample, ts, fanout_index)
# return out
class NeighborSampler(BaseSampler): class NeighborSampler(BaseSampler):
r'''
Parallel sampling is crucial for expanding model training to a large amount of data.Due to the large scale and complexity of graph data, traditional serial sampling may lead to significant waste of computing and storage resources. The significance of parallel sampling lies in improving the efficiency and overall computational speed of sampling by simultaneously sampling from multiple nodes or neighbors.
This helps to accelerate the training and inference process of the model, making it more scalable and practical when dealing with large-scale graph data.
Our parallel sampling adopts a hybrid approach of CPU and GPU, where the entire graph structure is stored on the CPU and then uploaded to the GPU after sampling the graph structure on the CPU. Each trainer has a separate sampler for parallel training.
We have encapsulated the functions for parallel sampling, and you can easily use them in the following ways:
.. code-block:: python
# First,you need to import Python packages
from starrygl.sample.sample_core.neighbor_sampler import NeighborSampler
# Then,you can use ours parallel sampler
sampler = NeighborSampler(num_nodes=num_nodes, num_layers=num_layers, fanout=fanout, graph_data=graph_data,
workers=workers, is_distinct = is_distinct, policy = policy, edge_weight= edge_weight, graph_name = graph_name)
Args:
num_nodes (int): the num of all nodes in the graph
num_layers (int): the num of layers to be sampled
fanout (list): the list of max neighbors' number chosen for each layer
graph_data (:class: starrygl.sample.sample_core.neighbor_sampler): the graph data you want to sample
workers (int): the number of threads, default value is 1
is_distinct (bool): 1-need distinct muti-edge, 0-don't need distinct muti-edge
policy (str): "uniform" or "recent" or "weighted"
edge_weight (torch.Tensor,Optional): the initial weights of edges
graph_name (str): the name of graph should provide edge_index or (neighbors, deg)
Examples:
.. code-block:: python
from starrygl.sample.part_utils.partition_tgnn import partition_load
from starrygl.sample.graph_core import DataSet, DistributedGraphStore, TemporalNeighborSampleGraph
from starrygl.sample.sample_core.neighbor_sampler import NeighborSampler
pdata = partition_load("PATH/{}".format(dataname), algo="metis_for_tgnn")
graph = DistributedGraphStore(pdata = pdata,uvm_edge = False,uvm_node = False)
sample_graph = TemporalNeighborSampleGraph(sample_graph = pdata.sample_graph,mode = 'full')
sampler = NeighborSampler(num_nodes=graph.num_nodes, num_layers=1, fanout=[10],
graph_data=sample_graph, workers=15, policy = 'recent', graph_name = "wiki_train")
If you want to directly call parallel sampling functions, use the following methods:
.. code-block:: python
# the parameter meaning is the same as the `Args` above
from starrygl.lib.libstarrygl_sampler import ParallelSampler, get_neighbors
# get neighbor infomation table,row and col come from graph_data.edge_index=(row, col)
tnb = get_neighbors(graph_name, row.contiguous(), col.contiguous(), num_nodes, is_distinct, graph_data. eid, edge_weight, timestamp)
# call parallel sampler
p_sampler = ParallelSampler(self.tnb, num_nodes, graph_data.num_edges, workers, fanout, num_layers, policy)
For complete usage and more details, please refer to `~starrygl.sample.sample_core.neighbor_sampler`
'''
def __init__( def __init__(
self, self,
num_nodes: int, num_nodes: int,
...@@ -68,11 +124,11 @@ class NeighborSampler(BaseSampler): ...@@ -68,11 +124,11 @@ class NeighborSampler(BaseSampler):
row, col = graph_data.edge_index row, col = graph_data.edge_index
if(edge_weight is not None): if(edge_weight is not None):
edge_weight = edge_weight.float().contiguous() edge_weight = edge_weight.float().contiguous()
self.tnb = get_neighbors(graph_name, row.contiguous(), col.contiguous(), num_nodes, is_distinct, eid, edge_weight, timestamp) self.tnb = starrygl.sampler_ops.get_neighbors(graph_name, row.contiguous(), col.contiguous(), num_nodes, is_distinct, eid, edge_weight, timestamp)
else: else:
assert tnb is not None assert tnb is not None
self.tnb = tnb self.tnb = tnb
self.p_sampler = ParallelSampler(self.tnb, num_nodes, graph_data.num_edges, workers, self.p_sampler = starrygl.sampler_ops.ParallelSampler(self.tnb, num_nodes, graph_data.num_edges, workers,
fanout, num_layers, policy) fanout, num_layers, policy)
def _get_sample_info(self): def _get_sample_info(self):
......
...@@ -5,10 +5,16 @@ from os.path import abspath, join, dirname ...@@ -5,10 +5,16 @@ from os.path import abspath, join, dirname
from starrygl.distributed.context import DistributedContext from starrygl.distributed.context import DistributedContext
from starrygl.distributed.utils import DistIndex from starrygl.distributed.utils import DistIndex
from starrygl.module.modules import GeneralModel from starrygl.module.modules import GeneralModel
from pathlib import Path
from pathlib import Path
from starrygl.module.utils import parse_config from starrygl.module.utils import parse_config
from starrygl.sample.cache.fetch_cache import FetchFeatureCache from starrygl.sample.cache.fetch_cache import FetchFeatureCache
from starrygl.sample.graph_core import DataSet, DistributedGraphStore, TemporalNeighborSampleGraph from starrygl.sample.graph_core import DataSet, DistributedGraphStore, TemporalNeighborSampleGraph
from starrygl.module.utils import parse_config, EarlyStopMonitor
from starrygl.sample.graph_core import DataSet, DistributedGraphStore, TemporalNeighborSampleGraph
from starrygl.sample.memory.shared_mailbox import SharedMailBox from starrygl.sample.memory.shared_mailbox import SharedMailBox
from starrygl.sample.sample_core.base import NegativeSampling from starrygl.sample.sample_core.base import NegativeSampling
from starrygl.sample.sample_core.neighbor_sampler import NeighborSampler from starrygl.sample.sample_core.neighbor_sampler import NeighborSampler
...@@ -34,26 +40,30 @@ parser = argparse.ArgumentParser( ...@@ -34,26 +40,30 @@ parser = argparse.ArgumentParser(
) )
parser.add_argument('--rank', default=0, type=int, metavar='W', parser.add_argument('--rank', default=0, type=int, metavar='W',
help='name of dataset') help='name of dataset')
parser.add_argument('--patience', type=int, default=5, help='Patience for early stopping')
parser.add_argument('--world_size', default=1, type=int, metavar='W', parser.add_argument('--world_size', default=1, type=int, metavar='W',
help='number of negative samples') help='number of negative samples')
parser.add_argument('--dataname', default=1, type=str, metavar='W', parser.add_argument('--dataname', default=1, type=str, metavar='W',
help='number of negative samples') help='name of dataset')
parser.add_argument('--model', default='TGN', type=str, metavar='W',
help='name of model')
args = parser.parse_args() args = parser.parse_args()
from sklearn.metrics import average_precision_score, roc_auc_score from sklearn.metrics import average_precision_score, roc_auc_score
import torch import torch
import time import time
import random import random
import dgl
import numpy as np import numpy as np
from sklearn.metrics import average_precision_score, roc_auc_score from sklearn.metrics import average_precision_score, roc_auc_score
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
#os.environ['CUDA_VISIBLE_DEVICES'] = str(args.rank)
#os.environ["RANK"] = str(args.rank) os.environ['CUDA_VISIBLE_DEVICES'] = str(args.rank)
#os.environ["WORLD_SIZE"] = str(args.world_size) os.environ["RANK"] = str(args.rank)
#os.environ["LOCAL_RANK"] = str(0) os.environ["WORLD_SIZE"] = str(args.world_size)
os.environ["LOCAL_RANK"] = str(0)
torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
os.environ["MASTER_ADDR"] = '10.214.211.187' os.environ["MASTER_ADDR"] = '10.214.211.186'
os.environ["MASTER_PORT"] = '9337' os.environ["MASTER_PORT"] = '9667'
def seed_everything(seed=42): def seed_everything(seed=42):
random.seed(seed) random.seed(seed)
np.random.seed(seed) np.random.seed(seed)
...@@ -66,14 +76,21 @@ seed_everything(1234) ...@@ -66,14 +76,21 @@ seed_everything(1234)
def main(): def main():
print('main') print('main')
use_cuda = True use_cuda = True
sample_param, memory_param, gnn_param, train_param = parse_config('./config/TGN.yml') sample_param, memory_param, gnn_param, train_param = parse_config('./config/{}.yml'.format(args.model))
torch.set_num_threads(12) torch.set_num_threads(12)
ctx = DistributedContext.init(backend="nccl", use_gpu=True) ctx = DistributedContext.init(backend="nccl", use_gpu=True)
device_id = torch.cuda.current_device() device_id = torch.cuda.current_device()
print('use cuda on',device_id) print('use cuda on',device_id)
pdata = partition_load("/mnt/data/part_data/dataset/here/{}".format(args.dataname), algo="metis_for_tgnn") pdata = partition_load("/mnt/data/part_data/dataset/here/{}".format(args.dataname), algo="metis_for_tgnn")
graph = DistributedGraphStore(pdata = pdata,uvm_edge = False,uvm_node = False) graph = DistributedGraphStore(pdata = pdata)
Path("./saved_models/").mkdir(parents=True, exist_ok=True)
Path("./saved_checkpoints/").mkdir(parents=True, exist_ok=True)
get_checkpoint_path = lambda \
epoch: f'./saved_checkpoints/{args.model}-{args.dataname}-{epoch}.pth'
gnn_param['dyrep'] = True if args.model == 'DyRep' else False
use_src_emb = gnn_param['use_src_emb'] if 'use_src_emb' in gnn_param else False
use_dst_emb = gnn_param['use_dst_emb'] if 'use_dst_emb' in gnn_param else False
sample_graph = TemporalNeighborSampleGraph(sample_graph = pdata.sample_graph,mode = 'full') sample_graph = TemporalNeighborSampleGraph(sample_graph = pdata.sample_graph,mode = 'full')
mailbox = SharedMailBox(pdata.ids.shape[0], memory_param, dim_edge_feat = pdata.edge_attr.shape[1] if pdata.edge_attr is not None else 0) mailbox = SharedMailBox(pdata.ids.shape[0], memory_param, dim_edge_feat = pdata.edge_attr.shape[1] if pdata.edge_attr is not None else 0)
sampler = NeighborSampler(num_nodes=graph.num_nodes, num_layers=1, fanout=[10],graph_data=sample_graph, workers=15,policy = 'recent',graph_name = "wiki_train") sampler = NeighborSampler(num_nodes=graph.num_nodes, num_layers=1, fanout=[10],graph_data=sample_graph, workers=15,policy = 'recent',graph_name = "wiki_train")
...@@ -83,7 +100,7 @@ def main(): ...@@ -83,7 +100,7 @@ def main():
val_ts = torch.masked_select(graph.edge_ts,pdata.val_mask.to(graph.edge_index.device)) val_ts = torch.masked_select(graph.edge_ts,pdata.val_mask.to(graph.edge_index.device))
test_data = torch.masked_select(graph.edge_index,pdata.test_mask.to(graph.edge_index.device)).reshape(2,-1) test_data = torch.masked_select(graph.edge_index,pdata.test_mask.to(graph.edge_index.device)).reshape(2,-1)
test_ts = torch.masked_select(graph.edge_ts,pdata.test_mask.to(graph.edge_index.device)) test_ts = torch.masked_select(graph.edge_ts,pdata.test_mask.to(graph.edge_index.device))
print(train_data.shape[1],val_data.shape[1],test_data.shape[1]) ##print(train_data.shape[1],val_data.shape[1],test_data.shape[1])
train_data = DataSet(edges = train_data,ts =train_ts,eids = torch.nonzero(pdata.train_mask).view(-1)) train_data = DataSet(edges = train_data,ts =train_ts,eids = torch.nonzero(pdata.train_mask).view(-1))
#if dist.get_rank() == 0: #if dist.get_rank() == 0:
test_data = DataSet(edges = test_data,ts =test_ts,eids = torch.nonzero(pdata.test_mask).view(-1)) test_data = DataSet(edges = test_data,ts =test_ts,eids = torch.nonzero(pdata.test_mask).view(-1))
...@@ -100,7 +117,7 @@ def main(): ...@@ -100,7 +117,7 @@ def main():
trainloader = DistributedDataLoader(graph,train_data,sampler = sampler, trainloader = DistributedDataLoader(graph,train_data,sampler = sampler,
sampler_fn = SAMPLE_TYPE.SAMPLE_FROM_TEMPORAL_EDGES, sampler_fn = SAMPLE_TYPE.SAMPLE_FROM_TEMPORAL_EDGES,
neg_sampler=neg_sampler, neg_sampler=neg_sampler,
batch_size = 1000, batch_size = train_param['batch_size'],
shuffle=False, shuffle=False,
drop_last=True, drop_last=True,
chunk_size = None, chunk_size = None,
...@@ -111,7 +128,7 @@ def main(): ...@@ -111,7 +128,7 @@ def main():
testloader = DistributedDataLoader(graph,test_data,sampler = sampler, testloader = DistributedDataLoader(graph,test_data,sampler = sampler,
sampler_fn = SAMPLE_TYPE.SAMPLE_FROM_TEMPORAL_EDGES, sampler_fn = SAMPLE_TYPE.SAMPLE_FROM_TEMPORAL_EDGES,
neg_sampler=neg_sampler, neg_sampler=neg_sampler,
batch_size = 1000, batch_size = train_param['batch_size'],
shuffle=False, shuffle=False,
drop_last=False, drop_last=False,
chunk_size = None, chunk_size = None,
...@@ -121,7 +138,7 @@ def main(): ...@@ -121,7 +138,7 @@ def main():
valloader = DistributedDataLoader(graph,val_data,sampler = sampler, valloader = DistributedDataLoader(graph,val_data,sampler = sampler,
sampler_fn = SAMPLE_TYPE.SAMPLE_FROM_TEMPORAL_EDGES, sampler_fn = SAMPLE_TYPE.SAMPLE_FROM_TEMPORAL_EDGES,
neg_sampler=neg_sampler, neg_sampler=neg_sampler,
batch_size = 1000, batch_size = train_param['batch_size'],
shuffle=False, shuffle=False,
drop_last=False, drop_last=False,
chunk_size = None, chunk_size = None,
...@@ -133,7 +150,7 @@ def main(): ...@@ -133,7 +150,7 @@ def main():
#cache.init_cache_with_presample(trainloader,3) #cache.init_cache_with_presample(trainloader,3)
gnn_dim_node = 0 if graph.x is None else pdata.x.shape[1] gnn_dim_node = 0 if graph.x is None else pdata.x.shape[1]
gnn_dim_edge = 0 if graph.edge_attr is None else pdata.edge_attr.shape[1] gnn_dim_edge = 0 if graph.edge_attr is None else pdata.edge_attr.shape[1]
print(gnn_dim_node,gnn_dim_edge) print("gnn_dim_node:", gnn_dim_node, "gnn_dim_edge:", gnn_dim_edge)
avg_time = 0 avg_time = 0
if use_cuda: if use_cuda:
model = GeneralModel(gnn_dim_node, gnn_dim_edge, sample_param, memory_param, gnn_param, train_param).cuda() model = GeneralModel(gnn_dim_node, gnn_dim_edge, sample_param, memory_param, gnn_param, train_param).cuda()
...@@ -141,7 +158,7 @@ def main(): ...@@ -141,7 +158,7 @@ def main():
else: else:
model = GeneralModel(gnn_dim_node, gnn_dim_edge, sample_param, memory_param, gnn_param, train_param) model = GeneralModel(gnn_dim_node, gnn_dim_edge, sample_param, memory_param, gnn_param, train_param)
device = torch.device('cpu') device = torch.device('cpu')
model = DDP(model,find_unused_parameters=True) model = DDP(model,find_unused_parameters=False)
train_stream = torch.cuda.Stream() train_stream = torch.cuda.Stream()
send_stream = torch.cuda.Stream() send_stream = torch.cuda.Stream()
scatter_stream = torch.cuda.Stream() scatter_stream = torch.cuda.Stream()
...@@ -162,7 +179,7 @@ def main(): ...@@ -162,7 +179,7 @@ def main():
total_loss = 0 total_loss = 0
signal = torch.tensor([0],dtype = int,device = device) signal = torch.tensor([0],dtype = int,device = device)
with torch.cuda.stream(train_stream): with torch.cuda.stream(train_stream):
for roots,mfgs,metadata in loader: for roots,mfgs,metadata,sample_time in loader:
pred_pos, pred_neg = model(mfgs,metadata) pred_pos, pred_neg = model(mfgs,metadata)
total_loss += creterion(pred_pos, torch.ones_like(pred_pos)) total_loss += creterion(pred_pos, torch.ones_like(pred_pos))
...@@ -192,9 +209,11 @@ def main(): ...@@ -192,9 +209,11 @@ def main():
last_updated_ts) last_updated_ts)
# #
index, mail, mail_ts = mailbox.get_update_mail(dist_index_mapper, index, mail, mail_ts = mailbox.get_update_mail(dist_index_mapper,
src,dst,ts,edge_feats, src,dst,ts,edge_feats,
model.module.memory_updater.last_updated_memory, model.module.memory_updater.last_updated_memory,
) model.module.embedding,use_src_emb,
use_dst_emb,
)
mailbox.set_mailbox_all_to_all(index,memory,memory_ts,mail,mail_ts,reduce_Op = 'max') mailbox.set_mailbox_all_to_all(index,memory,memory_ts,mail,mail_ts,reduce_Op = 'max')
...@@ -208,14 +227,20 @@ def main(): ...@@ -208,14 +227,20 @@ def main():
auc_mrr = torch.empty([loader.expected_idx*world_size],dtype = torch.float,device = 'cuda') auc_mrr = torch.empty([loader.expected_idx*world_size],dtype = torch.float,device = 'cuda')
dist.all_gather_into_tensor(apc,torch.tensor(aps,device ='cuda',dtype=torch.float)) dist.all_gather_into_tensor(apc,torch.tensor(aps,device ='cuda',dtype=torch.float))
dist.all_gather_into_tensor(auc_mrr,torch.tensor(aucs_mrrs,device ='cuda',dtype=torch.float)) dist.all_gather_into_tensor(auc_mrr,torch.tensor(aucs_mrrs,device ='cuda',dtype=torch.float))
ap = float(torch.tensor(apc).mean()) # ap = float(torch.tensor(apc).mean())
auc_mrr = float(torch.tensor(auc_mrr).mean()) # auc_mrr = float(torch.tensor(auc_mrr).mean())
ap = float(apc.clone().mean())
auc_mrr = float(auc_mrr.clone().mean())
return ap, auc_mrr return ap, auc_mrr
creterion = torch.nn.BCEWithLogitsLoss() creterion = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=train_param['lr']) optimizer = torch.optim.Adam(model.parameters(), lr=train_param['lr'])
early_stopper = EarlyStopMonitor(max_round=args.patience)
MODEL_SAVE_PATH = f'./saved_models/{args.model}-{args.dataname}.pth'
for e in range(train_param['epoch']): for e in range(train_param['epoch']):
torch.cuda.synchronize() torch.cuda.synchronize()
write_back_time = 0
fetch_time = 0
epoch_start_time = time.time() epoch_start_time = time.time()
train_aps = list() train_aps = list()
print('Epoch {:d}:'.format(e)) print('Epoch {:d}:'.format(e))
...@@ -227,7 +252,8 @@ def main(): ...@@ -227,7 +252,8 @@ def main():
model.module.memory_updater.last_updated_nid = None model.module.memory_updater.last_updated_nid = None
model.module.memory_updater.last_updated_memory = None model.module.memory_updater.last_updated_memory = None
model.module.memory_updater.last_updated_ts = None model.module.memory_updater.last_updated_ts = None
for roots,mfgs,metadata in trainloader: for roots,mfgs,metadata,sample_time in trainloader:
fetch_time +=sample_time/1000
t_prep_s = time.time() t_prep_s = time.time()
with torch.cuda.stream(train_stream): with torch.cuda.stream(train_stream):
...@@ -242,9 +268,9 @@ def main(): ...@@ -242,9 +268,9 @@ def main():
optimizer.step() optimizer.step()
#torch.cuda.synchronize() #torch.cuda.synchronize()
t_prep_s = time.time() t_prep_s = time.time()
y_pred = torch.cat([pred_pos, pred_neg], dim=0).sigmoid().cpu() # y_pred = torch.cat([pred_pos, pred_neg], dim=0).sigmoid().cpu()
y_true = torch.cat([torch.ones(pred_pos.size(0)), torch.zeros(pred_neg.size(0))], dim=0) # y_true = torch.cat([torch.ones(pred_pos.size(0)), torch.zeros(pred_neg.size(0))], dim=0)
train_aps.append(average_precision_score(y_true, y_pred.detach().numpy())) # train_aps.append(average_precision_score(y_true, y_pred.detach().numpy()))
#start_event = torch.cuda.Event(enable_timing=True) #start_event = torch.cuda.Event(enable_timing=True)
#end_event = torch.cuda.Event(enable_timing=True) #end_event = torch.cuda.Event(enable_timing=True)
#start_event.record() #start_event.record()
...@@ -268,14 +294,17 @@ def main(): ...@@ -268,14 +294,17 @@ def main():
last_updated_memory, last_updated_memory,
last_updated_ts) last_updated_ts)
index, mail, mail_ts = mailbox.get_update_mail(dist_index_mapper, index, mail, mail_ts = mailbox.get_update_mail(dist_index_mapper,
src,dst,ts,edge_feats, src,dst,ts,edge_feats,
model.module.memory_updater.last_updated_memory, model.module.memory_updater.last_updated_memory,
) model.module.embedding,use_src_emb,use_dst_emb,
)
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
mailbox.set_mailbox_all_to_all(index,memory,memory_ts,mail,mail_ts,reduce_Op = 'max') mailbox.set_mailbox_all_to_all(index,memory,memory_ts,mail,mail_ts,reduce_Op = 'max')
end_event.record()
#end_event.record() torch.cuda.synchronize()
#torch.cuda.synchronize() write_back_time += start_event.elapsed_time(end_event)/1000
#write_back_time += start_event.elapsed_time(end_event)/1000
torch.cuda.synchronize() torch.cuda.synchronize()
time_prep = time.time() - epoch_start_time time_prep = time.time() - epoch_start_time
...@@ -288,9 +317,19 @@ def main(): ...@@ -288,9 +317,19 @@ def main():
#if cache.node_cache is not None: #if cache.node_cache is not None:
# print('hit {}'.format(cache.node_cache.hit_/ cache.node_cache.hit_sum)) # print('hit {}'.format(cache.node_cache.hit_/ cache.node_cache.hit_sum))
ap, auc = eval('val') ap, auc = eval('val')
print('\ttrain loss:{:.4f} train ap:{:4f} val ap:{:4f} val auc:{:4f}'.format(total_loss,train_ap, ap, auc)) early_stop = early_stopper.early_stop_check(ap)
print('\ttotal time:{:.2f}s prep time:{:.2f}s'.format(time.time()-epoch_start_time, time_prep)) if early_stop:
#print('\t fetch time:{:.2f}s write back time:{:.2f}s'.format(fetch_time,write_back_time)) print("Early stopping at epoch {:d}".format(e))
print(f"Loading the best model at epoch {early_stopper.best_epoch}")
best_model_path = get_checkpoint_path(early_stopper.best_epoch)
model.load_state_dict(torch.load(best_model_path))
break
else:
print('\ttrain loss:{:.4f} train ap:{:4f} val ap:{:4f} val auc:{:4f}'.format(total_loss,train_ap, ap, auc))
print('\ttotal time:{:.2f}s prep time:{:.2f}s'.format(time.time()-epoch_start_time, time_prep))
print('\t fetch time:{:.2f}s write back time:{:.2f}s'.format(fetch_time,write_back_time))
torch.save(model.state_dict(), get_checkpoint_path(e))
model.eval() model.eval()
if mailbox is not None: if mailbox is not None:
mailbox.reset() mailbox.reset()
...@@ -304,6 +343,7 @@ def main(): ...@@ -304,6 +343,7 @@ def main():
else: else:
print('\ttest AP:{:4f} test AUC:{:4f}'.format(ap, auc)) print('\ttest AP:{:4f} test AUC:{:4f}'.format(ap, auc))
print('test_dataset',test_data.edges.shape[1],'avg_time',avg_time/train_param['epoch']) print('test_dataset',test_data.edges.shape[1],'avg_time',avg_time/train_param['epoch'])
torch.save(model.state_dict(), MODEL_SAVE_PATH)
ctx.shutdown() ctx.shutdown()
if __name__ == "__main__": if __name__ == "__main__":
main() main()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment