CPU binary reduction + Nits (#1242)

* very minor nits

* reduce binary

* fix test
This commit is contained in:
Awni Hannun 2024-06-28 13:50:42 -07:00 committed by GitHub
parent d6383a1c6a
commit 20bb301195
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 26 additions and 32 deletions

View File

@ -101,7 +101,7 @@ void Add::eval_cpu(const std::vector<array>& inputs, array& out) {
auto& b = inputs[1]; auto& b = inputs[1];
if (a.dtype() == float32) { if (a.dtype() == float32) {
binary( binary_op<float>(
a, a,
b, b,
out, out,
@ -116,7 +116,7 @@ void Add::eval_cpu(const std::vector<array>& inputs, array& out) {
vDSP_vadd((const float*)a, 1, (const float*)b, 1, (float*)o, 1, n); vDSP_vadd((const float*)a, 1, (const float*)b, 1, (float*)o, 1, n);
}); });
} else if (a.dtype() == int32) { } else if (a.dtype() == int32) {
binary( binary_op<int>(
a, a,
b, b,
out, out,
@ -131,7 +131,7 @@ void Add::eval_cpu(const std::vector<array>& inputs, array& out) {
vDSP_vaddi((const int*)a, 1, (const int*)b, 1, (int*)o, 1, n); vDSP_vaddi((const int*)a, 1, (const int*)b, 1, (int*)o, 1, n);
}); });
} else { } else {
binary(a, b, out, [](auto x, auto y) { return x + y; }); eval(inputs, out);
} }
} }
@ -286,7 +286,7 @@ void Divide::eval_cpu(const std::vector<array>& inputs, array& out) {
auto& b = inputs[1]; auto& b = inputs[1];
if (a.dtype() == int32) { if (a.dtype() == int32) {
binary( binary_op<int>(
a, a,
b, b,
out, out,
@ -299,7 +299,7 @@ void Divide::eval_cpu(const std::vector<array>& inputs, array& out) {
vDSP_vdivi((const int*)b, 1, (const int*)a, 1, (int*)o, 1, n); vDSP_vdivi((const int*)b, 1, (const int*)a, 1, (int*)o, 1, n);
}); });
} else if (a.dtype() == float32) { } else if (a.dtype() == float32) {
binary( binary_op<float>(
a, a,
b, b,
out, out,
@ -314,7 +314,7 @@ void Divide::eval_cpu(const std::vector<array>& inputs, array& out) {
vDSP_vdiv((const float*)b, 1, (const float*)a, 1, (float*)o, 1, n); vDSP_vdiv((const float*)b, 1, (const float*)a, 1, (float*)o, 1, n);
}); });
} else { } else {
binary(a, b, out, [](auto x, auto y) { return x / y; }); eval(inputs, out);
} }
} }
@ -325,12 +325,8 @@ void Exp::eval_cpu(const std::vector<array>& inputs, array& out) {
set_unary_output_data(in, out); set_unary_output_data(in, out);
auto size = in.data_size(); auto size = in.data_size();
vvexpf(out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size)); vvexpf(out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
} else if (issubdtype(out.dtype(), inexact)) {
unary_fp(in, out, [](auto x) { return std::exp(x); });
} else { } else {
throw std::invalid_argument( eval(inputs, out);
"[exp] Cannot exponentiate elements in array"
" with non floating point type.");
} }
} }
@ -392,12 +388,8 @@ void Log1p::eval_cpu(const std::vector<array>& inputs, array& out) {
auto size = in.data_size(); auto size = in.data_size();
vvlog1pf( vvlog1pf(
out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size)); out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
} else if (issubdtype(out.dtype(), inexact)) {
unary_fp(in, out, [](auto x) { return std::log1p(x); });
} else { } else {
throw std::invalid_argument( eval(inputs, out);
"[log1p] Cannot compute log of elements in array with"
" non floating point type.");
} }
} }
@ -407,7 +399,7 @@ void Multiply::eval_cpu(const std::vector<array>& inputs, array& out) {
auto& b = inputs[1]; auto& b = inputs[1];
if (a.dtype() == float32) { if (a.dtype() == float32) {
binary( binary_op<float>(
a, a,
b, b,
out, out,
@ -422,7 +414,7 @@ void Multiply::eval_cpu(const std::vector<array>& inputs, array& out) {
vDSP_vmul((const float*)a, 1, (const float*)b, 1, (float*)o, 1, n); vDSP_vmul((const float*)a, 1, (const float*)b, 1, (float*)o, 1, n);
}); });
} else { } else {
binary(a, b, out, [](auto x, auto y) { return x * y; }); eval(inputs, out);
} }
} }
@ -433,7 +425,7 @@ void Negative::eval_cpu(const std::vector<array>& inputs, array& out) {
set_unary_output_data(in, out); set_unary_output_data(in, out);
vDSP_vneg(in.data<float>(), 1, out.data<float>(), 1, in.data_size()); vDSP_vneg(in.data<float>(), 1, out.data<float>(), 1, in.data_size());
} else { } else {
unary(in, out, [](auto x) { return -x; }); eval(inputs, out);
} }
} }
@ -520,7 +512,7 @@ void Square::eval_cpu(const std::vector<array>& inputs, array& out) {
auto size = in.data_size(); auto size = in.data_size();
vDSP_vsq(in.data<float>(), 1, out.data<float>(), 1, size); vDSP_vsq(in.data<float>(), 1, out.data<float>(), 1, size);
} else { } else {
unary(in, out, [](auto x) { return x * x; }); eval(inputs, out);
} }
} }
@ -546,7 +538,7 @@ void Subtract::eval_cpu(const std::vector<array>& inputs, array& out) {
auto& b = inputs[1]; auto& b = inputs[1];
if (a.dtype() == float32) { if (a.dtype() == float32) {
binary( binary_op<float>(
a, a,
b, b,
out, out,
@ -564,7 +556,7 @@ void Subtract::eval_cpu(const std::vector<array>& inputs, array& out) {
vDSP_vsub((const float*)b, 1, (const float*)a, 1, (float*)o, 1, n); vDSP_vsub((const float*)b, 1, (const float*)a, 1, (float*)o, 1, n);
}); });
} else if (a.dtype() == int32) { } else if (a.dtype() == int32) {
binary( binary_op<int>(
a, a,
b, b,
out, out,
@ -576,7 +568,7 @@ void Subtract::eval_cpu(const std::vector<array>& inputs, array& out) {
}, },
UseDefaultBinaryOp()); UseDefaultBinaryOp());
} else { } else {
binary(a, b, out, [](auto x, auto y) { return x - y; }); eval(inputs, out);
} }
} }

