mirror of
https://github.com/xai-org/grok-1.git
synced 2024-11-24 12:39:54 +03:00
Implemented automated broadcasting in weight rescale when number of model shards is less than number of experts.
This commit is contained in:
parent
7050ed204b
commit
6fd75b4340
47
model.py
47
model.py
@ -44,6 +44,47 @@ class QuantizedWeight8bit:
|
|||||||
return self.weight.shape
|
return self.weight.shape
|
||||||
|
|
||||||
|
|
||||||
|
def rescale_quantized_weight(weight: jax.Array, scales: jax.Array) -> jax.Array:
|
||||||
|
"""
|
||||||
|
Automatically handle broadcasting when total
|
||||||
|
number of model shards is less than 8.
|
||||||
|
|
||||||
|
Params:
|
||||||
|
weight: quantized weight array
|
||||||
|
scales: coefficients for restoring weights.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Array with same shape as weight and same dtype as scales.
|
||||||
|
"""
|
||||||
|
|
||||||
|
shape_w = weight.shape
|
||||||
|
shape_s = scales.shape
|
||||||
|
|
||||||
|
# Insert new axis at each mismatched axis.
|
||||||
|
shape_w_expanded = []
|
||||||
|
shape_s_expanded = []
|
||||||
|
|
||||||
|
# Insert length_w if matched.
|
||||||
|
# Otherwise, insert (length_s, length_w // length_s) to emulate sharding
|
||||||
|
for length_w, length_s in zip(shape_w, shape_s):
|
||||||
|
if (length_w != length_s) and (length_s > 1):
|
||||||
|
assert length_w % length_s == 0, (length_w, length_s)
|
||||||
|
shape_w_expanded.extend((length_s, length_w // length_s))
|
||||||
|
shape_s_expanded.extend((length_s, 1))
|
||||||
|
else:
|
||||||
|
shape_w_expanded.extend((length_w,))
|
||||||
|
shape_s_expanded.extend((length_s,))
|
||||||
|
|
||||||
|
# Reshape weight along each mismatched axis.
|
||||||
|
w_expanded = weight.reshape(shape_w_expanded)
|
||||||
|
s_expanded = scales.reshape(shape_s_expanded)
|
||||||
|
|
||||||
|
output_expanded = w_expanded.astype(s_expanded.dtype) * s_expanded
|
||||||
|
output = output_expanded.reshape(shape_w)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
tree_util.register_pytree_node(
|
tree_util.register_pytree_node(
|
||||||
QuantizedWeight8bit,
|
QuantizedWeight8bit,
|
||||||
lambda qw: ([qw.weight, qw.scales], ()),
|
lambda qw: ([qw.weight, qw.scales], ()),
|
||||||
@ -330,7 +371,7 @@ class MoELayer(hk.Module):
|
|||||||
check_rep=False,
|
check_rep=False,
|
||||||
)
|
)
|
||||||
def moe_slow_matmul1(input, weight, scales, index, prob):
|
def moe_slow_matmul1(input, weight, scales, index, prob):
|
||||||
weight = weight * scales
|
weight = rescale_quantized_weight(weight, scales)
|
||||||
one_hot_indices = jax.nn.one_hot(index.reshape(-1), 8, axis=0)
|
one_hot_indices = jax.nn.one_hot(index.reshape(-1), 8, axis=0)
|
||||||
all_expert_output = jnp.einsum("mk,bkn->bmn", input, weight)
|
all_expert_output = jnp.einsum("mk,bkn->bmn", input, weight)
|
||||||
output = jnp.einsum("bm,bmn->mn", one_hot_indices, all_expert_output)
|
output = jnp.einsum("bm,bmn->mn", one_hot_indices, all_expert_output)
|
||||||
@ -350,7 +391,7 @@ class MoELayer(hk.Module):
|
|||||||
check_rep=False,
|
check_rep=False,
|
||||||
)
|
)
|
||||||
def moe_slow_matmul2(input, weight, scales, index, prob):
|
def moe_slow_matmul2(input, weight, scales, index, prob):
|
||||||
weight = weight * scales
|
weight = rescale_quantized_weight(weight, scales)
|
||||||
one_hot_indices = jax.nn.one_hot(index.reshape(-1), 8, axis=0)
|
one_hot_indices = jax.nn.one_hot(index.reshape(-1), 8, axis=0)
|
||||||
all_expert_output = jnp.einsum("mk,bkn->bmn", input, weight)
|
all_expert_output = jnp.einsum("mk,bkn->bmn", input, weight)
|
||||||
output = jnp.einsum("bm,bmn->mn", one_hot_indices, all_expert_output)
|
output = jnp.einsum("bm,bmn->mn", one_hot_indices, all_expert_output)
|
||||||
@ -570,7 +611,7 @@ class Linear(hk.Linear):
|
|||||||
check_rep=False,
|
check_rep=False,
|
||||||
)
|
)
|
||||||
def mul(w, s):
|
def mul(w, s):
|
||||||
return w.astype(s.dtype) * s
|
return rescale_quantized_weight(w, s)
|
||||||
|
|
||||||
w = mul(w.weight, w.scales)
|
w = mul(w.weight, w.scales)
|
||||||
out = jnp.dot(inputs, w.astype(fprop_dtype))
|
out = jnp.dot(inputs, w.astype(fprop_dtype))
|
||||||
|
23
test_modelling.py
Normal file
23
test_modelling.py
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from model import rescale_quantized_weight
|
||||||
|
|
||||||
|
|
||||||
|
def test_rescale():
|
||||||
|
weight = np.arange(42).reshape((6, 7)).astype(np.float16)
|
||||||
|
|
||||||
|
# Each row of scales is applied to
|
||||||
|
# three consecutive rows of weight.
|
||||||
|
scales = np.arange(2 * 7).reshape((2, 7)).astype(np.int32)
|
||||||
|
|
||||||
|
rescaled_array = rescale_quantized_weight(weight, scales)
|
||||||
|
assert rescaled_array.shape == weight.shape
|
||||||
|
assert rescaled_array[:, 0].flatten().tolist() == [
|
||||||
|
0 * 0,
|
||||||
|
0 * 7,
|
||||||
|
0 * 14,
|
||||||
|
7 * 21,
|
||||||
|
7 * 28,
|
||||||
|
7 * 35,
|
||||||
|
]
|
||||||
|
assert rescaled_array.dtype == np.int32
|
Loading…
Reference in New Issue
Block a user