mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-21 04:31:48 +08:00
add groups in conv2d (#1569)
This commit is contained in:
19
mlx/ops.cpp
19
mlx/ops.cpp
@@ -1402,10 +1402,16 @@ array isnan(const array& a, StreamOrDevice s /* = {} */) {
|
||||
}
|
||||
|
||||
array isinf(const array& a, StreamOrDevice s /* = {} */) {
|
||||
if (issubdtype(a.dtype(), integer) || a.dtype() == bool_) {
|
||||
return full(a.shape(), false, bool_, s);
|
||||
}
|
||||
return logical_or(isposinf(a, s), isneginf(a, s), s);
|
||||
}
|
||||
|
||||
array isfinite(const array& a, StreamOrDevice s /* = {} */) {
|
||||
if (issubdtype(a.dtype(), integer) || a.dtype() == bool_) {
|
||||
return full(a.shape(), true, bool_, s);
|
||||
}
|
||||
return logical_not(logical_or(isinf(a, s), isnan(a, s), s), s);
|
||||
}
|
||||
|
||||
@@ -1497,10 +1503,17 @@ array isclose(
|
||||
auto out = less_equal(lhs, rhs, s);
|
||||
|
||||
// Correct the result for infinite values.
|
||||
auto any_inf = logical_or(isinf(a, s), isinf(b, s), s);
|
||||
auto a_pos_inf = isposinf(a, s);
|
||||
auto b_pos_inf = isposinf(b, s);
|
||||
auto a_neg_inf = isneginf(a, s);
|
||||
auto b_neg_inf = isneginf(b, s);
|
||||
auto any_inf = logical_or(
|
||||
logical_or(a_pos_inf, a_neg_inf, s),
|
||||
logical_or(b_pos_inf, b_neg_inf, s),
|
||||
s);
|
||||
auto both_inf = logical_or(
|
||||
logical_and(isposinf(a, s), isposinf(b, s), s),
|
||||
logical_and(isneginf(a, s), isneginf(b, s), s),
|
||||
logical_and(a_pos_inf, b_pos_inf, s),
|
||||
logical_and(a_neg_inf, b_neg_inf, s),
|
||||
s);
|
||||
|
||||
// Convert all elements where either value is infinite to False.
|
||||
|
Reference in New Issue
Block a user