Remove "using namespace mlx::core" in benchmarks/examples (#1685)

* Remove "using namespace mlx::core" in benchmarks/examples

* Fix building example extension

* A missing one in comment

* Fix building on M chips
This commit is contained in:
Cheng 2024-12-12 00:08:29 +09:00 committed by GitHub
parent f76a49e555
commit 4f9b60dd53
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 373 additions and 367 deletions

View File

@ -5,35 +5,35 @@
#include "mlx/mlx.h" #include "mlx/mlx.h"
#include "time_utils.h" #include "time_utils.h"
using namespace mlx::core; namespace mx = mlx::core;
void time_value_and_grad() { void time_value_and_grad() {
auto x = ones({200, 1000}); auto x = mx::ones({200, 1000});
eval(x); mx::eval(x);
auto fn = [](array x) { auto fn = [](mx::array x) {
for (int i = 0; i < 20; ++i) { for (int i = 0; i < 20; ++i) {
x = log(exp(x)); x = mx::log(mx::exp(x));
} }
return sum(x); return mx::sum(x);
}; };
auto grad_fn = grad(fn); auto grad_fn = mx::grad(fn);
auto independent_value_and_grad = [&]() { auto independent_value_and_grad = [&]() {
auto value = fn(x); auto value = fn(x);
auto dfdx = grad_fn(x); auto dfdx = grad_fn(x);
return std::vector<array>{value, dfdx}; return std::vector<mx::array>{value, dfdx};
}; };
TIME(independent_value_and_grad); TIME(independent_value_and_grad);
auto value_and_grad_fn = value_and_grad(fn); auto value_and_grad_fn = mx::value_and_grad(fn);
auto combined_value_and_grad = [&]() { auto combined_value_and_grad = [&]() {
auto [value, dfdx] = value_and_grad_fn(x); auto [value, dfdx] = value_and_grad_fn(x);
return std::vector<array>{value, dfdx}; return std::vector<mx::array>{value, dfdx};
}; };
TIME(combined_value_and_grad); TIME(combined_value_and_grad);
} }
int main() { int main() {
std::cout << "Benchmarks for " << default_device() << std::endl; std::cout << "Benchmarks for " << mx::default_device() << std::endl;
time_value_and_grad(); time_value_and_grad();
} }

View File

@ -4,21 +4,21 @@
#include "mlx/mlx.h" #include "mlx/mlx.h"
#include "time_utils.h" #include "time_utils.h"
using namespace mlx::core; namespace mx = mlx::core;
void time_add_op() { void time_add_op() {
std::vector<int> sizes(1, 1); std::vector<int> sizes(1, 1);
for (int i = 0; i < 9; ++i) { for (int i = 0; i < 9; ++i) {
sizes.push_back(10 * sizes.back()); sizes.push_back(10 * sizes.back());
} }
set_default_device(Device::cpu); set_default_device(mx::Device::cpu);
for (auto size : sizes) { for (auto size : sizes) {
auto a = random::uniform({size}); auto a = mx::random::uniform({size});
auto b = random::uniform({size}); auto b = mx::random::uniform({size});
eval(a, b); mx::eval(a, b);
std::cout << "Size " << size << std::endl; std::cout << "Size " << size << std::endl;
TIMEM("cpu", add, a, b, Device::cpu); TIMEM("cpu", mx::add, a, b, mx::Device::cpu);
TIMEM("gpu", add, a, b, Device::gpu); TIMEM("gpu", mx::add, a, b, mx::Device::gpu);
} }
} }

View File

@ -6,105 +6,105 @@
#include "mlx/mlx.h" #include "mlx/mlx.h"
#include "time_utils.h" #include "time_utils.h"
using namespace mlx::core; namespace mx = mlx::core;
void time_irregular_binary_ops_1D() { void time_irregular_binary_ops_1D() {
auto device = default_device(); auto device = mx::default_device();
int size = 1000000; int size = 1000000;
int step = 2; int step = 2;
auto a = random::uniform({size}); auto a = mx::random::uniform({size});
auto b = random::uniform({size}); auto b = mx::random::uniform({size});
eval(a, b); mx::eval(a, b);
a = slice(a, {0}, {size}, {step}); a = slice(a, {0}, {size}, {step});
b = slice(b, {0}, {size}, {step}); b = slice(b, {0}, {size}, {step});
TIMEM("1D strided", add, a, b, device); TIMEM("1D strided", mx::add, a, b, device);
} }
void time_irregular_binary_ops_2D() { void time_irregular_binary_ops_2D() {
auto device = default_device(); auto device = mx::default_device();
int size = 2048; int size = 2048;
auto a = random::uniform({size, size}); auto a = mx::random::uniform({size, size});
auto b = random::uniform({size, size}); auto b = mx::random::uniform({size, size});
eval(a, b); mx::eval(a, b);
TIMEM("2D regular", add, a, b, device); TIMEM("2D regular", mx::add, a, b, device);
b = transpose(b); b = mx::transpose(b);
eval(b); mx::eval(b);
TIMEM("2D transpose", add, a, b, device); TIMEM("2D mx::transpose", mx::add, a, b, device);
b = random::uniform({size}); b = mx::random::uniform({size});
eval(b); mx::eval(b);
TIMEM("2D broadcast dim 0", add, a, b, device); TIMEM("2D broadcast dim 0", mx::add, a, b, device);
b = reshape(b, {size, 1}); b = mx::reshape(b, {size, 1});
eval(b); mx::eval(b);
TIMEM("2D broadcast dim 1", add, a, b, device); TIMEM("2D broadcast dim 1", mx::add, a, b, device);
} }
void time_irregular_binary_ops_3D() { void time_irregular_binary_ops_3D() {
auto device = default_device(); auto device = mx::default_device();
int d0 = 32; int d0 = 32;
int d1 = 512; int d1 = 512;
int d2 = 512; int d2 = 512;
auto a = random::uniform({d0, d1, d2}); auto a = mx::random::uniform({d0, d1, d2});
auto b = random::uniform({d0, d1, d2}); auto b = mx::random::uniform({d0, d1, d2});
TIMEM("3D regular", add, a, b, device); TIMEM("3D regular", mx::add, a, b, device);
b = transpose(b, {0, 2, 1}); b = mx::transpose(b, {0, 2, 1});
TIMEM("3D transpose", add, a, b, device); TIMEM("3D mx::transpose", mx::add, a, b, device);
b = random::uniform({d1, d2}); b = mx::random::uniform({d1, d2});
TIMEM("3D broadcast dim 0", add, a, b, device); TIMEM("3D broadcast dim 0", mx::add, a, b, device);
b = random::uniform({d0, 1, d2}); b = mx::random::uniform({d0, 1, d2});
TIMEM("3D broadcast dim 1", add, a, b, device); TIMEM("3D broadcast dim 1", mx::add, a, b, device);
b = random::uniform({d0, d1, 1}); b = mx::random::uniform({d0, d1, 1});
TIMEM("3D broadcast dim 2", add, a, b, device); TIMEM("3D broadcast dim 2", mx::add, a, b, device);
b = random::uniform({d2}); b = mx::random::uniform({d2});
TIMEM("3D broadcast dims 0, 1", add, a, b, device); TIMEM("3D broadcast dims 0, 1", mx::add, a, b, device);
b = random::uniform({d1, 1}); b = mx::random::uniform({d1, 1});
TIMEM("3D broadcast dims 0, 2", add, a, b, device); TIMEM("3D broadcast dims 0, 2", mx::add, a, b, device);
b = random::uniform({d0, 1, 1}); b = mx::random::uniform({d0, 1, 1});
TIMEM("3D broadcast dims 1, 2", add, a, b, device); TIMEM("3D broadcast dims 1, 2", mx::add, a, b, device);
} }
void time_irregular_binary_ops_4D() { void time_irregular_binary_ops_4D() {
auto device = default_device(); auto device = mx::default_device();
std::vector<int> shape = {8, 8, 512, 512}; std::vector<int> shape = {8, 8, 512, 512};
auto a = random::uniform(shape); auto a = mx::random::uniform(shape);
auto b = random::uniform(shape); auto b = mx::random::uniform(shape);
TIMEM("4D regular", add, a, b, device); TIMEM("4D regular", mx::add, a, b, device);
b = transpose(b, {0, 1, 3, 2}); b = mx::transpose(b, {0, 1, 3, 2});
TIMEM("4D transpose", add, a, b, device); TIMEM("4D mx::transpose", mx::add, a, b, device);
std::string om = "4D broadcast dims "; std::string om = "4D broadcast dims ";
for (int i = 0; i < shape.size(); ++i) { for (int i = 0; i < shape.size(); ++i) {
shape[i] = 1; shape[i] = 1;
b = random::uniform(shape); b = mx::random::uniform(shape);
std::ostringstream msg; std::ostringstream msg;
msg << om << i; msg << om << i;
TIMEM(msg.str(), add, a, b, device); TIMEM(msg.str(), mx::add, a, b, device);
for (int j = i + 1; j < shape.size(); ++j) { for (int j = i + 1; j < shape.size(); ++j) {
shape[j] = 1; shape[j] = 1;
std::ostringstream msg; std::ostringstream msg;
msg << om << i << ", " << j; msg << om << i << ", " << j;
b = random::uniform(shape); b = mx::random::uniform(shape);
TIMEM(msg.str(), add, a, b, device); TIMEM(msg.str(), mx::add, a, b, device);
shape[j] = a.shape(j); shape[j] = a.shape(j);
for (int k = j + 1; k < shape.size(); ++k) { for (int k = j + 1; k < shape.size(); ++k) {
shape[k] = 1; shape[k] = 1;
std::ostringstream msg; std::ostringstream msg;
msg << om << i << ", " << j << ", " << k; msg << om << i << ", " << j << ", " << k;
b = random::uniform(shape); b = mx::random::uniform(shape);
TIMEM(msg.str(), add, a, b, device); TIMEM(msg.str(), mx::add, a, b, device);
shape[k] = a.shape(k); shape[k] = a.shape(k);
} }
} }
@ -113,83 +113,83 @@ void time_irregular_binary_ops_4D() {
} }
void time_irregular_reshape() { void time_irregular_reshape() {
auto device = default_device(); auto device = mx::default_device();
std::vector<int> shape; std::vector<int> shape;
auto reshape_fn = [&shape, device](const array& a) { auto reshape_fn = [&shape, device](const mx::array& a) {
return reshape(a, shape, device); return mx::reshape(a, shape, device);
}; };
int size = 64; int size = 64;
int d = 2 * size; int d = 2 * size;
auto a = random::uniform({d, d, d}); auto a = mx::random::uniform({d, d, d});
shape = {8 * size, size, size}; shape = {8 * size, size, size};
TIMEM("3D contiguous", reshape_fn, a); TIMEM("3D contiguous", reshape_fn, a);
a = transpose(a); a = mx::transpose(a);
shape = {8 * size, size, size}; shape = {8 * size, size, size};
TIMEM("3D transpose", reshape_fn, a); TIMEM("3D mx::transpose", reshape_fn, a);
a = transpose(a, {1, 2, 0}); a = mx::transpose(a, {1, 2, 0});
shape = {8 * size, size, size}; shape = {8 * size, size, size};
TIMEM("3D transpose dims 1 2", reshape_fn, a); TIMEM("3D mx::transpose dims 1 2", reshape_fn, a);
a = broadcast_to(random::uniform({d, d}), {d, d, d}); a = mx::broadcast_to(mx::random::uniform({d, d}), {d, d, d});
TIMEM("3D broadcast dim 0", reshape_fn, a); TIMEM("3D broadcast dim 0", reshape_fn, a);
a = broadcast_to(random::uniform({d, 1, d}), {d, d, d}); a = mx::broadcast_to(mx::random::uniform({d, 1, d}), {d, d, d});
TIMEM("3D broadcast dim 1", reshape_fn, a); TIMEM("3D broadcast dim 1", reshape_fn, a);
a = broadcast_to(random::uniform({d, d, 1}), {d, d, d}); a = mx::broadcast_to(mx::random::uniform({d, d, 1}), {d, d, d});
TIMEM("3D broadcast dim 2", reshape_fn, a); TIMEM("3D broadcast dim 2", reshape_fn, a);
a = broadcast_to(random::uniform({d}), {d, d, d}); a = mx::broadcast_to(mx::random::uniform({d}), {d, d, d});
TIMEM("3D broadcast dims 0, 1", reshape_fn, a); TIMEM("3D broadcast dims 0, 1", reshape_fn, a);
a = broadcast_to(random::uniform({d, 1}), {d, d, d}); a = mx::broadcast_to(mx::random::uniform({d, 1}), {d, d, d});
TIMEM("3D broadcast dims 0, 2", reshape_fn, a); TIMEM("3D broadcast dims 0, 2", reshape_fn, a);
a = broadcast_to(random::uniform({d, 1, 1}), {d, d, d}); a = mx::broadcast_to(mx::random::uniform({d, 1, 1}), {d, d, d});
TIMEM("3D broadcast dims 1, 2", reshape_fn, a); TIMEM("3D broadcast dims 1, 2", reshape_fn, a);
a = broadcast_to(random::uniform({1, 1, 1}), {d, d, d}); a = mx::broadcast_to(mx::random::uniform({1, 1, 1}), {d, d, d});
TIMEM("3D broadcast dims 1, 2, 3", reshape_fn, a); TIMEM("3D broadcast dims 1, 2, 3", reshape_fn, a);
} }
void time_irregular_astype_1D() { void time_irregular_astype_1D() {
auto device = default_device(); auto device = mx::default_device();
int size = 1000000; int size = 1000000;
int step = 2; int step = 2;
auto a = random::uniform({size}); auto a = mx::random::uniform({size});
a = slice(a, {0}, {size}, {step}); a = slice(a, {0}, {size}, {step});
TIMEM("1D strided", astype, a, int32, device); TIMEM("1D strided", mx::astype, a, mx::int32, device);
} }
void time_irregular_astype_2D() { void time_irregular_astype_2D() {
auto device = default_device(); auto device = mx::default_device();
int size = 2048; int size = 2048;
std::vector<int> shape = {size, size}; std::vector<int> shape = {size, size};
auto a = random::uniform(shape); auto a = mx::random::uniform(shape);
TIMEM("2D regular", astype, a, int32, device); TIMEM("2D regular", mx::astype, a, mx::int32, device);
a = transpose(a); a = mx::transpose(a);
TIMEM("2D transpose", astype, a, int32, device); TIMEM("2D mx::transpose", mx::astype, a, mx::int32, device);
a = broadcast_to(random::uniform({size}), shape); a = mx::broadcast_to(mx::random::uniform({size}), shape);
TIMEM("2D broadcast dim 0", astype, a, int32, device); TIMEM("2D broadcast dim 0", mx::astype, a, mx::int32, device);
a = broadcast_to(random::uniform({size, 1}), shape); a = mx::broadcast_to(mx::random::uniform({size, 1}), shape);
TIMEM("2D broadcast dim 1", astype, a, int32, device); TIMEM("2D broadcast dim 1", mx::astype, a, mx::int32, device);
} }
int main(int argc, char** argv) { int main(int argc, char** argv) {
if (argc > 1) { if (argc > 1) {
bool use_gpu = !strcmp(argv[1], "gpu"); bool use_gpu = !strcmp(argv[1], "gpu");
set_default_device(use_gpu ? Device::gpu : Device::cpu); set_default_device(use_gpu ? mx::Device::gpu : mx::Device::cpu);
} }
std::cout << "Benchmarks for " << default_device() << std::endl; std::cout << "Benchmarks for " << mx::default_device() << std::endl;
time_irregular_binary_ops_1D(); time_irregular_binary_ops_1D();
time_irregular_binary_ops_2D(); time_irregular_binary_ops_2D();
time_irregular_binary_ops_3D(); time_irregular_binary_ops_3D();

View File

@ -3,20 +3,20 @@
#include "mlx/mlx.h" #include "mlx/mlx.h"
#include "time_utils.h" #include "time_utils.h"
using namespace mlx::core; namespace mx = mlx::core;
void time_creation_ops() { void time_creation_ops() {
int M = 2000; int M = 2000;
int N = 500; int N = 500;
auto shape = {M, N}; auto shape = {M, N};
auto full_fp32 = [&]() { return full(shape, 3.3f); }; auto full_fp32 = [&]() { return mx::full(shape, 3.3f); };
TIME(full_fp32); TIME(full_fp32);
auto zeros_fp32 = [&]() { return zeros(shape, float32); }; auto zeros_fp32 = [&]() { return mx::zeros(shape, mx::float32); };
TIME(zeros_fp32); TIME(zeros_fp32);
auto ones_fp32 = [&]() { return ones(shape, float32); }; auto ones_fp32 = [&]() { return mx::ones(shape, mx::float32); };
TIME(ones_fp32); TIME(ones_fp32);
auto arange_fp32 = [&]() { return arange(0.0, 10.0, 1e-4); }; auto arange_fp32 = [&]() { return mx::arange(0.0, 10.0, 1e-4); };
TIME(arange_fp32); TIME(arange_fp32);
} }
@ -24,194 +24,196 @@ void time_type_conversions() {
int M = 2000; int M = 2000;
int N = 500; int N = 500;
auto shape = {M, N}; auto shape = {M, N};
auto device = default_device(); auto device = mx::default_device();
auto a = zeros(shape, float32); auto a = mx::zeros(shape, mx::float32);
eval(a); mx::eval(a);
TIMEM("float32 to int32", astype, a, int32, device); TIMEM("mx::float32 to mx::int32", mx::astype, a, mx::int32, device);
TIMEM("float32 to uint32", astype, a, uint32, device); TIMEM("mx::float32 to mx::uint32", mx::astype, a, mx::uint32, device);
a = zeros(shape, int32); a = mx::zeros(shape, mx::int32);
eval(a); mx::eval(a);
TIMEM("int32 to float32", astype, a, float32, device); TIMEM("mx::int32 to mx::float32", mx::astype, a, mx::float32, device);
a = zeros(shape, bool_); a = mx::zeros(shape, mx::bool_);
eval(a); mx::eval(a);
TIMEM("bool to float32", astype, a, float32, device); TIMEM("bool to mx::float32", mx::astype, a, mx::float32, device);
TIMEM("bool to int32", astype, a, int32, device); TIMEM("bool to mx::int32", mx::astype, a, mx::int32, device);
TIMEM("bool to uint32", astype, a, uint32, device); TIMEM("bool to mx::uint32", mx::astype, a, mx::uint32, device);
} }
void time_random_generation() { void time_random_generation() {
int M = 2000; int M = 2000;
int N = 500; int N = 500;
auto uniform = [&]() { return random::uniform({M, N}, float32); }; auto uniform = [&]() { return mx::random::uniform({M, N}, mx::float32); };
TIME(uniform); TIME(uniform);
auto normal = [&]() { return random::normal({M, N}, float32); }; auto normal = [&]() { return mx::random::normal({M, N}, mx::float32); };
TIME(normal); TIME(normal);
} }
void time_unary_ops() { void time_unary_ops() {
int M = 2000; int M = 2000;
int N = 500; int N = 500;
auto device = default_device(); auto device = mx::default_device();
auto a = random::normal({M, N}); auto a = mx::random::normal({M, N});
eval(a); mx::eval(a);
TIME(mlx::core::abs, a, device); TIME(mlx::core::abs, a, device);
TIME(negative, a, device); TIME(mx::negative, a, device);
TIME(sign, a, device); TIME(mx::sign, a, device);
TIME(square, a, device); TIME(mx::square, a, device);
TIME(mlx::core::sqrt, a, device); TIME(mlx::core::sqrt, a, device);
TIME(rsqrt, a, device); TIME(mx::rsqrt, a, device);
TIME(mlx::core::exp, a, device); TIME(mlx::core::exp, a, device);
a = random::uniform({M, N}); a = mx::random::uniform({M, N});
TIME(mlx::core::log, a, device); TIME(mlx::core::log, a, device);
} }
void time_binary_ops() { void time_binary_ops() {
int M = 1000, N = 100, K = 10; int M = 1000, N = 100, K = 10;
auto condition = random::randint(0, 2, {M, N, K}); auto condition = mx::random::randint(0, 2, {M, N, K});
auto a = random::uniform({M, N, K}); auto a = mx::random::uniform({M, N, K});
auto b = random::uniform({M, N, K}); auto b = mx::random::uniform({M, N, K});
auto device = default_device(); auto device = mx::default_device();
eval(a, b); mx::eval(a, b);
TIME(add, a, b, device); TIME(mx::add, a, b, device);
TIME(subtract, a, b, device); TIME(mx::subtract, a, b, device);
TIME(multiply, a, b, device); TIME(mx::multiply, a, b, device);
TIME(divide, a, b, device); TIME(mx::divide, a, b, device);
TIME(maximum, a, b, device); TIME(mx::maximum, a, b, device);
TIME(minimum, a, b, device); TIME(mx::minimum, a, b, device);
TIME(where, condition, a, b, device); TIME(mx::where, condition, a, b, device);
condition = array({true}); condition = mx::array({true});
b = random::uniform({1}); b = mx::random::uniform({1});
eval(b); mx::eval(b);
TIMEM("scalar", add, a, b, device); TIMEM("scalar", mx::add, a, b, device);
TIMEM("vector-scalar", subtract, a, b, device); TIMEM("vector-scalar", mx::subtract, a, b, device);
TIMEM("scalar-vector", subtract, b, a, device); TIMEM("scalar-vector", mx::subtract, b, a, device);
TIMEM("scalar", multiply, a, b, device); TIMEM("scalar", mx::multiply, a, b, device);
TIMEM("vector-scalar", divide, a, b, device); TIMEM("vector-scalar", mx::divide, a, b, device);
TIMEM("scalar-vector", divide, b, a, device); TIMEM("scalar-vector", mx::divide, b, a, device);
TIMEM("scalar-vector", where, condition, a, b, device); TIMEM("scalar-vector", mx::where, condition, a, b, device);
condition = broadcast_to(array({true}), {1000, 100}); condition = mx::broadcast_to(mx::array({true}), {1000, 100});
a = broadcast_to(random::uniform({1}), {1000, 100}); a = mx::broadcast_to(mx::random::uniform({1}), {1000, 100});
b = broadcast_to(random::uniform({1}), {1000, 100}); b = mx::broadcast_to(mx::random::uniform({1}), {1000, 100});
eval(a, b); mx::eval(a, b);
TIMEM("scalar-scalar broadcast", add, a, b, device); TIMEM("scalar-scalar broadcast", mx::add, a, b, device);
TIMEM("scalar-scalar broadcast", subtract, a, b, device); TIMEM("scalar-scalar broadcast", mx::subtract, a, b, device);
TIMEM("scalar-scalar broadcast", multiply, a, b, device); TIMEM("scalar-scalar broadcast", mx::multiply, a, b, device);
TIMEM("scalar-scalar broadcast", divide, a, b, device); TIMEM("scalar-scalar broadcast", mx::divide, a, b, device);
TIMEM("scalar-scalar broadcast", where, condition, a, b, device); TIMEM("scalar-scalar broadcast", mx::where, condition, a, b, device);
} }
void time_strided_ops() { void time_strided_ops() {
int M = 50, N = 50, O = 50, P = 50; int M = 50, N = 50, O = 50, P = 50;
auto a = random::uniform({M, N, O, P}); auto a = mx::random::uniform({M, N, O, P});
auto b = random::uniform({M, N, O, P}); auto b = mx::random::uniform({M, N, O, P});
auto device = default_device(); auto device = mx::default_device();
eval(a, b); mx::eval(a, b);
TIMEM("non-strided", add, a, b, device); TIMEM("non-strided", mx::add, a, b, device);
a = transpose(a, {1, 0, 2, 3}); a = mx::transpose(a, {1, 0, 2, 3});
b = transpose(b, {3, 2, 0, 1}); b = mx::transpose(b, {3, 2, 0, 1});
eval(a, b); mx::eval(a, b);
TIMEM("strided", add, a, b, device); TIMEM("strided", mx::add, a, b, device);
} }
void time_comparisons() { void time_comparisons() {
int M = 1000, N = 100, K = 10; int M = 1000, N = 100, K = 10;
auto a = random::uniform({M, N, K}); auto a = mx::random::uniform({M, N, K});
auto b = random::uniform({M, N, K}); auto b = mx::random::uniform({M, N, K});
auto device = default_device(); auto device = mx::default_device();
eval(a, b); mx::eval(a, b);
TIME(equal, a, b, device); TIME(mx::equal, a, b, device);
TIME(greater, a, b, device); TIME(mx::greater, a, b, device);
TIME(greater_equal, a, b, device); TIME(mx::greater_equal, a, b, device);
TIME(less, a, b, device); TIME(mx::less, a, b, device);
TIME(less_equal, a, b, device); TIME(mx::less_equal, a, b, device);
} }
void time_matvec() { void time_matvec() {
int M = 2000, N = 200; int M = 2000, N = 200;
auto a = random::uniform({M, N}); auto a = mx::random::uniform({M, N});
auto b = random::uniform({N}); auto b = mx::random::uniform({N});
auto c = random::uniform({M}); auto c = mx::random::uniform({M});
eval(a, b, c); mx::eval(a, b, c);
auto matvec = [&]() { return matmul(a, b); }; auto matvec = [&]() { return mx::matmul(a, b); };
TIME(matvec); TIME(matvec);
auto matvec_transpose = [&]() { return matmul(transpose(a), c); }; auto matvec_transpose = [&]() { return mx::matmul(mx::transpose(a), c); };
TIME(matvec_transpose); TIME(matvec_transpose);
} }
void time_matmul() { void time_matmul() {
int M = 1000, N = 1000, K = 1000; int M = 1000, N = 1000, K = 1000;
auto a = random::uniform({M, K}); auto a = mx::random::uniform({M, K});
auto b = random::uniform({K, N}); auto b = mx::random::uniform({K, N});
auto device = default_device(); auto device = mx::default_device();
eval(a, b); mx::eval(a, b);
TIME(matmul, a, b, device); TIME(mx::matmul, a, b, device);
auto transpose_matmul = [&]() { return matmul(transpose(a), b); }; auto transpose_matmul = [&]() { return mx::matmul(mx::transpose(a), b); };
TIME(transpose_matmul); TIME(transpose_matmul);
} }
void time_reductions() { void time_reductions() {
auto a = random::normal({10000, 1000}); auto a = mx::random::normal({10000, 1000});
eval(a); mx::eval(a);
auto sum_all = [&a]() { return sum(a, false); }; auto sum_all = [&a]() { return mx::sum(a, false); };
TIME(sum_all); TIME(sum_all);
auto sum_along_0 = [&a]() { return sum(a, 0, false); }; auto sum_along_0 = [&a]() { return mx::sum(a, 0, false); };
TIME(sum_along_0); TIME(sum_along_0);
auto sum_along_1 = [&a]() { return sum(a, 1, false); }; auto sum_along_1 = [&a]() { return mx::sum(a, 1, false); };
TIME(sum_along_1); TIME(sum_along_1);
auto prod_all = [&a]() { return prod(a, false); }; auto prod_all = [&a]() { return mx::prod(a, false); };
TIME(prod_all); TIME(prod_all);
auto all_true = [&a]() { return all(a, false); }; auto all_true = [&a]() { return mx::all(a, false); };
TIME(all_true); TIME(all_true);
auto all_along_0 = [&a]() { return all(a, 0, false); }; auto all_along_0 = [&a]() { return mx::all(a, 0, false); };
TIME(all_along_0); TIME(all_along_0);
auto all_along_1 = [&a]() { return all(a, 1, false); }; auto all_along_1 = [&a]() { return mx::all(a, 1, false); };
TIME(all_along_1); TIME(all_along_1);
auto any_true = [&a]() { return any(a, false); }; auto any_true = [&a]() { return mx::any(a, false); };
TIME(any_true); TIME(any_true);
auto argmin_along_0 = [&a]() { return argmin(a, 0, false); }; auto argmin_along_0 = [&a]() { return mx::argmin(a, 0, false); };
TIME(argmin_along_0); TIME(argmin_along_0);
auto argmin_along_1 = [&a]() { return argmin(a, 1, false); }; auto argmin_along_1 = [&a]() { return mx::argmin(a, 1, false); };
TIME(argmin_along_1); TIME(argmin_along_1);
} }
void time_gather_scatter() { void time_gather_scatter() {
auto a = random::normal({1000, 768}); auto a = mx::random::normal({1000, 768});
eval(a); mx::eval(a);
auto indices = random::randint(0, 1000, {256}); auto indices = mx::random::randint(0, 1000, {256});
eval(indices); mx::eval(indices);
auto embedding_lookup = [&a, &indices]() { return take(a, indices, 0); }; auto embedding_lookup = [&a, &indices]() { return mx::take(a, indices, 0); };
TIME(embedding_lookup); TIME(embedding_lookup);
indices = random::randint(0, 768 * 1000, {256 * 768}); indices = mx::random::randint(0, 768 * 1000, {256 * 768});
eval(indices); mx::eval(indices);
auto single_element_lookup = [&a, &indices]() { return take(a, indices); }; auto single_element_lookup = [&a, &indices]() {
return mx::take(a, indices);
};
TIME(single_element_lookup); TIME(single_element_lookup);
indices = random::randint(0, 1000, {256}); indices = mx::random::randint(0, 1000, {256});
auto updates = random::normal({256, 1, 768}); auto updates = mx::random::normal({256, 1, 768});
eval(indices, updates); mx::eval(indices, updates);
auto embedding_update = [&a, &indices, &updates]() { auto embedding_update = [&a, &indices, &updates]() {
return scatter(a, indices, updates, 0); return scatter(a, indices, updates, 0);
@ -223,10 +225,10 @@ void time_gather_scatter() {
}; };
TIME(embedding_add); TIME(embedding_add);
a = reshape(a, {-1}); a = mx::reshape(a, {-1});
indices = random::randint(0, 768 * 1000, {768 * 256}); indices = mx::random::randint(0, 768 * 1000, {768 * 256});
updates = random::normal({256 * 768, 1}); updates = mx::random::normal({256 * 768, 1});
eval(a, indices, updates); mx::eval(a, indices, updates);
auto single_element_update = [&a, &indices, &updates]() { auto single_element_update = [&a, &indices, &updates]() {
return scatter(a, indices, updates, 0); return scatter(a, indices, updates, 0);
@ -240,21 +242,21 @@ void time_gather_scatter() {
} }
void time_divmod() { void time_divmod() {
auto a = random::normal({1000}); auto a = mx::random::normal({1000});
auto b = random::normal({1000}); auto b = mx::random::normal({1000});
eval({a, b}); mx::eval({a, b});
auto divmod_fused = [&a, &b]() { return divmod(a, b); }; auto divmod_fused = [&a, &b]() { return mx::divmod(a, b); };
TIME(divmod_fused); TIME(divmod_fused);
auto divmod_separate = [&a, &b]() { auto divmod_separate = [&a, &b]() {
return std::vector<array>{floor_divide(a, b), remainder(a, b)}; return std::vector<mx::array>{mx::floor_divide(a, b), mx::remainder(a, b)};
}; };
TIME(divmod_separate); TIME(divmod_separate);
} }
int main() { int main() {
std::cout << "Benchmarks for " << default_device() << std::endl; std::cout << "Benchmarks for " << mx::default_device() << std::endl;
time_creation_ops(); time_creation_ops();
time_type_conversions(); time_type_conversions();
time_unary_ops(); time_unary_ops();

View File

@ -4,19 +4,19 @@
#include "mlx/mlx.h" #include "mlx/mlx.h"
using namespace mlx::core; namespace mx = mlx::core;
int main() { int main() {
if (!distributed::is_available()) { if (!mx::distributed::is_available()) {
std::cout << "No communication backend found" << std::endl; std::cout << "No communication backend found" << std::endl;
return 1; return 1;
} }
auto global_group = distributed::init(); auto global_group = mx::distributed::init();
std::cout << global_group.rank() << " / " << global_group.size() << std::endl; std::cout << global_group.rank() << " / " << global_group.size() << std::endl;
array x = ones({10}); mx::array x = mx::ones({10});
array out = distributed::all_sum(x, global_group); mx::array out = mx::distributed::all_sum(x, global_group);
std::cout << out << std::endl; std::cout << out << std::endl;
} }

View File

@ -10,7 +10,7 @@
/** /**
* An example of linear regression with MLX. * An example of linear regression with MLX.
*/ */
using namespace mlx::core; namespace mx = mlx::core;
int main() { int main() {
int num_features = 100; int num_features = 100;
@ -19,35 +19,35 @@ int main() {
float learning_rate = 0.01; float learning_rate = 0.01;
// True parameters // True parameters
auto w_star = random::normal({num_features}); auto w_star = mx::random::normal({num_features});
// The input examples (design matrix) // The input examples (design matrix)
auto X = random::normal({num_examples, num_features}); auto X = mx::random::normal({num_examples, num_features});
// Noisy labels // Noisy labels
auto eps = 1e-2 * random::normal({num_examples}); auto eps = 1e-2 * mx::random::normal({num_examples});
auto y = matmul(X, w_star) + eps; auto y = mx::matmul(X, w_star) + eps;
// Initialize random parameters // Initialize random parameters
array w = 1e-2 * random::normal({num_features}); mx::array w = 1e-2 * mx::random::normal({num_features});
auto loss_fn = [&](array w) { auto loss_fn = [&](mx::array w) {
auto yhat = matmul(X, w); auto yhat = mx::matmul(X, w);
return (0.5f / num_examples) * sum(square(yhat - y)); return (0.5f / num_examples) * mx::sum(mx::square(yhat - y));
}; };
auto grad_fn = grad(loss_fn); auto grad_fn = mx::grad(loss_fn);
auto tic = timer::time(); auto tic = timer::time();
for (int it = 0; it < num_iters; ++it) { for (int it = 0; it < num_iters; ++it) {
auto grad = grad_fn(w); auto grads = grad_fn(w);
w = w - learning_rate * grad; w = w - learning_rate * grads;
eval(w); mx::eval(w);
} }
auto toc = timer::time(); auto toc = timer::time();
auto loss = loss_fn(w); auto loss = loss_fn(w);
auto error_norm = std::sqrt(sum(square(w - w_star)).item<float>()); auto error_norm = std::sqrt(mx::sum(mx::square(w - w_star)).item<float>());
auto throughput = num_iters / timer::seconds(toc - tic); auto throughput = num_iters / timer::seconds(toc - tic);
std::cout << "Loss " << loss << ", |w - w*| = " << error_norm std::cout << "Loss " << loss << ", |w - w*| = " << error_norm
<< ", Throughput " << throughput << " (it/s)." << std::endl; << ", Throughput " << throughput << " (it/s)." << std::endl;

View File

@ -10,7 +10,7 @@
/** /**
* An example of logistic regression with MLX. * An example of logistic regression with MLX.
*/ */
using namespace mlx::core; namespace mx = mlx::core;
int main() { int main() {
int num_features = 100; int num_features = 100;
@ -19,35 +19,35 @@ int main() {
float learning_rate = 0.1; float learning_rate = 0.1;
// True parameters // True parameters
auto w_star = random::normal({num_features}); auto w_star = mx::random::normal({num_features});
// The input examples // The input examples
auto X = random::normal({num_examples, num_features}); auto X = mx::random::normal({num_examples, num_features});
// Labels // Labels
auto y = matmul(X, w_star) > 0; auto y = mx::matmul(X, w_star) > 0;
// Initialize random parameters // Initialize random parameters
array w = 1e-2 * random::normal({num_features}); mx::array w = 1e-2 * mx::random::normal({num_features});
auto loss_fn = [&](array w) { auto loss_fn = [&](mx::array w) {
auto logits = matmul(X, w); auto logits = mx::matmul(X, w);
auto scale = (1.0f / num_examples); auto scale = (1.0f / num_examples);
return scale * sum(logaddexp(array(0.0f), logits) - y * logits); return scale * mx::sum(mx::logaddexp(mx::array(0.0f), logits) - y * logits);
}; };
auto grad_fn = grad(loss_fn); auto grad_fn = mx::grad(loss_fn);
auto tic = timer::time(); auto tic = timer::time();
for (int it = 0; it < num_iters; ++it) { for (int it = 0; it < num_iters; ++it) {
auto grad = grad_fn(w); auto grads = grad_fn(w);
w = w - learning_rate * grad; w = w - learning_rate * grads;
eval(w); mx::eval(w);
} }
auto toc = timer::time(); auto toc = timer::time();
auto loss = loss_fn(w); auto loss = loss_fn(w);
auto acc = sum((matmul(X, w) > 0) == y) / num_examples; auto acc = mx::sum((mx::matmul(X, w) > 0) == y) / num_examples;
auto throughput = num_iters / timer::seconds(toc - tic); auto throughput = num_iters / timer::seconds(toc - tic);
std::cout << "Loss " << loss << ", Accuracy, " << acc << ", Throughput " std::cout << "Loss " << loss << ", Accuracy, " << acc << ", Throughput "
<< throughput << " (it/s)." << std::endl; << throughput << " (it/s)." << std::endl;

View File

@ -5,27 +5,27 @@
#include "mlx/mlx.h" #include "mlx/mlx.h"
using namespace mlx::core; namespace mx = mlx::core;
int main() { int main() {
// To use Metal debugging and profiling: // To use Metal debugging and profiling:
// 1. Build with the MLX_METAL_DEBUG CMake option (i.e. -DMLX_METAL_DEBUG=ON). // 1. Build with the MLX_METAL_DEBUG CMake option (i.e. -DMLX_METAL_DEBUG=ON).
// 2. Run with MTL_CAPTURE_ENABLED=1. // 2. Run with MTL_CAPTURE_ENABLED=1.
metal::start_capture("mlx_trace.gputrace"); mx::metal::start_capture("mlx_trace.gputrace");
// Start at index two because the default GPU and CPU streams have indices // Start at index two because the default GPU and CPU streams have indices
// zero and one, respectively. This naming matches the label assigned to each // zero and one, respectively. This naming matches the label assigned to each
// stream's command queue. // stream's command queue.
auto s2 = new_stream(Device::gpu); auto s2 = new_stream(mx::Device::gpu);
auto s3 = new_stream(Device::gpu); auto s3 = new_stream(mx::Device::gpu);
auto a = arange(1.f, 10.f, 1.f, float32, s2); auto a = mx::arange(1.f, 10.f, 1.f, mx::float32, s2);
auto b = arange(1.f, 10.f, 1.f, float32, s3); auto b = mx::arange(1.f, 10.f, 1.f, mx::float32, s3);
auto x = add(a, a, s2); auto x = mx::add(a, a, s2);
auto y = add(b, b, s3); auto y = mx::add(b, b, s3);
// The multiply will happen on the default stream. // The multiply will happen on the default stream.
std::cout << multiply(x, y) << std::endl; std::cout << mx::multiply(x, y) << std::endl;
metal::stop_capture(); mx::metal::stop_capture();
} }

View File

@ -5,11 +5,11 @@
#include "mlx/mlx.h" #include "mlx/mlx.h"
using namespace mlx::core; namespace mx = mlx::core;
void array_basics() { void array_basics() {
// Make a scalar array: // Make a scalar array:
array x(1.0); mx::array x(1.0);
// Get the value out of it: // Get the value out of it:
auto s = x.item<float>(); auto s = x.item<float>();
@ -29,31 +29,31 @@ void array_basics() {
// The datatype should be float32: // The datatype should be float32:
auto dtype = x.dtype(); auto dtype = x.dtype();
assert(dtype == float32); assert(dtype == mx::float32);
// Specify the dtype when constructing the array: // Specify the dtype when constructing the array:
x = array(1, int32); x = mx::array(1, mx::int32);
assert(x.dtype() == int32); assert(x.dtype() == mx::int32);
x.item<int>(); // OK x.item<int>(); // OK
// x.item<float>(); // Undefined! // x.item<float>(); // Undefined!
// Make a multidimensional array: // Make a multidimensional array:
x = array({1.0f, 2.0f, 3.0f, 4.0f}, {2, 2}); x = mx::array({1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});
// mlx is row-major by default so the first row of this array // mlx is row-major by default so the first row of this array
// is [1.0, 2.0] and the second row is [3.0, 4.0] // is [1.0, 2.0] and the second row is [3.0, 4.0]
// Make an array of shape {2, 2} filled with ones: // Make an array of shape {2, 2} filled with ones:
auto y = ones({2, 2}); auto y = mx::ones({2, 2});
// Pointwise add x and y: // Pointwise add x and y:
auto z = add(x, y); auto z = mx::add(x, y);
// Same thing: // Same thing:
z = x + y; z = x + y;
// mlx is lazy by default. At this point `z` only // mlx is lazy by default. At this point `z` only
// has a shape and a type but no actual data: // has a shape and a type but no actual data:
assert(z.dtype() == float32); assert(z.dtype() == mx::float32);
assert(z.shape(0) == 2); assert(z.shape(0) == 2);
assert(z.shape(1) == 2); assert(z.shape(1) == 2);
@ -63,33 +63,33 @@ void array_basics() {
// and inputs. When `eval` is called on an array (or arrays), the array and // and inputs. When `eval` is called on an array (or arrays), the array and
// all of its dependencies are recursively evaluated to produce the result. // all of its dependencies are recursively evaluated to produce the result.
// Once an array is evaluated, it has data and is detached from its inputs. // Once an array is evaluated, it has data and is detached from its inputs.
eval(z); mx::eval(z);
// Of course the array can still be an input to other operations. You can even // Of course the array can still be an input to other operations. You can
// call eval on the array again, this will just be a no-op: // even call eval on the array again, this will just be a no-op:
eval(z); // no-op mx::eval(z); // no-op
// Some functions or methods on arrays implicitly evaluate them. For example // Some functions or methods on arrays implicitly evaluate them. For example
// accessing a value in an array or printing the array implicitly evaluate it: // accessing a value in an array or printing the array implicitly evaluate it:
z = ones({1}); z = mx::ones({1});
z.item<float>(); // implicit evaluation z.item<float>(); // implicit evaluation
z = ones({2, 2}); z = mx::ones({2, 2});
std::cout << z << std::endl; // implicit evaluation std::cout << z << std::endl; // implicit evaluation
} }
void automatic_differentiation() { void automatic_differentiation() {
auto fn = [](array x) { return square(x); }; auto fn = [](mx::array x) { return mx::square(x); };
// Computing the derivative function of a function // Computing the derivative function of a function
auto grad_fn = grad(fn); auto grad_fn = mx::grad(fn);
// Call grad_fn on the input to get the derivative // Call grad_fn on the input to get the derivative
auto x = array(1.5); auto x = mx::array(1.5);
auto dfdx = grad_fn(x); auto dfdx = grad_fn(x);
// dfdx is 2 * x // dfdx is 2 * x
// Get the second derivative by composing grad with grad // Get the second derivative by composing grad with grad
auto d2fdx2 = grad(grad(fn))(x); auto d2fdx2 = mx::grad(mx::grad(fn))(x);
// d2fdx2 is 2 // d2fdx2 is 2
} }

View File

@ -19,7 +19,7 @@
#include "mlx/backend/metal/utils.h" #include "mlx/backend/metal/utils.h"
#endif #endif
namespace mlx::core { namespace my_ext {
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
// Operation Implementation // Operation Implementation
@ -32,24 +32,24 @@ namespace mlx::core {
* Follow numpy style broadcasting between x and y * Follow numpy style broadcasting between x and y
* Inputs are upcasted to floats if needed * Inputs are upcasted to floats if needed
**/ **/
array axpby( mx::array axpby(
const array& x, // Input array x const mx::array& x, // Input mx::array x
const array& y, // Input array y const mx::array& y, // Input mx::array y
const float alpha, // Scaling factor for x const float alpha, // Scaling factor for x
const float beta, // Scaling factor for y const float beta, // Scaling factor for y
StreamOrDevice s /* = {} */ // Stream on which to schedule the operation mx::StreamOrDevice s /* = {} */ // Stream on which to schedule the operation
) { ) {
// Promote dtypes between x and y as needed // Promote dtypes between x and y as needed
auto promoted_dtype = promote_types(x.dtype(), y.dtype()); auto promoted_dtype = promote_types(x.dtype(), y.dtype());
// Upcast to float32 for non-floating point inputs x and y // Upcast to float32 for non-floating point inputs x and y
auto out_dtype = issubdtype(promoted_dtype, float32) auto out_dtype = mx::issubdtype(promoted_dtype, mx::float32)
? promoted_dtype ? promoted_dtype
: promote_types(promoted_dtype, float32); : promote_types(promoted_dtype, mx::float32);
// Cast x and y up to the determined dtype (on the same stream s) // Cast x and y up to the determined dtype (on the same stream s)
auto x_casted = astype(x, out_dtype, s); auto x_casted = mx::astype(x, out_dtype, s);
auto y_casted = astype(y, out_dtype, s); auto y_casted = mx::astype(y, out_dtype, s);
// Broadcast the shapes of x and y (on the same stream s) // Broadcast the shapes of x and y (on the same stream s)
auto broadcasted_inputs = broadcast_arrays({x_casted, y_casted}, s); auto broadcasted_inputs = broadcast_arrays({x_casted, y_casted}, s);
@ -57,12 +57,12 @@ array axpby(
// Construct the array as the output of the Axpby primitive // Construct the array as the output of the Axpby primitive
// with the broadcasted and upcasted arrays as inputs // with the broadcasted and upcasted arrays as inputs
return array( return mx::array(
/* const std::vector<int>& shape = */ out_shape, /* const std::vector<int>& shape = */ out_shape,
/* Dtype dtype = */ out_dtype, /* mx::Dtype dtype = */ out_dtype,
/* std::unique_ptr<Primitive> primitive = */ /* std::unique_ptr<mx::Primitive> primitive = */
std::make_shared<Axpby>(to_stream(s), alpha, beta), std::make_shared<Axpby>(to_stream(s), alpha, beta),
/* const std::vector<array>& inputs = */ broadcasted_inputs); /* const std::vector<mx::array>& inputs = */ broadcasted_inputs);
} }
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
@ -71,16 +71,16 @@ array axpby(
template <typename T> template <typename T>
void axpby_impl( void axpby_impl(
const array& x, const mx::array& x,
const array& y, const mx::array& y,
array& out, mx::array& out,
float alpha_, float alpha_,
float beta_) { float beta_) {
// We only allocate memory when we are ready to fill the output // We only allocate memory when we are ready to fill the output
// malloc_or_wait synchronously allocates available memory // malloc_or_wait synchronously allocates available memory
// There may be a wait executed here if the allocation is requested // There may be a wait executed here if the allocation is requested
// under memory-pressured conditions // under memory-pressured conditions
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(mx::allocator::malloc_or_wait(out.nbytes()));
// Collect input and output data pointers // Collect input and output data pointers
const T* x_ptr = x.data<T>(); const T* x_ptr = x.data<T>();
@ -94,8 +94,8 @@ void axpby_impl(
// Do the element-wise operation for each output // Do the element-wise operation for each output
for (size_t out_idx = 0; out_idx < out.size(); out_idx++) { for (size_t out_idx = 0; out_idx < out.size(); out_idx++) {
// Map linear indices to offsets in x and y // Map linear indices to offsets in x and y
auto x_offset = elem_to_loc(out_idx, x.shape(), x.strides()); auto x_offset = mx::elem_to_loc(out_idx, x.shape(), x.strides());
auto y_offset = elem_to_loc(out_idx, y.shape(), y.strides()); auto y_offset = mx::elem_to_loc(out_idx, y.shape(), y.strides());
// We allocate the output to be contiguous and regularly strided // We allocate the output to be contiguous and regularly strided
// (defaults to row major) and hence it doesn't need additional mapping // (defaults to row major) and hence it doesn't need additional mapping
@ -105,8 +105,8 @@ void axpby_impl(
/** Fall back implementation for evaluation on CPU */ /** Fall back implementation for evaluation on CPU */
void Axpby::eval( void Axpby::eval(
const std::vector<array>& inputs, const std::vector<mx::array>& inputs,
std::vector<array>& outputs) { std::vector<mx::array>& outputs) {
// Check the inputs (registered in the op while constructing the out array) // Check the inputs (registered in the op while constructing the out array)
assert(inputs.size() == 2); assert(inputs.size() == 2);
auto& x = inputs[0]; auto& x = inputs[0];
@ -114,14 +114,14 @@ void Axpby::eval(
auto& out = outputs[0]; auto& out = outputs[0];
// Dispatch to the correct dtype // Dispatch to the correct dtype
if (out.dtype() == float32) { if (out.dtype() == mx::float32) {
return axpby_impl<float>(x, y, out, alpha_, beta_); return axpby_impl<float>(x, y, out, alpha_, beta_);
} else if (out.dtype() == float16) { } else if (out.dtype() == mx::float16) {
return axpby_impl<float16_t>(x, y, out, alpha_, beta_); return axpby_impl<mx::float16_t>(x, y, out, alpha_, beta_);
} else if (out.dtype() == bfloat16) { } else if (out.dtype() == mx::bfloat16) {
return axpby_impl<bfloat16_t>(x, y, out, alpha_, beta_); return axpby_impl<mx::bfloat16_t>(x, y, out, alpha_, beta_);
} else if (out.dtype() == complex64) { } else if (out.dtype() == mx::complex64) {
return axpby_impl<complex64_t>(x, y, out, alpha_, beta_); return axpby_impl<mx::complex64_t>(x, y, out, alpha_, beta_);
} else { } else {
throw std::runtime_error( throw std::runtime_error(
"Axpby is only supported for floating point types."); "Axpby is only supported for floating point types.");
@ -136,9 +136,9 @@ void Axpby::eval(
template <typename T> template <typename T>
void axpby_impl_accelerate( void axpby_impl_accelerate(
const array& x, const mx::array& x,
const array& y, const mx::array& y,
array& out, mx::array& out,
float alpha_, float alpha_,
float beta_) { float beta_) {
// Accelerate library provides catlas_saxpby which does // Accelerate library provides catlas_saxpby which does
@ -150,10 +150,10 @@ void axpby_impl_accelerate(
// The data in the output array is allocated to match the strides in y // The data in the output array is allocated to match the strides in y
// such that x, y, and out are contiguous in the same mode and // such that x, y, and out are contiguous in the same mode and
// no transposition is needed // no transposition is needed
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(mx::allocator::malloc_or_wait(out.nbytes()));
// We then copy over the elements using the contiguous vector specialization // We then copy over the elements using the contiguous vector specialization
copy_inplace(y, out, CopyType::Vector); copy_inplace(y, out, mx::CopyType::Vector);
// Get x and y pointers for catlas_saxpby // Get x and y pointers for catlas_saxpby
const T* x_ptr = x.data<T>(); const T* x_ptr = x.data<T>();
@ -175,15 +175,15 @@ void axpby_impl_accelerate(
/** Evaluate primitive on CPU using accelerate specializations */ /** Evaluate primitive on CPU using accelerate specializations */
void Axpby::eval_cpu( void Axpby::eval_cpu(
const std::vector<array>& inputs, const std::vector<mx::array>& inputs,
std::vector<array>& outputs) { std::vector<mx::array>& outputs) {
assert(inputs.size() == 2); assert(inputs.size() == 2);
auto& x = inputs[0]; auto& x = inputs[0];
auto& y = inputs[1]; auto& y = inputs[1];
auto& out = outputs[0]; auto& out = outputs[0];
// Accelerate specialization for contiguous single precision float arrays // Accelerate specialization for contiguous single precision float arrays
if (out.dtype() == float32 && if (out.dtype() == mx::float32 &&
((x.flags().row_contiguous && y.flags().row_contiguous) || ((x.flags().row_contiguous && y.flags().row_contiguous) ||
(x.flags().col_contiguous && y.flags().col_contiguous))) { (x.flags().col_contiguous && y.flags().col_contiguous))) {
axpby_impl_accelerate<float>(x, y, out, alpha_, beta_); axpby_impl_accelerate<float>(x, y, out, alpha_, beta_);
@ -198,8 +198,8 @@ void Axpby::eval_cpu(
/** Evaluate primitive on CPU falling back to common backend */ /** Evaluate primitive on CPU falling back to common backend */
void Axpby::eval_cpu( void Axpby::eval_cpu(
const std::vector<array>& inputs, const std::vector<mx::array>& inputs,
const std::vector<array>& outputs) { std::vector<mx::array>& outputs) {
eval(inputs, outputs); eval(inputs, outputs);
} }
@ -213,8 +213,8 @@ void Axpby::eval_cpu(
/** Evaluate primitive on GPU */ /** Evaluate primitive on GPU */
void Axpby::eval_gpu( void Axpby::eval_gpu(
const std::vector<array>& inputs, const std::vector<mx::array>& inputs,
std::vector<array>& outputs) { std::vector<mx::array>& outputs) {
// Prepare inputs // Prepare inputs
assert(inputs.size() == 2); assert(inputs.size() == 2);
auto& x = inputs[0]; auto& x = inputs[0];
@ -225,7 +225,7 @@ void Axpby::eval_gpu(
// and each stream carries its device identifiers // and each stream carries its device identifiers
auto& s = stream(); auto& s = stream();
// We get the needed metal device using the stream // We get the needed metal device using the stream
auto& d = metal::device(s.device); auto& d = mx::metal::device(s.device);
// Prepare to specialize based on contiguity // Prepare to specialize based on contiguity
bool contiguous_kernel = bool contiguous_kernel =
@ -235,12 +235,12 @@ void Axpby::eval_gpu(
// Allocate output memory with strides based on specialization // Allocate output memory with strides based on specialization
if (contiguous_kernel) { if (contiguous_kernel) {
out.set_data( out.set_data(
allocator::malloc_or_wait(x.data_size() * out.itemsize()), mx::allocator::malloc_or_wait(x.data_size() * out.itemsize()),
x.data_size(), x.data_size(),
x.strides(), x.strides(),
x.flags()); x.flags());
} else { } else {
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(mx::allocator::malloc_or_wait(out.nbytes()));
} }
// Resolve name of kernel (corresponds to axpby.metal) // Resolve name of kernel (corresponds to axpby.metal)
@ -302,8 +302,8 @@ void Axpby::eval_gpu(
/** Fail evaluation on GPU */ /** Fail evaluation on GPU */
void Axpby::eval_gpu( void Axpby::eval_gpu(
const std::vector<array>& inputs, const std::vector<mx::array>& inputs,
std::vector<array>& out) { std::vector<mx::array>& out) {
throw std::runtime_error("Axpby has no GPU implementation."); throw std::runtime_error("Axpby has no GPU implementation.");
} }
@ -314,9 +314,9 @@ void Axpby::eval_gpu(
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
/** The Jacobian-vector product. */ /** The Jacobian-vector product. */
std::vector<array> Axpby::jvp( std::vector<mx::array> Axpby::jvp(
const std::vector<array>& primals, const std::vector<mx::array>& primals,
const std::vector<array>& tangents, const std::vector<mx::array>& tangents,
const std::vector<int>& argnums) { const std::vector<int>& argnums) {
// Forward mode diff that pushes along the tangents // Forward mode diff that pushes along the tangents
// The jvp transform on the primitive can built with ops // The jvp transform on the primitive can built with ops
@ -328,8 +328,8 @@ std::vector<array> Axpby::jvp(
// scaled by beta // scaled by beta
if (argnums.size() > 1) { if (argnums.size() > 1) {
auto scale = argnums[0] == 0 ? alpha_ : beta_; auto scale = argnums[0] == 0 ? alpha_ : beta_;
auto scale_arr = array(scale, tangents[0].dtype()); auto scale_arr = mx::array(scale, tangents[0].dtype());
return {multiply(scale_arr, tangents[0], stream())}; return {mx::multiply(scale_arr, tangents[0], stream())};
} }
// If, argnums = {0, 1}, we take contributions from both // If, argnums = {0, 1}, we take contributions from both
// which gives us jvp = tangent_x * alpha + tangent_y * beta // which gives us jvp = tangent_x * alpha + tangent_y * beta
@ -339,24 +339,24 @@ std::vector<array> Axpby::jvp(
} }
/** The vector-Jacobian product. */ /** The vector-Jacobian product. */
std::vector<array> Axpby::vjp( std::vector<mx::array> Axpby::vjp(
const std::vector<array>& primals, const std::vector<mx::array>& primals,
const std::vector<array>& cotangents, const std::vector<mx::array>& cotangents,
const std::vector<int>& argnums, const std::vector<int>& argnums,
const std::vector<array>&) { const std::vector<mx::array>&) {
// Reverse mode diff // Reverse mode diff
std::vector<array> vjps; std::vector<mx::array> vjps;
for (auto arg : argnums) { for (auto arg : argnums) {
auto scale = arg == 0 ? alpha_ : beta_; auto scale = arg == 0 ? alpha_ : beta_;
auto scale_arr = array(scale, cotangents[0].dtype()); auto scale_arr = mx::array(scale, cotangents[0].dtype());
vjps.push_back(multiply(scale_arr, cotangents[0], stream())); vjps.push_back(mx::multiply(scale_arr, cotangents[0], stream()));
} }
return vjps; return vjps;
} }
/** Vectorize primitive along given axis */ /** Vectorize primitive along given axis */
std::pair<std::vector<array>, std::vector<int>> Axpby::vmap( std::pair<std::vector<mx::array>, std::vector<int>> Axpby::vmap(
const std::vector<array>& inputs, const std::vector<mx::array>& inputs,
const std::vector<int>& axes) { const std::vector<int>& axes) {
throw std::runtime_error("Axpby has no vmap implementation."); throw std::runtime_error("Axpby has no vmap implementation.");
} }
@ -367,4 +367,4 @@ bool Axpby::is_equivalent(const Primitive& other) const {
return alpha_ == r_other.alpha_ && beta_ == r_other.beta_; return alpha_ == r_other.alpha_ && beta_ == r_other.beta_;
} }
} // namespace mlx::core } // namespace my_ext

View File

@ -5,7 +5,9 @@
#include "mlx/ops.h" #include "mlx/ops.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
namespace mlx::core { namespace mx = mlx::core;
namespace my_ext {
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
// Operation // Operation
@ -18,22 +20,22 @@ namespace mlx::core {
* Follow numpy style broadcasting between x and y * Follow numpy style broadcasting between x and y
* Inputs are upcasted to floats if needed * Inputs are upcasted to floats if needed
**/ **/
array axpby( mx::array axpby(
const array& x, // Input array x const mx::array& x, // Input array x
const array& y, // Input array y const mx::array& y, // Input array y
const float alpha, // Scaling factor for x const float alpha, // Scaling factor for x
const float beta, // Scaling factor for y const float beta, // Scaling factor for y
StreamOrDevice s = {} // Stream on which to schedule the operation mx::StreamOrDevice s = {} // Stream on which to schedule the operation
); );
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
// Primitive // Primitive
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
class Axpby : public Primitive { class Axpby : public mx::Primitive {
public: public:
explicit Axpby(Stream stream, float alpha, float beta) explicit Axpby(mx::Stream stream, float alpha, float beta)
: Primitive(stream), alpha_(alpha), beta_(beta) {}; : mx::Primitive(stream), alpha_(alpha), beta_(beta) {};
/** /**
* A primitive must know how to evaluate itself on the CPU/GPU * A primitive must know how to evaluate itself on the CPU/GPU
@ -42,23 +44,25 @@ class Axpby : public Primitive {
* To avoid unnecessary allocations, the evaluation function * To avoid unnecessary allocations, the evaluation function
* is responsible for allocating space for the array. * is responsible for allocating space for the array.
*/ */
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs) void eval_cpu(
override; const std::vector<mx::array>& inputs,
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs) std::vector<mx::array>& outputs) override;
override; void eval_gpu(
const std::vector<mx::array>& inputs,
std::vector<mx::array>& outputs) override;
/** The Jacobian-vector product. */ /** The Jacobian-vector product. */
std::vector<array> jvp( std::vector<mx::array> jvp(
const std::vector<array>& primals, const std::vector<mx::array>& primals,
const std::vector<array>& tangents, const std::vector<mx::array>& tangents,
const std::vector<int>& argnums) override; const std::vector<int>& argnums) override;
/** The vector-Jacobian product. */ /** The vector-Jacobian product. */
std::vector<array> vjp( std::vector<mx::array> vjp(
const std::vector<array>& primals, const std::vector<mx::array>& primals,
const std::vector<array>& cotangents, const std::vector<mx::array>& cotangents,
const std::vector<int>& argnums, const std::vector<int>& argnums,
const std::vector<array>& outputs) override; const std::vector<mx::array>& outputs) override;
/** /**
* The primitive must know how to vectorize itself across * The primitive must know how to vectorize itself across
@ -66,8 +70,8 @@ class Axpby : public Primitive {
* representing the vectorized computation and the axis which * representing the vectorized computation and the axis which
* corresponds to the output vectorized dimension. * corresponds to the output vectorized dimension.
*/ */
std::pair<std::vector<array>, std::vector<int>> vmap( std::pair<std::vector<mx::array>, std::vector<int>> vmap(
const std::vector<array>& inputs, const std::vector<mx::array>& inputs,
const std::vector<int>& axes) override; const std::vector<int>& axes) override;
/** Print the primitive. */ /** Print the primitive. */
@ -76,14 +80,16 @@ class Axpby : public Primitive {
} }
/** Equivalence check **/ /** Equivalence check **/
bool is_equivalent(const Primitive& other) const override; bool is_equivalent(const mx::Primitive& other) const override;
private: private:
float alpha_; float alpha_;
float beta_; float beta_;
/** Fall back implementation for evaluation on CPU */ /** Fall back implementation for evaluation on CPU */
void eval(const std::vector<array>& inputs, std::vector<array>& outputs); void eval(
const std::vector<mx::array>& inputs,
std::vector<mx::array>& outputs);
}; };
} // namespace mlx::core } // namespace my_ext

View File

@ -8,14 +8,12 @@
namespace nb = nanobind; namespace nb = nanobind;
using namespace nb::literals; using namespace nb::literals;
using namespace mlx::core;
NB_MODULE(_ext, m) { NB_MODULE(_ext, m) {
m.doc() = "Sample extension for MLX"; m.doc() = "Sample extension for MLX";
m.def( m.def(
"axpby", "axpby",
&axpby, &my_ext::axpby,
"x"_a, "x"_a,
"y"_a, "y"_a,
"alpha"_a, "alpha"_a,