Browse Source

modify static check

tags/v1.1.0
yvette 5 years ago
parent
commit
4b5050b11b
31 changed files with 271 additions and 264 deletions
  1. +3
    -3
      mindspore/lite/src/common/common.h
  2. +5
    -4
      mindspore/lite/src/common/file_utils.h
  3. +4
    -4
      mindspore/lite/src/common/graph_util.h
  4. +3
    -3
      mindspore/lite/src/common/log_adapter.h
  5. +3
    -3
      mindspore/lite/src/common/string_util.h
  6. +3
    -4
      mindspore/lite/src/common/utils.h
  7. +1
    -1
      mindspore/lite/src/executor.h
  8. +2
    -2
      mindspore/lite/src/inner_context.cc
  9. +2
    -4
      mindspore/lite/src/kernel_registry.cc
  10. +1
    -1
      mindspore/lite/src/kernel_registry.h
  11. +7
    -2
      mindspore/lite/src/lite_kernel.cc
  12. +1
    -1
      mindspore/lite/src/lite_kernel.h
  13. +5
    -1
      mindspore/lite/src/lite_session.cc
  14. +1
    -1
      mindspore/lite/src/lite_session.h
  15. +0
    -2
      mindspore/lite/src/ops/pooling_grad.cc
  16. +165
    -165
      mindspore/lite/src/ops/primitive_c.cc
  17. +6
    -6
      mindspore/lite/src/ops/primitive_c.h
  18. +0
    -1
      mindspore/lite/src/ops/sub.cc
  19. +3
    -3
      mindspore/lite/src/param_value_lite.h
  20. +9
    -9
      mindspore/lite/src/runtime/allocator.cc
  21. +1
    -1
      mindspore/lite/src/runtime/allocator.h
  22. +7
    -4
      mindspore/lite/src/runtime/parallel_executor.cc
  23. +5
    -5
      mindspore/lite/src/runtime/parallel_executor.h
  24. +0
    -1
      mindspore/lite/src/runtime/runtime_api.cc
  25. +3
    -3
      mindspore/lite/src/runtime/runtime_api.h
  26. +8
    -12
      mindspore/lite/src/scheduler.cc
  27. +2
    -2
      mindspore/lite/src/scheduler.h
  28. +5
    -1
      mindspore/lite/src/sub_graph_kernel.cc
  29. +1
    -1
      mindspore/lite/src/sub_graph_kernel.h
  30. +6
    -5
      mindspore/lite/src/tensor.cc
  31. +9
    -9
      mindspore/lite/src/tensor.h

+ 3
- 3
mindspore/lite/src/common/common.h View File

@@ -14,8 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */


#ifndef MINDSPORE_LITE_COMMON_COMMON_H_
#define MINDSPORE_LITE_COMMON_COMMON_H_
#ifndef MINDSPORE_LITE_SRC_COMMON_COMMON_H_
#define MINDSPORE_LITE_SRC_COMMON_COMMON_H_


#include <string> #include <string>
#include "src/tensor.h" #include "src/tensor.h"
@@ -56,4 +56,4 @@ static const schema::Format DEFAULT_FORMAT = schema::Format::Format_NCHW;
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore


#endif // MINDSPORE_LITE_COMMON_COMMON_H_
#endif // MINDSPORE_LITE_SRC_COMMON_COMMON_H_

+ 5
- 4
mindspore/lite/src/common/file_utils.h View File

@@ -14,8 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */


#ifndef MINDSPORE_LITE_COMMON_FILE_UTILS_H_
#define MINDSPORE_LITE_COMMON_FILE_UTILS_H_
#ifndef MINDSPORE_LITE_SRC_COMMON_FILE_UTILS_H_
#define MINDSPORE_LITE_SRC_COMMON_FILE_UTILS_H_


#include <cstdio> #include <cstdio>
#include <cstdlib> #include <cstdlib>
@@ -48,13 +48,14 @@ void WriteToTxt(const std::string &file_path, void *data, size_t element_size) {
out_file.close(); out_file.close();
} }


inline int WriteToBin(const std::string &file_path, void *data, size_t size) {
inline int WriteToBin(const std::string &file_path, void *data, const size_t size) {
std::ofstream out_file; std::ofstream out_file;
out_file.open(file_path.c_str(), std::ios::binary); out_file.open(file_path.c_str(), std::ios::binary);
if (!out_file.good() || !out_file.is_open()) { if (!out_file.good() || !out_file.is_open()) {
return -1; return -1;
} }
out_file.write(reinterpret_cast<char *>(data), size); out_file.write(reinterpret_cast<char *>(data), size);
out_file.close();
return 0; return 0;
} }


@@ -63,4 +64,4 @@ std::string GetAndroidPackagePath();
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore


#endif // MINDSPORE_LITE_COMMON_FILE_UTILS_H_
#endif // MINDSPORE_LITE_SRC_COMMON_FILE_UTILS_H_

+ 4
- 4
mindspore/lite/src/common/graph_util.h View File

@@ -14,8 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */


#ifndef MINDSPORE_LITE_COMMON_GRAPH_UTIL_H_
#define MINDSPORE_LITE_COMMON_GRAPH_UTIL_H_
#ifndef MINDSPORE_LITE_SRC_COMMON_GRAPH_UTIL_H_
#define MINDSPORE_LITE_SRC_COMMON_GRAPH_UTIL_H_


#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
@@ -35,8 +35,8 @@ std::vector<size_t> GetGraphInputNodes(const lite::Model *model);


std::vector<size_t> GetGraphOutputNodes(const lite::Model *model); std::vector<size_t> GetGraphOutputNodes(const lite::Model *model);


std::vector<size_t> GetLinkedPostNodeIdx(const lite::Model *model, const size_t tensor_idx);
std::vector<size_t> GetLinkedPostNodeIdx(const lite::Model *model, size_t tensor_idx);
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore


#endif // MINDSPORE_LITE_COMMON_GRAPH_UTIL_H_
#endif // MINDSPORE_LITE_SRC_COMMON_GRAPH_UTIL_H_

+ 3
- 3
mindspore/lite/src/common/log_adapter.h View File

@@ -14,8 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */


#ifndef MINDSPORE_LITE_COMMON_LOG_ADAPTER_H_
#define MINDSPORE_LITE_COMMON_LOG_ADAPTER_H_
#ifndef MINDSPORE_LITE_SRC_COMMON_LOG_ADAPTER_H_
#define MINDSPORE_LITE_SRC_COMMON_LOG_ADAPTER_H_
#ifdef USE_GLOG #ifdef USE_GLOG
#include "utils/log_adapter.h" #include "utils/log_adapter.h"
#else #else
@@ -113,4 +113,4 @@ class LogWriter {
#define MS_ASSERT(f) ((void)0) #define MS_ASSERT(f) ((void)0)
#endif #endif
#endif #endif
#endif // MINDSPORE_LITE_COMMON_LOG_ADAPTER_H_
#endif // MINDSPORE_LITE_SRC_COMMON_LOG_ADAPTER_H_

+ 3
- 3
mindspore/lite/src/common/string_util.h View File

@@ -14,8 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */


#ifndef MINDSPORE_LITE_COMMON_STRING_UTIL_H_
#define MINDSPORE_LITE_COMMON_STRING_UTIL_H_
#ifndef MINDSPORE_LITE_SRC_COMMON_STRING_UTIL_H_
#define MINDSPORE_LITE_SRC_COMMON_STRING_UTIL_H_


#include <vector> #include <vector>
#include <string> #include <string>
@@ -52,4 +52,4 @@ uint64_t StringHash64(const char *s, size_t len);
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore


#endif // MINDSPORE_LITE_COMMON_STRING_UTIL_H_
#endif // MINDSPORE_LITE_SRC_COMMON_STRING_UTIL_H_

+ 3
- 4
mindspore/lite/src/common/utils.h View File

@@ -14,10 +14,9 @@
* limitations under the License. * limitations under the License.
*/ */


#ifndef MINDSPORE_LITE_COMMON_UTILS_H_
#define MINDSPORE_LITE_COMMON_UTILS_H_
#ifndef MINDSPORE_LITE_SRC_COMMON_UTILS_H_
#define MINDSPORE_LITE_SRC_COMMON_UTILS_H_


#include <stdint.h>
#include <ctime> #include <ctime>
#include <cstdint> #include <cstdint>
#include <vector> #include <vector>
@@ -185,4 +184,4 @@ inline Option<bool> GenericParseValue(const std::string &value) {
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore


#endif // MINDSPORE_LITE_COMMON_UTILS_H_
#endif // MINDSPORE_LITE_SRC_COMMON_UTILS_H_

+ 1
- 1
mindspore/lite/src/executor.h View File

@@ -35,7 +35,7 @@ class Executor {
const KernelCallBack &before = nullptr, const KernelCallBack &after = nullptr); const KernelCallBack &before = nullptr, const KernelCallBack &after = nullptr);


protected: protected:
int CheckInputs(std::vector<Tensor *> &in_tensors);
static int CheckInputs(std::vector<Tensor *> &in_tensors);
}; };
} // namespace mindspore::lite } // namespace mindspore::lite
#endif #endif

+ 2
- 2
mindspore/lite/src/inner_context.cc View File

@@ -52,10 +52,10 @@ int InnerContext::Init() {
} }


InnerContext::~InnerContext() { InnerContext::~InnerContext() {
if (this->thread_pool_ != NULL) {
if (this->thread_pool_ != nullptr) {
DestroyThreadPool(this->thread_pool_); DestroyThreadPool(this->thread_pool_);
free(this->thread_pool_); free(this->thread_pool_);
this->thread_pool_ = NULL;
this->thread_pool_ = nullptr;
} }
} }




+ 2
- 4
mindspore/lite/src/kernel_registry.cc View File

@@ -73,7 +73,7 @@ int KernelRegistry::GetCreatorFuncIndex(const kernel::KernelKey desc) {
return index; return index;
} }


void KernelRegistry::RegKernel(const KernelKey desc, kernel::KernelCreator creator) {
void KernelRegistry::RegKernel(const KernelKey desc, const kernel::KernelCreator creator) {
int index = GetCreatorFuncIndex(desc); int index = GetCreatorFuncIndex(desc);
if (index >= array_size_) { if (index >= array_size_) {
MS_LOG(ERROR) << "invalid kernel key, arch " << desc.arch << ", data_type" << desc.data_type << ",op type " MS_LOG(ERROR) << "invalid kernel key, arch " << desc.arch << ", data_type" << desc.data_type << ",op type "
@@ -97,8 +97,6 @@ void KernelRegistry::RegKernel(const KERNEL_ARCH arch, const TypeId data_type, c


bool KernelRegistry::Merge(const std::unordered_map<KernelKey, KernelCreator> &new_creators) { return false; } bool KernelRegistry::Merge(const std::unordered_map<KernelKey, KernelCreator> &new_creators) { return false; }


const kernel::KernelCreator *KernelRegistry::GetCreatorArrays() { return creator_arrays_; }

kernel::LiteKernel *KernelRegistry::GetKernel(const std::vector<Tensor *> &in_tensors, kernel::LiteKernel *KernelRegistry::GetKernel(const std::vector<Tensor *> &in_tensors,
const std::vector<Tensor *> &out_tensors, const PrimitiveC *primitive, const std::vector<Tensor *> &out_tensors, const PrimitiveC *primitive,
const InnerContext *ctx, const kernel::KernelKey &key) { const InnerContext *ctx, const kernel::KernelKey &key) {
@@ -124,5 +122,5 @@ kernel::LiteKernel *KernelRegistry::GetKernel(const std::vector<Tensor *> &in_te
return nullptr; return nullptr;
} }


KernelRegistry::~KernelRegistry() {}
KernelRegistry::~KernelRegistry() = default;
} // namespace mindspore::lite } // namespace mindspore::lite

+ 1
- 1
mindspore/lite/src/kernel_registry.h View File

@@ -35,7 +35,7 @@ class KernelRegistry {
virtual ~KernelRegistry(); virtual ~KernelRegistry();


static KernelRegistry *GetInstance(); static KernelRegistry *GetInstance();
int Init();
static int Init();
virtual kernel::KernelCreator GetCreator(const kernel::KernelKey &desc); virtual kernel::KernelCreator GetCreator(const kernel::KernelKey &desc);
const kernel::KernelCreator *GetCreatorArrays(); const kernel::KernelCreator *GetCreatorArrays();
int GetCreatorFuncIndex(kernel::KernelKey desc); int GetCreatorFuncIndex(kernel::KernelKey desc);


+ 7
- 2
mindspore/lite/src/lite_kernel.cc View File

@@ -93,7 +93,12 @@ int LiteKernel::PreProcess() {
auto outputs = this->out_tensors(); auto outputs = this->out_tensors();
for (auto *output : outputs) { for (auto *output : outputs) {
MS_ASSERT(output != nullptr); MS_ASSERT(output != nullptr);
output->MallocData();

auto ret = output->MallocData();
if (ret != RET_OK) {
MS_LOG(ERROR) << "MallocData failed";
return ret;
}
} }
return RET_OK; return RET_OK;
} }
@@ -308,5 +313,5 @@ void LiteKernelUtil::InitTensorRefCount(std::vector<kernel::LiteKernel *> &kerne
} }
} }


int LiteKernelUtil::SetInput(LiteKernel &kernelMod, std::vector<lite::Tensor *> inputs) { return -1; }
int LiteKernelUtil::SetInput(LiteKernel &kernelMod, const std::vector<lite::Tensor *> &inputs) { return -1; }
} // namespace mindspore::kernel } // namespace mindspore::kernel

+ 1
- 1
mindspore/lite/src/lite_kernel.h View File

@@ -217,7 +217,7 @@ class LiteKernelUtil {


static void InitTensorRefCount(std::vector<kernel::LiteKernel *> &kernels); static void InitTensorRefCount(std::vector<kernel::LiteKernel *> &kernels);


static int SetInput(LiteKernel &kernelMod, std::vector<lite::Tensor *> inputs);
static int SetInput(LiteKernel &kernelMod, const std::vector<lite::Tensor *> &inputs);
}; };
} // namespace mindspore::kernel } // namespace mindspore::kernel




+ 5
- 1
mindspore/lite/src/lite_session.cc View File

@@ -571,7 +571,11 @@ int LiteSession::Resize(const std::vector<mindspore::tensor::MSTensor *> &inputs
} // namespace lite } // namespace lite


session::LiteSession *session::LiteSession::CreateSession(const lite::Context *context) { session::LiteSession *session::LiteSession::CreateSession(const lite::Context *context) {
auto session = new lite::LiteSession();
auto session = new (std::nothrow) lite::LiteSession();
if (session == nullptr) {
MS_LOG(ERROR) << "create sesssion failed";
return nullptr;
}
auto ret = session->Init(context); auto ret = session->Init(context);
if (ret != mindspore::lite::RET_OK) { if (ret != mindspore::lite::RET_OK) {
MS_LOG(ERROR) << "init sesssion failed"; MS_LOG(ERROR) << "init sesssion failed";


+ 1
- 1
mindspore/lite/src/lite_session.h View File

@@ -66,7 +66,7 @@ class LiteSession : public session::LiteSession {
const std::vector<std::vector<int>> &dims) override; const std::vector<std::vector<int>> &dims) override;


protected: protected:
void ConvertTensorsQuantParam(const schema::Tensor *src_tensor, lite::Tensor *dst_tensor);
static void ConvertTensorsQuantParam(const schema::Tensor *src_tensor, lite::Tensor *dst_tensor);


int ConvertTensorsData(const lite::Model *model, size_t tensor_index, const schema::Tensor *src_tensor, int ConvertTensorsData(const lite::Model *model, size_t tensor_index, const schema::Tensor *src_tensor,
lite::Tensor *dst_tensor); lite::Tensor *dst_tensor);


+ 0
- 2
mindspore/lite/src/ops/pooling_grad.cc View File

@@ -198,11 +198,9 @@ int PoolingGrad::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *>
} }
} }
auto grad_output = outputs_.at(0); auto grad_output = outputs_.at(0);
// todo: fmk type
auto output_shape = input->shape(); auto output_shape = input->shape();
grad_output->set_shape(output_shape); grad_output->set_shape(output_shape);
grad_output->set_data_type(input->data_type()); grad_output->set_data_type(input->data_type());
// todo: temp fix
grad_output->set_format(input->format()); grad_output->set_format(input->format());
return RET_OK; return RET_OK;
} }


+ 165
- 165
mindspore/lite/src/ops/primitive_c.cc View File

