Commit 9ca0dca0 by Wenjie Huang

update emma

parent 17da3ed5
......@@ -79,16 +79,20 @@ class EmmaAttention(nn.Module):
ptr, ind, val = adj_t.csr()
a = src_a[ind] + gather_csr(dst_a, ptr)
if val is not None:
assert val.dim() == src_a.dim()
a = a + val
a = F.leaky_relu(a, negative_slope=negative_slope)
with torch.no_grad():
max_a = torch.full_like(dst_a, -torch.inf)
max_a = segment_csr(a, ptr, reduce="max", out=max_a)
exp_a = torch.exp(a - gather_csr(max_a, ptr))
if val is not None:
assert val.dim() == 1
if exp_a.dim() == 1:
exp_a = exp_a * val
else:
exp_a = exp_a * val.unsqueeze(-1)
sum_exp_a = segment_csr(exp_a, ptr, reduce="sum")
exp_a = exp_a / gather_csr(sum_exp_a, ptr)
with torch.no_grad():
......@@ -104,7 +108,7 @@ class EmmaAttention(nn.Module):
dst_a: Tensor,
adj_t: SparseTensor,
negative_slope: float = 0.01,
) -> Tuple[Tensor, Tensor, Tensor]:
) -> Tuple[Tensor, Tensor]:
adj_t, max_a = EmmaAttention.softmax_gat(
src_a=src_a, dst_a=dst_a,
adj_t=adj_t, negative_slope=negative_slope,
......@@ -126,10 +130,7 @@ class EmmaAttention(nn.Module):
)
x = torch.cat(xs, dim=1).view(-1, *x.shape[1:])
with torch.no_grad():
agg_n = torch.ones_like(ind)
agg_n = segment_csr(agg_n, ptr, reduce="sum")
return x, max_a, agg_n
return x, max_a
class EmmaAttentionFunction(autograd.Function):
@staticmethod
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment