mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
CPU binary reduction + Nits (#1242)
* very minor nits * reduce binary * fix test
This commit is contained in:
parent
d6383a1c6a
commit
20bb301195
@ -101,7 +101,7 @@ void Add::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& b = inputs[1];
|
||||
|
||||
if (a.dtype() == float32) {
|
||||
binary(
|
||||
binary_op<float>(
|
||||
a,
|
||||
b,
|
||||
out,
|
||||
@ -116,7 +116,7 @@ void Add::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
vDSP_vadd((const float*)a, 1, (const float*)b, 1, (float*)o, 1, n);
|
||||
});
|
||||
} else if (a.dtype() == int32) {
|
||||
binary(
|
||||
binary_op<int>(
|
||||
a,
|
||||
b,
|
||||
out,
|
||||
@ -131,7 +131,7 @@ void Add::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
vDSP_vaddi((const int*)a, 1, (const int*)b, 1, (int*)o, 1, n);
|
||||
});
|
||||
} else {
|
||||
binary(a, b, out, [](auto x, auto y) { return x + y; });
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
@ -286,7 +286,7 @@ void Divide::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& b = inputs[1];
|
||||
|
||||
if (a.dtype() == int32) {
|
||||
binary(
|
||||
binary_op<int>(
|
||||
a,
|
||||
b,
|
||||
out,
|
||||
@ -299,7 +299,7 @@ void Divide::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
vDSP_vdivi((const int*)b, 1, (const int*)a, 1, (int*)o, 1, n);
|
||||
});
|
||||
} else if (a.dtype() == float32) {
|
||||
binary(
|
||||
binary_op<float>(
|
||||
a,
|
||||
b,
|
||||
out,
|
||||
@ -314,7 +314,7 @@ void Divide::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
vDSP_vdiv((const float*)b, 1, (const float*)a, 1, (float*)o, 1, n);
|
||||
});
|
||||
} else {
|
||||
binary(a, b, out, [](auto x, auto y) { return x / y; });
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
@ -325,12 +325,8 @@ void Exp::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
set_unary_output_data(in, out);
|
||||
auto size = in.data_size();
|
||||
vvexpf(out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
|
||||
} else if (issubdtype(out.dtype(), inexact)) {
|
||||
unary_fp(in, out, [](auto x) { return std::exp(x); });
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[exp] Cannot exponentiate elements in array"
|
||||
" with non floating point type.");
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
@ -392,12 +388,8 @@ void Log1p::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
auto size = in.data_size();
|
||||
vvlog1pf(
|
||||
out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
|
||||
} else if (issubdtype(out.dtype(), inexact)) {
|
||||
unary_fp(in, out, [](auto x) { return std::log1p(x); });
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[log1p] Cannot compute log of elements in array with"
|
||||
" non floating point type.");
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
@ -407,7 +399,7 @@ void Multiply::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& b = inputs[1];
|
||||
|
||||
if (a.dtype() == float32) {
|
||||
binary(
|
||||
binary_op<float>(
|
||||
a,
|
||||
b,
|
||||
out,
|
||||
@ -422,7 +414,7 @@ void Multiply::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
vDSP_vmul((const float*)a, 1, (const float*)b, 1, (float*)o, 1, n);
|
||||
});
|
||||
} else {
|
||||
binary(a, b, out, [](auto x, auto y) { return x * y; });
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
@ -433,7 +425,7 @@ void Negative::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
set_unary_output_data(in, out);
|
||||
vDSP_vneg(in.data<float>(), 1, out.data<float>(), 1, in.data_size());
|
||||
} else {
|
||||
unary(in, out, [](auto x) { return -x; });
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
@ -520,7 +512,7 @@ void Square::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
auto size = in.data_size();
|
||||
vDSP_vsq(in.data<float>(), 1, out.data<float>(), 1, size);
|
||||
} else {
|
||||
unary(in, out, [](auto x) { return x * x; });
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
@ -546,7 +538,7 @@ void Subtract::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& b = inputs[1];
|
||||
|
||||
if (a.dtype() == float32) {
|
||||
binary(
|
||||
binary_op<float>(
|
||||
a,
|
||||
b,
|
||||
out,
|
||||
@ -564,7 +556,7 @@ void Subtract::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
vDSP_vsub((const float*)b, 1, (const float*)a, 1, (float*)o, 1, n);
|
||||
});
|
||||
} else if (a.dtype() == int32) {
|
||||
binary(
|
||||
binary_op<int>(
|
||||
a,
|
||||
b,
|
||||
out,
|
||||
@ -576,7 +568,7 @@ void Subtract::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
},
|
||||
UseDefaultBinaryOp());
|
||||
} else {
|
||||
binary(a, b, out, [](auto x, auto y) { return x - y; });
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -76,7 +76,7 @@ std::pair<int, int> extract_quantized_matmul_dims(
|
||||
if (w.dtype() != uint32) {
|
||||
std::ostringstream msg;
|
||||
msg << "[" << tag << "] The weight matrix should be uint32 "
|
||||
<< "but received" << w.dtype();
|
||||
<< "but received " << w.dtype();
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
|
@ -41,12 +41,13 @@ array bits(
|
||||
auto key = key_ ? *key_ : KeySequence::default_().next();
|
||||
if (key.dtype() != uint32) {
|
||||
std::ostringstream msg;
|
||||
msg << "Expected key type uint32 but received " << key.dtype() << ".";
|
||||
msg << "[bits] Expected key type uint32 but received " << key.dtype()
|
||||
<< ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (key.shape() != std::vector<int>{2}) {
|
||||
std::ostringstream msg;
|
||||
msg << "Expected key shape (2) but received " << key.shape() << ".";
|
||||
msg << "[bits] Expected key shape (2) but received " << key.shape() << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
@ -110,7 +111,8 @@ array uniform(
|
||||
StreamOrDevice s /* = {} */) {
|
||||
if (!issubdtype(dtype, floating)) {
|
||||
throw std::invalid_argument(
|
||||
"Can only generate uniform numbers with real floating point type.");
|
||||
"[uniform] Can only generate uniform numbers with real "
|
||||
"floating point type.");
|
||||
}
|
||||
|
||||
auto stream = to_stream(s);
|
||||
@ -120,7 +122,7 @@ array uniform(
|
||||
auto out_shape = broadcast_shapes(shape, range.shape());
|
||||
if (out_shape != shape) {
|
||||
std::ostringstream msg;
|
||||
msg << "Cannot generate random values of shape " << shape
|
||||
msg << "[uniform] Cannot generate random values of shape " << shape
|
||||
<< " from broadcasted shape " << out_shape << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
@ -157,8 +157,8 @@ class Module(dict):
|
||||
|
||||
Args:
|
||||
file_or_weights (str or list(tuple(str, mx.array))): The path to
|
||||
the weights ``.npz`` file (``.npz`` or ``.safetensors``) or a list of pairs of parameter names
|
||||
and arrays.
|
||||
the weights ``.npz`` file (``.npz`` or ``.safetensors``) or a list
|
||||
of pairs of parameter names and arrays.
|
||||
strict (bool, optional): If ``True`` then checks that the provided
|
||||
weights exactly match the parameters of the model. Otherwise,
|
||||
only the weights actually contained in the model are loaded and
|
||||
@ -222,7 +222,7 @@ class Module(dict):
|
||||
if v_new.shape != v.shape:
|
||||
raise ValueError(
|
||||
f"Expected shape {v.shape} but received "
|
||||
f" shape {v_new.shape} for parameter {k}"
|
||||
f"shape {v_new.shape} for parameter {k}"
|
||||
)
|
||||
|
||||
self.update(tree_unflatten(weights))
|
||||
|
@ -83,7 +83,7 @@ void init_fast(nb::module_& parent_module) {
|
||||
"offset"_a,
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def rope(a: array, dims: int, *, traditinoal: bool, base: float, scale: float, offset: int, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
"def rope(a: array, dims: int, *, traditional: bool, base: float, scale: float, offset: int, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Apply rotary positional encoding to the input.
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user