mlx/examples/extensions/axpby/axpby.h
Awni Hannun c4230747a1
redesign for faster cpu/gpu synch (#1869)
* redesign for faster cpu/gpu synch

* load + more async CPU

* use command encoder API and move more ops to use it

* make fence back-end generic + CPU only fence

* faster build

* fix async eval

* fixes + handle temporaries

* fix / improve cpu conv

* remove unused status, fix siblings

* fix extensions

* fix

* fix no cpu build

* format

* comments

* fix perf regression, remove unecessary abort

* fix events, task limit cpu

* fix waiting

* fix donation / temporaries in normalization
2025-03-06 19:23:38 -08:00

91 lines
2.7 KiB
C++

// Copyright © 2023-2025 Apple Inc.
#pragma once
#include "mlx/ops.h"
#include "mlx/primitives.h"
namespace mx = mlx::core;
namespace my_ext {
///////////////////////////////////////////////////////////////////////////////
// Operation
///////////////////////////////////////////////////////////////////////////////
/**
* Scale and sum two vectors element-wise
* z = alpha * x + beta * y
*
* Follow numpy style broadcasting between x and y
* Inputs are upcasted to floats if needed
**/
mx::array axpby(
const mx::array& x, // Input array x
const mx::array& y, // Input array y
const float alpha, // Scaling factor for x
const float beta, // Scaling factor for y
mx::StreamOrDevice s = {} // Stream on which to schedule the operation
);
///////////////////////////////////////////////////////////////////////////////
// Primitive
///////////////////////////////////////////////////////////////////////////////
class Axpby : public mx::Primitive {
public:
explicit Axpby(mx::Stream stream, float alpha, float beta)
: mx::Primitive(stream), alpha_(alpha), beta_(beta) {};
/**
* A primitive must know how to evaluate itself on the CPU/GPU
* for the given inputs and populate the output array.
*
* To avoid unnecessary allocations, the evaluation function
* is responsible for allocating space for the array.
*/
void eval_cpu(
const std::vector<mx::array>& inputs,
std::vector<mx::array>& outputs) override;
void eval_gpu(
const std::vector<mx::array>& inputs,
std::vector<mx::array>& outputs) override;
/** The Jacobian-vector product. */
std::vector<mx::array> jvp(
const std::vector<mx::array>& primals,
const std::vector<mx::array>& tangents,
const std::vector<int>& argnums) override;
/** The vector-Jacobian product. */
std::vector<mx::array> vjp(
const std::vector<mx::array>& primals,
const std::vector<mx::array>& cotangents,
const std::vector<int>& argnums,
const std::vector<mx::array>& outputs) override;
/**
* The primitive must know how to vectorize itself across
* the given axes. The output is a pair containing the array
* representing the vectorized computation and the axis which
* corresponds to the output vectorized dimension.
*/
std::pair<std::vector<mx::array>, std::vector<int>> vmap(
const std::vector<mx::array>& inputs,
const std::vector<int>& axes) override;
/** Print the primitive. */
void print(std::ostream& os) override {
os << "Axpby";
}
/** Equivalence check **/
bool is_equivalent(const mx::Primitive& other) const override;
private:
float alpha_;
float beta_;
};
} // namespace my_ext