Remove "using namespace mlx::core" in benchmarks/examples (#1685)

* Remove "using namespace mlx::core" in benchmarks/examples

* Fix building example extension

* A missing one in comment

* Fix building on M chips
This commit is contained in:
Cheng
2024-12-12 00:08:29 +09:00
committed by GitHub
parent f76a49e555
commit 4f9b60dd53
12 changed files with 373 additions and 367 deletions

View File

@@ -19,7 +19,7 @@
#include "mlx/backend/metal/utils.h"
#endif
namespace mlx::core {
namespace my_ext {
///////////////////////////////////////////////////////////////////////////////
// Operation Implementation
@@ -32,24 +32,24 @@ namespace mlx::core {
* Follow numpy style broadcasting between x and y
* Inputs are upcasted to floats if needed
**/
array axpby(
const array& x, // Input array x
const array& y, // Input array y
mx::array axpby(
const mx::array& x, // Input mx::array x
const mx::array& y, // Input mx::array y
const float alpha, // Scaling factor for x
const float beta, // Scaling factor for y
StreamOrDevice s /* = {} */ // Stream on which to schedule the operation
mx::StreamOrDevice s /* = {} */ // Stream on which to schedule the operation
) {
// Promote dtypes between x and y as needed
auto promoted_dtype = promote_types(x.dtype(), y.dtype());
// Upcast to float32 for non-floating point inputs x and y
auto out_dtype = issubdtype(promoted_dtype, float32)
auto out_dtype = mx::issubdtype(promoted_dtype, mx::float32)
? promoted_dtype
: promote_types(promoted_dtype, float32);
: promote_types(promoted_dtype, mx::float32);
// Cast x and y up to the determined dtype (on the same stream s)
auto x_casted = astype(x, out_dtype, s);
auto y_casted = astype(y, out_dtype, s);
auto x_casted = mx::astype(x, out_dtype, s);
auto y_casted = mx::astype(y, out_dtype, s);
// Broadcast the shapes of x and y (on the same stream s)
auto broadcasted_inputs = broadcast_arrays({x_casted, y_casted}, s);
@@ -57,12 +57,12 @@ array axpby(
// Construct the array as the output of the Axpby primitive
// with the broadcasted and upcasted arrays as inputs
return array(
return mx::array(
/* const std::vector<int>& shape = */ out_shape,
/* Dtype dtype = */ out_dtype,
/* std::unique_ptr<Primitive> primitive = */
/* mx::Dtype dtype = */ out_dtype,
/* std::unique_ptr<mx::Primitive> primitive = */
std::make_shared<Axpby>(to_stream(s), alpha, beta),
/* const std::vector<array>& inputs = */ broadcasted_inputs);
/* const std::vector<mx::array>& inputs = */ broadcasted_inputs);
}
///////////////////////////////////////////////////////////////////////////////
@@ -71,16 +71,16 @@ array axpby(
template <typename T>
void axpby_impl(
const array& x,
const array& y,
array& out,
const mx::array& x,
const mx::array& y,
mx::array& out,
float alpha_,
float beta_) {
// We only allocate memory when we are ready to fill the output
// malloc_or_wait synchronously allocates available memory
// There may be a wait executed here if the allocation is requested
// under memory-pressured conditions
out.set_data(allocator::malloc_or_wait(out.nbytes()));
out.set_data(mx::allocator::malloc_or_wait(out.nbytes()));
// Collect input and output data pointers
const T* x_ptr = x.data<T>();
@@ -94,8 +94,8 @@ void axpby_impl(
// Do the element-wise operation for each output
for (size_t out_idx = 0; out_idx < out.size(); out_idx++) {
// Map linear indices to offsets in x and y
auto x_offset = elem_to_loc(out_idx, x.shape(), x.strides());
auto y_offset = elem_to_loc(out_idx, y.shape(), y.strides());
auto x_offset = mx::elem_to_loc(out_idx, x.shape(), x.strides());
auto y_offset = mx::elem_to_loc(out_idx, y.shape(), y.strides());
// We allocate the output to be contiguous and regularly strided
// (defaults to row major) and hence it doesn't need additional mapping
@@ -105,8 +105,8 @@ void axpby_impl(
/** Fall back implementation for evaluation on CPU */
void Axpby::eval(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
const std::vector<mx::array>& inputs,
std::vector<mx::array>& outputs) {
// Check the inputs (registered in the op while constructing the out array)
assert(inputs.size() == 2);
auto& x = inputs[0];
@@ -114,14 +114,14 @@ void Axpby::eval(
auto& out = outputs[0];
// Dispatch to the correct dtype
if (out.dtype() == float32) {
if (out.dtype() == mx::float32) {
return axpby_impl<float>(x, y, out, alpha_, beta_);
} else if (out.dtype() == float16) {
return axpby_impl<float16_t>(x, y, out, alpha_, beta_);
} else if (out.dtype() == bfloat16) {
return axpby_impl<bfloat16_t>(x, y, out, alpha_, beta_);
} else if (out.dtype() == complex64) {
return axpby_impl<complex64_t>(x, y, out, alpha_, beta_);
} else if (out.dtype() == mx::float16) {
return axpby_impl<mx::float16_t>(x, y, out, alpha_, beta_);
} else if (out.dtype() == mx::bfloat16) {
return axpby_impl<mx::bfloat16_t>(x, y, out, alpha_, beta_);
} else if (out.dtype() == mx::complex64) {
return axpby_impl<mx::complex64_t>(x, y, out, alpha_, beta_);
} else {
throw std::runtime_error(
"Axpby is only supported for floating point types.");
@@ -136,9 +136,9 @@ void Axpby::eval(
template <typename T>
void axpby_impl_accelerate(
const array& x,
const array& y,
array& out,
const mx::array& x,
const mx::array& y,
mx::array& out,
float alpha_,
float beta_) {
// Accelerate library provides catlas_saxpby which does
@@ -150,10 +150,10 @@ void axpby_impl_accelerate(
// The data in the output array is allocated to match the strides in y
// such that x, y, and out are contiguous in the same mode and
// no transposition is needed
out.set_data(allocator::malloc_or_wait(out.nbytes()));
out.set_data(mx::allocator::malloc_or_wait(out.nbytes()));
// We then copy over the elements using the contiguous vector specialization
copy_inplace(y, out, CopyType::Vector);
copy_inplace(y, out, mx::CopyType::Vector);
// Get x and y pointers for catlas_saxpby
const T* x_ptr = x.data<T>();
@@ -175,15 +175,15 @@ void axpby_impl_accelerate(
/** Evaluate primitive on CPU using accelerate specializations */
void Axpby::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
const std::vector<mx::array>& inputs,
std::vector<mx::array>& outputs) {
assert(inputs.size() == 2);
auto& x = inputs[0];
auto& y = inputs[1];
auto& out = outputs[0];
// Accelerate specialization for contiguous single precision float arrays
if (out.dtype() == float32 &&
if (out.dtype() == mx::float32 &&
((x.flags().row_contiguous && y.flags().row_contiguous) ||
(x.flags().col_contiguous && y.flags().col_contiguous))) {
axpby_impl_accelerate<float>(x, y, out, alpha_, beta_);
@@ -198,8 +198,8 @@ void Axpby::eval_cpu(
/** Evaluate primitive on CPU falling back to common backend */
void Axpby::eval_cpu(
const std::vector<array>& inputs,
const std::vector<array>& outputs) {
const std::vector<mx::array>& inputs,
std::vector<mx::array>& outputs) {
eval(inputs, outputs);
}
@@ -213,8 +213,8 @@ void Axpby::eval_cpu(
/** Evaluate primitive on GPU */
void Axpby::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
const std::vector<mx::array>& inputs,
std::vector<mx::array>& outputs) {
// Prepare inputs
assert(inputs.size() == 2);
auto& x = inputs[0];
@@ -225,7 +225,7 @@ void Axpby::eval_gpu(
// and each stream carries its device identifiers
auto& s = stream();
// We get the needed metal device using the stream
auto& d = metal::device(s.device);
auto& d = mx::metal::device(s.device);
// Prepare to specialize based on contiguity
bool contiguous_kernel =
@@ -235,12 +235,12 @@ void Axpby::eval_gpu(
// Allocate output memory with strides based on specialization
if (contiguous_kernel) {
out.set_data(
allocator::malloc_or_wait(x.data_size() * out.itemsize()),
mx::allocator::malloc_or_wait(x.data_size() * out.itemsize()),
x.data_size(),
x.strides(),
x.flags());
} else {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
out.set_data(mx::allocator::malloc_or_wait(out.nbytes()));
}
// Resolve name of kernel (corresponds to axpby.metal)
@@ -302,8 +302,8 @@ void Axpby::eval_gpu(
/** Fail evaluation on GPU */
void Axpby::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& out) {
const std::vector<mx::array>& inputs,
std::vector<mx::array>& out) {
throw std::runtime_error("Axpby has no GPU implementation.");
}
@@ -314,9 +314,9 @@ void Axpby::eval_gpu(
///////////////////////////////////////////////////////////////////////////////
/** The Jacobian-vector product. */
std::vector<array> Axpby::jvp(
const std::vector<array>& primals,
const std::vector<array>& tangents,
std::vector<mx::array> Axpby::jvp(
const std::vector<mx::array>& primals,
const std::vector<mx::array>& tangents,
const std::vector<int>& argnums) {
// Forward mode diff that pushes along the tangents
// The jvp transform on the primitive can built with ops
@@ -328,8 +328,8 @@ std::vector<array> Axpby::jvp(
// scaled by beta
if (argnums.size() > 1) {
auto scale = argnums[0] == 0 ? alpha_ : beta_;
auto scale_arr = array(scale, tangents[0].dtype());
return {multiply(scale_arr, tangents[0], stream())};
auto scale_arr = mx::array(scale, tangents[0].dtype());
return {mx::multiply(scale_arr, tangents[0], stream())};
}
// If, argnums = {0, 1}, we take contributions from both
// which gives us jvp = tangent_x * alpha + tangent_y * beta
@@ -339,24 +339,24 @@ std::vector<array> Axpby::jvp(
}
/** The vector-Jacobian product. */
std::vector<array> Axpby::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
std::vector<mx::array> Axpby::vjp(
const std::vector<mx::array>& primals,
const std::vector<mx::array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>&) {
const std::vector<mx::array>&) {
// Reverse mode diff
std::vector<array> vjps;
std::vector<mx::array> vjps;
for (auto arg : argnums) {
auto scale = arg == 0 ? alpha_ : beta_;
auto scale_arr = array(scale, cotangents[0].dtype());
vjps.push_back(multiply(scale_arr, cotangents[0], stream()));
auto scale_arr = mx::array(scale, cotangents[0].dtype());
vjps.push_back(mx::multiply(scale_arr, cotangents[0], stream()));
}
return vjps;
}
/** Vectorize primitive along given axis */
std::pair<std::vector<array>, std::vector<int>> Axpby::vmap(
const std::vector<array>& inputs,
std::pair<std::vector<mx::array>, std::vector<int>> Axpby::vmap(
const std::vector<mx::array>& inputs,
const std::vector<int>& axes) {
throw std::runtime_error("Axpby has no vmap implementation.");
}
@@ -367,4 +367,4 @@ bool Axpby::is_equivalent(const Primitive& other) const {
return alpha_ == r_other.alpha_ && beta_ == r_other.beta_;
}
} // namespace mlx::core
} // namespace my_ext

View File

@@ -5,7 +5,9 @@
#include "mlx/ops.h"
#include "mlx/primitives.h"
namespace mlx::core {
namespace mx = mlx::core;
namespace my_ext {
///////////////////////////////////////////////////////////////////////////////
// Operation
@@ -18,22 +20,22 @@ namespace mlx::core {
* Follow numpy style broadcasting between x and y
* Inputs are upcasted to floats if needed
**/
array axpby(
const array& x, // Input array x
const array& y, // Input array y
mx::array axpby(
const mx::array& x, // Input array x
const mx::array& y, // Input array y
const float alpha, // Scaling factor for x
const float beta, // Scaling factor for y
StreamOrDevice s = {} // Stream on which to schedule the operation
mx::StreamOrDevice s = {} // Stream on which to schedule the operation
);
///////////////////////////////////////////////////////////////////////////////
// Primitive
///////////////////////////////////////////////////////////////////////////////
class Axpby : public Primitive {
class Axpby : public mx::Primitive {
public:
explicit Axpby(Stream stream, float alpha, float beta)
: Primitive(stream), alpha_(alpha), beta_(beta) {};
explicit Axpby(mx::Stream stream, float alpha, float beta)
: mx::Primitive(stream), alpha_(alpha), beta_(beta) {};
/**
* A primitive must know how to evaluate itself on the CPU/GPU
@@ -42,23 +44,25 @@ class Axpby : public Primitive {
* To avoid unnecessary allocations, the evaluation function
* is responsible for allocating space for the array.
*/
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
void eval_cpu(
const std::vector<mx::array>& inputs,
std::vector<mx::array>& outputs) override;
void eval_gpu(
const std::vector<mx::array>& inputs,
std::vector<mx::array>& outputs) override;
/** The Jacobian-vector product. */
std::vector<array> jvp(
const std::vector<array>& primals,
const std::vector<array>& tangents,
std::vector<mx::array> jvp(
const std::vector<mx::array>& primals,
const std::vector<mx::array>& tangents,
const std::vector<int>& argnums) override;
/** The vector-Jacobian product. */
std::vector<array> vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
std::vector<mx::array> vjp(
const std::vector<mx::array>& primals,
const std::vector<mx::array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>& outputs) override;
const std::vector<mx::array>& outputs) override;
/**
* The primitive must know how to vectorize itself across
@@ -66,8 +70,8 @@ class Axpby : public Primitive {
* representing the vectorized computation and the axis which
* corresponds to the output vectorized dimension.
*/
std::pair<std::vector<array>, std::vector<int>> vmap(
const std::vector<array>& inputs,
std::pair<std::vector<mx::array>, std::vector<int>> vmap(
const std::vector<mx::array>& inputs,
const std::vector<int>& axes) override;
/** Print the primitive. */
@@ -76,14 +80,16 @@ class Axpby : public Primitive {
}
/** Equivalence check **/
bool is_equivalent(const Primitive& other) const override;
bool is_equivalent(const mx::Primitive& other) const override;
private:
float alpha_;
float beta_;
/** Fall back implementation for evaluation on CPU */
void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
void eval(
const std::vector<mx::array>& inputs,
std::vector<mx::array>& outputs);
};
} // namespace mlx::core
} // namespace my_ext