CPU binary reduction + Nits (#1242)

* very minor nits

* reduce binary

* fix test
This commit is contained in:
Awni Hannun 2024-06-28 13:50:42 -07:00 committed by GitHub
parent d6383a1c6a
commit 20bb301195
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 26 additions and 32 deletions

View File

@ -101,7 +101,7 @@ void Add::eval_cpu(const std::vector<array>& inputs, array& out) {
auto& b = inputs[1];
if (a.dtype() == float32) {
binary(
binary_op<float>(
a,
b,
out,
@ -116,7 +116,7 @@ void Add::eval_cpu(const std::vector<array>& inputs, array& out) {
vDSP_vadd((const float*)a, 1, (const float*)b, 1, (float*)o, 1, n);
});
} else if (a.dtype() == int32) {
binary(
binary_op<int>(
a,
b,
out,
@ -131,7 +131,7 @@ void Add::eval_cpu(const std::vector<array>& inputs, array& out) {
vDSP_vaddi((const int*)a, 1, (const int*)b, 1, (int*)o, 1, n);
});
} else {
binary(a, b, out, [](auto x, auto y) { return x + y; });
eval(inputs, out);
}
}
@ -286,7 +286,7 @@ void Divide::eval_cpu(const std::vector<array>& inputs, array& out) {
auto& b = inputs[1];
if (a.dtype() == int32) {
binary(
binary_op<int>(
a,
b,
out,
@ -299,7 +299,7 @@ void Divide::eval_cpu(const std::vector<array>& inputs, array& out) {
vDSP_vdivi((const int*)b, 1, (const int*)a, 1, (int*)o, 1, n);
});
} else if (a.dtype() == float32) {
binary(
binary_op<float>(
a,
b,
out,
@ -314,7 +314,7 @@ void Divide::eval_cpu(const std::vector<array>& inputs, array& out) {
vDSP_vdiv((const float*)b, 1, (const float*)a, 1, (float*)o, 1, n);
});
} else {
binary(a, b, out, [](auto x, auto y) { return x / y; });
eval(inputs, out);
}
}
@ -325,12 +325,8 @@ void Exp::eval_cpu(const std::vector<array>& inputs, array& out) {
set_unary_output_data(in, out);
auto size = in.data_size();
vvexpf(out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
} else if (issubdtype(out.dtype(), inexact)) {
unary_fp(in, out, [](auto x) { return std::exp(x); });
} else {
throw std::invalid_argument(
"[exp] Cannot exponentiate elements in array"
" with non floating point type.");
eval(inputs, out);
}
}
@ -392,12 +388,8 @@ void Log1p::eval_cpu(const std::vector<array>& inputs, array& out) {
auto size = in.data_size();
vvlog1pf(
out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
} else if (issubdtype(out.dtype(), inexact)) {
unary_fp(in, out, [](auto x) { return std::log1p(x); });
} else {
throw std::invalid_argument(
"[log1p] Cannot compute log of elements in array with"
" non floating point type.");
eval(inputs, out);
}
}
@ -407,7 +399,7 @@ void Multiply::eval_cpu(const std::vector<array>& inputs, array& out) {
auto& b = inputs[1];
if (a.dtype() == float32) {
binary(
binary_op<float>(
a,
b,
out,
@ -422,7 +414,7 @@ void Multiply::eval_cpu(const std::vector<array>& inputs, array& out) {
vDSP_vmul((const float*)a, 1, (const float*)b, 1, (float*)o, 1, n);
});
} else {
binary(a, b, out, [](auto x, auto y) { return x * y; });
eval(inputs, out);
}
}
@ -433,7 +425,7 @@ void Negative::eval_cpu(const std::vector<array>& inputs, array& out) {
set_unary_output_data(in, out);
vDSP_vneg(in.data<float>(), 1, out.data<float>(), 1, in.data_size());
} else {
unary(in, out, [](auto x) { return -x; });
eval(inputs, out);
}
}
@ -520,7 +512,7 @@ void Square::eval_cpu(const std::vector<array>& inputs, array& out) {
auto size = in.data_size();
vDSP_vsq(in.data<float>(), 1, out.data<float>(), 1, size);
} else {
unary(in, out, [](auto x) { return x * x; });
eval(inputs, out);
}
}
@ -546,7 +538,7 @@ void Subtract::eval_cpu(const std::vector<array>& inputs, array& out) {
auto& b = inputs[1];
if (a.dtype() == float32) {
binary(
binary_op<float>(
a,
b,
out,
@ -564,7 +556,7 @@ void Subtract::eval_cpu(const std::vector<array>& inputs, array& out) {
vDSP_vsub((const float*)b, 1, (const float*)a, 1, (float*)o, 1, n);
});
} else if (a.dtype() == int32) {
binary(
binary_op<int>(
a,
b,
out,
@ -576,7 +568,7 @@ void Subtract::eval_cpu(const std::vector<array>& inputs, array& out) {
},
UseDefaultBinaryOp());
} else {
binary(a, b, out, [](auto x, auto y) { return x - y; });
eval(inputs, out);
}
}

View File

@ -76,7 +76,7 @@ std::pair<int, int> extract_quantized_matmul_dims(
if (w.dtype() != uint32) {
std::ostringstream msg;
msg << "[" << tag << "] The weight matrix should be uint32 "
<< "but received" << w.dtype();
<< "but received " << w.dtype();
throw std::invalid_argument(msg.str());
}

View File

@ -41,12 +41,13 @@ array bits(
auto key = key_ ? *key_ : KeySequence::default_().next();
if (key.dtype() != uint32) {
std::ostringstream msg;
msg << "Expected key type uint32 but received " << key.dtype() << ".";
msg << "[bits] Expected key type uint32 but received " << key.dtype()
<< ".";
throw std::invalid_argument(msg.str());
}
if (key.shape() != std::vector<int>{2}) {
std::ostringstream msg;
msg << "Expected key shape (2) but received " << key.shape() << ".";
msg << "[bits] Expected key shape (2) but received " << key.shape() << ".";
throw std::invalid_argument(msg.str());
}
@ -110,7 +111,8 @@ array uniform(
StreamOrDevice s /* = {} */) {
if (!issubdtype(dtype, floating)) {
throw std::invalid_argument(
"Can only generate uniform numbers with real floating point type.");
"[uniform] Can only generate uniform numbers with real "
"floating point type.");
}
auto stream = to_stream(s);
@ -120,7 +122,7 @@ array uniform(
auto out_shape = broadcast_shapes(shape, range.shape());
if (out_shape != shape) {
std::ostringstream msg;
msg << "Cannot generate random values of shape " << shape
msg << "[uniform] Cannot generate random values of shape " << shape
<< " from broadcasted shape " << out_shape << ".";
throw std::invalid_argument(msg.str());
}

View File

@ -157,8 +157,8 @@ class Module(dict):
Args:
file_or_weights (str or list(tuple(str, mx.array))): The path to
the weights ``.npz`` file (``.npz`` or ``.safetensors``) or a list of pairs of parameter names
and arrays.
the weights ``.npz`` file (``.npz`` or ``.safetensors``) or a list
of pairs of parameter names and arrays.
strict (bool, optional): If ``True`` then checks that the provided
weights exactly match the parameters of the model. Otherwise,
only the weights actually contained in the model are loaded and
@ -222,7 +222,7 @@ class Module(dict):
if v_new.shape != v.shape:
raise ValueError(
f"Expected shape {v.shape} but received "
f" shape {v_new.shape} for parameter {k}"
f"shape {v_new.shape} for parameter {k}"
)
self.update(tree_unflatten(weights))

View File

@ -83,7 +83,7 @@ void init_fast(nb::module_& parent_module) {
"offset"_a,
"stream"_a = nb::none(),
nb::sig(
"def rope(a: array, dims: int, *, traditinoal: bool, base: float, scale: float, offset: int, stream: Union[None, Stream, Device] = None) -> array"),
"def rope(a: array, dims: int, *, traditional: bool, base: float, scale: float, offset: int, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Apply rotary positional encoding to the input.