Added clip function (#159)

* Added clip

* Added Python bindings

* Formatting

* Added cpp tests

* Added Python tests

* python bindings work

* rebase

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Cyril Zakka, MD
2023-12-17 22:00:29 -06:00
committed by GitHub
parent ee0c2835c5
commit 8eb56beb3a
6 changed files with 116 additions and 0 deletions

View File

@@ -595,6 +595,24 @@ split(const array& a, int num_splits, StreamOrDevice s /* = {} */) {
return split(a, num_splits, 0, to_stream(s));
}
array clip(
const array& a,
const std::optional<array>& a_min,
const std::optional<array>& a_max,
StreamOrDevice s /* = {} */) {
if (!a_min.has_value() && !a_max.has_value()) {
throw std::invalid_argument("At most one of a_min and a_max may be None");
}
array result = astype(a, a.dtype(), s);
if (a_min.has_value()) {
result = maximum(result, a_min.value(), s);
}
if (a_max.has_value()) {
result = minimum(result, a_max.value(), s);
}
return result;
}
array concatenate(
const std::vector<array>& arrays,
int axis,

View File

@@ -185,6 +185,15 @@ std::vector<array> split(
std::vector<array>
split(const array& a, const std::vector<int>& indices, StreamOrDevice s = {});
/**
* Clip (limit) the values in an array.
*/
array clip(
const array& a,
const std::optional<array>& a_min = std::nullopt,
const std::optional<array>& a_max = std::nullopt,
StreamOrDevice s = {});
/** Concatenate arrays along a given axis. */
array concatenate(
const std::vector<array>& arrays,