mirror of
https://github.com/xai-org/grok-1.git
synced 2024-11-24 04:29:53 +03:00
Library fix and annotation fix
This commit is contained in:
parent
7050ed204b
commit
937cb24a73
@ -86,7 +86,7 @@ def load_tensors(shaped_arrays, directory, mesh_config, tensor_indices=None):
|
|||||||
fs = list()
|
fs = list()
|
||||||
num_tensors = 0
|
num_tensors = 0
|
||||||
num_replicas = 1
|
num_replicas = 1
|
||||||
data_model_shards = math.prod(mesh_config)
|
data_model_shards = np.prod(mesh_config)
|
||||||
if tensor_indices is None:
|
if tensor_indices is None:
|
||||||
iterator = enumerate(shaped_arrays)
|
iterator = enumerate(shaped_arrays)
|
||||||
else:
|
else:
|
||||||
@ -182,7 +182,7 @@ def restore(
|
|||||||
state_shapes: Any,
|
state_shapes: Any,
|
||||||
mesh,
|
mesh,
|
||||||
between_hosts_config,
|
between_hosts_config,
|
||||||
params_only,
|
params_only: bool,
|
||||||
state_sharding,
|
state_sharding,
|
||||||
init_state: Optional[Any] = None,
|
init_state: Optional[Any] = None,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
@ -218,4 +218,4 @@ def restore(
|
|||||||
state = multihost_utils.host_local_array_to_global_array(state, mesh, state_sharding)
|
state = multihost_utils.host_local_array_to_global_array(state, mesh, state_sharding)
|
||||||
if params_only:
|
if params_only:
|
||||||
state = state.params
|
state = state.params
|
||||||
return state
|
return state
|
Loading…
Reference in New Issue
Block a user