mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Adds round op and primitive (#203)
This commit is contained in:
parent
477397bc98
commit
4d4af12c6f
@ -34,6 +34,7 @@ Array
|
|||||||
array.prod
|
array.prod
|
||||||
array.reciprocal
|
array.reciprocal
|
||||||
array.reshape
|
array.reshape
|
||||||
|
array.round
|
||||||
array.rsqrt
|
array.rsqrt
|
||||||
array.sin
|
array.sin
|
||||||
array.split
|
array.split
|
||||||
|
@ -73,6 +73,7 @@ Operations
|
|||||||
prod
|
prod
|
||||||
reciprocal
|
reciprocal
|
||||||
reshape
|
reshape
|
||||||
|
round
|
||||||
rsqrt
|
rsqrt
|
||||||
save
|
save
|
||||||
savez
|
savez
|
||||||
|
@ -47,6 +47,7 @@ DEFAULT(Pad)
|
|||||||
DEFAULT(Partition)
|
DEFAULT(Partition)
|
||||||
DEFAULT(RandomBits)
|
DEFAULT(RandomBits)
|
||||||
DEFAULT(Reshape)
|
DEFAULT(Reshape)
|
||||||
|
DEFAULT(Round)
|
||||||
DEFAULT(Scatter)
|
DEFAULT(Scatter)
|
||||||
DEFAULT(Sigmoid)
|
DEFAULT(Sigmoid)
|
||||||
DEFAULT(Sign)
|
DEFAULT(Sign)
|
||||||
|
@ -65,6 +65,7 @@ DEFAULT(Power)
|
|||||||
DEFAULT(RandomBits)
|
DEFAULT(RandomBits)
|
||||||
DEFAULT(Reduce)
|
DEFAULT(Reduce)
|
||||||
DEFAULT(Reshape)
|
DEFAULT(Reshape)
|
||||||
|
DEFAULT(Round)
|
||||||
DEFAULT(Scan)
|
DEFAULT(Scan)
|
||||||
DEFAULT(Scatter)
|
DEFAULT(Scatter)
|
||||||
DEFAULT(Sigmoid)
|
DEFAULT(Sigmoid)
|
||||||
|
@ -466,6 +466,17 @@ void Reshape::eval(const std::vector<array>& inputs, array& out) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void Round::eval(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
auto& in = inputs[0];
|
||||||
|
if (not is_integral(in.dtype())) {
|
||||||
|
unary_fp(in, out, RoundOp());
|
||||||
|
} else {
|
||||||
|
// No-op integer types
|
||||||
|
out.copy_shared_buffer(in);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void Sigmoid::eval(const std::vector<array>& inputs, array& out) {
|
void Sigmoid::eval(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
const auto& in = inputs[0];
|
const auto& in = inputs[0];
|
||||||
|
@ -53,6 +53,17 @@ struct SignOp {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct RoundOp {
|
||||||
|
template <typename T>
|
||||||
|
T operator()(T x) {
|
||||||
|
return std::round(x);
|
||||||
|
}
|
||||||
|
|
||||||
|
complex64_t operator()(complex64_t x) {
|
||||||
|
return {std::round(x.real()), std::round(x.imag())};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
template <typename T, typename Op>
|
template <typename T, typename Op>
|
||||||
void unary_op(const array& a, array& out, Op op) {
|
void unary_op(const array& a, array& out, Op op) {
|
||||||
const T* a_ptr = a.data<T>();
|
const T* a_ptr = a.data<T>();
|
||||||
|
@ -133,6 +133,11 @@ struct Negative {
|
|||||||
template <typename T> T operator()(T x) { return -x; };
|
template <typename T> T operator()(T x) { return -x; };
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct Round {
|
||||||
|
template <typename T> T operator()(T x) { return metal::round(x); };
|
||||||
|
template <> complex64_t operator()(complex64_t x) { return {metal::round(x.real), metal::round(x.imag)}; };
|
||||||
|
};
|
||||||
|
|
||||||
struct Sigmoid {
|
struct Sigmoid {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
T operator()(T x) {
|
T operator()(T x) {
|
||||||
@ -300,6 +305,7 @@ instantiate_unary_float(sqrt, Sqrt)
|
|||||||
instantiate_unary_float(rsqrt, Rsqrt)
|
instantiate_unary_float(rsqrt, Rsqrt)
|
||||||
instantiate_unary_float(tan, Tan)
|
instantiate_unary_float(tan, Tan)
|
||||||
instantiate_unary_float(tanh, Tanh)
|
instantiate_unary_float(tanh, Tanh)
|
||||||
|
instantiate_unary_float(round, Round)
|
||||||
|
|
||||||
instantiate_unary_all(abs, complex64, complex64_t, Abs)
|
instantiate_unary_all(abs, complex64, complex64_t, Abs)
|
||||||
instantiate_unary_all(cos, complex64, complex64_t, Cos)
|
instantiate_unary_all(cos, complex64, complex64_t, Cos)
|
||||||
@ -310,5 +316,6 @@ instantiate_unary_all(sin, complex64, complex64_t, Sin)
|
|||||||
instantiate_unary_all(sinh, complex64, complex64_t, Sinh)
|
instantiate_unary_all(sinh, complex64, complex64_t, Sinh)
|
||||||
instantiate_unary_all(tan, complex64, complex64_t, Tan)
|
instantiate_unary_all(tan, complex64, complex64_t, Tan)
|
||||||
instantiate_unary_all(tanh, complex64, complex64_t, Tanh)
|
instantiate_unary_all(tanh, complex64, complex64_t, Tanh)
|
||||||
|
instantiate_unary_all(round, complex64, complex64_t, Round)
|
||||||
|
|
||||||
instantiate_unary_all(lnot, bool_, bool, LogicalNot)
|
instantiate_unary_all(lnot, bool_, bool, LogicalNot)
|
||||||
|
@ -563,6 +563,17 @@ void Reshape::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void Round::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
const auto& in = inputs[0];
|
||||||
|
if (not is_integral(in.dtype())) {
|
||||||
|
unary_op(inputs, out, "round");
|
||||||
|
} else {
|
||||||
|
// No-op integer types
|
||||||
|
out.copy_shared_buffer(in);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void Sigmoid::eval_gpu(const std::vector<array>& inputs, array& out) {
|
void Sigmoid::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
unary_op(inputs, out, "sigmoid");
|
unary_op(inputs, out, "sigmoid");
|
||||||
}
|
}
|
||||||
|
@ -61,6 +61,7 @@ NO_GPU(Power)
|
|||||||
NO_GPU(RandomBits)
|
NO_GPU(RandomBits)
|
||||||
NO_GPU(Reduce)
|
NO_GPU(Reduce)
|
||||||
NO_GPU(Reshape)
|
NO_GPU(Reshape)
|
||||||
|
NO_GPU(Round)
|
||||||
NO_GPU(Scan)
|
NO_GPU(Scan)
|
||||||
NO_GPU(Scatter)
|
NO_GPU(Scatter)
|
||||||
NO_GPU(Sigmoid)
|
NO_GPU(Sigmoid)
|
||||||
|
15
mlx/ops.cpp
15
mlx/ops.cpp
@ -1834,6 +1834,21 @@ array stop_gradient(const array& a, StreamOrDevice s /* = {} */) {
|
|||||||
a.shape(), a.dtype(), std::make_unique<StopGradient>(to_stream(s)), {a});
|
a.shape(), a.dtype(), std::make_unique<StopGradient>(to_stream(s)), {a});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
array round(const array& a, int decimals, StreamOrDevice s /* = {} */) {
|
||||||
|
if (decimals == 0) {
|
||||||
|
return array(
|
||||||
|
a.shape(), a.dtype(), std::make_unique<Round>(to_stream(s)), {a});
|
||||||
|
}
|
||||||
|
|
||||||
|
auto dtype = at_least_float(a.dtype());
|
||||||
|
float scale = std::pow(10, decimals);
|
||||||
|
auto result = multiply(a, array(scale, dtype), s);
|
||||||
|
result = round(result, 0, s);
|
||||||
|
result = multiply(result, array(1 / scale, dtype), s);
|
||||||
|
|
||||||
|
return astype(result, a.dtype(), s);
|
||||||
|
}
|
||||||
|
|
||||||
array matmul(
|
array matmul(
|
||||||
const array& in_a,
|
const array& in_a,
|
||||||
const array& in_b,
|
const array& in_b,
|
||||||
|
@ -794,6 +794,12 @@ array erfinv(const array& a, StreamOrDevice s = {});
|
|||||||
/** Stop the flow of gradients. */
|
/** Stop the flow of gradients. */
|
||||||
array stop_gradient(const array& a, StreamOrDevice s = {});
|
array stop_gradient(const array& a, StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/** Round a floating point number */
|
||||||
|
array round(const array& a, int decimals, StreamOrDevice s = {});
|
||||||
|
inline array round(const array& a, StreamOrDevice s = {}) {
|
||||||
|
return round(a, 0, s);
|
||||||
|
}
|
||||||
|
|
||||||
/** Matrix-matrix multiplication. */
|
/** Matrix-matrix multiplication. */
|
||||||
array matmul(const array& a, const array& b, StreamOrDevice s = {});
|
array matmul(const array& a, const array& b, StreamOrDevice s = {});
|
||||||
|
|
||||||
|
@ -1888,6 +1888,30 @@ bool Reduce::is_equivalent(const Primitive& other) const {
|
|||||||
return reduce_type_ == r_other.reduce_type_ && axes_ == r_other.axes_;
|
return reduce_type_ == r_other.reduce_type_ && axes_ == r_other.axes_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<array> Round::vjp(
|
||||||
|
const std::vector<array>& primals,
|
||||||
|
const array& cotan,
|
||||||
|
const std::vector<int>& argnums) {
|
||||||
|
return {jvp(primals, {cotan}, argnums)};
|
||||||
|
}
|
||||||
|
|
||||||
|
array Round::jvp(
|
||||||
|
const std::vector<array>& primals,
|
||||||
|
const std::vector<array>& tangents,
|
||||||
|
const std::vector<int>& argnums) {
|
||||||
|
assert(primals.size() == 1);
|
||||||
|
assert(argnums.size() == 1);
|
||||||
|
return zeros_like(primals[0], stream());
|
||||||
|
}
|
||||||
|
|
||||||
|
std::pair<array, int> Round::vmap(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
const std::vector<int>& axes) {
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
assert(axes.size() == 1);
|
||||||
|
return {round(inputs[0], stream()), axes[0]};
|
||||||
|
}
|
||||||
|
|
||||||
std::pair<array, int> Scan::vmap(
|
std::pair<array, int> Scan::vmap(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
const std::vector<int>& axes) {
|
const std::vector<int>& axes) {
|
||||||
|
@ -1206,6 +1206,25 @@ class Reduce : public Primitive {
|
|||||||
void eval(const std::vector<array>& inputs, array& out);
|
void eval(const std::vector<array>& inputs, array& out);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class Round : public Primitive {
|
||||||
|
public:
|
||||||
|
explicit Round(Stream stream) : Primitive(stream){};
|
||||||
|
|
||||||
|
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||||
|
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||||
|
|
||||||
|
std::pair<array, int> vmap(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
const std::vector<int>& axes) override;
|
||||||
|
|
||||||
|
DEFINE_GRADS()
|
||||||
|
DEFINE_PRINT(Round)
|
||||||
|
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||||
|
|
||||||
|
private:
|
||||||
|
void eval(const std::vector<array>& inputs, array& out);
|
||||||
|
};
|
||||||
|
|
||||||
class Scan : public Primitive {
|
class Scan : public Primitive {
|
||||||
public:
|
public:
|
||||||
enum ReduceType { Max, Min, Sum, Prod };
|
enum ReduceType { Max, Min, Sum, Prod };
|
||||||
|
@ -1148,5 +1148,15 @@ void init_array(py::module_& m) {
|
|||||||
"reverse"_a = false,
|
"reverse"_a = false,
|
||||||
"inclusive"_a = true,
|
"inclusive"_a = true,
|
||||||
"stream"_a = none,
|
"stream"_a = none,
|
||||||
"See :func:`cummin`.");
|
"See :func:`cummin`.")
|
||||||
|
.def(
|
||||||
|
"round",
|
||||||
|
[](const array& a, int decimals, StreamOrDevice s) {
|
||||||
|
return round(a, decimals, s);
|
||||||
|
},
|
||||||
|
py::pos_only(),
|
||||||
|
"decimals"_a = 0,
|
||||||
|
py::kw_only(),
|
||||||
|
"stream"_a = none,
|
||||||
|
"See :func:`round`.");
|
||||||
}
|
}
|
||||||
|
@ -2922,4 +2922,33 @@ void init_ops(py::module_& m) {
|
|||||||
Returns:
|
Returns:
|
||||||
result (array): The output containing elements selected from ``x`` and ``y``.
|
result (array): The output containing elements selected from ``x`` and ``y``.
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
|
m.def(
|
||||||
|
"round",
|
||||||
|
[](const array& a, int decimals, StreamOrDevice s) {
|
||||||
|
return round(a, decimals, s);
|
||||||
|
},
|
||||||
|
"a"_a,
|
||||||
|
py::pos_only(),
|
||||||
|
"decimals"_a = 0,
|
||||||
|
py::kw_only(),
|
||||||
|
"stream"_a = none,
|
||||||
|
R"pbdoc(
|
||||||
|
round(a: array, /, decimals: int = 0, stream: Union[None, Stream, Device] = None) -> array
|
||||||
|
|
||||||
|
Round to the given number of decimals.
|
||||||
|
|
||||||
|
Bascially performs:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
s = 10**decimals
|
||||||
|
x = round(x * s) / s
|
||||||
|
|
||||||
|
Args:
|
||||||
|
a (array): Input array
|
||||||
|
decimals (int): Number of decimal places to round to. (default: 0)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
result (array): An array of the same type as ``a`` rounded to the given number of decimals.
|
||||||
|
)pbdoc");
|
||||||
}
|
}
|
||||||
|
@ -372,7 +372,35 @@ class TestOps(mlx_tests.MLXTestCase):
|
|||||||
self.assertListEqual(mx.ceil(x).tolist(), expected)
|
self.assertListEqual(mx.ceil(x).tolist(), expected)
|
||||||
|
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
mx.floor(mx.array([22 + 3j, 19 + 98j]))
|
mx.ceil(mx.array([22 + 3j, 19 + 98j]))
|
||||||
|
|
||||||
|
def test_round(self):
|
||||||
|
# float
|
||||||
|
x = mx.array(
|
||||||
|
[0.5, -0.5, 1.5, -1.5, -22.03, 19.98, -27, 9, 0.0, -np.inf, np.inf]
|
||||||
|
)
|
||||||
|
expected = [1, -1, 2, -2, -22, 20, -27, 9, 0, -np.inf, np.inf]
|
||||||
|
self.assertListEqual(mx.round(x).tolist(), expected)
|
||||||
|
|
||||||
|
# complex
|
||||||
|
y = mx.round(mx.array([22.2 + 3.6j, 19.5 + 98.2j]))
|
||||||
|
self.assertListEqual(y.tolist(), [22 + 4j, 20 + 98j])
|
||||||
|
|
||||||
|
# decimals
|
||||||
|
y0 = mx.round(mx.array([15, 122], mx.int32), decimals=0)
|
||||||
|
y1 = mx.round(mx.array([15, 122], mx.int32), decimals=-1)
|
||||||
|
y2 = mx.round(mx.array([15, 122], mx.int32), decimals=-2)
|
||||||
|
self.assertEqual(y0.dtype, mx.int32)
|
||||||
|
self.assertEqual(y1.dtype, mx.int32)
|
||||||
|
self.assertEqual(y2.dtype, mx.int32)
|
||||||
|
self.assertListEqual(y0.tolist(), [15, 122])
|
||||||
|
self.assertListEqual(y1.tolist(), [20, 120])
|
||||||
|
self.assertListEqual(y2.tolist(), [0, 100])
|
||||||
|
|
||||||
|
y1 = mx.round(mx.array([1.537, 1.471], mx.float32), decimals=1)
|
||||||
|
y2 = mx.round(mx.array([1.537, 1.471], mx.float32), decimals=2)
|
||||||
|
self.assertTrue(mx.allclose(y1, mx.array([1.5, 1.5])))
|
||||||
|
self.assertTrue(mx.allclose(y2, mx.array([1.54, 1.47])))
|
||||||
|
|
||||||
def test_transpose_noargs(self):
|
def test_transpose_noargs(self):
|
||||||
x = mx.array([[0, 1, 1], [1, 0, 0]])
|
x = mx.array([[0, 1, 1], [1, 0, 0]])
|
||||||
|
@ -862,6 +862,15 @@ TEST_CASE("test arithmetic unary ops") {
|
|||||||
CHECK_THROWS_AS(ceil(x), std::invalid_argument);
|
CHECK_THROWS_AS(ceil(x), std::invalid_argument);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Test round
|
||||||
|
{
|
||||||
|
array x({0.5, -0.5, 1.5, -1.5, 2.3, 2.6});
|
||||||
|
CHECK(array_equal(round(x), array({1, -1, 2, -2, 2, 3})).item<bool>());
|
||||||
|
|
||||||
|
x = array({11, 222, 32});
|
||||||
|
CHECK(array_equal(round(x, -1), array({10, 220, 30})).item<bool>());
|
||||||
|
}
|
||||||
|
|
||||||
// Test exponential
|
// Test exponential
|
||||||
{
|
{
|
||||||
array x(0.0);
|
array x(0.0);
|
||||||
|
Loading…
Reference in New Issue
Block a user