Browse Source

optimize the interface of model and kernel registry

tags/v0.7.0-beta
zhaizhiqiang 5 years ago
parent
commit
1b900ca44f
45 changed files with 358 additions and 1686 deletions
  1. +1
    -6
      mindspore/lite/include/model.h
  2. +5
    -40
      mindspore/lite/src/CMakeLists.txt
  3. +0
    -53
      mindspore/lite/src/kernel_factory.cc
  4. +0
    -40
      mindspore/lite/src/kernel_factory.h
  5. +20
    -0
      mindspore/lite/src/kernel_registry.cc
  6. +4
    -1
      mindspore/lite/src/kernel_registry.h
  7. +299
    -9
      mindspore/lite/src/model.cc
  8. +0
    -297
      mindspore/lite/src/model_impl.cc
  9. +0
    -53
      mindspore/lite/src/model_impl.h
  10. +1
    -1
      mindspore/lite/src/runtime/kernel/arm/base/arg_min_max_base.cc
  11. +1
    -1
      mindspore/lite/src/runtime/kernel/arm/base/batch_to_space_base.cc
  12. +1
    -1
      mindspore/lite/src/runtime/kernel/arm/base/caffeprelu_base.cc
  13. +1
    -1
      mindspore/lite/src/runtime/kernel/arm/base/concat_base.cc
  14. +1
    -1
      mindspore/lite/src/runtime/kernel/arm/base/convolution_base.cc
  15. +1
    -1
      mindspore/lite/src/runtime/kernel/arm/base/crop_base.cc
  16. +1
    -1
      mindspore/lite/src/runtime/kernel/arm/base/depth_to_space_base.cc
  17. +1
    -1
      mindspore/lite/src/runtime/kernel/arm/base/fullconnection_base.cc
  18. +1
    -1
      mindspore/lite/src/runtime/kernel/arm/base/matmul_base.cc
  19. +1
    -1
      mindspore/lite/src/runtime/kernel/arm/base/pad.cc
  20. +1
    -1
      mindspore/lite/src/runtime/kernel/arm/base/pooling_base.cc
  21. +1
    -1
      mindspore/lite/src/runtime/kernel/arm/base/prelu_base.cc
  22. +1
    -1
      mindspore/lite/src/runtime/kernel/arm/base/prior_box.cc
  23. +1
    -1
      mindspore/lite/src/runtime/kernel/arm/base/reshape_base.cc
  24. +1
    -1
      mindspore/lite/src/runtime/kernel/arm/base/softmax_base.cc
  25. +1
    -1
      mindspore/lite/src/runtime/kernel/arm/base/split_base.cc
  26. +1
    -1
      mindspore/lite/src/runtime/kernel/arm/base/squeeze_base.cc
  27. +1
    -1
      mindspore/lite/src/runtime/kernel/arm/fp16/convolution_base_fp16.cc
  28. +1
    -1
      mindspore/lite/src/runtime/kernel/arm/fp32/convolution.cc
  29. +1
    -1
      mindspore/lite/src/runtime/kernel/arm/fp32/convolution_slidewindow.cc
  30. +1
    -1
      mindspore/lite/src/runtime/kernel/arm/fp32/transpose.h
  31. +1
    -1
      mindspore/lite/src/runtime/kernel/arm/fp32_grad/bn_grad.cc
  32. +4
    -4
      mindspore/lite/src/scheduler.cc
  33. +0
    -59
      mindspore/lite/src/train/base_ref_utils.cc
  34. +0
    -30
      mindspore/lite/src/train/base_ref_utils.h
  35. +0
    -50
      mindspore/lite/src/train/import.hpp
  36. +0
    -86
      mindspore/lite/src/train/lite_kernel_runtime.cc
  37. +0
    -50
      mindspore/lite/src/train/lite_kernel_runtime.h
  38. +0
    -145
      mindspore/lite/src/train/model_impl.cc
  39. +0
    -64
      mindspore/lite/src/train/model_impl.h
  40. +0
    -253
      mindspore/lite/src/train/train_anf_session.cc
  41. +0
    -76
      mindspore/lite/src/train/train_anf_session.h
  42. +0
    -267
      mindspore/lite/src/train/train_session.cc
  43. +0
    -76
      mindspore/lite/src/train/train_session.h
  44. +1
    -3
      mindspore/lite/test/CMakeLists.txt
  45. +2
    -2
      mindspore/lite/tools/optimizer/fusion/constant_folding_fusion.cc

+ 1
- 6
mindspore/lite/include/model.h View File

@@ -25,12 +25,12 @@
namespace mindspore {
#define MS_API __attribute__((visibility("default")))

namespace lite {
/// \brief ModelImpl defined the implement class of Model in MindSpore Lite.
///
/// \note List public class and interface for reference.
class ModelImpl;

namespace lite {
/// \brief Primitive defined as prototype of operator.
///
/// \note List public class and interface for reference.
@@ -67,11 +67,6 @@ class MS_API Model {
/// \return the pointer of graph defined in flatbuffers.
const schema::MetaGraph *GetMetaGraph() const;

/// \brief Get MindSpore Lite ModelImpl.
///
/// \return the pointer of MindSpore Lite ModelImpl.
ModelImpl *model_impl();

/// \brief Free MetaGraph in MindSpore Lite Model.
void FreeMetaGraph();



+ 5
- 40
mindspore/lite/src/CMakeLists.txt View File

@@ -8,10 +8,8 @@ set(LITE_SRC
${CMAKE_CURRENT_SOURCE_DIR}/ir/tensor.cc
${CMAKE_CURRENT_SOURCE_DIR}/context.cc
${CMAKE_CURRENT_SOURCE_DIR}/executor.cc
${CMAKE_CURRENT_SOURCE_DIR}/kernel_factory.cc
${CMAKE_CURRENT_SOURCE_DIR}/kernel_registry.cc
${CMAKE_CURRENT_SOURCE_DIR}/lite_kernel.cc
${CMAKE_CURRENT_SOURCE_DIR}/model.cc
${CMAKE_CURRENT_SOURCE_DIR}/populate_parameter.cc
${CMAKE_CURRENT_SOURCE_DIR}/scheduler.cc
)
@@ -21,44 +19,11 @@ if (SUPPORT_GPU)
list(APPEND LITE_SRC ${CMAKE_CURRENT_SOURCE_DIR}/runtime/kernel/opencl/utils.cc)
endif()

if (SUPPORT_TRAIN)
set(ANF_SRC
${ANF_SRC}
# ${CCSRC_DIR}/common/trans.cc
# ${CCSRC_DIR}/utils/lite/base_ref_utils.cc
# ${CCSRC_DIR}/runtime/kernel/kernel_compiler/kernel_build_info.cc
# ${CCSRC_DIR}/session/lite/anf_runtime_algorithm_extends.cc
# ${CCSRC_DIR}/session/lite/session_basic_extends.cc
# ${CCSRC_DIR}/session/anf_runtime_algorithm.cc
# ${CCSRC_DIR}/session/session_basic.cc
# ${CCSRC_DIR}/session/kernel_graph.cc
# ${CCSRC_DIR}/session/session_factory.cc
# ${CCSRC_DIR}/device/kernel_info.cc
# ${CCSRC_DIR}/device/kernel_runtime.cc
# ${CCSRC_DIR}/device/lite/kernel_runtime_extends.cc
)
set(PASS_SRC)
set(LITE_SRC
${LITE_SRC}
${ANF_SRC}
# ${PASS_SRC}
# ${CMAKE_CURRENT_SOURCE_DIR}/common/anf_importer/anf_importer.cc
# ${CMAKE_CURRENT_SOURCE_DIR}/common/anf_importer/import_from_meta_graph.cc
# ${CMAKE_CURRENT_SOURCE_DIR}/ir/primitive_value.cc
# ${CMAKE_CURRENT_SOURCE_DIR}/train/lite_kernel_runtime.cc
# ${CMAKE_CURRENT_SOURCE_DIR}/train/train_session.cc
# ${CMAKE_CURRENT_SOURCE_DIR}/train/model_impl.cc
${CMAKE_CURRENT_SOURCE_DIR}/lite_session.cc # temporary
${CMAKE_CURRENT_SOURCE_DIR}/model_impl.cc # temporary
)
else ()
set(LITE_SRC
${LITE_SRC}
${CMAKE_CURRENT_SOURCE_DIR}/lite_session.cc
${CMAKE_CURRENT_SOURCE_DIR}/model_impl.cc
)
endif ()
set(LITE_SRC
${LITE_SRC}
${CMAKE_CURRENT_SOURCE_DIR}/lite_session.cc
${CMAKE_CURRENT_SOURCE_DIR}/model.cc
)

if (SUPPORT_GPU)
set(LITE_SRC


+ 0
- 53
mindspore/lite/src/kernel_factory.cc View File

@@ -1,53 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "mindspore/lite/src/kernel_factory.h"
#include "utils/log_adapter.h"
#include "src/populate_parameter.h"
#include "schema/model_generated.h"

using mindspore::kernel::KERNEL_ARCH;
using mindspore::kernel::KernelKey;
using mindspore::kernel::LiteKernel;

namespace mindspore::lite {
KernelFactory::KernelFactory() = default;

KernelFactory::~KernelFactory() = default;

KernelFactory *KernelFactory::GetInstance() {
static KernelFactory instance;
return &instance;
}

LiteKernel *KernelFactory::GetKernel(const std::vector<tensor::Tensor *> &in_tensors,
const std::vector<tensor::Tensor *> &out_tensors, const lite::Primitive *primitive,
const Context *ctx, const kernel::KernelKey &key) {
MS_EXCEPTION_IF_NULL(primitive);
MS_EXCEPTION_IF_NULL(ctx);
auto parameter = kernel::PopulateParameter(primitive);
if (parameter == nullptr) {
MS_LOG(ERROR) << "PopulateParameter return nullptr, type: " << schema::EnumNamePrimitiveType(primitive->Type());
return nullptr;
}
auto creator = KernelRegistry::GetInstance()->GetCreator(key);
if (creator != nullptr) {
auto kernel = creator(in_tensors, out_tensors, parameter, ctx, key, primitive);
return kernel;
}
return nullptr;
}
} // namespace mindspore::lite

+ 0
- 40
mindspore/lite/src/kernel_factory.h View File

@@ -1,40 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_LITE_SRC_KERNEL_FACTORY_H_
#define MINDSPORE_LITE_SRC_KERNEL_FACTORY_H_

#include <vector>
#include "mindspore/lite/src/lite_kernel.h"
#include "mindspore/lite/src/kernel_registry.h"
#include "mindspore/lite/include/context.h"
#include "mindspore/lite/src/ir/tensor.h"
#include "schema/model_generated.h"

namespace mindspore::lite {
class KernelFactory {
public:
KernelFactory();
virtual ~KernelFactory();

static KernelFactory *GetInstance();
kernel::LiteKernel *GetKernel(const std::vector<tensor::Tensor *> &in_tensors,
const std::vector<tensor::Tensor *> &out_tensors, const lite::Primitive *primitive,
const Context *ctx, const kernel::KernelKey &key);
};
} // namespace mindspore::lite

#endif // MINDSPORE_LITE_SRC_KERNEL_FACTORY_H_

+ 20
- 0
mindspore/lite/src/kernel_registry.cc View File

@@ -16,6 +16,7 @@
#include "src/kernel_registry.h"
#include "include/errorcode.h"
#include "ir/dtype/type_id.h"
#include "src/populate_parameter.h"
#ifdef ENABLE_ARM64
#include <asm/hwcap.h>
#include "common/utils.h"
@@ -120,4 +121,23 @@ void KernelRegistry::RegKernel(const KERNEL_ARCH arch, const TypeId data_type, c
bool KernelRegistry::Merge(const std::unordered_map<KernelKey, KernelCreator> &newCreators) { return false; }

const kernel::KernelCreator *KernelRegistry::GetCreatorArrays() { return creator_arrays_; }

kernel::LiteKernel *KernelRegistry::GetKernel(const std::vector<tensor::Tensor *> &in_tensors,
const std::vector<tensor::Tensor *> &out_tensors,
const lite::Primitive *primitive, const Context *ctx,
const kernel::KernelKey &key) {
MS_EXCEPTION_IF_NULL(primitive);
MS_EXCEPTION_IF_NULL(ctx);
auto parameter = kernel::PopulateParameter(primitive);
if (parameter == nullptr) {
MS_LOG(ERROR) << "PopulateParameter return nullptr, type: " << schema::EnumNamePrimitiveType(primitive->Type());
return nullptr;
}
auto creator = GetCreator(key);
if (creator != nullptr) {
auto kernel = creator(in_tensors, out_tensors, parameter, ctx, key, primitive);
return kernel;
}
return nullptr;
}
} // namespace mindspore::lite

+ 4
- 1
mindspore/lite/src/kernel_registry.h View File

@@ -17,9 +17,9 @@
#ifndef MINDSPORE_LITE_SRC_KERNEL_REGISTRY_H_
#define MINDSPORE_LITE_SRC_KERNEL_REGISTRY_H_

#include <map>
#include <string>
#include <unordered_map>
#include <vector>
#include "src/lite_kernel.h"
#include "schema/model_generated.h"

@@ -39,6 +39,9 @@ class KernelRegistry {
void RegKernel(const kernel::KERNEL_ARCH arch, const TypeId data_type, const schema::PrimitiveType type,
kernel::KernelCreator creator);
bool Merge(const std::unordered_map<kernel::KernelKey, kernel::KernelCreator> &newCreators);
kernel::LiteKernel *GetKernel(const std::vector<tensor::Tensor *> &in_tensors,
const std::vector<tensor::Tensor *> &out_tensors, const lite::Primitive *primitive,
const Context *ctx, const kernel::KernelKey &key);

protected:
kernel::KernelCreator *creator_arrays_ = nullptr;


+ 299
- 9
mindspore/lite/src/model.cc View File

@@ -14,16 +14,310 @@
* limitations under the License.
*/

// #ifdef SUPPORT_TRAIN
// #include "src/train/model_impl.h"
// #else
#include "src/model_impl.h"
// #endif
#include "include/model.h"
#include "utils/log_adapter.h"
#include "src/ops/ops.h"

namespace mindspore::lite {

class ModelImpl {
public:
static ModelImpl *Import(const char *model_buf, size_t size);
ModelImpl() = default;
explicit ModelImpl(const char *model_buf, size_t size) : model_buf_(model_buf), buf_size_(size) {
meta_graph_ = schema::GetMetaGraph(model_buf);
}
virtual ~ModelImpl();
lite::Primitive *GetOp(const std::string &name) const;
const schema::MetaGraph *meta_graph() const;
void FreeMetaGraph();
int BuildOps();

protected:
lite::Primitive *CopyPrimitive(const schema::Primitive *src_prim);

protected:
const char *model_buf_;
size_t buf_size_;
const schema::MetaGraph *meta_graph_ = nullptr;
std::map<std::string, lite::Primitive *> ops_;
};

ModelImpl *ModelImpl::Import(const char *model_buf, size_t size) {
if (model_buf == nullptr) {
MS_LOG(ERROR) << "The model buf is nullptr";
return nullptr;
}
flatbuffers::Verifier verify((const uint8_t *)model_buf, size);
if (!schema::VerifyMetaGraphBuffer(verify)) {
MS_LOG(ERROR) << "The buffer is invalid and fail to create graph.";
return nullptr;
}
auto *inner_model_buf = new (std::nothrow) char[size];
if (inner_model_buf == nullptr) {
MS_LOG(ERROR) << "new model buf fail.";
return nullptr;
}
memcpy(inner_model_buf, model_buf, size);
auto model = new (std::nothrow) ModelImpl(inner_model_buf, size);
if (model == nullptr) {
MS_LOG(ERROR) << "Create modelImpl failed";
return nullptr;
}
auto ret = model->BuildOps();
if (0 != ret) {
MS_LOG(ERROR) << "BuildOps failed";
return nullptr;
}
return model;
}

lite::Primitive *ModelImpl::GetOp(const std::string &name) const {
auto iter = ops_.find(name);
if (iter == ops_.end()) {
return nullptr;
} else {
return iter->second;
}
}

ModelImpl::~ModelImpl() {
delete[](this->model_buf_);
for (auto iter : ops_) {
delete (iter.second);
}
ops_.clear();
}

void ModelImpl::FreeMetaGraph() {
delete[](this->model_buf_);
model_buf_ = nullptr;
}

const schema::MetaGraph *ModelImpl::meta_graph() const { return this->meta_graph_; }

lite::Primitive *ModelImpl::CopyPrimitive(const schema::Primitive *src_prim) {
MS_EXCEPTION_IF_NULL(src_prim);
auto op_type = src_prim->value_type();
switch (op_type) {
case schema::PrimitiveType_SoftMax:
return new lite::SoftMax(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Activation:
return new lite::Activation(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Conv2D:
return new lite::Conv2D(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_DeConv2D:
return new lite::DeConv2D(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Reduce:
return new lite::Reduce(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Pooling:
return new lite::Pooling(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_DepthwiseConv2D:
return new lite::DepthwiseConv2D(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_FusedBatchNorm:
return new lite::FusedBatchNorm(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_BatchNorm:
return new lite::BatchNorm(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_FullConnection:
return new lite::FullConnection(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Power:
return new lite::Power(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Range:
return new lite::Range(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Mul:
return new lite::Mul(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Add:
return new lite::Add(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Sub:
return new lite::Sub(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Div:
return new lite::Div(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_BiasAdd:
return new lite::BiasAdd(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_ExpandDims:
return new lite::ExpandDims(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_ArgMax:
return new lite::ArgMax(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_ArgMin:
return new lite::ArgMin(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Cast:
return new lite::Cast(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Reshape:
return new lite::Reshape(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Scale:
return new lite::Scale(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Eltwise:
return new lite::Eltwise(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Concat:
return new lite::Concat(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Fill:
return new lite::Fill(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Transpose:
return new lite::Transpose(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Slice:
return new lite::Slice(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Squeeze:
return new lite::Squeeze(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Nchw2Nhwc:
return new lite::Nchw2Nhwc(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Nhwc2Nchw:
return new lite::Nhwc2Nchw(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Flatten:
return new lite::Flatten(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Mean:
return new lite::Mean(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Stack:
return new lite::Stack(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Crop:
return new lite::Crop(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_SquaredDifference:
return new lite::SquaredDifference(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_AddN:
return new lite::AddN(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Abs:
return new lite::Abs(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Sin:
return new lite::Sin(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Cos:
return new lite::Cos(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Log:
return new lite::Log(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Sqrt:
return new lite::Sqrt(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Rsqrt:
return new lite::Rsqrt(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Square:
return new lite::Square(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Exp:
return new lite::Exp(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Gather:
return new lite::Gather(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_GatherNd:
return new lite::GatherNd(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_LocalResponseNormalization:
return new lite::LocalResponseNormalization(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Maximum:
return new lite::Maximum(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Minimum:
return new lite::Minimum(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Pad:
return new lite::Pad(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_StridedSlice:
return new lite::StridedSlice(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Prelu:
return new lite::Prelu(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_CaffePReLU:
return new lite::CaffePReLU(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Round:
return new lite::Round(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Reverse:
return new lite::Reverse(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_ReverseSequence:
return new lite::ReverseSequence(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_LogicalAnd:
return new lite::LogicalAnd(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_LogicalOr:
return new lite::LogicalOr(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_LogicalNot:
return new lite::LogicalNot(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_FloorDiv:
return new lite::FloorDiv(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_FloorMod:
return new lite::FloorMod(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Equal:
return new lite::Equal(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_NotEqual:
return new lite::NotEqual(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Less:
return new lite::Less(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_LessEqual:
return new lite::LessEqual(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Greater:
return new lite::Greater(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_GreaterEqual:
return new lite::GreaterEqual(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Floor:
return new lite::Floor(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Ceil:
return new lite::Ceil(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Split:
return new lite::Split(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_OneHot:
return new lite::OneHot(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_SpaceToDepth:
return new lite::SpaceToDepth(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Tile:
return new lite::Tile(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Resize:
return new lite::Resize(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Unstack:
return new lite::Unstack(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Unique:
return new lite::Unique(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_TopK:
return new lite::TopK(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_MatMul:
return new lite::MatMul(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_QuantDTypeCast:
return new lite::QuantDTypeCast(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_EmbeddingLookup:
return new lite::EmbeddingLookup(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Elu:
return new lite::Elu(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_DeDepthwiseConv2D:
return new lite::DeconvDepthwiseConv2D(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Shape:
return new lite::Shape(const_cast<schema::Primitive *>(src_prim));
default:
break;
}
return nullptr;
}

int ModelImpl::BuildOps() {
if (this->meta_graph_ == nullptr) {
MS_LOG(ERROR) << "mete_graph is nullptr";
return -1;
}
MS_EXCEPTION_IF_NULL(meta_graph_->nodes());
for (size_t i = 0; i < meta_graph_->nodes()->size(); i++) {
auto cNode = meta_graph_->nodes()->GetAs<schema::CNode>(i);
auto name = cNode->name()->str();
auto srcPrim = cNode->primitive();

this->ops_[name] = CopyPrimitive(srcPrim);
// flatbuffers::FlatBufferBuilder fbb(1024);
// schema::Conv2DBuilder conv2DBuilder(fbb);
// conv2DBuilder.add_padMode(srcPrim->value_as_Conv2D()->padMode());
// conv2DBuilder.add_channelOut(srcPrim->value_as_Conv2D()->channelOut());
// conv2DBuilder.add_channelIn(srcPrim->value_as_Conv2D()->channelIn());
// conv2DBuilder.add_strideH(srcPrim->value_as_Conv2D()->strideH());
// conv2DBuilder.add_strideW(srcPrim->value_as_Conv2D()->strideW());
// conv2DBuilder.add_dilateH(srcPrim->value_as_Conv2D()->dilateH());
// conv2DBuilder.add_dilateW(srcPrim->value_as_Conv2D()->dilateW());
// conv2DBuilder.add_kernelH(srcPrim->value_as_Conv2D()->kernelH());
// conv2DBuilder.add_kernelW(srcPrim->value_as_Conv2D()->kernelW());
// conv2DBuilder.add_padUp(srcPrim->value_as_Conv2D()->padUp());
// conv2DBuilder.add_padDown(srcPrim->value_as_Conv2D()->padDown());
// conv2DBuilder.add_padLeft(srcPrim->value_as_Conv2D()->padLeft());
// conv2DBuilder.add_padRight(srcPrim->value_as_Conv2D()->padRight());
// conv2DBuilder.add_format(srcPrim->value_as_Conv2D()->format());
// conv2DBuilder.add_group(srcPrim->value_as_Conv2D()->group());
// conv2DBuilder.add_activationType(srcPrim->value_as_Conv2D()->activationType());
// schema::PrimitiveBuilder primBuilder(fbb);
// primBuilder.add_value_type(srcPrim->value_type());
// primBuilder.add_value(conv2DBuilder.Finish());
//
// fbb.Finish(conv2DBuilder.Finish());
// auto buf = fbb.GetBufferPointer();
// auto conv2D = flatbuffers::GetRoot<schema::Conv2D>(buf);
// fbb.Clear();
//
// return const_cast<mindspore::predict::OpDef *>(opDef);
}
return 0;
}

Model *Model::Import(const char *model_buf, size_t size) {
auto model = new Model();
if (model_buf == nullptr) {
@@ -55,8 +349,4 @@ const schema::MetaGraph *Model::GetMetaGraph() const {
return model_impl_->meta_graph();
}

ModelImpl *Model::model_impl() {
MS_EXCEPTION_IF_NULL(model_impl_);
return this->model_impl_;
}
} // namespace mindspore::lite

+ 0
- 297
mindspore/lite/src/model_impl.cc View File

@@ -1,297 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <memory>
#include <string>
#include "src/model_impl.h"
#include "utils/log_adapter.h"

namespace mindspore::lite {
ModelImpl *ModelImpl::Import(const char *model_buf, size_t size) {
if (model_buf == nullptr) {
MS_LOG(ERROR) << "The model buf is nullptr";
return nullptr;
}
flatbuffers::Verifier verify((const uint8_t *)model_buf, size);
if (!schema::VerifyMetaGraphBuffer(verify)) {
MS_LOG(ERROR) << "The buffer is invalid and fail to create graph.";
return nullptr;
}
auto *inner_model_buf = new (std::nothrow) char[size];
if (inner_model_buf == nullptr) {
MS_LOG(ERROR) << "new model buf fail.";
return nullptr;
}
memcpy(inner_model_buf, model_buf, size);
auto model = new (std::nothrow) ModelImpl(inner_model_buf, size);
if (model == nullptr) {
MS_LOG(ERROR) << "Create modelImpl failed";
return nullptr;
}
auto ret = model->BuildOps();
if (0 != ret) {
MS_LOG(ERROR) << "BuildOps failed";
return nullptr;
}
return model;
}

lite::Primitive *ModelImpl::GetOp(const std::string &name) const {
auto iter = ops_.find(name);
if (iter == ops_.end()) {
return nullptr;
} else {
return iter->second;
}
}

ModelImpl::~ModelImpl() {
delete[](this->model_buf_);
for (auto iter : ops_) {
delete (iter.second);
}
ops_.clear();
}

void ModelImpl::FreeMetaGraph() {
delete[](this->model_buf_);
model_buf_ = nullptr;
}

const schema::MetaGraph *ModelImpl::meta_graph() const { return this->meta_graph_; }

lite::Primitive *ModelImpl::CopyPrimitive(const schema::Primitive *src_prim) {
MS_EXCEPTION_IF_NULL(src_prim);
auto op_type = src_prim->value_type();
switch (op_type) {
case schema::PrimitiveType_SoftMax:
return new lite::SoftMax(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Activation:
return new lite::Activation(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Conv2D:
return new lite::Conv2D(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_DeConv2D:
return new lite::DeConv2D(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Reduce:
return new lite::Reduce(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Pooling:
return new lite::Pooling(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_DepthwiseConv2D:
return new lite::DepthwiseConv2D(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_FusedBatchNorm:
return new lite::FusedBatchNorm(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_BatchNorm:
return new lite::BatchNorm(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_FullConnection:
return new lite::FullConnection(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Power:
return new lite::Power(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Range:
return new lite::Range(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Mul:
return new lite::Mul(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Add:
return new lite::Add(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Sub:
return new lite::Sub(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Div:
return new lite::Div(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_BiasAdd:
return new lite::BiasAdd(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_ExpandDims:
return new lite::ExpandDims(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_ArgMax:
return new lite::ArgMax(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_ArgMin:
return new lite::ArgMin(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Cast:
return new lite::Cast(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Reshape:
return new lite::Reshape(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Scale:
return new lite::Scale(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Eltwise:
return new lite::Eltwise(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Concat:
return new lite::Concat(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Fill:
return new lite::Fill(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Transpose:
return new lite::Transpose(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Slice:
return new lite::Slice(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Squeeze:
return new lite::Squeeze(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Nchw2Nhwc:
return new lite::Nchw2Nhwc(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Nhwc2Nchw:
return new lite::Nhwc2Nchw(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Flatten:
return new lite::Flatten(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Mean:
return new lite::Mean(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Stack:
return new lite::Stack(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Crop:
return new lite::Crop(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_SquaredDifference:
return new lite::SquaredDifference(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_AddN:
return new lite::AddN(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Abs:
return new lite::Abs(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Sin:
return new lite::Sin(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Cos:
return new lite::Cos(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Log:
return new lite::Log(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Sqrt:
return new lite::Sqrt(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Rsqrt:
return new lite::Rsqrt(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Square:
return new lite::Square(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Exp:
return new lite::Exp(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Gather:
return new lite::Gather(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_GatherNd:
return new lite::GatherNd(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_LocalResponseNormalization:
return new lite::LocalResponseNormalization(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Maximum:
return new lite::Maximum(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Minimum:
return new lite::Minimum(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Pad:
return new lite::Pad(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_StridedSlice:
return new lite::StridedSlice(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Prelu:
return new lite::Prelu(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_CaffePReLU:
return new lite::CaffePReLU(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Round:
return new lite::Round(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Reverse:
return new lite::Reverse(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_ReverseSequence:
return new lite::ReverseSequence(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_LogicalAnd:
return new lite::LogicalAnd(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_LogicalOr:
return new lite::LogicalOr(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_LogicalNot:
return new lite::LogicalNot(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_FloorDiv:
return new lite::FloorDiv(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_FloorMod:
return new lite::FloorMod(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Equal:
return new lite::Equal(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_NotEqual:
return new lite::NotEqual(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Less:
return new lite::Less(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_LessEqual:
return new lite::LessEqual(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Greater:
return new lite::Greater(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_GreaterEqual:
return new lite::GreaterEqual(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Floor:
return new lite::Floor(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Ceil:
return new lite::Ceil(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Split:
return new lite::Split(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_OneHot:
return new lite::OneHot(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_SpaceToDepth:
return new lite::SpaceToDepth(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Tile:
return new lite::Tile(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Resize:
return new lite::Resize(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Unstack:
return new lite::Unstack(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Unique:
return new lite::Unique(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_TopK:
return new lite::TopK(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_MatMul:
return new lite::MatMul(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_QuantDTypeCast:
return new lite::QuantDTypeCast(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_EmbeddingLookup:
return new lite::EmbeddingLookup(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Elu:
return new lite::Elu(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_DeDepthwiseConv2D:
return new lite::DeconvDepthwiseConv2D(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Shape:
return new lite::Shape(const_cast<schema::Primitive *>(src_prim));
default:
break;
}
return nullptr;
}

int ModelImpl::BuildOps() {
if (this->meta_graph_ == nullptr) {
MS_LOG(ERROR) << "mete_graph is nullptr";
return -1;
}
MS_EXCEPTION_IF_NULL(meta_graph_->nodes());
for (size_t i = 0; i < meta_graph_->nodes()->size(); i++) {
auto cNode = meta_graph_->nodes()->GetAs<schema::CNode>(i);
auto name = cNode->name()->str();
auto srcPrim = cNode->primitive();

this->ops_[name] = CopyPrimitive(srcPrim);
// flatbuffers::FlatBufferBuilder fbb(1024);
// schema::Conv2DBuilder conv2DBuilder(fbb);
// conv2DBuilder.add_padMode(srcPrim->value_as_Conv2D()->padMode());
// conv2DBuilder.add_channelOut(srcPrim->value_as_Conv2D()->channelOut());
// conv2DBuilder.add_channelIn(srcPrim->value_as_Conv2D()->channelIn());
// conv2DBuilder.add_strideH(srcPrim->value_as_Conv2D()->strideH());
// conv2DBuilder.add_strideW(srcPrim->value_as_Conv2D()->strideW());
// conv2DBuilder.add_dilateH(srcPrim->value_as_Conv2D()->dilateH());
// conv2DBuilder.add_dilateW(srcPrim->value_as_Conv2D()->dilateW());
// conv2DBuilder.add_kernelH(srcPrim->value_as_Conv2D()->kernelH());
// conv2DBuilder.add_kernelW(srcPrim->value_as_Conv2D()->kernelW());
// conv2DBuilder.add_padUp(srcPrim->value_as_Conv2D()->padUp());
// conv2DBuilder.add_padDown(srcPrim->value_as_Conv2D()->padDown());
// conv2DBuilder.add_padLeft(srcPrim->value_as_Conv2D()->padLeft());
// conv2DBuilder.add_padRight(srcPrim->value_as_Conv2D()->padRight());
// conv2DBuilder.add_format(srcPrim->value_as_Conv2D()->format());
// conv2DBuilder.add_group(srcPrim->value_as_Conv2D()->group());
// conv2DBuilder.add_activationType(srcPrim->value_as_Conv2D()->activationType());
// schema::PrimitiveBuilder primBuilder(fbb);
// primBuilder.add_value_type(srcPrim->value_type());
// primBuilder.add_value(conv2DBuilder.Finish());
//
// fbb.Finish(conv2DBuilder.Finish());
// auto buf = fbb.GetBufferPointer();
// auto conv2D = flatbuffers::GetRoot<schema::Conv2D>(buf);
// fbb.Clear();
//
// return const_cast<mindspore::predict::OpDef *>(opDef);
}
return 0;
}
} // namespace mindspore::lite

+ 0
- 53
mindspore/lite/src/model_impl.h View File

@@ -1,53 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_LITE_SRC_MODEL_IMPL_H_
#define MINDSPORE_LITE_SRC_MODEL_IMPL_H_

#include <map>
#include <memory>
#include <string>
#include "schema/model_generated.h"
#include "src/ops/ops.h"

namespace mindspore {
namespace lite {
class ModelImpl {
public:
static ModelImpl *Import(const char *model_buf, size_t size);
ModelImpl() = default;
explicit ModelImpl(const char *model_buf, size_t size) : model_buf_(model_buf), buf_size_(size) {
meta_graph_ = schema::GetMetaGraph(model_buf);
}
virtual ~ModelImpl();
lite::Primitive *GetOp(const std::string &name) const;
const schema::MetaGraph *meta_graph() const;
void FreeMetaGraph();
int BuildOps();

protected:
lite::Primitive *CopyPrimitive(const schema::Primitive *src_prim);

protected:
const char *model_buf_;
size_t buf_size_;
const schema::MetaGraph *meta_graph_ = nullptr;
std::map<std::string, lite::Primitive *> ops_;
};
} // namespace lite
} // namespace mindspore

#endif // MINDSPORE_LITE_INCLUDE_MODEL_H_

+ 1
- 1
mindspore/lite/src/runtime/kernel/arm/base/arg_min_max_base.cc View File

@@ -19,7 +19,7 @@
#include "src/runtime/kernel/arm/int8/argminmax_int8.h"
#include "src/runtime/kernel/arm/nnacl/arithmetic_common.h"
#include "schema/model_generated.h"
#include "src/kernel_factory.h"
#include "src/kernel_registry.h"
#include "include/errorcode.h"
#include "include/context.h"



+ 1
- 1
mindspore/lite/src/runtime/kernel/arm/base/batch_to_space_base.cc View File

@@ -18,7 +18,7 @@
#include "src/runtime/kernel/arm/fp32/batch_to_space.h"
#include "src/runtime/kernel/arm/int8/batch_to_space_int8.h"
#include "schema/model_generated.h"
#include "src/kernel_factory.h"
#include "src/kernel_registry.h"
#include "include/errorcode.h"
#include "include/context.h"



+ 1
- 1
mindspore/lite/src/runtime/kernel/arm/base/caffeprelu_base.cc View File

@@ -16,7 +16,7 @@
#include "src/runtime/kernel/arm/base/caffeprelu_base.h"
#include <vector>
#include "schema/model_generated.h"
#include "src/kernel_factory.h"
#include "src/kernel_registry.h"
#include "include/errorcode.h"
#include "include/context.h"



+ 1
- 1
mindspore/lite/src/runtime/kernel/arm/base/concat_base.cc View File

@@ -19,7 +19,7 @@
#include "src/runtime/kernel/arm/fp32/concat.h"
#include "src/runtime/kernel/arm/nnacl/fp32/concat.h"
#include "schema/model_generated.h"
#include "src/kernel_factory.h"
#include "src/kernel_registry.h"
#include "include/errorcode.h"
#include "include/context.h"



+ 1
- 1
mindspore/lite/src/runtime/kernel/arm/base/convolution_base.cc View File

@@ -17,7 +17,7 @@
#include "src/runtime/kernel/arm/base/convolution_base.h"
#include <float.h>
#include "schema/model_generated.h"
#include "src/kernel_factory.h"
#include "src/kernel_registry.h"
#include "include/errorcode.h"

using mindspore::lite::KernelRegistrar;


+ 1
- 1
mindspore/lite/src/runtime/kernel/arm/base/crop_base.cc View File

@@ -18,7 +18,7 @@
#include "src/runtime/kernel/arm/int8/crop_int8.h"
#include "src/runtime/kernel/arm/fp32/crop.h"
#include "schema/model_generated.h"
#include "src/kernel_factory.h"
#include "src/kernel_registry.h"
#include "include/errorcode.h"
#include "include/context.h"



+ 1
- 1
mindspore/lite/src/runtime/kernel/arm/base/depth_to_space_base.cc View File

@@ -19,7 +19,7 @@
#include "src/runtime/kernel/arm/int8/depth_to_space_int8.h"
#include "src/runtime/kernel/arm/nnacl/arithmetic_common.h"
#include "schema/model_generated.h"
#include "src/kernel_factory.h"
#include "src/kernel_registry.h"
#include "include/errorcode.h"
#include "include/context.h"



+ 1
- 1
mindspore/lite/src/runtime/kernel/arm/base/fullconnection_base.cc View File

@@ -17,7 +17,7 @@
#include "src/runtime/kernel/arm/int8/fullconnection_int8.h"
#include "src/runtime/kernel/arm/fp32/fullconnection.h"
#include "schema/model_generated.h"
#include "src/kernel_factory.h"
#include "src/kernel_registry.h"
#include "include/errorcode.h"
#include "include/context.h"



+ 1
- 1
mindspore/lite/src/runtime/kernel/arm/base/matmul_base.cc View File

@@ -16,7 +16,7 @@
#include "src/runtime/kernel/arm/base/matmul_base.h"
#include "src/runtime/kernel/arm/fp32/matmul.h"
#include "src/runtime/kernel/arm/int8/matmul_int8.h"
#include "src/kernel_factory.h"
#include "src/kernel_registry.h"
#include "include/errorcode.h"
#include "include/context.h"



+ 1
- 1
mindspore/lite/src/runtime/kernel/arm/base/pad.cc View File

@@ -17,7 +17,7 @@
#include "src/runtime/kernel/arm/fp32/pad.h"
#include "src/runtime/kernel/arm/int8/pad_int8.h"
#include "schema/model_generated.h"
#include "src/kernel_factory.h"
#include "src/kernel_registry.h"
#include "include/errorcode.h"
#include "include/context.h"



+ 1
- 1
mindspore/lite/src/runtime/kernel/arm/base/pooling_base.cc View File

@@ -18,7 +18,7 @@
#include "src/runtime/kernel/arm/int8/pooling_int8.h"
#include "src/runtime/kernel/arm/fp32/pooling.h"
#include "schema/model_generated.h"
#include "src/kernel_factory.h"
#include "src/kernel_registry.h"
#include "include/errorcode.h"
#include "include/context.h"



+ 1
- 1
mindspore/lite/src/runtime/kernel/arm/base/prelu_base.cc View File

@@ -17,7 +17,7 @@
#include <vector>
#include "src/runtime/kernel/arm/int8/prelu_int8.h"
#include "schema/model_generated.h"
#include "src/kernel_factory.h"
#include "src/kernel_registry.h"
#include "include/errorcode.h"
#include "include/context.h"



+ 1
- 1
mindspore/lite/src/runtime/kernel/arm/base/prior_box.cc View File

@@ -18,7 +18,7 @@
#include <cmath>
#include "src/runtime/kernel/arm/base/prior_box.h"
#include "schema/model_generated.h"
#include "src/kernel_factory.h"
#include "src/kernel_registry.h"
#include "include/errorcode.h"
#include "include/context.h"
#include "src/runtime/runtime_api.h"


+ 1
- 1
mindspore/lite/src/runtime/kernel/arm/base/reshape_base.cc View File

@@ -18,7 +18,7 @@
#include "src/runtime/kernel/arm/int8/reshape_int8.h"
#include "src/runtime/kernel/arm/fp32/reshape.h"
#include "schema/model_generated.h"
#include "src/kernel_factory.h"
#include "src/kernel_registry.h"
#include "include/errorcode.h"
#include "include/context.h"



+ 1
- 1
mindspore/lite/src/runtime/kernel/arm/base/softmax_base.cc View File

@@ -20,7 +20,7 @@
#include "src/runtime/kernel/arm/fp32/softmax.h"
#include "src/runtime/kernel/arm/nnacl/fp32/softmax.h"
#include "schema/model_generated.h"
#include "src/kernel_factory.h"
#include "src/kernel_registry.h"
#include "include/errorcode.h"

using mindspore::lite::KernelRegistrar;


+ 1
- 1
mindspore/lite/src/runtime/kernel/arm/base/split_base.cc View File

@@ -18,7 +18,7 @@
#include "src/runtime/kernel/arm/int8/split_int8.h"
#include "src/runtime/kernel/arm/fp32/split.h"
#include "schema/model_generated.h"
#include "src/kernel_factory.h"
#include "src/kernel_registry.h"
#include "include/errorcode.h"
#include "include/context.h"



+ 1
- 1
mindspore/lite/src/runtime/kernel/arm/base/squeeze_base.cc View File

@@ -17,7 +17,7 @@
#include <vector>
#include "src/runtime/kernel/arm/int8/squeeze_int8.h"
#include "schema/model_generated.h"
#include "src/kernel_factory.h"
#include "src/kernel_registry.h"
#include "include/errorcode.h"
#include "include/context.h"



+ 1
- 1
mindspore/lite/src/runtime/kernel/arm/fp16/convolution_base_fp16.cc View File

@@ -18,7 +18,7 @@
#include "src/runtime/kernel/arm/nnacl/fp16/cast_fp16.h"
#include "src/runtime/kernel/arm/fp16/common_fp16.h"
#include "schema/model_generated.h"
#include "src/kernel_factory.h"
#include "src/kernel_registry.h"
#include "include/errorcode.h"
#include "src/runtime/runtime_api.h"



+ 1
- 1
mindspore/lite/src/runtime/kernel/arm/fp32/convolution.cc View File

@@ -22,7 +22,7 @@
#include "src/runtime/kernel/arm/nnacl/fp32/conv.h"
#include "src/runtime/kernel/arm/nnacl/common_func.h"
#include "schema/model_generated.h"
#include "src/kernel_factory.h"
#include "src/kernel_registry.h"
#include "include/errorcode.h"
#include "src/runtime/runtime_api.h"



+ 1
- 1
mindspore/lite/src/runtime/kernel/arm/fp32/convolution_slidewindow.cc View File

@@ -17,7 +17,7 @@
#include "src/runtime/kernel/arm/fp32/convolution_slidewindow.h"
#include "src/runtime/kernel/arm/nnacl/common_func.h"
#include "schema/model_generated.h"
#include "src/kernel_factory.h"
#include "src/kernel_registry.h"
#include "include/errorcode.h"
#include "src/runtime/runtime_api.h"



+ 1
- 1
mindspore/lite/src/runtime/kernel/arm/fp32/transpose.h View File

@@ -20,7 +20,7 @@
#include <vector>
#include "src/lite_kernel.h"

#include "src/kernel_factory.h"
#include "src/kernel_registry.h"

namespace mindspore::kernel {



+ 1
- 1
mindspore/lite/src/runtime/kernel/arm/fp32_grad/bn_grad.cc View File

@@ -17,7 +17,7 @@
#include <algorithm>
#include <vector>
#include "schema/model_generated.h"
#include "src/kernel_factory.h"
#include "src/kernel_registry.h"
#include "src/runtime/kernel/arm/fp32_grad/bn_grad.h"
#include "src/runtime/kernel/arm/nnacl/fp32_grad/batch_norm.h"
#include "include/errorcode.h"


+ 4
- 4
mindspore/lite/src/scheduler.cc View File

@@ -18,7 +18,7 @@
#include <vector>
#include <algorithm>
#include "include/errorcode.h"
#include "src/kernel_factory.h"
#include "src/kernel_registry.h"
#include "src/common/graph_util.h"
#include "src/common/utils.h"
#if SUPPORT_GPU
@@ -191,7 +191,7 @@ kernel::LiteKernel *Scheduler::ScheduleNode(const std::vector<tensor::Tensor *>
kernel::KernelKey desc{kernel::KERNEL_ARCH::kCPU, data_type, primitive->Type()};
if (context_->device_ctx_.type == DT_GPU) {
desc.arch = kernel::KERNEL_ARCH::kGPU;
auto *kernel = KernelFactory::GetInstance()->GetKernel(in_tensors, out_tensors, primitive, context_, desc);
auto *kernel = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, primitive, context_, desc);
if (nullptr != kernel) {
kernel->set_desc(desc);
return kernel;
@@ -203,7 +203,7 @@ kernel::LiteKernel *Scheduler::ScheduleNode(const std::vector<tensor::Tensor *>
if ((context_->float16_priority && data_type == kNumberTypeFloat32) || data_type == kNumberTypeFloat16) {
// check if support fp16
kernel::KernelKey key{desc.arch, kNumberTypeFloat16, desc.type};
kernel = KernelFactory::GetInstance()->GetKernel(in_tensors, out_tensors, primitive, context_, key);
kernel = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, primitive, context_, key);
if (kernel != nullptr) {
MS_LOG(DEBUG) << "Get fp16 op success.";
desc.data_type = kNumberTypeFloat16;
@@ -215,7 +215,7 @@ kernel::LiteKernel *Scheduler::ScheduleNode(const std::vector<tensor::Tensor *>
if (data_type == kNumberTypeFloat16) {
desc.data_type = kNumberTypeFloat32;
}
kernel = KernelFactory::GetInstance()->GetKernel(in_tensors, out_tensors, primitive, context_, desc);
kernel = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, primitive, context_, desc);
if (kernel != nullptr) {
kernel->set_desc(desc);
return kernel;


+ 0
- 59
mindspore/lite/src/train/base_ref_utils.cc View File

@@ -1,59 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "src/train/base_ref_utils.h"
#include <vector>
#include <memory>
// #include "utils/base_ref_utils.h"
#include "include/ms_tensor.h"
#include "src/ir/tensor.h"

namespace mindspore {
std::vector<std::shared_ptr<tensor::MSTensor>> TransformBaseRefToMSTensor(const BaseRef &base_ref) {
std::vector<std::shared_ptr<tensor::MSTensor>> msTensors;
if (utils::isa<VectorRef>(base_ref)) {
auto ref_list = utils::cast<VectorRef>(base_ref);
for (size_t i = 0; i < ref_list.size(); ++i) {
if (utils::isa<tensor::Tensor>(ref_list[i])) {
auto tensor_ptr = utils::cast<std::shared_ptr<tensor::Tensor>>(ref_list[i]);
MS_EXCEPTION_IF_NULL(tensor_ptr);
auto tensor = new tensor::LiteTensor(new tensor::Tensor(*tensor_ptr));
msTensors.emplace_back(std::shared_ptr<tensor::MSTensor>(tensor));
} else {
MS_LOG(EXCEPTION) << "The output is not a tensor!";
}
}
} else if (utils::isa<tensor::Tensor>(base_ref)) {
auto tensor_ptr = utils::cast<std::shared_ptr<tensor::Tensor>>(base_ref);
MS_EXCEPTION_IF_NULL(tensor_ptr);
auto tensor = new tensor::LiteTensor(new tensor::Tensor(*tensor_ptr));
msTensors.emplace_back(std::shared_ptr<tensor::MSTensor>(tensor));
} else {
MS_LOG(EXCEPTION) << "The output is not a base ref list or a tensor!";
}
return msTensors;
}

std::vector<std::vector<std::shared_ptr<tensor::MSTensor>>> TransformVectorRefToMultiTensor(
const VectorRef &vector_ref) {
std::vector<std::vector<std::shared_ptr<tensor::MSTensor>>> multiTensor;
for (size_t i = 0; i < vector_ref.size(); ++i) {
auto tensors = TransformBaseRefToMSTensor(vector_ref[i]);
multiTensor.emplace_back(tensors);
}
return multiTensor;
}
} // namespace mindspore

+ 0
- 30
mindspore/lite/src/train/base_ref_utils.h View File

@@ -1,30 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <vector>
#include <memory>
#include "utils/base_ref.h"
#include "include/ms_tensor.h"

#ifndef MINDSPORE_LITE_SRC_TRAIN_BASE_REF_UTILS_H_
#define MINDSPORE_LITE_SRC_TRAIN_BASE_REF_UTILS_H_
namespace mindspore {
std::vector<std::shared_ptr<tensor::MSTensor>> TransformBaseRefToMSTensor(const BaseRef &base_ref);

std::vector<std::vector<std::shared_ptr<tensor::MSTensor>>> TransformVectorRefToMultiTensor(
const VectorRef &vector_ref);
} // namespace mindspore
#endif // MINDSPORE_LITE_SRC_TRAIN_BASE_REF_UTILS_H_

+ 0
- 50
mindspore/lite/src/train/import.hpp View File

@@ -1,50 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include<memory>
#include "src/common/anf_importer/import_from_meta_graph.h"
namespace mindspore::lite::train {
std::shared_ptr<ModelImpl> Import(const char *model_buf, size_t size) {
MS_EXCEPTION_IF_NULL(model_buf);
flatbuffers::Verifier verify((const uint8_t *) model_buf, size);
if (!schema::VerifyMetaGraphBuffer(verify)) {
MS_LOG(ERROR) << "The buffer is invalid and fail to create graph.";
return nullptr;
}
// todo hangangqiang remove when copy primitive done
if (size <= 0) {
MS_LOG(ERROR) << "size is zero";
return nullptr;
}
auto *inner_buf = new char[size];
memcpy(inner_buf, model_buf, size);
auto meta_graph = schema::GetMetaGraph(inner_buf);
auto model = std::make_shared<ModelImpl>(meta_graph);
auto ret = model->BuildOps();
if (0 != ret) {
MS_LOG(ERROR) << "BuildOps failed";
return nullptr;
}
MS_EXCEPTION_IF_NULL(meta_graph);
auto importer = new AnfImporterFromMetaGraph(model);
auto ret2 = importer->Import();
if (0 != ret2) {
MS_LOG(ERROR) << "Import anf_graph from meta_graph failed, ret2: " << ret2;
return nullptr;
}
return model;
}
} // namespace mindspore::lite::train

+ 0
- 86
mindspore/lite/src/train/lite_kernel_runtime.cc View File

@@ -1,86 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "src/train/lite_kernel_runtime.h"
#include "backend/session/anf_runtime_algorithm.h"
namespace mindspore::lite {
std::vector<CNodePtr> LiteInferKernelRuntime::GetGraphInputs(const std::vector<CNodePtr> &execution_order) {
std::vector<CNodePtr> graph_inputs;
for (const auto &cnode : execution_order) {
bool is_graph_inputs = true;
for (const auto &input : cnode->inputs()) {
if (input->isa<CNode>()) {
is_graph_inputs = false;
break;
}
}
if (is_graph_inputs) {
graph_inputs.emplace_back(cnode);
}
}
return graph_inputs;
}

void LiteInferKernelRuntime::BindInputOutput(const session::KernelGraph *graph,
const std::vector<tensor::Tensor *> &inputs,
std::vector<tensor::Tensor *> *outputs) {
MS_EXCEPTION_IF_NULL(graph);
auto execution_order = graph->execution_order();
auto graph_inputs = GetGraphInputs(execution_order);
int input_count = 0;
for (const auto &graph_input : graph_inputs) {
auto liteKernel = dynamic_cast<kernel::LiteKernel *>(AnfAlgo::GetKernelMod(graph_input));
for (auto input_tensor : liteKernel->GetInputs()) {
if (schema::NodeType_ValueNode == input_tensor->TensorType() && input_tensor->Data() != nullptr) {
continue;
}
input_tensor->SetData(inputs[input_count]->Data());
input_count++;
}
}

auto return_node = graph->get_return();
for (const auto &return_input : return_node->inputs()) {
if (return_input->isa<CNode>()) {
auto liteKernel = dynamic_cast<kernel::LiteKernel *>(AnfAlgo::GetKernelMod(return_input));
auto output_tensors = liteKernel->GetOutputs();
for (auto output_tensor : output_tensors) {
// tensor::TensorPtr output_tensor_ptr(output_tensor);
outputs->push_back(output_tensor);
}
}
}
}

bool LiteInferKernelRuntime::Run(session::KernelGraph *graph, const std::vector<tensor::Tensor *> &inputs,
std::vector<tensor::Tensor *> *outputs) {
MS_EXCEPTION_IF_NULL(graph);
BindInputOutput(graph, inputs, *outputs);
std::vector<kernel::LiteKernel *> kernels;
auto nodes = graph->execution_order();
for (const auto &node : nodes) {
auto liteKernel = dynamic_cast<kernel::LiteKernel *>(AnfAlgo::GetKernelMod(node));
if (liteKernel == nullptr) {
continue;
}
kernels.emplace_back(liteKernel);
}
kernel::LiteKernelUtil::TopologicalSortKernels(kernels);
Executor executor;
auto ret = executor.Run(inputs, *outputs, kernels);
return 0 == ret;
}
} // namespace mindspore::lite

+ 0
- 50
mindspore/lite/src/train/lite_kernel_runtime.h View File

@@ -1,50 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_LITE_SRC_TRAIN_LITE_KERNEL_RUNTIME_H_
#define MINDSPORE_LITE_SRC_TRAIN_LITE_KERNEL_RUNTIME_H_

#include <memory>
#include <vector>
#include <string>
#include <unordered_map>
#include "src/runtime/allocator.h"
#include "src/executor.h"
// #include "runtime/device/kernel_runtime.h"
#include "runtime/device/device_address.h"
#include "src/lite_kernel.h"
#include "backend/session/kernel_graph.h"
namespace mindspore::lite {
class LiteInferKernelRuntime {
public:
LiteInferKernelRuntime() = default;
~LiteInferKernelRuntime() = default;

bool Run(session::KernelGraph *graph, const std::vector<tensor::Tensor *> &inputs,
std::vector<tensor::Tensor *> *outputs);

void AssignKernelAddress(session::KernelGraph *graph) {}

protected:
void BindInputOutput(const session::KernelGraph *graph, const std::vector<tensor::Tensor *> &inputs,
std::vector<tensor::Tensor *> *outputs);

std::vector<CNodePtr> GetGraphInputs(const std::vector<CNodePtr> &execution_order);
};

} // namespace mindspore::lite

#endif // MINDSPORE_LITE_SRC_TRAIN_LITE_KERNEL_RUNTIME_H_

+ 0
- 145
mindspore/lite/src/train/model_impl.cc View File

@@ -1,145 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <string>
#include "src/train/model_impl.h"
#include "ir/func_graph.h"
#include "schema/model_generated.h"
#include "src/common/anf_importer/import_from_meta_graph.h"

namespace mindspore::lite::train {

std::shared_ptr<ModelImpl> ModelImpl::Import(const char *model_buf, size_t size) {
MS_EXCEPTION_IF_NULL(model_buf);
flatbuffers::Verifier verify((const uint8_t *)model_buf, size);
if (!schema::VerifyMetaGraphBuffer(verify)) {
MS_LOG(ERROR) << "The buffer is invalid and fail to create graph.";
return nullptr;
}
// todo hangangqiang remove when copy primitive done
auto *inner_buf = new char[size];
memcpy(inner_buf, model_buf, size);
auto meta_graph = schema::GetMetaGraph(inner_buf);
auto func_graph_model = std::make_shared<ModelImpl>(meta_graph);
auto ret = func_graph_model->BuildOps();
if (0 != ret) {
MS_LOG(ERROR) << "BuildOps failed";
return nullptr;
}
AnfImporterFromMetaGraph anfImporter(func_graph_model);
anfImporter.Import();
return func_graph_model;
}

const lite::Primitive *ModelImpl::GetOp(const std::string &name) const {
auto iter = ops.find(name);
if (iter == ops.end()) {
return nullptr;
} else {
return iter->second;
}
}

void ModelImpl::FreeMetaGraph() { delete this->meta_graph; }

const schema::MetaGraph *ModelImpl::GetMetaGraph() const { return this->meta_graph; }

lite::Primitive *ModelImpl::CopyPrimitive(const schema::Primitive *srcPrim) {
MS_EXCEPTION_IF_NULL(srcPrim);
auto op_type = srcPrim->value_type();
switch (op_type) {
case schema::PrimitiveType_SoftMax:
return new lite::SoftMax(const_cast<schema::Primitive *>(srcPrim));
case schema::PrimitiveType_Activation:
return new lite::Activation(const_cast<schema::Primitive *>(srcPrim));
case schema::PrimitiveType_Conv2D:
return new lite::Conv2D(const_cast<schema::Primitive *>(srcPrim));
case schema::PrimitiveType_Reduce:
return new lite::Reduce(const_cast<schema::Primitive *>(srcPrim));
case schema::PrimitiveType_Pooling:
return new lite::Pooling(const_cast<schema::Primitive *>(srcPrim));
case schema::PrimitiveType_DepthwiseConv2D:
return new lite::DepthwiseConv2D(const_cast<schema::Primitive *>(srcPrim));
case schema::PrimitiveType_FusedBatchNorm:
return new lite::FusedBatchNorm(const_cast<schema::Primitive *>(srcPrim));
case schema::PrimitiveType_CaffeBatchNorm:
return new lite::CaffeBatchNorm(const_cast<schema::Primitive *>(srcPrim));
case schema::PrimitiveType_FullConnection:
return new lite::FullConnection(const_cast<schema::Primitive *>(srcPrim));
case schema::PrimitiveType_Power:
return new lite::Power(const_cast<schema::Primitive *>(srcPrim));
case schema::PrimitiveType_Range:
return new lite::Range(const_cast<schema::Primitive *>(srcPrim));
case schema::PrimitiveType_Mul:
return new lite::Mul(const_cast<schema::Primitive *>(srcPrim));
case schema::PrimitiveType_Add:
return new lite::Add(const_cast<schema::Primitive *>(srcPrim));
case schema::PrimitiveType_Sub:
return new lite::Sub(const_cast<schema::Primitive *>(srcPrim));
case schema::PrimitiveType_Div:
return new lite::Div(const_cast<schema::Primitive *>(srcPrim));
case schema::PrimitiveType_BiasAdd:
return new lite::BiasAdd(const_cast<schema::Primitive *>(srcPrim));
case schema::PrimitiveType_ExpandDims:
return new lite::ExpandDims(const_cast<schema::Primitive *>(srcPrim));
case schema::PrimitiveType_ArgMax:
return new lite::ArgMax(const_cast<schema::Primitive *>(srcPrim));
case schema::PrimitiveType_ArgMin:
return new lite::ArgMin(const_cast<schema::Primitive *>(srcPrim));
case schema::PrimitiveType_Cast:
return new lite::Cast(const_cast<schema::Primitive *>(srcPrim));
case schema::PrimitiveType_Reshape:
return new lite::Reshape(const_cast<schema::Primitive *>(srcPrim));
case schema::PrimitiveType_Scale:
return new lite::Scale(const_cast<schema::Primitive *>(srcPrim));
case schema::PrimitiveType_Eltwise:
return new lite::Eltwise(const_cast<schema::Primitive *>(srcPrim));
case schema::PrimitiveType_Ceil:
return new lite::Ceil(const_cast<schema::Primitive *>(srcPrim));
case schema::PrimitiveType_Concat:
return new lite::Concat(const_cast<schema::Primitive *>(srcPrim));
case schema::PrimitiveType_Fill:
return new lite::Fill(const_cast<schema::Primitive *>(srcPrim));
case schema::PrimitiveType_Transpose:
return new lite::Transpose(const_cast<schema::Primitive *>(srcPrim));
case schema::PrimitiveType_Slice:
return new lite::Slice(const_cast<schema::Primitive *>(srcPrim));
case schema::PrimitiveType_Nchw2Nhwc:
return new lite::Nchw2Nhwc(const_cast<schema::Primitive *>(srcPrim));
case schema::PrimitiveType_Nhwc2Nchw:
return new lite::Nhwc2Nchw(const_cast<schema::Primitive *>(srcPrim));
case schema::PrimitiveType_MatMul:
return new lite::MatMul(const_cast<schema::Primitive *>(srcPrim));
default:
break;
}
return nullptr;
}

int ModelImpl::BuildOps() {
if (this->meta_graph == nullptr) {
MS_LOG(ERROR) << "mete_graph is nullptr";
return -1;
}
for (size_t i = 0; i < meta_graph->nodes()->size(); i++) {
auto cNode = meta_graph->nodes()->GetAs<schema::CNode>(i);
auto name = cNode->name()->str();
auto srcPrim = cNode->primitive();
this->ops[name] = CopyPrimitive(srcPrim);
}
return 0;
}
} // namespace mindspore::lite::train

+ 0
- 64
mindspore/lite/src/train/model_impl.h View File

@@ -1,64 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_LITE_SRC_TRAIN_MODEL_IMPL_H_
#define MINDSPORE_LITE_SRC_TRAIN_MODEL_IMPL_H_

#include <string>
#include <map>
#include <memory>
#include <vector>
#include "schema/model_generated.h"
#include "src/ops/ops.h"
#include "ir/func_graph.h"

namespace mindspore::lite {
namespace train {
class ModelImpl : public FuncGraph {
public:
static std::shared_ptr<ModelImpl> Import(const char *model_buf, size_t size); // { return NULL; };
ModelImpl() = default;
explicit ModelImpl(const schema::MetaGraph *graph) : meta_graph(graph) {}
~ModelImpl() override = default;
const lite::Primitive *GetOp(const std::string &name) const;
const schema::MetaGraph *GetMetaGraph() const;
void FreeMetaGraph();
int BuildOps();

void AddCNodeInputOutput(std::string name, const std::vector<int> &input, const std::vector<int> &output) {
std::vector<int> *tuple = new std::vector<int>[2];
tuple[0] = input;
tuple[1] = output;
connectivity_[name] = tuple;
}
std::vector<int> *GetCNodeInputOutputIndices(std::string name) { return connectivity_[name]; }
void AddAnfNode(int id, AnfNodePtr anf_ptr) { tensors_[id] = anf_ptr; }
AnfNodePtr GetAnfNode(int id) { return tensors_[id]; }

protected:
lite::Primitive *CopyPrimitive(const schema::Primitive *srcPrim);

protected:
const schema::MetaGraph *meta_graph = nullptr;
std::map<int, AnfNodePtr> tensors_;
std::map<std::string, std::vector<int> *> connectivity_;
std::map<std::string, lite::Primitive *> ops;
};
} // namespace train
using ModelImpl = mindspore::lite::train::ModelImpl;
} // namespace mindspore::lite

#endif // MINDSPORE_LITE_SRC_TRAIN_MODEL_IMPL_H_

+ 0
- 253
mindspore/lite/src/train/train_anf_session.cc View File

@@ -1,253 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <algorithm>
#include "src/train/train_anf_session.h"
#include "include/context.h"
#include "mindspore/ccsrc/runtime/device/kernel_info.h"
#include "mindspore/lite/src/train/train_session.h"
#include "mindspore/lite/src/kernel_factory.h"
#include "mindspore/lite/src/param_value_lite.h"
#include "common/utils.h"
#include "mindspore/lite/src/ops/ops.h"
#include "ir/anf.h"
#include "mindspore/lite/src/ir/tensor.h"
#include "abstract/abstract_value.h"
#include "backend/session/anf_runtime_algorithm.h"
#include "src/ir/primitive_value.h"
#include "src/train/model_impl.h"

namespace mindspore {
namespace session {
static std::vector<int> GetAnfNodeOutDims(const AnfNodePtr &anfNodePtr) {
auto nodeAbstract = anfNodePtr->abstract();
if (nodeAbstract != nullptr) {
auto shape = nodeAbstract->GetShapeTrack();
if (!shape->isa<abstract::Shape>()) {
MS_LOG(EXCEPTION) << "Not a Shape";
return {};
}
auto dims = dyn_cast<abstract::Shape>(shape)->shape();
return dims;
} else {
MS_LOG(WARNING) << "abstract is nullptr, return empty dims";
return {};
}
}

static schema::Format GetAnfNodeFormat(const AnfNodePtr &anfNodePtr) {
auto nodeAbstract = anfNodePtr->abstract();
if (nodeAbstract != nullptr) {
return schema::Format_NHWC; // XXX TODO -- extract Format from AnfNode
} else {
MS_LOG(WARNING) << "abstract is nullptr, return schema::Format_NHWC";
return schema::Format_NHWC;
}
}

static TypeId GetAnfNodeOutTypeId(const AnfNodePtr &anfNodePtr) {
auto nodeAbstract = anfNodePtr->abstract();
if (nodeAbstract != nullptr) {
return TypeId::kNumberTypeFloat32; // XXX TODO nodeAbstract->GetTypeTrack()->generic_type_id();
} else {
MS_LOG(WARNING) << "abstract is nullptr, return kTypeUnknown";
return TypeId::kTypeUnknown;
}
}

void TrainANFSession::Init(lite::Context *context) {
MS_EXCEPTION_IF_NULL(context);
this->context_ = std::make_shared<lite::Context>(context->thread_num_, context->allocator, context->device_ctx_);
}

lite::tensor::Tensor *TrainANFSession::GetTensorForAnfNode(const AnfNodePtr anf_node) {
lite::tensor::Tensor *out_tensor = tensors_[anf_node];
if (out_tensor == NULL) {
out_tensor = new lite::tensor::Tensor(GetAnfNodeOutTypeId(anf_node),
GetAnfNodeOutDims(anf_node)); //, schema::NodeType_Parameter);
tensors_[anf_node] = out_tensor;
}
return out_tensor;
}

int TrainANFSession::BuildKernelInputAndOutputFromFuncGraph(const KernelGraphPtr &kernel_graph) {
auto return_node = kernel_graph->get_return();
auto node_list = TopoSort(return_node);
auto model_imp = std::dynamic_pointer_cast<lite::train::ModelImpl>(func_graph_);
for (auto &node : node_list) {
if (!node->isa<CNode>()) {
continue;
}
KernelRelation kernel_relation;
auto cnode = node->cast<CNodePtr>();
kernel_relation.node_full_name = cnode->fullname_with_scope();
kernel_relation.cnode = cnode;
std::vector<int> *cnode_io_indices = model_imp->GetCNodeInputOutputIndices(cnode->fullname_with_scope());
if (cnode_io_indices == NULL) {
MS_LOG(WARNING) << "No IO vectors for " << cnode->fullname_with_scope();
} else {
for (int i = 0; i < cnode_io_indices[1].size(); i++) {
AnfNodePtr anf_node = model_imp->GetAnfNode(cnode_io_indices[1].data()[i]);
kernel_relation.output_tensor.push_back(GetTensorForAnfNode(anf_node));
}
}
lite::tensor::Tensor *tensor_ptr = nullptr;
for (size_t index = 1; index < cnode->inputs().size(); ++index) {
if (cnode->input(index)->isa<CNode>()) {
auto input_cnode = cnode->input(index)->cast<CNodePtr>();
auto input_kernel_relation = kernel_relation_infos_[input_cnode->fullname_with_scope()];
// todo not support multi-outputs kernel sudo as spilt
tensor_ptr = input_kernel_relation.output_tensor.front();
} else if (cnode->input(index)->isa<Parameter>()) {
auto input_parameter = cnode->input(index)->cast<ParameterPtr>();
auto para = input_parameter->default_param();
auto param_value = std::dynamic_pointer_cast<ParamValueLite>(para);
// auto dims = param_value->tensor_shape();
// tensor_ptr = new lite::tensor::Tensor(param_value->tensor_type(), dims); // schema::NodeType_ValueNode);
tensor_ptr = GetTensorForAnfNode(cnode->input(index));
if ((param_value != nullptr) && (param_value->tensor_size() != 0)) {
tensor_ptr->SetData(param_value->tensor_addr());
}
} else if (cnode->input(index)->isa<ValueNode>()) {
auto input_valuenode = cnode->input(index)->cast<ValueNodePtr>();
// tensor_ptr = new lite::tensor::Tensor(GetAnfNodeOutTypeId(input_valuenode),
// GetAnfNodeOutDims(input_valuenode)); // schema::NodeType_Parameter);
tensor_ptr = GetTensorForAnfNode(input_valuenode);
// todo(yankai)
} else {
MS_ASSERT(false);
}
kernel_relation.input_tensor.push_back(tensor_ptr);
}
kernel_relation_infos_[cnode->fullname_with_scope()] = kernel_relation;
}
return 0;
}

GraphId TrainANFSession::graph_sum_ = 0;

KernelGraphPtr TrainANFSession::NewKernelGraph() {
auto graph = std::make_shared<KernelGraph>();
graph->set_graph_id(graph_sum_);
graphs_[graph_sum_++] = graph;
return graph;
}

std::shared_ptr<KernelGraph> TrainANFSession::ConstructKernelGraph(const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(func_graph);
auto graph = NewKernelGraph();
graph->set_return(func_graph->get_return());
auto node_list = TopoSort(func_graph->get_return());
std::vector<CNodePtr> cnode_order;
for (const auto &node : node_list) {
MS_EXCEPTION_IF_NULL(node);
if (node->isa<CNode>()) {
auto cn_node = node->cast<CNodePtr>();
cnode_order.push_back(cn_node);
}
}
graph->set_execution_order(cnode_order);
return graph;
}
GraphId TrainANFSession::CompileGraph(NotNull<FuncGraphPtr> func_graph) {
auto graph = ConstructKernelGraph(func_graph);
func_graph_ = func_graph;
MS_EXCEPTION_IF_NULL(graph);
MS_LOG(INFO) << "Set kernel info";
SetKernelInfo(graph.get());

(void)BuildKernelInputAndOutputFromFuncGraph(graph);
MS_LOG(INFO) << "Build kernel";
auto ret = BuildKernel(graph.get());
if (0 != ret) {
MS_LOG(EXCEPTION) << "BuildKernel failed";
}

// return the graph id to backend
auto graph_id = graph->graph_id();
graphs_[graph_id] = graph;
MS_LOG(INFO) << "Compile graph " << graph_id << " success";
return graph_id;
}

void TrainANFSession::RunGraph(const GraphId &graph_id, const std::vector<lite::tensor::Tensor *> &inputs,
std::vector<lite::tensor::Tensor *> *outputs) {
auto &kernel_graph = graphs_[graph_id];
MS_EXCEPTION_IF_NULL(kernel_graph);
MS_LOG(INFO) << "Bind input output address";
// runtime_.BindInputOutput(kernel_graph.get(), inputs, outputs); -- will be bound in Run
// auto execution_order = kernel_graph->execution_order();
// Todo : hangangqiang
// Reorder(&execution_order);
// kernel_graph->set_execution_order(execution_order);
MS_LOG(INFO) << "Run graph start";
auto ret = runtime_.Run(kernel_graph.get(), (std::vector<lite::tensor::Tensor *> &)inputs, *outputs);
if (!ret) {
MS_LOG(EXCEPTION) << "Run graph failed";
}
MS_LOG(INFO) << "Run graph end";
}

void TrainANFSession::SetKernelInfo(const KernelGraph *kernel_graph) {
MS_EXCEPTION_IF_NULL(kernel_graph);
auto &kernel_nodes = kernel_graph->execution_order();
for (const auto &kernel_node : kernel_nodes) {
MS_EXCEPTION_IF_NULL(kernel_node);
auto kernel_info = std::make_shared<device::KernelInfo>();
kernel_node->set_kernel_info(kernel_info);
}
}

int TrainANFSession::BuildKernel(const KernelGraph *kernel_graph) {
MS_EXCEPTION_IF_NULL(kernel_graph);
for (auto iter = kernel_relation_infos_.begin(); iter != kernel_relation_infos_.end(); ++iter) {
std::string kernel_name = iter->first;
KernelRelation anf_register = iter->second;
MS_EXCEPTION_IF_NULL(anf_register.cnode);
if (IsPrimitiveCNode(anf_register.cnode, prim::kPrimReturn)) {
continue;
}
auto value_node_prim = anf_register.cnode->input(0);
MS_EXCEPTION_IF_NULL(value_node_prim);
auto prim = GetValueNode<std::shared_ptr<lite::PrimitiveValue>>(value_node_prim);
MS_EXCEPTION_IF_NULL(prim);
auto node_primitive = (lite::Primitive *)(prim->GetPrimitive());
MS_EXCEPTION_IF_NULL(node_primitive);
auto ret = node_primitive->InferShape(anf_register.input_tensor, anf_register.output_tensor);
if (0 != ret) {
MS_LOG(ERROR) << "InferShape failed, node : " << kernel_name;
return ret;
}
kernel::KernelKey desc{kernel::KERNEL_ARCH::kCPU, kNumberTypeFloat32, node_primitive->Type()};

auto *kernel = lite::KernelFactory::GetInstance()->GetKernel(anf_register.input_tensor, anf_register.output_tensor,
node_primitive, context_.get(), desc);
if (nullptr == kernel) {
MS_LOG(ERROR) << "Create kernel return nullptr, name: " << kernel_name;
return -1;
}
std::shared_ptr<kernel::LiteKernel> kernel_mod(kernel);
kernel_mod->set_name(anf_register.cnode->fullname_with_scope());

// kernel->train();
auto kernel_info = dynamic_cast<device::KernelInfo *>(anf_register.cnode->kernel_info());
MS_EXCEPTION_IF_NULL(kernel_info);
kernel_info->set_kernel_mod(kernel_mod); // XXX TODO -- only derived class KernelInfo has this method
}
return 0;
}
} // namespace session
} // namespace mindspore

+ 0
- 76
mindspore/lite/src/train/train_anf_session.h View File

@@ -1,76 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_TRAIN_TRAIN_SESSION_H_
#define MINDSPORE_LITE_SRC_TRAIN_TRAIN_SESSION_H_
#include <map>
#include <string>
#include <memory>
#include <vector>
#include <unordered_map>
#include "include/context.h"
#include "backend/session/session_basic.h"
#include "backend/session/kernel_graph.h"
#include "mindspore/lite/src/train/lite_kernel_runtime.h"
// #include "backend/session/session_factory.h"
namespace mindspore {
namespace lite::tensor {
class Tensor;
}
namespace session {
struct KernelRelation {
std::string node_full_name;
std::vector<lite::tensor::Tensor *> input_tensor;
std::vector<lite::tensor::Tensor *> output_tensor;
CNodePtr cnode;
};

class TrainANFSession {
public:
explicit TrainANFSession(lite::Context *context) { Init(context); }
~TrainANFSession() = default;

GraphId CompileGraph(NotNull<FuncGraphPtr> func_graph);

void RunGraph(const GraphId &graph_id, const std::vector<lite::tensor::Tensor *> &inputs,
std::vector<lite::tensor::Tensor *> *outputs);

// void RunGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) override
// {};
protected:
void Init(lite::Context *context);
std::shared_ptr<lite::Context> context_ = nullptr;
std::unordered_map<GraphId, std::shared_ptr<KernelGraph>> graphs_;
static GraphId graph_sum_;
KernelGraphPtr NewKernelGraph();

private:
// GraphId CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) override;
// GraphId CompileGraph(const char *model_buf, size_t size);
std::shared_ptr<KernelGraph> ConstructKernelGraph(const FuncGraphPtr &func_graph);
int BuildKernelInputAndOutputFromFuncGraph(const KernelGraphPtr &kernel_graph);

lite::tensor::Tensor *GetTensorForAnfNode(const AnfNodePtr anf_node);

void SetKernelInfo(const KernelGraph *kernel_graph);
int BuildKernel(const KernelGraph *kernel_graph);
lite::LiteInferKernelRuntime runtime_;
std::map<std::string, KernelRelation> kernel_relation_infos_;
FuncGraphPtr func_graph_ = NULL;
std::map<AnfNodePtr, lite::tensor::Tensor *> tensors_;
};
} // namespace session
} // namespace mindspore
#endif // MINDSPORE_LITE_SRC_TRAIN_TRAIN_SESSION_H_

+ 0
- 267
mindspore/lite/src/train/train_session.cc View File

@@ -1,267 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <algorithm>
#include "include/context.h"
#include "mindspore/ccsrc/runtime/device/kernel_info.h"
#include "mindspore/lite/src/train/train_session.h"
#include "mindspore/lite/src/kernel_factory.h"
#include "mindspore/lite/src/param_value_lite.h"
#include "utils/ms_utils.h"
#include "mindspore/lite/src/ops/ops.h"
#include "ir/anf.h"
#include "mindspore/lite/src/ir/tensor.h"
#include "abstract/abstract_value.h"
#include "backend/session/anf_runtime_algorithm.h"
#include "src/ir/primitive_value.h"
#include "src/train/model_impl.h"

namespace mindspore {
namespace session {
static std::vector<int> GetAnfNodeOutDims(const AnfNodePtr &anfNodePtr) {
auto nodeAbstract = anfNodePtr->abstract();
if (nodeAbstract != nullptr) {
auto shape = nodeAbstract->GetShapeTrack();
if (!shape->isa<abstract::Shape>()) {
MS_LOG(EXCEPTION) << "Not a Shape";
return {};
}
auto dims = dyn_cast<abstract::Shape>(shape)->shape();
return dims;
} else {
MS_LOG(WARNING) << "abstract is nullptr, return empty dims";
return {};
}
}

static schema::Format GetAnfNodeFormat(const AnfNodePtr &anfNodePtr) {
auto nodeAbstract = anfNodePtr->abstract();
if (nodeAbstract != nullptr) {
return schema::Format_NHWC; // XXX TODO -- extract Format from AnfNode
} else {
MS_LOG(WARNING) << "abstract is nullptr, return schema::Format_NHWC";
return schema::Format_NHWC;
}
}

static TypeId GetAnfNodeOutTypeId(const AnfNodePtr &anfNodePtr) {
auto nodeAbstract = anfNodePtr->abstract();
if (nodeAbstract != nullptr) {
return TypeId::kNumberTypeFloat32; // XXX TODO nodeAbstract->GetTypeTrack()->generic_type_id();
} else {
MS_LOG(WARNING) << "abstract is nullptr, return kTypeUnknown";
return TypeId::kTypeUnknown;
}
}

void TrainSession::Init(lite::Context *context) {
MS_EXCEPTION_IF_NULL(context);
this->context_ = std::make_shared<lite::Context>(context->thread_num_, context->allocator, context->device_ctx_);
}

lite::tensor::Tensor *TrainSession::GetTensorForAnfNode(const AnfNodePtr anf_node) {
lite::tensor::Tensor *out_tensor = tensors_[anf_node];
if (out_tensor == NULL) {
out_tensor = new lite::tensor::Tensor(GetAnfNodeOutTypeId(anf_node),
GetAnfNodeOutDims(anf_node)); //, schema::NodeType_Parameter);
tensors_[anf_node] = out_tensor;
}
return out_tensor;
}

int TrainSession::BuildKernelInputAndOutputFromFuncGraph(const KernelGraphPtr &kernel_graph) {
auto return_node = kernel_graph->get_return();
auto node_list = TopoSort(return_node);
auto model_imp = std::dynamic_pointer_cast<lite::train::ModelImpl>(func_graph_);
for (auto &node : node_list) {
if (!node->isa<CNode>()) {
continue;
}
KernelRelation kernel_relation;
auto cnode = node->cast<CNodePtr>();
kernel_relation.node_full_name = cnode->fullname_with_scope();
kernel_relation.cnode = cnode;
std::vector<int> *cnode_io_indices = model_imp->GetCNodeInputOutputIndices(cnode->fullname_with_scope());
if (cnode_io_indices == NULL) {
MS_LOG(WARNING) << "No IO vectors for " << cnode->fullname_with_scope();
} else {
for (int i = 0; i < cnode_io_indices[1].size(); i++) {
AnfNodePtr anf_node = model_imp->GetAnfNode(cnode_io_indices[1].data()[i]);
kernel_relation.output_tensor.push_back(GetTensorForAnfNode(anf_node));
}
}
lite::tensor::Tensor *tensor_ptr = nullptr;
for (size_t index = 1; index < cnode->inputs().size(); ++index) {
if (cnode->input(index)->isa<CNode>()) {
auto input_cnode = cnode->input(index)->cast<CNodePtr>();
auto input_kernel_relation = kernel_relation_infos_[input_cnode->fullname_with_scope()];
// todo not support multi-outputs kernel sudo as spilt
tensor_ptr = input_kernel_relation.output_tensor.front();
} else if (cnode->input(index)->isa<Parameter>()) {
auto input_parameter = cnode->input(index)->cast<ParameterPtr>();
auto para = input_parameter->default_param();
auto param_value = std::dynamic_pointer_cast<ParamValueLite>(para);
// auto dims = param_value->tensor_shape();
// tensor_ptr = new lite::tensor::Tensor(param_value->tensor_type(), dims); // schema::NodeType_ValueNode);
tensor_ptr = GetTensorForAnfNode(cnode->input(index));
if ((param_value != nullptr) && (param_value->tensor_size() != 0)) {
tensor_ptr->SetData(param_value->tensor_addr());
}
} else if (cnode->input(index)->isa<ValueNode>()) {
auto input_valuenode = cnode->input(index)->cast<ValueNodePtr>();
// tensor_ptr = new lite::tensor::Tensor(GetAnfNodeOutTypeId(input_valuenode),
// GetAnfNodeOutDims(input_valuenode)); // schema::NodeType_Parameter);
tensor_ptr = GetTensorForAnfNode(input_valuenode);
// todo(yankai)
} else {
MS_ASSERT(false);
}
kernel_relation.input_tensor.push_back(tensor_ptr);
}
kernel_relation_infos_[cnode->fullname_with_scope()] = kernel_relation;
}
return 0;
}
#if 0
GraphId TrainSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) {
auto graph_id = graph_sum_;
auto graph = SessionBasic::ConstructKernelGraph(lst, outputs);
MS_EXCEPTION_IF_NULL(graph);

BuildKernel(graph.get());
MS_LOG(INFO) << "Assign kernel address";
runtime_.AssignKernelAddress(graph.get());
return graph_id;
}

GraphId TrainSession::CompileGraph(const char *model_buf, size_t size) { return 0; }
#else
GraphId TrainSession::graph_sum_ = 0;

KernelGraphPtr TrainSession::NewKernelGraph() {
auto graph = std::make_shared<KernelGraph>();
graph->set_graph_id(graph_sum_);
graphs_[graph_sum_++] = graph;
return graph;
}

#endif

std::shared_ptr<KernelGraph> TrainSession::ConstructKernelGraph(const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(func_graph);
auto graph = NewKernelGraph();
graph->set_return(func_graph->get_return());
auto node_list = TopoSort(func_graph->get_return());
std::vector<CNodePtr> cnode_order;
for (const auto &node : node_list) {
MS_EXCEPTION_IF_NULL(node);
if (node->isa<CNode>()) {
auto cn_node = node->cast<CNodePtr>();
cnode_order.push_back(cn_node);
}
}
graph->set_execution_order(cnode_order);
return graph;
}
GraphId TrainSession::CompileGraph(NotNull<FuncGraphPtr> func_graph) {
auto graph = ConstructKernelGraph(func_graph);
func_graph_ = func_graph;
MS_EXCEPTION_IF_NULL(graph);
MS_LOG(INFO) << "Set kernel info";
SetKernelInfo(graph.get());

(void)BuildKernelInputAndOutputFromFuncGraph(graph);
MS_LOG(INFO) << "Build kernel";
auto ret = BuildKernel(graph.get());
if (0 != ret) {
MS_LOG(EXCEPTION) << "BuildKernel failed";
}

// return the graph id to backend
auto graph_id = graph->graph_id();
graphs_[graph_id] = graph;
MS_LOG(INFO) << "Compile graph " << graph_id << " success";
return graph_id;
}

void TrainSession::RunGraph(const GraphId &graph_id, const std::vector<lite::tensor::Tensor *> &inputs,
std::vector<lite::tensor::Tensor *> *outputs) {
auto &kernel_graph = graphs_[graph_id];
MS_EXCEPTION_IF_NULL(kernel_graph);
MS_LOG(INFO) << "Bind input output address";
// runtime_.BindInputOutput(kernel_graph.get(), inputs, outputs); -- will be bound in Run
// auto execution_order = kernel_graph->execution_order();
// Todo : hangangqiang
// Reorder(&execution_order);
// kernel_graph->set_execution_order(execution_order);
MS_LOG(INFO) << "Run graph start";
auto ret = runtime_.Run(kernel_graph.get(), (std::vector<lite::tensor::Tensor *> &)inputs, outputs);
if (!ret) {
MS_LOG(EXCEPTION) << "Run graph failed";
}
MS_LOG(INFO) << "Run graph end";
}

void TrainSession::SetKernelInfo(const KernelGraph *kernel_graph) {
MS_EXCEPTION_IF_NULL(kernel_graph);
auto &kernel_nodes = kernel_graph->execution_order();
for (const auto &kernel_node : kernel_nodes) {
MS_EXCEPTION_IF_NULL(kernel_node);
auto kernel_info = std::make_shared<device::KernelInfo>();
kernel_node->set_kernel_info(kernel_info);
}
}

int TrainSession::BuildKernel(const KernelGraph *kernel_graph) {
MS_EXCEPTION_IF_NULL(kernel_graph);
for (auto iter = kernel_relation_infos_.begin(); iter != kernel_relation_infos_.end(); ++iter) {
std::string kernel_name = iter->first;
KernelRelation anf_register = iter->second;
MS_EXCEPTION_IF_NULL(anf_register.cnode);
if (IsPrimitiveCNode(anf_register.cnode, prim::kPrimReturn)) {
continue;
}
auto value_node_prim = anf_register.cnode->input(0);
MS_EXCEPTION_IF_NULL(value_node_prim);
auto prim = GetValueNode<std::shared_ptr<lite::PrimitiveValue>>(value_node_prim);
MS_EXCEPTION_IF_NULL(prim);
auto node_primitive = (lite::Primitive *)(prim->GetPrimitive());
MS_EXCEPTION_IF_NULL(node_primitive);
auto ret = node_primitive->InferShape(anf_register.input_tensor, anf_register.output_tensor);
if (0 != ret) {
MS_LOG(ERROR) << "InferShape failed, node : " << kernel_name;
return ret;
}
kernel::KernelKey desc{kernel::KERNEL_ARCH::kCPU, kNumberTypeFloat32, node_primitive->Type()};

auto *kernel = lite::KernelFactory::GetInstance()->GetKernel(anf_register.input_tensor, anf_register.output_tensor,
node_primitive, context_.get(), desc);
if (nullptr == kernel) {
MS_LOG(ERROR) << "Create kernel return nullptr, name: " << kernel_name;
return -1;
}
std::shared_ptr<kernel::LiteKernel> kernel_mod(kernel);
kernel_mod->set_name(anf_register.cnode->fullname_with_scope());

// kernel->train();
auto kernel_info = dynamic_cast<device::KernelInfo *>(anf_register.cnode->kernel_info());
MS_EXCEPTION_IF_NULL(kernel_info);
kernel_info->set_kernel_mod(kernel_mod); // XXX TODO -- only derived class KernelInfo has this method
}
return 0;
}
} // namespace session
} // namespace mindspore

+ 0
- 76
mindspore/lite/src/train/train_session.h View File

@@ -1,76 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_TRAIN_TRAIN_SESSION_H_
#define MINDSPORE_LITE_SRC_TRAIN_TRAIN_SESSION_H_
#include <map>
#include <string>
#include <memory>
#include <vector>
#include <unordered_map>
#include "include/context.h"
#include "backend/session/session_basic.h"
#include "backend/session/kernel_graph.h"
#include "mindspore/lite/src/train/lite_kernel_runtime.h"
// #include "backend/session/session_factory.h"
namespace mindspore {
namespace lite::tensor {
class Tensor;
}
namespace session {
struct KernelRelation {
std::string node_full_name;
std::vector<lite::tensor::Tensor *> input_tensor;
std::vector<lite::tensor::Tensor *> output_tensor;
CNodePtr cnode;
};

class TrainSession {
public:
explicit TrainSession(lite::Context * context) { Init(context); }
~TrainSession() = default;

GraphId CompileGraph(NotNull<FuncGraphPtr> func_graph);

void RunGraph(const GraphId &graph_id, const std::vector<lite::tensor::Tensor *> &inputs,
std::vector<lite::tensor::Tensor *> *outputs);

// void RunGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) override
// {};
protected:
void Init(lite::Context *context);
std::shared_ptr<lite::Context> context_ = nullptr;
std::unordered_map<GraphId, std::shared_ptr<KernelGraph>> graphs_;
static GraphId graph_sum_;
KernelGraphPtr NewKernelGraph();

private:
// GraphId CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) override;
// GraphId CompileGraph(const char *model_buf, size_t size);
std::shared_ptr<KernelGraph> ConstructKernelGraph(const FuncGraphPtr &func_graph);
int BuildKernelInputAndOutputFromFuncGraph(const KernelGraphPtr &kernel_graph);

lite::tensor::Tensor *GetTensorForAnfNode(const AnfNodePtr anf_node);

void SetKernelInfo(const KernelGraph *kernel_graph);
int BuildKernel(const KernelGraph *kernel_graph);
lite::LiteInferKernelRuntime runtime_;
std::map<std::string, KernelRelation> kernel_relation_infos_;
FuncGraphPtr func_graph_ = NULL;
std::map<AnfNodePtr, lite::tensor::Tensor *> tensors_;
};
} // namespace session
} // namespace mindspore
#endif // MINDSPORE_LITE_SRC_TRAIN_TRAIN_SESSION_H_

+ 1
- 3
mindspore/lite/test/CMakeLists.txt View File

@@ -175,12 +175,10 @@ set(TEST_LITE_SRC
${LITE_DIR}/src/ir/primitive_t_value.cc
${LITE_DIR}/src/context.cc
${LITE_DIR}/src/executor.cc
${LITE_DIR}/src/kernel_factory.cc
${LITE_DIR}/src/kernel_registry.cc
${LITE_DIR}/src/lite_kernel.cc
${LITE_DIR}/src/lite_session.cc
${LITE_DIR}/src/model.cc
${LITE_DIR}/src/model_impl.cc
${LITE_DIR}/src/populate_parameter.cc
${LITE_DIR}/src/scheduler.cc
${LITE_DIR}/src/common/graph_util.cc
@@ -265,7 +263,7 @@ if (SUPPORT_TRAIN)
# ${LITE_DIR}/src/ir/primitive_value.cc
# ${LITE_DIR}/src/train/lite_kernel_runtime.cc
# ${LITE_DIR}/src/train/train_session.cc
# ${LITE_DIR}/src/train/model_impl.cc
# ${LITE_DIR}/src/train/model.cc
${LITE_DIR}/src/lite_session.cc # temporary
)
else()


+ 2
- 2
mindspore/lite/tools/optimizer/fusion/constant_folding_fusion.cc View File

@@ -20,15 +20,15 @@
#include <algorithm>
#include "schema/inner/model_generated.h"
#include "tools/optimizer/common/gllo_utils.h"
#include "src/kernel_factory.h"
#include "tools/anf_exporter/anf_exporter.h"
#include "src/kernel_registry.h"
#include "src/scheduler.h"
#include "include/context.h"
#include "src/lite_session.h"
#include "src/ir/primitive_t_value.h"
#include "src/populate_parameter.h"

using mindspore::lite::KernelFactory;
using mindspore::lite::KernelRegistry;
using mindspore::lite::tensor::Tensor;
using mindspore::lite::PrimitiveTValue;
namespace mindspore::opt {


Loading…
Cancel
Save