mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Adds round op and primitive (#203)
This commit is contained in:
committed by
GitHub
parent
477397bc98
commit
4d4af12c6f
@@ -47,6 +47,7 @@ DEFAULT(Pad)
|
||||
DEFAULT(Partition)
|
||||
DEFAULT(RandomBits)
|
||||
DEFAULT(Reshape)
|
||||
DEFAULT(Round)
|
||||
DEFAULT(Scatter)
|
||||
DEFAULT(Sigmoid)
|
||||
DEFAULT(Sign)
|
||||
|
||||
@@ -65,6 +65,7 @@ DEFAULT(Power)
|
||||
DEFAULT(RandomBits)
|
||||
DEFAULT(Reduce)
|
||||
DEFAULT(Reshape)
|
||||
DEFAULT(Round)
|
||||
DEFAULT(Scan)
|
||||
DEFAULT(Scatter)
|
||||
DEFAULT(Sigmoid)
|
||||
|
||||
@@ -466,6 +466,17 @@ void Reshape::eval(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
}
|
||||
|
||||
void Round::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
if (not is_integral(in.dtype())) {
|
||||
unary_fp(in, out, RoundOp());
|
||||
} else {
|
||||
// No-op integer types
|
||||
out.copy_shared_buffer(in);
|
||||
}
|
||||
}
|
||||
|
||||
void Sigmoid::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
|
||||
@@ -53,6 +53,17 @@ struct SignOp {
|
||||
}
|
||||
};
|
||||
|
||||
struct RoundOp {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return std::round(x);
|
||||
}
|
||||
|
||||
complex64_t operator()(complex64_t x) {
|
||||
return {std::round(x.real()), std::round(x.imag())};
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename Op>
|
||||
void unary_op(const array& a, array& out, Op op) {
|
||||
const T* a_ptr = a.data<T>();
|
||||
|
||||
@@ -133,6 +133,11 @@ struct Negative {
|
||||
template <typename T> T operator()(T x) { return -x; };
|
||||
};
|
||||
|
||||
struct Round {
|
||||
template <typename T> T operator()(T x) { return metal::round(x); };
|
||||
template <> complex64_t operator()(complex64_t x) { return {metal::round(x.real), metal::round(x.imag)}; };
|
||||
};
|
||||
|
||||
struct Sigmoid {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
@@ -300,6 +305,7 @@ instantiate_unary_float(sqrt, Sqrt)
|
||||
instantiate_unary_float(rsqrt, Rsqrt)
|
||||
instantiate_unary_float(tan, Tan)
|
||||
instantiate_unary_float(tanh, Tanh)
|
||||
instantiate_unary_float(round, Round)
|
||||
|
||||
instantiate_unary_all(abs, complex64, complex64_t, Abs)
|
||||
instantiate_unary_all(cos, complex64, complex64_t, Cos)
|
||||
@@ -310,5 +316,6 @@ instantiate_unary_all(sin, complex64, complex64_t, Sin)
|
||||
instantiate_unary_all(sinh, complex64, complex64_t, Sinh)
|
||||
instantiate_unary_all(tan, complex64, complex64_t, Tan)
|
||||
instantiate_unary_all(tanh, complex64, complex64_t, Tanh)
|
||||
instantiate_unary_all(round, complex64, complex64_t, Round)
|
||||
|
||||
instantiate_unary_all(lnot, bool_, bool, LogicalNot)
|
||||
|
||||
@@ -563,6 +563,17 @@ void Reshape::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
}
|
||||
|
||||
void Round::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (not is_integral(in.dtype())) {
|
||||
unary_op(inputs, out, "round");
|
||||
} else {
|
||||
// No-op integer types
|
||||
out.copy_shared_buffer(in);
|
||||
}
|
||||
}
|
||||
|
||||
void Sigmoid::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op(inputs, out, "sigmoid");
|
||||
}
|
||||
|
||||
@@ -61,6 +61,7 @@ NO_GPU(Power)
|
||||
NO_GPU(RandomBits)
|
||||
NO_GPU(Reduce)
|
||||
NO_GPU(Reshape)
|
||||
NO_GPU(Round)
|
||||
NO_GPU(Scan)
|
||||
NO_GPU(Scatter)
|
||||
NO_GPU(Sigmoid)
|
||||
|
||||
15
mlx/ops.cpp
15
mlx/ops.cpp
@@ -1834,6 +1834,21 @@ array stop_gradient(const array& a, StreamOrDevice s /* = {} */) {
|
||||
a.shape(), a.dtype(), std::make_unique<StopGradient>(to_stream(s)), {a});
|
||||
}
|
||||
|
||||
array round(const array& a, int decimals, StreamOrDevice s /* = {} */) {
|
||||
if (decimals == 0) {
|
||||
return array(
|
||||
a.shape(), a.dtype(), std::make_unique<Round>(to_stream(s)), {a});
|
||||
}
|
||||
|
||||
auto dtype = at_least_float(a.dtype());
|
||||
float scale = std::pow(10, decimals);
|
||||
auto result = multiply(a, array(scale, dtype), s);
|
||||
result = round(result, 0, s);
|
||||
result = multiply(result, array(1 / scale, dtype), s);
|
||||
|
||||
return astype(result, a.dtype(), s);
|
||||
}
|
||||
|
||||
array matmul(
|
||||
const array& in_a,
|
||||
const array& in_b,
|
||||
|
||||
@@ -794,6 +794,12 @@ array erfinv(const array& a, StreamOrDevice s = {});
|
||||
/** Stop the flow of gradients. */
|
||||
array stop_gradient(const array& a, StreamOrDevice s = {});
|
||||
|
||||
/** Round a floating point number */
|
||||
array round(const array& a, int decimals, StreamOrDevice s = {});
|
||||
inline array round(const array& a, StreamOrDevice s = {}) {
|
||||
return round(a, 0, s);
|
||||
}
|
||||
|
||||
/** Matrix-matrix multiplication. */
|
||||
array matmul(const array& a, const array& b, StreamOrDevice s = {});
|
||||
|
||||
|
||||
@@ -1888,6 +1888,30 @@ bool Reduce::is_equivalent(const Primitive& other) const {
|
||||
return reduce_type_ == r_other.reduce_type_ && axes_ == r_other.axes_;
|
||||
}
|
||||
|
||||
std::vector<array> Round::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const array& cotan,
|
||||
const std::vector<int>& argnums) {
|
||||
return {jvp(primals, {cotan}, argnums)};
|
||||
}
|
||||
|
||||
array Round::jvp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& tangents,
|
||||
const std::vector<int>& argnums) {
|
||||
assert(primals.size() == 1);
|
||||
assert(argnums.size() == 1);
|
||||
return zeros_like(primals[0], stream());
|
||||
}
|
||||
|
||||
std::pair<array, int> Round::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
assert(inputs.size() == 1);
|
||||
assert(axes.size() == 1);
|
||||
return {round(inputs[0], stream()), axes[0]};
|
||||
}
|
||||
|
||||
std::pair<array, int> Scan::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
|
||||
@@ -1206,6 +1206,25 @@ class Reduce : public Primitive {
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
class Round : public Primitive {
|
||||
public:
|
||||
explicit Round(Stream stream) : Primitive(stream){};
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
|
||||
std::pair<array, int> vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) override;
|
||||
|
||||
DEFINE_GRADS()
|
||||
DEFINE_PRINT(Round)
|
||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
class Scan : public Primitive {
|
||||
public:
|
||||
enum ReduceType { Max, Min, Sum, Prod };
|
||||
|
||||
Reference in New Issue
Block a user