diff --git a/checkpoint.py b/checkpoint.py index 1c6e878..aa785a1 100644 --- a/checkpoint.py +++ b/checkpoint.py @@ -1,16 +1,4 @@ -# Copyright 2024 X.AI Corp. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. + from __future__ import annotations @@ -213,7 +201,7 @@ def restore( state_sharding = jax.tree_util.tree_map( lambda x: jax.sharding.PartitionSpec() if x is None else x, state_sharding, - is_leaf=lambda x: x is None, + is_leaf=lambda is None, ) state = multihost_utils.host_local_array_to_global_array(state, mesh, state_sharding) if params_only: