fix a couple bugs (#952)

This commit is contained in:
Awni Hannun
2024-04-02 12:07:41 -07:00
committed by GitHub
parent 1a87dc5ea8
commit 741eb28443
3 changed files with 9 additions and 5 deletions

View File

@@ -186,9 +186,11 @@ class GRU(Module):
n = n + r * h_proj_n
n = mx.tanh(n)
hidden = (1 - z) * n
if hidden is not None:
hidden = hidden + z * hidden
hidden = (1 - z) * n + z * hidden
else:
hidden = (1 - z) * n
all_hidden.append(hidden)
return mx.stack(all_hidden, axis=-2)