diff --git a/python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt b/python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt index 8e4c091cd01dd3a7ee72957e3e6e3a7661ac8b19..f73327f8248d8a7c9d9cc9357b1812526efc437a 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt @@ -3,14 +3,29 @@ string(REPLACE ".py" "" TEST_INFERENCE_IR_PASSES "${TEST_INFERENCE_IR_PASSES}") file(GLOB TEST_TRT_IR_PASSES RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_trt_*.py") string(REPLACE ".py" "" TEST_TRT_IR_PASSES "${TEST_TRT_IR_PASSES}") + +file(GLOB TEST_TRT_CONVERTER RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_trt_convert_*.py") +string(REPLACE ".py" "" TEST_TRT_CONVERTER "${TEST_TRT_CONVERTER}") + foreach(TEST_INFERENCE_IR_PASS ${TEST_TRT_IR_PASSES}) list(REMOVE_ITEM TEST_INFERENCE_IR_PASSES ${TEST_INFERENCE_IR_PASS}) endforeach() if(WITH_GPU AND TENSORRT_FOUND) + list(REMOVE_ITEM TEST_TRT_IR_PASSES test_trt_multiclass_nms_op) + + foreach(TRT_CONVERT ${TEST_TRT_CONVERTER}) + list(REMOVE_ITEM TEST_TRT_IR_PASSES ${TRT_CONVERT}) + endforeach() + foreach(target ${TEST_TRT_IR_PASSES}) py_test_modules(${target} MODULES ${target}) endforeach() + + foreach(target ${TEST_TRT_CONVERTER}) + py_test_modules(${target} MODULES ${target}) + set_tests_properties(${target} PROPERTIES TIMEOUT 100) + endforeach() endif() file(GLOB TEST_MKLDNN_IR_PASSES RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_mkldnn_*.py") @@ -32,6 +47,12 @@ if(WITH_GPU AND TENSORRT_FOUND) set_tests_properties(test_trt_subgraph_pass PROPERTIES TIMEOUT 120) set_tests_properties(test_trt_activation_pass PROPERTIES TIMEOUT 120) set_tests_properties(test_trt_conv_pass PROPERTIES TIMEOUT 120) -set_tests_properties(test_trt_multiclass_nms_op PROPERTIES TIMEOUT 200) +#set_tests_properties(test_trt_multiclass_nms_op PROPERTIES TIMEOUT 200) set_tests_properties(test_trt_dynamic_shape PROPERTIES TIMEOUT 120) +set_tests_properties(test_trt_pool_op PROPERTIES ENVIRONMENT FLAGS_fraction_of_gpu_memory_to_use=0.1 TIMEOUT 45) +set_tests_properties(test_trt_reduce_mean_op PROPERTIES TIMEOUT 60) +set_tests_properties(test_trt_tile_op PROPERTIES TIMEOUT 60) +set_tests_properties(test_trt_fc_fuse_quant_dequant_pass PROPERTIES TIMEOUT 100) +set_tests_properties(test_trt_conv_quant_dequant_pass PROPERTIES TIMEOUT 100) +set_tests_properties(test_trt_matmul_quant_dequant PROPERTIES TIMEOUT 100) endif() diff --git a/python/paddle/fluid/tests/unittests/ir/inference/auto_scan_test.py b/python/paddle/fluid/tests/unittests/ir/inference/auto_scan_test.py new file mode 100644 index 0000000000000000000000000000000000000000..59729e5637c4e9f03a6f871627743f38eaae8c61 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ir/inference/auto_scan_test.py @@ -0,0 +1,139 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import unittest +import abc +import os +import logging +import paddle +import paddle.fluid as fluid +from paddle.fluid.initializer import NumpyArrayInitializer +import paddle.fluid.core as core +from paddle import compat as cpt +import paddle.inference as paddle_infer +from typing import Optional, List, Callable, Dict, Any, Set +from program_config import TensorConfig, OpConfig, ProgramConfig, create_fake_model, create_quant_model + +logging.basicConfig(level=logging.INFO, format="%(message)s") + + +class AutoScanTest(unittest.TestCase): + def __init__(self, methodName='runTest'): + paddle.enable_static() + super(AutoScanTest, self).__init__(methodName) + + @abc.abstractmethod + def sample_program_configs(self) -> List[ProgramConfig]: + ''' + Generate all config with the combination of different Input tensor shape and + different Attr values. + ''' + raise NotImplementedError + + @abc.abstractmethod + def sample_predictor_configs(self) -> List[paddle_infer.Config]: + raise NotImplementedError + + def run_test_config(self, model, params, prog_config, pred_config, + feed_data) -> Dict[str, np.ndarray]: + ''' + Test a single case. + ''' + pred_config.set_model_buffer(model, len(model), params, len(params)) + predictor = paddle_infer.create_predictor(pred_config) + + for name, _ in prog_config.inputs.items(): + input_tensor = predictor.get_input_handle(name) + input_tensor.copy_from_cpu(feed_data[name]['shape']) + if feed_data[name]['lod'] is not None: + input_tensor.set_lod(feed_data[name]['lod']) + predictor.run() + result = {} + for out_name, o_name in zip(prog_config.outputs, + predictor.get_output_names()): + result[out_name] = predictor.get_output_handle(o_name).copy_to_cpu() + return result + + def assert_op_size(self, trt_engine_num, paddle_op_num): + cur_path = os.path.dirname(__file__) + last_passed_program = os.path.join( + cur_path, 'transpose_flatten_concat_fuse_pass.pdmodel') + model_bytes = paddle.static.load_from_file(last_passed_program) + pg = paddle.static.deserialize_program(model_bytes) + main_block = pg.desc.block(0) + op_size = main_block.op_size() + op_types = [ + main_block.op(i).type() == 'tensorrt_engine' for i in range(op_size) + ] + trt_engine_size = sum(op_types) + paddle_op_size = op_size - trt_engine_size + self.assertTrue(trt_engine_size == trt_engine_num, + 'trt_engine_num is {}, but got {}!'.format( + trt_engine_size, trt_engine_num)) + self.assertTrue(paddle_op_size == paddle_op_num, + 'paddle_op_num is {}, but got {}!'.format( + paddle_op_size, paddle_op_num)) + + def assert_tensors_near(self, + threshold: float, + tensors: List[Dict[str, np.array]]): + assert len(tensors) > 1 + first = tensors[0] + for group in tensors[1:]: + for key, arr in group.items(): + self.assertTrue( + np.allclose( + first[key], arr, atol=threshold), + "Output has diff between GPU and TensorRT. ") + + def run_test(self, + trt_engine_num: int, + paddle_op_num: int, + threshold=1e-5, + quant=False, + error_msg=None): + for prog_config in self.sample_program_configs(): + model, params = create_fake_model(prog_config) + if quant: + model, params = create_quant_model(model, params) + for batch_size in self.batch_size_set: + feed_data = {} + log_str = ' -- Input tensor info: ' + for name, tensor_config in prog_config.inputs.items(): + tensor_shape = tensor_config.shape.copy() + tensor_shape[0] = batch_size + feed_data[name] = { + 'shape': np.random.random(tensor_shape).astype( + tensor_config.dtype), + 'lod': tensor_config.lod + } + log_str += str({ + name: { + 'shape': tensor_shape, + 'lod': tensor_config.lod + } + }) + logging.info(log_str) + results: List[Dict[str, Tensor]] = [] + for pred_config in self.sample_predictor_configs(): + results.append( + self.run_test_config(model, params, prog_config, + pred_config, feed_data)) + try: + self.assert_tensors_near( + threshold=threshold, tensors=results) + self.assert_op_size(trt_engine_num, paddle_op_num) + except: + logging.info('ERROR OCCURED: ' + error_msg) diff --git a/python/paddle/fluid/tests/unittests/ir/inference/program_config.py b/python/paddle/fluid/tests/unittests/ir/inference/program_config.py new file mode 100644 index 0000000000000000000000000000000000000000..1343e9673667ac7006febc900ee8f7d0917504dc --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ir/inference/program_config.py @@ -0,0 +1,350 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, List, Callable, Dict, Any, Set +import numpy as np +import paddle +import paddle.fluid as fluid +import paddle.fluid.core as core +from paddle import compat as cpt +from paddle.fluid.initializer import NumpyArrayInitializer +from paddle.fluid.framework import convert_np_dtype_to_dtype_ + +from paddle.fluid.contrib.slim.quantization import QuantizationTransformPass +from paddle.fluid.contrib.slim.quantization import QuantizationFreezePass +from paddle.fluid.framework import IrGraph, IrNode, Operator +from paddle.fluid.executor import global_scope + + +class TensorConfig: + ''' + A config builder for a input or a weight. + + InputVar's shape can be [-1, xxx], batch_size + ''' + + def __init__(self, + shape: [List[int]], + dtype: [str]="float32", + data: Optional[np.array]=None, + lod: [List[List[int]]]=None): + ''' + shape: The shape of the tensor. + dtype: The data type of the tensor. + data: The value of WeightVar. for input, it should be None + ''' + self.shape = shape + self.dtype = dtype + self.data = data + self.lod = lod + + +class OpConfig: + ''' A config builder for generating a Op. ''' + + def __init__(self, + type: str, + inputs: Dict[str, List[str]], + outputs: Dict[str, List[str]], + attrs: Dict[str, Any]): + self.type = type + self.inputs = inputs + self.outputs = outputs + self.attrs = attrs + + +class ProgramConfig: + ''' A config builder for generating a Program. ''' + + def __init__(self, + ops: List[OpConfig], + weights: Dict[str, TensorConfig], + inputs: Dict[str, TensorConfig], + outputs: List[str]): + self.ops = ops + self.weights = weights + self.inputs = inputs + self.outputs = outputs + + +def create_fake_model(program_config): + ''' Create a Paddle model(in memory) according to the given config. ''' + paddle.enable_static() + main_program_desc = core.ProgramDesc() + util_program = fluid.Program() + main_block_desc = main_program_desc.block(0) + + var_desc = main_block_desc.var(cpt.to_bytes("feed")) + var_desc.set_type(core.VarDesc.VarType.FEED_MINIBATCH) + var_desc.set_persistable(True) + + index = 0 + for name, tensor_config in program_config.inputs.items(): + var_desc = main_block_desc.var(cpt.to_bytes(name)) + var_desc.set_type(core.VarDesc.VarType.LOD_TENSOR) + var_desc.set_dtype(convert_np_dtype_to_dtype_(tensor_config.dtype)) + var_desc.set_shape(tensor_config.shape) + var_desc.set_need_check_feed(True) + op_desc = main_block_desc._prepend_op() + op_desc.set_type("feed") + op_desc.set_input('X', ["feed"]) + op_desc.set_output('Out', [name]) + op_desc._set_attr("col", index) + index = index + 1 + + save_var_map = {} + for name, tensor_config in program_config.weights.items(): + var_desc = main_block_desc.var(cpt.to_bytes(name)) + var_desc.set_type(core.VarDesc.VarType.LOD_TENSOR) + var_desc.set_dtype(convert_np_dtype_to_dtype_(tensor_config.dtype)) + var_desc.set_shape(tensor_config.shape) + var_desc.set_persistable(True) + + save_var_map[name] = util_program.global_block().create_parameter( + dtype=tensor_config.dtype, + shape=tensor_config.shape, + type=core.VarDesc.VarType.LOD_TENSOR, + name=name, + initializer=NumpyArrayInitializer(tensor_config.data)) + in_vars = [] + for name in sorted(save_var_map.keys()): + in_vars.append(save_var_map[name]) + + out_var = util_program.global_block().create_var( + type=core.VarDesc.VarType.RAW, name="out_var_0") + out_var.desc.set_persistable(True) + util_program.global_block().append_op( + type='save_combine', + inputs={'X': in_vars}, + outputs={'Y': out_var}, + attrs={'file_path': '', + 'save_to_memory': True}) + for op_config in program_config.ops: + op_desc = main_block_desc.append_op() + op_desc.set_type(op_config.type) + for name, values in op_config.inputs.items(): + op_desc.set_input(name, values) + for name, values in op_config.attrs.items(): + op_desc._set_attr(name, values) + for name, values in op_config.outputs.items(): + op_desc.set_output(name, values) + for v in values: + var_desc = main_block_desc.var(cpt.to_bytes(v)) + var_desc.set_type(core.VarDesc.VarType.LOD_TENSOR) + var_desc.set_dtype( + convert_np_dtype_to_dtype_(tensor_config.dtype)) + op_desc.infer_var_type(main_block_desc) + op_desc.infer_shape(main_block_desc) + + for index, name in enumerate(program_config.outputs): + var_desc = main_block_desc.var(cpt.to_bytes("fetch")) + var_desc.set_type(core.VarDesc.VarType.FETCH_LIST) + var_desc.set_need_check_feed(True) + op_desc = main_block_desc.append_op() + op_desc.set_type("fetch") + op_desc.set_input('X', [name]) + op_desc.set_output('Out', ["fetch"]) + op_desc._set_attr("col", index) + + main_program_desc._set_version() + paddle.fluid.core.save_op_version_info(main_program_desc) + + model = main_program_desc.serialize_to_string() + + util_program._sync_with_cpp() + place = fluid.CPUPlace() + executor = fluid.Executor(place) + scope = fluid.Scope() + with fluid.scope_guard(scope): + executor.run(util_program) + params = scope.find_var("out_var_0").get_bytes() + return model, params + + +def create_quant_model(model, + params, + activation_quantize_type='moving_average_abs_max', + weight_quantize_type='channel_wise_abs_max', + save=False): + place = paddle.CUDAPlace(0) + scope = global_scope() + exe = paddle.static.Executor(place) + [inference_program, feed_target_names, + fetch_targets] = paddle.static.load_inference_model( + path_prefix=None, + executor=exe, + model_filename=model, + params_filename=params) + graph = IrGraph(core.Graph(inference_program.desc), for_test=True) + + out_scale_op_list = [ + "conv2d", + "depthwise_conv2d", + "mul", + "matmul", + "relu", + "leaky_relu", + "relu6", + "sigmoid", + "tanh", + "prelu", + "swish", + "softmax", + "batch_norm", + "layer_norm", + "elementwise_add", + "pool2d", + "reshape2", + "transpose2", + "concat", + "elementwise_mul", + "scale", + "slice", + "hard_swish", + "hard_sigmoid", + "conv2d_transpose", + "gru", + "bilinear_interp", + "nearest_interp", + "trilinear_interp", + "flatten", + "flatten2", + "transpose", + "pad2d", + "reshape", + "layer_norm", + ] + op_real_in_out_name = { + "conv2d": [["Input", "Filter"], ["Output"]], + "depthwise_conv2d": [["Input", "Filter"], ["Output"]], + "conv2d_transpose": [["Input", "Filter"], ["Output"]], + "mul": [["X", "Y"], ["Out"]], + "matmul": [["X", "Y"], ["Out"]], + "pool2d": [["X"], ["Out"]], + "elementwise_add": [["X", "Y"], ["Out"]], + "concat": [["X"], ["Out"]], + "softmax": [["X"], ["Out"]], + "argmax": [["X"], ["Out"]], + "transpose": [["X"], ["Out"]], + "equal": [["X", "Y"], ["Out"]], + "gather": [["X"], ["Out"]], + "greater_equal": [["X", "Y"], ["Out"]], + "greater_than": [["X", "Y"], ["Out"]], + "less_equal": [["X", "Y"], ["Out"]], + "less_than": [["X", "Y"], ["Out"]], + "mean": [["X"], ["Out"]], + "not_equal": [["X", "Y"], ["Out"]], + "reshape": [["X"], ["Out"]], + "reshape2": [["X"], ["Out"]], + "transpose2": [["X"], ["Out"]], + "bilinear_interp": [["X"], ["Out"]], + "nearest_interp": [["X"], ["Out"]], + "trilinear_interp": [["X"], ["Out"]], + "slice": [["Input"], ["Out"]], + "squeeze": [["X"], ["Out"]], + "elementwise_sub": [["X", "Y"], ["Out"]], + "relu": [["X"], ["Out"]], + "relu6": [["X"], ["Out"]], + "leaky_relu": [["X"], ["Out"]], + "prelu": [["X"], ["Out"]], + "tanh": [["X"], ["Out"]], + "swish": [["X"], ["Out"]], + "dropout": [["X"], ["Out"]], + "batch_norm": [["X"], ["Y"]], + "layer_norm": [["X"], ["Y"]], + "sigmoid": [["X"], ["Out"]], + "elementwise_mul": [["X", "Y"], ["Out"]], + "scale": [["X"], ["Out"]], + "hard_swish": [["X"], ["Out"]], + "hard_sigmoid": [["X"], ["Out"]], + "gru": [["Input", "Weight"], ["Hidden"]], + "lstm": [["Input", "Weight"], ["Hidden"]], + "pad2d": [["X"], ["Out"]], + "flatten": [["X"], ["Out"]], + "flatten2": [["X"], ["Out"]], + } + + def _get_op_output_var_names(op): + """ """ + assert isinstance(op, (IrNode, Operator)), \ + "The input op should be IrNode or Operator." + var_names = [] + op_name = op.name() if isinstance(op, IrNode) \ + else op.type + if op_name not in op_real_in_out_name: + return [] + + name_list = op_real_in_out_name[op_name][1] + for name in name_list: + var_name = op.output(name) + if isinstance(var_name, list): + var_names.extend(var_name) + else: + var_names.append(var_name) + return var_names + + transform_pass = QuantizationTransformPass( + scope=scope, + place=place, + activation_quantize_type=activation_quantize_type, + weight_quantize_type=weight_quantize_type) + transform_pass.apply(graph) + + op_nodes = graph.all_op_nodes() + for op_node in op_nodes: + if op_node.name() in out_scale_op_list: + var_names = _get_op_output_var_names(op_node) + for var_name in var_names: + in_node = graph._find_node_by_name(op_node.outputs, var_name) + if in_node.dtype() not in \ + [core.VarDesc.VarType.FP64, core.VarDesc.VarType.FP32]: + continue + + op_node.op()._set_attr("out_threshold", 3.0) + + # Freeze graph for inference, but the weight of fc/conv is still float type. + freeze_pass = QuantizationFreezePass( + scope=scope, place=place, weight_quantize_type=weight_quantize_type) + freeze_pass.apply(graph) + + main_program = graph.to_program() + + # modify fake_quantize_moving_average_abs_max(InScale) and fake_channel_wise_dequantize_max_abs(Scales) + op_nodes = graph.all_op_nodes() + for op_node in op_nodes: + if op_node.name() == 'fake_quantize_moving_average_abs_max': + var_name = op_node.input("InScale")[0] + tensor = scope.var(var_name).get_tensor() + tensor.set(np.array([1], dtype=np.float32), place) + elif op_node.name() == 'fake_channel_wise_dequantize_max_abs': + var_name = op_node.input("Scales")[0] + tensor = scope.var(var_name).get_tensor() + tensor.set(np.ones(tensor.shape(), dtype=np.float32), place) + + if save: + fluid.io.save_inference_model( + 'test_inference_model', + feed_target_names, + fetch_targets, + exe, + main_program=main_program) + + feed_vars = [ + main_program.global_block().var(name) for name in feed_target_names + ] + serialized_program = paddle.static.serialize_program( + feed_vars, fetch_targets, program=main_program) + serialized_params = paddle.static.serialize_persistables( + feed_vars, fetch_targets, executor=exe, program=main_program) + return serialized_program, serialized_params diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_conv2d.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_conv2d.py new file mode 100644 index 0000000000000000000000000000000000000000..109eef2038a77e2552ebd9991f27815f7632cae9 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_conv2d.py @@ -0,0 +1,93 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from trt_layer_auto_scan_test import TrtLayerAutoScanTest +from program_config import TensorConfig +import numpy as np +import paddle.inference as paddle_infer + + +class TrtConvertConv2dTest(TrtLayerAutoScanTest): + def setUp(self): + self.ops_config = [{ + "op_type": "conv2d", + "op_inputs": { + "Input": ["input_data"], + "Filter": ["conv2d_weight"] + }, + "op_outputs": { + "Output": ["conv_output_data"] + }, + "op_attrs": { + "data_format": ["NCHW"], + "dilations": [[1, 1]], + "padding_algorithm": ["EXPLICIT"], + "groups": [1], + "paddings": [[0, 3], [3, 1]], + "strides": [[1, 1], [2, 2]], + } + }, { + "op_type": "relu", + "op_inputs": { + "X": ["conv_output_data"] + }, + "op_outputs": { + "Out": ["relu_output_data"] + }, + "op_attrs": {} + }] + self.batch_size_set = [1, 2, 4] + + def update_program_input_and_weight_with_attr(self, op_attr_list): + weight = np.random.randn(24, 3, 3, 3).astype("float32") + filter = TensorConfig(shape=[24, 3, 3, 3], data=weight) + if op_attr_list[0]["data_format"] == "NCHW": + input_data = TensorConfig(shape=[-1, 3, 64, 64]) + else: + input_data = TensorConfig(shape=[-1, 64, 64, 3]) + self.program_weights = {"conv2d_weight": filter} + self.program_inputs = {"input_data": input_data} + self.program_outputs = ["relu_output_data"] + + def test_check_fp32_output(self): + self.trt_param.precision = paddle_infer.PrecisionType.Float32 + # the fused tensorrt engine num is 1, and paddle op num is 2(feed and fetch). + self.run_test(trt_engine_num=1, paddle_op_num=2, threshold=1e-5) + + def test_check_fp16_output(self): + self.trt_param.precision = paddle_infer.PrecisionType.Half + self.run_test(trt_engine_num=1, paddle_op_num=2, threshold=1e-2) + + def test_dynamic_shape_fp32_check_output(self): + self.trt_param.precision = paddle_infer.PrecisionType.Float32 + self.dynamic_shape.min_input_shape = {"input_data": [1, 3, 32, 32]} + self.dynamic_shape.max_input_shape = {"input_data": [4, 3, 64, 64]} + self.dynamic_shape.opt_input_shape = {"input_data": [1, 3, 64, 64]} + self.run_test(trt_engine_num=1, paddle_op_num=2, threshold=1e-5) + + def test_dynamic_shape_fp16_check_output(self): + self.trt_param.precision = paddle_infer.PrecisionType.Half + self.dynamic_shape.min_input_shape = {"input_data": [1, 3, 32, 32]} + self.dynamic_shape.max_input_shape = {"input_data": [4, 3, 64, 64]} + self.dynamic_shape.opt_input_shape = {"input_data": [1, 3, 64, 64]} + self.run_test(trt_engine_num=1, paddle_op_num=2, threshold=1e-2) + + def test_trt_int8_check_output(self): + self.trt_param.precision = paddle_infer.PrecisionType.Int8 + self.run_test( + trt_engine_num=1, paddle_op_num=2, quant=True, threshold=1e-1) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/ir/inference/trt_layer_auto_scan_test.py b/python/paddle/fluid/tests/unittests/ir/inference/trt_layer_auto_scan_test.py new file mode 100644 index 0000000000000000000000000000000000000000..715006771878795674d9391b926be40d2ed27bc1 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ir/inference/trt_layer_auto_scan_test.py @@ -0,0 +1,179 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import unittest +import itertools +import abc +import logging +import paddle +import paddle.fluid as fluid +import paddle.fluid.core as core +import paddle.inference as paddle_infer + +from paddle import compat as cpt +from typing import * +from program_config import TensorConfig, OpConfig, ProgramConfig +from auto_scan_test import AutoScanTest + +logging.basicConfig(level=logging.INFO, format="%(message)s") + + +class TrtLayerAutoScanTest(AutoScanTest): + class TensorRTParam: + ''' + TensorRT subgraph engine parameters. + ''' + + def __init__(self, workspace_size, max_batch_size, min_subgraph_size, + precision, use_static, use_calib_mode): + self.workspace_size = workspace_size + self.max_batch_size = max_batch_size + self.min_subgraph_size = min_subgraph_size + self.precision = precision + self.use_static = use_static + self.use_calib_mode = use_calib_mode + + class DynamicShapeParam: + ''' + Prepare TensorRT subgraph engine dynamic shape parameters. + ''' + + def __init__(self, min_input_shape, max_input_shape, optim_input_shape, + disable_trt_plugin_fp16): + self.min_input_shape = min_input_shape + self.max_input_shape = max_input_shape + self.optim_input_shape = optim_input_shape + self.disable_trt_plugin_fp16 = disable_trt_plugin_fp16 + + def __init__(self, methodName='runTest'): + super(TrtLayerAutoScanTest, self).__init__(methodName) + self.trt_param = self.TensorRTParam( + workspace_size=0, + max_batch_size=4, + min_subgraph_size=0, + precision=paddle_infer.PrecisionType.Float32, + use_static=False, + use_calib_mode=False) + self.dynamic_shape = self.DynamicShapeParam({}, {}, {}, False) + + def update_program_input_and_weight_with_attr(self, op_attr_list): + raise NotImplementedError + + @abc.abstractmethod + def sample_program_configs(self): + all_op_attrs_keys = [] + all_op_attrs_values = [] + for op_config in self.ops_config: + all_op_attrs_keys.append(list(op_config["op_attrs"].keys())) + all_op_attrs_values.extend(list(op_config["op_attrs"].values())) + if len(all_op_attrs_values) == 0: + all_op_attrs_values.append([None]) + for attrs_sample in itertools.product(*all_op_attrs_values): + op_attr_list = [] + index = 0 + ops = [] + log_str = 'TEST_CASE: ' + for i in range(len(self.ops_config)): + op_config = self.ops_config[i] + op_attr = dict( + zip( + list(op_config["op_attrs"].keys()), attrs_sample[ + index:index + len(op_config["op_attrs"])])) + + if i != len(self.ops_config) - 1: + log_str += op_config['op_type'] + str(op_attr) + ' + ' + else: + log_str += op_config['op_type'] + str(op_attr) + + op_attr_list.append(op_attr) + index = index + len(op_config["op_attrs"]) + ops.append( + OpConfig( + type=op_config["op_type"], + inputs=op_config["op_inputs"], + outputs=op_config["op_outputs"], + attrs=op_attr)) + + logging.info(log_str) + self.update_program_input_and_weight_with_attr(op_attr_list) + # if no weight need to save, we create a place_holder to help seriazlie params. + if not self.program_weights: + self.program_weights = { + "place_holder_weight": TensorConfig( + shape=[1], data=np.array([1]).astype(np.float32)) + } + program_config = ProgramConfig( + ops=ops, + weights=self.program_weights, + inputs=self.program_inputs, + outputs=self.program_outputs) + yield program_config + + def create_program_config( + self, use_trt=True, + precision_mode=paddle_infer.PrecisionType.Float32): + config = paddle_infer.Config() + config.disable_glog_info() + config.enable_use_gpu(100, 0) + if use_trt: + config.switch_ir_debug() + config.enable_tensorrt_engine( + max_batch_size=self.trt_param.max_batch_size, + workspace_size=self.trt_param.workspace_size, + min_subgraph_size=self.trt_param.min_subgraph_size, + precision_mode=precision_mode, + use_static=self.trt_param.use_static, + use_calib_mode=self.trt_param.use_calib_mode) + if len(self.dynamic_shape.min_input_shape + ) != 0 and self.dynamic_shape.min_input_shape.keys( + ) == self.dynamic_shape.max_input_shape.keys( + ) and self.dynamic_shape.min_input_shape.keys( + ) == self.dynamic_shape.opt_input_shape.keys(): + config.set_trt_dynamic_shape_info( + self.dynamic_shape.min_input_shape, + self.dynamic_shape.max_input_shape, + self.dynamic_shape.opt_input_shape, + self.dynamic_shape.disable_trt_plugin_fp16) + return config + + @abc.abstractmethod + def sample_predictor_configs(self): + def precision_to_str(p): + if p == paddle_infer.PrecisionType.Float32: + return 'float32' + elif p == paddle_infer.PrecisionType.Half: + return 'half' + elif p == paddle_infer.PrecisionType.Int8: + return 'int8' + else: + raise NotImplementedError('not supported type.') + + trt_log_str = '' + if len(self.dynamic_shape.min_input_shape + ) != 0 and self.dynamic_shape.min_input_shape.keys( + ) == self.dynamic_shape.max_input_shape.keys( + ) and self.dynamic_shape.min_input_shape.keys( + ) == self.dynamic_shape.opt_input_shape.keys(): + trt_log_str += 'dynamic_shape ' + else: + trt_log_str += 'static_shape ' + trt_log_str += precision_to_str(self.trt_param.precision) + + logging.info(' --------- gpu inference ---------') + yield self.create_program_config(use_trt=False) + logging.info(' --------- trt ' + trt_log_str + + ' inference ---------') + yield self.create_program_config( + use_trt=True, precision_mode=self.trt_param.precision)