mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 07:58:14 +08:00 
			
		
		
		
	| @@ -54,7 +54,7 @@ class RNN(Module): | ||||
|         scale = 1.0 / math.sqrt(hidden_size) | ||||
|         self.hidden_size = hidden_size | ||||
|         self.Wxh = mx.random.uniform( | ||||
|             low=-scale, high=scale, shape=(input_size, hidden_size) | ||||
|             low=-scale, high=scale, shape=(hidden_size, input_size) | ||||
|         ) | ||||
|         self.Whh = mx.random.uniform( | ||||
|             low=-scale, high=scale, shape=(hidden_size, hidden_size) | ||||
| @@ -67,21 +67,21 @@ class RNN(Module): | ||||
|  | ||||
|     def _extra_repr(self): | ||||
|         return ( | ||||
|             f"input_dims={self.Wxh.shape[0]}, " | ||||
|             f"input_dims={self.Wxh.shape[1]}, " | ||||
|             f"hidden_size={self.hidden_size}, " | ||||
|             f"nonlinearity={self.nonlinearity}, bias={self.bias is not None}" | ||||
|         ) | ||||
|  | ||||
|     def __call__(self, x, hidden=None): | ||||
|         if self.bias is not None: | ||||
|             x = mx.addmm(self.bias, x, self.Wxh) | ||||
|             x = mx.addmm(self.bias, x, self.Wxh.T) | ||||
|         else: | ||||
|             x = x @ self.Wxh | ||||
|             x = x @ self.Wxh.T | ||||
|  | ||||
|         all_hidden = [] | ||||
|         for idx in range(x.shape[-2]): | ||||
|             if hidden is not None: | ||||
|                 hidden = x[..., idx, :] + hidden @ self.Whh | ||||
|                 hidden = mx.addmm(x[..., idx, :], hidden, self.Whh.T) | ||||
|             else: | ||||
|                 hidden = x[..., idx, :] | ||||
|             hidden = self.nonlinearity(hidden) | ||||
| @@ -131,10 +131,10 @@ class GRU(Module): | ||||
|         self.hidden_size = hidden_size | ||||
|         scale = 1.0 / math.sqrt(hidden_size) | ||||
|         self.Wx = mx.random.uniform( | ||||
|             low=-scale, high=scale, shape=(input_size, 3 * hidden_size) | ||||
|             low=-scale, high=scale, shape=(3 * hidden_size, input_size) | ||||
|         ) | ||||
|         self.Wh = mx.random.uniform( | ||||
|             low=-scale, high=scale, shape=(hidden_size, 3 * hidden_size) | ||||
|             low=-scale, high=scale, shape=(3 * hidden_size, hidden_size) | ||||
|         ) | ||||
|         self.b = ( | ||||
|             mx.random.uniform(low=-scale, high=scale, shape=(3 * hidden_size,)) | ||||
| @@ -149,15 +149,15 @@ class GRU(Module): | ||||
|  | ||||
|     def _extra_repr(self): | ||||
|         return ( | ||||
|             f"input_dims={self.Wx.shape[0]}, " | ||||
|             f"input_dims={self.Wx.shape[1]}, " | ||||
|             f"hidden_size={self.hidden_size}, bias={self.b is not None}" | ||||
|         ) | ||||
|  | ||||
|     def __call__(self, x, hidden=None): | ||||
|         if self.b is not None: | ||||
|             x = mx.addmm(self.b, x, self.Wx) | ||||
|             x = mx.addmm(self.b, x, self.Wx.T) | ||||
|         else: | ||||
|             x = x @ self.Wx | ||||
|             x = x @ self.Wx.T | ||||
|  | ||||
|         x_rz = x[..., : -self.hidden_size] | ||||
|         x_n = x[..., -self.hidden_size :] | ||||
| @@ -167,7 +167,7 @@ class GRU(Module): | ||||
|         for idx in range(x.shape[-2]): | ||||
|             rz = x_rz[..., idx, :] | ||||
|             if hidden is not None: | ||||
|                 h_proj = hidden @ self.Wh | ||||
|                 h_proj = hidden @ self.Wh.T | ||||
|                 h_proj_rz = h_proj[..., : -self.hidden_size] | ||||
|                 h_proj_n = h_proj[..., -self.hidden_size :] | ||||
|  | ||||
| @@ -240,10 +240,10 @@ class LSTM(Module): | ||||
|         self.hidden_size = hidden_size | ||||
|         scale = 1.0 / math.sqrt(hidden_size) | ||||
|         self.Wx = mx.random.uniform( | ||||
|             low=-scale, high=scale, shape=(input_size, 4 * hidden_size) | ||||
|             low=-scale, high=scale, shape=(4 * hidden_size, input_size) | ||||
|         ) | ||||
|         self.Wh = mx.random.uniform( | ||||
|             low=-scale, high=scale, shape=(hidden_size, 4 * hidden_size) | ||||
|             low=-scale, high=scale, shape=(4 * hidden_size, hidden_size) | ||||
|         ) | ||||
|         self.bias = ( | ||||
|             mx.random.uniform(low=-scale, high=scale, shape=(4 * hidden_size,)) | ||||
| @@ -253,15 +253,15 @@ class LSTM(Module): | ||||
|  | ||||
|     def _extra_repr(self): | ||||
|         return ( | ||||
|             f"input_dims={self.Wx.shape[0]}, " | ||||
|             f"input_dims={self.Wx.shape[1]}, " | ||||
|             f"hidden_size={self.hidden_size}, bias={self.bias is not None}" | ||||
|         ) | ||||
|  | ||||
|     def __call__(self, x, hidden=None, cell=None): | ||||
|         if self.bias is not None: | ||||
|             x = mx.addmm(self.bias, x, self.Wx) | ||||
|             x = mx.addmm(self.bias, x, self.Wx.T) | ||||
|         else: | ||||
|             x = x @ self.Wx | ||||
|             x = x @ self.Wx.T | ||||
|  | ||||
|         all_hidden = [] | ||||
|         all_cell = [] | ||||
| @@ -269,7 +269,7 @@ class LSTM(Module): | ||||
|         for idx in range(x.shape[-2]): | ||||
|             ifgo = x[..., idx, :] | ||||
|             if hidden is not None: | ||||
|                 ifgo = ifgo + hidden @ self.Wh | ||||
|                 ifgo = mx.addmm(ifgo, hidden, self.Wh.T) | ||||
|             i, f, g, o = mx.split(ifgo, 4, axis=-1) | ||||
|  | ||||
|             i = mx.sigmoid(i) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun