Commit b9ca4758 by zhlj

do break down without pipeline

parent a6fb7bc7
......@@ -2,24 +2,26 @@
#跑了4卡的TaoBao
# 定义数组变量
seed=$1
addr="192.168.1.107"
addr="192.168.1.105"
partition_params=("ours")
#"metis" "ldg" "random")
#("ours" "metis" "ldg" "random")
partitions="16"
partitions="8"
node_per="4"
nnodes="4"
node_rank="3"
probability_params=("0.01")
nnodes="2"
node_rank="0"
probability_params=("0.1" )
sample_type_params=("boundery_recent_decay")
#sample_type_params=("recent" "boundery_recent_decay") #"boundery_recent_uniform")
#memory_type=("all_update" "p2p" "all_reduce" "historical" "local")
memory_type=("historical")
memory_type=("all_update" "historical" "local")
#"historical")
#memory_type=("local" "all_update" "historical" "all_reduce")
shared_memory_ssim=("0.3" "0.7")
#"historical")
#memory_type=("local" "all_update" "historical" "all_reduce")
shared_memory_ssim=("0.3")
#data_param=("WIKI" "REDDIT" "LASTFM" "WikiTalk")
data_param=("WikiTalk" "GDELT" "StackOverflow")
data_param=("GDELT")
#"GDELT")
#data_param=("WIKI" "REDDIT" "LASTFM" "DGraphFin" "WikiTalk" "StackOverflow")
#data_param=("WIKI" "REDDIT" "LASTFM" "WikiTalk" "StackOverflow")
......@@ -32,9 +34,9 @@ data_param=("WikiTalk" "GDELT" "StackOverflow")
#seed=(( RANDOM % 1000000 + 1 ))
mkdir -p all_"$seed"
for data in "${data_param[@]}"; do
model="APAN_large"
model="TGN_large"
if [ "$data" = "WIKI" ] || [ "$data" = "REDDIT" ] || [ "$data" = "LASTFM" ]; then
model="APAN"
model="TGN"
fi
#model="APAN"
mkdir all_"$seed"/"$data"
......@@ -59,7 +61,7 @@ for data in "${data_param[@]}"; do
wait
fi
else
torchrun --nnodes "$nnodes" --node_rank "$node_rank" --nproc-per-node "$node_per" --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode "$model" --partition "$partition" --topk 0 --sample_type "$sample" --memory_type "$mem" --seed "$seed" > all_"$seed"/"$data"/"$model"/"$partitions"-"$partition"-0-"$mem"-"$sample".out &
#torchrun --nnodes "$nnodes" --node_rank "$node_rank" --nproc-per-node "$node_per" --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode "$model" --partition "$partition" --topk 0 --sample_type "$sample" --memory_type "$mem" --seed "$seed" > all_"$seed"/"$data"/"$model"/"$partitions"-"$partition"-0-"$mem"-"$sample".out &
wait
if [ "$partition" = "ours" ] && [ "$mem" != "all_local" ]; then
torchrun --nnodes "$nnodes" --node_rank "$node_rank" --nproc-per-node "$node_per" --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode "$model" --partition "$partition" --topk 0.1 --sample_type "$sample" --memory_type "$mem" --seed "$seed" > all_"$seed"/"$data"/"$model"/"$partitions"-"$partition"-0.01-"$mem"-"$sample".out &
......
......@@ -202,7 +202,7 @@ def main():
else:
graph,full_sampler_graph,train_mask,val_mask,test_mask,full_train_mask,cache_route = load_from_speed(args.dataname,seed=123457,top=args.topk,sampler_graph_add_rev=True, feature_device=torch.device('cuda:{}'.format(ctx.local_rank)),partition=args.partition)#torch.device('cpu'))
if(args.dataname=='GDELT'):
train_param['epoch'] = 10
train_param['epoch'] = 1
#torch.autograd.set_detect_anomaly(True)
# 确保 CUDA 可用
if torch.cuda.is_available():
......@@ -569,7 +569,7 @@ def main():
#torch.cuda.synchronize()
loss.backward()
optimizer.step()
time_count.time_forward.elapsed_event(t1)
time_count.time_forward+=time_count.elapsed_event(t1)
#torch.cuda.synchronize()
## train aps
#y_pred = torch.cat([pred_pos, pred_neg], dim=0).sigmoid().cpu()
......
......@@ -32,12 +32,12 @@ class time_count:
@staticmethod
def start_gpu():
# Uncomment for better breakdown timings
#torch.cuda.synchronize()
#start_event = torch.cuda.Event(enable_timing=True)
#end_event = torch.cuda.Event(enable_timing=True)
#start_event.record()
#return start_event,end_event
return 0,0
torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
return start_event,end_event
#return 0,0
@staticmethod
def start():
# return time.perf_counter(),0
......@@ -56,6 +56,7 @@ class time_count:
@staticmethod
def print():
print('time_count.time_forward={} time_count.time_backward={} time_count.time_memory_updater={} time_count.time_embedding={} time_count.time_local_update={} time_count.time_memory_sync={} time_count.time_sample_and_build={} time_count.time_memory_fetch={}\n'.format(
time_count.time_forward,
time_count.time_backward,
time_count.time_memory_updater,
time_count.time_embedding,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment