mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 10:38:10 +08:00 
			
		
		
		
	* Implementation of mlx.random.multivariate_normal (#502) * Update python/src/random.cpp Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * Update python/src/random.cpp Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * Update python/src/random.cpp Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * Updated typo in docstring * Restricted multivariate_normal to float32 * Generic mean and variance shapes * Review edits * Update mlx/random.cpp Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * Update python/src/random.cpp Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * Update python/src/random.cpp Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * Update python/src/random.cpp Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * Test for ndim of mean and cov * nits * smaller size for test * fix broadcasted sampling --------- Co-authored-by: Awni Hannun <awni.hannun@gmail.com> Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
		@@ -179,6 +179,48 @@ void init_random(nb::module_& parent_module) {
 | 
			
		||||
            array: The output array of random values.
 | 
			
		||||
      )pbdoc");
 | 
			
		||||
  m.def(
 | 
			
		||||
      "multivariate_normal",
 | 
			
		||||
      [](const array& mean,
 | 
			
		||||
         const array& cov,
 | 
			
		||||
         const std::vector<int>& shape,
 | 
			
		||||
         std::optional<Dtype> type,
 | 
			
		||||
         const std::optional<array>& key_,
 | 
			
		||||
         StreamOrDevice s) {
 | 
			
		||||
        auto key = key_ ? key_.value() : default_key().next();
 | 
			
		||||
        return multivariate_normal(
 | 
			
		||||
            mean, cov, shape, type.value_or(float32), key, s);
 | 
			
		||||
      },
 | 
			
		||||
      "mean"_a,
 | 
			
		||||
      "cov"_a,
 | 
			
		||||
      "shape"_a = std::vector<int>{},
 | 
			
		||||
      "dtype"_a.none() = float32,
 | 
			
		||||
      "key"_a = nb::none(),
 | 
			
		||||
      "stream"_a = nb::none(),
 | 
			
		||||
      nb::sig(
 | 
			
		||||
          "def multivariate_normal(mean: array, cov: array, shape: Sequence[int] = [], dtype: Optional[Dtype] = float32, key: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"),
 | 
			
		||||
      R"pbdoc(
 | 
			
		||||
        Generate jointly-normal random samples given a mean and covariance.
 | 
			
		||||
 | 
			
		||||
        The matrix ``cov`` must be positive semi-definite. The behavior is
 | 
			
		||||
        undefined if it is not.  The only supported ``dtype`` is ``float32``.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            mean (array): array of shape ``(..., n)``, the mean of the
 | 
			
		||||
              distribution.
 | 
			
		||||
            cov (array): array  of shape ``(..., n, n)``, the covariance
 | 
			
		||||
              matrix of the distribution. The batch shape ``...`` must be
 | 
			
		||||
              broadcast-compatible with that of ``mean``.
 | 
			
		||||
            shape (list(int), optional): The output shape must be
 | 
			
		||||
              broadcast-compatible with ``mean.shape[:-1]`` and ``cov.shape[:-2]``.
 | 
			
		||||
              If empty, the result shape is determined by broadcasting the batch
 | 
			
		||||
              shapes of ``mean`` and ``cov``. Default: ``[]``.
 | 
			
		||||
            dtype (Dtype, optional): The output type. Default: ``float32``.
 | 
			
		||||
            key (array, optional): A PRNG key. Default: ``None``.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            array: The output array of random values.
 | 
			
		||||
      )pbdoc");
 | 
			
		||||
  m.def(
 | 
			
		||||
      "randint",
 | 
			
		||||
      [](const ScalarOrArray& low,
 | 
			
		||||
         const ScalarOrArray& high,
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user