diff --git a/checkpoint.py b/checkpoint.py index 1c6e878..33711eb 100644 --- a/checkpoint.py +++ b/checkpoint.py @@ -86,7 +86,7 @@ def load_tensors(shaped_arrays, directory, mesh_config, tensor_indices=None): fs = list() num_tensors = 0 num_replicas = 1 - data_model_shards = math.prod(mesh_config) + data_model_shards = np.prod(mesh_config) if tensor_indices is None: iterator = enumerate(shaped_arrays) else: @@ -182,7 +182,7 @@ def restore( state_shapes: Any, mesh, between_hosts_config, - params_only, + params_only: bool, state_sharding, init_state: Optional[Any] = None, ) -> Any: @@ -218,4 +218,4 @@ def restore( state = multihost_utils.host_local_array_to_global_array(state, mesh, state_sharding) if params_only: state = state.params - return state + return state \ No newline at end of file