mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
parent
1fa0d20a30
commit
e1c9600da3
@ -45,3 +45,4 @@ we use a splittable version of Threefry, which is a counter-based PRNG.
|
|||||||
truncated_normal
|
truncated_normal
|
||||||
uniform
|
uniform
|
||||||
laplace
|
laplace
|
||||||
|
permutation
|
||||||
|
@ -458,4 +458,19 @@ array laplace(
|
|||||||
return samples;
|
return samples;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
array permutation(
|
||||||
|
const array& x,
|
||||||
|
int axis /* = 0 */,
|
||||||
|
const std::optional<array>& key /* = std::nullopt */,
|
||||||
|
StreamOrDevice s /* = {} */) {
|
||||||
|
return take(x, permutation(x.shape(axis), key, s), axis, s);
|
||||||
|
}
|
||||||
|
|
||||||
|
array permutation(
|
||||||
|
int x,
|
||||||
|
const std::optional<array>& key /* = std::nullopt */,
|
||||||
|
StreamOrDevice s /* = {} */) {
|
||||||
|
return argsort(bits({x}, key, s), s);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mlx::core::random
|
} // namespace mlx::core::random
|
||||||
|
13
mlx/random.h
13
mlx/random.h
@ -254,4 +254,17 @@ inline array laplace(
|
|||||||
return laplace(shape, float32, 0.0, 1.0, key, s);
|
return laplace(shape, float32, 0.0, 1.0, key, s);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* Randomly permute the elements of x along the given axis. */
|
||||||
|
array permutation(
|
||||||
|
const array& x,
|
||||||
|
int axis = 0,
|
||||||
|
const std::optional<array>& key = std::nullopt,
|
||||||
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/* A random permutation of `arange(x)` */
|
||||||
|
array permutation(
|
||||||
|
int x,
|
||||||
|
const std::optional<array>& key = std::nullopt,
|
||||||
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
} // namespace mlx::core::random
|
} // namespace mlx::core::random
|
||||||
|
@ -454,6 +454,39 @@ void init_random(nb::module_& parent_module) {
|
|||||||
Returns:
|
Returns:
|
||||||
array: The output array of random values.
|
array: The output array of random values.
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
|
m.def(
|
||||||
|
"permuation",
|
||||||
|
[](const std::variant<int, array>& x,
|
||||||
|
int axis,
|
||||||
|
const std::optional<array>& key_,
|
||||||
|
StreamOrDevice s) {
|
||||||
|
auto key = key_ ? key_.value() : default_key().next();
|
||||||
|
if (auto pv = std::get_if<int>(&x); pv) {
|
||||||
|
return permutation(*pv, key, s);
|
||||||
|
} else {
|
||||||
|
return permutation(std::get<array>(x), axis, key, s);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"shape"_a = std::vector<int>{},
|
||||||
|
"axis"_a = 0,
|
||||||
|
"key"_a = nb::none(),
|
||||||
|
"stream"_a = nb::none(),
|
||||||
|
nb::sig(
|
||||||
|
"def permutation(x: Union[int, array], axis: int = 0, key: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"),
|
||||||
|
R"pbdoc(
|
||||||
|
Generate a random permutation or permute the entries of an array.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (int or array, optional): If an integer is provided a random
|
||||||
|
permtuation of ``mx.arange(x)`` is returned. Otherwise the entries
|
||||||
|
of ``x`` along the given axis are randomly permuted.
|
||||||
|
axis (int, optional): The axis to permute along. Default: ``0``.
|
||||||
|
key (array, optional): A PRNG key. Default: ``None``.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
array:
|
||||||
|
The generated random permutation or randomly permuted input array.
|
||||||
|
)pbdoc");
|
||||||
// Register static Python object cleanup before the interpreter exits
|
// Register static Python object cleanup before the interpreter exits
|
||||||
auto atexit = nb::module_::import_("atexit");
|
auto atexit = nb::module_::import_("atexit");
|
||||||
atexit.attr("register")(nb::cpp_function([]() { default_key().release(); }));
|
atexit.attr("register")(nb::cpp_function([]() { default_key().release(); }));
|
||||||
|
@ -325,6 +325,29 @@ class TestRandom(mlx_tests.MLXTestCase):
|
|||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
mx.random.categorical(logits, shape=[10, 5], num_samples=5)
|
mx.random.categorical(logits, shape=[10, 5], num_samples=5)
|
||||||
|
|
||||||
|
def test_permutation(self):
|
||||||
|
x = sorted(mx.random.permutation(4).tolist())
|
||||||
|
self.assertEqual([0, 1, 2, 3], x)
|
||||||
|
|
||||||
|
x = mx.array([0, 1, 2, 3])
|
||||||
|
x = sorted(mx.random.permutation(x).tolist())
|
||||||
|
self.assertEqual([0, 1, 2, 3], x)
|
||||||
|
|
||||||
|
x = mx.array([0, 1, 2, 3])
|
||||||
|
x = sorted(mx.random.permutation(x).tolist())
|
||||||
|
|
||||||
|
# 2-D
|
||||||
|
x = mx.arange(16).reshape(4, 4)
|
||||||
|
out = mx.sort(mx.random.permutation(x, axis=0), axis=0)
|
||||||
|
self.assertTrue(mx.array_equal(x, out))
|
||||||
|
out = mx.sort(mx.random.permutation(x, axis=1), axis=1)
|
||||||
|
self.assertTrue(mx.array_equal(x, out))
|
||||||
|
|
||||||
|
# Basically 0 probability this should fail.
|
||||||
|
sorted_x = mx.arange(16384)
|
||||||
|
x = mx.random.permutation(16384)
|
||||||
|
self.assertFalse(mx.array_equal(sorted_x, x))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user