mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-17 06:21:12 +08:00
parent
1e0c78b970
commit
b93c4cf378
@ -26,6 +26,7 @@ Operations
|
||||
argsort
|
||||
array_equal
|
||||
broadcast_to
|
||||
ceil
|
||||
concatenate
|
||||
convolve
|
||||
conv1d
|
||||
@ -39,6 +40,7 @@ Operations
|
||||
exp
|
||||
expand_dims
|
||||
eye
|
||||
floor
|
||||
full
|
||||
greater
|
||||
greater_equal
|
||||
|
@ -26,12 +26,14 @@ DEFAULT(ArgReduce)
|
||||
DEFAULT(ArgSort)
|
||||
DEFAULT(AsStrided)
|
||||
DEFAULT(Broadcast)
|
||||
DEFAULT(Ceil)
|
||||
DEFAULT(Concatenate)
|
||||
DEFAULT(Copy)
|
||||
DEFAULT(Equal)
|
||||
DEFAULT(Erf)
|
||||
DEFAULT(ErfInv)
|
||||
DEFAULT(FFT)
|
||||
DEFAULT(Floor)
|
||||
DEFAULT(Gather)
|
||||
DEFAULT(Greater)
|
||||
DEFAULT(GreaterEqual)
|
||||
|
@ -29,6 +29,7 @@ DEFAULT(ArgSort)
|
||||
DEFAULT(AsType)
|
||||
DEFAULT(AsStrided)
|
||||
DEFAULT(Broadcast)
|
||||
DEFAULT(Ceil)
|
||||
DEFAULT(Concatenate)
|
||||
DEFAULT(Convolution)
|
||||
DEFAULT(Copy)
|
||||
@ -41,6 +42,7 @@ DEFAULT(Erf)
|
||||
DEFAULT(ErfInv)
|
||||
DEFAULT(Exp)
|
||||
DEFAULT(FFT)
|
||||
DEFAULT(Floor)
|
||||
DEFAULT(Full)
|
||||
DEFAULT(Gather)
|
||||
DEFAULT(Greater)
|
||||
|
@ -167,6 +167,17 @@ void Broadcast::eval(const std::vector<array>& inputs, array& out) {
|
||||
out.copy_shared_buffer(in, strides, flags, in.data_size());
|
||||
}
|
||||
|
||||
void Ceil::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
if (not is_integral(in.dtype())) {
|
||||
unary_fp(in, out, [](auto x) { return std::ceil(x); });
|
||||
} else {
|
||||
// No-op integer types
|
||||
out.copy_shared_buffer(in);
|
||||
}
|
||||
}
|
||||
|
||||
void Concatenate::eval(const std::vector<array>& inputs, array& out) {
|
||||
std::vector<int> sizes;
|
||||
sizes.push_back(0);
|
||||
@ -287,6 +298,17 @@ void Exp::eval(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
}
|
||||
|
||||
void Floor::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
if (not is_integral(in.dtype())) {
|
||||
unary_fp(in, out, [](auto x) { return std::floor(x); });
|
||||
} else {
|
||||
// No-op integer types
|
||||
out.copy_shared_buffer(in);
|
||||
}
|
||||
}
|
||||
|
||||
void Full::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
|
@ -43,6 +43,19 @@ struct ArcTanh {
|
||||
template <typename T> T operator()(T x) { return metal::precise::atanh(x); };
|
||||
};
|
||||
|
||||
struct Ceil {
|
||||
template <typename T> T operator()(T x) { return metal::ceil(x); };
|
||||
template <> int8_t operator()(int8_t x) { return x; };
|
||||
template <> int16_t operator()(int16_t x) { return x; };
|
||||
template <> int32_t operator()(int32_t x) { return x; };
|
||||
template <> int64_t operator()(int64_t x) { return x; };
|
||||
template <> uint8_t operator()(uint8_t x) { return x; };
|
||||
template <> uint16_t operator()(uint16_t x) { return x; };
|
||||
template <> uint32_t operator()(uint32_t x) { return x; };
|
||||
template <> uint64_t operator()(uint64_t x) { return x; };
|
||||
template <> bool operator()(bool x) { return x; };
|
||||
};
|
||||
|
||||
struct Cos {
|
||||
template <typename T> T operator()(T x) { return metal::precise::cos(x); };
|
||||
|
||||
@ -83,6 +96,19 @@ struct Exp {
|
||||
}
|
||||
};
|
||||
|
||||
struct Floor {
|
||||
template <typename T> T operator()(T x) { return metal::floor(x); };
|
||||
template <> int8_t operator()(int8_t x) { return x; };
|
||||
template <> int16_t operator()(int16_t x) { return x; };
|
||||
template <> int32_t operator()(int32_t x) { return x; };
|
||||
template <> int64_t operator()(int64_t x) { return x; };
|
||||
template <> uint8_t operator()(uint8_t x) { return x; };
|
||||
template <> uint16_t operator()(uint16_t x) { return x; };
|
||||
template <> uint32_t operator()(uint32_t x) { return x; };
|
||||
template <> uint64_t operator()(uint64_t x) { return x; };
|
||||
template <> bool operator()(bool x) { return x; };
|
||||
};
|
||||
|
||||
struct Log {
|
||||
template <typename T> T operator()(T x) { return metal::precise::log(x); };
|
||||
};
|
||||
@ -253,9 +279,11 @@ instantiate_unary_float(arcsin, ArcSin)
|
||||
instantiate_unary_float(arcsinh, ArcSinh)
|
||||
instantiate_unary_float(arctan, ArcTan)
|
||||
instantiate_unary_float(arctanh, ArcTanh)
|
||||
instantiate_unary_types(ceil, Ceil)
|
||||
instantiate_unary_float(cos, Cos)
|
||||
instantiate_unary_float(cosh, Cosh)
|
||||
instantiate_unary_float(exp, Exp)
|
||||
instantiate_unary_types(floor, Floor)
|
||||
instantiate_unary_float(log, Log)
|
||||
instantiate_unary_float(log2, Log2)
|
||||
instantiate_unary_float(log10, Log10)
|
||||
|
@ -450,6 +450,14 @@ void Minimum::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, "min");
|
||||
}
|
||||
|
||||
void Floor::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op(inputs, out, "floor");
|
||||
}
|
||||
|
||||
void Ceil::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op(inputs, out, "ceil");
|
||||
}
|
||||
|
||||
void Multiply::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, "mul");
|
||||
}
|
||||
|
@ -24,6 +24,7 @@ NO_GPU(ArgSort)
|
||||
NO_GPU(AsType)
|
||||
NO_GPU(AsStrided)
|
||||
NO_GPU(Broadcast)
|
||||
NO_GPU(Ceil)
|
||||
NO_GPU(Concatenate)
|
||||
NO_GPU(Convolution)
|
||||
NO_GPU(Copy)
|
||||
@ -36,6 +37,7 @@ NO_GPU(Erf)
|
||||
NO_GPU(ErfInv)
|
||||
NO_GPU(Exp)
|
||||
NO_GPU(FFT)
|
||||
NO_GPU(Floor)
|
||||
NO_GPU(Full)
|
||||
NO_GPU(Gather)
|
||||
NO_GPU(Greater)
|
||||
|
15
mlx/ops.cpp
15
mlx/ops.cpp
@ -1498,6 +1498,21 @@ array minimum(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
||||
inputs);
|
||||
}
|
||||
|
||||
array floor(const array& a, StreamOrDevice s /* = {} */) {
|
||||
if (a.dtype() == complex64) {
|
||||
throw std::invalid_argument("[floor] Not supported for complex64.");
|
||||
}
|
||||
return array(
|
||||
a.shape(), a.dtype(), std::make_unique<Floor>(to_stream(s)), {a});
|
||||
}
|
||||
|
||||
array ceil(const array& a, StreamOrDevice s /* = {} */) {
|
||||
if (a.dtype() == complex64) {
|
||||
throw std::invalid_argument("[floor] Not supported for complex64.");
|
||||
}
|
||||
return array(a.shape(), a.dtype(), std::make_unique<Ceil>(to_stream(s)), {a});
|
||||
}
|
||||
|
||||
array square(const array& a, StreamOrDevice s /* = {} */) {
|
||||
return array(
|
||||
a.shape(), a.dtype(), std::make_unique<Square>(to_stream(s)), {a});
|
||||
|
@ -677,6 +677,12 @@ array maximum(const array& a, const array& b, StreamOrDevice s = {});
|
||||
/** Element-wise minimum between two arrays. */
|
||||
array minimum(const array& a, const array& b, StreamOrDevice s = {});
|
||||
|
||||
/** Floor the element of an array. **/
|
||||
array floor(const array& a, StreamOrDevice s = {});
|
||||
|
||||
/** Ceil the element of an array. **/
|
||||
array ceil(const array& a, StreamOrDevice s = {});
|
||||
|
||||
/** Square the elements of an array. */
|
||||
array square(const array& a, StreamOrDevice s = {});
|
||||
|
||||
|
@ -441,6 +441,30 @@ bool Broadcast::is_equivalent(const Primitive& other) const {
|
||||
return shape_ == b_other.shape_;
|
||||
}
|
||||
|
||||
std::vector<array> Ceil::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const array& cotan,
|
||||
const std::vector<int>& argnums) {
|
||||
return {jvp(primals, {cotan}, argnums)};
|
||||
}
|
||||
|
||||
array Ceil::jvp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& tangents,
|
||||
const std::vector<int>& argnums) {
|
||||
assert(primals.size() == 1);
|
||||
assert(argnums.size() == 1);
|
||||
return zeros_like(primals[0], stream());
|
||||
}
|
||||
|
||||
std::pair<array, int> Ceil::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
assert(inputs.size() == 1);
|
||||
assert(axes.size() == 1);
|
||||
return {ceil(inputs[0], stream()), axes[0]};
|
||||
}
|
||||
|
||||
std::vector<array> Concatenate::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const array& cotan,
|
||||
@ -748,8 +772,7 @@ std::vector<array> Remainder::vjp(
|
||||
vjps.push_back(cotan);
|
||||
} else {
|
||||
auto x_over_y = divide(primals[0], primals[1], stream());
|
||||
// TODO: Replace with a proper floor when available
|
||||
x_over_y = astype(x_over_y, int32, stream());
|
||||
x_over_y = floor(x_over_y, stream());
|
||||
vjps.push_back(negative(multiply(x_over_y, cotan, stream()), stream()));
|
||||
}
|
||||
}
|
||||
@ -766,8 +789,7 @@ array Remainder::jvp(
|
||||
return tangents[i];
|
||||
} else {
|
||||
auto x_over_y = divide(primals[0], primals[1], stream());
|
||||
// TODO: Replace with a proper floor when available
|
||||
x_over_y = astype(x_over_y, int32, stream());
|
||||
x_over_y = floor(x_over_y, stream());
|
||||
return negative(multiply(x_over_y, tangents[i], stream()), stream());
|
||||
}
|
||||
};
|
||||
@ -976,6 +998,30 @@ array FFT::jvp(
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<array> Floor::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const array& cotan,
|
||||
const std::vector<int>& argnums) {
|
||||
return {jvp(primals, {cotan}, argnums)};
|
||||
}
|
||||
|
||||
array Floor::jvp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& tangents,
|
||||
const std::vector<int>& argnums) {
|
||||
assert(primals.size() == 1);
|
||||
assert(argnums.size() == 1);
|
||||
return zeros_like(primals[0], stream());
|
||||
}
|
||||
|
||||
std::pair<array, int> Floor::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
assert(inputs.size() == 1);
|
||||
assert(axes.size() == 1);
|
||||
return {floor(inputs[0], stream()), axes[0]};
|
||||
}
|
||||
|
||||
std::vector<array> Full::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const array& cotan,
|
||||
|
@ -404,6 +404,25 @@ class Broadcast : public Primitive {
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
class Ceil : public Primitive {
|
||||
public:
|
||||
explicit Ceil(Stream stream) : Primitive(stream){};
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
|
||||
std::pair<array, int> vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) override;
|
||||
|
||||
DEFINE_GRADS()
|
||||
DEFINE_PRINT(Ceil)
|
||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
class Concatenate : public Primitive {
|
||||
public:
|
||||
explicit Concatenate(Stream stream, int axis)
|
||||
@ -662,6 +681,25 @@ class FFT : public Primitive {
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
class Floor : public Primitive {
|
||||
public:
|
||||
explicit Floor(Stream stream) : Primitive(stream){};
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
|
||||
std::pair<array, int> vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) override;
|
||||
|
||||
DEFINE_GRADS()
|
||||
DEFINE_PRINT(Floor)
|
||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
class Full : public Primitive {
|
||||
public:
|
||||
explicit Full(Stream stream) : Primitive(stream){};
|
||||
|
@ -1555,6 +1555,42 @@ void init_ops(py::module_& m) {
|
||||
Returns:
|
||||
array: The max of ``a`` and ``b``.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"floor",
|
||||
&mlx::core::floor,
|
||||
"a"_a,
|
||||
py::pos_only(),
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
floor(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
Element-wise floor.
|
||||
|
||||
Args:
|
||||
a (array): Input array.
|
||||
|
||||
Returns:
|
||||
array: The floor of ``a``.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"ceil",
|
||||
&mlx::core::ceil,
|
||||
"a"_a,
|
||||
py::pos_only(),
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
ceil(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
Element-wise ceil.
|
||||
|
||||
Args:
|
||||
a (array): Input array.
|
||||
|
||||
Returns:
|
||||
array: The ceil of ``a``.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"transpose",
|
||||
[](const array& a,
|
||||
|
@ -334,6 +334,22 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
expected = [1, -5, 10]
|
||||
self.assertListEqual(mx.maximum(x, y).tolist(), expected)
|
||||
|
||||
def test_floor(self):
|
||||
x = mx.array([-22.03, 19.98, -27, 9, 0.0, -np.inf, np.inf])
|
||||
expected = [-23, 19, -27, 9, 0, -np.inf, np.inf]
|
||||
self.assertListEqual(mx.floor(x).tolist(), expected)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
mx.floor(mx.array([22 + 3j, 19 + 98j]))
|
||||
|
||||
def test_ceil(self):
|
||||
x = mx.array([-22.03, 19.98, -27, 9, 0.0, -np.inf, np.inf])
|
||||
expected = [-22, 20, -27, 9, 0, -np.inf, np.inf]
|
||||
self.assertListEqual(mx.ceil(x).tolist(), expected)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
mx.floor(mx.array([22 + 3j, 19 + 98j]))
|
||||
|
||||
def test_transpose_noargs(self):
|
||||
x = mx.array([[0, 1, 1], [1, 0, 0]])
|
||||
|
||||
|
@ -773,6 +773,29 @@ TEST_CASE("test arithmetic unary ops") {
|
||||
|
||||
constexpr float neginf = -std::numeric_limits<float>::infinity();
|
||||
|
||||
// Test floor and ceil
|
||||
{
|
||||
array x(1.0f);
|
||||
CHECK_EQ(floor(x).item<float>(), 1.0f);
|
||||
CHECK_EQ(ceil(x).item<float>(), 1.0f);
|
||||
|
||||
x = array(1.5f);
|
||||
CHECK_EQ(floor(x).item<float>(), 1.0f);
|
||||
CHECK_EQ(ceil(x).item<float>(), 2.0f);
|
||||
|
||||
x = array(-1.5f);
|
||||
CHECK_EQ(floor(x).item<float>(), -2.0f);
|
||||
CHECK_EQ(ceil(x).item<float>(), -1.0f);
|
||||
|
||||
x = array(neginf);
|
||||
CHECK_EQ(floor(x).item<float>(), neginf);
|
||||
CHECK_EQ(ceil(x).item<float>(), neginf);
|
||||
|
||||
x = array(std::complex<float>(1.0f, 1.0f));
|
||||
CHECK_THROWS_AS(floor(x), std::invalid_argument);
|
||||
CHECK_THROWS_AS(ceil(x), std::invalid_argument);
|
||||
}
|
||||
|
||||
// Test exponential
|
||||
{
|
||||
array x(0.0);
|
||||
|
Loading…
Reference in New Issue
Block a user