Commit abb7e9e8 by zlj

test no time encoder

parent 48c10caf
......@@ -24,7 +24,7 @@ gnn:
dim_time: 100
dim_out: 100
train:
- epoch: 100
- epoch: 20
batch_size: 1000
# reorder: 16
lr: 0.0004
......
sampling:
- layer: 1
neighbor:
- 20
- 10
strategy: 'recent'
prop_time: False
history: 1
......@@ -24,10 +24,10 @@ gnn:
use_dst_emb: False
layer: 1
att_head: 2
dim_time: 100
dim_time: 0
dim_out: 100
train:
- epoch: 50
- epoch: 100
batch_size: 3000
# reorder: 16
lr: 0.0004
......
......@@ -256,6 +256,7 @@ void ParallelSampler :: neighbor_sample_from_nodes_with_before_layer(
tgb_i[tid].eid.insert(tgb_i[tid].eid.end(), tnb.eid[node].begin()+start_index, tnb.eid[node].begin()+end_index);
for(int cid = start_index; cid < end_index;cid++){
tgb_i[tid].delta_ts.emplace_back(rtts-tnb.timestamp[node][cid]);
tgb_i[tid].sample_weight.emplace_back(cid-start_index);
}
}
else if(policy == "boundery_recent_uniform"){
......@@ -363,8 +364,8 @@ void ParallelSampler :: neighbor_sample_from_nodes_with_before_layer(
each_begin[i]=size;
size += s;
}
if(policy == "boundery_recent_decay")
ret[cur_layer].sample_weight.resize(size);
//if(policy == "boundery_recent_decay")
ret[cur_layer].sample_weight.resize(size);
ret[cur_layer].eid.resize(size);
ret[cur_layer].src_index.resize(size);
ret[cur_layer].delta_ts.resize(size);
......@@ -374,7 +375,7 @@ void ParallelSampler :: neighbor_sample_from_nodes_with_before_layer(
#pragma omp parallel for schedule(static, 1)
for(int i = 0; i<threads; i++){
//if(policy == "boundery_recent_decay")
// copy(tgb_i[i].sample_weight.begin(), tgb_i[i].sample_weight.end(), ret[cur_layer].sample_weight.begin()+each_begin[i]);
copy(tgb_i[i].sample_weight.begin(), tgb_i[i].sample_weight.end(), ret[cur_layer].sample_weight.begin()+each_begin[i]);
copy(tgb_i[i].eid.begin(), tgb_i[i].eid.end(), ret[cur_layer].eid.begin()+each_begin[i]);
copy(tgb_i[i].src_index.begin(), tgb_i[i].src_index.end(), ret[cur_layer].src_index.begin()+each_begin[i]);
copy(tgb_i[i].delta_ts.begin(), tgb_i[i].delta_ts.end(), ret[cur_layer].delta_ts.begin()+each_begin[i]);
......
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
import matplotlib.pyplot as plt
import numpy as np
import torch
# 读取文件内容
ssim_values = [0, 0.1, 0.2, 0.3, 0.4, 2] # 假设这是你的 ssim 参数值
probability_values = [1,0.5,0.1,0.05,0.01,0]
data_values = ['WIKI','REDDIT','LASTFM','DGraphFin','WikiTalk'] # 存储从文件中读取的数据
partition = 'ours_shared'
# 从文件中读取数据,假设数据存储在文件 data.txt 中
#all/"$data"/"$partitions"-ours_shared-0.01-"$mem"-"$ssim"-"$sample".out
partitions=4
topk=0.01
mem='all_update'#'historical'
for data in data_values:
ap_list = []
comm_list = []
for p in probability_values:
file = '{}/{}-{}-{}-{}-boundery_recent_decay-{}.out'.format(data,partitions,partition,topk,mem,p)
prefix = 'best test AP:'
cnt = 0
sum = 0
with open(file, 'r') as file:
for line in file:
if line.startswith(prefix):
ap = float(line.lstrip(prefix).split(' ')[0])
pos = line.find('remote node number tensor')
if(pos!=-1):
posr = line.find(']',pos+2+len('remote node number tensor'),)
comm = int(line[pos+2+len('remote node number tensor'):posr])
sum = sum+comm
cnt = cnt+1
ap_list.append(ap)
comm_list.append(comm/cnt*4)
# 绘制柱状图
bar_width = 0.4
#shared comm tensor
# 设置柱状图的位置
bars = range(len(ssim_values))
# 绘制柱状图
print('{} TestAP={}\n'.format(data,ap_list))
plt.bar([b for b in bars], ap_list, width=bar_width)
# 绘制柱状图
plt.xticks([b for b in bars], probability_values)
plt.xlabel('probability')
plt.ylabel('Test AP')
plt.title('{}({} partitions)'.format(data,partitions))
plt.savefig('boundary_AP_{}.png'.format(data))
plt.clf()
plt.bar([b for b in bars], comm_list, width=bar_width)
# 绘制柱状图
plt.xticks([b for b in bars], probability_values)
plt.xlabel('probability')
plt.ylabel('Communication volume')
plt.title('{}({} partitions)'.format(data,partitions))
plt.savefig('boundary_comm_{}.png'.format(data))
plt.clf()
if partition == 'ours_shared':
partition0 = 'ours'
else:
partition0=partition
for p in probability_values:
file = '{}/loss_{}_{}_{}_0_boundery_recent_decay_{}_all_update_2.pt'.format(data,partition0,topk,partitions,float(p))
val_ap = torch.tensor(torch.load(file))
epoch = torch.arange(val_ap.shape[0])
#绘制曲线图
plt.plot(epoch,val_ap, label='probability={}'.format(p))
plt.xlabel('Epoch')
plt.ylabel('loss')
plt.title('{}({} partitions)'.format(data,partitions))
# plt.grid(True)
plt.legend()
plt.savefig('{}_boundary_Convergence_loss.png'.format(data))
plt.clf()
import matplotlib.pyplot as plt
import numpy as np
import torch
# 读取文件内容
ssim_values = [0, 0.1, 0.2, 0.3, 0.4, 2] # 假设这是你的 ssim 参数值
data_values = ['WIKI','REDDIT','LASTFM','DGraphFin'] # 存储从文件中读取的数据
partition = 'ours_shared'
# 从文件中读取数据,假设数据存储在文件 data.txt 中
#all/"$data"/"$partitions"-ours_shared-0.01-"$mem"-"$ssim"-"$sample".out
partitions=4
topk=0.01
mem='historical'
for data in data_values:
ap_list = []
comm_list = []
for ssim in ssim_values:
if ssim == 2:
file = '{}/{}-{}-{}-local-recent.out'.format(data,partitions,partition,topk)
else:
file = '{}/{}-{}-{}-{}-{}-recent.out'.format(data,partitions,partition,topk,mem,ssim)
prefix = 'best test AP:'
with open(file, 'r') as file:
for line in file:
if line.startswith(prefix):
ap = float(line.lstrip(prefix).split(' ')[0])
pos = line.find('shared comm tensor')
if(pos!=-1):
comm = int(line[pos+2+len('shared comm tensor'):len(line)-3])
ap_list.append(ap)
comm_list.append(comm)
print('{} TestAP={}\n'.format(data,ap_list))
# 绘制柱状图
bar_width = 0.4
#shared comm tensor
# 设置柱状图的位置
bars = range(len(ssim_values))
# 绘制柱状图
plt.bar([b for b in bars], ap_list, width=bar_width)
# 绘制柱状图
plt.xticks([b for b in bars], ssim_values)
plt.xlabel('SSIM threshold Values')
plt.ylabel('Test AP')
plt.title('{}({} partitions)'.format(data,partitions))
plt.savefig('ssim_{}_{}.png'.format(data,partitions))
plt.clf()
plt.bar([b for b in bars], comm_list, width=bar_width)
# 绘制柱状图
plt.xticks([b for b in bars], ssim_values)
plt.xlabel('SSIM threshold Values')
plt.ylabel('Communication volume')
plt.title('{}({} partitions)'.format(data,partitions))
plt.savefig('ssim_comm_{}_{}.png'.format(data,partitions))
plt.clf()
if partition == 'ours_shared':
partition0 = 'ours'
else:
partition0=partition
for ssim in ssim_values:
if ssim == 2:
file = '{}/val_{}_{}_{}_0_recent_0.1_local_2.pt'.format(data,partition0,topk,partitions,)
else:
file = '{}/val_{}_{}_{}_0_recent_0.1_{}_{}.pt'.format(data,partition0,topk,partitions,mem,float(ssim))
val_ap = torch.tensor(torch.load(file))
epoch = torch.arange(val_ap.shape[0])
#绘制曲线图
plt.plot(epoch,val_ap, label='ssim={}'.format(ssim))
plt.xlabel('Epoch')
plt.ylabel('Val AP')
plt.title('{}({} partitions)'.format(data,partitions))
# plt.grid(True)
plt.legend()
plt.savefig('{}_{}_ssim_Convergence_rate.png'.format(data,partitions))
plt.clf()
import matplotlib.pyplot as plt
import numpy as np
import torch
# 读取文件内容
ssim_values = [0, 0.1, 0.2, 0.3, 0.4, 2] # 假设这是你的 ssim 参数值
probability_values = [1,0.5,0.1,0.05,0.01,0]
data_values = ['WIKI_3','LASTFM_3','WikiTalk','StackOverflow'] # 存储从文件中读取的数据
partition = 'ours'
# 从文件中读取数据,假设数据存储在文件 data.txt 中
#all/"$data"/"$partitions"-ours_shared-0.01-"$mem"-"$ssim"-"$sample".out
partitions=4
topk=0
mem='all_update'#'historical'
model='TGN'
for data in data_values:
ap_list = []
comm_list = []
for p in probability_values:
file = '{}/{}/{}-{}-{}-{}-boundery_recent_uniform-{}.out'.format(data,model,partitions,partition,topk,mem,p)
prefix = 'best test AP:'
cnt = 0
sum = 0
with open(file, 'r') as file:
for line in file:
if line.startswith(prefix):
ap = float(line.lstrip(prefix).split(' ')[0])
pos = line.find('remote node number tensor')
if(pos!=-1):
posr = line.find(']',pos+2+len('remote node number tensor'),)
#print(line,line[pos+2+len('remote node number tensor'):posr])
comm = int(line[pos+2+len('remote node number tensor'):posr])
#print()
sum = sum+comm
cnt = cnt+1
#print(comm)
ap_list.append(ap)
comm_list.append(sum/cnt*4)
# 绘制柱状图
print('{} TestAP={}\n'.format(data,ap_list))
bar_width = 0.4
#shared comm tensor
# 设置柱状图的位置
bars = range(len(probability_values))
# 绘制柱状图
plt.bar([b for b in bars], ap_list, width=bar_width)
# 绘制柱状图
plt.ylim([0.9,1])
plt.xticks([b for b in bars], probability_values)
plt.xlabel('probability')
plt.ylabel('Test AP')
plt.title('{}({} partitions)'.format(data,partitions))
plt.savefig('boundary_AP_{}_{}_{}.png'.format(data,partitions,model))
plt.clf()
print(comm_list)
plt.bar([b for b in bars], comm_list, width=bar_width)
# 绘制柱状图
plt.xticks([b for b in bars], probability_values)
plt.xlabel('probability')
plt.ylabel('Communication volume')
plt.title('{}({} partitions)'.format(data,partitions))
plt.savefig('boundary_comm_{}_{}_{}.png'.format(data,partitions,model))
plt.clf()
if partition == 'ours_shared':
partition0 = 'ours'
else:
partition0=partition
for p in probability_values:
file = '{}/{}/test_{}_{}_{}_0_boundery_recent_uniform_{}_all_update_2.pt'.format(data,model,partition0,topk,partitions,float(p))
val_ap = torch.tensor(torch.load(file))[:,0]
epoch = torch.arange(val_ap.shape[0])
#绘制曲线图
plt.plot(epoch,val_ap, label='probability={}'.format(p))
plt.xlabel('Epoch')
plt.ylabel('Val AP')
plt.title('{}({} partitions)'.format(data,partitions))
# plt.grid(True)
plt.legend()
plt.savefig('{}_{}_{}_boundary_Convergence_rate.png'.format(data,partitions,model))
plt.clf()
import torch
import matplotlib.pyplot as plt
partitions = 4
data='WikiTalk'
method = 'ours'
local_edge_comm_list = [torch.tensor([]) for _ in range(partitions)]
remote_edge_comm_list = [torch.tensor([]) for _ in range(partitions)]
calculate = [torch.tensor([]) for _ in range(partitions)]
for r in range(partitions):
comm = torch.load('{}/comm/comm_{}_0_{}_{}_recent_0.1_all_update_2.pt'.format(data,method,partitions,r))
local_access,remote_access,local_edge_access,remote_edge_access,local_comm,remote_comm,local_edge_comm,remote_edge_comm = comm
#local_access = torch.tensor(local_access)
#remote_access = torch.tensor(remote_access)
#local_comm = torch.tensor(local_comm)
#remote_comm = torch.tensor(remote_comm)
local_edge_comm_list[r] = torch.tensor(local_edge_comm)
remote_edge_comm_list[r] = torch.tensor(remote_edge_comm)
calculate[r] = local_edge_comm_list[r] + remote_edge_comm[r]
cal = torch.stack(calculate)[:,:12]
cal_max,_ = torch.max(cal,dim = 0)
cal_min = torch.sum(cal,dim = 0)/partitions
x = torch.arange(cal.shape[1])
y0 = cal_max/cal_min
y1 = torch.ones(cal.shape[1])
# 堆叠柱状图
fig, ax = plt.subplots()
bar_width = 0.35
ax.bar(x,cal_max,bar_width,label='max overhead')
ax.bar(x+bar_width,cal_min,bar_width,label='avg overhead')
#bars_A = ax.plot(x, cal[0] , label='rank = 0')
#bars_B = ax.plot(x, cal[1] , label='rank = 1')
#bars_C = ax.plot(x, cal[2] , label='rank = 2')
#bars_D = ax.plot(x, cal[3] , label='rank = 3')
#avg = y1*torch.mean(y0.to(torch.float))
#ax.plot(x,avg,color='orange',label='EB average')
#ax.plot(x,y1,color='black',label='average')
plt.xlabel('Batch Index')
plt.ylabel('Computational Load Balace')
plt.legend()
if method == 'dis_tgl':
method = 'DisTGL'
plt.title('{}({} {}-partitions)'.format(data,method,partitions))
plt.savefig('{}_{}_balance.png'.format(data,method))
import matplotlib.pyplot as plt
import numpy as np
import torch
# 读取文件内容
ssim_values = [-1,0.3,0.5,0.7,2] # 假设这是你的 ssim 参数值
data_values = ['WIKI','LASTFM','WikiTalk','REDDIT','LASTFM','DGraphFin'] # 存储从文件中读取的数据
partition = 'ours_shared'
# 从文件中读取数据,假设数据存储在文件 data.txt 中
#all/"$data"/"$partitions"-ours_shared-0.01-"$mem"-"$ssim"-"$sample".out
partitions=4
model = 'TGN'
topk=0.01
mem='historical'
for data in data_values:
ap_list = []
comm_list = []
for ssim in ssim_values:
if ssim == 2:
file = '{}/{}/{}-{}-{}-local-recent.out'.format(data,model,partitions,partition,topk)
elif ssim == -1:
file = '{}/{}/{}-{}-{}-all_update-recent.out'.format(data,model,partitions,partition,topk)
else:
file = '{}/{}/{}-{}-{}-{}-{}-recent.out'.format(data,model,partitions,partition,topk,mem,ssim)
prefix = 'best test AP:'
with open(file, 'r') as file:
for line in file:
if line.startswith(prefix):
ap = float(line.lstrip(prefix).split(' ')[0])
pos = line.find('shared comm tensor')
if(pos!=-1):
comm = int(line[pos+2+len('shared comm tensor'):len(line)-3])
print(ap)
ap_list.append(ap)
comm_list.append(comm)
print('{} TestAP={}\n'.format(data,ap_list))
# 绘制柱状图
bar_width = 0.4
#shared comm tensor
print('{} TestAP={}\n'.format(data,ap_list))
# 设置柱状图的位置
bars = range(len(ssim_values))
# 绘制柱状图
plt.bar([b for b in bars], ap_list, width=bar_width)
# 绘制柱状图
plt.xticks([b for b in bars], ssim_values)
plt.xlabel('SSIM threshold Values')
plt.ylabel('Test AP')
#if(data=='WIKI'):
# plt.ylim([0.97,1])
plt.title('{}({} partitions)'.format(data,partitions))
plt.savefig('ssim_{}_{}_{}.png'.format(data,partitions,model))
plt.clf()
plt.bar([b for b in bars], comm_list, width=bar_width)
# 绘制柱状图
plt.xticks([b for b in bars], ssim_values)
plt.xlabel('SSIM threshold Values')
plt.ylabel('Communication volume')
plt.title('{}({} partitions)'.format(data,partitions))
plt.savefig('ssim_comm_{}_{}_{}.png'.format(data,partitions,model))
plt.clf()
if partition == 'ours_shared':
partition0 = 'ours'
else:
partition0=partition
for ssim in ssim_values:
if ssim == 2:
file = '{}/{}/test_{}_{}_{}_0_recent_0.1_local_2.pt'.format(data,model,partition0,topk,partitions,)
elif ssim == -1:
file = '{}/{}/test_{}_{}_{}_0_recent_0.1_all_update_2.pt'.format(data,model,partition0,topk,partitions,)
else:
file = '{}/{}/test_{}_{}_{}_0_recent_0.1_{}_{}.pt'.format(data,model,partition0,topk,partitions,mem,float(ssim))
val_ap = torch.tensor(torch.load(file))[:,0]
print(val_ap)
epoch = torch.arange(val_ap.shape[0])
#绘制曲线图
#print(val_ap)
if ssim == -1:
plt.plot(epoch,val_ap, label='all-update')
elif ssim == 2:
plt.plot(epoch,val_ap, label='local')
else:
plt.plot(epoch,val_ap, label='ssim = {}'.format(ssim))
if(data=='WIKI'):
plt.ylim([0.85,0.90])
plt.xlabel('Epoch')
plt.ylabel('Val AP')
plt.title('{}({} partitions)'.format(data,partitions))
# plt.grid(True)
plt.legend()
plt.savefig('{}_{}_{}_ssim_Convergence_rate.png'.format(data,partitions,model))
plt.clf()
import matplotlib.pyplot as plt
import numpy as np
import torch
# 读取文件内容
ssim_values = [0, 0.1, 0.2, 0.3, 0.4, 2] # 假设这是你的 ssim 参数值
data_values = ['WIKI','REDDIT'] # 存储从文件中读取的数据
partition = 'ours_shared'
# 从文件中读取数据,假设数据存储在文件 data.txt 中
#all/"$data"/"$partitions"-ours_shared-0.01-"$mem"-"$ssim"-"$sample".out
partitions=4
topk=0.01
mem='historical'
for data in data_values:
ap_list = []
for ssim in ssim_values:
if ssim == 2:
file = '{}/{}-{}-{}-local-recent.out'.format(data,partitions,partition,topk)
else:
file = '{}/{}-{}-{}-{}-{}-recent.out'.format(data,partitions,partition,topk,mem,ssim)
prefix = 'best test AP:'
with open(file, 'r') as file:
for line in file:
if line.startswith(prefix):
ap = float(line.lstrip(prefix).split(' ')[0])
ap_list.append(ap)
# 绘制柱状图
bar_width = 0.4
# 设置柱状图的位置
bars = range(len(ssim_values))
# 绘制柱状图
plt.bar([b for b in bars], ap_list, width=bar_width)
# 绘制柱状图
plt.ylim([0.8,1])
plt.xticks([b for b in bars], ssim_values)
plt.xlabel('SSIM threshold Values')
plt.ylabel('Test AP')
plt.title('{}({} partitions)'.format(data,partitions))
plt.savefig('ssim_{}.png'.format(data))
plt.clf()
if partition == 'ours_shared':
partition0 = 'ours'
else:
partition0=partition
for ssim in ssim_values:
if ssim == 2:
file = '{}/val_{}_{}_{}_0_recent_0.1_local_2.pt'.format(data,partition0,topk,partitions,)
else:
file = '{}/val_{}_{}_{}_0_recent_0.1_{}_{}.pt'.format(data,partition0,topk,partitions,mem,float(ssim))
val_ap = torch.tensor(torch.load(file))
epoch = torch.arange(val_ap.shape[0])
#绘制曲线图
plt.plot(epoch,val_ap, label='ssim={}'.format(ssim))
plt.xlabel('Epoch')
plt.ylabel('Val AP')
plt.title('{}({} partitions)'.format(data,partitions))
# plt.grid(True)
plt.legend()
plt.ylim([0.98,0.99])
plt.savefig('{}_ssim_Convergence_rate.png'.format(data))
plt.clf()
import matplotlib.pyplot as plt
import numpy as np
import torch
# 读取文件内容
ssim_values = [0, 0.1, 0.2, 0.3, 0.4, 2] # 假设这是你的 ssim 参数值
data_values = ['WIKI','WikiTalk','REDDIT','LASTFM','DGraphFin'] # 存储从文件中读取的数据
partition = 'ours_shared'
# 从文件中读取数据,假设数据存储在文件 data.txt 中
#all/"$data"/"$partitions"-ours_shared-0.01-"$mem"-"$ssim"-"$sample".out
partitions=4
topk=0.01
mem='historical'
for data in data_values:
ap_list = []
comm_list = []
for ssim in ssim_values:
if ssim == 2:
file = '{}/{}-{}-{}-local-recent.out'.format(data,partitions,partition,topk)
else:
file = '{}/{}-{}-{}-{}-{}-recent.out'.format(data,partitions,partition,topk,mem,ssim)
prefix = 'best test AP:'
with open(file, 'r') as file:
for line in file:
if line.startswith(prefix):
ap = float(line.lstrip(prefix).split(' ')[0])
pos = line.find('shared comm tensor')
if(pos!=-1):
comm = int(line[pos+2+len('shared comm tensor'):len(line)-3])
ap_list.append(ap)
comm_list.append(comm)
print('{} TestAP={}\n'.format(data,ap_list))
# 绘制柱状图
bar_width = 0.4
#shared comm tensor
# 设置柱状图的位置
bars = range(len(ssim_values))
# 绘制柱状图
plt.bar([b for b in bars], ap_list, width=bar_width)
# 绘制柱状图
plt.xticks([b for b in bars], ssim_values)
plt.xlabel('SSIM threshold Values')
plt.ylabel('Test AP')
plt.title('{}({} partitions)'.format(data,partitions))
plt.savefig('ssim_{}_{}.png'.format(data,partitions))
plt.clf()
plt.bar([b for b in bars], comm_list, width=bar_width)
# 绘制柱状图
plt.xticks([b for b in bars], ssim_values)
plt.xlabel('SSIM threshold Values')
plt.ylabel('Communication volume')
plt.title('{}({} partitions)'.format(data,partitions))
plt.savefig('ssim_comm_{}_{}.png'.format(data,partitions))
plt.clf()
if partition == 'ours_shared':
partition0 = 'ours'
else:
partition0=partition
for ssim in ssim_values:
if ssim == 2:
file = '{}/val_{}_{}_{}_0_recent_0.1_local_2.pt'.format(data,partition0,topk,partitions,)
else:
file = '{}/val_{}_{}_{}_0_recent_0.1_{}_{}.pt'.format(data,partition0,topk,partitions,mem,float(ssim))
val_ap = torch.tensor(torch.load(file))
epoch = torch.arange(val_ap.shape[0])
#绘制曲线图
plt.plot(epoch,val_ap, label='ssim={}'.format(ssim))
plt.xlabel('Epoch')
plt.ylabel('Val AP')
plt.title('{}({} partitions)'.format(data,partitions))
# plt.grid(True)
plt.legend()
plt.savefig('{}_{}_ssim_Convergence_rate.png'.format(data,partitions))
plt.clf()
import argparse
import os
import sys
from os.path import abspath, join, dirname
from starrygl.distributed.context import DistributedContext
from starrygl.distributed.utils import DistIndex
from starrygl.evaluation.get_evalute_data import get_link_prediction_data
from starrygl.module.modules import GeneralModel
from pathlib import Path
from pathlib import Path
from starrygl.module.utils import parse_config
from starrygl.sample.cache.fetch_cache import FetchFeatureCache
from starrygl.sample.graph_core import DataSet, DistributedGraphStore, TemporalNeighborSampleGraph
from starrygl.module.utils import parse_config, EarlyStopMonitor
from starrygl.sample.graph_core import DataSet, DistributedGraphStore, TemporalNeighborSampleGraph
from starrygl.sample.memory.shared_mailbox import SharedMailBox
from starrygl.sample.sample_core.EvaluateNegativeSampling import EvaluateNegativeSampling
from starrygl.sample.sample_core.base import NegativeSampling
from starrygl.sample.sample_core.neighbor_sampler import NeighborSampler
from starrygl.sample.part_utils.partition_tgnn import partition_load
import torch
import time
import torch
import torch.nn.functional as F
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group
import os
from starrygl.sample.data_loader import DistributedDataLoader
from starrygl.sample.batch_data import SAMPLE_TYPE
from starrygl.sample.stream_manager import getPipelineManger
parser = argparse.ArgumentParser(
description="RPC Reinforcement Learning Example",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument('--rank', default=0, type=int, metavar='W',
help='name of dataset')
parser.add_argument('--patience', type=int, default=5, help='Patience for early stopping')
parser.add_argument('--world_size', default=1, type=int, metavar='W',
help='number of negative samples')
parser.add_argument('--dataname', default=1, type=str, metavar='W',
help='name of dataset')
parser.add_argument('--model', default='TGN', type=str, metavar='W',
help='name of model')
parser.add_argument('--negative_sample_strategy', default='random', type=str, metavar='W',
help='name of negative sample strategy')
parser.add_argument('--negative_sample_strategy', default='random', type=str, metavar='W',
help='name of negative sample strategy')
args = parser.parse_args()
from sklearn.metrics import average_precision_score, roc_auc_score
import torch
import time
import random
import dgl
import numpy as np
from sklearn.metrics import average_precision_score, roc_auc_score
from torch.nn.parallel import DistributedDataParallel as DDP
#os.environ['CUDA_VISIBLE_DEVICES'] = str(args.rank)
#os.environ["RANK"] = str(args.rank)
#os.environ["WORLD_SIZE"] = str(args.world_size)
#os.environ["LOCAL_RANK"] = str(0)
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
os.environ["MASTER_ADDR"] = '10.214.211.187'
os.environ["MASTER_PORT"] = '9337'
def seed_everything(seed=42):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
seed_everything(1234)
def main():
print('main')
use_cuda = True
sample_param, memory_param, gnn_param, train_param = parse_config('./config/{}.yml'.format(args.model))
torch.set_num_threads(12)
ctx = DistributedContext.init(backend="nccl", use_gpu=True)
device_id = torch.cuda.current_device()
pdata = partition_load("/mnt/data/part_data/evaluate/{}".format(args.dataname), algo="metis_for_tgnn")
graph = DistributedGraphStore(pdata = pdata,uvm_edge = False)
gnn_param['dyrep'] = True if args.model == 'DyRep' else False
use_src_emb = gnn_param['use_src_emb'] if 'use_src_emb' in gnn_param else False
use_dst_emb = gnn_param['use_dst_emb'] if 'use_dst_emb' in gnn_param else False
gnn_dim_node = 0 if graph.x is None else pdata.x.shape[1]
gnn_dim_edge = 0 if graph.edge_attr is None else pdata.edge_attr.shape[1]
print(gnn_dim_node,gnn_dim_edge)
avg_time = 0
MODEL_SAVE_PATH = f'./saved_models/{args.model}-{args.dataname}.pth'
if use_cuda:
model = GeneralModel(gnn_dim_node, gnn_dim_edge, sample_param, memory_param, gnn_param, train_param).cuda()
device = torch.device('cuda')
else:
model = GeneralModel(gnn_dim_node, gnn_dim_edge, sample_param, memory_param, gnn_param, train_param)
device = torch.device('cpu')
model.load_state_dict(torch.load(MODEL_SAVE_PATH))
sample_graph = TemporalNeighborSampleGraph(sample_graph = pdata.sample_graph,mode = 'full')
if memory_param['type'] != 'none':
mailbox = SharedMailBox(pdata.ids.shape[0], memory_param, dim_edge_feat = pdata.edge_attr.shape[1] if pdata.edge_attr is not None else 0)
else:
mailbox = None
fanout = []
num_layers = sample_param['layer'] if 'layer' in sample_param else 1
fanout = sample_param['neighbor'] if 'neighbor' in sample_param else [10]
policy = sample_param['strategy'] if 'strategy' in sample_param else 'recent'
sampler = NeighborSampler(num_nodes=graph.num_nodes, num_layers=num_layers, fanout=fanout,graph_data=sample_graph, workers=10,policy = policy, graph_name = "wiki_train")
train_data = torch.masked_select(graph.edge_index,pdata.train_mask.to(device)).reshape(2,-1)
train_ts = torch.masked_select(graph.edge_ts,pdata.train_mask.to(device))
val_data = torch.masked_select(graph.edge_index,pdata.val_mask.to(device)).reshape(2,-1)
val_ts = torch.masked_select(graph.edge_ts,pdata.val_mask.to(device))
test_data = torch.masked_select(graph.edge_index,pdata.test_mask.to(device)).reshape(2,-1)
test_ts = torch.masked_select(graph.edge_ts,pdata.test_mask.to(device))
#print(train_data.shape[1],val_data.shape[1],test_data.shape[1])
train_data = DataSet(edges = train_data,ts =train_ts,eids = torch.nonzero(pdata.train_mask).view(-1))
test_data = DataSet(edges = test_data,ts =test_ts,eids = torch.nonzero(pdata.test_mask).view(-1))
val_data = DataSet(edges = val_data,ts = val_ts,eids = torch.nonzero(pdata.val_mask).view(-1))
new_node_val_data = torch.masked_select(graph.edge_index,pdata.new_node_val_mask.to(device)).reshape(2,-1)
new_node_val_ts = torch.masked_select(graph.edge_ts,pdata.new_node_val_mask.to(device))
new_node_test_data = torch.masked_select(graph.edge_index,pdata.new_node_test_mask.to(device)).reshape(2,-1)
new_node_test_ts = torch.masked_select(graph.edge_ts,pdata.new_node_test_mask.to(device))
new_node_val_data = DataSet(edges = new_node_val_data, ts = new_node_val_ts, edis = torch.nonzero(pdata.new_node_val_mask).view(-1))
new_node_test_data = DataSet(edges = new_node_test_data, ts = new_node_test_ts, edis = torch.nonzero(pdata.new_node_test_mask).view(-1))
if args.negative_sample_strategy != 'random':
val_neg_edge_sampler = EvaluateNegativeSampling(src_node_ids=graph.edge_index[0,:], dst_node_ids=graph.edge_index[1,:],
interact_times=graph.edge_ts, last_observed_time=train_data.ts[-1],
negative_sample_strategy=args.negative_sample_strategy, seed=0)
new_node_val_neg_edge_sampler = EvaluateNegativeSampling(src_node_ids=new_node_val_data.edges[0,:], dst_node_ids=new_node_val_data.edges[1,:],
interact_times=new_node_val_data.ts, last_observed_time=train_data.ts[-1],
negative_sample_strategy=args.negative_sample_strategy, seed=1)
test_neg_edge_sampler = EvaluateNegativeSampling(src_node_ids=graph.edge_index[0,:], dst_node_ids=graph.edge_index[1,:],
interact_times=graph.edge_ts, last_observed_time=val_data.ts[-1],
negative_sample_strategy=args.negative_sample_strategy, seed=2)
new_node_test_neg_edge_sampler = EvaluateNegativeSampling(src_node_ids=new_node_test_data.edges[0,:], dst_node_ids=new_node_test_data.edges[1,:],
interact_times=new_node_test_data.ts, last_observed_time=val_data.ts[-1],
negative_sample_strategy=args.negative_sample_strategy, seed=3)
else:
val_neg_edge_sampler = EvaluateNegativeSampling(src_node_ids=graph.edge_index[0,:], dst_node_ids=graph.edge_index[1,:], seed=0)
new_node_val_neg_edge_sampler = EvaluateNegativeSampling(src_node_ids=new_node_val_data.edges[0,:], dst_node_ids=new_node_val_data.edges[1,:], seed=1)
test_neg_edge_sampler =EvaluateNegativeSampling(src_node_ids=graph.edge_index[0,:], dst_node_ids=graph.edge_index[1,:], seed=2)
new_node_test_neg_edge_sampler = EvaluateNegativeSampling(src_node_ids=new_node_test_data.edges[0,:], dst_node_ids=new_node_test_data.edges[1,:], seed=3)
This source diff could not be displayed because it is too large. You can view the blob instead.
import itertools
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl.function as fn
from dgl.nn import SAGEConv
import numpy as np
class TimeEncode(torch.nn.Module):
def __init__(self, dim):
super(TimeEncode, self).__init__()
self.dim = dim
self.w = torch.nn.Linear(1, dim)
self.w.weight = torch.nn.Parameter((torch.from_numpy(1 / 10 ** np.linspace(0, 9, dim, dtype=np.float32))).reshape(dim, -1))
self.w.bias = torch.nn.Parameter(torch.zeros(dim))
def forward(self, t):
output = torch.cos(self.w(t.float().reshape((-1, 1))))
return output
class GraphSAGE(nn.Module):
def __init__(self, in_feats, h_feats):
super(GraphSAGE, self).__init__()
self.conv1 = SAGEConv(in_feats, h_feats, "mean")
self.conv2 = SAGEConv(h_feats, h_feats, "mean")
def forward(self, g, in_feat):
h = self.conv1(g, in_feat)
h = F.relu(h)
h = self.conv2(g, h)
return h
class TSAGELayer(nn.Module):
def __init__(self, in_dim = 0, edge_dim = 0, time_dim = 0, h_feats = 0):
super(TSAGELayer, self).__init__()
assert in_dim + time_dim != 0 and h_feats != 0
self.time_dim = time_dim
self.time_enc = TimeEncode(time_dim)
self.sage = SAGEConv(in_dim,h_feats,"mean"),
def forward(self,b):
time_f = self.time_enc(b.edata['dt'])
time_f = torch.cat((torch.zeros(b.num_dst_nodes(),self.time_dim,dtype = time_f.dtype,
device = time_f.device),time_f),dim = 1)
if 'f' in b.edata:
edge_f = torch.cat((torch.zeros(b.num_dst_nodes(),b.edata['f'].shape[1],
dtype = b.edata['f'].dtype,device = b.edata['f'].device),
b.edata['f']),dim = 1)
if 'h' in b.srcdata:
b.srcdata['h'] = torch.cat((b.srcnode['h'],edge_f,time_f),dim = 1)
else:
b.srcdata['h'] = torch.cat((edge_f,time_f),dim = 1)
else:
if 'h' in b.srcdata:
b.srcdata['h'] = torch.cat((b.srcnode['h'],time_f),dim = 1)
else:
b.srcdata['h'] = time_f
return F.relu(self.sage(b,b.src_data['h']))
class TSAGEModel(nn.Module):
def __init__(self, num_layer, node_dim, edge_dim, time_dim, h_dim):
super(TSAGEModel, self).__init__()
self.num_layer = num_layer
layer = []
for i in range(num_layer):
if i != 0:
layer.append(TSAGELayer(h_dim,edge_dim,time_dim,h_dim))
else:
layer.append(TSAGELayer(node_dim,edge_dim,time_dim,h_dim))
self.layers = layer
def forward(self,mfgs):
for l in range(len(mfgs)):
for h in range(len(mfgs[l])):
if l < self.num_layer - 1:
mfgs[l+1][h].srcdata['h'] = self.layers[l](self,mfgs[l][h])
else:
return self.layers[l](self,mfgs[l][h])
class DotPredictor(nn.Module):
def forward(self, g, h):
with g.local_scope():
g.ndata["h"] = h
# Compute a new edge feature named 'score' by a dot-product between the
# source node feature 'h' and destination node feature 'h'.
g.apply_edges(fn.u_dot_v("h", "h", "score"))
# u_dot_v returns a 1-element vector for each edge so you need to squeeze it.
return g.edata["score"][:, 0]
# Thumbnail credits: Link Prediction with Neo4j, Mark Needham
# sphinx_gallery_thumbnail_path = '_static/blitz_4_link_predict.png'
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment