Commit 83125594 by Wenjie Huang

add example for layerpipe

parent e6c06ee9
......@@ -128,7 +128,7 @@ class LayerPipeRuntime:
if op == "sync":
xs = self.ready_bw[(layer_i - 1, snap_i, 1)].values() if layer_i > 0 else None
with self.program._switch_layer(layer_i, snap_i):
xs = self.program.layer_inputs(None)
xs = self.program.layer_inputs(xs)
route = self.program.get_route()
self.ready_bw[(layer_i, snap_i, 0)] = LayerRoute(route, *xs)
elif op == "comp":
......
......@@ -15,6 +15,7 @@ from starrygl.parallel.utils import *
import torch_geometric.nn as pyg_nn
import torch_geometric.datasets as pyg_datasets
import torch_geometric.utils as pyg_utils
from torch_scatter import scatter_mean
import logging
logging.getLogger().setLevel(logging.INFO)
......@@ -41,71 +42,69 @@ def prepare_data(root: str, num_parts, part_algo: str = "metis"):
g.save_partition(root, num_parts, algorithm=part_algo)
return g
class SimpleConv(pyg_nn.MessagePassing):
class SageConv(nn.Module):
def __init__(self, in_feats: int, out_feats: int):
super().__init__(aggr="mean")
self.linear = nn.Linear(in_feats, out_feats)
def forward(self, x: Tensor, edge_index: Tensor, route: Route):
dst_len = x.size(0)
x = route.apply(x) # exchange features
return self.propagate(edge_index, x=x)[:dst_len]
super().__init__()
self.weight = nn.Parameter(torch.empty(out_feats, in_feats))
self.bias = nn.Parameter(torch.empty(out_feats))
self.reset_parameters()
def message(self, x_j: Tensor):
return x_j
def reset_parameters(self):
nn.init.xavier_normal_(self.weight)
nn.init.zeros_(self.bias)
def update(self, x: Tensor):
return F.relu(self.linear(x))
def forward(self, x: Tensor, edge_index: Tensor, num_nodes: int):
x = F.linear(x, self.weight)
x = x[edge_index[0]]
x = scatter_mean(x, edge_index[1], dim=0, dim_size=num_nodes)
return x + self.bias
class SimpleGNN(nn.Module):
class SageGNN(LayerPipe):
def __init__(self,
num_features: int,
hidden_dims: int,
graph: GraphData,
hidden_dim: int,
num_layers: int,
num_snapshots: int,
group: Any,
) -> None:
super().__init__()
self.layers = nn.ModuleList()
self.graph = graph
self.route = graph.to_route(group)
self.group = group
for i in range(num_layers):
in_ch = hidden_dims if i > 0 else num_features
out_ch = hidden_dims
self.layers.append(SimpleConv(in_ch, out_ch))
def forward(self, x: Tensor, edge_index: Tensor, route: Route):
for layer in self.layers:
x = layer(x, edge_index, route)
return x
self.num_layers = num_layers
self.num_snapshots = num_snapshots
class SimpleGNNPipe(LayerPipe):
def __init__(self,
num_features: int,
hidden_dims: int,
num_layers: int,
features: Tensor,
edge_index: Tensor,
route: Route,
) -> None:
super().__init__()
self.features = features
self.edge_index = edge_index
self.route = route
num_features = self.graph.node("dst")["x"].size(1)
self.hidden_dim = hidden_dim
self.net = SimpleGNN(num_features, hidden_dims, num_layers)
self.layers = nn.ModuleList()
for i in range(num_layers):
out_ch = in_ch = hidden_dim
if i == 0:
in_ch = num_features
self.layers.append(SageConv(in_ch, out_ch))
def get_route(self) -> Route:
return self.route
def layer_inputs(self, inputs: Sequence[Tensor] | None = None) -> Sequence[Tensor]:
if self.layer_id == 0:
x = self.features
x = self.graph.node("dst")["x"]
else:
x, = inputs
x = self.route.apply(x)
self.register_route(x)
return (x,)
def layer_forward(self, inputs: Sequence[Tensor]) -> Sequence[Tensor]:
x, = inputs
x = self.net.layers[self.layer_id](x, self.edge_index, self.route)
edge_index = self.graph.edge_index()
x = self.layers[self.layer_id](x, edge_index, self.route.dst_len)
return (x,)
class SimpleRNN(SequencePipe, nn.Module):
class SimpleRNN(SequencePipe):
def __init__(self,
num_classes: int,
hidden_dims: int,
......@@ -135,11 +134,19 @@ class SimpleRNN(SequencePipe, nn.Module):
h = h.transpose(0, 1).contiguous() # (L, N, H)
x, h = self.gru(x, h) # (N, L, H), (L, N, H)
h = h.transpose(0, 1).contiguous() # (N, L, H)
x = self.out(x)
return (x,), (h, )
def loss_fn(self, inputs, labels) -> Tensor:
x, = inputs
return x.square().mean()
mask, y = labels
x = x[mask, -1]
if x.numel() > 0:
y = y[mask]
return F.cross_entropy(x, y)
else:
return x.mul(0.0).sum()
def get_group(self) -> Any:
return self.group
......@@ -170,17 +177,26 @@ if __name__ == "__main__":
data_root,
dist.get_rank(pp_group), dist.get_world_size(pp_group),
).to(ctx.device)
route = g.to_route(pp_group) # only on subgroup
num_features = g.node("dst")["x"].size(-1)
num_classes = g.meta()["num_classes"]
hidden_dims = 128
hidden_dim = 128
num_layers = 3
num_snapshots = 200
num_classes = g.meta()["num_classes"]
gnn = SageGNN(g, hidden_dim, num_layers, num_snapshots, group=pp_group).to(ctx.device)
rnn = SimpleRNN(num_classes, hidden_dim, num_layers, device=ctx.device, group=sp_group).to(ctx.device)
gnn = SimpleGNN(num_features, hidden_dims, num_layers).to(ctx.device)
rnn = SimpleRNN(num_classes, hidden_dims, num_layers, device=ctx.device, group=sp_group).to(ctx.device)
params = []
for _, net in gnn.get_model():
params.extend(net.parameters())
for _, net in rnn.get_model():
params.extend(net.parameters())
opt = torch.optim.Adam(params)
opt = torch.optim.Adam([p for p in gnn.parameters()] + [p for p in rnn.parameters()])
labels = (
g.node("dst")["train_mask"],
g.node("dst")["y"],
)
for ep in range(1, 100+1):
seq_len = 200
......@@ -188,26 +204,17 @@ if __name__ == "__main__":
opt.zero_grad()
for _ in range(seq_len): # snapshot parallel between partition parallel subgroups
z = gnn(
x = g.node("dst")["x"],
edge_index = g.edge_index(),
route = route, #
)
xs.append(z.unsqueeze(1))
x = torch.cat(xs, dim=1) # (N, S, H)
xs = gnn.apply(num_layers, num_snapshots)
x = torch.cat([x.unsqueeze(1) for x, in xs], dim=1) # (N, S, H)
# loss = rnn.apply(32, x)[0].square().mean()
# loss.backward() # sequence and pipeline parallel on each graph nodes
loss = rnn.fast_backward(32, (x,), (g.node("dst")["train_mask"],))
loss = rnn.fast_backward(32, (x,), labels)
rnn.all_reduce()
# all reduce
all_reduce_gradients(rnn)
all_reduce_buffers(rnn)
all_reduce_gradients(gnn)
all_reduce_buffers(gnn)
gnn.backward()
gnn.all_reduce()
opt.step()
ctx.sync_print(loss)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment