Commit 342421ab by zlj

fix feature device

parent e35ddb0f
...@@ -4,19 +4,19 @@ ...@@ -4,19 +4,19 @@
addr="192.168.1.107" addr="192.168.1.107"
partition_params=("ours" "metis" "ldg" "random") partition_params=("ours" "metis" "ldg" "random")
#("ours" "metis" "ldg" "random") #("ours" "metis" "ldg" "random")
partitions="4" partitions="16"
nnodes="1" nnodes="4"
node_rank="0" node_rank="0"
probability_params=("1" "0.5" "0.1" "0.05" "0.01" "0") probability_params=("1" "0.5" "0.1" "0.05" "0.01" "0")
#sample_type_params=("recent" "boundery_recent_decay" "boundery_recent_uniform") sample_type_params=("recent" "boundery_recent_decay" "boundery_recent_uniform")
sample_type_params=("recent") #sample_type_params=("recent")
#memory_type=("all_update" "p2p" "all_reduce" "historical" "local") #memory_type=("all_update" "p2p" "all_reduce" "historical" "local")
#memory_type=("all_update" "local" "historical") #memory_type=("all_update" "local" "historical")
memory_type=("local" "all_update") memory_type=("local" "all_update" "historical" "all_reduce")
shared_memory_ssim=("0" "0.1" "0.2" "0.3" "0.4" ) shared_memory_ssim=("0" "0.1" "0.2" "0.3" "0.4" )
#data_param=("WIKI" "REDDIT" "LASTFM" "WikiTalk") #data_param=("WIKI" "REDDIT" "LASTFM" "WikiTalk")
#data_param=("WIKI" "REDDIT" "LASTFM" "WikiTalk" "StackOverflow") data_param=("WIKI" "REDDIT" "LASTFM" "WikiTalk" "StackOverflow")
data_param=("REDDIT" "WikiTalk") #data_param=("REDDIT" "WikiTalk")
# 创建输出目录 # 创建输出目录
mkdir -p all mkdir -p all
......
...@@ -185,7 +185,7 @@ def main(): ...@@ -185,7 +185,7 @@ def main():
ctx = DistributedContext.init(backend="nccl", use_gpu=True,memory_group_num=1,cache_use_rpc=True) ctx = DistributedContext.init(backend="nccl", use_gpu=True,memory_group_num=1,cache_use_rpc=True)
torch.set_num_threads(10) torch.set_num_threads(10)
device_id = torch.cuda.current_device() device_id = torch.cuda.current_device()
graph,full_sampler_graph,train_mask,val_mask,test_mask,full_train_mask,cache_route = load_from_speed(args.dataname,seed=123457,top=args.topk,sampler_graph_add_rev=True, feature_device=torch.device('cpu'),partition=args.partition)#torch.device('cpu')) graph,full_sampler_graph,train_mask,val_mask,test_mask,full_train_mask,cache_route = load_from_speed(args.dataname,seed=123457,top=args.topk,sampler_graph_add_rev=True, feature_device=torch.device('cuda:{}'.format(ctx.local_rank)),partition=args.partition)#torch.device('cpu'))
# 确保 CUDA 可用 # 确保 CUDA 可用
if torch.cuda.is_available(): if torch.cuda.is_available():
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment