Floor and Ceil (#150)

* Implements Floor and Ceil Ops
This commit is contained in:
Luca Arnaboldi
2023-12-14 19:00:23 +01:00
committed by GitHub
parent 1e0c78b970
commit b93c4cf378
14 changed files with 250 additions and 4 deletions

View File

@@ -26,12 +26,14 @@ DEFAULT(ArgReduce)
DEFAULT(ArgSort)
DEFAULT(AsStrided)
DEFAULT(Broadcast)
DEFAULT(Ceil)
DEFAULT(Concatenate)
DEFAULT(Copy)
DEFAULT(Equal)
DEFAULT(Erf)
DEFAULT(ErfInv)
DEFAULT(FFT)
DEFAULT(Floor)
DEFAULT(Gather)
DEFAULT(Greater)
DEFAULT(GreaterEqual)

View File

@@ -29,6 +29,7 @@ DEFAULT(ArgSort)
DEFAULT(AsType)
DEFAULT(AsStrided)
DEFAULT(Broadcast)
DEFAULT(Ceil)
DEFAULT(Concatenate)
DEFAULT(Convolution)
DEFAULT(Copy)
@@ -41,6 +42,7 @@ DEFAULT(Erf)
DEFAULT(ErfInv)
DEFAULT(Exp)
DEFAULT(FFT)
DEFAULT(Floor)
DEFAULT(Full)
DEFAULT(Gather)
DEFAULT(Greater)

View File

@@ -167,6 +167,17 @@ void Broadcast::eval(const std::vector<array>& inputs, array& out) {
out.copy_shared_buffer(in, strides, flags, in.data_size());
}
void Ceil::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (not is_integral(in.dtype())) {
unary_fp(in, out, [](auto x) { return std::ceil(x); });
} else {
// No-op integer types
out.copy_shared_buffer(in);
}
}
void Concatenate::eval(const std::vector<array>& inputs, array& out) {
std::vector<int> sizes;
sizes.push_back(0);
@@ -287,6 +298,17 @@ void Exp::eval(const std::vector<array>& inputs, array& out) {
}
}
void Floor::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (not is_integral(in.dtype())) {
unary_fp(in, out, [](auto x) { return std::floor(x); });
} else {
// No-op integer types
out.copy_shared_buffer(in);
}
}
void Full::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];

View File

@@ -43,6 +43,19 @@ struct ArcTanh {
template <typename T> T operator()(T x) { return metal::precise::atanh(x); };
};
struct Ceil {
template <typename T> T operator()(T x) { return metal::ceil(x); };
template <> int8_t operator()(int8_t x) { return x; };
template <> int16_t operator()(int16_t x) { return x; };
template <> int32_t operator()(int32_t x) { return x; };
template <> int64_t operator()(int64_t x) { return x; };
template <> uint8_t operator()(uint8_t x) { return x; };
template <> uint16_t operator()(uint16_t x) { return x; };
template <> uint32_t operator()(uint32_t x) { return x; };
template <> uint64_t operator()(uint64_t x) { return x; };
template <> bool operator()(bool x) { return x; };
};
struct Cos {
template <typename T> T operator()(T x) { return metal::precise::cos(x); };
@@ -83,6 +96,19 @@ struct Exp {
}
};
struct Floor {
template <typename T> T operator()(T x) { return metal::floor(x); };
template <> int8_t operator()(int8_t x) { return x; };
template <> int16_t operator()(int16_t x) { return x; };
template <> int32_t operator()(int32_t x) { return x; };
template <> int64_t operator()(int64_t x) { return x; };
template <> uint8_t operator()(uint8_t x) { return x; };
template <> uint16_t operator()(uint16_t x) { return x; };
template <> uint32_t operator()(uint32_t x) { return x; };
template <> uint64_t operator()(uint64_t x) { return x; };
template <> bool operator()(bool x) { return x; };
};
struct Log {
template <typename T> T operator()(T x) { return metal::precise::log(x); };
};
@@ -253,9 +279,11 @@ instantiate_unary_float(arcsin, ArcSin)
instantiate_unary_float(arcsinh, ArcSinh)
instantiate_unary_float(arctan, ArcTan)
instantiate_unary_float(arctanh, ArcTanh)
instantiate_unary_types(ceil, Ceil)
instantiate_unary_float(cos, Cos)
instantiate_unary_float(cosh, Cosh)
instantiate_unary_float(exp, Exp)
instantiate_unary_types(floor, Floor)
instantiate_unary_float(log, Log)
instantiate_unary_float(log2, Log2)
instantiate_unary_float(log10, Log10)

View File

@@ -450,6 +450,14 @@ void Minimum::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, "min");
}
void Floor::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "floor");
}
void Ceil::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "ceil");
}
void Multiply::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, "mul");
}

View File

@@ -24,6 +24,7 @@ NO_GPU(ArgSort)
NO_GPU(AsType)
NO_GPU(AsStrided)
NO_GPU(Broadcast)
NO_GPU(Ceil)
NO_GPU(Concatenate)
NO_GPU(Convolution)
NO_GPU(Copy)
@@ -36,6 +37,7 @@ NO_GPU(Erf)
NO_GPU(ErfInv)
NO_GPU(Exp)
NO_GPU(FFT)
NO_GPU(Floor)
NO_GPU(Full)
NO_GPU(Gather)
NO_GPU(Greater)

View File

@@ -1498,6 +1498,21 @@ array minimum(const array& a, const array& b, StreamOrDevice s /* = {} */) {
inputs);
}
array floor(const array& a, StreamOrDevice s /* = {} */) {
if (a.dtype() == complex64) {
throw std::invalid_argument("[floor] Not supported for complex64.");
}
return array(
a.shape(), a.dtype(), std::make_unique<Floor>(to_stream(s)), {a});
}
array ceil(const array& a, StreamOrDevice s /* = {} */) {
if (a.dtype() == complex64) {
throw std::invalid_argument("[floor] Not supported for complex64.");
}
return array(a.shape(), a.dtype(), std::make_unique<Ceil>(to_stream(s)), {a});
}
array square(const array& a, StreamOrDevice s /* = {} */) {
return array(
a.shape(), a.dtype(), std::make_unique<Square>(to_stream(s)), {a});

View File

@@ -677,6 +677,12 @@ array maximum(const array& a, const array& b, StreamOrDevice s = {});
/** Element-wise minimum between two arrays. */
array minimum(const array& a, const array& b, StreamOrDevice s = {});
/** Floor the element of an array. **/
array floor(const array& a, StreamOrDevice s = {});
/** Ceil the element of an array. **/
array ceil(const array& a, StreamOrDevice s = {});
/** Square the elements of an array. */
array square(const array& a, StreamOrDevice s = {});

View File

@@ -441,6 +441,30 @@ bool Broadcast::is_equivalent(const Primitive& other) const {
return shape_ == b_other.shape_;
}
std::vector<array> Ceil::vjp(
const std::vector<array>& primals,
const array& cotan,
const std::vector<int>& argnums) {
return {jvp(primals, {cotan}, argnums)};
}
array Ceil::jvp(
const std::vector<array>& primals,
const std::vector<array>& tangents,
const std::vector<int>& argnums) {
assert(primals.size() == 1);
assert(argnums.size() == 1);
return zeros_like(primals[0], stream());
}
std::pair<array, int> Ceil::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
assert(inputs.size() == 1);
assert(axes.size() == 1);
return {ceil(inputs[0], stream()), axes[0]};
}
std::vector<array> Concatenate::vjp(
const std::vector<array>& primals,
const array& cotan,
@@ -748,8 +772,7 @@ std::vector<array> Remainder::vjp(
vjps.push_back(cotan);
} else {
auto x_over_y = divide(primals[0], primals[1], stream());
// TODO: Replace with a proper floor when available
x_over_y = astype(x_over_y, int32, stream());
x_over_y = floor(x_over_y, stream());
vjps.push_back(negative(multiply(x_over_y, cotan, stream()), stream()));
}
}
@@ -766,8 +789,7 @@ array Remainder::jvp(
return tangents[i];
} else {
auto x_over_y = divide(primals[0], primals[1], stream());
// TODO: Replace with a proper floor when available
x_over_y = astype(x_over_y, int32, stream());
x_over_y = floor(x_over_y, stream());
return negative(multiply(x_over_y, tangents[i], stream()), stream());
}
};
@@ -976,6 +998,30 @@ array FFT::jvp(
}
}
std::vector<array> Floor::vjp(
const std::vector<array>& primals,
const array& cotan,
const std::vector<int>& argnums) {
return {jvp(primals, {cotan}, argnums)};
}
array Floor::jvp(
const std::vector<array>& primals,
const std::vector<array>& tangents,
const std::vector<int>& argnums) {
assert(primals.size() == 1);
assert(argnums.size() == 1);
return zeros_like(primals[0], stream());
}
std::pair<array, int> Floor::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
assert(inputs.size() == 1);
assert(axes.size() == 1);
return {floor(inputs[0], stream()), axes[0]};
}
std::vector<array> Full::vjp(
const std::vector<array>& primals,
const array& cotan,

View File

@@ -404,6 +404,25 @@ class Broadcast : public Primitive {
void eval(const std::vector<array>& inputs, array& out);
};
class Ceil : public Primitive {
public:
explicit Ceil(Stream stream) : Primitive(stream){};
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
std::pair<array, int> vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) override;
DEFINE_GRADS()
DEFINE_PRINT(Ceil)
DEFINE_DEFAULT_IS_EQUIVALENT()
private:
void eval(const std::vector<array>& inputs, array& out);
};
class Concatenate : public Primitive {
public:
explicit Concatenate(Stream stream, int axis)
@@ -662,6 +681,25 @@ class FFT : public Primitive {
void eval(const std::vector<array>& inputs, array& out);
};
class Floor : public Primitive {
public:
explicit Floor(Stream stream) : Primitive(stream){};
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
std::pair<array, int> vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) override;
DEFINE_GRADS()
DEFINE_PRINT(Floor)
DEFINE_DEFAULT_IS_EQUIVALENT()
private:
void eval(const std::vector<array>& inputs, array& out);
};
class Full : public Primitive {
public:
explicit Full(Stream stream) : Primitive(stream){};