mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-20 09:33:16 +08:00
Remove "using namespace mlx::core" in benchmarks/examples (#1685)
* Remove "using namespace mlx::core" in benchmarks/examples * Fix building example extension * A missing one in comment * Fix building on M chips
This commit is contained in:
@@ -5,7 +5,9 @@
|
||||
#include "mlx/ops.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
namespace mx = mlx::core;
|
||||
|
||||
namespace my_ext {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Operation
|
||||
@@ -18,22 +20,22 @@ namespace mlx::core {
|
||||
* Follow numpy style broadcasting between x and y
|
||||
* Inputs are upcasted to floats if needed
|
||||
**/
|
||||
array axpby(
|
||||
const array& x, // Input array x
|
||||
const array& y, // Input array y
|
||||
mx::array axpby(
|
||||
const mx::array& x, // Input array x
|
||||
const mx::array& y, // Input array y
|
||||
const float alpha, // Scaling factor for x
|
||||
const float beta, // Scaling factor for y
|
||||
StreamOrDevice s = {} // Stream on which to schedule the operation
|
||||
mx::StreamOrDevice s = {} // Stream on which to schedule the operation
|
||||
);
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Primitive
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
class Axpby : public Primitive {
|
||||
class Axpby : public mx::Primitive {
|
||||
public:
|
||||
explicit Axpby(Stream stream, float alpha, float beta)
|
||||
: Primitive(stream), alpha_(alpha), beta_(beta) {};
|
||||
explicit Axpby(mx::Stream stream, float alpha, float beta)
|
||||
: mx::Primitive(stream), alpha_(alpha), beta_(beta) {};
|
||||
|
||||
/**
|
||||
* A primitive must know how to evaluate itself on the CPU/GPU
|
||||
@@ -42,23 +44,25 @@ class Axpby : public Primitive {
|
||||
* To avoid unnecessary allocations, the evaluation function
|
||||
* is responsible for allocating space for the array.
|
||||
*/
|
||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override;
|
||||
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override;
|
||||
void eval_cpu(
|
||||
const std::vector<mx::array>& inputs,
|
||||
std::vector<mx::array>& outputs) override;
|
||||
void eval_gpu(
|
||||
const std::vector<mx::array>& inputs,
|
||||
std::vector<mx::array>& outputs) override;
|
||||
|
||||
/** The Jacobian-vector product. */
|
||||
std::vector<array> jvp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& tangents,
|
||||
std::vector<mx::array> jvp(
|
||||
const std::vector<mx::array>& primals,
|
||||
const std::vector<mx::array>& tangents,
|
||||
const std::vector<int>& argnums) override;
|
||||
|
||||
/** The vector-Jacobian product. */
|
||||
std::vector<array> vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
std::vector<mx::array> vjp(
|
||||
const std::vector<mx::array>& primals,
|
||||
const std::vector<mx::array>& cotangents,
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>& outputs) override;
|
||||
const std::vector<mx::array>& outputs) override;
|
||||
|
||||
/**
|
||||
* The primitive must know how to vectorize itself across
|
||||
@@ -66,8 +70,8 @@ class Axpby : public Primitive {
|
||||
* representing the vectorized computation and the axis which
|
||||
* corresponds to the output vectorized dimension.
|
||||
*/
|
||||
std::pair<std::vector<array>, std::vector<int>> vmap(
|
||||
const std::vector<array>& inputs,
|
||||
std::pair<std::vector<mx::array>, std::vector<int>> vmap(
|
||||
const std::vector<mx::array>& inputs,
|
||||
const std::vector<int>& axes) override;
|
||||
|
||||
/** Print the primitive. */
|
||||
@@ -76,14 +80,16 @@ class Axpby : public Primitive {
|
||||
}
|
||||
|
||||
/** Equivalence check **/
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
bool is_equivalent(const mx::Primitive& other) const override;
|
||||
|
||||
private:
|
||||
float alpha_;
|
||||
float beta_;
|
||||
|
||||
/** Fall back implementation for evaluation on CPU */
|
||||
void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
|
||||
void eval(
|
||||
const std::vector<mx::array>& inputs,
|
||||
std::vector<mx::array>& outputs);
|
||||
};
|
||||
|
||||
} // namespace mlx::core
|
||||
} // namespace my_ext
|
||||
|
Reference in New Issue
Block a user