Commit 595d8907 by Wenjie Huang

add negative sampling demo for DTDG

parent 6e65489d
......@@ -49,6 +49,7 @@ def prepare_data(root: str, num_parts):
g.node()["y"] = y
g.edge()["time"] = edge_times
g.edge()["attr"] = edge_attr
g.meta()["num_nodes"] = x.size(0) # 全局的节点数量
g.meta()["num_snapshots"] = snapshot_count # 快照数量
logging.info(f"GraphData.meta().keys(): {g.meta().keys()}")
......@@ -289,8 +290,6 @@ def get_graph(
sg = GraphData.from_bipartite(
edge_index=edge_index,
num_src_nodes=g.node("src").num_nodes,
num_dst_nodes=g.node("dst").num_nodes,
raw_src_ids=g.node("src")["raw_ids"],
raw_dst_ids=g.node("dst")["raw_ids"],
)
......@@ -299,8 +298,31 @@ def get_graph(
sg.edge()["time"] = time[mask] - start # 快照偏移从0开始
sg.edge()["attr"] = edge_attr
sg.meta()["num_snapshots"] = end - start
sg.meta()["num_nodes"] = g.meta()["num_nodes"]
return sg
def get_negative_route(g: GraphData, num_edges: int, group: Any):
num_nodes = g.meta()["num_nodes"] # 这个num_nodes是全局节点数量
raw_src_ids = g.node("src")["raw_ids"]
raw_dst_ids = g.node("dst")["raw_ids"]
# 随机选择src节点
src = torch.randint(num_nodes, size=(num_edges,)).type_as(raw_src_ids)
# 随机选择dst节点,并映射到全局id
dst = torch.randint(raw_dst_ids.numel(), size=(num_edges,)).type_as(raw_dst_ids)
dst = raw_dst_ids[dst]
edge_index = torch.vstack([src, dst]) # 生成负采样边
raw_src_ids = src.unique() # 对src节点去重,raw_dst_ids本身就是去重的
route = GraphData.from_bipartite(
edge_index=edge_index,
raw_src_ids=raw_src_ids,
raw_dst_ids=raw_dst_ids,
).to_route(group)
return route, edge_index
if __name__ == "__main__":
data_root = "./dataset"
......@@ -407,3 +429,9 @@ if __name__ == "__main__":
opt.step()
ctx.sync_print(f"loss: {loss.item():.6f}")
if True:
"""案例4:随机负采样边及其route。保持dst节点不变,随机选择src节点
"""
route, edge_index = get_negative_route(g, num_edges=1000, group=pp_group)
ctx.sync_print(route, edge_index.size())
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment