Faster RNN layers (#1419)

* faster rnn

* use admm
This commit is contained in:
Awni Hannun 2024-09-17 06:04:19 -07:00 committed by GitHub
parent 914409fef9
commit c6739ba7f3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -54,7 +54,7 @@ class RNN(Module):
scale = 1.0 / math.sqrt(hidden_size) scale = 1.0 / math.sqrt(hidden_size)
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.Wxh = mx.random.uniform( self.Wxh = mx.random.uniform(
low=-scale, high=scale, shape=(input_size, hidden_size) low=-scale, high=scale, shape=(hidden_size, input_size)
) )
self.Whh = mx.random.uniform( self.Whh = mx.random.uniform(
low=-scale, high=scale, shape=(hidden_size, hidden_size) low=-scale, high=scale, shape=(hidden_size, hidden_size)
@ -67,21 +67,21 @@ class RNN(Module):
def _extra_repr(self): def _extra_repr(self):
return ( return (
f"input_dims={self.Wxh.shape[0]}, " f"input_dims={self.Wxh.shape[1]}, "
f"hidden_size={self.hidden_size}, " f"hidden_size={self.hidden_size}, "
f"nonlinearity={self.nonlinearity}, bias={self.bias is not None}" f"nonlinearity={self.nonlinearity}, bias={self.bias is not None}"
) )
def __call__(self, x, hidden=None): def __call__(self, x, hidden=None):
if self.bias is not None: if self.bias is not None:
x = mx.addmm(self.bias, x, self.Wxh) x = mx.addmm(self.bias, x, self.Wxh.T)
else: else:
x = x @ self.Wxh x = x @ self.Wxh.T
all_hidden = [] all_hidden = []
for idx in range(x.shape[-2]): for idx in range(x.shape[-2]):
if hidden is not None: if hidden is not None:
hidden = x[..., idx, :] + hidden @ self.Whh hidden = mx.addmm(x[..., idx, :], hidden, self.Whh.T)
else: else:
hidden = x[..., idx, :] hidden = x[..., idx, :]
hidden = self.nonlinearity(hidden) hidden = self.nonlinearity(hidden)
@ -131,10 +131,10 @@ class GRU(Module):
self.hidden_size = hidden_size self.hidden_size = hidden_size
scale = 1.0 / math.sqrt(hidden_size) scale = 1.0 / math.sqrt(hidden_size)
self.Wx = mx.random.uniform( self.Wx = mx.random.uniform(
low=-scale, high=scale, shape=(input_size, 3 * hidden_size) low=-scale, high=scale, shape=(3 * hidden_size, input_size)
) )
self.Wh = mx.random.uniform( self.Wh = mx.random.uniform(
low=-scale, high=scale, shape=(hidden_size, 3 * hidden_size) low=-scale, high=scale, shape=(3 * hidden_size, hidden_size)
) )
self.b = ( self.b = (
mx.random.uniform(low=-scale, high=scale, shape=(3 * hidden_size,)) mx.random.uniform(low=-scale, high=scale, shape=(3 * hidden_size,))
@ -149,15 +149,15 @@ class GRU(Module):
def _extra_repr(self): def _extra_repr(self):
return ( return (
f"input_dims={self.Wx.shape[0]}, " f"input_dims={self.Wx.shape[1]}, "
f"hidden_size={self.hidden_size}, bias={self.b is not None}" f"hidden_size={self.hidden_size}, bias={self.b is not None}"
) )
def __call__(self, x, hidden=None): def __call__(self, x, hidden=None):
if self.b is not None: if self.b is not None:
x = mx.addmm(self.b, x, self.Wx) x = mx.addmm(self.b, x, self.Wx.T)
else: else:
x = x @ self.Wx x = x @ self.Wx.T
x_rz = x[..., : -self.hidden_size] x_rz = x[..., : -self.hidden_size]
x_n = x[..., -self.hidden_size :] x_n = x[..., -self.hidden_size :]
@ -167,7 +167,7 @@ class GRU(Module):
for idx in range(x.shape[-2]): for idx in range(x.shape[-2]):
rz = x_rz[..., idx, :] rz = x_rz[..., idx, :]
if hidden is not None: if hidden is not None:
h_proj = hidden @ self.Wh h_proj = hidden @ self.Wh.T
h_proj_rz = h_proj[..., : -self.hidden_size] h_proj_rz = h_proj[..., : -self.hidden_size]
h_proj_n = h_proj[..., -self.hidden_size :] h_proj_n = h_proj[..., -self.hidden_size :]
@ -240,10 +240,10 @@ class LSTM(Module):
self.hidden_size = hidden_size self.hidden_size = hidden_size
scale = 1.0 / math.sqrt(hidden_size) scale = 1.0 / math.sqrt(hidden_size)
self.Wx = mx.random.uniform( self.Wx = mx.random.uniform(
low=-scale, high=scale, shape=(input_size, 4 * hidden_size) low=-scale, high=scale, shape=(4 * hidden_size, input_size)
) )
self.Wh = mx.random.uniform( self.Wh = mx.random.uniform(
low=-scale, high=scale, shape=(hidden_size, 4 * hidden_size) low=-scale, high=scale, shape=(4 * hidden_size, hidden_size)
) )
self.bias = ( self.bias = (
mx.random.uniform(low=-scale, high=scale, shape=(4 * hidden_size,)) mx.random.uniform(low=-scale, high=scale, shape=(4 * hidden_size,))
@ -253,15 +253,15 @@ class LSTM(Module):
def _extra_repr(self): def _extra_repr(self):
return ( return (
f"input_dims={self.Wx.shape[0]}, " f"input_dims={self.Wx.shape[1]}, "
f"hidden_size={self.hidden_size}, bias={self.bias is not None}" f"hidden_size={self.hidden_size}, bias={self.bias is not None}"
) )
def __call__(self, x, hidden=None, cell=None): def __call__(self, x, hidden=None, cell=None):
if self.bias is not None: if self.bias is not None:
x = mx.addmm(self.bias, x, self.Wx) x = mx.addmm(self.bias, x, self.Wx.T)
else: else:
x = x @ self.Wx x = x @ self.Wx.T
all_hidden = [] all_hidden = []
all_cell = [] all_cell = []
@ -269,7 +269,7 @@ class LSTM(Module):
for idx in range(x.shape[-2]): for idx in range(x.shape[-2]):
ifgo = x[..., idx, :] ifgo = x[..., idx, :]
if hidden is not None: if hidden is not None:
ifgo = ifgo + hidden @ self.Wh ifgo = mx.addmm(ifgo, hidden, self.Wh.T)
i, f, g, o = mx.split(ifgo, 4, axis=-1) i, f, g, o = mx.split(ifgo, 4, axis=-1)
i = mx.sigmoid(i) i = mx.sigmoid(i)