View File

@ -76,7 +76,7 @@ std::pair<int, int> extract_quantized_matmul_dims(
if (w.dtype() != uint32) { if (w.dtype() != uint32) {
std::ostringstream msg; std::ostringstream msg;
msg << "[" << tag << "] The weight matrix should be uint32 " msg << "[" << tag << "] The weight matrix should be uint32 "
<< "but received" << w.dtype(); << "but received " << w.dtype();
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} }

View File

@ -41,12 +41,13 @@ array bits(
auto key = key_ ? *key_ : KeySequence::default_().next(); auto key = key_ ? *key_ : KeySequence::default_().next();
if (key.dtype() != uint32) { if (key.dtype() != uint32) {
std::ostringstream msg; std::ostringstream msg;
msg << "Expected key type uint32 but received " << key.dtype() << "."; msg << "[bits] Expected key type uint32 but received " << key.dtype()
<< ".";
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} }
if (key.shape() != std::vector<int>{2}) { if (key.shape() != std::vector<int>{2}) {
std::ostringstream msg; std::ostringstream msg;
msg << "Expected key shape (2) but received " << key.shape() << "."; msg << "[bits] Expected key shape (2) but received " << key.shape() << ".";
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} }
@ -110,7 +111,8 @@ array uniform(
StreamOrDevice s /* = {} */) { StreamOrDevice s /* = {} */) {
if (!issubdtype(dtype, floating)) { if (!issubdtype(dtype, floating)) {
throw std::invalid_argument( throw std::invalid_argument(
"Can only generate uniform numbers with real floating point type."); "[uniform] Can only generate uniform numbers with real "
"floating point type.");
} }
auto stream = to_stream(s); auto stream = to_stream(s);
@ -120,7 +122,7 @@ array uniform(
auto out_shape = broadcast_shapes(shape, range.shape()); auto out_shape = broadcast_shapes(shape, range.shape());
if (out_shape != shape) { if (out_shape != shape) {
std::ostringstream msg; std::ostringstream msg;
msg << "Cannot generate random values of shape " << shape msg << "[uniform] Cannot generate random values of shape " << shape
<< " from broadcasted shape " << out_shape << "."; << " from broadcasted shape " << out_shape << ".";
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} }

View File

@ -157,8 +157,8 @@ class Module(dict):
Args: Args:
file_or_weights (str or list(tuple(str, mx.array))): The path to file_or_weights (str or list(tuple(str, mx.array))): The path to
the weights ``.npz`` file (``.npz`` or ``.safetensors``) or a list of pairs of parameter names the weights ``.npz`` file (``.npz`` or ``.safetensors``) or a list
and arrays. of pairs of parameter names and arrays.
strict (bool, optional): If ``True`` then checks that the provided strict (bool, optional): If ``True`` then checks that the provided
weights exactly match the parameters of the model. Otherwise, weights exactly match the parameters of the model. Otherwise,
only the weights actually contained in the model are loaded and only the weights actually contained in the model are loaded and
@ -222,7 +222,7 @@ class Module(dict):
if v_new.shape != v.shape: if v_new.shape != v.shape:
raise ValueError( raise ValueError(
f"Expected shape {v.shape} but received " f"Expected shape {v.shape} but received "
f" shape {v_new.shape} for parameter {k}" f"shape {v_new.shape} for parameter {k}"
) )
self.update(tree_unflatten(weights)) self.update(tree_unflatten(weights))

View File

@ -83,7 +83,7 @@ void init_fast(nb::module_& parent_module) {
"offset"_a, "offset"_a,
"stream"_a = nb::none(), "stream"_a = nb::none(),
nb::sig( nb::sig(
"def rope(a: array, dims: int, *, traditinoal: bool, base: float, scale: float, offset: int, stream: Union[None, Stream, Device] = None) -> array"), "def rope(a: array, dims: int, *, traditional: bool, base: float, scale: float, offset: int, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc( R"pbdoc(
Apply rotary positional encoding to the input. Apply rotary positional encoding to the input.