diff --git a/benchmarks/cpp/autograd.cpp b/benchmarks/cpp/autograd.cpp index 3beaa04bc..b4303a840 100644 --- a/benchmarks/cpp/autograd.cpp +++ b/benchmarks/cpp/autograd.cpp @@ -5,35 +5,35 @@ #include "mlx/mlx.h" #include "time_utils.h" -using namespace mlx::core; +namespace mx = mlx::core; void time_value_and_grad() { - auto x = ones({200, 1000}); - eval(x); - auto fn = [](array x) { + auto x = mx::ones({200, 1000}); + mx::eval(x); + auto fn = [](mx::array x) { for (int i = 0; i < 20; ++i) { - x = log(exp(x)); + x = mx::log(mx::exp(x)); } - return sum(x); + return mx::sum(x); }; - auto grad_fn = grad(fn); + auto grad_fn = mx::grad(fn); auto independent_value_and_grad = [&]() { auto value = fn(x); auto dfdx = grad_fn(x); - return std::vector{value, dfdx}; + return std::vector{value, dfdx}; }; TIME(independent_value_and_grad); - auto value_and_grad_fn = value_and_grad(fn); + auto value_and_grad_fn = mx::value_and_grad(fn); auto combined_value_and_grad = [&]() { auto [value, dfdx] = value_and_grad_fn(x); - return std::vector{value, dfdx}; + return std::vector{value, dfdx}; }; TIME(combined_value_and_grad); } int main() { - std::cout << "Benchmarks for " << default_device() << std::endl; + std::cout << "Benchmarks for " << mx::default_device() << std::endl; time_value_and_grad(); } diff --git a/benchmarks/cpp/compare_devices.cpp b/benchmarks/cpp/compare_devices.cpp index eecbc9c4a..9af6c7103 100644 --- a/benchmarks/cpp/compare_devices.cpp +++ b/benchmarks/cpp/compare_devices.cpp @@ -4,21 +4,21 @@ #include "mlx/mlx.h" #include "time_utils.h" -using namespace mlx::core; +namespace mx = mlx::core; void time_add_op() { std::vector sizes(1, 1); for (int i = 0; i < 9; ++i) { sizes.push_back(10 * sizes.back()); } - set_default_device(Device::cpu); + set_default_device(mx::Device::cpu); for (auto size : sizes) { - auto a = random::uniform({size}); - auto b = random::uniform({size}); - eval(a, b); + auto a = mx::random::uniform({size}); + auto b = mx::random::uniform({size}); + mx::eval(a, b); std::cout << "Size " << size << std::endl; - TIMEM("cpu", add, a, b, Device::cpu); - TIMEM("gpu", add, a, b, Device::gpu); + TIMEM("cpu", mx::add, a, b, mx::Device::cpu); + TIMEM("gpu", mx::add, a, b, mx::Device::gpu); } } diff --git a/benchmarks/cpp/irregular_strides.cpp b/benchmarks/cpp/irregular_strides.cpp index 50d4c5b73..cda76fed6 100644 --- a/benchmarks/cpp/irregular_strides.cpp +++ b/benchmarks/cpp/irregular_strides.cpp @@ -6,105 +6,105 @@ #include "mlx/mlx.h" #include "time_utils.h" -using namespace mlx::core; +namespace mx = mlx::core; void time_irregular_binary_ops_1D() { - auto device = default_device(); + auto device = mx::default_device(); int size = 1000000; int step = 2; - auto a = random::uniform({size}); - auto b = random::uniform({size}); - eval(a, b); + auto a = mx::random::uniform({size}); + auto b = mx::random::uniform({size}); + mx::eval(a, b); a = slice(a, {0}, {size}, {step}); b = slice(b, {0}, {size}, {step}); - TIMEM("1D strided", add, a, b, device); + TIMEM("1D strided", mx::add, a, b, device); } void time_irregular_binary_ops_2D() { - auto device = default_device(); + auto device = mx::default_device(); int size = 2048; - auto a = random::uniform({size, size}); - auto b = random::uniform({size, size}); - eval(a, b); - TIMEM("2D regular", add, a, b, device); + auto a = mx::random::uniform({size, size}); + auto b = mx::random::uniform({size, size}); + mx::eval(a, b); + TIMEM("2D regular", mx::add, a, b, device); - b = transpose(b); - eval(b); - TIMEM("2D transpose", add, a, b, device); + b = mx::transpose(b); + mx::eval(b); + TIMEM("2D mx::transpose", mx::add, a, b, device); - b = random::uniform({size}); - eval(b); - TIMEM("2D broadcast dim 0", add, a, b, device); + b = mx::random::uniform({size}); + mx::eval(b); + TIMEM("2D broadcast dim 0", mx::add, a, b, device); - b = reshape(b, {size, 1}); - eval(b); - TIMEM("2D broadcast dim 1", add, a, b, device); + b = mx::reshape(b, {size, 1}); + mx::eval(b); + TIMEM("2D broadcast dim 1", mx::add, a, b, device); } void time_irregular_binary_ops_3D() { - auto device = default_device(); + auto device = mx::default_device(); int d0 = 32; int d1 = 512; int d2 = 512; - auto a = random::uniform({d0, d1, d2}); - auto b = random::uniform({d0, d1, d2}); - TIMEM("3D regular", add, a, b, device); + auto a = mx::random::uniform({d0, d1, d2}); + auto b = mx::random::uniform({d0, d1, d2}); + TIMEM("3D regular", mx::add, a, b, device); - b = transpose(b, {0, 2, 1}); - TIMEM("3D transpose", add, a, b, device); + b = mx::transpose(b, {0, 2, 1}); + TIMEM("3D mx::transpose", mx::add, a, b, device); - b = random::uniform({d1, d2}); - TIMEM("3D broadcast dim 0", add, a, b, device); + b = mx::random::uniform({d1, d2}); + TIMEM("3D broadcast dim 0", mx::add, a, b, device); - b = random::uniform({d0, 1, d2}); - TIMEM("3D broadcast dim 1", add, a, b, device); + b = mx::random::uniform({d0, 1, d2}); + TIMEM("3D broadcast dim 1", mx::add, a, b, device); - b = random::uniform({d0, d1, 1}); - TIMEM("3D broadcast dim 2", add, a, b, device); + b = mx::random::uniform({d0, d1, 1}); + TIMEM("3D broadcast dim 2", mx::add, a, b, device); - b = random::uniform({d2}); - TIMEM("3D broadcast dims 0, 1", add, a, b, device); + b = mx::random::uniform({d2}); + TIMEM("3D broadcast dims 0, 1", mx::add, a, b, device); - b = random::uniform({d1, 1}); - TIMEM("3D broadcast dims 0, 2", add, a, b, device); + b = mx::random::uniform({d1, 1}); + TIMEM("3D broadcast dims 0, 2", mx::add, a, b, device); - b = random::uniform({d0, 1, 1}); - TIMEM("3D broadcast dims 1, 2", add, a, b, device); + b = mx::random::uniform({d0, 1, 1}); + TIMEM("3D broadcast dims 1, 2", mx::add, a, b, device); } void time_irregular_binary_ops_4D() { - auto device = default_device(); + auto device = mx::default_device(); std::vector shape = {8, 8, 512, 512}; - auto a = random::uniform(shape); - auto b = random::uniform(shape); + auto a = mx::random::uniform(shape); + auto b = mx::random::uniform(shape); - TIMEM("4D regular", add, a, b, device); + TIMEM("4D regular", mx::add, a, b, device); - b = transpose(b, {0, 1, 3, 2}); - TIMEM("4D transpose", add, a, b, device); + b = mx::transpose(b, {0, 1, 3, 2}); + TIMEM("4D mx::transpose", mx::add, a, b, device); std::string om = "4D broadcast dims "; for (int i = 0; i < shape.size(); ++i) { shape[i] = 1; - b = random::uniform(shape); + b = mx::random::uniform(shape); std::ostringstream msg; msg << om << i; - TIMEM(msg.str(), add, a, b, device); + TIMEM(msg.str(), mx::add, a, b, device); for (int j = i + 1; j < shape.size(); ++j) { shape[j] = 1; std::ostringstream msg; msg << om << i << ", " << j; - b = random::uniform(shape); - TIMEM(msg.str(), add, a, b, device); + b = mx::random::uniform(shape); + TIMEM(msg.str(), mx::add, a, b, device); shape[j] = a.shape(j); for (int k = j + 1; k < shape.size(); ++k) { shape[k] = 1; std::ostringstream msg; msg << om << i << ", " << j << ", " << k; - b = random::uniform(shape); - TIMEM(msg.str(), add, a, b, device); + b = mx::random::uniform(shape); + TIMEM(msg.str(), mx::add, a, b, device); shape[k] = a.shape(k); } } @@ -113,83 +113,83 @@ void time_irregular_binary_ops_4D() { } void time_irregular_reshape() { - auto device = default_device(); + auto device = mx::default_device(); std::vector shape; - auto reshape_fn = [&shape, device](const array& a) { - return reshape(a, shape, device); + auto reshape_fn = [&shape, device](const mx::array& a) { + return mx::reshape(a, shape, device); }; int size = 64; int d = 2 * size; - auto a = random::uniform({d, d, d}); + auto a = mx::random::uniform({d, d, d}); shape = {8 * size, size, size}; TIMEM("3D contiguous", reshape_fn, a); - a = transpose(a); + a = mx::transpose(a); shape = {8 * size, size, size}; - TIMEM("3D transpose", reshape_fn, a); + TIMEM("3D mx::transpose", reshape_fn, a); - a = transpose(a, {1, 2, 0}); + a = mx::transpose(a, {1, 2, 0}); shape = {8 * size, size, size}; - TIMEM("3D transpose dims 1 2", reshape_fn, a); + TIMEM("3D mx::transpose dims 1 2", reshape_fn, a); - a = broadcast_to(random::uniform({d, d}), {d, d, d}); + a = mx::broadcast_to(mx::random::uniform({d, d}), {d, d, d}); TIMEM("3D broadcast dim 0", reshape_fn, a); - a = broadcast_to(random::uniform({d, 1, d}), {d, d, d}); + a = mx::broadcast_to(mx::random::uniform({d, 1, d}), {d, d, d}); TIMEM("3D broadcast dim 1", reshape_fn, a); - a = broadcast_to(random::uniform({d, d, 1}), {d, d, d}); + a = mx::broadcast_to(mx::random::uniform({d, d, 1}), {d, d, d}); TIMEM("3D broadcast dim 2", reshape_fn, a); - a = broadcast_to(random::uniform({d}), {d, d, d}); + a = mx::broadcast_to(mx::random::uniform({d}), {d, d, d}); TIMEM("3D broadcast dims 0, 1", reshape_fn, a); - a = broadcast_to(random::uniform({d, 1}), {d, d, d}); + a = mx::broadcast_to(mx::random::uniform({d, 1}), {d, d, d}); TIMEM("3D broadcast dims 0, 2", reshape_fn, a); - a = broadcast_to(random::uniform({d, 1, 1}), {d, d, d}); + a = mx::broadcast_to(mx::random::uniform({d, 1, 1}), {d, d, d}); TIMEM("3D broadcast dims 1, 2", reshape_fn, a); - a = broadcast_to(random::uniform({1, 1, 1}), {d, d, d}); + a = mx::broadcast_to(mx::random::uniform({1, 1, 1}), {d, d, d}); TIMEM("3D broadcast dims 1, 2, 3", reshape_fn, a); } void time_irregular_astype_1D() { - auto device = default_device(); + auto device = mx::default_device(); int size = 1000000; int step = 2; - auto a = random::uniform({size}); + auto a = mx::random::uniform({size}); a = slice(a, {0}, {size}, {step}); - TIMEM("1D strided", astype, a, int32, device); + TIMEM("1D strided", mx::astype, a, mx::int32, device); } void time_irregular_astype_2D() { - auto device = default_device(); + auto device = mx::default_device(); int size = 2048; std::vector shape = {size, size}; - auto a = random::uniform(shape); - TIMEM("2D regular", astype, a, int32, device); + auto a = mx::random::uniform(shape); + TIMEM("2D regular", mx::astype, a, mx::int32, device); - a = transpose(a); - TIMEM("2D transpose", astype, a, int32, device); + a = mx::transpose(a); + TIMEM("2D mx::transpose", mx::astype, a, mx::int32, device); - a = broadcast_to(random::uniform({size}), shape); - TIMEM("2D broadcast dim 0", astype, a, int32, device); + a = mx::broadcast_to(mx::random::uniform({size}), shape); + TIMEM("2D broadcast dim 0", mx::astype, a, mx::int32, device); - a = broadcast_to(random::uniform({size, 1}), shape); - TIMEM("2D broadcast dim 1", astype, a, int32, device); + a = mx::broadcast_to(mx::random::uniform({size, 1}), shape); + TIMEM("2D broadcast dim 1", mx::astype, a, mx::int32, device); } int main(int argc, char** argv) { if (argc > 1) { bool use_gpu = !strcmp(argv[1], "gpu"); - set_default_device(use_gpu ? Device::gpu : Device::cpu); + set_default_device(use_gpu ? mx::Device::gpu : mx::Device::cpu); } - std::cout << "Benchmarks for " << default_device() << std::endl; + std::cout << "Benchmarks for " << mx::default_device() << std::endl; time_irregular_binary_ops_1D(); time_irregular_binary_ops_2D(); time_irregular_binary_ops_3D(); diff --git a/benchmarks/cpp/single_ops.cpp b/benchmarks/cpp/single_ops.cpp index 4505282f1..5b327be58 100644 --- a/benchmarks/cpp/single_ops.cpp +++ b/benchmarks/cpp/single_ops.cpp @@ -3,20 +3,20 @@ #include "mlx/mlx.h" #include "time_utils.h" -using namespace mlx::core; +namespace mx = mlx::core; void time_creation_ops() { int M = 2000; int N = 500; auto shape = {M, N}; - auto full_fp32 = [&]() { return full(shape, 3.3f); }; + auto full_fp32 = [&]() { return mx::full(shape, 3.3f); }; TIME(full_fp32); - auto zeros_fp32 = [&]() { return zeros(shape, float32); }; + auto zeros_fp32 = [&]() { return mx::zeros(shape, mx::float32); }; TIME(zeros_fp32); - auto ones_fp32 = [&]() { return ones(shape, float32); }; + auto ones_fp32 = [&]() { return mx::ones(shape, mx::float32); }; TIME(ones_fp32); - auto arange_fp32 = [&]() { return arange(0.0, 10.0, 1e-4); }; + auto arange_fp32 = [&]() { return mx::arange(0.0, 10.0, 1e-4); }; TIME(arange_fp32); } @@ -24,194 +24,196 @@ void time_type_conversions() { int M = 2000; int N = 500; auto shape = {M, N}; - auto device = default_device(); + auto device = mx::default_device(); - auto a = zeros(shape, float32); - eval(a); - TIMEM("float32 to int32", astype, a, int32, device); - TIMEM("float32 to uint32", astype, a, uint32, device); + auto a = mx::zeros(shape, mx::float32); + mx::eval(a); + TIMEM("mx::float32 to mx::int32", mx::astype, a, mx::int32, device); + TIMEM("mx::float32 to mx::uint32", mx::astype, a, mx::uint32, device); - a = zeros(shape, int32); - eval(a); - TIMEM("int32 to float32", astype, a, float32, device); + a = mx::zeros(shape, mx::int32); + mx::eval(a); + TIMEM("mx::int32 to mx::float32", mx::astype, a, mx::float32, device); - a = zeros(shape, bool_); - eval(a); - TIMEM("bool to float32", astype, a, float32, device); - TIMEM("bool to int32", astype, a, int32, device); - TIMEM("bool to uint32", astype, a, uint32, device); + a = mx::zeros(shape, mx::bool_); + mx::eval(a); + TIMEM("bool to mx::float32", mx::astype, a, mx::float32, device); + TIMEM("bool to mx::int32", mx::astype, a, mx::int32, device); + TIMEM("bool to mx::uint32", mx::astype, a, mx::uint32, device); } void time_random_generation() { int M = 2000; int N = 500; - auto uniform = [&]() { return random::uniform({M, N}, float32); }; + auto uniform = [&]() { return mx::random::uniform({M, N}, mx::float32); }; TIME(uniform); - auto normal = [&]() { return random::normal({M, N}, float32); }; + auto normal = [&]() { return mx::random::normal({M, N}, mx::float32); }; TIME(normal); } void time_unary_ops() { int M = 2000; int N = 500; - auto device = default_device(); + auto device = mx::default_device(); - auto a = random::normal({M, N}); - eval(a); + auto a = mx::random::normal({M, N}); + mx::eval(a); TIME(mlx::core::abs, a, device); - TIME(negative, a, device); - TIME(sign, a, device); - TIME(square, a, device); + TIME(mx::negative, a, device); + TIME(mx::sign, a, device); + TIME(mx::square, a, device); TIME(mlx::core::sqrt, a, device); - TIME(rsqrt, a, device); + TIME(mx::rsqrt, a, device); TIME(mlx::core::exp, a, device); - a = random::uniform({M, N}); + a = mx::random::uniform({M, N}); TIME(mlx::core::log, a, device); } void time_binary_ops() { int M = 1000, N = 100, K = 10; - auto condition = random::randint(0, 2, {M, N, K}); - auto a = random::uniform({M, N, K}); - auto b = random::uniform({M, N, K}); - auto device = default_device(); - eval(a, b); + auto condition = mx::random::randint(0, 2, {M, N, K}); + auto a = mx::random::uniform({M, N, K}); + auto b = mx::random::uniform({M, N, K}); + auto device = mx::default_device(); + mx::eval(a, b); - TIME(add, a, b, device); - TIME(subtract, a, b, device); - TIME(multiply, a, b, device); - TIME(divide, a, b, device); - TIME(maximum, a, b, device); - TIME(minimum, a, b, device); - TIME(where, condition, a, b, device); + TIME(mx::add, a, b, device); + TIME(mx::subtract, a, b, device); + TIME(mx::multiply, a, b, device); + TIME(mx::divide, a, b, device); + TIME(mx::maximum, a, b, device); + TIME(mx::minimum, a, b, device); + TIME(mx::where, condition, a, b, device); - condition = array({true}); - b = random::uniform({1}); - eval(b); - TIMEM("scalar", add, a, b, device); - TIMEM("vector-scalar", subtract, a, b, device); - TIMEM("scalar-vector", subtract, b, a, device); - TIMEM("scalar", multiply, a, b, device); - TIMEM("vector-scalar", divide, a, b, device); - TIMEM("scalar-vector", divide, b, a, device); - TIMEM("scalar-vector", where, condition, a, b, device); + condition = mx::array({true}); + b = mx::random::uniform({1}); + mx::eval(b); + TIMEM("scalar", mx::add, a, b, device); + TIMEM("vector-scalar", mx::subtract, a, b, device); + TIMEM("scalar-vector", mx::subtract, b, a, device); + TIMEM("scalar", mx::multiply, a, b, device); + TIMEM("vector-scalar", mx::divide, a, b, device); + TIMEM("scalar-vector", mx::divide, b, a, device); + TIMEM("scalar-vector", mx::where, condition, a, b, device); - condition = broadcast_to(array({true}), {1000, 100}); - a = broadcast_to(random::uniform({1}), {1000, 100}); - b = broadcast_to(random::uniform({1}), {1000, 100}); - eval(a, b); - TIMEM("scalar-scalar broadcast", add, a, b, device); - TIMEM("scalar-scalar broadcast", subtract, a, b, device); - TIMEM("scalar-scalar broadcast", multiply, a, b, device); - TIMEM("scalar-scalar broadcast", divide, a, b, device); - TIMEM("scalar-scalar broadcast", where, condition, a, b, device); + condition = mx::broadcast_to(mx::array({true}), {1000, 100}); + a = mx::broadcast_to(mx::random::uniform({1}), {1000, 100}); + b = mx::broadcast_to(mx::random::uniform({1}), {1000, 100}); + mx::eval(a, b); + TIMEM("scalar-scalar broadcast", mx::add, a, b, device); + TIMEM("scalar-scalar broadcast", mx::subtract, a, b, device); + TIMEM("scalar-scalar broadcast", mx::multiply, a, b, device); + TIMEM("scalar-scalar broadcast", mx::divide, a, b, device); + TIMEM("scalar-scalar broadcast", mx::where, condition, a, b, device); } void time_strided_ops() { int M = 50, N = 50, O = 50, P = 50; - auto a = random::uniform({M, N, O, P}); - auto b = random::uniform({M, N, O, P}); - auto device = default_device(); - eval(a, b); - TIMEM("non-strided", add, a, b, device); - a = transpose(a, {1, 0, 2, 3}); - b = transpose(b, {3, 2, 0, 1}); - eval(a, b); - TIMEM("strided", add, a, b, device); + auto a = mx::random::uniform({M, N, O, P}); + auto b = mx::random::uniform({M, N, O, P}); + auto device = mx::default_device(); + mx::eval(a, b); + TIMEM("non-strided", mx::add, a, b, device); + a = mx::transpose(a, {1, 0, 2, 3}); + b = mx::transpose(b, {3, 2, 0, 1}); + mx::eval(a, b); + TIMEM("strided", mx::add, a, b, device); } void time_comparisons() { int M = 1000, N = 100, K = 10; - auto a = random::uniform({M, N, K}); - auto b = random::uniform({M, N, K}); - auto device = default_device(); - eval(a, b); - TIME(equal, a, b, device); - TIME(greater, a, b, device); - TIME(greater_equal, a, b, device); - TIME(less, a, b, device); - TIME(less_equal, a, b, device); + auto a = mx::random::uniform({M, N, K}); + auto b = mx::random::uniform({M, N, K}); + auto device = mx::default_device(); + mx::eval(a, b); + TIME(mx::equal, a, b, device); + TIME(mx::greater, a, b, device); + TIME(mx::greater_equal, a, b, device); + TIME(mx::less, a, b, device); + TIME(mx::less_equal, a, b, device); } void time_matvec() { int M = 2000, N = 200; - auto a = random::uniform({M, N}); - auto b = random::uniform({N}); - auto c = random::uniform({M}); - eval(a, b, c); - auto matvec = [&]() { return matmul(a, b); }; + auto a = mx::random::uniform({M, N}); + auto b = mx::random::uniform({N}); + auto c = mx::random::uniform({M}); + mx::eval(a, b, c); + auto matvec = [&]() { return mx::matmul(a, b); }; TIME(matvec); - auto matvec_transpose = [&]() { return matmul(transpose(a), c); }; + auto matvec_transpose = [&]() { return mx::matmul(mx::transpose(a), c); }; TIME(matvec_transpose); } void time_matmul() { int M = 1000, N = 1000, K = 1000; - auto a = random::uniform({M, K}); - auto b = random::uniform({K, N}); - auto device = default_device(); - eval(a, b); - TIME(matmul, a, b, device); + auto a = mx::random::uniform({M, K}); + auto b = mx::random::uniform({K, N}); + auto device = mx::default_device(); + mx::eval(a, b); + TIME(mx::matmul, a, b, device); - auto transpose_matmul = [&]() { return matmul(transpose(a), b); }; + auto transpose_matmul = [&]() { return mx::matmul(mx::transpose(a), b); }; TIME(transpose_matmul); } void time_reductions() { - auto a = random::normal({10000, 1000}); - eval(a); - auto sum_all = [&a]() { return sum(a, false); }; + auto a = mx::random::normal({10000, 1000}); + mx::eval(a); + auto sum_all = [&a]() { return mx::sum(a, false); }; TIME(sum_all); - auto sum_along_0 = [&a]() { return sum(a, 0, false); }; + auto sum_along_0 = [&a]() { return mx::sum(a, 0, false); }; TIME(sum_along_0); - auto sum_along_1 = [&a]() { return sum(a, 1, false); }; + auto sum_along_1 = [&a]() { return mx::sum(a, 1, false); }; TIME(sum_along_1); - auto prod_all = [&a]() { return prod(a, false); }; + auto prod_all = [&a]() { return mx::prod(a, false); }; TIME(prod_all); - auto all_true = [&a]() { return all(a, false); }; + auto all_true = [&a]() { return mx::all(a, false); }; TIME(all_true); - auto all_along_0 = [&a]() { return all(a, 0, false); }; + auto all_along_0 = [&a]() { return mx::all(a, 0, false); }; TIME(all_along_0); - auto all_along_1 = [&a]() { return all(a, 1, false); }; + auto all_along_1 = [&a]() { return mx::all(a, 1, false); }; TIME(all_along_1); - auto any_true = [&a]() { return any(a, false); }; + auto any_true = [&a]() { return mx::any(a, false); }; TIME(any_true); - auto argmin_along_0 = [&a]() { return argmin(a, 0, false); }; + auto argmin_along_0 = [&a]() { return mx::argmin(a, 0, false); }; TIME(argmin_along_0); - auto argmin_along_1 = [&a]() { return argmin(a, 1, false); }; + auto argmin_along_1 = [&a]() { return mx::argmin(a, 1, false); }; TIME(argmin_along_1); } void time_gather_scatter() { - auto a = random::normal({1000, 768}); - eval(a); - auto indices = random::randint(0, 1000, {256}); - eval(indices); + auto a = mx::random::normal({1000, 768}); + mx::eval(a); + auto indices = mx::random::randint(0, 1000, {256}); + mx::eval(indices); - auto embedding_lookup = [&a, &indices]() { return take(a, indices, 0); }; + auto embedding_lookup = [&a, &indices]() { return mx::take(a, indices, 0); }; TIME(embedding_lookup); - indices = random::randint(0, 768 * 1000, {256 * 768}); - eval(indices); + indices = mx::random::randint(0, 768 * 1000, {256 * 768}); + mx::eval(indices); - auto single_element_lookup = [&a, &indices]() { return take(a, indices); }; + auto single_element_lookup = [&a, &indices]() { + return mx::take(a, indices); + }; TIME(single_element_lookup); - indices = random::randint(0, 1000, {256}); - auto updates = random::normal({256, 1, 768}); - eval(indices, updates); + indices = mx::random::randint(0, 1000, {256}); + auto updates = mx::random::normal({256, 1, 768}); + mx::eval(indices, updates); auto embedding_update = [&a, &indices, &updates]() { return scatter(a, indices, updates, 0); @@ -223,10 +225,10 @@ void time_gather_scatter() { }; TIME(embedding_add); - a = reshape(a, {-1}); - indices = random::randint(0, 768 * 1000, {768 * 256}); - updates = random::normal({256 * 768, 1}); - eval(a, indices, updates); + a = mx::reshape(a, {-1}); + indices = mx::random::randint(0, 768 * 1000, {768 * 256}); + updates = mx::random::normal({256 * 768, 1}); + mx::eval(a, indices, updates); auto single_element_update = [&a, &indices, &updates]() { return scatter(a, indices, updates, 0); @@ -240,21 +242,21 @@ void time_gather_scatter() { } void time_divmod() { - auto a = random::normal({1000}); - auto b = random::normal({1000}); - eval({a, b}); + auto a = mx::random::normal({1000}); + auto b = mx::random::normal({1000}); + mx::eval({a, b}); - auto divmod_fused = [&a, &b]() { return divmod(a, b); }; + auto divmod_fused = [&a, &b]() { return mx::divmod(a, b); }; TIME(divmod_fused); auto divmod_separate = [&a, &b]() { - return std::vector{floor_divide(a, b), remainder(a, b)}; + return std::vector{mx::floor_divide(a, b), mx::remainder(a, b)}; }; TIME(divmod_separate); } int main() { - std::cout << "Benchmarks for " << default_device() << std::endl; + std::cout << "Benchmarks for " << mx::default_device() << std::endl; time_creation_ops(); time_type_conversions(); time_unary_ops(); diff --git a/examples/cpp/distributed.cpp b/examples/cpp/distributed.cpp index 14229c1a2..1cc69d951 100644 --- a/examples/cpp/distributed.cpp +++ b/examples/cpp/distributed.cpp @@ -4,19 +4,19 @@ #include "mlx/mlx.h" -using namespace mlx::core; +namespace mx = mlx::core; int main() { - if (!distributed::is_available()) { + if (!mx::distributed::is_available()) { std::cout << "No communication backend found" << std::endl; return 1; } - auto global_group = distributed::init(); + auto global_group = mx::distributed::init(); std::cout << global_group.rank() << " / " << global_group.size() << std::endl; - array x = ones({10}); - array out = distributed::all_sum(x, global_group); + mx::array x = mx::ones({10}); + mx::array out = mx::distributed::all_sum(x, global_group); std::cout << out << std::endl; } diff --git a/examples/cpp/linear_regression.cpp b/examples/cpp/linear_regression.cpp index f921da15b..ba3578d2f 100644 --- a/examples/cpp/linear_regression.cpp +++ b/examples/cpp/linear_regression.cpp @@ -10,7 +10,7 @@ /** * An example of linear regression with MLX. */ -using namespace mlx::core; +namespace mx = mlx::core; int main() { int num_features = 100; @@ -19,35 +19,35 @@ int main() { float learning_rate = 0.01; // True parameters - auto w_star = random::normal({num_features}); + auto w_star = mx::random::normal({num_features}); // The input examples (design matrix) - auto X = random::normal({num_examples, num_features}); + auto X = mx::random::normal({num_examples, num_features}); // Noisy labels - auto eps = 1e-2 * random::normal({num_examples}); - auto y = matmul(X, w_star) + eps; + auto eps = 1e-2 * mx::random::normal({num_examples}); + auto y = mx::matmul(X, w_star) + eps; // Initialize random parameters - array w = 1e-2 * random::normal({num_features}); + mx::array w = 1e-2 * mx::random::normal({num_features}); - auto loss_fn = [&](array w) { - auto yhat = matmul(X, w); - return (0.5f / num_examples) * sum(square(yhat - y)); + auto loss_fn = [&](mx::array w) { + auto yhat = mx::matmul(X, w); + return (0.5f / num_examples) * mx::sum(mx::square(yhat - y)); }; - auto grad_fn = grad(loss_fn); + auto grad_fn = mx::grad(loss_fn); auto tic = timer::time(); for (int it = 0; it < num_iters; ++it) { - auto grad = grad_fn(w); - w = w - learning_rate * grad; - eval(w); + auto grads = grad_fn(w); + w = w - learning_rate * grads; + mx::eval(w); } auto toc = timer::time(); auto loss = loss_fn(w); - auto error_norm = std::sqrt(sum(square(w - w_star)).item()); + auto error_norm = std::sqrt(mx::sum(mx::square(w - w_star)).item()); auto throughput = num_iters / timer::seconds(toc - tic); std::cout << "Loss " << loss << ", |w - w*| = " << error_norm << ", Throughput " << throughput << " (it/s)." << std::endl; diff --git a/examples/cpp/logistic_regression.cpp b/examples/cpp/logistic_regression.cpp index 7f13f35bb..1d373f524 100644 --- a/examples/cpp/logistic_regression.cpp +++ b/examples/cpp/logistic_regression.cpp @@ -10,7 +10,7 @@ /** * An example of logistic regression with MLX. */ -using namespace mlx::core; +namespace mx = mlx::core; int main() { int num_features = 100; @@ -19,35 +19,35 @@ int main() { float learning_rate = 0.1; // True parameters - auto w_star = random::normal({num_features}); + auto w_star = mx::random::normal({num_features}); // The input examples - auto X = random::normal({num_examples, num_features}); + auto X = mx::random::normal({num_examples, num_features}); // Labels - auto y = matmul(X, w_star) > 0; + auto y = mx::matmul(X, w_star) > 0; // Initialize random parameters - array w = 1e-2 * random::normal({num_features}); + mx::array w = 1e-2 * mx::random::normal({num_features}); - auto loss_fn = [&](array w) { - auto logits = matmul(X, w); + auto loss_fn = [&](mx::array w) { + auto logits = mx::matmul(X, w); auto scale = (1.0f / num_examples); - return scale * sum(logaddexp(array(0.0f), logits) - y * logits); + return scale * mx::sum(mx::logaddexp(mx::array(0.0f), logits) - y * logits); }; - auto grad_fn = grad(loss_fn); + auto grad_fn = mx::grad(loss_fn); auto tic = timer::time(); for (int it = 0; it < num_iters; ++it) { - auto grad = grad_fn(w); - w = w - learning_rate * grad; - eval(w); + auto grads = grad_fn(w); + w = w - learning_rate * grads; + mx::eval(w); } auto toc = timer::time(); auto loss = loss_fn(w); - auto acc = sum((matmul(X, w) > 0) == y) / num_examples; + auto acc = mx::sum((mx::matmul(X, w) > 0) == y) / num_examples; auto throughput = num_iters / timer::seconds(toc - tic); std::cout << "Loss " << loss << ", Accuracy, " << acc << ", Throughput " << throughput << " (it/s)." << std::endl; diff --git a/examples/cpp/metal_capture.cpp b/examples/cpp/metal_capture.cpp index d31c49f96..17f830066 100644 --- a/examples/cpp/metal_capture.cpp +++ b/examples/cpp/metal_capture.cpp @@ -5,27 +5,27 @@ #include "mlx/mlx.h" -using namespace mlx::core; +namespace mx = mlx::core; int main() { // To use Metal debugging and profiling: // 1. Build with the MLX_METAL_DEBUG CMake option (i.e. -DMLX_METAL_DEBUG=ON). // 2. Run with MTL_CAPTURE_ENABLED=1. - metal::start_capture("mlx_trace.gputrace"); + mx::metal::start_capture("mlx_trace.gputrace"); // Start at index two because the default GPU and CPU streams have indices // zero and one, respectively. This naming matches the label assigned to each // stream's command queue. - auto s2 = new_stream(Device::gpu); - auto s3 = new_stream(Device::gpu); + auto s2 = new_stream(mx::Device::gpu); + auto s3 = new_stream(mx::Device::gpu); - auto a = arange(1.f, 10.f, 1.f, float32, s2); - auto b = arange(1.f, 10.f, 1.f, float32, s3); - auto x = add(a, a, s2); - auto y = add(b, b, s3); + auto a = mx::arange(1.f, 10.f, 1.f, mx::float32, s2); + auto b = mx::arange(1.f, 10.f, 1.f, mx::float32, s3); + auto x = mx::add(a, a, s2); + auto y = mx::add(b, b, s3); // The multiply will happen on the default stream. - std::cout << multiply(x, y) << std::endl; + std::cout << mx::multiply(x, y) << std::endl; - metal::stop_capture(); + mx::metal::stop_capture(); } diff --git a/examples/cpp/tutorial.cpp b/examples/cpp/tutorial.cpp index 25cc85c31..ae2cd3cfb 100644 --- a/examples/cpp/tutorial.cpp +++ b/examples/cpp/tutorial.cpp @@ -5,11 +5,11 @@ #include "mlx/mlx.h" -using namespace mlx::core; +namespace mx = mlx::core; void array_basics() { // Make a scalar array: - array x(1.0); + mx::array x(1.0); // Get the value out of it: auto s = x.item(); @@ -29,31 +29,31 @@ void array_basics() { // The datatype should be float32: auto dtype = x.dtype(); - assert(dtype == float32); + assert(dtype == mx::float32); // Specify the dtype when constructing the array: - x = array(1, int32); - assert(x.dtype() == int32); + x = mx::array(1, mx::int32); + assert(x.dtype() == mx::int32); x.item(); // OK // x.item(); // Undefined! // Make a multidimensional array: - x = array({1.0f, 2.0f, 3.0f, 4.0f}, {2, 2}); + x = mx::array({1.0f, 2.0f, 3.0f, 4.0f}, {2, 2}); // mlx is row-major by default so the first row of this array // is [1.0, 2.0] and the second row is [3.0, 4.0] // Make an array of shape {2, 2} filled with ones: - auto y = ones({2, 2}); + auto y = mx::ones({2, 2}); // Pointwise add x and y: - auto z = add(x, y); + auto z = mx::add(x, y); // Same thing: z = x + y; // mlx is lazy by default. At this point `z` only // has a shape and a type but no actual data: - assert(z.dtype() == float32); + assert(z.dtype() == mx::float32); assert(z.shape(0) == 2); assert(z.shape(1) == 2); @@ -63,33 +63,33 @@ void array_basics() { // and inputs. When `eval` is called on an array (or arrays), the array and // all of its dependencies are recursively evaluated to produce the result. // Once an array is evaluated, it has data and is detached from its inputs. - eval(z); + mx::eval(z); - // Of course the array can still be an input to other operations. You can even - // call eval on the array again, this will just be a no-op: - eval(z); // no-op + // Of course the array can still be an input to other operations. You can + // even call eval on the array again, this will just be a no-op: + mx::eval(z); // no-op // Some functions or methods on arrays implicitly evaluate them. For example // accessing a value in an array or printing the array implicitly evaluate it: - z = ones({1}); + z = mx::ones({1}); z.item(); // implicit evaluation - z = ones({2, 2}); + z = mx::ones({2, 2}); std::cout << z << std::endl; // implicit evaluation } void automatic_differentiation() { - auto fn = [](array x) { return square(x); }; + auto fn = [](mx::array x) { return mx::square(x); }; // Computing the derivative function of a function - auto grad_fn = grad(fn); + auto grad_fn = mx::grad(fn); // Call grad_fn on the input to get the derivative - auto x = array(1.5); + auto x = mx::array(1.5); auto dfdx = grad_fn(x); // dfdx is 2 * x // Get the second derivative by composing grad with grad - auto d2fdx2 = grad(grad(fn))(x); + auto d2fdx2 = mx::grad(mx::grad(fn))(x); // d2fdx2 is 2 } diff --git a/examples/extensions/axpby/axpby.cpp b/examples/extensions/axpby/axpby.cpp index 07db2dd0c..70b02fb73 100644 --- a/examples/extensions/axpby/axpby.cpp +++ b/examples/extensions/axpby/axpby.cpp @@ -19,7 +19,7 @@ #include "mlx/backend/metal/utils.h" #endif -namespace mlx::core { +namespace my_ext { /////////////////////////////////////////////////////////////////////////////// // Operation Implementation @@ -32,24 +32,24 @@ namespace mlx::core { * Follow numpy style broadcasting between x and y * Inputs are upcasted to floats if needed **/ -array axpby( - const array& x, // Input array x - const array& y, // Input array y +mx::array axpby( + const mx::array& x, // Input mx::array x + const mx::array& y, // Input mx::array y const float alpha, // Scaling factor for x const float beta, // Scaling factor for y - StreamOrDevice s /* = {} */ // Stream on which to schedule the operation + mx::StreamOrDevice s /* = {} */ // Stream on which to schedule the operation ) { // Promote dtypes between x and y as needed auto promoted_dtype = promote_types(x.dtype(), y.dtype()); // Upcast to float32 for non-floating point inputs x and y - auto out_dtype = issubdtype(promoted_dtype, float32) + auto out_dtype = mx::issubdtype(promoted_dtype, mx::float32) ? promoted_dtype - : promote_types(promoted_dtype, float32); + : promote_types(promoted_dtype, mx::float32); // Cast x and y up to the determined dtype (on the same stream s) - auto x_casted = astype(x, out_dtype, s); - auto y_casted = astype(y, out_dtype, s); + auto x_casted = mx::astype(x, out_dtype, s); + auto y_casted = mx::astype(y, out_dtype, s); // Broadcast the shapes of x and y (on the same stream s) auto broadcasted_inputs = broadcast_arrays({x_casted, y_casted}, s); @@ -57,12 +57,12 @@ array axpby( // Construct the array as the output of the Axpby primitive // with the broadcasted and upcasted arrays as inputs - return array( + return mx::array( /* const std::vector& shape = */ out_shape, - /* Dtype dtype = */ out_dtype, - /* std::unique_ptr primitive = */ + /* mx::Dtype dtype = */ out_dtype, + /* std::unique_ptr primitive = */ std::make_shared(to_stream(s), alpha, beta), - /* const std::vector& inputs = */ broadcasted_inputs); + /* const std::vector& inputs = */ broadcasted_inputs); } /////////////////////////////////////////////////////////////////////////////// @@ -71,16 +71,16 @@ array axpby( template void axpby_impl( - const array& x, - const array& y, - array& out, + const mx::array& x, + const mx::array& y, + mx::array& out, float alpha_, float beta_) { // We only allocate memory when we are ready to fill the output // malloc_or_wait synchronously allocates available memory // There may be a wait executed here if the allocation is requested // under memory-pressured conditions - out.set_data(allocator::malloc_or_wait(out.nbytes())); + out.set_data(mx::allocator::malloc_or_wait(out.nbytes())); // Collect input and output data pointers const T* x_ptr = x.data(); @@ -94,8 +94,8 @@ void axpby_impl( // Do the element-wise operation for each output for (size_t out_idx = 0; out_idx < out.size(); out_idx++) { // Map linear indices to offsets in x and y - auto x_offset = elem_to_loc(out_idx, x.shape(), x.strides()); - auto y_offset = elem_to_loc(out_idx, y.shape(), y.strides()); + auto x_offset = mx::elem_to_loc(out_idx, x.shape(), x.strides()); + auto y_offset = mx::elem_to_loc(out_idx, y.shape(), y.strides()); // We allocate the output to be contiguous and regularly strided // (defaults to row major) and hence it doesn't need additional mapping @@ -105,8 +105,8 @@ void axpby_impl( /** Fall back implementation for evaluation on CPU */ void Axpby::eval( - const std::vector& inputs, - std::vector& outputs) { + const std::vector& inputs, + std::vector& outputs) { // Check the inputs (registered in the op while constructing the out array) assert(inputs.size() == 2); auto& x = inputs[0]; @@ -114,14 +114,14 @@ void Axpby::eval( auto& out = outputs[0]; // Dispatch to the correct dtype - if (out.dtype() == float32) { + if (out.dtype() == mx::float32) { return axpby_impl(x, y, out, alpha_, beta_); - } else if (out.dtype() == float16) { - return axpby_impl(x, y, out, alpha_, beta_); - } else if (out.dtype() == bfloat16) { - return axpby_impl(x, y, out, alpha_, beta_); - } else if (out.dtype() == complex64) { - return axpby_impl(x, y, out, alpha_, beta_); + } else if (out.dtype() == mx::float16) { + return axpby_impl(x, y, out, alpha_, beta_); + } else if (out.dtype() == mx::bfloat16) { + return axpby_impl(x, y, out, alpha_, beta_); + } else if (out.dtype() == mx::complex64) { + return axpby_impl(x, y, out, alpha_, beta_); } else { throw std::runtime_error( "Axpby is only supported for floating point types."); @@ -136,9 +136,9 @@ void Axpby::eval( template void axpby_impl_accelerate( - const array& x, - const array& y, - array& out, + const mx::array& x, + const mx::array& y, + mx::array& out, float alpha_, float beta_) { // Accelerate library provides catlas_saxpby which does @@ -150,10 +150,10 @@ void axpby_impl_accelerate( // The data in the output array is allocated to match the strides in y // such that x, y, and out are contiguous in the same mode and // no transposition is needed - out.set_data(allocator::malloc_or_wait(out.nbytes())); + out.set_data(mx::allocator::malloc_or_wait(out.nbytes())); // We then copy over the elements using the contiguous vector specialization - copy_inplace(y, out, CopyType::Vector); + copy_inplace(y, out, mx::CopyType::Vector); // Get x and y pointers for catlas_saxpby const T* x_ptr = x.data(); @@ -175,15 +175,15 @@ void axpby_impl_accelerate( /** Evaluate primitive on CPU using accelerate specializations */ void Axpby::eval_cpu( - const std::vector& inputs, - std::vector& outputs) { + const std::vector& inputs, + std::vector& outputs) { assert(inputs.size() == 2); auto& x = inputs[0]; auto& y = inputs[1]; auto& out = outputs[0]; // Accelerate specialization for contiguous single precision float arrays - if (out.dtype() == float32 && + if (out.dtype() == mx::float32 && ((x.flags().row_contiguous && y.flags().row_contiguous) || (x.flags().col_contiguous && y.flags().col_contiguous))) { axpby_impl_accelerate(x, y, out, alpha_, beta_); @@ -198,8 +198,8 @@ void Axpby::eval_cpu( /** Evaluate primitive on CPU falling back to common backend */ void Axpby::eval_cpu( - const std::vector& inputs, - const std::vector& outputs) { + const std::vector& inputs, + std::vector& outputs) { eval(inputs, outputs); } @@ -213,8 +213,8 @@ void Axpby::eval_cpu( /** Evaluate primitive on GPU */ void Axpby::eval_gpu( - const std::vector& inputs, - std::vector& outputs) { + const std::vector& inputs, + std::vector& outputs) { // Prepare inputs assert(inputs.size() == 2); auto& x = inputs[0]; @@ -225,7 +225,7 @@ void Axpby::eval_gpu( // and each stream carries its device identifiers auto& s = stream(); // We get the needed metal device using the stream - auto& d = metal::device(s.device); + auto& d = mx::metal::device(s.device); // Prepare to specialize based on contiguity bool contiguous_kernel = @@ -235,12 +235,12 @@ void Axpby::eval_gpu( // Allocate output memory with strides based on specialization if (contiguous_kernel) { out.set_data( - allocator::malloc_or_wait(x.data_size() * out.itemsize()), + mx::allocator::malloc_or_wait(x.data_size() * out.itemsize()), x.data_size(), x.strides(), x.flags()); } else { - out.set_data(allocator::malloc_or_wait(out.nbytes())); + out.set_data(mx::allocator::malloc_or_wait(out.nbytes())); } // Resolve name of kernel (corresponds to axpby.metal) @@ -302,8 +302,8 @@ void Axpby::eval_gpu( /** Fail evaluation on GPU */ void Axpby::eval_gpu( - const std::vector& inputs, - std::vector& out) { + const std::vector& inputs, + std::vector& out) { throw std::runtime_error("Axpby has no GPU implementation."); } @@ -314,9 +314,9 @@ void Axpby::eval_gpu( /////////////////////////////////////////////////////////////////////////////// /** The Jacobian-vector product. */ -std::vector Axpby::jvp( - const std::vector& primals, - const std::vector& tangents, +std::vector Axpby::jvp( + const std::vector& primals, + const std::vector& tangents, const std::vector& argnums) { // Forward mode diff that pushes along the tangents // The jvp transform on the primitive can built with ops @@ -328,8 +328,8 @@ std::vector Axpby::jvp( // scaled by beta if (argnums.size() > 1) { auto scale = argnums[0] == 0 ? alpha_ : beta_; - auto scale_arr = array(scale, tangents[0].dtype()); - return {multiply(scale_arr, tangents[0], stream())}; + auto scale_arr = mx::array(scale, tangents[0].dtype()); + return {mx::multiply(scale_arr, tangents[0], stream())}; } // If, argnums = {0, 1}, we take contributions from both // which gives us jvp = tangent_x * alpha + tangent_y * beta @@ -339,24 +339,24 @@ std::vector Axpby::jvp( } /** The vector-Jacobian product. */ -std::vector Axpby::vjp( - const std::vector& primals, - const std::vector& cotangents, +std::vector Axpby::vjp( + const std::vector& primals, + const std::vector& cotangents, const std::vector& argnums, - const std::vector&) { + const std::vector&) { // Reverse mode diff - std::vector vjps; + std::vector vjps; for (auto arg : argnums) { auto scale = arg == 0 ? alpha_ : beta_; - auto scale_arr = array(scale, cotangents[0].dtype()); - vjps.push_back(multiply(scale_arr, cotangents[0], stream())); + auto scale_arr = mx::array(scale, cotangents[0].dtype()); + vjps.push_back(mx::multiply(scale_arr, cotangents[0], stream())); } return vjps; } /** Vectorize primitive along given axis */ -std::pair, std::vector> Axpby::vmap( - const std::vector& inputs, +std::pair, std::vector> Axpby::vmap( + const std::vector& inputs, const std::vector& axes) { throw std::runtime_error("Axpby has no vmap implementation."); } @@ -367,4 +367,4 @@ bool Axpby::is_equivalent(const Primitive& other) const { return alpha_ == r_other.alpha_ && beta_ == r_other.beta_; } -} // namespace mlx::core +} // namespace my_ext diff --git a/examples/extensions/axpby/axpby.h b/examples/extensions/axpby/axpby.h index a2c34123e..76421b493 100644 --- a/examples/extensions/axpby/axpby.h +++ b/examples/extensions/axpby/axpby.h @@ -5,7 +5,9 @@ #include "mlx/ops.h" #include "mlx/primitives.h" -namespace mlx::core { +namespace mx = mlx::core; + +namespace my_ext { /////////////////////////////////////////////////////////////////////////////// // Operation @@ -18,22 +20,22 @@ namespace mlx::core { * Follow numpy style broadcasting between x and y * Inputs are upcasted to floats if needed **/ -array axpby( - const array& x, // Input array x - const array& y, // Input array y +mx::array axpby( + const mx::array& x, // Input array x + const mx::array& y, // Input array y const float alpha, // Scaling factor for x const float beta, // Scaling factor for y - StreamOrDevice s = {} // Stream on which to schedule the operation + mx::StreamOrDevice s = {} // Stream on which to schedule the operation ); /////////////////////////////////////////////////////////////////////////////// // Primitive /////////////////////////////////////////////////////////////////////////////// -class Axpby : public Primitive { +class Axpby : public mx::Primitive { public: - explicit Axpby(Stream stream, float alpha, float beta) - : Primitive(stream), alpha_(alpha), beta_(beta) {}; + explicit Axpby(mx::Stream stream, float alpha, float beta) + : mx::Primitive(stream), alpha_(alpha), beta_(beta) {}; /** * A primitive must know how to evaluate itself on the CPU/GPU @@ -42,23 +44,25 @@ class Axpby : public Primitive { * To avoid unnecessary allocations, the evaluation function * is responsible for allocating space for the array. */ - void eval_cpu(const std::vector& inputs, std::vector& outputs) - override; - void eval_gpu(const std::vector& inputs, std::vector& outputs) - override; + void eval_cpu( + const std::vector& inputs, + std::vector& outputs) override; + void eval_gpu( + const std::vector& inputs, + std::vector& outputs) override; /** The Jacobian-vector product. */ - std::vector jvp( - const std::vector& primals, - const std::vector& tangents, + std::vector jvp( + const std::vector& primals, + const std::vector& tangents, const std::vector& argnums) override; /** The vector-Jacobian product. */ - std::vector vjp( - const std::vector& primals, - const std::vector& cotangents, + std::vector vjp( + const std::vector& primals, + const std::vector& cotangents, const std::vector& argnums, - const std::vector& outputs) override; + const std::vector& outputs) override; /** * The primitive must know how to vectorize itself across @@ -66,8 +70,8 @@ class Axpby : public Primitive { * representing the vectorized computation and the axis which * corresponds to the output vectorized dimension. */ - std::pair, std::vector> vmap( - const std::vector& inputs, + std::pair, std::vector> vmap( + const std::vector& inputs, const std::vector& axes) override; /** Print the primitive. */ @@ -76,14 +80,16 @@ class Axpby : public Primitive { } /** Equivalence check **/ - bool is_equivalent(const Primitive& other) const override; + bool is_equivalent(const mx::Primitive& other) const override; private: float alpha_; float beta_; /** Fall back implementation for evaluation on CPU */ - void eval(const std::vector& inputs, std::vector& outputs); + void eval( + const std::vector& inputs, + std::vector& outputs); }; -} // namespace mlx::core +} // namespace my_ext diff --git a/examples/extensions/bindings.cpp b/examples/extensions/bindings.cpp index bd801b31e..91892fa90 100644 --- a/examples/extensions/bindings.cpp +++ b/examples/extensions/bindings.cpp @@ -8,14 +8,12 @@ namespace nb = nanobind; using namespace nb::literals; -using namespace mlx::core; - NB_MODULE(_ext, m) { m.doc() = "Sample extension for MLX"; m.def( "axpby", - &axpby, + &my_ext::axpby, "x"_a, "y"_a, "alpha"_a,