Library fix and annotation fix

This commit is contained in:
Neelesh Verma 2024-03-19 20:22:31 -04:00
parent 7050ed204b
commit 937cb24a73
No known key found for this signature in database
GPG Key ID: 89578C08FB09F0AB

View File

@ -86,7 +86,7 @@ def load_tensors(shaped_arrays, directory, mesh_config, tensor_indices=None):
fs = list() fs = list()
num_tensors = 0 num_tensors = 0
num_replicas = 1 num_replicas = 1
data_model_shards = math.prod(mesh_config) data_model_shards = np.prod(mesh_config)
if tensor_indices is None: if tensor_indices is None:
iterator = enumerate(shaped_arrays) iterator = enumerate(shaped_arrays)
else: else:
@ -182,7 +182,7 @@ def restore(
state_shapes: Any, state_shapes: Any,
mesh, mesh,
between_hosts_config, between_hosts_config,
params_only, params_only: bool,
state_sharding, state_sharding,
init_state: Optional[Any] = None, init_state: Optional[Any] = None,
) -> Any: ) -> Any:
@ -218,4 +218,4 @@ def restore(
state = multihost_utils.host_local_array_to_global_array(state, mesh, state_sharding) state = multihost_utils.host_local_array_to_global_array(state, mesh, state_sharding)
if params_only: if params_only:
state = state.params state = state.params
return state return state