diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/image_utils.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/image_utils.cc index 8de575c107c2ed808264134137059e45a60009e2..5680b6dec2b40e7718a50ab02a843b61655abd5d 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/image_utils.cc +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/image_utils.cc @@ -1137,8 +1137,8 @@ Status Affine(const std::shared_ptr &input, std::shared_ptr *out InterpolationMode interpolation, uint8_t fill_r, uint8_t fill_g, uint8_t fill_b) { try { std::shared_ptr input_cv = CVTensor::AsCVTensor(input); - if (input_cv->Rank() != 3 || input_cv->shape()[2] != 3) { - RETURN_STATUS_UNEXPECTED("Affine: image shape is not or channel is not 3."); + if (input_cv->Rank() == 1 || input_cv->Rank() > 3) { + RETURN_STATUS_UNEXPECTED("Affine: image shape is not or ."); } cv::Mat affine_mat(mat); diff --git a/mindspore/dataset/__init__.py b/mindspore/dataset/__init__.py index b62f7917e328f74417abc7b67def580d682cb7ff..2df1675f2759fa626988b74113f5ff5563712caa 100644 --- a/mindspore/dataset/__init__.py +++ b/mindspore/dataset/__init__.py @@ -13,9 +13,9 @@ # limitations under the License. """ This module provides APIs to load and process various common datasets such as MNIST, -CIFAR-10, CIFAR-100, VOC, ImageNet, CelebA, etc. It also supports datasets in standard -format, including MindRecord, TFRecord, Manifest, etc. Users can also define their own -datasets with this module. +CIFAR-10, CIFAR-100, VOC, COCO, ImageNet, CelebA, CLUE, etc. It also supports datasets +in standard format, including MindRecord, TFRecord, Manifest, etc. Users can also define +their owndatasets with this module. Besides, this module provides APIs to sample data while loading. diff --git a/mindspore/dataset/core/validator_helpers.py b/mindspore/dataset/core/validator_helpers.py index 9d9f98e7726814d15294df31a910f81dab536a7c..da10eab6a534c76e8503c627eb7456131269f080 100644 --- a/mindspore/dataset/core/validator_helpers.py +++ b/mindspore/dataset/core/validator_helpers.py @@ -74,6 +74,14 @@ def check_value(value, valid_range, arg_name=""): valid_range[1])) +def check_value_cutoff(value, valid_range, arg_name=""): + arg_name = pad_arg_name(arg_name) + if value < valid_range[0] or value >= valid_range[1]: + raise ValueError( + "Input {0}is not within the required interval of [{1}, {2}).".format(arg_name, valid_range[0], + valid_range[1])) + + def check_value_normalize_std(value, valid_range, arg_name=""): arg_name = pad_arg_name(arg_name) if value <= valid_range[0] or value > valid_range[1]: @@ -404,7 +412,7 @@ def check_tensor_op(param, param_name): def check_c_tensor_op(param, param_name): """check whether param is a tensor op or a callable Python function but not a py_transform""" - if callable(param) and getattr(param, 'parse', True): + if callable(param) and str(param).find("py_transform") >= 0: raise TypeError("{0} is a py_transform op which is not allow to use.".format(param_name)) if not isinstance(param, cde.TensorOp) and not callable(param) and not getattr(param, 'parse', None): raise TypeError("{0} is neither a c_transform op (TensorOperation) nor a callable pyfunc.".format(param_name)) diff --git a/mindspore/dataset/text/transforms.py b/mindspore/dataset/text/transforms.py index 2bdbc82cc53d2a2cd8f406d4d14a41f3f2cb5238..352a7d095acad1df501a105a17af341cf70bc7f6 100644 --- a/mindspore/dataset/text/transforms.py +++ b/mindspore/dataset/text/transforms.py @@ -531,7 +531,8 @@ class PythonTokenizer: self.random = False def __call__(self, in_array): - in_array = to_str(in_array) + if not isinstance(in_array, str): + in_array = to_str(in_array) tokens = self.tokenizer(in_array) return tokens diff --git a/mindspore/dataset/vision/c_transforms.py b/mindspore/dataset/vision/c_transforms.py index aa8786e90ed35143643924e175904bfe39bf3a0e..9ff6ec1401786d9597b4a56682c1d2145014d4b3 100644 --- a/mindspore/dataset/vision/c_transforms.py +++ b/mindspore/dataset/vision/c_transforms.py @@ -104,7 +104,8 @@ class AutoContrast(ImageTensorOperation): Apply automatic contrast on input image. Args: - cutoff (float, optional): Percent of pixels to cut off from the histogram (default=0.0). + cutoff (float, optional): Percent of pixels to cut off from the histogram, + the value must be in the range [0.0, 50.0) (default=0.0). ignore (Union[int, sequence], optional): Pixel values to ignore (default=None). Examples: @@ -770,7 +771,7 @@ class RandomCropDecodeResize(ImageTensorOperation): if img.ndim != 1 or img.dtype.type is not np.uint8: raise TypeError("Input should be an encoded image with uint8 type in 1-D NumPy format, " + "got format:{}, dtype:{}.".format(type(img), img.dtype.type)) - super().__call__(img=img) + return super().__call__(img) class RandomCropWithBBox(ImageTensorOperation): diff --git a/mindspore/dataset/vision/py_transforms.py b/mindspore/dataset/vision/py_transforms.py index f00203b2f5beeb1843f95632cec1ba7676900045..4b0a7cb9026f8ece30081eae29f2584462de6cc6 100644 --- a/mindspore/dataset/vision/py_transforms.py +++ b/mindspore/dataset/vision/py_transforms.py @@ -1031,7 +1031,7 @@ class RandomErasing: class Cutout: """ - Randomly cut (mask) out a given number of square patches from the input NumPy image array. + Randomly cut (mask) out a given number of square patches from the input NumPy image array of shape (C, H, W). Terrance DeVries and Graham W. Taylor 'Improved Regularization of Convolutional Neural Networks with Cutout' 2017 See https://arxiv.org/pdf/1708.04552.pdf @@ -1068,6 +1068,9 @@ class Cutout: """ if not isinstance(np_img, np.ndarray): raise TypeError("img should be NumPy array. Got {}.".format(type(np_img))) + if np_img.ndim != 3: + raise TypeError('img dimension should be 3. Got {}.'.format(np_img.ndim)) + _, image_h, image_w = np_img.shape scale = (self.length * self.length) / (image_h * image_w) bounded = False @@ -1426,7 +1429,8 @@ class AutoContrast: Automatically maximize the contrast of the input PIL image. Args: - cutoff (float, optional): Percent of pixels to cut off from the histogram (default=0.0). + cutoff (float, optional): Percent of pixels to cut off from the histogram, + the value must be in the range [0.0, 50.0) (default=0.0). ignore (Union[int, sequence], optional): Pixel values to ignore (default=None). Examples: diff --git a/mindspore/dataset/vision/py_transforms_util.py b/mindspore/dataset/vision/py_transforms_util.py index 7e05e5e983bd55915c21e0e18e52fc8451159741..e6a72aefabb73a99a4fb75f90af0f8a21201c75e 100644 --- a/mindspore/dataset/vision/py_transforms_util.py +++ b/mindspore/dataset/vision/py_transforms_util.py @@ -56,13 +56,16 @@ def normalize(img, mean, std, pad_channel=False, dtype="float32"): Returns: img (numpy.ndarray), Normalized image. """ + if not is_numpy(img): + raise TypeError("img should be NumPy image. Got {}.".format(type(img))) + + if img.ndim != 3: + raise TypeError('img dimension should be 3. Got {}.'.format(img.ndim)) + if np.issubdtype(img.dtype, np.integer): raise NotImplementedError("Unsupported image datatype: [{}], pls execute [ToTensor] before [Normalize]." .format(img.dtype)) - if not is_numpy(img): - raise TypeError("img should be NumPy image. Got {}.".format(type(img))) - num_channels = img.shape[0] # shape is (C, H, W) if len(mean) != len(std): @@ -119,9 +122,11 @@ def hwc_to_chw(img): Returns: img (numpy.ndarray), Converted image. """ - if is_numpy(img): - return img.transpose(2, 0, 1).copy() - raise TypeError('img should be NumPy array. Got {}.'.format(type(img))) + if not is_numpy(img): + raise TypeError('img should be NumPy array. Got {}.'.format(type(img))) + if img.ndim != 3: + raise TypeError('img dimension should be 3. Got {}.'.format(img.ndim)) + return img.transpose(2, 0, 1).copy() def to_tensor(img, output_type): @@ -140,7 +145,7 @@ def to_tensor(img, output_type): img = np.asarray(img) if img.ndim not in (2, 3): - raise ValueError("img dimension should be 2 or 3. Got {}.".format(img.ndim)) + raise TypeError("img dimension should be 2 or 3. Got {}.".format(img.ndim)) if img.ndim == 2: img = img[:, :, None] @@ -856,8 +861,8 @@ def pad(img, padding, fill_value, padding_mode): elif isinstance(padding, (tuple, list)): if len(padding) == 2: - left = right = padding[0] - top = bottom = padding[1] + left = top = padding[0] + right = bottom = padding[1] elif len(padding) == 4: left = padding[0] top = padding[1] @@ -877,10 +882,10 @@ def pad(img, padding, fill_value, padding_mode): if padding_mode == 'constant': if img.mode == 'P': palette = img.getpalette() - image = ImageOps.expand(img, border=padding, fill=fill_value) + image = ImageOps.expand(img, border=(left, top, right, bottom), fill=fill_value) image.putpalette(palette) return image - return ImageOps.expand(img, border=padding, fill=fill_value) + return ImageOps.expand(img, border=(left, top, right, bottom), fill=fill_value) if img.mode == 'P': palette = img.getpalette() @@ -1254,6 +1259,9 @@ def rgb_to_hsvs(np_rgb_imgs, is_hwc): if not is_numpy(np_rgb_imgs): raise TypeError("img should be NumPy image. Got {}".format(type(np_rgb_imgs))) + if not isinstance(is_hwc, bool): + raise TypeError("is_hwc should be bool type. Got {}.".format(type(is_hwc))) + shape_size = len(np_rgb_imgs.shape) if not shape_size in (3, 4): @@ -1322,6 +1330,9 @@ def hsv_to_rgbs(np_hsv_imgs, is_hwc): if not is_numpy(np_hsv_imgs): raise TypeError("img should be NumPy image. Got {}.".format(type(np_hsv_imgs))) + if not isinstance(is_hwc, bool): + raise TypeError("is_hwc should be bool type. Got {}.".format(type(is_hwc))) + shape_size = len(np_hsv_imgs.shape) if not shape_size in (3, 4): diff --git a/mindspore/dataset/vision/validators.py b/mindspore/dataset/vision/validators.py index e2cb8dfed04692ccfc2355647803906f4b406660..5a1b9f6dcbfbe98d24b622da0d1de065bf539225 100644 --- a/mindspore/dataset/vision/validators.py +++ b/mindspore/dataset/vision/validators.py @@ -21,7 +21,7 @@ from mindspore._c_dataengine import TensorOp, TensorOperation from mindspore.dataset.core.validator_helpers import check_value, check_uint8, FLOAT_MAX_INTEGER, check_pos_float32, \ check_float32, check_2tuple, check_range, check_positive, INT32_MAX, parse_user_args, type_check, type_check_list, \ - check_c_tensor_op, UINT8_MAX, check_value_normalize_std + check_c_tensor_op, UINT8_MAX, check_value_normalize_std, check_value_cutoff from .utils import Inter, Border, ImageBatchFormat @@ -650,7 +650,7 @@ def check_auto_contrast(method): def new_method(self, *args, **kwargs): [cutoff, ignore], _ = parse_user_args(method, *args, **kwargs) type_check(cutoff, (int, float), "cutoff") - check_value(cutoff, [0, 100], "cutoff") + check_value_cutoff(cutoff, [0, 50], "cutoff") if ignore is not None: type_check(ignore, (list, tuple, int), "ignore") if isinstance(ignore, int): diff --git a/tests/ut/python/dataset/test_autocontrast.py b/tests/ut/python/dataset/test_autocontrast.py index 550450f45fe0aea57fa86f49c6248d1cf5010755..a212b4660548d2aa0483a5933c0110e89be31fb8 100644 --- a/tests/ut/python/dataset/test_autocontrast.py +++ b/tests/ut/python/dataset/test_autocontrast.py @@ -270,7 +270,7 @@ def test_auto_contrast_invalid_cutoff_param_c(): data_set = data_set.map(operations=C.AutoContrast(cutoff=-10.0), input_columns="image") except ValueError as error: logger.info("Got an exception in DE: {}".format(str(error))) - assert "Input cutoff is not within the required interval of (0 to 100)." in str(error) + assert "Input cutoff is not within the required interval of [0, 50)." in str(error) try: data_set = ds.ImageFolderDataset(dataset_dir=DATA_DIR, shuffle=False) data_set = data_set.map(operations=[C.Decode(), @@ -280,7 +280,7 @@ def test_auto_contrast_invalid_cutoff_param_c(): data_set = data_set.map(operations=C.AutoContrast(cutoff=120.0), input_columns="image") except ValueError as error: logger.info("Got an exception in DE: {}".format(str(error))) - assert "Input cutoff is not within the required interval of (0 to 100)." in str(error) + assert "Input cutoff is not within the required interval of [0, 50)." in str(error) def test_auto_contrast_invalid_ignore_param_py(): @@ -327,7 +327,7 @@ def test_auto_contrast_invalid_cutoff_param_py(): input_columns=["image"]) except ValueError as error: logger.info("Got an exception in DE: {}".format(str(error))) - assert "Input cutoff is not within the required interval of (0 to 100)." in str(error) + assert "Input cutoff is not within the required interval of [0, 50)." in str(error) try: data_set = ds.ImageFolderDataset(dataset_dir=DATA_DIR, shuffle=False) data_set = data_set.map( @@ -338,7 +338,7 @@ def test_auto_contrast_invalid_cutoff_param_py(): input_columns=["image"]) except ValueError as error: logger.info("Got an exception in DE: {}".format(str(error))) - assert "Input cutoff is not within the required interval of (0 to 100)." in str(error) + assert "Input cutoff is not within the required interval of [0, 50)." in str(error) if __name__ == "__main__": diff --git a/tests/ut/python/dataset/test_eager_text.py b/tests/ut/python/dataset/test_eager_text.py new file mode 100644 index 0000000000000000000000000000000000000000..76a26e55040138bdf6a928ccfff69697548105e6 --- /dev/null +++ b/tests/ut/python/dataset/test_eager_text.py @@ -0,0 +1,67 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import numpy as np +import mindspore.dataset.text.transforms as T +import mindspore.common.dtype as mstype +from mindspore import log as logger + +def test_sliding_window(): + txt = ["Welcome", "to", "Beijing", "!"] + sliding_window = T.SlidingWindow(width=2) + txt = sliding_window(txt) + logger.info("Result: {}".format(txt)) + + expected = [['Welcome', 'to'], ['to', 'Beijing'], ['Beijing', '!']] + np.testing.assert_equal(txt, expected) + + +def test_to_number(): + txt = ["123456"] + to_number = T.ToNumber(mstype.int32) + txt = to_number(txt) + logger.info("Result: {}, type: {}".format(txt, type(txt[0]))) + + assert txt == 123456 + + +def test_whitespace_tokenizer(): + txt = "Welcome to Beijing !" + txt = T.WhitespaceTokenizer()(txt) + logger.info("Tokenize result: {}".format(txt)) + + expected = ['Welcome', 'to', 'Beijing', '!'] + np.testing.assert_equal(txt, expected) + + +def test_python_tokenizer(): + # whitespace tokenizer + def my_tokenizer(line): + words = line.split() + if not words: + return [""] + return words + txt = "Welcome to Beijing !" + txt = T.PythonTokenizer(my_tokenizer)(txt) + logger.info("Tokenize result: {}".format(txt)) + + expected = ['Welcome', 'to', 'Beijing', '!'] + np.testing.assert_equal(txt, expected) + + +if __name__ == '__main__': + test_sliding_window() + test_to_number() + test_whitespace_tokenizer() + test_python_tokenizer()