| @@ -25,12 +25,12 @@ | |||
| namespace mindspore { | |||
| #define MS_API __attribute__((visibility("default"))) | |||
| namespace lite { | |||
| /// \brief ModelImpl defined the implement class of Model in MindSpore Lite. | |||
| /// | |||
| /// \note List public class and interface for reference. | |||
| class ModelImpl; | |||
| namespace lite { | |||
| /// \brief Primitive defined as prototype of operator. | |||
| /// | |||
| /// \note List public class and interface for reference. | |||
| @@ -67,11 +67,6 @@ class MS_API Model { | |||
| /// \return the pointer of graph defined in flatbuffers. | |||
| const schema::MetaGraph *GetMetaGraph() const; | |||
| /// \brief Get MindSpore Lite ModelImpl. | |||
| /// | |||
| /// \return the pointer of MindSpore Lite ModelImpl. | |||
| ModelImpl *model_impl(); | |||
| /// \brief Free MetaGraph in MindSpore Lite Model. | |||
| void FreeMetaGraph(); | |||
| @@ -8,10 +8,8 @@ set(LITE_SRC | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/ir/tensor.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/context.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/executor.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/kernel_factory.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/kernel_registry.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/lite_kernel.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/model.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/populate_parameter.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/scheduler.cc | |||
| ) | |||
| @@ -21,44 +19,11 @@ if (SUPPORT_GPU) | |||
| list(APPEND LITE_SRC ${CMAKE_CURRENT_SOURCE_DIR}/runtime/kernel/opencl/utils.cc) | |||
| endif() | |||
| if (SUPPORT_TRAIN) | |||
| set(ANF_SRC | |||
| ${ANF_SRC} | |||
| # ${CCSRC_DIR}/common/trans.cc | |||
| # ${CCSRC_DIR}/utils/lite/base_ref_utils.cc | |||
| # ${CCSRC_DIR}/runtime/kernel/kernel_compiler/kernel_build_info.cc | |||
| # ${CCSRC_DIR}/session/lite/anf_runtime_algorithm_extends.cc | |||
| # ${CCSRC_DIR}/session/lite/session_basic_extends.cc | |||
| # ${CCSRC_DIR}/session/anf_runtime_algorithm.cc | |||
| # ${CCSRC_DIR}/session/session_basic.cc | |||
| # ${CCSRC_DIR}/session/kernel_graph.cc | |||
| # ${CCSRC_DIR}/session/session_factory.cc | |||
| # ${CCSRC_DIR}/device/kernel_info.cc | |||
| # ${CCSRC_DIR}/device/kernel_runtime.cc | |||
| # ${CCSRC_DIR}/device/lite/kernel_runtime_extends.cc | |||
| ) | |||
| set(PASS_SRC) | |||
| set(LITE_SRC | |||
| ${LITE_SRC} | |||
| ${ANF_SRC} | |||
| # ${PASS_SRC} | |||
| # ${CMAKE_CURRENT_SOURCE_DIR}/common/anf_importer/anf_importer.cc | |||
| # ${CMAKE_CURRENT_SOURCE_DIR}/common/anf_importer/import_from_meta_graph.cc | |||
| # ${CMAKE_CURRENT_SOURCE_DIR}/ir/primitive_value.cc | |||
| # ${CMAKE_CURRENT_SOURCE_DIR}/train/lite_kernel_runtime.cc | |||
| # ${CMAKE_CURRENT_SOURCE_DIR}/train/train_session.cc | |||
| # ${CMAKE_CURRENT_SOURCE_DIR}/train/model_impl.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/lite_session.cc # temporary | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/model_impl.cc # temporary | |||
| ) | |||
| else () | |||
| set(LITE_SRC | |||
| ${LITE_SRC} | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/lite_session.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/model_impl.cc | |||
| ) | |||
| endif () | |||
| set(LITE_SRC | |||
| ${LITE_SRC} | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/lite_session.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/model.cc | |||
| ) | |||
| if (SUPPORT_GPU) | |||
| set(LITE_SRC | |||
| @@ -1,53 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "mindspore/lite/src/kernel_factory.h" | |||
| #include "utils/log_adapter.h" | |||
| #include "src/populate_parameter.h" | |||
| #include "schema/model_generated.h" | |||
| using mindspore::kernel::KERNEL_ARCH; | |||
| using mindspore::kernel::KernelKey; | |||
| using mindspore::kernel::LiteKernel; | |||
| namespace mindspore::lite { | |||
| KernelFactory::KernelFactory() = default; | |||
| KernelFactory::~KernelFactory() = default; | |||
| KernelFactory *KernelFactory::GetInstance() { | |||
| static KernelFactory instance; | |||
| return &instance; | |||
| } | |||
| LiteKernel *KernelFactory::GetKernel(const std::vector<tensor::Tensor *> &in_tensors, | |||
| const std::vector<tensor::Tensor *> &out_tensors, const lite::Primitive *primitive, | |||
| const Context *ctx, const kernel::KernelKey &key) { | |||
| MS_EXCEPTION_IF_NULL(primitive); | |||
| MS_EXCEPTION_IF_NULL(ctx); | |||
| auto parameter = kernel::PopulateParameter(primitive); | |||
| if (parameter == nullptr) { | |||
| MS_LOG(ERROR) << "PopulateParameter return nullptr, type: " << schema::EnumNamePrimitiveType(primitive->Type()); | |||
| return nullptr; | |||
| } | |||
| auto creator = KernelRegistry::GetInstance()->GetCreator(key); | |||
| if (creator != nullptr) { | |||
| auto kernel = creator(in_tensors, out_tensors, parameter, ctx, key, primitive); | |||
| return kernel; | |||
| } | |||
| return nullptr; | |||
| } | |||
| } // namespace mindspore::lite | |||
| @@ -1,40 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_SRC_KERNEL_FACTORY_H_ | |||
| #define MINDSPORE_LITE_SRC_KERNEL_FACTORY_H_ | |||
| #include <vector> | |||
| #include "mindspore/lite/src/lite_kernel.h" | |||
| #include "mindspore/lite/src/kernel_registry.h" | |||
| #include "mindspore/lite/include/context.h" | |||
| #include "mindspore/lite/src/ir/tensor.h" | |||
| #include "schema/model_generated.h" | |||
| namespace mindspore::lite { | |||
| class KernelFactory { | |||
| public: | |||
| KernelFactory(); | |||
| virtual ~KernelFactory(); | |||
| static KernelFactory *GetInstance(); | |||
| kernel::LiteKernel *GetKernel(const std::vector<tensor::Tensor *> &in_tensors, | |||
| const std::vector<tensor::Tensor *> &out_tensors, const lite::Primitive *primitive, | |||
| const Context *ctx, const kernel::KernelKey &key); | |||
| }; | |||
| } // namespace mindspore::lite | |||
| #endif // MINDSPORE_LITE_SRC_KERNEL_FACTORY_H_ | |||
| @@ -16,6 +16,7 @@ | |||
| #include "src/kernel_registry.h" | |||
| #include "include/errorcode.h" | |||
| #include "ir/dtype/type_id.h" | |||
| #include "src/populate_parameter.h" | |||
| #ifdef ENABLE_ARM64 | |||
| #include <asm/hwcap.h> | |||
| #include "common/utils.h" | |||
| @@ -120,4 +121,23 @@ void KernelRegistry::RegKernel(const KERNEL_ARCH arch, const TypeId data_type, c | |||
| bool KernelRegistry::Merge(const std::unordered_map<KernelKey, KernelCreator> &newCreators) { return false; } | |||
| const kernel::KernelCreator *KernelRegistry::GetCreatorArrays() { return creator_arrays_; } | |||
| kernel::LiteKernel *KernelRegistry::GetKernel(const std::vector<tensor::Tensor *> &in_tensors, | |||
| const std::vector<tensor::Tensor *> &out_tensors, | |||
| const lite::Primitive *primitive, const Context *ctx, | |||
| const kernel::KernelKey &key) { | |||
| MS_EXCEPTION_IF_NULL(primitive); | |||
| MS_EXCEPTION_IF_NULL(ctx); | |||
| auto parameter = kernel::PopulateParameter(primitive); | |||
| if (parameter == nullptr) { | |||
| MS_LOG(ERROR) << "PopulateParameter return nullptr, type: " << schema::EnumNamePrimitiveType(primitive->Type()); | |||
| return nullptr; | |||
| } | |||
| auto creator = GetCreator(key); | |||
| if (creator != nullptr) { | |||
| auto kernel = creator(in_tensors, out_tensors, parameter, ctx, key, primitive); | |||
| return kernel; | |||
| } | |||
| return nullptr; | |||
| } | |||
| } // namespace mindspore::lite | |||
| @@ -17,9 +17,9 @@ | |||
| #ifndef MINDSPORE_LITE_SRC_KERNEL_REGISTRY_H_ | |||
| #define MINDSPORE_LITE_SRC_KERNEL_REGISTRY_H_ | |||
| #include <map> | |||
| #include <string> | |||
| #include <unordered_map> | |||
| #include <vector> | |||
| #include "src/lite_kernel.h" | |||
| #include "schema/model_generated.h" | |||
| @@ -39,6 +39,9 @@ class KernelRegistry { | |||
| void RegKernel(const kernel::KERNEL_ARCH arch, const TypeId data_type, const schema::PrimitiveType type, | |||
| kernel::KernelCreator creator); | |||
| bool Merge(const std::unordered_map<kernel::KernelKey, kernel::KernelCreator> &newCreators); | |||
| kernel::LiteKernel *GetKernel(const std::vector<tensor::Tensor *> &in_tensors, | |||
| const std::vector<tensor::Tensor *> &out_tensors, const lite::Primitive *primitive, | |||
| const Context *ctx, const kernel::KernelKey &key); | |||
| protected: | |||
| kernel::KernelCreator *creator_arrays_ = nullptr; | |||
| @@ -14,16 +14,310 @@ | |||
| * limitations under the License. | |||
| */ | |||
| // #ifdef SUPPORT_TRAIN | |||
| // #include "src/train/model_impl.h" | |||
| // #else | |||
| #include "src/model_impl.h" | |||
| // #endif | |||
| #include "include/model.h" | |||
| #include "utils/log_adapter.h" | |||
| #include "src/ops/ops.h" | |||
| namespace mindspore::lite { | |||
| class ModelImpl { | |||
| public: | |||
| static ModelImpl *Import(const char *model_buf, size_t size); | |||
| ModelImpl() = default; | |||
| explicit ModelImpl(const char *model_buf, size_t size) : model_buf_(model_buf), buf_size_(size) { | |||
| meta_graph_ = schema::GetMetaGraph(model_buf); | |||
| } | |||
| virtual ~ModelImpl(); | |||
| lite::Primitive *GetOp(const std::string &name) const; | |||
| const schema::MetaGraph *meta_graph() const; | |||
| void FreeMetaGraph(); | |||
| int BuildOps(); | |||
| protected: | |||
| lite::Primitive *CopyPrimitive(const schema::Primitive *src_prim); | |||
| protected: | |||
| const char *model_buf_; | |||
| size_t buf_size_; | |||
| const schema::MetaGraph *meta_graph_ = nullptr; | |||
| std::map<std::string, lite::Primitive *> ops_; | |||
| }; | |||
| ModelImpl *ModelImpl::Import(const char *model_buf, size_t size) { | |||
| if (model_buf == nullptr) { | |||
| MS_LOG(ERROR) << "The model buf is nullptr"; | |||
| return nullptr; | |||
| } | |||
| flatbuffers::Verifier verify((const uint8_t *)model_buf, size); | |||
| if (!schema::VerifyMetaGraphBuffer(verify)) { | |||
| MS_LOG(ERROR) << "The buffer is invalid and fail to create graph."; | |||
| return nullptr; | |||
| } | |||
| auto *inner_model_buf = new (std::nothrow) char[size]; | |||
| if (inner_model_buf == nullptr) { | |||
| MS_LOG(ERROR) << "new model buf fail."; | |||
| return nullptr; | |||
| } | |||
| memcpy(inner_model_buf, model_buf, size); | |||
| auto model = new (std::nothrow) ModelImpl(inner_model_buf, size); | |||
| if (model == nullptr) { | |||
| MS_LOG(ERROR) << "Create modelImpl failed"; | |||
| return nullptr; | |||
| } | |||
| auto ret = model->BuildOps(); | |||
| if (0 != ret) { | |||
| MS_LOG(ERROR) << "BuildOps failed"; | |||
| return nullptr; | |||
| } | |||
| return model; | |||
| } | |||
| lite::Primitive *ModelImpl::GetOp(const std::string &name) const { | |||
| auto iter = ops_.find(name); | |||
| if (iter == ops_.end()) { | |||
| return nullptr; | |||
| } else { | |||
| return iter->second; | |||
| } | |||
| } | |||
| ModelImpl::~ModelImpl() { | |||
| delete[](this->model_buf_); | |||
| for (auto iter : ops_) { | |||
| delete (iter.second); | |||
| } | |||
| ops_.clear(); | |||
| } | |||
| void ModelImpl::FreeMetaGraph() { | |||
| delete[](this->model_buf_); | |||
| model_buf_ = nullptr; | |||
| } | |||
| const schema::MetaGraph *ModelImpl::meta_graph() const { return this->meta_graph_; } | |||
| lite::Primitive *ModelImpl::CopyPrimitive(const schema::Primitive *src_prim) { | |||
| MS_EXCEPTION_IF_NULL(src_prim); | |||
| auto op_type = src_prim->value_type(); | |||
| switch (op_type) { | |||
| case schema::PrimitiveType_SoftMax: | |||
| return new lite::SoftMax(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_Activation: | |||
| return new lite::Activation(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_Conv2D: | |||
| return new lite::Conv2D(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_DeConv2D: | |||
| return new lite::DeConv2D(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_Reduce: | |||
| return new lite::Reduce(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_Pooling: | |||
| return new lite::Pooling(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_DepthwiseConv2D: | |||
| return new lite::DepthwiseConv2D(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_FusedBatchNorm: | |||
| return new lite::FusedBatchNorm(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_BatchNorm: | |||
| return new lite::BatchNorm(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_FullConnection: | |||
| return new lite::FullConnection(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_Power: | |||
| return new lite::Power(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_Range: | |||
| return new lite::Range(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_Mul: | |||
| return new lite::Mul(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_Add: | |||
| return new lite::Add(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_Sub: | |||
| return new lite::Sub(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_Div: | |||
| return new lite::Div(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_BiasAdd: | |||
| return new lite::BiasAdd(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_ExpandDims: | |||
| return new lite::ExpandDims(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_ArgMax: | |||
| return new lite::ArgMax(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_ArgMin: | |||
| return new lite::ArgMin(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_Cast: | |||
| return new lite::Cast(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_Reshape: | |||
| return new lite::Reshape(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_Scale: | |||
| return new lite::Scale(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_Eltwise: | |||
| return new lite::Eltwise(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_Concat: | |||
| return new lite::Concat(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_Fill: | |||
| return new lite::Fill(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_Transpose: | |||
| return new lite::Transpose(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_Slice: | |||
| return new lite::Slice(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_Squeeze: | |||
| return new lite::Squeeze(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_Nchw2Nhwc: | |||
| return new lite::Nchw2Nhwc(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_Nhwc2Nchw: | |||
| return new lite::Nhwc2Nchw(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_Flatten: | |||
| return new lite::Flatten(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_Mean: | |||
| return new lite::Mean(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_Stack: | |||
| return new lite::Stack(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_Crop: | |||
| return new lite::Crop(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_SquaredDifference: | |||
| return new lite::SquaredDifference(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_AddN: | |||
| return new lite::AddN(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_Abs: | |||
| return new lite::Abs(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_Sin: | |||
| return new lite::Sin(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_Cos: | |||
| return new lite::Cos(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_Log: | |||
| return new lite::Log(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_Sqrt: | |||
| return new lite::Sqrt(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_Rsqrt: | |||
| return new lite::Rsqrt(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_Square: | |||
| return new lite::Square(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_Exp: | |||
| return new lite::Exp(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_Gather: | |||
| return new lite::Gather(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_GatherNd: | |||
| return new lite::GatherNd(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_LocalResponseNormalization: | |||
| return new lite::LocalResponseNormalization(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_Maximum: | |||
| return new lite::Maximum(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_Minimum: | |||
| return new lite::Minimum(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_Pad: | |||
| return new lite::Pad(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_StridedSlice: | |||
| return new lite::StridedSlice(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_Prelu: | |||
| return new lite::Prelu(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_CaffePReLU: | |||
| return new lite::CaffePReLU(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_Round: | |||
| return new lite::Round(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_Reverse: | |||
| return new lite::Reverse(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_ReverseSequence: | |||
| return new lite::ReverseSequence(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_LogicalAnd: | |||
| return new lite::LogicalAnd(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_LogicalOr: | |||
| return new lite::LogicalOr(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_LogicalNot: | |||
| return new lite::LogicalNot(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_FloorDiv: | |||
| return new lite::FloorDiv(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_FloorMod: | |||
| return new lite::FloorMod(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_Equal: | |||
| return new lite::Equal(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_NotEqual: | |||
| return new lite::NotEqual(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_Less: | |||
| return new lite::Less(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_LessEqual: | |||
| return new lite::LessEqual(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_Greater: | |||
| return new lite::Greater(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_GreaterEqual: | |||
| return new lite::GreaterEqual(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_Floor: | |||
| return new lite::Floor(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_Ceil: | |||
| return new lite::Ceil(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_Split: | |||
| return new lite::Split(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_OneHot: | |||
| return new lite::OneHot(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_SpaceToDepth: | |||
| return new lite::SpaceToDepth(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_Tile: | |||
| return new lite::Tile(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_Resize: | |||
| return new lite::Resize(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_Unstack: | |||
| return new lite::Unstack(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_Unique: | |||
| return new lite::Unique(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_TopK: | |||
| return new lite::TopK(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_MatMul: | |||
| return new lite::MatMul(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_QuantDTypeCast: | |||
| return new lite::QuantDTypeCast(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_EmbeddingLookup: | |||
| return new lite::EmbeddingLookup(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_Elu: | |||
| return new lite::Elu(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_DeDepthwiseConv2D: | |||
| return new lite::DeconvDepthwiseConv2D(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_Shape: | |||
| return new lite::Shape(const_cast<schema::Primitive *>(src_prim)); | |||
| default: | |||
| break; | |||
| } | |||
| return nullptr; | |||
| } | |||
| int ModelImpl::BuildOps() { | |||
| if (this->meta_graph_ == nullptr) { | |||
| MS_LOG(ERROR) << "mete_graph is nullptr"; | |||
| return -1; | |||
| } | |||
| MS_EXCEPTION_IF_NULL(meta_graph_->nodes()); | |||
| for (size_t i = 0; i < meta_graph_->nodes()->size(); i++) { | |||
| auto cNode = meta_graph_->nodes()->GetAs<schema::CNode>(i); | |||
| auto name = cNode->name()->str(); | |||
| auto srcPrim = cNode->primitive(); | |||
| this->ops_[name] = CopyPrimitive(srcPrim); | |||
| // flatbuffers::FlatBufferBuilder fbb(1024); | |||
| // schema::Conv2DBuilder conv2DBuilder(fbb); | |||
| // conv2DBuilder.add_padMode(srcPrim->value_as_Conv2D()->padMode()); | |||
| // conv2DBuilder.add_channelOut(srcPrim->value_as_Conv2D()->channelOut()); | |||
| // conv2DBuilder.add_channelIn(srcPrim->value_as_Conv2D()->channelIn()); | |||
| // conv2DBuilder.add_strideH(srcPrim->value_as_Conv2D()->strideH()); | |||
| // conv2DBuilder.add_strideW(srcPrim->value_as_Conv2D()->strideW()); | |||
| // conv2DBuilder.add_dilateH(srcPrim->value_as_Conv2D()->dilateH()); | |||
| // conv2DBuilder.add_dilateW(srcPrim->value_as_Conv2D()->dilateW()); | |||
| // conv2DBuilder.add_kernelH(srcPrim->value_as_Conv2D()->kernelH()); | |||
| // conv2DBuilder.add_kernelW(srcPrim->value_as_Conv2D()->kernelW()); | |||
| // conv2DBuilder.add_padUp(srcPrim->value_as_Conv2D()->padUp()); | |||
| // conv2DBuilder.add_padDown(srcPrim->value_as_Conv2D()->padDown()); | |||
| // conv2DBuilder.add_padLeft(srcPrim->value_as_Conv2D()->padLeft()); | |||
| // conv2DBuilder.add_padRight(srcPrim->value_as_Conv2D()->padRight()); | |||
| // conv2DBuilder.add_format(srcPrim->value_as_Conv2D()->format()); | |||
| // conv2DBuilder.add_group(srcPrim->value_as_Conv2D()->group()); | |||
| // conv2DBuilder.add_activationType(srcPrim->value_as_Conv2D()->activationType()); | |||
| // schema::PrimitiveBuilder primBuilder(fbb); | |||
| // primBuilder.add_value_type(srcPrim->value_type()); | |||
| // primBuilder.add_value(conv2DBuilder.Finish()); | |||
| // | |||
| // fbb.Finish(conv2DBuilder.Finish()); | |||
| // auto buf = fbb.GetBufferPointer(); | |||
| // auto conv2D = flatbuffers::GetRoot<schema::Conv2D>(buf); | |||
| // fbb.Clear(); | |||
| // | |||
| // return const_cast<mindspore::predict::OpDef *>(opDef); | |||
| } | |||
| return 0; | |||
| } | |||
| Model *Model::Import(const char *model_buf, size_t size) { | |||
| auto model = new Model(); | |||
| if (model_buf == nullptr) { | |||
| @@ -55,8 +349,4 @@ const schema::MetaGraph *Model::GetMetaGraph() const { | |||
| return model_impl_->meta_graph(); | |||
| } | |||
| ModelImpl *Model::model_impl() { | |||
| MS_EXCEPTION_IF_NULL(model_impl_); | |||
| return this->model_impl_; | |||
| } | |||
| } // namespace mindspore::lite | |||
| @@ -1,297 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <memory> | |||
| #include <string> | |||
| #include "src/model_impl.h" | |||
| #include "utils/log_adapter.h" | |||
| namespace mindspore::lite { | |||
| ModelImpl *ModelImpl::Import(const char *model_buf, size_t size) { | |||
| if (model_buf == nullptr) { | |||
| MS_LOG(ERROR) << "The model buf is nullptr"; | |||
| return nullptr; | |||
| } | |||
| flatbuffers::Verifier verify((const uint8_t *)model_buf, size); | |||
| if (!schema::VerifyMetaGraphBuffer(verify)) { | |||
| MS_LOG(ERROR) << "The buffer is invalid and fail to create graph."; | |||
| return nullptr; | |||
| } | |||
| auto *inner_model_buf = new (std::nothrow) char[size]; | |||
| if (inner_model_buf == nullptr) { | |||
| MS_LOG(ERROR) << "new model buf fail."; | |||
| return nullptr; | |||
| } | |||
| memcpy(inner_model_buf, model_buf, size); | |||
| auto model = new (std::nothrow) ModelImpl(inner_model_buf, size); | |||
| if (model == nullptr) { | |||
| MS_LOG(ERROR) << "Create modelImpl failed"; | |||
| return nullptr; | |||
| } | |||
| auto ret = model->BuildOps(); | |||
| if (0 != ret) { | |||
| MS_LOG(ERROR) << "BuildOps failed"; | |||
| return nullptr; | |||
| } | |||
| return model; | |||
| } | |||
| lite::Primitive *ModelImpl::GetOp(const std::string &name) const { | |||
| auto iter = ops_.find(name); | |||
| if (iter == ops_.end()) { | |||
| return nullptr; | |||
| } else { | |||
| return iter->second; | |||
| } | |||
| } | |||
| ModelImpl::~ModelImpl() { | |||
| delete[](this->model_buf_); | |||
| for (auto iter : ops_) { | |||
| delete (iter.second); | |||
| } | |||
| ops_.clear(); | |||
| } | |||
| void ModelImpl::FreeMetaGraph() { | |||
| delete[](this->model_buf_); | |||
| model_buf_ = nullptr; | |||
| } | |||
| const schema::MetaGraph *ModelImpl::meta_graph() const { return this->meta_graph_; } | |||
| lite::Primitive *ModelImpl::CopyPrimitive(const schema::Primitive *src_prim) { | |||
| MS_EXCEPTION_IF_NULL(src_prim); | |||
| auto op_type = src_prim->value_type(); | |||
| switch (op_type) { | |||
| case schema::PrimitiveType_SoftMax: | |||
| return new lite::SoftMax(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_Activation: | |||
| return new lite::Activation(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_Conv2D: | |||
| return new lite::Conv2D(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_DeConv2D: | |||
| return new lite::DeConv2D(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_Reduce: | |||
| return new lite::Reduce(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_Pooling: | |||
| return new lite::Pooling(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_DepthwiseConv2D: | |||
| return new lite::DepthwiseConv2D(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_FusedBatchNorm: | |||
| return new lite::FusedBatchNorm(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_BatchNorm: | |||
| return new lite::BatchNorm(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_FullConnection: | |||
| return new lite::FullConnection(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_Power: | |||
| return new lite::Power(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_Range: | |||
| return new lite::Range(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_Mul: | |||
| return new lite::Mul(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_Add: | |||
| return new lite::Add(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_Sub: | |||
| return new lite::Sub(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_Div: | |||
| return new lite::Div(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_BiasAdd: | |||
| return new lite::BiasAdd(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_ExpandDims: | |||
| return new lite::ExpandDims(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_ArgMax: | |||
| return new lite::ArgMax(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_ArgMin: | |||
| return new lite::ArgMin(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_Cast: | |||
| return new lite::Cast(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_Reshape: | |||
| return new lite::Reshape(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_Scale: | |||
| return new lite::Scale(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_Eltwise: | |||
| return new lite::Eltwise(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_Concat: | |||
| return new lite::Concat(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_Fill: | |||
| return new lite::Fill(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_Transpose: | |||
| return new lite::Transpose(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_Slice: | |||
| return new lite::Slice(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_Squeeze: | |||
| return new lite::Squeeze(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_Nchw2Nhwc: | |||
| return new lite::Nchw2Nhwc(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_Nhwc2Nchw: | |||
| return new lite::Nhwc2Nchw(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_Flatten: | |||
| return new lite::Flatten(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_Mean: | |||
| return new lite::Mean(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_Stack: | |||
| return new lite::Stack(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_Crop: | |||
| return new lite::Crop(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_SquaredDifference: | |||
| return new lite::SquaredDifference(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_AddN: | |||
| return new lite::AddN(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_Abs: | |||
| return new lite::Abs(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_Sin: | |||
| return new lite::Sin(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_Cos: | |||
| return new lite::Cos(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_Log: | |||
| return new lite::Log(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_Sqrt: | |||
| return new lite::Sqrt(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_Rsqrt: | |||
| return new lite::Rsqrt(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_Square: | |||
| return new lite::Square(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_Exp: | |||
| return new lite::Exp(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_Gather: | |||
| return new lite::Gather(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_GatherNd: | |||
| return new lite::GatherNd(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_LocalResponseNormalization: | |||
| return new lite::LocalResponseNormalization(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_Maximum: | |||
| return new lite::Maximum(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_Minimum: | |||
| return new lite::Minimum(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_Pad: | |||
| return new lite::Pad(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_StridedSlice: | |||
| return new lite::StridedSlice(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_Prelu: | |||
| return new lite::Prelu(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_CaffePReLU: | |||
| return new lite::CaffePReLU(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_Round: | |||
| return new lite::Round(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_Reverse: | |||
| return new lite::Reverse(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_ReverseSequence: | |||
| return new lite::ReverseSequence(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_LogicalAnd: | |||
| return new lite::LogicalAnd(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_LogicalOr: | |||
| return new lite::LogicalOr(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_LogicalNot: | |||
| return new lite::LogicalNot(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_FloorDiv: | |||
| return new lite::FloorDiv(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_FloorMod: | |||
| return new lite::FloorMod(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_Equal: | |||
| return new lite::Equal(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_NotEqual: | |||
| return new lite::NotEqual(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_Less: | |||
| return new lite::Less(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_LessEqual: | |||
| return new lite::LessEqual(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_Greater: | |||
| return new lite::Greater(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_GreaterEqual: | |||
| return new lite::GreaterEqual(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_Floor: | |||
| return new lite::Floor(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_Ceil: | |||
| return new lite::Ceil(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_Split: | |||
| return new lite::Split(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_OneHot: | |||
| return new lite::OneHot(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_SpaceToDepth: | |||
| return new lite::SpaceToDepth(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_Tile: | |||
| return new lite::Tile(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_Resize: | |||
| return new lite::Resize(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_Unstack: | |||
| return new lite::Unstack(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_Unique: | |||
| return new lite::Unique(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_TopK: | |||
| return new lite::TopK(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_MatMul: | |||
| return new lite::MatMul(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_QuantDTypeCast: | |||
| return new lite::QuantDTypeCast(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_EmbeddingLookup: | |||
| return new lite::EmbeddingLookup(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_Elu: | |||
| return new lite::Elu(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_DeDepthwiseConv2D: | |||
| return new lite::DeconvDepthwiseConv2D(const_cast<schema::Primitive *>(src_prim)); | |||
| case schema::PrimitiveType_Shape: | |||
| return new lite::Shape(const_cast<schema::Primitive *>(src_prim)); | |||
| default: | |||
| break; | |||
| } | |||
| return nullptr; | |||
| } | |||
| int ModelImpl::BuildOps() { | |||
| if (this->meta_graph_ == nullptr) { | |||
| MS_LOG(ERROR) << "mete_graph is nullptr"; | |||
| return -1; | |||
| } | |||
| MS_EXCEPTION_IF_NULL(meta_graph_->nodes()); | |||
| for (size_t i = 0; i < meta_graph_->nodes()->size(); i++) { | |||
| auto cNode = meta_graph_->nodes()->GetAs<schema::CNode>(i); | |||
| auto name = cNode->name()->str(); | |||
| auto srcPrim = cNode->primitive(); | |||
| this->ops_[name] = CopyPrimitive(srcPrim); | |||
| // flatbuffers::FlatBufferBuilder fbb(1024); | |||
| // schema::Conv2DBuilder conv2DBuilder(fbb); | |||
| // conv2DBuilder.add_padMode(srcPrim->value_as_Conv2D()->padMode()); | |||
| // conv2DBuilder.add_channelOut(srcPrim->value_as_Conv2D()->channelOut()); | |||
| // conv2DBuilder.add_channelIn(srcPrim->value_as_Conv2D()->channelIn()); | |||
| // conv2DBuilder.add_strideH(srcPrim->value_as_Conv2D()->strideH()); | |||
| // conv2DBuilder.add_strideW(srcPrim->value_as_Conv2D()->strideW()); | |||
| // conv2DBuilder.add_dilateH(srcPrim->value_as_Conv2D()->dilateH()); | |||
| // conv2DBuilder.add_dilateW(srcPrim->value_as_Conv2D()->dilateW()); | |||
| // conv2DBuilder.add_kernelH(srcPrim->value_as_Conv2D()->kernelH()); | |||
| // conv2DBuilder.add_kernelW(srcPrim->value_as_Conv2D()->kernelW()); | |||
| // conv2DBuilder.add_padUp(srcPrim->value_as_Conv2D()->padUp()); | |||
| // conv2DBuilder.add_padDown(srcPrim->value_as_Conv2D()->padDown()); | |||
| // conv2DBuilder.add_padLeft(srcPrim->value_as_Conv2D()->padLeft()); | |||
| // conv2DBuilder.add_padRight(srcPrim->value_as_Conv2D()->padRight()); | |||
| // conv2DBuilder.add_format(srcPrim->value_as_Conv2D()->format()); | |||
| // conv2DBuilder.add_group(srcPrim->value_as_Conv2D()->group()); | |||
| // conv2DBuilder.add_activationType(srcPrim->value_as_Conv2D()->activationType()); | |||
| // schema::PrimitiveBuilder primBuilder(fbb); | |||
| // primBuilder.add_value_type(srcPrim->value_type()); | |||
| // primBuilder.add_value(conv2DBuilder.Finish()); | |||
| // | |||
| // fbb.Finish(conv2DBuilder.Finish()); | |||
| // auto buf = fbb.GetBufferPointer(); | |||
| // auto conv2D = flatbuffers::GetRoot<schema::Conv2D>(buf); | |||
| // fbb.Clear(); | |||
| // | |||
| // return const_cast<mindspore::predict::OpDef *>(opDef); | |||
| } | |||
| return 0; | |||
| } | |||
| } // namespace mindspore::lite | |||
| @@ -1,53 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_SRC_MODEL_IMPL_H_ | |||
| #define MINDSPORE_LITE_SRC_MODEL_IMPL_H_ | |||
| #include <map> | |||
| #include <memory> | |||
| #include <string> | |||
| #include "schema/model_generated.h" | |||
| #include "src/ops/ops.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| class ModelImpl { | |||
| public: | |||
| static ModelImpl *Import(const char *model_buf, size_t size); | |||
| ModelImpl() = default; | |||
| explicit ModelImpl(const char *model_buf, size_t size) : model_buf_(model_buf), buf_size_(size) { | |||
| meta_graph_ = schema::GetMetaGraph(model_buf); | |||
| } | |||
| virtual ~ModelImpl(); | |||
| lite::Primitive *GetOp(const std::string &name) const; | |||
| const schema::MetaGraph *meta_graph() const; | |||
| void FreeMetaGraph(); | |||
| int BuildOps(); | |||
| protected: | |||
| lite::Primitive *CopyPrimitive(const schema::Primitive *src_prim); | |||
| protected: | |||
| const char *model_buf_; | |||
| size_t buf_size_; | |||
| const schema::MetaGraph *meta_graph_ = nullptr; | |||
| std::map<std::string, lite::Primitive *> ops_; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_LITE_INCLUDE_MODEL_H_ | |||
| @@ -19,7 +19,7 @@ | |||
| #include "src/runtime/kernel/arm/int8/argminmax_int8.h" | |||
| #include "src/runtime/kernel/arm/nnacl/arithmetic_common.h" | |||
| #include "schema/model_generated.h" | |||
| #include "src/kernel_factory.h" | |||
| #include "src/kernel_registry.h" | |||
| #include "include/errorcode.h" | |||
| #include "include/context.h" | |||
| @@ -18,7 +18,7 @@ | |||
| #include "src/runtime/kernel/arm/fp32/batch_to_space.h" | |||
| #include "src/runtime/kernel/arm/int8/batch_to_space_int8.h" | |||
| #include "schema/model_generated.h" | |||
| #include "src/kernel_factory.h" | |||
| #include "src/kernel_registry.h" | |||
| #include "include/errorcode.h" | |||
| #include "include/context.h" | |||
| @@ -16,7 +16,7 @@ | |||
| #include "src/runtime/kernel/arm/base/caffeprelu_base.h" | |||
| #include <vector> | |||
| #include "schema/model_generated.h" | |||
| #include "src/kernel_factory.h" | |||
| #include "src/kernel_registry.h" | |||
| #include "include/errorcode.h" | |||
| #include "include/context.h" | |||
| @@ -19,7 +19,7 @@ | |||
| #include "src/runtime/kernel/arm/fp32/concat.h" | |||
| #include "src/runtime/kernel/arm/nnacl/fp32/concat.h" | |||
| #include "schema/model_generated.h" | |||
| #include "src/kernel_factory.h" | |||
| #include "src/kernel_registry.h" | |||
| #include "include/errorcode.h" | |||
| #include "include/context.h" | |||
| @@ -17,7 +17,7 @@ | |||
| #include "src/runtime/kernel/arm/base/convolution_base.h" | |||
| #include <float.h> | |||
| #include "schema/model_generated.h" | |||
| #include "src/kernel_factory.h" | |||
| #include "src/kernel_registry.h" | |||
| #include "include/errorcode.h" | |||
| using mindspore::lite::KernelRegistrar; | |||
| @@ -18,7 +18,7 @@ | |||
| #include "src/runtime/kernel/arm/int8/crop_int8.h" | |||
| #include "src/runtime/kernel/arm/fp32/crop.h" | |||
| #include "schema/model_generated.h" | |||
| #include "src/kernel_factory.h" | |||
| #include "src/kernel_registry.h" | |||
| #include "include/errorcode.h" | |||
| #include "include/context.h" | |||
| @@ -19,7 +19,7 @@ | |||
| #include "src/runtime/kernel/arm/int8/depth_to_space_int8.h" | |||
| #include "src/runtime/kernel/arm/nnacl/arithmetic_common.h" | |||
| #include "schema/model_generated.h" | |||
| #include "src/kernel_factory.h" | |||
| #include "src/kernel_registry.h" | |||
| #include "include/errorcode.h" | |||
| #include "include/context.h" | |||
| @@ -17,7 +17,7 @@ | |||
| #include "src/runtime/kernel/arm/int8/fullconnection_int8.h" | |||
| #include "src/runtime/kernel/arm/fp32/fullconnection.h" | |||
| #include "schema/model_generated.h" | |||
| #include "src/kernel_factory.h" | |||
| #include "src/kernel_registry.h" | |||
| #include "include/errorcode.h" | |||
| #include "include/context.h" | |||
| @@ -16,7 +16,7 @@ | |||
| #include "src/runtime/kernel/arm/base/matmul_base.h" | |||
| #include "src/runtime/kernel/arm/fp32/matmul.h" | |||
| #include "src/runtime/kernel/arm/int8/matmul_int8.h" | |||
| #include "src/kernel_factory.h" | |||
| #include "src/kernel_registry.h" | |||
| #include "include/errorcode.h" | |||
| #include "include/context.h" | |||
| @@ -17,7 +17,7 @@ | |||
| #include "src/runtime/kernel/arm/fp32/pad.h" | |||
| #include "src/runtime/kernel/arm/int8/pad_int8.h" | |||
| #include "schema/model_generated.h" | |||
| #include "src/kernel_factory.h" | |||
| #include "src/kernel_registry.h" | |||
| #include "include/errorcode.h" | |||
| #include "include/context.h" | |||
| @@ -18,7 +18,7 @@ | |||
| #include "src/runtime/kernel/arm/int8/pooling_int8.h" | |||
| #include "src/runtime/kernel/arm/fp32/pooling.h" | |||
| #include "schema/model_generated.h" | |||
| #include "src/kernel_factory.h" | |||
| #include "src/kernel_registry.h" | |||
| #include "include/errorcode.h" | |||
| #include "include/context.h" | |||
| @@ -17,7 +17,7 @@ | |||
| #include <vector> | |||
| #include "src/runtime/kernel/arm/int8/prelu_int8.h" | |||
| #include "schema/model_generated.h" | |||
| #include "src/kernel_factory.h" | |||
| #include "src/kernel_registry.h" | |||
| #include "include/errorcode.h" | |||
| #include "include/context.h" | |||
| @@ -18,7 +18,7 @@ | |||
| #include <cmath> | |||
| #include "src/runtime/kernel/arm/base/prior_box.h" | |||
| #include "schema/model_generated.h" | |||
| #include "src/kernel_factory.h" | |||
| #include "src/kernel_registry.h" | |||
| #include "include/errorcode.h" | |||
| #include "include/context.h" | |||
| #include "src/runtime/runtime_api.h" | |||
| @@ -18,7 +18,7 @@ | |||
| #include "src/runtime/kernel/arm/int8/reshape_int8.h" | |||
| #include "src/runtime/kernel/arm/fp32/reshape.h" | |||
| #include "schema/model_generated.h" | |||
| #include "src/kernel_factory.h" | |||
| #include "src/kernel_registry.h" | |||
| #include "include/errorcode.h" | |||
| #include "include/context.h" | |||
| @@ -20,7 +20,7 @@ | |||
| #include "src/runtime/kernel/arm/fp32/softmax.h" | |||
| #include "src/runtime/kernel/arm/nnacl/fp32/softmax.h" | |||
| #include "schema/model_generated.h" | |||
| #include "src/kernel_factory.h" | |||
| #include "src/kernel_registry.h" | |||
| #include "include/errorcode.h" | |||
| using mindspore::lite::KernelRegistrar; | |||
| @@ -18,7 +18,7 @@ | |||
| #include "src/runtime/kernel/arm/int8/split_int8.h" | |||
| #include "src/runtime/kernel/arm/fp32/split.h" | |||
| #include "schema/model_generated.h" | |||
| #include "src/kernel_factory.h" | |||
| #include "src/kernel_registry.h" | |||
| #include "include/errorcode.h" | |||
| #include "include/context.h" | |||
| @@ -17,7 +17,7 @@ | |||
| #include <vector> | |||
| #include "src/runtime/kernel/arm/int8/squeeze_int8.h" | |||
| #include "schema/model_generated.h" | |||
| #include "src/kernel_factory.h" | |||
| #include "src/kernel_registry.h" | |||
| #include "include/errorcode.h" | |||
| #include "include/context.h" | |||
| @@ -18,7 +18,7 @@ | |||
| #include "src/runtime/kernel/arm/nnacl/fp16/cast_fp16.h" | |||
| #include "src/runtime/kernel/arm/fp16/common_fp16.h" | |||
| #include "schema/model_generated.h" | |||
| #include "src/kernel_factory.h" | |||
| #include "src/kernel_registry.h" | |||
| #include "include/errorcode.h" | |||
| #include "src/runtime/runtime_api.h" | |||
| @@ -22,7 +22,7 @@ | |||
| #include "src/runtime/kernel/arm/nnacl/fp32/conv.h" | |||
| #include "src/runtime/kernel/arm/nnacl/common_func.h" | |||
| #include "schema/model_generated.h" | |||
| #include "src/kernel_factory.h" | |||
| #include "src/kernel_registry.h" | |||
| #include "include/errorcode.h" | |||
| #include "src/runtime/runtime_api.h" | |||
| @@ -17,7 +17,7 @@ | |||
| #include "src/runtime/kernel/arm/fp32/convolution_slidewindow.h" | |||
| #include "src/runtime/kernel/arm/nnacl/common_func.h" | |||
| #include "schema/model_generated.h" | |||
| #include "src/kernel_factory.h" | |||
| #include "src/kernel_registry.h" | |||
| #include "include/errorcode.h" | |||
| #include "src/runtime/runtime_api.h" | |||
| @@ -20,7 +20,7 @@ | |||
| #include <vector> | |||
| #include "src/lite_kernel.h" | |||
| #include "src/kernel_factory.h" | |||
| #include "src/kernel_registry.h" | |||
| namespace mindspore::kernel { | |||
| @@ -17,7 +17,7 @@ | |||
| #include <algorithm> | |||
| #include <vector> | |||
| #include "schema/model_generated.h" | |||
| #include "src/kernel_factory.h" | |||
| #include "src/kernel_registry.h" | |||
| #include "src/runtime/kernel/arm/fp32_grad/bn_grad.h" | |||
| #include "src/runtime/kernel/arm/nnacl/fp32_grad/batch_norm.h" | |||
| #include "include/errorcode.h" | |||
| @@ -18,7 +18,7 @@ | |||
| #include <vector> | |||
| #include <algorithm> | |||
| #include "include/errorcode.h" | |||
| #include "src/kernel_factory.h" | |||
| #include "src/kernel_registry.h" | |||
| #include "src/common/graph_util.h" | |||
| #include "src/common/utils.h" | |||
| #if SUPPORT_GPU | |||
| @@ -191,7 +191,7 @@ kernel::LiteKernel *Scheduler::ScheduleNode(const std::vector<tensor::Tensor *> | |||
| kernel::KernelKey desc{kernel::KERNEL_ARCH::kCPU, data_type, primitive->Type()}; | |||
| if (context_->device_ctx_.type == DT_GPU) { | |||
| desc.arch = kernel::KERNEL_ARCH::kGPU; | |||
| auto *kernel = KernelFactory::GetInstance()->GetKernel(in_tensors, out_tensors, primitive, context_, desc); | |||
| auto *kernel = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, primitive, context_, desc); | |||
| if (nullptr != kernel) { | |||
| kernel->set_desc(desc); | |||
| return kernel; | |||
| @@ -203,7 +203,7 @@ kernel::LiteKernel *Scheduler::ScheduleNode(const std::vector<tensor::Tensor *> | |||
| if ((context_->float16_priority && data_type == kNumberTypeFloat32) || data_type == kNumberTypeFloat16) { | |||
| // check if support fp16 | |||
| kernel::KernelKey key{desc.arch, kNumberTypeFloat16, desc.type}; | |||
| kernel = KernelFactory::GetInstance()->GetKernel(in_tensors, out_tensors, primitive, context_, key); | |||
| kernel = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, primitive, context_, key); | |||
| if (kernel != nullptr) { | |||
| MS_LOG(DEBUG) << "Get fp16 op success."; | |||
| desc.data_type = kNumberTypeFloat16; | |||
| @@ -215,7 +215,7 @@ kernel::LiteKernel *Scheduler::ScheduleNode(const std::vector<tensor::Tensor *> | |||
| if (data_type == kNumberTypeFloat16) { | |||
| desc.data_type = kNumberTypeFloat32; | |||
| } | |||
| kernel = KernelFactory::GetInstance()->GetKernel(in_tensors, out_tensors, primitive, context_, desc); | |||
| kernel = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, primitive, context_, desc); | |||
| if (kernel != nullptr) { | |||
| kernel->set_desc(desc); | |||
| return kernel; | |||
| @@ -1,59 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "src/train/base_ref_utils.h" | |||
| #include <vector> | |||
| #include <memory> | |||
| // #include "utils/base_ref_utils.h" | |||
| #include "include/ms_tensor.h" | |||
| #include "src/ir/tensor.h" | |||
| namespace mindspore { | |||
| std::vector<std::shared_ptr<tensor::MSTensor>> TransformBaseRefToMSTensor(const BaseRef &base_ref) { | |||
| std::vector<std::shared_ptr<tensor::MSTensor>> msTensors; | |||
| if (utils::isa<VectorRef>(base_ref)) { | |||
| auto ref_list = utils::cast<VectorRef>(base_ref); | |||
| for (size_t i = 0; i < ref_list.size(); ++i) { | |||
| if (utils::isa<tensor::Tensor>(ref_list[i])) { | |||
| auto tensor_ptr = utils::cast<std::shared_ptr<tensor::Tensor>>(ref_list[i]); | |||
| MS_EXCEPTION_IF_NULL(tensor_ptr); | |||
| auto tensor = new tensor::LiteTensor(new tensor::Tensor(*tensor_ptr)); | |||
| msTensors.emplace_back(std::shared_ptr<tensor::MSTensor>(tensor)); | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "The output is not a tensor!"; | |||
| } | |||
| } | |||
| } else if (utils::isa<tensor::Tensor>(base_ref)) { | |||
| auto tensor_ptr = utils::cast<std::shared_ptr<tensor::Tensor>>(base_ref); | |||
| MS_EXCEPTION_IF_NULL(tensor_ptr); | |||
| auto tensor = new tensor::LiteTensor(new tensor::Tensor(*tensor_ptr)); | |||
| msTensors.emplace_back(std::shared_ptr<tensor::MSTensor>(tensor)); | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "The output is not a base ref list or a tensor!"; | |||
| } | |||
| return msTensors; | |||
| } | |||
| std::vector<std::vector<std::shared_ptr<tensor::MSTensor>>> TransformVectorRefToMultiTensor( | |||
| const VectorRef &vector_ref) { | |||
| std::vector<std::vector<std::shared_ptr<tensor::MSTensor>>> multiTensor; | |||
| for (size_t i = 0; i < vector_ref.size(); ++i) { | |||
| auto tensors = TransformBaseRefToMSTensor(vector_ref[i]); | |||
| multiTensor.emplace_back(tensors); | |||
| } | |||
| return multiTensor; | |||
| } | |||
| } // namespace mindspore | |||
| @@ -1,30 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <vector> | |||
| #include <memory> | |||
| #include "utils/base_ref.h" | |||
| #include "include/ms_tensor.h" | |||
| #ifndef MINDSPORE_LITE_SRC_TRAIN_BASE_REF_UTILS_H_ | |||
| #define MINDSPORE_LITE_SRC_TRAIN_BASE_REF_UTILS_H_ | |||
| namespace mindspore { | |||
| std::vector<std::shared_ptr<tensor::MSTensor>> TransformBaseRefToMSTensor(const BaseRef &base_ref); | |||
| std::vector<std::vector<std::shared_ptr<tensor::MSTensor>>> TransformVectorRefToMultiTensor( | |||
| const VectorRef &vector_ref); | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_LITE_SRC_TRAIN_BASE_REF_UTILS_H_ | |||
| @@ -1,50 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include<memory> | |||
| #include "src/common/anf_importer/import_from_meta_graph.h" | |||
| namespace mindspore::lite::train { | |||
| std::shared_ptr<ModelImpl> Import(const char *model_buf, size_t size) { | |||
| MS_EXCEPTION_IF_NULL(model_buf); | |||
| flatbuffers::Verifier verify((const uint8_t *) model_buf, size); | |||
| if (!schema::VerifyMetaGraphBuffer(verify)) { | |||
| MS_LOG(ERROR) << "The buffer is invalid and fail to create graph."; | |||
| return nullptr; | |||
| } | |||
| // todo hangangqiang remove when copy primitive done | |||
| if (size <= 0) { | |||
| MS_LOG(ERROR) << "size is zero"; | |||
| return nullptr; | |||
| } | |||
| auto *inner_buf = new char[size]; | |||
| memcpy(inner_buf, model_buf, size); | |||
| auto meta_graph = schema::GetMetaGraph(inner_buf); | |||
| auto model = std::make_shared<ModelImpl>(meta_graph); | |||
| auto ret = model->BuildOps(); | |||
| if (0 != ret) { | |||
| MS_LOG(ERROR) << "BuildOps failed"; | |||
| return nullptr; | |||
| } | |||
| MS_EXCEPTION_IF_NULL(meta_graph); | |||
| auto importer = new AnfImporterFromMetaGraph(model); | |||
| auto ret2 = importer->Import(); | |||
| if (0 != ret2) { | |||
| MS_LOG(ERROR) << "Import anf_graph from meta_graph failed, ret2: " << ret2; | |||
| return nullptr; | |||
| } | |||
| return model; | |||
| } | |||
| } // namespace mindspore::lite::train | |||
| @@ -1,86 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "src/train/lite_kernel_runtime.h" | |||
| #include "backend/session/anf_runtime_algorithm.h" | |||
| namespace mindspore::lite { | |||
| std::vector<CNodePtr> LiteInferKernelRuntime::GetGraphInputs(const std::vector<CNodePtr> &execution_order) { | |||
| std::vector<CNodePtr> graph_inputs; | |||
| for (const auto &cnode : execution_order) { | |||
| bool is_graph_inputs = true; | |||
| for (const auto &input : cnode->inputs()) { | |||
| if (input->isa<CNode>()) { | |||
| is_graph_inputs = false; | |||
| break; | |||
| } | |||
| } | |||
| if (is_graph_inputs) { | |||
| graph_inputs.emplace_back(cnode); | |||
| } | |||
| } | |||
| return graph_inputs; | |||
| } | |||
| void LiteInferKernelRuntime::BindInputOutput(const session::KernelGraph *graph, | |||
| const std::vector<tensor::Tensor *> &inputs, | |||
| std::vector<tensor::Tensor *> *outputs) { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| auto execution_order = graph->execution_order(); | |||
| auto graph_inputs = GetGraphInputs(execution_order); | |||
| int input_count = 0; | |||
| for (const auto &graph_input : graph_inputs) { | |||
| auto liteKernel = dynamic_cast<kernel::LiteKernel *>(AnfAlgo::GetKernelMod(graph_input)); | |||
| for (auto input_tensor : liteKernel->GetInputs()) { | |||
| if (schema::NodeType_ValueNode == input_tensor->TensorType() && input_tensor->Data() != nullptr) { | |||
| continue; | |||
| } | |||
| input_tensor->SetData(inputs[input_count]->Data()); | |||
| input_count++; | |||
| } | |||
| } | |||
| auto return_node = graph->get_return(); | |||
| for (const auto &return_input : return_node->inputs()) { | |||
| if (return_input->isa<CNode>()) { | |||
| auto liteKernel = dynamic_cast<kernel::LiteKernel *>(AnfAlgo::GetKernelMod(return_input)); | |||
| auto output_tensors = liteKernel->GetOutputs(); | |||
| for (auto output_tensor : output_tensors) { | |||
| // tensor::TensorPtr output_tensor_ptr(output_tensor); | |||
| outputs->push_back(output_tensor); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| bool LiteInferKernelRuntime::Run(session::KernelGraph *graph, const std::vector<tensor::Tensor *> &inputs, | |||
| std::vector<tensor::Tensor *> *outputs) { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| BindInputOutput(graph, inputs, *outputs); | |||
| std::vector<kernel::LiteKernel *> kernels; | |||
| auto nodes = graph->execution_order(); | |||
| for (const auto &node : nodes) { | |||
| auto liteKernel = dynamic_cast<kernel::LiteKernel *>(AnfAlgo::GetKernelMod(node)); | |||
| if (liteKernel == nullptr) { | |||
| continue; | |||
| } | |||
| kernels.emplace_back(liteKernel); | |||
| } | |||
| kernel::LiteKernelUtil::TopologicalSortKernels(kernels); | |||
| Executor executor; | |||
| auto ret = executor.Run(inputs, *outputs, kernels); | |||
| return 0 == ret; | |||
| } | |||
| } // namespace mindspore::lite | |||
| @@ -1,50 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_SRC_TRAIN_LITE_KERNEL_RUNTIME_H_ | |||
| #define MINDSPORE_LITE_SRC_TRAIN_LITE_KERNEL_RUNTIME_H_ | |||
| #include <memory> | |||
| #include <vector> | |||
| #include <string> | |||
| #include <unordered_map> | |||
| #include "src/runtime/allocator.h" | |||
| #include "src/executor.h" | |||
| // #include "runtime/device/kernel_runtime.h" | |||
| #include "runtime/device/device_address.h" | |||
| #include "src/lite_kernel.h" | |||
| #include "backend/session/kernel_graph.h" | |||
| namespace mindspore::lite { | |||
| class LiteInferKernelRuntime { | |||
| public: | |||
| LiteInferKernelRuntime() = default; | |||
| ~LiteInferKernelRuntime() = default; | |||
| bool Run(session::KernelGraph *graph, const std::vector<tensor::Tensor *> &inputs, | |||
| std::vector<tensor::Tensor *> *outputs); | |||
| void AssignKernelAddress(session::KernelGraph *graph) {} | |||
| protected: | |||
| void BindInputOutput(const session::KernelGraph *graph, const std::vector<tensor::Tensor *> &inputs, | |||
| std::vector<tensor::Tensor *> *outputs); | |||
| std::vector<CNodePtr> GetGraphInputs(const std::vector<CNodePtr> &execution_order); | |||
| }; | |||
| } // namespace mindspore::lite | |||
| #endif // MINDSPORE_LITE_SRC_TRAIN_LITE_KERNEL_RUNTIME_H_ | |||
| @@ -1,145 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <string> | |||
| #include "src/train/model_impl.h" | |||
| #include "ir/func_graph.h" | |||
| #include "schema/model_generated.h" | |||
| #include "src/common/anf_importer/import_from_meta_graph.h" | |||
| namespace mindspore::lite::train { | |||
| std::shared_ptr<ModelImpl> ModelImpl::Import(const char *model_buf, size_t size) { | |||
| MS_EXCEPTION_IF_NULL(model_buf); | |||
| flatbuffers::Verifier verify((const uint8_t *)model_buf, size); | |||
| if (!schema::VerifyMetaGraphBuffer(verify)) { | |||
| MS_LOG(ERROR) << "The buffer is invalid and fail to create graph."; | |||
| return nullptr; | |||
| } | |||
| // todo hangangqiang remove when copy primitive done | |||
| auto *inner_buf = new char[size]; | |||
| memcpy(inner_buf, model_buf, size); | |||
| auto meta_graph = schema::GetMetaGraph(inner_buf); | |||
| auto func_graph_model = std::make_shared<ModelImpl>(meta_graph); | |||
| auto ret = func_graph_model->BuildOps(); | |||
| if (0 != ret) { | |||
| MS_LOG(ERROR) << "BuildOps failed"; | |||
| return nullptr; | |||
| } | |||
| AnfImporterFromMetaGraph anfImporter(func_graph_model); | |||
| anfImporter.Import(); | |||
| return func_graph_model; | |||
| } | |||
| const lite::Primitive *ModelImpl::GetOp(const std::string &name) const { | |||
| auto iter = ops.find(name); | |||
| if (iter == ops.end()) { | |||
| return nullptr; | |||
| } else { | |||
| return iter->second; | |||
| } | |||
| } | |||
| void ModelImpl::FreeMetaGraph() { delete this->meta_graph; } | |||
| const schema::MetaGraph *ModelImpl::GetMetaGraph() const { return this->meta_graph; } | |||
| lite::Primitive *ModelImpl::CopyPrimitive(const schema::Primitive *srcPrim) { | |||
| MS_EXCEPTION_IF_NULL(srcPrim); | |||
| auto op_type = srcPrim->value_type(); | |||
| switch (op_type) { | |||
| case schema::PrimitiveType_SoftMax: | |||
| return new lite::SoftMax(const_cast<schema::Primitive *>(srcPrim)); | |||
| case schema::PrimitiveType_Activation: | |||
| return new lite::Activation(const_cast<schema::Primitive *>(srcPrim)); | |||
| case schema::PrimitiveType_Conv2D: | |||
| return new lite::Conv2D(const_cast<schema::Primitive *>(srcPrim)); | |||
| case schema::PrimitiveType_Reduce: | |||
| return new lite::Reduce(const_cast<schema::Primitive *>(srcPrim)); | |||
| case schema::PrimitiveType_Pooling: | |||
| return new lite::Pooling(const_cast<schema::Primitive *>(srcPrim)); | |||
| case schema::PrimitiveType_DepthwiseConv2D: | |||
| return new lite::DepthwiseConv2D(const_cast<schema::Primitive *>(srcPrim)); | |||
| case schema::PrimitiveType_FusedBatchNorm: | |||
| return new lite::FusedBatchNorm(const_cast<schema::Primitive *>(srcPrim)); | |||
| case schema::PrimitiveType_CaffeBatchNorm: | |||
| return new lite::CaffeBatchNorm(const_cast<schema::Primitive *>(srcPrim)); | |||
| case schema::PrimitiveType_FullConnection: | |||
| return new lite::FullConnection(const_cast<schema::Primitive *>(srcPrim)); | |||
| case schema::PrimitiveType_Power: | |||
| return new lite::Power(const_cast<schema::Primitive *>(srcPrim)); | |||
| case schema::PrimitiveType_Range: | |||
| return new lite::Range(const_cast<schema::Primitive *>(srcPrim)); | |||
| case schema::PrimitiveType_Mul: | |||
| return new lite::Mul(const_cast<schema::Primitive *>(srcPrim)); | |||
| case schema::PrimitiveType_Add: | |||
| return new lite::Add(const_cast<schema::Primitive *>(srcPrim)); | |||
| case schema::PrimitiveType_Sub: | |||
| return new lite::Sub(const_cast<schema::Primitive *>(srcPrim)); | |||
| case schema::PrimitiveType_Div: | |||
| return new lite::Div(const_cast<schema::Primitive *>(srcPrim)); | |||
| case schema::PrimitiveType_BiasAdd: | |||
| return new lite::BiasAdd(const_cast<schema::Primitive *>(srcPrim)); | |||
| case schema::PrimitiveType_ExpandDims: | |||
| return new lite::ExpandDims(const_cast<schema::Primitive *>(srcPrim)); | |||
| case schema::PrimitiveType_ArgMax: | |||
| return new lite::ArgMax(const_cast<schema::Primitive *>(srcPrim)); | |||
| case schema::PrimitiveType_ArgMin: | |||
| return new lite::ArgMin(const_cast<schema::Primitive *>(srcPrim)); | |||
| case schema::PrimitiveType_Cast: | |||
| return new lite::Cast(const_cast<schema::Primitive *>(srcPrim)); | |||
| case schema::PrimitiveType_Reshape: | |||
| return new lite::Reshape(const_cast<schema::Primitive *>(srcPrim)); | |||
| case schema::PrimitiveType_Scale: | |||
| return new lite::Scale(const_cast<schema::Primitive *>(srcPrim)); | |||
| case schema::PrimitiveType_Eltwise: | |||
| return new lite::Eltwise(const_cast<schema::Primitive *>(srcPrim)); | |||
| case schema::PrimitiveType_Ceil: | |||
| return new lite::Ceil(const_cast<schema::Primitive *>(srcPrim)); | |||
| case schema::PrimitiveType_Concat: | |||
| return new lite::Concat(const_cast<schema::Primitive *>(srcPrim)); | |||
| case schema::PrimitiveType_Fill: | |||
| return new lite::Fill(const_cast<schema::Primitive *>(srcPrim)); | |||
| case schema::PrimitiveType_Transpose: | |||
| return new lite::Transpose(const_cast<schema::Primitive *>(srcPrim)); | |||
| case schema::PrimitiveType_Slice: | |||
| return new lite::Slice(const_cast<schema::Primitive *>(srcPrim)); | |||
| case schema::PrimitiveType_Nchw2Nhwc: | |||
| return new lite::Nchw2Nhwc(const_cast<schema::Primitive *>(srcPrim)); | |||
| case schema::PrimitiveType_Nhwc2Nchw: | |||
| return new lite::Nhwc2Nchw(const_cast<schema::Primitive *>(srcPrim)); | |||
| case schema::PrimitiveType_MatMul: | |||
| return new lite::MatMul(const_cast<schema::Primitive *>(srcPrim)); | |||
| default: | |||
| break; | |||
| } | |||
| return nullptr; | |||
| } | |||
| int ModelImpl::BuildOps() { | |||
| if (this->meta_graph == nullptr) { | |||
| MS_LOG(ERROR) << "mete_graph is nullptr"; | |||
| return -1; | |||
| } | |||
| for (size_t i = 0; i < meta_graph->nodes()->size(); i++) { | |||
| auto cNode = meta_graph->nodes()->GetAs<schema::CNode>(i); | |||
| auto name = cNode->name()->str(); | |||
| auto srcPrim = cNode->primitive(); | |||
| this->ops[name] = CopyPrimitive(srcPrim); | |||
| } | |||
| return 0; | |||
| } | |||
| } // namespace mindspore::lite::train | |||
| @@ -1,64 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_SRC_TRAIN_MODEL_IMPL_H_ | |||
| #define MINDSPORE_LITE_SRC_TRAIN_MODEL_IMPL_H_ | |||
| #include <string> | |||
| #include <map> | |||
| #include <memory> | |||
| #include <vector> | |||
| #include "schema/model_generated.h" | |||
| #include "src/ops/ops.h" | |||
| #include "ir/func_graph.h" | |||
| namespace mindspore::lite { | |||
| namespace train { | |||
| class ModelImpl : public FuncGraph { | |||
| public: | |||
| static std::shared_ptr<ModelImpl> Import(const char *model_buf, size_t size); // { return NULL; }; | |||
| ModelImpl() = default; | |||
| explicit ModelImpl(const schema::MetaGraph *graph) : meta_graph(graph) {} | |||
| ~ModelImpl() override = default; | |||
| const lite::Primitive *GetOp(const std::string &name) const; | |||
| const schema::MetaGraph *GetMetaGraph() const; | |||
| void FreeMetaGraph(); | |||
| int BuildOps(); | |||
| void AddCNodeInputOutput(std::string name, const std::vector<int> &input, const std::vector<int> &output) { | |||
| std::vector<int> *tuple = new std::vector<int>[2]; | |||
| tuple[0] = input; | |||
| tuple[1] = output; | |||
| connectivity_[name] = tuple; | |||
| } | |||
| std::vector<int> *GetCNodeInputOutputIndices(std::string name) { return connectivity_[name]; } | |||
| void AddAnfNode(int id, AnfNodePtr anf_ptr) { tensors_[id] = anf_ptr; } | |||
| AnfNodePtr GetAnfNode(int id) { return tensors_[id]; } | |||
| protected: | |||
| lite::Primitive *CopyPrimitive(const schema::Primitive *srcPrim); | |||
| protected: | |||
| const schema::MetaGraph *meta_graph = nullptr; | |||
| std::map<int, AnfNodePtr> tensors_; | |||
| std::map<std::string, std::vector<int> *> connectivity_; | |||
| std::map<std::string, lite::Primitive *> ops; | |||
| }; | |||
| } // namespace train | |||
| using ModelImpl = mindspore::lite::train::ModelImpl; | |||
| } // namespace mindspore::lite | |||
| #endif // MINDSPORE_LITE_SRC_TRAIN_MODEL_IMPL_H_ | |||
| @@ -1,253 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <algorithm> | |||
| #include "src/train/train_anf_session.h" | |||
| #include "include/context.h" | |||
| #include "mindspore/ccsrc/runtime/device/kernel_info.h" | |||
| #include "mindspore/lite/src/train/train_session.h" | |||
| #include "mindspore/lite/src/kernel_factory.h" | |||
| #include "mindspore/lite/src/param_value_lite.h" | |||
| #include "common/utils.h" | |||
| #include "mindspore/lite/src/ops/ops.h" | |||
| #include "ir/anf.h" | |||
| #include "mindspore/lite/src/ir/tensor.h" | |||
| #include "abstract/abstract_value.h" | |||
| #include "backend/session/anf_runtime_algorithm.h" | |||
| #include "src/ir/primitive_value.h" | |||
| #include "src/train/model_impl.h" | |||
| namespace mindspore { | |||
| namespace session { | |||
| static std::vector<int> GetAnfNodeOutDims(const AnfNodePtr &anfNodePtr) { | |||
| auto nodeAbstract = anfNodePtr->abstract(); | |||
| if (nodeAbstract != nullptr) { | |||
| auto shape = nodeAbstract->GetShapeTrack(); | |||
| if (!shape->isa<abstract::Shape>()) { | |||
| MS_LOG(EXCEPTION) << "Not a Shape"; | |||
| return {}; | |||
| } | |||
| auto dims = dyn_cast<abstract::Shape>(shape)->shape(); | |||
| return dims; | |||
| } else { | |||
| MS_LOG(WARNING) << "abstract is nullptr, return empty dims"; | |||
| return {}; | |||
| } | |||
| } | |||
| static schema::Format GetAnfNodeFormat(const AnfNodePtr &anfNodePtr) { | |||
| auto nodeAbstract = anfNodePtr->abstract(); | |||
| if (nodeAbstract != nullptr) { | |||
| return schema::Format_NHWC; // XXX TODO -- extract Format from AnfNode | |||
| } else { | |||
| MS_LOG(WARNING) << "abstract is nullptr, return schema::Format_NHWC"; | |||
| return schema::Format_NHWC; | |||
| } | |||
| } | |||
| static TypeId GetAnfNodeOutTypeId(const AnfNodePtr &anfNodePtr) { | |||
| auto nodeAbstract = anfNodePtr->abstract(); | |||
| if (nodeAbstract != nullptr) { | |||
| return TypeId::kNumberTypeFloat32; // XXX TODO nodeAbstract->GetTypeTrack()->generic_type_id(); | |||
| } else { | |||
| MS_LOG(WARNING) << "abstract is nullptr, return kTypeUnknown"; | |||
| return TypeId::kTypeUnknown; | |||
| } | |||
| } | |||
| void TrainANFSession::Init(lite::Context *context) { | |||
| MS_EXCEPTION_IF_NULL(context); | |||
| this->context_ = std::make_shared<lite::Context>(context->thread_num_, context->allocator, context->device_ctx_); | |||
| } | |||
| lite::tensor::Tensor *TrainANFSession::GetTensorForAnfNode(const AnfNodePtr anf_node) { | |||
| lite::tensor::Tensor *out_tensor = tensors_[anf_node]; | |||
| if (out_tensor == NULL) { | |||
| out_tensor = new lite::tensor::Tensor(GetAnfNodeOutTypeId(anf_node), | |||
| GetAnfNodeOutDims(anf_node)); //, schema::NodeType_Parameter); | |||
| tensors_[anf_node] = out_tensor; | |||
| } | |||
| return out_tensor; | |||
| } | |||
| int TrainANFSession::BuildKernelInputAndOutputFromFuncGraph(const KernelGraphPtr &kernel_graph) { | |||
| auto return_node = kernel_graph->get_return(); | |||
| auto node_list = TopoSort(return_node); | |||
| auto model_imp = std::dynamic_pointer_cast<lite::train::ModelImpl>(func_graph_); | |||
| for (auto &node : node_list) { | |||
| if (!node->isa<CNode>()) { | |||
| continue; | |||
| } | |||
| KernelRelation kernel_relation; | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| kernel_relation.node_full_name = cnode->fullname_with_scope(); | |||
| kernel_relation.cnode = cnode; | |||
| std::vector<int> *cnode_io_indices = model_imp->GetCNodeInputOutputIndices(cnode->fullname_with_scope()); | |||
| if (cnode_io_indices == NULL) { | |||
| MS_LOG(WARNING) << "No IO vectors for " << cnode->fullname_with_scope(); | |||
| } else { | |||
| for (int i = 0; i < cnode_io_indices[1].size(); i++) { | |||
| AnfNodePtr anf_node = model_imp->GetAnfNode(cnode_io_indices[1].data()[i]); | |||
| kernel_relation.output_tensor.push_back(GetTensorForAnfNode(anf_node)); | |||
| } | |||
| } | |||
| lite::tensor::Tensor *tensor_ptr = nullptr; | |||
| for (size_t index = 1; index < cnode->inputs().size(); ++index) { | |||
| if (cnode->input(index)->isa<CNode>()) { | |||
| auto input_cnode = cnode->input(index)->cast<CNodePtr>(); | |||
| auto input_kernel_relation = kernel_relation_infos_[input_cnode->fullname_with_scope()]; | |||
| // todo not support multi-outputs kernel sudo as spilt | |||
| tensor_ptr = input_kernel_relation.output_tensor.front(); | |||
| } else if (cnode->input(index)->isa<Parameter>()) { | |||
| auto input_parameter = cnode->input(index)->cast<ParameterPtr>(); | |||
| auto para = input_parameter->default_param(); | |||
| auto param_value = std::dynamic_pointer_cast<ParamValueLite>(para); | |||
| // auto dims = param_value->tensor_shape(); | |||
| // tensor_ptr = new lite::tensor::Tensor(param_value->tensor_type(), dims); // schema::NodeType_ValueNode); | |||
| tensor_ptr = GetTensorForAnfNode(cnode->input(index)); | |||
| if ((param_value != nullptr) && (param_value->tensor_size() != 0)) { | |||
| tensor_ptr->SetData(param_value->tensor_addr()); | |||
| } | |||
| } else if (cnode->input(index)->isa<ValueNode>()) { | |||
| auto input_valuenode = cnode->input(index)->cast<ValueNodePtr>(); | |||
| // tensor_ptr = new lite::tensor::Tensor(GetAnfNodeOutTypeId(input_valuenode), | |||
| // GetAnfNodeOutDims(input_valuenode)); // schema::NodeType_Parameter); | |||
| tensor_ptr = GetTensorForAnfNode(input_valuenode); | |||
| // todo(yankai) | |||
| } else { | |||
| MS_ASSERT(false); | |||
| } | |||
| kernel_relation.input_tensor.push_back(tensor_ptr); | |||
| } | |||
| kernel_relation_infos_[cnode->fullname_with_scope()] = kernel_relation; | |||
| } | |||
| return 0; | |||
| } | |||
| GraphId TrainANFSession::graph_sum_ = 0; | |||
| KernelGraphPtr TrainANFSession::NewKernelGraph() { | |||
| auto graph = std::make_shared<KernelGraph>(); | |||
| graph->set_graph_id(graph_sum_); | |||
| graphs_[graph_sum_++] = graph; | |||
| return graph; | |||
| } | |||
| std::shared_ptr<KernelGraph> TrainANFSession::ConstructKernelGraph(const FuncGraphPtr &func_graph) { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| auto graph = NewKernelGraph(); | |||
| graph->set_return(func_graph->get_return()); | |||
| auto node_list = TopoSort(func_graph->get_return()); | |||
| std::vector<CNodePtr> cnode_order; | |||
| for (const auto &node : node_list) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| if (node->isa<CNode>()) { | |||
| auto cn_node = node->cast<CNodePtr>(); | |||
| cnode_order.push_back(cn_node); | |||
| } | |||
| } | |||
| graph->set_execution_order(cnode_order); | |||
| return graph; | |||
| } | |||
| GraphId TrainANFSession::CompileGraph(NotNull<FuncGraphPtr> func_graph) { | |||
| auto graph = ConstructKernelGraph(func_graph); | |||
| func_graph_ = func_graph; | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| MS_LOG(INFO) << "Set kernel info"; | |||
| SetKernelInfo(graph.get()); | |||
| (void)BuildKernelInputAndOutputFromFuncGraph(graph); | |||
| MS_LOG(INFO) << "Build kernel"; | |||
| auto ret = BuildKernel(graph.get()); | |||
| if (0 != ret) { | |||
| MS_LOG(EXCEPTION) << "BuildKernel failed"; | |||
| } | |||
| // return the graph id to backend | |||
| auto graph_id = graph->graph_id(); | |||
| graphs_[graph_id] = graph; | |||
| MS_LOG(INFO) << "Compile graph " << graph_id << " success"; | |||
| return graph_id; | |||
| } | |||
| void TrainANFSession::RunGraph(const GraphId &graph_id, const std::vector<lite::tensor::Tensor *> &inputs, | |||
| std::vector<lite::tensor::Tensor *> *outputs) { | |||
| auto &kernel_graph = graphs_[graph_id]; | |||
| MS_EXCEPTION_IF_NULL(kernel_graph); | |||
| MS_LOG(INFO) << "Bind input output address"; | |||
| // runtime_.BindInputOutput(kernel_graph.get(), inputs, outputs); -- will be bound in Run | |||
| // auto execution_order = kernel_graph->execution_order(); | |||
| // Todo : hangangqiang | |||
| // Reorder(&execution_order); | |||
| // kernel_graph->set_execution_order(execution_order); | |||
| MS_LOG(INFO) << "Run graph start"; | |||
| auto ret = runtime_.Run(kernel_graph.get(), (std::vector<lite::tensor::Tensor *> &)inputs, *outputs); | |||
| if (!ret) { | |||
| MS_LOG(EXCEPTION) << "Run graph failed"; | |||
| } | |||
| MS_LOG(INFO) << "Run graph end"; | |||
| } | |||
| void TrainANFSession::SetKernelInfo(const KernelGraph *kernel_graph) { | |||
| MS_EXCEPTION_IF_NULL(kernel_graph); | |||
| auto &kernel_nodes = kernel_graph->execution_order(); | |||
| for (const auto &kernel_node : kernel_nodes) { | |||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||
| auto kernel_info = std::make_shared<device::KernelInfo>(); | |||
| kernel_node->set_kernel_info(kernel_info); | |||
| } | |||
| } | |||
| int TrainANFSession::BuildKernel(const KernelGraph *kernel_graph) { | |||
| MS_EXCEPTION_IF_NULL(kernel_graph); | |||
| for (auto iter = kernel_relation_infos_.begin(); iter != kernel_relation_infos_.end(); ++iter) { | |||
| std::string kernel_name = iter->first; | |||
| KernelRelation anf_register = iter->second; | |||
| MS_EXCEPTION_IF_NULL(anf_register.cnode); | |||
| if (IsPrimitiveCNode(anf_register.cnode, prim::kPrimReturn)) { | |||
| continue; | |||
| } | |||
| auto value_node_prim = anf_register.cnode->input(0); | |||
| MS_EXCEPTION_IF_NULL(value_node_prim); | |||
| auto prim = GetValueNode<std::shared_ptr<lite::PrimitiveValue>>(value_node_prim); | |||
| MS_EXCEPTION_IF_NULL(prim); | |||
| auto node_primitive = (lite::Primitive *)(prim->GetPrimitive()); | |||
| MS_EXCEPTION_IF_NULL(node_primitive); | |||
| auto ret = node_primitive->InferShape(anf_register.input_tensor, anf_register.output_tensor); | |||
| if (0 != ret) { | |||
| MS_LOG(ERROR) << "InferShape failed, node : " << kernel_name; | |||
| return ret; | |||
| } | |||
| kernel::KernelKey desc{kernel::KERNEL_ARCH::kCPU, kNumberTypeFloat32, node_primitive->Type()}; | |||
| auto *kernel = lite::KernelFactory::GetInstance()->GetKernel(anf_register.input_tensor, anf_register.output_tensor, | |||
| node_primitive, context_.get(), desc); | |||
| if (nullptr == kernel) { | |||
| MS_LOG(ERROR) << "Create kernel return nullptr, name: " << kernel_name; | |||
| return -1; | |||
| } | |||
| std::shared_ptr<kernel::LiteKernel> kernel_mod(kernel); | |||
| kernel_mod->set_name(anf_register.cnode->fullname_with_scope()); | |||
| // kernel->train(); | |||
| auto kernel_info = dynamic_cast<device::KernelInfo *>(anf_register.cnode->kernel_info()); | |||
| MS_EXCEPTION_IF_NULL(kernel_info); | |||
| kernel_info->set_kernel_mod(kernel_mod); // XXX TODO -- only derived class KernelInfo has this method | |||
| } | |||
| return 0; | |||
| } | |||
| } // namespace session | |||
| } // namespace mindspore | |||
| @@ -1,76 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_SRC_TRAIN_TRAIN_SESSION_H_ | |||
| #define MINDSPORE_LITE_SRC_TRAIN_TRAIN_SESSION_H_ | |||
| #include <map> | |||
| #include <string> | |||
| #include <memory> | |||
| #include <vector> | |||
| #include <unordered_map> | |||
| #include "include/context.h" | |||
| #include "backend/session/session_basic.h" | |||
| #include "backend/session/kernel_graph.h" | |||
| #include "mindspore/lite/src/train/lite_kernel_runtime.h" | |||
| // #include "backend/session/session_factory.h" | |||
| namespace mindspore { | |||
| namespace lite::tensor { | |||
| class Tensor; | |||
| } | |||
| namespace session { | |||
| struct KernelRelation { | |||
| std::string node_full_name; | |||
| std::vector<lite::tensor::Tensor *> input_tensor; | |||
| std::vector<lite::tensor::Tensor *> output_tensor; | |||
| CNodePtr cnode; | |||
| }; | |||
| class TrainANFSession { | |||
| public: | |||
| explicit TrainANFSession(lite::Context *context) { Init(context); } | |||
| ~TrainANFSession() = default; | |||
| GraphId CompileGraph(NotNull<FuncGraphPtr> func_graph); | |||
| void RunGraph(const GraphId &graph_id, const std::vector<lite::tensor::Tensor *> &inputs, | |||
| std::vector<lite::tensor::Tensor *> *outputs); | |||
| // void RunGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) override | |||
| // {}; | |||
| protected: | |||
| void Init(lite::Context *context); | |||
| std::shared_ptr<lite::Context> context_ = nullptr; | |||
| std::unordered_map<GraphId, std::shared_ptr<KernelGraph>> graphs_; | |||
| static GraphId graph_sum_; | |||
| KernelGraphPtr NewKernelGraph(); | |||
| private: | |||
| // GraphId CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) override; | |||
| // GraphId CompileGraph(const char *model_buf, size_t size); | |||
| std::shared_ptr<KernelGraph> ConstructKernelGraph(const FuncGraphPtr &func_graph); | |||
| int BuildKernelInputAndOutputFromFuncGraph(const KernelGraphPtr &kernel_graph); | |||
| lite::tensor::Tensor *GetTensorForAnfNode(const AnfNodePtr anf_node); | |||
| void SetKernelInfo(const KernelGraph *kernel_graph); | |||
| int BuildKernel(const KernelGraph *kernel_graph); | |||
| lite::LiteInferKernelRuntime runtime_; | |||
| std::map<std::string, KernelRelation> kernel_relation_infos_; | |||
| FuncGraphPtr func_graph_ = NULL; | |||
| std::map<AnfNodePtr, lite::tensor::Tensor *> tensors_; | |||
| }; | |||
| } // namespace session | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_LITE_SRC_TRAIN_TRAIN_SESSION_H_ | |||
| @@ -1,267 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <algorithm> | |||
| #include "include/context.h" | |||
| #include "mindspore/ccsrc/runtime/device/kernel_info.h" | |||
| #include "mindspore/lite/src/train/train_session.h" | |||
| #include "mindspore/lite/src/kernel_factory.h" | |||
| #include "mindspore/lite/src/param_value_lite.h" | |||
| #include "utils/ms_utils.h" | |||
| #include "mindspore/lite/src/ops/ops.h" | |||
| #include "ir/anf.h" | |||
| #include "mindspore/lite/src/ir/tensor.h" | |||
| #include "abstract/abstract_value.h" | |||
| #include "backend/session/anf_runtime_algorithm.h" | |||
| #include "src/ir/primitive_value.h" | |||
| #include "src/train/model_impl.h" | |||
| namespace mindspore { | |||
| namespace session { | |||
| static std::vector<int> GetAnfNodeOutDims(const AnfNodePtr &anfNodePtr) { | |||
| auto nodeAbstract = anfNodePtr->abstract(); | |||
| if (nodeAbstract != nullptr) { | |||
| auto shape = nodeAbstract->GetShapeTrack(); | |||
| if (!shape->isa<abstract::Shape>()) { | |||
| MS_LOG(EXCEPTION) << "Not a Shape"; | |||
| return {}; | |||
| } | |||
| auto dims = dyn_cast<abstract::Shape>(shape)->shape(); | |||
| return dims; | |||
| } else { | |||
| MS_LOG(WARNING) << "abstract is nullptr, return empty dims"; | |||
| return {}; | |||
| } | |||
| } | |||
| static schema::Format GetAnfNodeFormat(const AnfNodePtr &anfNodePtr) { | |||
| auto nodeAbstract = anfNodePtr->abstract(); | |||
| if (nodeAbstract != nullptr) { | |||
| return schema::Format_NHWC; // XXX TODO -- extract Format from AnfNode | |||
| } else { | |||
| MS_LOG(WARNING) << "abstract is nullptr, return schema::Format_NHWC"; | |||
| return schema::Format_NHWC; | |||
| } | |||
| } | |||
| static TypeId GetAnfNodeOutTypeId(const AnfNodePtr &anfNodePtr) { | |||
| auto nodeAbstract = anfNodePtr->abstract(); | |||
| if (nodeAbstract != nullptr) { | |||
| return TypeId::kNumberTypeFloat32; // XXX TODO nodeAbstract->GetTypeTrack()->generic_type_id(); | |||
| } else { | |||
| MS_LOG(WARNING) << "abstract is nullptr, return kTypeUnknown"; | |||
| return TypeId::kTypeUnknown; | |||
| } | |||
| } | |||
| void TrainSession::Init(lite::Context *context) { | |||
| MS_EXCEPTION_IF_NULL(context); | |||
| this->context_ = std::make_shared<lite::Context>(context->thread_num_, context->allocator, context->device_ctx_); | |||
| } | |||
| lite::tensor::Tensor *TrainSession::GetTensorForAnfNode(const AnfNodePtr anf_node) { | |||
| lite::tensor::Tensor *out_tensor = tensors_[anf_node]; | |||
| if (out_tensor == NULL) { | |||
| out_tensor = new lite::tensor::Tensor(GetAnfNodeOutTypeId(anf_node), | |||
| GetAnfNodeOutDims(anf_node)); //, schema::NodeType_Parameter); | |||
| tensors_[anf_node] = out_tensor; | |||
| } | |||
| return out_tensor; | |||
| } | |||
| int TrainSession::BuildKernelInputAndOutputFromFuncGraph(const KernelGraphPtr &kernel_graph) { | |||
| auto return_node = kernel_graph->get_return(); | |||
| auto node_list = TopoSort(return_node); | |||
| auto model_imp = std::dynamic_pointer_cast<lite::train::ModelImpl>(func_graph_); | |||
| for (auto &node : node_list) { | |||
| if (!node->isa<CNode>()) { | |||
| continue; | |||
| } | |||
| KernelRelation kernel_relation; | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| kernel_relation.node_full_name = cnode->fullname_with_scope(); | |||
| kernel_relation.cnode = cnode; | |||
| std::vector<int> *cnode_io_indices = model_imp->GetCNodeInputOutputIndices(cnode->fullname_with_scope()); | |||
| if (cnode_io_indices == NULL) { | |||
| MS_LOG(WARNING) << "No IO vectors for " << cnode->fullname_with_scope(); | |||
| } else { | |||
| for (int i = 0; i < cnode_io_indices[1].size(); i++) { | |||
| AnfNodePtr anf_node = model_imp->GetAnfNode(cnode_io_indices[1].data()[i]); | |||
| kernel_relation.output_tensor.push_back(GetTensorForAnfNode(anf_node)); | |||
| } | |||
| } | |||
| lite::tensor::Tensor *tensor_ptr = nullptr; | |||
| for (size_t index = 1; index < cnode->inputs().size(); ++index) { | |||
| if (cnode->input(index)->isa<CNode>()) { | |||
| auto input_cnode = cnode->input(index)->cast<CNodePtr>(); | |||
| auto input_kernel_relation = kernel_relation_infos_[input_cnode->fullname_with_scope()]; | |||
| // todo not support multi-outputs kernel sudo as spilt | |||
| tensor_ptr = input_kernel_relation.output_tensor.front(); | |||
| } else if (cnode->input(index)->isa<Parameter>()) { | |||
| auto input_parameter = cnode->input(index)->cast<ParameterPtr>(); | |||
| auto para = input_parameter->default_param(); | |||
| auto param_value = std::dynamic_pointer_cast<ParamValueLite>(para); | |||
| // auto dims = param_value->tensor_shape(); | |||
| // tensor_ptr = new lite::tensor::Tensor(param_value->tensor_type(), dims); // schema::NodeType_ValueNode); | |||
| tensor_ptr = GetTensorForAnfNode(cnode->input(index)); | |||
| if ((param_value != nullptr) && (param_value->tensor_size() != 0)) { | |||
| tensor_ptr->SetData(param_value->tensor_addr()); | |||
| } | |||
| } else if (cnode->input(index)->isa<ValueNode>()) { | |||
| auto input_valuenode = cnode->input(index)->cast<ValueNodePtr>(); | |||
| // tensor_ptr = new lite::tensor::Tensor(GetAnfNodeOutTypeId(input_valuenode), | |||
| // GetAnfNodeOutDims(input_valuenode)); // schema::NodeType_Parameter); | |||
| tensor_ptr = GetTensorForAnfNode(input_valuenode); | |||
| // todo(yankai) | |||
| } else { | |||
| MS_ASSERT(false); | |||
| } | |||
| kernel_relation.input_tensor.push_back(tensor_ptr); | |||
| } | |||
| kernel_relation_infos_[cnode->fullname_with_scope()] = kernel_relation; | |||
| } | |||
| return 0; | |||
| } | |||
| #if 0 | |||
| GraphId TrainSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) { | |||
| auto graph_id = graph_sum_; | |||
| auto graph = SessionBasic::ConstructKernelGraph(lst, outputs); | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| BuildKernel(graph.get()); | |||
| MS_LOG(INFO) << "Assign kernel address"; | |||
| runtime_.AssignKernelAddress(graph.get()); | |||
| return graph_id; | |||
| } | |||
| GraphId TrainSession::CompileGraph(const char *model_buf, size_t size) { return 0; } | |||
| #else | |||
| GraphId TrainSession::graph_sum_ = 0; | |||
| KernelGraphPtr TrainSession::NewKernelGraph() { | |||
| auto graph = std::make_shared<KernelGraph>(); | |||
| graph->set_graph_id(graph_sum_); | |||
| graphs_[graph_sum_++] = graph; | |||
| return graph; | |||
| } | |||
| #endif | |||
| std::shared_ptr<KernelGraph> TrainSession::ConstructKernelGraph(const FuncGraphPtr &func_graph) { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| auto graph = NewKernelGraph(); | |||
| graph->set_return(func_graph->get_return()); | |||
| auto node_list = TopoSort(func_graph->get_return()); | |||
| std::vector<CNodePtr> cnode_order; | |||
| for (const auto &node : node_list) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| if (node->isa<CNode>()) { | |||
| auto cn_node = node->cast<CNodePtr>(); | |||
| cnode_order.push_back(cn_node); | |||
| } | |||
| } | |||
| graph->set_execution_order(cnode_order); | |||
| return graph; | |||
| } | |||
| GraphId TrainSession::CompileGraph(NotNull<FuncGraphPtr> func_graph) { | |||
| auto graph = ConstructKernelGraph(func_graph); | |||
| func_graph_ = func_graph; | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| MS_LOG(INFO) << "Set kernel info"; | |||
| SetKernelInfo(graph.get()); | |||
| (void)BuildKernelInputAndOutputFromFuncGraph(graph); | |||
| MS_LOG(INFO) << "Build kernel"; | |||
| auto ret = BuildKernel(graph.get()); | |||
| if (0 != ret) { | |||
| MS_LOG(EXCEPTION) << "BuildKernel failed"; | |||
| } | |||
| // return the graph id to backend | |||
| auto graph_id = graph->graph_id(); | |||
| graphs_[graph_id] = graph; | |||
| MS_LOG(INFO) << "Compile graph " << graph_id << " success"; | |||
| return graph_id; | |||
| } | |||
| void TrainSession::RunGraph(const GraphId &graph_id, const std::vector<lite::tensor::Tensor *> &inputs, | |||
| std::vector<lite::tensor::Tensor *> *outputs) { | |||
| auto &kernel_graph = graphs_[graph_id]; | |||
| MS_EXCEPTION_IF_NULL(kernel_graph); | |||
| MS_LOG(INFO) << "Bind input output address"; | |||
| // runtime_.BindInputOutput(kernel_graph.get(), inputs, outputs); -- will be bound in Run | |||
| // auto execution_order = kernel_graph->execution_order(); | |||
| // Todo : hangangqiang | |||
| // Reorder(&execution_order); | |||
| // kernel_graph->set_execution_order(execution_order); | |||
| MS_LOG(INFO) << "Run graph start"; | |||
| auto ret = runtime_.Run(kernel_graph.get(), (std::vector<lite::tensor::Tensor *> &)inputs, outputs); | |||
| if (!ret) { | |||
| MS_LOG(EXCEPTION) << "Run graph failed"; | |||
| } | |||
| MS_LOG(INFO) << "Run graph end"; | |||
| } | |||
| void TrainSession::SetKernelInfo(const KernelGraph *kernel_graph) { | |||
| MS_EXCEPTION_IF_NULL(kernel_graph); | |||
| auto &kernel_nodes = kernel_graph->execution_order(); | |||
| for (const auto &kernel_node : kernel_nodes) { | |||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||
| auto kernel_info = std::make_shared<device::KernelInfo>(); | |||
| kernel_node->set_kernel_info(kernel_info); | |||
| } | |||
| } | |||
| int TrainSession::BuildKernel(const KernelGraph *kernel_graph) { | |||
| MS_EXCEPTION_IF_NULL(kernel_graph); | |||
| for (auto iter = kernel_relation_infos_.begin(); iter != kernel_relation_infos_.end(); ++iter) { | |||
| std::string kernel_name = iter->first; | |||
| KernelRelation anf_register = iter->second; | |||
| MS_EXCEPTION_IF_NULL(anf_register.cnode); | |||
| if (IsPrimitiveCNode(anf_register.cnode, prim::kPrimReturn)) { | |||
| continue; | |||
| } | |||
| auto value_node_prim = anf_register.cnode->input(0); | |||
| MS_EXCEPTION_IF_NULL(value_node_prim); | |||
| auto prim = GetValueNode<std::shared_ptr<lite::PrimitiveValue>>(value_node_prim); | |||
| MS_EXCEPTION_IF_NULL(prim); | |||
| auto node_primitive = (lite::Primitive *)(prim->GetPrimitive()); | |||
| MS_EXCEPTION_IF_NULL(node_primitive); | |||
| auto ret = node_primitive->InferShape(anf_register.input_tensor, anf_register.output_tensor); | |||
| if (0 != ret) { | |||
| MS_LOG(ERROR) << "InferShape failed, node : " << kernel_name; | |||
| return ret; | |||
| } | |||
| kernel::KernelKey desc{kernel::KERNEL_ARCH::kCPU, kNumberTypeFloat32, node_primitive->Type()}; | |||
| auto *kernel = lite::KernelFactory::GetInstance()->GetKernel(anf_register.input_tensor, anf_register.output_tensor, | |||
| node_primitive, context_.get(), desc); | |||
| if (nullptr == kernel) { | |||
| MS_LOG(ERROR) << "Create kernel return nullptr, name: " << kernel_name; | |||
| return -1; | |||
| } | |||
| std::shared_ptr<kernel::LiteKernel> kernel_mod(kernel); | |||
| kernel_mod->set_name(anf_register.cnode->fullname_with_scope()); | |||
| // kernel->train(); | |||
| auto kernel_info = dynamic_cast<device::KernelInfo *>(anf_register.cnode->kernel_info()); | |||
| MS_EXCEPTION_IF_NULL(kernel_info); | |||
| kernel_info->set_kernel_mod(kernel_mod); // XXX TODO -- only derived class KernelInfo has this method | |||
| } | |||
| return 0; | |||
| } | |||
| } // namespace session | |||
| } // namespace mindspore | |||
| @@ -1,76 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_SRC_TRAIN_TRAIN_SESSION_H_ | |||
| #define MINDSPORE_LITE_SRC_TRAIN_TRAIN_SESSION_H_ | |||
| #include <map> | |||
| #include <string> | |||
| #include <memory> | |||
| #include <vector> | |||
| #include <unordered_map> | |||
| #include "include/context.h" | |||
| #include "backend/session/session_basic.h" | |||
| #include "backend/session/kernel_graph.h" | |||
| #include "mindspore/lite/src/train/lite_kernel_runtime.h" | |||
| // #include "backend/session/session_factory.h" | |||
| namespace mindspore { | |||
| namespace lite::tensor { | |||
| class Tensor; | |||
| } | |||
| namespace session { | |||
| struct KernelRelation { | |||
| std::string node_full_name; | |||
| std::vector<lite::tensor::Tensor *> input_tensor; | |||
| std::vector<lite::tensor::Tensor *> output_tensor; | |||
| CNodePtr cnode; | |||
| }; | |||
| class TrainSession { | |||
| public: | |||
| explicit TrainSession(lite::Context * context) { Init(context); } | |||
| ~TrainSession() = default; | |||
| GraphId CompileGraph(NotNull<FuncGraphPtr> func_graph); | |||
| void RunGraph(const GraphId &graph_id, const std::vector<lite::tensor::Tensor *> &inputs, | |||
| std::vector<lite::tensor::Tensor *> *outputs); | |||
| // void RunGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) override | |||
| // {}; | |||
| protected: | |||
| void Init(lite::Context *context); | |||
| std::shared_ptr<lite::Context> context_ = nullptr; | |||
| std::unordered_map<GraphId, std::shared_ptr<KernelGraph>> graphs_; | |||
| static GraphId graph_sum_; | |||
| KernelGraphPtr NewKernelGraph(); | |||
| private: | |||
| // GraphId CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) override; | |||
| // GraphId CompileGraph(const char *model_buf, size_t size); | |||
| std::shared_ptr<KernelGraph> ConstructKernelGraph(const FuncGraphPtr &func_graph); | |||
| int BuildKernelInputAndOutputFromFuncGraph(const KernelGraphPtr &kernel_graph); | |||
| lite::tensor::Tensor *GetTensorForAnfNode(const AnfNodePtr anf_node); | |||
| void SetKernelInfo(const KernelGraph *kernel_graph); | |||
| int BuildKernel(const KernelGraph *kernel_graph); | |||
| lite::LiteInferKernelRuntime runtime_; | |||
| std::map<std::string, KernelRelation> kernel_relation_infos_; | |||
| FuncGraphPtr func_graph_ = NULL; | |||
| std::map<AnfNodePtr, lite::tensor::Tensor *> tensors_; | |||
| }; | |||
| } // namespace session | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_LITE_SRC_TRAIN_TRAIN_SESSION_H_ | |||
| @@ -175,12 +175,10 @@ set(TEST_LITE_SRC | |||
| ${LITE_DIR}/src/ir/primitive_t_value.cc | |||
| ${LITE_DIR}/src/context.cc | |||
| ${LITE_DIR}/src/executor.cc | |||
| ${LITE_DIR}/src/kernel_factory.cc | |||
| ${LITE_DIR}/src/kernel_registry.cc | |||
| ${LITE_DIR}/src/lite_kernel.cc | |||
| ${LITE_DIR}/src/lite_session.cc | |||
| ${LITE_DIR}/src/model.cc | |||
| ${LITE_DIR}/src/model_impl.cc | |||
| ${LITE_DIR}/src/populate_parameter.cc | |||
| ${LITE_DIR}/src/scheduler.cc | |||
| ${LITE_DIR}/src/common/graph_util.cc | |||
| @@ -265,7 +263,7 @@ if (SUPPORT_TRAIN) | |||
| # ${LITE_DIR}/src/ir/primitive_value.cc | |||
| # ${LITE_DIR}/src/train/lite_kernel_runtime.cc | |||
| # ${LITE_DIR}/src/train/train_session.cc | |||
| # ${LITE_DIR}/src/train/model_impl.cc | |||
| # ${LITE_DIR}/src/train/model.cc | |||
| ${LITE_DIR}/src/lite_session.cc # temporary | |||
| ) | |||
| else() | |||
| @@ -20,15 +20,15 @@ | |||
| #include <algorithm> | |||
| #include "schema/inner/model_generated.h" | |||
| #include "tools/optimizer/common/gllo_utils.h" | |||
| #include "src/kernel_factory.h" | |||
| #include "tools/anf_exporter/anf_exporter.h" | |||
| #include "src/kernel_registry.h" | |||
| #include "src/scheduler.h" | |||
| #include "include/context.h" | |||
| #include "src/lite_session.h" | |||
| #include "src/ir/primitive_t_value.h" | |||
| #include "src/populate_parameter.h" | |||
| using mindspore::lite::KernelFactory; | |||
| using mindspore::lite::KernelRegistry; | |||
| using mindspore::lite::tensor::Tensor; | |||
| using mindspore::lite::PrimitiveTValue; | |||
| namespace mindspore::opt { | |||