Commit d756e756 by zlj

test

parent cf5a4a40
import matplotlib.pyplot as plt
# 数据
alpha_values = [0, 0.1, 0.3, 0.5, 0.7, 0.9, 1.3, 1.5, 1.7, 2]
lastfm = [0.8712 ,0.933538, 0.931761, 0.932499, 0.933383, 0.931225, 0.929983, 0.933971, 0.928771, 0.8707 ]
wikitalk = [0.9710,0.979627, 0.97997, 0.979484, 0.980269, 0.980758, 0.97979, 0.980233, 0.980004, 0.9734 ]
stackoverflow = [0.9630,0.979641, 0.979372, 0.97967, 0.978169, 0.979624, 0.978846, 0.978428, 0.978397, 0.9749 ]
# 创建新的图像和子图
fig, axs = plt.subplots(1,3, figsize=(15,4))
# 设置全局标题
#fig.suptitle('精度折线图', fontsize=24, fontweight='bold')
# 绘制LASTFM折线图
axs[0].plot(alpha_values, lastfm, marker='o', linestyle='-', color='blue', linewidth=3)
axs[0].set_title('LASTFM', fontsize=20, fontweight='bold')
axs[0].set_xlabel('α', fontsize=16)
axs[0].set_ylabel('Test AP', fontsize=16)
axs[0].grid(True)
# 绘制WikiTalk折线图
axs[1].plot(alpha_values, wikitalk, marker='o', linestyle='-', color='green', linewidth=3)
axs[1].set_title('WikiTalk', fontsize=20, fontweight='bold')
axs[1].set_xlabel('α', fontsize=16)
axs[1].set_ylabel('Test AP', fontsize=16)
axs[1].grid(True)
# 绘制StackOverflow折线图
axs[2].plot(alpha_values, stackoverflow, marker='o', linestyle='-', color='red', linewidth=3)
axs[2].set_title('StackOverflow', fontsize=20, fontweight='bold')
axs[2].set_xlabel('α', fontsize=16)
axs[2].set_ylabel('Test AP', fontsize=16)
axs[2].grid(True)
plt.tight_layout()
plt.subplots_adjust(top=0.92)
plt.savefig('alpha.png')
theta,average_forward,average_backward,average_memory_updater,average_embedding,average_local_update,average_memory_sync,average_sample_and_build,average_memory_fetch
0,96856.49113285542,112221.48245840073,40301.897388637066,48822.715317726135,0.0,16000.196441346407,63911.53604484796,101853.45604764223
0.1,95662.69874529839,111533.12768054008,39077.55181486606,48873.47717645764,0.0,14643.937985780834,58304.243553686145,101811.6861515522
0.3,93883.74570572376,111777.802801466,37273.128538405894,48970.05911774039,0.0,12898.560065916181,56696.85652973652,110367.52254390717
0.5,93170.1633874178,112099.55837235451,36740.77984666824,48773.87382101417,0.0,12359.136115914584,58283.52728830576,103249.00511301756
0.7,93020.92460577488,112152.42083537579,36487.8266374588,48792.20610948205,0.0,11997.735760217905,62157.66416819095,102181.84207931758
0.9,93220.3081300497,111953.625908041,36531.27581562996,48914.576190519336,0.0,12060.063199546934,62090.06763002872,103451.30294804573
1.3,92909.71276578904,112195.95699987412,36376.36022908687,48855.23247879148,0.0,12023.359235164524,63666.89264053106,104705.44265072346
1.5,92906.04469819069,112168.77278401851,36401.54488204718,48837.5102509439,0.0,12066.843119865656,56549.59820256233,105085.46720664501
1.7,93191.48518505096,112471.45009524822,36524.07843518257,48860.69497559666,0.0,11936.615039774775,63745.34643083811,104101.63882282973
2,93234.853577137,112044.83792653083,36572.26520168781,48917.23433451653,0.0,12179.351857697964,57281.0512067914,106566.74535493851
theta,average_forward,average_backward,average_memory_updater,average_embedding,average_local_update,average_memory_sync,average_sample_and_build,average_memory_fetch
0.2,1020.0475714838504,837.6078978073597,609.9538529646396,269.77122262626887,0.0,315.69769369095565,430.82478720486165,832.9704289394617
0.4,1063.291701117754,834.1297216558456,642.3754260027408,274.332287170887,0.0,321.6530518513918,477.28159144580366,1129.4822247475386
0.6,1111.8212998402119,853.6988647305966,679.6548159575462,279.53119566351177,0.0,328.1626760321856,518.8979942470789,1429.0898495674132
0.8,1137.628830360174,870.1874030959606,705.8249095845223,279.0993657562137,0.0,343.0778785696626,513.5670813483,1723.660543204546
theta,average_forward,average_backward,average_memory_updater,average_embedding,average_local_update,average_memory_sync,average_sample_and_build,average_memory_fetch
theta,average_forward,average_backward,average_memory_updater,average_embedding,average_local_update,average_memory_sync,average_sample_and_build,average_memory_fetch
0,1038.5108485102653,762.7768550395965,610.9076774358749,273.6194903701544,0.0,311.0628311574459,490.81624430418015,662.4450206398964
0.1,996.7090173840522,768.2794192314147,587.606333398819,263.88492989838124,0.0,306.96758087277414,439.05469593405724,663.4521445333958
0.3,1016.776134121418,749.5895442962646,595.5274738430977,270.1349016159773,0.0,298.80202387571336,490.0341570496559,657.9857880115509
0.5,974.8833702325821,760.3781009674072,572.4515949070453,260.2825967848301,0.0,297.0785448908806,423.19416716098783,667.6366684556008
0.7,973.7973290085793,760.7215455651283,571.4436944365501,260.3819809645414,0.0,293.0186234384775,436.5759696722031,675.245310473442
0.9,972.2916827440262,760.531633090973,564.3444319665432,263.93835043311117,0.0,283.1126416653395,443.14170249700544,672.7872511267663
1.3,922.9985635995865,758.0635960578918,516.0758509039879,263.28917373418807,0.0,235.2485984534025,433.28196505308154,664.3425768435002
1.5,925.2486111044884,769.4904140233994,519.2285070300102,261.8939856499434,0.0,238.7658057242632,440.0552374184132,677.313303130865
1.7,914.4984895467758,768.0550072312355,511.9697087109089,261.20752018988134,0.0,237.16629126369952,421.08992391228674,679.3114461362362
2,926.4393308520317,762.9712475299835,517.6751752018929,264.23539278805254,0.0,239.46505763530732,440.6554836392403,666.9296593427658
import re
def read(file_name):
# 定义正则表达式模式
pattern = re.compile(r"time_count\.time_forward=(\d+\.\d+).*?time_count\.time_backward=(\d+\.\d+).*?time_count\.time_memory_updater=(\d+\.\d+).*?time_count\.time_embedding=(\d+\.\d+).*?time_count\.time_local_update=(\d+\.\d+).*?time_count\.time_memory_sync=(\d+\.\d+).*?time_count\.time_sample_and_build=(\d+\.\d+).*?time_count\.time_memory_fetch=(\d+\.\d+)")
pattern = re.compile(r"time_count\.time_forward=(\d+\.\d+).*?time_count\.time_backward=(\d+\.\d+).*?time_count\.time_memory_updater=(\d+\.\d+).*?time_count\.time_embedding=(\d+\.\d+).*?time_count\.time_local_update=(\d+).*?time_count\.time_memory_sync=(\d+\.\d+).*?time_count\.time_sample_and_build=(\d+\.\d+).*?time_count\.time_memory_fetch=(\d+\.\d+)")
# 读取日志文件
with open(file_name, 'r') as file:
log_content = file.read()
#print(log_content)
# 查找所有匹配的行
matches = pattern.findall(log_content)
# 初始化累加器
forward_times = []
backward_times = []
memory_updater_times = []
embedding_times = []
local_update_times = []
memory_sync_times = []
sample_and_build_times = []
memory_fetch_times = []
# 提取数据并累加
for match in matches:
forward_times.append(float(match[0]))
backward_times.append(float(match[1]))
memory_updater_times.append(float(match[2]))
embedding_times.append(float(match[3]))
local_update_times.append(float(match[4]))
memory_sync_times.append(float(match[5]))
sample_and_build_times.append(float(match[6]))
memory_fetch_times.append(float(match[7]))
# 计算平均值
def calculate_average(times):
return sum(times) / len(times) if times else 0
average_forward = calculate_average(forward_times)
average_backward = calculate_average(backward_times)
average_memory_updater = calculate_average(memory_updater_times)
average_embedding = calculate_average(embedding_times)
average_local_update = calculate_average(local_update_times)
average_memory_sync = calculate_average(memory_sync_times)
average_sample_and_build = calculate_average (sample_and_build_times)
average_memory_fetch = calculate_average(memory_fetch_times)
averages = {
'average_forward': average_forward,
'average_backward': average_backward,
'average_memory_updater': average_memory_updater,
'average_embedding': average_embedding,
'average_local_update': average_local_update,
'average_memory_sync': average_memory_sync,
'average_sample_and_build': average_sample_and_build,
'average_memory_fetch':average_memory_fetch,
}
return averages
import csv
def write_to_csv(data,theta_list, output_file):
model = "TGN" if data == "LASTFM" else "TGN_large"
neighbor = "10" if data == "LASTFM" else "20"
# 定义CSV文件的列名
fieldnames = [
'theta',
'average_forward',
'average_backward',
'average_memory_updater',
'average_embedding',
'average_local_update',
'average_memory_sync',
'average_sample_and_build',
'average_memory_fetch'
]
# 打开CSV文件进行写操作
with open(output_file, 'w', newline='') as csvfile:
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
# 写入表头
writer.writeheader()
# 遍历每个theta值
for theta in theta_list:
print("theta:",theta)
file_name = f"MemShare_v1/all_12357/{data}/{model}/8-ours-0.1-historical-0.3-boundery_recent_decay-{theta}-{neighbor}.out"
averages = read(file_name)
averages['theta'] = theta # 添加theta列
writer.writerow(averages)
# 示例theta参数列表
theta_list = ["0.2","0.4","0.6","0.8"]
# 输出CSV文件
data = "WikiTalk"
output_file = f"{data}_averages.csv"
# 调用函数写入CSV
write_to_csv(data,theta_list, output_file)
...@@ -10,23 +10,23 @@ partitions="8" ...@@ -10,23 +10,23 @@ partitions="8"
node_per="4" node_per="4"
nnodes="2" nnodes="2"
node_rank="0" node_rank="0"
probability_params=("0" "0.0001" "0.001" "0.01" "0.1" "1") probability_params=("0.1")
sample_type_params=("boundery_recent_decay" "recent") sample_type_params=("boundery_recent_decay")
#"boundery_recent_decay") #"boundery_recent_decay")
#sample_type_params=("recent" "boundery_recent_decay") #"boundery_recent_uniform") #sample_type_params=("recent" "boundery_recent_decay") #"boundery_recent_uniform")
#memory_type=("all_update" "p2p" "all_reduce" "historical" "local") #memory_type=("all_update" "p2p" "all_reduce" "historical" "local")
memory_type=("historical") memory_type=("historical")
#"historical") #"historical")
#memory_type=("local" "all_update" "historical" "all_reduce") #memory_type=("local" "all_update" "historical" "all_reduce")
shared_memory_ssim=("0" "0.1" "0.3" "0.5" "0.7" "0.9" "1.3" "1.5" "1.7" "2") shared_memory_ssim=("0.3")
#("0" "0.1" "0.3" "0.5" "0.7" "0.9" "1.3" "1.5" "1.7" "2")
#"historical") #"historical")
#memory_type=("local" "all_update" "historical" "all_reduce") #memory_type=("local" "all_update" "historical" "all_reduce")
#data_param=("WIKI" "REDDIT" "LASTFM" "WikiTalk") #data_param=("WIKI" "REDDIT" "LASTFM" "WikiTalk")
neighbor_num=( "10" "20") neighbor_num=( "10" "20")
neighbor="10" neighbor="10"
topk_list=("0.1") topk_list=("0.02" "0.04" "0.06" "0.08" "0.1" "0.2" "0.3")
#("0.02" "0.04" "0.06" "0.08" "0.1" "0.2" "0.3") data_param=("GDELT")
data_param=("LASTFM" "WikiTalk" "StackOverflow" "GDELT")
#"GDELT") #"GDELT")
#data_param=("WIKI" "REDDIT" "LASTFM" "DGraphFin" "WikiTalk" "StackOverflow") #data_param=("WIKI" "REDDIT" "LASTFM" "DGraphFin" "WikiTalk" "StackOverflow")
#data_param=("WIKI" "REDDIT" "LASTFM" "WikiTalk" "StackOverflow") #data_param=("WIKI" "REDDIT" "LASTFM" "WikiTalk" "StackOverflow")
...@@ -41,10 +41,10 @@ mkdir -p all_"$seed" ...@@ -41,10 +41,10 @@ mkdir -p all_"$seed"
for data in "${data_param[@]}"; do for data in "${data_param[@]}"; do
for topk in "${topk_list[@]}"; do for topk in "${topk_list[@]}"; do
model="TGN_large" model="TGN_large"
probability_params=("0" "0.0001" "0.001" "0.01" "0.1" "1") # probability_params=("0" "0.0001" "0.001" "0.01" "0.1" "1")
if [ "$data" = "StackOverflow" ]; then # if [ "$data" = "StackOverflow" ]; then
probability_params=("0.0001" "0.001" ) # probability_params=("0.0001" "0.001" )
fi # fi
neighbor="20" neighbor="20"
if [ "$data" = "WIKI" ] || [ "$data" = "REDDIT" ] || [ "$data" = "LASTFM" ]; then if [ "$data" = "WIKI" ] || [ "$data" = "REDDIT" ] || [ "$data" = "LASTFM" ]; then
model="TGN" model="TGN"
...@@ -83,10 +83,10 @@ for data in "${data_param[@]}"; do ...@@ -83,10 +83,10 @@ for data in "${data_param[@]}"; do
done done
else else
for pro in "${probability_params[@]}"; do for pro in "${probability_params[@]}"; do
shared_memory_ssim=("0.3") # shared_memory_ssim=("0.3")
if [ "$pro" = "0.1" ]; then # if [ "$pro" = "0.1" ]; then
shared_memory_ssim=("0" "0.1" "0.3" "0.5" "0.7" "0.9" "1.3" "1.5" "1.7" "2") # shared_memory_ssim=("0" "0.1" "0.3" "0.5" "0.7" "0.9" "1.3" "1.5" "1.7" "2")
fi # fi
for mem in "${memory_type[@]}"; do for mem in "${memory_type[@]}"; do
if [ "$mem" = "historical" ]; then if [ "$mem" = "historical" ]; then
for ssim in "${shared_memory_ssim[@]}"; do for ssim in "${shared_memory_ssim[@]}"; do
......
...@@ -205,8 +205,8 @@ def main(): ...@@ -205,8 +205,8 @@ def main():
graph,full_sampler_graph,train_mask,val_mask,test_mask,full_train_mask,cache_route = load_from_speed(args.dataname,seed=123457,top=args.topk,sampler_graph_add_rev=True, feature_device=torch.device('cuda:{}'.format(ctx.local_rank)),partition=args.partition)#torch.device('cpu')) graph,full_sampler_graph,train_mask,val_mask,test_mask,full_train_mask,cache_route = load_from_speed(args.dataname,seed=123457,top=args.topk,sampler_graph_add_rev=True, feature_device=torch.device('cuda:{}'.format(ctx.local_rank)),partition=args.partition)#torch.device('cpu'))
if(args.dataname=='GDELT'): if(args.dataname=='GDELT'):
train_param['epoch'] = 10 train_param['epoch'] = 10
if(args.probability > 0.005): #if(args.probability > 0.005):
train_param['epoch'] = 10 #train_param['epoch'] = 1
#torch.autograd.set_detect_anomaly(True) #torch.autograd.set_detect_anomaly(True)
# 确保 CUDA 可用 # 确保 CUDA 可用
if torch.cuda.is_available(): if torch.cuda.is_available():
......
...@@ -406,6 +406,11 @@ class AsyncMemeoryUpdater(torch.nn.Module): ...@@ -406,6 +406,11 @@ class AsyncMemeoryUpdater(torch.nn.Module):
wait_submit=submit_to_queue,spread_mail=spread_mail, wait_submit=submit_to_queue,spread_mail=spread_mail,
update_cross_mm=False, update_cross_mm=False,
) )
# self.mailbox.update_shared()
# self.mailbox.update_p2p_mem()
# self.mailbox.update_p2p_mail()
# self.mailbox.sychronize_shared()
# self.mailbox.handle_last_async()
if nxt_fetch_func is not None: if nxt_fetch_func is not None:
nxt_fetch_func() nxt_fetch_func()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment