/** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include #include #include "utils/base_ref_utils.h" #include "include/ms_tensor.h" #include "ir/tensor.h" namespace mindspore { void IterateFindTensor(std::vector> *msTensors, const VectorRef &ref_list) { for (size_t i = 0; i < ref_list.size(); ++i) { if (utils::isa(ref_list[i])) { auto tensor_ptr = utils::cast>(ref_list[i]); MS_EXCEPTION_IF_NULL(tensor_ptr); auto tensor = new inference::Tensor(tensor_ptr); msTensors->emplace_back(std::shared_ptr(tensor)); } else if (utils::isa(ref_list[i])) { auto ref_iter = utils::cast(ref_list[i]); IterateFindTensor(msTensors, ref_iter); } else { MS_LOG(EXCEPTION) << "The output is not a tensor"; } } } std::vector> TransformVectorRefToMultiTensor(const VectorRef &base_ref) { std::vector> msTensors; if (utils::isa(base_ref)) { auto ref_list = utils::cast(base_ref); IterateFindTensor(&msTensors, ref_list); } else if (utils::isa(base_ref)) { auto tensor_ptr = utils::cast>(base_ref); MS_EXCEPTION_IF_NULL(tensor_ptr); auto tensor = new inference::Tensor(tensor_ptr); msTensors.emplace_back(std::shared_ptr(tensor)); } else { MS_LOG(EXCEPTION) << "The output is not a base ref list or a tensor!"; } return msTensors; } } // namespace mindspore