| @@ -41,15 +41,15 @@ Status CacheClient::Builder::Build(std::shared_ptr<CacheClient> *out) { | |||
| } | |||
| Status CacheClient::Builder::SanityCheck() { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(session_id_ > 0, "session id must be positive"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(cache_mem_sz_ >= 0, "cache memory size must not be negative. (0 implies unlimited"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(num_connections_ > 0, "rpc connections must be positive"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(prefetch_size_ > 0, "prefetch size must be positive"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(!hostname_.empty(), "hostname must not be empty"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(port_ > 1024, "Port must be in range (1025..65535)"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(port_ <= 65535, "Port must be in range (1025..65535)"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(hostname_ == "127.0.0.1", | |||
| "now cache client has to be on the same host with cache server"); | |||
| CHECK_FAIL_RETURN_SYNTAX_ERROR(session_id_ > 0, "session id must be positive"); | |||
| CHECK_FAIL_RETURN_SYNTAX_ERROR(cache_mem_sz_ >= 0, "cache memory size must not be negative. (0 implies unlimited"); | |||
| CHECK_FAIL_RETURN_SYNTAX_ERROR(num_connections_ > 0, "rpc connections must be positive"); | |||
| CHECK_FAIL_RETURN_SYNTAX_ERROR(prefetch_size_ > 0, "prefetch size must be positive"); | |||
| CHECK_FAIL_RETURN_SYNTAX_ERROR(!hostname_.empty(), "hostname must not be empty"); | |||
| CHECK_FAIL_RETURN_SYNTAX_ERROR(port_ > 1024, "Port must be in range (1025..65535)"); | |||
| CHECK_FAIL_RETURN_SYNTAX_ERROR(port_ <= 65535, "Port must be in range (1025..65535)"); | |||
| CHECK_FAIL_RETURN_SYNTAX_ERROR(hostname_ == "127.0.0.1", | |||
| "now cache client has to be on the same host with cache server"); | |||
| return Status::OK(); | |||
| } | |||
| @@ -476,6 +476,12 @@ Status BatchOp::Accept(NodePass *p, bool *modified) { | |||
| return p->RunOnNode(shared_from_base<BatchOp>(), modified); | |||
| } | |||
| // Visitor pre-accept method for NodePass | |||
| Status BatchOp::PreAccept(NodePass *p, bool *modified) { | |||
| // Downcast shared pointer then call visitor | |||
| return p->PreRunOnNode(shared_from_base<BatchOp>(), modified); | |||
| } | |||
| Status BatchOp::ComputeColMap() { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(child_.size() == 1, | |||
| "Batch has " + std::to_string(child_.size()) + " child/children, expects only 1 child."); | |||
| @@ -199,6 +199,12 @@ class BatchOp : public ParallelOp { | |||
| // @return - Status of the node visit. | |||
| Status Accept(NodePass *p, bool *modified) override; | |||
| // Base-class override for NodePass visitor acceptor. | |||
| // @param p - Pointer to the NodePass to be accepted. | |||
| // @param modified - Whether this node visit modified the pipeline. | |||
| // @return - Status of the node visit. | |||
| Status PreAccept(NodePass *p, bool *modified) override; | |||
| // Op name getter | |||
| // @return Name of the current Op | |||
| std::string Name() const override { return kBatchOp; } | |||
| @@ -468,10 +468,16 @@ uint32_t DatasetOp::GenerateCRC(const std::shared_ptr<DatasetOp> &op) { | |||
| ss_str = std::regex_replace(ss_str, std::regex("device_id.*\n"), ""); | |||
| // Filter out the operator id field | |||
| ss_str = std::regex_replace(ss_str, std::regex("Parent.*\n"), ""); | |||
| ss_str = std::regex_replace(ss_str, std::regex("Child.*\n"), ""); | |||
| ss_str = std::regex_replace(ss_str, std::regex(".*Parent.*\n"), ""); | |||
| ss_str = std::regex_replace(ss_str, std::regex(".*Child.*\n"), ""); | |||
| ss_str = std::regex_replace(ss_str, std::regex(R"(\(\s*\d+?\))"), ""); | |||
| // Doesn't matter whether there is any parent node above CacheOp or not. | |||
| ss_str = std::regex_replace(ss_str, std::regex("Number of parents.*\n"), ""); | |||
| // Filter out shuffle seed from ShuffleOp | |||
| ss_str = std::regex_replace(ss_str, std::regex("Shuffle seed.*\n"), ""); | |||
| // Filter out the total repeats and number repeats per epoch field | |||
| ss_str = std::regex_replace(ss_str, std::regex("Total repeats.*\n"), ""); | |||
| ss_str = std::regex_replace(ss_str, std::regex("Number repeats per epoch.*\n"), ""); | |||
| @@ -454,6 +454,11 @@ Status NodePass::PreRunOnNode(std::shared_ptr<SkipOp> node, bool *modified) { | |||
| return PreRunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | |||
| } | |||
| Status NodePass::PreRunOnNode(std::shared_ptr<BatchOp> node, bool *modified) { | |||
| // Fallback to base class visitor by default | |||
| return PreRunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); | |||
| } | |||
| #ifndef ENABLE_ANDROID | |||
| Status NodePass::RunOnNode(std::shared_ptr<MindRecordOp> node, bool *modified) { | |||
| // Fallback to base class visitor by default | |||
| @@ -305,6 +305,7 @@ class NodePass : public Pass { | |||
| virtual Status PreRunOnNode(std::shared_ptr<ConcatOp> node, bool *modified); | |||
| virtual Status PreRunOnNode(std::shared_ptr<TakeOp> node, bool *modified); | |||
| virtual Status PreRunOnNode(std::shared_ptr<SkipOp> node, bool *modified); | |||
| virtual Status PreRunOnNode(std::shared_ptr<BatchOp> node, bool *modified); | |||
| #ifndef ENABLE_ANDROID | |||
| virtual Status RunOnNode(std::shared_ptr<MindRecordOp> node, bool *modified); | |||
| virtual Status RunOnNode(std::shared_ptr<TFReaderOp> node, bool *modified); | |||
| @@ -87,6 +87,16 @@ Status CacheErrorPass::PreRunOnNode(std::shared_ptr<SkipOp> node, bool *modified | |||
| return Status::OK(); | |||
| } | |||
| // Returns an error if SkipOp exists under a cache | |||
| Status CacheErrorPass::PreRunOnNode(std::shared_ptr<BatchOp> node, bool *modified) { | |||
| if (is_cached_) { | |||
| return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, | |||
| "BatchOp is currently not supported as a descendant operator under a cache."); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| #ifdef ENABLE_PYTHON | |||
| // Returns an error if FilterOp exists under a cache | |||
| Status CacheErrorPass::PreRunOnNode(std::shared_ptr<FilterOp> node, bool *modified) { | |||
| @@ -169,7 +179,8 @@ Status CacheErrorPass::RunOnNode(std::shared_ptr<CacheOp> node, bool *modified) | |||
| // Because there is no operator in the cache hit stream to consume eoes, caching above repeat causes problem. | |||
| Status CacheErrorPass::RunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) { | |||
| if (is_cached_ && is_mappable_) { | |||
| RETURN_STATUS_UNEXPECTED("Repeat is not supported as a descendant operator under a mappable cache."); | |||
| return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, | |||
| "Repeat is not supported as a descendant operator under a mappable cache."); | |||
| } | |||
| return Status::OK(); | |||
| @@ -71,6 +71,12 @@ class CacheErrorPass : public NodePass { | |||
| /// \return Status The error code return | |||
| Status PreRunOnNode(std::shared_ptr<SkipOp> node, bool *modified) override; | |||
| /// \brief Returns an error if SkipOp exists under a cache | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| Status PreRunOnNode(std::shared_ptr<BatchOp> node, bool *modified) override; | |||
| #ifdef ENABLE_PYTHON | |||
| /// \brief Returns an error if FilterOp exists under a cache | |||
| /// \param[in] node The node being visited | |||
| @@ -58,7 +58,7 @@ Status CacheTransformPass::CachePass::PreRunOnNode(std::shared_ptr<CacheOp> node | |||
| *modified = false; | |||
| MS_LOG(INFO) << "Cache transform pass: CacheOp found, identified descendant tree."; | |||
| if (is_caching_) { | |||
| RETURN_STATUS_UNEXPECTED("Nested cache operations is not supported!"); | |||
| return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Nested cache operations is not supported!"); | |||
| } | |||
| is_caching_ = true; | |||
| return Status::OK(); | |||
| @@ -102,7 +102,8 @@ Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<CacheOp> node, b | |||
| Status CacheTransformPass::CachePass::MappableCacheLeafSetup(std::shared_ptr<DatasetOp> leaf_op) { | |||
| // If a leaf has already been assigned, then we have more than one leaf inside this cache descendant tree. | |||
| if (is_caching_ && leaf_op_) { | |||
| RETURN_STATUS_UNEXPECTED("There is currently no support for multiple leaf nodes under cache."); | |||
| return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, | |||
| "There is currently no support for multiple leaf nodes under cache."); | |||
| } | |||
| // If we are a leaf in the caching path, then save this leaf. | |||
| @@ -117,7 +118,8 @@ Status CacheTransformPass::CachePass::MappableCacheLeafSetup(std::shared_ptr<Dat | |||
| Status CacheTransformPass::CachePass::NonMappableCacheLeafSetup(std::shared_ptr<DatasetOp> leaf_op) { | |||
| // If a leaf has already been assigned, then we have more than one leaf inside this cache descendant tree. | |||
| if (is_caching_ && leaf_op_) { | |||
| RETURN_STATUS_UNEXPECTED("There is currently no support for multiple leaf nodes under cache."); | |||
| return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, | |||
| "There is currently no support for multiple leaf nodes under cache."); | |||
| } | |||
| // Sampler for non mappable dataset only works if there is a downstream cache. Remove it from the leaf | |||
| @@ -215,7 +217,8 @@ Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<CelebAOp> node, | |||
| // Perform leaf node cache transform identification | |||
| Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<MindRecordOp> node, bool *modified) { | |||
| if (is_caching_) { | |||
| RETURN_STATUS_UNEXPECTED("There is currently no support for MindRecordOp under cache."); | |||
| return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, | |||
| "There is currently no support for MindRecordOp under cache."); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| @@ -225,7 +228,8 @@ Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<MindRecordOp> no | |||
| // Perform leaf node cache transform identification | |||
| Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<GeneratorOp> node, bool *modified) { | |||
| if (is_caching_) { | |||
| RETURN_STATUS_UNEXPECTED("There is currently no support for GeneratorOp under cache."); | |||
| return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, | |||
| "There is currently no support for GeneratorOp under cache."); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| @@ -51,6 +51,13 @@ namespace dataset { | |||
| } \ | |||
| } while (false) | |||
| #define CHECK_FAIL_RETURN_SYNTAX_ERROR(_condition, _e) \ | |||
| do { \ | |||
| if (!(_condition)) { \ | |||
| return Status(StatusCode::kSyntaxError, __LINE__, __FILE__, _e); \ | |||
| } \ | |||
| } while (false) | |||
| #define RETURN_UNEXPECTED_IF_NULL(_ptr) \ | |||
| do { \ | |||
| if ((_ptr) == nullptr) { \ | |||
| @@ -321,6 +321,9 @@ HandleRcExit $? 0 0 | |||
| PytestCmd "test_cache_nomap.py" "test_cache_nomap_nested_repeat" | |||
| HandleRcExit $? 0 0 | |||
| PytestCmd "test_cache_nomap.py" "test_cache_nomap_get_repeat_count" | |||
| HandleRcExit $? 0 0 | |||
| for i in $(seq 1 3) | |||
| do | |||
| test_name="test_cache_nomap_multiple_cache${i}" | |||
| @@ -319,7 +319,7 @@ def test_cache_map_failure3(): | |||
| num_iter = 0 | |||
| for _ in ds1.create_dict_iterator(): | |||
| num_iter += 1 | |||
| assert "Unexpected error. Expect positive row id: -1" in str(e.value) | |||
| assert "BatchOp is currently not supported as a descendant operator under a cache" in str(e.value) | |||
| assert num_iter == 0 | |||
| logger.info('test_cache_failure3 Ended.\n') | |||
| @@ -754,11 +754,11 @@ def test_cache_map_parameter_check(): | |||
| with pytest.raises(RuntimeError) as err: | |||
| ds.DatasetCache(session_id=1, size=0, spilling=True, hostname="illegal") | |||
| assert "Unexpected error. now cache client has to be on the same host with cache server" in str(err.value) | |||
| assert "now cache client has to be on the same host with cache server" in str(err.value) | |||
| with pytest.raises(RuntimeError) as err: | |||
| ds.DatasetCache(session_id=1, size=0, spilling=True, hostname="127.0.0.2") | |||
| assert "Unexpected error. now cache client has to be on the same host with cache server" in str(err.value) | |||
| assert "now cache client has to be on the same host with cache server" in str(err.value) | |||
| with pytest.raises(TypeError) as info: | |||
| ds.DatasetCache(session_id=1, size=0, spilling=True, port="illegal") | |||
| @@ -1846,6 +1846,44 @@ def test_cache_nomap_nested_repeat(): | |||
| logger.info('test_cache_nomap_nested_repeat Ended.\n') | |||
| @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | |||
| def test_cache_nomap_get_repeat_count(): | |||
| """ | |||
| Test get_repeat_count() for a pipeline with cache and nested repeat ops | |||
| Cache | |||
| | | |||
| Map(decode) | |||
| | | |||
| Repeat | |||
| | | |||
| TFRecord | |||
| """ | |||
| logger.info("Test cache nomap get_repeat_count") | |||
| if "SESSION_ID" in os.environ: | |||
| session_id = int(os.environ['SESSION_ID']) | |||
| else: | |||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) | |||
| # This dataset has 3 records in it only | |||
| ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) | |||
| ds1 = ds1.repeat(4) | |||
| decode_op = c_vision.Decode() | |||
| ds1 = ds1.map(operations=decode_op, input_columns=["image"], cache=some_cache) | |||
| repeat_count = ds1.get_repeat_count() | |||
| logger.info("repeat_count: {}".format(repeat_count)) | |||
| assert repeat_count == 4 | |||
| num_iter = 0 | |||
| for _ in ds1.create_dict_iterator(num_epochs=1): | |||
| logger.info("get data from dataset") | |||
| num_iter += 1 | |||
| assert num_iter == 12 | |||
| if __name__ == '__main__': | |||
| test_cache_nomap_basic1() | |||
| test_cache_nomap_basic2() | |||