From 422614c55895c70ca086b7bb78b4b360fb4c2152 Mon Sep 17 00:00:00 2001 From: changzherui Date: Sat, 8 May 2021 22:11:38 +0800 Subject: [PATCH] modify dis_load_ckpt for master --- mindspore/train/serialization.py | 90 ++++++++++++++++++++++---------- 1 file changed, 63 insertions(+), 27 deletions(-) diff --git a/mindspore/train/serialization.py b/mindspore/train/serialization.py index fa6b6dcc5660..46e5d5802656 100644 --- a/mindspore/train/serialization.py +++ b/mindspore/train/serialization.py @@ -22,6 +22,7 @@ import shutil import time import copy from threading import Thread, Lock +from collections import defaultdict import numpy as np import mindspore.nn as nn @@ -1138,19 +1139,18 @@ def merge_sliced_parameter(sliced_parameters, strategy=None): return merged_parameter -def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy=None, dec_key=None, dec_mode='AES-GCM'): +def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy=None, + train_strategy_filename=None, dec_key=None, dec_mode='AES-GCM'): """ Load checkpoint into net for distributed predication. Args: network (Cell): Network for distributed predication. - checkpoint_filenames (list(str)): The name of Checkpoint files - in order of rank id. - predict_strategy (Optional(dict)): Strategy of predication process, whose key - is parameter name, and value is a list or a tuple that the first four - elements are [dev_matrix, tensor_map, param_split_shape, field]. If None, - it means that the predication process just uses single device. - Default: None. + checkpoint_filenames (list[str]): The name of Checkpoint files in order of rank id. + predict_strategy (dict): Strategy of predication process, whose key is parameter name, and value is a list or + a tuple that the first four elements are [dev_matrix, tensor_map, param_split_shape, field]. If None, + it means that the predication process just uses single device. Default: None. + train_strategy_filename (str): Train strategy file. Default: None. dec_key (Union[None, bytes]): Byte type key used for decryption. If the value is None, the decryption is not required. Default: None. dec_mode (str): This parameter is valid only when dec_key is not set to None. Specifies the decryption @@ -1161,35 +1161,34 @@ def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy= ValueError: Failed to load checkpoint into net. """ network = Validator.check_isinstance("network", network, nn.Cell) - - for index, filename in enumerate(checkpoint_filenames): - if not isinstance(filename, str) or not os.path.exists(filename) \ - or filename[-5:] != ".ckpt" or os.path.getsize(filename) == 0: - raise ValueError(f"Please make sure that the {filename} at index {index} is a valid checkpoint file.") - - if not _check_predict_strategy(predict_strategy): - raise ValueError(f"Please make sure that the key of predict_strategy is str, " - f"and the value is a list or a tuple that the first four elements are " - f"dev_matrix (list[int]), tensor_map (list[int]), " - f"param_split_shape (list[int]) and field_size (zero).") + _check_checkpoint_file(checkpoint_filenames) + _check_predict_strategy(predict_strategy) dec_key = Validator.check_isinstance('dec_key', dec_key, (type(None), bytes)) dec_mode = Validator.check_isinstance('dec_mode', dec_mode, str) - train_strategy_filename = context.get_auto_parallel_context("strategy_ckpt_load_file") + if train_strategy_filename is None: + train_strategy_filename = context.get_auto_parallel_context("strategy_ckpt_load_file") _train_strategy = build_searched_strategy(train_strategy_filename) train_strategy = _convert_to_list(_train_strategy) train_dev_count = 1 + ckpt_file_len = len(checkpoint_filenames) for dim in train_strategy[list(train_strategy.keys())[0]][0]: train_dev_count *= dim - if train_dev_count != len(checkpoint_filenames): + if train_dev_count != ckpt_file_len: raise ValueError( f"The length of checkpoint_filenames should be equal to the device count of training process. " - f"The length is {len(checkpoint_filenames)} but the device count is {train_dev_count}.") + f"The length is {ckpt_file_len} but the device count is {train_dev_count}.") rank_list = _infer_rank_list(train_strategy, predict_strategy) + param_total_dict = defaultdict(dict) + for file_index, file_name in enumerate(checkpoint_filenames): + ckpt_dict = load_checkpoint(file_name, dec_key, dec_mode) + for param_name, param in ckpt_dict.items(): + param_total_dict[param_name][file_index] = param + param_dict = {} for _, param in network.parameters_and_names(): sliced_params = [] @@ -1197,8 +1196,31 @@ def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy= continue param_rank = rank_list[param.name][0] skip_merge_split = rank_list[param.name][1] + shard_stride = train_strategy[param.name][4] + if train_strategy[param.name][5]: + shard_size = ckpt_file_len / shard_stride / train_strategy[param.name][5] + else: + shard_size = 0 for rank in param_rank: - sliced_param = load_checkpoint(checkpoint_filenames[rank], dec_key=dec_key, dec_mode=dec_mode)[param.name] + param_total_list = list(range(0, ckpt_file_len)) + if shard_size > 0: + shard_total_list = [param_total_list[i:i + shard_size] for i in + range(0, ckpt_file_len, shard_size)] + param_total_list = shard_total_list[rank // shard_size] + if shard_stride > 0: + param_stride = [] + # merge pre parameter + param_index = param_total_list[0:param_total_list.index(rank) + 1][::-1][::shard_stride] + param_index.extend(param_total_list[param_total_list.index(rank):][::shard_stride]) + param_index = list(set(param_index)) + param_index.sort() + for rank_num in param_index: + param_stride.append(param_total_dict[param.name][rank_num].data.asnumpy()) + + sliced_param = Parameter(Tensor(np.concatenate(param_stride)), name=param.name) + else: + sliced_param = param_total_dict[param.name][rank] + sliced_params.append(sliced_param) if skip_merge_split: split_param = sliced_params[0] @@ -1222,19 +1244,33 @@ def _check_predict_strategy(predict_strategy): return True if predict_strategy is None: - return True + return + flag = True predict_strategy = Validator.check_isinstance("predict_strategy", predict_strategy, dict) for key in predict_strategy.keys(): if not isinstance(key, str) or not isinstance(predict_strategy[key], (list, tuple)) \ or len(predict_strategy[key]) < 4: - return False + flag = False dev_matrix, tensor_map, param_split_shape, field_size = predict_strategy[key][:4] if not _check_int_list(dev_matrix) or not _check_int_list(tensor_map) or \ not (_check_int_list(param_split_shape) or not param_split_shape) or \ not (isinstance(field_size, int) and field_size == 0): - return False - return True + flag = False + + if not flag: + raise ValueError(f"Please make sure that the key of predict_strategy is str, " + f"and the value is a list or a tuple that the first four elements are " + f"dev_matrix (list[int]), tensor_map (list[int]), " + f"param_split_shape (list[int]) and field_size (zero).") + + +def _check_checkpoint_file(checkpoint_filenames): + """Check checkpoint file name.""" + for index, filename in enumerate(checkpoint_filenames): + if not isinstance(filename, str) or not os.path.exists(filename) \ + or filename[-5:] != ".ckpt" or os.path.getsize(filename) == 0: + raise ValueError(f"Please make sure that the {filename} at index {index} is a valid checkpoint file.") def _convert_to_list(strategy): -- Gitee