grok-1/test_modelling.py

24 lines
601 B
Python
Raw Normal View History

import numpy as np
from model import rescale_quantized_weight
def test_rescale():
weight = np.arange(42).reshape((6, 7)).astype(np.float16)
# Each row of scales is applied to
# three consecutive rows of weight.
scales = np.arange(2 * 7).reshape((2, 7)).astype(np.int32)
rescaled_array = rescale_quantized_weight(weight, scales)
assert rescaled_array.shape == weight.shape
assert rescaled_array[:, 0].flatten().tolist() == [
0 * 0,
0 * 7,
0 * 14,
7 * 21,
7 * 28,
7 * 35,
]
assert rescaled_array.dtype == np.int32