@@ -266,8 +266,8 @@ void PrimitiveC::PopulaterInputQuantParam(const Primitive &prim, const std::vect
if (filterMin != nullptr && filterMax != nullptr) { if (filterMin != nullptr && filterMax != nullptr) {
auto filterMinPtr = filterMin->cast<TensorPtr>(); auto filterMinPtr = filterMin->cast<TensorPtr>();
auto filterMaxPtr = filterMax->cast<TensorPtr>(); auto filterMaxPtr = filterMax->cast<TensorPtr>();
float *minBuf = static_cast<float *>(filterMinPtr->data_c());
float *maxBuf = static_cast<float *>(filterMaxPtr->data_c());
auto *minBuf = static_cast<float *>(filterMinPtr->data_c());
auto *maxBuf = static_cast<float *>(filterMaxPtr->data_c());
quantParam.min = FLT_MAX; quantParam.min = FLT_MAX;
quantParam.max = FLT_MIN; quantParam.max = FLT_MIN;
for (int i = 0; i < filterMinPtr->ElementsNum(); ++i) { for (int i = 0; i < filterMinPtr->ElementsNum(); ++i) {
@@ -296,8 +296,8 @@ void PrimitiveC::PopulaterOutputQuantParam(const Primitive &prim, bool narrowRan
if (outputMin != nullptr && outputMax != nullptr) { if (outputMin != nullptr && outputMax != nullptr) {
auto outputMinPtr = outputMin->cast<TensorPtr>(); auto outputMinPtr = outputMin->cast<TensorPtr>();
auto outputMaxPtr = outputMax->cast<TensorPtr>(); auto outputMaxPtr = outputMax->cast<TensorPtr>();
float *minBuf = static_cast<float *>(outputMinPtr->data_c());
float *maxBuf = static_cast<float *>(outputMaxPtr->data_c());
auto *minBuf = static_cast<float *>(outputMinPtr->data_c());
auto *maxBuf = static_cast<float *>(outputMaxPtr->data_c());
quantParam.min = *minBuf; quantParam.min = *minBuf;
quantParam.max = *maxBuf; quantParam.max = *maxBuf;
auto ret = quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, narrowRangeQuantParam, auto ret = quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, narrowRangeQuantParam,
@@ -317,14 +317,14 @@ void PrimitiveC::PopulaterOutputQuantParam(const Primitive &prim, bool narrowRan


void PrimitiveC::PopulaterQuantParam(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { void PrimitiveC::PopulaterQuantParam(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
auto narrow_range = prim.GetAttr("narrow_range"); auto narrow_range = prim.GetAttr("narrow_range");
bool narrowRangeQuantParam = narrow_range != nullptr ? GetValue<bool>(narrow_range) : false;
bool narrowRangeQuantParam = narrow_range != nullptr && GetValue<bool>(narrow_range);
auto num_bits = prim.GetAttr("num_bits"); auto num_bits = prim.GetAttr("num_bits");
int32_t numbitsRangeQuantParam = num_bits != nullptr ? GetValue<int64_t>(num_bits) : 8; int32_t numbitsRangeQuantParam = num_bits != nullptr ? GetValue<int64_t>(num_bits) : 8;
PopulaterInputQuantParam(prim, inputs, narrowRangeQuantParam, numbitsRangeQuantParam); PopulaterInputQuantParam(prim, inputs, narrowRangeQuantParam, numbitsRangeQuantParam);
PopulaterOutputQuantParam(prim, narrowRangeQuantParam, numbitsRangeQuantParam); PopulaterOutputQuantParam(prim, narrowRangeQuantParam, numbitsRangeQuantParam);
} }


void PrimitiveC::GetAttrDataFromInput(const AnfNodePtr inputNode, std::vector<int> *data) {
void PrimitiveC::GetAttrDataFromInput(const AnfNodePtr &inputNode, std::vector<int> *data) {
if (inputNode->isa<ValueNode>()) { if (inputNode->isa<ValueNode>()) {
auto valNode = inputNode->cast<ValueNodePtr>(); auto valNode = inputNode->cast<ValueNodePtr>();
MS_ASSERT(valNode != nullptr); MS_ASSERT(valNode != nullptr);
@@ -394,12 +394,12 @@ void PrimitiveC::ClearInputOutputQuantParam() {
output_quant_param_.clear(); output_quant_param_.clear();
} }


void PrimitiveC::AddInputQuantParam(std::vector<schema::QuantParamT> quant_param) {
void PrimitiveC::AddInputQuantParam(const std::vector<schema::QuantParamT> &quant_param) {
this->input_quant_param_.emplace_back(quant_param); this->input_quant_param_.emplace_back(quant_param);
} }
std::vector<std::vector<schema::QuantParamT>> PrimitiveC::input_quant_params() const { return input_quant_param_; } std::vector<std::vector<schema::QuantParamT>> PrimitiveC::input_quant_params() const { return input_quant_param_; }


void PrimitiveC::AddOutputQuantParam(std::vector<schema::QuantParamT> quant_param) {
void PrimitiveC::AddOutputQuantParam(const std::vector<schema::QuantParamT> &quant_param) {
this->output_quant_param_.emplace_back(quant_param); this->output_quant_param_.emplace_back(quant_param);
} }
std::vector<std::vector<schema::QuantParamT>> PrimitiveC::output_quant_params() const { return output_quant_param_; } std::vector<std::vector<schema::QuantParamT>> PrimitiveC::output_quant_params() const { return output_quant_param_; }
@@ -415,7 +415,7 @@ std::shared_ptr<PrimitiveC> GetReturnPrim() {
return nullptr; return nullptr;
} }
return_primitiveT->value.type = schema::PrimitiveType_Return; return_primitiveT->value.type = schema::PrimitiveType_Return;
return_primitiveT->value.value = new schema::ReturnT;
return_primitiveT->value.value = new (std::nothrow) schema::ReturnT;
if (return_primitiveT->value.value == nullptr) { if (return_primitiveT->value.value == nullptr) {
MS_LOG(ERROR) << "new ReturnT failed"; MS_LOG(ERROR) << "new ReturnT failed";
delete (return_primitiveT); delete (return_primitiveT);
@@ -425,13 +425,13 @@ std::shared_ptr<PrimitiveC> GetReturnPrim() {
} }


std::shared_ptr<PrimitiveC> GetMakeTuplePrim() { std::shared_ptr<PrimitiveC> GetMakeTuplePrim() {
auto make_tuple_primitiveT = new schema::PrimitiveT;
auto make_tuple_primitiveT = new (std::nothrow) schema::PrimitiveT;
if (make_tuple_primitiveT == nullptr) { if (make_tuple_primitiveT == nullptr) {
MS_LOG(ERROR) << "new PrimitiveT failed"; MS_LOG(ERROR) << "new PrimitiveT failed";
return nullptr; return nullptr;
} }
make_tuple_primitiveT->value.type = schema::PrimitiveType_MakeTuple; make_tuple_primitiveT->value.type = schema::PrimitiveType_MakeTuple;
make_tuple_primitiveT->value.value = new schema::MakeTupleT;
make_tuple_primitiveT->value.value = new (std::nothrow) schema::MakeTupleT;
if (make_tuple_primitiveT->value.value == nullptr) { if (make_tuple_primitiveT->value.value == nullptr) {
MS_LOG(ERROR) << "new MakeTupleT failed"; MS_LOG(ERROR) << "new MakeTupleT failed";
delete (make_tuple_primitiveT); delete (make_tuple_primitiveT);
@@ -441,13 +441,13 @@ std::shared_ptr<PrimitiveC> GetMakeTuplePrim() {
} }


std::shared_ptr<PrimitiveC> GetTupleGetItemPrim() { std::shared_ptr<PrimitiveC> GetTupleGetItemPrim() {
auto tuple_get_item_primitiveT = new schema::PrimitiveT();
auto tuple_get_item_primitiveT = new (std::nothrow) schema::PrimitiveT();
if (tuple_get_item_primitiveT == nullptr) { if (tuple_get_item_primitiveT == nullptr) {
MS_LOG(ERROR) << "new PrimitiveT failed"; MS_LOG(ERROR) << "new PrimitiveT failed";
return nullptr; return nullptr;
} }
tuple_get_item_primitiveT->value.type = schema::PrimitiveType_TupleGetItem; tuple_get_item_primitiveT->value.type = schema::PrimitiveType_TupleGetItem;
tuple_get_item_primitiveT->value.value = new schema::TupleGetItemT;
tuple_get_item_primitiveT->value.value = new (std::nothrow) schema::TupleGetItemT;
if (tuple_get_item_primitiveT->value.value == nullptr) { if (tuple_get_item_primitiveT->value.value == nullptr) {
MS_LOG(ERROR) << "new TupleGetItemT failed"; MS_LOG(ERROR) << "new TupleGetItemT failed";
delete (tuple_get_item_primitiveT); delete (tuple_get_item_primitiveT);
@@ -642,316 +642,316 @@ PrimitiveC *PrimitiveC::Create(mindspore::schema::PrimitiveT *primitive) {
auto op_type = primitive->value.type; auto op_type = primitive->value.type;
switch (op_type) { switch (op_type) {
case schema::PrimitiveType_SoftMax: case schema::PrimitiveType_SoftMax:
return new SoftMax(primitive);
return new (std::nothrow) SoftMax(primitive);
case schema::PrimitiveType_Activation: case schema::PrimitiveType_Activation:
return new Activation(primitive);
return new (std::nothrow) Activation(primitive);
case schema::PrimitiveType_Conv2D: case schema::PrimitiveType_Conv2D:
return new Conv2D(primitive);
return new (std::nothrow) Conv2D(primitive);
case schema::PrimitiveType_DeConv2D: case schema::PrimitiveType_DeConv2D:
return new DeConv2D(primitive);
return new (std::nothrow) DeConv2D(primitive);
case schema::PrimitiveType_Reduce: case schema::PrimitiveType_Reduce:
return new Reduce(primitive);
return new (std::nothrow) Reduce(primitive);
case schema::PrimitiveType_Pooling: case schema::PrimitiveType_Pooling:
return new Pooling(primitive);
return new (std::nothrow) Pooling(primitive);
case schema::PrimitiveType_ROIPooling: case schema::PrimitiveType_ROIPooling:
return new ROIPooling(primitive);
return new (std::nothrow) ROIPooling(primitive);
case schema::PrimitiveType_DepthwiseConv2D: case schema::PrimitiveType_DepthwiseConv2D:
return new DepthwiseConv2D(primitive);
return new (std::nothrow) DepthwiseConv2D(primitive);
case schema::PrimitiveType_FusedBatchNorm: case schema::PrimitiveType_FusedBatchNorm:
return new FusedBatchNorm(primitive);
return new (std::nothrow) FusedBatchNorm(primitive);
case schema::PrimitiveType_BatchNorm: case schema::PrimitiveType_BatchNorm:
return new BatchNorm(primitive);
return new (std::nothrow) BatchNorm(primitive);
case schema::PrimitiveType_FullConnection: case schema::PrimitiveType_FullConnection:
return new FullConnection(primitive);
return new (std::nothrow) FullConnection(primitive);
case schema::PrimitiveType_Power: case schema::PrimitiveType_Power:
return new Power(primitive);
return new (std::nothrow) Power(primitive);
case schema::PrimitiveType_Pad: case schema::PrimitiveType_Pad:
return new Pad(primitive);
return new (std::nothrow) Pad(primitive);
case schema::PrimitiveType_Range: case schema::PrimitiveType_Range:
return new Range(primitive);
return new (std::nothrow) Range(primitive);
case schema::PrimitiveType_Mul: case schema::PrimitiveType_Mul:
return new Mul(primitive);
return new (std::nothrow) Mul(primitive);
case schema::PrimitiveType_Add: case schema::PrimitiveType_Add:
return new Add(primitive);
return new (std::nothrow) Add(primitive);
case schema::PrimitiveType_Sub: case schema::PrimitiveType_Sub:
return new Sub(primitive);
return new (std::nothrow) Sub(primitive);
case schema::PrimitiveType_Div: case schema::PrimitiveType_Div:
return new Div(primitive);
return new (std::nothrow) Div(primitive);
case schema::PrimitiveType_BiasAdd: case schema::PrimitiveType_BiasAdd:
return new BiasAdd(primitive);
return new (std::nothrow) BiasAdd(primitive);
case schema::PrimitiveType_ExpandDims: case schema::PrimitiveType_ExpandDims:
return new ExpandDims(primitive);
return new (std::nothrow) ExpandDims(primitive);
case schema::PrimitiveType_ArgMax: case schema::PrimitiveType_ArgMax:
return new ArgMax(primitive);
return new (std::nothrow) ArgMax(primitive);
case schema::PrimitiveType_ArgMin: case schema::PrimitiveType_ArgMin:
return new ArgMin(primitive);
return new (std::nothrow) ArgMin(primitive);
case schema::PrimitiveType_Cast: case schema::PrimitiveType_Cast:
return new Cast(primitive);
return new (std::nothrow) Cast(primitive);
case schema::PrimitiveType_Reshape: case schema::PrimitiveType_Reshape:
return new Reshape(primitive);
return new (std::nothrow) Reshape(primitive);
case schema::PrimitiveType_Scale: case schema::PrimitiveType_Scale:
return new Scale(primitive);
return new (std::nothrow) Scale(primitive);
case schema::PrimitiveType_Eltwise: case schema::PrimitiveType_Eltwise:
return new Eltwise(primitive);
return new (std::nothrow) Eltwise(primitive);
case schema::PrimitiveType_Ceil: case schema::PrimitiveType_Ceil:
return new Ceil(primitive);
return new (std::nothrow) Ceil(primitive);
case schema::PrimitiveType_Concat: case schema::PrimitiveType_Concat:
return new Concat(primitive);
return new (std::nothrow) Concat(primitive);
case schema::PrimitiveType_Fill: case schema::PrimitiveType_Fill:
return new Fill(primitive);
return new (std::nothrow) Fill(primitive);
case schema::PrimitiveType_Nhwc2Nchw: case schema::PrimitiveType_Nhwc2Nchw:
return new Nhwc2Nchw(primitive);
return new (std::nothrow) Nhwc2Nchw(primitive);
case schema::PrimitiveType_Nchw2Nhwc: case schema::PrimitiveType_Nchw2Nhwc:
return new Nchw2Nhwc(primitive);
return new (std::nothrow) Nchw2Nhwc(primitive);
case schema::PrimitiveType_Transpose: case schema::PrimitiveType_Transpose:
return new Transpose(primitive);
return new (std::nothrow) Transpose(primitive);
case schema::PrimitiveType_Slice: case schema::PrimitiveType_Slice:
return new Slice(primitive);
return new (std::nothrow) Slice(primitive);
case schema::PrimitiveType_Squeeze: case schema::PrimitiveType_Squeeze:
return new Squeeze(primitive);
return new (std::nothrow) Squeeze(primitive);
case schema::PrimitiveType_Flatten: case schema::PrimitiveType_Flatten:
return new Flatten(primitive);
return new (std::nothrow) Flatten(primitive);
case schema::PrimitiveType_Mean: case schema::PrimitiveType_Mean:
return new Mean(primitive);
return new (std::nothrow) Mean(primitive);
case schema::PrimitiveType_Stack: case schema::PrimitiveType_Stack:
return new Stack(primitive);
return new (std::nothrow) Stack(primitive);
case schema::PrimitiveType_Crop: case schema::PrimitiveType_Crop:
return new Crop(primitive);
return new (std::nothrow) Crop(primitive);
case schema::PrimitiveType_SquaredDifference: case schema::PrimitiveType_SquaredDifference:
return new SquaredDifference(primitive);
return new (std::nothrow) SquaredDifference(primitive);
case schema::PrimitiveType_AddN: case schema::PrimitiveType_AddN:
return new AddN(primitive);
return new (std::nothrow) AddN(primitive);
case schema::PrimitiveType_Abs: case schema::PrimitiveType_Abs:
return new Abs(primitive);
return new (std::nothrow) Abs(primitive);
case schema::PrimitiveType_Sin: case schema::PrimitiveType_Sin:
return new Sin(primitive);
return new (std::nothrow) Sin(primitive);
case schema::PrimitiveType_Cos: case schema::PrimitiveType_Cos:
return new Cos(primitive);
return new (std::nothrow) Cos(primitive);
case schema::PrimitiveType_Log: case schema::PrimitiveType_Log:
return new Log(primitive);
return new (std::nothrow) Log(primitive);
case schema::PrimitiveType_Sqrt: case schema::PrimitiveType_Sqrt:
return new Sqrt(primitive);
return new (std::nothrow) Sqrt(primitive);
case schema::PrimitiveType_Rsqrt: case schema::PrimitiveType_Rsqrt:
return new Rsqrt(primitive);
return new (std::nothrow) Rsqrt(primitive);
case schema::PrimitiveType_Square: case schema::PrimitiveType_Square:
return new Square(primitive);
return new (std::nothrow) Square(primitive);
case schema::PrimitiveType_Exp: case schema::PrimitiveType_Exp:
return new Exp(primitive);
return new (std::nothrow) Exp(primitive);
case schema::PrimitiveType_Gather: case schema::PrimitiveType_Gather:
return new Gather(primitive);
return new (std::nothrow) Gather(primitive);
case schema::PrimitiveType_GatherNd: case schema::PrimitiveType_GatherNd:
return new GatherNd(primitive);
return new (std::nothrow) GatherNd(primitive);
case schema::PrimitiveType_LocalResponseNormalization: case schema::PrimitiveType_LocalResponseNormalization:
return new LocalResponseNormalization(primitive);
return new (std::nothrow) LocalResponseNormalization(primitive);
case schema::PrimitiveType_Maximum: case schema::PrimitiveType_Maximum:
return new Maximum(primitive);
return new (std::nothrow) Maximum(primitive);
case schema::PrimitiveType_Minimum: case schema::PrimitiveType_Minimum:
return new Minimum(primitive);
return new (std::nothrow) Minimum(primitive);
case schema::PrimitiveType_StridedSlice: case schema::PrimitiveType_StridedSlice:
return new StridedSlice(primitive);
return new (std::nothrow) StridedSlice(primitive);
case schema::PrimitiveType_LeakyReLU: case schema::PrimitiveType_LeakyReLU:
return new (std::nothrow) LeakyReLU(primitive); return new (std::nothrow) LeakyReLU(primitive);
case schema::PrimitiveType_PReLU: case schema::PrimitiveType_PReLU:
return new (std::nothrow) PReLU(primitive); return new (std::nothrow) PReLU(primitive);
case schema::PrimitiveType_Round: case schema::PrimitiveType_Round:
return new Round(primitive);
return new (std::nothrow) Round(primitive);
case schema::PrimitiveType_Reverse: case schema::PrimitiveType_Reverse:
return new Reverse(primitive);
return new (std::nothrow) Reverse(primitive);
case schema::PrimitiveType_ReverseSequence: case schema::PrimitiveType_ReverseSequence:
return new ReverseSequence(primitive);
return new (std::nothrow) ReverseSequence(primitive);
case schema::PrimitiveType_LogicalAnd: case schema::PrimitiveType_LogicalAnd:
return new LogicalAnd(primitive);
return new (std::nothrow) LogicalAnd(primitive);
case schema::PrimitiveType_LogicalOr: case schema::PrimitiveType_LogicalOr:
return new LogicalOr(primitive);
return new (std::nothrow) LogicalOr(primitive);
case schema::PrimitiveType_LogicalNot: case schema::PrimitiveType_LogicalNot:
return new LogicalNot(primitive);
return new (std::nothrow) LogicalNot(primitive);
case schema::PrimitiveType_FloorDiv: case schema::PrimitiveType_FloorDiv:
return new FloorDiv(primitive);
return new (std::nothrow) FloorDiv(primitive);
case schema::PrimitiveType_FloorMod: case schema::PrimitiveType_FloorMod:
return new FloorMod(primitive);
return new (std::nothrow) FloorMod(primitive);
case schema::PrimitiveType_Equal: case schema::PrimitiveType_Equal:
return new Equal(primitive);
return new (std::nothrow) Equal(primitive);
case schema::PrimitiveType_NotEqual: case schema::PrimitiveType_NotEqual:
return new NotEqual(primitive);
return new (std::nothrow) NotEqual(primitive);
case schema::PrimitiveType_Less: case schema::PrimitiveType_Less:
return new Less(primitive);
return new (std::nothrow) Less(primitive);
case schema::PrimitiveType_LessEqual: case schema::PrimitiveType_LessEqual:
return new LessEqual(primitive);
return new (std::nothrow) LessEqual(primitive);
case schema::PrimitiveType_Greater: case schema::PrimitiveType_Greater:
return new Greater(primitive);
return new (std::nothrow) Greater(primitive);
case schema::PrimitiveType_GreaterEqual: case schema::PrimitiveType_GreaterEqual:
return new GreaterEqual(primitive);
return new (std::nothrow) GreaterEqual(primitive);
case schema::PrimitiveType_Floor: case schema::PrimitiveType_Floor:
return new Floor(primitive);
return new (std::nothrow) Floor(primitive);
case schema::PrimitiveType_Split: case schema::PrimitiveType_Split:
return new Split(primitive);
return new (std::nothrow) Split(primitive);
case schema::PrimitiveType_OneHot: case schema::PrimitiveType_OneHot:
return new OneHot(primitive);
return new (std::nothrow) OneHot(primitive);
case schema::PrimitiveType_PriorBox: case schema::PrimitiveType_PriorBox:
return new PriorBox(primitive);
return new (std::nothrow) PriorBox(primitive);
case schema::PrimitiveType_SpaceToDepth: case schema::PrimitiveType_SpaceToDepth:
return new SpaceToDepth(primitive);
return new (std::nothrow) SpaceToDepth(primitive);
case schema::PrimitiveType_Tile: case schema::PrimitiveType_Tile:
return new Tile(primitive);
return new (std::nothrow) Tile(primitive);
case schema::PrimitiveType_Resize: case schema::PrimitiveType_Resize:
return new Resize(primitive);
return new (std::nothrow) Resize(primitive);
case schema::PrimitiveType_Unstack: case schema::PrimitiveType_Unstack:
return new Unstack(primitive);
return new (std::nothrow) Unstack(primitive);
case schema::PrimitiveType_Unique: case schema::PrimitiveType_Unique:
return new Unique(primitive);
return new (std::nothrow) Unique(primitive);
case schema::PrimitiveType_TopK: case schema::PrimitiveType_TopK:
return new TopK(primitive);
return new (std::nothrow) TopK(primitive);
case schema::PrimitiveType_MatMul: case schema::PrimitiveType_MatMul:
return new MatMul(primitive);
return new (std::nothrow) MatMul(primitive);
case schema::PrimitiveType_QuantDTypeCast: case schema::PrimitiveType_QuantDTypeCast:
return new QuantDTypeCast(primitive);
return new (std::nothrow) QuantDTypeCast(primitive);
case schema::PrimitiveType_EmbeddingLookup: case schema::PrimitiveType_EmbeddingLookup:
return new EmbeddingLookup(primitive);
return new (std::nothrow) EmbeddingLookup(primitive);
case schema::PrimitiveType_Elu: case schema::PrimitiveType_Elu:
return new Elu(primitive);
return new (std::nothrow) Elu(primitive);
case schema::PrimitiveType_DeDepthwiseConv2D: case schema::PrimitiveType_DeDepthwiseConv2D:
return new DeDepthwiseConv2D(primitive);
return new (std::nothrow) DeDepthwiseConv2D(primitive);
case schema::PrimitiveType_Shape: case schema::PrimitiveType_Shape:
return new Shape(primitive);
return new (std::nothrow) Shape(primitive);
case schema::PrimitiveType_Unsqueeze: case schema::PrimitiveType_Unsqueeze:
return new Unsqueeze(primitive);
return new (std::nothrow) Unsqueeze(primitive);
case schema::PrimitiveType_BatchToSpace: case schema::PrimitiveType_BatchToSpace:
case schema::PrimitiveType_BatchToSpaceND: case schema::PrimitiveType_BatchToSpaceND:
return new BatchToSpace(primitive);
return new (std::nothrow) BatchToSpace(primitive);
case schema::PrimitiveType_SpaceToBatch: case schema::PrimitiveType_SpaceToBatch:
return new SpaceToBatch(primitive);
return new (std::nothrow) SpaceToBatch(primitive);
case schema::PrimitiveType_SpaceToBatchND: case schema::PrimitiveType_SpaceToBatchND:
return new SpaceToBatchND(primitive);
return new (std::nothrow) SpaceToBatchND(primitive);
case schema::PrimitiveType_BroadcastTo: case schema::PrimitiveType_BroadcastTo:
return new BroadcastTo(primitive);
return new (std::nothrow) BroadcastTo(primitive);
case schema::PrimitiveType_DepthToSpace: case schema::PrimitiveType_DepthToSpace:
return new DepthToSpace(primitive);
return new (std::nothrow) DepthToSpace(primitive);
case schema::PrimitiveType_Lstm: case schema::PrimitiveType_Lstm:
return new Lstm(primitive);
return new (std::nothrow) Lstm(primitive);
case schema::PrimitiveType_ZerosLike: case schema::PrimitiveType_ZerosLike:
return new ZerosLike(primitive);
return new (std::nothrow) ZerosLike(primitive);
case schema::PrimitiveType_MakeTuple: case schema::PrimitiveType_MakeTuple:
return new MakeTuple(primitive);
return new (std::nothrow) MakeTuple(primitive);
case schema::PrimitiveType_Where: case schema::PrimitiveType_Where:
return new Where(primitive);
return new (std::nothrow) Where(primitive);
case schema::PrimitiveType_ScatterND: case schema::PrimitiveType_ScatterND:
return new ScatterND(primitive);
return new (std::nothrow) ScatterND(primitive);
case schema::PrimitiveType_ConstantOfShape: case schema::PrimitiveType_ConstantOfShape:
return new ConstantOfShape(primitive);
return new (std::nothrow) ConstantOfShape(primitive);
case schema::PrimitiveType_L2Norm: case schema::PrimitiveType_L2Norm:
return new L2Norm(primitive);
return new (std::nothrow) L2Norm(primitive);
case schema::PrimitiveType_SparseToDense: case schema::PrimitiveType_SparseToDense:
return new SparseToDense(primitive);
return new (std::nothrow) SparseToDense(primitive);
case schema::PrimitiveType_DetectionPostProcess: case schema::PrimitiveType_DetectionPostProcess:
return new DetectionPostProcess(primitive);
return new (std::nothrow) DetectionPostProcess(primitive);
case schema::PrimitiveType_Dropout: case schema::PrimitiveType_Dropout:
return new Dropout(primitive);
return new (std::nothrow) Dropout(primitive);
case schema::PrimitiveType_Neg: case schema::PrimitiveType_Neg:
return new Neg(primitive);
return new (std::nothrow) Neg(primitive);
case schema::PrimitiveType_RealDiv: case schema::PrimitiveType_RealDiv:
return new RealDiv(primitive);
return new (std::nothrow) RealDiv(primitive);
case schema::PrimitiveType_LshProjection: case schema::PrimitiveType_LshProjection:
return new LshProjection(primitive);
return new (std::nothrow) LshProjection(primitive);
case schema::PrimitiveType_HashtableLookup: case schema::PrimitiveType_HashtableLookup:
return new HashtableLookup(primitive);
return new (std::nothrow) HashtableLookup(primitive);
case schema::PrimitiveType_SkipGram: case schema::PrimitiveType_SkipGram:
return new SkipGram(primitive);
return new (std::nothrow) SkipGram(primitive);
case schema::PrimitiveType_Clip: case schema::PrimitiveType_Clip:
return new Clip(primitive);
return new (std::nothrow) Clip(primitive);
case schema::PrimitiveType_CustomPredict: case schema::PrimitiveType_CustomPredict:
return new CustomPredict(primitive);
return new (std::nothrow) CustomPredict(primitive);
case schema::PrimitiveType_CustomNormalize: case schema::PrimitiveType_CustomNormalize:
return new CustomNormalize(primitive);
return new (std::nothrow) CustomNormalize(primitive);
case schema::PrimitiveType_CustomExtractFeatures: case schema::PrimitiveType_CustomExtractFeatures:
return new CustomExtractFeatures(primitive);
return new (std::nothrow) CustomExtractFeatures(primitive);
case schema::PrimitiveType_Upsample: case schema::PrimitiveType_Upsample:
return new Upsample(primitive);
return new (std::nothrow) Upsample(primitive);
case schema::PrimitiveType_LayerNorm: case schema::PrimitiveType_LayerNorm:
return new LayerNorm(primitive);
return new (std::nothrow) LayerNorm(primitive);
case schema::PrimitiveType_NonMaxSuppression: case schema::PrimitiveType_NonMaxSuppression:
return new NonMaxSuppression(primitive);
return new (std::nothrow) NonMaxSuppression(primitive);
case schema::PrimitiveType_Identity: case schema::PrimitiveType_Identity:
return new Identity(primitive);
return new (std::nothrow) Identity(primitive);
case schema::PrimitiveType_Rfft: case schema::PrimitiveType_Rfft:
return new Rfft(primitive);
return new (std::nothrow) Rfft(primitive);
case schema::PrimitiveType_FftReal: case schema::PrimitiveType_FftReal:
return new FftReal(primitive);
return new (std::nothrow) FftReal(primitive);
case schema::PrimitiveType_FftImag: case schema::PrimitiveType_FftImag:
return new FftImag(primitive);
return new (std::nothrow) FftImag(primitive);
case schema::PrimitiveType_AudioSpectrogram: case schema::PrimitiveType_AudioSpectrogram:
return new AudioSpectrogram(primitive);
return new (std::nothrow) AudioSpectrogram(primitive);
case schema::PrimitiveType_Mfcc: case schema::PrimitiveType_Mfcc:
return new Mfcc(primitive);
return new (std::nothrow) Mfcc(primitive);
case schema::PrimitiveType_InstanceNorm: case schema::PrimitiveType_InstanceNorm:
return new InstanceNorm(primitive);
return new (std::nothrow) InstanceNorm(primitive);
case schema::PrimitiveType_While: case schema::PrimitiveType_While:
return new While(primitive);
return new (std::nothrow) While(primitive);
case schema::PrimitiveType_OnnxInt8Quantize: case schema::PrimitiveType_OnnxInt8Quantize:
return new Quant(primitive);
return new (std::nothrow) Quant(primitive);
case schema::PrimitiveType_OnnxInt8Dequantize: case schema::PrimitiveType_OnnxInt8Dequantize:
return new Dequant(primitive);
return new (std::nothrow) Dequant(primitive);


#ifdef SUPPORT_TRAIN #ifdef SUPPORT_TRAIN
case schema::PrimitiveType_ActivationGrad: case schema::PrimitiveType_ActivationGrad:
return new ActivationGrad(primitive);
return new (std::nothrow) ActivationGrad(primitive);
case schema::PrimitiveType_PoolingGrad: case schema::PrimitiveType_PoolingGrad:
return new PoolingGrad(primitive);
return new (std::nothrow) PoolingGrad(primitive);
case schema::PrimitiveType_Conv2DGradFilter: case schema::PrimitiveType_Conv2DGradFilter:
return new Conv2DGradFilter(primitive);
return new (std::nothrow) Conv2DGradFilter(primitive);
case schema::PrimitiveType_Conv2DGradInput: case schema::PrimitiveType_Conv2DGradInput:
return new Conv2DGradInput(primitive);
return new (std::nothrow) Conv2DGradInput(primitive);
case schema::PrimitiveType_GroupConv2DGradInput: case schema::PrimitiveType_GroupConv2DGradInput:
return new GroupConv2DGradInput(primitive);
return new (std::nothrow) GroupConv2DGradInput(primitive);
case schema::PrimitiveType_BiasGrad: case schema::PrimitiveType_BiasGrad:
return new BiasGrad(primitive);
return new (std::nothrow) BiasGrad(primitive);
case schema::PrimitiveType_ApplyMomentum: case schema::PrimitiveType_ApplyMomentum:
return new ApplyMomentum(primitive);
return new (std::nothrow) ApplyMomentum(primitive);
case schema::PrimitiveType_BNGrad: case schema::PrimitiveType_BNGrad:
return new BNGrad(primitive);
return new (std::nothrow) BNGrad(primitive);
case schema::PrimitiveType_AddGrad: case schema::PrimitiveType_AddGrad:
return new ArithmeticGrad(primitive);
return new (std::nothrow) ArithmeticGrad(primitive);
case schema::PrimitiveType_SubGrad: case schema::PrimitiveType_SubGrad:
return new ArithmeticGrad(primitive);
return new (std::nothrow) ArithmeticGrad(primitive);
case schema::PrimitiveType_MulGrad: case schema::PrimitiveType_MulGrad:
return new ArithmeticGrad(primitive);
return new (std::nothrow) ArithmeticGrad(primitive);
case schema::PrimitiveType_DivGrad: case schema::PrimitiveType_DivGrad:
return new ArithmeticGrad(primitive);
return new (std::nothrow) ArithmeticGrad(primitive);
case schema::PrimitiveType_SoftmaxCrossEntropy: case schema::PrimitiveType_SoftmaxCrossEntropy:
return new SoftmaxCrossEntropy(primitive);
return new (std::nothrow) SoftmaxCrossEntropy(primitive);
case schema::PrimitiveType_PowerGrad: case schema::PrimitiveType_PowerGrad:
return new PowerGrad(primitive);
return new (std::nothrow) PowerGrad(primitive);
case schema::PrimitiveType_Depend: case schema::PrimitiveType_Depend:
return new Depend(primitive);
return new (std::nothrow) Depend(primitive);
case schema::PrimitiveType_ControlDepend: case schema::PrimitiveType_ControlDepend:
return new ControlDepend(primitive);
return new (std::nothrow) ControlDepend(primitive);
case schema::PrimitiveType_FlattenGrad: case schema::PrimitiveType_FlattenGrad:
return new FlattenGrad(primitive);
return new (std::nothrow) FlattenGrad(primitive);
case schema::PrimitiveType_NegGrad: case schema::PrimitiveType_NegGrad:
return new NegGrad(primitive);
return new (std::nothrow) NegGrad(primitive);
case schema::PrimitiveType_LogGrad: case schema::PrimitiveType_LogGrad:
return new LogGrad(primitive);
return new (std::nothrow) LogGrad(primitive);
case schema::PrimitiveType_Sgd: case schema::PrimitiveType_Sgd:
return new Sgd(primitive);
return new (std::nothrow) Sgd(primitive);
case schema::PrimitiveType_Adam: case schema::PrimitiveType_Adam:
return new Adam(primitive);
return new (std::nothrow) Adam(primitive);
case schema::PrimitiveType_Assign: case schema::PrimitiveType_Assign:
return new Assign(primitive);
return new (std::nothrow) Assign(primitive);
case schema::PrimitiveType_AssignAdd: case schema::PrimitiveType_AssignAdd:
return new AssignAdd(primitive);
return new (std::nothrow) AssignAdd(primitive);
case schema::PrimitiveType_OnesLike: case schema::PrimitiveType_OnesLike:
return new OnesLike(primitive);
return new (std::nothrow) OnesLike(primitive);
case schema::PrimitiveType_UnsortedSegmentSum: case schema::PrimitiveType_UnsortedSegmentSum:
return new UnsortedSegmentSum(primitive);
return new (std::nothrow) UnsortedSegmentSum(primitive);
case schema::PrimitiveType_BinaryCrossEntropyGrad: case schema::PrimitiveType_BinaryCrossEntropyGrad:
return new BinaryCrossEntropyGrad(primitive);
return new (std::nothrow) BinaryCrossEntropyGrad(primitive);
case schema::PrimitiveType_BinaryCrossEntropy: case schema::PrimitiveType_BinaryCrossEntropy:
return new BinaryCrossEntropy(primitive);
return new (std::nothrow) BinaryCrossEntropy(primitive);
case schema::PrimitiveType_DropoutGrad: case schema::PrimitiveType_DropoutGrad:
return new DropoutGrad(primitive);
return new (std::nothrow) DropoutGrad(primitive);
case schema::PrimitiveType_MaximumGrad: case schema::PrimitiveType_MaximumGrad:
return new MaximumGrad(primitive);
return new (std::nothrow) MaximumGrad(primitive);
case schema::PrimitiveType_MinimumGrad: case schema::PrimitiveType_MinimumGrad:
return new MinimumGrad(primitive);
return new (std::nothrow) MinimumGrad(primitive);
#endif #endif
default: default:
MS_LOG(ERROR) << "Unsupported primitive type in Create : " << schema::EnumNamePrimitiveType(op_type); MS_LOG(ERROR) << "Unsupported primitive type in Create : " << schema::EnumNamePrimitiveType(op_type);


+ 6
- 6
mindspore/lite/src/ops/primitive_c.h View File

@@ -79,9 +79,9 @@ class PrimitiveC : public mindspore::Primitive {


void ClearPrimitiveT(); void ClearPrimitiveT();


bool operator==(const Value &rhs) const {
bool operator==(const Value &rhs) const override {
if (rhs.isa<PrimitiveC>()) { if (rhs.isa<PrimitiveC>()) {
auto other_prim = static_cast<const PrimitiveC &>(rhs);
auto other_prim = dynamic_cast<const PrimitiveC &>(rhs);
auto a = this->primitive_->value.type; auto a = this->primitive_->value.type;
auto b = other_prim.primitive_->value.type; auto b = other_prim.primitive_->value.type;
return a == b; return a == b;
@@ -104,11 +104,11 @@ class PrimitiveC : public mindspore::Primitive {


void ClearInputOutputQuantParam(); void ClearInputOutputQuantParam();


void AddInputQuantParam(std::vector<schema::QuantParamT> quant_param);
void AddInputQuantParam(const std::vector<schema::QuantParamT> &quant_param);


std::vector<std::vector<schema::QuantParamT>> input_quant_params() const; std::vector<std::vector<schema::QuantParamT>> input_quant_params() const;


void AddOutputQuantParam(std::vector<schema::QuantParamT> quant_param);
void AddOutputQuantParam(const std::vector<schema::QuantParamT> &quant_param);


std::vector<std::vector<schema::QuantParamT>> output_quant_params() const; std::vector<std::vector<schema::QuantParamT>> output_quant_params() const;


@@ -126,7 +126,7 @@ class PrimitiveC : public mindspore::Primitive {


static PrimitiveC *Create(mindspore::schema::PrimitiveT *primitive); static PrimitiveC *Create(mindspore::schema::PrimitiveT *primitive);


void GetAttrDataFromInput(const AnfNodePtr inputNode, std::vector<int> *data);
static void GetAttrDataFromInput(const AnfNodePtr &inputNode, std::vector<int> *data);


static std::shared_ptr<PrimitiveC> Create(const Primitive &prim, const std::vector<AnfNodePtr> &inputs, static std::shared_ptr<PrimitiveC> Create(const Primitive &prim, const std::vector<AnfNodePtr> &inputs,
const schema::QuantType &quantType); const schema::QuantType &quantType);
@@ -135,7 +135,7 @@ class PrimitiveC : public mindspore::Primitive {
void PopulaterInputQuantParam(const Primitive &prim, const std::vector<AnfNodePtr> &inputs, void PopulaterInputQuantParam(const Primitive &prim, const std::vector<AnfNodePtr> &inputs,
bool narrowRangeQuantParam, int32_t numbitsRangeQuantParam); bool narrowRangeQuantParam, int32_t numbitsRangeQuantParam);
void PopulaterOutputQuantParam(const Primitive &prim, bool narrowRangeQuantParam, int32_t numbitsRangeQuantParam); void PopulaterOutputQuantParam(const Primitive &prim, bool narrowRangeQuantParam, int32_t numbitsRangeQuantParam);
void CalFloatScopeByMeanAndStddev(const double &mean, const double &stdDev, float *mMin, float *mMax);
static void CalFloatScopeByMeanAndStddev(const double &mean, const double &stdDev, float *mMin, float *mMax);


protected: protected:
virtual int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { return RET_ERROR; } virtual int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { return RET_ERROR; }


+ 0
- 1
mindspore/lite/src/ops/sub.cc View File

@@ -52,7 +52,6 @@ int Sub::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs
this->primitive_ = nullptr; this->primitive_ = nullptr;
return RET_ERROR; return RET_ERROR;
} }
// todo: confirm the activationType
attr->activationType = schema::ActivationType_NO_ACTIVATION; attr->activationType = schema::ActivationType_NO_ACTIVATION;
this->primitive_->value.value = attr; this->primitive_->value.value = attr;
} }


+ 3
- 3
mindspore/lite/src/param_value_lite.h View File

@@ -37,15 +37,15 @@ class ParamValueLite : public Value {
} }
MS_DECLARE_PARENT(ParamValueLite, Value) MS_DECLARE_PARENT(ParamValueLite, Value)
size_t tensor_size() const { return tensor_size_; } size_t tensor_size() const { return tensor_size_; }
void set_tensor_size(size_t size) { tensor_size_ = size; }
void set_tensor_size(const size_t size) { tensor_size_ = size; }
void *tensor_addr() const { return tensor_addr_; } void *tensor_addr() const { return tensor_addr_; }
void set_tensor_addr(void *addr) { tensor_addr_ = addr; } void set_tensor_addr(void *addr) { tensor_addr_ = addr; }


std::vector<int> tensor_shape() const { return tensor_shape_; } std::vector<int> tensor_shape() const { return tensor_shape_; }
void set_tensor_shape(std::vector<int> tensor_shape) { tensor_shape_ = std::move(tensor_shape); }
void set_tensor_shape(const std::vector<int> &tensor_shape) { tensor_shape_ = tensor_shape; }


TypeId tensor_type() const { return type_id_; } TypeId tensor_type() const { return type_id_; }
void set_tensor_type(TypeId type_id) { type_id_ = type_id; }
void set_tensor_type(const TypeId type_id) { type_id_ = type_id; }


int tensor_shape_size() const { int tensor_shape_size() const {
int size = 1; int size = 1;


+ 9
- 9
mindspore/lite/src/runtime/allocator.cc View File

@@ -23,7 +23,7 @@ std::shared_ptr<Allocator> Allocator::Create() {
return std::shared_ptr<Allocator>(new (std::nothrow) DefaultAllocator()); return std::shared_ptr<Allocator>(new (std::nothrow) DefaultAllocator());
} }


DefaultAllocator::DefaultAllocator() {}
DefaultAllocator::DefaultAllocator() = default;


DefaultAllocator::~DefaultAllocator() { Clear(); } DefaultAllocator::~DefaultAllocator() { Clear(); }


@@ -94,13 +94,13 @@ size_t DefaultAllocator::GetTotalSize() {
Lock(); Lock();
size_t totalSize = 0; size_t totalSize = 0;


for (auto it = allocatedList_.begin(); it != allocatedList_.end(); it++) {
auto membuf = it->second;
for (auto &it : allocatedList_) {
auto membuf = it.second;
totalSize += membuf->size; totalSize += membuf->size;
} }


for (auto it = freeList_.begin(); it != freeList_.end(); it++) {
auto membuf = it->second;
for (auto &it : freeList_) {
auto membuf = it.second;
totalSize += membuf->size; totalSize += membuf->size;
} }
UnLock(); UnLock();
@@ -110,13 +110,13 @@ size_t DefaultAllocator::GetTotalSize() {
void DefaultAllocator::Clear() { void DefaultAllocator::Clear() {
Lock(); Lock();


for (auto it = allocatedList_.begin(); it != allocatedList_.end(); it++) {
free(it->second);
for (auto &it : allocatedList_) {
free(it.second);
} }
allocatedList_.clear(); allocatedList_.clear();


for (auto it = freeList_.begin(); it != freeList_.end(); it++) {
free(it->second);
for (auto &it : freeList_) {
free(it.second);
} }
freeList_.clear(); freeList_.clear();
UnLock(); UnLock();


+ 1
- 1
mindspore/lite/src/runtime/allocator.h View File

@@ -34,7 +34,7 @@ struct AllocatorContext {
class Allocator { class Allocator {
public: public:
Allocator() : name("default") {} Allocator() : name("default") {}
virtual ~Allocator() {}
virtual ~Allocator() = default;
virtual void *Malloc(size_t size) = 0; virtual void *Malloc(size_t size) = 0;
virtual void Free(void *ptr) = 0; virtual void Free(void *ptr) = 0;
virtual void SetContext(const AllocatorContext &ctx) {} virtual void SetContext(const AllocatorContext &ctx) {}


+ 7
- 4
mindspore/lite/src/runtime/parallel_executor.cc View File

@@ -31,7 +31,7 @@ int ParallelExecutor::Prepare(const std::vector<mindspore::kernel::LiteKernel *>
} }


static int RunKernel(void *data, int index) { static int RunKernel(void *data, int index) {
ParallelExecutor *executor = reinterpret_cast<ParallelExecutor *>(data);
auto *executor = reinterpret_cast<ParallelExecutor *>(data);
auto kernel = executor->GetReadyKernel(index); auto kernel = executor->GetReadyKernel(index);
auto ret = kernel->Run(); auto ret = kernel->Run();
executor->SetResult(index, ret); executor->SetResult(index, ret);
@@ -65,16 +65,19 @@ int ParallelExecutor::Run(std::vector<Tensor *> &in_tensors, std::vector<Tensor
kernel::LiteKernelUtil::InitTensorRefCount(kernels); kernel::LiteKernelUtil::InitTensorRefCount(kernels);


for (auto kernel : kernels) { for (auto kernel : kernels) {
if (kernel->in_kernels().size() == 0) {
if (kernel->in_kernels().empty()) {
readyKernels.emplace_back(kernel); readyKernels.emplace_back(kernel);
continue; continue;
} }
refCount[kernel] = kernel->in_kernels().size(); refCount[kernel] = kernel->in_kernels().size();
} }
std::vector<kernel::LiteKernel *> newReadyKernels; std::vector<kernel::LiteKernel *> newReadyKernels;
while (readyKernels.size() > 0) {
while (!readyKernels.empty()) {
results.resize(readyKernels.size(), RET_OK); results.resize(readyKernels.size(), RET_OK);
ParallelLaunch(thread_pool_, RunKernel, this, readyKernels.size());
if (0 != ParallelLaunch(thread_pool_, RunKernel, this, readyKernels.size())) {
MS_LOG(ERROR) << "ParallelLaunch failed ";
return RET_ERROR;
}


if (std::find_if(results.begin(), results.end(), [](const int &ret) { return (ret != 0); }) != results.end()) { if (std::find_if(results.begin(), results.end(), [](const int &ret) { return (ret != 0); }) != results.end()) {
return RET_ERROR; return RET_ERROR;


+ 5
- 5
mindspore/lite/src/runtime/parallel_executor.h View File

@@ -14,8 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */


#ifndef MINDSPORE_LITE_PARALLEL_EXECUTOR_H_
#define MINDSPORE_LITE_PARALLEL_EXECUTOR_H_
#ifndef MINDSPORE_LITE_SRC_RUNTIME_PARALLEL_EXECUTOR_H_
#define MINDSPORE_LITE_SRC_RUNTIME_PARALLEL_EXECUTOR_H_


#include <vector> #include <vector>
#include <unordered_map> #include <unordered_map>
@@ -28,7 +28,7 @@ namespace mindspore::lite {
class ParallelExecutor : public Executor { class ParallelExecutor : public Executor {
public: public:
ParallelExecutor() = default; ParallelExecutor() = default;
virtual ~ParallelExecutor();
~ParallelExecutor() override;


int Prepare(const std::vector<kernel::LiteKernel *> &kernels) override; int Prepare(const std::vector<kernel::LiteKernel *> &kernels) override;


@@ -42,8 +42,8 @@ class ParallelExecutor : public Executor {
std::unordered_map<kernel::LiteKernel *, size_t> refCount; std::unordered_map<kernel::LiteKernel *, size_t> refCount;
std::vector<kernel::LiteKernel *> readyKernels; std::vector<kernel::LiteKernel *> readyKernels;
std::vector<int> results; std::vector<int> results;
struct ThreadPool *thread_pool_ = NULL;
struct ThreadPool *thread_pool_ = nullptr;
}; };


} // namespace mindspore::lite } // namespace mindspore::lite
#endif
#endif // MINDSPORE_LITE_SRC_RUNTIME_PARALLEL_EXECUTOR_H_

+ 0
- 1
mindspore/lite/src/runtime/runtime_api.cc View File

@@ -16,7 +16,6 @@


#include "src/runtime/runtime_api.h" #include "src/runtime/runtime_api.h"
#include <mutex> #include <mutex>
#include <string>
#include "src/common/log_adapter.h" #include "src/common/log_adapter.h"


static std::mutex gWorkspaceMutex; static std::mutex gWorkspaceMutex;


+ 3
- 3
mindspore/lite/src/runtime/runtime_api.h View File

@@ -13,8 +13,8 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
#ifndef PREDICT_SRC_RUNTIME_RUNTIME_API_H_
#define PREDICT_SRC_RUNTIME_RUNTIME_API_H_
#ifndef MINDSPORE_LITE_SRC_RUNTIME_RUNTIME_API_H_
#define MINDSPORE_LITE_SRC_RUNTIME_RUNTIME_API_H_
#include <memory> #include <memory>


#ifndef INTERNAL_API_DLL #ifndef INTERNAL_API_DLL
@@ -40,4 +40,4 @@ INTERNAL_API_DLL int LiteBackendRegisterSystemLibSymbol(const char *name, void *
#ifdef __cplusplus #ifdef __cplusplus
} }
#endif #endif
#endif // PREDICT_SRC_RUNTIME_RUNTIME_API_H_
#endif // MINDSPORE_LITE_SRC_RUNTIME_RUNTIME_API_H_

+ 8
- 12
mindspore/lite/src/scheduler.cc View File

@@ -72,10 +72,6 @@ int Scheduler::ReSizeKernels(const std::vector<kernel::LiteKernel *> &kernels) {
return RET_ERROR; return RET_ERROR;
} }
auto sub_graph = reinterpret_cast<kernel::SubGraphKernel *>(kernel); auto sub_graph = reinterpret_cast<kernel::SubGraphKernel *>(kernel);
if (sub_graph == nullptr) {
MS_LOG(ERROR) << "node " << kernel->name() << " is neither a kernel or a sub_graph";
return RET_ERROR;
}
auto ret = sub_graph->ReSize(infer_shape_interrupt); auto ret = sub_graph->ReSize(infer_shape_interrupt);
if (ret == RET_INFER_INVALID) { if (ret == RET_INFER_INVALID) {
MS_LOG(INFO) << "InferShape is interrupted"; MS_LOG(INFO) << "InferShape is interrupted";
@@ -133,7 +129,7 @@ int Scheduler::InferShape(const lite::Model *model, std::vector<Tensor *> *tenso
return RET_OK; return RET_OK;
} }


int Scheduler::BuildKernels(const lite::Model *model, std::vector<Tensor *> *tensors,
int Scheduler::BuildKernels(const lite::Model *model, const std::vector<Tensor *> *tensors,
std::vector<kernel::LiteKernel *> *kernels) { std::vector<kernel::LiteKernel *> *kernels) {
MS_ASSERT(model != nullptr); MS_ASSERT(model != nullptr);
MS_ASSERT(tensors != nullptr); MS_ASSERT(tensors != nullptr);
@@ -244,21 +240,21 @@ kernel::SubGraphKernel *Scheduler::CreateSubGraphKernel(const std::vector<kernel
std::vector<kernel::LiteKernel *> output_kernels = kernel::LiteKernelUtil::SubgraphOutputKernels(kernels); std::vector<kernel::LiteKernel *> output_kernels = kernel::LiteKernelUtil::SubgraphOutputKernels(kernels);
if (type == kernel::kGpuSubGraph) { if (type == kernel::kGpuSubGraph) {
#if SUPPORT_GPU #if SUPPORT_GPU
auto sub_kernel =
new kernel::SubGraphOpenCLKernel(input_tensors, output_tensors, input_kernels, output_kernels, kernels, context_);
auto sub_kernel = new (std::nothrow)
kernel::SubGraphOpenCLKernel(input_tensors, output_tensors, input_kernels, output_kernels, kernels, context_);
return sub_kernel; return sub_kernel;
#else #else
return nullptr; return nullptr;
#endif #endif
} }
if (type == kernel::kCpuFP16SubGraph) { if (type == kernel::kCpuFP16SubGraph) {
auto sub_kernel =
new kernel::CpuFp16SubGraph(input_tensors, output_tensors, input_kernels, output_kernels, kernels, context_);
auto sub_kernel = new (std::nothrow)
kernel::CpuFp16SubGraph(input_tensors, output_tensors, input_kernels, output_kernels, kernels, context_);
return sub_kernel; return sub_kernel;
} }
if (type == kernel::kCpuFP32SubGraph) { if (type == kernel::kCpuFP32SubGraph) {
auto sub_kernel =
new kernel::CpuFp32SubGraph(input_tensors, output_tensors, input_kernels, output_kernels, kernels, context_);
auto sub_kernel = new (std::nothrow)
kernel::CpuFp32SubGraph(input_tensors, output_tensors, input_kernels, output_kernels, kernels, context_);
return sub_kernel; return sub_kernel;
} }
return nullptr; return nullptr;
@@ -344,7 +340,7 @@ void Scheduler::SetKernelTensorDataType(kernel::LiteKernel *kernel) {
} }
} }


kernel::SubGraphType Scheduler::GetKernelSubGraphType(kernel::LiteKernel *kernel) {
kernel::SubGraphType Scheduler::GetKernelSubGraphType(const kernel::LiteKernel *kernel) {
if (kernel == nullptr) { if (kernel == nullptr) {
return kernel::kNotSubGraph; return kernel::kNotSubGraph;
} }


+ 2
- 2
mindspore/lite/src/scheduler.h View File

@@ -38,7 +38,7 @@ class Scheduler {
kernel::LiteKernel *ScheduleNode(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors, kernel::LiteKernel *ScheduleNode(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors,
const mindspore::lite::PrimitiveC *primitive, const Model::Node *cnode); const mindspore::lite::PrimitiveC *primitive, const Model::Node *cnode);


int BuildKernels(const lite::Model *model, std::vector<Tensor *> *tensors,
int BuildKernels(const lite::Model *model, const std::vector<Tensor *> *tensors,
std::vector<kernel::LiteKernel *> *kernels); std::vector<kernel::LiteKernel *> *kernels);


static int InferShape(const lite::Model *model, std::vector<Tensor *> *tensors); static int InferShape(const lite::Model *model, std::vector<Tensor *> *tensors);
@@ -55,7 +55,7 @@ class Scheduler {


static void SetKernelTensorDataType(kernel::LiteKernel *kernel); static void SetKernelTensorDataType(kernel::LiteKernel *kernel);


static kernel::SubGraphType GetKernelSubGraphType(kernel::LiteKernel *kernel);
static kernel::SubGraphType GetKernelSubGraphType(const kernel::LiteKernel *kernel);


protected: protected:
InnerContext *context_ = nullptr; InnerContext *context_ = nullptr;


+ 5
- 1
mindspore/lite/src/sub_graph_kernel.cc View File

@@ -115,7 +115,11 @@ int SubGraphKernel::ReSize(bool is_interrupt) {
std::vector<lite::Tensor *> inputs = kernel->in_tensors(); std::vector<lite::Tensor *> inputs = kernel->in_tensors();
std::vector<lite::Tensor *> outputs = kernel->out_tensors(); std::vector<lite::Tensor *> outputs = kernel->out_tensors();
for (auto &output : outputs) { for (auto &output : outputs) {
output->FreeData();
auto ret = output->FreeData();
if (ret != RET_OK) {
MS_LOG(ERROR) << "FreeData failed";
return RET_ERROR;
}
} }
primitive->set_infer_flag(!is_interrupt); primitive->set_infer_flag(!is_interrupt);
auto ret = primitive->InferShape(inputs, outputs); auto ret = primitive->InferShape(inputs, outputs);


+ 1
- 1
mindspore/lite/src/sub_graph_kernel.h View File

@@ -104,7 +104,7 @@ class CpuSubGraph : public SubGraphKernel {
const std::vector<LiteKernel *> &nodes, const lite::InnerContext *ctx) const std::vector<LiteKernel *> &nodes, const lite::InnerContext *ctx)
: SubGraphKernel(inputs, outputs, in_kernels, out_kernels, nodes, ctx) { : SubGraphKernel(inputs, outputs, in_kernels, out_kernels, nodes, ctx) {
subgraph_type_ = kCpuFP32SubGraph; subgraph_type_ = kCpuFP32SubGraph;
this->executor_ = new mindspore::lite::Executor;
this->executor_ = new (std::nothrow) mindspore::lite::Executor;
} }


~CpuSubGraph() override { delete this->executor_; } ~CpuSubGraph() override { delete this->executor_; }


+ 6
- 5
mindspore/lite/src/tensor.cc View File

@@ -18,6 +18,7 @@
#include <string> #include <string>
#include <utility> #include <utility>
#include <algorithm> #include <algorithm>
#include <functional>
#include "src/tensor.h" #include "src/tensor.h"
#include "securec/include/securec.h" #include "securec/include/securec.h"
#include "include/errorcode.h" #include "include/errorcode.h"
@@ -25,8 +26,8 @@
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
#define kMaxMallocSize 1024 * 1024 * 100 #define kMaxMallocSize 1024 * 1024 * 100
Tensor::Tensor(const TypeId data_type, const std::vector<int> &shape, const schema::Format &format, Category category)
: data_type_(data_type), shape_(shape), format_(format), category_(category) {}
Tensor::Tensor(const TypeId data_type, std::vector<int> shape, const schema::Format &format, Category category)
: data_type_(data_type), shape_(std::move(shape)), format_(format), category_(category) {}


Tensor::Tensor(const Tensor &tensor) { Tensor::Tensor(const Tensor &tensor) {
auto ret = CopyTensor(tensor, true); auto ret = CopyTensor(tensor, true);
@@ -234,7 +235,7 @@ int32_t Tensor::ElementsC4Num() const {
return result; return result;
} }


int Tensor::DimensionSize(size_t index) const {
int Tensor::DimensionSize(const size_t index) const {
int dim_size = -1; int dim_size = -1;
if (index < shape_.size()) { if (index < shape_.size()) {
dim_size = shape_[index]; dim_size = shape_[index];
@@ -277,12 +278,12 @@ std::string Tensor::ToString() const {
return oss.str(); return oss.str();
} }


int Tensor::MallocData(mindspore::lite::Allocator *allocator) {
int Tensor::MallocData(const mindspore::lite::Allocator *allocator) {
if (nullptr != this->data_) { if (nullptr != this->data_) {
return RET_OK; return RET_OK;
} }
if (allocator != nullptr) { if (allocator != nullptr) {
allocator_ = allocator;
allocator_ = const_cast<mindspore::lite::Allocator *>(allocator);
} }
if (allocator_ == nullptr) { if (allocator_ == nullptr) {
this->data_ = malloc(this->Size()); this->data_ = malloc(this->Size());


+ 9
- 9
mindspore/lite/src/tensor.h View File

@@ -14,8 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */


#ifndef MINDSPORE_LITE_SRC_IR_TENSOR_H_
#define MINDSPORE_LITE_SRC_IR_TENSOR_H_
#ifndef MINDSPORE_LITE_SRC_TENSOR_H_
#define MINDSPORE_LITE_SRC_TENSOR_H_


#include <memory> #include <memory>
#include <vector> #include <vector>
@@ -49,18 +49,18 @@ class Tensor : public mindspore::tensor::MSTensor {
}; };
Tensor() = default; Tensor() = default;


Tensor(const TypeId data_type, const std::vector<int> &shape,
const schema::Format &format = schema::Format::Format_NHWC, Category category = VAR);
Tensor(TypeId data_type, std::vector<int> shape, const schema::Format &format = schema::Format::Format_NHWC,
Category category = VAR);


Tensor(const Tensor &tensor); Tensor(const Tensor &tensor);


virtual ~Tensor();
~Tensor() override;


int CopyTensorData(const Tensor &srcTensor); int CopyTensorData(const Tensor &srcTensor);


int CopyTensor(const Tensor &srcTensor, bool copyData = false); int CopyTensor(const Tensor &srcTensor, bool copyData = false);


virtual Tensor &operator=(const Tensor &tensor);
Tensor &operator=(const Tensor &tensor);


virtual bool operator==(const Tensor &tensor); virtual bool operator==(const Tensor &tensor);


@@ -92,7 +92,7 @@ class Tensor : public mindspore::tensor::MSTensor {


mindspore::lite::Allocator *allocator() const { return this->allocator_; } mindspore::lite::Allocator *allocator() const { return this->allocator_; }


int MallocData(mindspore::lite::Allocator *allocator = nullptr);
int MallocData(const mindspore::lite::Allocator *allocator = nullptr);


int FreeData(); int FreeData();


@@ -108,7 +108,7 @@ class Tensor : public mindspore::tensor::MSTensor {


schema::Format format() { return this->format_; } schema::Format format() { return this->format_; }


size_t ref_count() { return this->ref_count_; }
size_t ref_count() const { return this->ref_count_; }


void set_ref_count(size_t ref_count) { this->ref_count_ = ref_count; } void set_ref_count(size_t ref_count) { this->ref_count_ = ref_count; }


@@ -218,4 +218,4 @@ std::vector<tensor::MSTensor *> TensorVectorCast(const std::vector<Tensor *> &sr
} // namespace mindspore } // namespace mindspore


using TensorPtr = std::shared_ptr<mindspore::lite::Tensor>; using TensorPtr = std::shared_ptr<mindspore::lite::Tensor>;
#endif // MINDSPORE_LITE_SRC_IR_TENSOR_H_
#endif // MINDSPORE_LITE_SRC_TENSOR_H_

Loading…
Cancel
Save