mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
parent
914409fef9
commit
c6739ba7f3
@ -54,7 +54,7 @@ class RNN(Module):
|
||||
scale = 1.0 / math.sqrt(hidden_size)
|
||||
self.hidden_size = hidden_size
|
||||
self.Wxh = mx.random.uniform(
|
||||
low=-scale, high=scale, shape=(input_size, hidden_size)
|
||||
low=-scale, high=scale, shape=(hidden_size, input_size)
|
||||
)
|
||||
self.Whh = mx.random.uniform(
|
||||
low=-scale, high=scale, shape=(hidden_size, hidden_size)
|
||||
@ -67,21 +67,21 @@ class RNN(Module):
|
||||
|
||||
def _extra_repr(self):
|
||||
return (
|
||||
f"input_dims={self.Wxh.shape[0]}, "
|
||||
f"input_dims={self.Wxh.shape[1]}, "
|
||||
f"hidden_size={self.hidden_size}, "
|
||||
f"nonlinearity={self.nonlinearity}, bias={self.bias is not None}"
|
||||
)
|
||||
|
||||
def __call__(self, x, hidden=None):
|
||||
if self.bias is not None:
|
||||
x = mx.addmm(self.bias, x, self.Wxh)
|
||||
x = mx.addmm(self.bias, x, self.Wxh.T)
|
||||
else:
|
||||
x = x @ self.Wxh
|
||||
x = x @ self.Wxh.T
|
||||
|
||||
all_hidden = []
|
||||
for idx in range(x.shape[-2]):
|
||||
if hidden is not None:
|
||||
hidden = x[..., idx, :] + hidden @ self.Whh
|
||||
hidden = mx.addmm(x[..., idx, :], hidden, self.Whh.T)
|
||||
else:
|
||||
hidden = x[..., idx, :]
|
||||
hidden = self.nonlinearity(hidden)
|
||||
@ -131,10 +131,10 @@ class GRU(Module):
|
||||
self.hidden_size = hidden_size
|
||||
scale = 1.0 / math.sqrt(hidden_size)
|
||||
self.Wx = mx.random.uniform(
|
||||
low=-scale, high=scale, shape=(input_size, 3 * hidden_size)
|
||||
low=-scale, high=scale, shape=(3 * hidden_size, input_size)
|
||||
)
|
||||
self.Wh = mx.random.uniform(
|
||||
low=-scale, high=scale, shape=(hidden_size, 3 * hidden_size)
|
||||
low=-scale, high=scale, shape=(3 * hidden_size, hidden_size)
|
||||
)
|
||||
self.b = (
|
||||
mx.random.uniform(low=-scale, high=scale, shape=(3 * hidden_size,))
|
||||
@ -149,15 +149,15 @@ class GRU(Module):
|
||||
|
||||
def _extra_repr(self):
|
||||
return (
|
||||
f"input_dims={self.Wx.shape[0]}, "
|
||||
f"input_dims={self.Wx.shape[1]}, "
|
||||
f"hidden_size={self.hidden_size}, bias={self.b is not None}"
|
||||
)
|
||||
|
||||
def __call__(self, x, hidden=None):
|
||||
if self.b is not None:
|
||||
x = mx.addmm(self.b, x, self.Wx)
|
||||
x = mx.addmm(self.b, x, self.Wx.T)
|
||||
else:
|
||||
x = x @ self.Wx
|
||||
x = x @ self.Wx.T
|
||||
|
||||
x_rz = x[..., : -self.hidden_size]
|
||||
x_n = x[..., -self.hidden_size :]
|
||||
@ -167,7 +167,7 @@ class GRU(Module):
|
||||
for idx in range(x.shape[-2]):
|
||||
rz = x_rz[..., idx, :]
|
||||
if hidden is not None:
|
||||
h_proj = hidden @ self.Wh
|
||||
h_proj = hidden @ self.Wh.T
|
||||
h_proj_rz = h_proj[..., : -self.hidden_size]
|
||||
h_proj_n = h_proj[..., -self.hidden_size :]
|
||||
|
||||
@ -240,10 +240,10 @@ class LSTM(Module):
|
||||
self.hidden_size = hidden_size
|
||||
scale = 1.0 / math.sqrt(hidden_size)
|
||||
self.Wx = mx.random.uniform(
|
||||
low=-scale, high=scale, shape=(input_size, 4 * hidden_size)
|
||||
low=-scale, high=scale, shape=(4 * hidden_size, input_size)
|
||||
)
|
||||
self.Wh = mx.random.uniform(
|
||||
low=-scale, high=scale, shape=(hidden_size, 4 * hidden_size)
|
||||
low=-scale, high=scale, shape=(4 * hidden_size, hidden_size)
|
||||
)
|
||||
self.bias = (
|
||||
mx.random.uniform(low=-scale, high=scale, shape=(4 * hidden_size,))
|
||||
@ -253,15 +253,15 @@ class LSTM(Module):
|
||||
|
||||
def _extra_repr(self):
|
||||
return (
|
||||
f"input_dims={self.Wx.shape[0]}, "
|
||||
f"input_dims={self.Wx.shape[1]}, "
|
||||
f"hidden_size={self.hidden_size}, bias={self.bias is not None}"
|
||||
)
|
||||
|
||||
def __call__(self, x, hidden=None, cell=None):
|
||||
if self.bias is not None:
|
||||
x = mx.addmm(self.bias, x, self.Wx)
|
||||
x = mx.addmm(self.bias, x, self.Wx.T)
|
||||
else:
|
||||
x = x @ self.Wx
|
||||
x = x @ self.Wx.T
|
||||
|
||||
all_hidden = []
|
||||
all_cell = []
|
||||
@ -269,7 +269,7 @@ class LSTM(Module):
|
||||
for idx in range(x.shape[-2]):
|
||||
ifgo = x[..., idx, :]
|
||||
if hidden is not None:
|
||||
ifgo = ifgo + hidden @ self.Wh
|
||||
ifgo = mx.addmm(ifgo, hidden, self.Wh.T)
|
||||
i, f, g, o = mx.split(ifgo, 4, axis=-1)
|
||||
|
||||
i = mx.sigmoid(i)
|
||||
|
Loading…
Reference in New Issue
Block a user