Commit b9ca4758 by zhlj

do break down without pipeline

parent a6fb7bc7
...@@ -2,24 +2,26 @@ ...@@ -2,24 +2,26 @@
#跑了4卡的TaoBao #跑了4卡的TaoBao
# 定义数组变量 # 定义数组变量
seed=$1 seed=$1
addr="192.168.1.107" addr="192.168.1.105"
partition_params=("ours") partition_params=("ours")
#"metis" "ldg" "random") #"metis" "ldg" "random")
#("ours" "metis" "ldg" "random") #("ours" "metis" "ldg" "random")
partitions="16" partitions="8"
node_per="4" node_per="4"
nnodes="4" nnodes="2"
node_rank="3" node_rank="0"
probability_params=("0.01") probability_params=("0.1" )
sample_type_params=("boundery_recent_decay") sample_type_params=("boundery_recent_decay")
#sample_type_params=("recent" "boundery_recent_decay") #"boundery_recent_uniform") #sample_type_params=("recent" "boundery_recent_decay") #"boundery_recent_uniform")
#memory_type=("all_update" "p2p" "all_reduce" "historical" "local") #memory_type=("all_update" "p2p" "all_reduce" "historical" "local")
memory_type=("historical") memory_type=("all_update" "historical" "local")
#"historical")
#memory_type=("local" "all_update" "historical" "all_reduce")
shared_memory_ssim=("0.3" "0.7")
#"historical") #"historical")
#memory_type=("local" "all_update" "historical" "all_reduce") #memory_type=("local" "all_update" "historical" "all_reduce")
shared_memory_ssim=("0.3")
#data_param=("WIKI" "REDDIT" "LASTFM" "WikiTalk") #data_param=("WIKI" "REDDIT" "LASTFM" "WikiTalk")
data_param=("WikiTalk" "GDELT" "StackOverflow") data_param=("GDELT")
#"GDELT") #"GDELT")
#data_param=("WIKI" "REDDIT" "LASTFM" "DGraphFin" "WikiTalk" "StackOverflow") #data_param=("WIKI" "REDDIT" "LASTFM" "DGraphFin" "WikiTalk" "StackOverflow")
#data_param=("WIKI" "REDDIT" "LASTFM" "WikiTalk" "StackOverflow") #data_param=("WIKI" "REDDIT" "LASTFM" "WikiTalk" "StackOverflow")
...@@ -32,9 +34,9 @@ data_param=("WikiTalk" "GDELT" "StackOverflow") ...@@ -32,9 +34,9 @@ data_param=("WikiTalk" "GDELT" "StackOverflow")
#seed=(( RANDOM % 1000000 + 1 )) #seed=(( RANDOM % 1000000 + 1 ))
mkdir -p all_"$seed" mkdir -p all_"$seed"
for data in "${data_param[@]}"; do for data in "${data_param[@]}"; do
model="APAN_large" model="TGN_large"
if [ "$data" = "WIKI" ] || [ "$data" = "REDDIT" ] || [ "$data" = "LASTFM" ]; then if [ "$data" = "WIKI" ] || [ "$data" = "REDDIT" ] || [ "$data" = "LASTFM" ]; then
model="APAN" model="TGN"
fi fi
#model="APAN" #model="APAN"
mkdir all_"$seed"/"$data" mkdir all_"$seed"/"$data"
...@@ -59,7 +61,7 @@ for data in "${data_param[@]}"; do ...@@ -59,7 +61,7 @@ for data in "${data_param[@]}"; do
wait wait
fi fi
else else
torchrun --nnodes "$nnodes" --node_rank "$node_rank" --nproc-per-node "$node_per" --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode "$model" --partition "$partition" --topk 0 --sample_type "$sample" --memory_type "$mem" --seed "$seed" > all_"$seed"/"$data"/"$model"/"$partitions"-"$partition"-0-"$mem"-"$sample".out & #torchrun --nnodes "$nnodes" --node_rank "$node_rank" --nproc-per-node "$node_per" --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode "$model" --partition "$partition" --topk 0 --sample_type "$sample" --memory_type "$mem" --seed "$seed" > all_"$seed"/"$data"/"$model"/"$partitions"-"$partition"-0-"$mem"-"$sample".out &
wait wait
if [ "$partition" = "ours" ] && [ "$mem" != "all_local" ]; then if [ "$partition" = "ours" ] && [ "$mem" != "all_local" ]; then
torchrun --nnodes "$nnodes" --node_rank "$node_rank" --nproc-per-node "$node_per" --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode "$model" --partition "$partition" --topk 0.1 --sample_type "$sample" --memory_type "$mem" --seed "$seed" > all_"$seed"/"$data"/"$model"/"$partitions"-"$partition"-0.01-"$mem"-"$sample".out & torchrun --nnodes "$nnodes" --node_rank "$node_rank" --nproc-per-node "$node_per" --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode "$model" --partition "$partition" --topk 0.1 --sample_type "$sample" --memory_type "$mem" --seed "$seed" > all_"$seed"/"$data"/"$model"/"$partitions"-"$partition"-0.01-"$mem"-"$sample".out &
......
...@@ -202,7 +202,7 @@ def main(): ...@@ -202,7 +202,7 @@ def main():
else: else:
graph,full_sampler_graph,train_mask,val_mask,test_mask,full_train_mask,cache_route = load_from_speed(args.dataname,seed=123457,top=args.topk,sampler_graph_add_rev=True, feature_device=torch.device('cuda:{}'.format(ctx.local_rank)),partition=args.partition)#torch.device('cpu')) graph,full_sampler_graph,train_mask,val_mask,test_mask,full_train_mask,cache_route = load_from_speed(args.dataname,seed=123457,top=args.topk,sampler_graph_add_rev=True, feature_device=torch.device('cuda:{}'.format(ctx.local_rank)),partition=args.partition)#torch.device('cpu'))
if(args.dataname=='GDELT'): if(args.dataname=='GDELT'):
train_param['epoch'] = 10 train_param['epoch'] = 1
#torch.autograd.set_detect_anomaly(True) #torch.autograd.set_detect_anomaly(True)
# 确保 CUDA 可用 # 确保 CUDA 可用
if torch.cuda.is_available(): if torch.cuda.is_available():
...@@ -569,7 +569,7 @@ def main(): ...@@ -569,7 +569,7 @@ def main():
#torch.cuda.synchronize() #torch.cuda.synchronize()
loss.backward() loss.backward()
optimizer.step() optimizer.step()
time_count.time_forward.elapsed_event(t1) time_count.time_forward+=time_count.elapsed_event(t1)
#torch.cuda.synchronize() #torch.cuda.synchronize()
## train aps ## train aps
#y_pred = torch.cat([pred_pos, pred_neg], dim=0).sigmoid().cpu() #y_pred = torch.cat([pred_pos, pred_neg], dim=0).sigmoid().cpu()
......
...@@ -32,12 +32,12 @@ class time_count: ...@@ -32,12 +32,12 @@ class time_count:
@staticmethod @staticmethod
def start_gpu(): def start_gpu():
# Uncomment for better breakdown timings # Uncomment for better breakdown timings
#torch.cuda.synchronize() torch.cuda.synchronize()
#start_event = torch.cuda.Event(enable_timing=True) start_event = torch.cuda.Event(enable_timing=True)
#end_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True)
#start_event.record() start_event.record()
#return start_event,end_event return start_event,end_event
return 0,0 #return 0,0
@staticmethod @staticmethod
def start(): def start():
# return time.perf_counter(),0 # return time.perf_counter(),0
...@@ -56,6 +56,7 @@ class time_count: ...@@ -56,6 +56,7 @@ class time_count:
@staticmethod @staticmethod
def print(): def print():
print('time_count.time_forward={} time_count.time_backward={} time_count.time_memory_updater={} time_count.time_embedding={} time_count.time_local_update={} time_count.time_memory_sync={} time_count.time_sample_and_build={} time_count.time_memory_fetch={}\n'.format( print('time_count.time_forward={} time_count.time_backward={} time_count.time_memory_updater={} time_count.time_embedding={} time_count.time_local_update={} time_count.time_memory_sync={} time_count.time_sample_and_build={} time_count.time_memory_fetch={}\n'.format(
time_count.time_forward,
time_count.time_backward, time_count.time_backward,
time_count.time_memory_updater, time_count.time_memory_updater,
time_count.time_embedding, time_count.time_embedding,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment