mirror of
https://github.com/xai-org/grok-1.git
synced 2024-11-09 05:10:40 +03:00
Library fix and annotation fix
This commit is contained in:
parent
7050ed204b
commit
937cb24a73
@ -86,7 +86,7 @@ def load_tensors(shaped_arrays, directory, mesh_config, tensor_indices=None):
|
||||
fs = list()
|
||||
num_tensors = 0
|
||||
num_replicas = 1
|
||||
data_model_shards = math.prod(mesh_config)
|
||||
data_model_shards = np.prod(mesh_config)
|
||||
if tensor_indices is None:
|
||||
iterator = enumerate(shaped_arrays)
|
||||
else:
|
||||
@ -182,7 +182,7 @@ def restore(
|
||||
state_shapes: Any,
|
||||
mesh,
|
||||
between_hosts_config,
|
||||
params_only,
|
||||
params_only: bool,
|
||||
state_sharding,
|
||||
init_state: Optional[Any] = None,
|
||||
) -> Any:
|
||||
@ -218,4 +218,4 @@ def restore(
|
||||
state = multihost_utils.host_local_array_to_global_array(state, mesh, state_sharding)
|
||||
if params_only:
|
||||
state = state.params
|
||||
return state
|
||||
return state
|
Loading…
Reference in New Issue
Block a user