mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Compare commits
2 Commits
1ff2b713b6
...
910b3e3299
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
910b3e3299 | ||
|
|
50fa315d18 |
3
.github/workflows/nightly.yml
vendored
3
.github/workflows/nightly.yml
vendored
@@ -50,6 +50,7 @@ jobs:
|
|||||||
- uses: ./.github/actions/build-linux
|
- uses: ./.github/actions/build-linux
|
||||||
|
|
||||||
build_mac_release:
|
build_mac_release:
|
||||||
|
if: github.repository == 'ml-explore/mlx'
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
python-version: ["3.10", "3.13"]
|
python-version: ["3.10", "3.13"]
|
||||||
@@ -65,6 +66,7 @@ jobs:
|
|||||||
- uses: ./.github/actions/build-macos
|
- uses: ./.github/actions/build-macos
|
||||||
|
|
||||||
build_cuda_with_tests:
|
build_cuda_with_tests:
|
||||||
|
if: github.repository == 'ml-explore/mlx'
|
||||||
runs-on: gpu-t4-4-core
|
runs-on: gpu-t4-4-core
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v5
|
- uses: actions/checkout@v5
|
||||||
@@ -74,6 +76,7 @@ jobs:
|
|||||||
- uses: ./.github/actions/build-cuda
|
- uses: ./.github/actions/build-cuda
|
||||||
|
|
||||||
build_cuda_release:
|
build_cuda_release:
|
||||||
|
if: github.repository == 'ml-explore/mlx'
|
||||||
runs-on: ubuntu-22-large
|
runs-on: ubuntu-22-large
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v5
|
- uses: actions/checkout@v5
|
||||||
|
|||||||
3
.github/workflows/pull_request.yml
vendored
3
.github/workflows/pull_request.yml
vendored
@@ -21,6 +21,7 @@ jobs:
|
|||||||
- uses: ./.github/actions/build-linux
|
- uses: ./.github/actions/build-linux
|
||||||
|
|
||||||
mac_build_and_test:
|
mac_build_and_test:
|
||||||
|
if: github.repository == 'ml-explore/mlx'
|
||||||
runs-on: [self-hosted, macos]
|
runs-on: [self-hosted, macos]
|
||||||
needs: check_lint
|
needs: check_lint
|
||||||
steps:
|
steps:
|
||||||
@@ -29,6 +30,7 @@ jobs:
|
|||||||
- uses: ./.github/actions/build-macos
|
- uses: ./.github/actions/build-macos
|
||||||
|
|
||||||
cuda_build_and_test:
|
cuda_build_and_test:
|
||||||
|
if: github.repository == 'ml-explore/mlx'
|
||||||
runs-on: gpu-t4-4-core
|
runs-on: gpu-t4-4-core
|
||||||
needs: check_lint
|
needs: check_lint
|
||||||
steps:
|
steps:
|
||||||
@@ -39,6 +41,7 @@ jobs:
|
|||||||
- uses: ./.github/actions/build-cuda
|
- uses: ./.github/actions/build-cuda
|
||||||
|
|
||||||
build_documentation:
|
build_documentation:
|
||||||
|
if: github.repository == 'ml-explore/mlx'
|
||||||
runs-on: [self-hosted, macos]
|
runs-on: [self-hosted, macos]
|
||||||
needs: check_lint
|
needs: check_lint
|
||||||
steps:
|
steps:
|
||||||
|
|||||||
4
.github/workflows/release.yml
vendored
4
.github/workflows/release.yml
vendored
@@ -20,6 +20,7 @@ jobs:
|
|||||||
run: echo "Publishing setup complete"
|
run: echo "Publishing setup complete"
|
||||||
|
|
||||||
build_documentation:
|
build_documentation:
|
||||||
|
if: github.repository == 'ml-explore/mlx'
|
||||||
runs-on: [self-hosted, macos]
|
runs-on: [self-hosted, macos]
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v5
|
- uses: actions/checkout@v5
|
||||||
@@ -40,6 +41,7 @@ jobs:
|
|||||||
uses: actions/deploy-pages@v4
|
uses: actions/deploy-pages@v4
|
||||||
|
|
||||||
build_linux_release:
|
build_linux_release:
|
||||||
|
if: github.repository == 'ml-explore/mlx'
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
python_version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
|
python_version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
|
||||||
@@ -68,6 +70,7 @@ jobs:
|
|||||||
path: wheelhouse/mlx_cpu-*.whl
|
path: wheelhouse/mlx_cpu-*.whl
|
||||||
|
|
||||||
build_mac_release:
|
build_mac_release:
|
||||||
|
if: github.repository == 'ml-explore/mlx'
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
python-version: ["3.10", "3.11", "3.12", "3.13"]
|
python-version: ["3.10", "3.11", "3.12", "3.13"]
|
||||||
@@ -96,6 +99,7 @@ jobs:
|
|||||||
path: dist/mlx_metal-*.whl
|
path: dist/mlx_metal-*.whl
|
||||||
|
|
||||||
build_cuda_release:
|
build_cuda_release:
|
||||||
|
if: github.repository == 'ml-explore/mlx'
|
||||||
runs-on: ubuntu-22-large
|
runs-on: ubuntu-22-large
|
||||||
env:
|
env:
|
||||||
PYPI_RELEASE: 1
|
PYPI_RELEASE: 1
|
||||||
|
|||||||
@@ -2,6 +2,8 @@
|
|||||||
|
|
||||||
#include <cstring>
|
#include <cstring>
|
||||||
#include "mlx/array.h"
|
#include "mlx/array.h"
|
||||||
|
#include "mlx/backend/cpu/binary.h"
|
||||||
|
#include "mlx/backend/cpu/binary_ops.h"
|
||||||
#include "mlx/backend/cpu/copy.h"
|
#include "mlx/backend/cpu/copy.h"
|
||||||
#include "mlx/backend/cpu/encoder.h"
|
#include "mlx/backend/cpu/encoder.h"
|
||||||
#include "mlx/backend/cpu/gemm.h"
|
#include "mlx/backend/cpu/gemm.h"
|
||||||
@@ -135,15 +137,58 @@ void AddMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Handle empty matrix case (K=0)
|
||||||
|
if (inputs[0].shape(-1) == 0) {
|
||||||
|
auto& c = inputs[2];
|
||||||
|
if (beta_ == 1.0f) {
|
||||||
|
CopyType ctype = c.data_size() == 1
|
||||||
|
? CopyType::Scalar
|
||||||
|
: (c.flags().row_contiguous ? CopyType::Vector : CopyType::General);
|
||||||
|
copy_cpu(c, out, ctype, stream());
|
||||||
|
} else {
|
||||||
|
array beta_scalar = array(beta_, c.dtype());
|
||||||
|
auto bopt = get_binary_op_type(c, beta_scalar);
|
||||||
|
set_binary_op_output_data(c, beta_scalar, out, bopt);
|
||||||
|
auto& encoder = cpu::get_command_encoder(stream());
|
||||||
|
encoder.set_input_array(c);
|
||||||
|
encoder.set_input_array(beta_scalar);
|
||||||
|
encoder.set_output_array(out);
|
||||||
|
encoder.dispatch([c = array::unsafe_weak_copy(c),
|
||||||
|
beta_scalar = array::unsafe_weak_copy(beta_scalar),
|
||||||
|
out = array::unsafe_weak_copy(out),
|
||||||
|
bopt]() mutable {
|
||||||
|
switch (out.dtype()) {
|
||||||
|
case float16:
|
||||||
|
binary_op<float16_t, detail::Multiply>(c, beta_scalar, out, bopt);
|
||||||
|
break;
|
||||||
|
case float32:
|
||||||
|
binary_op<float, detail::Multiply>(c, beta_scalar, out, bopt);
|
||||||
|
break;
|
||||||
|
case float64:
|
||||||
|
binary_op<double, detail::Multiply>(c, beta_scalar, out, bopt);
|
||||||
|
break;
|
||||||
|
case bfloat16:
|
||||||
|
binary_op<bfloat16_t, detail::Multiply>(c, beta_scalar, out, bopt);
|
||||||
|
break;
|
||||||
|
case complex64:
|
||||||
|
binary_op<complex64_t, detail::Multiply>(c, beta_scalar, out, bopt);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
throw std::runtime_error(
|
||||||
|
"[AddMM::eval_cpu] Unsupported dtype for beta scaling");
|
||||||
|
}
|
||||||
|
});
|
||||||
|
encoder.add_temporary(std::move(beta_scalar));
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
// Fill output with C
|
// Fill output with C
|
||||||
auto& c = inputs[2];
|
auto& c = inputs[2];
|
||||||
CopyType ctype = c.data_size() == 1
|
CopyType ctype = c.data_size() == 1
|
||||||
? CopyType::Scalar
|
? CopyType::Scalar
|
||||||
: (c.flags().row_contiguous ? CopyType::Vector : CopyType::General);
|
: (c.flags().row_contiguous ? CopyType::Vector : CopyType::General);
|
||||||
copy_cpu(c, out, ctype, stream());
|
copy_cpu(c, out, ctype, stream());
|
||||||
if (inputs[0].shape(-1) == 0) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
matmul_general(inputs[0], inputs[1], out, stream(), alpha_, beta_);
|
matmul_general(inputs[0], inputs[1], out, stream(), alpha_, beta_);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -8,6 +8,7 @@
|
|||||||
#include "mlx/backend/common/broadcasting.h"
|
#include "mlx/backend/common/broadcasting.h"
|
||||||
#include "mlx/backend/common/matmul.h"
|
#include "mlx/backend/common/matmul.h"
|
||||||
#include "mlx/backend/gpu/copy.h"
|
#include "mlx/backend/gpu/copy.h"
|
||||||
|
#include "mlx/backend/metal/binary.h"
|
||||||
#include "mlx/backend/metal/device.h"
|
#include "mlx/backend/metal/device.h"
|
||||||
#include "mlx/backend/metal/kernels.h"
|
#include "mlx/backend/metal/kernels.h"
|
||||||
#include "mlx/backend/metal/kernels/defines.h"
|
#include "mlx/backend/metal/kernels/defines.h"
|
||||||
@@ -925,19 +926,27 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Copy c into out and return
|
auto& s = stream();
|
||||||
|
auto& d = metal::device(s.device);
|
||||||
|
|
||||||
|
// Handle empty matrix case (K=0)
|
||||||
if (inputs[0].shape(-1) == 0) {
|
if (inputs[0].shape(-1) == 0) {
|
||||||
copy_gpu(
|
auto& c = inputs[2];
|
||||||
inputs[2],
|
if (beta_ == 1.0f) {
|
||||||
out,
|
copy_gpu(
|
||||||
inputs[2].flags().row_contiguous ? CopyType::Vector : CopyType::General,
|
c,
|
||||||
stream());
|
out,
|
||||||
|
c.flags().row_contiguous ? CopyType::Vector : CopyType::General,
|
||||||
|
s);
|
||||||
|
} else {
|
||||||
|
array beta_scalar = array(beta_, c.dtype());
|
||||||
|
binary_op_gpu({c, beta_scalar}, out, "Multiply", s);
|
||||||
|
d.add_temporary(std::move(beta_scalar), s.index);
|
||||||
|
}
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
out.set_data(allocator::malloc(out.nbytes()));
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
auto& s = stream();
|
|
||||||
auto& d = metal::device(s.device);
|
|
||||||
|
|
||||||
auto& a_pre = inputs[0];
|
auto& a_pre = inputs[0];
|
||||||
auto& b_pre = inputs[1];
|
auto& b_pre = inputs[1];
|
||||||
|
|||||||
@@ -785,11 +785,46 @@ class TestBlas(mlx_tests.MLXTestCase):
|
|||||||
self.assertEqual(out.item(), 1.0)
|
self.assertEqual(out.item(), 1.0)
|
||||||
self.assertEqual(out.shape, ())
|
self.assertEqual(out.shape, ())
|
||||||
|
|
||||||
a = mx.zeros(shape=(5, 0))
|
a = mx.ones((2, 0))
|
||||||
b = mx.zeros(shape=(0, 5))
|
b = mx.ones((0, 2))
|
||||||
c = mx.random.uniform(shape=(5, 5))
|
c = mx.ones((2, 2))
|
||||||
out = mx.addmm(c, a, b)
|
|
||||||
self.assertTrue(mx.allclose(out, c))
|
test_cases = [
|
||||||
|
(0.0, 1.0),
|
||||||
|
(0.0, 2.0),
|
||||||
|
(0.0, 0.5),
|
||||||
|
(0.0, 0.0),
|
||||||
|
(1.0, 2.0),
|
||||||
|
]
|
||||||
|
|
||||||
|
for alpha, beta in test_cases:
|
||||||
|
with self.subTest(alpha=alpha, beta=beta):
|
||||||
|
result = mx.addmm(c, a, b, alpha=alpha, beta=beta)
|
||||||
|
expected = c * beta # a @ b = 0 for empty matrices
|
||||||
|
self.assertTrue(mx.allclose(result, expected))
|
||||||
|
|
||||||
|
shapes_tests = [
|
||||||
|
((3, 0), (0, 3), (3, 3)),
|
||||||
|
((5, 0), (0, 5), (5, 5)),
|
||||||
|
((1, 0), (0, 10), (1, 10)),
|
||||||
|
((10, 0), (0, 1), (10, 1)),
|
||||||
|
]
|
||||||
|
|
||||||
|
for shape_a, shape_b, shape_c in shapes_tests:
|
||||||
|
with self.subTest(shape_a=shape_a, shape_b=shape_b, shape_c=shape_c):
|
||||||
|
a = mx.ones(shape_a)
|
||||||
|
b = mx.ones(shape_b)
|
||||||
|
c = mx.ones(shape_c)
|
||||||
|
result = mx.addmm(c, a, b, alpha=0.5, beta=2.0)
|
||||||
|
expected = c * 2.0
|
||||||
|
self.assertTrue(mx.allclose(result, expected))
|
||||||
|
|
||||||
|
a = mx.ones((2, 5, 0))
|
||||||
|
b = mx.ones((2, 0, 5))
|
||||||
|
c = mx.ones((2, 5, 5))
|
||||||
|
result = mx.addmm(c, a, b, alpha=0.0, beta=3.0)
|
||||||
|
expected = c * 3.0
|
||||||
|
self.assertTrue(mx.allclose(result, expected))
|
||||||
|
|
||||||
def test_block_masked_matmul(self):
|
def test_block_masked_matmul(self):
|
||||||
def ref_block_masked_mm(
|
def ref_block_masked_mm(
|
||||||
|
|||||||
Reference in New Issue
Block a user