From ac39c20f41b880afa43f9bedd579fdcdc197a0ed Mon Sep 17 00:00:00 2001 From: liyong Date: Mon, 31 Aug 2020 09:55:59 +0800 Subject: [PATCH] del finish in FileReader --- .../mindrecord/common/shard_pybind.cc | 1 - .../mindrecord/include/shard_reader.h | 4 --- .../minddata/mindrecord/io/shard_reader.cc | 30 +++++++--------- mindspore/mindrecord/filereader.py | 9 ----- mindspore/mindrecord/shardreader.py | 19 +---------- .../cpp/mindrecord/ut_shard_operator_test.cc | 34 +++++++++---------- .../ut/cpp/mindrecord/ut_shard_reader_test.cc | 12 +++---- .../ut/cpp/mindrecord/ut_shard_writer_test.cc | 10 +++--- .../python/mindrecord/test_mindrecord_base.py | 2 +- 9 files changed, 43 insertions(+), 78 deletions(-) diff --git a/mindspore/ccsrc/minddata/mindrecord/common/shard_pybind.cc b/mindspore/ccsrc/minddata/mindrecord/common/shard_pybind.cc index 7c7f79ccfb8d..72492c08575d 100644 --- a/mindspore/ccsrc/minddata/mindrecord/common/shard_pybind.cc +++ b/mindspore/ccsrc/minddata/mindrecord/common/shard_pybind.cc @@ -93,7 +93,6 @@ void BindShardReader(const py::module *m) { .def("get_blob_fields", &ShardReader::GetBlobFields) .def("get_next", (std::vector>, pybind11::object>>(ShardReader::*)()) & ShardReader::GetNextPy) - .def("finish", &ShardReader::Finish) .def("close", &ShardReader::Close); } diff --git a/mindspore/ccsrc/minddata/mindrecord/include/shard_reader.h b/mindspore/ccsrc/minddata/mindrecord/include/shard_reader.h index 13607aebe3d1..32bf1431ab7d 100644 --- a/mindspore/ccsrc/minddata/mindrecord/include/shard_reader.h +++ b/mindspore/ccsrc/minddata/mindrecord/include/shard_reader.h @@ -174,10 +174,6 @@ class ShardReader { ROW_GROUP_BRIEF ReadRowGroupCriteria(int group_id, int shard_id, const std::pair &criteria, const std::vector &columns = std::vector()); - /// \brief join all created threads - /// \return MSRStatus the status of MSRStatus - MSRStatus Finish(); - /// \brief return a batch, given that one is ready /// \return a batch of images and image data std::vector, json>> GetNext(); diff --git a/mindspore/ccsrc/minddata/mindrecord/io/shard_reader.cc b/mindspore/ccsrc/minddata/mindrecord/io/shard_reader.cc index 67fee2c13c90..cc0e9decbf3f 100644 --- a/mindspore/ccsrc/minddata/mindrecord/io/shard_reader.cc +++ b/mindspore/ccsrc/minddata/mindrecord/io/shard_reader.cc @@ -239,7 +239,19 @@ void ShardReader::FileStreamsOperator() { ShardReader::~ShardReader() { Close(); } void ShardReader::Close() { - (void)Finish(); // interrupt reading and stop threads + { + std::lock_guard lck(mtx_delivery_); + interrupt_ = true; // interrupt reading and stop threads + } + cv_delivery_.notify_all(); + + // Wait for all threads to finish + for (auto &i_thread : thread_set_) { + if (i_thread.joinable()) { + i_thread.join(); + } + } + FileStreamsOperator(); } @@ -759,22 +771,6 @@ bool ResortRowGroups(std::tuple a, std::tuple(a) < std::get<1>(b) || (std::get<1>(a) == std::get<1>(b) && std::get<0>(a) < std::get<0>(b)); } -MSRStatus ShardReader::Finish() { - { - std::lock_guard lck(mtx_delivery_); - interrupt_ = true; - } - cv_delivery_.notify_all(); - - // Wait for all threads to finish - for (auto &i_thread : thread_set_) { - if (i_thread.joinable()) { - i_thread.join(); - } - } - return SUCCESS; -} - int64_t ShardReader::GetNumClasses(const std::string &category_field) { auto shard_count = file_paths_.size(); auto index_fields = shard_header_->GetFields(); diff --git a/mindspore/mindrecord/filereader.py b/mindspore/mindrecord/filereader.py index c97bbd687d00..5c38bbc23d4d 100644 --- a/mindspore/mindrecord/filereader.py +++ b/mindspore/mindrecord/filereader.py @@ -83,15 +83,6 @@ class FileReader: yield populate_data(raw, blob, self._columns, self._header.blob_fields, self._header.schema) iterator = self._reader.get_next() - def finish(self): - """ - Stop reader worker. - - Raises: - MRMFinishError: If failed to finish worker threads. - """ - return self._reader.finish() - def close(self): """Stop reader worker and close File.""" return self._reader.close() diff --git a/mindspore/mindrecord/shardreader.py b/mindspore/mindrecord/shardreader.py index f3fc6ffc35fe..b5f2fbd1cc70 100644 --- a/mindspore/mindrecord/shardreader.py +++ b/mindspore/mindrecord/shardreader.py @@ -17,8 +17,7 @@ This module is to read data from mindrecord. """ import mindspore._c_mindrecord as ms from mindspore import log as logger -from .common.exceptions import MRMOpenError, MRMLaunchError, MRMFinishError - +from .common.exceptions import MRMOpenError, MRMLaunchError __all__ = ['ShardReader'] class ShardReader: @@ -102,22 +101,6 @@ class ShardReader: """ return self._reader.get_header() - def finish(self): - """ - stop the worker threads. - - Returns: - MSRStatus, SUCCESS or FAILED. - - Raises: - MRMFinishError: If failed to finish worker threads. - """ - ret = self._reader.finish() - if ret != ms.MSRStatus.SUCCESS: - logger.error("Failed to finish worker threads.") - raise MRMFinishError - return ret - def close(self): """close MindRecord File.""" self._reader.close() diff --git a/tests/ut/cpp/mindrecord/ut_shard_operator_test.cc b/tests/ut/cpp/mindrecord/ut_shard_operator_test.cc index 7b9186ac37d7..e2e5adfbdd04 100644 --- a/tests/ut/cpp/mindrecord/ut_shard_operator_test.cc +++ b/tests/ut/cpp/mindrecord/ut_shard_operator_test.cc @@ -73,7 +73,7 @@ TEST_F(TestShardOperator, TestShardSampleBasic) { MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"]); i++; } - dataset.Finish(); + dataset.Close(); ASSERT_TRUE(i <= kSampleCount); } @@ -99,7 +99,7 @@ TEST_F(TestShardOperator, TestShardSampleWrongNumber) { MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"]); i++; } - dataset.Finish(); + dataset.Close(); ASSERT_TRUE(i <= 5); } @@ -125,7 +125,7 @@ TEST_F(TestShardOperator, TestShardSampleRatio) { MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"]); i++; } - dataset.Finish(); + dataset.Close(); ASSERT_TRUE(i <= 10); } @@ -151,7 +151,7 @@ TEST_F(TestShardOperator, TestShardSamplePartition) { MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"]); i++; } - dataset.Finish(); + dataset.Close(); ASSERT_TRUE(i <= 10); } @@ -176,7 +176,7 @@ TEST_F(TestShardOperator, TestShardPkSamplerBasic) { << ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump()) << std::endl; i++; } - dataset.Finish(); + dataset.Close(); ASSERT_TRUE(i == 20); } // namespace mindrecord @@ -202,7 +202,7 @@ TEST_F(TestShardOperator, TestShardPkSamplerNumClass) { << ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump()) << std::endl; i++; } - dataset.Finish(); + dataset.Close(); ASSERT_TRUE(i == 6); } @@ -238,7 +238,7 @@ TEST_F(TestShardOperator, TestShardCategory) { category_no++; category_no %= static_cast(categories.size()); } - dataset.Finish(); + dataset.Close(); } TEST_F(TestShardOperator, TestShardShuffle) { @@ -262,7 +262,7 @@ TEST_F(TestShardOperator, TestShardShuffle) { << ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump()); i++; } - dataset.Finish(); + dataset.Close(); } TEST_F(TestShardOperator, TestShardSampleShuffle) { @@ -287,7 +287,7 @@ TEST_F(TestShardOperator, TestShardSampleShuffle) { << ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump()); i++; } - dataset.Finish(); + dataset.Close(); ASSERT_LE(i, 35); } @@ -314,7 +314,7 @@ TEST_F(TestShardOperator, TestShardShuffleSample) { << ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump()); i++; } - dataset.Finish(); + dataset.Close(); ASSERT_TRUE(i <= kSampleSize); } @@ -341,7 +341,7 @@ TEST_F(TestShardOperator, TestShardSampleShuffleSample) { << ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump()); i++; } - dataset.Finish(); + dataset.Close(); ASSERT_LE(i, 35); } @@ -373,8 +373,8 @@ TEST_F(TestShardOperator, TestShardShuffleCompare) { auto y = compare_dataset.GetNext(); if ((std::get<1>(x[0]))["file_name"] != (std::get<1>(y[0]))["file_name"]) different = true; } - dataset.Finish(); - compare_dataset.Finish(); + dataset.Close(); + compare_dataset.Close(); ASSERT_TRUE(different); } @@ -409,7 +409,7 @@ TEST_F(TestShardOperator, TestShardCategoryShuffle1) { category_no++; category_no %= static_cast(categories.size()); } - dataset.Finish(); + dataset.Close(); } TEST_F(TestShardOperator, TestShardCategoryShuffle2) { @@ -442,7 +442,7 @@ TEST_F(TestShardOperator, TestShardCategoryShuffle2) { category_no++; category_no %= static_cast(categories.size()); } - dataset.Finish(); + dataset.Close(); } TEST_F(TestShardOperator, TestShardCategorySample) { @@ -477,7 +477,7 @@ TEST_F(TestShardOperator, TestShardCategorySample) { category_no++; category_no %= static_cast(categories.size()); } - dataset.Finish(); + dataset.Close(); ASSERT_EQ(category_no, 0); ASSERT_TRUE(i <= kSampleSize); } @@ -515,7 +515,7 @@ TEST_F(TestShardOperator, TestShardCategorySampleShuffle) { category_no++; category_no %= static_cast(categories.size()); } - dataset.Finish(); + dataset.Close(); ASSERT_EQ(category_no, 0); ASSERT_TRUE(i <= kSampleSize); } diff --git a/tests/ut/cpp/mindrecord/ut_shard_reader_test.cc b/tests/ut/cpp/mindrecord/ut_shard_reader_test.cc index fb0e8470ced6..7b56f5e18f99 100644 --- a/tests/ut/cpp/mindrecord/ut_shard_reader_test.cc +++ b/tests/ut/cpp/mindrecord/ut_shard_reader_test.cc @@ -67,7 +67,7 @@ TEST_F(TestShardReader, TestShardReaderGeneral) { } } } - dataset.Finish(); + dataset.Close(); } TEST_F(TestShardReader, TestShardReaderSample) { @@ -90,7 +90,7 @@ TEST_F(TestShardReader, TestShardReaderSample) { } } } - dataset.Finish(); + dataset.Close(); dataset.Close(); } @@ -110,7 +110,7 @@ TEST_F(TestShardReader, TestShardReaderEasy) { } } } - dataset.Finish(); + dataset.Close(); } TEST_F(TestShardReader, TestShardReaderColumnNotInIndex) { @@ -131,7 +131,7 @@ TEST_F(TestShardReader, TestShardReaderColumnNotInIndex) { } } } - dataset.Finish(); + dataset.Close(); } TEST_F(TestShardReader, TestShardReaderColumnNotInSchema) { @@ -161,7 +161,7 @@ TEST_F(TestShardReader, TestShardVersion) { } } } - dataset.Finish(); + dataset.Close(); } TEST_F(TestShardReader, TestShardReaderDir) { @@ -192,7 +192,7 @@ TEST_F(TestShardReader, TestShardReaderConsumer) { } } } - dataset.Finish(); + dataset.Close(); } } // namespace mindrecord } // namespace mindspore diff --git a/tests/ut/cpp/mindrecord/ut_shard_writer_test.cc b/tests/ut/cpp/mindrecord/ut_shard_writer_test.cc index a8abe5e98db4..ce60874a7963 100644 --- a/tests/ut/cpp/mindrecord/ut_shard_writer_test.cc +++ b/tests/ut/cpp/mindrecord/ut_shard_writer_test.cc @@ -74,7 +74,7 @@ TEST_F(TestShardWriter, TestShardWriterOneSample) { } } } - dataset.Finish(); + dataset.Close(); for (int i = 1; i <= 4; i++) { string filename = std::string("./OneSample.shard0") + std::to_string(i); string db_name = std::string("./OneSample.shard0") + std::to_string(i) + ".db"; @@ -775,7 +775,7 @@ TEST_F(TestShardWriter, TestShardReaderStringAndNumberColumnInIndex) { } } ASSERT_TRUE(count == 10); - dataset.Finish(); + dataset.Close(); for (const auto &filename : file_names) { auto filename_db = filename + ".db"; @@ -858,7 +858,7 @@ TEST_F(TestShardWriter, TestShardNoBlob) { } } ASSERT_TRUE(count == 10); - dataset.Finish(); + dataset.Close(); for (const auto &filename : file_names) { auto filename_db = filename + ".db"; remove(common::SafeCStr(filename_db)); @@ -952,7 +952,7 @@ TEST_F(TestShardWriter, TestShardReaderStringAndNumberNotColumnInIndex) { } } ASSERT_TRUE(count == 10); - dataset.Finish(); + dataset.Close(); for (const auto &filename : file_names) { auto filename_db = filename + ".db"; remove(common::SafeCStr(filename_db)); @@ -1060,7 +1060,7 @@ TEST_F(TestShardWriter, TestShardWriter10Sample40Shard) { count++; } ASSERT_TRUE(count == 10); - dataset.Finish(); + dataset.Close(); for (const auto &filename : file_names) { auto filename_db = filename + ".db"; remove(common::SafeCStr(filename_db)); diff --git a/tests/ut/python/mindrecord/test_mindrecord_base.py b/tests/ut/python/mindrecord/test_mindrecord_base.py index 72a370e42448..3b25cf73a448 100644 --- a/tests/ut/python/mindrecord/test_mindrecord_base.py +++ b/tests/ut/python/mindrecord/test_mindrecord_base.py @@ -260,7 +260,7 @@ def test_cv_file_reader_partial_tutorial(): count = count + 1 logger.info("#item{}: {}".format(index, x)) if count == 5: - reader.finish() + reader.close() assert count == 5 -- Gitee