Commit e713643f by zhlj

update memory

parent e0711e14
......@@ -9,6 +9,7 @@ __pycache__/
*.so
# Distribution / packaging
examples/all*
.Python
examples/all*
build/
......
......@@ -14,6 +14,8 @@ memory:
deliver_to: 'neighbors'
mail_combine: 'last'
memory_update: 'transformer'
historical_fix: False
async: True
attention_head: 2
mailbox_size: 10
combine_node_feature: False
......@@ -22,7 +24,7 @@ gnn:
- arch: 'identity'
train:
- epoch: 100
batch_size: 600
batch_size: 1000
lr: 0.0001
dropout: 0.1
att_dropout: 0.1
......
......@@ -7,6 +7,8 @@ memory:
deliver_to: 'self'
mail_combine: 'last'
memory_update: 'rnn'
historical_fix: False
async: True
mailbox_size: 1
combine_node_feature: True
dim_out: 100
......@@ -16,8 +18,8 @@ gnn:
use_dst_emb: False
time_transform: 'JODIE'
train:
- epoch: 100
- epoch: 250
batch_size: 1000
lr: 0.0001
lr: 0.0002
dropout: 0.1
all_on_gpu: True
\ No newline at end of file
......@@ -27,7 +27,7 @@ gnn:
dim_time: 100
dim_out: 100
train:
- epoch: 100
- epoch: 200
batch_size: 1000
# reorder: 16
lr: 0.0004
......
......@@ -4,11 +4,11 @@ import torch
# 读取文件内容
ssim_values = [0, 0.1, 0.2, 0.3, 0.4, 2] # 假设这是你的 ssim 参数值
probability_values = [1,0.5,0.1,0.05,0.01,0]
data_values = ['WIKI'] # 存储从文件中读取的数据
data_values = ['WikiTalk','StackOverflow'] # 存储从文件中读取的数据
partition = 'ours_shared'
# 从文件中读取数据,假设数据存储在文件 data.txt 中
#all/"$data"/"$partitions"-ours_shared-0.01-"$mem"-"$ssim"-"$sample".out
partitions=4
partitions=8
topk=0.01
mem='all_update'#'historical'
for data in data_values:
......@@ -35,11 +35,12 @@ for data in data_values:
# 绘制柱状图
print('{} TestAP={}\n'.format(data,ap_list))
bar_width = 0.4
#shared comm tensor
# 设置柱状图的位置
bars = range(len(ssim_values))
bars = range(len(probability_values))
# 绘制柱状图
plt.bar([b for b in bars], ap_list, width=bar_width)
......@@ -49,7 +50,7 @@ for data in data_values:
plt.xlabel('probability')
plt.ylabel('Test AP')
plt.title('{}({} partitions)'.format(data,partitions))
plt.savefig('boundary_AP_{}.png'.format(data))
plt.savefig('boundary_AP_{}_{}.png'.format(data,partitions))
plt.clf()
plt.bar([b for b in bars], comm_list, width=bar_width)
......@@ -58,7 +59,7 @@ for data in data_values:
plt.xlabel('probability')
plt.ylabel('Communication volume')
plt.title('{}({} partitions)'.format(data,partitions))
plt.savefig('boundary_comm_{}.png'.format(data))
plt.savefig('boundary_comm_{}_{}.png'.format(data,partitions))
plt.clf()
if partition == 'ours_shared':
......@@ -76,5 +77,5 @@ for data in data_values:
plt.title('{}({} partitions)'.format(data,partitions))
# plt.grid(True)
plt.legend()
plt.savefig('{}_boundary_Convergence_rate.png'.format(data))
plt.savefig('{}_{}_boundary_Convergence_rate.png'.format(data,partitions))
plt.clf()
......@@ -2,12 +2,13 @@ import matplotlib.pyplot as plt
import numpy as np
import torch
# 读取文件内容
ssim_values = [0, 0.5, 1.0, 1.5, 2] # 假设这是你的 ssim 参数值
data_values = ['WIKI','WikiTalk','REDDIT','LASTFM','DGraphFin'] # 存储从文件中读取的数据
ssim_values = [-1,0,0.3,0.7,2] # 假设这是你的 ssim 参数值
data_values = ['WIKI','LASTFM','WikiTalk','REDDIT','LASTFM','DGraphFin'] # 存储从文件中读取的数据
partition = 'ours_shared'
# 从文件中读取数据,假设数据存储在文件 data.txt 中
#all/"$data"/"$partitions"-ours_shared-0.01-"$mem"-"$ssim"-"$sample".out
partitions=4
model = 'JODIE'
topk=0.01
mem='historical'
for data in data_values:
......@@ -15,9 +16,11 @@ for data in data_values:
comm_list = []
for ssim in ssim_values:
if ssim == 2:
file = '{}/{}-{}-{}-local-recent.out'.format(data,partitions,partition,topk)
file = '{}/{}/{}-{}-{}-local-recent.out'.format(data,model,partitions,partition,topk)
elif ssim == -1:
file = '{}/{}/{}-{}-{}-all_update-recent.out'.format(data,model,partitions,partition,topk)
else:
file = '{}/{}-{}-{}-{}-{}-recent.out'.format(data,partitions,partition,topk,mem,ssim)
file = '{}/{}/{}-{}-{}-{}-{}-recent.out'.format(data,model,partitions,partition,topk,mem,ssim)
prefix = 'best test AP:'
with open(file, 'r') as file:
for line in file:
......@@ -26,6 +29,7 @@ for data in data_values:
pos = line.find('shared comm tensor')
if(pos!=-1):
comm = int(line[pos+2+len('shared comm tensor'):len(line)-3])
print(ap)
ap_list.append(ap)
comm_list.append(comm)
print('{} TestAP={}\n'.format(data,ap_list))
......@@ -33,7 +37,7 @@ for data in data_values:
# 绘制柱状图
bar_width = 0.4
#shared comm tensor
print('{} TestAP={}\n'.format(data,ap_list))
# 设置柱状图的位置
bars = range(len(ssim_values))
......@@ -43,8 +47,10 @@ for data in data_values:
plt.xticks([b for b in bars], ssim_values)
plt.xlabel('SSIM threshold Values')
plt.ylabel('Test AP')
#if(data=='WIKI'):
# plt.ylim([0.97,1])
plt.title('{}({} partitions)'.format(data,partitions))
plt.savefig('ssim_{}_{}.png'.format(data,partitions))
plt.savefig('ssim_{}_{}_{}.png'.format(data,partitions,model))
plt.clf()
plt.bar([b for b in bars], comm_list, width=bar_width)
......@@ -53,7 +59,7 @@ for data in data_values:
plt.xlabel('SSIM threshold Values')
plt.ylabel('Communication volume')
plt.title('{}({} partitions)'.format(data,partitions))
plt.savefig('ssim_comm_{}_{}.png'.format(data,partitions))
plt.savefig('ssim_comm_{}_{}_{}.png'.format(data,partitions,model))
plt.clf()
if partition == 'ours_shared':
......@@ -62,18 +68,28 @@ for data in data_values:
partition0=partition
for ssim in ssim_values:
if ssim == 2:
file = '{}/val_{}_{}_{}_0_recent_0.1_local_2.pt'.format(data,partition0,topk,partitions,)
file = '{}/{}/test_{}_{}_{}_0_recent_0.1_local_2.pt'.format(data,model,partition0,topk,partitions,)
elif ssim == -1:
file = '{}/{}/test_{}_{}_{}_0_recent_0.1_all_update_2.pt'.format(data,model,partition0,topk,partitions,)
else:
file = '{}/val_{}_{}_{}_0_recent_0.1_{}_{}.pt'.format(data,partition0,topk,partitions,mem,float(ssim))
val_ap = torch.tensor(torch.load(file))
file = '{}/{}/test_{}_{}_{}_0_recent_0.1_{}_{}.pt'.format(data,model,partition0,topk,partitions,mem,float(ssim))
val_ap = torch.tensor(torch.load(file))[:,0]
print(val_ap)
epoch = torch.arange(val_ap.shape[0])
#绘制曲线图
print(val_ap)
plt.plot(epoch,val_ap, label='ssim={}'.format(ssim))
#print(val_ap)
if ssim == -1:
plt.plot(epoch,val_ap, label='all-update')
elif ssim == 2:
plt.plot(epoch,val_ap, label='local')
else:
plt.plot(epoch,val_ap, label='ssim = {}'.format(ssim))
if(data=='WIKI'):
plt.ylim([0.85,0.90])
plt.xlabel('Epoch')
plt.ylabel('Val AP')
plt.title('{}({} partitions)'.format(data,partitions))
# plt.grid(True)
plt.legend()
plt.savefig('{}_{}_ssim_Convergence_rate.png'.format(data,partitions))
plt.savefig('{}_{}_{}_ssim_Convergence_rate.png'.format(data,partitions,model))
plt.clf()
......@@ -41,7 +41,7 @@ pinn_memory = {}
class HistoricalCache:
def __init__(self,cache_index,layer,shape,dtype,device,threshold = 3,time_threshold = None, times_threshold = 5, use_rpc = True, num_threshold = 0):
def __init__(self,cache_index,layer,shape,dtype,device,threshold = 3,time_threshold = None, times_threshold = 10, use_rpc = True, num_threshold = 0):
#self.cache_index = cache_index
self.layer = layer
print(shape)
......@@ -88,12 +88,17 @@ class HistoricalCache:
return torch.sum((x -y)**2,dim = 1)
def historical_check(self,index,new_data,ts):
if self.time_threshold is not None:
return (self.ssim(new_data,self.local_historical_data[index]) > self.threshold | (ts - self.local_ts[index] > self.time_threshold | self.loss_count[index] > self.times_threshold))
mask = (self.ssim(new_data,self.local_historical_data[index]) > self.threshold | (ts - self.local_ts[index] > self.time_threshold | self.loss_count[index] > self.times_threshold))
self.loss_count[index][~mask] += 1
self.loss_count[index][mask] = 0
else:
#print('{} {} {} {} \n'.format(index,self.ssim(new_data,self.local_historical_data[index]),new_data,self.local_historical_data[index]))
#print(new_data,self.local_historical_data[index])
#print(self.ssim(new_data,self.local_historical_data[index]) < self.threshold, (self.loss_count[index] > self.times_threshold))
return (self.ssim(new_data,self.local_historical_data[index]) > self.threshold) | (self.loss_count[index] > self.times_threshold)
mask = (self.ssim(new_data,self.local_historical_data[index]) > self.threshold) | (self.loss_count[index] > self.times_threshold)
self.loss_count[index][~mask] += 1
self.loss_count[index][mask] = 0
return mask
def read_synchronize(self):
torch.cuda.synchronize(get_stream_set(self.layer))
......
......@@ -2,6 +2,8 @@ from os.path import abspath, join, dirname
import os
import sys
from os.path import abspath, join, dirname
from starrygl.distributed.utils import DistIndex
sys.path.insert(0, join(abspath(dirname(__file__))))
import torch
import dgl
......@@ -286,9 +288,22 @@ class TransfomerAttentionLayer(torch.nn.Module):
#att = dgl.ops.e_div_v(b,att_e_sub_max,torch.clamp_min(dgl.ops.copy_e_sum(b,att_e_sub_max),1))
att = dgl.ops.edge_softmax(b, self.att_act(torch.sum(Q*K, dim=2)))
att = self.att_dropout(att)
tt.weight_count_remote+=torch.sum(att[DistIndex(b.srcdata['ID']).part[b.edges()[0]]!=torch.distributed.get_rank()]**2)
tt.weight_count_local+=torch.sum(att[DistIndex(b.srcdata['ID']).part[b.edges()[0]]==torch.distributed.get_rank()]**2)
V = torch.reshape(V*att[:, :, None], (V.shape[0], -1))
V_local = V.clone()
V_remote = V.clone()
V_local[DistIndex(b.srcdata['ID']).part[b.edges()[0]]!=torch.distributed.get_rank()] = 0
V_remote[DistIndex(b.srcdata['ID']).part[b.edges()[0]]==torch.distributed.get_rank()] = 0
b.edata['v'] = V
b.edata['v0'] = V_local
b.edata['v1'] = V_remote
b.update_all(dgl.function.copy_e('v0', 'm0'), dgl.function.sum('m0', 'h0'))
b.update_all(dgl.function.copy_e('v1', 'm1'), dgl.function.sum('m1', 'h1'))
b.update_all(dgl.function.copy_e('v', 'm'), dgl.function.sum('m', 'h'))
tt.ssim_local+=torch.sum(torch.cosine_similarity(b.dstdata['h'],b.dstdata['h0']))
tt.ssim_remote+=torch.sum(torch.cosine_similarity(b.dstdata['h'],b.dstdata['h1']))
tt.ssim_cnt += b.num_dst_nodes()
#print('dst {}'.format(b.dstdata['h']))
#b.srcdata['v'] = torch.cat([torch.zeros((b.num_dst_nodes(), V.shape[1]), device=torch.device('cuda:0')), V], dim=0)
#b.update_all(dgl.function.copy_u('v', 'm'), dgl.function.sum('m', 'h'))
......
......@@ -332,6 +332,7 @@ class TransformerMemoryUpdater(torch.nn.Module):
self.dropout = torch.nn.Dropout(train_param['dropout'])
self.att_dropout = torch.nn.Dropout(train_param['att_dropout'])
def forward(self, b, param = None):
Q = self.w_q(b.srcdata['mem']).reshape((b.num_src_nodes(), self.att_h, -1))
mails = b.srcdata['mem_input'].reshape((b.num_src_nodes(), self.memory_param['mailbox_size'], -1))
......@@ -396,7 +397,7 @@ class AsyncMemeoryUpdater(torch.nn.Module):
self.ceil_updater = updater(memory_param, dim_in, dim_hid, dim_time, train_param)
self.updater = self.transformer_updater
else:
self.updater = updater(dim_in + dim_time, dim_hid)
self.ceil_updater = updater(dim_in + dim_time, dim_hid)
self.updater = self.rnn_updater
self.last_updated_memory = None
self.last_updated_ts = None
......@@ -419,13 +420,30 @@ class AsyncMemeoryUpdater(torch.nn.Module):
self.update_hunk = self.historical_func
elif self.mode == 'local' or self.mode=='all_local':
self.update_hunk = self.local_func
if self.mode == 'historical':
self.gamma = torch.nn.Parameter(torch.tensor([0.9]),
requires_grad=True)
else:
self.gamma = 1
def forward(self, mfg, param = None):
for b in mfg:
updated_memory = self.updater(b)
mail_input = b.srcdata['mem_input']
updated_memory0 = self.updater(b)
mask = DistIndex(b.srcdata['ID']).is_shared
#incr = updated_memory[mask] - b.srcdata['mem'][mask]
time_feat = self.time_enc(b.srcdata['ts'][mask].reshape(-1,1) - b.srcdata['his_ts'][mask].reshape(-1,1))
his_mem = torch.cat((mail_input[mask],time_feat),dim = 1)
upd0 = torch.zeros_like(updated_memory0)
upd0[mask] = self.ceil_updater(his_mem, b.srcdata['his_mem'][mask])
#updated_memory = torch.where(mask.unsqueeze(1),self.gamma*updated_memory0 + (1-self.gamma)*(b.srcdata['his_mem']),updated_memory0)
updated_memory = torch.where(mask.unsqueeze(1),self.gamma*updated_memory0 + (1-self.gamma)*(upd0),updated_memory0)
print(torch.cosine_similarity(updated_memory[mask],b.srcdata['his_mem'][mask]).sum()/torch.sum(mask))
print(self.gamma)
self.pre_mem = b.srcdata['his_mem']
self.last_updated_ts = b.srcdata['ts'].detach().clone()
self.last_updated_memory = updated_memory.detach().clone()
self.last_updated_nid = b.srcdata['ID'].detach().clone()
with torch.no_grad():
if param is not None:
_,src,dst,ts,edge_feats,nxt_fetch_func = param
......@@ -434,12 +452,14 @@ class AsyncMemeoryUpdater(torch.nn.Module):
self.last_updated_memory[indx],
self.last_updated_ts[indx],
None)
#print(index.shape[0])
if param[0]:
index, mail, mail_ts = self.mailbox.get_update_mail(
b.srcdata['ID'],src,dst,ts,edge_feats,
self.last_updated_memory,
None,False,False,
None,False,False,block=b
)
#print(index.shape[0])
if torch.distributed.get_world_size() == 0:
self.mailbox.mon.add(index,self.mailbox.node_memory.accessor.data[index],memory)
##print(index.shape,memory.shape,memory_ts.shape,mail.shape,mail_ts.shape)
......@@ -450,11 +470,11 @@ class AsyncMemeoryUpdater(torch.nn.Module):
if self.memory_param['combine_node_feature'] and self.dim_node_feat > 0:
if self.dim_node_feat == self.dim_hid:
b.srcdata['h'] += memory
b.srcdata['h'] += updated_memory
else:
b.srcdata['h'] = memory + self.node_feat_map(b.srcdata['h'])
b.srcdata['h'] = updated_memory + self.node_feat_map(b.srcdata['h'])
else:
b.srcdata['h'] = memory
b.srcdata['h'] = updated_memory
def empty_cache(self):
pass
......@@ -86,7 +86,7 @@ class GeneralModel(torch.nn.Module):
# else:
# self.memory_updater = HistoricalMemeoryUpdater(memory_param, 2 * memory_param['dim_out'] + dim_edge, memory_param['dim_out'], memory_param['dim_time'], dim_node,updater=updater,learnable=True,num_nodes=num_nodes)
elif memory_param['memory_update'] == 'transformer':
self.memory_updater = TransformerMemoryUpdater
updater = TransformerMemoryUpdater
self.memory_updater = AsyncMemeoryUpdater(memory_param, 2 * memory_param['dim_out'] + dim_edge, memory_param['dim_out'], memory_param['dim_time'], dim_node, updater=updater, mailbox=mailbox, mode = memory_param['mode'],train_param=train_param)
else:
raise NotImplementedError
......
......@@ -123,7 +123,7 @@ def get_node_all_to_all_route(graph:DistributedGraphStore,mailbox:SharedMailBox,
ind_dict = graph.nfeat.all_to_all_ind2ptr(query_nid_feature,group = group)
memory,memory_ts,mail,mail_ts = mailbox.gather_local_memory(ind_dict['recv_ind'],compute_device=out_device)
memory_ts = memory_ts.reshape(-1,1)
mail_ts = mail_ts.reshape(-1,1)
mail_ts = mail_ts.reshape(mail_ts.shape[0],-1)
mail = mail.reshape(mail.shape[0],-1)
data.append(torch.cat((memory,memory_ts,mail,mail_ts),dim = 1))
if ind_dict is not None:
......@@ -144,14 +144,14 @@ def get_edge_all_to_all_route(graph:DistributedGraphStore,query_eid_feature,out_
def prepare_input(node_feat, edge_feat, mem_embedding,mfgs,dist_nid=None,dist_eid=None):
for i,mfg in enumerate(mfgs):
for b in mfg:
e_idx = b.edata['__ID']
idx = b.srcdata['__ID']
if dist_eid is not None:
if '__ID' in b.edata:
e_idx = b.edata['__ID']
b.edata['ID'] = dist_eid[e_idx]
if dist_nid is not None:
b.srcdata['ID'] = dist_nid[idx]
if edge_feat is not None:
b.edata['f'] = edge_feat[e_idx]
if dist_nid is not None:
b.srcdata['ID'] = dist_nid[idx]
if i == 0:
if node_feat is not None:
b.srcdata['h'] = node_feat[idx]
......@@ -336,15 +336,96 @@ def to_block(graph,data, sample_out,device = torch.device('cuda'),unique = True)
data,mfgs,metadata = build_block()
return (data,mfgs,metadata),dist_nid,dist_eid
def to_reversed_block(graph,data, sample_out,device = torch.device('cuda'),unique = True,identity=False):
if len(sample_out) > 1:
sample_out,metadata = sample_out
else:
metadata = None
nid_mapper: torch.Tensor = graph.nids_mapper
if identity is False:
assert len(sample_out) == 1
ret = sample_out[0]
eid_len = ret.eid().shape[0]
t0 = time.time()
dst_ts = ret.sample_nodes_ts().to(device)
dst = nid_mapper[ret.sample_nodes()].to(device)
dist_eid = torch.tensor([],dtype=torch.long,device=device)
src_index = ret.src_index().to(device)
else:
src_index = torch.tensor([],dtype=torch.long,device=device)
dst = torch.tensor([],dtype=torch.long,device=device)
dist_eid = torch.tensor([],dtype=torch.long,device=device)
if metadata is None:
root_node = data.nodes.to(graph.nids_mapper.device)
root_len = [root_node.shape[0]]
root_ts = data.ts.to(device)
elif 'seed' in metadata:
root_node = metadata.pop('seed').to(graph.nids_mapper.device)
root_len = root_node.shape[0]
if 'seed_ts' in metadata:
root_ts = metadata.pop('seed_ts').to(device)
for k in metadata:
if isinstance(metadata[k],torch.Tensor):
metadata[k] = metadata[k].to(device)
src_node = root_node
src_ts = root_ts
nid_tensor = torch.cat([root_node],dim = 0)
dist_nid = nid_mapper[nid_tensor].to(device)
CountComm.origin_local = (DistIndex(dist_nid).part == dist.get_rank()).sum().item()
CountComm.origin_remote =(DistIndex(dist_nid).part != dist.get_rank()).sum().item()
dist_nid,nid_inv = dist_nid.unique(return_inverse=True)
"""
对于同id和同时间的节点去重取得index
"""
block_node_list,unq_id = torch.stack((nid_inv.to(torch.float64),src_ts.to(torch.float64))).unique(dim = 1,return_inverse=True)
first_index,_ = torch_scatter.scatter_min(torch.arange(unq_id.shape[0],device=unq_id.device,dtype=unq_id.dtype),unq_id)
first_mask = torch.zeros(unq_id.shape[0],device = unq_id.device,dtype=torch.bool)
first_mask[first_index] = True
first_index = unq_id[first_mask]
first_block_id = torch.empty(first_index.shape[0],device=unq_id.device,dtype=unq_id.dtype)
first_block_id[first_index] = torch.arange(first_index.shape[0],device=first_index.device,dtype=first_index.dtype)
first_block_id = first_block_id[unq_id].contiguous()
block_node_list = block_node_list[:,first_index]
for k in metadata:
if isinstance(metadata[k],torch.Tensor):
metadata[k] = first_block_id[metadata[k]]
t2 = time.time()
def build_block():
mfgs = list()
col_len = 0
row_len = root_len
col = first_block_id[:row_len]
max_row = col.max().item()+1
b = dgl.create_block((col[src_index].to(device),
torch.arange(dst.shape[0],device=device,dtype=torch.long)),num_src_nodes=first_block_id.max().item()+1,
num_dst_nodes=dst.shape[0])
idx = block_node_list[0,b.srcnodes()].to(torch.long)
b.srcdata['__ID'] = idx
b.srcdata['ts'] = block_node_list[1,b.srcnodes()].to(torch.float)
b.dstdata['ID'] = dst
mfgs.append(b)
mfgs = list(map(list, zip(*[iter(mfgs)])))
mfgs.reverse()
return data,mfgs,metadata
data,mfgs,metadata = build_block()
return (data,mfgs,metadata),dist_nid,dist_eid
import concurrent.futures
executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
def graph_sample(graph,sampler,sample_fn,data,neg_sampling = None,out_device = torch.device('cuda'),nid_mapper = None,eid_mapper=None):
def graph_sample(graph,sampler,sample_fn,data,neg_sampling = None,out_device = torch.device('cuda'),nid_mapper = None,eid_mapper=None,reversed=False):
t_s = time.time()
param = {'is_unique':False,'nid_mapper':nid_mapper,'eid_mapper':eid_mapper,'out_device':out_device}
out = sample_fn(sampler,data,neg_sampling,**param)
out,dist_nid,dist_eid = to_block(graph,data,out,out_device)
if reversed is False:
out,dist_nid,dist_eid = to_block(graph,data,out,out_device,reversed)
else:
out,dist_nid,dist_eid = to_reversed_block(graph,data,out,out_device,reversed)
t_e = time.time()
#print(t_e-t_s)
return out,dist_nid,dist_eid
......
......@@ -12,6 +12,13 @@ class time_count:
time_sample_and_build = 0
time_memory_fetch = 0
weight_count_remote = 0
weight_count_local = 0
ssim_remote = 0
ssim_cnt = 0
ssim_local = 0
ssim_cnt = 0
@staticmethod
def _zero():
time_count.time_forward = 0
......@@ -34,8 +41,9 @@ class time_count:
def start():
return time.perf_counter(),0
@staticmethod
def elapsed_event(start_event,end_event):
if start_event.isinstance(torch.cuda.Event):
def elapsed_event(start_event):
if isinstance(start_event,tuple):
start_event,end_event = start_event
end_event.record()
end_event.synchronize()
return start_event.elapsed_time(end_event)
......@@ -52,3 +60,5 @@ class time_count:
time_count.time_memory_sync,
time_count.time_sample_and_build,
time_count.time_memory_fetch ))
\ No newline at end of file
......@@ -100,8 +100,10 @@ class DistributedDataLoader:
cache_mask = None,
use_local_feature = True,
probability = 1,
reversed = False,
**kwargs
):
self.reversed = reversed
self.use_local_feature = use_local_feature
self.local_embedding = local_embedding
self.chunk_size = chunk_size
......@@ -223,7 +225,7 @@ class DistributedDataLoader:
if self.mode=='train' and self.probability < 1:
mask = ((DistIndex(self.graph.nids_mapper[next_data.edges[0,:].to('cpu')]).part == dist.get_rank())&(DistIndex(self.graph.nids_mapper[next_data.edges[1,:].to('cpu')]).part == dist.get_rank()))
if self.probability > 0:
mask[~mask] = (torch.rand((~mask)) < self.probability)
mask[~mask] = (torch.rand((~mask).sum().item()) < self.probability)
next_data = next_data[mask.to(next_data.device)]
self.submitted = self.submitted + 1
return next_data
......@@ -239,13 +241,14 @@ class DistributedDataLoader:
self.device,
nid_mapper = self.graph.nids_mapper,
eid_mapper = self.graph.eids_mapper,
reversed = self.reversed
)
self.result_queue.append((fut))
@torch.no_grad()
def async_feature(self):
if(self.recv_idxs >= self.expected_idx):
if(self.recv_idxs >= self.expected_idx or self.is_pipeline == False):
return
is_local = (self.is_train & self.use_local_feature)
if(is_local):
......@@ -303,7 +306,9 @@ class DistributedDataLoader:
data,self.neg_sampler,
self.device,
nid_mapper = self.graph.nids_mapper,
eid_mapper = self.graph.eids_mapper)
eid_mapper = self.graph.eids_mapper,
reversed = self.reversed
)
root,mfgs,metadata = batch_data
t_sample = tt.elapsed_event(t0)
......@@ -312,6 +317,12 @@ class DistributedDataLoader:
edge_feat = get_edge_feature_by_dist(self.graph,dist_eid,is_local,out_device=self.device)
node_feat,mem = get_node_feature_by_dist(self.graph,self.mailbox,dist_nid, is_local,out_device=self.device)
prepare_input(node_feat,edge_feat,mem,mfgs,dist_nid,dist_eid)
batch_data[1][0][0].srcdata['his_mem'] = batch_data[1][0][0].srcdata['mem'].clone()
batch_data[1][0][0].srcdata['his_ts'] = batch_data[1][0][0].srcdata['mail_ts'].clone()
mask = DistIndex(batch_data[1][0][0].srcdata['ID']).is_shared
indx = self.mailbox.is_shared_mask[DistIndex(batch_data[1][0][0].srcdata['ID']).loc[mask]]
batch_data[1][0][0].srcdata['his_mem'][mask] = self.mailbox.historical_cache.local_historical_data[indx]
batch_data[1][0][0].srcdata['his_ts'][mask] = self.mailbox.historical_cache.local_ts[indx].reshape(-1,1)
t_fetch = tt.elapsed_event(t1)
tt.time_memory_fetch += t_fetch
#if(self.mailbox is not None and self.mailbox.historical_cache is not None):
......@@ -335,10 +346,15 @@ class DistributedDataLoader:
data,self.neg_sampler,
self.device,
nid_mapper = self.graph.nids_mapper,
eid_mapper = self.graph.eids_mapper)
eid_mapper = self.graph.eids_mapper,
reversed = self.reversed
)
edge_feat = get_edge_feature_by_dist(self.graph,dist_eid,is_local,out_device=self.device)
node_feat,mem = get_node_feature_by_dist(self.graph,self.mailbox,dist_nid, is_local,out_device=self.device)
prepare_input(node_feat,edge_feat,mem,batch_data[1],dist_nid,dist_eid)
batch_data[1][0][0].srcdata['his_mem'] = batch_data[1][0][0].srcdata['mem'].clone()
batch_data[1][0][0].srcdata['his_ts'] = batch_data[1][0][0].srcdata['mail_ts'].clone()
#if(self.mailbox is not None and self.mailbox.historical_cache is not None):
# id = batch_data[1][0][0].srcdata['ID']
# mask = DistIndex(id).is_shared
......@@ -363,14 +379,14 @@ class DistributedDataLoader:
node_feat0 = node_feat0[0]
node_feat = node_feat0[:,:self.graph.nfeat.shape[1]]
if self.graph.nfeat.shape[1] < node_feat0.shape[1]:
mem = self.mailbox.unpack(node_feat0[:,self.graph.nfeat.shape[1]:])
mem = self.mailbox.unpack(node_feat0[:,self.graph.nfeat.shape[1]:],mailbox = True)
else:
mem = None
elif self.mailbox is not None:
node_feat0[1].wait()
node_feat0 = node_feat0[0]
node_feat = None
mem = self.mailbox.unpack(node_feat0)
mem = self.mailbox.unpack(node_feat0,mailbox = True)
#print(node_feat.shape,edge_feat.shape,mem[0].shape)
#node_feat[1].wait()
#node_feat = node_feat[0]
......@@ -394,7 +410,11 @@ class DistributedDataLoader:
if(self.mailbox is not None and self.mailbox.historical_cache is not None):
id = batch_data[1][0][0].srcdata['ID']
mask = DistIndex(id).is_shared
batch_data[1][0][0].srcdata['mem'][mask] = self.mailbox.node_memory.accessor.data[DistIndex(id).loc[mask]]
batch_data[1][0][0].srcdata['his_mem'] = batch_data[1][0][0].srcdata['mem'].clone()
batch_data[1][0][0].srcdata['his_ts'] = batch_data[1][0][0].srcdata['mail_ts'].clone()
indx = self.mailbox.is_shared_mask[DistIndex(batch_data[1][0][0].srcdata['ID']).loc[mask]]
batch_data[1][0][0].srcdata['his_mem'][mask] = self.mailbox.historical_cache.local_historical_data[indx]
batch_data[1][0][0].srcdata['his_ts'][mask] = self.mailbox.historical_cache.local_ts[indx].reshape(-1,1)#self.mailbox.node_memory.accessor.data[DistIndex(id).loc[mask]]
#his_mem = torch.clone(batch_data[1][0][0].srcdata['mem'])
#his_ts = self.mailbox.historical_cache.local_ts[DistIndex(id[mask]).loc]
#maxer = his_ts>batch_data[1][0][0].srcdata['mem_ts'][mask]
......
......@@ -17,10 +17,10 @@ class MemoryMoniter:
self.memorychange.append(self.ssim(pre_memory,now_memory,method = 'F'))
self.memory_ssim.append(self.ssim(pre_memory,now_memory,method = 'cos'))
self.nid_list.append(nid)
def draw(self,degree,data,e):
torch.save(self.nid_list,'all/{}/memorynid_{}.pt'.format(data,e))
torch.save(self.memorychange,'all/{}/memoryF_{}.pt'.format(data,e))
torch.save(self.memory_ssim,'all/{}/memcos_{}.pt'.format(data,e))
def draw(self,degree,data,model,e):
torch.save(self.nid_list,'all/{}/{}/memorynid_{}.pt'.format(data,model,e))
torch.save(self.memorychange,'all/{}/{}/memoryF_{}.pt'.format(data,model,e))
torch.save(self.memory_ssim,'all/{}/{}/memcos_{}.pt'.format(data,model,e))
# path = './memory/{}/'.format(data)
# if not os.path.exists(path):
......
......@@ -98,6 +98,7 @@ class SharedMailBox():
self.tot_comm_count = 0
self.tot_shared_count = 0
self.shared_nodes_index = None
self.deliver_to = memory_param['deliver_to'] if 'deliver_to' in memory_param else 'self'
if shared_nodes_index is not None:
self.shared_nodes_index = shared_nodes_index.to('cuda:{}'.format(ctx.local_rank))
self.is_shared_mask = torch.zeros(self.num_nodes,dtype=torch.int,device=torch.device('cuda:{}'.format(ctx.local_rank)))-1
......@@ -152,7 +153,7 @@ class SharedMailBox():
def get_update_mail(self,dist_indx_mapper,
src,dst,ts,edge_feats,
memory,embedding=None,use_src_emb=False,use_dst_emb=False,
deliver_to='self',block = None,Reduce_score=None,):
block = None,Reduce_score=None,):
if edge_feats is not None:
edge_feats = edge_feats.to(self.device).to(self.mailbox.dtype)
src = src.to(self.device)
......@@ -172,11 +173,13 @@ class SharedMailBox():
mail = torch.cat([src_mail, dst_mail], dim=0)
mail_ts = torch.cat((ts,ts),-1).to(self.device).to(self.mailbox_ts.dtype)
#print(mail_ts)
if deliver_to == 'neighbor':
assert block is None and Reduce_score is not None
mail = torch.cat([mail, mail[block.edges()[1].long()]], dim=0)
mail_ts = torch.cat([mail_ts, mail_ts[block.edges()[1].long()]], dim=0)
index = torch.cat([index,dist_indx_mapper[block.edges()[0].long()]],dim=0)
#print(self.deliver_to)
if self.deliver_to == 'neighbors':
assert block is not None and Reduce_score is None
mail = torch.cat([mail, mail[block.edges()[0].long()]], dim=0)
mail_ts = torch.cat([mail_ts, mail_ts[block.edges()[0].long()]], dim=0)
index = torch.cat([index,block.dstdata['ID'][block.edges()[1].long()]],dim=0)
if Reduce_score is not None:
Reduce_score = torch.cat((Reduce_score,Reduce_score),-1).to(self.device)
if Reduce_score is None:
......@@ -214,18 +217,23 @@ class SharedMailBox():
#else:
# mem = torch.cat((memory,memory_ts.view(-1,1),mail,mail_ts.view(-1,1),index.to(torch.float32).view(-1,1)),dim = 1)
return mem
def unpack(self,mem):
def unpack(self,mem,mailbox = False):
if mem.shape[1] == self.node_memory.shape[1] + 1 or mem.shape[1] == self.mailbox.shape[2] + 1 :
mail = mem[:,: -1]
mail_ts = mem[:,-1].view(-1)
return mail,mail_ts
else:
elif mailbox is False:
memory = mem[:,:self.node_memory.shape[1]]
memory_ts = mem[:,self.node_memory.shape[1]].view(-1)
mail = mem[:,self.node_memory.shape[1]+1:-1]
mail_ts = mem[:,-1].view(-1)
return memory,memory_ts,mail,mail_ts
else:
memory = mem[:,:self.node_memory.shape[1]]
memory_ts = mem[:,self.node_memory.shape[1]].view(-1)
mail = mem[:,self.node_memory.shape[1]+1:mem.shape[1]-self.mailbox_ts.shape[1]].reshape(mem.shape[0],self.mailbox.shape[1],-1)
mail_ts = mem[:,mem.shape[1]-self.mailbox_ts.shape[1]:]
return memory,memory_ts,mail,mail_ts
def handle_last_async(self,reduce_Op = None):
if self.last_memory_sync is not None:
......@@ -247,10 +255,11 @@ class SharedMailBox():
if out is not None:
shared_index,shared_data,shared_ts,mail,mail_ts = out
index = self.shared_nodes_index[shared_index]
self.node_memory.accessor.data[index] = shared_data
self.node_memory_ts.accessor.data[index] = shared_ts
self.mailbox.accessor.data[index, torch.max(self.next_mail_pos[index]-1,torch.tensor([0],device=mail.device))] = mail
self.mailbox_ts.accessor.data[index, torch.max(self.next_mail_pos[index]-1,torch.tensor([0],device=mail_ts.device))] = mail_ts
mask= (shared_ts > self.node_memory_ts.accessor.data[index])
self.node_memory.accessor.data[index][mask] = shared_data[mask]
self.node_memory_ts.accessor.data[index][mask] = shared_ts[mask]
#self.mailbox.accessor.data[index, torch.max(self.next_mail_pos[index]-1,torch.tensor([0],device=mail.device))] = mail
#self.mailbox_ts.accessor.data[index, torch.max(self.next_mail_pos[index]-1,torch.tensor([0],device=mail_ts.device))] = mail_ts
def update_shared(self):
ctx = DistributedContext.get_default_context()
if self.last_job is not None:
......@@ -273,6 +282,7 @@ class SharedMailBox():
output_split_sizes=output_split,
input_split_sizes=input_split,group = group,async_op=async_op)
self.last_memory_sync = (gather_id_list,handle0,gather_memory,handle1)
def set_memory_all_reduce(self,index,memory,memory_ts,mail,mail_ts,reduce_Op = None,async_op = True,set_remote = False,mode=None,filter=None,submit = True):
ctx = DistributedContext.get_default_context()
#print(DistIndex(index).part)
......@@ -337,6 +347,7 @@ class SharedMailBox():
#,shared_memory,shared_memory_ts,
#shared_memory,shared_memory_ts = self.unpack(mem)
shared_memory,shared_memory_ts,shared_mail,shared_mail_ts = self.unpack(mem)
#print(shared_memory_ts,shared_mail_ts)
unq_index,inv = torch.unique(shared_index,return_inverse = True)
#print(inv.shape,Reduce_score.shape)
max_ts,idx = torch_scatter.scatter_max(shared_memory_ts,inv,0)
......
......@@ -20,19 +20,23 @@ class LocalNegativeSampling(NegativeSampling):
unique: bool = False,
src_node_list: torch.Tensor = None,
dst_node_list: torch.Tensor = None,
seed = False
local_mask = None,
seed = None
):
super(LocalNegativeSampling,self).__init__(mode,amount,unique=unique)
self.src_node_list = src_node_list.to('cpu') if src_node_list is not None else None
self.dst_node_list = dst_node_list.to('cpu') if dst_node_list is not None else None
self.rdm = torch.Generator()
if seed is True:
if seed is not None:
random.seed(seed)
seed = random.randint(0,100000)
print('seed is',seed)
ctx = DistributedContext.get_default_context()
self.rdm.manual_seed(seed^ctx.rank)
self.rdm = torch.Generator()
self.local_mask = local_mask
if self.local_mask is not None:
self.local_dst = dst_node_list[local_mask]
#self.rdm.manual_seed(42)
#print('dst_nde_list {}\n'.format(dst_node_list))
def is_binary(self) -> bool:
......@@ -53,6 +57,12 @@ class LocalNegativeSampling(NegativeSampling):
else:
if self.dst_node_list is None:
return torch.randint(num_nodes, (num_samples, ),generator=self.rdm)
elif self.local_mask is not None:
p = torch.rand(size=(num_samples,))
sr = torch.randint(len(self.dst_node_list), (num_samples, ),generator=self.rdm)
sl = torch.randint(len(self.local_dst), (num_samples, ),generator=self.rdm)
return torch.where(p<0.9,sl,sr)
else:
return self.dst_node_list[torch.randint(len(self.dst_node_list), (num_samples, ),generator=self.rdm)]
s = torch.randint(len(self.dst_node_list), (num_samples, ),generator=self.rdm)
return self.dst_node_list[s]
......@@ -227,27 +227,12 @@ class NeighborSampler(BaseSampler):
sampled_nodes: the node sampled
sampled_edge_index_list: the edge sampled
"""
if(ts is None):
self.part_unique = True
self.p_sampler.neighbor_sample_from_nodes(nodes.contiguous(), None, self.part_unique)
ret = self.p_sampler.get_ret()
return ret
else:
if self.policy != 'identity':
self.p_sampler.neighbor_sample_from_nodes(nodes.contiguous(), ts.contiguous(), None)
ret = self.p_sampler.get_ret()
else:
ret = None
metadata = {}
if is_unique:
self.p_sampler.sample_unique(
nodes,ts.float(),nid_mapper,
eid_mapper,'cpu' if out_device.type == 'cpu' else str(out_device.index))
metadata = {
'eid_inv':self.p_sampler.eid_inv,
'first_block_id':self.p_sampler.first_block_id,
'block_node_list':self.p_sampler.block_node_list,
'unq_id':self.p_sampler.unq_id,
'dist_nid':self.p_sampler.dist_nid,
'dist_eid':self.p_sampler.dist_eid,
}
#print(nodes.shape[0],ret[0].src_index().max(),ret[0].src_index().min())
return ret,metadata
......@@ -329,12 +314,6 @@ class NeighborSampler(BaseSampler):
else:
seed, inverse_seed = seed.unique(return_inverse=True)
"""
metadata = {}
if is_unique:
out,metadata = self.sample_from_nodes(seed, seed_ts, is_unique=is_unique,eid_mapper=eid_mapper,nid_mapper=nid_mapper,out_device=out_device)
first_block_id = self.p_sampler.first_block_id
else:
#print('is unique')
out,metadata = self.sample_from_nodes(seed, seed_ts, is_unique=False)
src_pos_index = torch.arange(0,num_pos,dtype= torch.long,device=out_device)
dst_pos_index = torch.arange(num_pos,2*num_pos,dtype= torch.long,device=out_device)
......@@ -344,12 +323,6 @@ class NeighborSampler(BaseSampler):
else:
src_neg_index = torch.arange(2*num_pos,3*num_pos,dtype= torch.long,device=out_device)
dst_neg_index=torch.arange(3*num_pos,seed.shape[0],dtype= torch.long,device=out_device)
#if is_unique:
# src_pos_index = first_block_id[src_pos_index].contiguous()
# dst_pos_index = first_block_id[dst_pos_index].contiguous()
# dst_neg_index = first_block_id[dst_neg_index].contiguous()
# src_neg_index = first_block_id[src_neg_index].contiguous()
metadata['seed'] = seed
metadata['seed_ts'] = seed_ts
metadata['src_pos_index']=src_pos_index
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment