From 7a19c9eb9ca7a4ffc0f304a870e78c92ef59eb69 Mon Sep 17 00:00:00 2001 From: Yahweh Rapha Bradford <166758746+El-o-heka@users.noreply.github.com> Date: Tue, 7 May 2024 01:51:50 -0400 Subject: [PATCH] Update checkpoint.py --- checkpoint.py | 16 ++-------------- 1 file changed, 2 insertions(+), 14 deletions(-) diff --git a/checkpoint.py b/checkpoint.py index 1c6e878..aa785a1 100644 --- a/checkpoint.py +++ b/checkpoint.py @@ -1,16 +1,4 @@ -# Copyright 2024 X.AI Corp. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. + from __future__ import annotations @@ -213,7 +201,7 @@ def restore( state_sharding = jax.tree_util.tree_map( lambda x: jax.sharding.PartitionSpec() if x is None else x, state_sharding, - is_leaf=lambda x: x is None, + is_leaf=lambda is None, ) state = multihost_utils.host_local_array_to_global_array(state, mesh, state_sharding) if params_only: