mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 09:51:17 +08:00

* Implementation of mlx.random.multivariate_normal (#502) * Update python/src/random.cpp Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * Update python/src/random.cpp Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * Update python/src/random.cpp Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * Updated typo in docstring * Restricted multivariate_normal to float32 * Generic mean and variance shapes * Review edits * Update mlx/random.cpp Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * Update python/src/random.cpp Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * Update python/src/random.cpp Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * Update python/src/random.cpp Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * Test for ndim of mean and cov * nits * smaller size for test * fix broadcasted sampling --------- Co-authored-by: Awni Hannun <awni.hannun@gmail.com> Co-authored-by: Awni Hannun <awni@apple.com>
47 lines
1.0 KiB
ReStructuredText
47 lines
1.0 KiB
ReStructuredText
.. _random:
|
|
|
|
Random
|
|
======
|
|
|
|
Random sampling functions in MLX use an implicit global PRNG state by default.
|
|
However, all function take an optional ``key`` keyword argument for when more
|
|
fine-grained control or explicit state management is needed.
|
|
|
|
For example, you can generate random numbers with:
|
|
|
|
.. code-block:: python
|
|
|
|
for _ in range(3):
|
|
print(mx.random.uniform())
|
|
|
|
which will print a sequence of unique pseudo random numbers. Alternatively you
|
|
can explicitly set the key:
|
|
|
|
.. code-block:: python
|
|
|
|
key = mx.random.key(0)
|
|
for _ in range(3):
|
|
print(mx.random.uniform(key=key))
|
|
|
|
which will yield the same pseudo random number at each iteration.
|
|
|
|
Following `JAX's PRNG design <https://jax.readthedocs.io/en/latest/jep/263-prng.html>`_
|
|
we use a splittable version of Threefry, which is a counter-based PRNG.
|
|
|
|
.. currentmodule:: mlx.core.random
|
|
|
|
.. autosummary::
|
|
:toctree: _autosummary
|
|
|
|
bernoulli
|
|
categorical
|
|
gumbel
|
|
key
|
|
normal
|
|
multivariate_normal
|
|
randint
|
|
seed
|
|
split
|
|
truncated_normal
|
|
uniform
|