mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 01:17:26 +08:00
Remove "using namespace mlx::core" in benchmarks/examples (#1685)
* Remove "using namespace mlx::core" in benchmarks/examples * Fix building example extension * A missing one in comment * Fix building on M chips
This commit is contained in:
parent
f76a49e555
commit
4f9b60dd53
@ -5,35 +5,35 @@
|
||||
#include "mlx/mlx.h"
|
||||
#include "time_utils.h"
|
||||
|
||||
using namespace mlx::core;
|
||||
namespace mx = mlx::core;
|
||||
|
||||
void time_value_and_grad() {
|
||||
auto x = ones({200, 1000});
|
||||
eval(x);
|
||||
auto fn = [](array x) {
|
||||
auto x = mx::ones({200, 1000});
|
||||
mx::eval(x);
|
||||
auto fn = [](mx::array x) {
|
||||
for (int i = 0; i < 20; ++i) {
|
||||
x = log(exp(x));
|
||||
x = mx::log(mx::exp(x));
|
||||
}
|
||||
return sum(x);
|
||||
return mx::sum(x);
|
||||
};
|
||||
|
||||
auto grad_fn = grad(fn);
|
||||
auto grad_fn = mx::grad(fn);
|
||||
auto independent_value_and_grad = [&]() {
|
||||
auto value = fn(x);
|
||||
auto dfdx = grad_fn(x);
|
||||
return std::vector<array>{value, dfdx};
|
||||
return std::vector<mx::array>{value, dfdx};
|
||||
};
|
||||
TIME(independent_value_and_grad);
|
||||
|
||||
auto value_and_grad_fn = value_and_grad(fn);
|
||||
auto value_and_grad_fn = mx::value_and_grad(fn);
|
||||
auto combined_value_and_grad = [&]() {
|
||||
auto [value, dfdx] = value_and_grad_fn(x);
|
||||
return std::vector<array>{value, dfdx};
|
||||
return std::vector<mx::array>{value, dfdx};
|
||||
};
|
||||
TIME(combined_value_and_grad);
|
||||
}
|
||||
|
||||
int main() {
|
||||
std::cout << "Benchmarks for " << default_device() << std::endl;
|
||||
std::cout << "Benchmarks for " << mx::default_device() << std::endl;
|
||||
time_value_and_grad();
|
||||
}
|
||||
|
@ -4,21 +4,21 @@
|
||||
#include "mlx/mlx.h"
|
||||
#include "time_utils.h"
|
||||
|
||||
using namespace mlx::core;
|
||||
namespace mx = mlx::core;
|
||||
|
||||
void time_add_op() {
|
||||
std::vector<int> sizes(1, 1);
|
||||
for (int i = 0; i < 9; ++i) {
|
||||
sizes.push_back(10 * sizes.back());
|
||||
}
|
||||
set_default_device(Device::cpu);
|
||||
set_default_device(mx::Device::cpu);
|
||||
for (auto size : sizes) {
|
||||
auto a = random::uniform({size});
|
||||
auto b = random::uniform({size});
|
||||
eval(a, b);
|
||||
auto a = mx::random::uniform({size});
|
||||
auto b = mx::random::uniform({size});
|
||||
mx::eval(a, b);
|
||||
std::cout << "Size " << size << std::endl;
|
||||
TIMEM("cpu", add, a, b, Device::cpu);
|
||||
TIMEM("gpu", add, a, b, Device::gpu);
|
||||
TIMEM("cpu", mx::add, a, b, mx::Device::cpu);
|
||||
TIMEM("gpu", mx::add, a, b, mx::Device::gpu);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -6,105 +6,105 @@
|
||||
#include "mlx/mlx.h"
|
||||
#include "time_utils.h"
|
||||
|
||||
using namespace mlx::core;
|
||||
namespace mx = mlx::core;
|
||||
|
||||
void time_irregular_binary_ops_1D() {
|
||||
auto device = default_device();
|
||||
auto device = mx::default_device();
|
||||
int size = 1000000;
|
||||
int step = 2;
|
||||
auto a = random::uniform({size});
|
||||
auto b = random::uniform({size});
|
||||
eval(a, b);
|
||||
auto a = mx::random::uniform({size});
|
||||
auto b = mx::random::uniform({size});
|
||||
mx::eval(a, b);
|
||||
a = slice(a, {0}, {size}, {step});
|
||||
b = slice(b, {0}, {size}, {step});
|
||||
TIMEM("1D strided", add, a, b, device);
|
||||
TIMEM("1D strided", mx::add, a, b, device);
|
||||
}
|
||||
|
||||
void time_irregular_binary_ops_2D() {
|
||||
auto device = default_device();
|
||||
auto device = mx::default_device();
|
||||
int size = 2048;
|
||||
auto a = random::uniform({size, size});
|
||||
auto b = random::uniform({size, size});
|
||||
eval(a, b);
|
||||
TIMEM("2D regular", add, a, b, device);
|
||||
auto a = mx::random::uniform({size, size});
|
||||
auto b = mx::random::uniform({size, size});
|
||||
mx::eval(a, b);
|
||||
TIMEM("2D regular", mx::add, a, b, device);
|
||||
|
||||
b = transpose(b);
|
||||
eval(b);
|
||||
TIMEM("2D transpose", add, a, b, device);
|
||||
b = mx::transpose(b);
|
||||
mx::eval(b);
|
||||
TIMEM("2D mx::transpose", mx::add, a, b, device);
|
||||
|
||||
b = random::uniform({size});
|
||||
eval(b);
|
||||
TIMEM("2D broadcast dim 0", add, a, b, device);
|
||||
b = mx::random::uniform({size});
|
||||
mx::eval(b);
|
||||
TIMEM("2D broadcast dim 0", mx::add, a, b, device);
|
||||
|
||||
b = reshape(b, {size, 1});
|
||||
eval(b);
|
||||
TIMEM("2D broadcast dim 1", add, a, b, device);
|
||||
b = mx::reshape(b, {size, 1});
|
||||
mx::eval(b);
|
||||
TIMEM("2D broadcast dim 1", mx::add, a, b, device);
|
||||
}
|
||||
|
||||
void time_irregular_binary_ops_3D() {
|
||||
auto device = default_device();
|
||||
auto device = mx::default_device();
|
||||
int d0 = 32;
|
||||
int d1 = 512;
|
||||
int d2 = 512;
|
||||
auto a = random::uniform({d0, d1, d2});
|
||||
auto b = random::uniform({d0, d1, d2});
|
||||
TIMEM("3D regular", add, a, b, device);
|
||||
auto a = mx::random::uniform({d0, d1, d2});
|
||||
auto b = mx::random::uniform({d0, d1, d2});
|
||||
TIMEM("3D regular", mx::add, a, b, device);
|
||||
|
||||
b = transpose(b, {0, 2, 1});
|
||||
TIMEM("3D transpose", add, a, b, device);
|
||||
b = mx::transpose(b, {0, 2, 1});
|
||||
TIMEM("3D mx::transpose", mx::add, a, b, device);
|
||||
|
||||
b = random::uniform({d1, d2});
|
||||
TIMEM("3D broadcast dim 0", add, a, b, device);
|
||||
b = mx::random::uniform({d1, d2});
|
||||
TIMEM("3D broadcast dim 0", mx::add, a, b, device);
|
||||
|
||||
b = random::uniform({d0, 1, d2});
|
||||
TIMEM("3D broadcast dim 1", add, a, b, device);
|
||||
b = mx::random::uniform({d0, 1, d2});
|
||||
TIMEM("3D broadcast dim 1", mx::add, a, b, device);
|
||||
|
||||
b = random::uniform({d0, d1, 1});
|
||||
TIMEM("3D broadcast dim 2", add, a, b, device);
|
||||
b = mx::random::uniform({d0, d1, 1});
|
||||
TIMEM("3D broadcast dim 2", mx::add, a, b, device);
|
||||
|
||||
b = random::uniform({d2});
|
||||
TIMEM("3D broadcast dims 0, 1", add, a, b, device);
|
||||
b = mx::random::uniform({d2});
|
||||
TIMEM("3D broadcast dims 0, 1", mx::add, a, b, device);
|
||||
|
||||
b = random::uniform({d1, 1});
|
||||
TIMEM("3D broadcast dims 0, 2", add, a, b, device);
|
||||
b = mx::random::uniform({d1, 1});
|
||||
TIMEM("3D broadcast dims 0, 2", mx::add, a, b, device);
|
||||
|
||||
b = random::uniform({d0, 1, 1});
|
||||
TIMEM("3D broadcast dims 1, 2", add, a, b, device);
|
||||
b = mx::random::uniform({d0, 1, 1});
|
||||
TIMEM("3D broadcast dims 1, 2", mx::add, a, b, device);
|
||||
}
|
||||
|
||||
void time_irregular_binary_ops_4D() {
|
||||
auto device = default_device();
|
||||
auto device = mx::default_device();
|
||||
std::vector<int> shape = {8, 8, 512, 512};
|
||||
auto a = random::uniform(shape);
|
||||
auto b = random::uniform(shape);
|
||||
auto a = mx::random::uniform(shape);
|
||||
auto b = mx::random::uniform(shape);
|
||||
|
||||
TIMEM("4D regular", add, a, b, device);
|
||||
TIMEM("4D regular", mx::add, a, b, device);
|
||||
|
||||
b = transpose(b, {0, 1, 3, 2});
|
||||
TIMEM("4D transpose", add, a, b, device);
|
||||
b = mx::transpose(b, {0, 1, 3, 2});
|
||||
TIMEM("4D mx::transpose", mx::add, a, b, device);
|
||||
|
||||
std::string om = "4D broadcast dims ";
|
||||
for (int i = 0; i < shape.size(); ++i) {
|
||||
shape[i] = 1;
|
||||
b = random::uniform(shape);
|
||||
b = mx::random::uniform(shape);
|
||||
std::ostringstream msg;
|
||||
msg << om << i;
|
||||
TIMEM(msg.str(), add, a, b, device);
|
||||
TIMEM(msg.str(), mx::add, a, b, device);
|
||||
|
||||
for (int j = i + 1; j < shape.size(); ++j) {
|
||||
shape[j] = 1;
|
||||
std::ostringstream msg;
|
||||
msg << om << i << ", " << j;
|
||||
b = random::uniform(shape);
|
||||
TIMEM(msg.str(), add, a, b, device);
|
||||
b = mx::random::uniform(shape);
|
||||
TIMEM(msg.str(), mx::add, a, b, device);
|
||||
shape[j] = a.shape(j);
|
||||
|
||||
for (int k = j + 1; k < shape.size(); ++k) {
|
||||
shape[k] = 1;
|
||||
std::ostringstream msg;
|
||||
msg << om << i << ", " << j << ", " << k;
|
||||
b = random::uniform(shape);
|
||||
TIMEM(msg.str(), add, a, b, device);
|
||||
b = mx::random::uniform(shape);
|
||||
TIMEM(msg.str(), mx::add, a, b, device);
|
||||
shape[k] = a.shape(k);
|
||||
}
|
||||
}
|
||||
@ -113,83 +113,83 @@ void time_irregular_binary_ops_4D() {
|
||||
}
|
||||
|
||||
void time_irregular_reshape() {
|
||||
auto device = default_device();
|
||||
auto device = mx::default_device();
|
||||
std::vector<int> shape;
|
||||
auto reshape_fn = [&shape, device](const array& a) {
|
||||
return reshape(a, shape, device);
|
||||
auto reshape_fn = [&shape, device](const mx::array& a) {
|
||||
return mx::reshape(a, shape, device);
|
||||
};
|
||||
|
||||
int size = 64;
|
||||
int d = 2 * size;
|
||||
|
||||
auto a = random::uniform({d, d, d});
|
||||
auto a = mx::random::uniform({d, d, d});
|
||||
|
||||
shape = {8 * size, size, size};
|
||||
TIMEM("3D contiguous", reshape_fn, a);
|
||||
|
||||
a = transpose(a);
|
||||
a = mx::transpose(a);
|
||||
shape = {8 * size, size, size};
|
||||
TIMEM("3D transpose", reshape_fn, a);
|
||||
TIMEM("3D mx::transpose", reshape_fn, a);
|
||||
|
||||
a = transpose(a, {1, 2, 0});
|
||||
a = mx::transpose(a, {1, 2, 0});
|
||||
shape = {8 * size, size, size};
|
||||
TIMEM("3D transpose dims 1 2", reshape_fn, a);
|
||||
TIMEM("3D mx::transpose dims 1 2", reshape_fn, a);
|
||||
|
||||
a = broadcast_to(random::uniform({d, d}), {d, d, d});
|
||||
a = mx::broadcast_to(mx::random::uniform({d, d}), {d, d, d});
|
||||
TIMEM("3D broadcast dim 0", reshape_fn, a);
|
||||
|
||||
a = broadcast_to(random::uniform({d, 1, d}), {d, d, d});
|
||||
a = mx::broadcast_to(mx::random::uniform({d, 1, d}), {d, d, d});
|
||||
TIMEM("3D broadcast dim 1", reshape_fn, a);
|
||||
|
||||
a = broadcast_to(random::uniform({d, d, 1}), {d, d, d});
|
||||
a = mx::broadcast_to(mx::random::uniform({d, d, 1}), {d, d, d});
|
||||
TIMEM("3D broadcast dim 2", reshape_fn, a);
|
||||
|
||||
a = broadcast_to(random::uniform({d}), {d, d, d});
|
||||
a = mx::broadcast_to(mx::random::uniform({d}), {d, d, d});
|
||||
TIMEM("3D broadcast dims 0, 1", reshape_fn, a);
|
||||
|
||||
a = broadcast_to(random::uniform({d, 1}), {d, d, d});
|
||||
a = mx::broadcast_to(mx::random::uniform({d, 1}), {d, d, d});
|
||||
TIMEM("3D broadcast dims 0, 2", reshape_fn, a);
|
||||
|
||||
a = broadcast_to(random::uniform({d, 1, 1}), {d, d, d});
|
||||
a = mx::broadcast_to(mx::random::uniform({d, 1, 1}), {d, d, d});
|
||||
TIMEM("3D broadcast dims 1, 2", reshape_fn, a);
|
||||
|
||||
a = broadcast_to(random::uniform({1, 1, 1}), {d, d, d});
|
||||
a = mx::broadcast_to(mx::random::uniform({1, 1, 1}), {d, d, d});
|
||||
TIMEM("3D broadcast dims 1, 2, 3", reshape_fn, a);
|
||||
}
|
||||
|
||||
void time_irregular_astype_1D() {
|
||||
auto device = default_device();
|
||||
auto device = mx::default_device();
|
||||
int size = 1000000;
|
||||
int step = 2;
|
||||
auto a = random::uniform({size});
|
||||
auto a = mx::random::uniform({size});
|
||||
a = slice(a, {0}, {size}, {step});
|
||||
TIMEM("1D strided", astype, a, int32, device);
|
||||
TIMEM("1D strided", mx::astype, a, mx::int32, device);
|
||||
}
|
||||
|
||||
void time_irregular_astype_2D() {
|
||||
auto device = default_device();
|
||||
auto device = mx::default_device();
|
||||
int size = 2048;
|
||||
std::vector<int> shape = {size, size};
|
||||
|
||||
auto a = random::uniform(shape);
|
||||
TIMEM("2D regular", astype, a, int32, device);
|
||||
auto a = mx::random::uniform(shape);
|
||||
TIMEM("2D regular", mx::astype, a, mx::int32, device);
|
||||
|
||||
a = transpose(a);
|
||||
TIMEM("2D transpose", astype, a, int32, device);
|
||||
a = mx::transpose(a);
|
||||
TIMEM("2D mx::transpose", mx::astype, a, mx::int32, device);
|
||||
|
||||
a = broadcast_to(random::uniform({size}), shape);
|
||||
TIMEM("2D broadcast dim 0", astype, a, int32, device);
|
||||
a = mx::broadcast_to(mx::random::uniform({size}), shape);
|
||||
TIMEM("2D broadcast dim 0", mx::astype, a, mx::int32, device);
|
||||
|
||||
a = broadcast_to(random::uniform({size, 1}), shape);
|
||||
TIMEM("2D broadcast dim 1", astype, a, int32, device);
|
||||
a = mx::broadcast_to(mx::random::uniform({size, 1}), shape);
|
||||
TIMEM("2D broadcast dim 1", mx::astype, a, mx::int32, device);
|
||||
}
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
if (argc > 1) {
|
||||
bool use_gpu = !strcmp(argv[1], "gpu");
|
||||
set_default_device(use_gpu ? Device::gpu : Device::cpu);
|
||||
set_default_device(use_gpu ? mx::Device::gpu : mx::Device::cpu);
|
||||
}
|
||||
std::cout << "Benchmarks for " << default_device() << std::endl;
|
||||
std::cout << "Benchmarks for " << mx::default_device() << std::endl;
|
||||
time_irregular_binary_ops_1D();
|
||||
time_irregular_binary_ops_2D();
|
||||
time_irregular_binary_ops_3D();
|
||||
|
@ -3,20 +3,20 @@
|
||||
#include "mlx/mlx.h"
|
||||
#include "time_utils.h"
|
||||
|
||||
using namespace mlx::core;
|
||||
namespace mx = mlx::core;
|
||||
|
||||
void time_creation_ops() {
|
||||
int M = 2000;
|
||||
int N = 500;
|
||||
auto shape = {M, N};
|
||||
auto full_fp32 = [&]() { return full(shape, 3.3f); };
|
||||
auto full_fp32 = [&]() { return mx::full(shape, 3.3f); };
|
||||
TIME(full_fp32);
|
||||
auto zeros_fp32 = [&]() { return zeros(shape, float32); };
|
||||
auto zeros_fp32 = [&]() { return mx::zeros(shape, mx::float32); };
|
||||
TIME(zeros_fp32);
|
||||
auto ones_fp32 = [&]() { return ones(shape, float32); };
|
||||
auto ones_fp32 = [&]() { return mx::ones(shape, mx::float32); };
|
||||
TIME(ones_fp32);
|
||||
|
||||
auto arange_fp32 = [&]() { return arange(0.0, 10.0, 1e-4); };
|
||||
auto arange_fp32 = [&]() { return mx::arange(0.0, 10.0, 1e-4); };
|
||||
TIME(arange_fp32);
|
||||
}
|
||||
|
||||
@ -24,194 +24,196 @@ void time_type_conversions() {
|
||||
int M = 2000;
|
||||
int N = 500;
|
||||
auto shape = {M, N};
|
||||
auto device = default_device();
|
||||
auto device = mx::default_device();
|
||||
|
||||
auto a = zeros(shape, float32);
|
||||
eval(a);
|
||||
TIMEM("float32 to int32", astype, a, int32, device);
|
||||
TIMEM("float32 to uint32", astype, a, uint32, device);
|
||||
auto a = mx::zeros(shape, mx::float32);
|
||||
mx::eval(a);
|
||||
TIMEM("mx::float32 to mx::int32", mx::astype, a, mx::int32, device);
|
||||
TIMEM("mx::float32 to mx::uint32", mx::astype, a, mx::uint32, device);
|
||||
|
||||
a = zeros(shape, int32);
|
||||
eval(a);
|
||||
TIMEM("int32 to float32", astype, a, float32, device);
|
||||
a = mx::zeros(shape, mx::int32);
|
||||
mx::eval(a);
|
||||
TIMEM("mx::int32 to mx::float32", mx::astype, a, mx::float32, device);
|
||||
|
||||
a = zeros(shape, bool_);
|
||||
eval(a);
|
||||
TIMEM("bool to float32", astype, a, float32, device);
|
||||
TIMEM("bool to int32", astype, a, int32, device);
|
||||
TIMEM("bool to uint32", astype, a, uint32, device);
|
||||
a = mx::zeros(shape, mx::bool_);
|
||||
mx::eval(a);
|
||||
TIMEM("bool to mx::float32", mx::astype, a, mx::float32, device);
|
||||
TIMEM("bool to mx::int32", mx::astype, a, mx::int32, device);
|
||||
TIMEM("bool to mx::uint32", mx::astype, a, mx::uint32, device);
|
||||
}
|
||||
|
||||
void time_random_generation() {
|
||||
int M = 2000;
|
||||
int N = 500;
|
||||
|
||||
auto uniform = [&]() { return random::uniform({M, N}, float32); };
|
||||
auto uniform = [&]() { return mx::random::uniform({M, N}, mx::float32); };
|
||||
TIME(uniform);
|
||||
auto normal = [&]() { return random::normal({M, N}, float32); };
|
||||
auto normal = [&]() { return mx::random::normal({M, N}, mx::float32); };
|
||||
TIME(normal);
|
||||
}
|
||||
|
||||
void time_unary_ops() {
|
||||
int M = 2000;
|
||||
int N = 500;
|
||||
auto device = default_device();
|
||||
auto device = mx::default_device();
|
||||
|
||||
auto a = random::normal({M, N});
|
||||
eval(a);
|
||||
auto a = mx::random::normal({M, N});
|
||||
mx::eval(a);
|
||||
TIME(mlx::core::abs, a, device);
|
||||
TIME(negative, a, device);
|
||||
TIME(sign, a, device);
|
||||
TIME(square, a, device);
|
||||
TIME(mx::negative, a, device);
|
||||
TIME(mx::sign, a, device);
|
||||
TIME(mx::square, a, device);
|
||||
TIME(mlx::core::sqrt, a, device);
|
||||
TIME(rsqrt, a, device);
|
||||
TIME(mx::rsqrt, a, device);
|
||||
TIME(mlx::core::exp, a, device);
|
||||
|
||||
a = random::uniform({M, N});
|
||||
a = mx::random::uniform({M, N});
|
||||
TIME(mlx::core::log, a, device);
|
||||
}
|
||||
|
||||
void time_binary_ops() {
|
||||
int M = 1000, N = 100, K = 10;
|
||||
auto condition = random::randint(0, 2, {M, N, K});
|
||||
auto a = random::uniform({M, N, K});
|
||||
auto b = random::uniform({M, N, K});
|
||||
auto device = default_device();
|
||||
eval(a, b);
|
||||
auto condition = mx::random::randint(0, 2, {M, N, K});
|
||||
auto a = mx::random::uniform({M, N, K});
|
||||
auto b = mx::random::uniform({M, N, K});
|
||||
auto device = mx::default_device();
|
||||
mx::eval(a, b);
|
||||
|
||||
TIME(add, a, b, device);
|
||||
TIME(subtract, a, b, device);
|
||||
TIME(multiply, a, b, device);
|
||||
TIME(divide, a, b, device);
|
||||
TIME(maximum, a, b, device);
|
||||
TIME(minimum, a, b, device);
|
||||
TIME(where, condition, a, b, device);
|
||||
TIME(mx::add, a, b, device);
|
||||
TIME(mx::subtract, a, b, device);
|
||||
TIME(mx::multiply, a, b, device);
|
||||
TIME(mx::divide, a, b, device);
|
||||
TIME(mx::maximum, a, b, device);
|
||||
TIME(mx::minimum, a, b, device);
|
||||
TIME(mx::where, condition, a, b, device);
|
||||
|
||||
condition = array({true});
|
||||
b = random::uniform({1});
|
||||
eval(b);
|
||||
TIMEM("scalar", add, a, b, device);
|
||||
TIMEM("vector-scalar", subtract, a, b, device);
|
||||
TIMEM("scalar-vector", subtract, b, a, device);
|
||||
TIMEM("scalar", multiply, a, b, device);
|
||||
TIMEM("vector-scalar", divide, a, b, device);
|
||||
TIMEM("scalar-vector", divide, b, a, device);
|
||||
TIMEM("scalar-vector", where, condition, a, b, device);
|
||||
condition = mx::array({true});
|
||||
b = mx::random::uniform({1});
|
||||
mx::eval(b);
|
||||
TIMEM("scalar", mx::add, a, b, device);
|
||||
TIMEM("vector-scalar", mx::subtract, a, b, device);
|
||||
TIMEM("scalar-vector", mx::subtract, b, a, device);
|
||||
TIMEM("scalar", mx::multiply, a, b, device);
|
||||
TIMEM("vector-scalar", mx::divide, a, b, device);
|
||||
TIMEM("scalar-vector", mx::divide, b, a, device);
|
||||
TIMEM("scalar-vector", mx::where, condition, a, b, device);
|
||||
|
||||
condition = broadcast_to(array({true}), {1000, 100});
|
||||
a = broadcast_to(random::uniform({1}), {1000, 100});
|
||||
b = broadcast_to(random::uniform({1}), {1000, 100});
|
||||
eval(a, b);
|
||||
TIMEM("scalar-scalar broadcast", add, a, b, device);
|
||||
TIMEM("scalar-scalar broadcast", subtract, a, b, device);
|
||||
TIMEM("scalar-scalar broadcast", multiply, a, b, device);
|
||||
TIMEM("scalar-scalar broadcast", divide, a, b, device);
|
||||
TIMEM("scalar-scalar broadcast", where, condition, a, b, device);
|
||||
condition = mx::broadcast_to(mx::array({true}), {1000, 100});
|
||||
a = mx::broadcast_to(mx::random::uniform({1}), {1000, 100});
|
||||
b = mx::broadcast_to(mx::random::uniform({1}), {1000, 100});
|
||||
mx::eval(a, b);
|
||||
TIMEM("scalar-scalar broadcast", mx::add, a, b, device);
|
||||
TIMEM("scalar-scalar broadcast", mx::subtract, a, b, device);
|
||||
TIMEM("scalar-scalar broadcast", mx::multiply, a, b, device);
|
||||
TIMEM("scalar-scalar broadcast", mx::divide, a, b, device);
|
||||
TIMEM("scalar-scalar broadcast", mx::where, condition, a, b, device);
|
||||
}
|
||||
|
||||
void time_strided_ops() {
|
||||
int M = 50, N = 50, O = 50, P = 50;
|
||||
auto a = random::uniform({M, N, O, P});
|
||||
auto b = random::uniform({M, N, O, P});
|
||||
auto device = default_device();
|
||||
eval(a, b);
|
||||
TIMEM("non-strided", add, a, b, device);
|
||||
a = transpose(a, {1, 0, 2, 3});
|
||||
b = transpose(b, {3, 2, 0, 1});
|
||||
eval(a, b);
|
||||
TIMEM("strided", add, a, b, device);
|
||||
auto a = mx::random::uniform({M, N, O, P});
|
||||
auto b = mx::random::uniform({M, N, O, P});
|
||||
auto device = mx::default_device();
|
||||
mx::eval(a, b);
|
||||
TIMEM("non-strided", mx::add, a, b, device);
|
||||
a = mx::transpose(a, {1, 0, 2, 3});
|
||||
b = mx::transpose(b, {3, 2, 0, 1});
|
||||
mx::eval(a, b);
|
||||
TIMEM("strided", mx::add, a, b, device);
|
||||
}
|
||||
|
||||
void time_comparisons() {
|
||||
int M = 1000, N = 100, K = 10;
|
||||
auto a = random::uniform({M, N, K});
|
||||
auto b = random::uniform({M, N, K});
|
||||
auto device = default_device();
|
||||
eval(a, b);
|
||||
TIME(equal, a, b, device);
|
||||
TIME(greater, a, b, device);
|
||||
TIME(greater_equal, a, b, device);
|
||||
TIME(less, a, b, device);
|
||||
TIME(less_equal, a, b, device);
|
||||
auto a = mx::random::uniform({M, N, K});
|
||||
auto b = mx::random::uniform({M, N, K});
|
||||
auto device = mx::default_device();
|
||||
mx::eval(a, b);
|
||||
TIME(mx::equal, a, b, device);
|
||||
TIME(mx::greater, a, b, device);
|
||||
TIME(mx::greater_equal, a, b, device);
|
||||
TIME(mx::less, a, b, device);
|
||||
TIME(mx::less_equal, a, b, device);
|
||||
}
|
||||
|
||||
void time_matvec() {
|
||||
int M = 2000, N = 200;
|
||||
auto a = random::uniform({M, N});
|
||||
auto b = random::uniform({N});
|
||||
auto c = random::uniform({M});
|
||||
eval(a, b, c);
|
||||
auto matvec = [&]() { return matmul(a, b); };
|
||||
auto a = mx::random::uniform({M, N});
|
||||
auto b = mx::random::uniform({N});
|
||||
auto c = mx::random::uniform({M});
|
||||
mx::eval(a, b, c);
|
||||
auto matvec = [&]() { return mx::matmul(a, b); };
|
||||
TIME(matvec);
|
||||
|
||||
auto matvec_transpose = [&]() { return matmul(transpose(a), c); };
|
||||
auto matvec_transpose = [&]() { return mx::matmul(mx::transpose(a), c); };
|
||||
TIME(matvec_transpose);
|
||||
}
|
||||
|
||||
void time_matmul() {
|
||||
int M = 1000, N = 1000, K = 1000;
|
||||
auto a = random::uniform({M, K});
|
||||
auto b = random::uniform({K, N});
|
||||
auto device = default_device();
|
||||
eval(a, b);
|
||||
TIME(matmul, a, b, device);
|
||||
auto a = mx::random::uniform({M, K});
|
||||
auto b = mx::random::uniform({K, N});
|
||||
auto device = mx::default_device();
|
||||
mx::eval(a, b);
|
||||
TIME(mx::matmul, a, b, device);
|
||||
|
||||
auto transpose_matmul = [&]() { return matmul(transpose(a), b); };
|
||||
auto transpose_matmul = [&]() { return mx::matmul(mx::transpose(a), b); };
|
||||
TIME(transpose_matmul);
|
||||
}
|
||||
|
||||
void time_reductions() {
|
||||
auto a = random::normal({10000, 1000});
|
||||
eval(a);
|
||||
auto sum_all = [&a]() { return sum(a, false); };
|
||||
auto a = mx::random::normal({10000, 1000});
|
||||
mx::eval(a);
|
||||
auto sum_all = [&a]() { return mx::sum(a, false); };
|
||||
TIME(sum_all);
|
||||
|
||||
auto sum_along_0 = [&a]() { return sum(a, 0, false); };
|
||||
auto sum_along_0 = [&a]() { return mx::sum(a, 0, false); };
|
||||
TIME(sum_along_0);
|
||||
|
||||
auto sum_along_1 = [&a]() { return sum(a, 1, false); };
|
||||
auto sum_along_1 = [&a]() { return mx::sum(a, 1, false); };
|
||||
TIME(sum_along_1);
|
||||
|
||||
auto prod_all = [&a]() { return prod(a, false); };
|
||||
auto prod_all = [&a]() { return mx::prod(a, false); };
|
||||
TIME(prod_all);
|
||||
|
||||
auto all_true = [&a]() { return all(a, false); };
|
||||
auto all_true = [&a]() { return mx::all(a, false); };
|
||||
TIME(all_true);
|
||||
|
||||
auto all_along_0 = [&a]() { return all(a, 0, false); };
|
||||
auto all_along_0 = [&a]() { return mx::all(a, 0, false); };
|
||||
TIME(all_along_0);
|
||||
|
||||
auto all_along_1 = [&a]() { return all(a, 1, false); };
|
||||
auto all_along_1 = [&a]() { return mx::all(a, 1, false); };
|
||||
TIME(all_along_1);
|
||||
|
||||
auto any_true = [&a]() { return any(a, false); };
|
||||
auto any_true = [&a]() { return mx::any(a, false); };
|
||||
TIME(any_true);
|
||||
|
||||
auto argmin_along_0 = [&a]() { return argmin(a, 0, false); };
|
||||
auto argmin_along_0 = [&a]() { return mx::argmin(a, 0, false); };
|
||||
TIME(argmin_along_0);
|
||||
|
||||
auto argmin_along_1 = [&a]() { return argmin(a, 1, false); };
|
||||
auto argmin_along_1 = [&a]() { return mx::argmin(a, 1, false); };
|
||||
TIME(argmin_along_1);
|
||||
}
|
||||
|
||||
void time_gather_scatter() {
|
||||
auto a = random::normal({1000, 768});
|
||||
eval(a);
|
||||
auto indices = random::randint(0, 1000, {256});
|
||||
eval(indices);
|
||||
auto a = mx::random::normal({1000, 768});
|
||||
mx::eval(a);
|
||||
auto indices = mx::random::randint(0, 1000, {256});
|
||||
mx::eval(indices);
|
||||
|
||||
auto embedding_lookup = [&a, &indices]() { return take(a, indices, 0); };
|
||||
auto embedding_lookup = [&a, &indices]() { return mx::take(a, indices, 0); };
|
||||
TIME(embedding_lookup);
|
||||
|
||||
indices = random::randint(0, 768 * 1000, {256 * 768});
|
||||
eval(indices);
|
||||
indices = mx::random::randint(0, 768 * 1000, {256 * 768});
|
||||
mx::eval(indices);
|
||||
|
||||
auto single_element_lookup = [&a, &indices]() { return take(a, indices); };
|
||||
auto single_element_lookup = [&a, &indices]() {
|
||||
return mx::take(a, indices);
|
||||
};
|
||||
TIME(single_element_lookup);
|
||||
|
||||
indices = random::randint(0, 1000, {256});
|
||||
auto updates = random::normal({256, 1, 768});
|
||||
eval(indices, updates);
|
||||
indices = mx::random::randint(0, 1000, {256});
|
||||
auto updates = mx::random::normal({256, 1, 768});
|
||||
mx::eval(indices, updates);
|
||||
|
||||
auto embedding_update = [&a, &indices, &updates]() {
|
||||
return scatter(a, indices, updates, 0);
|
||||
@ -223,10 +225,10 @@ void time_gather_scatter() {
|
||||
};
|
||||
TIME(embedding_add);
|
||||
|
||||
a = reshape(a, {-1});
|
||||
indices = random::randint(0, 768 * 1000, {768 * 256});
|
||||
updates = random::normal({256 * 768, 1});
|
||||
eval(a, indices, updates);
|
||||
a = mx::reshape(a, {-1});
|
||||
indices = mx::random::randint(0, 768 * 1000, {768 * 256});
|
||||
updates = mx::random::normal({256 * 768, 1});
|
||||
mx::eval(a, indices, updates);
|
||||
|
||||
auto single_element_update = [&a, &indices, &updates]() {
|
||||
return scatter(a, indices, updates, 0);
|
||||
@ -240,21 +242,21 @@ void time_gather_scatter() {
|
||||
}
|
||||
|
||||
void time_divmod() {
|
||||
auto a = random::normal({1000});
|
||||
auto b = random::normal({1000});
|
||||
eval({a, b});
|
||||
auto a = mx::random::normal({1000});
|
||||
auto b = mx::random::normal({1000});
|
||||
mx::eval({a, b});
|
||||
|
||||
auto divmod_fused = [&a, &b]() { return divmod(a, b); };
|
||||
auto divmod_fused = [&a, &b]() { return mx::divmod(a, b); };
|
||||
TIME(divmod_fused);
|
||||
|
||||
auto divmod_separate = [&a, &b]() {
|
||||
return std::vector<array>{floor_divide(a, b), remainder(a, b)};
|
||||
return std::vector<mx::array>{mx::floor_divide(a, b), mx::remainder(a, b)};
|
||||
};
|
||||
TIME(divmod_separate);
|
||||
}
|
||||
|
||||
int main() {
|
||||
std::cout << "Benchmarks for " << default_device() << std::endl;
|
||||
std::cout << "Benchmarks for " << mx::default_device() << std::endl;
|
||||
time_creation_ops();
|
||||
time_type_conversions();
|
||||
time_unary_ops();
|
||||
|
@ -4,19 +4,19 @@
|
||||
|
||||
#include "mlx/mlx.h"
|
||||
|
||||
using namespace mlx::core;
|
||||
namespace mx = mlx::core;
|
||||
|
||||
int main() {
|
||||
if (!distributed::is_available()) {
|
||||
if (!mx::distributed::is_available()) {
|
||||
std::cout << "No communication backend found" << std::endl;
|
||||
return 1;
|
||||
}
|
||||
|
||||
auto global_group = distributed::init();
|
||||
auto global_group = mx::distributed::init();
|
||||
std::cout << global_group.rank() << " / " << global_group.size() << std::endl;
|
||||
|
||||
array x = ones({10});
|
||||
array out = distributed::all_sum(x, global_group);
|
||||
mx::array x = mx::ones({10});
|
||||
mx::array out = mx::distributed::all_sum(x, global_group);
|
||||
|
||||
std::cout << out << std::endl;
|
||||
}
|
||||
|
@ -10,7 +10,7 @@
|
||||
/**
|
||||
* An example of linear regression with MLX.
|
||||
*/
|
||||
using namespace mlx::core;
|
||||
namespace mx = mlx::core;
|
||||
|
||||
int main() {
|
||||
int num_features = 100;
|
||||
@ -19,35 +19,35 @@ int main() {
|
||||
float learning_rate = 0.01;
|
||||
|
||||
// True parameters
|
||||
auto w_star = random::normal({num_features});
|
||||
auto w_star = mx::random::normal({num_features});
|
||||
|
||||
// The input examples (design matrix)
|
||||
auto X = random::normal({num_examples, num_features});
|
||||
auto X = mx::random::normal({num_examples, num_features});
|
||||
|
||||
// Noisy labels
|
||||
auto eps = 1e-2 * random::normal({num_examples});
|
||||
auto y = matmul(X, w_star) + eps;
|
||||
auto eps = 1e-2 * mx::random::normal({num_examples});
|
||||
auto y = mx::matmul(X, w_star) + eps;
|
||||
|
||||
// Initialize random parameters
|
||||
array w = 1e-2 * random::normal({num_features});
|
||||
mx::array w = 1e-2 * mx::random::normal({num_features});
|
||||
|
||||
auto loss_fn = [&](array w) {
|
||||
auto yhat = matmul(X, w);
|
||||
return (0.5f / num_examples) * sum(square(yhat - y));
|
||||
auto loss_fn = [&](mx::array w) {
|
||||
auto yhat = mx::matmul(X, w);
|
||||
return (0.5f / num_examples) * mx::sum(mx::square(yhat - y));
|
||||
};
|
||||
|
||||
auto grad_fn = grad(loss_fn);
|
||||
auto grad_fn = mx::grad(loss_fn);
|
||||
|
||||
auto tic = timer::time();
|
||||
for (int it = 0; it < num_iters; ++it) {
|
||||
auto grad = grad_fn(w);
|
||||
w = w - learning_rate * grad;
|
||||
eval(w);
|
||||
auto grads = grad_fn(w);
|
||||
w = w - learning_rate * grads;
|
||||
mx::eval(w);
|
||||
}
|
||||
auto toc = timer::time();
|
||||
|
||||
auto loss = loss_fn(w);
|
||||
auto error_norm = std::sqrt(sum(square(w - w_star)).item<float>());
|
||||
auto error_norm = std::sqrt(mx::sum(mx::square(w - w_star)).item<float>());
|
||||
auto throughput = num_iters / timer::seconds(toc - tic);
|
||||
std::cout << "Loss " << loss << ", |w - w*| = " << error_norm
|
||||
<< ", Throughput " << throughput << " (it/s)." << std::endl;
|
||||
|
@ -10,7 +10,7 @@
|
||||
/**
|
||||
* An example of logistic regression with MLX.
|
||||
*/
|
||||
using namespace mlx::core;
|
||||
namespace mx = mlx::core;
|
||||
|
||||
int main() {
|
||||
int num_features = 100;
|
||||
@ -19,35 +19,35 @@ int main() {
|
||||
float learning_rate = 0.1;
|
||||
|
||||
// True parameters
|
||||
auto w_star = random::normal({num_features});
|
||||
auto w_star = mx::random::normal({num_features});
|
||||
|
||||
// The input examples
|
||||
auto X = random::normal({num_examples, num_features});
|
||||
auto X = mx::random::normal({num_examples, num_features});
|
||||
|
||||
// Labels
|
||||
auto y = matmul(X, w_star) > 0;
|
||||
auto y = mx::matmul(X, w_star) > 0;
|
||||
|
||||
// Initialize random parameters
|
||||
array w = 1e-2 * random::normal({num_features});
|
||||
mx::array w = 1e-2 * mx::random::normal({num_features});
|
||||
|
||||
auto loss_fn = [&](array w) {
|
||||
auto logits = matmul(X, w);
|
||||
auto loss_fn = [&](mx::array w) {
|
||||
auto logits = mx::matmul(X, w);
|
||||
auto scale = (1.0f / num_examples);
|
||||
return scale * sum(logaddexp(array(0.0f), logits) - y * logits);
|
||||
return scale * mx::sum(mx::logaddexp(mx::array(0.0f), logits) - y * logits);
|
||||
};
|
||||
|
||||
auto grad_fn = grad(loss_fn);
|
||||
auto grad_fn = mx::grad(loss_fn);
|
||||
|
||||
auto tic = timer::time();
|
||||
for (int it = 0; it < num_iters; ++it) {
|
||||
auto grad = grad_fn(w);
|
||||
w = w - learning_rate * grad;
|
||||
eval(w);
|
||||
auto grads = grad_fn(w);
|
||||
w = w - learning_rate * grads;
|
||||
mx::eval(w);
|
||||
}
|
||||
auto toc = timer::time();
|
||||
|
||||
auto loss = loss_fn(w);
|
||||
auto acc = sum((matmul(X, w) > 0) == y) / num_examples;
|
||||
auto acc = mx::sum((mx::matmul(X, w) > 0) == y) / num_examples;
|
||||
auto throughput = num_iters / timer::seconds(toc - tic);
|
||||
std::cout << "Loss " << loss << ", Accuracy, " << acc << ", Throughput "
|
||||
<< throughput << " (it/s)." << std::endl;
|
||||
|
@ -5,27 +5,27 @@
|
||||
|
||||
#include "mlx/mlx.h"
|
||||
|
||||
using namespace mlx::core;
|
||||
namespace mx = mlx::core;
|
||||
|
||||
int main() {
|
||||
// To use Metal debugging and profiling:
|
||||
// 1. Build with the MLX_METAL_DEBUG CMake option (i.e. -DMLX_METAL_DEBUG=ON).
|
||||
// 2. Run with MTL_CAPTURE_ENABLED=1.
|
||||
metal::start_capture("mlx_trace.gputrace");
|
||||
mx::metal::start_capture("mlx_trace.gputrace");
|
||||
|
||||
// Start at index two because the default GPU and CPU streams have indices
|
||||
// zero and one, respectively. This naming matches the label assigned to each
|
||||
// stream's command queue.
|
||||
auto s2 = new_stream(Device::gpu);
|
||||
auto s3 = new_stream(Device::gpu);
|
||||
auto s2 = new_stream(mx::Device::gpu);
|
||||
auto s3 = new_stream(mx::Device::gpu);
|
||||
|
||||
auto a = arange(1.f, 10.f, 1.f, float32, s2);
|
||||
auto b = arange(1.f, 10.f, 1.f, float32, s3);
|
||||
auto x = add(a, a, s2);
|
||||
auto y = add(b, b, s3);
|
||||
auto a = mx::arange(1.f, 10.f, 1.f, mx::float32, s2);
|
||||
auto b = mx::arange(1.f, 10.f, 1.f, mx::float32, s3);
|
||||
auto x = mx::add(a, a, s2);
|
||||
auto y = mx::add(b, b, s3);
|
||||
|
||||
// The multiply will happen on the default stream.
|
||||
std::cout << multiply(x, y) << std::endl;
|
||||
std::cout << mx::multiply(x, y) << std::endl;
|
||||
|
||||
metal::stop_capture();
|
||||
mx::metal::stop_capture();
|
||||
}
|
||||
|
@ -5,11 +5,11 @@
|
||||
|
||||
#include "mlx/mlx.h"
|
||||
|
||||
using namespace mlx::core;
|
||||
namespace mx = mlx::core;
|
||||
|
||||
void array_basics() {
|
||||
// Make a scalar array:
|
||||
array x(1.0);
|
||||
mx::array x(1.0);
|
||||
|
||||
// Get the value out of it:
|
||||
auto s = x.item<float>();
|
||||
@ -29,31 +29,31 @@ void array_basics() {
|
||||
|
||||
// The datatype should be float32:
|
||||
auto dtype = x.dtype();
|
||||
assert(dtype == float32);
|
||||
assert(dtype == mx::float32);
|
||||
|
||||
// Specify the dtype when constructing the array:
|
||||
x = array(1, int32);
|
||||
assert(x.dtype() == int32);
|
||||
x = mx::array(1, mx::int32);
|
||||
assert(x.dtype() == mx::int32);
|
||||
x.item<int>(); // OK
|
||||
// x.item<float>(); // Undefined!
|
||||
|
||||
// Make a multidimensional array:
|
||||
x = array({1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});
|
||||
x = mx::array({1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});
|
||||
// mlx is row-major by default so the first row of this array
|
||||
// is [1.0, 2.0] and the second row is [3.0, 4.0]
|
||||
|
||||
// Make an array of shape {2, 2} filled with ones:
|
||||
auto y = ones({2, 2});
|
||||
auto y = mx::ones({2, 2});
|
||||
|
||||
// Pointwise add x and y:
|
||||
auto z = add(x, y);
|
||||
auto z = mx::add(x, y);
|
||||
|
||||
// Same thing:
|
||||
z = x + y;
|
||||
|
||||
// mlx is lazy by default. At this point `z` only
|
||||
// has a shape and a type but no actual data:
|
||||
assert(z.dtype() == float32);
|
||||
assert(z.dtype() == mx::float32);
|
||||
assert(z.shape(0) == 2);
|
||||
assert(z.shape(1) == 2);
|
||||
|
||||
@ -63,33 +63,33 @@ void array_basics() {
|
||||
// and inputs. When `eval` is called on an array (or arrays), the array and
|
||||
// all of its dependencies are recursively evaluated to produce the result.
|
||||
// Once an array is evaluated, it has data and is detached from its inputs.
|
||||
eval(z);
|
||||
mx::eval(z);
|
||||
|
||||
// Of course the array can still be an input to other operations. You can even
|
||||
// call eval on the array again, this will just be a no-op:
|
||||
eval(z); // no-op
|
||||
// Of course the array can still be an input to other operations. You can
|
||||
// even call eval on the array again, this will just be a no-op:
|
||||
mx::eval(z); // no-op
|
||||
|
||||
// Some functions or methods on arrays implicitly evaluate them. For example
|
||||
// accessing a value in an array or printing the array implicitly evaluate it:
|
||||
z = ones({1});
|
||||
z = mx::ones({1});
|
||||
z.item<float>(); // implicit evaluation
|
||||
|
||||
z = ones({2, 2});
|
||||
z = mx::ones({2, 2});
|
||||
std::cout << z << std::endl; // implicit evaluation
|
||||
}
|
||||
|
||||
void automatic_differentiation() {
|
||||
auto fn = [](array x) { return square(x); };
|
||||
auto fn = [](mx::array x) { return mx::square(x); };
|
||||
|
||||
// Computing the derivative function of a function
|
||||
auto grad_fn = grad(fn);
|
||||
auto grad_fn = mx::grad(fn);
|
||||
// Call grad_fn on the input to get the derivative
|
||||
auto x = array(1.5);
|
||||
auto x = mx::array(1.5);
|
||||
auto dfdx = grad_fn(x);
|
||||
// dfdx is 2 * x
|
||||
|
||||
// Get the second derivative by composing grad with grad
|
||||
auto d2fdx2 = grad(grad(fn))(x);
|
||||
auto d2fdx2 = mx::grad(mx::grad(fn))(x);
|
||||
// d2fdx2 is 2
|
||||
}
|
||||
|
||||
|
@ -19,7 +19,7 @@
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#endif
|
||||
|
||||
namespace mlx::core {
|
||||
namespace my_ext {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Operation Implementation
|
||||
@ -32,24 +32,24 @@ namespace mlx::core {
|
||||
* Follow numpy style broadcasting between x and y
|
||||
* Inputs are upcasted to floats if needed
|
||||
**/
|
||||
array axpby(
|
||||
const array& x, // Input array x
|
||||
const array& y, // Input array y
|
||||
mx::array axpby(
|
||||
const mx::array& x, // Input mx::array x
|
||||
const mx::array& y, // Input mx::array y
|
||||
const float alpha, // Scaling factor for x
|
||||
const float beta, // Scaling factor for y
|
||||
StreamOrDevice s /* = {} */ // Stream on which to schedule the operation
|
||||
mx::StreamOrDevice s /* = {} */ // Stream on which to schedule the operation
|
||||
) {
|
||||
// Promote dtypes between x and y as needed
|
||||
auto promoted_dtype = promote_types(x.dtype(), y.dtype());
|
||||
|
||||
// Upcast to float32 for non-floating point inputs x and y
|
||||
auto out_dtype = issubdtype(promoted_dtype, float32)
|
||||
auto out_dtype = mx::issubdtype(promoted_dtype, mx::float32)
|
||||
? promoted_dtype
|
||||
: promote_types(promoted_dtype, float32);
|
||||
: promote_types(promoted_dtype, mx::float32);
|
||||
|
||||
// Cast x and y up to the determined dtype (on the same stream s)
|
||||
auto x_casted = astype(x, out_dtype, s);
|
||||
auto y_casted = astype(y, out_dtype, s);
|
||||
auto x_casted = mx::astype(x, out_dtype, s);
|
||||
auto y_casted = mx::astype(y, out_dtype, s);
|
||||
|
||||
// Broadcast the shapes of x and y (on the same stream s)
|
||||
auto broadcasted_inputs = broadcast_arrays({x_casted, y_casted}, s);
|
||||
@ -57,12 +57,12 @@ array axpby(
|
||||
|
||||
// Construct the array as the output of the Axpby primitive
|
||||
// with the broadcasted and upcasted arrays as inputs
|
||||
return array(
|
||||
return mx::array(
|
||||
/* const std::vector<int>& shape = */ out_shape,
|
||||
/* Dtype dtype = */ out_dtype,
|
||||
/* std::unique_ptr<Primitive> primitive = */
|
||||
/* mx::Dtype dtype = */ out_dtype,
|
||||
/* std::unique_ptr<mx::Primitive> primitive = */
|
||||
std::make_shared<Axpby>(to_stream(s), alpha, beta),
|
||||
/* const std::vector<array>& inputs = */ broadcasted_inputs);
|
||||
/* const std::vector<mx::array>& inputs = */ broadcasted_inputs);
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
@ -71,16 +71,16 @@ array axpby(
|
||||
|
||||
template <typename T>
|
||||
void axpby_impl(
|
||||
const array& x,
|
||||
const array& y,
|
||||
array& out,
|
||||
const mx::array& x,
|
||||
const mx::array& y,
|
||||
mx::array& out,
|
||||
float alpha_,
|
||||
float beta_) {
|
||||
// We only allocate memory when we are ready to fill the output
|
||||
// malloc_or_wait synchronously allocates available memory
|
||||
// There may be a wait executed here if the allocation is requested
|
||||
// under memory-pressured conditions
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(mx::allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
// Collect input and output data pointers
|
||||
const T* x_ptr = x.data<T>();
|
||||
@ -94,8 +94,8 @@ void axpby_impl(
|
||||
// Do the element-wise operation for each output
|
||||
for (size_t out_idx = 0; out_idx < out.size(); out_idx++) {
|
||||
// Map linear indices to offsets in x and y
|
||||
auto x_offset = elem_to_loc(out_idx, x.shape(), x.strides());
|
||||
auto y_offset = elem_to_loc(out_idx, y.shape(), y.strides());
|
||||
auto x_offset = mx::elem_to_loc(out_idx, x.shape(), x.strides());
|
||||
auto y_offset = mx::elem_to_loc(out_idx, y.shape(), y.strides());
|
||||
|
||||
// We allocate the output to be contiguous and regularly strided
|
||||
// (defaults to row major) and hence it doesn't need additional mapping
|
||||
@ -105,8 +105,8 @@ void axpby_impl(
|
||||
|
||||
/** Fall back implementation for evaluation on CPU */
|
||||
void Axpby::eval(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
const std::vector<mx::array>& inputs,
|
||||
std::vector<mx::array>& outputs) {
|
||||
// Check the inputs (registered in the op while constructing the out array)
|
||||
assert(inputs.size() == 2);
|
||||
auto& x = inputs[0];
|
||||
@ -114,14 +114,14 @@ void Axpby::eval(
|
||||
auto& out = outputs[0];
|
||||
|
||||
// Dispatch to the correct dtype
|
||||
if (out.dtype() == float32) {
|
||||
if (out.dtype() == mx::float32) {
|
||||
return axpby_impl<float>(x, y, out, alpha_, beta_);
|
||||
} else if (out.dtype() == float16) {
|
||||
return axpby_impl<float16_t>(x, y, out, alpha_, beta_);
|
||||
} else if (out.dtype() == bfloat16) {
|
||||
return axpby_impl<bfloat16_t>(x, y, out, alpha_, beta_);
|
||||
} else if (out.dtype() == complex64) {
|
||||
return axpby_impl<complex64_t>(x, y, out, alpha_, beta_);
|
||||
} else if (out.dtype() == mx::float16) {
|
||||
return axpby_impl<mx::float16_t>(x, y, out, alpha_, beta_);
|
||||
} else if (out.dtype() == mx::bfloat16) {
|
||||
return axpby_impl<mx::bfloat16_t>(x, y, out, alpha_, beta_);
|
||||
} else if (out.dtype() == mx::complex64) {
|
||||
return axpby_impl<mx::complex64_t>(x, y, out, alpha_, beta_);
|
||||
} else {
|
||||
throw std::runtime_error(
|
||||
"Axpby is only supported for floating point types.");
|
||||
@ -136,9 +136,9 @@ void Axpby::eval(
|
||||
|
||||
template <typename T>
|
||||
void axpby_impl_accelerate(
|
||||
const array& x,
|
||||
const array& y,
|
||||
array& out,
|
||||
const mx::array& x,
|
||||
const mx::array& y,
|
||||
mx::array& out,
|
||||
float alpha_,
|
||||
float beta_) {
|
||||
// Accelerate library provides catlas_saxpby which does
|
||||
@ -150,10 +150,10 @@ void axpby_impl_accelerate(
|
||||
// The data in the output array is allocated to match the strides in y
|
||||
// such that x, y, and out are contiguous in the same mode and
|
||||
// no transposition is needed
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(mx::allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
// We then copy over the elements using the contiguous vector specialization
|
||||
copy_inplace(y, out, CopyType::Vector);
|
||||
copy_inplace(y, out, mx::CopyType::Vector);
|
||||
|
||||
// Get x and y pointers for catlas_saxpby
|
||||
const T* x_ptr = x.data<T>();
|
||||
@ -175,15 +175,15 @@ void axpby_impl_accelerate(
|
||||
|
||||
/** Evaluate primitive on CPU using accelerate specializations */
|
||||
void Axpby::eval_cpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
const std::vector<mx::array>& inputs,
|
||||
std::vector<mx::array>& outputs) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& x = inputs[0];
|
||||
auto& y = inputs[1];
|
||||
auto& out = outputs[0];
|
||||
|
||||
// Accelerate specialization for contiguous single precision float arrays
|
||||
if (out.dtype() == float32 &&
|
||||
if (out.dtype() == mx::float32 &&
|
||||
((x.flags().row_contiguous && y.flags().row_contiguous) ||
|
||||
(x.flags().col_contiguous && y.flags().col_contiguous))) {
|
||||
axpby_impl_accelerate<float>(x, y, out, alpha_, beta_);
|
||||
@ -198,8 +198,8 @@ void Axpby::eval_cpu(
|
||||
|
||||
/** Evaluate primitive on CPU falling back to common backend */
|
||||
void Axpby::eval_cpu(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<array>& outputs) {
|
||||
const std::vector<mx::array>& inputs,
|
||||
std::vector<mx::array>& outputs) {
|
||||
eval(inputs, outputs);
|
||||
}
|
||||
|
||||
@ -213,8 +213,8 @@ void Axpby::eval_cpu(
|
||||
|
||||
/** Evaluate primitive on GPU */
|
||||
void Axpby::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
const std::vector<mx::array>& inputs,
|
||||
std::vector<mx::array>& outputs) {
|
||||
// Prepare inputs
|
||||
assert(inputs.size() == 2);
|
||||
auto& x = inputs[0];
|
||||
@ -225,7 +225,7 @@ void Axpby::eval_gpu(
|
||||
// and each stream carries its device identifiers
|
||||
auto& s = stream();
|
||||
// We get the needed metal device using the stream
|
||||
auto& d = metal::device(s.device);
|
||||
auto& d = mx::metal::device(s.device);
|
||||
|
||||
// Prepare to specialize based on contiguity
|
||||
bool contiguous_kernel =
|
||||
@ -235,12 +235,12 @@ void Axpby::eval_gpu(
|
||||
// Allocate output memory with strides based on specialization
|
||||
if (contiguous_kernel) {
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(x.data_size() * out.itemsize()),
|
||||
mx::allocator::malloc_or_wait(x.data_size() * out.itemsize()),
|
||||
x.data_size(),
|
||||
x.strides(),
|
||||
x.flags());
|
||||
} else {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(mx::allocator::malloc_or_wait(out.nbytes()));
|
||||
}
|
||||
|
||||
// Resolve name of kernel (corresponds to axpby.metal)
|
||||
@ -302,8 +302,8 @@ void Axpby::eval_gpu(
|
||||
|
||||
/** Fail evaluation on GPU */
|
||||
void Axpby::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& out) {
|
||||
const std::vector<mx::array>& inputs,
|
||||
std::vector<mx::array>& out) {
|
||||
throw std::runtime_error("Axpby has no GPU implementation.");
|
||||
}
|
||||
|
||||
@ -314,9 +314,9 @@ void Axpby::eval_gpu(
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/** The Jacobian-vector product. */
|
||||
std::vector<array> Axpby::jvp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& tangents,
|
||||
std::vector<mx::array> Axpby::jvp(
|
||||
const std::vector<mx::array>& primals,
|
||||
const std::vector<mx::array>& tangents,
|
||||
const std::vector<int>& argnums) {
|
||||
// Forward mode diff that pushes along the tangents
|
||||
// The jvp transform on the primitive can built with ops
|
||||
@ -328,8 +328,8 @@ std::vector<array> Axpby::jvp(
|
||||
// scaled by beta
|
||||
if (argnums.size() > 1) {
|
||||
auto scale = argnums[0] == 0 ? alpha_ : beta_;
|
||||
auto scale_arr = array(scale, tangents[0].dtype());
|
||||
return {multiply(scale_arr, tangents[0], stream())};
|
||||
auto scale_arr = mx::array(scale, tangents[0].dtype());
|
||||
return {mx::multiply(scale_arr, tangents[0], stream())};
|
||||
}
|
||||
// If, argnums = {0, 1}, we take contributions from both
|
||||
// which gives us jvp = tangent_x * alpha + tangent_y * beta
|
||||
@ -339,24 +339,24 @@ std::vector<array> Axpby::jvp(
|
||||
}
|
||||
|
||||
/** The vector-Jacobian product. */
|
||||
std::vector<array> Axpby::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
std::vector<mx::array> Axpby::vjp(
|
||||
const std::vector<mx::array>& primals,
|
||||
const std::vector<mx::array>& cotangents,
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>&) {
|
||||
const std::vector<mx::array>&) {
|
||||
// Reverse mode diff
|
||||
std::vector<array> vjps;
|
||||
std::vector<mx::array> vjps;
|
||||
for (auto arg : argnums) {
|
||||
auto scale = arg == 0 ? alpha_ : beta_;
|
||||
auto scale_arr = array(scale, cotangents[0].dtype());
|
||||
vjps.push_back(multiply(scale_arr, cotangents[0], stream()));
|
||||
auto scale_arr = mx::array(scale, cotangents[0].dtype());
|
||||
vjps.push_back(mx::multiply(scale_arr, cotangents[0], stream()));
|
||||
}
|
||||
return vjps;
|
||||
}
|
||||
|
||||
/** Vectorize primitive along given axis */
|
||||
std::pair<std::vector<array>, std::vector<int>> Axpby::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
std::pair<std::vector<mx::array>, std::vector<int>> Axpby::vmap(
|
||||
const std::vector<mx::array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
throw std::runtime_error("Axpby has no vmap implementation.");
|
||||
}
|
||||
@ -367,4 +367,4 @@ bool Axpby::is_equivalent(const Primitive& other) const {
|
||||
return alpha_ == r_other.alpha_ && beta_ == r_other.beta_;
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
} // namespace my_ext
|
||||
|
@ -5,7 +5,9 @@
|
||||
#include "mlx/ops.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
namespace mx = mlx::core;
|
||||
|
||||
namespace my_ext {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Operation
|
||||
@ -18,22 +20,22 @@ namespace mlx::core {
|
||||
* Follow numpy style broadcasting between x and y
|
||||
* Inputs are upcasted to floats if needed
|
||||
**/
|
||||
array axpby(
|
||||
const array& x, // Input array x
|
||||
const array& y, // Input array y
|
||||
mx::array axpby(
|
||||
const mx::array& x, // Input array x
|
||||
const mx::array& y, // Input array y
|
||||
const float alpha, // Scaling factor for x
|
||||
const float beta, // Scaling factor for y
|
||||
StreamOrDevice s = {} // Stream on which to schedule the operation
|
||||
mx::StreamOrDevice s = {} // Stream on which to schedule the operation
|
||||
);
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Primitive
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
class Axpby : public Primitive {
|
||||
class Axpby : public mx::Primitive {
|
||||
public:
|
||||
explicit Axpby(Stream stream, float alpha, float beta)
|
||||
: Primitive(stream), alpha_(alpha), beta_(beta) {};
|
||||
explicit Axpby(mx::Stream stream, float alpha, float beta)
|
||||
: mx::Primitive(stream), alpha_(alpha), beta_(beta) {};
|
||||
|
||||
/**
|
||||
* A primitive must know how to evaluate itself on the CPU/GPU
|
||||
@ -42,23 +44,25 @@ class Axpby : public Primitive {
|
||||
* To avoid unnecessary allocations, the evaluation function
|
||||
* is responsible for allocating space for the array.
|
||||
*/
|
||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override;
|
||||
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override;
|
||||
void eval_cpu(
|
||||
const std::vector<mx::array>& inputs,
|
||||
std::vector<mx::array>& outputs) override;
|
||||
void eval_gpu(
|
||||
const std::vector<mx::array>& inputs,
|
||||
std::vector<mx::array>& outputs) override;
|
||||
|
||||
/** The Jacobian-vector product. */
|
||||
std::vector<array> jvp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& tangents,
|
||||
std::vector<mx::array> jvp(
|
||||
const std::vector<mx::array>& primals,
|
||||
const std::vector<mx::array>& tangents,
|
||||
const std::vector<int>& argnums) override;
|
||||
|
||||
/** The vector-Jacobian product. */
|
||||
std::vector<array> vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
std::vector<mx::array> vjp(
|
||||
const std::vector<mx::array>& primals,
|
||||
const std::vector<mx::array>& cotangents,
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>& outputs) override;
|
||||
const std::vector<mx::array>& outputs) override;
|
||||
|
||||
/**
|
||||
* The primitive must know how to vectorize itself across
|
||||
@ -66,8 +70,8 @@ class Axpby : public Primitive {
|
||||
* representing the vectorized computation and the axis which
|
||||
* corresponds to the output vectorized dimension.
|
||||
*/
|
||||
std::pair<std::vector<array>, std::vector<int>> vmap(
|
||||
const std::vector<array>& inputs,
|
||||
std::pair<std::vector<mx::array>, std::vector<int>> vmap(
|
||||
const std::vector<mx::array>& inputs,
|
||||
const std::vector<int>& axes) override;
|
||||
|
||||
/** Print the primitive. */
|
||||
@ -76,14 +80,16 @@ class Axpby : public Primitive {
|
||||
}
|
||||
|
||||
/** Equivalence check **/
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
bool is_equivalent(const mx::Primitive& other) const override;
|
||||
|
||||
private:
|
||||
float alpha_;
|
||||
float beta_;
|
||||
|
||||
/** Fall back implementation for evaluation on CPU */
|
||||
void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
|
||||
void eval(
|
||||
const std::vector<mx::array>& inputs,
|
||||
std::vector<mx::array>& outputs);
|
||||
};
|
||||
|
||||
} // namespace mlx::core
|
||||
} // namespace my_ext
|
||||
|
@ -8,14 +8,12 @@
|
||||
namespace nb = nanobind;
|
||||
using namespace nb::literals;
|
||||
|
||||
using namespace mlx::core;
|
||||
|
||||
NB_MODULE(_ext, m) {
|
||||
m.doc() = "Sample extension for MLX";
|
||||
|
||||
m.def(
|
||||
"axpby",
|
||||
&axpby,
|
||||
&my_ext::axpby,
|
||||
"x"_a,
|
||||
"y"_a,
|
||||
"alpha"_a,
|
||||
|
Loading…
Reference in New Issue
Block a user