From 20bb30119559da672688af67c8f773aa304bf274 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Fri, 28 Jun 2024 13:50:42 -0700 Subject: [PATCH] CPU binary reduction + Nits (#1242) * very minor nits * reduce binary * fix test --- mlx/backend/accelerate/primitives.cpp | 38 +++++++++++---------------- mlx/ops.cpp | 2 +- mlx/random.cpp | 10 ++++--- python/mlx/nn/layers/base.py | 6 ++--- python/src/fast.cpp | 2 +- 5 files changed, 26 insertions(+), 32 deletions(-) diff --git a/mlx/backend/accelerate/primitives.cpp b/mlx/backend/accelerate/primitives.cpp index 8c3615599..778840115 100644 --- a/mlx/backend/accelerate/primitives.cpp +++ b/mlx/backend/accelerate/primitives.cpp @@ -101,7 +101,7 @@ void Add::eval_cpu(const std::vector& inputs, array& out) { auto& b = inputs[1]; if (a.dtype() == float32) { - binary( + binary_op( a, b, out, @@ -116,7 +116,7 @@ void Add::eval_cpu(const std::vector& inputs, array& out) { vDSP_vadd((const float*)a, 1, (const float*)b, 1, (float*)o, 1, n); }); } else if (a.dtype() == int32) { - binary( + binary_op( a, b, out, @@ -131,7 +131,7 @@ void Add::eval_cpu(const std::vector& inputs, array& out) { vDSP_vaddi((const int*)a, 1, (const int*)b, 1, (int*)o, 1, n); }); } else { - binary(a, b, out, [](auto x, auto y) { return x + y; }); + eval(inputs, out); } } @@ -286,7 +286,7 @@ void Divide::eval_cpu(const std::vector& inputs, array& out) { auto& b = inputs[1]; if (a.dtype() == int32) { - binary( + binary_op( a, b, out, @@ -299,7 +299,7 @@ void Divide::eval_cpu(const std::vector& inputs, array& out) { vDSP_vdivi((const int*)b, 1, (const int*)a, 1, (int*)o, 1, n); }); } else if (a.dtype() == float32) { - binary( + binary_op( a, b, out, @@ -314,7 +314,7 @@ void Divide::eval_cpu(const std::vector& inputs, array& out) { vDSP_vdiv((const float*)b, 1, (const float*)a, 1, (float*)o, 1, n); }); } else { - binary(a, b, out, [](auto x, auto y) { return x / y; }); + eval(inputs, out); } } @@ -325,12 +325,8 @@ void Exp::eval_cpu(const std::vector& inputs, array& out) { set_unary_output_data(in, out); auto size = in.data_size(); vvexpf(out.data(), in.data(), reinterpret_cast(&size)); - } else if (issubdtype(out.dtype(), inexact)) { - unary_fp(in, out, [](auto x) { return std::exp(x); }); } else { - throw std::invalid_argument( - "[exp] Cannot exponentiate elements in array" - " with non floating point type."); + eval(inputs, out); } } @@ -392,12 +388,8 @@ void Log1p::eval_cpu(const std::vector& inputs, array& out) { auto size = in.data_size(); vvlog1pf( out.data(), in.data(), reinterpret_cast(&size)); - } else if (issubdtype(out.dtype(), inexact)) { - unary_fp(in, out, [](auto x) { return std::log1p(x); }); } else { - throw std::invalid_argument( - "[log1p] Cannot compute log of elements in array with" - " non floating point type."); + eval(inputs, out); } } @@ -407,7 +399,7 @@ void Multiply::eval_cpu(const std::vector& inputs, array& out) { auto& b = inputs[1]; if (a.dtype() == float32) { - binary( + binary_op( a, b, out, @@ -422,7 +414,7 @@ void Multiply::eval_cpu(const std::vector& inputs, array& out) { vDSP_vmul((const float*)a, 1, (const float*)b, 1, (float*)o, 1, n); }); } else { - binary(a, b, out, [](auto x, auto y) { return x * y; }); + eval(inputs, out); } } @@ -433,7 +425,7 @@ void Negative::eval_cpu(const std::vector& inputs, array& out) { set_unary_output_data(in, out); vDSP_vneg(in.data(), 1, out.data(), 1, in.data_size()); } else { - unary(in, out, [](auto x) { return -x; }); + eval(inputs, out); } } @@ -520,7 +512,7 @@ void Square::eval_cpu(const std::vector& inputs, array& out) { auto size = in.data_size(); vDSP_vsq(in.data(), 1, out.data(), 1, size); } else { - unary(in, out, [](auto x) { return x * x; }); + eval(inputs, out); } } @@ -546,7 +538,7 @@ void Subtract::eval_cpu(const std::vector& inputs, array& out) { auto& b = inputs[1]; if (a.dtype() == float32) { - binary( + binary_op( a, b, out, @@ -564,7 +556,7 @@ void Subtract::eval_cpu(const std::vector& inputs, array& out) { vDSP_vsub((const float*)b, 1, (const float*)a, 1, (float*)o, 1, n); }); } else if (a.dtype() == int32) { - binary( + binary_op( a, b, out, @@ -576,7 +568,7 @@ void Subtract::eval_cpu(const std::vector& inputs, array& out) { }, UseDefaultBinaryOp()); } else { - binary(a, b, out, [](auto x, auto y) { return x - y; }); + eval(inputs, out); } } diff --git a/mlx/ops.cpp b/mlx/ops.cpp index d38b8e6f7..5d019edf0 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -76,7 +76,7 @@ std::pair extract_quantized_matmul_dims( if (w.dtype() != uint32) { std::ostringstream msg; msg << "[" << tag << "] The weight matrix should be uint32 " - << "but received" << w.dtype(); + << "but received " << w.dtype(); throw std::invalid_argument(msg.str()); } diff --git a/mlx/random.cpp b/mlx/random.cpp index 05405acbb..45ce85763 100644 --- a/mlx/random.cpp +++ b/mlx/random.cpp @@ -41,12 +41,13 @@ array bits( auto key = key_ ? *key_ : KeySequence::default_().next(); if (key.dtype() != uint32) { std::ostringstream msg; - msg << "Expected key type uint32 but received " << key.dtype() << "."; + msg << "[bits] Expected key type uint32 but received " << key.dtype() + << "."; throw std::invalid_argument(msg.str()); } if (key.shape() != std::vector{2}) { std::ostringstream msg; - msg << "Expected key shape (2) but received " << key.shape() << "."; + msg << "[bits] Expected key shape (2) but received " << key.shape() << "."; throw std::invalid_argument(msg.str()); } @@ -110,7 +111,8 @@ array uniform( StreamOrDevice s /* = {} */) { if (!issubdtype(dtype, floating)) { throw std::invalid_argument( - "Can only generate uniform numbers with real floating point type."); + "[uniform] Can only generate uniform numbers with real " + "floating point type."); } auto stream = to_stream(s); @@ -120,7 +122,7 @@ array uniform( auto out_shape = broadcast_shapes(shape, range.shape()); if (out_shape != shape) { std::ostringstream msg; - msg << "Cannot generate random values of shape " << shape + msg << "[uniform] Cannot generate random values of shape " << shape << " from broadcasted shape " << out_shape << "."; throw std::invalid_argument(msg.str()); } diff --git a/python/mlx/nn/layers/base.py b/python/mlx/nn/layers/base.py index 996410526..a756ace5c 100644 --- a/python/mlx/nn/layers/base.py +++ b/python/mlx/nn/layers/base.py @@ -157,8 +157,8 @@ class Module(dict): Args: file_or_weights (str or list(tuple(str, mx.array))): The path to - the weights ``.npz`` file (``.npz`` or ``.safetensors``) or a list of pairs of parameter names - and arrays. + the weights ``.npz`` file (``.npz`` or ``.safetensors``) or a list + of pairs of parameter names and arrays. strict (bool, optional): If ``True`` then checks that the provided weights exactly match the parameters of the model. Otherwise, only the weights actually contained in the model are loaded and @@ -222,7 +222,7 @@ class Module(dict): if v_new.shape != v.shape: raise ValueError( f"Expected shape {v.shape} but received " - f" shape {v_new.shape} for parameter {k}" + f"shape {v_new.shape} for parameter {k}" ) self.update(tree_unflatten(weights)) diff --git a/python/src/fast.cpp b/python/src/fast.cpp index f729b76fc..451937b21 100644 --- a/python/src/fast.cpp +++ b/python/src/fast.cpp @@ -83,7 +83,7 @@ void init_fast(nb::module_& parent_module) { "offset"_a, "stream"_a = nb::none(), nb::sig( - "def rope(a: array, dims: int, *, traditinoal: bool, base: float, scale: float, offset: int, stream: Union[None, Stream, Device] = None) -> array"), + "def rope(a: array, dims: int, *, traditional: bool, base: float, scale: float, offset: int, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( Apply rotary positional encoding to the input.