Implemented automated broadcasting in weight rescale when number of model shards is less than number of experts.

This commit is contained in:
Jacob-Junqi Tian 2024-03-21 17:23:45 -04:00
parent 7050ed204b
commit 6fd75b4340
No known key found for this signature in database
GPG Key ID: 0BCBEDEEBC9054CA
2 changed files with 67 additions and 3 deletions

View File

@ -44,6 +44,47 @@ class QuantizedWeight8bit:
return self.weight.shape return self.weight.shape
def rescale_quantized_weight(weight: jax.Array, scales: jax.Array) -> jax.Array:
"""
Automatically handle broadcasting when total
number of model shards is less than 8.
Params:
weight: quantized weight array
scales: coefficients for restoring weights.
Returns:
Array with same shape as weight and same dtype as scales.
"""
shape_w = weight.shape
shape_s = scales.shape
# Insert new axis at each mismatched axis.
shape_w_expanded = []
shape_s_expanded = []
# Insert length_w if matched.
# Otherwise, insert (length_s, length_w // length_s) to emulate sharding
for length_w, length_s in zip(shape_w, shape_s):
if (length_w != length_s) and (length_s > 1):
assert length_w % length_s == 0, (length_w, length_s)
shape_w_expanded.extend((length_s, length_w // length_s))
shape_s_expanded.extend((length_s, 1))
else:
shape_w_expanded.extend((length_w,))
shape_s_expanded.extend((length_s,))
# Reshape weight along each mismatched axis.
w_expanded = weight.reshape(shape_w_expanded)
s_expanded = scales.reshape(shape_s_expanded)
output_expanded = w_expanded.astype(s_expanded.dtype) * s_expanded
output = output_expanded.reshape(shape_w)
return output
tree_util.register_pytree_node( tree_util.register_pytree_node(
QuantizedWeight8bit, QuantizedWeight8bit,
lambda qw: ([qw.weight, qw.scales], ()), lambda qw: ([qw.weight, qw.scales], ()),
@ -330,7 +371,7 @@ class MoELayer(hk.Module):
check_rep=False, check_rep=False,
) )
def moe_slow_matmul1(input, weight, scales, index, prob): def moe_slow_matmul1(input, weight, scales, index, prob):
weight = weight * scales weight = rescale_quantized_weight(weight, scales)
one_hot_indices = jax.nn.one_hot(index.reshape(-1), 8, axis=0) one_hot_indices = jax.nn.one_hot(index.reshape(-1), 8, axis=0)
all_expert_output = jnp.einsum("mk,bkn->bmn", input, weight) all_expert_output = jnp.einsum("mk,bkn->bmn", input, weight)
output = jnp.einsum("bm,bmn->mn", one_hot_indices, all_expert_output) output = jnp.einsum("bm,bmn->mn", one_hot_indices, all_expert_output)
@ -350,7 +391,7 @@ class MoELayer(hk.Module):
check_rep=False, check_rep=False,
) )
def moe_slow_matmul2(input, weight, scales, index, prob): def moe_slow_matmul2(input, weight, scales, index, prob):
weight = weight * scales weight = rescale_quantized_weight(weight, scales)
one_hot_indices = jax.nn.one_hot(index.reshape(-1), 8, axis=0) one_hot_indices = jax.nn.one_hot(index.reshape(-1), 8, axis=0)
all_expert_output = jnp.einsum("mk,bkn->bmn", input, weight) all_expert_output = jnp.einsum("mk,bkn->bmn", input, weight)
output = jnp.einsum("bm,bmn->mn", one_hot_indices, all_expert_output) output = jnp.einsum("bm,bmn->mn", one_hot_indices, all_expert_output)
@ -570,7 +611,7 @@ class Linear(hk.Linear):
check_rep=False, check_rep=False,
) )
def mul(w, s): def mul(w, s):
return w.astype(s.dtype) * s return rescale_quantized_weight(w, s)
w = mul(w.weight, w.scales) w = mul(w.weight, w.scales)
out = jnp.dot(inputs, w.astype(fprop_dtype)) out = jnp.dot(inputs, w.astype(fprop_dtype))

23
test_modelling.py Normal file
View File

@ -0,0 +1,23 @@
import numpy as np
from model import rescale_quantized_weight
def test_rescale():
weight = np.arange(42).reshape((6, 7)).astype(np.float16)
# Each row of scales is applied to
# three consecutive rows of weight.
scales = np.arange(2 * 7).reshape((2, 7)).astype(np.int32)
rescaled_array = rescale_quantized_weight(weight, scales)
assert rescaled_array.shape == weight.shape
assert rescaled_array[:, 0].flatten().tolist() == [
0 * 0,
0 * 7,
0 * 14,
7 * 21,
7 * 28,
7 * 35,
]
assert rescaled_array.dtype == np.int32