Skip to content

Commit 36c6fe1

Browse files
authored
[Tests] Mock Observers, Static Lifecycle Tests (#482)
* refactor Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * reduce diff Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * reduce diff Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * increase num of required observed dims Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * remove attention head Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * add tests Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * remove attn head Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * fix quality Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * simplify Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> --------- Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent 5121d47 commit 36c6fe1

File tree

2 files changed

+509
-0
lines changed

2 files changed

+509
-0
lines changed

tests/mock_observer.py

Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import Tuple
16+
from weakref import ref
17+
18+
import torch
19+
from compressed_tensors.quantization import QuantizationArgs, QuantizationStrategy
20+
from compressed_tensors.quantization.utils import (
21+
calculate_qparams,
22+
generate_gparam,
23+
strategy_cdiv,
24+
)
25+
26+
27+
class MockMinMaxObserver(torch.nn.Module):
28+
def __init__(self, base_name: str, args: QuantizationArgs, module: torch.nn.Module):
29+
super().__init__()
30+
self.parent = ref(module)
31+
self.base_name = base_name
32+
self.args = args
33+
34+
# used for testing
35+
self.min_vals = None
36+
self.max_vals = None
37+
38+
def get_min_max(self, observed: torch.Tensor):
39+
min_vals = torch.amin(observed, dim=(0, -1))
40+
max_vals = torch.amax(observed, dim=(0, -1))
41+
42+
return min_vals, max_vals
43+
44+
def forward(self, observed: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
45+
observed = flatten_for_quantization(observed, self.base_name, self.args)
46+
47+
self.min_vals, self.max_vals = self.get_min_max(observed)
48+
49+
scales, zero_points = calculate_qparams(
50+
min_vals=self.min_vals,
51+
max_vals=self.max_vals,
52+
quantization_args=self.args,
53+
global_scale=getattr(self.parent(), f"{self.base_name}_global_scale", None),
54+
)
55+
56+
return scales, zero_points
57+
58+
def get_global_scale(self, observed: torch.Tensor):
59+
observed = observed.reshape((1, 1, -1)) # per tensor reshape
60+
min_vals, max_vals = self.get_min_max(observed)
61+
global_scale = generate_gparam(min_vals, max_vals)
62+
63+
return global_scale
64+
65+
66+
def flatten_for_quantization(
67+
value: torch.Tensor, base_name: str, args: QuantizationArgs
68+
) -> torch.Tensor:
69+
if base_name == "weight":
70+
return flatten_weight_for_quantization(value, args)
71+
elif base_name in ("input", "output"):
72+
return flatten_activation_for_quantization(value, args)
73+
elif base_name in ("q", "k", "v"):
74+
return flatten_attention_for_quantization(value, args)
75+
else:
76+
raise ValueError(f"Unknown quantization base name: {base_name}")
77+
78+
79+
def flatten_weight_for_quantization(value: torch.Tensor, args: QuantizationArgs):
80+
if args.strategy == QuantizationStrategy.TENSOR:
81+
# (1, 1, num_weight_elems)
82+
return value.reshape((1, 1, -1))
83+
84+
if args.strategy == QuantizationStrategy.TOKEN:
85+
raise ValueError("Token quantization cannot be applied to weights")
86+
87+
if args.strategy == QuantizationStrategy.CHANNEL:
88+
# (1, num_rows, 1, num_cols)
89+
return value.unsqueeze(-2).unsqueeze(0)
90+
91+
if args.strategy in (QuantizationStrategy.GROUP, QuantizationStrategy.TENSOR_GROUP):
92+
# (1, num_rows, num_groups, group_size)
93+
return value.unflatten(-1, (-1, args.group_size)).unsqueeze(0)
94+
95+
if args.strategy == QuantizationStrategy.BLOCK:
96+
# (1, num_block_rows, num_block_cols, block_width * block_height)
97+
block_height, block_width = args.block_structure
98+
num_rows, num_cols = value.shape
99+
num_block_rows = strategy_cdiv(num_rows, block_height, args.strategy)
100+
num_block_cols = strategy_cdiv(num_cols, block_width, args.strategy)
101+
return (
102+
value.reshape(
103+
num_block_rows,
104+
block_height,
105+
num_block_cols,
106+
block_width,
107+
)
108+
.transpose(1, 2)
109+
.flatten(-2, -1)
110+
.unsqueeze(0)
111+
)
112+
113+
assert False, f"Unknown strategy {args.strategy}"
114+
115+
116+
def flatten_activation_for_quantization(value: torch.Tensor, args: QuantizationArgs):
117+
if args.strategy == QuantizationStrategy.TENSOR:
118+
# (batch_size * seq_len, 1, hidden_dim)
119+
return value.reshape((-1, 1, value.size(-1)))
120+
121+
if args.strategy == QuantizationStrategy.TOKEN:
122+
# (batch_size, seq_len, hidden_dim)
123+
# warning: token quantization uses `compute_dynamic_scales_and_zp`
124+
return value.flatten(2, -1)
125+
126+
if args.strategy == QuantizationStrategy.CHANNEL:
127+
raise ValueError("Channel quantization cannot be applied to activations")
128+
129+
if args.strategy in (QuantizationStrategy.GROUP, QuantizationStrategy.TENSOR_GROUP):
130+
# (batch_size * seq_len, num_groups, group_size)
131+
# warning: group activation quantization uses compute_dynamic_scales_and_zp
132+
return value.flatten(0, 1).unflatten(-1, (-1, args.group_size))
133+
134+
if args.strategy == QuantizationStrategy.BLOCK:
135+
raise ValueError("Block quantization cannot be applied to activations")
136+
137+
assert False, f"Unknown strategy {args.strategy}"
138+
139+
140+
def flatten_attention_for_quantization(value: torch.Tensor, args: QuantizationArgs):
141+
if args.strategy == QuantizationStrategy.TENSOR:
142+
# (batch_size, seq_len, num_heads, head_dim)
143+
# (batch_size * seq_len, 1, num_heads * head_dim)
144+
return value.flatten(0, 1).flatten(-2, -1).unsqueeze(-2)
145+
146+
if args.strategy == QuantizationStrategy.TOKEN:
147+
raise ValueError("Token quantization cannot be applied to attention")
148+
149+
if args.strategy == QuantizationStrategy.CHANNEL:
150+
raise ValueError("Channel quantization cannot be applied to attention")
151+
152+
if args.strategy in (QuantizationStrategy.GROUP, QuantizationStrategy.TENSOR_GROUP):
153+
raise ValueError("Group quantization cannot be applied to attention")
154+
155+
if args.strategy == QuantizationStrategy.BLOCK:
156+
raise ValueError("Block quantization cannot be applied to attention")
157+
158+
assert False, f"Unknown strategy {args.strategy}"

0 commit comments

Comments
 (0)