diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index e035146225ab2b1bd3eef7093b034a74241f0814..89c9e65ea8cb563527f13e7be5c68133dd5c5342 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -2343,8 +2343,10 @@ class ConcatDataset(DatasetOp): Number, number of batches. """ if self.dataset_size is None: - children_sizes = [c.get_dataset_size() for c in self.children] - self.dataset_size = sum(children_sizes) + num_rows = 0 + for _ in self.create_dict_iterator(): + num_rows += 1 + self.dataset_size = num_rows return self.dataset_size def use_sampler(self, sampler): diff --git a/mindspore/dataset/transforms/vision/py_transforms.py b/mindspore/dataset/transforms/vision/py_transforms.py index 9e4be238c20a6bee160874dbefa5c77ff6c3c642..2ddf6bbc10a9a37fff7e163424572660f3c13509 100644 --- a/mindspore/dataset/transforms/vision/py_transforms.py +++ b/mindspore/dataset/transforms/vision/py_transforms.py @@ -1115,7 +1115,8 @@ class RandomAffine: - Inter.BICUBIC, means resample method is bicubic interpolation. fill_value (Union[tuple, int], optional): Optional fill_value to fill the area outside the transform - in the output image. Used only in Pillow versions > 5.0.0 (default=0, filling is performed). + in the output image. There must be three elements in tuple and the value of single element is [0, 255]. + Used only in Pillow versions > 5.0.0 (default=0, filling is performed). Raises: ValueError: If degrees is negative. @@ -1127,6 +1128,7 @@ class RandomAffine: TypeError: If translate is specified but is not list or a tuple of length 2. TypeError: If scale is not a list or tuple of length 2. TypeError: If shear is not a list or tuple of length 2 or 4. + TypeError: If fill_value is not a single integer or a 3-tuple. Examples: >>> py_transforms.ComposeOp([py_transforms.Decode(), diff --git a/tests/ut/python/dataset/test_paddeddataset.py b/tests/ut/python/dataset/test_paddeddataset.py index a2a24d03dfce16723d5107d6ea893e17fd777ba3..0c25b67b803a2980fef6e6c06c06132bab8afef1 100644 --- a/tests/ut/python/dataset/test_paddeddataset.py +++ b/tests/ut/python/dataset/test_paddeddataset.py @@ -225,31 +225,63 @@ def test_imagefolder_padded(): assert verify_list[9] == 6 def test_imagefolder_padded_with_decode(): - DATA_DIR = "../data/dataset/testPK/data" - data = ds.ImageFolderDatasetV2(DATA_DIR) + num_shards = 5 + count = 0 + for shard_id in range(num_shards): + DATA_DIR = "../data/dataset/testPK/data" + data = ds.ImageFolderDatasetV2(DATA_DIR) - white_io = BytesIO() - Image.new('RGB', (224, 224), (255, 255, 255)).save(white_io, 'JPEG') - padded_sample = {} - padded_sample['image'] = np.array(bytearray(white_io), dtype='uint8') - padded_sample['label'] = np.array(-1, np.int32) + white_io = BytesIO() + Image.new('RGB', (224, 224), (255, 255, 255)).save(white_io, 'JPEG') + padded_sample = {} + padded_sample['image'] = np.array(bytearray(white_io.getvalue()), dtype='uint8') + padded_sample['label'] = np.array(-1, np.int32) - white_samples = [padded_sample, padded_sample, padded_sample, padded_sample] - data2 = ds.PaddedDataset(white_samples) - data3 = data + data2 + white_samples = [padded_sample, padded_sample, padded_sample, padded_sample] + data2 = ds.PaddedDataset(white_samples) + data3 = data + data2 + testsampler = ds.DistributedSampler(num_shards=num_shards, shard_id=shard_id, shuffle=False, num_samples=None) + data3.use_sampler(testsampler) + data3 = data3.map(input_columns="image", operations=V_C.Decode()) + shard_sample_count = 0 + for ele in data3.create_dict_iterator(): + print("label: {}".format(ele['label'])) + count += 1 + shard_sample_count += 1 + assert shard_sample_count in (9, 10) + assert count == 48 + +def test_imagefolder_padded_with_decode_and_get_dataset_size(): num_shards = 5 count = 0 for shard_id in range(num_shards): + DATA_DIR = "../data/dataset/testPK/data" + data = ds.ImageFolderDatasetV2(DATA_DIR) + + white_io = BytesIO() + Image.new('RGB', (224, 224), (255, 255, 255)).save(white_io, 'JPEG') + padded_sample = {} + padded_sample['image'] = np.array(bytearray(white_io.getvalue()), dtype='uint8') + padded_sample['label'] = np.array(-1, np.int32) + + white_samples = [padded_sample, padded_sample, padded_sample, padded_sample] + data2 = ds.PaddedDataset(white_samples) + data3 = data + data2 + testsampler = ds.DistributedSampler(num_shards=num_shards, shard_id=shard_id, shuffle=False, num_samples=None) data3.use_sampler(testsampler) - data3.map(input_columns="image", operations=V_C.Decode()) + shard_dataset_size = data3.get_dataset_size() + data3 = data3.map(input_columns="image", operations=V_C.Decode()) + shard_sample_count = 0 for ele in data3.create_dict_iterator(): print("label: {}".format(ele['label'])) count += 1 + shard_sample_count += 1 + assert shard_sample_count in (9, 10) + assert shard_dataset_size == shard_sample_count assert count == 48 - def test_more_shard_padded(): result_list = [] for i in range(8):