mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-02 16:56:46 +08:00
make behaviour of dtype arguments consistent and compliant to numpy (#379)
All functions that take an optional dtype should * have a default dtype visible in the generated docs (accomplished via `"dtype"_a = std::optional{float32}`) * behave identical when `dtype=None` or no dtype is passed This important when passing kw args down from a numpy function like: ``` def f(x, dtype=None): mx.random.uniform(dtype=dtype) # ... ``` NumPy functions behave like this. It also fixes a minor bug in `tri`: #378 Closes #378
This commit is contained in:
parent
d8f41a5c0f
commit
dfdb284e16
@ -1212,14 +1212,22 @@ void init_ops(py::module_& m) {
|
|||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"linspace",
|
"linspace",
|
||||||
[](Scalar start, Scalar stop, int num, Dtype dtype, StreamOrDevice s) {
|
[](Scalar start,
|
||||||
|
Scalar stop,
|
||||||
|
int num,
|
||||||
|
std::optional<Dtype> dtype,
|
||||||
|
StreamOrDevice s) {
|
||||||
return linspace(
|
return linspace(
|
||||||
scalar_to_double(start), scalar_to_double(stop), num, dtype, s);
|
scalar_to_double(start),
|
||||||
|
scalar_to_double(stop),
|
||||||
|
num,
|
||||||
|
dtype.value_or(float32),
|
||||||
|
s);
|
||||||
},
|
},
|
||||||
"start"_a,
|
"start"_a,
|
||||||
"stop"_a,
|
"stop"_a,
|
||||||
"num"_a = 50,
|
"num"_a = 50,
|
||||||
"dtype"_a = float32,
|
"dtype"_a = std::optional{float32},
|
||||||
"stream"_a = none,
|
"stream"_a = none,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
linspace(start, stop, num: Optional[int] = 50, dtype: Optional[Dtype] = float32, stream: Union[None, Stream, Device] = None) -> array
|
linspace(start, stop, num: Optional[int] = 50, dtype: Optional[Dtype] = float32, stream: Union[None, Stream, Device] = None) -> array
|
||||||
@ -1356,11 +1364,11 @@ void init_ops(py::module_& m) {
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
"shape"_a,
|
"shape"_a,
|
||||||
"dtype"_a = std::nullopt,
|
"dtype"_a = std::optional{float32},
|
||||||
py::kw_only(),
|
py::kw_only(),
|
||||||
"stream"_a = none,
|
"stream"_a = none,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
zeros(shape: Union[int, List[int]], dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array
|
zeros(shape: Union[int, List[int]], dtype: Optional[Dtype] = float32, *, stream: Union[None, Stream, Device] = None) -> array
|
||||||
|
|
||||||
Construct an array of zeros.
|
Construct an array of zeros.
|
||||||
|
|
||||||
@ -1403,11 +1411,11 @@ void init_ops(py::module_& m) {
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
"shape"_a,
|
"shape"_a,
|
||||||
"dtype"_a = std::nullopt,
|
"dtype"_a = std::optional{float32},
|
||||||
py::kw_only(),
|
py::kw_only(),
|
||||||
"stream"_a = none,
|
"stream"_a = none,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
ones(shape: Union[int, List[int]], dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array
|
ones(shape: Union[int, List[int]], dtype: Optional[Dtype] = float32, *, stream: Union[None, Stream, Device] = None) -> array
|
||||||
|
|
||||||
Construct an array of ones.
|
Construct an array of ones.
|
||||||
|
|
||||||
@ -1449,11 +1457,11 @@ void init_ops(py::module_& m) {
|
|||||||
"n"_a,
|
"n"_a,
|
||||||
"m"_a = py::none(),
|
"m"_a = py::none(),
|
||||||
"k"_a = 0,
|
"k"_a = 0,
|
||||||
"dtype"_a = std::nullopt,
|
"dtype"_a = std::optional{float32},
|
||||||
py::kw_only(),
|
py::kw_only(),
|
||||||
"stream"_a = none,
|
"stream"_a = none,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
eye(n: int, m: Optional[int] = None, k: int = 0, dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array
|
eye(n: int, m: Optional[int] = None, k: int = 0, dtype: Optional[Dtype] = float32, *, stream: Union[None, Stream, Device] = None) -> array
|
||||||
|
|
||||||
Create an identity matrix or a general diagonal matrix.
|
Create an identity matrix or a general diagonal matrix.
|
||||||
|
|
||||||
@ -1473,11 +1481,11 @@ void init_ops(py::module_& m) {
|
|||||||
return identity(n, dtype.value_or(float32), s);
|
return identity(n, dtype.value_or(float32), s);
|
||||||
},
|
},
|
||||||
"n"_a,
|
"n"_a,
|
||||||
"dtype"_a = std::nullopt,
|
"dtype"_a = std::optional{float32},
|
||||||
py::kw_only(),
|
py::kw_only(),
|
||||||
"stream"_a = none,
|
"stream"_a = none,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
identity(n: int, dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array
|
identity(n: int, dtype: Optional[Dtype] = float32, *, stream: Union[None, Stream, Device] = None) -> array
|
||||||
|
|
||||||
Create a square identity matrix.
|
Create a square identity matrix.
|
||||||
|
|
||||||
@ -1491,13 +1499,17 @@ void init_ops(py::module_& m) {
|
|||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"tri",
|
"tri",
|
||||||
[](int n, std::optional<int> m, int k, Dtype dtype, StreamOrDevice s) {
|
[](int n,
|
||||||
return tri(n, m.value_or(n), k, float32, s);
|
std::optional<int> m,
|
||||||
|
int k,
|
||||||
|
std::optional<Dtype> type,
|
||||||
|
StreamOrDevice s) {
|
||||||
|
return tri(n, m.value_or(n), k, type.value_or(float32), s);
|
||||||
},
|
},
|
||||||
"n"_a,
|
"n"_a,
|
||||||
"m"_a = none,
|
"m"_a = none,
|
||||||
"k"_a = 0,
|
"k"_a = 0,
|
||||||
"dtype"_a = float32,
|
"dtype"_a = std::optional{float32},
|
||||||
py::kw_only(),
|
py::kw_only(),
|
||||||
"stream"_a = none,
|
"stream"_a = none,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
|
@ -61,15 +61,21 @@ void init_random(py::module_& parent_module) {
|
|||||||
[](const ScalarOrArray& low,
|
[](const ScalarOrArray& low,
|
||||||
const ScalarOrArray& high,
|
const ScalarOrArray& high,
|
||||||
const std::vector<int>& shape,
|
const std::vector<int>& shape,
|
||||||
Dtype type,
|
std::optional<Dtype> type,
|
||||||
const std::optional<array>& key,
|
const std::optional<array>& key,
|
||||||
StreamOrDevice s) {
|
StreamOrDevice s) {
|
||||||
return uniform(to_array(low), to_array(high), shape, type, key, s);
|
return uniform(
|
||||||
|
to_array(low),
|
||||||
|
to_array(high),
|
||||||
|
shape,
|
||||||
|
type.value_or(float32),
|
||||||
|
key,
|
||||||
|
s);
|
||||||
},
|
},
|
||||||
"low"_a = 0,
|
"low"_a = 0,
|
||||||
"high"_a = 1,
|
"high"_a = 1,
|
||||||
"shape"_a = std::vector<int>{},
|
"shape"_a = std::vector<int>{},
|
||||||
"dtype"_a = float32,
|
"dtype"_a = std::optional{float32},
|
||||||
"key"_a = none,
|
"key"_a = none,
|
||||||
"stream"_a = none,
|
"stream"_a = none,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
@ -92,12 +98,14 @@ void init_random(py::module_& parent_module) {
|
|||||||
m.def(
|
m.def(
|
||||||
"normal",
|
"normal",
|
||||||
[](const std::vector<int>& shape,
|
[](const std::vector<int>& shape,
|
||||||
Dtype type,
|
std::optional<Dtype> type,
|
||||||
const std::optional<array>& key,
|
const std::optional<array>& key,
|
||||||
StreamOrDevice s) { return normal(shape, type, key, s); },
|
StreamOrDevice s) {
|
||||||
|
return normal(shape, type.value_or(float32), key, s);
|
||||||
|
},
|
||||||
|
|
||||||
"shape"_a = std::vector<int>{},
|
"shape"_a = std::vector<int>{},
|
||||||
"dtype"_a = float32,
|
"dtype"_a = std::optional{float32},
|
||||||
"key"_a = none,
|
"key"_a = none,
|
||||||
"stream"_a = none,
|
"stream"_a = none,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
@ -116,10 +124,11 @@ void init_random(py::module_& parent_module) {
|
|||||||
[](const ScalarOrArray& low,
|
[](const ScalarOrArray& low,
|
||||||
const ScalarOrArray& high,
|
const ScalarOrArray& high,
|
||||||
const std::vector<int>& shape,
|
const std::vector<int>& shape,
|
||||||
Dtype type,
|
std::optional<Dtype> type,
|
||||||
const std::optional<array>& key,
|
const std::optional<array>& key,
|
||||||
StreamOrDevice s) {
|
StreamOrDevice s) {
|
||||||
return randint(to_array(low), to_array(high), shape, type, key, s);
|
return randint(
|
||||||
|
to_array(low), to_array(high), shape, type.value_or(int32), key, s);
|
||||||
},
|
},
|
||||||
"low"_a,
|
"low"_a,
|
||||||
"high"_a,
|
"high"_a,
|
||||||
@ -183,21 +192,22 @@ void init_random(py::module_& parent_module) {
|
|||||||
[](const ScalarOrArray& lower_,
|
[](const ScalarOrArray& lower_,
|
||||||
const ScalarOrArray& upper_,
|
const ScalarOrArray& upper_,
|
||||||
const std::optional<std::vector<int>> shape_,
|
const std::optional<std::vector<int>> shape_,
|
||||||
Dtype dtype,
|
std::optional<Dtype> type,
|
||||||
const std::optional<array>& key,
|
const std::optional<array>& key,
|
||||||
StreamOrDevice s) {
|
StreamOrDevice s) {
|
||||||
auto lower = to_array(lower_);
|
auto lower = to_array(lower_);
|
||||||
auto upper = to_array(upper_);
|
auto upper = to_array(upper_);
|
||||||
|
auto t = type.value_or(float32);
|
||||||
if (shape_.has_value()) {
|
if (shape_.has_value()) {
|
||||||
return truncated_normal(lower, upper, shape_.value(), dtype, key, s);
|
return truncated_normal(lower, upper, shape_.value(), t, key, s);
|
||||||
} else {
|
} else {
|
||||||
return truncated_normal(lower, upper, dtype, key, s);
|
return truncated_normal(lower, upper, t, key, s);
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"lower"_a,
|
"lower"_a,
|
||||||
"upper"_a,
|
"upper"_a,
|
||||||
"shape"_a = none,
|
"shape"_a = none,
|
||||||
"dtype"_a = float32,
|
"dtype"_a = std::optional{float32},
|
||||||
"key"_a = none,
|
"key"_a = none,
|
||||||
"stream"_a = none,
|
"stream"_a = none,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
@ -221,9 +231,14 @@ void init_random(py::module_& parent_module) {
|
|||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"gumbel",
|
"gumbel",
|
||||||
&gumbel,
|
[](const std::vector<int>& shape,
|
||||||
|
std::optional<Dtype> type,
|
||||||
|
const std::optional<array>& key,
|
||||||
|
StreamOrDevice s) {
|
||||||
|
return gumbel(shape, type.value_or(float32), key, s);
|
||||||
|
},
|
||||||
"shape"_a = std::vector<int>{},
|
"shape"_a = std::vector<int>{},
|
||||||
"dtype"_a = float32,
|
"dtype"_a = std::optional{float32},
|
||||||
"stream"_a = none,
|
"stream"_a = none,
|
||||||
"key"_a = none,
|
"key"_a = none,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
|
@ -325,6 +325,8 @@ class TestOps(mlx_tests.MLXTestCase):
|
|||||||
for shape in [[4], [4, 4], [2, 10]]:
|
for shape in [[4], [4, 4], [2, 10]]:
|
||||||
for diag in [-1, 0, 1, -2]:
|
for diag in [-1, 0, 1, -2]:
|
||||||
self.assertCmpNumpy(shape, mx.tri, np.tri, k=diag)
|
self.assertCmpNumpy(shape, mx.tri, np.tri, k=diag)
|
||||||
|
self.assertEqual(mx.tri(1, 1).dtype, mx.float32)
|
||||||
|
self.assertEqual(mx.tri(1, 1, dtype=mx.bfloat16).dtype, mx.bfloat16)
|
||||||
|
|
||||||
def test_tril(self):
|
def test_tril(self):
|
||||||
for diag in [-1, 0, 1, -2]:
|
for diag in [-1, 0, 1, -2]:
|
||||||
|
@ -61,6 +61,8 @@ class TestRandom(mlx_tests.MLXTestCase):
|
|||||||
a = mx.random.uniform(low=-0.1, high=0.1, shape=(1,), dtype=mx.bfloat16)
|
a = mx.random.uniform(low=-0.1, high=0.1, shape=(1,), dtype=mx.bfloat16)
|
||||||
self.assertEqual(a.dtype, mx.bfloat16)
|
self.assertEqual(a.dtype, mx.bfloat16)
|
||||||
|
|
||||||
|
self.assertEqual(mx.random.uniform().dtype, mx.random.uniform(dtype=None).dtype)
|
||||||
|
|
||||||
def test_normal(self):
|
def test_normal(self):
|
||||||
key = mx.random.key(0)
|
key = mx.random.key(0)
|
||||||
a = mx.random.normal(key=key)
|
a = mx.random.normal(key=key)
|
||||||
@ -78,6 +80,8 @@ class TestRandom(mlx_tests.MLXTestCase):
|
|||||||
a = mx.random.normal(dtype=t)
|
a = mx.random.normal(dtype=t)
|
||||||
self.assertEqual(a.dtype, t)
|
self.assertEqual(a.dtype, t)
|
||||||
|
|
||||||
|
self.assertEqual(mx.random.normal().dtype, mx.random.normal(dtype=None).dtype)
|
||||||
|
|
||||||
def test_randint(self):
|
def test_randint(self):
|
||||||
a = mx.random.randint(0, 1, [])
|
a = mx.random.randint(0, 1, [])
|
||||||
self.assertEqual(a.shape, [])
|
self.assertEqual(a.shape, [])
|
||||||
@ -109,6 +113,10 @@ class TestRandom(mlx_tests.MLXTestCase):
|
|||||||
a = mx.random.randint(10, -10, [1000, 1000])
|
a = mx.random.randint(10, -10, [1000, 1000])
|
||||||
self.assertTrue(mx.all(a == 10).item())
|
self.assertTrue(mx.all(a == 10).item())
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
mx.random.randint(0, 1).dtype, mx.random.randint(0, 1, dtype=None).dtype
|
||||||
|
)
|
||||||
|
|
||||||
def test_bernoulli(self):
|
def test_bernoulli(self):
|
||||||
a = mx.random.bernoulli()
|
a = mx.random.bernoulli()
|
||||||
self.assertEqual(a.shape, [])
|
self.assertEqual(a.shape, [])
|
||||||
@ -159,6 +167,11 @@ class TestRandom(mlx_tests.MLXTestCase):
|
|||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
mx.random.truncated_normal(lower, higher) # Bad shape
|
mx.random.truncated_normal(lower, higher) # Bad shape
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
mx.random.truncated_normal(0, 1).dtype,
|
||||||
|
mx.random.truncated_normal(0, 1, dtype=None).dtype,
|
||||||
|
)
|
||||||
|
|
||||||
def test_gumbel(self):
|
def test_gumbel(self):
|
||||||
samples = mx.random.gumbel(shape=(100, 100))
|
samples = mx.random.gumbel(shape=(100, 100))
|
||||||
self.assertEqual(samples.shape, [100, 100])
|
self.assertEqual(samples.shape, [100, 100])
|
||||||
@ -168,6 +181,10 @@ class TestRandom(mlx_tests.MLXTestCase):
|
|||||||
# so this test is pretty conservative
|
# so this test is pretty conservative
|
||||||
self.assertTrue(mx.abs(mx.mean(samples) - mean) < 0.2)
|
self.assertTrue(mx.abs(mx.mean(samples) - mean) < 0.2)
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
mx.random.gumbel((1, 1)).dtype, mx.random.gumbel((1, 1), dtype=None).dtype
|
||||||
|
)
|
||||||
|
|
||||||
def test_categorical(self):
|
def test_categorical(self):
|
||||||
logits = mx.zeros((10, 20))
|
logits = mx.zeros((10, 20))
|
||||||
self.assertEqual(mx.random.categorical(logits, -1).shape, [10])
|
self.assertEqual(mx.random.categorical(logits, -1).shape, [10])
|
||||||
|
Loading…
Reference in New Issue
Block a user