mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
parent
4c1dfa58b7
commit
71de73a668
@ -25,7 +25,7 @@ option(MLX_METAL_JIT "Use JIT compilation for Metal kernels" OFF)
|
|||||||
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
|
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
|
||||||
|
|
||||||
if(NOT MLX_VERSION)
|
if(NOT MLX_VERSION)
|
||||||
set(MLX_VERSION 0.23.0)
|
set(MLX_VERSION 0.23.1)
|
||||||
endif()
|
endif()
|
||||||
add_compile_definitions("MLX_VERSION=${MLX_VERSION}")
|
add_compile_definitions("MLX_VERSION=${MLX_VERSION}")
|
||||||
|
|
||||||
|
@ -533,45 +533,6 @@ void implicit_gemm_conv_2D_general_gpu(
|
|||||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||||
}
|
}
|
||||||
|
|
||||||
void winograd_conv_2D_fused_gpu(
|
|
||||||
const Stream& s,
|
|
||||||
metal::Device& d,
|
|
||||||
const array& in,
|
|
||||||
const array& wt,
|
|
||||||
array out,
|
|
||||||
const MLXConvParams<2>& conv_params,
|
|
||||||
std::vector<array>& copies_w) {
|
|
||||||
int O_c = conv_params.O;
|
|
||||||
int C_c = conv_params.C;
|
|
||||||
|
|
||||||
int N_tiles_n = conv_params.N;
|
|
||||||
int N_tiles_h = (conv_params.oS[0] + 1) / 2;
|
|
||||||
int N_tiles_w = (conv_params.oS[1] + 1) / 2;
|
|
||||||
int N_tiles = N_tiles_n * N_tiles_h * N_tiles_w;
|
|
||||||
|
|
||||||
int bc = 32;
|
|
||||||
int wm = 4;
|
|
||||||
int wn = 1;
|
|
||||||
std::ostringstream kname;
|
|
||||||
kname << "winograd_conv_2d_fused_" << type_to_name(out) << "_flip"
|
|
||||||
<< conv_params.flip;
|
|
||||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
|
||||||
auto kernel = d.get_kernel(kname.str());
|
|
||||||
compute_encoder.set_compute_pipeline_state(kernel);
|
|
||||||
|
|
||||||
compute_encoder.set_input_array(in, 0);
|
|
||||||
compute_encoder.set_input_array(wt, 1);
|
|
||||||
compute_encoder.set_output_array(out, 2);
|
|
||||||
|
|
||||||
compute_encoder.set_bytes(conv_params, 3);
|
|
||||||
|
|
||||||
MTL::Size group_dims = MTL::Size(8, 8, 2);
|
|
||||||
MTL::Size grid_dims =
|
|
||||||
MTL::Size(O_c / 8, (N_tiles_h * N_tiles_w) / 8, N_tiles_n);
|
|
||||||
|
|
||||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
|
||||||
}
|
|
||||||
|
|
||||||
void winograd_conv_2D_gpu(
|
void winograd_conv_2D_gpu(
|
||||||
const Stream& s,
|
const Stream& s,
|
||||||
metal::Device& d,
|
metal::Device& d,
|
||||||
@ -580,6 +541,67 @@ void winograd_conv_2D_gpu(
|
|||||||
array out,
|
array out,
|
||||||
const MLXConvParams<2>& conv_params,
|
const MLXConvParams<2>& conv_params,
|
||||||
std::vector<array>& copies_w) {
|
std::vector<array>& copies_w) {
|
||||||
|
Shape padded_shape = {
|
||||||
|
conv_params.N,
|
||||||
|
conv_params.iS[0] + 2 * conv_params.pad[0],
|
||||||
|
conv_params.iS[1] + 2 * conv_params.pad[1],
|
||||||
|
conv_params.C};
|
||||||
|
|
||||||
|
padded_shape[1] = 6 * ((padded_shape[1] - 2 + 5) / 6) + 2;
|
||||||
|
padded_shape[2] = 6 * ((padded_shape[2] - 2 + 5) / 6) + 2;
|
||||||
|
|
||||||
|
array in_padded(std::move(padded_shape), in.dtype(), nullptr, {});
|
||||||
|
|
||||||
|
// Fill with zeros
|
||||||
|
array zero_arr = array(0, in.dtype());
|
||||||
|
fill_gpu(zero_arr, in_padded, s);
|
||||||
|
copies_w.push_back(zero_arr);
|
||||||
|
|
||||||
|
// Pick input slice from padded
|
||||||
|
size_t data_offset = conv_params.pad[0] * in_padded.strides()[1] +
|
||||||
|
conv_params.pad[1] * in_padded.strides()[2];
|
||||||
|
array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {});
|
||||||
|
in_padded_slice.copy_shared_buffer(
|
||||||
|
in_padded,
|
||||||
|
in_padded.strides(),
|
||||||
|
in_padded.flags(),
|
||||||
|
in_padded_slice.size(),
|
||||||
|
data_offset);
|
||||||
|
|
||||||
|
// Copy input values into the slice
|
||||||
|
copy_gpu_inplace(in, in_padded_slice, CopyType::GeneralGeneral, s);
|
||||||
|
|
||||||
|
copies_w.push_back(in_padded_slice);
|
||||||
|
copies_w.push_back(in_padded);
|
||||||
|
|
||||||
|
MLXConvParams<2> conv_params_updated{
|
||||||
|
/* const int N = */ static_cast<int>(in_padded.shape(0)),
|
||||||
|
/* const int C = */ static_cast<int>(in_padded.shape(3)),
|
||||||
|
/* const int O = */ static_cast<int>(wt.shape(0)),
|
||||||
|
/* const int iS[NDIM] = */
|
||||||
|
{static_cast<int>(in_padded.shape(1)),
|
||||||
|
static_cast<int>(in_padded.shape(2))},
|
||||||
|
/* const int wS[NDIM] = */
|
||||||
|
{static_cast<int>(wt.shape(1)), static_cast<int>(wt.shape(2))},
|
||||||
|
/* const int oS[NDIM] = */
|
||||||
|
{static_cast<int>(out.shape(1)), static_cast<int>(out.shape(2))},
|
||||||
|
/* const int str[NDIM] = */ {1, 1},
|
||||||
|
/* const int pad[NDIM] = */ {0, 0},
|
||||||
|
/* const int kdil[NDIM] = */ {1, 1},
|
||||||
|
/* const int idil[NDIM] = */ {1, 1},
|
||||||
|
/* const size_t in_strides[NDIM + 2] = */
|
||||||
|
{in_padded.strides()[0],
|
||||||
|
in_padded.strides()[1],
|
||||||
|
in_padded.strides()[2],
|
||||||
|
in_padded.strides()[3]},
|
||||||
|
/* const size_t wt_strides[NDIM + 2] = */
|
||||||
|
{wt.strides()[0], wt.strides()[1], wt.strides()[2], wt.strides()[3]},
|
||||||
|
/* const size_t out_strides[NDIM + 2] = */
|
||||||
|
{out.strides()[0], out.strides()[1], out.strides()[2], out.strides()[3]},
|
||||||
|
/* const int groups = */ 1,
|
||||||
|
/* const bool flip = */ false,
|
||||||
|
};
|
||||||
|
|
||||||
int O_c = conv_params.O;
|
int O_c = conv_params.O;
|
||||||
int C_c = conv_params.C;
|
int C_c = conv_params.C;
|
||||||
|
|
||||||
@ -598,7 +620,7 @@ void winograd_conv_2D_gpu(
|
|||||||
int bo = 4;
|
int bo = 4;
|
||||||
std::ostringstream kname;
|
std::ostringstream kname;
|
||||||
kname << "winograd_conv_2d_weight_transform_" << type_to_name(out) << "_bc"
|
kname << "winograd_conv_2d_weight_transform_" << type_to_name(out) << "_bc"
|
||||||
<< bc << "_flip" << conv_params.flip;
|
<< bc;
|
||||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
auto kernel = d.get_kernel(kname.str());
|
auto kernel = d.get_kernel(kname.str());
|
||||||
compute_encoder.set_compute_pipeline_state(kernel);
|
compute_encoder.set_compute_pipeline_state(kernel);
|
||||||
@ -631,10 +653,10 @@ void winograd_conv_2D_gpu(
|
|||||||
auto kernel = d.get_kernel(kname.str());
|
auto kernel = d.get_kernel(kname.str());
|
||||||
compute_encoder.set_compute_pipeline_state(kernel);
|
compute_encoder.set_compute_pipeline_state(kernel);
|
||||||
|
|
||||||
compute_encoder.set_input_array(in, 0);
|
compute_encoder.set_input_array(in_padded, 0);
|
||||||
compute_encoder.set_output_array(inp_wg, 1);
|
compute_encoder.set_output_array(inp_wg, 1);
|
||||||
|
|
||||||
compute_encoder.set_bytes(conv_params, 2);
|
compute_encoder.set_bytes(conv_params_updated, 2);
|
||||||
|
|
||||||
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
||||||
MTL::Size grid_dims = MTL::Size(N_tiles_w, N_tiles_h, N_tiles_n);
|
MTL::Size grid_dims = MTL::Size(N_tiles_w, N_tiles_h, N_tiles_n);
|
||||||
@ -681,7 +703,7 @@ void winograd_conv_2D_gpu(
|
|||||||
compute_encoder.set_input_array(out_wg, 0);
|
compute_encoder.set_input_array(out_wg, 0);
|
||||||
compute_encoder.set_output_array(out, 1);
|
compute_encoder.set_output_array(out, 1);
|
||||||
|
|
||||||
compute_encoder.set_bytes(conv_params, 2);
|
compute_encoder.set_bytes(conv_params_updated, 2);
|
||||||
|
|
||||||
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
||||||
MTL::Size grid_dims = MTL::Size(N_tiles_w, N_tiles_h, N_tiles_n);
|
MTL::Size grid_dims = MTL::Size(N_tiles_w, N_tiles_h, N_tiles_n);
|
||||||
@ -745,18 +767,14 @@ void conv_2D_gpu(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Direct to winograd conv
|
// Direct to winograd conv
|
||||||
bool img_large =
|
bool inp_large =
|
||||||
(conv_params.N * conv_params.iS[0] * conv_params.iS[1]) >= 1ul << 12;
|
(conv_params.N * conv_params.iS[0] * conv_params.iS[1]) >= 1ul << 12;
|
||||||
bool channels_large = (conv_params.C + conv_params.O) >= 256;
|
bool channels_large = (conv_params.C + conv_params.O) >= 256;
|
||||||
if (conv_params.wS[0] == 3 && conv_params.wS[1] == 3 &&
|
if (!flip && is_stride_one && is_kdil_one && is_idil_one &&
|
||||||
conv_params.C % 32 == 0 && conv_params.O % 32 == 0 && is_stride_one &&
|
conv_params.wS[0] == 3 && conv_params.wS[1] == 3 &&
|
||||||
is_kdil_one && is_idil_one) {
|
conv_params.C % 32 == 0 && conv_params.O % 32 == 0 && inp_large &&
|
||||||
if (img_large && channels_large) {
|
channels_large) {
|
||||||
return winograd_conv_2D_gpu(s, d, in, wt, out, conv_params, copies);
|
return winograd_conv_2D_gpu(s, d, in, wt, out, conv_params, copies);
|
||||||
}
|
|
||||||
if (conv_params.N <= 1) {
|
|
||||||
return winograd_conv_2D_fused_gpu(s, d, in, wt, out, conv_params, copies);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Direct to implicit gemm conv
|
// Direct to implicit gemm conv
|
||||||
@ -858,40 +876,8 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
wt = arr_copy;
|
wt = arr_copy;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check for 1x1 conv
|
|
||||||
auto is_one = [](int x) { return x == 1; };
|
|
||||||
auto is_zero = [](int x) { return x == 0; };
|
|
||||||
if (groups_ == 1 && (wt.shape(0) * wt.shape(-1) == wt.size()) &&
|
|
||||||
std::all_of(wt.shape().begin() + 1, wt.shape().end() - 1, is_one) &&
|
|
||||||
std::all_of(kernel_strides_.begin(), kernel_strides_.end(), is_one) &&
|
|
||||||
std::all_of(input_dilation_.begin(), input_dilation_.end(), is_one) &&
|
|
||||||
std::all_of(kernel_dilation_.begin(), kernel_dilation_.end(), is_one) &&
|
|
||||||
std::all_of(padding_.begin(), padding_.end(), is_zero)) {
|
|
||||||
std::vector<array> empty_copies;
|
|
||||||
steel_matmul_regular(
|
|
||||||
s,
|
|
||||||
d,
|
|
||||||
/*a = */ in,
|
|
||||||
/*b = */ wt,
|
|
||||||
/*c = */ out,
|
|
||||||
/*M = */ in.size() / in.shape(-1),
|
|
||||||
/*N = */ wt.shape(0),
|
|
||||||
/*K = */ in.shape(-1),
|
|
||||||
/*batch_size_out = */ 1,
|
|
||||||
/*lda = */ in.shape(-1),
|
|
||||||
/*ldb = */ wt.shape(-1),
|
|
||||||
/*ldd = */ wt.shape(0),
|
|
||||||
/*transpose_a = */ false,
|
|
||||||
/*transpose_b = */ true,
|
|
||||||
/*batch_shape = */ {1},
|
|
||||||
/*batch_strides = */ {1},
|
|
||||||
/*A_batch_stride = */ 0,
|
|
||||||
/*B_batch_stride = */ 0,
|
|
||||||
/*matrix_stride_out = */ 0,
|
|
||||||
/*copies = */ empty_copies);
|
|
||||||
}
|
|
||||||
// 3D conv
|
// 3D conv
|
||||||
else if (out.ndim() == 5) {
|
if (out.ndim() == 5) {
|
||||||
conv_3D_gpu(
|
conv_3D_gpu(
|
||||||
s,
|
s,
|
||||||
d,
|
d,
|
||||||
|
@ -326,13 +326,7 @@ constant constexpr const float WinogradTransforms<6, 3, 8>::wt_transform[8][8];
|
|||||||
constant constexpr const float WinogradTransforms<6, 3, 8>::in_transform[8][8];
|
constant constexpr const float WinogradTransforms<6, 3, 8>::in_transform[8][8];
|
||||||
constant constexpr const float WinogradTransforms<6, 3, 8>::out_transform[8][8];
|
constant constexpr const float WinogradTransforms<6, 3, 8>::out_transform[8][8];
|
||||||
|
|
||||||
template <
|
template <typename T, int BC = 32, int BO = 4, int M = 6, int R = 3>
|
||||||
typename T,
|
|
||||||
int BC = 32,
|
|
||||||
int BO = 4,
|
|
||||||
bool do_flip = false,
|
|
||||||
int M = 6,
|
|
||||||
int R = 3>
|
|
||||||
[[kernel, max_total_threads_per_threadgroup(BO * 32)]] void
|
[[kernel, max_total_threads_per_threadgroup(BO * 32)]] void
|
||||||
winograd_conv_2d_weight_transform(
|
winograd_conv_2d_weight_transform(
|
||||||
const device T* wt_in [[buffer(0)]],
|
const device T* wt_in [[buffer(0)]],
|
||||||
@ -379,12 +373,7 @@ winograd_conv_2d_weight_transform(
|
|||||||
for (int kh = 0; kh < R; ++kh) {
|
for (int kh = 0; kh < R; ++kh) {
|
||||||
for (int kw = 0; kw < R; ++kw) {
|
for (int kw = 0; kw < R; ++kw) {
|
||||||
for (int kc = simd_lane_id; kc < BC; kc += 32) {
|
for (int kc = simd_lane_id; kc < BC; kc += 32) {
|
||||||
if (do_flip) {
|
Ws[simd_group_id][kh][kw][kc] = wt_in[kh * R * C + kw * C + kc];
|
||||||
Ws[simd_group_id][R - 1 - kh][R - 1 - kw][kc] =
|
|
||||||
wt_in[kh * R * C + kw * C + kc];
|
|
||||||
} else {
|
|
||||||
Ws[simd_group_id][kh][kw][kc] = wt_in[kh * R * C + kw * C + kc];
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -409,10 +398,10 @@ winograd_conv_2d_weight_transform(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#define instantiate_winograd_conv_2d_weight_tr_base_2(name, itype, bc, f) \
|
#define instantiate_winograd_conv_2d_weight_transform_base(name, itype, bc) \
|
||||||
template [[host_name("winograd_conv_2d_weight_transform_" #name "_bc" #bc \
|
template [[host_name("winograd_conv_2d_weight_transform_" #name \
|
||||||
"_flip" #f)]] [[kernel]] void \
|
"_bc" #bc)]] [[kernel]] void \
|
||||||
winograd_conv_2d_weight_transform<itype, bc, 4, f>( \
|
winograd_conv_2d_weight_transform<itype, bc>( \
|
||||||
const device itype* wt_in [[buffer(0)]], \
|
const device itype* wt_in [[buffer(0)]], \
|
||||||
device itype* wt_out [[buffer(1)]], \
|
device itype* wt_out [[buffer(1)]], \
|
||||||
const constant int& C [[buffer(2)]], \
|
const constant int& C [[buffer(2)]], \
|
||||||
@ -421,10 +410,6 @@ winograd_conv_2d_weight_transform(
|
|||||||
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
|
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
|
||||||
uint simd_lane_id [[thread_index_in_simdgroup]]);
|
uint simd_lane_id [[thread_index_in_simdgroup]]);
|
||||||
|
|
||||||
#define instantiate_winograd_conv_2d_weight_transform_base(name, itype, bc) \
|
|
||||||
instantiate_winograd_conv_2d_weight_tr_base_2(name, itype, bc, 0) \
|
|
||||||
instantiate_winograd_conv_2d_weight_tr_base_2(name, itype, bc, 1)
|
|
||||||
|
|
||||||
template <typename T, int BC, int WM, int WN, int M = 6, int R = 3>
|
template <typename T, int BC, int WM, int WN, int M = 6, int R = 3>
|
||||||
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void
|
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void
|
||||||
winograd_conv_2d_input_transform(
|
winograd_conv_2d_input_transform(
|
||||||
@ -460,17 +445,10 @@ winograd_conv_2d_input_transform(
|
|||||||
// Resolve input tile
|
// Resolve input tile
|
||||||
constexpr int TH = (A / WM);
|
constexpr int TH = (A / WM);
|
||||||
constexpr int TW = (A / WN);
|
constexpr int TW = (A / WN);
|
||||||
const int kh = TH * (simd_group_id / WN);
|
int kh = TH * (simd_group_id / WN);
|
||||||
const int kw = TW * (simd_group_id % WN);
|
int kw = TW * (simd_group_id % WN);
|
||||||
const int bh = M * tid.y + kh - params.pad[1];
|
int bh = M * tid.y + kh;
|
||||||
const int bw = M * tid.x + kw - params.pad[0];
|
int bw = M * tid.x + kw;
|
||||||
|
|
||||||
const bool is_edge_w_lo = bw < 0;
|
|
||||||
const bool is_edge_h_lo = bh < 0;
|
|
||||||
const bool is_edge_w_hi = bw + (TW - 1) >= params.iS[0];
|
|
||||||
const bool is_edge_h_hi = bh + (TH - 1) >= params.iS[1];
|
|
||||||
const bool is_edge =
|
|
||||||
is_edge_w_lo || is_edge_h_lo || is_edge_w_hi || is_edge_h_hi;
|
|
||||||
|
|
||||||
// Move to the correct input tile
|
// Move to the correct input tile
|
||||||
inp_in += tid.z * params.in_strides[0] + bh * params.in_strides[1] +
|
inp_in += tid.z * params.in_strides[0] + bh * params.in_strides[1] +
|
||||||
@ -506,21 +484,8 @@ winograd_conv_2d_input_transform(
|
|||||||
for (int h = 0; h < TH; h++) {
|
for (int h = 0; h < TH; h++) {
|
||||||
for (int w = 0; w < TW; w++) {
|
for (int w = 0; w < TW; w++) {
|
||||||
const device T* in_ptr = inp_in + jump_in[h][w];
|
const device T* in_ptr = inp_in + jump_in[h][w];
|
||||||
if (is_edge) {
|
for (int c = simd_lane_id; c < BC; c += 32) {
|
||||||
if (((bh + h) < 0 || (bh + h) >= params.iS[1]) ||
|
Is[kh + h][kw + w][c] = in_ptr[c];
|
||||||
((bw + w) < 0 || (bw + w) >= params.iS[0])) {
|
|
||||||
for (int c = simd_lane_id; c < BC; c += 32) {
|
|
||||||
Is[kh + h][kw + w][c] = T(0);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
for (int c = simd_lane_id; c < BC; c += 32) {
|
|
||||||
Is[kh + h][kw + w][c] = in_ptr[c];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
for (int c = simd_lane_id; c < BC; c += 32) {
|
|
||||||
Is[kh + h][kw + w][c] = in_ptr[c];
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -687,371 +652,3 @@ winograd_conv_2d_output_transform(
|
|||||||
instantiate_winograd_conv_2d(float32, float);
|
instantiate_winograd_conv_2d(float32, float);
|
||||||
instantiate_winograd_conv_2d(bfloat16, bfloat16_t);
|
instantiate_winograd_conv_2d(bfloat16, bfloat16_t);
|
||||||
instantiate_winograd_conv_2d(float16, half); // clang-format on
|
instantiate_winograd_conv_2d(float16, half); // clang-format on
|
||||||
|
|
||||||
#include "mlx/backend/metal/kernels/steel/attn/mma.h"
|
|
||||||
|
|
||||||
template <
|
|
||||||
typename T,
|
|
||||||
bool do_flip = false,
|
|
||||||
int WM = 4,
|
|
||||||
int WN = 1,
|
|
||||||
typename AccumType = float>
|
|
||||||
[[kernel]] void winograd_fused(
|
|
||||||
const device T* input [[buffer(0)]],
|
|
||||||
const device T* weight [[buffer(1)]],
|
|
||||||
device T* output [[buffer(2)]],
|
|
||||||
const constant MLXConvParams<2>& params [[buffer(3)]],
|
|
||||||
uint3 tid [[threadgroup_position_in_grid]],
|
|
||||||
uint3 lid [[thread_position_in_threadgroup]],
|
|
||||||
uint3 tgp_per_grid [[threadgroups_per_grid]],
|
|
||||||
ushort simd_group_id [[simdgroup_index_in_threadgroup]],
|
|
||||||
ushort simd_lane_id [[thread_index_in_simdgroup]]) {
|
|
||||||
using namespace mlx::steel;
|
|
||||||
|
|
||||||
(void)tgp_per_grid;
|
|
||||||
|
|
||||||
// Winograd F(n x n, r x r)
|
|
||||||
// n x n output window
|
|
||||||
constexpr short FN = 2;
|
|
||||||
// r x r filter size
|
|
||||||
constexpr short FR = 3;
|
|
||||||
// a x a input window, a = n + r - 1
|
|
||||||
constexpr short FA = 4;
|
|
||||||
|
|
||||||
constexpr short kFragSize = 8; // MMA frag size
|
|
||||||
|
|
||||||
constexpr short BT = 8; // Tile block size
|
|
||||||
constexpr short BO = 8; // Output channel block size
|
|
||||||
constexpr short BC = 8; // Input channel block size
|
|
||||||
|
|
||||||
// clang-format off
|
|
||||||
static_assert(BT % (1 * kFragSize) == 0 &&
|
|
||||||
BO % (1 * kFragSize) == 0 &&
|
|
||||||
BC % kFragSize == 0,
|
|
||||||
"Matmuls sizes must be compatible with fragments");
|
|
||||||
// clang-format on
|
|
||||||
|
|
||||||
// Prepare for matmul
|
|
||||||
|
|
||||||
// Warp tile sizes for matmul
|
|
||||||
constexpr short TM = (FA * FA * BT) / (WM * kFragSize);
|
|
||||||
constexpr short TN = (BO) / (WN * kFragSize);
|
|
||||||
constexpr short TK = (BC) / (kFragSize);
|
|
||||||
|
|
||||||
// Warp primitives
|
|
||||||
using MMAFrag_acc_t = BaseMMAFrag<AccumType, kFragSize, kFragSize>;
|
|
||||||
|
|
||||||
// Warp tiles sizes for matmul
|
|
||||||
MMATile<AccumType, 1, TK, MMAFrag_acc_t> Itile;
|
|
||||||
MMATile<AccumType, TK, TN, MMAFrag_acc_t> Wtile;
|
|
||||||
MMATile<AccumType, 1, TN, MMAFrag_acc_t> Otile[TM];
|
|
||||||
|
|
||||||
for (int im = 0; im < 4; im++) {
|
|
||||||
Otile[im].clear();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Threadgroup memory for Weights and Inputs
|
|
||||||
constexpr short BS = BT > BO ? BT : BO;
|
|
||||||
threadgroup T Wt[FA * FA * BC * BO];
|
|
||||||
threadgroup T It[FA * FA * BS * BS];
|
|
||||||
|
|
||||||
// Get thread position in tile
|
|
||||||
short2 simd_coord = MMAFrag_acc_t::get_coord(simd_lane_id);
|
|
||||||
const short sm = simd_coord.y;
|
|
||||||
const short sn = simd_coord.x;
|
|
||||||
|
|
||||||
static_assert(FA * FA * BT == 32 * WM * WN, "Each thread loads one pixel.");
|
|
||||||
const int thr_idx = simd_group_id * 32 + simd_lane_id;
|
|
||||||
const int thr_t = thr_idx / (FA * FA);
|
|
||||||
const int thr_hw = thr_idx % (FA * FA);
|
|
||||||
const int thr_h = thr_hw / FA;
|
|
||||||
const int thr_w = thr_hw % FA;
|
|
||||||
|
|
||||||
// Get batch, tile, and output idx for warp
|
|
||||||
const int b_idx = tid.z;
|
|
||||||
const int t_idx = BT * tid.y + thr_t;
|
|
||||||
const int o_idx = BO * tid.x + thr_t;
|
|
||||||
|
|
||||||
// Divide tile into h, w tile
|
|
||||||
uniform<int> oWu = make_uniform(params.oS[1]);
|
|
||||||
uniform<int> tWu = (oWu + make_uniform(FN - 1)) / make_uniform(FN);
|
|
||||||
|
|
||||||
const int oH_idx = FN * (t_idx / tWu);
|
|
||||||
const int oW_idx = FN * (t_idx % tWu);
|
|
||||||
const int iH_idx = oH_idx + thr_h - params.pad[0];
|
|
||||||
const int iW_idx = oW_idx + thr_w - params.pad[1];
|
|
||||||
|
|
||||||
// Move to correct location
|
|
||||||
|
|
||||||
// clang-format off
|
|
||||||
input += b_idx * params.in_strides[0] + // N
|
|
||||||
iH_idx * params.in_strides[1] + // H
|
|
||||||
iW_idx * params.in_strides[2]; // W
|
|
||||||
|
|
||||||
weight += o_idx * params.wt_strides[0] + // O
|
|
||||||
thr_h * params.wt_strides[1] + // H
|
|
||||||
thr_w * params.wt_strides[2]; // W
|
|
||||||
// clang-format on
|
|
||||||
|
|
||||||
// Do edge check prep for input
|
|
||||||
const bool is_edge_w_lo = iH_idx < 0;
|
|
||||||
const bool is_edge_h_lo = iW_idx < 0;
|
|
||||||
const bool is_edge_w_hi = iH_idx >= params.iS[0];
|
|
||||||
const bool is_edge_h_hi = iW_idx >= params.iS[1];
|
|
||||||
const bool is_edge =
|
|
||||||
is_edge_w_lo || is_edge_h_lo || is_edge_w_hi || is_edge_h_hi;
|
|
||||||
|
|
||||||
// Iterate over C
|
|
||||||
for (int c = 0; c < params.C; c += BC) {
|
|
||||||
#define tmp_load_wt_idx(o, h, w, c) h* FA* BC* BO + w* BC* BO + c* BO + o
|
|
||||||
#define tmp_load_in_idx(t, h, w, c) h* FA* BS* BC + w* BS* BC + t* BC + c
|
|
||||||
|
|
||||||
#define tmp_trns_wt_idx(o, h, w, c) h* FA* BC* BO + w* BC* BO + c* BO + o
|
|
||||||
#define tmp_trns_in_idx(t, h, w, c) h* FA* BS* BC + w* BS* BC + t* BC + c
|
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
|
|
||||||
// Load weight
|
|
||||||
if (thr_h < FR && thr_w < FR && thr_t < BO) {
|
|
||||||
for (int ic = 0; ic < BC; ic++) {
|
|
||||||
if (do_flip) {
|
|
||||||
Wt[tmp_load_wt_idx(thr_t, FR - 1 - thr_h, FR - 1 - thr_w, ic)] =
|
|
||||||
weight[c + ic];
|
|
||||||
} else {
|
|
||||||
Wt[tmp_load_wt_idx(thr_t, thr_h, thr_w, ic)] = weight[c + ic];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Load input
|
|
||||||
if (is_edge) {
|
|
||||||
for (int ic = 0; ic < BC; ic++) {
|
|
||||||
It[tmp_load_in_idx(thr_t, thr_h, thr_w, ic)] = T(0);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
for (int ic = 0; ic < BC; ic++) {
|
|
||||||
It[tmp_load_in_idx(thr_t, thr_h, thr_w, ic)] = input[c + ic];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
|
|
||||||
// Transform weight
|
|
||||||
if (lid.z == 0) {
|
|
||||||
const short ic = lid.y;
|
|
||||||
const short io = lid.x;
|
|
||||||
|
|
||||||
T tmp_0[4][4];
|
|
||||||
T tmp_1[4][4];
|
|
||||||
|
|
||||||
for (int ii = 0; ii < 3; ++ii) {
|
|
||||||
for (int jj = 0; jj < 3; ++jj) {
|
|
||||||
tmp_0[ii][jj] = Wt[tmp_load_wt_idx(io, ii, jj, ic)];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
//////////////////////////////////////////////
|
|
||||||
|
|
||||||
tmp_1[0][0] = tmp_0[0][0];
|
|
||||||
tmp_1[0][1] = tmp_0[0][1];
|
|
||||||
tmp_1[0][2] = tmp_0[0][2];
|
|
||||||
|
|
||||||
tmp_1[1][0] = T(0.5) * (tmp_0[0][0] + tmp_0[1][0] + tmp_0[2][0]);
|
|
||||||
tmp_1[1][1] = T(0.5) * (tmp_0[0][1] + tmp_0[1][1] + tmp_0[2][1]);
|
|
||||||
tmp_1[1][2] = T(0.5) * (tmp_0[0][2] + tmp_0[1][2] + tmp_0[2][2]);
|
|
||||||
|
|
||||||
tmp_1[2][0] = tmp_1[1][0] - tmp_0[1][0];
|
|
||||||
tmp_1[2][1] = tmp_1[1][1] - tmp_0[1][1];
|
|
||||||
tmp_1[2][2] = tmp_1[1][2] - tmp_0[1][2];
|
|
||||||
|
|
||||||
tmp_1[3][0] = tmp_0[2][0];
|
|
||||||
tmp_1[3][1] = tmp_0[2][1];
|
|
||||||
tmp_1[3][2] = tmp_0[2][2];
|
|
||||||
|
|
||||||
//////////////////////////////////////////////
|
|
||||||
tmp_0[0][0] = tmp_1[0][0];
|
|
||||||
tmp_0[1][0] = tmp_1[1][0];
|
|
||||||
tmp_0[2][0] = tmp_1[2][0];
|
|
||||||
tmp_0[3][0] = tmp_1[3][0];
|
|
||||||
|
|
||||||
tmp_0[0][1] = T(0.5) * (tmp_1[0][0] + tmp_1[0][1] + tmp_1[0][2]);
|
|
||||||
tmp_0[1][1] = T(0.5) * (tmp_1[1][0] + tmp_1[1][1] + tmp_1[1][2]);
|
|
||||||
tmp_0[2][1] = T(0.5) * (tmp_1[2][0] + tmp_1[2][1] + tmp_1[2][2]);
|
|
||||||
tmp_0[3][1] = T(0.5) * (tmp_1[3][0] + tmp_1[3][1] + tmp_1[3][2]);
|
|
||||||
|
|
||||||
tmp_0[0][2] = tmp_0[0][1] - tmp_1[0][1];
|
|
||||||
tmp_0[1][2] = tmp_0[1][1] - tmp_1[1][1];
|
|
||||||
tmp_0[2][2] = tmp_0[2][1] - tmp_1[2][1];
|
|
||||||
tmp_0[3][2] = tmp_0[3][1] - tmp_1[3][1];
|
|
||||||
|
|
||||||
tmp_0[0][3] = tmp_1[0][2];
|
|
||||||
tmp_0[1][3] = tmp_1[1][2];
|
|
||||||
tmp_0[2][3] = tmp_1[2][2];
|
|
||||||
tmp_0[3][3] = tmp_1[3][2];
|
|
||||||
|
|
||||||
for (int ii = 0; ii < 4; ++ii) {
|
|
||||||
for (int jj = 0; jj < 4; ++jj) {
|
|
||||||
Wt[tmp_trns_wt_idx(io, ii, jj, ic)] = tmp_0[ii][jj];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Transform input
|
|
||||||
else {
|
|
||||||
const short it = lid.y;
|
|
||||||
const short ic = lid.x;
|
|
||||||
|
|
||||||
T tmp_0[4][4];
|
|
||||||
T tmp_1[4][4];
|
|
||||||
|
|
||||||
for (int ii = 0; ii < 4; ++ii) {
|
|
||||||
for (int jj = 0; jj < 4; ++jj) {
|
|
||||||
tmp_0[ii][jj] = It[tmp_load_in_idx(it, ii, jj, ic)];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
//////////////////////////////////////////////
|
|
||||||
|
|
||||||
tmp_1[0][0] = tmp_0[0][0] - tmp_0[2][0];
|
|
||||||
tmp_1[0][1] = tmp_0[0][1] - tmp_0[2][1];
|
|
||||||
tmp_1[0][2] = tmp_0[0][2] - tmp_0[2][2];
|
|
||||||
tmp_1[0][3] = tmp_0[0][3] - tmp_0[2][3];
|
|
||||||
|
|
||||||
tmp_1[1][0] = tmp_0[1][0] + tmp_0[2][0];
|
|
||||||
tmp_1[1][1] = tmp_0[1][1] + tmp_0[2][1];
|
|
||||||
tmp_1[1][2] = tmp_0[1][2] + tmp_0[2][2];
|
|
||||||
tmp_1[1][3] = tmp_0[1][3] + tmp_0[2][3];
|
|
||||||
|
|
||||||
tmp_1[2][0] = tmp_0[2][0] - tmp_0[1][0];
|
|
||||||
tmp_1[2][1] = tmp_0[2][1] - tmp_0[1][1];
|
|
||||||
tmp_1[2][2] = tmp_0[2][2] - tmp_0[1][2];
|
|
||||||
tmp_1[2][3] = tmp_0[2][3] - tmp_0[1][3];
|
|
||||||
|
|
||||||
tmp_1[3][0] = tmp_0[1][0] - tmp_0[3][0];
|
|
||||||
tmp_1[3][1] = tmp_0[1][1] - tmp_0[3][1];
|
|
||||||
tmp_1[3][2] = tmp_0[1][2] - tmp_0[3][2];
|
|
||||||
tmp_1[3][3] = tmp_0[1][3] - tmp_0[3][3];
|
|
||||||
|
|
||||||
//////////////////////////////////////////////
|
|
||||||
tmp_0[0][0] = tmp_1[0][0] - tmp_1[0][2];
|
|
||||||
tmp_0[1][0] = tmp_1[1][0] - tmp_1[1][2];
|
|
||||||
tmp_0[2][0] = tmp_1[2][0] - tmp_1[2][2];
|
|
||||||
tmp_0[3][0] = tmp_1[3][0] - tmp_1[3][2];
|
|
||||||
|
|
||||||
tmp_0[0][1] = tmp_1[0][1] + tmp_1[0][2];
|
|
||||||
tmp_0[1][1] = tmp_1[1][1] + tmp_1[1][2];
|
|
||||||
tmp_0[2][1] = tmp_1[2][1] + tmp_1[2][2];
|
|
||||||
tmp_0[3][1] = tmp_1[3][1] + tmp_1[3][2];
|
|
||||||
|
|
||||||
tmp_0[0][2] = tmp_1[0][2] - tmp_1[0][1];
|
|
||||||
tmp_0[1][2] = tmp_1[1][2] - tmp_1[1][1];
|
|
||||||
tmp_0[2][2] = tmp_1[2][2] - tmp_1[2][1];
|
|
||||||
tmp_0[3][2] = tmp_1[3][2] - tmp_1[3][1];
|
|
||||||
|
|
||||||
tmp_0[0][3] = tmp_1[0][1] - tmp_1[0][3];
|
|
||||||
tmp_0[1][3] = tmp_1[1][1] - tmp_1[1][3];
|
|
||||||
tmp_0[2][3] = tmp_1[2][1] - tmp_1[2][3];
|
|
||||||
tmp_0[3][3] = tmp_1[3][1] - tmp_1[3][3];
|
|
||||||
|
|
||||||
for (int ii = 0; ii < 4; ++ii) {
|
|
||||||
for (int jj = 0; jj < 4; ++jj) {
|
|
||||||
It[tmp_trns_in_idx(it, ii, jj, ic)] = tmp_0[ii][jj];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
|
|
||||||
// Do matmul
|
|
||||||
for (int im = 0; im < 4; im++) {
|
|
||||||
simdgroup_barrier(mem_flags::mem_none);
|
|
||||||
Itile.template load<T, 1, 1, BS, 1>(
|
|
||||||
&It[simd_group_id * FA * BS * BS + im * BS * BS + sm * BS + sn]);
|
|
||||||
simdgroup_barrier(mem_flags::mem_none);
|
|
||||||
Wtile.template load<T, 1, 1, BO, 1>(
|
|
||||||
&Wt[simd_group_id * FA * BC * BO + im * BC * BO + sm * BO + sn]);
|
|
||||||
simdgroup_barrier(mem_flags::mem_none);
|
|
||||||
tile_matmad(Otile[im], Itile, Wtile, Otile[im]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Transform and write output
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
for (int im = 0; im < 4; im++) {
|
|
||||||
Otile[im].template store<T, 1, 1, BS, 1>(
|
|
||||||
&It[simd_group_id * FA * BS * BS + im * BS * BS + sm * BS + sn]);
|
|
||||||
}
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
|
|
||||||
if (lid.z == 0) {
|
|
||||||
const short it = lid.y;
|
|
||||||
const short io = lid.x;
|
|
||||||
|
|
||||||
T tmp_0[4][4];
|
|
||||||
T tmp_1[2][4];
|
|
||||||
T tmp_2[2][2];
|
|
||||||
|
|
||||||
for (int ii = 0; ii < 4; ++ii) {
|
|
||||||
for (int jj = 0; jj < 4; ++jj) {
|
|
||||||
tmp_0[ii][jj] = It[tmp_trns_in_idx(it, ii, jj, io)];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
tmp_1[0][0] = tmp_0[0][0] + tmp_0[1][0] + tmp_0[2][0];
|
|
||||||
tmp_1[0][1] = tmp_0[0][1] + tmp_0[1][1] + tmp_0[2][1];
|
|
||||||
tmp_1[0][2] = tmp_0[0][2] + tmp_0[1][2] + tmp_0[2][2];
|
|
||||||
tmp_1[0][3] = tmp_0[0][3] + tmp_0[1][3] + tmp_0[2][3];
|
|
||||||
|
|
||||||
tmp_1[1][0] = tmp_0[1][0] - tmp_0[2][0] - tmp_0[3][0];
|
|
||||||
tmp_1[1][1] = tmp_0[1][1] - tmp_0[2][1] - tmp_0[3][1];
|
|
||||||
tmp_1[1][2] = tmp_0[1][2] - tmp_0[2][2] - tmp_0[3][2];
|
|
||||||
tmp_1[1][3] = tmp_0[1][3] - tmp_0[2][3] - tmp_0[3][3];
|
|
||||||
|
|
||||||
tmp_2[0][0] = tmp_1[0][0] + tmp_1[0][1] + tmp_1[0][2];
|
|
||||||
tmp_2[1][0] = tmp_1[1][0] + tmp_1[1][1] + tmp_1[1][2];
|
|
||||||
|
|
||||||
tmp_2[0][1] = tmp_1[0][1] - tmp_1[0][2] - tmp_1[0][3];
|
|
||||||
tmp_2[1][1] = tmp_1[1][1] - tmp_1[1][2] - tmp_1[1][3];
|
|
||||||
|
|
||||||
const int oH_i = FN * ((BT * tid.y + it) / tWu);
|
|
||||||
const int oW_i = FN * ((BT * tid.y + it) % tWu);
|
|
||||||
|
|
||||||
// clang-format off
|
|
||||||
output += b_idx * params.out_strides[0] + // N
|
|
||||||
oH_i * params.out_strides[1] + // H
|
|
||||||
oW_i * params.out_strides[2] + // W
|
|
||||||
BO * tid.x; // C
|
|
||||||
|
|
||||||
// clang-format on
|
|
||||||
|
|
||||||
output[0 * params.out_strides[1] + 0 * params.out_strides[2] + io] =
|
|
||||||
tmp_2[0][0];
|
|
||||||
output[0 * params.out_strides[1] + 1 * params.out_strides[2] + io] =
|
|
||||||
tmp_2[0][1];
|
|
||||||
output[1 * params.out_strides[1] + 0 * params.out_strides[2] + io] =
|
|
||||||
tmp_2[1][0];
|
|
||||||
output[1 * params.out_strides[1] + 1 * params.out_strides[2] + io] =
|
|
||||||
tmp_2[1][1];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// clang-format off
|
|
||||||
#define instantiate_winograd_conv_2d_fused(name, itype, f) \
|
|
||||||
template [[host_name("winograd_conv_2d_fused_" #name "_flip" #f)]] \
|
|
||||||
[[kernel]] void winograd_fused<itype, f>( \
|
|
||||||
const device itype* input [[buffer(0)]], \
|
|
||||||
const device itype* weight [[buffer(1)]], \
|
|
||||||
device itype* output [[buffer(2)]], \
|
|
||||||
const constant MLXConvParams<2>& params [[buffer(3)]], \
|
|
||||||
uint3 tid [[threadgroup_position_in_grid]], \
|
|
||||||
uint3 lid [[thread_position_in_threadgroup]], \
|
|
||||||
uint3 tgp_per_grid [[threadgroups_per_grid]], \
|
|
||||||
ushort simd_group_id [[simdgroup_index_in_threadgroup]], \
|
|
||||||
ushort simd_lane_id [[thread_index_in_simdgroup]]);
|
|
||||||
|
|
||||||
#define instantiate_winograd_conv_2d_fused_2(name, itype) \
|
|
||||||
instantiate_winograd_conv_2d_fused(name, itype, 0) \
|
|
||||||
instantiate_winograd_conv_2d_fused(name, itype, 1)
|
|
||||||
|
|
||||||
instantiate_winograd_conv_2d_fused_2(float32, float);
|
|
||||||
instantiate_winograd_conv_2d_fused_2(float16, float16_t);
|
|
||||||
instantiate_winograd_conv_2d_fused_2(bfloat16, bfloat16_t);
|
|
||||||
// clang-format on
|
|
||||||
|
@ -3882,7 +3882,7 @@ array conv_general(
|
|||||||
|
|
||||||
return array(
|
return array(
|
||||||
std::move(out_shape),
|
std::move(out_shape),
|
||||||
out_type,
|
in.dtype(),
|
||||||
std::make_shared<Convolution>(
|
std::make_shared<Convolution>(
|
||||||
to_stream(s),
|
to_stream(s),
|
||||||
stride,
|
stride,
|
||||||
|
@ -341,7 +341,7 @@ class TestConv(mlx_tests.MLXTestCase):
|
|||||||
atol, rtol = 1e-1, 1e-3
|
atol, rtol = 1e-1, 1e-3
|
||||||
else:
|
else:
|
||||||
atol, rtol = 1e-5, 1e-6
|
atol, rtol = 1e-5, 1e-6
|
||||||
self.assertTrue(np.allclose(out_pt, out_mx, atol=atol, rtol=rtol))
|
self.assertTrue(np.allclose(out_pt, out_mx, atol=atol))
|
||||||
|
|
||||||
for dtype in ("float32", "bfloat16"):
|
for dtype in ("float32", "bfloat16"):
|
||||||
for N, C, O in (
|
for N, C, O in (
|
||||||
@ -1042,6 +1042,14 @@ class TestConv(mlx_tests.MLXTestCase):
|
|||||||
self.assertTrue(mx.allclose(expected[0], grads[0]))
|
self.assertTrue(mx.allclose(expected[0], grads[0]))
|
||||||
self.assertTrue(mx.allclose(expected[1], grads[1]))
|
self.assertTrue(mx.allclose(expected[1], grads[1]))
|
||||||
|
|
||||||
|
def test_repeated_conv(self):
|
||||||
|
x = mx.random.normal((1, 3, 3, 320))
|
||||||
|
w = mx.random.normal((320, 3, 3, 320))
|
||||||
|
for i in range(8):
|
||||||
|
y1 = mx.conv2d(x, w, (1, 1), (1, 1), (1, 1), 1)
|
||||||
|
y2 = mx.conv2d(x, w, (1, 1), (1, 1), (1, 1), 1)
|
||||||
|
self.assertTrue(mx.allclose(y1, y2))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
2
setup.py
2
setup.py
@ -173,7 +173,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
setup(
|
setup(
|
||||||
name="mlx",
|
name="mlx",
|
||||||
version=get_version("0.23.0"),
|
version=get_version("0.23.1"),
|
||||||
author="MLX Contributors",
|
author="MLX Contributors",
|
||||||
author_email="mlx@group.apple.com",
|
author_email="mlx@group.apple.com",
|
||||||
description="A framework for machine learning on Apple silicon.",
|
description="A framework for machine learning on Apple silicon.",
|
||||||
|
Loading…
Reference in New Issue
Block a user