mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 18:48:15 +08:00 
			
		
		
		
	@@ -54,7 +54,7 @@ class RNN(Module):
 | 
				
			|||||||
        scale = 1.0 / math.sqrt(hidden_size)
 | 
					        scale = 1.0 / math.sqrt(hidden_size)
 | 
				
			||||||
        self.hidden_size = hidden_size
 | 
					        self.hidden_size = hidden_size
 | 
				
			||||||
        self.Wxh = mx.random.uniform(
 | 
					        self.Wxh = mx.random.uniform(
 | 
				
			||||||
            low=-scale, high=scale, shape=(input_size, hidden_size)
 | 
					            low=-scale, high=scale, shape=(hidden_size, input_size)
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        self.Whh = mx.random.uniform(
 | 
					        self.Whh = mx.random.uniform(
 | 
				
			||||||
            low=-scale, high=scale, shape=(hidden_size, hidden_size)
 | 
					            low=-scale, high=scale, shape=(hidden_size, hidden_size)
 | 
				
			||||||
@@ -67,21 +67,21 @@ class RNN(Module):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    def _extra_repr(self):
 | 
					    def _extra_repr(self):
 | 
				
			||||||
        return (
 | 
					        return (
 | 
				
			||||||
            f"input_dims={self.Wxh.shape[0]}, "
 | 
					            f"input_dims={self.Wxh.shape[1]}, "
 | 
				
			||||||
            f"hidden_size={self.hidden_size}, "
 | 
					            f"hidden_size={self.hidden_size}, "
 | 
				
			||||||
            f"nonlinearity={self.nonlinearity}, bias={self.bias is not None}"
 | 
					            f"nonlinearity={self.nonlinearity}, bias={self.bias is not None}"
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __call__(self, x, hidden=None):
 | 
					    def __call__(self, x, hidden=None):
 | 
				
			||||||
        if self.bias is not None:
 | 
					        if self.bias is not None:
 | 
				
			||||||
            x = mx.addmm(self.bias, x, self.Wxh)
 | 
					            x = mx.addmm(self.bias, x, self.Wxh.T)
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            x = x @ self.Wxh
 | 
					            x = x @ self.Wxh.T
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        all_hidden = []
 | 
					        all_hidden = []
 | 
				
			||||||
        for idx in range(x.shape[-2]):
 | 
					        for idx in range(x.shape[-2]):
 | 
				
			||||||
            if hidden is not None:
 | 
					            if hidden is not None:
 | 
				
			||||||
                hidden = x[..., idx, :] + hidden @ self.Whh
 | 
					                hidden = mx.addmm(x[..., idx, :], hidden, self.Whh.T)
 | 
				
			||||||
            else:
 | 
					            else:
 | 
				
			||||||
                hidden = x[..., idx, :]
 | 
					                hidden = x[..., idx, :]
 | 
				
			||||||
            hidden = self.nonlinearity(hidden)
 | 
					            hidden = self.nonlinearity(hidden)
 | 
				
			||||||
@@ -131,10 +131,10 @@ class GRU(Module):
 | 
				
			|||||||
        self.hidden_size = hidden_size
 | 
					        self.hidden_size = hidden_size
 | 
				
			||||||
        scale = 1.0 / math.sqrt(hidden_size)
 | 
					        scale = 1.0 / math.sqrt(hidden_size)
 | 
				
			||||||
        self.Wx = mx.random.uniform(
 | 
					        self.Wx = mx.random.uniform(
 | 
				
			||||||
            low=-scale, high=scale, shape=(input_size, 3 * hidden_size)
 | 
					            low=-scale, high=scale, shape=(3 * hidden_size, input_size)
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        self.Wh = mx.random.uniform(
 | 
					        self.Wh = mx.random.uniform(
 | 
				
			||||||
            low=-scale, high=scale, shape=(hidden_size, 3 * hidden_size)
 | 
					            low=-scale, high=scale, shape=(3 * hidden_size, hidden_size)
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        self.b = (
 | 
					        self.b = (
 | 
				
			||||||
            mx.random.uniform(low=-scale, high=scale, shape=(3 * hidden_size,))
 | 
					            mx.random.uniform(low=-scale, high=scale, shape=(3 * hidden_size,))
 | 
				
			||||||
@@ -149,15 +149,15 @@ class GRU(Module):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    def _extra_repr(self):
 | 
					    def _extra_repr(self):
 | 
				
			||||||
        return (
 | 
					        return (
 | 
				
			||||||
            f"input_dims={self.Wx.shape[0]}, "
 | 
					            f"input_dims={self.Wx.shape[1]}, "
 | 
				
			||||||
            f"hidden_size={self.hidden_size}, bias={self.b is not None}"
 | 
					            f"hidden_size={self.hidden_size}, bias={self.b is not None}"
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __call__(self, x, hidden=None):
 | 
					    def __call__(self, x, hidden=None):
 | 
				
			||||||
        if self.b is not None:
 | 
					        if self.b is not None:
 | 
				
			||||||
            x = mx.addmm(self.b, x, self.Wx)
 | 
					            x = mx.addmm(self.b, x, self.Wx.T)
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            x = x @ self.Wx
 | 
					            x = x @ self.Wx.T
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        x_rz = x[..., : -self.hidden_size]
 | 
					        x_rz = x[..., : -self.hidden_size]
 | 
				
			||||||
        x_n = x[..., -self.hidden_size :]
 | 
					        x_n = x[..., -self.hidden_size :]
 | 
				
			||||||
@@ -167,7 +167,7 @@ class GRU(Module):
 | 
				
			|||||||
        for idx in range(x.shape[-2]):
 | 
					        for idx in range(x.shape[-2]):
 | 
				
			||||||
            rz = x_rz[..., idx, :]
 | 
					            rz = x_rz[..., idx, :]
 | 
				
			||||||
            if hidden is not None:
 | 
					            if hidden is not None:
 | 
				
			||||||
                h_proj = hidden @ self.Wh
 | 
					                h_proj = hidden @ self.Wh.T
 | 
				
			||||||
                h_proj_rz = h_proj[..., : -self.hidden_size]
 | 
					                h_proj_rz = h_proj[..., : -self.hidden_size]
 | 
				
			||||||
                h_proj_n = h_proj[..., -self.hidden_size :]
 | 
					                h_proj_n = h_proj[..., -self.hidden_size :]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -240,10 +240,10 @@ class LSTM(Module):
 | 
				
			|||||||
        self.hidden_size = hidden_size
 | 
					        self.hidden_size = hidden_size
 | 
				
			||||||
        scale = 1.0 / math.sqrt(hidden_size)
 | 
					        scale = 1.0 / math.sqrt(hidden_size)
 | 
				
			||||||
        self.Wx = mx.random.uniform(
 | 
					        self.Wx = mx.random.uniform(
 | 
				
			||||||
            low=-scale, high=scale, shape=(input_size, 4 * hidden_size)
 | 
					            low=-scale, high=scale, shape=(4 * hidden_size, input_size)
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        self.Wh = mx.random.uniform(
 | 
					        self.Wh = mx.random.uniform(
 | 
				
			||||||
            low=-scale, high=scale, shape=(hidden_size, 4 * hidden_size)
 | 
					            low=-scale, high=scale, shape=(4 * hidden_size, hidden_size)
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        self.bias = (
 | 
					        self.bias = (
 | 
				
			||||||
            mx.random.uniform(low=-scale, high=scale, shape=(4 * hidden_size,))
 | 
					            mx.random.uniform(low=-scale, high=scale, shape=(4 * hidden_size,))
 | 
				
			||||||
@@ -253,15 +253,15 @@ class LSTM(Module):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    def _extra_repr(self):
 | 
					    def _extra_repr(self):
 | 
				
			||||||
        return (
 | 
					        return (
 | 
				
			||||||
            f"input_dims={self.Wx.shape[0]}, "
 | 
					            f"input_dims={self.Wx.shape[1]}, "
 | 
				
			||||||
            f"hidden_size={self.hidden_size}, bias={self.bias is not None}"
 | 
					            f"hidden_size={self.hidden_size}, bias={self.bias is not None}"
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __call__(self, x, hidden=None, cell=None):
 | 
					    def __call__(self, x, hidden=None, cell=None):
 | 
				
			||||||
        if self.bias is not None:
 | 
					        if self.bias is not None:
 | 
				
			||||||
            x = mx.addmm(self.bias, x, self.Wx)
 | 
					            x = mx.addmm(self.bias, x, self.Wx.T)
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            x = x @ self.Wx
 | 
					            x = x @ self.Wx.T
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        all_hidden = []
 | 
					        all_hidden = []
 | 
				
			||||||
        all_cell = []
 | 
					        all_cell = []
 | 
				
			||||||
@@ -269,7 +269,7 @@ class LSTM(Module):
 | 
				
			|||||||
        for idx in range(x.shape[-2]):
 | 
					        for idx in range(x.shape[-2]):
 | 
				
			||||||
            ifgo = x[..., idx, :]
 | 
					            ifgo = x[..., idx, :]
 | 
				
			||||||
            if hidden is not None:
 | 
					            if hidden is not None:
 | 
				
			||||||
                ifgo = ifgo + hidden @ self.Wh
 | 
					                ifgo = mx.addmm(ifgo, hidden, self.Wh.T)
 | 
				
			||||||
            i, f, g, o = mx.split(ifgo, 4, axis=-1)
 | 
					            i, f, g, o = mx.split(ifgo, 4, axis=-1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            i = mx.sigmoid(i)
 | 
					            i = mx.sigmoid(i)
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user