Commit 506058f4 by zhlj

fist

parent 911d3eb8
......@@ -67,7 +67,7 @@ parser.add_argument('--probability', default=1, type=float, metavar='W',
help='name of model')
parser.add_argument('--sample_type', default='recent', type=str, metavar='W',
help='name of model')
parser.add_argument('--local_neg_sample', default=False, type=bool, metavar='W',
parser.add_argument('--local_neg_sample', default='local', type=str, metavar='W',
help='name of model')
parser.add_argument('--shared_memory_ssim', default=2, type=float, metavar='W',
help='name of model')
......@@ -275,10 +275,13 @@ def main():
neg_samples = args.eval_neg_samples
mask = DistIndex(graph.nids_mapper[graph.edge_index[1,:]].to('cpu')).part == dist.get_rank()
if args.local_neg_sample:
if args.local_neg_sample == 'local':
print('dst len {} origin len {}'.format(graph.edge_index[1,mask].unique().shape[0],full_dst.unique().shape[0]))
train_neg_sampler = LocalNegativeSampling('triplet',amount = args.neg_samples,dst_node_list = graph.edge_index[1,mask].unique())
elif args.local_neg_sample == 'all':
train_neg_sampler = LocalNegativeSampling('triplet',amount = args.neg_samples,dst_node_list = full_dst.unique(),local_mask=(DistIndex(graph.nids_mapper[full_dst.unique()].to('cpu')).part == dist.get_rank()),prob=1)
train_ratio_pos = 1
train_ratio_neg = 1
else:
#train_neg_sampler = LocalNegativeSampling('triplet',amount = args.neg_samples,dst_node_list = full_dst.unique())
train_neg_sampler = LocalNegativeSampling('triplet',amount = args.neg_samples,dst_node_list = full_dst.unique(),local_mask=(DistIndex(graph.nids_mapper[full_dst.unique()].to('cpu')).part == dist.get_rank()),prob=args.probability)
......@@ -300,7 +303,7 @@ def main():
mode='train',
queue_size = 200,
mailbox = mailbox,
is_pipeline=True,
is_pipeline=False,
use_local_feature = False,
device = torch.device('cuda:{}'.format(local_rank)),
probability=args.probability,
......@@ -561,7 +564,7 @@ def main():
#print(time_count.elapsed_event(t2))
loss = creterion(pred_pos, torch.ones_like(pred_pos))
if args.local_neg_sample is False:
if args.local_neg_sample != 'local':
weight = torch.where(DistIndex(mfgs[0][0].srcdata['ID'][metadata['dst_neg_index']]).part == torch.distributed.get_rank(),ones*train_ratio_pos,ones*train_ratio_neg).reshape(-1,1)
neg_creterion = torch.nn.BCEWithLogitsLoss(weight)
loss += neg_creterion(pred_neg, torch.zeros_like(pred_neg))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment