mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-21 20:46:46 +08:00
Remove unnecessary reshape
This commit is contained in:
parent
bfd7e786c3
commit
e1c12343b6
@ -118,7 +118,7 @@ class Bilinear(Module):
|
||||
)
|
||||
|
||||
def __call__(self, x1: mx.array, x2: mx.array) -> mx.array:
|
||||
y = (x1 @ self.weight * x2.reshape(1, *x2.shape)).sum(-1).T
|
||||
y = (x1 @ self.weight * x2).sum(-1).T
|
||||
if "bias" in self:
|
||||
y = y + self.bias
|
||||
return y
|
||||
|
Loading…
Reference in New Issue
Block a user