mirror of
https://github.com/xai-org/grok-1.git
synced 2024-11-24 12:39:54 +03:00
24 lines
601 B
Python
24 lines
601 B
Python
|
import numpy as np
|
||
|
|
||
|
from model import rescale_quantized_weight
|
||
|
|
||
|
|
||
|
def test_rescale():
|
||
|
weight = np.arange(42).reshape((6, 7)).astype(np.float16)
|
||
|
|
||
|
# Each row of scales is applied to
|
||
|
# three consecutive rows of weight.
|
||
|
scales = np.arange(2 * 7).reshape((2, 7)).astype(np.int32)
|
||
|
|
||
|
rescaled_array = rescale_quantized_weight(weight, scales)
|
||
|
assert rescaled_array.shape == weight.shape
|
||
|
assert rescaled_array[:, 0].flatten().tolist() == [
|
||
|
0 * 0,
|
||
|
0 * 7,
|
||
|
0 * 14,
|
||
|
7 * 21,
|
||
|
7 * 28,
|
||
|
7 * 35,
|
||
|
]
|
||
|
assert rescaled_array.dtype == np.int32
|