From c6739ba7f39368562bc84682683c277aa0df79f3 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Tue, 17 Sep 2024 06:04:19 -0700 Subject: [PATCH] Faster RNN layers (#1419) * faster rnn * use admm --- python/mlx/nn/layers/recurrent.py | 34 +++++++++++++++---------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/python/mlx/nn/layers/recurrent.py b/python/mlx/nn/layers/recurrent.py index 2bb32c711..3ffa7654c 100644 --- a/python/mlx/nn/layers/recurrent.py +++ b/python/mlx/nn/layers/recurrent.py @@ -54,7 +54,7 @@ class RNN(Module): scale = 1.0 / math.sqrt(hidden_size) self.hidden_size = hidden_size self.Wxh = mx.random.uniform( - low=-scale, high=scale, shape=(input_size, hidden_size) + low=-scale, high=scale, shape=(hidden_size, input_size) ) self.Whh = mx.random.uniform( low=-scale, high=scale, shape=(hidden_size, hidden_size) @@ -67,21 +67,21 @@ class RNN(Module): def _extra_repr(self): return ( - f"input_dims={self.Wxh.shape[0]}, " + f"input_dims={self.Wxh.shape[1]}, " f"hidden_size={self.hidden_size}, " f"nonlinearity={self.nonlinearity}, bias={self.bias is not None}" ) def __call__(self, x, hidden=None): if self.bias is not None: - x = mx.addmm(self.bias, x, self.Wxh) + x = mx.addmm(self.bias, x, self.Wxh.T) else: - x = x @ self.Wxh + x = x @ self.Wxh.T all_hidden = [] for idx in range(x.shape[-2]): if hidden is not None: - hidden = x[..., idx, :] + hidden @ self.Whh + hidden = mx.addmm(x[..., idx, :], hidden, self.Whh.T) else: hidden = x[..., idx, :] hidden = self.nonlinearity(hidden) @@ -131,10 +131,10 @@ class GRU(Module): self.hidden_size = hidden_size scale = 1.0 / math.sqrt(hidden_size) self.Wx = mx.random.uniform( - low=-scale, high=scale, shape=(input_size, 3 * hidden_size) + low=-scale, high=scale, shape=(3 * hidden_size, input_size) ) self.Wh = mx.random.uniform( - low=-scale, high=scale, shape=(hidden_size, 3 * hidden_size) + low=-scale, high=scale, shape=(3 * hidden_size, hidden_size) ) self.b = ( mx.random.uniform(low=-scale, high=scale, shape=(3 * hidden_size,)) @@ -149,15 +149,15 @@ class GRU(Module): def _extra_repr(self): return ( - f"input_dims={self.Wx.shape[0]}, " + f"input_dims={self.Wx.shape[1]}, " f"hidden_size={self.hidden_size}, bias={self.b is not None}" ) def __call__(self, x, hidden=None): if self.b is not None: - x = mx.addmm(self.b, x, self.Wx) + x = mx.addmm(self.b, x, self.Wx.T) else: - x = x @ self.Wx + x = x @ self.Wx.T x_rz = x[..., : -self.hidden_size] x_n = x[..., -self.hidden_size :] @@ -167,7 +167,7 @@ class GRU(Module): for idx in range(x.shape[-2]): rz = x_rz[..., idx, :] if hidden is not None: - h_proj = hidden @ self.Wh + h_proj = hidden @ self.Wh.T h_proj_rz = h_proj[..., : -self.hidden_size] h_proj_n = h_proj[..., -self.hidden_size :] @@ -240,10 +240,10 @@ class LSTM(Module): self.hidden_size = hidden_size scale = 1.0 / math.sqrt(hidden_size) self.Wx = mx.random.uniform( - low=-scale, high=scale, shape=(input_size, 4 * hidden_size) + low=-scale, high=scale, shape=(4 * hidden_size, input_size) ) self.Wh = mx.random.uniform( - low=-scale, high=scale, shape=(hidden_size, 4 * hidden_size) + low=-scale, high=scale, shape=(4 * hidden_size, hidden_size) ) self.bias = ( mx.random.uniform(low=-scale, high=scale, shape=(4 * hidden_size,)) @@ -253,15 +253,15 @@ class LSTM(Module): def _extra_repr(self): return ( - f"input_dims={self.Wx.shape[0]}, " + f"input_dims={self.Wx.shape[1]}, " f"hidden_size={self.hidden_size}, bias={self.bias is not None}" ) def __call__(self, x, hidden=None, cell=None): if self.bias is not None: - x = mx.addmm(self.bias, x, self.Wx) + x = mx.addmm(self.bias, x, self.Wx.T) else: - x = x @ self.Wx + x = x @ self.Wx.T all_hidden = [] all_cell = [] @@ -269,7 +269,7 @@ class LSTM(Module): for idx in range(x.shape[-2]): ifgo = x[..., idx, :] if hidden is not None: - ifgo = ifgo + hidden @ self.Wh + ifgo = mx.addmm(ifgo, hidden, self.Wh.T) i, f, g, o = mx.split(ifgo, 4, axis=-1) i = mx.sigmoid(i)