mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
CPU binary reduction + Nits (#1242)
* very minor nits * reduce binary * fix test
This commit is contained in:
parent
d6383a1c6a
commit
20bb301195
@ -101,7 +101,7 @@ void Add::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
auto& b = inputs[1];
|
auto& b = inputs[1];
|
||||||
|
|
||||||
if (a.dtype() == float32) {
|
if (a.dtype() == float32) {
|
||||||
binary(
|
binary_op<float>(
|
||||||
a,
|
a,
|
||||||
b,
|
b,
|
||||||
out,
|
out,
|
||||||
@ -116,7 +116,7 @@ void Add::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
vDSP_vadd((const float*)a, 1, (const float*)b, 1, (float*)o, 1, n);
|
vDSP_vadd((const float*)a, 1, (const float*)b, 1, (float*)o, 1, n);
|
||||||
});
|
});
|
||||||
} else if (a.dtype() == int32) {
|
} else if (a.dtype() == int32) {
|
||||||
binary(
|
binary_op<int>(
|
||||||
a,
|
a,
|
||||||
b,
|
b,
|
||||||
out,
|
out,
|
||||||
@ -131,7 +131,7 @@ void Add::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
vDSP_vaddi((const int*)a, 1, (const int*)b, 1, (int*)o, 1, n);
|
vDSP_vaddi((const int*)a, 1, (const int*)b, 1, (int*)o, 1, n);
|
||||||
});
|
});
|
||||||
} else {
|
} else {
|
||||||
binary(a, b, out, [](auto x, auto y) { return x + y; });
|
eval(inputs, out);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -286,7 +286,7 @@ void Divide::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
auto& b = inputs[1];
|
auto& b = inputs[1];
|
||||||
|
|
||||||
if (a.dtype() == int32) {
|
if (a.dtype() == int32) {
|
||||||
binary(
|
binary_op<int>(
|
||||||
a,
|
a,
|
||||||
b,
|
b,
|
||||||
out,
|
out,
|
||||||
@ -299,7 +299,7 @@ void Divide::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
vDSP_vdivi((const int*)b, 1, (const int*)a, 1, (int*)o, 1, n);
|
vDSP_vdivi((const int*)b, 1, (const int*)a, 1, (int*)o, 1, n);
|
||||||
});
|
});
|
||||||
} else if (a.dtype() == float32) {
|
} else if (a.dtype() == float32) {
|
||||||
binary(
|
binary_op<float>(
|
||||||
a,
|
a,
|
||||||
b,
|
b,
|
||||||
out,
|
out,
|
||||||
@ -314,7 +314,7 @@ void Divide::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
vDSP_vdiv((const float*)b, 1, (const float*)a, 1, (float*)o, 1, n);
|
vDSP_vdiv((const float*)b, 1, (const float*)a, 1, (float*)o, 1, n);
|
||||||
});
|
});
|
||||||
} else {
|
} else {
|
||||||
binary(a, b, out, [](auto x, auto y) { return x / y; });
|
eval(inputs, out);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -325,12 +325,8 @@ void Exp::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
set_unary_output_data(in, out);
|
set_unary_output_data(in, out);
|
||||||
auto size = in.data_size();
|
auto size = in.data_size();
|
||||||
vvexpf(out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
|
vvexpf(out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
|
||||||
} else if (issubdtype(out.dtype(), inexact)) {
|
|
||||||
unary_fp(in, out, [](auto x) { return std::exp(x); });
|
|
||||||
} else {
|
} else {
|
||||||
throw std::invalid_argument(
|
eval(inputs, out);
|
||||||
"[exp] Cannot exponentiate elements in array"
|
|
||||||
" with non floating point type.");
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -392,12 +388,8 @@ void Log1p::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
auto size = in.data_size();
|
auto size = in.data_size();
|
||||||
vvlog1pf(
|
vvlog1pf(
|
||||||
out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
|
out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
|
||||||
} else if (issubdtype(out.dtype(), inexact)) {
|
|
||||||
unary_fp(in, out, [](auto x) { return std::log1p(x); });
|
|
||||||
} else {
|
} else {
|
||||||
throw std::invalid_argument(
|
eval(inputs, out);
|
||||||
"[log1p] Cannot compute log of elements in array with"
|
|
||||||
" non floating point type.");
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -407,7 +399,7 @@ void Multiply::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
auto& b = inputs[1];
|
auto& b = inputs[1];
|
||||||
|
|
||||||
if (a.dtype() == float32) {
|
if (a.dtype() == float32) {
|
||||||
binary(
|
binary_op<float>(
|
||||||
a,
|
a,
|
||||||
b,
|
b,
|
||||||
out,
|
out,
|
||||||
@ -422,7 +414,7 @@ void Multiply::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
vDSP_vmul((const float*)a, 1, (const float*)b, 1, (float*)o, 1, n);
|
vDSP_vmul((const float*)a, 1, (const float*)b, 1, (float*)o, 1, n);
|
||||||
});
|
});
|
||||||
} else {
|
} else {
|
||||||
binary(a, b, out, [](auto x, auto y) { return x * y; });
|
eval(inputs, out);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -433,7 +425,7 @@ void Negative::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
set_unary_output_data(in, out);
|
set_unary_output_data(in, out);
|
||||||
vDSP_vneg(in.data<float>(), 1, out.data<float>(), 1, in.data_size());
|
vDSP_vneg(in.data<float>(), 1, out.data<float>(), 1, in.data_size());
|
||||||
} else {
|
} else {
|
||||||
unary(in, out, [](auto x) { return -x; });
|
eval(inputs, out);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -520,7 +512,7 @@ void Square::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
auto size = in.data_size();
|
auto size = in.data_size();
|
||||||
vDSP_vsq(in.data<float>(), 1, out.data<float>(), 1, size);
|
vDSP_vsq(in.data<float>(), 1, out.data<float>(), 1, size);
|
||||||
} else {
|
} else {
|
||||||
unary(in, out, [](auto x) { return x * x; });
|
eval(inputs, out);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -546,7 +538,7 @@ void Subtract::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
auto& b = inputs[1];
|
auto& b = inputs[1];
|
||||||
|
|
||||||
if (a.dtype() == float32) {
|
if (a.dtype() == float32) {
|
||||||
binary(
|
binary_op<float>(
|
||||||
a,
|
a,
|
||||||
b,
|
b,
|
||||||
out,
|
out,
|
||||||
@ -564,7 +556,7 @@ void Subtract::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
vDSP_vsub((const float*)b, 1, (const float*)a, 1, (float*)o, 1, n);
|
vDSP_vsub((const float*)b, 1, (const float*)a, 1, (float*)o, 1, n);
|
||||||
});
|
});
|
||||||
} else if (a.dtype() == int32) {
|
} else if (a.dtype() == int32) {
|
||||||
binary(
|
binary_op<int>(
|
||||||
a,
|
a,
|
||||||
b,
|
b,
|
||||||
out,
|
out,
|
||||||
@ -576,7 +568,7 @@ void Subtract::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
},
|
},
|
||||||
UseDefaultBinaryOp());
|
UseDefaultBinaryOp());
|
||||||
} else {
|
} else {
|
||||||
binary(a, b, out, [](auto x, auto y) { return x - y; });
|
eval(inputs, out);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -76,7 +76,7 @@ std::pair<int, int> extract_quantized_matmul_dims(
|
|||||||
if (w.dtype() != uint32) {
|
if (w.dtype() != uint32) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[" << tag << "] The weight matrix should be uint32 "
|
msg << "[" << tag << "] The weight matrix should be uint32 "
|
||||||
<< "but received" << w.dtype();
|
<< "but received " << w.dtype();
|
||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -41,12 +41,13 @@ array bits(
|
|||||||
auto key = key_ ? *key_ : KeySequence::default_().next();
|
auto key = key_ ? *key_ : KeySequence::default_().next();
|
||||||
if (key.dtype() != uint32) {
|
if (key.dtype() != uint32) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "Expected key type uint32 but received " << key.dtype() << ".";
|
msg << "[bits] Expected key type uint32 but received " << key.dtype()
|
||||||
|
<< ".";
|
||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
if (key.shape() != std::vector<int>{2}) {
|
if (key.shape() != std::vector<int>{2}) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "Expected key shape (2) but received " << key.shape() << ".";
|
msg << "[bits] Expected key shape (2) but received " << key.shape() << ".";
|
||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -110,7 +111,8 @@ array uniform(
|
|||||||
StreamOrDevice s /* = {} */) {
|
StreamOrDevice s /* = {} */) {
|
||||||
if (!issubdtype(dtype, floating)) {
|
if (!issubdtype(dtype, floating)) {
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
"Can only generate uniform numbers with real floating point type.");
|
"[uniform] Can only generate uniform numbers with real "
|
||||||
|
"floating point type.");
|
||||||
}
|
}
|
||||||
|
|
||||||
auto stream = to_stream(s);
|
auto stream = to_stream(s);
|
||||||
@ -120,7 +122,7 @@ array uniform(
|
|||||||
auto out_shape = broadcast_shapes(shape, range.shape());
|
auto out_shape = broadcast_shapes(shape, range.shape());
|
||||||
if (out_shape != shape) {
|
if (out_shape != shape) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "Cannot generate random values of shape " << shape
|
msg << "[uniform] Cannot generate random values of shape " << shape
|
||||||
<< " from broadcasted shape " << out_shape << ".";
|
<< " from broadcasted shape " << out_shape << ".";
|
||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
|
@ -157,8 +157,8 @@ class Module(dict):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
file_or_weights (str or list(tuple(str, mx.array))): The path to
|
file_or_weights (str or list(tuple(str, mx.array))): The path to
|
||||||
the weights ``.npz`` file (``.npz`` or ``.safetensors``) or a list of pairs of parameter names
|
the weights ``.npz`` file (``.npz`` or ``.safetensors``) or a list
|
||||||
and arrays.
|
of pairs of parameter names and arrays.
|
||||||
strict (bool, optional): If ``True`` then checks that the provided
|
strict (bool, optional): If ``True`` then checks that the provided
|
||||||
weights exactly match the parameters of the model. Otherwise,
|
weights exactly match the parameters of the model. Otherwise,
|
||||||
only the weights actually contained in the model are loaded and
|
only the weights actually contained in the model are loaded and
|
||||||
@ -222,7 +222,7 @@ class Module(dict):
|
|||||||
if v_new.shape != v.shape:
|
if v_new.shape != v.shape:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Expected shape {v.shape} but received "
|
f"Expected shape {v.shape} but received "
|
||||||
f" shape {v_new.shape} for parameter {k}"
|
f"shape {v_new.shape} for parameter {k}"
|
||||||
)
|
)
|
||||||
|
|
||||||
self.update(tree_unflatten(weights))
|
self.update(tree_unflatten(weights))
|
||||||
|
@ -83,7 +83,7 @@ void init_fast(nb::module_& parent_module) {
|
|||||||
"offset"_a,
|
"offset"_a,
|
||||||
"stream"_a = nb::none(),
|
"stream"_a = nb::none(),
|
||||||
nb::sig(
|
nb::sig(
|
||||||
"def rope(a: array, dims: int, *, traditinoal: bool, base: float, scale: float, offset: int, stream: Union[None, Stream, Device] = None) -> array"),
|
"def rope(a: array, dims: int, *, traditional: bool, base: float, scale: float, offset: int, stream: Union[None, Stream, Device] = None) -> array"),
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
Apply rotary positional encoding to the input.
|
Apply rotary positional encoding to the input.
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user