| @@ -41,15 +41,15 @@ Status CacheClient::Builder::Build(std::shared_ptr<CacheClient> *out) { | |||||
| } | } | ||||
| Status CacheClient::Builder::SanityCheck() { | Status CacheClient::Builder::SanityCheck() { | ||||
| CHECK_FAIL_RETURN_UNEXPECTED(session_id_ > 0, "session id must be positive"); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(cache_mem_sz_ >= 0, "cache memory size must not be negative. (0 implies unlimited"); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(num_connections_ > 0, "rpc connections must be positive"); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(prefetch_size_ > 0, "prefetch size must be positive"); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(!hostname_.empty(), "hostname must not be empty"); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(port_ > 1024, "Port must be in range (1025..65535)"); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(port_ <= 65535, "Port must be in range (1025..65535)"); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(hostname_ == "127.0.0.1", | |||||
| "now cache client has to be on the same host with cache server"); | |||||
| CHECK_FAIL_RETURN_SYNTAX_ERROR(session_id_ > 0, "session id must be positive"); | |||||
| CHECK_FAIL_RETURN_SYNTAX_ERROR(cache_mem_sz_ >= 0, "cache memory size must not be negative. (0 implies unlimited"); | |||||
| CHECK_FAIL_RETURN_SYNTAX_ERROR(num_connections_ > 0, "rpc connections must be positive"); | |||||
| CHECK_FAIL_RETURN_SYNTAX_ERROR(prefetch_size_ > 0, "prefetch size must be positive"); | |||||
| CHECK_FAIL_RETURN_SYNTAX_ERROR(!hostname_.empty(), "hostname must not be empty"); | |||||
| CHECK_FAIL_RETURN_SYNTAX_ERROR(port_ > 1024, "Port must be in range (1025..65535)"); | |||||
| CHECK_FAIL_RETURN_SYNTAX_ERROR(port_ <= 65535, "Port must be in range (1025..65535)"); | |||||
| CHECK_FAIL_RETURN_SYNTAX_ERROR(hostname_ == "127.0.0.1", | |||||
| "now cache client has to be on the same host with cache server"); | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| @@ -476,6 +476,12 @@ Status BatchOp::Accept(NodePass *p, bool *modified) { | |||||
| return p->RunOnNode(shared_from_base<BatchOp>(), modified); | return p->RunOnNode(shared_from_base<BatchOp>(), modified); | ||||
| } | } | ||||
| // Visitor pre-accept method for NodePass | |||||
| Status BatchOp::PreAccept(NodePass *p, bool *modified) { | |||||
| // Downcast shared pointer then call visitor | |||||
| return p->PreRunOnNode(shared_from_base<BatchOp>(), modified); | |||||
| } | |||||
| Status BatchOp::ComputeColMap() { | Status BatchOp::ComputeColMap() { | ||||
| CHECK_FAIL_RETURN_UNEXPECTED(child_.size() == 1, | CHECK_FAIL_RETURN_UNEXPECTED(child_.size() == 1, | ||||
| "Batch has " + std::to_string(child_.size()) + " child/children, expects only 1 child."); | "Batch has " + std::to_string(child_.size()) + " child/children, expects only 1 child."); | ||||
| @@ -199,6 +199,12 @@ class BatchOp : public ParallelOp { | |||||
| // @return - Status of the node visit. | // @return - Status of the node visit. | ||||
| Status Accept(NodePass *p, bool *modified) override; | Status Accept(NodePass *p, bool *modified) override; | ||||
| // Base-class override for NodePass visitor acceptor. | |||||
| // @param p - Pointer to the NodePass to be accepted. | |||||
| // @param modified - Whether this node visit modified the pipeline. | |||||
| // @return - Status of the node visit. | |||||
| Status PreAccept(NodePass *p, bool *modified) override; | |||||
| // Op name getter | // Op name getter | ||||
| // @return Name of the current Op | // @return Name of the current Op | ||||
| std::string Name() const override { return kBatchOp; } | std::string Name() const override { return kBatchOp; } | ||||
| @@ -468,10 +468,16 @@ uint32_t DatasetOp::GenerateCRC(const std::shared_ptr<DatasetOp> &op) { | |||||
| ss_str = std::regex_replace(ss_str, std::regex("device_id.*\n"), ""); | ss_str = std::regex_replace(ss_str, std::regex("device_id.*\n"), ""); | ||||
| // Filter out the operator id field | // Filter out the operator id field | ||||
| ss_str = std::regex_replace(ss_str, std::regex("Parent.*\n"), ""); | |||||
| ss_str = std::regex_replace(ss_str, std::regex("Child.*\n"), ""); | |||||
| ss_str = std::regex_replace(ss_str, std::regex(".*Parent.*\n"), ""); | |||||
| ss_str = std::regex_replace(ss_str, std::regex(".*Child.*\n"), ""); | |||||
| ss_str = std::regex_replace(ss_str, std::regex(R"(\(\s*\d+?\))"), ""); | ss_str = std::regex_replace(ss_str, std::regex(R"(\(\s*\d+?\))"), ""); | ||||
| // Doesn't matter whether there is any parent node above CacheOp or not. | |||||
| ss_str = std::regex_replace(ss_str, std::regex("Number of parents.*\n"), ""); | |||||
| // Filter out shuffle seed from ShuffleOp | |||||
| ss_str = std::regex_replace(ss_str, std::regex("Shuffle seed.*\n"), ""); | |||||
| // Filter out the total repeats and number repeats per epoch field | // Filter out the total repeats and number repeats per epoch field | ||||
| ss_str = std::regex_replace(ss_str, std::regex("Total repeats.*\n"), ""); | ss_str = std::regex_replace(ss_str, std::regex("Total repeats.*\n"), ""); | ||||
| ss_str = std::regex_replace(ss_str, std::regex("Number repeats per epoch.*\n"), ""); | ss_str = std::regex_replace(ss_str, std::regex("Number repeats per epoch.*\n"), ""); | ||||
| @@ -454,6 +454,11 @@ Status NodePass::PreRunOnNode(std::shared_ptr<SkipOp> node, bool *modified) { | |||||
| return PreRunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | return PreRunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | ||||
| } | } | ||||
| Status NodePass::PreRunOnNode(std::shared_ptr<BatchOp> node, bool *modified) { | |||||
| // Fallback to base class visitor by default | |||||
| return PreRunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | |||||
| } | |||||
| #ifndef ENABLE_ANDROID | #ifndef ENABLE_ANDROID | ||||
| Status NodePass::RunOnNode(std::shared_ptr<MindRecordOp> node, bool *modified) { | Status NodePass::RunOnNode(std::shared_ptr<MindRecordOp> node, bool *modified) { | ||||
| // Fallback to base class visitor by default | // Fallback to base class visitor by default | ||||
| @@ -305,6 +305,7 @@ class NodePass : public Pass { | |||||
| virtual Status PreRunOnNode(std::shared_ptr<ConcatOp> node, bool *modified); | virtual Status PreRunOnNode(std::shared_ptr<ConcatOp> node, bool *modified); | ||||
| virtual Status PreRunOnNode(std::shared_ptr<TakeOp> node, bool *modified); | virtual Status PreRunOnNode(std::shared_ptr<TakeOp> node, bool *modified); | ||||
| virtual Status PreRunOnNode(std::shared_ptr<SkipOp> node, bool *modified); | virtual Status PreRunOnNode(std::shared_ptr<SkipOp> node, bool *modified); | ||||
| virtual Status PreRunOnNode(std::shared_ptr<BatchOp> node, bool *modified); | |||||
| #ifndef ENABLE_ANDROID | #ifndef ENABLE_ANDROID | ||||
| virtual Status RunOnNode(std::shared_ptr<MindRecordOp> node, bool *modified); | virtual Status RunOnNode(std::shared_ptr<MindRecordOp> node, bool *modified); | ||||
| virtual Status RunOnNode(std::shared_ptr<TFReaderOp> node, bool *modified); | virtual Status RunOnNode(std::shared_ptr<TFReaderOp> node, bool *modified); | ||||
| @@ -87,6 +87,16 @@ Status CacheErrorPass::PreRunOnNode(std::shared_ptr<SkipOp> node, bool *modified | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| // Returns an error if SkipOp exists under a cache | |||||
| Status CacheErrorPass::PreRunOnNode(std::shared_ptr<BatchOp> node, bool *modified) { | |||||
| if (is_cached_) { | |||||
| return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, | |||||
| "BatchOp is currently not supported as a descendant operator under a cache."); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| #ifdef ENABLE_PYTHON | #ifdef ENABLE_PYTHON | ||||
| // Returns an error if FilterOp exists under a cache | // Returns an error if FilterOp exists under a cache | ||||
| Status CacheErrorPass::PreRunOnNode(std::shared_ptr<FilterOp> node, bool *modified) { | Status CacheErrorPass::PreRunOnNode(std::shared_ptr<FilterOp> node, bool *modified) { | ||||
| @@ -169,7 +179,8 @@ Status CacheErrorPass::RunOnNode(std::shared_ptr<CacheOp> node, bool *modified) | |||||
| // Because there is no operator in the cache hit stream to consume eoes, caching above repeat causes problem. | // Because there is no operator in the cache hit stream to consume eoes, caching above repeat causes problem. | ||||
| Status CacheErrorPass::RunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) { | Status CacheErrorPass::RunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) { | ||||
| if (is_cached_ && is_mappable_) { | if (is_cached_ && is_mappable_) { | ||||
| RETURN_STATUS_UNEXPECTED("Repeat is not supported as a descendant operator under a mappable cache."); | |||||
| return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, | |||||
| "Repeat is not supported as a descendant operator under a mappable cache."); | |||||
| } | } | ||||
| return Status::OK(); | return Status::OK(); | ||||
| @@ -71,6 +71,12 @@ class CacheErrorPass : public NodePass { | |||||
| /// \return Status The error code return | /// \return Status The error code return | ||||
| Status PreRunOnNode(std::shared_ptr<SkipOp> node, bool *modified) override; | Status PreRunOnNode(std::shared_ptr<SkipOp> node, bool *modified) override; | ||||
| /// \brief Returns an error if SkipOp exists under a cache | |||||
| /// \param[in] node The node being visited | |||||
| /// \param[inout] modified Indicator if the node was changed at all | |||||
| /// \return Status The error code return | |||||
| Status PreRunOnNode(std::shared_ptr<BatchOp> node, bool *modified) override; | |||||
| #ifdef ENABLE_PYTHON | #ifdef ENABLE_PYTHON | ||||
| /// \brief Returns an error if FilterOp exists under a cache | /// \brief Returns an error if FilterOp exists under a cache | ||||
| /// \param[in] node The node being visited | /// \param[in] node The node being visited | ||||
| @@ -58,7 +58,7 @@ Status CacheTransformPass::CachePass::PreRunOnNode(std::shared_ptr<CacheOp> node | |||||
| *modified = false; | *modified = false; | ||||
| MS_LOG(INFO) << "Cache transform pass: CacheOp found, identified descendant tree."; | MS_LOG(INFO) << "Cache transform pass: CacheOp found, identified descendant tree."; | ||||
| if (is_caching_) { | if (is_caching_) { | ||||
| RETURN_STATUS_UNEXPECTED("Nested cache operations is not supported!"); | |||||
| return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Nested cache operations is not supported!"); | |||||
| } | } | ||||
| is_caching_ = true; | is_caching_ = true; | ||||
| return Status::OK(); | return Status::OK(); | ||||
| @@ -102,7 +102,8 @@ Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<CacheOp> node, b | |||||
| Status CacheTransformPass::CachePass::MappableCacheLeafSetup(std::shared_ptr<DatasetOp> leaf_op) { | Status CacheTransformPass::CachePass::MappableCacheLeafSetup(std::shared_ptr<DatasetOp> leaf_op) { | ||||
| // If a leaf has already been assigned, then we have more than one leaf inside this cache descendant tree. | // If a leaf has already been assigned, then we have more than one leaf inside this cache descendant tree. | ||||
| if (is_caching_ && leaf_op_) { | if (is_caching_ && leaf_op_) { | ||||
| RETURN_STATUS_UNEXPECTED("There is currently no support for multiple leaf nodes under cache."); | |||||
| return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, | |||||
| "There is currently no support for multiple leaf nodes under cache."); | |||||
| } | } | ||||
| // If we are a leaf in the caching path, then save this leaf. | // If we are a leaf in the caching path, then save this leaf. | ||||
| @@ -117,7 +118,8 @@ Status CacheTransformPass::CachePass::MappableCacheLeafSetup(std::shared_ptr<Dat | |||||
| Status CacheTransformPass::CachePass::NonMappableCacheLeafSetup(std::shared_ptr<DatasetOp> leaf_op) { | Status CacheTransformPass::CachePass::NonMappableCacheLeafSetup(std::shared_ptr<DatasetOp> leaf_op) { | ||||
| // If a leaf has already been assigned, then we have more than one leaf inside this cache descendant tree. | // If a leaf has already been assigned, then we have more than one leaf inside this cache descendant tree. | ||||
| if (is_caching_ && leaf_op_) { | if (is_caching_ && leaf_op_) { | ||||
| RETURN_STATUS_UNEXPECTED("There is currently no support for multiple leaf nodes under cache."); | |||||
| return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, | |||||
| "There is currently no support for multiple leaf nodes under cache."); | |||||
| } | } | ||||
| // Sampler for non mappable dataset only works if there is a downstream cache. Remove it from the leaf | // Sampler for non mappable dataset only works if there is a downstream cache. Remove it from the leaf | ||||
| @@ -215,7 +217,8 @@ Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<CelebAOp> node, | |||||
| // Perform leaf node cache transform identification | // Perform leaf node cache transform identification | ||||
| Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<MindRecordOp> node, bool *modified) { | Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<MindRecordOp> node, bool *modified) { | ||||
| if (is_caching_) { | if (is_caching_) { | ||||
| RETURN_STATUS_UNEXPECTED("There is currently no support for MindRecordOp under cache."); | |||||
| return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, | |||||
| "There is currently no support for MindRecordOp under cache."); | |||||
| } | } | ||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| @@ -225,7 +228,8 @@ Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<MindRecordOp> no | |||||
| // Perform leaf node cache transform identification | // Perform leaf node cache transform identification | ||||
| Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<GeneratorOp> node, bool *modified) { | Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<GeneratorOp> node, bool *modified) { | ||||
| if (is_caching_) { | if (is_caching_) { | ||||
| RETURN_STATUS_UNEXPECTED("There is currently no support for GeneratorOp under cache."); | |||||
| return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, | |||||
| "There is currently no support for GeneratorOp under cache."); | |||||
| } | } | ||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| @@ -51,6 +51,13 @@ namespace dataset { | |||||
| } \ | } \ | ||||
| } while (false) | } while (false) | ||||
| #define CHECK_FAIL_RETURN_SYNTAX_ERROR(_condition, _e) \ | |||||
| do { \ | |||||
| if (!(_condition)) { \ | |||||
| return Status(StatusCode::kSyntaxError, __LINE__, __FILE__, _e); \ | |||||
| } \ | |||||
| } while (false) | |||||
| #define RETURN_UNEXPECTED_IF_NULL(_ptr) \ | #define RETURN_UNEXPECTED_IF_NULL(_ptr) \ | ||||
| do { \ | do { \ | ||||
| if ((_ptr) == nullptr) { \ | if ((_ptr) == nullptr) { \ | ||||
| @@ -321,6 +321,9 @@ HandleRcExit $? 0 0 | |||||
| PytestCmd "test_cache_nomap.py" "test_cache_nomap_nested_repeat" | PytestCmd "test_cache_nomap.py" "test_cache_nomap_nested_repeat" | ||||
| HandleRcExit $? 0 0 | HandleRcExit $? 0 0 | ||||
| PytestCmd "test_cache_nomap.py" "test_cache_nomap_get_repeat_count" | |||||
| HandleRcExit $? 0 0 | |||||
| for i in $(seq 1 3) | for i in $(seq 1 3) | ||||
| do | do | ||||
| test_name="test_cache_nomap_multiple_cache${i}" | test_name="test_cache_nomap_multiple_cache${i}" | ||||
| @@ -319,7 +319,7 @@ def test_cache_map_failure3(): | |||||
| num_iter = 0 | num_iter = 0 | ||||
| for _ in ds1.create_dict_iterator(): | for _ in ds1.create_dict_iterator(): | ||||
| num_iter += 1 | num_iter += 1 | ||||
| assert "Unexpected error. Expect positive row id: -1" in str(e.value) | |||||
| assert "BatchOp is currently not supported as a descendant operator under a cache" in str(e.value) | |||||
| assert num_iter == 0 | assert num_iter == 0 | ||||
| logger.info('test_cache_failure3 Ended.\n') | logger.info('test_cache_failure3 Ended.\n') | ||||
| @@ -754,11 +754,11 @@ def test_cache_map_parameter_check(): | |||||
| with pytest.raises(RuntimeError) as err: | with pytest.raises(RuntimeError) as err: | ||||
| ds.DatasetCache(session_id=1, size=0, spilling=True, hostname="illegal") | ds.DatasetCache(session_id=1, size=0, spilling=True, hostname="illegal") | ||||
| assert "Unexpected error. now cache client has to be on the same host with cache server" in str(err.value) | |||||
| assert "now cache client has to be on the same host with cache server" in str(err.value) | |||||
| with pytest.raises(RuntimeError) as err: | with pytest.raises(RuntimeError) as err: | ||||
| ds.DatasetCache(session_id=1, size=0, spilling=True, hostname="127.0.0.2") | ds.DatasetCache(session_id=1, size=0, spilling=True, hostname="127.0.0.2") | ||||
| assert "Unexpected error. now cache client has to be on the same host with cache server" in str(err.value) | |||||
| assert "now cache client has to be on the same host with cache server" in str(err.value) | |||||
| with pytest.raises(TypeError) as info: | with pytest.raises(TypeError) as info: | ||||
| ds.DatasetCache(session_id=1, size=0, spilling=True, port="illegal") | ds.DatasetCache(session_id=1, size=0, spilling=True, port="illegal") | ||||
| @@ -1846,6 +1846,44 @@ def test_cache_nomap_nested_repeat(): | |||||
| logger.info('test_cache_nomap_nested_repeat Ended.\n') | logger.info('test_cache_nomap_nested_repeat Ended.\n') | ||||
| @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | |||||
| def test_cache_nomap_get_repeat_count(): | |||||
| """ | |||||
| Test get_repeat_count() for a pipeline with cache and nested repeat ops | |||||
| Cache | |||||
| | | |||||
| Map(decode) | |||||
| | | |||||
| Repeat | |||||
| | | |||||
| TFRecord | |||||
| """ | |||||
| logger.info("Test cache nomap get_repeat_count") | |||||
| if "SESSION_ID" in os.environ: | |||||
| session_id = int(os.environ['SESSION_ID']) | |||||
| else: | |||||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||||
| # This dataset has 3 records in it only | |||||
| ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) | |||||
| ds1 = ds1.repeat(4) | |||||
| decode_op = c_vision.Decode() | |||||
| ds1 = ds1.map(operations=decode_op, input_columns=["image"], cache=some_cache) | |||||
| repeat_count = ds1.get_repeat_count() | |||||
| logger.info("repeat_count: {}".format(repeat_count)) | |||||
| assert repeat_count == 4 | |||||
| num_iter = 0 | |||||
| for _ in ds1.create_dict_iterator(num_epochs=1): | |||||
| logger.info("get data from dataset") | |||||
| num_iter += 1 | |||||
| assert num_iter == 12 | |||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||
| test_cache_nomap_basic1() | test_cache_nomap_basic1() | ||||
| test_cache_nomap_basic2() | test_cache_nomap_basic2() | ||||