mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-01 04:24:36 +08:00
fix a couple bugs (#952)
This commit is contained in:
@@ -186,9 +186,11 @@ class GRU(Module):
|
||||
n = n + r * h_proj_n
|
||||
n = mx.tanh(n)
|
||||
|
||||
hidden = (1 - z) * n
|
||||
if hidden is not None:
|
||||
hidden = hidden + z * hidden
|
||||
hidden = (1 - z) * n + z * hidden
|
||||
else:
|
||||
hidden = (1 - z) * n
|
||||
|
||||
all_hidden.append(hidden)
|
||||
|
||||
return mx.stack(all_hidden, axis=-2)
|
||||
|
Reference in New Issue
Block a user