Commit 38eb626e by xxx

samplers little changes

parent 781f626e
......@@ -43,7 +43,7 @@ parser.add_argument('--rank', default=0, type=int, metavar='W',
parser.add_argument('--patience', type=int, default=5, help='Patience for early stopping')
parser.add_argument('--world_size', default=1, type=int, metavar='W',
help='number of negative samples')
parser.add_argument('--dataname', default=1, type=str, metavar='W',
parser.add_argument('--dataname', default="WIKI", type=str, metavar='W',
help='name of dataset')
parser.add_argument('--model', default='TGN', type=str, metavar='W',
help='name of model')
......@@ -81,7 +81,7 @@ def main():
ctx = DistributedContext.init(backend="nccl", use_gpu=True)
device_id = torch.cuda.current_device()
print('use cuda on',device_id)
pdata = partition_load("/mnt/data/part_data/dataset/here/{}".format(args.dataname), algo="metis_for_tgnn")
pdata = partition_load("/mnt/data/part_data/here/{}".format(args.dataname), algo="metis_for_tgnn")
graph = DistributedGraphStore(pdata = pdata)
Path("./saved_models/").mkdir(parents=True, exist_ok=True)
......@@ -179,7 +179,7 @@ def main():
total_loss = 0
signal = torch.tensor([0],dtype = int,device = device)
with torch.cuda.stream(train_stream):
for roots,mfgs,metadata,sample_time in loader:
for roots,mfgs,metadata in loader:
pred_pos, pred_neg = model(mfgs,metadata)
total_loss += creterion(pred_pos, torch.ones_like(pred_pos))
......@@ -252,8 +252,8 @@ def main():
model.module.memory_updater.last_updated_nid = None
model.module.memory_updater.last_updated_memory = None
model.module.memory_updater.last_updated_ts = None
for roots,mfgs,metadata,sample_time in trainloader:
fetch_time +=sample_time/1000
for roots,mfgs,metadata in trainloader:
# fetch_time +=sample_time/1000
t_prep_s = time.time()
with torch.cuda.stream(train_stream):
......@@ -327,7 +327,7 @@ def main():
else:
print('\ttrain loss:{:.4f} train ap:{:4f} val ap:{:4f} val auc:{:4f}'.format(total_loss,train_ap, ap, auc))
print('\ttotal time:{:.2f}s prep time:{:.2f}s'.format(time.time()-epoch_start_time, time_prep))
print('\t fetch time:{:.2f}s write back time:{:.2f}s'.format(fetch_time,write_back_time))
# print('\t fetch time:{:.2f}s write back time:{:.2f}s'.format(fetch_time,write_back_time))
torch.save(model.state_dict(), get_checkpoint_path(e))
model.eval()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment