Commit e713643f by zhlj

update memory

parent e0711e14
...@@ -9,6 +9,7 @@ __pycache__/ ...@@ -9,6 +9,7 @@ __pycache__/
*.so *.so
# Distribution / packaging # Distribution / packaging
examples/all*
.Python .Python
examples/all* examples/all*
build/ build/
......
...@@ -14,6 +14,8 @@ memory: ...@@ -14,6 +14,8 @@ memory:
deliver_to: 'neighbors' deliver_to: 'neighbors'
mail_combine: 'last' mail_combine: 'last'
memory_update: 'transformer' memory_update: 'transformer'
historical_fix: False
async: True
attention_head: 2 attention_head: 2
mailbox_size: 10 mailbox_size: 10
combine_node_feature: False combine_node_feature: False
...@@ -22,7 +24,7 @@ gnn: ...@@ -22,7 +24,7 @@ gnn:
- arch: 'identity' - arch: 'identity'
train: train:
- epoch: 100 - epoch: 100
batch_size: 600 batch_size: 1000
lr: 0.0001 lr: 0.0001
dropout: 0.1 dropout: 0.1
att_dropout: 0.1 att_dropout: 0.1
......
...@@ -7,6 +7,8 @@ memory: ...@@ -7,6 +7,8 @@ memory:
deliver_to: 'self' deliver_to: 'self'
mail_combine: 'last' mail_combine: 'last'
memory_update: 'rnn' memory_update: 'rnn'
historical_fix: False
async: True
mailbox_size: 1 mailbox_size: 1
combine_node_feature: True combine_node_feature: True
dim_out: 100 dim_out: 100
...@@ -16,8 +18,8 @@ gnn: ...@@ -16,8 +18,8 @@ gnn:
use_dst_emb: False use_dst_emb: False
time_transform: 'JODIE' time_transform: 'JODIE'
train: train:
- epoch: 100 - epoch: 250
batch_size: 1000 batch_size: 1000
lr: 0.0001 lr: 0.0002
dropout: 0.1 dropout: 0.1
all_on_gpu: True all_on_gpu: True
\ No newline at end of file
...@@ -27,7 +27,7 @@ gnn: ...@@ -27,7 +27,7 @@ gnn:
dim_time: 100 dim_time: 100
dim_out: 100 dim_out: 100
train: train:
- epoch: 100 - epoch: 200
batch_size: 1000 batch_size: 1000
# reorder: 16 # reorder: 16
lr: 0.0004 lr: 0.0004
......
...@@ -4,11 +4,11 @@ import torch ...@@ -4,11 +4,11 @@ import torch
# 读取文件内容 # 读取文件内容
ssim_values = [0, 0.1, 0.2, 0.3, 0.4, 2] # 假设这是你的 ssim 参数值 ssim_values = [0, 0.1, 0.2, 0.3, 0.4, 2] # 假设这是你的 ssim 参数值
probability_values = [1,0.5,0.1,0.05,0.01,0] probability_values = [1,0.5,0.1,0.05,0.01,0]
data_values = ['WIKI'] # 存储从文件中读取的数据 data_values = ['WikiTalk','StackOverflow'] # 存储从文件中读取的数据
partition = 'ours_shared' partition = 'ours_shared'
# 从文件中读取数据,假设数据存储在文件 data.txt 中 # 从文件中读取数据,假设数据存储在文件 data.txt 中
#all/"$data"/"$partitions"-ours_shared-0.01-"$mem"-"$ssim"-"$sample".out #all/"$data"/"$partitions"-ours_shared-0.01-"$mem"-"$ssim"-"$sample".out
partitions=4 partitions=8
topk=0.01 topk=0.01
mem='all_update'#'historical' mem='all_update'#'historical'
for data in data_values: for data in data_values:
...@@ -35,11 +35,12 @@ for data in data_values: ...@@ -35,11 +35,12 @@ for data in data_values:
# 绘制柱状图 # 绘制柱状图
print('{} TestAP={}\n'.format(data,ap_list))
bar_width = 0.4 bar_width = 0.4
#shared comm tensor #shared comm tensor
# 设置柱状图的位置 # 设置柱状图的位置
bars = range(len(ssim_values)) bars = range(len(probability_values))
# 绘制柱状图 # 绘制柱状图
plt.bar([b for b in bars], ap_list, width=bar_width) plt.bar([b for b in bars], ap_list, width=bar_width)
...@@ -49,7 +50,7 @@ for data in data_values: ...@@ -49,7 +50,7 @@ for data in data_values:
plt.xlabel('probability') plt.xlabel('probability')
plt.ylabel('Test AP') plt.ylabel('Test AP')
plt.title('{}({} partitions)'.format(data,partitions)) plt.title('{}({} partitions)'.format(data,partitions))
plt.savefig('boundary_AP_{}.png'.format(data)) plt.savefig('boundary_AP_{}_{}.png'.format(data,partitions))
plt.clf() plt.clf()
plt.bar([b for b in bars], comm_list, width=bar_width) plt.bar([b for b in bars], comm_list, width=bar_width)
...@@ -58,7 +59,7 @@ for data in data_values: ...@@ -58,7 +59,7 @@ for data in data_values:
plt.xlabel('probability') plt.xlabel('probability')
plt.ylabel('Communication volume') plt.ylabel('Communication volume')
plt.title('{}({} partitions)'.format(data,partitions)) plt.title('{}({} partitions)'.format(data,partitions))
plt.savefig('boundary_comm_{}.png'.format(data)) plt.savefig('boundary_comm_{}_{}.png'.format(data,partitions))
plt.clf() plt.clf()
if partition == 'ours_shared': if partition == 'ours_shared':
...@@ -76,5 +77,5 @@ for data in data_values: ...@@ -76,5 +77,5 @@ for data in data_values:
plt.title('{}({} partitions)'.format(data,partitions)) plt.title('{}({} partitions)'.format(data,partitions))
# plt.grid(True) # plt.grid(True)
plt.legend() plt.legend()
plt.savefig('{}_boundary_Convergence_rate.png'.format(data)) plt.savefig('{}_{}_boundary_Convergence_rate.png'.format(data,partitions))
plt.clf() plt.clf()
...@@ -2,12 +2,13 @@ import matplotlib.pyplot as plt ...@@ -2,12 +2,13 @@ import matplotlib.pyplot as plt
import numpy as np import numpy as np
import torch import torch
# 读取文件内容 # 读取文件内容
ssim_values = [0, 0.5, 1.0, 1.5, 2] # 假设这是你的 ssim 参数值 ssim_values = [-1,0,0.3,0.7,2] # 假设这是你的 ssim 参数值
data_values = ['WIKI','WikiTalk','REDDIT','LASTFM','DGraphFin'] # 存储从文件中读取的数据 data_values = ['WIKI','LASTFM','WikiTalk','REDDIT','LASTFM','DGraphFin'] # 存储从文件中读取的数据
partition = 'ours_shared' partition = 'ours_shared'
# 从文件中读取数据,假设数据存储在文件 data.txt 中 # 从文件中读取数据,假设数据存储在文件 data.txt 中
#all/"$data"/"$partitions"-ours_shared-0.01-"$mem"-"$ssim"-"$sample".out #all/"$data"/"$partitions"-ours_shared-0.01-"$mem"-"$ssim"-"$sample".out
partitions=4 partitions=4
model = 'JODIE'
topk=0.01 topk=0.01
mem='historical' mem='historical'
for data in data_values: for data in data_values:
...@@ -15,9 +16,11 @@ for data in data_values: ...@@ -15,9 +16,11 @@ for data in data_values:
comm_list = [] comm_list = []
for ssim in ssim_values: for ssim in ssim_values:
if ssim == 2: if ssim == 2:
file = '{}/{}-{}-{}-local-recent.out'.format(data,partitions,partition,topk) file = '{}/{}/{}-{}-{}-local-recent.out'.format(data,model,partitions,partition,topk)
elif ssim == -1:
file = '{}/{}/{}-{}-{}-all_update-recent.out'.format(data,model,partitions,partition,topk)
else: else:
file = '{}/{}-{}-{}-{}-{}-recent.out'.format(data,partitions,partition,topk,mem,ssim) file = '{}/{}/{}-{}-{}-{}-{}-recent.out'.format(data,model,partitions,partition,topk,mem,ssim)
prefix = 'best test AP:' prefix = 'best test AP:'
with open(file, 'r') as file: with open(file, 'r') as file:
for line in file: for line in file:
...@@ -26,6 +29,7 @@ for data in data_values: ...@@ -26,6 +29,7 @@ for data in data_values:
pos = line.find('shared comm tensor') pos = line.find('shared comm tensor')
if(pos!=-1): if(pos!=-1):
comm = int(line[pos+2+len('shared comm tensor'):len(line)-3]) comm = int(line[pos+2+len('shared comm tensor'):len(line)-3])
print(ap)
ap_list.append(ap) ap_list.append(ap)
comm_list.append(comm) comm_list.append(comm)
print('{} TestAP={}\n'.format(data,ap_list)) print('{} TestAP={}\n'.format(data,ap_list))
...@@ -33,7 +37,7 @@ for data in data_values: ...@@ -33,7 +37,7 @@ for data in data_values:
# 绘制柱状图 # 绘制柱状图
bar_width = 0.4 bar_width = 0.4
#shared comm tensor #shared comm tensor
print('{} TestAP={}\n'.format(data,ap_list))
# 设置柱状图的位置 # 设置柱状图的位置
bars = range(len(ssim_values)) bars = range(len(ssim_values))
...@@ -43,8 +47,10 @@ for data in data_values: ...@@ -43,8 +47,10 @@ for data in data_values:
plt.xticks([b for b in bars], ssim_values) plt.xticks([b for b in bars], ssim_values)
plt.xlabel('SSIM threshold Values') plt.xlabel('SSIM threshold Values')
plt.ylabel('Test AP') plt.ylabel('Test AP')
#if(data=='WIKI'):
# plt.ylim([0.97,1])
plt.title('{}({} partitions)'.format(data,partitions)) plt.title('{}({} partitions)'.format(data,partitions))
plt.savefig('ssim_{}_{}.png'.format(data,partitions)) plt.savefig('ssim_{}_{}_{}.png'.format(data,partitions,model))
plt.clf() plt.clf()
plt.bar([b for b in bars], comm_list, width=bar_width) plt.bar([b for b in bars], comm_list, width=bar_width)
...@@ -53,7 +59,7 @@ for data in data_values: ...@@ -53,7 +59,7 @@ for data in data_values:
plt.xlabel('SSIM threshold Values') plt.xlabel('SSIM threshold Values')
plt.ylabel('Communication volume') plt.ylabel('Communication volume')
plt.title('{}({} partitions)'.format(data,partitions)) plt.title('{}({} partitions)'.format(data,partitions))
plt.savefig('ssim_comm_{}_{}.png'.format(data,partitions)) plt.savefig('ssim_comm_{}_{}_{}.png'.format(data,partitions,model))
plt.clf() plt.clf()
if partition == 'ours_shared': if partition == 'ours_shared':
...@@ -62,18 +68,28 @@ for data in data_values: ...@@ -62,18 +68,28 @@ for data in data_values:
partition0=partition partition0=partition
for ssim in ssim_values: for ssim in ssim_values:
if ssim == 2: if ssim == 2:
file = '{}/val_{}_{}_{}_0_recent_0.1_local_2.pt'.format(data,partition0,topk,partitions,) file = '{}/{}/test_{}_{}_{}_0_recent_0.1_local_2.pt'.format(data,model,partition0,topk,partitions,)
elif ssim == -1:
file = '{}/{}/test_{}_{}_{}_0_recent_0.1_all_update_2.pt'.format(data,model,partition0,topk,partitions,)
else: else:
file = '{}/val_{}_{}_{}_0_recent_0.1_{}_{}.pt'.format(data,partition0,topk,partitions,mem,float(ssim)) file = '{}/{}/test_{}_{}_{}_0_recent_0.1_{}_{}.pt'.format(data,model,partition0,topk,partitions,mem,float(ssim))
val_ap = torch.tensor(torch.load(file)) val_ap = torch.tensor(torch.load(file))[:,0]
print(val_ap)
epoch = torch.arange(val_ap.shape[0]) epoch = torch.arange(val_ap.shape[0])
#绘制曲线图 #绘制曲线图
print(val_ap) #print(val_ap)
plt.plot(epoch,val_ap, label='ssim={}'.format(ssim)) if ssim == -1:
plt.plot(epoch,val_ap, label='all-update')
elif ssim == 2:
plt.plot(epoch,val_ap, label='local')
else:
plt.plot(epoch,val_ap, label='ssim = {}'.format(ssim))
if(data=='WIKI'):
plt.ylim([0.85,0.90])
plt.xlabel('Epoch') plt.xlabel('Epoch')
plt.ylabel('Val AP') plt.ylabel('Val AP')
plt.title('{}({} partitions)'.format(data,partitions)) plt.title('{}({} partitions)'.format(data,partitions))
# plt.grid(True) # plt.grid(True)
plt.legend() plt.legend()
plt.savefig('{}_{}_ssim_Convergence_rate.png'.format(data,partitions)) plt.savefig('{}_{}_{}_ssim_Convergence_rate.png'.format(data,partitions,model))
plt.clf() plt.clf()
...@@ -41,7 +41,7 @@ pinn_memory = {} ...@@ -41,7 +41,7 @@ pinn_memory = {}
class HistoricalCache: class HistoricalCache:
def __init__(self,cache_index,layer,shape,dtype,device,threshold = 3,time_threshold = None, times_threshold = 5, use_rpc = True, num_threshold = 0): def __init__(self,cache_index,layer,shape,dtype,device,threshold = 3,time_threshold = None, times_threshold = 10, use_rpc = True, num_threshold = 0):
#self.cache_index = cache_index #self.cache_index = cache_index
self.layer = layer self.layer = layer
print(shape) print(shape)
...@@ -88,12 +88,17 @@ class HistoricalCache: ...@@ -88,12 +88,17 @@ class HistoricalCache:
return torch.sum((x -y)**2,dim = 1) return torch.sum((x -y)**2,dim = 1)
def historical_check(self,index,new_data,ts): def historical_check(self,index,new_data,ts):
if self.time_threshold is not None: if self.time_threshold is not None:
return (self.ssim(new_data,self.local_historical_data[index]) > self.threshold | (ts - self.local_ts[index] > self.time_threshold | self.loss_count[index] > self.times_threshold)) mask = (self.ssim(new_data,self.local_historical_data[index]) > self.threshold | (ts - self.local_ts[index] > self.time_threshold | self.loss_count[index] > self.times_threshold))
self.loss_count[index][~mask] += 1
self.loss_count[index][mask] = 0
else: else:
#print('{} {} {} {} \n'.format(index,self.ssim(new_data,self.local_historical_data[index]),new_data,self.local_historical_data[index])) #print('{} {} {} {} \n'.format(index,self.ssim(new_data,self.local_historical_data[index]),new_data,self.local_historical_data[index]))
#print(new_data,self.local_historical_data[index]) #print(new_data,self.local_historical_data[index])
#print(self.ssim(new_data,self.local_historical_data[index]) < self.threshold, (self.loss_count[index] > self.times_threshold)) #print(self.ssim(new_data,self.local_historical_data[index]) < self.threshold, (self.loss_count[index] > self.times_threshold))
return (self.ssim(new_data,self.local_historical_data[index]) > self.threshold) | (self.loss_count[index] > self.times_threshold) mask = (self.ssim(new_data,self.local_historical_data[index]) > self.threshold) | (self.loss_count[index] > self.times_threshold)
self.loss_count[index][~mask] += 1
self.loss_count[index][mask] = 0
return mask
def read_synchronize(self): def read_synchronize(self):
torch.cuda.synchronize(get_stream_set(self.layer)) torch.cuda.synchronize(get_stream_set(self.layer))
......
...@@ -2,6 +2,8 @@ from os.path import abspath, join, dirname ...@@ -2,6 +2,8 @@ from os.path import abspath, join, dirname
import os import os
import sys import sys
from os.path import abspath, join, dirname from os.path import abspath, join, dirname
from starrygl.distributed.utils import DistIndex
sys.path.insert(0, join(abspath(dirname(__file__)))) sys.path.insert(0, join(abspath(dirname(__file__))))
import torch import torch
import dgl import dgl
...@@ -172,7 +174,7 @@ class MixerMLP(torch.nn.Module): ...@@ -172,7 +174,7 @@ class MixerMLP(torch.nn.Module):
self.block_padding(b) self.block_padding(b)
#return x #return x
class TransfomerAttentionLayer(torch.nn.Module): class TransfomerAttentionLayer(torch.nn.Module):
def __init__(self, dim_node_feat, dim_edge_feat, dim_time, num_head, dropout, att_dropout, dim_out, combined=False): def __init__(self, dim_node_feat, dim_edge_feat, dim_time, num_head, dropout, att_dropout, dim_out, combined=False):
...@@ -286,9 +288,22 @@ class TransfomerAttentionLayer(torch.nn.Module): ...@@ -286,9 +288,22 @@ class TransfomerAttentionLayer(torch.nn.Module):
#att = dgl.ops.e_div_v(b,att_e_sub_max,torch.clamp_min(dgl.ops.copy_e_sum(b,att_e_sub_max),1)) #att = dgl.ops.e_div_v(b,att_e_sub_max,torch.clamp_min(dgl.ops.copy_e_sum(b,att_e_sub_max),1))
att = dgl.ops.edge_softmax(b, self.att_act(torch.sum(Q*K, dim=2))) att = dgl.ops.edge_softmax(b, self.att_act(torch.sum(Q*K, dim=2)))
att = self.att_dropout(att) att = self.att_dropout(att)
tt.weight_count_remote+=torch.sum(att[DistIndex(b.srcdata['ID']).part[b.edges()[0]]!=torch.distributed.get_rank()]**2)
tt.weight_count_local+=torch.sum(att[DistIndex(b.srcdata['ID']).part[b.edges()[0]]==torch.distributed.get_rank()]**2)
V = torch.reshape(V*att[:, :, None], (V.shape[0], -1)) V = torch.reshape(V*att[:, :, None], (V.shape[0], -1))
V_local = V.clone()
V_remote = V.clone()
V_local[DistIndex(b.srcdata['ID']).part[b.edges()[0]]!=torch.distributed.get_rank()] = 0
V_remote[DistIndex(b.srcdata['ID']).part[b.edges()[0]]==torch.distributed.get_rank()] = 0
b.edata['v'] = V b.edata['v'] = V
b.edata['v0'] = V_local
b.edata['v1'] = V_remote
b.update_all(dgl.function.copy_e('v0', 'm0'), dgl.function.sum('m0', 'h0'))
b.update_all(dgl.function.copy_e('v1', 'm1'), dgl.function.sum('m1', 'h1'))
b.update_all(dgl.function.copy_e('v', 'm'), dgl.function.sum('m', 'h')) b.update_all(dgl.function.copy_e('v', 'm'), dgl.function.sum('m', 'h'))
tt.ssim_local+=torch.sum(torch.cosine_similarity(b.dstdata['h'],b.dstdata['h0']))
tt.ssim_remote+=torch.sum(torch.cosine_similarity(b.dstdata['h'],b.dstdata['h1']))
tt.ssim_cnt += b.num_dst_nodes()
#print('dst {}'.format(b.dstdata['h'])) #print('dst {}'.format(b.dstdata['h']))
#b.srcdata['v'] = torch.cat([torch.zeros((b.num_dst_nodes(), V.shape[1]), device=torch.device('cuda:0')), V], dim=0) #b.srcdata['v'] = torch.cat([torch.zeros((b.num_dst_nodes(), V.shape[1]), device=torch.device('cuda:0')), V], dim=0)
#b.update_all(dgl.function.copy_u('v', 'm'), dgl.function.sum('m', 'h')) #b.update_all(dgl.function.copy_u('v', 'm'), dgl.function.sum('m', 'h'))
......
...@@ -331,6 +331,7 @@ class TransformerMemoryUpdater(torch.nn.Module): ...@@ -331,6 +331,7 @@ class TransformerMemoryUpdater(torch.nn.Module):
self.mlp = torch.nn.Linear(dim_out, dim_out) self.mlp = torch.nn.Linear(dim_out, dim_out)
self.dropout = torch.nn.Dropout(train_param['dropout']) self.dropout = torch.nn.Dropout(train_param['dropout'])
self.att_dropout = torch.nn.Dropout(train_param['att_dropout']) self.att_dropout = torch.nn.Dropout(train_param['att_dropout'])
def forward(self, b, param = None): def forward(self, b, param = None):
Q = self.w_q(b.srcdata['mem']).reshape((b.num_src_nodes(), self.att_h, -1)) Q = self.w_q(b.srcdata['mem']).reshape((b.num_src_nodes(), self.att_h, -1))
...@@ -396,7 +397,7 @@ class AsyncMemeoryUpdater(torch.nn.Module): ...@@ -396,7 +397,7 @@ class AsyncMemeoryUpdater(torch.nn.Module):
self.ceil_updater = updater(memory_param, dim_in, dim_hid, dim_time, train_param) self.ceil_updater = updater(memory_param, dim_in, dim_hid, dim_time, train_param)
self.updater = self.transformer_updater self.updater = self.transformer_updater
else: else:
self.updater = updater(dim_in + dim_time, dim_hid) self.ceil_updater = updater(dim_in + dim_time, dim_hid)
self.updater = self.rnn_updater self.updater = self.rnn_updater
self.last_updated_memory = None self.last_updated_memory = None
self.last_updated_ts = None self.last_updated_ts = None
...@@ -419,13 +420,30 @@ class AsyncMemeoryUpdater(torch.nn.Module): ...@@ -419,13 +420,30 @@ class AsyncMemeoryUpdater(torch.nn.Module):
self.update_hunk = self.historical_func self.update_hunk = self.historical_func
elif self.mode == 'local' or self.mode=='all_local': elif self.mode == 'local' or self.mode=='all_local':
self.update_hunk = self.local_func self.update_hunk = self.local_func
if self.mode == 'historical':
self.gamma = torch.nn.Parameter(torch.tensor([0.9]),
requires_grad=True)
else:
self.gamma = 1
def forward(self, mfg, param = None): def forward(self, mfg, param = None):
for b in mfg: for b in mfg:
updated_memory = self.updater(b) mail_input = b.srcdata['mem_input']
updated_memory0 = self.updater(b)
mask = DistIndex(b.srcdata['ID']).is_shared
#incr = updated_memory[mask] - b.srcdata['mem'][mask]
time_feat = self.time_enc(b.srcdata['ts'][mask].reshape(-1,1) - b.srcdata['his_ts'][mask].reshape(-1,1))
his_mem = torch.cat((mail_input[mask],time_feat),dim = 1)
upd0 = torch.zeros_like(updated_memory0)
upd0[mask] = self.ceil_updater(his_mem, b.srcdata['his_mem'][mask])
#updated_memory = torch.where(mask.unsqueeze(1),self.gamma*updated_memory0 + (1-self.gamma)*(b.srcdata['his_mem']),updated_memory0)
updated_memory = torch.where(mask.unsqueeze(1),self.gamma*updated_memory0 + (1-self.gamma)*(upd0),updated_memory0)
print(torch.cosine_similarity(updated_memory[mask],b.srcdata['his_mem'][mask]).sum()/torch.sum(mask))
print(self.gamma)
self.pre_mem = b.srcdata['his_mem']
self.last_updated_ts = b.srcdata['ts'].detach().clone() self.last_updated_ts = b.srcdata['ts'].detach().clone()
self.last_updated_memory = updated_memory.detach().clone() self.last_updated_memory = updated_memory.detach().clone()
self.last_updated_nid = b.srcdata['ID'].detach().clone() self.last_updated_nid = b.srcdata['ID'].detach().clone()
with torch.no_grad(): with torch.no_grad():
if param is not None: if param is not None:
_,src,dst,ts,edge_feats,nxt_fetch_func = param _,src,dst,ts,edge_feats,nxt_fetch_func = param
...@@ -434,12 +452,14 @@ class AsyncMemeoryUpdater(torch.nn.Module): ...@@ -434,12 +452,14 @@ class AsyncMemeoryUpdater(torch.nn.Module):
self.last_updated_memory[indx], self.last_updated_memory[indx],
self.last_updated_ts[indx], self.last_updated_ts[indx],
None) None)
#print(index.shape[0])
if param[0]: if param[0]:
index, mail, mail_ts = self.mailbox.get_update_mail( index, mail, mail_ts = self.mailbox.get_update_mail(
b.srcdata['ID'],src,dst,ts,edge_feats, b.srcdata['ID'],src,dst,ts,edge_feats,
self.last_updated_memory, self.last_updated_memory,
None,False,False, None,False,False,block=b
) )
#print(index.shape[0])
if torch.distributed.get_world_size() == 0: if torch.distributed.get_world_size() == 0:
self.mailbox.mon.add(index,self.mailbox.node_memory.accessor.data[index],memory) self.mailbox.mon.add(index,self.mailbox.node_memory.accessor.data[index],memory)
##print(index.shape,memory.shape,memory_ts.shape,mail.shape,mail_ts.shape) ##print(index.shape,memory.shape,memory_ts.shape,mail.shape,mail_ts.shape)
...@@ -447,14 +467,14 @@ class AsyncMemeoryUpdater(torch.nn.Module): ...@@ -447,14 +467,14 @@ class AsyncMemeoryUpdater(torch.nn.Module):
self.mailbox.set_mailbox_local(DistIndex(index[local_mask]).loc,mail[local_mask],mail_ts[local_mask],Reduce_Op = 'max') self.mailbox.set_mailbox_local(DistIndex(index[local_mask]).loc,mail[local_mask],mail_ts[local_mask],Reduce_Op = 'max')
self.mailbox.set_memory_local(DistIndex(index[local_mask]).loc,memory[local_mask],memory_ts[local_mask], Reduce_Op = 'max') self.mailbox.set_memory_local(DistIndex(index[local_mask]).loc,memory[local_mask],memory_ts[local_mask], Reduce_Op = 'max')
self.update_hunk(index,memory,memory_ts,mail,mail_ts,nxt_fetch_func) self.update_hunk(index,memory,memory_ts,mail,mail_ts,nxt_fetch_func)
if self.memory_param['combine_node_feature'] and self.dim_node_feat > 0: if self.memory_param['combine_node_feature'] and self.dim_node_feat > 0:
if self.dim_node_feat == self.dim_hid: if self.dim_node_feat == self.dim_hid:
b.srcdata['h'] += memory b.srcdata['h'] += updated_memory
else: else:
b.srcdata['h'] = memory + self.node_feat_map(b.srcdata['h']) b.srcdata['h'] = updated_memory + self.node_feat_map(b.srcdata['h'])
else: else:
b.srcdata['h'] = memory b.srcdata['h'] = updated_memory
def empty_cache(self): def empty_cache(self):
pass pass
...@@ -86,7 +86,7 @@ class GeneralModel(torch.nn.Module): ...@@ -86,7 +86,7 @@ class GeneralModel(torch.nn.Module):
# else: # else:
# self.memory_updater = HistoricalMemeoryUpdater(memory_param, 2 * memory_param['dim_out'] + dim_edge, memory_param['dim_out'], memory_param['dim_time'], dim_node,updater=updater,learnable=True,num_nodes=num_nodes) # self.memory_updater = HistoricalMemeoryUpdater(memory_param, 2 * memory_param['dim_out'] + dim_edge, memory_param['dim_out'], memory_param['dim_time'], dim_node,updater=updater,learnable=True,num_nodes=num_nodes)
elif memory_param['memory_update'] == 'transformer': elif memory_param['memory_update'] == 'transformer':
self.memory_updater = TransformerMemoryUpdater updater = TransformerMemoryUpdater
self.memory_updater = AsyncMemeoryUpdater(memory_param, 2 * memory_param['dim_out'] + dim_edge, memory_param['dim_out'], memory_param['dim_time'], dim_node, updater=updater, mailbox=mailbox, mode = memory_param['mode'],train_param=train_param) self.memory_updater = AsyncMemeoryUpdater(memory_param, 2 * memory_param['dim_out'] + dim_edge, memory_param['dim_out'], memory_param['dim_time'], dim_node, updater=updater, mailbox=mailbox, mode = memory_param['mode'],train_param=train_param)
else: else:
raise NotImplementedError raise NotImplementedError
......
...@@ -123,7 +123,7 @@ def get_node_all_to_all_route(graph:DistributedGraphStore,mailbox:SharedMailBox, ...@@ -123,7 +123,7 @@ def get_node_all_to_all_route(graph:DistributedGraphStore,mailbox:SharedMailBox,
ind_dict = graph.nfeat.all_to_all_ind2ptr(query_nid_feature,group = group) ind_dict = graph.nfeat.all_to_all_ind2ptr(query_nid_feature,group = group)
memory,memory_ts,mail,mail_ts = mailbox.gather_local_memory(ind_dict['recv_ind'],compute_device=out_device) memory,memory_ts,mail,mail_ts = mailbox.gather_local_memory(ind_dict['recv_ind'],compute_device=out_device)
memory_ts = memory_ts.reshape(-1,1) memory_ts = memory_ts.reshape(-1,1)
mail_ts = mail_ts.reshape(-1,1) mail_ts = mail_ts.reshape(mail_ts.shape[0],-1)
mail = mail.reshape(mail.shape[0],-1) mail = mail.reshape(mail.shape[0],-1)
data.append(torch.cat((memory,memory_ts,mail,mail_ts),dim = 1)) data.append(torch.cat((memory,memory_ts,mail,mail_ts),dim = 1))
if ind_dict is not None: if ind_dict is not None:
...@@ -144,14 +144,14 @@ def get_edge_all_to_all_route(graph:DistributedGraphStore,query_eid_feature,out_ ...@@ -144,14 +144,14 @@ def get_edge_all_to_all_route(graph:DistributedGraphStore,query_eid_feature,out_
def prepare_input(node_feat, edge_feat, mem_embedding,mfgs,dist_nid=None,dist_eid=None): def prepare_input(node_feat, edge_feat, mem_embedding,mfgs,dist_nid=None,dist_eid=None):
for i,mfg in enumerate(mfgs): for i,mfg in enumerate(mfgs):
for b in mfg: for b in mfg:
e_idx = b.edata['__ID']
idx = b.srcdata['__ID'] idx = b.srcdata['__ID']
if dist_eid is not None: if '__ID' in b.edata:
e_idx = b.edata['__ID']
b.edata['ID'] = dist_eid[e_idx] b.edata['ID'] = dist_eid[e_idx]
if edge_feat is not None:
b.edata['f'] = edge_feat[e_idx]
if dist_nid is not None: if dist_nid is not None:
b.srcdata['ID'] = dist_nid[idx] b.srcdata['ID'] = dist_nid[idx]
if edge_feat is not None:
b.edata['f'] = edge_feat[e_idx]
if i == 0: if i == 0:
if node_feat is not None: if node_feat is not None:
b.srcdata['h'] = node_feat[idx] b.srcdata['h'] = node_feat[idx]
...@@ -336,15 +336,96 @@ def to_block(graph,data, sample_out,device = torch.device('cuda'),unique = True) ...@@ -336,15 +336,96 @@ def to_block(graph,data, sample_out,device = torch.device('cuda'),unique = True)
data,mfgs,metadata = build_block() data,mfgs,metadata = build_block()
return (data,mfgs,metadata),dist_nid,dist_eid return (data,mfgs,metadata),dist_nid,dist_eid
def to_reversed_block(graph,data, sample_out,device = torch.device('cuda'),unique = True,identity=False):
if len(sample_out) > 1:
sample_out,metadata = sample_out
else:
metadata = None
nid_mapper: torch.Tensor = graph.nids_mapper
if identity is False:
assert len(sample_out) == 1
ret = sample_out[0]
eid_len = ret.eid().shape[0]
t0 = time.time()
dst_ts = ret.sample_nodes_ts().to(device)
dst = nid_mapper[ret.sample_nodes()].to(device)
dist_eid = torch.tensor([],dtype=torch.long,device=device)
src_index = ret.src_index().to(device)
else:
src_index = torch.tensor([],dtype=torch.long,device=device)
dst = torch.tensor([],dtype=torch.long,device=device)
dist_eid = torch.tensor([],dtype=torch.long,device=device)
if metadata is None:
root_node = data.nodes.to(graph.nids_mapper.device)
root_len = [root_node.shape[0]]
root_ts = data.ts.to(device)
elif 'seed' in metadata:
root_node = metadata.pop('seed').to(graph.nids_mapper.device)
root_len = root_node.shape[0]
if 'seed_ts' in metadata:
root_ts = metadata.pop('seed_ts').to(device)
for k in metadata:
if isinstance(metadata[k],torch.Tensor):
metadata[k] = metadata[k].to(device)
src_node = root_node
src_ts = root_ts
nid_tensor = torch.cat([root_node],dim = 0)
dist_nid = nid_mapper[nid_tensor].to(device)
CountComm.origin_local = (DistIndex(dist_nid).part == dist.get_rank()).sum().item()
CountComm.origin_remote =(DistIndex(dist_nid).part != dist.get_rank()).sum().item()
dist_nid,nid_inv = dist_nid.unique(return_inverse=True)
"""
对于同id和同时间的节点去重取得index
"""
block_node_list,unq_id = torch.stack((nid_inv.to(torch.float64),src_ts.to(torch.float64))).unique(dim = 1,return_inverse=True)
first_index,_ = torch_scatter.scatter_min(torch.arange(unq_id.shape[0],device=unq_id.device,dtype=unq_id.dtype),unq_id)
first_mask = torch.zeros(unq_id.shape[0],device = unq_id.device,dtype=torch.bool)
first_mask[first_index] = True
first_index = unq_id[first_mask]
first_block_id = torch.empty(first_index.shape[0],device=unq_id.device,dtype=unq_id.dtype)
first_block_id[first_index] = torch.arange(first_index.shape[0],device=first_index.device,dtype=first_index.dtype)
first_block_id = first_block_id[unq_id].contiguous()
block_node_list = block_node_list[:,first_index]
for k in metadata:
if isinstance(metadata[k],torch.Tensor):
metadata[k] = first_block_id[metadata[k]]
t2 = time.time()
def build_block():
mfgs = list()
col_len = 0
row_len = root_len
col = first_block_id[:row_len]
max_row = col.max().item()+1
b = dgl.create_block((col[src_index].to(device),
torch.arange(dst.shape[0],device=device,dtype=torch.long)),num_src_nodes=first_block_id.max().item()+1,
num_dst_nodes=dst.shape[0])
idx = block_node_list[0,b.srcnodes()].to(torch.long)
b.srcdata['__ID'] = idx
b.srcdata['ts'] = block_node_list[1,b.srcnodes()].to(torch.float)
b.dstdata['ID'] = dst
mfgs.append(b)
mfgs = list(map(list, zip(*[iter(mfgs)])))
mfgs.reverse()
return data,mfgs,metadata
data,mfgs,metadata = build_block()
return (data,mfgs,metadata),dist_nid,dist_eid
import concurrent.futures import concurrent.futures
executor = concurrent.futures.ThreadPoolExecutor(max_workers=1) executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
def graph_sample(graph,sampler,sample_fn,data,neg_sampling = None,out_device = torch.device('cuda'),nid_mapper = None,eid_mapper=None): def graph_sample(graph,sampler,sample_fn,data,neg_sampling = None,out_device = torch.device('cuda'),nid_mapper = None,eid_mapper=None,reversed=False):
t_s = time.time() t_s = time.time()
param = {'is_unique':False,'nid_mapper':nid_mapper,'eid_mapper':eid_mapper,'out_device':out_device} param = {'is_unique':False,'nid_mapper':nid_mapper,'eid_mapper':eid_mapper,'out_device':out_device}
out = sample_fn(sampler,data,neg_sampling,**param) out = sample_fn(sampler,data,neg_sampling,**param)
out,dist_nid,dist_eid = to_block(graph,data,out,out_device) if reversed is False:
out,dist_nid,dist_eid = to_block(graph,data,out,out_device,reversed)
else:
out,dist_nid,dist_eid = to_reversed_block(graph,data,out,out_device,reversed)
t_e = time.time() t_e = time.time()
#print(t_e-t_s) #print(t_e-t_s)
return out,dist_nid,dist_eid return out,dist_nid,dist_eid
......
...@@ -12,6 +12,13 @@ class time_count: ...@@ -12,6 +12,13 @@ class time_count:
time_sample_and_build = 0 time_sample_and_build = 0
time_memory_fetch = 0 time_memory_fetch = 0
weight_count_remote = 0
weight_count_local = 0
ssim_remote = 0
ssim_cnt = 0
ssim_local = 0
ssim_cnt = 0
@staticmethod @staticmethod
def _zero(): def _zero():
time_count.time_forward = 0 time_count.time_forward = 0
...@@ -34,8 +41,9 @@ class time_count: ...@@ -34,8 +41,9 @@ class time_count:
def start(): def start():
return time.perf_counter(),0 return time.perf_counter(),0
@staticmethod @staticmethod
def elapsed_event(start_event,end_event): def elapsed_event(start_event):
if start_event.isinstance(torch.cuda.Event): if isinstance(start_event,tuple):
start_event,end_event = start_event
end_event.record() end_event.record()
end_event.synchronize() end_event.synchronize()
return start_event.elapsed_time(end_event) return start_event.elapsed_time(end_event)
...@@ -51,4 +59,6 @@ class time_count: ...@@ -51,4 +59,6 @@ class time_count:
time_count.time_local_update, time_count.time_local_update,
time_count.time_memory_sync, time_count.time_memory_sync,
time_count.time_sample_and_build, time_count.time_sample_and_build,
time_count.time_memory_fetch )) time_count.time_memory_fetch ))
\ No newline at end of file
\ No newline at end of file
...@@ -100,8 +100,10 @@ class DistributedDataLoader: ...@@ -100,8 +100,10 @@ class DistributedDataLoader:
cache_mask = None, cache_mask = None,
use_local_feature = True, use_local_feature = True,
probability = 1, probability = 1,
reversed = False,
**kwargs **kwargs
): ):
self.reversed = reversed
self.use_local_feature = use_local_feature self.use_local_feature = use_local_feature
self.local_embedding = local_embedding self.local_embedding = local_embedding
self.chunk_size = chunk_size self.chunk_size = chunk_size
...@@ -223,7 +225,7 @@ class DistributedDataLoader: ...@@ -223,7 +225,7 @@ class DistributedDataLoader:
if self.mode=='train' and self.probability < 1: if self.mode=='train' and self.probability < 1:
mask = ((DistIndex(self.graph.nids_mapper[next_data.edges[0,:].to('cpu')]).part == dist.get_rank())&(DistIndex(self.graph.nids_mapper[next_data.edges[1,:].to('cpu')]).part == dist.get_rank())) mask = ((DistIndex(self.graph.nids_mapper[next_data.edges[0,:].to('cpu')]).part == dist.get_rank())&(DistIndex(self.graph.nids_mapper[next_data.edges[1,:].to('cpu')]).part == dist.get_rank()))
if self.probability > 0: if self.probability > 0:
mask[~mask] = (torch.rand((~mask)) < self.probability) mask[~mask] = (torch.rand((~mask).sum().item()) < self.probability)
next_data = next_data[mask.to(next_data.device)] next_data = next_data[mask.to(next_data.device)]
self.submitted = self.submitted + 1 self.submitted = self.submitted + 1
return next_data return next_data
...@@ -239,13 +241,14 @@ class DistributedDataLoader: ...@@ -239,13 +241,14 @@ class DistributedDataLoader:
self.device, self.device,
nid_mapper = self.graph.nids_mapper, nid_mapper = self.graph.nids_mapper,
eid_mapper = self.graph.eids_mapper, eid_mapper = self.graph.eids_mapper,
reversed = self.reversed
) )
self.result_queue.append((fut)) self.result_queue.append((fut))
@torch.no_grad() @torch.no_grad()
def async_feature(self): def async_feature(self):
if(self.recv_idxs >= self.expected_idx): if(self.recv_idxs >= self.expected_idx or self.is_pipeline == False):
return return
is_local = (self.is_train & self.use_local_feature) is_local = (self.is_train & self.use_local_feature)
if(is_local): if(is_local):
...@@ -303,7 +306,9 @@ class DistributedDataLoader: ...@@ -303,7 +306,9 @@ class DistributedDataLoader:
data,self.neg_sampler, data,self.neg_sampler,
self.device, self.device,
nid_mapper = self.graph.nids_mapper, nid_mapper = self.graph.nids_mapper,
eid_mapper = self.graph.eids_mapper) eid_mapper = self.graph.eids_mapper,
reversed = self.reversed
)
root,mfgs,metadata = batch_data root,mfgs,metadata = batch_data
t_sample = tt.elapsed_event(t0) t_sample = tt.elapsed_event(t0)
...@@ -312,6 +317,12 @@ class DistributedDataLoader: ...@@ -312,6 +317,12 @@ class DistributedDataLoader:
edge_feat = get_edge_feature_by_dist(self.graph,dist_eid,is_local,out_device=self.device) edge_feat = get_edge_feature_by_dist(self.graph,dist_eid,is_local,out_device=self.device)
node_feat,mem = get_node_feature_by_dist(self.graph,self.mailbox,dist_nid, is_local,out_device=self.device) node_feat,mem = get_node_feature_by_dist(self.graph,self.mailbox,dist_nid, is_local,out_device=self.device)
prepare_input(node_feat,edge_feat,mem,mfgs,dist_nid,dist_eid) prepare_input(node_feat,edge_feat,mem,mfgs,dist_nid,dist_eid)
batch_data[1][0][0].srcdata['his_mem'] = batch_data[1][0][0].srcdata['mem'].clone()
batch_data[1][0][0].srcdata['his_ts'] = batch_data[1][0][0].srcdata['mail_ts'].clone()
mask = DistIndex(batch_data[1][0][0].srcdata['ID']).is_shared
indx = self.mailbox.is_shared_mask[DistIndex(batch_data[1][0][0].srcdata['ID']).loc[mask]]
batch_data[1][0][0].srcdata['his_mem'][mask] = self.mailbox.historical_cache.local_historical_data[indx]
batch_data[1][0][0].srcdata['his_ts'][mask] = self.mailbox.historical_cache.local_ts[indx].reshape(-1,1)
t_fetch = tt.elapsed_event(t1) t_fetch = tt.elapsed_event(t1)
tt.time_memory_fetch += t_fetch tt.time_memory_fetch += t_fetch
#if(self.mailbox is not None and self.mailbox.historical_cache is not None): #if(self.mailbox is not None and self.mailbox.historical_cache is not None):
...@@ -335,10 +346,15 @@ class DistributedDataLoader: ...@@ -335,10 +346,15 @@ class DistributedDataLoader:
data,self.neg_sampler, data,self.neg_sampler,
self.device, self.device,
nid_mapper = self.graph.nids_mapper, nid_mapper = self.graph.nids_mapper,
eid_mapper = self.graph.eids_mapper) eid_mapper = self.graph.eids_mapper,
reversed = self.reversed
)
edge_feat = get_edge_feature_by_dist(self.graph,dist_eid,is_local,out_device=self.device) edge_feat = get_edge_feature_by_dist(self.graph,dist_eid,is_local,out_device=self.device)
node_feat,mem = get_node_feature_by_dist(self.graph,self.mailbox,dist_nid, is_local,out_device=self.device) node_feat,mem = get_node_feature_by_dist(self.graph,self.mailbox,dist_nid, is_local,out_device=self.device)
prepare_input(node_feat,edge_feat,mem,batch_data[1],dist_nid,dist_eid) prepare_input(node_feat,edge_feat,mem,batch_data[1],dist_nid,dist_eid)
batch_data[1][0][0].srcdata['his_mem'] = batch_data[1][0][0].srcdata['mem'].clone()
batch_data[1][0][0].srcdata['his_ts'] = batch_data[1][0][0].srcdata['mail_ts'].clone()
#if(self.mailbox is not None and self.mailbox.historical_cache is not None): #if(self.mailbox is not None and self.mailbox.historical_cache is not None):
# id = batch_data[1][0][0].srcdata['ID'] # id = batch_data[1][0][0].srcdata['ID']
# mask = DistIndex(id).is_shared # mask = DistIndex(id).is_shared
...@@ -363,14 +379,14 @@ class DistributedDataLoader: ...@@ -363,14 +379,14 @@ class DistributedDataLoader:
node_feat0 = node_feat0[0] node_feat0 = node_feat0[0]
node_feat = node_feat0[:,:self.graph.nfeat.shape[1]] node_feat = node_feat0[:,:self.graph.nfeat.shape[1]]
if self.graph.nfeat.shape[1] < node_feat0.shape[1]: if self.graph.nfeat.shape[1] < node_feat0.shape[1]:
mem = self.mailbox.unpack(node_feat0[:,self.graph.nfeat.shape[1]:]) mem = self.mailbox.unpack(node_feat0[:,self.graph.nfeat.shape[1]:],mailbox = True)
else: else:
mem = None mem = None
elif self.mailbox is not None: elif self.mailbox is not None:
node_feat0[1].wait() node_feat0[1].wait()
node_feat0 = node_feat0[0] node_feat0 = node_feat0[0]
node_feat = None node_feat = None
mem = self.mailbox.unpack(node_feat0) mem = self.mailbox.unpack(node_feat0,mailbox = True)
#print(node_feat.shape,edge_feat.shape,mem[0].shape) #print(node_feat.shape,edge_feat.shape,mem[0].shape)
#node_feat[1].wait() #node_feat[1].wait()
#node_feat = node_feat[0] #node_feat = node_feat[0]
...@@ -394,7 +410,11 @@ class DistributedDataLoader: ...@@ -394,7 +410,11 @@ class DistributedDataLoader:
if(self.mailbox is not None and self.mailbox.historical_cache is not None): if(self.mailbox is not None and self.mailbox.historical_cache is not None):
id = batch_data[1][0][0].srcdata['ID'] id = batch_data[1][0][0].srcdata['ID']
mask = DistIndex(id).is_shared mask = DistIndex(id).is_shared
batch_data[1][0][0].srcdata['mem'][mask] = self.mailbox.node_memory.accessor.data[DistIndex(id).loc[mask]] batch_data[1][0][0].srcdata['his_mem'] = batch_data[1][0][0].srcdata['mem'].clone()
batch_data[1][0][0].srcdata['his_ts'] = batch_data[1][0][0].srcdata['mail_ts'].clone()
indx = self.mailbox.is_shared_mask[DistIndex(batch_data[1][0][0].srcdata['ID']).loc[mask]]
batch_data[1][0][0].srcdata['his_mem'][mask] = self.mailbox.historical_cache.local_historical_data[indx]
batch_data[1][0][0].srcdata['his_ts'][mask] = self.mailbox.historical_cache.local_ts[indx].reshape(-1,1)#self.mailbox.node_memory.accessor.data[DistIndex(id).loc[mask]]
#his_mem = torch.clone(batch_data[1][0][0].srcdata['mem']) #his_mem = torch.clone(batch_data[1][0][0].srcdata['mem'])
#his_ts = self.mailbox.historical_cache.local_ts[DistIndex(id[mask]).loc] #his_ts = self.mailbox.historical_cache.local_ts[DistIndex(id[mask]).loc]
#maxer = his_ts>batch_data[1][0][0].srcdata['mem_ts'][mask] #maxer = his_ts>batch_data[1][0][0].srcdata['mem_ts'][mask]
......
...@@ -17,10 +17,10 @@ class MemoryMoniter: ...@@ -17,10 +17,10 @@ class MemoryMoniter:
self.memorychange.append(self.ssim(pre_memory,now_memory,method = 'F')) self.memorychange.append(self.ssim(pre_memory,now_memory,method = 'F'))
self.memory_ssim.append(self.ssim(pre_memory,now_memory,method = 'cos')) self.memory_ssim.append(self.ssim(pre_memory,now_memory,method = 'cos'))
self.nid_list.append(nid) self.nid_list.append(nid)
def draw(self,degree,data,e): def draw(self,degree,data,model,e):
torch.save(self.nid_list,'all/{}/memorynid_{}.pt'.format(data,e)) torch.save(self.nid_list,'all/{}/{}/memorynid_{}.pt'.format(data,model,e))
torch.save(self.memorychange,'all/{}/memoryF_{}.pt'.format(data,e)) torch.save(self.memorychange,'all/{}/{}/memoryF_{}.pt'.format(data,model,e))
torch.save(self.memory_ssim,'all/{}/memcos_{}.pt'.format(data,e)) torch.save(self.memory_ssim,'all/{}/{}/memcos_{}.pt'.format(data,model,e))
# path = './memory/{}/'.format(data) # path = './memory/{}/'.format(data)
# if not os.path.exists(path): # if not os.path.exists(path):
......
...@@ -98,6 +98,7 @@ class SharedMailBox(): ...@@ -98,6 +98,7 @@ class SharedMailBox():
self.tot_comm_count = 0 self.tot_comm_count = 0
self.tot_shared_count = 0 self.tot_shared_count = 0
self.shared_nodes_index = None self.shared_nodes_index = None
self.deliver_to = memory_param['deliver_to'] if 'deliver_to' in memory_param else 'self'
if shared_nodes_index is not None: if shared_nodes_index is not None:
self.shared_nodes_index = shared_nodes_index.to('cuda:{}'.format(ctx.local_rank)) self.shared_nodes_index = shared_nodes_index.to('cuda:{}'.format(ctx.local_rank))
self.is_shared_mask = torch.zeros(self.num_nodes,dtype=torch.int,device=torch.device('cuda:{}'.format(ctx.local_rank)))-1 self.is_shared_mask = torch.zeros(self.num_nodes,dtype=torch.int,device=torch.device('cuda:{}'.format(ctx.local_rank)))-1
...@@ -152,7 +153,7 @@ class SharedMailBox(): ...@@ -152,7 +153,7 @@ class SharedMailBox():
def get_update_mail(self,dist_indx_mapper, def get_update_mail(self,dist_indx_mapper,
src,dst,ts,edge_feats, src,dst,ts,edge_feats,
memory,embedding=None,use_src_emb=False,use_dst_emb=False, memory,embedding=None,use_src_emb=False,use_dst_emb=False,
deliver_to='self',block = None,Reduce_score=None,): block = None,Reduce_score=None,):
if edge_feats is not None: if edge_feats is not None:
edge_feats = edge_feats.to(self.device).to(self.mailbox.dtype) edge_feats = edge_feats.to(self.device).to(self.mailbox.dtype)
src = src.to(self.device) src = src.to(self.device)
...@@ -172,11 +173,13 @@ class SharedMailBox(): ...@@ -172,11 +173,13 @@ class SharedMailBox():
mail = torch.cat([src_mail, dst_mail], dim=0) mail = torch.cat([src_mail, dst_mail], dim=0)
mail_ts = torch.cat((ts,ts),-1).to(self.device).to(self.mailbox_ts.dtype) mail_ts = torch.cat((ts,ts),-1).to(self.device).to(self.mailbox_ts.dtype)
#print(mail_ts) #print(mail_ts)
if deliver_to == 'neighbor': #print(self.deliver_to)
assert block is None and Reduce_score is not None if self.deliver_to == 'neighbors':
mail = torch.cat([mail, mail[block.edges()[1].long()]], dim=0)
mail_ts = torch.cat([mail_ts, mail_ts[block.edges()[1].long()]], dim=0) assert block is not None and Reduce_score is None
index = torch.cat([index,dist_indx_mapper[block.edges()[0].long()]],dim=0) mail = torch.cat([mail, mail[block.edges()[0].long()]], dim=0)
mail_ts = torch.cat([mail_ts, mail_ts[block.edges()[0].long()]], dim=0)
index = torch.cat([index,block.dstdata['ID'][block.edges()[1].long()]],dim=0)
if Reduce_score is not None: if Reduce_score is not None:
Reduce_score = torch.cat((Reduce_score,Reduce_score),-1).to(self.device) Reduce_score = torch.cat((Reduce_score,Reduce_score),-1).to(self.device)
if Reduce_score is None: if Reduce_score is None:
...@@ -214,18 +217,23 @@ class SharedMailBox(): ...@@ -214,18 +217,23 @@ class SharedMailBox():
#else: #else:
# mem = torch.cat((memory,memory_ts.view(-1,1),mail,mail_ts.view(-1,1),index.to(torch.float32).view(-1,1)),dim = 1) # mem = torch.cat((memory,memory_ts.view(-1,1),mail,mail_ts.view(-1,1),index.to(torch.float32).view(-1,1)),dim = 1)
return mem return mem
def unpack(self,mem): def unpack(self,mem,mailbox = False):
if mem.shape[1] == self.node_memory.shape[1] + 1 or mem.shape[1] == self.mailbox.shape[2] + 1 : if mem.shape[1] == self.node_memory.shape[1] + 1 or mem.shape[1] == self.mailbox.shape[2] + 1 :
mail = mem[:,: -1] mail = mem[:,: -1]
mail_ts = mem[:,-1].view(-1) mail_ts = mem[:,-1].view(-1)
return mail,mail_ts return mail,mail_ts
else: elif mailbox is False:
memory = mem[:,:self.node_memory.shape[1]] memory = mem[:,:self.node_memory.shape[1]]
memory_ts = mem[:,self.node_memory.shape[1]].view(-1) memory_ts = mem[:,self.node_memory.shape[1]].view(-1)
mail = mem[:,self.node_memory.shape[1]+1:-1] mail = mem[:,self.node_memory.shape[1]+1:-1]
mail_ts = mem[:,-1].view(-1) mail_ts = mem[:,-1].view(-1)
return memory,memory_ts,mail,mail_ts return memory,memory_ts,mail,mail_ts
else:
memory = mem[:,:self.node_memory.shape[1]]
memory_ts = mem[:,self.node_memory.shape[1]].view(-1)
mail = mem[:,self.node_memory.shape[1]+1:mem.shape[1]-self.mailbox_ts.shape[1]].reshape(mem.shape[0],self.mailbox.shape[1],-1)
mail_ts = mem[:,mem.shape[1]-self.mailbox_ts.shape[1]:]
return memory,memory_ts,mail,mail_ts
def handle_last_async(self,reduce_Op = None): def handle_last_async(self,reduce_Op = None):
if self.last_memory_sync is not None: if self.last_memory_sync is not None:
...@@ -247,10 +255,11 @@ class SharedMailBox(): ...@@ -247,10 +255,11 @@ class SharedMailBox():
if out is not None: if out is not None:
shared_index,shared_data,shared_ts,mail,mail_ts = out shared_index,shared_data,shared_ts,mail,mail_ts = out
index = self.shared_nodes_index[shared_index] index = self.shared_nodes_index[shared_index]
self.node_memory.accessor.data[index] = shared_data mask= (shared_ts > self.node_memory_ts.accessor.data[index])
self.node_memory_ts.accessor.data[index] = shared_ts self.node_memory.accessor.data[index][mask] = shared_data[mask]
self.mailbox.accessor.data[index, torch.max(self.next_mail_pos[index]-1,torch.tensor([0],device=mail.device))] = mail self.node_memory_ts.accessor.data[index][mask] = shared_ts[mask]
self.mailbox_ts.accessor.data[index, torch.max(self.next_mail_pos[index]-1,torch.tensor([0],device=mail_ts.device))] = mail_ts #self.mailbox.accessor.data[index, torch.max(self.next_mail_pos[index]-1,torch.tensor([0],device=mail.device))] = mail
#self.mailbox_ts.accessor.data[index, torch.max(self.next_mail_pos[index]-1,torch.tensor([0],device=mail_ts.device))] = mail_ts
def update_shared(self): def update_shared(self):
ctx = DistributedContext.get_default_context() ctx = DistributedContext.get_default_context()
if self.last_job is not None: if self.last_job is not None:
...@@ -273,6 +282,7 @@ class SharedMailBox(): ...@@ -273,6 +282,7 @@ class SharedMailBox():
output_split_sizes=output_split, output_split_sizes=output_split,
input_split_sizes=input_split,group = group,async_op=async_op) input_split_sizes=input_split,group = group,async_op=async_op)
self.last_memory_sync = (gather_id_list,handle0,gather_memory,handle1) self.last_memory_sync = (gather_id_list,handle0,gather_memory,handle1)
def set_memory_all_reduce(self,index,memory,memory_ts,mail,mail_ts,reduce_Op = None,async_op = True,set_remote = False,mode=None,filter=None,submit = True): def set_memory_all_reduce(self,index,memory,memory_ts,mail,mail_ts,reduce_Op = None,async_op = True,set_remote = False,mode=None,filter=None,submit = True):
ctx = DistributedContext.get_default_context() ctx = DistributedContext.get_default_context()
#print(DistIndex(index).part) #print(DistIndex(index).part)
...@@ -337,6 +347,7 @@ class SharedMailBox(): ...@@ -337,6 +347,7 @@ class SharedMailBox():
#,shared_memory,shared_memory_ts, #,shared_memory,shared_memory_ts,
#shared_memory,shared_memory_ts = self.unpack(mem) #shared_memory,shared_memory_ts = self.unpack(mem)
shared_memory,shared_memory_ts,shared_mail,shared_mail_ts = self.unpack(mem) shared_memory,shared_memory_ts,shared_mail,shared_mail_ts = self.unpack(mem)
#print(shared_memory_ts,shared_mail_ts)
unq_index,inv = torch.unique(shared_index,return_inverse = True) unq_index,inv = torch.unique(shared_index,return_inverse = True)
#print(inv.shape,Reduce_score.shape) #print(inv.shape,Reduce_score.shape)
max_ts,idx = torch_scatter.scatter_max(shared_memory_ts,inv,0) max_ts,idx = torch_scatter.scatter_max(shared_memory_ts,inv,0)
......
...@@ -20,19 +20,23 @@ class LocalNegativeSampling(NegativeSampling): ...@@ -20,19 +20,23 @@ class LocalNegativeSampling(NegativeSampling):
unique: bool = False, unique: bool = False,
src_node_list: torch.Tensor = None, src_node_list: torch.Tensor = None,
dst_node_list: torch.Tensor = None, dst_node_list: torch.Tensor = None,
seed = False local_mask = None,
seed = None
): ):
super(LocalNegativeSampling,self).__init__(mode,amount,unique=unique) super(LocalNegativeSampling,self).__init__(mode,amount,unique=unique)
self.src_node_list = src_node_list.to('cpu') if src_node_list is not None else None self.src_node_list = src_node_list.to('cpu') if src_node_list is not None else None
self.dst_node_list = dst_node_list.to('cpu') if dst_node_list is not None else None self.dst_node_list = dst_node_list.to('cpu') if dst_node_list is not None else None
self.rdm = torch.Generator() self.rdm = torch.Generator()
if seed is True: if seed is not None:
random.seed(seed) random.seed(seed)
seed = random.randint(0,100000) seed = random.randint(0,100000)
print('seed is',seed) print('seed is',seed)
ctx = DistributedContext.get_default_context() ctx = DistributedContext.get_default_context()
self.rdm.manual_seed(seed^ctx.rank) self.rdm.manual_seed(seed^ctx.rank)
self.rdm = torch.Generator() self.rdm = torch.Generator()
self.local_mask = local_mask
if self.local_mask is not None:
self.local_dst = dst_node_list[local_mask]
#self.rdm.manual_seed(42) #self.rdm.manual_seed(42)
#print('dst_nde_list {}\n'.format(dst_node_list)) #print('dst_nde_list {}\n'.format(dst_node_list))
def is_binary(self) -> bool: def is_binary(self) -> bool:
...@@ -53,6 +57,12 @@ class LocalNegativeSampling(NegativeSampling): ...@@ -53,6 +57,12 @@ class LocalNegativeSampling(NegativeSampling):
else: else:
if self.dst_node_list is None: if self.dst_node_list is None:
return torch.randint(num_nodes, (num_samples, ),generator=self.rdm) return torch.randint(num_nodes, (num_samples, ),generator=self.rdm)
elif self.local_mask is not None:
p = torch.rand(size=(num_samples,))
sr = torch.randint(len(self.dst_node_list), (num_samples, ),generator=self.rdm)
sl = torch.randint(len(self.local_dst), (num_samples, ),generator=self.rdm)
return torch.where(p<0.9,sl,sr)
else: else:
return self.dst_node_list[torch.randint(len(self.dst_node_list), (num_samples, ),generator=self.rdm)] s = torch.randint(len(self.dst_node_list), (num_samples, ),generator=self.rdm)
return self.dst_node_list[s]
...@@ -227,29 +227,14 @@ class NeighborSampler(BaseSampler): ...@@ -227,29 +227,14 @@ class NeighborSampler(BaseSampler):
sampled_nodes: the node sampled sampled_nodes: the node sampled
sampled_edge_index_list: the edge sampled sampled_edge_index_list: the edge sampled
""" """
if(ts is None): if self.policy != 'identity':
self.part_unique = True
self.p_sampler.neighbor_sample_from_nodes(nodes.contiguous(), None, self.part_unique)
ret = self.p_sampler.get_ret()
return ret
else:
self.p_sampler.neighbor_sample_from_nodes(nodes.contiguous(), ts.contiguous(), None) self.p_sampler.neighbor_sample_from_nodes(nodes.contiguous(), ts.contiguous(), None)
ret = self.p_sampler.get_ret() ret = self.p_sampler.get_ret()
metadata = {} else:
if is_unique: ret = None
self.p_sampler.sample_unique( metadata = {}
nodes,ts.float(),nid_mapper, #print(nodes.shape[0],ret[0].src_index().max(),ret[0].src_index().min())
eid_mapper,'cpu' if out_device.type == 'cpu' else str(out_device.index)) return ret,metadata
metadata = {
'eid_inv':self.p_sampler.eid_inv,
'first_block_id':self.p_sampler.first_block_id,
'block_node_list':self.p_sampler.block_node_list,
'unq_id':self.p_sampler.unq_id,
'dist_nid':self.p_sampler.dist_nid,
'dist_eid':self.p_sampler.dist_eid,
}
#print(nodes.shape[0],ret[0].src_index().max(),ret[0].src_index().min())
return ret,metadata
def sample_from_edges( def sample_from_edges(
self, self,
...@@ -329,13 +314,7 @@ class NeighborSampler(BaseSampler): ...@@ -329,13 +314,7 @@ class NeighborSampler(BaseSampler):
else: else:
seed, inverse_seed = seed.unique(return_inverse=True) seed, inverse_seed = seed.unique(return_inverse=True)
""" """
metadata = {} out,metadata = self.sample_from_nodes(seed, seed_ts, is_unique=False)
if is_unique:
out,metadata = self.sample_from_nodes(seed, seed_ts, is_unique=is_unique,eid_mapper=eid_mapper,nid_mapper=nid_mapper,out_device=out_device)
first_block_id = self.p_sampler.first_block_id
else:
#print('is unique')
out,metadata = self.sample_from_nodes(seed, seed_ts, is_unique=False)
src_pos_index = torch.arange(0,num_pos,dtype= torch.long,device=out_device) src_pos_index = torch.arange(0,num_pos,dtype= torch.long,device=out_device)
dst_pos_index = torch.arange(num_pos,2*num_pos,dtype= torch.long,device=out_device) dst_pos_index = torch.arange(num_pos,2*num_pos,dtype= torch.long,device=out_device)
if neg_sampling.is_triplet() or neg_sampling.is_tgbtriplet(): if neg_sampling.is_triplet() or neg_sampling.is_tgbtriplet():
...@@ -344,12 +323,6 @@ class NeighborSampler(BaseSampler): ...@@ -344,12 +323,6 @@ class NeighborSampler(BaseSampler):
else: else:
src_neg_index = torch.arange(2*num_pos,3*num_pos,dtype= torch.long,device=out_device) src_neg_index = torch.arange(2*num_pos,3*num_pos,dtype= torch.long,device=out_device)
dst_neg_index=torch.arange(3*num_pos,seed.shape[0],dtype= torch.long,device=out_device) dst_neg_index=torch.arange(3*num_pos,seed.shape[0],dtype= torch.long,device=out_device)
#if is_unique:
# src_pos_index = first_block_id[src_pos_index].contiguous()
# dst_pos_index = first_block_id[dst_pos_index].contiguous()
# dst_neg_index = first_block_id[dst_neg_index].contiguous()
# src_neg_index = first_block_id[src_neg_index].contiguous()
metadata['seed'] = seed metadata['seed'] = seed
metadata['seed_ts'] = seed_ts metadata['seed_ts'] = seed_ts
metadata['src_pos_index']=src_pos_index metadata['src_pos_index']=src_pos_index
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment