Faster RNN layers (#1419)

* faster rnn

* use admm
This commit is contained in:
Awni Hannun 2024-09-17 06:04:19 -07:00 committed by GitHub
parent 914409fef9
commit c6739ba7f3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -54,7 +54,7 @@ class RNN(Module):
scale = 1.0 / math.sqrt(hidden_size)
self.hidden_size = hidden_size
self.Wxh = mx.random.uniform(
low=-scale, high=scale, shape=(input_size, hidden_size)
low=-scale, high=scale, shape=(hidden_size, input_size)
)
self.Whh = mx.random.uniform(
low=-scale, high=scale, shape=(hidden_size, hidden_size)
@ -67,21 +67,21 @@ class RNN(Module):
def _extra_repr(self):
return (
f"input_dims={self.Wxh.shape[0]}, "
f"input_dims={self.Wxh.shape[1]}, "
f"hidden_size={self.hidden_size}, "
f"nonlinearity={self.nonlinearity}, bias={self.bias is not None}"
)
def __call__(self, x, hidden=None):
if self.bias is not None:
x = mx.addmm(self.bias, x, self.Wxh)
x = mx.addmm(self.bias, x, self.Wxh.T)
else:
x = x @ self.Wxh
x = x @ self.Wxh.T
all_hidden = []
for idx in range(x.shape[-2]):
if hidden is not None:
hidden = x[..., idx, :] + hidden @ self.Whh
hidden = mx.addmm(x[..., idx, :], hidden, self.Whh.T)
else:
hidden = x[..., idx, :]
hidden = self.nonlinearity(hidden)
@ -131,10 +131,10 @@ class GRU(Module):
self.hidden_size = hidden_size
scale = 1.0 / math.sqrt(hidden_size)
self.Wx = mx.random.uniform(
low=-scale, high=scale, shape=(input_size, 3 * hidden_size)
low=-scale, high=scale, shape=(3 * hidden_size, input_size)
)
self.Wh = mx.random.uniform(
low=-scale, high=scale, shape=(hidden_size, 3 * hidden_size)
low=-scale, high=scale, shape=(3 * hidden_size, hidden_size)
)
self.b = (
mx.random.uniform(low=-scale, high=scale, shape=(3 * hidden_size,))
@ -149,15 +149,15 @@ class GRU(Module):
def _extra_repr(self):
return (
f"input_dims={self.Wx.shape[0]}, "
f"input_dims={self.Wx.shape[1]}, "
f"hidden_size={self.hidden_size}, bias={self.b is not None}"
)
def __call__(self, x, hidden=None):
if self.b is not None:
x = mx.addmm(self.b, x, self.Wx)
x = mx.addmm(self.b, x, self.Wx.T)
else:
x = x @ self.Wx
x = x @ self.Wx.T
x_rz = x[..., : -self.hidden_size]
x_n = x[..., -self.hidden_size :]
@ -167,7 +167,7 @@ class GRU(Module):
for idx in range(x.shape[-2]):
rz = x_rz[..., idx, :]
if hidden is not None:
h_proj = hidden @ self.Wh
h_proj = hidden @ self.Wh.T
h_proj_rz = h_proj[..., : -self.hidden_size]
h_proj_n = h_proj[..., -self.hidden_size :]
@ -240,10 +240,10 @@ class LSTM(Module):
self.hidden_size = hidden_size
scale = 1.0 / math.sqrt(hidden_size)
self.Wx = mx.random.uniform(
low=-scale, high=scale, shape=(input_size, 4 * hidden_size)
low=-scale, high=scale, shape=(4 * hidden_size, input_size)
)
self.Wh = mx.random.uniform(
low=-scale, high=scale, shape=(hidden_size, 4 * hidden_size)
low=-scale, high=scale, shape=(4 * hidden_size, hidden_size)
)
self.bias = (
mx.random.uniform(low=-scale, high=scale, shape=(4 * hidden_size,))
@ -253,15 +253,15 @@ class LSTM(Module):
def _extra_repr(self):
return (
f"input_dims={self.Wx.shape[0]}, "
f"input_dims={self.Wx.shape[1]}, "
f"hidden_size={self.hidden_size}, bias={self.bias is not None}"
)
def __call__(self, x, hidden=None, cell=None):
if self.bias is not None:
x = mx.addmm(self.bias, x, self.Wx)
x = mx.addmm(self.bias, x, self.Wx.T)
else:
x = x @ self.Wx
x = x @ self.Wx.T
all_hidden = []
all_cell = []
@ -269,7 +269,7 @@ class LSTM(Module):
for idx in range(x.shape[-2]):
ifgo = x[..., idx, :]
if hidden is not None:
ifgo = ifgo + hidden @ self.Wh
ifgo = mx.addmm(ifgo, hidden, self.Wh.T)
i, f, g, o = mx.split(ifgo, 4, axis=-1)
i = mx.sigmoid(i)