mirror of
https://github.com/xai-org/grok-1.git
synced 2024-11-24 04:29:53 +03:00
Implemented automated broadcasting in weight rescale when number of model shards is less than number of experts.
This commit is contained in:
parent
7050ed204b
commit
6fd75b4340
47
model.py
47
model.py
@ -44,6 +44,47 @@ class QuantizedWeight8bit:
|
||||
return self.weight.shape
|
||||
|
||||
|
||||
def rescale_quantized_weight(weight: jax.Array, scales: jax.Array) -> jax.Array:
|
||||
"""
|
||||
Automatically handle broadcasting when total
|
||||
number of model shards is less than 8.
|
||||
|
||||
Params:
|
||||
weight: quantized weight array
|
||||
scales: coefficients for restoring weights.
|
||||
|
||||
Returns:
|
||||
Array with same shape as weight and same dtype as scales.
|
||||
"""
|
||||
|
||||
shape_w = weight.shape
|
||||
shape_s = scales.shape
|
||||
|
||||
# Insert new axis at each mismatched axis.
|
||||
shape_w_expanded = []
|
||||
shape_s_expanded = []
|
||||
|
||||
# Insert length_w if matched.
|
||||
# Otherwise, insert (length_s, length_w // length_s) to emulate sharding
|
||||
for length_w, length_s in zip(shape_w, shape_s):
|
||||
if (length_w != length_s) and (length_s > 1):
|
||||
assert length_w % length_s == 0, (length_w, length_s)
|
||||
shape_w_expanded.extend((length_s, length_w // length_s))
|
||||
shape_s_expanded.extend((length_s, 1))
|
||||
else:
|
||||
shape_w_expanded.extend((length_w,))
|
||||
shape_s_expanded.extend((length_s,))
|
||||
|
||||
# Reshape weight along each mismatched axis.
|
||||
w_expanded = weight.reshape(shape_w_expanded)
|
||||
s_expanded = scales.reshape(shape_s_expanded)
|
||||
|
||||
output_expanded = w_expanded.astype(s_expanded.dtype) * s_expanded
|
||||
output = output_expanded.reshape(shape_w)
|
||||
return output
|
||||
|
||||
|
||||
|
||||
tree_util.register_pytree_node(
|
||||
QuantizedWeight8bit,
|
||||
lambda qw: ([qw.weight, qw.scales], ()),
|
||||
@ -330,7 +371,7 @@ class MoELayer(hk.Module):
|
||||
check_rep=False,
|
||||
)
|
||||
def moe_slow_matmul1(input, weight, scales, index, prob):
|
||||
weight = weight * scales
|
||||
weight = rescale_quantized_weight(weight, scales)
|
||||
one_hot_indices = jax.nn.one_hot(index.reshape(-1), 8, axis=0)
|
||||
all_expert_output = jnp.einsum("mk,bkn->bmn", input, weight)
|
||||
output = jnp.einsum("bm,bmn->mn", one_hot_indices, all_expert_output)
|
||||
@ -350,7 +391,7 @@ class MoELayer(hk.Module):
|
||||
check_rep=False,
|
||||
)
|
||||
def moe_slow_matmul2(input, weight, scales, index, prob):
|
||||
weight = weight * scales
|
||||
weight = rescale_quantized_weight(weight, scales)
|
||||
one_hot_indices = jax.nn.one_hot(index.reshape(-1), 8, axis=0)
|
||||
all_expert_output = jnp.einsum("mk,bkn->bmn", input, weight)
|
||||
output = jnp.einsum("bm,bmn->mn", one_hot_indices, all_expert_output)
|
||||
@ -570,7 +611,7 @@ class Linear(hk.Linear):
|
||||
check_rep=False,
|
||||
)
|
||||
def mul(w, s):
|
||||
return w.astype(s.dtype) * s
|
||||
return rescale_quantized_weight(w, s)
|
||||
|
||||
w = mul(w.weight, w.scales)
|
||||
out = jnp.dot(inputs, w.astype(fprop_dtype))
|
||||
|
23
test_modelling.py
Normal file
23
test_modelling.py
Normal file
@ -0,0 +1,23 @@
|
||||
import numpy as np
|
||||
|
||||
from model import rescale_quantized_weight
|
||||
|
||||
|
||||
def test_rescale():
|
||||
weight = np.arange(42).reshape((6, 7)).astype(np.float16)
|
||||
|
||||
# Each row of scales is applied to
|
||||
# three consecutive rows of weight.
|
||||
scales = np.arange(2 * 7).reshape((2, 7)).astype(np.int32)
|
||||
|
||||
rescaled_array = rescale_quantized_weight(weight, scales)
|
||||
assert rescaled_array.shape == weight.shape
|
||||
assert rescaled_array[:, 0].flatten().tolist() == [
|
||||
0 * 0,
|
||||
0 * 7,
|
||||
0 * 14,
|
||||
7 * 21,
|
||||
7 * 28,
|
||||
7 * 35,
|
||||
]
|
||||
assert rescaled_array.dtype == np.int32
|
Loading…
Reference in New Issue
Block a user