mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-18 00:16:39 +08:00
Compare commits
5 Commits
5f9ab0436c
...
543fad7536
Author | SHA1 | Date | |
---|---|---|---|
![]() |
543fad7536 | ||
![]() |
c35f4d089a | ||
![]() |
8590c0941e | ||
![]() |
095163b8d1 | ||
![]() |
d2e0b0465c |
@ -212,6 +212,29 @@ jobs:
|
||||
METAL_DEBUG_ERROR_MODE=0 \
|
||||
python -m xmlrunner discover -v python/tests -o test-results/gpu_jit
|
||||
|
||||
cuda_build_and_test:
|
||||
machine:
|
||||
image: linux-cuda-12:default
|
||||
resource_class: gpu.nvidia.small.gen2
|
||||
steps:
|
||||
- checkout
|
||||
- run:
|
||||
name: Install Python package
|
||||
command: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
|
||||
sudo apt-get install openmpi-bin openmpi-common libopenmpi-dev
|
||||
python -m venv env
|
||||
source env/bin/activate
|
||||
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
||||
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
|
||||
pip install -e ".[dev]"
|
||||
- run:
|
||||
name: Run Python tests
|
||||
command: |
|
||||
source env/bin/activate
|
||||
LOW_MEMORY=1 DEVICE=cpu python -m unittest discover python/tests -v
|
||||
|
||||
build_release:
|
||||
parameters:
|
||||
python_version:
|
||||
@ -348,6 +371,7 @@ workflows:
|
||||
parameters:
|
||||
macosx_deployment_target: ["13.5", "14.0"]
|
||||
- linux_build_and_test
|
||||
- cuda_build_and_test
|
||||
- build_documentation
|
||||
|
||||
build_pypi_release:
|
||||
@ -455,6 +479,8 @@ workflows:
|
||||
macosx_deployment_target: ["13.5", "14.0"]
|
||||
- linux_build_and_test:
|
||||
requires: [ hold ]
|
||||
- cuda_build_and_test:
|
||||
requires: [ hold ]
|
||||
nightly_build:
|
||||
when:
|
||||
and:
|
||||
|
@ -42,6 +42,7 @@ option(MLX_BUILD_SAFETENSORS "Include support for safetensors format" ON)
|
||||
option(MLX_BUILD_BLAS_FROM_SOURCE "Build OpenBLAS from source code" OFF)
|
||||
option(MLX_METAL_JIT "Use JIT compilation for Metal kernels" OFF)
|
||||
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
|
||||
option(USE_SYSTEM_FMT "Use system's provided fmt library" OFF)
|
||||
|
||||
# --------------------- Processor tests -------------------------
|
||||
message(
|
||||
@ -234,12 +235,16 @@ target_include_directories(
|
||||
# Do not add mlx_EXPORTS define for shared library.
|
||||
set_target_properties(mlx PROPERTIES DEFINE_SYMBOL "")
|
||||
|
||||
FetchContent_Declare(
|
||||
fmt
|
||||
GIT_REPOSITORY https://github.com/fmtlib/fmt.git
|
||||
GIT_TAG 10.2.1
|
||||
EXCLUDE_FROM_ALL)
|
||||
FetchContent_MakeAvailable(fmt)
|
||||
if(USE_SYSTEM_FMT)
|
||||
find_package(fmt REQUIRED)
|
||||
else()
|
||||
FetchContent_Declare(
|
||||
fmt
|
||||
GIT_REPOSITORY https://github.com/fmtlib/fmt.git
|
||||
GIT_TAG 10.2.1
|
||||
EXCLUDE_FROM_ALL)
|
||||
FetchContent_MakeAvailable(fmt)
|
||||
endif()
|
||||
target_link_libraries(mlx PRIVATE $<BUILD_INTERFACE:fmt::fmt-header-only>)
|
||||
|
||||
if(MLX_BUILD_PYTHON_BINDINGS)
|
||||
|
@ -1,5 +1,6 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
|
107
benchmarks/python/conv_unaligned_bench.py
Normal file
107
benchmarks/python/conv_unaligned_bench.py
Normal file
@ -0,0 +1,107 @@
|
||||
import math
|
||||
import time
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
N_warmup = 10
|
||||
N_iter_bench = 100
|
||||
N_iter_func = 5
|
||||
|
||||
|
||||
def bench(f, a, b):
|
||||
for i in range(N_warmup):
|
||||
f(a, b)
|
||||
torch.mps.synchronize()
|
||||
|
||||
s = time.perf_counter_ns()
|
||||
for i in range(N_iter_bench):
|
||||
f(a, b)
|
||||
e = time.perf_counter_ns()
|
||||
return (e - s) * 1e-9
|
||||
|
||||
|
||||
def make_mx_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
|
||||
def mx_conv_2D(a, b):
|
||||
ys = []
|
||||
for i in range(N_iter_func):
|
||||
y = mx.conv2d(a, b, stride=strides, padding=padding, groups=groups)
|
||||
ys.append(y)
|
||||
mx.eval(ys)
|
||||
return ys
|
||||
|
||||
return mx_conv_2D
|
||||
|
||||
|
||||
def make_pt_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
|
||||
@torch.no_grad()
|
||||
def pt_conv_2D(a, b):
|
||||
ys = []
|
||||
for i in range(N_iter_func):
|
||||
y = torch.conv2d(a, b, stride=strides, padding=padding, groups=groups)
|
||||
ys.append(y)
|
||||
torch.mps.synchronize()
|
||||
return ys
|
||||
|
||||
return pt_conv_2D
|
||||
|
||||
|
||||
def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype):
|
||||
scale = 1.0 / math.sqrt(kH * kH * C)
|
||||
a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype)
|
||||
b_np = np.random.uniform(-scale, scale, (O, kH, kW, int(C / groups))).astype(
|
||||
np_dtype
|
||||
)
|
||||
|
||||
a_mx = mx.array(a_np)
|
||||
b_mx = mx.array(b_np)
|
||||
|
||||
a_pt = torch.from_numpy(a_np.transpose((0, 3, 1, 2))).to("mps")
|
||||
b_pt = torch.from_numpy(b_np.transpose((0, 3, 1, 2))).to("mps")
|
||||
|
||||
torch.mps.synchronize()
|
||||
|
||||
f_mx = make_mx_conv_2D(strides, padding, groups)
|
||||
f_pt = make_pt_conv_2D(strides, padding, groups)
|
||||
|
||||
time_torch = bench(f_pt, a_pt, b_pt)
|
||||
time_mlx = bench(f_mx, a_mx, b_mx)
|
||||
|
||||
out_mx = mx.conv2d(a_mx, b_mx, stride=strides, padding=padding, groups=groups)
|
||||
out_pt = torch.conv2d(
|
||||
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
|
||||
)
|
||||
out_pt = torch.permute(out_pt, (0, 2, 3, 1))
|
||||
out_pt = out_pt.numpy(force=True)
|
||||
|
||||
atol = 2e-5 if np_dtype == np.float32 else 1e-4
|
||||
|
||||
if not np.allclose(out_pt, out_mx, atol=atol):
|
||||
print(
|
||||
f"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
|
||||
)
|
||||
|
||||
return time_mlx, time_torch
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
dtype = "float32"
|
||||
shapes = (
|
||||
(4, 32, 32, 21, 3, 3, 128),
|
||||
(4, 32, 32, 21, 3, 3, 37),
|
||||
(4, 32, 32, 370, 3, 3, 370),
|
||||
(4, 32, 32, 370, 7, 7, 128),
|
||||
(2, 320, 640, 21, 7, 7, 21),
|
||||
)
|
||||
for N, H, W, C, kh, kw, O in shapes:
|
||||
time_mlx, time_torch = bench_shape(
|
||||
N, H, W, C, kh, kw, O, (1, 1), (0, 0), 1, dtype
|
||||
)
|
||||
diff = time_torch / time_mlx - 1.0
|
||||
|
||||
print(
|
||||
f"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kh:2d}, {kw:2d}, {C:3d}), {dtype}, {100. * diff:+5.2f}%"
|
||||
)
|
||||
if time_mlx >= 2.0 * time_torch:
|
||||
print("ATTENTION ^^^^^^^")
|
@ -55,6 +55,9 @@ endif()
|
||||
|
||||
if(MLX_BUILD_CUDA)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/cuda)
|
||||
else()
|
||||
target_sources(mlx
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/backend/cuda/no_cuda.cpp)
|
||||
endif()
|
||||
|
||||
if(MLX_BUILD_METAL OR MLX_BUILD_CUDA)
|
||||
|
@ -12,6 +12,7 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_dynamic.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_input.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cuda.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/event.cu
|
||||
|
11
mlx/backend/cuda/cuda.cpp
Normal file
11
mlx/backend/cuda/cuda.cpp
Normal file
@ -0,0 +1,11 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/cuda.h"
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
bool is_available() {
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace mlx::core::cu
|
10
mlx/backend/cuda/cuda.h
Normal file
10
mlx/backend/cuda/cuda.h
Normal file
@ -0,0 +1,10 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
/* Check if the CUDA backend is available. */
|
||||
bool is_available();
|
||||
|
||||
} // namespace mlx::core::cu
|
11
mlx/backend/cuda/no_cuda.cpp
Normal file
11
mlx/backend/cuda/no_cuda.cpp
Normal file
@ -0,0 +1,11 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/cuda.h"
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
bool is_available() {
|
||||
return false;
|
||||
}
|
||||
|
||||
} // namespace mlx::core::cu
|
@ -391,6 +391,7 @@ void implicit_gemm_conv_2D_general_gpu(
|
||||
// Get channel iteration info
|
||||
int channel_k_iters = ((conv_params.C + bk - 1) / bk);
|
||||
int gemm_k_iters = channel_k_iters;
|
||||
bool align_C = conv_params.C % bk == 0;
|
||||
|
||||
// Fix host side helper params
|
||||
int sign = (conv_params.flip ? -1 : 1);
|
||||
@ -419,14 +420,33 @@ void implicit_gemm_conv_2D_general_gpu(
|
||||
/* const int swizzle_log = */ swizzle_log};
|
||||
|
||||
// Determine kernel
|
||||
std::ostringstream kname;
|
||||
kname << "implicit_gemm_conv_2d_general_" << type_to_name(out) << "_bm" << bm
|
||||
<< "_bn" << bn << "_bk" << bk << "_wm" << wm << "_wn" << wn;
|
||||
std::string kname;
|
||||
kname.reserve(64);
|
||||
concatenate(
|
||||
kname,
|
||||
"implicit_gemm_conv_2d_general_",
|
||||
type_to_name(out),
|
||||
"_bm",
|
||||
bm,
|
||||
"_bn",
|
||||
bn,
|
||||
"_bk",
|
||||
bk,
|
||||
"_wm",
|
||||
wm,
|
||||
"_wn",
|
||||
wn);
|
||||
std::string hash_name;
|
||||
hash_name.reserve(64);
|
||||
concatenate(hash_name, kname, "_alC_", align_C);
|
||||
metal::MTLFCList func_consts = {
|
||||
{&align_C, MTL::DataType::DataTypeBool, 200},
|
||||
};
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel =
|
||||
get_steel_conv_general_kernel(d, kname.str(), out, bm, bn, bk, wm, wn);
|
||||
auto kernel = get_steel_conv_general_kernel(
|
||||
d, kname, hash_name, func_consts, out, bm, bn, bk, wm, wn);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
// Deduce grid launch dimensions
|
||||
@ -728,8 +748,10 @@ void dispatch_conv_2D_gpu(
|
||||
|
||||
// Direct to winograd conv
|
||||
bool inp_large =
|
||||
(conv_params.N * conv_params.iS[0] * conv_params.iS[1]) >= 1ul << 12;
|
||||
(conv_params.N * conv_params.iS[0] * conv_params.iS[1]) >= 4096;
|
||||
bool channels_large = (conv_params.C + conv_params.O) >= 256;
|
||||
bool out_large =
|
||||
(conv_params.N * conv_params.oS[0] * conv_params.oS[1]) >= 256;
|
||||
if (!conv_params.flip && is_stride_one && is_kdil_one && is_idil_one &&
|
||||
conv_params.wS[0] == 3 && conv_params.wS[1] == 3 &&
|
||||
conv_params.C % 32 == 0 && conv_params.O % 32 == 0 && inp_large &&
|
||||
@ -743,7 +765,7 @@ void dispatch_conv_2D_gpu(
|
||||
return implicit_gemm_conv_2D_gpu(s, d, in, wt, out, conv_params);
|
||||
}
|
||||
|
||||
else if (conv_params.C % 16 == 0 && conv_params.O % 16 == 0) {
|
||||
else if ((conv_params.C % 16 == 0 && conv_params.O % 16 == 0) || out_large) {
|
||||
return implicit_gemm_conv_2D_general_gpu(s, d, in, wt, out, conv_params);
|
||||
}
|
||||
|
||||
|
@ -727,6 +727,8 @@ MTL::ComputePipelineState* get_steel_conv_kernel(
|
||||
MTL::ComputePipelineState* get_steel_conv_general_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const std::string& hash_name,
|
||||
const metal::MTLFCList& func_consts,
|
||||
const array& out,
|
||||
int bm,
|
||||
int bn,
|
||||
@ -749,7 +751,7 @@ MTL::ComputePipelineState* get_steel_conv_general_kernel(
|
||||
wn);
|
||||
return kernel_source.str();
|
||||
});
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
return d.get_kernel(kernel_name, lib, hash_name, func_consts);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_fft_kernel(
|
||||
|
@ -205,6 +205,8 @@ MTL::ComputePipelineState* get_gemv_masked_kernel(
|
||||
MTL::ComputePipelineState* get_steel_conv_general_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const std::string& hash_name,
|
||||
const metal::MTLFCList& func_consts,
|
||||
const array& out,
|
||||
int bm,
|
||||
int bn,
|
||||
|
@ -2,6 +2,8 @@
|
||||
|
||||
#include "mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h"
|
||||
|
||||
constant bool align_C [[function_constant(200)]];
|
||||
|
||||
template <
|
||||
typename T,
|
||||
int BM,
|
||||
@ -118,23 +120,58 @@ implicit_gemm_conv_2d_general(
|
||||
// Prepare threadgroup mma operation
|
||||
mma_t mma_op(simd_gid, simd_lid);
|
||||
|
||||
int gemm_k_iterations =
|
||||
base_wh_size * base_ww_size * gemm_params->gemm_k_iterations;
|
||||
if (align_C) {
|
||||
int gemm_k_iterations =
|
||||
base_wh_size * base_ww_size * gemm_params->gemm_k_iterations;
|
||||
|
||||
for (int k = 0; k < gemm_k_iterations; k++) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
// Load elements into threadgroup
|
||||
loader_a.load_unsafe();
|
||||
loader_b.load_unsafe();
|
||||
for (int k = 0; k < gemm_k_iterations; k++) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
// Load elements into threadgroup
|
||||
loader_a.load_unsafe();
|
||||
loader_b.load_unsafe();
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Multiply and accumulate threadgroup elements
|
||||
mma_op.mma(As, Bs);
|
||||
// Multiply and accumulate threadgroup elements
|
||||
mma_op.mma(As, Bs);
|
||||
|
||||
// Prepare for next iteration
|
||||
loader_a.next();
|
||||
loader_b.next();
|
||||
// Prepare for next iteration
|
||||
loader_a.next();
|
||||
loader_b.next();
|
||||
}
|
||||
}
|
||||
|
||||
else {
|
||||
for (int k = 1; k < gemm_params->gemm_k_iterations; k++) {
|
||||
for (int j = 0; j < base_wh_size * base_ww_size; j++) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
// Load elements into threadgroup
|
||||
loader_a.load_unsafe();
|
||||
loader_b.load_unsafe();
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Multiply and accumulate threadgroup elements
|
||||
mma_op.mma(As, Bs);
|
||||
|
||||
// Prepare for next iteration
|
||||
loader_a.next();
|
||||
loader_b.next();
|
||||
}
|
||||
}
|
||||
const short remaining_k = params->C % BK;
|
||||
for (int j = 0; j < base_wh_size * base_ww_size; j++) {
|
||||
// Load elements into threadgroup
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
loader_a.load_safe(remaining_k);
|
||||
loader_b.load_safe(remaining_k);
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
// Multiply and accumulate threadgroup elements
|
||||
mma_op.mma(As, Bs);
|
||||
// Prepare for next iteration
|
||||
loader_a.next();
|
||||
loader_b.next();
|
||||
}
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
@ -137,6 +137,52 @@ struct Conv2DInputBlockLoaderGeneral {
|
||||
}
|
||||
}
|
||||
|
||||
METAL_FUNC void load_safe(const short remaining_k) const {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0, is = 0; i < n_rows; ++i, is += TROWS) {
|
||||
// Find bounds
|
||||
int n = read_n[i];
|
||||
|
||||
int h_flip = params->flip ? params->wS[0] - weight_h - 1 : weight_h;
|
||||
int w_flip = params->flip ? params->wS[1] - weight_w - 1 : weight_w;
|
||||
|
||||
int ih_dil = read_ih[i] + h_flip * params->kdil[0];
|
||||
int iw_dil = read_iw[i] + w_flip * params->kdil[1];
|
||||
|
||||
int ih = ih_dil / params->idil[0];
|
||||
int iw = iw_dil / params->idil[1];
|
||||
|
||||
size_t offset = ih * params->in_strides[1] + iw * params->in_strides[2];
|
||||
|
||||
// Read from input if in bounds
|
||||
if ((n < params->N) && (ih_dil >= 0 && ih < params->iS[0]) &&
|
||||
(iw_dil >= 0 && iw < params->iS[1])) {
|
||||
if (bj + vec_size <= remaining_k) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < vec_size; ++j) {
|
||||
dst[is * dst_ld + j] = (src[i])[offset + j];
|
||||
}
|
||||
} else {
|
||||
for (short j = 0; j < vec_size; ++j) {
|
||||
if (bj + j < remaining_k) {
|
||||
dst[is * dst_ld + j] = (src[i])[offset + j];
|
||||
} else {
|
||||
dst[is * dst_ld + j] = T(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Zero pad otherwise
|
||||
else {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < vec_size; ++j) {
|
||||
dst[is * dst_ld + j] = T(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Iteration helper */
|
||||
METAL_FUNC void next() {
|
||||
weight_w += jump_params->f_wgt_jump_w;
|
||||
@ -262,6 +308,55 @@ struct Conv2DWeightBlockLoaderGeneral {
|
||||
}
|
||||
}
|
||||
|
||||
METAL_FUNC void load_safe(const short remaining_k) const {
|
||||
const device T* curr_src = src + weight_h * params->wt_strides[1] +
|
||||
weight_w * params->wt_strides[2];
|
||||
|
||||
if ((start_row + BN <= params->O)) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < BN; i += TROWS) {
|
||||
if (bj + vec_size <= remaining_k) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
dst[i * dst_ld + j] = curr_src[i * src_ld + j];
|
||||
}
|
||||
} else {
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
if (bj + j < remaining_k) {
|
||||
dst[i * dst_ld + j] = curr_src[i * src_ld + j];
|
||||
} else {
|
||||
dst[i * dst_ld + j] = T(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (short i = 0; i < BN; i += TROWS) {
|
||||
if ((start_row + i) < params->O) {
|
||||
if (bj + vec_size <= remaining_k) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
dst[i * dst_ld + j] = curr_src[i * src_ld + j];
|
||||
}
|
||||
} else {
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
if (bj + j < remaining_k) {
|
||||
dst[i * dst_ld + j] = curr_src[i * src_ld + j];
|
||||
} else {
|
||||
dst[i * dst_ld + j] = T(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
dst[i * dst_ld + j] = T(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Iteration helper */
|
||||
METAL_FUNC void next() {
|
||||
weight_w += jump_params->f_wgt_jump_w;
|
||||
|
@ -3,8 +3,11 @@
|
||||
#include <stdexcept>
|
||||
|
||||
#include "mlx/backend/metal/metal.h"
|
||||
#include "mlx/fast.h"
|
||||
|
||||
namespace mlx::core::metal {
|
||||
namespace mlx::core {
|
||||
|
||||
namespace metal {
|
||||
|
||||
bool is_available() {
|
||||
return false;
|
||||
@ -19,4 +22,21 @@ device_info() {
|
||||
"[metal::device_info] Cannot get device info without metal backend");
|
||||
};
|
||||
|
||||
} // namespace mlx::core::metal
|
||||
} // namespace metal
|
||||
|
||||
namespace fast {
|
||||
|
||||
MetalKernelFunction metal_kernel(
|
||||
const std::string&,
|
||||
const std::vector<std::string>&,
|
||||
const std::vector<std::string>&,
|
||||
const std::string&,
|
||||
const std::string&,
|
||||
bool ensure_row_contiguous,
|
||||
bool atomic_outputs) {
|
||||
throw std::runtime_error("[metal_kernel] No GPU back-end.");
|
||||
}
|
||||
|
||||
} // namespace fast
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -244,13 +244,15 @@ MTL::ComputePipelineState* get_steel_conv_kernel(
|
||||
MTL::ComputePipelineState* get_steel_conv_general_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const std::string& hash_name,
|
||||
const metal::MTLFCList& func_consts,
|
||||
const array&,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
int) {
|
||||
return d.get_kernel(kernel_name);
|
||||
return d.get_kernel(kernel_name, hash_name, func_consts);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_fft_kernel(
|
||||
|
@ -2,7 +2,6 @@
|
||||
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/distributed/primitives.h"
|
||||
#include "mlx/fast.h"
|
||||
#include "mlx/fast_primitives.h"
|
||||
|
||||
#define NO_GPU_MULTI(func) \
|
||||
@ -156,18 +155,6 @@ NO_GPU_USE_FALLBACK(RoPE)
|
||||
NO_GPU(ScaledDotProductAttention)
|
||||
NO_GPU_MULTI(AffineQuantize)
|
||||
NO_GPU_MULTI(CustomKernel)
|
||||
|
||||
MetalKernelFunction metal_kernel(
|
||||
const std::string&,
|
||||
const std::vector<std::string>&,
|
||||
const std::vector<std::string>&,
|
||||
const std::string&,
|
||||
const std::string&,
|
||||
bool ensure_row_contiguous,
|
||||
bool atomic_outputs) {
|
||||
throw std::runtime_error("[metal_kernel] No GPU back-end.");
|
||||
}
|
||||
|
||||
} // namespace fast
|
||||
|
||||
namespace distributed {
|
||||
|
@ -3,6 +3,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/cuda/cuda.h"
|
||||
#include "mlx/backend/metal/metal.h"
|
||||
#include "mlx/compile.h"
|
||||
#include "mlx/device.h"
|
||||
|
@ -17,10 +17,7 @@
|
||||
#include "python/src/indexing.h"
|
||||
#include "python/src/utils.h"
|
||||
|
||||
#include "mlx/device.h"
|
||||
#include "mlx/ops.h"
|
||||
#include "mlx/transforms.h"
|
||||
#include "mlx/utils.h"
|
||||
#include "mlx/mlx.h"
|
||||
|
||||
namespace mx = mlx::core;
|
||||
namespace nb = nanobind;
|
||||
@ -461,9 +458,12 @@ void init_array(nb::module_& m) {
|
||||
.def(
|
||||
"__dlpack_device__",
|
||||
[](const mx::array& a) {
|
||||
// See
|
||||
// https://github.com/dmlc/dlpack/blob/5c210da409e7f1e51ddf445134a4376fdbd70d7d/include/dlpack/dlpack.h#L74
|
||||
if (mx::metal::is_available()) {
|
||||
// Metal device is available
|
||||
return nb::make_tuple(8, 0);
|
||||
} else if (mx::cu::is_available()) {
|
||||
return nb::make_tuple(13, 0);
|
||||
} else {
|
||||
// CPU device
|
||||
return nb::make_tuple(1, 0);
|
||||
|
@ -58,4 +58,9 @@ void init_device(nb::module_& m) {
|
||||
&mx::set_default_device,
|
||||
"device"_a,
|
||||
R"pbdoc(Set the default device.)pbdoc");
|
||||
m.def(
|
||||
"is_available",
|
||||
&mx::is_available,
|
||||
"device"_a,
|
||||
R"pbdoc(Check if a back-end is available for the given device.)pbdoc");
|
||||
}
|
||||
|
@ -198,7 +198,7 @@ class TestInequality(mlx_tests.MLXTestCase):
|
||||
def test_dlx_device_type(self):
|
||||
a = mx.array([1, 2, 3])
|
||||
device_type, device_id = a.__dlpack_device__()
|
||||
self.assertIn(device_type, [1, 8])
|
||||
self.assertIn(device_type, [1, 8, 13])
|
||||
self.assertEqual(device_id, 0)
|
||||
|
||||
if device_type == 8:
|
||||
|
@ -1173,6 +1173,19 @@ class TestConv(mlx_tests.MLXTestCase):
|
||||
|
||||
self.assertTrue(mx.allclose(out, out_2d.squeeze(2)))
|
||||
|
||||
def test_conv2d_unaligned_channels(self):
|
||||
x = mx.random.uniform(shape=(2, 16, 16, 21))
|
||||
w = mx.random.uniform(shape=(32, 3, 3, 21))
|
||||
y = mx.conv2d(x, w, stream=mx.cpu)
|
||||
y_hat = mx.conv2d(x, w)
|
||||
self.assertTrue(mx.allclose(y, y_hat))
|
||||
|
||||
x = mx.random.uniform(shape=(2, 16, 16, 21))
|
||||
w = mx.random.uniform(shape=(21, 3, 3, 21))
|
||||
y = mx.conv2d(x, w, stream=mx.cpu)
|
||||
y_hat = mx.conv2d(x, w)
|
||||
self.assertTrue(mx.allclose(y, y_hat))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
@ -10,7 +10,7 @@ import mlx_tests
|
||||
class TestDefaultDevice(unittest.TestCase):
|
||||
def test_mlx_default_device(self):
|
||||
device = mx.default_device()
|
||||
if mx.metal.is_available():
|
||||
if mx.is_available(mx.gpu):
|
||||
self.assertEqual(device, mx.Device(mx.gpu))
|
||||
self.assertEqual(str(device), "Device(gpu, 0)")
|
||||
self.assertEqual(device, mx.gpu)
|
||||
@ -73,7 +73,7 @@ class TestStream(mlx_tests.MLXTestCase):
|
||||
self.assertEqual(s2.device, mx.default_device())
|
||||
self.assertNotEqual(s1, s2)
|
||||
|
||||
if mx.metal.is_available():
|
||||
if mx.is_available(mx.gpu):
|
||||
s_gpu = mx.default_stream(mx.gpu)
|
||||
self.assertEqual(s_gpu.device, mx.gpu)
|
||||
else:
|
||||
@ -86,7 +86,7 @@ class TestStream(mlx_tests.MLXTestCase):
|
||||
s_cpu = mx.new_stream(mx.cpu)
|
||||
self.assertEqual(s_cpu.device, mx.cpu)
|
||||
|
||||
if mx.metal.is_available():
|
||||
if mx.is_available(mx.gpu):
|
||||
s_gpu = mx.new_stream(mx.gpu)
|
||||
self.assertEqual(s_gpu.device, mx.gpu)
|
||||
else:
|
||||
@ -99,7 +99,7 @@ class TestStream(mlx_tests.MLXTestCase):
|
||||
|
||||
a = mx.add(x, y, stream=mx.default_stream(mx.default_device()))
|
||||
|
||||
if mx.metal.is_available():
|
||||
if mx.is_available(mx.gpu):
|
||||
b = mx.add(x, y, stream=mx.default_stream(mx.gpu))
|
||||
self.assertEqual(a.item(), b.item())
|
||||
s_gpu = mx.new_stream(mx.gpu)
|
||||
|
@ -353,7 +353,7 @@ class TestOptimizers(mlx_tests.MLXTestCase):
|
||||
self.assertTrue(mx.allclose(result["w"], mx.full((5, 5), 3.0)))
|
||||
|
||||
|
||||
class TestSchedulers(unittest.TestCase):
|
||||
class TestSchedulers(mlx_tests.MLXTestCase):
|
||||
def test_decay_lr(self):
|
||||
for optim_class in optimizers_dict.values():
|
||||
lr_schedule = opt.step_decay(1e-1, 0.9, 1)
|
||||
|
Loading…
Reference in New Issue
Block a user