mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Added clip function (#159)
* Added clip * Added Python bindings * Formatting * Added cpp tests * Added Python tests * python bindings work * rebase --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
18
mlx/ops.cpp
18
mlx/ops.cpp
@@ -595,6 +595,24 @@ split(const array& a, int num_splits, StreamOrDevice s /* = {} */) {
|
||||
return split(a, num_splits, 0, to_stream(s));
|
||||
}
|
||||
|
||||
array clip(
|
||||
const array& a,
|
||||
const std::optional<array>& a_min,
|
||||
const std::optional<array>& a_max,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
if (!a_min.has_value() && !a_max.has_value()) {
|
||||
throw std::invalid_argument("At most one of a_min and a_max may be None");
|
||||
}
|
||||
array result = astype(a, a.dtype(), s);
|
||||
if (a_min.has_value()) {
|
||||
result = maximum(result, a_min.value(), s);
|
||||
}
|
||||
if (a_max.has_value()) {
|
||||
result = minimum(result, a_max.value(), s);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
array concatenate(
|
||||
const std::vector<array>& arrays,
|
||||
int axis,
|
||||
|
||||
@@ -185,6 +185,15 @@ std::vector<array> split(
|
||||
std::vector<array>
|
||||
split(const array& a, const std::vector<int>& indices, StreamOrDevice s = {});
|
||||
|
||||
/**
|
||||
* Clip (limit) the values in an array.
|
||||
*/
|
||||
array clip(
|
||||
const array& a,
|
||||
const std::optional<array>& a_min = std::nullopt,
|
||||
const std::optional<array>& a_max = std::nullopt,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** Concatenate arrays along a given axis. */
|
||||
array concatenate(
|
||||
const std::vector<array>& arrays,
|
||||
|
||||
Reference in New Issue
Block a user