From d9d0777c2ea4eb1bf82377375282f5a980d19aab Mon Sep 17 00:00:00 2001
From: Awni Hannun Usage Examples Usage Examples
This operation itself can call other operations within it if needed. So, the simplest way to go about implementing this operation would be do so in terms of existing operations.
-array axpby(
+array axpby(
const array& x, // Input array x
const array& y, // Input array y
const float alpha, // Scaling factor for x
const float beta, // Scaling factor for y
StreamOrDevice s /* = {} */ // Stream on which to schedule the operation
-) {
+) {
// Scale x and y on the provided stream
- auto ax = multiply(array(alpha), x, s);
- auto by = multiply(array(beta), y, s);
+ auto ax = multiply(array(alpha), x, s);
+ auto by = multiply(array(beta), y, s);
// Add and return
- return add(ax, by, s);
-}
+ return add(ax, by, s);
+}
However, as we discussed earlier, this is not our goal. The operations themselves
@@ -768,10 +778,10 @@ a
on the CPU or GPU, and how it acts under transformations such as vjp
and
jvp
. These words on their own can be a bit abstract, so lets take a step
back and go to our example to give ourselves a more concrete image.
-class Axpby : public Primitive {
- public:
- explicit Axpby(Stream stream, float alpha, float beta)
- : Primitive(stream), alpha_(alpha), beta_(beta){};
+class Axpby : public Primitive {
+ public:
+ explicit Axpby(Stream stream, float alpha, float beta)
+ : Primitive(stream), alpha_(alpha), beta_(beta){};
/**
* A primitive must know how to evaluate itself on the CPU/GPU
@@ -779,47 +789,47 @@ back and go to our example to give ourselves a more concrete image.
*
* To avoid unnecessary allocations, the evaluation function
* is responsible for allocating space for the array.
- */
- void eval_cpu(const std::vector<array>& inputs, array& out) override;
- void eval_gpu(const std::vector<array>& inputs, array& out) override;
+ */
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
- /** The Jacobian-vector product. */
- array jvp(
- const std::vector<array>& primals,
- const std::vector<array>& tangents,
- const std::vector<int>& argnums) override;
+ /** The Jacobian-vector product. */
+ array jvp(
+ const std::vector<array>& primals,
+ const std::vector<array>& tangents,
+ const std::vector<int>& argnums) override;
- /** The vector-Jacobian product. */
- std::vector<array> vjp(
- const std::vector<array>& primals,
- const array& cotan,
- const std::vector<int>& argnums) override;
+ /** The vector-Jacobian product. */
+ std::vector<array> vjp(
+ const std::vector<array>& primals,
+ const array& cotan,
+ const std::vector<int>& argnums) override;
/**
* The primitive must know how to vectorize itself across
* the given axes. The output is a pair containing the array
* representing the vectorized computation and the axis which
* corresponds to the output vectorized dimension.
- */
- std::pair<array, int> vmap(
- const std::vector<array>& inputs,
- const std::vector<int>& axes) override;
+ */
+ std::pair<array, int> vmap(
+ const std::vector<array>& inputs,
+ const std::vector<int>& axes) override;
- /** Print the primitive. */
- void print(std::ostream& os) override {
- os << "Axpby";
- }
+ /** Print the primitive. */
+ void print(std::ostream& os) override {
+ os << "Axpby";
+ }
- /** Equivalence check **/
- bool is_equivalent(const Primitive& other) const override;
+ /** Equivalence check **/
+ bool is_equivalent(const Primitive& other) const override;
- private:
- float alpha_;
- float beta_;
+ private:
+ float alpha_;
+ float beta_;
- /** Fall back implementation for evaluation on CPU */
- void eval(const std::vector<array>& inputs, array& out);
-};
+ /** Fall back implementation for evaluation on CPU */
+ void eval(const std::vector<array>& inputs, array& out);
+};
The Axpby
class derives from the base Primitive
class and
@@ -836,38 +846,38 @@ the computation graph. An Primitive
that computes it, and the
array
inputs that are passed to the primitive.
Let’s re-implement our operation now in terms of our Axpby
primitive.
-array axpby(
+array axpby(
const array& x, // Input array x
const array& y, // Input array y
const float alpha, // Scaling factor for x
const float beta, // Scaling factor for y
StreamOrDevice s /* = {} */ // Stream on which to schedule the operation
-) {
+) {
// Promote dtypes between x and y as needed
- auto promoted_dtype = promote_types(x.dtype(), y.dtype());
+ auto promoted_dtype = promote_types(x.dtype(), y.dtype());
// Upcast to float32 for non-floating point inputs x and y
- auto out_dtype = is_floating_point(promoted_dtype)
- ? promoted_dtype
- : promote_types(promoted_dtype, float32);
+ auto out_dtype = is_floating_point(promoted_dtype)
+ ? promoted_dtype
+ : promote_types(promoted_dtype, float32);
// Cast x and y up to the determined dtype (on the same stream s)
- auto x_casted = astype(x, out_dtype, s);
- auto y_casted = astype(y, out_dtype, s);
+ auto x_casted = astype(x, out_dtype, s);
+ auto y_casted = astype(y, out_dtype, s);
// Broadcast the shapes of x and y (on the same stream s)
- auto broadcasted_inputs = broadcast_arrays({x_casted, y_casted}, s);
- auto out_shape = broadcasted_inputs[0].shape();
+ auto broadcasted_inputs = broadcast_arrays({x_casted, y_casted}, s);
+ auto out_shape = broadcasted_inputs[0].shape();
// Construct the array as the output of the Axpby primitive
// with the broadcasted and upcasted arrays as inputs
- return array(
- /* const std::vector<int>& shape = */ out_shape,
- /* Dtype dtype = */ out_dtype,
- /* std::unique_ptr<Primitive> primitive = */
- std::make_unique<Axpby>(to_stream(s), alpha, beta),
- /* const std::vector<array>& inputs = */ broadcasted_inputs);
-}
+ return array(
+ /* const std::vector<int>& shape = */ out_shape,
+ /* Dtype dtype = */ out_dtype,
+ /* std::unique_ptr<Primitive> primitive = */
+ std::make_unique<Axpby>(to_stream(s), alpha, beta),
+ /* const std::vector<array>& inputs = */ broadcasted_inputs);
+}
This operation now handles the following:
@@ -900,66 +910,66 @@ of these functions to allocate memory as needed
Our naive method will go over each element of the output array, find the
corresponding input elements of x
and y
and perform the operation
pointwise. This is captured in the templated function axpby_impl()
.
-template <typename T>
-void axpby_impl(
- const array& x,
- const array& y,
- array& out,
- float alpha_,
- float beta_) {
+template <typename T>
+void axpby_impl(
+ const array& x,
+ const array& y,
+ array& out,
+ float alpha_,
+ float beta_) {
// We only allocate memory when we are ready to fill the output
// malloc_or_wait synchronously allocates available memory
// There may be a wait executed here if the allocation is requested
// under memory-pressured conditions
- out.set_data(allocator::malloc_or_wait(out.nbytes()));
+ out.set_data(allocator::malloc_or_wait(out.nbytes()));
// Collect input and output data pointers
- const T* x_ptr = x.data<T>();
- const T* y_ptr = y.data<T>();
- T* out_ptr = out.data<T>();
+ const T* x_ptr = x.data<T>();
+ const T* y_ptr = y.data<T>();
+ T* out_ptr = out.data<T>();
// Cast alpha and beta to the relevant types
- T alpha = static_cast<T>(alpha_);
- T beta = static_cast<T>(beta_);
+ T alpha = static_cast<T>(alpha_);
+ T beta = static_cast<T>(beta_);
// Do the element-wise operation for each output
- for (size_t out_idx = 0; out_idx < out.size(); out_idx++) {
+ for (size_t out_idx = 0; out_idx < out.size(); out_idx++) {
// Map linear indices to offsets in x and y
- auto x_offset = elem_to_loc(out_idx, x.shape(), x.strides());
- auto y_offset = elem_to_loc(out_idx, y.shape(), y.strides());
+ auto x_offset = elem_to_loc(out_idx, x.shape(), x.strides());
+ auto y_offset = elem_to_loc(out_idx, y.shape(), y.strides());
// We allocate the output to be contiguous and regularly strided
// (defaults to row major) and hence it doesn't need additional mapping
- out_ptr[out_idx] = alpha * x_ptr[x_offset] + beta * y_ptr[y_offset];
- }
-}
+ out_ptr[out_idx] = alpha * x_ptr[x_offset] + beta * y_ptr[y_offset];
+ }
+}
Now, we would like our implementation to be able to do this pointwise operation
for all incoming floating point arrays. Accordingly, we add dispatches for
float32
, float16
, bfloat16
and complex64
. We throw an error
if we encounter an unexpected type.
-/** Fall back implementation for evaluation on CPU */
-void Axpby::eval(const std::vector<array>& inputs, array& out) {
+/** Fall back implementation for evaluation on CPU */
+void Axpby::eval(const std::vector<array>& inputs, array& out) {
// Check the inputs (registered in the op while constructing the out array)
- assert(inputs.size() == 2);
- auto& x = inputs[0];
- auto& y = inputs[1];
+ assert(inputs.size() == 2);
+ auto& x = inputs[0];
+ auto& y = inputs[1];
// Dispatch to the correct dtype
- if (out.dtype() == float32) {
- return axpby_impl<float>(x, y, out, alpha_, beta_);
- } else if (out.dtype() == float16) {
- return axpby_impl<float16_t>(x, y, out, alpha_, beta_);
- } else if (out.dtype() == bfloat16) {
- return axpby_impl<bfloat16_t>(x, y, out, alpha_, beta_);
- } else if (out.dtype() == complex64) {
- return axpby_impl<complex64_t>(x, y, out, alpha_, beta_);
- } else {
- throw std::runtime_error(
- "Axpby is only supported for floating point types.");
- }
-}
+ if (out.dtype() == float32) {
+ return axpby_impl<float>(x, y, out, alpha_, beta_);
+ } else if (out.dtype() == float16) {
+ return axpby_impl<float16_t>(x, y, out, alpha_, beta_);
+ } else if (out.dtype() == bfloat16) {
+ return axpby_impl<bfloat16_t>(x, y, out, alpha_, beta_);
+ } else if (out.dtype() == complex64) {
+ return axpby_impl<complex64_t>(x, y, out, alpha_, beta_);
+ } else {
+ throw std::runtime_error(
+ "Axpby is only supported for floating point types.");
+ }
+}
We have a fallback implementation! Now, to do what we are really here to do.
@@ -980,13 +990,13 @@ of y
Let’s write out an implementation that uses Accelerate in the right conditions.
It must simply allocate data for the output, copy elements of y
into it,
and then call the catlas_saxpby()
from accelerate.
-template <typename T>
-void axpby_impl_accelerate(
- const array& x,
- const array& y,
- array& out,
- float alpha_,
- float beta_) {
+template <typename T>
+void axpby_impl_accelerate(
+ const array& x,
+ const array& y,
+ array& out,
+ float alpha_,
+ float beta_) {
// Accelerate library provides catlas_saxpby which does
// Y = (alpha * X) + (beta * Y) in place
// To use it, we first copy the data in y over to the output array
@@ -996,54 +1006,54 @@ and then call the // The data in the output array is allocated to match the strides in y
// such that x, y, and out are contiguous in the same mode and
// no transposition is needed
- out.set_data(
- allocator::malloc_or_wait(y.data_size() * out.itemsize()),
- y.data_size(),
- y.strides(),
- y.flags());
+ out.set_data(
+ allocator::malloc_or_wait(y.data_size() * out.itemsize()),
+ y.data_size(),
+ y.strides(),
+ y.flags());
// We then copy over the elements using the contiguous vector specialization
- copy_inplace(y, out, CopyType::Vector);
+ copy_inplace(y, out, CopyType::Vector);
// Get x and y pointers for catlas_saxpby
- const T* x_ptr = x.data<T>();
- T* y_ptr = out.data<T>();
+ const T* x_ptr = x.data<T>();
+ T* y_ptr = out.data<T>();
- T alpha = static_cast<T>(alpha_);
- T beta = static_cast<T>(beta_);
+ T alpha = static_cast<T>(alpha_);
+ T beta = static_cast<T>(beta_);
// Call the inplace accelerate operator
- catlas_saxpby(
- /* N = */ out.size(),
- /* ALPHA = */ alpha,
- /* X = */ x_ptr,
- /* INCX = */ 1,
- /* BETA = */ beta,
- /* Y = */ y_ptr,
- /* INCY = */ 1);
-}
+ catlas_saxpby(
+ /* N = */ out.size(),
+ /* ALPHA = */ alpha,
+ /* X = */ x_ptr,
+ /* INCX = */ 1,
+ /* BETA = */ beta,
+ /* Y = */ y_ptr,
+ /* INCY = */ 1);
+}
Great! But what about the inputs that do not fit the criteria for accelerate?
Luckily, we can always just direct back to Axpby::eval()
.
With this in mind, lets finally implement our Axpby::eval_cpu()
.
-/** Evaluate primitive on CPU using accelerate specializations */
-void Axpby::eval_cpu(const std::vector<array>& inputs, array& out) {
- assert(inputs.size() == 2);
- auto& x = inputs[0];
- auto& y = inputs[1];
+/** Evaluate primitive on CPU using accelerate specializations */
+void Axpby::eval_cpu(const std::vector<array>& inputs, array& out) {
+ assert(inputs.size() == 2);
+ auto& x = inputs[0];
+ auto& y = inputs[1];
// Accelerate specialization for contiguous single precision float arrays
- if (out.dtype() == float32 &&
- ((x.flags().row_contiguous && y.flags().row_contiguous) ||
- (x.flags().col_contiguous && y.flags().col_contiguous))) {
- axpby_impl_accelerate<float>(x, y, out, alpha_, beta_);
- return;
- }
+ if (out.dtype() == float32 &&
+ ((x.flags().row_contiguous && y.flags().row_contiguous) ||
+ (x.flags().col_contiguous && y.flags().col_contiguous))) {
+ axpby_impl_accelerate<float>(x, y, out, alpha_, beta_);
+ return;
+ }
// Fall back to common backend if specializations are not available
- eval(inputs, out);
-}
+ eval(inputs, out);
+}
We have now hit a milestone! Just this much is enough to run the operation
@@ -1069,26 +1079,26 @@ all GPU kernels in MLX are written using metal.
as there are elements in the output. Each thread will pick the element it needs
from x
and y
, do the pointwise operation, and then update its assigned
element in the output.
-template <typename T>
-[[kernel]] void axpby_general(
- device const T* x [[buffer(0)]],
- device const T* y [[buffer(1)]],
- device T* out [[buffer(2)]],
- constant const float& alpha [[buffer(3)]],
- constant const float& beta [[buffer(4)]],
- constant const int* shape [[buffer(5)]],
- constant const size_t* x_strides [[buffer(6)]],
- constant const size_t* y_strides [[buffer(7)]],
- constant const int& ndim [[buffer(8)]],
- uint index [[thread_position_in_grid]]) {
+template <typename T>
+[[kernel]] void axpby_general(
+ device const T* x [[buffer(0)]],
+ device const T* y [[buffer(1)]],
+ device T* out [[buffer(2)]],
+ constant const float& alpha [[buffer(3)]],
+ constant const float& beta [[buffer(4)]],
+ constant const int* shape [[buffer(5)]],
+ constant const size_t* x_strides [[buffer(6)]],
+ constant const size_t* y_strides [[buffer(7)]],
+ constant const int& ndim [[buffer(8)]],
+ uint index [[thread_position_in_grid]]) {
// Convert linear indices to offsets in array
- auto x_offset = elem_to_loc(index, shape, x_strides, ndim);
- auto y_offset = elem_to_loc(index, shape, y_strides, ndim);
+ auto x_offset = elem_to_loc(index, shape, x_strides, ndim);
+ auto y_offset = elem_to_loc(index, shape, y_strides, ndim);
// Do the operation and update the output
- out[index] =
- static_cast<T>(alpha) * x[x_offset] + static_cast<T>(beta) * y[y_offset];
-}
+ out[index] =
+ static_cast<T>(alpha) * x[x_offset] + static_cast<T>(beta) * y[y_offset];
+}
We then need to instantiate this template for all floating point types and give
@@ -1108,10 +1118,10 @@ each data type.
constant const int& ndim [[buffer(8)]], \
uint index [[thread_position_in_grid]]);
-instantiate_axpby(float32, float);
-instantiate_axpby(float16, half);
-instantiate_axpby(bfloat16, bfloat16_t);
-instantiate_axpby(complex64, complex64_t);
+instantiate_axpby(float32, float);
+instantiate_axpby(float16, half);
+instantiate_axpby(bfloat16, bfloat16_t);
+instantiate_axpby(complex64, complex64_t);
This kernel will be compiled into a metal library mlx_ext.metallib
as we
@@ -1127,73 +1137,73 @@ go over this process in more detail later.
The logic to determine the kernel, set the inputs, resolve the grid dimensions
and dispatch it to the GPU are contained in Axpby::eval_gpu()
as shown
below.
-/** Evaluate primitive on GPU */
-void Axpby::eval_gpu(const std::vector<array>& inputs, array& out) {
+/** Evaluate primitive on GPU */
+void Axpby::eval_gpu(const std::vector<array>& inputs, array& out) {
// Prepare inputs
- assert(inputs.size() == 2);
- auto& x = inputs[0];
- auto& y = inputs[1];
+ assert(inputs.size() == 2);
+ auto& x = inputs[0];
+ auto& y = inputs[1];
// Each primitive carries the stream it should execute on
// and each stream carries its device identifiers
- auto& s = stream();
+ auto& s = stream();
// We get the needed metal device using the stream
- auto& d = metal::device(s.device);
+ auto& d = metal::device(s.device);
// Allocate output memory
- out.set_data(allocator::malloc_or_wait(out.nbytes()));
+ out.set_data(allocator::malloc_or_wait(out.nbytes()));
// Resolve name of kernel (corresponds to axpby.metal)
- std::ostringstream kname;
- kname << "axpby_" << "general_" << type_to_name(out);
+ std::ostringstream kname;
+ kname << "axpby_" << "general_" << type_to_name(out);
// Make sure the metal library is available and look for it
// in the same folder as this executable if needed
- d.register_library("mlx_ext", metal::get_colocated_mtllib_path);
+ d.register_library("mlx_ext", metal::get_colocated_mtllib_path);
// Make a kernel from this metal library
- auto kernel = d.get_kernel(kname.str(), "mlx_ext");
+ auto kernel = d.get_kernel(kname.str(), "mlx_ext");
// Prepare to encode kernel
- auto compute_encoder = d.get_command_encoder(s.index);
- compute_encoder->setComputePipelineState(kernel);
+ auto compute_encoder = d.get_command_encoder(s.index);
+ compute_encoder->setComputePipelineState(kernel);
// Kernel parameters are registered with buffer indices corresponding to
// those in the kernel declaration at axpby.metal
- int ndim = out.ndim();
- size_t nelem = out.size();
+ int ndim = out.ndim();
+ size_t nelem = out.size();
// Encode input arrays to kernel
- set_array_buffer(compute_encoder, x, 0);
- set_array_buffer(compute_encoder, y, 1);
+ set_array_buffer(compute_encoder, x, 0);
+ set_array_buffer(compute_encoder, y, 1);
// Encode output arrays to kernel
- set_array_buffer(compute_encoder, out, 2);
+ set_array_buffer(compute_encoder, out, 2);
// Encode alpha and beta
- compute_encoder->setBytes(&alpha_, sizeof(float), 3);
- compute_encoder->setBytes(&beta_, sizeof(float), 4);
+ compute_encoder->setBytes(&alpha_, sizeof(float), 3);
+ compute_encoder->setBytes(&beta_, sizeof(float), 4);
// Encode shape, strides and ndim
- compute_encoder->setBytes(x.shape().data(), ndim * sizeof(int), 5);
- compute_encoder->setBytes(x.strides().data(), ndim * sizeof(size_t), 6);
- compute_encoder->setBytes(y.strides().data(), ndim * sizeof(size_t), 7);
- compute_encoder->setBytes(&ndim, sizeof(int), 8);
+ compute_encoder->setBytes(x.shape().data(), ndim * sizeof(int), 5);
+ compute_encoder->setBytes(x.strides().data(), ndim * sizeof(size_t), 6);
+ compute_encoder->setBytes(y.strides().data(), ndim * sizeof(size_t), 7);
+ compute_encoder->setBytes(&ndim, sizeof(int), 8);
// We launch 1 thread for each input and make sure that the number of
// threads in any given threadgroup is not higher than the max allowed
- size_t tgp_size = std::min(nelem, kernel->maxTotalThreadsPerThreadgroup());
+ size_t tgp_size = std::min(nelem, kernel->maxTotalThreadsPerThreadgroup());
// Fix the 3D size of each threadgroup (in terms of threads)
- MTL::Size group_dims = MTL::Size(tgp_size, 1, 1);
+ MTL::Size group_dims = MTL::Size(tgp_size, 1, 1);
// Fix the 3D size of the launch grid (in terms of threads)
- MTL::Size grid_dims = MTL::Size(nelem, 1, 1);
+ MTL::Size grid_dims = MTL::Size(nelem, 1, 1);
// Launch the grid with the given number of threads divided among
// the given threadgroups
- compute_encoder->dispatchThreads(grid_dims, group_dims);
-}
+ compute_encoder->dispatchThreads(grid_dims, group_dims);
+}
We can now call the axpby()
operation on both the CPU and the GPU!
@@ -1213,11 +1223,11 @@ command buffers as needed. We suggest taking a deeper dive into
transformations in a Primitive
. These transformations can be built on
top of our operations, including the one we just defined now. Which then gives
us the following Axpby::jvp()
and Axpby::vjp()
implementations.
-/** The Jacobian-vector product. */
-array Axpby::jvp(
- const std::vector<array>& primals,
- const std::vector<array>& tangents,
- const std::vector<int>& argnums) {
+/** The Jacobian-vector product. */
+array Axpby::jvp(
+ const std::vector<array>& primals,
+ const std::vector<array>& tangents,
+ const std::vector<int>& argnums) {
// Forward mode diff that pushes along the tangents
// The jvp transform on the primitive can built with ops
// that are scheduled on the same stream as the primitive
@@ -1226,43 +1236,43 @@ us the following // jvp is just the tangent scaled by alpha
// Similarly, if argnums = {1}, the jvp is just the tangent
// scaled by beta
- if (argnums.size() > 1) {
- auto scale = argnums[0] == 0 ? alpha_ : beta_;
- auto scale_arr = array(scale, tangents[0].dtype());
- return multiply(scale_arr, tangents[0], stream());
- }
+ if (argnums.size() > 1) {
+ auto scale = argnums[0] == 0 ? alpha_ : beta_;
+ auto scale_arr = array(scale, tangents[0].dtype());
+ return multiply(scale_arr, tangents[0], stream());
+ }
// If, argnums = {0, 1}, we take contributions from both
// which gives us jvp = tangent_x * alpha + tangent_y * beta
- else {
- return axpby(tangents[0], tangents[1], alpha_, beta_, stream());
- }
-}
+ else {
+ return axpby(tangents[0], tangents[1], alpha_, beta_, stream());
+ }
+}
-/** The vector-Jacobian product. */
-std::vector<array> Axpby::vjp(
- const std::vector<array>& primals,
- const array& cotan,
- const std::vector<int>& argnums) {
+/** The vector-Jacobian product. */
+std::vector<array> Axpby::vjp(
+ const std::vector<array>& primals,
+ const array& cotan,
+ const std::vector<int>& argnums) {
// Reverse mode diff
- std::vector<array> vjps;
- for (auto arg : argnums) {
- auto scale = arg == 0 ? alpha_ : beta_;
- auto scale_arr = array(scale, cotan.dtype());
- vjps.push_back(multiply(scale_arr, cotan, stream()));
- }
- return vjps;
-}
+ std::vector<array> vjps;
+ for (auto arg : argnums) {
+ auto scale = arg == 0 ? alpha_ : beta_;
+ auto scale_arr = array(scale, cotan.dtype());
+ vjps.push_back(multiply(scale_arr, cotan, stream()));
+ }
+ return vjps;
+}
Finally, you need not have a transformation fully defined to start using your
own Primitive
.
-/** Vectorize primitive along given axis */
-std::pair<array, int> Axpby::vmap(
- const std::vector<array>& inputs,
- const std::vector<int>& axes) {
- throw std::runtime_error("Axpby has no vmap implementation.");
-}
+/** Vectorize primitive along given axis */
+std::pair<array, int> Axpby::vmap(
+ const std::vector<array>& inputs,
+ const std::vector<int>& axes) {
+ throw std::runtime_error("Axpby has no vmap implementation.");
+}
@@ -1297,20 +1307,20 @@ the python package
We use PyBind11 to build a Python API for the C++ library. Since bindings
for all needed components such as mlx.core.array, mlx.core.stream, etc.
are already provided, adding our axpby()
becomes very simple!
-PYBIND11_MODULE(mlx_sample_extensions, m) {
- m.doc() = "Sample C++ and metal extensions for MLX";
+PYBIND11_MODULE(mlx_sample_extensions, m) {
+ m.doc() = "Sample C++ and metal extensions for MLX";
- m.def(
- "axpby",
- &axpby,
- "x"_a,
- "y"_a,
- py::pos_only(),
- "alpha"_a,
- "beta"_a,
- py::kw_only(),
- "stream"_a = py::none(),
- R"pbdoc(
+ m.def(
+ "axpby",
+ &axpby,
+ "x"_a,
+ "y"_a,
+ py::pos_only(),
+ "alpha"_a,
+ "beta"_a,
+ py::kw_only(),
+ "stream"_a = py::none(),
+ R"pbdoc(
Scale and sum two vectors element-wise
``z = alpha * x + beta * y``
@@ -1325,8 +1335,8 @@ are already provided, adding our Returns:
array: ``alpha * x + beta * y``
- )pbdoc");
-}
+ )pbdoc");
+}
Most of the complexity in the above example comes from additional bells and
@@ -1463,7 +1473,7 @@ import the python package and play with it as you would any other MLX operation!
print(f"c shape: {c.shape}")
print(f"c dtype: {c.dtype}")
-print(f"c correctness: {mx.all(c == 6.0).item()}")
+print(f"c correctness: {mx.all(c == 6.0).item()}")
Output:
diff --git a/docs/build/html/examples/linear_regression.html b/docs/build/html/examples/linear_regression.html
index 5d2417057..80d34e46e 100644
--- a/docs/build/html/examples/linear_regression.html
+++ b/docs/build/html/examples/linear_regression.html
@@ -47,7 +47,7 @@
-
+
@@ -148,9 +148,12 @@
Usage
Examples
@@ -238,6 +241,7 @@
- mlx.core.cosh
- mlx.core.dequantize
- mlx.core.divide
+- mlx.core.divmod
- mlx.core.equal
- mlx.core.erf
- mlx.core.erfinv
@@ -251,6 +255,7 @@
- mlx.core.greater
- mlx.core.greater_equal
- mlx.core.identity
+- mlx.core.inner
- mlx.core.less
- mlx.core.less_equal
- mlx.core.linspace
@@ -261,6 +266,8 @@
- mlx.core.log1p
- mlx.core.logaddexp
- mlx.core.logical_not
+- mlx.core.logical_and
+- mlx.core.logical_or
- mlx.core.logsumexp
- mlx.core.matmul
- mlx.core.max
@@ -273,6 +280,7 @@
- mlx.core.negative
- mlx.core.ones
- mlx.core.ones_like
+- mlx.core.outer
- mlx.core.partition
- mlx.core.pad
- mlx.core.prod
@@ -286,6 +294,7 @@
- mlx.core.save
- mlx.core.savez
- mlx.core.savez_compressed
+- mlx.core.save_gguf
- mlx.core.save_safetensors
- mlx.core.sigmoid
- mlx.core.sign
@@ -435,6 +444,7 @@
- mlx.nn.losses.hinge_loss
- mlx.nn.losses.huber_loss
- mlx.nn.losses.log_cosh_loss
+- mlx.nn.losses.cosine_similarity_loss
@@ -706,12 +716,12 @@ examples are available in the MLX GitHub repo.