group-wbl/.venv/lib/python3.13/site-packages/onnxruntime/quantization/static_quantize_runner.py
2026-01-09 09:12:25 +08:00

257 lines
11 KiB
Python

import argparse
import json
import os
import numpy as np
import onnx
import onnxruntime
from onnxruntime.quantization import QuantFormat, QuantType, StaticQuantConfig, quantize
from onnxruntime.quantization.calibrate import CalibrationDataReader, CalibrationMethod
class OnnxModelCalibrationDataReader(CalibrationDataReader):
def __init__(self, model_path):
self.model_dir = os.path.dirname(model_path)
data_dirs = [
os.path.join(self.model_dir, a) for a in os.listdir(self.model_dir) if a.startswith("test_data_set_")
]
model_inputs = onnxruntime.InferenceSession(model_path).get_inputs()
name2tensors = []
for data_dir in data_dirs:
name2tensor = {}
data_paths = [os.path.join(data_dir, a) for a in sorted(os.listdir(data_dir))]
data_ndarrays = [self.read_onnx_pb_data(data_path) for data_path in data_paths]
for model_input, data_ndarray in zip(model_inputs, data_ndarrays, strict=False):
name2tensor[model_input.name] = data_ndarray
name2tensors.append(name2tensor)
assert len(name2tensors) == len(data_dirs)
assert len(name2tensors[0]) == len(model_inputs)
self.calibration_data = iter(name2tensors)
def get_next(self) -> dict:
"""generate the input data dict for ONNXinferenceSession run"""
return next(self.calibration_data, None)
def read_onnx_pb_data(self, file_pb):
tensor = onnx.TensorProto()
with open(file_pb, "rb") as f:
tensor.ParseFromString(f.read())
ret = onnx.numpy_helper.to_array(tensor)
return ret
def parse_arguments():
parser = argparse.ArgumentParser(description="The arguments for static quantization")
parser.add_argument("-i", "--input_model_path", required=True, help="Path to the input onnx model")
parser.add_argument(
"-o", "--output_quantized_model_path", required=True, help="Path to the output quantized onnx model"
)
parser.add_argument(
"--activation_type",
choices=["qint8", "quint8", "qint16", "quint16", "qint4", "quint4", "qfloat8e4m3fn"],
default="quint8",
help="Activation quantization type used",
)
parser.add_argument(
"--weight_type",
choices=["qint8", "quint8", "qint16", "quint16", "qint4", "quint4", "qfloat8e4m3fn"],
default="qint8",
help="Weight quantization type used",
)
parser.add_argument("--enable_subgraph", action="store_true", help="If set, subgraph will be quantized.")
parser.add_argument(
"--force_quantize_no_input_check",
action="store_true",
help="By default, some latent operators like maxpool, transpose, do not quantize if their input is not"
" quantized already. Setting to True to force such operator always quantize input and so generate"
" quantized output. Also the True behavior could be disabled per node using the nodes_to_exclude.",
)
parser.add_argument(
"--matmul_const_b_only",
action="store_true",
help="If set, only MatMul with const B will be quantized.",
)
parser.add_argument(
"--add_qdq_pair_to_weight",
action="store_true",
help="If set, it remains floating-point weight and inserts both QuantizeLinear/DeQuantizeLinear"
" nodes to weight.",
)
parser.add_argument(
"--dedicated_qdq_pair",
action="store_true",
help="If set, it will create identical and dedicated QDQ pair for each node.",
)
parser.add_argument(
"--op_types_to_exclude_output_quantization",
nargs="+",
default=[],
help="If any op type is specified, it won't quantize the output of ops with this specific op types.",
)
parser.add_argument(
"--calibration_method",
default="minmax",
choices=["minmax", "entropy", "percentile", "distribution"],
help="Calibration method used",
)
parser.add_argument("--quant_format", default="qdq", choices=["qdq", "qoperator"], help="Quantization format used")
parser.add_argument(
"--calib_tensor_range_symmetric",
action="store_true",
help="If enabled, the final range of tensor during calibration will be explicitly"
" set to symmetric to central point 0",
)
# TODO: --calib_strided_minmax"
# TODO: --calib_moving_average_constant"
# TODO: --calib_max_intermediate_outputs"
parser.add_argument(
"--calib_moving_average",
action="store_true",
help="If enabled, the moving average of"
" the minimum and maximum values will be computed when the calibration method selected is MinMax.",
)
parser.add_argument(
"--disable_quantize_bias",
action="store_true",
help="Whether to quantize floating-point biases by solely inserting a DeQuantizeLinear node"
" If not set, it remains floating-point bias and does not insert any quantization nodes"
" associated with biases.",
)
# TODO: Add arguments related to Smooth Quant
parser.add_argument(
"--use_qdq_contrib_ops",
action="store_true",
help="If set, the inserted QuantizeLinear and DequantizeLinear ops will have the com.microsoft domain,"
" which forces use of ONNX Runtime's QuantizeLinear and DequantizeLinear contrib op implementations.",
)
parser.add_argument(
"--minimum_real_range",
type=float,
default=0.0001,
help="If set to a floating-point value, the calculation of the quantization parameters"
" (i.e., scale and zero point) will enforce a minimum range between rmin and rmax. If (rmax-rmin)"
" is less than the specified minimum range, rmax will be set to rmin + MinimumRealRange. This is"
" necessary for EPs like QNN that require a minimum floating-point range when determining "
" quantization parameters.",
)
parser.add_argument(
"--qdq_keep_removable_activations",
action="store_true",
help="If set, removable activations (e.g., Clip or Relu) will not be removed,"
" and will be explicitly represented in the QDQ model.",
)
parser.add_argument(
"--qdq_disable_weight_adjust_for_int32_bias",
action="store_true",
help="If set, QDQ quantizer will not adjust the weight's scale when the bias"
" has a scale (input_scale * weight_scale) that is too small.",
)
parser.add_argument("--per_channel", action="store_true", help="Whether using per-channel quantization")
parser.add_argument(
"--nodes_to_quantize",
nargs="+",
default=None,
help="List of nodes names to quantize. When this list is not None only the nodes in this list are quantized.",
)
parser.add_argument(
"--nodes_to_exclude",
nargs="+",
default=None,
help="List of nodes names to exclude. The nodes in this list will be excluded from quantization when it is not None.",
)
parser.add_argument(
"--op_per_channel_axis",
nargs=2,
action="append",
metavar=("OP_TYPE", "PER_CHANNEL_AXIS"),
default=[],
help="Set channel axis for specific op type, for example: --op_per_channel_axis MatMul 1, and it's"
" effective only when per channel quantization is supported and per_channel is True. If specific"
" op type supports per channel quantization but not explicitly specified with channel axis,"
" default channel axis will be used.",
)
parser.add_argument("--tensor_quant_overrides", help="Set the json file for tensor quantization overrides.")
return parser.parse_args()
def get_tensor_quant_overrides(file):
# TODO: Enhance the function to handle more real cases of json file
if not file:
return {}
with open(file) as f:
quant_override_dict = json.load(f)
for tensor in quant_override_dict:
for enc_dict in quant_override_dict[tensor]:
enc_dict["scale"] = np.array(enc_dict["scale"], dtype=np.float32)
enc_dict["zero_point"] = np.array(enc_dict["zero_point"])
return quant_override_dict
def main():
args = parse_arguments()
data_reader = OnnxModelCalibrationDataReader(model_path=args.input_model_path)
arg2quant_type = {
"qint8": QuantType.QInt8,
"quint8": QuantType.QUInt8,
"qint16": QuantType.QInt16,
"quint16": QuantType.QUInt16,
"qint4": QuantType.QInt4,
"quint4": QuantType.QUInt4,
"qfloat8e4m3fn": QuantType.QFLOAT8E4M3FN,
}
activation_type = arg2quant_type[args.activation_type]
weight_type = arg2quant_type[args.weight_type]
qdq_op_type_per_channel_support_to_axis = dict(args.op_per_channel_axis)
extra_options = {
"EnableSubgraph": args.enable_subgraph,
"ForceQuantizeNoInputCheck": args.force_quantize_no_input_check,
"MatMulConstBOnly": args.matmul_const_b_only,
"AddQDQPairToWeight": args.add_qdq_pair_to_weight,
"OpTypesToExcludeOutputQuantization": args.op_types_to_exclude_output_quantization,
"DedicatedQDQPair": args.dedicated_qdq_pair,
"QDQOpTypePerChannelSupportToAxis": qdq_op_type_per_channel_support_to_axis,
"CalibTensorRangeSymmetric": args.calib_tensor_range_symmetric,
"CalibMovingAverage": args.calib_moving_average,
"QuantizeBias": not args.disable_quantize_bias,
"UseQDQContribOps": args.use_qdq_contrib_ops,
"MinimumRealRange": args.minimum_real_range,
"QDQKeepRemovableActivations": args.qdq_keep_removable_activations,
"QDQDisableWeightAdjustForInt32Bias": args.qdq_disable_weight_adjust_for_int32_bias,
# Load json file for encoding override
"TensorQuantOverrides": get_tensor_quant_overrides(args.tensor_quant_overrides),
}
arg2calib_method = {
"minmax": CalibrationMethod.MinMax,
"entropy": CalibrationMethod.Entropy,
"percentile": CalibrationMethod.Percentile,
"distribution": CalibrationMethod.Distribution,
}
arg2quant_format = {
"qdq": QuantFormat.QDQ,
"qoperator": QuantFormat.QOperator,
}
sqc = StaticQuantConfig(
calibration_data_reader=data_reader,
calibrate_method=arg2calib_method[args.calibration_method],
quant_format=arg2quant_format[args.quant_format],
activation_type=activation_type,
weight_type=weight_type,
op_types_to_quantize=None,
nodes_to_quantize=args.nodes_to_quantize,
nodes_to_exclude=args.nodes_to_exclude,
per_channel=args.per_channel,
reduce_range=False,
use_external_data_format=False,
calibration_providers=None, # Use CPUExecutionProvider
extra_options=extra_options,
)
quantize(model_input=args.input_model_path, model_output=args.output_quantized_model_path, quant_config=sqc)
if __name__ == "__main__":
main()