Commit d1128852 by zhlj

test new memory

parent e713643f
......@@ -4,19 +4,20 @@ import torch
# 读取文件内容
ssim_values = [0, 0.1, 0.2, 0.3, 0.4, 2] # 假设这是你的 ssim 参数值
probability_values = [1,0.5,0.1,0.05,0.01,0]
data_values = ['WikiTalk','StackOverflow'] # 存储从文件中读取的数据
partition = 'ours_shared'
data_values = ['WIKI_3','LASTFM_3','WikiTalk','StackOverflow'] # 存储从文件中读取的数据
partition = 'ours'
# 从文件中读取数据,假设数据存储在文件 data.txt 中
#all/"$data"/"$partitions"-ours_shared-0.01-"$mem"-"$ssim"-"$sample".out
partitions=8
topk=0.01
partitions=4
topk=0
mem='all_update'#'historical'
model='TGN'
for data in data_values:
ap_list = []
comm_list = []
for p in probability_values:
file = '{}/{}-{}-{}-{}-boundery_recent_decay-{}.out'.format(data,partitions,partition,topk,mem,p)
file = '{}/{}/{}-{}-{}-{}-boundery_recent_uniform-{}.out'.format(data,model,partitions,partition,topk,mem,p)
prefix = 'best test AP:'
cnt = 0
sum = 0
......@@ -27,11 +28,14 @@ for data in data_values:
pos = line.find('remote node number tensor')
if(pos!=-1):
posr = line.find(']',pos+2+len('remote node number tensor'),)
#print(line,line[pos+2+len('remote node number tensor'):posr])
comm = int(line[pos+2+len('remote node number tensor'):posr])
#print()
sum = sum+comm
cnt = cnt+1
#print(comm)
ap_list.append(ap)
comm_list.append(comm/cnt*4)
comm_list.append(sum/cnt*4)
# 绘制柱状图
......@@ -50,16 +54,16 @@ for data in data_values:
plt.xlabel('probability')
plt.ylabel('Test AP')
plt.title('{}({} partitions)'.format(data,partitions))
plt.savefig('boundary_AP_{}_{}.png'.format(data,partitions))
plt.savefig('boundary_AP_{}_{}_{}.png'.format(data,partitions,model))
plt.clf()
print(comm_list)
plt.bar([b for b in bars], comm_list, width=bar_width)
# 绘制柱状图
plt.xticks([b for b in bars], probability_values)
plt.xlabel('probability')
plt.ylabel('Communication volume')
plt.title('{}({} partitions)'.format(data,partitions))
plt.savefig('boundary_comm_{}_{}.png'.format(data,partitions))
plt.savefig('boundary_comm_{}_{}_{}.png'.format(data,partitions,model))
plt.clf()
if partition == 'ours_shared':
......@@ -67,8 +71,8 @@ for data in data_values:
else:
partition0=partition
for p in probability_values:
file = '{}/val_{}_{}_{}_0_boundery_recent_decay_{}_all_update_2.pt'.format(data,partition0,topk,partitions,float(p))
val_ap = torch.tensor(torch.load(file))
file = '{}/{}/test_{}_{}_{}_0_boundery_recent_uniform_{}_all_update_2.pt'.format(data,model,partition0,topk,partitions,float(p))
val_ap = torch.tensor(torch.load(file))[:,0]
epoch = torch.arange(val_ap.shape[0])
#绘制曲线图
plt.plot(epoch,val_ap, label='probability={}'.format(p))
......@@ -77,5 +81,5 @@ for data in data_values:
plt.title('{}({} partitions)'.format(data,partitions))
# plt.grid(True)
plt.legend()
plt.savefig('{}_{}_boundary_Convergence_rate.png'.format(data,partitions))
plt.savefig('{}_{}_{}_boundary_Convergence_rate.png'.format(data,partitions,model))
plt.clf()
......@@ -2,13 +2,13 @@ import matplotlib.pyplot as plt
import numpy as np
import torch
# 读取文件内容
ssim_values = [-1,0,0.3,0.7,2] # 假设这是你的 ssim 参数值
ssim_values = [-1,0.3,0.5,0.7,2] # 假设这是你的 ssim 参数值
data_values = ['WIKI','LASTFM','WikiTalk','REDDIT','LASTFM','DGraphFin'] # 存储从文件中读取的数据
partition = 'ours_shared'
# 从文件中读取数据,假设数据存储在文件 data.txt 中
#all/"$data"/"$partitions"-ours_shared-0.01-"$mem"-"$ssim"-"$sample".out
partitions=4
model = 'JODIE'
model = 'TGN'
topk=0.01
mem='historical'
for data in data_values:
......
......@@ -261,8 +261,8 @@ def main():
print('dst len {} origin len {}'.format(graph.edge_index[1,mask].unique().shape[0],full_dst.unique().shape[0]))
train_neg_sampler = LocalNegativeSampling('triplet',amount = args.neg_samples,dst_node_list = graph.edge_index[1,mask].unique())
else:
train_neg_sampler = LocalNegativeSampling('triplet',amount = args.neg_samples,dst_node_list = full_dst.unique())
#train_neg_sampler = LocalNegativeSampling('triplet',amount = args.neg_samples,dst_node_list = full_dst.unique(),local_mask=(DistIndex(graph.nids_mapper[full_dst.unique()].to('cpu')).part == dist.get_rank()))
#train_neg_sampler = LocalNegativeSampling('triplet',amount = args.neg_samples,dst_node_list = full_dst.unique())
train_neg_sampler = LocalNegativeSampling('triplet',amount = args.neg_samples,dst_node_list = full_dst.unique(),local_mask=(DistIndex(graph.nids_mapper[full_dst.unique()].to('cpu')).part == dist.get_rank()))
print(train_neg_sampler.dst_node_list)
neg_sampler = LocalNegativeSampling('triplet',amount= neg_samples,dst_node_list = full_dst.unique(),seed=6773)
......@@ -279,7 +279,7 @@ def main():
is_pipeline=True,
use_local_feature = False,
device = torch.device('cuda:{}'.format(local_rank)),
probability=1,#train_cross_probability,
probability=0.1,#train_cross_probability,
reversed = (gnn_param['arch'] == 'identity')
)
......@@ -454,7 +454,7 @@ def main():
cos = torch.nn.CosineSimilarity(dim=0)
return cos(normalize(x1),normalize(x2)).sum()/x1.size(dim=0)
creterion = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=train_param['lr'])
optimizer = torch.optim.Adam(model.parameters(), lr=train_param['lr'],weight_decay=1e-4)
early_stopper = EarlyStopMonitor(max_round=args.patience)
MODEL_SAVE_PATH = f'../saved_models/{args.model}-{args.dataname}-{dist.get_world_size()}.pth'
epoch_cnt = 0
......
......@@ -423,6 +423,9 @@ class AsyncMemeoryUpdater(torch.nn.Module):
if self.mode == 'historical':
self.gamma = torch.nn.Parameter(torch.tensor([0.9]),
requires_grad=True)
self.filter = Filter(n_nodes=mailbox.shared_nodes_index.shape[0],
memory_dimension=self.dim_hid,
)
else:
self.gamma = 1
def forward(self, mfg, param = None):
......@@ -431,12 +434,36 @@ class AsyncMemeoryUpdater(torch.nn.Module):
updated_memory0 = self.updater(b)
mask = DistIndex(b.srcdata['ID']).is_shared
#incr = updated_memory[mask] - b.srcdata['mem'][mask]
time_feat = self.time_enc(b.srcdata['ts'][mask].reshape(-1,1) - b.srcdata['his_ts'][mask].reshape(-1,1))
his_mem = torch.cat((mail_input[mask],time_feat),dim = 1)
upd0 = torch.zeros_like(updated_memory0)
upd0[mask] = self.ceil_updater(his_mem, b.srcdata['his_mem'][mask])
#updated_memory = torch.where(mask.unsqueeze(1),self.gamma*updated_memory0 + (1-self.gamma)*(b.srcdata['his_mem']),updated_memory0)
#time_feat = self.time_enc(b.srcdata['ts'][mask].reshape(-1,1) - b.srcdata['his_ts'][mask].reshape(-1,1))
#his_mem = torch.cat((mail_input[mask],time_feat),dim = 1)
with torch.no_grad():
upd0 = torch.zeros_like(updated_memory0)
if self.mode == 'historical':
shared_ind = self.mailbox.is_shared_mask[DistIndex(b.srcdata['ID'][mask]).loc]
transition_dense = self.filter.get_incretment(shared_ind)
transition_dense*=2
if not (transition_dense.max().item() == 0):
transition_dense -= transition_dense.min()
transition_dense /=transition_dense.max()
transition_dense = 2*transition_dense - 1
upd0[mask] = b.srcdata['his_mem'][mask] + transition_dense
else:
upd0[mask] = b.srcdata['his_mem'][mask]
#upd0[mask] = self.ceil_updater(his_mem, b.srcdata['his_mem'][mask])
#updated_memory = torch.where(mask.unsqueeze(1),self.gamma*updated_memory0 + (1-self.gamma)*(b.srcdata['his_mem'])
# ,updated_memory0)
updated_memory = torch.where(mask.unsqueeze(1),self.gamma*updated_memory0 + (1-self.gamma)*(upd0),updated_memory0)
with torch.no_grad():
if self.mode == 'historical':
change = upd0[mask] - b.srcdata['his_mem'][mask]
change.detach()
if not (change.max().item() == 0):
change -= change.min()
change /=change.max()
change = 2*change - 1
self.filter.update(shared_ind,change)
#print(transition_dense)
print(torch.cosine_similarity(updated_memory[mask],b.srcdata['his_mem'][mask]).sum()/torch.sum(mask))
print(self.gamma)
self.pre_mem = b.srcdata['his_mem']
......
......@@ -59,9 +59,10 @@ class LocalNegativeSampling(NegativeSampling):
return torch.randint(num_nodes, (num_samples, ),generator=self.rdm)
elif self.local_mask is not None:
p = torch.rand(size=(num_samples,))
sr = torch.randint(len(self.dst_node_list), (num_samples, ),generator=self.rdm)
sl = torch.randint(len(self.local_dst), (num_samples, ),generator=self.rdm)
return torch.where(p<0.9,sl,sr)
sr = self.dst_node_list[torch.randint(len(self.dst_node_list), (num_samples, ),generator=self.rdm)]
sl = self.local_dst[torch.randint(len(self.local_dst), (num_samples, ),generator=self.rdm)]
s=torch.where(p<=1,sl,sr)
return sr
else:
s = torch.randint(len(self.dst_node_list), (num_samples, ),generator=self.rdm)
return self.dst_node_list[s]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment