Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
B
BTS-MTGNN
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
zhlj
BTS-MTGNN
Commits
8adc9ed9
Commit
8adc9ed9
authored
Oct 05, 2024
by
zlj
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
fix smooth aggregation
parent
27dbfee5
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
41 additions
and
35 deletions
+41
-35
config/TGN.yml
+3
-3
config/TGN_large.yml
+1
-1
csrc/sampler/include/sampler.h
+8
-5
starrygl/module/Filter.py
+1
-1
starrygl/module/historical_cache.py
+5
-5
starrygl/module/layers.py
+2
-2
starrygl/module/memorys.py
+2
-2
starrygl/sample/count_static.py
+9
-8
starrygl/sample/memory/shared_mailbox.py
+4
-2
starrygl/sample/part_utils/transformer_from_speed.py
+4
-4
starrygl/sample/sample_core/LocalNegSampling.py
+2
-2
No files found.
config/TGN.yml
View file @
8adc9ed9
sampling
:
sampling
:
-
layer
:
1
-
layer
:
1
neighbor
:
neighbor
:
-
2
0
-
1
0
strategy
:
'
recent'
strategy
:
'
recent'
prop_time
:
False
prop_time
:
False
history
:
1
history
:
1
...
@@ -27,10 +27,10 @@ gnn:
...
@@ -27,10 +27,10 @@ gnn:
dim_time
:
100
dim_time
:
100
dim_out
:
100
dim_out
:
100
train
:
train
:
-
epoch
:
2
00
-
epoch
:
1
00
batch_size
:
1000
batch_size
:
1000
# reorder: 16
# reorder: 16
lr
:
0.000
4
lr
:
0.000
8
dropout
:
0.2
dropout
:
0.2
att_dropout
:
0.2
att_dropout
:
0.2
all_on_gpu
:
True
all_on_gpu
:
True
config/TGN_large.yml
View file @
8adc9ed9
...
@@ -30,7 +30,7 @@ train:
...
@@ -30,7 +30,7 @@ train:
-
epoch
:
50
-
epoch
:
50
batch_size
:
3000
batch_size
:
3000
# reorder: 16
# reorder: 16
lr
:
0.000
5
lr
:
0.000
8
dropout
:
0.2
dropout
:
0.2
att_dropout
:
0.2
att_dropout
:
0.2
all_on_gpu
:
True
all_on_gpu
:
True
csrc/sampler/include/sampler.h
View file @
8adc9ed9
...
@@ -300,13 +300,16 @@ void ParallelSampler :: neighbor_sample_from_nodes_with_before_layer(
...
@@ -300,13 +300,16 @@ void ParallelSampler :: neighbor_sample_from_nodes_with_before_layer(
cal_cnt
=
0
;
cal_cnt
=
0
;
for
(
int
cid
=
end_index
-
1
;
cid
>=
0
;
cid
--
){
for
(
int
cid
=
end_index
-
1
;
cid
>=
0
;
cid
--
){
cal_cnt
++
;
cal_cnt
++
;
if
(
cal_cnt
>
fanout
)
break
;
//
if(cal_cnt > fanout)break;
int
eid
=
tnb
.
eid
[
node
][
cid
];
int
eid
=
tnb
.
eid
[
node
][
cid
];
if
(
part
[
tnb
.
eid
[
node
][
cid
]]
!=
local_part
||
node_part
[
tnb
.
neighbors
[
node
][
cid
]]
!=
local_part
){
if
((
part
[
tnb
.
eid
[
node
][
cid
]]
!=
local_part
||
node_part
[
tnb
.
neighbors
[
node
][
cid
]]
!=
local_part
)){
double
p0
=
(
double
)
rand_r
(
&
loc_seeds
[
tid
])
/
(
RAND_MAX
+
1
.
0
);
if
(
cal_cnt
<=
fanout
){
double
ep
=
boundery_probility
*
pr
[
cal_cnt
-
1
]
/
sum_p
*
sum_1
;
double
p0
=
(
double
)
rand_r
(
&
loc_seeds
[
tid
])
/
(
RAND_MAX
+
1
.
0
);
if
(
p0
>
ep
)
continue
;
double
ep
=
boundery_probility
*
pr
[
cal_cnt
-
1
]
/
sum_p
*
sum_1
;
if
(
p0
>
ep
)
continue
;
}
else
continue
;
//cout<<"in"<<endl;
//cout<<"in"<<endl;
}
}
tgb_i
[
tid
].
src_index
.
emplace_back
(
i
);
tgb_i
[
tid
].
src_index
.
emplace_back
(
i
);
...
...
starrygl/module/Filter.py
View file @
8adc9ed9
...
@@ -31,7 +31,7 @@ class Filter(nn.Module):
...
@@ -31,7 +31,7 @@ class Filter(nn.Module):
def
get_incretment
(
self
,
node_idxs
):
def
get_incretment
(
self
,
node_idxs
):
#print(self.incretment[node_idxs,:].shape,self.count[node_idxs].shape)
#print(self.incretment[node_idxs,:].shape,self.count[node_idxs].shape)
return
self
.
incretment
[
node_idxs
,:]
/
torch
.
clamp
(
self
.
count
[
node_idxs
,:],
1
)
return
self
.
incretment
[
node_idxs
,:]
#
/torch.clamp(self.count[node_idxs,:],1)
def
get_incretment_remote
(
self
,
idx
):
def
get_incretment_remote
(
self
,
idx
):
remote_tensor
=
DistributedTensor
(
self
.
incretment
)
remote_tensor
=
DistributedTensor
(
self
.
incretment
)
...
...
starrygl/module/historical_cache.py
View file @
8adc9ed9
...
@@ -139,8 +139,8 @@ class HistoricalCache:
...
@@ -139,8 +139,8 @@ class HistoricalCache:
if
(
shared_data
.
shape
[
0
]
==
0
):
if
(
shared_data
.
shape
[
0
]
==
0
):
return
None
return
None
len
=
self
.
local_historical_data
.
shape
[
1
]
len
=
self
.
local_historical_data
.
shape
[
1
]
mail_ts
=
shared_data
[:,
-
1
]
#
mail_ts = shared_data[:,-1]
mail_data
=
shared_data
[:,
len
+
1
:
-
1
]
#
mail_data = shared_data[:,len+1:-1]
shared_ts
=
shared_data
[:,
len
]
shared_ts
=
shared_data
[:,
len
]
shared_mem
=
shared_data
[:,:
len
]
shared_mem
=
shared_data
[:,:
len
]
#print(shared_index)
#print(shared_index)
...
@@ -150,8 +150,8 @@ class HistoricalCache:
...
@@ -150,8 +150,8 @@ class HistoricalCache:
#shared_data = torch_scatter.scatter_mean(shared_data,inv,0)
#shared_data = torch_scatter.scatter_mean(shared_data,inv,0)
shared_mem
=
shared_mem
[
idx
]
shared_mem
=
shared_mem
[
idx
]
shared_ts
=
shared_ts
[
idx
]
shared_ts
=
shared_ts
[
idx
]
mail_data
=
mail_data
[
idx
]
#
mail_data = mail_data[idx]
mail_ts
=
mail_ts
[
idx
]
#
mail_ts = mail_ts[idx]
shared_index
=
unq_index
shared_index
=
unq_index
#print('{} {} {}\n'.format(shared_index,shared_data,shared_ts))
#print('{} {} {}\n'.format(shared_index,shared_data,shared_ts))
# if filter is not None:
# if filter is not None:
...
@@ -164,7 +164,7 @@ class HistoricalCache:
...
@@ -164,7 +164,7 @@ class HistoricalCache:
self
.
local_historical_data
[
shared_index
]
=
shared_mem
self
.
local_historical_data
[
shared_index
]
=
shared_mem
self
.
local_ts
[
shared_index
]
=
shared_ts
self
.
local_ts
[
shared_index
]
=
shared_ts
self
.
last_shared_update_wait
=
None
self
.
last_shared_update_wait
=
None
return
shared_index
,
shared_mem
,
shared_ts
,
mail_data
,
mail_ts
return
shared_index
,
shared_mem
,
shared_ts
#
,mail_data,mail_ts
...
...
starrygl/module/layers.py
View file @
8adc9ed9
...
@@ -288,8 +288,8 @@ class TransfomerAttentionLayer(torch.nn.Module):
...
@@ -288,8 +288,8 @@ class TransfomerAttentionLayer(torch.nn.Module):
#att = dgl.ops.e_div_v(b,att_e_sub_max,torch.clamp_min(dgl.ops.copy_e_sum(b,att_e_sub_max),1))
#att = dgl.ops.e_div_v(b,att_e_sub_max,torch.clamp_min(dgl.ops.copy_e_sum(b,att_e_sub_max),1))
att
=
dgl
.
ops
.
edge_softmax
(
b
,
self
.
att_act
(
torch
.
sum
(
Q
*
K
,
dim
=
2
)))
att
=
dgl
.
ops
.
edge_softmax
(
b
,
self
.
att_act
(
torch
.
sum
(
Q
*
K
,
dim
=
2
)))
att
=
self
.
att_dropout
(
att
)
att
=
self
.
att_dropout
(
att
)
tt
.
weight_count_remote
+=
torch
.
sum
(
att
[
DistIndex
(
b
.
srcdata
[
'ID'
])
.
part
[
b
.
edges
()[
0
]]
!=
torch
.
distributed
.
get_rank
()]
**
2
)
#
tt.weight_count_remote+=torch.sum(att[DistIndex(b.srcdata['ID']).part[b.edges()[0]]!=torch.distributed.get_rank()]**2)
tt
.
weight_count_local
+=
torch
.
sum
(
att
[
DistIndex
(
b
.
srcdata
[
'ID'
])
.
part
[
b
.
edges
()[
0
]]
==
torch
.
distributed
.
get_rank
()]
**
2
)
#
tt.weight_count_local+=torch.sum(att[DistIndex(b.srcdata['ID']).part[b.edges()[0]]==torch.distributed.get_rank()]**2)
V
=
torch
.
reshape
(
V
*
att
[:,
:,
None
],
(
V
.
shape
[
0
],
-
1
))
V
=
torch
.
reshape
(
V
*
att
[:,
:,
None
],
(
V
.
shape
[
0
],
-
1
))
V_local
=
V
.
clone
()
V_local
=
V
.
clone
()
V_remote
=
V
.
clone
()
V_remote
=
V
.
clone
()
...
...
starrygl/module/memorys.py
View file @
8adc9ed9
...
@@ -445,7 +445,7 @@ class AsyncMemeoryUpdater(torch.nn.Module):
...
@@ -445,7 +445,7 @@ class AsyncMemeoryUpdater(torch.nn.Module):
transition_dense
*=
2
transition_dense
*=
2
if
not
(
transition_dense
.
max
()
.
item
()
==
0
):
if
not
(
transition_dense
.
max
()
.
item
()
==
0
):
transition_dense
-=
transition_dense
.
min
()
transition_dense
-=
transition_dense
.
min
()
transition_dense
/=
t
ransition_dense
.
max
()
transition_dense
/=
t
orch
.
clamp
(
transition_dense
.
max
()
,
1
)
transition_dense
=
2
*
transition_dense
-
1
transition_dense
=
2
*
transition_dense
-
1
upd0
[
mask
]
=
b
.
srcdata
[
'his_mem'
][
mask
]
+
transition_dense
upd0
[
mask
]
=
b
.
srcdata
[
'his_mem'
][
mask
]
+
transition_dense
else
:
else
:
...
@@ -456,7 +456,7 @@ class AsyncMemeoryUpdater(torch.nn.Module):
...
@@ -456,7 +456,7 @@ class AsyncMemeoryUpdater(torch.nn.Module):
updated_memory
=
torch
.
where
(
mask
.
unsqueeze
(
1
),
self
.
gamma
*
updated_memory0
+
(
1
-
self
.
gamma
)
*
(
upd0
),
updated_memory0
)
updated_memory
=
torch
.
where
(
mask
.
unsqueeze
(
1
),
self
.
gamma
*
updated_memory0
+
(
1
-
self
.
gamma
)
*
(
upd0
),
updated_memory0
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
if
self
.
mode
==
'historical'
:
if
self
.
mode
==
'historical'
:
change
=
upd
0
[
mask
]
-
b
.
srcdata
[
'his_mem'
][
mask
]
change
=
upd
ated_memory
[
mask
]
-
b
.
srcdata
[
'his_mem'
][
mask
]
change
.
detach
()
change
.
detach
()
if
not
(
change
.
max
()
.
item
()
==
0
):
if
not
(
change
.
max
()
.
item
()
==
0
):
change
-=
change
.
min
()
change
-=
change
.
min
()
...
...
starrygl/sample/count_static.py
View file @
8adc9ed9
...
@@ -42,14 +42,15 @@ class time_count:
...
@@ -42,14 +42,15 @@ class time_count:
return
time
.
perf_counter
(),
0
return
time
.
perf_counter
(),
0
@staticmethod
@staticmethod
def
elapsed_event
(
start_event
):
def
elapsed_event
(
start_event
):
if
isinstance
(
start_event
,
tuple
):
#if isinstance(start_event,tuple):
start_event
,
end_event
=
start_event
# start_event,end_event = start_event
end_event
.
record
()
# end_event.record()
end_event
.
synchronize
()
# end_event.synchronize()
return
start_event
.
elapsed_time
(
end_event
)
# return start_event.elapsed_time(end_event)
else
:
#else:
torch
.
cuda
.
synchronize
()
# torch.cuda.synchronize()
return
time
.
perf_counter
()
-
start_event
# return time.perf_counter() - start_event
return
0
@staticmethod
@staticmethod
def
print
():
def
print
():
print
(
'time_count.time_forward={} time_count.time_backward={} time_count.time_memory_updater={} time_count.time_embedding={} time_count.time_local_update={} time_count.time_memory_sync={} time_count.time_sample_and_build={} time_count.time_memory_fetch={}
\n
'
.
format
(
print
(
'time_count.time_forward={} time_count.time_backward={} time_count.time_memory_updater={} time_count.time_embedding={} time_count.time_local_update={} time_count.time_memory_sync={} time_count.time_sample_and_build={} time_count.time_memory_fetch={}
\n
'
.
format
(
...
...
starrygl/sample/memory/shared_mailbox.py
View file @
8adc9ed9
...
@@ -253,7 +253,7 @@ class SharedMailBox():
...
@@ -253,7 +253,7 @@ class SharedMailBox():
def
sychronize_shared
(
self
):
def
sychronize_shared
(
self
):
out
=
self
.
historical_cache
.
synchronize_shared_update
()
out
=
self
.
historical_cache
.
synchronize_shared_update
()
if
out
is
not
None
:
if
out
is
not
None
:
shared_index
,
shared_data
,
shared_ts
,
mail
,
mail_ts
=
out
shared_index
,
shared_data
,
shared_ts
=
out
index
=
self
.
shared_nodes_index
[
shared_index
]
index
=
self
.
shared_nodes_index
[
shared_index
]
mask
=
(
shared_ts
>
self
.
node_memory_ts
.
accessor
.
data
[
index
])
mask
=
(
shared_ts
>
self
.
node_memory_ts
.
accessor
.
data
[
index
])
self
.
node_memory
.
accessor
.
data
[
index
][
mask
]
=
shared_data
[
mask
]
self
.
node_memory
.
accessor
.
data
[
index
][
mask
]
=
shared_data
[
mask
]
...
@@ -320,8 +320,10 @@ class SharedMailBox():
...
@@ -320,8 +320,10 @@ class SharedMailBox():
#mem = self.pack(memory=shared_memory,memory_ts=shared_memory_ts,mail=shared_mail,mail_ts=shared_mail_ts,index=shared_memory_ind,mode=mode)
#mem = self.pack(memory=shared_memory,memory_ts=shared_memory_ts,mail=shared_mail,mail_ts=shared_mail_ts,index=shared_memory_ind,mode=mode)
#mem = self.pack(memory=shared_mail,memory_ts=shared_mail_ts,index=shared_memory_ind,mode=mode)
#mem = self.pack(memory=shared_mail,memory_ts=shared_mail_ts,index=shared_memory_ind,mode=mode)
mem
=
self
.
pack
(
memory
=
shared_memory
,
memory_ts
=
shared_memory_ts
,
index
=
shared_memory_ind
,
mode
=
mode
)
else
:
mem
=
self
.
pack
(
memory
=
shared_memory
,
memory_ts
=
shared_memory_ts
,
mail
=
shared_mail
,
mail_ts
=
shared_mail_ts
,
index
=
shared_memory_ind
,
mode
=
mode
)
self
.
tot_shared_count
+=
shared_memory_ind
.
shape
[
0
]
self
.
tot_shared_count
+=
shared_memory_ind
.
shape
[
0
]
mem
=
self
.
pack
(
memory
=
shared_memory
,
memory_ts
=
shared_memory_ts
,
mail
=
shared_mail
,
mail_ts
=
shared_mail_ts
,
index
=
shared_memory_ind
,
mode
=
mode
)
broadcast_len
=
torch
.
empty
([
1
],
device
=
mem
.
device
,
dtype
=
torch
.
int
)
broadcast_len
=
torch
.
empty
([
1
],
device
=
mem
.
device
,
dtype
=
torch
.
int
)
broadcast_len
[
0
]
=
shared_memory_ind
.
shape
[
0
]
broadcast_len
[
0
]
=
shared_memory_ind
.
shape
[
0
]
shared_len
=
[
torch
.
empty
([
1
],
device
=
mem
.
device
,
dtype
=
torch
.
int
)
for
_
in
range
(
ctx
.
memory_group_size
)]
shared_len
=
[
torch
.
empty
([
1
],
device
=
mem
.
device
,
dtype
=
torch
.
int
)
for
_
in
range
(
ctx
.
memory_group_size
)]
...
...
starrygl/sample/part_utils/transformer_from_speed.py
View file @
8adc9ed9
...
@@ -286,10 +286,10 @@ def load_from_speed(data,seed,top,sampler_graph_add_rev,device=torch.device('cud
...
@@ -286,10 +286,10 @@ def load_from_speed(data,seed,top,sampler_graph_add_rev,device=torch.device('cud
return
load_from_shared_node_partition
(
data
,
None
,
None
,
sample_add_rev
=
sampler_graph_add_rev
,
device
=
device
,
feature_device
=
feature_device
)
return
load_from_shared_node_partition
(
data
,
None
,
None
,
sample_add_rev
=
sampler_graph_add_rev
,
device
=
device
,
feature_device
=
feature_device
)
else
:
else
:
if
partition
==
'ours'
:
if
partition
==
'ours'
:
fnode_i
=
'../../SPEED/partition/divided_nodes_seed_
t2
/{}/{}/{}_{}parts_top{}/output{}.txt'
.
format
(
data
,
seed
,
data
,
ctx
.
memory_group_size
,
top
,
ctx
.
memory_group_rank
)
fnode_i
=
'../../SPEED/partition/divided_nodes_seed_
starrygl
/{}/{}/{}_{}parts_top{}/output{}.txt'
.
format
(
data
,
seed
,
data
,
ctx
.
memory_group_size
,
top
,
ctx
.
memory_group_rank
)
fnode_share
=
'../../SPEED/partition/divided_nodes_seed_
t2
/{}/{}/{}_{}parts_top{}/outputshared.txt'
.
format
(
data
,
seed
,
data
,
ctx
.
memory_group_size
,
top
)
fnode_share
=
'../../SPEED/partition/divided_nodes_seed_
starrygl
/{}/{}/{}_{}parts_top{}/outputshared.txt'
.
format
(
data
,
seed
,
data
,
ctx
.
memory_group_size
,
top
)
reorder
=
'../../SPEED/partition/divided_nodes_seed_
t2
/{}/reorder.txt'
.
format
(
data
)
reorder
=
'../../SPEED/partition/divided_nodes_seed_
starrygl
/{}/reorder.txt'
.
format
(
data
)
edge_i
=
'../../SPEED/partition/divided_nodes_seed_
t2
/{}/{}/{}_{}parts_top{}/edge_output{}.txt'
.
format
(
data
,
seed
,
data
,
ctx
.
memory_group_size
,
top
,
ctx
.
memory_group_rank
)
edge_i
=
'../../SPEED/partition/divided_nodes_seed_
starrygl
/{}/{}/{}_{}parts_top{}/edge_output{}.txt'
.
format
(
data
,
seed
,
data
,
ctx
.
memory_group_size
,
top
,
ctx
.
memory_group_rank
)
elif
partition
==
'metis'
:
elif
partition
==
'metis'
:
fnode_i
=
'../../SPEED/partition/divided_nodes_metis_test/{}/{}/{}_{}parts_top{}/output{}.txt'
.
format
(
data
,
seed
,
data
,
ctx
.
memory_group_size
,
top
,
ctx
.
memory_group_rank
)
fnode_i
=
'../../SPEED/partition/divided_nodes_metis_test/{}/{}/{}_{}parts_top{}/output{}.txt'
.
format
(
data
,
seed
,
data
,
ctx
.
memory_group_size
,
top
,
ctx
.
memory_group_rank
)
fnode_share
=
'../../SPEED/partition/divided_nodes_metis_test/{}/{}/{}_{}parts_top{}/outputshared.txt'
.
format
(
data
,
seed
,
data
,
ctx
.
memory_group_size
,
top
)
fnode_share
=
'../../SPEED/partition/divided_nodes_metis_test/{}/{}/{}_{}parts_top{}/outputshared.txt'
.
format
(
data
,
seed
,
data
,
ctx
.
memory_group_size
,
top
)
...
...
starrygl/sample/sample_core/LocalNegSampling.py
View file @
8adc9ed9
...
@@ -63,8 +63,8 @@ class LocalNegativeSampling(NegativeSampling):
...
@@ -63,8 +63,8 @@ class LocalNegativeSampling(NegativeSampling):
p
=
torch
.
rand
(
size
=
(
num_samples
,))
p
=
torch
.
rand
(
size
=
(
num_samples
,))
sr
=
self
.
dst_node_list
[
torch
.
randint
(
len
(
self
.
dst_node_list
),
(
num_samples
,
),
generator
=
self
.
rdm
)]
sr
=
self
.
dst_node_list
[
torch
.
randint
(
len
(
self
.
dst_node_list
),
(
num_samples
,
),
generator
=
self
.
rdm
)]
sl
=
self
.
local_dst
[
torch
.
randint
(
len
(
self
.
local_dst
),
(
num_samples
,
),
generator
=
self
.
rdm
)]
sl
=
self
.
local_dst
[
torch
.
randint
(
len
(
self
.
local_dst
),
(
num_samples
,
),
generator
=
self
.
rdm
)]
s
=
torch
.
where
(
p
<=
self
.
prob
,
s
l
,
sr
)
s
=
torch
.
where
(
p
<=
self
.
prob
,
s
r
,
sl
)
return
s
r
return
s
else
:
else
:
s
=
torch
.
randint
(
len
(
self
.
dst_node_list
),
(
num_samples
,
),
generator
=
self
.
rdm
)
s
=
torch
.
randint
(
len
(
self
.
dst_node_list
),
(
num_samples
,
),
generator
=
self
.
rdm
)
return
self
.
dst_node_list
[
s
]
return
self
.
dst_node_list
[
s
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment