Commit 8a16ee73 by zlj

delete weight param for no local

parent e893493a
...@@ -83,7 +83,7 @@ for data in "${data_param[@]}"; do ...@@ -83,7 +83,7 @@ for data in "${data_param[@]}"; do
wait wait
fi fi
else else
#torchrun --nnodes "$nnodes" --node_rank "$node_rank" --nproc-per-node "$node_per" --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode "$model" --partition "$partition" --topk 0 --sample_type "$sample" --probability "$pro" --memory_type "$mem" --seed "$seed" > all_"$seed"/"$data"/"$model"/"$partitions"-"$partition"-0-"$mem"-"$sample"-"$pro".out & torchrun --nnodes "$nnodes" --node_rank "$node_rank" --nproc-per-node "$node_per" --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode "$model" --partition "$partition" --topk 0 --sample_type "$sample" --probability "$pro" --memory_type "$mem" --seed "$seed" > all_"$seed"/"$data"/"$model"/"$partitions"-"$partition"-0-"$mem"-"$sample"-"$pro".out &
wait wait
if [ "$partition" = "ours" ] && [ "$mem" != "all_local" ]; then if [ "$partition" = "ours" ] && [ "$mem" != "all_local" ]; then
torchrun --nnodes "$nnodes" --node_rank "$node_rank" --nproc-per-node "$node_per" --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode "$model" --partition "$partition" --topk 0.1 --sample_type "$sample" --probability "$pro" --memory_type "$mem" --seed "$seed" > all_"$seed"/"$data"/"$model"/"$partitions"-ours_shared-0.01-"$mem"-"$sample"-"$pro".out & torchrun --nnodes "$nnodes" --node_rank "$node_rank" --nproc-per-node "$node_per" --master-addr "$addr" --master-port 9445 train_boundery.py --dataname "$data" --mode "$model" --partition "$partition" --topk 0.1 --sample_type "$sample" --probability "$pro" --memory_type "$mem" --seed "$seed" > all_"$seed"/"$data"/"$model"/"$partitions"-ours_shared-0.01-"$mem"-"$sample"-"$pro".out &
......
...@@ -67,7 +67,7 @@ parser.add_argument('--probability', default=1, type=float, metavar='W', ...@@ -67,7 +67,7 @@ parser.add_argument('--probability', default=1, type=float, metavar='W',
help='name of model') help='name of model')
parser.add_argument('--sample_type', default='recent', type=str, metavar='W', parser.add_argument('--sample_type', default='recent', type=str, metavar='W',
help='name of model') help='name of model')
parser.add_argument('--local_neg_sample', default=False, type=bool, metavar='W', parser.add_argument('--local_neg_sample', default=True, type=bool, metavar='W',
help='name of model') help='name of model')
parser.add_argument('--shared_memory_ssim', default=2, type=float, metavar='W', parser.add_argument('--shared_memory_ssim', default=2, type=float, metavar='W',
help='name of model') help='name of model')
...@@ -350,10 +350,10 @@ def main(): ...@@ -350,10 +350,10 @@ def main():
print('dim_node {} dim_edge {}\n'.format(gnn_dim_node,gnn_dim_edge)) print('dim_node {} dim_edge {}\n'.format(gnn_dim_node,gnn_dim_edge))
avg_time = 0 avg_time = 0
if use_cuda: if use_cuda:
model = GeneralModel(gnn_dim_node, gnn_dim_edge, sample_param, memory_param, gnn_param, train_param,graph.ids.shape[0],mailbox,train_ratio=(train_ratio_pos,train_ratio_neg)).cuda() model = GeneralModel(gnn_dim_node, gnn_dim_edge, sample_param, memory_param, gnn_param, train_param,graph.ids.shape[0],mailbox).cuda()
device = torch.device('cuda') device = torch.device('cuda')
else: else:
model = GeneralModel(gnn_dim_node, gnn_dim_edge, sample_param, memory_param, gnn_param, train_param,graph.ids.shape[0],mailbox,train_ratio=(train_ratio_pos,train_ratio_neg)) model = GeneralModel(gnn_dim_node, gnn_dim_edge, sample_param, memory_param, gnn_param, train_param,graph.ids.shape[0],mailbox)
device = torch.device('cpu') device = torch.device('cpu')
model = DDP(model,find_unused_parameters=True) model = DDP(model,find_unused_parameters=True)
def count_parameters(model): def count_parameters(model):
...@@ -552,12 +552,15 @@ def main(): ...@@ -552,12 +552,15 @@ def main():
t2 = time_count.start_gpu() t2 = time_count.start_gpu()
optimizer.zero_grad() optimizer.zero_grad()
ones = torch.ones(metadata['dst_neg_index'].shape[0],device = model.device,dtype=torch.float) ones = torch.ones(metadata['dst_neg_index'].shape[0],device = model.device,dtype=torch.float)
weight = torch.where(DistIndex(mfgs[0][0].srcdata['ID'][metadata['dst_neg_index']]).part == torch.distributed.get_rank(),ones*train_ratio_pos,ones*train_ratio_neg).reshape(-1,1)
pred_pos, pred_neg = model(mfgs,metadata,neg_samples=args.neg_samples,async_param = param) pred_pos, pred_neg = model(mfgs,metadata,neg_samples=args.neg_samples,async_param = param)
#print(time_count.elapsed_event(t2)) #print(time_count.elapsed_event(t2))
loss = creterion(pred_pos, torch.ones_like(pred_pos)) loss = creterion(pred_pos, torch.ones_like(pred_pos))
if args.local_neg_sample is False:
weight = torch.where(DistIndex(mfgs[0][0].srcdata['ID'][metadata['dst_neg_index']]).part == torch.distributed.get_rank(),ones*train_ratio_pos,ones*train_ratio_neg).reshape(-1,1)
neg_creterion = torch.nn.BCEWithLogitsLoss(weight) neg_creterion = torch.nn.BCEWithLogitsLoss(weight)
loss += neg_creterion(pred_neg, torch.zeros_like(pred_neg)) loss += neg_creterion(pred_neg, torch.zeros_like(pred_neg))
else:
loss += creterion(pred_neg, torch.zeros_like(pred_neg))
total_loss += float(loss.item()) total_loss += float(loss.item())
#mailbox.handle_last_async() #mailbox.handle_last_async()
#trainloader.async_feature() #trainloader.async_feature()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment