Commit 82337762 by zlj

add negative fix weight

parent cc8abec4
...@@ -126,7 +126,7 @@ def seed_everything(seed=42): ...@@ -126,7 +126,7 @@ def seed_everything(seed=42):
torch.cuda.manual_seed(seed) torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False torch.backends.cudnn.benchmark = False
seed_everything(args.seed)
total_next_batch = 0 total_next_batch = 0
total_forward = 0 total_forward = 0
total_count_score = 0 total_count_score = 0
...@@ -267,9 +267,13 @@ def main(): ...@@ -267,9 +267,13 @@ def main():
if args.local_neg_sample: if args.local_neg_sample:
print('dst len {} origin len {}'.format(graph.edge_index[1,mask].unique().shape[0],full_dst.unique().shape[0])) print('dst len {} origin len {}'.format(graph.edge_index[1,mask].unique().shape[0],full_dst.unique().shape[0]))
train_neg_sampler = LocalNegativeSampling('triplet',amount = args.neg_samples,dst_node_list = graph.edge_index[1,mask].unique()) train_neg_sampler = LocalNegativeSampling('triplet',amount = args.neg_samples,dst_node_list = graph.edge_index[1,mask].unique())
else: else:
#train_neg_sampler = LocalNegativeSampling('triplet',amount = args.neg_samples,dst_node_list = full_dst.unique()) #train_neg_sampler = LocalNegativeSampling('triplet',amount = args.neg_samples,dst_node_list = full_dst.unique())
train_neg_sampler = LocalNegativeSampling('triplet',amount = args.neg_samples,dst_node_list = full_dst.unique(),local_mask=(DistIndex(graph.nids_mapper[full_dst.unique()].to('cpu')).part == dist.get_rank()),prob=args.probability) train_neg_sampler = LocalNegativeSampling('triplet',amount = args.neg_samples,dst_node_list = full_dst.unique(),local_mask=(DistIndex(graph.nids_mapper[full_dst.unique()].to('cpu')).part == dist.get_rank()),prob=args.probability)
remote_ratio = train_neg_sampler.local_dst.shape[0] / train_neg_sampler.dst_node_list.shape[0]
train_ratio_pos = (1 - args.probability) + args.probability * remote_ratio
train_ratio_neg = args.probability * (1-remote_ratio)
print(train_neg_sampler.dst_node_list) print(train_neg_sampler.dst_node_list)
neg_sampler = LocalNegativeSampling('triplet',amount= neg_samples,dst_node_list = full_dst.unique(),seed=args.seed) neg_sampler = LocalNegativeSampling('triplet',amount= neg_samples,dst_node_list = full_dst.unique(),seed=args.seed)
...@@ -338,10 +342,10 @@ def main(): ...@@ -338,10 +342,10 @@ def main():
print('dim_node {} dim_edge {}\n'.format(gnn_dim_node,gnn_dim_edge)) print('dim_node {} dim_edge {}\n'.format(gnn_dim_node,gnn_dim_edge))
avg_time = 0 avg_time = 0
if use_cuda: if use_cuda:
model = GeneralModel(gnn_dim_node, gnn_dim_edge, sample_param, memory_param, gnn_param, train_param,graph.ids.shape[0],mailbox).cuda() model = GeneralModel(gnn_dim_node, gnn_dim_edge, sample_param, memory_param, gnn_param, train_param,graph.ids.shape[0],mailbox,train_ratio=(train_ratio_pos,train_ratio_neg)).cuda()
device = torch.device('cuda') device = torch.device('cuda')
else: else:
model = GeneralModel(gnn_dim_node, gnn_dim_edge, sample_param, memory_param, gnn_param, train_param,graph.ids.shape[0],mailbox) model = GeneralModel(gnn_dim_node, gnn_dim_edge, sample_param, memory_param, gnn_param, train_param,graph.ids.shape[0],mailbox,train_ratio=(train_ratio_pos,train_ratio_neg))
device = torch.device('cpu') device = torch.device('cpu')
model = DDP(model,find_unused_parameters=True) model = DDP(model,find_unused_parameters=True)
def count_parameters(model): def count_parameters(model):
...@@ -531,9 +535,12 @@ def main(): ...@@ -531,9 +535,12 @@ def main():
model.train() model.train()
optimizer.zero_grad() optimizer.zero_grad()
ones = torch.ones(metadata['dst_neg_index'].shape[0],device = model.device,dtype=torch.float)
weight = torch.where(DistIndex(mfgs[0][0].srcdata['ID'][metadata['dst_neg_index']]).part == torch.distributed.get_rank(),ones/train_ratio_pos,ones/train_ratio_neg).reshape(-1,1)
pred_pos, pred_neg = model(mfgs,metadata,neg_samples=args.neg_samples,async_param = param) pred_pos, pred_neg = model(mfgs,metadata,neg_samples=args.neg_samples,async_param = param)
loss = creterion(pred_pos, torch.ones_like(pred_pos)) loss = creterion(pred_pos, torch.ones_like(pred_pos))
loss += creterion(pred_neg, torch.zeros_like(pred_neg)) neg_creterion = torch.nn.BCEWithLogitsLoss(weight)
loss += neg_creterion(pred_neg, torch.zeros_like(pred_neg))
total_loss += float(loss.item()) total_loss += float(loss.item())
#mailbox.handle_last_async() #mailbox.handle_last_async()
#trainloader.async_feature() #trainloader.async_feature()
...@@ -663,9 +670,9 @@ def main(): ...@@ -663,9 +670,9 @@ def main():
pass pass
# print('weight {} {}\n'.format(tt.weight_count_local,tt.weight_count_remote)) # print('weight {} {}\n'.format(tt.weight_count_local,tt.weight_count_remote))
# print('ssim {} {}\n'.format(tt.ssim_local/tt.ssim_cnt,tt.ssim_remote/tt.ssim_cnt)) # print('ssim {} {}\n'.format(tt.ssim_local/tt.ssim_cnt,tt.ssim_remote/tt.ssim_cnt))
torch.save(val_list,'all_args.seed/{}/{}/val_{}_{}_{}_{}_{}_{}_{}_{}.pt'.format(args.dataname,args.model,args.partition,args.topk,dist.get_world_size(),dist.get_rank(),args.sample_type,args.probability,args.memory_type,args.shared_memory_ssim)) torch.save(val_list,'all_{}/{}/{}/val_{}_{}_{}_{}_{}_{}_{}_{}.pt'.format(args.seed,args.dataname,args.model,args.partition,args.topk,dist.get_world_size(),dist.get_rank(),args.sample_type,args.probability,args.memory_type,args.shared_memory_ssim))
torch.save(loss_list,'all_args.seed/{}/{}/loss_{}_{}_{}_{}_{}_{}_{}_{}.pt'.format(args.dataname,args.model,args.partition,args.topk,dist.get_world_size(),dist.get_rank(),args.sample_type,args.probability,args.memory_type,args.shared_memory_ssim)) torch.save(loss_list,'all_{}/{}/{}/loss_{}_{}_{}_{}_{}_{}_{}_{}.pt'.format(args.seed,args.dataname,args.model,args.partition,args.topk,dist.get_world_size(),dist.get_rank(),args.sample_type,args.probability,args.memory_type,args.shared_memory_ssim))
torch.save(test_ap_list,'all_args.seed/{}/{}/test_{}_{}_{}_{}_{}_{}_{}_{}.pt'.format(args.dataname,args.model,args.partition,args.topk,dist.get_world_size(),dist.get_rank(),args.sample_type,args.probability,args.memory_type,args.shared_memory_ssim)) torch.save(test_ap_list,'all_{}/{}/{}/test_{}_{}_{}_{}_{}_{}_{}_{}.pt'.format(args.seed,args.dataname,args.model,args.partition,args.topk,dist.get_world_size(),dist.get_rank(),args.sample_type,args.probability,args.memory_type,args.shared_memory_ssim))
print(avg_time) print(avg_time)
if not early_stop: if not early_stop:
......
...@@ -299,13 +299,16 @@ class TransfomerAttentionLayer(torch.nn.Module): ...@@ -299,13 +299,16 @@ class TransfomerAttentionLayer(torch.nn.Module):
#b.edata['v1'] = V_remote #b.edata['v1'] = V_remote
#b.update_all(dgl.function.copy_e('v0', 'm0'), dgl.function.sum('m0', 'h0')) #b.update_all(dgl.function.copy_e('v0', 'm0'), dgl.function.sum('m0', 'h0'))
#b.update_all(dgl.function.copy_e('v1', 'm1'), dgl.function.sum('m1', 'h1')) #b.update_all(dgl.function.copy_e('v1', 'm1'), dgl.function.sum('m1', 'h1'))
if 'weight' in b.edata: #if 'weight' in b.edata and self.training is True:
with torch.no_grad(): # with torch.no_grad():
weight = b.edata['weight'].reshape(-1,1)#(b.edata['weight']/torch.sum(b.edata['weight']).item()).reshape(-1,1) # weight = b.edata['weight'].reshape(-1,1)#(b.edata['weight']/torch.sum(b.edata['weight']).item()).reshape(-1,1)
#weight =
#print(weight.max()) #print(weight.max())
b.edata['v'] = V*weight # b.edata['v'] = V*weight
else: #else:
# weight = b.edata['weight'].reshape(-1,1)
b.edata['v'] = V b.edata['v'] = V
#print(torch.sum(torch.sum(((V-V*weight)**2))))
b.update_all(dgl.function.copy_e('v', 'm'), dgl.function.sum('m', 'h')) b.update_all(dgl.function.copy_e('v', 'm'), dgl.function.sum('m', 'h'))
#tt.ssim_local+=torch.sum(torch.cosine_similarity(b.dstdata['h'],b.dstdata['h0'])) #tt.ssim_local+=torch.sum(torch.cosine_similarity(b.dstdata['h'],b.dstdata['h0']))
#tt.ssim_remote+=torch.sum(torch.cosine_similarity(b.dstdata['h'],b.dstdata['h1'])) #tt.ssim_remote+=torch.sum(torch.cosine_similarity(b.dstdata['h'],b.dstdata['h1']))
......
...@@ -52,20 +52,36 @@ class all_to_all_embedding(torch.autograd.Function): ...@@ -52,20 +52,36 @@ class all_to_all_embedding(torch.autograd.Function):
grad[dst_pos_index] = grad_pos_dst grad[dst_pos_index] = grad_pos_dst
grad[dst_neg_index] = grad_neg_dst grad[dst_neg_index] = grad_neg_dst
return grad,None,None return grad,None,None
class NegFixLayer(torch.autograd.Function):
def __init__(self):
super(NegFixLayer, self).__init__()
def forward(ctx, input, weight):
ctx.save_for_backward(weight)
return input
def backward(ctx, grad_output):
# Define your backward pass
# ...
weight, = ctx.saved_tensors
#print(weight)
return grad_output/weight,None
class GeneralModel(torch.nn.Module): class GeneralModel(torch.nn.Module):
def __init__(self, dim_node, dim_edge, sample_param, memory_param, gnn_param, train_param, num_nodes = None,mailbox = None,combined=False): def __init__(self, dim_node, dim_edge, sample_param, memory_param, gnn_param, train_param, num_nodes = None,mailbox = None,combined=False,train_ratio = None):
super(GeneralModel, self).__init__() super(GeneralModel, self).__init__()
self.dim_node = dim_node self.dim_node = dim_node
self.dim_node_input = dim_node self.dim_node_input = dim_node
self.dim_edge = dim_edge self.dim_edge = dim_edge
self.sample_param = sample_param self.sample_param = sample_param
self.memory_param = memory_param self.memory_param = memory_param
self.train_pos_ratio,self.train_neg_ratio = train_ratio
if not 'dim_out' in gnn_param: if not 'dim_out' in gnn_param:
gnn_param['dim_out'] = memory_param['dim_out'] gnn_param['dim_out'] = memory_param['dim_out']
self.gnn_param = gnn_param self.gnn_param = gnn_param
self.train_param = train_param self.train_param = train_param
self.neg_fix_layer = NegFixLayer()
if memory_param['type'] == 'node': if memory_param['type'] == 'node':
if memory_param['memory_update'] == 'gru': if memory_param['memory_update'] == 'gru':
#if memory_param['async'] == False: #if memory_param['async'] == False:
...@@ -138,12 +154,24 @@ class GeneralModel(torch.nn.Module): ...@@ -138,12 +154,24 @@ class GeneralModel(torch.nn.Module):
h_pos_src = out[metadata['src_pos_index']] h_pos_src = out[metadata['src_pos_index']]
h_pos_dst = out[metadata['dst_pos_index']] h_pos_dst = out[metadata['dst_pos_index']]
h_neg_dst = out[metadata['dst_neg_index']] h_neg_dst = out[metadata['dst_neg_index']]
#end.record() #end.record()
#end.synchronize() #end.synchronize()
#elapsed_time_ms = start.elapsed_time(end) #elapsed_time_ms = start.elapsed_time(end)
#print('time {}\n'.format(elapsed_time_ms)) #print('time {}\n'.format(elapsed_time_ms))
#print('pos src {} \n pos dst {} \n neg dst{} \n'.format(h_pos_src, h_pos_dst,h_neg_dst)) #print('pos src {} \n pos dst {} \n neg dst{} \n'.format(h_pos_src, h_pos_dst,h_neg_dst))
#print('pre predict {}'.format(mfgs[0][0].srcdata['ID'])) #print('pre predict {}'.format(mfgs[0][0].srcdata['ID']))
#if self.training is True:
# with torch.no_grad():
# ones = torch.ones(h_neg_dst.shape[0],device = h_neg_dst.device,dtype=torch.float)
# weight = torch.where(DistIndex(mfgs[0][0].srcdata['ID'][metadata['dst_neg_index']]).part == torch.distributed.get_rank(),ones/self.train_pos_ratio,ones/self.train_neg_ratio).reshape(-1,1)
#weight = torch.clip(weigh)
#weight = weight/weight.max().item()
#print(weight)
#weight =
#h_neg_dst*weight
# pred = self.edge_predictor(h_pos_src, h_pos_dst, None , self.neg_fix_layer.apply(h_neg_dst,weight), neg_samples=neg_samples, mode = mode)
#else:
pred = self.edge_predictor(h_pos_src, h_pos_dst, None , h_neg_dst, neg_samples=neg_samples, mode = mode) pred = self.edge_predictor(h_pos_src, h_pos_dst, None , h_neg_dst, neg_samples=neg_samples, mode = mode)
t_embedding = tt.elapsed_event(t1) t_embedding = tt.elapsed_event(t1)
tt.time_embedding+=t_embedding tt.time_embedding+=t_embedding
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment