mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Compare commits
8 Commits
012fb220a1
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c2764d1073 | ||
|
|
093a62d2ed | ||
|
|
1b591ec736 | ||
|
|
47d2505ea9 | ||
|
|
bedefed784 | ||
|
|
ccaaa7d6df | ||
|
|
f3e5ca5414 | ||
|
|
81dfe5f137 |
20
.github/actions/setup-linux/action.yml
vendored
20
.github/actions/setup-linux/action.yml
vendored
@@ -10,23 +10,29 @@ inputs:
|
|||||||
description: 'Version of python to set up'
|
description: 'Version of python to set up'
|
||||||
required: false
|
required: false
|
||||||
default: '3.10'
|
default: '3.10'
|
||||||
|
use-ccache:
|
||||||
|
description: 'Whether to enable ccache'
|
||||||
|
required: false
|
||||||
|
default: 'true'
|
||||||
|
|
||||||
runs:
|
runs:
|
||||||
using: "composite"
|
using: "composite"
|
||||||
steps:
|
steps:
|
||||||
- name: Use ccache
|
|
||||||
if: ${{ runner.arch == 'x86_64' }}
|
|
||||||
uses: hendrikmuhs/ccache-action@v1.2
|
|
||||||
with:
|
|
||||||
key: ccache-${{ runner.os }}-${{ runner.arch }}-${{ inputs.toolkit }}-py${{ inputs.python-version }}
|
|
||||||
max-size: 1GB
|
|
||||||
|
|
||||||
- name: Install common dependencies
|
- name: Install common dependencies
|
||||||
shell: bash
|
shell: bash
|
||||||
run: |
|
run: |
|
||||||
sudo apt-get update
|
sudo apt-get update
|
||||||
sudo apt-get install -y libblas-dev liblapack-dev liblapacke-dev zip
|
sudo apt-get install -y libblas-dev liblapack-dev liblapacke-dev zip
|
||||||
|
|
||||||
|
- name: Use ccache
|
||||||
|
if: ${{ inputs.use-ccache == 'true' }}
|
||||||
|
uses: hendrikmuhs/ccache-action@v1.2
|
||||||
|
with:
|
||||||
|
key: ccache-${{ runner.os }}-${{ runner.arch }}-${{ inputs.toolkit }}
|
||||||
|
max-size: 1GB
|
||||||
|
# ccache-action bug: running "apt-get update" fails on large arm runner.
|
||||||
|
update-package-index: false
|
||||||
|
|
||||||
- uses: actions/setup-python@v6
|
- uses: actions/setup-python@v6
|
||||||
with:
|
with:
|
||||||
python-version: ${{ inputs.python-version }}
|
python-version: ${{ inputs.python-version }}
|
||||||
|
|||||||
6
.github/workflows/nightly.yml
vendored
6
.github/workflows/nightly.yml
vendored
@@ -23,14 +23,14 @@ jobs:
|
|||||||
build-backend: ${{ matrix.python-version == '3.10' }}
|
build-backend: ${{ matrix.python-version == '3.10' }}
|
||||||
arch: "x86_64"
|
arch: "x86_64"
|
||||||
- name: Upload mlx artifacts
|
- name: Upload mlx artifacts
|
||||||
uses: actions/upload-artifact@v5
|
uses: actions/upload-artifact@v6
|
||||||
with:
|
with:
|
||||||
name: linux-wheels-${{ matrix.python_version }}
|
name: linux-wheels-${{ matrix.python_version }}
|
||||||
path: wheelhouse/mlx-*.whl
|
path: wheelhouse/mlx-*.whl
|
||||||
retention-days: 7
|
retention-days: 7
|
||||||
- name: Upload mlx-cpu artifacts
|
- name: Upload mlx-cpu artifacts
|
||||||
if: matrix.python_version == '3.10'
|
if: matrix.python_version == '3.10'
|
||||||
uses: actions/upload-artifact@v5
|
uses: actions/upload-artifact@v6
|
||||||
with:
|
with:
|
||||||
name: mlx-cpu
|
name: mlx-cpu
|
||||||
path: wheelhouse/mlx_cpu-*.whl
|
path: wheelhouse/mlx_cpu-*.whl
|
||||||
@@ -89,7 +89,7 @@ jobs:
|
|||||||
with:
|
with:
|
||||||
toolkit: 'cuda-12.9'
|
toolkit: 'cuda-12.9'
|
||||||
- name: Upload artifacts
|
- name: Upload artifacts
|
||||||
uses: actions/upload-artifact@v5
|
uses: actions/upload-artifact@v6
|
||||||
with:
|
with:
|
||||||
name: mlx-cuda
|
name: mlx-cuda
|
||||||
path: wheelhouse/mlx_cuda-*.whl
|
path: wheelhouse/mlx_cuda-*.whl
|
||||||
|
|||||||
22
.github/workflows/release.yml
vendored
22
.github/workflows/release.yml
vendored
@@ -57,19 +57,20 @@ jobs:
|
|||||||
- uses: ./.github/actions/setup-linux
|
- uses: ./.github/actions/setup-linux
|
||||||
with:
|
with:
|
||||||
python-version: ${{ matrix.python_version }}
|
python-version: ${{ matrix.python_version }}
|
||||||
|
use-ccache: false
|
||||||
- uses: ./.github/actions/build-linux-release
|
- uses: ./.github/actions/build-linux-release
|
||||||
with:
|
with:
|
||||||
build-backend: ${{ matrix.python-version == '3.10' }}
|
build-backend: ${{ matrix.python-version == '3.10' }}
|
||||||
arch: ${{ matrix.arch }}
|
arch: ${{ matrix.arch }}
|
||||||
- name: Upload MLX artifacts
|
- name: Upload MLX artifacts
|
||||||
uses: actions/upload-artifact@v5
|
uses: actions/upload-artifact@v6
|
||||||
with:
|
with:
|
||||||
overwrite: true
|
overwrite: true
|
||||||
name: linux-wheels-${{ matrix.python_version }}-${{ matrix.arch }}
|
name: linux-wheels-${{ matrix.python_version }}-${{ matrix.arch }}
|
||||||
path: wheelhouse/mlx-*.whl
|
path: wheelhouse/mlx-*.whl
|
||||||
- name: Upload CPU artifacts
|
- name: Upload CPU artifacts
|
||||||
if: matrix.python_version == '3.10'
|
if: matrix.python_version == '3.10'
|
||||||
uses: actions/upload-artifact@v5
|
uses: actions/upload-artifact@v6
|
||||||
with:
|
with:
|
||||||
overwrite: true
|
overwrite: true
|
||||||
name: mlx-cpu-${{ matrix.arch }}
|
name: mlx-cpu-${{ matrix.arch }}
|
||||||
@@ -113,14 +114,14 @@ jobs:
|
|||||||
macos-target: 15.0
|
macos-target: 15.0
|
||||||
build-backend: ${{ matrix.python-version == '3.10' }}
|
build-backend: ${{ matrix.python-version == '3.10' }}
|
||||||
- name: Upload MLX artifacts
|
- name: Upload MLX artifacts
|
||||||
uses: actions/upload-artifact@v5
|
uses: actions/upload-artifact@v6
|
||||||
with:
|
with:
|
||||||
overwrite: true
|
overwrite: true
|
||||||
name: mac-wheels-${{ matrix.python-version }}
|
name: mac-wheels-${{ matrix.python-version }}
|
||||||
path: dist/mlx-*.whl
|
path: dist/mlx-*.whl
|
||||||
- name: Upload Metal artifacts
|
- name: Upload Metal artifacts
|
||||||
if: matrix.python-version == '3.10'
|
if: matrix.python-version == '3.10'
|
||||||
uses: actions/upload-artifact@v5
|
uses: actions/upload-artifact@v6
|
||||||
with:
|
with:
|
||||||
overwrite: true
|
overwrite: true
|
||||||
name: mlx-metal
|
name: mlx-metal
|
||||||
@@ -141,12 +142,13 @@ jobs:
|
|||||||
- uses: ./.github/actions/setup-linux
|
- uses: ./.github/actions/setup-linux
|
||||||
with:
|
with:
|
||||||
toolkit: ${{ matrix.toolkit }}
|
toolkit: ${{ matrix.toolkit }}
|
||||||
|
use-ccache: false
|
||||||
- name: Build Python package
|
- name: Build Python package
|
||||||
uses: ./.github/actions/build-cuda-release
|
uses: ./.github/actions/build-cuda-release
|
||||||
with:
|
with:
|
||||||
arch: ${{ matrix.arch }}
|
arch: ${{ matrix.arch }}
|
||||||
- name: Upload artifacts
|
- name: Upload artifacts
|
||||||
uses: actions/upload-artifact@v5
|
uses: actions/upload-artifact@v6
|
||||||
with:
|
with:
|
||||||
overwrite: true
|
overwrite: true
|
||||||
name: mlx-cuda
|
name: mlx-cuda
|
||||||
@@ -162,12 +164,12 @@ jobs:
|
|||||||
name: pypi
|
name: pypi
|
||||||
url: https://pypi.org/p/mlx
|
url: https://pypi.org/p/mlx
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/download-artifact@v6
|
- uses: actions/download-artifact@v7
|
||||||
with:
|
with:
|
||||||
pattern: linux-wheels-*
|
pattern: linux-wheels-*
|
||||||
merge-multiple: true
|
merge-multiple: true
|
||||||
path: dist
|
path: dist
|
||||||
- uses: actions/download-artifact@v6
|
- uses: actions/download-artifact@v7
|
||||||
with:
|
with:
|
||||||
pattern: mac-wheels-*
|
pattern: mac-wheels-*
|
||||||
merge-multiple: true
|
merge-multiple: true
|
||||||
@@ -189,7 +191,7 @@ jobs:
|
|||||||
name: pypi
|
name: pypi
|
||||||
url: https://pypi.org/p/mlx-cuda
|
url: https://pypi.org/p/mlx-cuda
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/download-artifact@v6
|
- uses: actions/download-artifact@v7
|
||||||
with:
|
with:
|
||||||
name: mlx-cuda
|
name: mlx-cuda
|
||||||
path: dist
|
path: dist
|
||||||
@@ -210,7 +212,7 @@ jobs:
|
|||||||
name: pypi
|
name: pypi
|
||||||
url: https://pypi.org/p/mlx-cpu
|
url: https://pypi.org/p/mlx-cpu
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/download-artifact@v6
|
- uses: actions/download-artifact@v7
|
||||||
with:
|
with:
|
||||||
pattern: mlx-cpu-*
|
pattern: mlx-cpu-*
|
||||||
merge-multiple: true
|
merge-multiple: true
|
||||||
@@ -232,7 +234,7 @@ jobs:
|
|||||||
name: pypi
|
name: pypi
|
||||||
url: https://pypi.org/p/mlx-metal
|
url: https://pypi.org/p/mlx-metal
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/download-artifact@v6
|
- uses: actions/download-artifact@v7
|
||||||
with:
|
with:
|
||||||
name: mlx-metal
|
name: mlx-metal
|
||||||
path: dist
|
path: dist
|
||||||
|
|||||||
@@ -291,6 +291,17 @@ void RandomBits::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
num_keys,
|
num_keys,
|
||||||
kshape = keys.shape(),
|
kshape = keys.shape(),
|
||||||
kstrides = keys.strides()]() mutable {
|
kstrides = keys.strides()]() mutable {
|
||||||
|
auto copy_remaining = [&](char* cptr, size_t loc, uint32_t v) {
|
||||||
|
if (4 * loc + 4 <= bytes_per_key) {
|
||||||
|
reinterpret_cast<uint32_t*>(cptr)[loc] = v;
|
||||||
|
} else {
|
||||||
|
std::copy(
|
||||||
|
reinterpret_cast<char*>(&v),
|
||||||
|
reinterpret_cast<char*>(&v) + bytes_per_key - 4 * loc,
|
||||||
|
cptr + 4 * loc);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
size_t out_skip = (bytes_per_key + 4 - 1) / 4;
|
size_t out_skip = (bytes_per_key + 4 - 1) / 4;
|
||||||
auto half_size = out_skip / 2;
|
auto half_size = out_skip / 2;
|
||||||
bool even = out_skip % 2 == 0;
|
bool even = out_skip % 2 == 0;
|
||||||
@@ -310,18 +321,12 @@ void RandomBits::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
if (count.first < half_size) {
|
if (count.first < half_size) {
|
||||||
auto rb = random::threefry2x32_hash(key, count);
|
auto rb = random::threefry2x32_hash(key, count);
|
||||||
ptr[count.first++] = rb.first;
|
ptr[count.first++] = rb.first;
|
||||||
if (bytes_per_key % 4 > 0) {
|
copy_remaining(cptr, count.second, rb.second);
|
||||||
std::copy(
|
|
||||||
reinterpret_cast<char*>(&rb.second),
|
|
||||||
reinterpret_cast<char*>(&rb.second) + bytes_per_key % 4,
|
|
||||||
cptr + 4 * count.second);
|
|
||||||
} else {
|
|
||||||
ptr[count.second] = rb.second;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
if (!even) {
|
if (!even) {
|
||||||
count.second = 0;
|
count.second = 0;
|
||||||
ptr[half_size] = random::threefry2x32_hash(key, count).first;
|
copy_remaining(
|
||||||
|
cptr, half_size, random::threefry2x32_hash(key, count).first);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -348,6 +348,9 @@ std::pair<std::string, bool> subgraph_to_key(cudaGraph_t graph) {
|
|||||||
key += subkey;
|
key += subkey;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
case cudaGraphNodeTypeHost:
|
||||||
|
key += "H";
|
||||||
|
break;
|
||||||
case cudaGraphNodeTypeMemset:
|
case cudaGraphNodeTypeMemset:
|
||||||
key += "M";
|
key += "M";
|
||||||
break;
|
break;
|
||||||
|
|||||||
@@ -347,7 +347,7 @@ template <
|
|||||||
MMAFrag_mask_t::load_safe(
|
MMAFrag_mask_t::load_safe(
|
||||||
mfrag,
|
mfrag,
|
||||||
mask,
|
mask,
|
||||||
int(mask_params->M_strides[2]),
|
int64_t(mask_params->M_strides[2]),
|
||||||
Int<1>{},
|
Int<1>{},
|
||||||
params->qL,
|
params->qL,
|
||||||
params->kL,
|
params->kL,
|
||||||
|
|||||||
@@ -346,7 +346,7 @@ template <
|
|||||||
MSubTile mfrag;
|
MSubTile mfrag;
|
||||||
mfrag.load_safe(
|
mfrag.load_safe(
|
||||||
mask,
|
mask,
|
||||||
int(mask_params->M_strides[2]),
|
int64_t(mask_params->M_strides[2]),
|
||||||
Int<1>{},
|
Int<1>{},
|
||||||
params->qL,
|
params->qL,
|
||||||
params->kL,
|
params->kL,
|
||||||
|
|||||||
@@ -105,17 +105,20 @@ struct BaseMMAFrag<T, 8, 8> {
|
|||||||
LimY lim_y,
|
LimY lim_y,
|
||||||
OffX off_x = Int<0>{},
|
OffX off_x = Int<0>{},
|
||||||
OffY off_y = Int<0>{}) {
|
OffY off_y = Int<0>{}) {
|
||||||
|
src += off_x * str_x + off_y * str_y;
|
||||||
STEEL_PRAGMA_UNROLL
|
STEEL_PRAGMA_UNROLL
|
||||||
for (short i = 0; i < kElemRows; i++) {
|
for (short i = 0; i < kElemRows; i++) {
|
||||||
STEEL_PRAGMA_UNROLL
|
STEEL_PRAGMA_UNROLL
|
||||||
for (short j = 0; j < kElemCols; j++) {
|
for (short j = 0; j < kElemCols; j++) {
|
||||||
if ((off_x + i) < lim_x && (off_y + j) < lim_y) {
|
if ((off_x + i) < lim_x && (off_y + j) < lim_y) {
|
||||||
dst[i * kElemCols + j] =
|
dst[i * kElemCols + j] = static_cast<T>(src[0]);
|
||||||
static_cast<T>(src[(off_x + i) * str_x + (off_y + j) * str_y]);
|
|
||||||
} else {
|
} else {
|
||||||
dst[i * kElemCols + j] = T(0);
|
dst[i * kElemCols + j] = T(0);
|
||||||
}
|
}
|
||||||
|
src += str_y;
|
||||||
}
|
}
|
||||||
|
src -= kElemCols * str_y;
|
||||||
|
src += str_x;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -880,6 +880,11 @@ std::vector<array> ScaledDotProductAttention::vjp(
|
|||||||
|
|
||||||
std::vector<array> returned_vjps;
|
std::vector<array> returned_vjps;
|
||||||
for (int arg : argnums) {
|
for (int arg : argnums) {
|
||||||
|
if (arg >= 3) {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[scale_dot_product_attention] Does not support VJP with respect "
|
||||||
|
" to mask or attention sinks.");
|
||||||
|
}
|
||||||
returned_vjps.push_back(std::move(vjps[arg]));
|
returned_vjps.push_back(std::move(vjps[arg]));
|
||||||
}
|
}
|
||||||
return returned_vjps;
|
return returned_vjps;
|
||||||
|
|||||||
@@ -124,37 +124,53 @@ auto py_value_and_grad(
|
|||||||
|
|
||||||
// Collect the arrays
|
// Collect the arrays
|
||||||
std::vector<mx::array> arrays;
|
std::vector<mx::array> arrays;
|
||||||
|
std::vector<nb::object> array_objects;
|
||||||
|
auto flatten_with_objects = [&arrays, &array_objects](
|
||||||
|
auto tree, bool strict) {
|
||||||
|
tree_visit(tree, [&](nb::handle obj) {
|
||||||
|
if (nb::isinstance<mx::array>(obj)) {
|
||||||
|
arrays.push_back(nb::cast<mx::array>(obj));
|
||||||
|
array_objects.push_back(nb::borrow<nb::object>(obj));
|
||||||
|
} else if (strict) {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[tree_flatten] The argument should contain only arrays");
|
||||||
|
}
|
||||||
|
});
|
||||||
|
};
|
||||||
|
|
||||||
std::vector<int> counts(1, 0);
|
std::vector<int> counts(1, 0);
|
||||||
std::vector<int> gradient_indices;
|
std::vector<int> gradient_indices;
|
||||||
for (int i = 0, j = 0; i < args.size(); ++i) {
|
for (int i = 0, j = 0; i < args.size(); ++i) {
|
||||||
bool needs_grad = (j < argnums.size() && argnums[j] == i);
|
bool needs_grad = (j < argnums.size() && argnums[j] == i);
|
||||||
auto argsi = tree_flatten(args[i], /* strict = */ needs_grad);
|
auto pre_size = arrays.size();
|
||||||
|
flatten_with_objects(args[i], /* strict = */ needs_grad);
|
||||||
if (needs_grad) {
|
if (needs_grad) {
|
||||||
auto old_size = gradient_indices.size();
|
auto old_size = gradient_indices.size();
|
||||||
gradient_indices.resize(old_size + argsi.size());
|
auto delta_size = arrays.size() - pre_size;
|
||||||
|
gradient_indices.resize(old_size + delta_size);
|
||||||
std::iota(
|
std::iota(
|
||||||
gradient_indices.begin() + old_size,
|
gradient_indices.begin() + old_size,
|
||||||
gradient_indices.end(),
|
gradient_indices.end(),
|
||||||
arrays.size());
|
pre_size);
|
||||||
j++;
|
j++;
|
||||||
counts.push_back(argsi.size());
|
counts.push_back(delta_size);
|
||||||
}
|
}
|
||||||
arrays.insert(arrays.end(), argsi.begin(), argsi.end());
|
|
||||||
}
|
}
|
||||||
for (auto item : kwargs) {
|
for (auto item : kwargs) {
|
||||||
bool needs_grad =
|
bool needs_grad =
|
||||||
(argnames.find(nb::cast<std::string>(item.first)) != argnames.end());
|
(argnames.find(nb::cast<std::string>(item.first)) != argnames.end());
|
||||||
auto argsk = tree_flatten(item.second, /* strict = */ needs_grad);
|
auto pre_size = arrays.size();
|
||||||
|
flatten_with_objects(item.second, /* strict = */ needs_grad);
|
||||||
if (needs_grad) {
|
if (needs_grad) {
|
||||||
auto old_size = gradient_indices.size();
|
auto old_size = gradient_indices.size();
|
||||||
gradient_indices.resize(old_size + argsk.size());
|
auto delta_size = arrays.size() - pre_size;
|
||||||
|
gradient_indices.resize(old_size + delta_size);
|
||||||
std::iota(
|
std::iota(
|
||||||
gradient_indices.begin() + old_size,
|
gradient_indices.begin() + old_size,
|
||||||
gradient_indices.end(),
|
gradient_indices.end(),
|
||||||
arrays.size());
|
pre_size);
|
||||||
counts.push_back(argsk.size());
|
counts.push_back(delta_size);
|
||||||
}
|
}
|
||||||
arrays.insert(arrays.end(), argsk.begin(), argsk.end());
|
|
||||||
}
|
}
|
||||||
std::partial_sum(counts.cbegin(), counts.cend(), counts.begin());
|
std::partial_sum(counts.cbegin(), counts.cend(), counts.begin());
|
||||||
|
|
||||||
@@ -163,7 +179,7 @@ auto py_value_and_grad(
|
|||||||
nb::object py_value_out;
|
nb::object py_value_out;
|
||||||
auto value_and_grads = mx::value_and_grad(
|
auto value_and_grads = mx::value_and_grad(
|
||||||
[&fun,
|
[&fun,
|
||||||
&arrays,
|
&array_objects,
|
||||||
&args,
|
&args,
|
||||||
&kwargs,
|
&kwargs,
|
||||||
&py_value_out,
|
&py_value_out,
|
||||||
@@ -183,8 +199,9 @@ auto py_value_and_grad(
|
|||||||
tree_visit_update(tree, [&](nb::handle node) {
|
tree_visit_update(tree, [&](nb::handle node) {
|
||||||
auto replace_arr = nb::cast<mx::array>(node);
|
auto replace_arr = nb::cast<mx::array>(node);
|
||||||
if (replace_arr.id() == a[index].id()) {
|
if (replace_arr.id() == a[index].id()) {
|
||||||
return nb::cast(arrays[index++]);
|
return array_objects[index++];
|
||||||
} else {
|
} else {
|
||||||
|
index++;
|
||||||
return nb::cast(replace_arr);
|
return nb::cast(replace_arr);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -780,9 +780,21 @@ class TestAutograd(mlx_tests.MLXTestCase):
|
|||||||
return arrs[0]
|
return arrs[0]
|
||||||
|
|
||||||
arrs = [mx.array(1.0)]
|
arrs = [mx.array(1.0)]
|
||||||
init_id = id(arrs[0])
|
arr = arrs[0]
|
||||||
mx.grad(fun)(arrs)
|
mx.grad(fun)(arrs)
|
||||||
self.assertEqual(init_id, id(arrs[0]))
|
self.assertEqual(id(arr), id(arrs[0]))
|
||||||
|
|
||||||
|
def fun(arrs):
|
||||||
|
arrs[1] = sum(arrs)
|
||||||
|
return arrs[1]
|
||||||
|
|
||||||
|
arrs = [mx.array(1.0), mx.array(1.0), mx.array(1.0)]
|
||||||
|
a_0, a_1, a_2 = arrs
|
||||||
|
|
||||||
|
mx.grad(fun)(arrs)
|
||||||
|
self.assertEqual(id(a_0), id(arrs[0]))
|
||||||
|
self.assertNotEqual(id(a_1), id(arrs[1]))
|
||||||
|
self.assertEqual(id(a_2), id(arrs[2]))
|
||||||
|
|
||||||
def test_grad_with_inplace_update(self):
|
def test_grad_with_inplace_update(self):
|
||||||
def loss_fn(model):
|
def loss_fn(model):
|
||||||
|
|||||||
@@ -744,7 +744,6 @@ class TestVmap(mlx_tests.MLXTestCase):
|
|||||||
return Vector([t[0] + 10, t[1] * 10])
|
return Vector([t[0] + 10, t[1] * 10])
|
||||||
|
|
||||||
x = State(mx.array(1), mx.array(2))
|
x = State(mx.array(1), mx.array(2))
|
||||||
print(f"{transform(x)=}")
|
|
||||||
|
|
||||||
vmap_transform = mx.vmap(transform)
|
vmap_transform = mx.vmap(transform)
|
||||||
vmap_transform_tuple = mx.vmap(transform_tuple)
|
vmap_transform_tuple = mx.vmap(transform_tuple)
|
||||||
|
|||||||
Reference in New Issue
Block a user