Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
B
BTS-MTGNN
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
zhlj
BTS-MTGNN
Commits
99e5b95a
Commit
99e5b95a
authored
Nov 04, 2024
by
xxx
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
fix bugs and add APAN model
parent
b305c21a
Hide whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
727 additions
and
500 deletions
+727
-500
config/APAN.yml
+0
-3
config/JODIE.yml
+2
-1
config/TGN.yml
+0
-3
examples-probability-sample/average.sh
+0
-2
examples-probability-sample/boundary_static.py
+114
-0
examples-probability-sample/draw_boundary.py
+3
-3
examples-probability-sample/test_all.sh
+19
-18
examples-probability-sample/train_boundery.py
+16
-7
examples/test_all.sh
+6
-6
examples/tgbl_coin_train_4.out
+22
-2
examples/tgbl_comment_train.out
+0
-59
examples/train_boundery.py
+17
-10
starrygl/module/historical_cache.py
+1
-1
starrygl/module/memorys.py
+53
-21
starrygl/sample/memory/_shared_mailbox.py
+223
-185
starrygl/sample/memory/shared_mailbox.py
+245
-177
starrygl/sample/sample_core/neighbor_sampler.py
+6
-2
No files found.
config/APAN.yml
View file @
99e5b95a
...
...
@@ -3,10 +3,7 @@ sampling:
neighbor
:
-
10
strategy
:
'
recent'
prop_time
:
False
history
:
1
duration
:
0
num_thread
:
32
no_neg
:
True
memory
:
-
type
:
'
node'
...
...
config/JODIE.yml
View file @
99e5b95a
sampling
:
-
no_sample
:
True
-
strategy
:
'
identity'
history
:
1
memory
:
-
type
:
'
node'
dim_time
:
100
...
...
config/TGN.yml
View file @
99e5b95a
...
...
@@ -3,10 +3,7 @@ sampling:
neighbor
:
-
10
strategy
:
'
recent'
prop_time
:
False
history
:
1
duration
:
0
num_thread
:
32
memory
:
-
type
:
'
node'
dim_time
:
100
...
...
examples-probability-sample/average.sh
View file @
99e5b95a
bash test_all.sh 13357
>
13357.out
wait
bash test_all.sh 12347
>
12347.out
wait
bash test_all.sh 63377
>
63377.out
...
...
examples-probability-sample/boundary_static.py
0 → 100644
View file @
99e5b95a
import
matplotlib.pyplot
as
plt
import
numpy
as
np
import
torch
# 读取文件内容
ssim_values
=
[
0
,
0.1
,
0.2
,
0.3
,
0.4
,
2
]
# 假设这是你的 ssim 参数值
probability_values
=
[
1
,
0.1
,
0.05
,
0.01
,
0
]
data_values
=
[
'WIKI'
,
'LASTFM'
,
'WikiTalk'
,
'DGraphFin'
]
# 存储从文件中读取的数据
seed
=
[
'13357'
,
'12347'
,
'63377'
,
'53473'
,
' 54763'
]
partition
=
'ours'
# 从文件中读取数据,假设数据存储在文件 data.txt 中
#all/"$data"/"$partitions"-ours_shared-0.01-"$mem"-"$ssim"-"$sample".out
partitions
=
4
topk
=
0
mem
=
'all_update'
#'historical'
model
=
'TGN'
for
sd
in
seed
:
for
data
in
data_values
:
ap_list
=
[]
comm_list
=
[]
for
p
in
probability_values
:
if
data
==
'WIKI'
or
data
==
'LASTFM'
:
model
=
'TGN'
else
:
model
=
'TGN_large'
if
p
==
1
:
file
=
'all_{}/{}/{}/{}-{}-{}-{}-recent.out'
.
format
(
sd
,
data
,
model
,
partitions
,
partition
,
topk
,
mem
)
else
:
file
=
'all_{}/{}/{}/{}-{}-{}-{}-boundery_recent_decay-{}.out'
.
format
(
sd
,
data
,
model
,
partitions
,
partition
,
topk
,
mem
,
p
)
prefix
=
"val ap:"
max_val_ap
=
0
test_ap
=
0
with
open
(
file
,
'r'
)
as
file
:
for
line
in
file
:
if
line
.
find
(
prefix
)
!=-
1
:
pos
=
line
.
find
(
prefix
)
+
len
(
prefix
)
posr
=
line
.
find
(
' '
,
pos
)
#print(line[pos:posr])
val_ap
=
float
(
line
[
pos
:
posr
])
pos
=
line
.
find
(
"test ap "
)
+
len
(
"test ap "
)
posr
=
line
.
find
(
' '
,
pos
)
#print(line[pos:posr])
_test_ap
=
float
(
line
[
pos
:
posr
])
if
(
val_ap
>
max_val_ap
):
max_val_ap
=
val_ap
test_ap
=
_test_ap
ap_list
.
append
(
test_ap
)
print
(
'data {} seed {} ap: {}'
.
format
(
data
,
sd
,
ap_list
))
# prefix = 'best test AP:'
# cnt = 0
# sum = 0
# with open(file, 'r') as file:
# for line in file:
# if line.startswith(prefix):
# ap = float(line.lstrip(prefix).split(' ')[0])
# pos = line.find('remote node number tensor')
# if(pos!=-1):
# posr = line.find(']',pos+2+len('remote node number tensor'),)
# #print(line,line[pos+2+len('remote node number tensor'):posr])
# comm = int(line[pos+2+len('remote node number tensor'):posr])
# #print()
# sum = sum+comm
# cnt = cnt+1
# #print(comm)
# ap_list.append(ap)
# comm_list.append(sum/cnt*4)
# # 绘制柱状图
# print('{} TestAP={}\n'.format(data,ap_list))
# bar_width = 0.4
# #shared comm tensor
# # 设置柱状图的位置
# bars = range(len(probability_values))
# # 绘制柱状图
# plt.bar([b for b in bars], ap_list, width=bar_width)
# # 绘制柱状图
# plt.ylim([0.9,1])
# plt.xticks([b for b in bars], probability_values)
# plt.xlabel('probability')
# plt.ylabel('Test AP')
# plt.title('{}({} partitions)'.format(data,partitions))
# plt.savefig('boundary_AP_{}_{}_{}.png'.format(data,partitions,model))
# plt.clf()
# print(comm_list)
# plt.bar([b for b in bars], comm_list, width=bar_width)
# # 绘制柱状图
# plt.xticks([b for b in bars], probability_values)
# plt.xlabel('probability')
# plt.ylabel('Communication volume')
# plt.title('{}({} partitions)'.format(data,partitions))
# plt.savefig('boundary_comm_{}_{}_{}.png'.format(data,partitions,model))
# plt.clf()
# if partition == 'ours_shared':
# partition0 = 'ours'
# else:
# partition0=partition
# for p in probability_values:
# file = '{}/{}/test_{}_{}_{}_0_boundery_recent_uniform_{}_all_update_2.pt'.format(data,model,partition0,topk,partitions,float(p))
# val_ap = torch.tensor(torch.load(file))[:,0]
# epoch = torch.arange(val_ap.shape[0])
# #绘制曲线图
# plt.plot(epoch,val_ap, label='probability={}'.format(p))
# plt.xlabel('Epoch')
# plt.ylabel('Val AP')
# plt.title('{}({} partitions)'.format(data,partitions))
# # plt.grid(True)
# plt.legend()
# plt.savefig('{}_{}_{}_boundary_Convergence_rate.png'.format(data,partitions,model))
# plt.clf()
examples-probability-sample/draw_boundary.py
View file @
99e5b95a
...
...
@@ -4,15 +4,15 @@ import torch
# 读取文件内容
ssim_values
=
[
0
,
0.1
,
0.2
,
0.3
,
0.4
,
2
]
# 假设这是你的 ssim 参数值
probability_values
=
[
1
,
0.1
,
0.05
,
0.01
,
0
]
data_values
=
[
'WIKI'
,
'
LASTFM'
,
'WikiTalk'
,
'DGraphFin
'
]
# 存储从文件中读取的数据
seed
=
[
'13357'
,
'12347'
,
'63377'
,
'53473'
,
'
54763'
]
data_values
=
[
'WIKI'
,
'
WikiTalk
'
]
# 存储从文件中读取的数据
seed
=
[
'13357'
,
'12347'
,
'63377'
,
'53473'
,
'54763'
]
partition
=
'ours_shared'
# 从文件中读取数据,假设数据存储在文件 data.txt 中
#all/"$data"/"$partitions"-ours_shared-0.01-"$mem"-"$ssim"-"$sample".out
partitions
=
4
topk
=
0.01
mem
=
'
all_update
'
#'historical'
mem
=
'
local
'
#'historical'
model
=
'TGN'
for
sd
in
seed
:
for
data
in
data_values
:
...
...
examples-probability-sample/test_all.sh
View file @
99e5b95a
...
...
@@ -6,19 +6,20 @@ addr="192.168.1.107"
partition_params
=(
"ours"
)
#"metis" "ldg" "random")
#("ours" "metis" "ldg" "random")
partitions
=
"
4
"
partitions
=
"
8
"
node_per
=
"4"
nnodes
=
"
1
"
node_rank
=
"
0
"
probability_params
=(
"0.1"
"0
"
"0.05"
"0.01
"
)
sample_type_params
=(
"boundery_recent_decay"
"recent"
)
nnodes
=
"
2
"
node_rank
=
"
1
"
probability_params
=(
"0.1"
"0
.05
"
)
sample_type_params
=(
"boundery_recent_decay"
)
#sample_type_params=("recent" "boundery_recent_decay") #"boundery_recent_uniform")
#memory_type=("all_update" "p2p" "all_reduce" "historical" "local")
memory_type
=(
"all_update"
)
memory_type
=(
"all_update"
"historical"
"local"
)
#memory_type=("local" "all_update" "historical" "all_reduce")
shared_memory_ssim
=(
"0.3"
"0.7"
)
#data_param=("WIKI" "REDDIT" "LASTFM" "WikiTalk")
data_param
=(
"
WIKI"
"LASTFM"
"WikiTalk"
"DGraphFin
"
)
data_param
=(
"
StackOverflow
"
)
#data_param=("WIKI" "REDDIT" "LASTFM" "DGraphFin" "WikiTalk" "StackOverflow")
#data_param=("WIKI" "REDDIT" "LASTFM" "WikiTalk" "StackOverflow")
#data_param=("REDDIT" "WikiTalk")
...
...
@@ -30,9 +31,9 @@ data_param=("WIKI" "LASTFM" "WikiTalk" "DGraphFin")
#seed=(( RANDOM % 1000000 + 1 ))
mkdir
-p
all_
"
$seed
"
for
data
in
"
${
data_param
[@]
}
"
;
do
model
=
"
TGN_large
"
model
=
"
JODIE
"
if
[
"
$data
"
=
"WIKI"
]
||
[
"
$data
"
=
"REDDIT"
]
||
[
"
$data
"
=
"LASTFM"
]
;
then
model
=
"
TGN
"
model
=
"
JODIE
"
fi
#model="APAN"
mkdir all_
"
$seed
"
/
"
$data
"
...
...
@@ -57,8 +58,8 @@ for data in "${data_param[@]}"; do
wait
fi
else
torchrun
--nnodes
"
$nnodes
"
--node_rank
"
$node_rank
"
--nproc-per-node
"
$node_per
"
--master-addr
"
$addr
"
--master-port
9445 train_boundery.py
--dataname
"
$data
"
--mode
"
$model
"
--partition
"
$partition
"
--topk
0
--sample_type
"
$sample
"
--memory_type
"
$mem
"
--seed
"
$seed
"
>
all_
"
$seed
"
/
"
$data
"
/
"
$model
"
/
"
$partitions
"
-
"
$partition
"
-0-
"
$mem
"
-
"
$sample
"
.out &
wait
#
torchrun --nnodes "$nnodes" --node_rank "$node_rank" --nproc-per-node "$node_per" --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode "$model" --partition "$partition" --topk 0 --sample_type "$sample" --memory_type "$mem" --seed "$seed" > all_"$seed"/"$data"/"$model"/"$partitions"-"$partition"-0-"$mem"-"$sample".out &
#
wait
if
[
"
$partition
"
=
"ours"
]
&&
[
"
$mem
"
!=
"all_local"
]
;
then
torchrun
--nnodes
"
$nnodes
"
--node_rank
"
$node_rank
"
--nproc-per-node
"
$node_per
"
--master-addr
"
$addr
"
--master-port
9445 train_boundery.py
--dataname
"
$data
"
--mode
"
$model
"
--partition
"
$partition
"
--topk
0.1
--sample_type
"
$sample
"
--memory_type
"
$mem
"
--seed
"
$seed
"
>
all_
"
$seed
"
/
"
$data
"
/
"
$model
"
/
"
$partitions
"
-ours_shared-0
.01-
"
$mem
"
-
"
$sample
"
.out &
wait
...
...
@@ -80,13 +81,13 @@ for data in "${data_param[@]}"; do
torchrun
--nnodes
"
$nnodes
"
--node_rank
"
$node_rank
"
--nproc-per-node
"
$node_per
"
--master-addr
"
$addr
"
--master-port
9445 train_boundery.py
--dataname
"
$data
"
--mode
"
$model
"
--partition
"
$partition
"
--topk
0.1
--sample_type
"
$sample
"
--probability
"
$pro
"
--memory_type
"
$mem
"
--seed
"
$seed
"
>
all_
"
$seed
"
/
"
$data
"
/
"
$model
"
/
"
$partitions
"
-ours_shared-0
.01-
"
$mem
"
-
"
$sample
"
-
"
$pro
"
.out&
wait
fi
else
torchrun
--nnodes
"
$nnodes
"
--node_rank
"
$node_rank
"
--nproc-per-node
"
$node_per
"
--master-addr
"
$addr
"
--master-port
9445 train_boundery.py
--dataname
"
$data
"
--mode
"
$model
"
--partition
"
$partition
"
--topk
0
--sample_type
"
$sample
"
--probability
"
$pro
"
--memory_type
"
$mem
"
--seed
"
$seed
"
>
all_
"
$seed
"
/
"
$data
"
/
"
$model
"
/
"
$partitions
"
-
"
$partition
"
-0-
"
$mem
"
-
"
$sample
"
-
"
$pro
"
.out &
wait
if
[
"
$partition
"
=
"ours"
]
&&
[
"
$mem
"
!=
"all_local"
]
;
then
torchrun
--nnodes
"
$nnodes
"
--node_rank
"
$node_rank
"
--nproc-per-node
"
$node_per
"
--master-addr
"
$addr
"
--master-port
9445 train_boundery.py
--dataname
"
$data
"
--mode
"
$model
"
--partition
"
$partition
"
--topk
0.1
--sample_type
"
$sample
"
--probability
"
$pro
"
--memory_type
"
$mem
"
--seed
"
$seed
"
>
all_
"
$seed
"
/
"
$data
"
/
"
$model
"
/
"
$partitions
"
-ours_shared-0
.01-
"
$mem
"
-
"
$sample
"
-
"
$pro
"
.out &
wait
fi
#
else
#
torchrun --nnodes "$nnodes" --node_rank "$node_rank" --nproc-per-node "$node_per" --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode "$model" --partition "$partition" --topk 0 --sample_type "$sample" --probability "$pro" --memory_type "$mem" --seed "$seed" > all_"$seed"/"$data"/"$model"/"$partitions"-"$partition"-0-"$mem"-"$sample"-"$pro".out &
#
wait
#
if [ "$partition" = "ours" ] && [ "$mem" != "all_local" ]; then
#
torchrun --nnodes "$nnodes" --node_rank "$node_rank" --nproc-per-node "$node_per" --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode "$model" --partition "$partition" --topk 0.1 --sample_type "$sample" --probability "$pro" --memory_type "$mem" --seed "$seed" > all_"$seed"/"$data"/"$model"/"$partitions"-ours_shared-0.01-"$mem"-"$sample"-"$pro".out &
#
wait
#
fi
fi
done
done
...
...
examples-probability-sample/train_boundery.py
View file @
99e5b95a
...
...
@@ -126,7 +126,7 @@ def seed_everything(seed=42):
torch
.
cuda
.
manual_seed
(
seed
)
torch
.
backends
.
cudnn
.
deterministic
=
True
torch
.
backends
.
cudnn
.
benchmark
=
False
seed_everything
(
args
.
seed
)
total_next_batch
=
0
total_forward
=
0
total_count_score
=
0
...
...
@@ -267,9 +267,15 @@ def main():
if
args
.
local_neg_sample
:
print
(
'dst len {} origin len {}'
.
format
(
graph
.
edge_index
[
1
,
mask
]
.
unique
()
.
shape
[
0
],
full_dst
.
unique
()
.
shape
[
0
]))
train_neg_sampler
=
LocalNegativeSampling
(
'triplet'
,
amount
=
args
.
neg_samples
,
dst_node_list
=
graph
.
edge_index
[
1
,
mask
]
.
unique
())
else
:
#train_neg_sampler = LocalNegativeSampling('triplet',amount = args.neg_samples,dst_node_list = full_dst.unique())
train_neg_sampler
=
LocalNegativeSampling
(
'triplet'
,
amount
=
args
.
neg_samples
,
dst_node_list
=
full_dst
.
unique
(),
local_mask
=
(
DistIndex
(
graph
.
nids_mapper
[
full_dst
.
unique
()]
.
to
(
'cpu'
))
.
part
==
dist
.
get_rank
()),
prob
=
args
.
probability
)
remote_ratio
=
train_neg_sampler
.
local_dst
.
shape
[
0
]
/
train_neg_sampler
.
dst_node_list
.
shape
[
0
]
#train_ratio_pos = (1 - args.probability) + args.probability * remote_ratio
#train_ratio_neg = args.probability * (1-remote_ratio)
train_ratio_pos
=
1.0
/
(
1
-
args
.
probability
+
args
.
probability
*
remote_ratio
)
if
((
args
.
probability
<
1
)
&
(
args
.
probability
>
0
))
else
1
train_ratio_neg
=
1.0
/
(
args
.
probability
*
remote_ratio
)
if
((
args
.
probability
<
1
)
&
(
args
.
probability
>
0
))
else
1
print
(
train_neg_sampler
.
dst_node_list
)
neg_sampler
=
LocalNegativeSampling
(
'triplet'
,
amount
=
neg_samples
,
dst_node_list
=
full_dst
.
unique
(),
seed
=
args
.
seed
)
...
...
@@ -338,10 +344,10 @@ def main():
print
(
'dim_node {} dim_edge {}
\n
'
.
format
(
gnn_dim_node
,
gnn_dim_edge
))
avg_time
=
0
if
use_cuda
:
model
=
GeneralModel
(
gnn_dim_node
,
gnn_dim_edge
,
sample_param
,
memory_param
,
gnn_param
,
train_param
,
graph
.
ids
.
shape
[
0
],
mailbox
)
.
cuda
()
model
=
GeneralModel
(
gnn_dim_node
,
gnn_dim_edge
,
sample_param
,
memory_param
,
gnn_param
,
train_param
,
graph
.
ids
.
shape
[
0
],
mailbox
,
train_ratio
=
(
train_ratio_pos
,
train_ratio_neg
)
)
.
cuda
()
device
=
torch
.
device
(
'cuda'
)
else
:
model
=
GeneralModel
(
gnn_dim_node
,
gnn_dim_edge
,
sample_param
,
memory_param
,
gnn_param
,
train_param
,
graph
.
ids
.
shape
[
0
],
mailbox
)
model
=
GeneralModel
(
gnn_dim_node
,
gnn_dim_edge
,
sample_param
,
memory_param
,
gnn_param
,
train_param
,
graph
.
ids
.
shape
[
0
],
mailbox
,
train_ratio
=
(
train_ratio_pos
,
train_ratio_neg
)
)
device
=
torch
.
device
(
'cpu'
)
model
=
DDP
(
model
,
find_unused_parameters
=
True
)
def
count_parameters
(
model
):
...
...
@@ -531,9 +537,12 @@ def main():
model
.
train
()
optimizer
.
zero_grad
()
ones
=
torch
.
ones
(
metadata
[
'dst_neg_index'
]
.
shape
[
0
],
device
=
model
.
device
,
dtype
=
torch
.
float
)
weight
=
torch
.
where
(
DistIndex
(
mfgs
[
0
][
0
]
.
srcdata
[
'ID'
][
metadata
[
'dst_neg_index'
]])
.
part
==
torch
.
distributed
.
get_rank
(),
ones
*
train_ratio_pos
,
ones
*
train_ratio_neg
)
.
reshape
(
-
1
,
1
)
pred_pos
,
pred_neg
=
model
(
mfgs
,
metadata
,
neg_samples
=
args
.
neg_samples
,
async_param
=
param
)
loss
=
creterion
(
pred_pos
,
torch
.
ones_like
(
pred_pos
))
loss
+=
creterion
(
pred_neg
,
torch
.
zeros_like
(
pred_neg
))
neg_creterion
=
torch
.
nn
.
BCEWithLogitsLoss
(
weight
)
loss
+=
neg_creterion
(
pred_neg
,
torch
.
zeros_like
(
pred_neg
))
total_loss
+=
float
(
loss
.
item
())
#mailbox.handle_last_async()
#trainloader.async_feature()
...
...
@@ -663,9 +672,9 @@ def main():
pass
# print('weight {} {}\n'.format(tt.weight_count_local,tt.weight_count_remote))
# print('ssim {} {}\n'.format(tt.ssim_local/tt.ssim_cnt,tt.ssim_remote/tt.ssim_cnt))
torch
.
save
(
val_list
,
'all_
args.seed/{}/{}/val_{}_{}_{}_{}_{}_{}_{}_{}.pt'
.
format
(
args
.
dataname
,
args
.
model
,
args
.
partition
,
args
.
topk
,
dist
.
get_world_size
(),
dist
.
get_rank
(),
args
.
sample_type
,
args
.
probability
,
args
.
memory_type
,
args
.
shared_memory_ssim
))
torch
.
save
(
loss_list
,
'all_
args.seed/{}/{}/loss_{}_{}_{}_{}_{}_{}_{}_{}.pt'
.
format
(
args
.
dataname
,
args
.
model
,
args
.
partition
,
args
.
topk
,
dist
.
get_world_size
(),
dist
.
get_rank
(),
args
.
sample_type
,
args
.
probability
,
args
.
memory_type
,
args
.
shared_memory_ssim
))
torch
.
save
(
test_ap_list
,
'all_
args.seed/{}/{}/test_{}_{}_{}_{}_{}_{}_{}_{}.pt'
.
format
(
args
.
dataname
,
args
.
model
,
args
.
partition
,
args
.
topk
,
dist
.
get_world_size
(),
dist
.
get_rank
(),
args
.
sample_type
,
args
.
probability
,
args
.
memory_type
,
args
.
shared_memory_ssim
))
torch
.
save
(
val_list
,
'all_
{}/{}/{}/val_{}_{}_{}_{}_{}_{}_{}_{}.pt'
.
format
(
args
.
seed
,
args
.
dataname
,
args
.
model
,
args
.
partition
,
args
.
topk
,
dist
.
get_world_size
(),
dist
.
get_rank
(),
args
.
sample_type
,
args
.
probability
,
args
.
memory_type
,
args
.
shared_memory_ssim
))
torch
.
save
(
loss_list
,
'all_
{}/{}/{}/loss_{}_{}_{}_{}_{}_{}_{}_{}.pt'
.
format
(
args
.
seed
,
args
.
dataname
,
args
.
model
,
args
.
partition
,
args
.
topk
,
dist
.
get_world_size
(),
dist
.
get_rank
(),
args
.
sample_type
,
args
.
probability
,
args
.
memory_type
,
args
.
shared_memory_ssim
))
torch
.
save
(
test_ap_list
,
'all_
{}/{}/{}/test_{}_{}_{}_{}_{}_{}_{}_{}.pt'
.
format
(
args
.
seed
,
args
.
dataname
,
args
.
model
,
args
.
partition
,
args
.
topk
,
dist
.
get_world_size
(),
dist
.
get_rank
(),
args
.
sample_type
,
args
.
probability
,
args
.
memory_type
,
args
.
shared_memory_ssim
))
print
(
avg_time
)
if
not
early_stop
:
...
...
examples/test_all.sh
View file @
99e5b95a
...
...
@@ -2,7 +2,7 @@
#跑了4卡的TaoBao
# 定义数组变量
seed
=
$1
addr
=
"192.168.1.10
7
"
addr
=
"192.168.1.10
6
"
partition_params
=(
"ours"
)
#"metis" "ldg" "random")
#("ours" "metis" "ldg" "random")
...
...
@@ -10,13 +10,13 @@ partitions="8"
node_per
=
"4"
nnodes
=
"2"
node_rank
=
"0"
probability_params
=(
"0.1"
"0.01"
"0.05"
)
probability_params
=(
"0.1"
)
sample_type_params
=(
"boundery_recent_decay"
)
#sample_type_params=("recent" "boundery_recent_decay") #"boundery_recent_uniform")
#memory_type=("all_update" "p2p" "all_reduce" "historical" "local")
memory_type
=(
"
all_update
"
)
memory_type
=(
"
historical
"
)
#memory_type=("local" "all_update" "historical" "all_reduce")
shared_memory_ssim
=(
"0.3"
"0.7"
)
shared_memory_ssim
=(
"0.3"
)
#data_param=("WIKI" "REDDIT" "LASTFM" "WikiTalk")
<<<<<<
< HEAD
data_param
=(
"GDELT"
)
...
...
@@ -34,9 +34,9 @@ data_param=("LASTFM")
#seed=(( RANDOM % 1000000 + 1 ))
mkdir
-p
all_
"
$seed
"
for
data
in
"
${
data_param
[@]
}
"
;
do
model
=
"
TGN_large
"
model
=
"
APAN
"
if
[
"
$data
"
=
"WIKI"
]
||
[
"
$data
"
=
"REDDIT"
]
||
[
"
$data
"
=
"LASTFM"
]
;
then
model
=
"
TG
N"
model
=
"
APA
N"
fi
#model="APAN"
mkdir all_
"
$seed
"
/
"
$data
"
...
...
examples/tgbl_coin_train_4.out
View file @
99e5b95a
LOCAL RANK 2, RANK2
LOCAL RANK 1, RANK1
LOCAL RANK 3, RANK3
LOCAL RANK 0, RANK0
LOCAL RANK 2, RANK2
LOCAL RANK 3, RANK3
in
in
in
in
local rank is 1 world_size is 4 memory group is 0 memory rank is 1 memory group size is 4
local rank is 0 world_size is 4 memory group is 0 memory rank is 0 memory group size is 4
local rank is 3 world_size is 4 memory group is 0 memory rank is 3 memory group size is 4
[0, 1, 2, 3]
[0, 1, 2, 3][0, 1, 2, 3]
local rank is 2 world_size is 4 memory group is 0 memory rank is 2 memory group size is 4
[0, 1, 2, 3]
use cuda on 0
use cuda on 3
use cuda on 2
use cuda on 1
examples/tgbl_comment_train.out
View file @
99e5b95a
LOCAL RANK 0, RANK0
use cuda on 0
994790
get_neighbors consume: 6.18508s
Epoch 0:
train loss:12236.7578 train ap:0.986976 val ap:0.934674 val auc:0.946284
total time:630.37s prep time:545.79s
fetch time:0.00s write back time:0.00s
Epoch 1:
train loss:11833.1818 train ap:0.987815 val ap:0.960581 val auc:0.965728
total time:628.44s prep time:542.56s
fetch time:0.00s write back time:0.00s
Epoch 2:
train loss:11622.9559 train ap:0.988244 val ap:0.956752 val auc:0.963083
total time:622.89s prep time:538.77s
fetch time:0.00s write back time:0.00s
Epoch 3:
train loss:11679.1400 train ap:0.988072 val ap:0.929351 val auc:0.943797
total time:681.88s prep time:569.50s
fetch time:0.00s write back time:0.00s
Epoch 4:
train loss:11676.1710 train ap:0.988098 val ap:0.936353 val auc:0.948531
total time:849.98s prep time:741.47s
fetch time:0.00s write back time:0.00s
Epoch 5:
train loss:11745.6001 train ap:0.987897 val ap:0.950828 val auc:0.958958
total time:862.77s prep time:750.90s
fetch time:0.00s write back time:0.00s
Epoch 6:
Early stopping at epoch 6
Loading the best model at epoch 1
0.9248434901237488 0.929413378238678
0.8653780221939087 0.861071765422821
test AP:0.847958 test AUC:0.837159
test_dataset 6647176 avg_time 87.00003329753876
examples/train_boundery.py
View file @
99e5b95a
...
...
@@ -228,16 +228,19 @@ def main():
num_layers
=
sample_param
[
'layer'
]
if
'layer'
in
sample_param
else
1
fanout
=
sample_param
[
'neighbor'
]
if
'neighbor'
in
sample_param
else
[
10
]
policy
=
sample_param
[
'strategy'
]
if
'strategy'
in
sample_param
else
'recent'
policy_train
=
args
.
sample_type
#'boundery_recent_decay'
no_neg
=
sample_param
[
'no_neg'
]
if
'no_neg'
in
sample_param
else
False
if
policy
!=
'recent'
:
policy_train
=
args
.
sample_type
#'boundery_recent_decay'
else
:
policy_train
=
policy
if
memory_param
[
'type'
]
!=
'none'
:
mailbox
=
SharedMailBox
(
graph
.
ids
.
shape
[
0
],
memory_param
,
dim_edge_feat
=
graph
.
efeat
.
shape
[
1
]
if
graph
.
efeat
is
not
None
else
0
,
shared_nodes_index
=
graph
.
shared_nids_list
[
ctx
.
memory_group_rank
],
device
=
torch
.
device
(
'cuda:{}'
.
format
(
local_rank
)),
cache_route
=
cache_route
,
shared_ssim
=
args
.
shared_memory_ssim
)
shared_nodes_index
=
graph
.
shared_nids_list
[
ctx
.
memory_group_rank
],
device
=
torch
.
device
(
'cuda:{}'
.
format
(
local_rank
)),
shared_ssim
=
args
.
shared_memory_ssim
)
else
:
mailbox
=
None
sampler
=
NeighborSampler
(
num_nodes
=
graph
.
num_nodes
,
num_layers
=
num_layers
,
fanout
=
fanout
,
graph_data
=
sample_graph
,
workers
=
10
,
policy
=
policy_train
,
graph_name
=
"train"
,
local_part
=
dist
.
get_rank
(),
edge_part
=
DistIndex
(
graph
.
eids_mapper
)
.
part
,
node_part
=
DistIndex
(
graph
.
nids_mapper
)
.
part
,
probability
=
args
.
probability
)
eval_sampler
=
NeighborSampler
(
num_nodes
=
graph
.
num_nodes
,
num_layers
=
num_layers
,
fanout
=
fanout
,
graph_data
=
eval_sample_graph
,
workers
=
10
,
policy
=
policy_train
,
graph_name
=
"eval"
,
local_part
=
dist
.
get_rank
(),
edge_part
=
DistIndex
(
graph
.
eids_mapper
)
.
part
,
node_part
=
DistIndex
(
graph
.
nids_mapper
)
.
part
,
probability
=
args
.
probability
)
sampler
=
NeighborSampler
(
num_nodes
=
graph
.
num_nodes
,
num_layers
=
num_layers
,
fanout
=
fanout
,
graph_data
=
sample_graph
,
workers
=
10
,
policy
=
policy_train
,
graph_name
=
"train"
,
local_part
=
dist
.
get_rank
(),
edge_part
=
DistIndex
(
graph
.
eids_mapper
)
.
part
,
node_part
=
DistIndex
(
graph
.
nids_mapper
)
.
part
,
probability
=
args
.
probability
,
no_neg
=
no_neg
)
eval_sampler
=
NeighborSampler
(
num_nodes
=
graph
.
num_nodes
,
num_layers
=
num_layers
,
fanout
=
fanout
,
graph_data
=
eval_sample_graph
,
workers
=
10
,
policy
=
policy_train
,
graph_name
=
"eval"
,
local_part
=
dist
.
get_rank
(),
edge_part
=
DistIndex
(
graph
.
eids_mapper
)
.
part
,
node_part
=
DistIndex
(
graph
.
nids_mapper
)
.
part
,
probability
=
args
.
probability
,
no_neg
=
no_neg
)
train_data
=
torch
.
masked_select
(
graph
.
edge_index
,
train_mask
.
to
(
graph
.
edge_index
.
device
))
.
reshape
(
2
,
-
1
)
train_ts
=
torch
.
masked_select
(
graph
.
ts
,
train_mask
.
to
(
graph
.
edge_index
.
device
))
...
...
@@ -272,8 +275,10 @@ def main():
#train_neg_sampler = LocalNegativeSampling('triplet',amount = args.neg_samples,dst_node_list = full_dst.unique())
train_neg_sampler
=
LocalNegativeSampling
(
'triplet'
,
amount
=
args
.
neg_samples
,
dst_node_list
=
full_dst
.
unique
(),
local_mask
=
(
DistIndex
(
graph
.
nids_mapper
[
full_dst
.
unique
()]
.
to
(
'cpu'
))
.
part
==
dist
.
get_rank
()),
prob
=
args
.
probability
)
remote_ratio
=
train_neg_sampler
.
local_dst
.
shape
[
0
]
/
train_neg_sampler
.
dst_node_list
.
shape
[
0
]
train_ratio_pos
=
(
1
-
args
.
probability
)
+
args
.
probability
*
remote_ratio
train_ratio_neg
=
args
.
probability
*
(
1
-
remote_ratio
)
#train_ratio_pos = (1 - args.probability) + args.probability * remote_ratio
#train_ratio_neg = args.probability * (1-remote_ratio)
train_ratio_pos
=
1.0
/
(
1
-
args
.
probability
+
args
.
probability
*
remote_ratio
)
if
((
args
.
probability
<
1
)
&
(
args
.
probability
>
0
))
else
1
train_ratio_neg
=
1.0
/
(
args
.
probability
*
remote_ratio
)
if
((
args
.
probability
<
1
)
&
(
args
.
probability
>
0
))
else
1
print
(
train_neg_sampler
.
dst_node_list
)
neg_sampler
=
LocalNegativeSampling
(
'triplet'
,
amount
=
neg_samples
,
dst_node_list
=
full_dst
.
unique
(),
seed
=
args
.
seed
)
...
...
@@ -402,7 +407,8 @@ def main():
aps
.
append
(
average_precision_score
(
y_true
,
y_pred
.
detach
()
.
numpy
()))
aucs_mrrs
.
append
(
roc_auc_score
(
y_true
,
y_pred
))
mailbox
.
update_shared
()
mailbox
.
update_p2p
()
mailbox
.
update_p2p_mem
()
mailbox
.
update_p2p_mail
()
"""
if mailbox is not None:
src = metadata['src_pos_index']
...
...
@@ -536,7 +542,7 @@ def main():
optimizer
.
zero_grad
()
ones
=
torch
.
ones
(
metadata
[
'dst_neg_index'
]
.
shape
[
0
],
device
=
model
.
device
,
dtype
=
torch
.
float
)
weight
=
torch
.
where
(
DistIndex
(
mfgs
[
0
][
0
]
.
srcdata
[
'ID'
][
metadata
[
'dst_neg_index'
]])
.
part
==
torch
.
distributed
.
get_rank
(),
ones
/
train_ratio_pos
,
ones
/
train_ratio_neg
)
.
reshape
(
-
1
,
1
)
weight
=
torch
.
where
(
DistIndex
(
mfgs
[
0
][
0
]
.
srcdata
[
'ID'
][
metadata
[
'dst_neg_index'
]])
.
part
==
torch
.
distributed
.
get_rank
(),
ones
*
train_ratio_pos
,
ones
*
train_ratio_neg
)
.
reshape
(
-
1
,
1
)
pred_pos
,
pred_neg
=
model
(
mfgs
,
metadata
,
neg_samples
=
args
.
neg_samples
,
async_param
=
param
)
loss
=
creterion
(
pred_pos
,
torch
.
ones_like
(
pred_pos
))
neg_creterion
=
torch
.
nn
.
BCEWithLogitsLoss
(
weight
)
...
...
@@ -554,7 +560,8 @@ def main():
#train_aps.append(average_precision_score(y_true, y_pred.detach().numpy()))
#torch.cuda.synchronize()
mailbox
.
update_shared
()
mailbox
.
update_p2p
()
mailbox
.
update_p2p_mem
()
mailbox
.
update_p2p_mail
()
#torch.cuda.empty_cache()
"""
...
...
starrygl/module/historical_cache.py
View file @
99e5b95a
...
...
@@ -18,7 +18,7 @@ class CachePushRoute():
def
__init__
(
self
,
local_num
,
local_drop_edge
,
dist_index
):
ctx
=
DistributedContext
.
get_default_context
()
dist_drop
=
dist_index
[
local_drop_edge
]
id_mask
=
~
(
DistIndex
(
dist_drop
)
==
ctx
.
memory_group_rank
)
id_mask
=
not
(
DistIndex
(
dist_drop
)
==
ctx
.
memory_group_rank
)
remote_id
=
local_drop_edge
[
id_mask
]
route
=
torch
.
zeros
(
local_num
,
dtype
=
torch
.
long
)
src_mask
=
DistIndex
(
dist_drop
[
0
,:])
.
part
==
ctx
.
memory_group_rank
...
...
starrygl/module/memorys.py
View file @
99e5b95a
...
...
@@ -352,31 +352,60 @@ class TransformerMemoryUpdater(torch.nn.Module):
rst
=
self
.
dropout
(
rst
)
rst
=
torch
.
nn
.
functional
.
relu
(
rst
)
return
rst
"""
(self,index,memory,memory_ts,
mail_index,mail,mail_ts,
reduce_Op=None,
async_op=True,
set_p2p=False,
mode=None,
filter=None,
wait_submit=True,
spread_mail=True,
update_cross_mm = False)
"""
class
AsyncMemeoryUpdater
(
torch
.
nn
.
Module
):
def
all_update_func
(
self
,
index
,
memory
,
memory_ts
,
mail
,
mail_ts
,
nxt_fetch_func
):
self
.
mailbox
.
set_memory_all_reduce
(
index
,
memory
,
memory_ts
,
mail
,
mail_ts
,
reduce_Op
=
'max'
,
async_op
=
False
,
filter
=
None
,
mode
=
'all_reduce'
,
set_remote
=
True
)
if
nxt_fetch_func
is
not
None
:
nxt_fetch_func
()
def
p2p_func
(
self
,
index
,
memory
,
memory_ts
,
mail
,
mail_ts
,
nxt_fetch_func
):
self
.
mailbox
.
handle_last_async
()
submit_to_queue
=
False
if
nxt_fetch_func
is
not
None
:
nxt_fetch_func
()
submit_to_queue
=
True
self
.
mailbox
.
set_memory_all_reduce
(
index
,
memory
,
memory_ts
,
mail
,
mail_ts
,
reduce_Op
=
'max'
,
async_op
=
True
,
filter
=
None
,
mode
=
None
,
set_remote
=
True
,
submit
=
submit_to_queue
)
def
all_reduce_func
(
self
,
index
,
memory
,
memory_ts
,
mail
,
mail_ts
,
nxt_fetch_func
):
self
.
mailbox
.
set_memory_all_reduce
(
index
,
memory
,
memory_ts
,
mail
,
mail_ts
,
reduce_Op
=
'max'
,
async_op
=
False
,
filter
=
None
,
mode
=
'all_reduce'
,
set_remote
=
False
)
def
all_update_func
(
self
,
index
,
memory
,
memory_ts
,
mail_index
,
mail
,
mail_ts
,
nxt_fetch_func
,
spread_mail
=
False
):
self
.
mailbox
.
set_memory_all_reduce
(
index
,
memory
,
memory_ts
,
mail_index
,
mail
,
mail_ts
,
reduce_Op
=
'max'
,
async_op
=
False
,
mode
=
'all_reduce'
,
wait_submit
=
False
,
spread_mail
=
spread_mail
,
update_cross_mm
=
True
,
)
#self.mailbox.set_memory_all_reduce(index,memory,memory_ts,mail,mail_ts,reduce_Op = 'max', async_op = False,filter=None,mode='all_reduce',set_remote=True)
if
nxt_fetch_func
is
not
None
:
nxt_fetch_func
()
def
historical_func
(
self
,
index
,
memory
,
memory_ts
,
mail
,
mail_ts
,
nxt_fetch_func
):
# def p2p_func(self,index,memory,memory_ts,mail,mail_ts,nxt_fetch_func,mail_index = None):
# self.mailbox.handle_last_async()
# submit_to_queue = False
# if nxt_fetch_func is not None:
# nxt_fetch_func()
# submit_to_queue = True
# self.mailbox.set_memory_all_reduce(index,memory,memory_ts,mail,mail_ts,reduce_Op = 'max', async_op = True,filter=None,mode=None,set_remote=True,submit = submit_to_queue)
# def all_reduce_func(self,index,memory,memory_ts,mail,mail_ts,nxt_fetch_func,mail_index = None):
# self.mailbox.set_memory_all_reduce(index,memory,memory_ts,mail,mail_ts,reduce_Op = 'max', async_op = False,filter=None,mode='all_reduce',set_remote=False)
# if nxt_fetch_func is not None:
# nxt_fetch_func()
def
historical_func
(
self
,
index
,
memory
,
memory_ts
,
mail_index
,
mail
,
mail_ts
,
nxt_fetch_func
,
spread_mail
=
False
):
self
.
mailbox
.
sychronize_shared
()
self
.
mailbox
.
handle_last_async
()
submit_to_queue
=
False
if
nxt_fetch_func
is
not
None
:
nxt_fetch_func
()
submit_to_queue
=
True
self
.
mailbox
.
set_memory_all_reduce
(
index
,
memory
,
memory_ts
,
mail
,
mail_ts
,
reduce_Op
=
'max'
,
async_op
=
False
,
filter
=
None
,
mode
=
'historical'
,
set_remote
=
False
,
submit
=
submit_to_queue
)
def
local_func
(
self
,
index
,
memory
,
memory_ts
,
mail
,
mail_ts
,
nxt_fetch_func
):
self
.
mailbox
.
set_memory_all_reduce
(
index
,
memory
,
memory_ts
,
mail_index
,
mail
,
mail_ts
,
reduce_Op
=
'max'
,
async_op
=
True
,
mode
=
'historical'
,
wait_submit
=
submit_to_queue
,
spread_mail
=
spread_mail
,
update_cross_mm
=
False
,
)
def
local_func
(
self
,
index
,
memory
,
memory_ts
,
mail_index
,
mail
,
mail_ts
,
nxt_fetch_func
,
spread_mail
=
False
):
if
nxt_fetch_func
is
not
None
:
nxt_fetch_func
()
def
transformer_updater
(
self
,
b
):
...
...
@@ -473,7 +502,7 @@ class AsyncMemeoryUpdater(torch.nn.Module):
None
)
#print(index.shape[0])
if
param
[
0
]:
index
,
mail
,
mail_ts
=
self
.
mailbox
.
get_update_mail
(
index
0
,
mail
,
mail_ts
=
self
.
mailbox
.
get_update_mail
(
b
.
srcdata
[
'ID'
],
src
,
dst
,
ts
,
edge_feats
,
self
.
last_updated_memory
,
None
,
False
,
False
,
block
=
b
...
...
@@ -483,9 +512,12 @@ class AsyncMemeoryUpdater(torch.nn.Module):
self
.
mailbox
.
mon
.
add
(
index
,
self
.
mailbox
.
node_memory
.
accessor
.
data
[
index
],
memory
)
##print(index.shape,memory.shape,memory_ts.shape,mail.shape,mail_ts.shape)
local_mask
=
(
DistIndex
(
index
)
.
part
==
torch
.
distributed
.
get_rank
())
self
.
mailbox
.
set_mailbox_local
(
DistIndex
(
index
[
local_mask
])
.
loc
,
mail
[
local_mask
],
mail_ts
[
local_mask
],
Reduce_Op
=
'max'
)
self
.
mailbox
.
set_memory_local
(
DistIndex
(
index
[
local_mask
])
.
loc
,
memory
[
local_mask
],
memory_ts
[
local_mask
],
Reduce_Op
=
'max'
)
self
.
update_hunk
(
index
,
memory
,
memory_ts
,
mail
,
mail_ts
,
nxt_fetch_func
)
local_mask_mail
=
(
DistIndex
(
index0
)
.
part
==
torch
.
distributed
.
get_rank
())
#self.mailbox.set_mailbox_local(DistIndex(index0[local_mask_mail]).loc,mail[local_mask_mail],mail_ts[local_mask_mail],Reduce_Op = 'max')
#self.mailbox.set_memory_local(DistIndex(index[local_mask]).loc,memory[local_mask],memory_ts[local_mask], Reduce_Op = 'max')
is_deliver
=
(
self
.
mailbox
.
deliver_to
==
'neighbors'
)
self
.
update_hunk
(
index
,
memory
,
memory_ts
,
index0
,
mail
,
mail_ts
,
nxt_fetch_func
,
spread_mail
=
is_deliver
)
if
self
.
memory_param
[
'combine_node_feature'
]
and
self
.
dim_node_feat
>
0
:
if
self
.
dim_node_feat
==
self
.
dim_hid
:
...
...
starrygl/sample/memory/_shared_mailbox.py
View file @
99e5b95a
...
...
@@ -58,7 +58,7 @@ class SharedMailBox():
ts_dtye
=
torch
.
float32
,
uvm
=
False
,
use_pin
=
False
,
cache_route
=
Non
e
,
start_historical
=
Fals
e
,
shared_ssim
=
2
):
ctx
=
distributed
.
context
.
_get_default_dist_context
()
self
.
device
=
device
...
...
@@ -106,13 +106,16 @@ class SharedMailBox():
device
=
torch
.
device
(
'cuda:{}'
.
format
(
ctx
.
local_rank
)))
print
(
self
.
shared_nodes_index
)
if
cache_route
is
not
None
:
if
start_historical
is
not
None
:
self
.
historical_cache
=
historical_cache
.
HistoricalCache
(
self
.
shared_nodes_index
,
0
,
self
.
node_memory
.
shape
[
1
],
self
.
node_memory
.
dtype
,
self
.
node_memory
.
device
,
threshold
=
shared_ssim
)
self
.
_mem_pin
=
{}
self
.
_mail_pin
=
{}
self
.
use_pin
=
use_pin
self
.
last_memory_sync
=
None
self
.
last_job
=
None
self
.
last_mail_sync
=
None
self
.
next_wait_memory_job
=
None
self
.
next_wait_mail_job
=
None
self
.
next_wait_gather_memory_job
=
None
self
.
mon
=
MemoryMoniter
()
...
...
@@ -124,6 +127,7 @@ class SharedMailBox():
self
.
next_mail_pos
.
zero_
()
self
.
historical_cache
.
empty
()
self
.
last_memory_sync
=
None
self
.
last_mail_sync
=
None
def
set_memory_local
(
self
,
index
,
source
,
source_ts
,
Reduce_Op
=
None
):
...
...
@@ -212,11 +216,8 @@ class SharedMailBox():
mem
=
torch
.
cat
((
mail
,
mail_ts
.
view
(
-
1
,
1
)),
dim
=
1
)
else
:
mem
=
torch
.
cat
((
memory
,
memory_ts
.
view
(
-
1
,
1
)),
dim
=
1
)
#if index is None:
# mem = torch.cat((memory,memory_ts.view(-1,1),mail,mail_ts.view(-1,1)),dim = 1)
#else:
# mem = torch.cat((memory,memory_ts.view(-1,1),mail,mail_ts.view(-1,1),index.to(torch.float32).view(-1,1)),dim = 1)
return
mem
def
unpack
(
self
,
mem
,
mailbox
=
False
):
if
mem
.
shape
[
1
]
==
self
.
node_memory
.
shape
[
1
]
+
1
or
mem
.
shape
[
1
]
==
self
.
mailbox
.
shape
[
2
]
+
1
:
mail
=
mem
[:,:
-
1
]
...
...
@@ -234,8 +235,10 @@ class SharedMailBox():
mail
=
mem
[:,
self
.
node_memory
.
shape
[
1
]
+
1
:
mem
.
shape
[
1
]
-
self
.
mailbox_ts
.
shape
[
1
]]
.
reshape
(
mem
.
shape
[
0
],
self
.
mailbox
.
shape
[
1
],
-
1
)
mail_ts
=
mem
[:,
mem
.
shape
[
1
]
-
self
.
mailbox_ts
.
shape
[
1
]:]
return
memory
,
memory_ts
,
mail
,
mail_ts
def
handle_last_async
(
self
,
reduce_Op
=
None
):
"""
sychronization last async all-to-all M&M communication task
"""
def
handle_last_memory
(
self
,
reduce_Op
=
None
,):
if
self
.
last_memory_sync
is
not
None
:
gather_id_list
,
handle0
,
gather_memory
,
handle1
=
self
.
last_memory_sync
self
.
last_memory_sync
=
None
...
...
@@ -249,56 +252,167 @@ class SharedMailBox():
else
:
gather_memory
,
gather_memory_ts
=
self
.
unpack
(
gather_memory
)
self
.
set_memory_local
(
DistIndex
(
gather_id_list
)
.
loc
,
gather_memory
,
gather_memory_ts
,
Reduce_Op
=
reduce_Op
)
def
handle_last_mail
(
self
,
reduce_Op
=
None
,):
if
self
.
last_mail_sync
is
not
None
:
gather_id_list
,
handle0
,
gather_memory
,
handle1
=
self
.
last_mail_sync
self
.
last_mail_sync
=
None
handle0
.
wait
()
handle1
.
wait
()
if
isinstance
(
gather_memory
,
list
):
gather_memory
=
torch
.
cat
(
gather_memory
,
dim
=
0
)
gather_memory
,
gather_memory_ts
=
self
.
unpack
(
gather_memory
)
self
.
set_mailbox_local
(
DistIndex
(
gather_id_list
)
.
loc
,
gather_memory
,
gather_memory_ts
,
Reduce_Op
=
reduce_Op
)
def
handle_last_async
(
self
,
reduce_Op
=
None
):
self
.
handle_last_memory
(
self
,
reduce_Op
)
self
.
handle_last_mail
(
self
,
reduce_Op
)
"""
sychronization last async all-gather memory communication task
"""
def
sychronize_shared
(
self
):
out
=
self
.
historical_cache
.
synchronize_shared_update
()
if
out
is
not
None
:
shared_index
,
shared_data
,
shared_ts
,
mail
,
mail_ts
=
out
shared_index
,
shared_data
,
shared_ts
=
out
index
=
self
.
shared_nodes_index
[
shared_index
]
self
.
node_memory
.
accessor
.
data
[
index
]
=
shared_data
self
.
node_memory
_ts
.
accessor
.
data
[
index
]
=
shared_ts
self
.
mailbox
.
accessor
.
data
[
index
,
torch
.
max
(
self
.
next_mail_pos
[
index
]
-
1
,
torch
.
tensor
([
0
],
device
=
mail
.
device
))]
=
mail
self
.
mailbox_ts
.
accessor
.
data
[
index
,
torch
.
max
(
self
.
next_mail_pos
[
index
]
-
1
,
torch
.
tensor
([
0
],
device
=
mail_ts
.
device
))]
=
mail_ts
mask
=
(
shared_ts
>
self
.
node_memory_ts
.
accessor
.
data
[
index
])
self
.
node_memory
.
accessor
.
data
[
index
][
mask
]
=
shared_data
[
mask
]
self
.
node_memory_ts
.
accessor
.
data
[
index
][
mask
]
=
shared_ts
[
mask
]
def
update_shared
(
self
):
ctx
=
DistributedContext
.
get_default_context
()
if
self
.
last
_job
is
not
None
:
shared_list
,
mem
,
shared_id_list
,
shared_memory_ind
=
self
.
last
_job
self
.
last
_job
=
None
if
self
.
next_wait_gather_memory
_job
is
not
None
:
shared_list
,
mem
,
shared_id_list
,
shared_memory_ind
=
self
.
self
.
next_wait_gather_memory
_job
self
.
next_wait_gather_memory
_job
=
None
handle0
=
dist
.
all_gather
(
shared_list
,
mem
,
group
=
ctx
.
memory_nccl_group
,
async_op
=
True
)
handle1
=
dist
.
all_gather
(
shared_id_list
,
shared_memory_ind
,
group
=
ctx
.
memory_nccl_group
,
async_op
=
True
)
self
.
historical_cache
.
add_shared_to_queue
(
handle0
,
handle1
,
shared_id_list
,
shared_list
)
def
update_p2p
(
self
):
if
self
.
last
_job
is
None
:
def
update_p2p
_mem
(
self
):
if
self
.
next_wait_gather_memory
_job
is
None
:
return
index
,
gather_id_list
,
mem
,
gather_memory
,
input_split
,
output_split
,
group
,
async_op
=
self
.
last
_job
self
.
last
_job
=
None
index
,
gather_id_list
,
mem
,
gather_memory
,
input_split
,
output_split
,
group
,
async_op
=
self
.
next_wait_gather_memory
_job
self
.
next_wait_gather_memory
_job
=
None
handle0
=
torch
.
distributed
.
all_to_all_single
(
gather_id_list
,
index
,
output_split_sizes
=
output_split
,
input_split_sizes
=
input_split
,
group
=
group
,
async_op
=
async_op
)
#print(input_split,output_split)
handle1
=
torch
.
distributed
.
all_to_all_single
(
gather_memory
,
mem
,
output_split_sizes
=
output_split
,
input_split_sizes
=
input_split
,
group
=
group
,
async_op
=
async_op
)
self
.
last_memory_sync
=
(
gather_id_list
,
handle0
,
gather_memory
,
handle1
)
def
set_memory_all_reduce
(
self
,
index
,
memory
,
memory_ts
,
mail
,
mail_ts
,
reduce_Op
=
None
,
async_op
=
True
,
set_remote
=
False
,
mode
=
None
,
filter
=
None
,
submit
=
True
):
def
update_p2p_mail
(
self
):
if
self
.
next_wait_gather_mail_job
is
None
:
return
index
,
gather_id_list
,
mem
,
gather_memory
,
input_split
,
output_split
,
group
,
async_op
=
self
.
next_wait_gather_mail_job
self
.
next_wait_gather_mail_job
=
None
handle0
=
torch
.
distributed
.
all_to_all_single
(
gather_id_list
,
index
,
output_split_sizes
=
output_split
,
input_split_sizes
=
input_split
,
group
=
group
,
async_op
=
async_op
)
handle1
=
torch
.
distributed
.
all_to_all_single
(
gather_memory
,
mem
,
output_split_sizes
=
output_split
,
input_split_sizes
=
input_split
,
group
=
group
,
async_op
=
async_op
)
self
.
last_mail_sync
=
(
gather_id_list
,
handle0
,
gather_memory
,
handle1
)
"""
submit: take the task request into queue wait for start
"""
def
build_all_to_all_route
(
self
,
index
,
mm
,
mm_ts
,
is_no_redundant
):
ctx
=
DistributedContext
.
get_default_context
()
#print(DistIndex(index).part)
if
set_remote
is
False
:
gather_len_list
=
torch
.
empty
([
self
.
num_parts
],
dtype
=
int
,
device
=
self
.
device
)
ind
=
torch
.
ops
.
torch_sparse
.
ind2ptr
(
DistIndex
(
index
)
.
part
,
self
.
num_parts
)
scatter_len_list
=
ind
[
1
:]
-
ind
[
0
:
-
1
]
torch
.
distributed
.
all_to_all_single
(
gather_len_list
,
scatter_len_list
,
group
=
ctx
.
memory_nccl_group
)
input_split
=
scatter_len_list
.
tolist
()
output_split
=
gather_len_list
.
tolist
()
if
is_no_redundant
:
gather_ts
=
torch
.
empty
(
[
gather_len_list
.
sum
()],
dtype
=
mm_ts
.
dtype
,
device
=
self
.
device
)
gather_id
=
torch
.
empty
(
[
gather_len_list
.
sum
()],
dtype
=
index
.
dtype
,
device
=
self
.
device
)
torch
.
distributed
.
all_to_all_single
(
gather_id
,
index
,
output_split_sizes
=
output_split
,
input_split_sizes
=
input_split
,
group
=
ctx
.
memory_nccl_group
,
)
torch
.
distributed
.
all_to_all_single
(
gather_ts
,
mm_ts
,
output_split_sizes
=
output_split
,
input_split_sizes
=
input_split
,
group
=
ctx
.
memory_nccl_group
,
)
unq_id
,
inv
=
gather_id
.
unique
(
return_inverse
=
True
)
max_ts
,
pos
=
torch_scatter
.
scatter_max
(
gather_ts
,
inv
,
dim
=
0
)
is_used
=
torch
.
zeros
(
gather_ts
.
shape
,
device
=
gather_ts
.
device
,
dtype
=
torch
.
bool
)
is_used
[
pos
]
=
True
send_mm_to_dst
=
torch
.
zeros
([
scatter_len_list
.
sum
()
.
item
()],
device
=
is_used
.
device
,
dtype
=
torch
.
bool
)
torch
.
distributed
.
all_to_all_single
(
send_mm_to_dst
,
is_used
,
output_split
=
input_split
,
input_split
=
output_split
,
group
=
ctx
.
memory_nccl_group
,
)
index
=
index
[
send_mm_to_dst
]
mm
=
mm
[
send_mm_to_dst
]
mm_ts
=
mm_ts
[
send_mm_to_dst
]
gather_len_list
=
torch
.
empty
([
self
.
num_parts
],
dtype
=
int
,
device
=
self
.
device
)
ind
=
torch
.
ops
.
torch_sparse
.
ind2ptr
(
DistIndex
(
index
)
.
part
,
self
.
num_parts
)
scatter_len_list
=
ind
[
1
:]
-
ind
[
0
:
-
1
]
torch
.
distributed
.
all_to_all_single
(
gather_len_list
,
scatter_len_list
,
group
=
ctx
.
memory_nccl_group
)
input_split
=
scatter_len_list
.
tolist
()
output_split
=
gather_len_list
.
tolist
()
gather_id
=
gather_id
[
is_used
]
else
:
gather_id
=
torch
.
empty
([
gather_len_list
.
sum
()],
dtype
=
index
.
dtype
(),
device
=
self
.
device
)
gather_memory
=
torch
.
empty
(
[
gather_len_list
.
sum
(),
mm
.
shape
[
1
]],
dtype
=
mm
.
dtype
,
device
=
self
.
device
)
return
index
,
gather_id
,
mm
,
gather_memory
,
input_split
,
output_split
def
build_all_to_all_async_task
(
self
,
index
,
mm
,
mm_ts
,
is_no_redundant
=
False
,
is_async
=
True
):
ctx
=
DistributedContext
.
get_default_context
()
p2p_async_info
=
self
.
build_all_to_all_route
(
self
,
index
,
mm
,
mm_ts
,
is_no_redundant
,
is_async
=
True
)
if
mm
.
shape
[
1
]
!=
self
.
mailbox
.
shape
[
-
1
]
+
1
:
self
.
next_wait_memory_job
=
(
*
p2p_async_info
,
is_async
,
ctx
.
memory_nccl_group
)
else
:
self
.
next_wait_mail_job
=
(
*
p2p_async_info
,
is_async
,
ctx
.
memory_nccl_group
)
def
set_memory_all_reduce
(
self
,
index
,
memory
,
memory_ts
,
mail_index
,
mail
,
mail_ts
,
reduce_Op
=
None
,
async_op
=
True
,
mode
=
None
,
wait_submit
=
True
,
spread_mail
=
True
,
update_cross_mm
=
False
):
if
self
.
num_parts
==
1
:
return
if
not
spread_mail
and
not
update_cross_mm
:
pass
# self.set_mailbox_local(DistIndex(index).loc,mail,mail_ts,Reduce_Op = reduce_Op)
# self.set_memory_local(DistIndex(index).loc,memory,memory_ts, Reduce_Op = reduce_Op)
else
:
#print(index,memory,memory_ts)
if
async_op
==
True
and
self
.
num_parts
>
1
:
#local_mask = (DistIndex(index).loc == dist.get_rank())
self
.
set_mailbox_all_to_all
(
index
,
memory
,
memory_ts
,
mail
,
mail_ts
,
reduce_Op
=
'max'
,
async_op
=
async_op
,
submit
=
submit
)
if
spread_mail
:
mm
=
torch
.
cat
((
mail
,
mail_ts
.
reshape
(
-
1
,
1
)),
dim
=
1
)
mm_ts
=
mail_ts
self
.
build_all_to_all_async_task
(
mail_index
,
mm
,
mm_ts
,
is_async
=
True
,
is_no_redundant
=
True
)
if
update_cross_mm
:
mm
=
torch
.
cat
((
memory
,
memory_ts
.
reshape
(
-
1
,
1
)),
dim
=
1
)
mm_ts
=
memory_ts
self
.
build_all_to_all_async_task
(
mail_index
,
mm
,
mm_ts
,
is_async
=
True
,
is_no_redundant
=
True
)
else
:
self
.
set_mailbox_all_to_all
(
index
,
memory
,
memory_ts
,
mail
,
mail_ts
,
reduce_Op
=
'max'
,
async_op
=
False
)
if
update_cross_mm
:
mm
=
self
.
pack
(
memory
,
memory_ts
,
mail
,
mail_ts
,
index
)
mm_ts
=
memory_ts
self
.
build_all_to_all_async_task
(
mail_index
,
mm
,
mm_ts
,
is_async
=
True
,
is_no_redundant
=
True
)
if
async_op
is
False
:
self
.
update_p2p_mail
()
self
.
update_p2p_mem
()
self
.
handle_last_async
()
ctx
=
DistributedContext
.
get_default_context
()
if
self
.
shared_nodes_index
is
not
None
and
(
mode
==
'all_reduce'
or
mode
==
'historical'
):
shared_memory_ind
=
self
.
is_shared_mask
[
torch
.
min
(
DistIndex
(
index
)
.
loc
,
torch
.
tensor
([
self
.
num_nodes
-
1
],
device
=
torch
.
device
(
'cuda'
)
))]
shared_memory_ind
=
self
.
is_shared_mask
[
torch
.
min
(
DistIndex
(
index
)
.
loc
,
torch
.
tensor
([
self
.
num_nodes
-
1
],
device
=
index
.
device
))]
mask
=
((
shared_memory_ind
>-
1
)
&
(
DistIndex
(
index
)
.
part
==
ctx
.
memory_group_rank
))
shared_memory_ind
=
shared_memory_ind
[
mask
]
shared_memory
=
memory
[
mask
]
...
...
@@ -306,20 +420,16 @@ class SharedMailBox():
shared_mail
=
mail
[
mask
]
shared_mail_ts
=
mail_ts
[
mask
]
if
mode
==
'historical'
:
#print(shared_memory_ind)
update_index
=
self
.
historical_cache
.
historical_check
(
shared_memory_ind
,
shared_memory
,
shared_memory_ts
)
#print(update_index.sum(),shared_memory_ind.shape)
shared_memory_ind
=
shared_memory_ind
[
update_index
]
shared_memory
=
shared_memory
[
update_index
]
shared_memory_ts
=
shared_memory_ts
[
update_index
]
shared_mail
=
shared_mail
[
update_index
]
shared_mail_ts
=
shared_mail_ts
[
update_index
]
#print(shared_memory_ind)
#mem = self.pack(memory=shared_memory,memory_ts=shared_memory_ts,mail=shared_mail,mail_ts=shared_mail_ts,index=shared_memory_ind,mode=mode)
#mem = self.pack(memory=shared_mail,memory_ts=shared_mail_ts,index=shared_memory_ind,mode=mode)
mem
=
self
.
pack
(
memory
=
shared_memory
,
memory_ts
=
shared_memory_ts
,
index
=
shared_memory_ind
,
mode
=
mode
)
else
:
mem
=
self
.
pack
(
memory
=
shared_memory
,
memory_ts
=
shared_memory_ts
,
mail
=
shared_mail
,
mail_ts
=
shared_mail_ts
,
index
=
shared_memory_ind
,
mode
=
mode
)
self
.
tot_shared_count
+=
shared_memory_ind
.
shape
[
0
]
mem
=
self
.
pack
(
memory
=
shared_memory
,
memory_ts
=
shared_memory_ts
,
mail
=
shared_mail
,
mail_ts
=
shared_mail_ts
,
index
=
shared_memory_ind
,
mode
=
mode
)
broadcast_len
=
torch
.
empty
([
1
],
device
=
mem
.
device
,
dtype
=
torch
.
int
)
broadcast_len
[
0
]
=
shared_memory_ind
.
shape
[
0
]
shared_len
=
[
torch
.
empty
([
1
],
device
=
mem
.
device
,
dtype
=
torch
.
int
)
for
_
in
range
(
ctx
.
memory_group_size
)]
...
...
@@ -327,53 +437,34 @@ class SharedMailBox():
shared_list
=
[
torch
.
empty
([
l
.
item
(),
mem
.
shape
[
1
]],
device
=
mem
.
device
,
dtype
=
mem
.
dtype
)
for
l
in
shared_len
]
shared_id_list
=
[
torch
.
empty
([
l
.
item
()],
device
=
shared_memory_ind
.
device
,
dtype
=
shared_memory_ind
.
dtype
)
for
l
in
shared_len
]
if
mode
==
'all_reduce'
:
start
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
end
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
start
.
record
()
if
async_op
==
True
:
handle0
=
dist
.
all_gather
(
shared_list
,
mem
,
group
=
ctx
.
memory_nccl_group
)
handle1
=
dist
.
all_gather
(
shared_id_list
,
shared_memory_ind
,
group
=
ctx
.
memory_nccl_group
)
self
.
last_memory_sync
=
shared_id_list
,
handle0
,
shared_id_list
,
handle1
else
:
dist
.
all_gather
(
shared_list
,
mem
,
group
=
ctx
.
memory_nccl_group
)
dist
.
all_gather
(
shared_id_list
,
shared_memory_ind
,
group
=
ctx
.
memory_nccl_group
)
mem
=
torch
.
cat
(
shared_list
,
dim
=
0
)
shared_index
=
torch
.
cat
(
shared_id_list
)
#id = shared_index.sort()
#print(mem[id],shared_index[id])
#,shared_memory,shared_memory_ts,
#shared_memory,shared_memory_ts = self.unpack(mem)
shared_memory
,
shared_memory_ts
,
shared_mail
,
shared_mail_ts
=
self
.
unpack
(
mem
)
unq_index
,
inv
=
torch
.
unique
(
shared_index
,
return_inverse
=
True
)
#print(inv.shape,Reduce_score.shape)
max_ts
,
idx
=
torch_scatter
.
scatter_max
(
shared_memory_ts
,
inv
,
0
)
#min_ts,_ = torch_scatter.scatter_min(shared_mail_ts,inv,0)
shared_memory
=
shared_memory
[
idx
]
shared_memory_ts
=
shared_memory_ts
[
idx
]
shared_mail_ts
=
shared_mail_ts
[
idx
]
shared_mail
=
shared_mail
[
idx
]
#shared_mail_ts = torch_scatter.scatter_mean(shared_mail_ts,inv,0)
#shared_mail = torch_scatter.scatter_mean(shared_mail,inv,0)
shared_index
=
unq_index
self
.
set_memory_local
(
self
.
shared_nodes_index
[
shared_index
],
shared_memory
,
shared_memory_ts
)
self
.
historical_cache
.
local_historical_data
[
shared_index
]
=
shared_memory
self
.
historical_cache
.
local_ts
[
shared_index
]
=
shared_memory_ts
self
.
set_mailbox_local
(
self
.
shared_nodes_index
[
shared_index
],
shared_mail
,
shared_mail_ts
)
#if async_op == True:
# handle0 = dist.all_gather(shared_list,mem,group=ctx.memory_nccl_group)
# handle1 = dist.all_gather(shared_id_list,shared_memory_ind,group=ctx.memory_nccl_group)
# self.last_memory_sync = shared_id_list,handle0,shared_id_list,handle1
#else:
dist
.
all_gather
(
shared_list
,
mem
,
group
=
ctx
.
memory_nccl_group
)
dist
.
all_gather
(
shared_id_list
,
shared_memory_ind
,
group
=
ctx
.
memory_nccl_group
)
mem
=
torch
.
cat
(
shared_list
,
dim
=
0
)
shared_index
=
torch
.
cat
(
shared_id_list
)
shared_memory
,
shared_memory_ts
,
shared_mail
,
shared_mail_ts
=
self
.
unpack
(
mem
)
unq_index
,
inv
=
torch
.
unique
(
shared_index
,
return_inverse
=
True
)
max_ts
,
idx
=
torch_scatter
.
scatter_max
(
shared_memory_ts
,
inv
,
0
)
shared_memory
=
shared_memory
[
idx
]
shared_memory_ts
=
shared_memory_ts
[
idx
]
shared_mail_ts
=
shared_mail_ts
[
idx
]
shared_mail
=
shared_mail
[
idx
]
shared_index
=
unq_index
self
.
set_memory_local
(
self
.
shared_nodes_index
[
shared_index
],
shared_memory
,
shared_memory_ts
)
self
.
historical_cache
.
local_historical_data
[
shared_index
]
=
shared_memory
self
.
historical_cache
.
local_ts
[
shared_index
]
=
shared_memory_ts
self
.
set_mailbox_local
(
self
.
shared_nodes_index
[
shared_index
],
shared_mail
,
shared_mail_ts
)
else
:
start
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
end
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
start
.
record
()
#self.historical_cache.synchronize_shared_update(filter)
#mem[:,:-1] = mem[:,:-1] + filter.get_incretment(shared_memory_ind)
#shared_list = [torch.empty([l.item(),mem.shape[1]],device = mem.device,dtype=mem.dtype) for l in shared_len]
#handle0 = dist.all_gather(shared_list,mem,group=ctx.memory_nccl_group,async_op=True)
#shared_id_list = [torch.empty([l.item()],device = shared_memory_ind.device,dtype=shared_memory_ind.dtype) for l in shared_len]
#handle1 = dist.all_gather(shared_id_list,shared_memory_ind,group=ctx.memory_nccl_group,async_op=True)
self
.
last_job
=
(
shared_list
,
mem
,
shared_id_list
,
shared_memory_ind
)
if
~
submit
:
self
.
update_shared
()
self
.
next_wait_gather_memory_job
=
(
shared_list
,
mem
,
shared_id_list
,
shared_memory_ind
)
if
not
wait_submit
:
self
.
update_shared
()
self
.
update_p2p_mail
()
self
.
update_p2p_mem
()
#self.historical_cache.add_shared_to_queue(handle0,handle1,shared_id_list,shared_list)
"""
shared_memory = self.node_memory.accessor.data[self.shared_nodes_index]
...
...
@@ -394,107 +485,54 @@ class SharedMailBox():
def
set_mailbox_all_to_all_empty
(
self
,
index
,
memory
,
memory_ts
,
mail
,
mail_ts
,
reduce_Op
=
None
,
group
=
None
):
if
self
.
num_parts
==
1
:
dist_index
=
DistIndex
(
index
)
part_idx
=
dist_index
.
part
index
=
dist_index
.
loc
self
.
set_mailbox_local
(
index
,
mail
,
mail_ts
)
self
.
set_memory_local
(
index
,
memory
,
memory_ts
)
else
:
gather_len_list
=
torch
.
empty
([
self
.
num_parts
],
dtype
=
int
,
device
=
self
.
device
)
#indic = torch.searchsorted(index,self.partptr,right=False)
indic
=
torch
.
ops
.
torch_sparse
.
ind2ptr
(
DistIndex
(
index
)
.
part
,
self
.
num_parts
)
scatter_len_list
=
indic
[
1
:]
-
indic
[
0
:
-
1
]
torch
.
distributed
.
all_to_all_single
(
gather_len_list
,
scatter_len_list
,
group
=
group
)
input_split
=
scatter_len_list
.
tolist
()
output_split
=
gather_len_list
.
tolist
()
gather_id_list
=
torch
.
empty
(
[
gather_len_list
.
sum
()],
dtype
=
torch
.
long
,
device
=
self
.
device
)
input_split
=
scatter_len_list
.
tolist
()
output_split
=
gather_len_list
.
tolist
()
torch
.
distributed
.
all_to_all_single
(
gather_id_list
,
index
,
output_split_sizes
=
output_split
,
input_split_sizes
=
input_split
,
group
=
group
)
index
=
gather_id_list
gather_memory
=
torch
.
empty
(
[
gather_len_list
.
sum
(),
memory
.
shape
[
1
]],
dtype
=
memory
.
dtype
,
device
=
self
.
device
)
gather_memory_ts
=
torch
.
empty
(
[
gather_len_list
.
sum
()],
dtype
=
memory_ts
.
dtype
,
device
=
self
.
device
)
gather_mail
=
torch
.
empty
(
[
gather_len_list
.
sum
(),
mail
.
shape
[
1
]],
dtype
=
mail
.
dtype
,
device
=
self
.
device
)
gather_mail_ts
=
torch
.
empty
(
[
gather_len_list
.
sum
()],
dtype
=
mail_ts
.
dtype
,
device
=
self
.
device
)
torch
.
distributed
.
all_to_all_single
(
gather_memory
,
memory
,
output_split_sizes
=
output_split
,
input_split_sizes
=
input_split
,
group
=
group
,
async_op
=
True
)
torch
.
distributed
.
all_to_all_single
(
gather_memory_ts
,
memory_ts
,
output_split_sizes
=
output_split
,
input_split_sizes
=
input_split
,
group
=
group
)
torch
.
distributed
.
all_to_all_single
(
gather_mail
,
mail
,
output_split_sizes
=
output_split
,
input_split_sizes
=
input_split
,
group
=
group
)
torch
.
distributed
.
all_to_all_single
(
gather_mail_ts
,
mail_ts
,
output_split_sizes
=
output_split
,
input_split_sizes
=
input_split
,
group
=
group
)
pass
def
set_mailbox_all_to_all
(
self
,
index
,
memory
,
memory_ts
,
mail
,
mail_ts
,
reduce_Op
=
None
,
group
=
None
,
async_op
=
False
,
submit
=
True
):
#futs: List[torch.futures.Future] = []
if
self
.
num_parts
==
1
:
dist_index
=
DistIndex
(
index
)
part_idx
=
dist_index
.
part
index
=
dist_index
.
loc
self
.
set_mailbox_local
(
index
,
mail
,
mail_ts
)
self
.
set_memory_local
(
index
,
memory
,
memory_ts
)
else
:
self
.
tot_comm_count
+=
(
DistIndex
(
index
)
.
part
!=
dist
.
get_rank
())
.
sum
()
gather_len_list
=
torch
.
empty
([
self
.
num_parts
],
dtype
=
int
,
device
=
self
.
device
)
indic
=
torch
.
ops
.
torch_sparse
.
ind2ptr
(
DistIndex
(
index
)
.
part
,
self
.
num_parts
)
scatter_len_list
=
indic
[
1
:]
-
indic
[
0
:
-
1
]
torch
.
distributed
.
all_to_all_single
(
gather_len_list
,
scatter_len_list
,
group
=
group
)
input_split
=
scatter_len_list
.
tolist
()
output_split
=
gather_len_list
.
tolist
()
mem
=
self
.
pack
(
memory
,
memory_ts
,
mail
,
mail_ts
)
gather_memory
=
torch
.
empty
(
[
gather_len_list
.
sum
(),
mem
.
shape
[
1
]],
dtype
=
memory
.
dtype
,
device
=
self
.
device
)
gather_id_list
=
torch
.
empty
([
gather_len_list
.
sum
()],
dtype
=
torch
.
long
,
device
=
self
.
device
)
input_split
=
scatter_len_list
.
tolist
()
output_split
=
gather_len_list
.
tolist
()
if
async_op
==
True
:
self
.
last_job
=
index
,
gather_id_list
,
mem
,
gather_memory
,
input_split
,
output_split
,
group
,
async_op
if
~
submit
:
self
.
update_p2p
()
else
:
torch
.
distributed
.
all_to_all_single
(
gather_id_list
,
index
,
output_split_sizes
=
output_split
,
input_split_sizes
=
input_split
,
group
=
group
,
async_op
=
async_op
)
torch
.
distributed
.
all_to_all_single
(
gather_memory
,
mem
,
output_split_sizes
=
output_split
,
input_split_sizes
=
input_split
,
group
=
group
)
if
gather_memory
.
shape
[
1
]
>
self
.
node_memory
.
shape
[
1
]
+
1
:
gather_memory
,
gather_memory_ts
,
gather_mail
,
gather_mail_ts
=
self
.
unpack
(
gather_memory
)
self
.
set_mailbox_local
(
DistIndex
(
gather_id_list
)
.
loc
,
gather_mail
,
gather_mail_ts
,
Reduce_Op
=
reduce_Op
)
else
:
gather_memory
,
gather_memory_ts
=
self
.
unpack
(
gather_memory
)
self
.
set_memory_local
(
DistIndex
(
gather_id_list
)
.
loc
,
gather_memory
,
gather_memory_ts
,
Reduce_Op
=
reduce_Op
)
pass
# #futs: List[torch.futures.Future] = []
# if self.num_parts == 1:
# dist_index = DistIndex(index)
#
index = dist_index.loc
#
self.set_mailbox_local(index,mail,mail_ts)
#
self.set_memory_local(index,memory,memory_ts)
#
else:
#
self.tot_comm_count += (DistIndex(index).part != dist.get_rank()).sum()
#
gather_len_list = torch.empty([self.num_parts],
#
dtype = int,
#
device = self.device)
#
indic = torch.ops.torch_sparse.ind2ptr(DistIndex(index).part, self.num_parts)
#
scatter_len_list = indic[1:] - indic[0:-1]
#
torch.distributed.all_to_all_single(gather_len_list,scatter_len_list,group = group)
#
input_split = scatter_len_list.tolist()
#
output_split = gather_len_list.tolist()
#
mem = self.pack(memory,memory_ts,mail,mail_ts)
#
gather_memory = torch.empty(
#
[gather_len_list.sum(),mem.shape[1]],
#
dtype = memory.dtype,device = self.device)
#
gather_id_list = torch.empty([gather_len_list.sum()],dtype = torch.long,device = self.device)
#
input_split = scatter_len_list.tolist()
#
output_split = gather_len_list.tolist()
#
if async_op == True:
#
self.last_job = index,gather_id_list,mem,gather_memory,input_split,output_split,group,async_op
# if not
submit:
#
self.update_p2p()
#
else:
#
torch.distributed.all_to_all_single(
#
gather_id_list,index,output_split_sizes=output_split,
#
input_split_sizes=input_split,group = group,async_op=async_op)
#
torch.distributed.all_to_all_single(
#
gather_memory,mem,
#
output_split_sizes=output_split,
#
input_split_sizes=input_split,group = group)
#
if gather_memory.shape[1] > self.node_memory.shape[1] + 1:
#
gather_memory,gather_memory_ts,gather_mail,gather_mail_ts = self.unpack(gather_memory)
#
self.set_mailbox_local(DistIndex(gather_id_list).loc,gather_mail,gather_mail_ts,Reduce_Op = reduce_Op)
#
else:
#
gather_memory,gather_memory_ts = self.unpack(gather_memory)
#
self.set_memory_local(DistIndex(gather_id_list).loc,gather_memory,gather_memory_ts, Reduce_Op = reduce_Op)
def
gather_memory
(
self
,
...
...
starrygl/sample/memory/shared_mailbox.py
View file @
99e5b95a
...
...
@@ -58,7 +58,7 @@ class SharedMailBox():
ts_dtye
=
torch
.
float32
,
uvm
=
False
,
use_pin
=
False
,
cache_route
=
Non
e
,
start_historical
=
Fals
e
,
shared_ssim
=
2
):
ctx
=
distributed
.
context
.
_get_default_dist_context
()
self
.
device
=
device
...
...
@@ -105,14 +105,16 @@ class SharedMailBox():
self
.
is_shared_mask
[
shared_nodes_index
]
=
torch
.
arange
(
self
.
shared_nodes_index
.
shape
[
0
],
dtype
=
torch
.
int
,
device
=
torch
.
device
(
'cuda:{}'
.
format
(
ctx
.
local_rank
)))
print
(
self
.
shared_nodes_index
)
if
cache_route
is
not
None
:
if
start_historical
is
not
None
:
self
.
historical_cache
=
historical_cache
.
HistoricalCache
(
self
.
shared_nodes_index
,
0
,
self
.
node_memory
.
shape
[
1
],
self
.
node_memory
.
dtype
,
self
.
node_memory
.
device
,
threshold
=
shared_ssim
)
self
.
_mem_pin
=
{}
self
.
_mail_pin
=
{}
self
.
use_pin
=
use_pin
self
.
last_memory_sync
=
None
self
.
last_job
=
None
self
.
last_mail_sync
=
None
self
.
next_wait_memory_job
=
None
self
.
next_wait_mail_job
=
None
self
.
next_wait_gather_memory_job
=
None
self
.
mon
=
MemoryMoniter
()
...
...
@@ -124,6 +126,7 @@ class SharedMailBox():
self
.
next_mail_pos
.
zero_
()
self
.
historical_cache
.
empty
()
self
.
last_memory_sync
=
None
self
.
last_mail_sync
=
None
def
set_memory_local
(
self
,
index
,
source
,
source_ts
,
Reduce_Op
=
None
):
...
...
@@ -212,11 +215,8 @@ class SharedMailBox():
mem
=
torch
.
cat
((
mail
,
mail_ts
.
view
(
-
1
,
1
)),
dim
=
1
)
else
:
mem
=
torch
.
cat
((
memory
,
memory_ts
.
view
(
-
1
,
1
)),
dim
=
1
)
#if index is None:
# mem = torch.cat((memory,memory_ts.view(-1,1),mail,mail_ts.view(-1,1)),dim = 1)
#else:
# mem = torch.cat((memory,memory_ts.view(-1,1),mail,mail_ts.view(-1,1),index.to(torch.float32).view(-1,1)),dim = 1)
return
mem
def
unpack
(
self
,
mem
,
mailbox
=
False
):
if
mem
.
shape
[
1
]
==
self
.
node_memory
.
shape
[
1
]
+
1
or
mem
.
shape
[
1
]
==
self
.
mailbox
.
shape
[
2
]
+
1
:
mail
=
mem
[:,:
-
1
]
...
...
@@ -234,8 +234,10 @@ class SharedMailBox():
mail
=
mem
[:,
self
.
node_memory
.
shape
[
1
]
+
1
:
mem
.
shape
[
1
]
-
self
.
mailbox_ts
.
shape
[
1
]]
.
reshape
(
mem
.
shape
[
0
],
self
.
mailbox
.
shape
[
1
],
-
1
)
mail_ts
=
mem
[:,
mem
.
shape
[
1
]
-
self
.
mailbox_ts
.
shape
[
1
]:]
return
memory
,
memory_ts
,
mail
,
mail_ts
def
handle_last_async
(
self
,
reduce_Op
=
None
):
"""
sychronization last async all-to-all M&M communication task
"""
def
handle_last_memory
(
self
,
reduce_Op
=
None
,):
if
self
.
last_memory_sync
is
not
None
:
gather_id_list
,
handle0
,
gather_memory
,
handle1
=
self
.
last_memory_sync
self
.
last_memory_sync
=
None
...
...
@@ -249,7 +251,23 @@ class SharedMailBox():
else
:
gather_memory
,
gather_memory_ts
=
self
.
unpack
(
gather_memory
)
self
.
set_memory_local
(
DistIndex
(
gather_id_list
)
.
loc
,
gather_memory
,
gather_memory_ts
,
Reduce_Op
=
reduce_Op
)
def
handle_last_mail
(
self
,
reduce_Op
=
None
,):
if
self
.
last_mail_sync
is
not
None
:
gather_id_list
,
handle0
,
gather_memory
,
handle1
=
self
.
last_mail_sync
self
.
last_mail_sync
=
None
handle0
.
wait
()
handle1
.
wait
()
if
isinstance
(
gather_memory
,
list
):
gather_memory
=
torch
.
cat
(
gather_memory
,
dim
=
0
)
gather_memory
,
gather_memory_ts
=
self
.
unpack
(
gather_memory
)
self
.
set_mailbox_local
(
DistIndex
(
gather_id_list
)
.
loc
,
gather_memory
,
gather_memory_ts
,
Reduce_Op
=
reduce_Op
)
def
handle_last_async
(
self
,
reduce_Op
=
None
):
self
.
handle_last_memory
(
reduce_Op
)
self
.
handle_last_mail
(
reduce_Op
)
"""
sychronization last async all-gather memory communication task
"""
def
sychronize_shared
(
self
):
out
=
self
.
historical_cache
.
synchronize_shared_update
()
if
out
is
not
None
:
...
...
@@ -258,71 +276,167 @@ class SharedMailBox():
mask
=
(
shared_ts
>
self
.
node_memory_ts
.
accessor
.
data
[
index
])
self
.
node_memory
.
accessor
.
data
[
index
][
mask
]
=
shared_data
[
mask
]
self
.
node_memory_ts
.
accessor
.
data
[
index
][
mask
]
=
shared_ts
[
mask
]
#self.mailbox.accessor.data[index, torch.max(self.next_mail_pos[index]-1,torch.tensor([0],device=mail.device))] = mail
#self.mailbox_ts.accessor.data[index, torch.max(self.next_mail_pos[index]-1,torch.tensor([0],device=mail_ts.device))] = mail_ts
def
update_shared
(
self
):
ctx
=
DistributedContext
.
get_default_context
()
if
self
.
last
_job
is
not
None
:
shared_list
,
mem
,
shared_id_list
,
shared_memory_ind
=
self
.
last
_job
self
.
last
_job
=
None
if
self
.
next_wait_gather_memory
_job
is
not
None
:
shared_list
,
mem
,
shared_id_list
,
shared_memory_ind
=
self
.
next_wait_gather_memory
_job
self
.
next_wait_gather_memory
_job
=
None
handle0
=
dist
.
all_gather
(
shared_list
,
mem
,
group
=
ctx
.
memory_nccl_group
,
async_op
=
True
)
handle1
=
dist
.
all_gather
(
shared_id_list
,
shared_memory_ind
,
group
=
ctx
.
memory_nccl_group
,
async_op
=
True
)
self
.
historical_cache
.
add_shared_to_queue
(
handle0
,
handle1
,
shared_id_list
,
shared_list
)
def
update_p2p
(
self
):
if
self
.
last
_job
is
None
:
def
update_p2p
_mem
(
self
):
if
self
.
next_wait_memory
_job
is
None
:
return
index
,
gather_id_list
,
mem
,
gather_memory
,
input_split
,
output_split
,
group
,
async_op
=
self
.
last
_job
self
.
last
_job
=
None
index
,
gather_id_list
,
mem
,
gather_memory
,
input_split
,
output_split
,
group
,
async_op
=
self
.
next_wait_memory
_job
self
.
next_wait_memory
_job
=
None
handle0
=
torch
.
distributed
.
all_to_all_single
(
gather_id_list
,
index
,
output_split_sizes
=
output_split
,
input_split_sizes
=
input_split
,
group
=
group
,
async_op
=
async_op
)
#print(input_split,output_split)
handle1
=
torch
.
distributed
.
all_to_all_single
(
gather_memory
,
mem
,
output_split_sizes
=
output_split
,
input_split_sizes
=
input_split
,
group
=
group
,
async_op
=
async_op
)
self
.
last_memory_sync
=
(
gather_id_list
,
handle0
,
gather_memory
,
handle1
)
def
update_p2p_mail
(
self
):
if
self
.
next_wait_mail_job
is
None
:
return
index
,
gather_id_list
,
mem
,
gather_memory
,
input_split
,
output_split
,
group
,
async_op
=
self
.
next_wait_mail_job
self
.
next_wait_mail_job
=
None
handle0
=
torch
.
distributed
.
all_to_all_single
(
gather_id_list
,
index
,
output_split_sizes
=
output_split
,
input_split_sizes
=
input_split
,
group
=
group
,
async_op
=
async_op
)
handle1
=
torch
.
distributed
.
all_to_all_single
(
gather_memory
,
mem
,
output_split_sizes
=
output_split
,
input_split_sizes
=
input_split
,
group
=
group
,
async_op
=
async_op
)
self
.
last_mail_sync
=
(
gather_id_list
,
handle0
,
gather_memory
,
handle1
)
"""
submit: take the task request into queue wait for start
"""
def
build_all_to_all_route
(
self
,
index
,
mm
,
mm_ts
,
is_no_redundant
):
ctx
=
DistributedContext
.
get_default_context
()
gather_len_list
=
torch
.
empty
([
self
.
num_parts
],
dtype
=
int
,
device
=
self
.
device
)
ind
=
torch
.
ops
.
torch_sparse
.
ind2ptr
(
DistIndex
(
index
)
.
part
,
self
.
num_parts
)
scatter_len_list
=
ind
[
1
:]
-
ind
[
0
:
-
1
]
torch
.
distributed
.
all_to_all_single
(
gather_len_list
,
scatter_len_list
,
group
=
ctx
.
memory_nccl_group
)
input_split
=
scatter_len_list
.
tolist
()
output_split
=
gather_len_list
.
tolist
()
if
is_no_redundant
:
gather_ts
=
torch
.
empty
(
[
gather_len_list
.
sum
()],
dtype
=
mm_ts
.
dtype
,
device
=
self
.
device
)
gather_id
=
torch
.
empty
(
[
gather_len_list
.
sum
()],
dtype
=
index
.
dtype
,
device
=
self
.
device
)
torch
.
distributed
.
all_to_all_single
(
gather_id
,
index
,
output_split_sizes
=
output_split
,
input_split_sizes
=
input_split
,
group
=
ctx
.
memory_nccl_group
,
)
torch
.
distributed
.
all_to_all_single
(
gather_ts
,
mm_ts
,
output_split_sizes
=
output_split
,
input_split_sizes
=
input_split
,
group
=
ctx
.
memory_nccl_group
,
)
unq_id
,
inv
=
gather_id
.
unique
(
return_inverse
=
True
)
max_ts
,
pos
=
torch_scatter
.
scatter_max
(
gather_ts
,
inv
,
dim
=
0
)
is_used
=
torch
.
zeros
(
gather_ts
.
shape
,
device
=
gather_ts
.
device
,
dtype
=
torch
.
int8
)
is_used
[
pos
]
=
1
send_mm_to_dst
=
torch
.
zeros
([
scatter_len_list
.
sum
()
.
item
()],
device
=
is_used
.
device
,
dtype
=
torch
.
int8
)
torch
.
distributed
.
all_to_all_single
(
send_mm_to_dst
,
is_used
,
output_split_sizes
=
input_split
,
input_split_sizes
=
output_split
,
group
=
ctx
.
memory_nccl_group
,
)
index
=
index
[
send_mm_to_dst
>
0
]
mm
=
mm
[
send_mm_to_dst
>
0
]
mm_ts
=
mm_ts
[
send_mm_to_dst
>
0
]
gather_len_list
=
torch
.
empty
([
self
.
num_parts
],
dtype
=
int
,
device
=
self
.
device
)
ind
=
torch
.
ops
.
torch_sparse
.
ind2ptr
(
DistIndex
(
index
)
.
part
,
self
.
num_parts
)
scatter_len_list
=
ind
[
1
:]
-
ind
[
0
:
-
1
]
torch
.
distributed
.
all_to_all_single
(
gather_len_list
,
scatter_len_list
,
group
=
ctx
.
memory_nccl_group
)
input_split
=
scatter_len_list
.
tolist
()
output_split
=
gather_len_list
.
tolist
()
gather_id
=
gather_id
[
is_used
>
0
]
else
:
gather_id
=
torch
.
empty
([
gather_len_list
.
sum
()],
dtype
=
index
.
dtype
,
device
=
self
.
device
)
gather_memory
=
torch
.
empty
(
[
gather_len_list
.
sum
(),
mm
.
shape
[
1
]],
dtype
=
mm
.
dtype
,
device
=
self
.
device
)
return
index
,
gather_id
,
mm
,
gather_memory
,
input_split
,
output_split
def
set_memory_all_reduce
(
self
,
index
,
memory
,
memory_ts
,
mail
,
mail_ts
,
reduce_Op
=
None
,
async_op
=
True
,
set_remote
=
False
,
mode
=
None
,
filter
=
None
,
submit
=
True
):
def
build_all_to_all_async_task
(
self
,
index
,
mm
,
mm_ts
,
is_no_redundant
=
False
,
is_async
=
True
):
ctx
=
DistributedContext
.
get_default_context
()
#print(DistIndex(index).part)
if
set_remote
is
False
:
p2p_async_info
=
self
.
build_all_to_all_route
(
index
,
mm
,
mm_ts
,
is_no_redundant
)
if
mm
.
shape
[
1
]
!=
self
.
mailbox
.
shape
[
-
1
]
+
1
:
self
.
next_wait_memory_job
=
(
*
p2p_async_info
,
ctx
.
memory_nccl_group
,
is_async
)
else
:
self
.
next_wait_mail_job
=
(
*
p2p_async_info
,
ctx
.
memory_nccl_group
,
is_async
)
def
set_memory_all_reduce
(
self
,
index
,
memory
,
memory_ts
,
mail_index
,
mail
,
mail_ts
,
reduce_Op
=
None
,
async_op
=
True
,
mode
=
None
,
wait_submit
=
True
,
spread_mail
=
True
,
update_cross_mm
=
False
,
is_no_redundant
=
True
):
if
self
.
num_parts
==
1
:
return
if
not
spread_mail
and
not
update_cross_mm
:
pass
# self.set_mailbox_local(DistIndex(index).loc,mail,mail_ts,Reduce_Op = reduce_Op)
# self.set_memory_local(DistIndex(index).loc,memory,memory_ts, Reduce_Op = reduce_Op)
else
:
#print(index,memory,memory_ts)
if
async_op
==
True
and
self
.
num_parts
>
1
:
#local_mask = (DistIndex(index).loc == dist.get_rank())
self
.
set_mailbox_all_to_all
(
index
,
memory
,
memory_ts
,
mail
,
mail_ts
,
reduce_Op
=
'max'
,
async_op
=
async_op
,
submit
=
submit
)
if
spread_mail
:
mm
=
torch
.
cat
((
mail
,
mail_ts
.
reshape
(
-
1
,
1
)),
dim
=
1
)
mm_ts
=
mail_ts
self
.
build_all_to_all_async_task
(
mail_index
,
mm
,
mm_ts
,
is_async
=
True
,
is_no_redundant
=
False
)
if
update_cross_mm
:
mm
=
torch
.
cat
((
memory
,
memory_ts
.
reshape
(
-
1
,
1
)),
dim
=
1
)
mm_ts
=
memory_ts
self
.
build_all_to_all_async_task
(
index
,
mm
,
mm_ts
,
is_async
=
True
,
is_no_redundant
=
False
)
else
:
self
.
set_mailbox_all_to_all
(
index
,
memory
,
memory_ts
,
mail
,
mail_ts
,
reduce_Op
=
'max'
,
async_op
=
False
)
if
update_cross_mm
:
mm
=
self
.
pack
(
memory
,
memory_ts
,
mail
,
mail_ts
,
index
)
mm_ts
=
memory_ts
self
.
build_all_to_all_async_task
(
index
,
mm
,
mm_ts
,
is_async
=
True
,
is_no_redundant
=
False
)
if
async_op
is
False
:
self
.
update_p2p_mail
()
self
.
update_p2p_mem
()
self
.
handle_last_async
()
ctx
=
DistributedContext
.
get_default_context
()
if
self
.
shared_nodes_index
is
not
None
and
(
mode
==
'all_reduce'
or
mode
==
'historical'
):
shared_memory_ind
=
self
.
is_shared_mask
[
torch
.
min
(
DistIndex
(
index
)
.
loc
,
torch
.
tensor
([
self
.
num_nodes
-
1
],
device
=
torch
.
device
(
'cuda'
)
))]
shared_memory_ind
=
self
.
is_shared_mask
[
torch
.
min
(
DistIndex
(
index
)
.
loc
,
torch
.
tensor
([
self
.
num_nodes
-
1
],
device
=
index
.
device
))]
mask
=
((
shared_memory_ind
>-
1
)
&
(
DistIndex
(
index
)
.
part
==
ctx
.
memory_group_rank
))
shared_memory_ind
=
shared_memory_ind
[
mask
]
shared_memory
=
memory
[
mask
]
shared_memory_ts
=
memory_ts
[
mask
]
if
spread_mail
:
shared_mail_indx
=
self
.
is_shared_mask
[
torch
.
min
(
DistIndex
(
mail_index
)
.
loc
,
torch
.
tensor
([
self
.
num_nodes
-
1
],
device
=
index
.
device
))]
mask
=
((
shared_mail_indx
>-
1
)
&
(
DistIndex
(
shared_mail_indx
)
.
part
==
ctx
.
memory_group_rank
))
shared_mail_indx
=
shared_mail_indx
[
mask
]
else
:
shared_mail_indx
=
shared_memory_ind
shared_mail
=
mail
[
mask
]
shared_mail_ts
=
mail_ts
[
mask
]
if
mode
==
'historical'
:
#print(shared_memory_ind)
update_index
=
self
.
historical_cache
.
historical_check
(
shared_memory_ind
,
shared_memory
,
shared_memory_ts
)
#print(update_index.sum(),shared_memory_ind.shape)
shared_memory_ind
=
shared_memory_ind
[
update_index
]
shared_memory
=
shared_memory
[
update_index
]
shared_memory_ts
=
shared_memory_ts
[
update_index
]
shared_mail
=
shared_mail
[
update_index
]
shared_mail_ts
=
shared_mail_ts
[
update_index
]
#print(shared_memory_ind)
#mem = self.pack(memory=shared_memory,memory_ts=shared_memory_ts,mail=shared_mail,mail_ts=shared_mail_ts,index=shared_memory_ind,mode=mode)
#mem = self.pack(memory=shared_mail,memory_ts=shared_mail_ts,index=shared_memory_ind,mode=mode)
mem
=
self
.
pack
(
memory
=
shared_memory
,
memory_ts
=
shared_memory_ts
,
index
=
shared_memory_ind
,
mode
=
mode
)
else
:
mem
=
self
.
pack
(
memory
=
shared_memory
,
memory_ts
=
shared_memory_ts
,
mail
=
shared_mail
,
mail_ts
=
shared_mail_ts
,
index
=
shared_memory_ind
,
mode
=
mode
)
if
not
spread_mail
:
mem
=
self
.
pack
(
memory
=
shared_memory
,
memory_ts
=
shared_memory_ts
,
mail
=
shared_mail
,
mail_ts
=
shared_mail_ts
,
index
=
shared_memory_ind
,
mode
=
mode
)
else
:
mem
=
self
.
pack
(
memory
=
shared_memory
,
memory_ts
=
shared_memory_ts
,
index
=
shared_memory_ind
,
mode
=
mode
)
self
.
tot_shared_count
+=
shared_memory_ind
.
shape
[
0
]
broadcast_len
=
torch
.
empty
([
1
],
device
=
mem
.
device
,
dtype
=
torch
.
int
)
broadcast_len
[
0
]
=
shared_memory_ind
.
shape
[
0
]
...
...
@@ -331,54 +445,61 @@ class SharedMailBox():
shared_list
=
[
torch
.
empty
([
l
.
item
(),
mem
.
shape
[
1
]],
device
=
mem
.
device
,
dtype
=
mem
.
dtype
)
for
l
in
shared_len
]
shared_id_list
=
[
torch
.
empty
([
l
.
item
()],
device
=
shared_memory_ind
.
device
,
dtype
=
shared_memory_ind
.
dtype
)
for
l
in
shared_len
]
if
mode
==
'all_reduce'
:
start
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
end
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
start
.
record
()
if
async_op
==
True
:
handle0
=
dist
.
all_gather
(
shared_list
,
mem
,
group
=
ctx
.
memory_nccl_group
)
handle1
=
dist
.
all_gather
(
shared_id_list
,
shared_memory_ind
,
group
=
ctx
.
memory_nccl_group
)
self
.
last_memory_sync
=
shared_id_list
,
handle0
,
shared_id_list
,
handle1
dist
.
all_gather
(
shared_list
,
mem
,
group
=
ctx
.
memory_nccl_group
)
dist
.
all_gather
(
shared_id_list
,
shared_memory_ind
,
group
=
ctx
.
memory_nccl_group
)
mem
=
torch
.
cat
(
shared_list
,
dim
=
0
)
shared_index
=
torch
.
cat
(
shared_id_list
)
if
spread_mail
:
shared_memory
,
shared_memory_ts
=
self
.
unpack
(
mem
)
unq_index
,
inv
=
torch
.
unique
(
shared_index
,
return_inverse
=
True
)
max_ts
,
idx
=
torch_scatter
.
scatter_max
(
shared_memory_ts
,
inv
,
0
)
shared_memory
=
shared_memory
[
idx
]
shared_memory_ts
=
shared_memory_ts
[
idx
]
shared_index
=
unq_index
self
.
historical_cache
.
local_historical_data
[
shared_index
]
=
shared_memory
self
.
historical_cache
.
local_ts
[
shared_index
]
=
shared_memory_ts
broadcast_len
=
torch
.
empty
([
1
],
device
=
mem
.
device
,
dtype
=
torch
.
int
)
broadcast_len
[
0
]
=
shared_mail_indx
.
shape
[
0
]
shared_len
=
[
torch
.
empty
([
1
],
device
=
mail
.
device
,
dtype
=
torch
.
int
)
for
_
in
range
(
ctx
.
memory_group_size
)]
dist
.
all_gather
(
shared_len
,
broadcast_len
,
group
=
ctx
.
memory_nccl_group
)
mail
=
self
.
pack
(
memory
=
shared_mail
,
memory_ts
=
shared_mail_ts
,
index
=
shared_mail_indx
,
mode
=
mode
)
shared_mail_list
=
[
torch
.
empty
([
l
.
item
(),
mail
.
shape
[
1
]],
device
=
mail
.
device
,
dtype
=
mail
.
dtype
)
for
l
in
shared_len
]
shared_mail_id_list
=
[
torch
.
empty
([
l
.
item
()],
device
=
shared_mail_indx
.
device
,
dtype
=
shared_mail_indx
.
dtype
)
for
l
in
shared_len
]
#print(mail.shape)
dist
.
all_gather
(
shared_mail_list
,
mail
,
group
=
ctx
.
memory_nccl_group
)
dist
.
all_gather
(
shared_mail_id_list
,
shared_mail_indx
,
group
=
ctx
.
memory_nccl_group
)
shared_mail_indx
=
torch
.
cat
(
shared_mail_id_list
,
dim
=
0
)
mail
=
torch
.
cat
(
shared_mail_list
,
dim
=
0
)
shared_mail
,
shared_mail_ts
=
self
.
unpack
(
mail
)
unq_index
,
inv
=
torch
.
unique
(
shared_mail_indx
,
return_inverse
=
True
)
max_ts
,
idx
=
torch_scatter
.
scatter_max
(
shared_mail_ts
,
inv
,
0
)
shared_mail
=
shared_mail
[
idx
]
shared_mail_ts
=
shared_mail_ts
[
idx
]
shared_mail_indx
=
unq_index
else
:
dist
.
all_gather
(
shared_list
,
mem
,
group
=
ctx
.
memory_nccl_group
)
dist
.
all_gather
(
shared_id_list
,
shared_memory_ind
,
group
=
ctx
.
memory_nccl_group
)
mem
=
torch
.
cat
(
shared_list
,
dim
=
0
)
shared_index
=
torch
.
cat
(
shared_id_list
)
#id = shared_index.sort()
#print(mem[id],shared_index[id])
#,shared_memory,shared_memory_ts,
#shared_memory,shared_memory_ts = self.unpack(mem)
shared_memory
,
shared_memory_ts
,
shared_mail
,
shared_mail_ts
=
self
.
unpack
(
mem
)
#print(shared_memory_ts,shared_mail_ts)
unq_index
,
inv
=
torch
.
unique
(
shared_index
,
return_inverse
=
True
)
#print(inv.shape,Reduce_score.shape)
max_ts
,
idx
=
torch_scatter
.
scatter_max
(
shared_memory_ts
,
inv
,
0
)
#min_ts,_ = torch_scatter.scatter_min(shared_mail_ts,inv,0)
shared_memory
=
shared_memory
[
idx
]
shared_memory
=
shared_memory
[
idx
]
shared_memory_ts
=
shared_memory_ts
[
idx
]
shared_mail_ts
=
shared_mail_ts
[
idx
]
shared_mail
=
shared_mail
[
idx
]
#shared_mail_ts = torch_scatter.scatter_mean(shared_mail_ts,inv,0)
#shared_mail = torch_scatter.scatter_mean(shared_mail,inv,0)
shared_index
=
unq_index
self
.
set_memory_local
(
self
.
shared_nodes_index
[
shared_index
],
shared_memory
,
shared_memory_ts
)
self
.
historical_cache
.
local_historical_data
[
shared_index
]
=
shared_memory
self
.
historical_cache
.
local_ts
[
shared_index
]
=
shared_memory_ts
self
.
set_mailbox_local
(
self
.
shared_nodes_index
[
shared_index
],
shared_mail
,
shared_mail_ts
)
shared_mail_indx
=
unq_index
self
.
set_memory_local
(
self
.
shared_nodes_index
[
shared_index
],
shared_memory
,
shared_memory_ts
)
self
.
historical_cache
.
local_historical_data
[
shared_index
]
=
shared_memory
self
.
historical_cache
.
local_ts
[
shared_index
]
=
shared_memory_ts
self
.
set_mailbox_local
(
self
.
shared_nodes_index
[
shared_mail_indx
],
shared_mail
,
shared_mail_ts
)
else
:
start
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
end
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
start
.
record
()
#self.historical_cache.synchronize_shared_update(filter)
#mem[:,:-1] = mem[:,:-1] + filter.get_incretment(shared_memory_ind)
#shared_list = [torch.empty([l.item(),mem.shape[1]],device = mem.device,dtype=mem.dtype) for l in shared_len]
#handle0 = dist.all_gather(shared_list,mem,group=ctx.memory_nccl_group,async_op=True)
#shared_id_list = [torch.empty([l.item()],device = shared_memory_ind.device,dtype=shared_memory_ind.dtype) for l in shared_len]
#handle1 = dist.all_gather(shared_id_list,shared_memory_ind,group=ctx.memory_nccl_group,async_op=True)
self
.
last_job
=
(
shared_list
,
mem
,
shared_id_list
,
shared_memory_ind
)
if
~
submit
:
self
.
update_shared
()
self
.
next_wait_gather_memory_job
=
(
shared_list
,
mem
,
shared_id_list
,
shared_memory_ind
)
if
not
wait_submit
:
self
.
update_shared
()
self
.
update_p2p_mail
()
self
.
update_p2p_mem
()
#self.historical_cache.add_shared_to_queue(handle0,handle1,shared_id_list,shared_list)
"""
shared_memory = self.node_memory.accessor.data[self.shared_nodes_index]
...
...
@@ -399,107 +520,54 @@ class SharedMailBox():
def
set_mailbox_all_to_all_empty
(
self
,
index
,
memory
,
memory_ts
,
mail
,
mail_ts
,
reduce_Op
=
None
,
group
=
None
):
if
self
.
num_parts
==
1
:
dist_index
=
DistIndex
(
index
)
part_idx
=
dist_index
.
part
index
=
dist_index
.
loc
self
.
set_mailbox_local
(
index
,
mail
,
mail_ts
)
self
.
set_memory_local
(
index
,
memory
,
memory_ts
)
else
:
gather_len_list
=
torch
.
empty
([
self
.
num_parts
],
dtype
=
int
,
device
=
self
.
device
)
#indic = torch.searchsorted(index,self.partptr,right=False)
indic
=
torch
.
ops
.
torch_sparse
.
ind2ptr
(
DistIndex
(
index
)
.
part
,
self
.
num_parts
)
scatter_len_list
=
indic
[
1
:]
-
indic
[
0
:
-
1
]
torch
.
distributed
.
all_to_all_single
(
gather_len_list
,
scatter_len_list
,
group
=
group
)
input_split
=
scatter_len_list
.
tolist
()
output_split
=
gather_len_list
.
tolist
()
gather_id_list
=
torch
.
empty
(
[
gather_len_list
.
sum
()],
dtype
=
torch
.
long
,
device
=
self
.
device
)
input_split
=
scatter_len_list
.
tolist
()
output_split
=
gather_len_list
.
tolist
()
torch
.
distributed
.
all_to_all_single
(
gather_id_list
,
index
,
output_split_sizes
=
output_split
,
input_split_sizes
=
input_split
,
group
=
group
)
index
=
gather_id_list
gather_memory
=
torch
.
empty
(
[
gather_len_list
.
sum
(),
memory
.
shape
[
1
]],
dtype
=
memory
.
dtype
,
device
=
self
.
device
)
gather_memory_ts
=
torch
.
empty
(
[
gather_len_list
.
sum
()],
dtype
=
memory_ts
.
dtype
,
device
=
self
.
device
)
gather_mail
=
torch
.
empty
(
[
gather_len_list
.
sum
(),
mail
.
shape
[
1
]],
dtype
=
mail
.
dtype
,
device
=
self
.
device
)
gather_mail_ts
=
torch
.
empty
(
[
gather_len_list
.
sum
()],
dtype
=
mail_ts
.
dtype
,
device
=
self
.
device
)
torch
.
distributed
.
all_to_all_single
(
gather_memory
,
memory
,
output_split_sizes
=
output_split
,
input_split_sizes
=
input_split
,
group
=
group
,
async_op
=
True
)
torch
.
distributed
.
all_to_all_single
(
gather_memory_ts
,
memory_ts
,
output_split_sizes
=
output_split
,
input_split_sizes
=
input_split
,
group
=
group
)
torch
.
distributed
.
all_to_all_single
(
gather_mail
,
mail
,
output_split_sizes
=
output_split
,
input_split_sizes
=
input_split
,
group
=
group
)
torch
.
distributed
.
all_to_all_single
(
gather_mail_ts
,
mail_ts
,
output_split_sizes
=
output_split
,
input_split_sizes
=
input_split
,
group
=
group
)
pass
def
set_mailbox_all_to_all
(
self
,
index
,
memory
,
memory_ts
,
mail
,
mail_ts
,
reduce_Op
=
None
,
group
=
None
,
async_op
=
False
,
submit
=
True
):
#futs: List[torch.futures.Future] = []
if
self
.
num_parts
==
1
:
dist_index
=
DistIndex
(
index
)
part_idx
=
dist_index
.
part
index
=
dist_index
.
loc
self
.
set_mailbox_local
(
index
,
mail
,
mail_ts
)
self
.
set_memory_local
(
index
,
memory
,
memory_ts
)
else
:
self
.
tot_comm_count
+=
(
DistIndex
(
index
)
.
part
!=
dist
.
get_rank
())
.
sum
()
gather_len_list
=
torch
.
empty
([
self
.
num_parts
],
dtype
=
int
,
device
=
self
.
device
)
indic
=
torch
.
ops
.
torch_sparse
.
ind2ptr
(
DistIndex
(
index
)
.
part
,
self
.
num_parts
)
scatter_len_list
=
indic
[
1
:]
-
indic
[
0
:
-
1
]
torch
.
distributed
.
all_to_all_single
(
gather_len_list
,
scatter_len_list
,
group
=
group
)
input_split
=
scatter_len_list
.
tolist
()
output_split
=
gather_len_list
.
tolist
()
mem
=
self
.
pack
(
memory
,
memory_ts
,
mail
,
mail_ts
)
gather_memory
=
torch
.
empty
(
[
gather_len_list
.
sum
(),
mem
.
shape
[
1
]],
dtype
=
memory
.
dtype
,
device
=
self
.
device
)
gather_id_list
=
torch
.
empty
([
gather_len_list
.
sum
()],
dtype
=
torch
.
long
,
device
=
self
.
device
)
input_split
=
scatter_len_list
.
tolist
()
output_split
=
gather_len_list
.
tolist
()
if
async_op
==
True
:
self
.
last_job
=
index
,
gather_id_list
,
mem
,
gather_memory
,
input_split
,
output_split
,
group
,
async_op
if
~
submit
:
self
.
update_p2p
()
else
:
torch
.
distributed
.
all_to_all_single
(
gather_id_list
,
index
,
output_split_sizes
=
output_split
,
input_split_sizes
=
input_split
,
group
=
group
,
async_op
=
async_op
)
torch
.
distributed
.
all_to_all_single
(
gather_memory
,
mem
,
output_split_sizes
=
output_split
,
input_split_sizes
=
input_split
,
group
=
group
)
if
gather_memory
.
shape
[
1
]
>
self
.
node_memory
.
shape
[
1
]
+
1
:
gather_memory
,
gather_memory_ts
,
gather_mail
,
gather_mail_ts
=
self
.
unpack
(
gather_memory
)
self
.
set_mailbox_local
(
DistIndex
(
gather_id_list
)
.
loc
,
gather_mail
,
gather_mail_ts
,
Reduce_Op
=
reduce_Op
)
else
:
gather_memory
,
gather_memory_ts
=
self
.
unpack
(
gather_memory
)
self
.
set_memory_local
(
DistIndex
(
gather_id_list
)
.
loc
,
gather_memory
,
gather_memory_ts
,
Reduce_Op
=
reduce_Op
)
pass
# #futs: List[torch.futures.Future] = []
# if self.num_parts == 1:
# dist_index = DistIndex(index)
#
index = dist_index.loc
#
self.set_mailbox_local(index,mail,mail_ts)
#
self.set_memory_local(index,memory,memory_ts)
#
else:
#
self.tot_comm_count += (DistIndex(index).part != dist.get_rank()).sum()
#
gather_len_list = torch.empty([self.num_parts],
#
dtype = int,
#
device = self.device)
#
indic = torch.ops.torch_sparse.ind2ptr(DistIndex(index).part, self.num_parts)
#
scatter_len_list = indic[1:] - indic[0:-1]
#
torch.distributed.all_to_all_single(gather_len_list,scatter_len_list,group = group)
#
input_split = scatter_len_list.tolist()
#
output_split = gather_len_list.tolist()
#
mem = self.pack(memory,memory_ts,mail,mail_ts)
#
gather_memory = torch.empty(
#
[gather_len_list.sum(),mem.shape[1]],
#
dtype = memory.dtype,device = self.device)
#
gather_id_list = torch.empty([gather_len_list.sum()],dtype = torch.long,device = self.device)
#
input_split = scatter_len_list.tolist()
#
output_split = gather_len_list.tolist()
#
if async_op == True:
#
self.last_job = index,gather_id_list,mem,gather_memory,input_split,output_split,group,async_op
# if not
submit:
#
self.update_p2p()
#
else:
#
torch.distributed.all_to_all_single(
#
gather_id_list,index,output_split_sizes=output_split,
#
input_split_sizes=input_split,group = group,async_op=async_op)
#
torch.distributed.all_to_all_single(
#
gather_memory,mem,
#
output_split_sizes=output_split,
#
input_split_sizes=input_split,group = group)
#
if gather_memory.shape[1] > self.node_memory.shape[1] + 1:
#
gather_memory,gather_memory_ts,gather_mail,gather_mail_ts = self.unpack(gather_memory)
#
self.set_mailbox_local(DistIndex(gather_id_list).loc,gather_mail,gather_mail_ts,Reduce_Op = reduce_Op)
#
else:
#
gather_memory,gather_memory_ts = self.unpack(gather_memory)
#
self.set_memory_local(DistIndex(gather_id_list).loc,gather_memory,gather_memory_ts, Reduce_Op = reduce_Op)
def
gather_memory
(
self
,
...
...
starrygl/sample/sample_core/neighbor_sampler.py
View file @
99e5b95a
...
...
@@ -97,6 +97,7 @@ class NeighborSampler(BaseSampler):
node_part
=
None
,
edge_part
=
None
,
probability
=
1
,
no_neg
=
False
,
)
->
None
:
r"""__init__
Args:
...
...
@@ -122,7 +123,7 @@ class NeighborSampler(BaseSampler):
self
.
is_distinct
=
is_distinct
assert
graph_name
is
not
None
self
.
graph_name
=
graph_name
self
.
no_neg
=
no_neg
if
(
tnb
is
None
):
if
(
graph_data
.
edge_ts
is
not
None
):
timestamp
,
ind
=
graph_data
.
edge_ts
.
sort
()
...
...
@@ -314,7 +315,10 @@ class NeighborSampler(BaseSampler):
else:
seed, inverse_seed = seed.unique(return_inverse=True)
"""
out
,
metadata
=
self
.
sample_from_nodes
(
seed
,
seed_ts
,
is_unique
=
False
)
if
self
.
no_neg
:
out
,
metadata
=
self
.
sample_from_nodes
(
seed
[:
seed
.
shape
[
0
]
//
3
*
2
],
seed_ts
[:
seed
.
shape
[
0
]
//
3
*
2
],
is_unique
=
False
)
else
:
out
,
metadata
=
self
.
sample_from_nodes
(
seed
[:
seed
.
shape
[
0
]],
seed_ts
[:
seed
.
shape
[
0
]
//
3
*
2
],
is_unique
=
False
)
src_pos_index
=
torch
.
arange
(
0
,
num_pos
,
dtype
=
torch
.
long
,
device
=
out_device
)
dst_pos_index
=
torch
.
arange
(
num_pos
,
2
*
num_pos
,
dtype
=
torch
.
long
,
device
=
out_device
)
if
neg_sampling
.
is_triplet
()
or
neg_sampling
.
is_tgbtriplet
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment