mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-19 23:51:14 +08:00
feat: add logicalAnd and logicalOR (#386)
* feat: add logicalAnd and logicalOR * run pre-commit * Refactor logical_and and logical_or functions * Add acknowledgement * Add logical AND and logical OR operators * Refactor logical_and and logical_or functions * Add support for logical operators on bool arrays * Update mlx/ops.cpp Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * Update mlx/ops.cpp Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * Add logical AND and OR operators for arrays and scalars * Refactor vjp and jvp methods in primitives.cpp * Add overloaded operators for logical AND and OR * format --------- Co-authored-by: Awni Hannun <awni.hannun@gmail.com> Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
parent
022a944367
commit
73321b8097
@ -7,7 +7,7 @@ with a short description of your contribution(s) below. For example:
|
|||||||
|
|
||||||
MLX was developed with contributions from the following individuals:
|
MLX was developed with contributions from the following individuals:
|
||||||
|
|
||||||
- Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions. Added `dropout3d` ops.
|
- Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions. Added `dropout3d` ops. Added `LogicalAnd` and `LogicalOR` ops.
|
||||||
- Juarez Bochi: Fixed bug in cross attention.
|
- Juarez Bochi: Fixed bug in cross attention.
|
||||||
- Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example.
|
- Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example.
|
||||||
- Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer` and safetensor support
|
- Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer` and safetensor support
|
||||||
|
@ -60,6 +60,8 @@ Operations
|
|||||||
log1p
|
log1p
|
||||||
logaddexp
|
logaddexp
|
||||||
logical_not
|
logical_not
|
||||||
|
logical_and
|
||||||
|
logical_or
|
||||||
logsumexp
|
logsumexp
|
||||||
matmul
|
matmul
|
||||||
max
|
max
|
||||||
|
@ -41,6 +41,8 @@ DEFAULT(Less)
|
|||||||
DEFAULT(LessEqual)
|
DEFAULT(LessEqual)
|
||||||
DEFAULT(Load)
|
DEFAULT(Load)
|
||||||
DEFAULT(LogicalNot)
|
DEFAULT(LogicalNot)
|
||||||
|
DEFAULT(LogicalAnd)
|
||||||
|
DEFAULT(LogicalOr)
|
||||||
DEFAULT(LogAddExp)
|
DEFAULT(LogAddExp)
|
||||||
DEFAULT(NotEqual)
|
DEFAULT(NotEqual)
|
||||||
DEFAULT(Pad)
|
DEFAULT(Pad)
|
||||||
|
@ -57,6 +57,8 @@ DEFAULT(Load)
|
|||||||
DEFAULT(Log)
|
DEFAULT(Log)
|
||||||
DEFAULT(Log1p)
|
DEFAULT(Log1p)
|
||||||
DEFAULT(LogicalNot)
|
DEFAULT(LogicalNot)
|
||||||
|
DEFAULT(LogicalAnd)
|
||||||
|
DEFAULT(LogicalOr)
|
||||||
DEFAULT(LogAddExp)
|
DEFAULT(LogAddExp)
|
||||||
DEFAULT(Maximum)
|
DEFAULT(Maximum)
|
||||||
DEFAULT(Minimum)
|
DEFAULT(Minimum)
|
||||||
|
@ -8,6 +8,7 @@
|
|||||||
|
|
||||||
#include "mlx/allocator.h"
|
#include "mlx/allocator.h"
|
||||||
#include "mlx/backend/common/arange.h"
|
#include "mlx/backend/common/arange.h"
|
||||||
|
#include "mlx/backend/common/binary.h"
|
||||||
#include "mlx/backend/common/copy.h"
|
#include "mlx/backend/common/copy.h"
|
||||||
#include "mlx/backend/common/erf.h"
|
#include "mlx/backend/common/erf.h"
|
||||||
#include "mlx/backend/common/threefry.h"
|
#include "mlx/backend/common/threefry.h"
|
||||||
@ -364,6 +365,20 @@ void LogicalNot::eval(const std::vector<array>& inputs, array& out) {
|
|||||||
unary(in, out, [](auto x) { return !x; });
|
unary(in, out, [](auto x) { return !x; });
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void LogicalAnd::eval(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 2); // LogicalAnd requires two input arrays
|
||||||
|
auto& in1 = inputs[0];
|
||||||
|
auto& in2 = inputs[1];
|
||||||
|
binary(in1, in2, out, [](auto x, auto y) { return x && y; });
|
||||||
|
}
|
||||||
|
|
||||||
|
void LogicalOr::eval(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 2); // LogicalOr requires two input arrays
|
||||||
|
auto& in1 = inputs[0];
|
||||||
|
auto& in2 = inputs[1];
|
||||||
|
binary(in1, in2, out, [](auto x, auto y) { return x || y; });
|
||||||
|
}
|
||||||
|
|
||||||
void Negative::eval(const std::vector<array>& inputs, array& out) {
|
void Negative::eval(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
auto& in = inputs[0];
|
auto& in = inputs[0];
|
||||||
|
@ -131,6 +131,16 @@ struct Subtract {
|
|||||||
template <typename T> T operator()(T x, T y) { return x - y; }
|
template <typename T> T operator()(T x, T y) { return x - y; }
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct LogicalAnd {
|
||||||
|
template <typename T>
|
||||||
|
T operator()(T x, T y) { return x && y; };
|
||||||
|
};
|
||||||
|
|
||||||
|
struct LogicalOr {
|
||||||
|
template <typename T>
|
||||||
|
T operator()(T x, T y) { return x || y; };
|
||||||
|
};
|
||||||
|
|
||||||
template <typename T, typename U, typename Op>
|
template <typename T, typename U, typename Op>
|
||||||
[[kernel]] void binary_op_s2s(
|
[[kernel]] void binary_op_s2s(
|
||||||
device const T* a,
|
device const T* a,
|
||||||
@ -377,3 +387,6 @@ instantiate_binary_all(naneq, float16, half, bool, NaNEqual)
|
|||||||
instantiate_binary_all(naneq, float32, float, bool, NaNEqual)
|
instantiate_binary_all(naneq, float32, float, bool, NaNEqual)
|
||||||
instantiate_binary_all(naneq, bfloat16, bfloat16_t, bool, NaNEqual)
|
instantiate_binary_all(naneq, bfloat16, bfloat16_t, bool, NaNEqual)
|
||||||
instantiate_binary_all(naneq, complex64, complex64_t, bool, NaNEqual)
|
instantiate_binary_all(naneq, complex64, complex64_t, bool, NaNEqual)
|
||||||
|
|
||||||
|
instantiate_binary_all(lor, bool_, bool, bool, LogicalOr)
|
||||||
|
instantiate_binary_all(land, bool_, bool, bool, LogicalAnd)
|
@ -439,6 +439,20 @@ void LogicalNot::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
unary_op(inputs, out, "lnot");
|
unary_op(inputs, out, "lnot");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void LogicalAnd::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
binary_op(
|
||||||
|
inputs,
|
||||||
|
out,
|
||||||
|
"land"); // Assume "land" is the operation identifier for logical AND
|
||||||
|
}
|
||||||
|
|
||||||
|
void LogicalOr::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
binary_op(
|
||||||
|
inputs,
|
||||||
|
out,
|
||||||
|
"lor"); // Assume "lor" is the operation identifier for logical OR
|
||||||
|
}
|
||||||
|
|
||||||
void LogAddExp::eval_gpu(const std::vector<array>& inputs, array& out) {
|
void LogAddExp::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
binary_op(inputs, out, "lae");
|
binary_op(inputs, out, "lae");
|
||||||
}
|
}
|
||||||
|
@ -48,6 +48,8 @@ NO_GPU(Load)
|
|||||||
NO_GPU(Log)
|
NO_GPU(Log)
|
||||||
NO_GPU(Log1p)
|
NO_GPU(Log1p)
|
||||||
NO_GPU(LogicalNot)
|
NO_GPU(LogicalNot)
|
||||||
|
NO_GPU(LogicalAnd)
|
||||||
|
NO_GPU(LogicalOr)
|
||||||
NO_GPU(LogAddExp)
|
NO_GPU(LogAddExp)
|
||||||
NO_GPU(Matmul)
|
NO_GPU(Matmul)
|
||||||
NO_GPU(Maximum)
|
NO_GPU(Maximum)
|
||||||
|
37
mlx/ops.cpp
37
mlx/ops.cpp
@ -1608,6 +1608,43 @@ array logical_not(const array& a, StreamOrDevice s /* = {} */) {
|
|||||||
{astype(a, bool_, s)});
|
{astype(a, bool_, s)});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
array logical_and(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
||||||
|
// Broadcast arrays to a common shape
|
||||||
|
auto inputs = broadcast_arrays({astype(a, bool_, s), astype(b, bool_, s)}, s);
|
||||||
|
|
||||||
|
return array(
|
||||||
|
inputs[0].shape(),
|
||||||
|
bool_,
|
||||||
|
std::make_unique<LogicalAnd>(to_stream(s)),
|
||||||
|
inputs);
|
||||||
|
}
|
||||||
|
array operator&&(const array& a, const array& b) {
|
||||||
|
// check if a and b are bool arrays
|
||||||
|
if (a.dtype() != bool_ || b.dtype() != bool_) {
|
||||||
|
throw std::invalid_argument("[operator&&] only supported for bool arrays.");
|
||||||
|
}
|
||||||
|
return logical_and(a, b);
|
||||||
|
}
|
||||||
|
|
||||||
|
array logical_or(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
||||||
|
// Broadcast arrays to a common shape
|
||||||
|
auto inputs = broadcast_arrays({astype(a, bool_, s), astype(b, bool_, s)}, s);
|
||||||
|
|
||||||
|
return array(
|
||||||
|
inputs[0].shape(),
|
||||||
|
bool_,
|
||||||
|
std::make_unique<LogicalOr>(to_stream(s)),
|
||||||
|
inputs);
|
||||||
|
}
|
||||||
|
array operator||(const array& a, const array& b) {
|
||||||
|
// check if a and b are bool arrays
|
||||||
|
if (a.dtype() != bool_ || b.dtype() != bool_) {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[operator||] is only supported for bool arrays.");
|
||||||
|
}
|
||||||
|
return logical_or(a, b);
|
||||||
|
}
|
||||||
|
|
||||||
array reciprocal(const array& a, StreamOrDevice s /* = {} */) {
|
array reciprocal(const array& a, StreamOrDevice s /* = {} */) {
|
||||||
auto dtype = at_least_float(a.dtype());
|
auto dtype = at_least_float(a.dtype());
|
||||||
return divide(array(1.0f, dtype), a, to_stream(s));
|
return divide(array(1.0f, dtype), a, to_stream(s));
|
||||||
|
@ -668,6 +668,14 @@ array sign(const array& a, StreamOrDevice s = {});
|
|||||||
/** Logical not of an array */
|
/** Logical not of an array */
|
||||||
array logical_not(const array& a, StreamOrDevice s = {});
|
array logical_not(const array& a, StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/** Logical and of two arrays */
|
||||||
|
array logical_and(const array& a, const array& b, StreamOrDevice s = {});
|
||||||
|
array operator&&(const array& a, const array& b);
|
||||||
|
|
||||||
|
/** Logical or of two arrays */
|
||||||
|
array logical_or(const array& a, const array& b, StreamOrDevice s = {});
|
||||||
|
array operator||(const array& a, const array& b);
|
||||||
|
|
||||||
/** The reciprocal (1/x) of the elements in an array. */
|
/** The reciprocal (1/x) of the elements in an array. */
|
||||||
array reciprocal(const array& a, StreamOrDevice s = {});
|
array reciprocal(const array& a, StreamOrDevice s = {});
|
||||||
|
|
||||||
|
@ -1338,6 +1338,64 @@ std::pair<array, int> LogicalNot::vmap(
|
|||||||
return {logical_not(inputs[0], stream()), axes[0]};
|
return {logical_not(inputs[0], stream()), axes[0]};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<array> LogicalAnd::vjp(
|
||||||
|
const std::vector<array>& primals,
|
||||||
|
const array& cotan,
|
||||||
|
const std::vector<int>& argnums) {
|
||||||
|
assert(primals.size() == 2);
|
||||||
|
|
||||||
|
return {zeros_like(cotan, stream()), zeros_like(cotan, stream())};
|
||||||
|
}
|
||||||
|
|
||||||
|
array LogicalAnd::jvp(
|
||||||
|
const std::vector<array>& primals,
|
||||||
|
const std::vector<array>& tangents,
|
||||||
|
const std::vector<int>& argnums) {
|
||||||
|
assert(primals.size() == 2);
|
||||||
|
assert(argnums.size() <= 2);
|
||||||
|
|
||||||
|
return zeros_like(primals[0], stream());
|
||||||
|
}
|
||||||
|
|
||||||
|
std::pair<array, int> LogicalAnd::vmap(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
const std::vector<int>& axes) {
|
||||||
|
assert(inputs.size() == 2);
|
||||||
|
assert(axes.size() == 2);
|
||||||
|
|
||||||
|
auto [a, b, to_ax] = vmap_binary_op(inputs, axes, stream());
|
||||||
|
return {logical_and(a, b, stream()), to_ax};
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<array> LogicalOr::vjp(
|
||||||
|
const std::vector<array>& primals,
|
||||||
|
const array& cotan,
|
||||||
|
const std::vector<int>& argnums) {
|
||||||
|
assert(primals.size() == 2);
|
||||||
|
|
||||||
|
return {zeros_like(cotan, stream()), zeros_like(cotan, stream())};
|
||||||
|
}
|
||||||
|
|
||||||
|
array LogicalOr::jvp(
|
||||||
|
const std::vector<array>& primals,
|
||||||
|
const std::vector<array>& tangents,
|
||||||
|
const std::vector<int>& argnums) {
|
||||||
|
assert(primals.size() == 2);
|
||||||
|
assert(argnums.size() <= 2);
|
||||||
|
|
||||||
|
return zeros_like(primals[0], stream());
|
||||||
|
}
|
||||||
|
|
||||||
|
std::pair<array, int> LogicalOr::vmap(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
const std::vector<int>& axes) {
|
||||||
|
assert(inputs.size() == 2);
|
||||||
|
assert(axes.size() == 2);
|
||||||
|
|
||||||
|
auto [a, b, to_ax] = vmap_binary_op(inputs, axes, stream());
|
||||||
|
return {logical_or(a, b, stream()), to_ax};
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<array> LogAddExp::vjp(
|
std::vector<array> LogAddExp::vjp(
|
||||||
const std::vector<array>& primals,
|
const std::vector<array>& primals,
|
||||||
const array& cotan,
|
const array& cotan,
|
||||||
|
@ -903,6 +903,44 @@ class LogicalNot : public Primitive {
|
|||||||
void eval(const std::vector<array>& inputs, array& out);
|
void eval(const std::vector<array>& inputs, array& out);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class LogicalAnd : public Primitive {
|
||||||
|
public:
|
||||||
|
explicit LogicalAnd(Stream stream) : Primitive(stream){};
|
||||||
|
|
||||||
|
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||||
|
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||||
|
|
||||||
|
std::pair<array, int> vmap(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
const std::vector<int>& axes) override;
|
||||||
|
|
||||||
|
DEFINE_GRADS()
|
||||||
|
DEFINE_PRINT(LogicalAnd)
|
||||||
|
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||||
|
|
||||||
|
private:
|
||||||
|
void eval(const std::vector<array>& inputs, array& out);
|
||||||
|
};
|
||||||
|
|
||||||
|
class LogicalOr : public Primitive {
|
||||||
|
public:
|
||||||
|
explicit LogicalOr(Stream stream) : Primitive(stream){};
|
||||||
|
|
||||||
|
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||||
|
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||||
|
|
||||||
|
std::pair<array, int> vmap(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
const std::vector<int>& axes) override;
|
||||||
|
|
||||||
|
DEFINE_GRADS()
|
||||||
|
DEFINE_PRINT(LogicalOr)
|
||||||
|
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||||
|
|
||||||
|
private:
|
||||||
|
void eval(const std::vector<array>& inputs, array& out);
|
||||||
|
};
|
||||||
|
|
||||||
class LogAddExp : public Primitive {
|
class LogAddExp : public Primitive {
|
||||||
public:
|
public:
|
||||||
explicit LogAddExp(Stream stream) : Primitive(stream){};
|
explicit LogAddExp(Stream stream) : Primitive(stream){};
|
||||||
|
@ -764,6 +764,18 @@ void init_array(py::module_& m) {
|
|||||||
return power(a, to_array(v, a.dtype()));
|
return power(a, to_array(v, a.dtype()));
|
||||||
},
|
},
|
||||||
"other"_a)
|
"other"_a)
|
||||||
|
.def(
|
||||||
|
"__and__",
|
||||||
|
[](const array& a, const ScalarOrArray v) {
|
||||||
|
return logical_and(a, to_array(v, a.dtype()));
|
||||||
|
},
|
||||||
|
"other"_a)
|
||||||
|
.def(
|
||||||
|
"__or__",
|
||||||
|
[](const array& a, const ScalarOrArray v) {
|
||||||
|
return logical_or(a, to_array(v, a.dtype()));
|
||||||
|
},
|
||||||
|
"other"_a)
|
||||||
.def(
|
.def(
|
||||||
"flatten",
|
"flatten",
|
||||||
[](const array& a,
|
[](const array& a,
|
||||||
|
@ -670,6 +670,51 @@ void init_ops(py::module_& m) {
|
|||||||
Returns:
|
Returns:
|
||||||
array: The boolean array containing the logical not of ``a``.
|
array: The boolean array containing the logical not of ``a``.
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
|
m.def(
|
||||||
|
"logical_and",
|
||||||
|
[](const ScalarOrArray& a, const ScalarOrArray& b, StreamOrDevice s) {
|
||||||
|
return logical_and(to_array(a), to_array(b), s);
|
||||||
|
},
|
||||||
|
"a"_a,
|
||||||
|
"b"_a,
|
||||||
|
py::pos_only(),
|
||||||
|
py::kw_only(),
|
||||||
|
"stream"_a = none,
|
||||||
|
R"pbdoc(
|
||||||
|
logical_and(a: array, b: array, /, *, stream: Union[None, Stream, Device] = None) -> array
|
||||||
|
|
||||||
|
Element-wise logical and.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
a (array): First input array or scalar.
|
||||||
|
b (array): Second input array or scalar.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
array: The boolean array containing the logical and of ``a`` and ``b``.
|
||||||
|
)pbdoc");
|
||||||
|
|
||||||
|
m.def(
|
||||||
|
"logical_or",
|
||||||
|
[](const ScalarOrArray& a, const ScalarOrArray& b, StreamOrDevice s) {
|
||||||
|
return logical_or(to_array(a), to_array(b), s);
|
||||||
|
},
|
||||||
|
"a"_a,
|
||||||
|
"b"_a,
|
||||||
|
py::pos_only(),
|
||||||
|
py::kw_only(),
|
||||||
|
"stream"_a = none,
|
||||||
|
R"pbdoc(
|
||||||
|
logical_or(a: array, b: array, /, *, stream: Union[None, Stream, Device] = None) -> array
|
||||||
|
|
||||||
|
Element-wise logical or.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
a (array): First input array or scalar.
|
||||||
|
b (array): Second input array or scalar.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
array: The boolean array containing the logical or of ``a`` and ``b``.
|
||||||
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"logaddexp",
|
"logaddexp",
|
||||||
[](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) {
|
[](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) {
|
||||||
|
@ -610,6 +610,28 @@ class TestOps(mlx_tests.MLXTestCase):
|
|||||||
expected = np.logical_not(a)
|
expected = np.logical_not(a)
|
||||||
self.assertTrue(np.array_equal(result, expected))
|
self.assertTrue(np.array_equal(result, expected))
|
||||||
|
|
||||||
|
def test_logical_and(self):
|
||||||
|
a = mx.array([True, False, True, False])
|
||||||
|
b = mx.array([True, True, False, False])
|
||||||
|
result = mx.logical_and(a, b)
|
||||||
|
expected = np.logical_and(a, b)
|
||||||
|
self.assertTrue(np.array_equal(result, expected))
|
||||||
|
|
||||||
|
# test overloaded operator
|
||||||
|
result = a & b
|
||||||
|
self.assertTrue(np.array_equal(result, expected))
|
||||||
|
|
||||||
|
def test_logical_or(self):
|
||||||
|
a = mx.array([True, False, True, False])
|
||||||
|
b = mx.array([True, True, False, False])
|
||||||
|
result = mx.logical_or(a, b)
|
||||||
|
expected = np.logical_or(a, b)
|
||||||
|
self.assertTrue(np.array_equal(result, expected))
|
||||||
|
|
||||||
|
# test overloaded operator
|
||||||
|
result = a | b
|
||||||
|
self.assertTrue(np.array_equal(result, expected))
|
||||||
|
|
||||||
def test_square(self):
|
def test_square(self):
|
||||||
a = mx.array([0.1, 0.5, 1.0, 10.0])
|
a = mx.array([0.1, 0.5, 1.0, 10.0])
|
||||||
result = mx.square(a)
|
result = mx.square(a)
|
||||||
|
@ -80,6 +80,8 @@ class TestVmap(mlx_tests.MLXTestCase):
|
|||||||
"multiply",
|
"multiply",
|
||||||
"power",
|
"power",
|
||||||
"subtract",
|
"subtract",
|
||||||
|
"logical_or",
|
||||||
|
"logical_and",
|
||||||
]
|
]
|
||||||
for opname in ops:
|
for opname in ops:
|
||||||
with self.subTest(op=opname):
|
with self.subTest(op=opname):
|
||||||
|
@ -795,6 +795,44 @@ TEST_CASE("test arithmetic unary ops") {
|
|||||||
CHECK_EQ(y.item<bool>(), true);
|
CHECK_EQ(y.item<bool>(), true);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Test logical and
|
||||||
|
{
|
||||||
|
array x(true);
|
||||||
|
array y(true);
|
||||||
|
CHECK_EQ(logical_and(x, y).item<bool>(), true);
|
||||||
|
|
||||||
|
x = array(1.0f);
|
||||||
|
y = array(1.0f);
|
||||||
|
auto z = logical_and(x, y);
|
||||||
|
CHECK_EQ(z.dtype(), bool_);
|
||||||
|
CHECK_EQ(z.item<bool>(), true);
|
||||||
|
|
||||||
|
x = array(0);
|
||||||
|
y = array(1.0f);
|
||||||
|
z = logical_and(x, y);
|
||||||
|
CHECK_EQ(z.dtype(), bool_);
|
||||||
|
CHECK_EQ(z.item<bool>(), false);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test logical or
|
||||||
|
{
|
||||||
|
array x(false);
|
||||||
|
array y(false);
|
||||||
|
CHECK_EQ(logical_or(x, y).item<bool>(), false);
|
||||||
|
|
||||||
|
x = array(1.0f);
|
||||||
|
y = array(1.0f);
|
||||||
|
auto z = logical_or(x, y);
|
||||||
|
CHECK_EQ(z.dtype(), bool_);
|
||||||
|
CHECK_EQ(z.item<bool>(), true);
|
||||||
|
|
||||||
|
x = array(0);
|
||||||
|
y = array(1.0f);
|
||||||
|
z = logical_or(x, y);
|
||||||
|
CHECK_EQ(z.dtype(), bool_);
|
||||||
|
CHECK_EQ(z.item<bool>(), true);
|
||||||
|
}
|
||||||
|
|
||||||
// Test abs
|
// Test abs
|
||||||
{
|
{
|
||||||
array x({-1.0f, 0.0f, 1.0f});
|
array x({-1.0f, 0.0f, 1.0f});
|
||||||
|
Loading…
Reference in New Issue
Block a user