From df7eab112370e2bc348b1e2aefc7ede02fa375b7 Mon Sep 17 00:00:00 2001 From: PaddlePaddle-Gardener Date: Thu, 21 Oct 2021 18:08:56 +0800 Subject: [PATCH] mirgate_36420 --- .../distributed/fleet/base/fleet_base.py | 40 ++++- .../fleet/meta_parallel/pipeline_parallel.py | 37 +++-- .../fleet/meta_parallel/pp_utils/utils.py | 13 +- .../distributed/fleet/utils/recompute.py | 15 +- python/paddle/fluid/framework.py | 2 +- .../unittests/hybrid_parallel_pp_fp16.py | 138 ++++++++++++++++++ .../tests/unittests/test_dygraph_recompute.py | 38 ++++- ...test_parallel_dygraph_pipeline_parallel.py | 8 +- 8 files changed, 260 insertions(+), 31 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/hybrid_parallel_pp_fp16.py diff --git a/python/paddle/distributed/fleet/base/fleet_base.py b/python/paddle/distributed/fleet/base/fleet_base.py index 544c79a0b3..571199b99b 100644 --- a/python/paddle/distributed/fleet/base/fleet_base.py +++ b/python/paddle/distributed/fleet/base/fleet_base.py @@ -35,6 +35,8 @@ from ..meta_parallel import TensorParallel, model_parallel_random_seed from ..meta_parallel import PipelineParallel, ShardingParallel from ..meta_optimizers import HybridParallelOptimizer from paddle import _C_ops +from paddle.fluid import core +from paddle.fluid.dygraph import to_variable __all__ = [] @@ -1548,26 +1550,52 @@ class Fleet(object): if getattr(optimizer, '_param_groups', None) and isinstance( optimizer._param_groups[0], dict): param_grads = [] + param_grads_fp16 = [] + param_grads_fp32 = [] for group in optimizer._param_groups: for param in group['params']: if param._grad_ivar() is not None: param_grads.append(param._grad_ivar()) + if param._grad_ivar( + ).dtype == core.VarDesc.VarType.FP16: + param_grads_fp16.append(param._grad_ivar()) + else: + param_grads_fp32.append(param._grad_ivar()) else: param_grads = [ param._grad_ivar() for param in optimizer._parameter_list if param._grad_ivar() is not None ] - _C_ops.check_finite_and_unscale(param_grads, self._scale, - param_grads, self._found_inf) - - self._found_inf = paddle.cast(self._found_inf, dtype="int32") + param_grads_fp16 = [ + param._grad_ivar() for param in optimizer._parameter_list + if (param._grad_ivar() is not None) and (param._grad_ivar( + ).dtype == core.VarDesc.VarType.FP16) + ] + param_grads_fp32 = [ + param._grad_ivar() for param in optimizer._parameter_list + if (param._grad_ivar() is not None) and (param._grad_ivar( + ).dtype == core.VarDesc.VarType.FP32) + ] + temp_found_inf_fp16 = to_variable(np.array([0]).astype(np.bool)) + temp_found_inf_fp32 = to_variable(np.array([0]).astype(np.bool)) + if len(param_grads_fp16): + _C_ops.check_finite_and_unscale(param_grads_fp16, self._scale, + param_grads_fp16, + temp_found_inf_fp16) + if len(param_grads_fp32): + _C_ops.check_finite_and_unscale(param_grads_fp32, self._scale, + param_grads_fp32, + temp_found_inf_fp32) + self._found_inf = 1 if temp_found_inf_fp16 or temp_found_inf_fp32 else 0 # TODO(shenliang03) Since dp allreduce in the optimizer is # after the gradscaler, check_finite needs to synchronize global # information. In the future, we should use check_group to speed. paddle.distributed.all_reduce( - self._found_inf, op=paddle.distributed.ReduceOp.MAX, group=None) - self._found_inf = paddle.cast(self._found_inf, dtype="bool") + paddle.to_tensor( + [self._found_inf], dtype="int32"), + op=paddle.distributed.ReduceOp.MAX, + group=None) # Only tensor_parallel and pipeline_parallel need to modify scaler if self._hcg.get_parallel_mode() in (ParallelMode.TENSOR_PARALLEL, diff --git a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py index 9096097397..7c7637a90f 100644 --- a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py +++ b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py @@ -145,9 +145,8 @@ class PipelineParallel(MetaParallelBase): p2p.send_backward(input_tensor_grad) self._layers.allreduce_shared_weight_gradients() - - train_loss = self._broadcast_final_loss() - + with paddle.amp.auto_cast(enable=False): + train_loss = self._broadcast_final_loss() return train_loss def train_batch(self, data, optimizer, lr_scheduler=None, scaler=None): @@ -172,7 +171,8 @@ class PipelineParallel(MetaParallelBase): train_loss = self.forward_backward_pipeline(data, scaler) # optimizer - self._optimizer_step() + with paddle.amp.auto_cast(enable=False): + self._optimizer_step() return train_loss @@ -242,12 +242,13 @@ class PipelineParallel(MetaParallelBase): output_tensor, paddle.Tensor ), "Currently, loss_fn should obtain Paddle.Tensor dtype" - if self.accumulate_steps > 1: - output_tensor = output_tensor / self.accumulate_steps + with paddle.amp.auto_cast(enable=False): + if self.accumulate_steps > 1: + output_tensor = output_tensor / self.accumulate_steps - if self.total_loss is None: - self.total_loss = paddle.zeros_like(output_tensor) - self.total_loss += output_tensor.detach() + if self.total_loss is None: + self.total_loss = paddle.zeros_like(output_tensor) + self.total_loss += output_tensor.detach() self.micro_batch_id += 1 return output_tensor @@ -321,13 +322,29 @@ class PipelineParallel(MetaParallelBase): if self.is_last_stage: assert self.total_loss is not None, "train_batch() in last stage should obtain vaild loss" loss = self.total_loss.detach() + is_fp32 = paddle.to_tensor( + 1) if loss.dtype == paddle.float32 else paddle.to_tensor(0) + paddle.distributed.broadcast( + is_fp32, + src=self.global_rank, + use_calc_stream=True, + group=self.pp_group) paddle.distributed.broadcast( loss, src=self.global_rank, use_calc_stream=True, group=self.pp_group) else: - loss = paddle.zeros(shape=[1], dtype="float32") + is_fp32 = paddle.to_tensor(1) + paddle.distributed.broadcast( + is_fp32, + src=self._hcg.get_rank_from_stage(self.num_stages - 1), + use_calc_stream=True, + group=self.pp_group) + loss = paddle.zeros( + shape=[1], + dtype="float32") if is_fp32.numpy()[0] else paddle.zeros( + shape=[1], dtype="float16") paddle.distributed.broadcast( loss, src=self._hcg.get_rank_from_stage(self.num_stages - 1), diff --git a/python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py b/python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py index 0826609654..7224ba6ded 100644 --- a/python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py +++ b/python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py @@ -198,11 +198,14 @@ class _HPRecomputeFunction(PyLayer): # TODO support AMP tracer = framework._dygraph_tracer() - if tracer._amp_level == core.AmpLevel.O0: - ctx.is_fw_autocast = False + ctx.is_fw_autocast = False if tracer._amp_level == core.AmpLevel.O0 else True + if tracer._amp_level == core.AmpLevel.O2: + ctx.amp_level = 'O2' + elif tracer._amp_level in (core.AmpLevel.O1, core.AmpLevel.O0): + ctx.amp_level = 'O1' else: - ctx.is_fw_autocast = True - ctx.amp_mode = 'O1' + raise ValueError("unsupported amp level: {}".format( + tracer._amp_level)) ctx.amp_white_list, ctx.amp_black_list = tracer._get_amp_op_list() with paddle.no_grad(): @@ -263,7 +266,7 @@ class _HPRecomputeFunction(PyLayer): enable=ctx.is_fw_autocast, custom_white_list=ctx.amp_white_list, custom_black_list=ctx.amp_black_list, - level=ctx.amp_mode): + level=ctx.amp_level): detached_inputs = detach_variable(tuple(inputs)) outputs = ctx.run_function(*detached_inputs) diff --git a/python/paddle/distributed/fleet/utils/recompute.py b/python/paddle/distributed/fleet/utils/recompute.py index 56a64049b1..2d1db5db94 100644 --- a/python/paddle/distributed/fleet/utils/recompute.py +++ b/python/paddle/distributed/fleet/utils/recompute.py @@ -98,11 +98,14 @@ class RecomputeFunction(PyLayer): # TODO support AMP tracer = framework._dygraph_tracer() - if tracer._amp_level == core.AmpLevel.O0: - ctx.is_fw_autocast = False + ctx.is_fw_autocast = False if tracer._amp_level == core.AmpLevel.O0 else True + if tracer._amp_level == core.AmpLevel.O2: + ctx.amp_level = 'O2' + elif tracer._amp_level in (core.AmpLevel.O1, core.AmpLevel.O0): + ctx.amp_level = 'O1' else: - ctx.is_fw_autocast = True - ctx.amp_mode = 'O1' + raise ValueError("unsupported amp level: {}".format( + tracer._amp_level)) ctx.amp_white_list, ctx.amp_black_list = tracer._get_amp_op_list() with paddle.no_grad(): @@ -133,7 +136,7 @@ class RecomputeFunction(PyLayer): enable=ctx.is_fw_autocast, custom_white_list=ctx.amp_white_list, custom_black_list=ctx.amp_black_list, - level=ctx.amp_mode): + level=ctx.amp_level): detached_inputs = detach_variable(tuple(inputs)) outputs = ctx.run_function(*detached_inputs) else: @@ -141,7 +144,7 @@ class RecomputeFunction(PyLayer): enable=ctx.is_fw_autocast, custom_white_list=ctx.amp_white_list, custom_black_list=ctx.amp_black_list, - level=ctx.amp_mode): + level=ctx.amp_level): detached_inputs = detach_variable(tuple(inputs)) outputs = ctx.run_function(*detached_inputs) diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 156ba07a4c..60e00238f6 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -6097,7 +6097,7 @@ class ParamBase(core.VarBase): self.need_clip = kwargs.get('need_clip', True) - self.is_distributed = False + self.is_distributed = kwargs.get('is_distributed', False) # self.block = default_main_program().global_block() @property diff --git a/python/paddle/fluid/tests/unittests/hybrid_parallel_pp_fp16.py b/python/paddle/fluid/tests/unittests/hybrid_parallel_pp_fp16.py new file mode 100644 index 0000000000..571459365a --- /dev/null +++ b/python/paddle/fluid/tests/unittests/hybrid_parallel_pp_fp16.py @@ -0,0 +1,138 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import division +from __future__ import print_function + +import unittest +import paddle +import numpy as np +import random +import paddle +import paddle.distributed as dist +import paddle.distributed.fleet as fleet +from hybrid_parallel_pp_layer import AlexNetPipeDesc, AlexNet + + +def set_random_seed(seed, dp_id, rank_id): + """Set random seed for reproducability.""" + random.seed(seed) + np.random.seed(seed + dp_id) + paddle.seed(seed + dp_id) + + +batch_size = 4 +micro_batch_size = 2 + + +class TestDistPPTraning(unittest.TestCase): + def setUp(self): + strategy = fleet.DistributedStrategy() + self.model_parallel_size = 1 + self.data_parallel_size = 1 + self.pipeline_parallel_size = 2 + strategy.hybrid_configs = { + "dp_degree": self.data_parallel_size, + "mp_degree": self.model_parallel_size, + "pp_degree": self.pipeline_parallel_size, + } + strategy.pipeline_configs = { + "accumulate_steps": batch_size // micro_batch_size, + "micro_batch_size": micro_batch_size + } + fleet.init(is_collective=True, strategy=strategy) + + def test_pp_model(self): + hcg = fleet.get_hybrid_communicate_group() + word_size = hcg.get_model_parallel_world_size() + dp_id = hcg.get_data_parallel_rank() + pp_id = hcg.get_stage_id() + rank_id = dist.get_rank() + set_random_seed(1024, dp_id, rank_id) + + #construct model a + model_a = AlexNet(10) + scheduler_a = paddle.optimizer.lr.PiecewiseDecay( + boundaries=[2], values=[0.001, 0.002], verbose=True) + optimizer_a = paddle.optimizer.SGD(learning_rate=scheduler_a, + parameters=model_a.parameters()) + + scaler_a = paddle.amp.GradScaler(init_loss_scaling=2**5) + + # construct model b + model_b = AlexNetPipeDesc(num_stages=self.pipeline_parallel_size) + scheduler_b = paddle.optimizer.lr.PiecewiseDecay( + boundaries=[2], values=[0.001, 0.002], verbose=True) + optimizer_b = paddle.optimizer.SGD(learning_rate=scheduler_b, + parameters=model_b.parameters()) + + param_len = len(model_a.parameters()) + parameters = [] + for param in model_a.parameters(): + parameters.append(param.numpy()) + + for idx, param in enumerate(model_b.parameters()): + param.set_value(parameters[idx + pp_id * (param_len // 2)]) + + model_a, optimizer_a = paddle.amp.decorate( + models=model_a, + optimizers=optimizer_a, + level='O2', + save_dtype='float32') + model_b, optimizer_b = paddle.amp.decorate( + models=model_b, + optimizers=optimizer_b, + level='O2', + save_dtype='float32') + + model_b = fleet.distributed_model(model_b) + optimizer_b = fleet.distributed_optimizer(optimizer_b) + scaler_b = paddle.amp.GradScaler(init_loss_scaling=2**5) + scaler_b = fleet.distributed_scaler(scaler_b) + + # construct reader + train_reader = paddle.batch( + paddle.dataset.mnist.train(), batch_size=batch_size, drop_last=True) + + for step_id, data in enumerate(train_reader()): + x_data = np.array([x[0] for x in data]).astype('float32').reshape( + batch_size, 1, 28, 28) + y_data = np.array([x[1] for x in data]).astype('int64').reshape( + batch_size, 1) + img = paddle.to_tensor(x_data) + label = paddle.to_tensor(y_data) + img.stop_gradient = True + label.stop_gradient = True + + if step_id >= 5: + return True + + with paddle.amp.auto_cast(enable=True, level='O2'): + loss_a = model_a(img, label) + scaler_a.scale(loss_a).backward() + with paddle.amp.auto_cast(enable=False): + scaler_a.minimize(optimizer_a, loss_a) + optimizer_a.clear_grad() + scheduler_a.step() + + loss_b = model_b.train_batch( + [img, label], optimizer_b, scheduler_b, scaler=scaler_b) + + print("loss: ", loss_a.numpy(), loss_b.numpy()) + np.testing.assert_allclose( + loss_a.numpy(), loss_b.numpy(), rtol=5e-3) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_dygraph_recompute.py b/python/paddle/fluid/tests/unittests/test_dygraph_recompute.py index 332603b812..4a4bcd2b81 100644 --- a/python/paddle/fluid/tests/unittests/test_dygraph_recompute.py +++ b/python/paddle/fluid/tests/unittests/test_dygraph_recompute.py @@ -92,7 +92,10 @@ class Naive_fc_net(paddle.nn.Layer): return inputs -def run_model(recompute_block=[], recompute_kwargs={}, enable_autocast=False): +def run_model(recompute_block=[], + recompute_kwargs={}, + enable_autocast=False, + pure_fp16=False): gen = paddle.seed(10) gen.manual_seed(10) np.random.seed(10) @@ -118,7 +121,8 @@ def run_model(recompute_block=[], recompute_kwargs={}, enable_autocast=False): x_data = np.random.randn(batch_size, input_size).astype(np.float32) x = paddle.to_tensor(x_data) # x.stop_gradient = False - with paddle.amp.auto_cast(True): + level = 'O2' if pure_fp16 else 'O1' + with paddle.amp.auto_cast(True, level=level): y_pred = model(x) loss = y_pred.mean() if enable_autocast: @@ -196,6 +200,36 @@ class TestPyLayer(unittest.TestCase): recompute_block=[1, 3], enable_autocast=True) check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) + def test_fc_net_with_fp16(self): + def check_identical(loss_ref, param_ref, grad_ref, loss, param, grad): + self.assertEqual(loss_ref, loss) + self.assertEqual(param_ref, param) + self.assertEqual(grad_ref, grad) + + # without recompute + loss_ref, param_ref, grad_ref = run_model( + recompute_block=[], enable_autocast=True, pure_fp16=True) + + # recompute second block + loss, param, grad = run_model( + recompute_block=[1], enable_autocast=True, pure_fp16=True) + check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) + + # recompute fourth block + loss, param, grad = run_model( + recompute_block=[3], enable_autocast=True, pure_fp16=True) + check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) + + # recompute second to fourth block + loss, param, grad = run_model( + recompute_block=[1, 2, 3], enable_autocast=True, pure_fp16=True) + check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) + + # recompute second & fourth block + loss, param, grad = run_model( + recompute_block=[1, 3], enable_autocast=True, pure_fp16=True) + check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) + def test_recompute_kwargs(self): paddle.set_device("gpu") kwargs = {"is_test": False} diff --git a/python/paddle/fluid/tests/unittests/test_parallel_dygraph_pipeline_parallel.py b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_pipeline_parallel.py index 7a4f7f9fbd..71c254dabb 100644 --- a/python/paddle/fluid/tests/unittests/test_parallel_dygraph_pipeline_parallel.py +++ b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_pipeline_parallel.py @@ -30,9 +30,12 @@ class TestHybridPipeParallel(TestMultipleGpus): def test_hybrid_parallel_shared_weight(self): self.run_mnist_2gpu('hybrid_parallel_shared_weight.py') - def test_pipeline_parallel(self): + def test_pipeline_parallel_amp(self): self.run_mnist_2gpu('hybrid_parallel_pp_amp.py') + def test_pipeline_parallel_fp16(self): + self.run_mnist_2gpu('hybrid_parallel_pp_fp16.py') + def test_hybrid_parallel_transformer(self): self.run_mnist_2gpu('hybrid_parallel_pp_transformer.py') @@ -42,6 +45,9 @@ class TestHybridPipeParallel(TestMultipleGpus): def test_hybrid_parallel_recompute(self): self.run_mnist_2gpu('hybrid_parallel_pp_recompute.py') + def test_hybrid_parallel_pp_clip_grad(self): + self.run_mnist_2gpu('hybrid_parallel_pp_clip_grad.py') + if __name__ == "__main__": unittest.main() -- Gitee