mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 07:58:14 +08:00 
			
		
		
		
	fix gumbel (#1495)
This commit is contained in:
		| @@ -352,10 +352,10 @@ void init_random(nb::module_& parent_module) { | ||||
|       }, | ||||
|       "shape"_a = std::vector<int>{}, | ||||
|       "dtype"_a.none() = float32, | ||||
|       "stream"_a = nb::none(), | ||||
|       "key"_a = nb::none(), | ||||
|       "stream"_a = nb::none(), | ||||
|       nb::sig( | ||||
|           "def gumbel(shape: Sequence[int] = [], dtype: Optional[Dtype] = float32, stream: Optional[array] = None, key: Union[None, Stream, Device] = None) -> array"), | ||||
|           "def gumbel(shape: Sequence[int] = [], dtype: Optional[Dtype] = float32, key: Union[None, Stream, Device] = None, stream: Optional[array] = None) -> array"), | ||||
|       R"pbdoc( | ||||
|         Sample from the standard Gumbel distribution. | ||||
|  | ||||
| @@ -364,11 +364,14 @@ void init_random(nb::module_& parent_module) { | ||||
|  | ||||
|         Args: | ||||
|             shape (list(int)): The shape of the output. | ||||
|             dtype (Dtype, optional): The data type of the output. | ||||
|               Default: ``float32``. | ||||
|             key (array, optional): A PRNG key. Default: ``None``. | ||||
|  | ||||
|         Returns: | ||||
|             array: The :class:`array` with shape ``shape`` and | ||||
|                    distributed according to the Gumbel distribution | ||||
|             array: | ||||
|               The :class:`array` with shape ``shape`` and distributed according | ||||
|               to the Gumbel distribution. | ||||
|       )pbdoc"); | ||||
|   m.def( | ||||
|       "categorical", | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun