From 9dd72cd421260ebc0f30e773f6b35fdf87555806 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 17 Oct 2024 13:52:39 -0700 Subject: [PATCH] fix gumbel (#1495) --- python/src/random.cpp | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/python/src/random.cpp b/python/src/random.cpp index af95d4e6a..538a46aaf 100644 --- a/python/src/random.cpp +++ b/python/src/random.cpp @@ -352,10 +352,10 @@ void init_random(nb::module_& parent_module) { }, "shape"_a = std::vector{}, "dtype"_a.none() = float32, - "stream"_a = nb::none(), "key"_a = nb::none(), + "stream"_a = nb::none(), nb::sig( - "def gumbel(shape: Sequence[int] = [], dtype: Optional[Dtype] = float32, stream: Optional[array] = None, key: Union[None, Stream, Device] = None) -> array"), + "def gumbel(shape: Sequence[int] = [], dtype: Optional[Dtype] = float32, key: Union[None, Stream, Device] = None, stream: Optional[array] = None) -> array"), R"pbdoc( Sample from the standard Gumbel distribution. @@ -364,11 +364,14 @@ void init_random(nb::module_& parent_module) { Args: shape (list(int)): The shape of the output. + dtype (Dtype, optional): The data type of the output. + Default: ``float32``. key (array, optional): A PRNG key. Default: ``None``. Returns: - array: The :class:`array` with shape ``shape`` and - distributed according to the Gumbel distribution + array: + The :class:`array` with shape ``shape`` and distributed according + to the Gumbel distribution. )pbdoc"); m.def( "categorical",