mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-26 02:33:21 +08:00
fix clip (#1415)
This commit is contained in:
parent
b3f52c9fbe
commit
d6492b0163
@ -838,7 +838,7 @@ array clip(
|
|||||||
if (!a_min.has_value() && !a_max.has_value()) {
|
if (!a_min.has_value() && !a_max.has_value()) {
|
||||||
throw std::invalid_argument("At most one of a_min and a_max may be None");
|
throw std::invalid_argument("At most one of a_min and a_max may be None");
|
||||||
}
|
}
|
||||||
array result = astype(a, a.dtype(), s);
|
array result = a;
|
||||||
if (a_min.has_value()) {
|
if (a_min.has_value()) {
|
||||||
result = maximum(result, a_min.value(), s);
|
result = maximum(result, a_min.value(), s);
|
||||||
}
|
}
|
||||||
|
@ -2827,10 +2827,10 @@ void init_ops(nb::module_& m) {
|
|||||||
std::optional<array> min_ = std::nullopt;
|
std::optional<array> min_ = std::nullopt;
|
||||||
std::optional<array> max_ = std::nullopt;
|
std::optional<array> max_ = std::nullopt;
|
||||||
if (min) {
|
if (min) {
|
||||||
min_ = to_array(min.value());
|
min_ = to_arrays(a, min.value()).second;
|
||||||
}
|
}
|
||||||
if (max) {
|
if (max) {
|
||||||
max_ = to_array(max.value());
|
max_ = to_arrays(a, max.value()).second;
|
||||||
}
|
}
|
||||||
return clip(a, min_, max_, s);
|
return clip(a, min_, max_, s);
|
||||||
},
|
},
|
||||||
|
@ -2008,6 +2008,22 @@ class TestOps(mlx_tests.MLXTestCase):
|
|||||||
clipped = mx.clip(mx.array(a), mx.array(mins), mx.array(maxs))
|
clipped = mx.clip(mx.array(a), mx.array(mins), mx.array(maxs))
|
||||||
self.assertTrue(np.array_equal(clipped, expected))
|
self.assertTrue(np.array_equal(clipped, expected))
|
||||||
|
|
||||||
|
# Check clip output types
|
||||||
|
a = mx.array([1, 2, 3], mx.int16)
|
||||||
|
out_t = mx.clip(a, a_min=0, a_max=5).dtype
|
||||||
|
self.assertEqual(out_t, mx.int16)
|
||||||
|
|
||||||
|
out_t = mx.clip(a, a_min=0.0, a_max=5).dtype
|
||||||
|
self.assertEqual(out_t, mx.float32)
|
||||||
|
|
||||||
|
a = mx.array([1, 2, 3], mx.float16)
|
||||||
|
out_t = mx.clip(a, a_min=0.0, a_max=5).dtype
|
||||||
|
self.assertEqual(out_t, mx.float16)
|
||||||
|
|
||||||
|
a = mx.array([1, 2, 3], mx.float16)
|
||||||
|
out_t = mx.clip(a, a_min=0.0, a_max=mx.array(1.0)).dtype
|
||||||
|
self.assertEqual(out_t, mx.float32)
|
||||||
|
|
||||||
def test_linspace(self):
|
def test_linspace(self):
|
||||||
# Test default num = 50
|
# Test default num = 50
|
||||||
a = mx.linspace(0, 1)
|
a = mx.linspace(0, 1)
|
||||||
|
Loading…
Reference in New Issue
Block a user