Browse Source

!4919 Change Primitive to PrimitiveC

Merge pull request !4919 from yeyunpeng2020/primitve_1
tags/v0.7.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
4f928a4f7e
100 changed files with 890 additions and 1106 deletions
  1. +1
    -1
      mindspore/lite/include/model.h
  2. +2
    -3
      mindspore/lite/src/CMakeLists.txt
  3. +0
    -40
      mindspore/lite/src/ir/primitive_t_value.cc
  4. +0
    -91
      mindspore/lite/src/ir/primitive_t_value.h
  5. +1
    -1
      mindspore/lite/src/lite_kernel.h
  6. +2
    -1
      mindspore/lite/src/lite_session.cc
  7. +1
    -1
      mindspore/lite/src/lite_session.h
  8. +1
    -2
      mindspore/lite/src/model.cc
  9. +0
    -3
      mindspore/lite/src/ops/CMakeLists.txt
  10. +4
    -1
      mindspore/lite/src/ops/abs.h
  11. +6
    -6
      mindspore/lite/src/ops/activation.cc
  12. +6
    -9
      mindspore/lite/src/ops/activation.h
  13. +3
    -3
      mindspore/lite/src/ops/activation_grad.cc
  14. +7
    -9
      mindspore/lite/src/ops/activation_grad.h
  15. +3
    -3
      mindspore/lite/src/ops/add.cc
  16. +7
    -3
      mindspore/lite/src/ops/add.h
  17. +4
    -4
      mindspore/lite/src/ops/addn.cc
  18. +7
    -9
      mindspore/lite/src/ops/addn.h
  19. +16
    -16
      mindspore/lite/src/ops/argmax.cc
  20. +7
    -9
      mindspore/lite/src/ops/argmax.h
  21. +16
    -16
      mindspore/lite/src/ops/argmin.cc
  22. +7
    -9
      mindspore/lite/src/ops/argmin.h
  23. +1
    -1
      mindspore/lite/src/ops/arithmetic.cc
  24. +7
    -9
      mindspore/lite/src/ops/arithmetic.h
  25. +1
    -1
      mindspore/lite/src/ops/arithmetic_self.cc
  26. +7
    -11
      mindspore/lite/src/ops/arithmetic_self.h
  27. +3
    -3
      mindspore/lite/src/ops/batch_norm.cc
  28. +7
    -9
      mindspore/lite/src/ops/batch_norm.h
  29. +7
    -7
      mindspore/lite/src/ops/batch_to_space.cc
  30. +7
    -9
      mindspore/lite/src/ops/batch_to_space.h
  31. +3
    -3
      mindspore/lite/src/ops/bias_add.cc
  32. +7
    -9
      mindspore/lite/src/ops/bias_add.h
  33. +3
    -3
      mindspore/lite/src/ops/bias_grad.cc
  34. +7
    -9
      mindspore/lite/src/ops/bias_grad.h
  35. +6
    -6
      mindspore/lite/src/ops/bn_grad_input.cc
  36. +7
    -9
      mindspore/lite/src/ops/bn_grad_input.h
  37. +3
    -3
      mindspore/lite/src/ops/broadcast_to.cc
  38. +7
    -9
      mindspore/lite/src/ops/broadcast_to.h
  39. +3
    -3
      mindspore/lite/src/ops/caffe_p_relu.cc
  40. +7
    -10
      mindspore/lite/src/ops/caffe_p_relu.h
  41. +7
    -7
      mindspore/lite/src/ops/cast.cc
  42. +7
    -9
      mindspore/lite/src/ops/cast.h
  43. +8
    -10
      mindspore/lite/src/ops/ceil.h
  44. +6
    -6
      mindspore/lite/src/ops/clip.cc
  45. +7
    -9
      mindspore/lite/src/ops/clip.h
  46. +7
    -7
      mindspore/lite/src/ops/concat.cc
  47. +7
    -9
      mindspore/lite/src/ops/concat.h
  48. +3
    -3
      mindspore/lite/src/ops/constant_of_shape.cc
  49. +7
    -8
      mindspore/lite/src/ops/constant_of_shape.h
  50. +52
    -52
      mindspore/lite/src/ops/conv2d.cc
  51. +7
    -9
      mindspore/lite/src/ops/conv2d.h
  52. +53
    -51
      mindspore/lite/src/ops/conv2d_grad_filter.cc
  53. +7
    -9
      mindspore/lite/src/ops/conv2d_grad_filter.h
  54. +53
    -51
      mindspore/lite/src/ops/conv2d_grad_input.cc
  55. +7
    -9
      mindspore/lite/src/ops/conv2d_grad_input.h
  56. +8
    -10
      mindspore/lite/src/ops/cos.h
  57. +6
    -6
      mindspore/lite/src/ops/crop.cc
  58. +7
    -9
      mindspore/lite/src/ops/crop.h
  59. +52
    -52
      mindspore/lite/src/ops/deconv2d.cc
  60. +7
    -9
      mindspore/lite/src/ops/deconv2d.h
  61. +49
    -49
      mindspore/lite/src/ops/dedepthwise_conv2d.cc
  62. +7
    -9
      mindspore/lite/src/ops/dedepthwise_conv2d.h
  63. +7
    -7
      mindspore/lite/src/ops/depth_to_space.cc
  64. +7
    -9
      mindspore/lite/src/ops/depth_to_space.h
  65. +51
    -49
      mindspore/lite/src/ops/depthwise_conv2d.cc
  66. +7
    -9
      mindspore/lite/src/ops/depthwise_conv2d.h
  67. +43
    -39
      mindspore/lite/src/ops/detection_post_process.cc
  68. +7
    -9
      mindspore/lite/src/ops/detection_post_process.h
  69. +3
    -3
      mindspore/lite/src/ops/div.cc
  70. +7
    -10
      mindspore/lite/src/ops/div.h
  71. +3
    -3
      mindspore/lite/src/ops/dropout.cc
  72. +7
    -9
      mindspore/lite/src/ops/dropout.h
  73. +3
    -3
      mindspore/lite/src/ops/eltwise.cc
  74. +7
    -9
      mindspore/lite/src/ops/eltwise.h
  75. +3
    -3
      mindspore/lite/src/ops/elu.cc
  76. +7
    -9
      mindspore/lite/src/ops/elu.h
  77. +4
    -4
      mindspore/lite/src/ops/embedding_lookup.cc
  78. +7
    -9
      mindspore/lite/src/ops/embedding_lookup.h
  79. +9
    -9
      mindspore/lite/src/ops/embedding_lookup_sparse.cc
  80. +7
    -9
      mindspore/lite/src/ops/embedding_lookup_sparse.h
  81. +7
    -9
      mindspore/lite/src/ops/equal.h
  82. +7
    -9
      mindspore/lite/src/ops/exp.h
  83. +4
    -4
      mindspore/lite/src/ops/expand_dims.cc
  84. +7
    -9
      mindspore/lite/src/ops/expand_dims.h
  85. +6
    -6
      mindspore/lite/src/ops/fake_quant_with_min_max_vars.cc
  86. +7
    -9
      mindspore/lite/src/ops/fake_quant_with_min_max_vars.h
  87. +4
    -4
      mindspore/lite/src/ops/fill.cc
  88. +7
    -9
      mindspore/lite/src/ops/fill.h
  89. +1
    -1
      mindspore/lite/src/ops/flatten.cc
  90. +7
    -9
      mindspore/lite/src/ops/flatten.h
  91. +8
    -10
      mindspore/lite/src/ops/floor.h
  92. +7
    -9
      mindspore/lite/src/ops/floor_div.h
  93. +7
    -9
      mindspore/lite/src/ops/floor_mod.h
  94. +13
    -13
      mindspore/lite/src/ops/full_connection.cc
  95. +7
    -9
      mindspore/lite/src/ops/full_connection.h
  96. +9
    -9
      mindspore/lite/src/ops/fused_batchnorm.cc
  97. +7
    -9
      mindspore/lite/src/ops/fused_batchnorm.h
  98. +7
    -7
      mindspore/lite/src/ops/gather.cc
  99. +7
    -9
      mindspore/lite/src/ops/gather.h
  100. +4
    -4
      mindspore/lite/src/ops/gather_nd.cc

+ 1
- 1
mindspore/lite/include/model.h View File

@@ -20,7 +20,7 @@
#include <string>
#include <vector>
#include <memory>
#include "schema/model_generated.h"
#include "src/ops/primitive_c.h"

namespace mindspore {
#define MS_API __attribute__((visibility("default")))


+ 2
- 3
mindspore/lite/src/CMakeLists.txt View File

@@ -32,11 +32,10 @@ set(ANF_SRC
${ANF_SRC}
${CMAKE_CURRENT_SOURCE_DIR}/ir/meta_tensor_extends.cc
)
add_library(mindspore-lite SHARED ${LITE_SRC} ${ANF_SRC})
file(GLOB_RECURSE C_OPS_SRC ${CMAKE_CURRENT_SOURCE_DIR}/ops/*.cc)
add_library(mindspore-lite SHARED ${LITE_SRC} ${ANF_SRC} ${C_OPS_SRC})
target_link_libraries(mindspore-lite
cpu_kernel_mid_
c_ops_mid
)

add_subdirectory(runtime/kernel/arm)


+ 0
- 40
mindspore/lite/src/ir/primitive_t_value.cc View File

@@ -1,40 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "src/ir/primitive_t_value.h"

namespace mindspore::lite {
std::shared_ptr<PrimitiveTValue> GetReturnPrim() {
auto return_primitiveT = new schema::PrimitiveT;
return_primitiveT->value.type = schema::PrimitiveType_Return;
return_primitiveT->value.value = new schema::ReturnT;
return std::make_shared<PrimitiveTValue>(return_primitiveT);
}

std::shared_ptr<PrimitiveTValue> GetMakeTuplePrim() {
auto make_tuple_primitiveT = new schema::PrimitiveT;
make_tuple_primitiveT->value.type = schema::PrimitiveType_MakeTuple;
make_tuple_primitiveT->value.value = new schema::MakeTupleT;
return std::make_shared<PrimitiveTValue>(make_tuple_primitiveT);
}

std::shared_ptr<PrimitiveTValue> GetTupleGetItemPrim() {
auto tuple_get_item_primitiveT = new schema::PrimitiveT();
tuple_get_item_primitiveT->value.type = schema::PrimitiveType_TupleGetItem;
tuple_get_item_primitiveT->value.value = new schema::TupleGetItemT;
return std::make_shared<PrimitiveTValue>(tuple_get_item_primitiveT);
}
} // namespace mindspore::lite

+ 0
- 91
mindspore/lite/src/ir/primitive_t_value.h View File

@@ -1,91 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_LITE_SRC_ANF_IMPORTER_PRIMITIVET_H_
#define MINDSPORE_LITE_SRC_ANF_IMPORTER_PRIMITIVET_H_

#include <vector>
#include <memory>
#include "schema/inner/model_generated.h"
#include "ir/value.h"

namespace mindspore::lite {

class PrimitiveTValue : public Value {
public:
explicit PrimitiveTValue(schema::PrimitiveT *primt) : primitive(primt) {}
// not responsible to free primitive, the one created the dynamic memory is responsible to free it.
~PrimitiveTValue() override = default;

MS_DECLARE_PARENT(PrimitiveTValue, Value)

schema::PrimitiveT *GetPrimitiveT() const { return this->primitive; }

void SetPrimitiveT(schema::PrimitiveT *primIn) { this->primitive = primIn; }

bool operator==(const Value &rhs) const override {
if (rhs.isa<PrimitiveTValue>()) {
auto other_prim = static_cast<const PrimitiveTValue &>(rhs);
auto a = this->primitive->value.type;
auto b = other_prim.primitive->value.type;
return a == b;
} else {
return false;
}
}

void SetInputQuantParam(const std::vector<std::vector<schema::QuantParamT>> &input_quant_param) {
this->input_quant_param_ = input_quant_param;
}

void SetOutputQuantParam(const std::vector<std::vector<schema::QuantParamT>> &output_quant_param) {
this->output_quant_param_ = output_quant_param;
}

void ClearInputOutputQuantParam() {
input_quant_param_.clear();
output_quant_param_.clear();
}

void AddInputQuantParam(std::vector<schema::QuantParamT> quant_param) {
this->input_quant_param_.emplace_back(quant_param);
}
std::vector<std::vector<schema::QuantParamT>> GetInputQuantParams() const { return input_quant_param_; }

void AddOutputQuantParam(std::vector<schema::QuantParamT> quant_param) {
this->output_quant_param_.emplace_back(quant_param);
}
std::vector<std::vector<schema::QuantParamT>> GetOutputQuantParams() const { return output_quant_param_; }

void SetQuantType(schema::QuantType quant_type) { this->quant_type_ = quant_type; }

schema::QuantType GetQuantType() const { return quant_type_; }

protected:
schema::PrimitiveT *primitive = nullptr;
std::vector<std::vector<schema::QuantParamT>> input_quant_param_;
std::vector<std::vector<schema::QuantParamT>> output_quant_param_;
schema::QuantType quant_type_{schema::QuantType_QUANT_NONE};
};

std::shared_ptr<PrimitiveTValue> GetReturnPrim();

std::shared_ptr<PrimitiveTValue> GetMakeTuplePrim();

std::shared_ptr<PrimitiveTValue> GetTupleGetItemPrim();
} // namespace mindspore::lite

#endif // MINDSPORE_LITE_SRC_ANF_IMPORTER_PRIMITIVET_H_

+ 1
- 1
mindspore/lite/src/lite_kernel.h View File

@@ -21,11 +21,11 @@
#ifdef ENABLE_ARM
#include <arm_neon.h>
#endif
#include "src/ops/primitive_c.h"
#include "src/runtime/kernel/arm/nnacl/op_base.h"
#include "include/context.h"
#include "src/ir/tensor.h"
#include "include/errorcode.h"
#include "src/ops/primitive_c.h"

#ifdef ENABLE_FP16
using FLOAT_t = float16_t;


+ 2
- 1
mindspore/lite/src/lite_session.cc View File

@@ -14,9 +14,9 @@
* limitations under the License.
*/

#include "src/lite_session.h"
#include <vector>
#include "include/errorcode.h"
#include "src/lite_session.h"
#include "utils/log_adapter.h"
#include "src/scheduler.h"
#include "src/runtime/runtime_api.h"
@@ -76,6 +76,7 @@ int LiteSession::ConvertTensors(const lite::Model *model) {

this->tensors_.emplace_back(dstTensor);
}

return RET_OK;
}



+ 1
- 1
mindspore/lite/src/lite_session.h View File

@@ -21,11 +21,11 @@
#include <vector>
#include <string>
#include <unordered_map>
#include "src/lite_kernel.h"
#include "include/ms_tensor.h"
#include "include/lite_session.h"
#include "include/model.h"
#include "include/context.h"
#include "src/lite_kernel.h"
#include "schema/model_generated.h"
#include "src/executor.h"



+ 1
- 2
mindspore/lite/src/model.cc View File

@@ -14,6 +14,7 @@
* limitations under the License.
*/

#include "include/model.h"
#include "src/ops/unique.h"
#include "src/ops/space_to_batch.h"
#include "src/ops/conv2d.h"
@@ -106,8 +107,6 @@
#include "src/ops/squared_difference.h"
#include "src/ops/ceil.h"
#include "src/ops/round.h"
#include "src/ops/primitive_c.h"
#include "include/model.h"
#include "utils/log_adapter.h"

namespace mindspore::lite {


+ 0
- 3
mindspore/lite/src/ops/CMakeLists.txt View File

@@ -1,3 +0,0 @@
file(GLOB_RECURSE C_OPS_SRC ${CMAKE_CURRENT_SOURCE_DIR}/*.cc)

add_library(c_ops_mid OBJECT ${C_OPS_SRC})

+ 4
- 1
mindspore/lite/src/ops/abs.h View File

@@ -32,7 +32,10 @@ namespace mindspore {
namespace lite {
class Abs : public ArithmeticSelf {
public:
explicit Abs(OriginPrimitive *primitive) : ArithmeticSelf(primitive) {}
#ifdef PRIMITIVE_WRITEABLE
explicit Abs(schema::PrimitiveT *primitive) : ArithmeticSelf(primitive) {}
#endif
explicit Abs(schema::Primitive *primitive) : ArithmeticSelf(primitive) {}
};
} // namespace lite
} // namespace mindspore


+ 6
- 6
mindspore/lite/src/ops/activation.cc View File

@@ -19,16 +19,16 @@
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
int Activation::GetType() const { return this->primitive->value.AsActivation()->type; }
float Activation::GetAlpha() const { return this->primitive->value.AsActivation()->alpha; }
int Activation::GetType() const { return this->primitive_->value.AsActivation()->type; }
float Activation::GetAlpha() const { return this->primitive_->value.AsActivation()->alpha; }

void Activation::SetType(int type) { this->primitive->value.AsActivation()->type = (schema::ActivationType)type; }
void Activation::SetAlpha(float alpha) { this->primitive->value.AsActivation()->alpha = alpha; }
void Activation::SetType(int type) { this->primitive_->value.AsActivation()->type = (schema::ActivationType)type; }
void Activation::SetAlpha(float alpha) { this->primitive_->value.AsActivation()->alpha = alpha; }

#else

int Activation::GetType() const { return this->primitive->value_as_Activation()->type(); }
float Activation::GetAlpha() const { return this->primitive->value_as_Activation()->alpha(); }
int Activation::GetType() const { return this->primitive_->value_as_Activation()->type(); }
float Activation::GetAlpha() const { return this->primitive_->value_as_Activation()->alpha(); }

void Activation::SetType(int type) {}
void Activation::SetAlpha(float alpha) {}


+ 6
- 9
mindspore/lite/src/ops/activation.h View File

@@ -13,26 +13,23 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef LITE_MINDSPORE_LITE_C_OPS_ACTIVATION_H_
#define LITE_MINDSPORE_LITE_C_OPS_ACTIVATION_H_

#include <vector>
#include <set>
#include <cmath>
#include "ir/dtype/type_id.h"
#include "src/ops/primitive_c.h"
#ifdef PRIMITIVE_WRITEABLE
#include "schema/inner/model_generated.h"
#else
#include "schema/model_generated.h"
#endif

#ifndef LITE_MINDSPORE_LITE_C_OPS_ACTIVATION_H_
#define LITE_MINDSPORE_LITE_C_OPS_ACTIVATION_H_

namespace mindspore {
namespace lite {
class Activation : public PrimitiveC {
public:
explicit Activation(OriginPrimitive *primitive) : PrimitiveC(primitive) {}
#ifdef PRIMITIVE_WRITEABLE
explicit Activation(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
#endif
explicit Activation(schema::Primitive *primitive) : PrimitiveC(primitive) {}
int GetType() const;
float GetAlpha() const;
void SetType(int type);


+ 3
- 3
mindspore/lite/src/ops/activation_grad.cc View File

@@ -19,15 +19,15 @@
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
int ActivationGrad::GetType() const { return this->primitive->value.AsActivationGrad()->type; }
int ActivationGrad::GetType() const { return this->primitive_->value.AsActivationGrad()->type; }

void ActivationGrad::SetType(int type) {
this->primitive->value.AsActivationGrad()->type = (schema::ActivationGradType)type;
this->primitive_->value.AsActivationGrad()->type = (schema::ActivationGradType)type;
}

#else

int ActivationGrad::GetType() const { return this->primitive->value_as_ActivationGrad()->type(); }
int ActivationGrad::GetType() const { return this->primitive_->value_as_ActivationGrad()->type(); }

void ActivationGrad::SetType(int type) {}
#endif


+ 7
- 9
mindspore/lite/src/ops/activation_grad.h View File

@@ -14,25 +14,23 @@
* limitations under the License.
*/

#ifndef LITE_MINDSPORE_LITE_C_OPS_ACTIVATION_GRAD_H_
#define LITE_MINDSPORE_LITE_C_OPS_ACTIVATION_GRAD_H_

#include <vector>
#include <set>
#include <cmath>
#include "ir/dtype/type_id.h"
#include "src/ops/primitive_c.h"
#ifdef PRIMITIVE_WRITEABLE
#include "schema/inner/model_generated.h"
#else
#include "schema/model_generated.h"
#endif

#ifndef LITE_MINDSPORE_LITE_C_OPS_ACTIVATION_GRAD_H_
#define LITE_MINDSPORE_LITE_C_OPS_ACTIVATION_GRAD_H_

namespace mindspore {
namespace lite {
class ActivationGrad : public PrimitiveC {
public:
explicit ActivationGrad(OriginPrimitive *primitive) : PrimitiveC(primitive) {}
#ifdef PRIMITIVE_WRITEABLE
explicit ActivationGrad(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
#endif
explicit ActivationGrad(schema::Primitive *primitive) : PrimitiveC(primitive) {}

int GetType() const;
void SetType(int type);


+ 3
- 3
mindspore/lite/src/ops/add.cc View File

@@ -19,15 +19,15 @@
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
int Add::GetActivationType() const { return this->primitive->value.AsAdd()->activationType; }
int Add::GetActivationType() const { return this->primitive_->value.AsAdd()->activationType; }

void Add::SetActivationType(int activation_type) {
this->primitive->value.AsAdd()->activationType = (schema::ActivationType)activation_type;
this->primitive_->value.AsAdd()->activationType = (schema::ActivationType)activation_type;
}

#else

int Add::GetActivationType() const { return this->primitive->value_as_Add()->activationType(); }
int Add::GetActivationType() const { return this->primitive_->value_as_Add()->activationType(); }

void Add::SetActivationType(int activation_type) {}
#endif


+ 7
- 3
mindspore/lite/src/ops/add.h View File

@@ -14,6 +14,9 @@
* limitations under the License.
*/

#ifndef LITE_MINDSPORE_LITE_C_OPS_ADD_H_
#define LITE_MINDSPORE_LITE_C_OPS_ADD_H_

#include <vector>
#include <set>
#include <cmath>
@@ -24,14 +27,15 @@
#else
#include "schema/model_generated.h"
#endif
#ifndef LITE_MINDSPORE_LITE_C_OPS_ADD_H_
#define LITE_MINDSPORE_LITE_C_OPS_ADD_H_

namespace mindspore {
namespace lite {
class Add : public Arithmetic {
public:
explicit Add(OriginPrimitive *primitive) : Arithmetic(primitive) {}
#ifdef PRIMITIVE_WRITEABLE
explicit Add(schema::PrimitiveT *primitive) : Arithmetic(primitive) {}
#endif
explicit Add(schema::Primitive *primitive) : Arithmetic(primitive) {}

int GetActivationType() const;
void SetActivationType(int activation_type);


+ 4
- 4
mindspore/lite/src/ops/addn.cc View File

@@ -19,13 +19,13 @@
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
int AddN::GetN() const { return this->primitive->value.AsAddN()->N; }
int AddN::GetN() const { return this->primitive_->value.AsAddN()->N; }

void AddN::SetN(int n) { this->primitive->value.AsAddN()->N = n; }
void AddN::SetN(int n) { this->primitive_->value.AsAddN()->N = n; }

#else

int AddN::GetN() const { return this->primitive->value_as_AddN()->N(); }
int AddN::GetN() const { return this->primitive_->value_as_AddN()->N(); }

void AddN::SetN(int n) {}
#endif
@@ -34,7 +34,7 @@ namespace {
constexpr int kLeastInputNum = 2;
}
int AddN::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::Tensor *> outputs) {
MS_ASSERT(this->primitive != nullptr);
MS_ASSERT(this->primitive_ != nullptr);
auto input = inputs.front();
MS_ASSERT(input != nullptr);
auto output = outputs.front();


+ 7
- 9
mindspore/lite/src/ops/addn.h View File

@@ -14,25 +14,23 @@
* limitations under the License.
*/

#ifndef LITE_MINDSPORE_LITE_C_OPS_ADD_N_H_
#define LITE_MINDSPORE_LITE_C_OPS_ADD_N_H_

#include <vector>
#include <set>
#include <cmath>
#include "ir/dtype/type_id.h"
#include "src/ops/primitive_c.h"
#ifdef PRIMITIVE_WRITEABLE
#include "schema/inner/model_generated.h"
#else
#include "schema/model_generated.h"
#endif

#ifndef LITE_MINDSPORE_LITE_C_OPS_ADD_N_H_
#define LITE_MINDSPORE_LITE_C_OPS_ADD_N_H_

namespace mindspore {
namespace lite {
class AddN : public PrimitiveC {
public:
explicit AddN(OriginPrimitive *primitive) : PrimitiveC(primitive) {}
#ifdef PRIMITIVE_WRITEABLE
explicit AddN(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
#endif
explicit AddN(schema::Primitive *primitive) : PrimitiveC(primitive) {}
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
int GetN() const;
void SetN(int n);


+ 16
- 16
mindspore/lite/src/ops/argmax.cc View File

@@ -19,25 +19,25 @@
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
int ArgMax::GetAxis() const { return this->primitive->value.AsArgMax()->axis; }
bool ArgMax::GetOutMaxValue() const { return this->primitive->value.AsArgMax()->outMaxValue; }
int ArgMax::GetTopK() const { return this->primitive->value.AsArgMax()->topK; }
bool ArgMax::GetKeepDims() const { return this->primitive->value.AsArgMax()->keepDims; }
int ArgMax::GetAxisType() const { return this->primitive->value.AsArgMax()->axisType; }
int ArgMax::GetAxis() const { return this->primitive_->value.AsArgMax()->axis; }
bool ArgMax::GetOutMaxValue() const { return this->primitive_->value.AsArgMax()->outMaxValue; }
int ArgMax::GetTopK() const { return this->primitive_->value.AsArgMax()->topK; }
bool ArgMax::GetKeepDims() const { return this->primitive_->value.AsArgMax()->keepDims; }
int ArgMax::GetAxisType() const { return this->primitive_->value.AsArgMax()->axisType; }

void ArgMax::SetAxis(int axis) { this->primitive->value.AsArgMax()->axis = axis; }
void ArgMax::SetOutMaxValue(bool out_max_value) { this->primitive->value.AsArgMax()->outMaxValue = out_max_value; }
void ArgMax::SetTopK(int top_k) { this->primitive->value.AsArgMax()->topK = top_k; }
void ArgMax::SetKeepDims(bool keep_dims) { this->primitive->value.AsArgMax()->keepDims = keep_dims; }
void ArgMax::SetAxisType(int axis_type) { this->primitive->value.AsArgMax()->axisType = axis_type; }
void ArgMax::SetAxis(int axis) { this->primitive_->value.AsArgMax()->axis = axis; }
void ArgMax::SetOutMaxValue(bool out_max_value) { this->primitive_->value.AsArgMax()->outMaxValue = out_max_value; }
void ArgMax::SetTopK(int top_k) { this->primitive_->value.AsArgMax()->topK = top_k; }
void ArgMax::SetKeepDims(bool keep_dims) { this->primitive_->value.AsArgMax()->keepDims = keep_dims; }
void ArgMax::SetAxisType(int axis_type) { this->primitive_->value.AsArgMax()->axisType = axis_type; }

#else

int ArgMax::GetAxis() const { return this->primitive->value_as_ArgMax()->axis(); }
bool ArgMax::GetOutMaxValue() const { return this->primitive->value_as_ArgMax()->outMaxValue(); }
int ArgMax::GetTopK() const { return this->primitive->value_as_ArgMax()->topK(); }
bool ArgMax::GetKeepDims() const { return this->primitive->value_as_ArgMax()->keepDims(); }
int ArgMax::GetAxisType() const { return this->primitive->value_as_ArgMax()->axisType(); }
int ArgMax::GetAxis() const { return this->primitive_->value_as_ArgMax()->axis(); }
bool ArgMax::GetOutMaxValue() const { return this->primitive_->value_as_ArgMax()->outMaxValue(); }
int ArgMax::GetTopK() const { return this->primitive_->value_as_ArgMax()->topK(); }
bool ArgMax::GetKeepDims() const { return this->primitive_->value_as_ArgMax()->keepDims(); }
int ArgMax::GetAxisType() const { return this->primitive_->value_as_ArgMax()->axisType(); }

void ArgMax::SetAxis(int axis) {}
void ArgMax::SetOutMaxValue(bool out_max_value) {}
@@ -47,7 +47,7 @@ void ArgMax::SetAxisType(int axis_type) {}
#endif

int ArgMax::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::Tensor *> outputs_) {
MS_ASSERT(this->primitive != nullptr);
MS_ASSERT(this->primitive_ != nullptr);
auto input = inputs_.front();
MS_ASSERT(input != nullptr);
auto output = outputs_.front();


+ 7
- 9
mindspore/lite/src/ops/argmax.h View File

@@ -14,25 +14,23 @@
* limitations under the License.
*/

#ifndef LITE_MINDSPORE_LITE_C_OPS_ARG_MAX_H_
#define LITE_MINDSPORE_LITE_C_OPS_ARG_MAX_H_

#include <vector>
#include <set>
#include <cmath>
#include "ir/dtype/type_id.h"
#include "src/ops/primitive_c.h"
#ifdef PRIMITIVE_WRITEABLE
#include "schema/inner/model_generated.h"
#else
#include "schema/model_generated.h"
#endif

#ifndef LITE_MINDSPORE_LITE_C_OPS_ARG_MAX_H_
#define LITE_MINDSPORE_LITE_C_OPS_ARG_MAX_H_

namespace mindspore {
namespace lite {
class ArgMax : public PrimitiveC {
public:
explicit ArgMax(OriginPrimitive *primitive) : PrimitiveC(primitive) {}
#ifdef PRIMITIVE_WRITEABLE
explicit ArgMax(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
#endif
explicit ArgMax(schema::Primitive *primitive) : PrimitiveC(primitive) {}

int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
int GetAxis() const;


+ 16
- 16
mindspore/lite/src/ops/argmin.cc View File

@@ -19,25 +19,25 @@
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
int ArgMin::GetAxis() const { return this->primitive->value.AsArgMin()->axis; }
bool ArgMin::GetOutMaxValue() const { return this->primitive->value.AsArgMin()->outMaxValue; }
int ArgMin::GetTopK() const { return this->primitive->value.AsArgMin()->topK; }
bool ArgMin::GetKeepDims() const { return this->primitive->value.AsArgMin()->keepDims; }
int ArgMin::GetAxisType() const { return this->primitive->value.AsArgMin()->axisType; }
int ArgMin::GetAxis() const { return this->primitive_->value.AsArgMin()->axis; }
bool ArgMin::GetOutMaxValue() const { return this->primitive_->value.AsArgMin()->outMaxValue; }
int ArgMin::GetTopK() const { return this->primitive_->value.AsArgMin()->topK; }
bool ArgMin::GetKeepDims() const { return this->primitive_->value.AsArgMin()->keepDims; }
int ArgMin::GetAxisType() const { return this->primitive_->value.AsArgMin()->axisType; }

void ArgMin::SetAxis(int axis) { this->primitive->value.AsArgMin()->axis = axis; }
void ArgMin::SetOutMaxValue(bool out_max_value) { this->primitive->value.AsArgMin()->outMaxValue = out_max_value; }
void ArgMin::SetTopK(int top_k) { this->primitive->value.AsArgMin()->topK = top_k; }
void ArgMin::SetKeepDims(bool keep_dims) { this->primitive->value.AsArgMin()->keepDims = keep_dims; }
void ArgMin::SetAxisType(int axis_type) { this->primitive->value.AsArgMin()->axisType = axis_type; }
void ArgMin::SetAxis(int axis) { this->primitive_->value.AsArgMin()->axis = axis; }
void ArgMin::SetOutMaxValue(bool out_max_value) { this->primitive_->value.AsArgMin()->outMaxValue = out_max_value; }
void ArgMin::SetTopK(int top_k) { this->primitive_->value.AsArgMin()->topK = top_k; }
void ArgMin::SetKeepDims(bool keep_dims) { this->primitive_->value.AsArgMin()->keepDims = keep_dims; }
void ArgMin::SetAxisType(int axis_type) { this->primitive_->value.AsArgMin()->axisType = axis_type; }

#else

int ArgMin::GetAxis() const { return this->primitive->value_as_ArgMin()->axis(); }
bool ArgMin::GetOutMaxValue() const { return this->primitive->value_as_ArgMin()->outMaxValue(); }
int ArgMin::GetTopK() const { return this->primitive->value_as_ArgMin()->topK(); }
bool ArgMin::GetKeepDims() const { return this->primitive->value_as_ArgMin()->keepDims(); }
int ArgMin::GetAxisType() const { return this->primitive->value_as_ArgMin()->axisType(); }
int ArgMin::GetAxis() const { return this->primitive_->value_as_ArgMin()->axis(); }
bool ArgMin::GetOutMaxValue() const { return this->primitive_->value_as_ArgMin()->outMaxValue(); }
int ArgMin::GetTopK() const { return this->primitive_->value_as_ArgMin()->topK(); }
bool ArgMin::GetKeepDims() const { return this->primitive_->value_as_ArgMin()->keepDims(); }
int ArgMin::GetAxisType() const { return this->primitive_->value_as_ArgMin()->axisType(); }

void ArgMin::SetAxis(int axis) {}
void ArgMin::SetOutMaxValue(bool out_max_value) {}
@@ -47,7 +47,7 @@ void ArgMin::SetAxisType(int axis_type) {}
#endif

int ArgMin::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) {
MS_ASSERT(this->primitive != nullptr);
MS_ASSERT(this->primitive_ != nullptr);
auto input = inputs_.front();
MS_ASSERT(input != nullptr);
auto output = outputs_.front();


+ 7
- 9
mindspore/lite/src/ops/argmin.h View File

@@ -14,25 +14,23 @@
* limitations under the License.
*/

#ifndef LITE_MINDSPORE_LITE_C_OPS_ARG_MIN_H_
#define LITE_MINDSPORE_LITE_C_OPS_ARG_MIN_H_

#include <vector>
#include <set>
#include <cmath>
#include "ir/dtype/type_id.h"
#include "src/ops/primitive_c.h"
#ifdef PRIMITIVE_WRITEABLE
#include "schema/inner/model_generated.h"
#else
#include "schema/model_generated.h"
#endif

#ifndef LITE_MINDSPORE_LITE_C_OPS_ARG_MIN_H_
#define LITE_MINDSPORE_LITE_C_OPS_ARG_MIN_H_

namespace mindspore {
namespace lite {
class ArgMin : public PrimitiveC {
public:
explicit ArgMin(OriginPrimitive *primitive) : PrimitiveC(primitive) {}
#ifdef PRIMITIVE_WRITEABLE
explicit ArgMin(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
#endif
explicit ArgMin(schema::Primitive *primitive) : PrimitiveC(primitive) {}

int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
int GetAxis() const;


+ 1
- 1
mindspore/lite/src/ops/arithmetic.cc View File

@@ -22,7 +22,7 @@
namespace mindspore {
namespace lite {
int Arithmetic::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) {
MS_ASSERT(this->primitive != nullptr);
MS_ASSERT(this->primitive_ != nullptr);
if (inputs_.size() != kDoubleNum) {
MS_LOG(ERROR) << "The number of input must be " << kDoubleNum;
return RET_INPUT_TENSOR_ERROR;


+ 7
- 9
mindspore/lite/src/ops/arithmetic.h View File

@@ -14,25 +14,23 @@
* limitations under the License.
*/

#ifndef LITE_MINDSPORE_LITE_C_OPS_ARITHMETIC_H_
#define LITE_MINDSPORE_LITE_C_OPS_ARITHMETIC_H_

#include <vector>
#include <set>
#include <cmath>
#include "ir/dtype/type_id.h"
#include "src/ops/primitive_c.h"
#ifdef PRIMITIVE_WRITEABLE
#include "schema/inner/model_generated.h"
#else
#include "schema/model_generated.h"
#endif

#ifndef LITE_MINDSPORE_LITE_C_OPS_ARITHMETIC_H_
#define LITE_MINDSPORE_LITE_C_OPS_ARITHMETIC_H_

namespace mindspore {
namespace lite {
class Arithmetic : public PrimitiveC {
public:
explicit Arithmetic(OriginPrimitive *primitive) : PrimitiveC(primitive) {}
#ifdef PRIMITIVE_WRITEABLE
explicit Arithmetic(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
#endif
explicit Arithmetic(schema::Primitive *primitive) : PrimitiveC(primitive) {}

int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
bool Broadcasting() { return this->broadcasting_; }


+ 1
- 1
mindspore/lite/src/ops/arithmetic_self.cc View File

@@ -22,7 +22,7 @@ namespace mindspore {
namespace lite {

int ArithmeticSelf::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::Tensor *> outputs_) {
MS_ASSERT(this->primitive != nullptr);
MS_ASSERT(this->primitive_ != nullptr);
auto input = inputs_.front();
MS_ASSERT(input != nullptr);
auto output = outputs_.front();


+ 7
- 11
mindspore/lite/src/ops/arithmetic_self.h View File

@@ -14,24 +14,20 @@
* limitations under the License.
*/

#include <vector>
#include <set>
#include <cmath>
#include "ir/dtype/type_id.h"
#include "src/ops/primitive_c.h"
#ifdef PRIMITIVE_WRITEABLE
#include "schema/inner/model_generated.h"
#else
#include "schema/model_generated.h"
#endif
#ifndef LITE_MINDSPORE_LITE_C_OPS_ARITHMETIC_SELF_H_
#define LITE_MINDSPORE_LITE_C_OPS_ARITHMETIC_SELF_H_

#include <vector>
#include "src/ops/primitive_c.h"

namespace mindspore {
namespace lite {
class ArithmeticSelf : public PrimitiveC {
public:
explicit ArithmeticSelf(OriginPrimitive *primitive) : PrimitiveC(primitive) {}
#ifdef PRIMITIVE_WRITEABLE
explicit ArithmeticSelf(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
#endif
explicit ArithmeticSelf(schema::Primitive *primitive) : PrimitiveC(primitive) {}

int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
};


+ 3
- 3
mindspore/lite/src/ops/batch_norm.cc View File

@@ -19,13 +19,13 @@
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
float BatchNorm::GetEpsilon() const { return this->primitive->value.AsBatchNorm()->epsilon; }
float BatchNorm::GetEpsilon() const { return this->primitive_->value.AsBatchNorm()->epsilon; }

void BatchNorm::SetEpsilon(float epsilon) { this->primitive->value.AsBatchNorm()->epsilon = epsilon; }
void BatchNorm::SetEpsilon(float epsilon) { this->primitive_->value.AsBatchNorm()->epsilon = epsilon; }

#else

float BatchNorm::GetEpsilon() const { return this->primitive->value_as_BatchNorm()->epsilon(); }
float BatchNorm::GetEpsilon() const { return this->primitive_->value_as_BatchNorm()->epsilon(); }

void BatchNorm::SetEpsilon(float epsilon) {}
#endif


+ 7
- 9
mindspore/lite/src/ops/batch_norm.h View File

@@ -14,25 +14,23 @@
* limitations under the License.
*/

#ifndef LITE_MINDSPORE_LITE_C_OPS_BATCH_NORM_H_
#define LITE_MINDSPORE_LITE_C_OPS_BATCH_NORM_H_

#include <vector>
#include <set>
#include <cmath>
#include "ir/dtype/type_id.h"
#include "src/ops/primitive_c.h"
#ifdef PRIMITIVE_WRITEABLE
#include "schema/inner/model_generated.h"
#else
#include "schema/model_generated.h"
#endif

#ifndef LITE_MINDSPORE_LITE_C_OPS_BATCH_NORM_H_
#define LITE_MINDSPORE_LITE_C_OPS_BATCH_NORM_H_

namespace mindspore {
namespace lite {
class BatchNorm : public PrimitiveC {
public:
explicit BatchNorm(OriginPrimitive *primitive) : PrimitiveC(primitive) {}
#ifdef PRIMITIVE_WRITEABLE
explicit BatchNorm(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
#endif
explicit BatchNorm(schema::Primitive *primitive) : PrimitiveC(primitive) {}

float GetEpsilon() const;
void SetEpsilon(float epsilon);


+ 7
- 7
mindspore/lite/src/ops/batch_to_space.cc View File

@@ -23,22 +23,22 @@
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
std::vector<int> BatchToSpace::GetBlockShape() const { return this->primitive->value.AsBatchToSpace()->blockShape; }
std::vector<int> BatchToSpace::GetCrops() const { return this->primitive->value.AsBatchToSpace()->crops; }
std::vector<int> BatchToSpace::GetBlockShape() const { return this->primitive_->value.AsBatchToSpace()->blockShape; }
std::vector<int> BatchToSpace::GetCrops() const { return this->primitive_->value.AsBatchToSpace()->crops; }

void BatchToSpace::SetBlockShape(const std::vector<int> &block_shape) {
this->primitive->value.AsBatchToSpace()->blockShape = block_shape;
this->primitive_->value.AsBatchToSpace()->blockShape = block_shape;
}
void BatchToSpace::SetCrops(const std::vector<int> &crops) { this->primitive->value.AsBatchToSpace()->crops = crops; }
void BatchToSpace::SetCrops(const std::vector<int> &crops) { this->primitive_->value.AsBatchToSpace()->crops = crops; }

#else

std::vector<int> BatchToSpace::GetBlockShape() const {
auto fb_vector = this->primitive->value_as_BatchToSpace()->blockShape();
auto fb_vector = this->primitive_->value_as_BatchToSpace()->blockShape();
return std::vector<int>(fb_vector->begin(), fb_vector->end());
}
std::vector<int> BatchToSpace::GetCrops() const {
auto fb_vector = this->primitive->value_as_BatchToSpace()->crops();
auto fb_vector = this->primitive_->value_as_BatchToSpace()->crops();
return std::vector<int>(fb_vector->begin(), fb_vector->end());
}

@@ -53,7 +53,7 @@ constexpr int kCropsSize = 4;
} // namespace

int BatchToSpace::InferShape(std::vector<lite::tensor::Tensor *> inputs, std::vector<lite::tensor::Tensor *> outputs) {
MS_ASSERT(this->primitive != nullptr);
MS_ASSERT(this->primitive_ != nullptr);
if (outputs.size() != kBatchToSpaceOutputNum || inputs.size() != kBatchToSpaceInputNum) {
MS_LOG(ERROR) << "Invalid output/input size! output size: " << outputs.size() << ",input size: " << inputs.size();
return RET_PARAM_INVALID;


+ 7
- 9
mindspore/lite/src/ops/batch_to_space.h View File

@@ -14,25 +14,23 @@
* limitations under the License.
*/

#ifndef LITE_MINDSPORE_LITE_C_OPS_BATCH_TO_SPACE_H_
#define LITE_MINDSPORE_LITE_C_OPS_BATCH_TO_SPACE_H_

#include <vector>
#include <set>
#include <cmath>
#include "ir/dtype/type_id.h"
#include "src/ops/primitive_c.h"
#ifdef PRIMITIVE_WRITEABLE
#include "schema/inner/model_generated.h"
#else
#include "schema/model_generated.h"
#endif

#ifndef LITE_MINDSPORE_LITE_C_OPS_BATCH_TO_SPACE_H_
#define LITE_MINDSPORE_LITE_C_OPS_BATCH_TO_SPACE_H_

namespace mindspore {
namespace lite {
class BatchToSpace : public PrimitiveC {
public:
explicit BatchToSpace(OriginPrimitive *primitive) : PrimitiveC(primitive) {}
#ifdef PRIMITIVE_WRITEABLE
explicit BatchToSpace(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
#endif
explicit BatchToSpace(schema::Primitive *primitive) : PrimitiveC(primitive) {}

int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
std::vector<int> GetBlockShape() const;


+ 3
- 3
mindspore/lite/src/ops/bias_add.cc View File

@@ -19,14 +19,14 @@
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
std::vector<int> BiasAdd::GetAxis() const { return this->primitive->value.AsBiasAdd()->axis; }
std::vector<int> BiasAdd::GetAxis() const { return this->primitive_->value.AsBiasAdd()->axis; }

void BiasAdd::SetAxis(const std::vector<int> &axis) { this->primitive->value.AsBiasAdd()->axis = axis; }
void BiasAdd::SetAxis(const std::vector<int> &axis) { this->primitive_->value.AsBiasAdd()->axis = axis; }

#else

std::vector<int> BiasAdd::GetAxis() const {
auto fb_vector = this->primitive->value_as_BiasAdd()->axis();
auto fb_vector = this->primitive_->value_as_BiasAdd()->axis();
return std::vector<int>(fb_vector->begin(), fb_vector->end());
}



+ 7
- 9
mindspore/lite/src/ops/bias_add.h View File

@@ -14,25 +14,23 @@
* limitations under the License.
*/

#ifndef LITE_MINDSPORE_LITE_C_OPS_BIAS_ADD_H_
#define LITE_MINDSPORE_LITE_C_OPS_BIAS_ADD_H_

#include <vector>
#include <set>
#include <cmath>
#include "ir/dtype/type_id.h"
#include "src/ops/primitive_c.h"
#ifdef PRIMITIVE_WRITEABLE
#include "schema/inner/model_generated.h"
#else
#include "schema/model_generated.h"
#endif

#ifndef LITE_MINDSPORE_LITE_C_OPS_BIAS_ADD_H_
#define LITE_MINDSPORE_LITE_C_OPS_BIAS_ADD_H_

namespace mindspore {
namespace lite {
class BiasAdd : public PrimitiveC {
public:
explicit BiasAdd(OriginPrimitive *primitive) : PrimitiveC(primitive) {}
#ifdef PRIMITIVE_WRITEABLE
explicit BiasAdd(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
#endif
explicit BiasAdd(schema::Primitive *primitive) : PrimitiveC(primitive) {}

std::vector<int> GetAxis() const;
void SetAxis(const std::vector<int> &axis);


+ 3
- 3
mindspore/lite/src/ops/bias_grad.cc View File

@@ -19,14 +19,14 @@
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
std::vector<int> BiasGrad::GetAxis() const { return this->primitive->value.AsBiasGrad()->axis; }
std::vector<int> BiasGrad::GetAxis() const { return this->primitive_->value.AsBiasGrad()->axis; }

void BiasGrad::SetAxis(const std::vector<int> &axis) { this->primitive->value.AsBiasGrad()->axis = axis; }
void BiasGrad::SetAxis(const std::vector<int> &axis) { this->primitive_->value.AsBiasGrad()->axis = axis; }

#else

std::vector<int> BiasGrad::GetAxis() const {
auto fb_vector = this->primitive->value_as_BiasGrad()->axis();
auto fb_vector = this->primitive_->value_as_BiasGrad()->axis();
return std::vector<int>(fb_vector->begin(), fb_vector->end());
}



+ 7
- 9
mindspore/lite/src/ops/bias_grad.h View File

@@ -14,25 +14,23 @@
* limitations under the License.
*/

#ifndef LITE_MINDSPORE_LITE_C_OPS_BIAS_GRAD_H_
#define LITE_MINDSPORE_LITE_C_OPS_BIAS_GRAD_H_

#include <vector>
#include <set>
#include <cmath>
#include "ir/dtype/type_id.h"
#include "src/ops/primitive_c.h"
#ifdef PRIMITIVE_WRITEABLE
#include "schema/inner/model_generated.h"
#else
#include "schema/model_generated.h"
#endif

#ifndef LITE_MINDSPORE_LITE_C_OPS_BIAS_GRAD_H_
#define LITE_MINDSPORE_LITE_C_OPS_BIAS_GRAD_H_

namespace mindspore {
namespace lite {
class BiasGrad : public PrimitiveC {
public:
explicit BiasGrad(OriginPrimitive *primitive) : PrimitiveC(primitive) {}
#ifdef PRIMITIVE_WRITEABLE
explicit BiasGrad(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
#endif
explicit BiasGrad(schema::Primitive *primitive) : PrimitiveC(primitive) {}

std::vector<int> GetAxis() const;
void SetAxis(const std::vector<int> &axis);


+ 6
- 6
mindspore/lite/src/ops/bn_grad_input.cc View File

@@ -19,16 +19,16 @@
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
float BNGradInput::GetEps() const { return this->primitive->value.AsBNGradInput()->eps; }
int BNGradInput::GetChannels() const { return this->primitive->value.AsBNGradInput()->channels; }
float BNGradInput::GetEps() const { return this->primitive_->value.AsBNGradInput()->eps; }
int BNGradInput::GetChannels() const { return this->primitive_->value.AsBNGradInput()->channels; }

void BNGradInput::SetEps(float eps) { this->primitive->value.AsBNGradInput()->eps = eps; }
void BNGradInput::SetChannels(int channels) { this->primitive->value.AsBNGradInput()->channels = channels; }
void BNGradInput::SetEps(float eps) { this->primitive_->value.AsBNGradInput()->eps = eps; }
void BNGradInput::SetChannels(int channels) { this->primitive_->value.AsBNGradInput()->channels = channels; }

#else

float BNGradInput::GetEps() const { return this->primitive->value_as_BNGradInput()->eps(); }
int BNGradInput::GetChannels() const { return this->primitive->value_as_BNGradInput()->channels(); }
float BNGradInput::GetEps() const { return this->primitive_->value_as_BNGradInput()->eps(); }
int BNGradInput::GetChannels() const { return this->primitive_->value_as_BNGradInput()->channels(); }

void BNGradInput::SetEps(float eps) {}
void BNGradInput::SetChannels(int channels) {}


+ 7
- 9
mindspore/lite/src/ops/bn_grad_input.h View File

@@ -14,25 +14,23 @@
* limitations under the License.
*/

#ifndef LITE_MINDSPORE_LITE_C_OPS_B_N_GRAD_INPUT_H_
#define LITE_MINDSPORE_LITE_C_OPS_B_N_GRAD_INPUT_H_

#include <vector>
#include <set>
#include <cmath>
#include "ir/dtype/type_id.h"
#include "src/ops/primitive_c.h"
#ifdef PRIMITIVE_WRITEABLE
#include "schema/inner/model_generated.h"
#else
#include "schema/model_generated.h"
#endif

#ifndef LITE_MINDSPORE_LITE_C_OPS_B_N_GRAD_INPUT_H_
#define LITE_MINDSPORE_LITE_C_OPS_B_N_GRAD_INPUT_H_

namespace mindspore {
namespace lite {
class BNGradInput : public PrimitiveC {
public:
explicit BNGradInput(OriginPrimitive *primitive) : PrimitiveC(primitive) {}
#ifdef PRIMITIVE_WRITEABLE
explicit BNGradInput(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
#endif
explicit BNGradInput(schema::Primitive *primitive) : PrimitiveC(primitive) {}

float GetEps() const;
int GetChannels() const;


+ 3
- 3
mindspore/lite/src/ops/broadcast_to.cc View File

@@ -19,16 +19,16 @@
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
std::vector<int> BroadcastTo::GetDstShape() const { return this->primitive->value.AsBroadcastTo()->dst_shape; }
std::vector<int> BroadcastTo::GetDstShape() const { return this->primitive_->value.AsBroadcastTo()->dst_shape; }

void BroadcastTo::SetDstShape(const std::vector<int> &dst_shape) {
this->primitive->value.AsBroadcastTo()->dst_shape = dst_shape;
this->primitive_->value.AsBroadcastTo()->dst_shape = dst_shape;
}

#else

std::vector<int> BroadcastTo::GetDstShape() const {
auto fb_vector = this->primitive->value_as_BroadcastTo()->dst_shape();
auto fb_vector = this->primitive_->value_as_BroadcastTo()->dst_shape();
return std::vector<int>(fb_vector->begin(), fb_vector->end());
}



+ 7
- 9
mindspore/lite/src/ops/broadcast_to.h View File

@@ -14,25 +14,23 @@
* limitations under the License.
*/

#ifndef LITE_MINDSPORE_LITE_C_OPS_BROADCAST_TO_H_
#define LITE_MINDSPORE_LITE_C_OPS_BROADCAST_TO_H_

#include <vector>
#include <set>
#include <cmath>
#include "ir/dtype/type_id.h"
#include "src/ops/primitive_c.h"
#ifdef PRIMITIVE_WRITEABLE
#include "schema/inner/model_generated.h"
#else
#include "schema/model_generated.h"
#endif

#ifndef LITE_MINDSPORE_LITE_C_OPS_BROADCAST_TO_H_
#define LITE_MINDSPORE_LITE_C_OPS_BROADCAST_TO_H_

namespace mindspore {
namespace lite {
class BroadcastTo : public PrimitiveC {
public:
explicit BroadcastTo(OriginPrimitive *primitive) : PrimitiveC(primitive) {}
#ifdef PRIMITIVE_WRITEABLE
explicit BroadcastTo(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
#endif
explicit BroadcastTo(schema::Primitive *primitive) : PrimitiveC(primitive) {}

int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
std::vector<int> GetDstShape() const;


+ 3
- 3
mindspore/lite/src/ops/caffe_p_relu.cc View File

@@ -19,15 +19,15 @@
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
bool CaffePReLU::GetChannelShared() const { return this->primitive->value.AsCaffePReLU()->channelShared; }
bool CaffePReLU::GetChannelShared() const { return this->primitive_->value.AsCaffePReLU()->channelShared; }

void CaffePReLU::SetChannelShared(bool channel_shared) {
this->primitive->value.AsCaffePReLU()->channelShared = channel_shared;
this->primitive_->value.AsCaffePReLU()->channelShared = channel_shared;
}

#else

bool CaffePReLU::GetChannelShared() const { return this->primitive->value_as_CaffePReLU()->channelShared(); }
bool CaffePReLU::GetChannelShared() const { return this->primitive_->value_as_CaffePReLU()->channelShared(); }

void CaffePReLU::SetChannelShared(bool channel_shared) {}
#endif


+ 7
- 10
mindspore/lite/src/ops/caffe_p_relu.h View File

@@ -14,26 +14,23 @@
* limitations under the License.
*/

#ifndef LITE_MINDSPORE_LITE_C_OPS_CAFFE_P_RE_L_U_H_
#define LITE_MINDSPORE_LITE_C_OPS_CAFFE_P_RE_L_U_H_

#include <vector>
#include <set>
#include <cmath>
#include "ir/dtype/type_id.h"
#include "src/ops/primitive_c.h"
#include "src/ops/activation.h"
#ifdef PRIMITIVE_WRITEABLE
#include "schema/inner/model_generated.h"
#else
#include "schema/model_generated.h"
#endif

#ifndef LITE_MINDSPORE_LITE_C_OPS_CAFFE_P_RE_L_U_H_
#define LITE_MINDSPORE_LITE_C_OPS_CAFFE_P_RE_L_U_H_

namespace mindspore {
namespace lite {
class CaffePReLU : public Activation {
public:
explicit CaffePReLU(OriginPrimitive *primitive) : Activation(primitive) {}
#ifdef PRIMITIVE_WRITEABLE
explicit CaffePReLU(schema::PrimitiveT *primitive) : Activation(primitive) {}
#endif
explicit CaffePReLU(schema::Primitive *primitive) : Activation(primitive) {}

bool GetChannelShared() const;
void SetChannelShared(bool channel_shared);


+ 7
- 7
mindspore/lite/src/ops/cast.cc View File

@@ -19,23 +19,23 @@
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
int Cast::GetSrcT() const { return this->primitive->value.AsCast()->srcT; }
int Cast::GetDstT() const { return this->primitive->value.AsCast()->dstT; }
int Cast::GetSrcT() const { return this->primitive_->value.AsCast()->srcT; }
int Cast::GetDstT() const { return this->primitive_->value.AsCast()->dstT; }

void Cast::SetSrcT(int src_t) { this->primitive->value.AsCast()->srcT = src_t; }
void Cast::SetDstT(int dst_t) { this->primitive->value.AsCast()->dstT = dst_t; }
void Cast::SetSrcT(int src_t) { this->primitive_->value.AsCast()->srcT = src_t; }
void Cast::SetDstT(int dst_t) { this->primitive_->value.AsCast()->dstT = dst_t; }

#else

int Cast::GetSrcT() const { return this->primitive->value_as_Cast()->srcT(); }
int Cast::GetDstT() const { return this->primitive->value_as_Cast()->dstT(); }
int Cast::GetSrcT() const { return this->primitive_->value_as_Cast()->srcT(); }
int Cast::GetDstT() const { return this->primitive_->value_as_Cast()->dstT(); }

void Cast::SetSrcT(int src_t) {}
void Cast::SetDstT(int dst_t) {}
#endif

int Cast::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::Tensor *> outputs_) {
MS_ASSERT(this->primitive != nullptr);
MS_ASSERT(this->primitive_ != nullptr);
auto input = inputs_.front();
MS_ASSERT(input != nullptr);
auto output = outputs_.front();


+ 7
- 9
mindspore/lite/src/ops/cast.h View File

@@ -14,25 +14,23 @@
* limitations under the License.
*/

#ifndef LITE_MINDSPORE_LITE_C_OPS_CAST_H_
#define LITE_MINDSPORE_LITE_C_OPS_CAST_H_

#include <vector>
#include <set>
#include <cmath>
#include "ir/dtype/type_id.h"
#include "src/ops/primitive_c.h"
#ifdef PRIMITIVE_WRITEABLE
#include "schema/inner/model_generated.h"
#else
#include "schema/model_generated.h"
#endif

#ifndef LITE_MINDSPORE_LITE_C_OPS_CAST_H_
#define LITE_MINDSPORE_LITE_C_OPS_CAST_H_

namespace mindspore {
namespace lite {
class Cast : public PrimitiveC {
public:
explicit Cast(OriginPrimitive *primitive) : PrimitiveC(primitive) {}
#ifdef PRIMITIVE_WRITEABLE
explicit Cast(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
#endif
explicit Cast(schema::Primitive *primitive) : PrimitiveC(primitive) {}

int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
int GetSrcT() const;


+ 8
- 10
mindspore/lite/src/ops/ceil.h View File

@@ -14,25 +14,23 @@
* limitations under the License.
*/

#ifndef LITE_MINDSPORE_LITE_C_OPS_CEIL_H_
#define LITE_MINDSPORE_LITE_C_OPS_CEIL_H_

#include <vector>
#include <set>
#include <cmath>
#include "ir/dtype/type_id.h"
#include "src/ops/arithmetic_self.h"
#ifdef PRIMITIVE_WRITEABLE
#include "schema/inner/model_generated.h"
#else
#include "schema/model_generated.h"
#endif

#ifndef LITE_MINDSPORE_LITE_C_OPS_CEIL_H_
#define LITE_MINDSPORE_LITE_C_OPS_CEIL_H_
#include "src/ops/primitive_c.h"

namespace mindspore {
namespace lite {
class Ceil : public ArithmeticSelf {
public:
explicit Ceil(OriginPrimitive *primitive) : ArithmeticSelf(primitive) {}
#ifdef PRIMITIVE_WRITEABLE
explicit Ceil(schema::PrimitiveT *primitive) : ArithmeticSelf(primitive) {}
#endif
explicit Ceil(schema::Primitive *primitive) : ArithmeticSelf(primitive) {}
};
} // namespace lite
} // namespace mindspore


+ 6
- 6
mindspore/lite/src/ops/clip.cc View File

@@ -19,16 +19,16 @@
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
float Clip::GetMax() const { return this->primitive->value.AsClip()->max; }
float Clip::GetMin() const { return this->primitive->value.AsClip()->min; }
float Clip::GetMax() const { return this->primitive_->value.AsClip()->max; }
float Clip::GetMin() const { return this->primitive_->value.AsClip()->min; }

void Clip::SetMax(float max) { this->primitive->value.AsClip()->max = max; }
void Clip::SetMin(float min) { this->primitive->value.AsClip()->min = min; }
void Clip::SetMax(float max) { this->primitive_->value.AsClip()->max = max; }
void Clip::SetMin(float min) { this->primitive_->value.AsClip()->min = min; }

#else

float Clip::GetMax() const { return this->primitive->value_as_Clip()->max(); }
float Clip::GetMin() const { return this->primitive->value_as_Clip()->min(); }
float Clip::GetMax() const { return this->primitive_->value_as_Clip()->max(); }
float Clip::GetMin() const { return this->primitive_->value_as_Clip()->min(); }

void Clip::SetMax(float max) {}
void Clip::SetMin(float min) {}


+ 7
- 9
mindspore/lite/src/ops/clip.h View File

@@ -14,25 +14,23 @@
* limitations under the License.
*/

#ifndef LITE_MINDSPORE_LITE_C_OPS_CLIP_H_
#define LITE_MINDSPORE_LITE_C_OPS_CLIP_H_

#include <vector>
#include <set>
#include <cmath>
#include "ir/dtype/type_id.h"
#include "src/ops/primitive_c.h"
#ifdef PRIMITIVE_WRITEABLE
#include "schema/inner/model_generated.h"
#else
#include "schema/model_generated.h"
#endif

#ifndef LITE_MINDSPORE_LITE_C_OPS_CLIP_H_
#define LITE_MINDSPORE_LITE_C_OPS_CLIP_H_

namespace mindspore {
namespace lite {
class Clip : public PrimitiveC {
public:
explicit Clip(OriginPrimitive *primitive) : PrimitiveC(primitive) {}
#ifdef PRIMITIVE_WRITEABLE
explicit Clip(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
#endif
explicit Clip(schema::Primitive *primitive) : PrimitiveC(primitive) {}

float GetMax() const;
float GetMin() const;


+ 7
- 7
mindspore/lite/src/ops/concat.cc View File

@@ -21,16 +21,16 @@
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
int Concat::GetAxis() const { return this->primitive->value.AsConcat()->axis; }
int Concat::GetN() const { return this->primitive->value.AsConcat()->n; }
int Concat::GetAxis() const { return this->primitive_->value.AsConcat()->axis; }
int Concat::GetN() const { return this->primitive_->value.AsConcat()->n; }

void Concat::SetAxis(int axis) { this->primitive->value.AsConcat()->axis = axis; }
void Concat::SetN(int n) { this->primitive->value.AsConcat()->n = n; }
void Concat::SetAxis(int axis) { this->primitive_->value.AsConcat()->axis = axis; }
void Concat::SetN(int n) { this->primitive_->value.AsConcat()->n = n; }

#else

int Concat::GetAxis() const { return this->primitive->value_as_Concat()->axis(); }
int Concat::GetN() const { return this->primitive->value_as_Concat()->n(); }
int Concat::GetAxis() const { return this->primitive_->value_as_Concat()->axis(); }
int Concat::GetN() const { return this->primitive_->value_as_Concat()->n(); }

void Concat::SetAxis(int axis) {}
void Concat::SetN(int n) {}
@@ -40,7 +40,7 @@ namespace {
constexpr int kConcatOutputNum = 1;
}
int Concat::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::Tensor *> outputs_) {
if (this->primitive == nullptr) {
if (this->primitive_ == nullptr) {
MS_LOG(ERROR) << "primitive is nullptr!";
return RET_PARAM_INVALID;
}


+ 7
- 9
mindspore/lite/src/ops/concat.h View File

@@ -14,25 +14,23 @@
* limitations under the License.
*/

#ifndef LITE_MINDSPORE_LITE_C_OPS_CONCAT_H_
#define LITE_MINDSPORE_LITE_C_OPS_CONCAT_H_

#include <vector>
#include <set>
#include <cmath>
#include "ir/dtype/type_id.h"
#include "src/ops/primitive_c.h"
#ifdef PRIMITIVE_WRITEABLE
#include "schema/inner/model_generated.h"
#else
#include "schema/model_generated.h"
#endif

#ifndef LITE_MINDSPORE_LITE_C_OPS_CONCAT_H_
#define LITE_MINDSPORE_LITE_C_OPS_CONCAT_H_

namespace mindspore {
namespace lite {
class Concat : public PrimitiveC {
public:
explicit Concat(OriginPrimitive *primitive) : PrimitiveC(primitive) {}
#ifdef PRIMITIVE_WRITEABLE
explicit Concat(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
#endif
explicit Concat(schema::Primitive *primitive) : PrimitiveC(primitive) {}

int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
int GetAxis() const;


+ 3
- 3
mindspore/lite/src/ops/constant_of_shape.cc View File

@@ -25,13 +25,13 @@ constexpr int kShapeInputNum = 1;
constexpr int kShapeOutputNum = 1;
} // namespace
#ifdef PRIMITIVE_WRITEABLE
float ConstantOfShape::GetValue() const { return this->primitive->value.AsConstantOfShape()->value; }
float ConstantOfShape::GetValue() const { return this->primitive_->value.AsConstantOfShape()->value; }

void ConstantOfShape::SetValue(float value) { this->primitive->value.AsConstantOfShape()->value = value; }
void ConstantOfShape::SetValue(float value) { this->primitive_->value.AsConstantOfShape()->value = value; }

#else

float ConstantOfShape::GetValue() const { return this->primitive->value_as_ConstantOfShape()->value(); }
float ConstantOfShape::GetValue() const { return this->primitive_->value_as_ConstantOfShape()->value(); }

void ConstantOfShape::SetValue(float value) {}
#endif


+ 7
- 8
mindspore/lite/src/ops/constant_of_shape.h View File

@@ -14,24 +14,23 @@
* limitations under the License.
*/

#ifndef LITE_MINDSPORE_LITE_SRC_OPS_CONSTANT_OF_SHAPE_H_
#define LITE_MINDSPORE_LITE_SRC_OPS_CONSTANT_OF_SHAPE_H_

#include <vector>
#include <set>
#include <cmath>
#include "ir/dtype/type_id.h"
#include "src/ops/primitive_c.h"
#ifdef PRIMITIVE_WRITEABLE
#include "schema/inner/model_generated.h"
#else
#include "schema/model_generated.h"
#endif

#ifndef LITE_MINDSPORE_LITE_SRC_OPS_CONSTANT_OF_SHAPE_H_
#define LITE_MINDSPORE_LITE_SRC_OPS_CONSTANT_OF_SHAPE_H_
namespace mindspore {
namespace lite {
class ConstantOfShape : public PrimitiveC {
public:
explicit ConstantOfShape(OriginPrimitive *primitive) : PrimitiveC(primitive) {}
#ifdef PRIMITIVE_WRITEABLE
explicit ConstantOfShape(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
#endif
explicit ConstantOfShape(schema::Primitive *primitive) : PrimitiveC(primitive) {}
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
float GetValue() const;
void SetValue(float value);


+ 52
- 52
mindspore/lite/src/ops/conv2d.cc View File

@@ -26,63 +26,63 @@ int Conv2D::PadDown() const { return this->pad_d_; }
int Conv2D::PadLeft() const { return this->pad_l_; }
int Conv2D::PadRight() const { return this->pad_r_; }
#ifdef PRIMITIVE_WRITEABLE
int Conv2D::GetFormat() const { return this->primitive->value.AsConv2D()->format; }
int Conv2D::GetGroup() const { return this->primitive->value.AsConv2D()->group; }
int Conv2D::GetChannelIn() const { return this->primitive->value.AsConv2D()->channelIn; }
int Conv2D::GetChannelOut() const { return this->primitive->value.AsConv2D()->channelOut; }
int Conv2D::GetKernelW() const { return this->primitive->value.AsConv2D()->kernelW; }
int Conv2D::GetKernelH() const { return this->primitive->value.AsConv2D()->kernelH; }
int Conv2D::GetStrideW() const { return this->primitive->value.AsConv2D()->strideW; }
int Conv2D::GetStrideH() const { return this->primitive->value.AsConv2D()->strideH; }
int Conv2D::GetPadMode() const { return this->primitive->value.AsConv2D()->padMode; }
int Conv2D::GetPadUp() const { return this->primitive->value.AsConv2D()->padUp; }
int Conv2D::GetPadDown() const { return this->primitive->value.AsConv2D()->padDown; }
int Conv2D::GetPadLeft() const { return this->primitive->value.AsConv2D()->padLeft; }
int Conv2D::GetPadRight() const { return this->primitive->value.AsConv2D()->padRight; }
int Conv2D::GetDilateW() const { return this->primitive->value.AsConv2D()->dilateW; }
int Conv2D::GetDilateH() const { return this->primitive->value.AsConv2D()->dilateH; }
bool Conv2D::GetHasBias() const { return this->primitive->value.AsConv2D()->hasBias; }
int Conv2D::GetActivationType() const { return this->primitive->value.AsConv2D()->activationType; }
int Conv2D::GetFormat() const { return this->primitive_->value.AsConv2D()->format; }
int Conv2D::GetGroup() const { return this->primitive_->value.AsConv2D()->group; }
int Conv2D::GetChannelIn() const { return this->primitive_->value.AsConv2D()->channelIn; }
int Conv2D::GetChannelOut() const { return this->primitive_->value.AsConv2D()->channelOut; }
int Conv2D::GetKernelW() const { return this->primitive_->value.AsConv2D()->kernelW; }
int Conv2D::GetKernelH() const { return this->primitive_->value.AsConv2D()->kernelH; }
int Conv2D::GetStrideW() const { return this->primitive_->value.AsConv2D()->strideW; }
int Conv2D::GetStrideH() const { return this->primitive_->value.AsConv2D()->strideH; }
int Conv2D::GetPadMode() const { return this->primitive_->value.AsConv2D()->padMode; }
int Conv2D::GetPadUp() const { return this->primitive_->value.AsConv2D()->padUp; }
int Conv2D::GetPadDown() const { return this->primitive_->value.AsConv2D()->padDown; }
int Conv2D::GetPadLeft() const { return this->primitive_->value.AsConv2D()->padLeft; }
int Conv2D::GetPadRight() const { return this->primitive_->value.AsConv2D()->padRight; }
int Conv2D::GetDilateW() const { return this->primitive_->value.AsConv2D()->dilateW; }
int Conv2D::GetDilateH() const { return this->primitive_->value.AsConv2D()->dilateH; }
bool Conv2D::GetHasBias() const { return this->primitive_->value.AsConv2D()->hasBias; }
int Conv2D::GetActivationType() const { return this->primitive_->value.AsConv2D()->activationType; }

void Conv2D::SetFormat(int format) { this->primitive->value.AsConv2D()->format = (schema::Format)format; }
void Conv2D::SetGroup(int group) { this->primitive->value.AsConv2D()->group = group; }
void Conv2D::SetChannelIn(int channel_in) { this->primitive->value.AsConv2D()->channelIn = channel_in; }
void Conv2D::SetChannelOut(int channel_out) { this->primitive->value.AsConv2D()->channelOut = channel_out; }
void Conv2D::SetKernelW(int kernel_w) { this->primitive->value.AsConv2D()->kernelW = kernel_w; }
void Conv2D::SetKernelH(int kernel_h) { this->primitive->value.AsConv2D()->kernelH = kernel_h; }
void Conv2D::SetStrideW(int stride_w) { this->primitive->value.AsConv2D()->strideW = stride_w; }
void Conv2D::SetStrideH(int stride_h) { this->primitive->value.AsConv2D()->strideH = stride_h; }
void Conv2D::SetPadMode(int pad_mode) { this->primitive->value.AsConv2D()->padMode = (schema::PadMode)pad_mode; }
void Conv2D::SetPadUp(int pad_up) { this->primitive->value.AsConv2D()->padUp = pad_up; }
void Conv2D::SetPadDown(int pad_down) { this->primitive->value.AsConv2D()->padDown = pad_down; }
void Conv2D::SetPadLeft(int pad_left) { this->primitive->value.AsConv2D()->padLeft = pad_left; }
void Conv2D::SetPadRight(int pad_right) { this->primitive->value.AsConv2D()->padRight = pad_right; }
void Conv2D::SetDilateW(int dilate_w) { this->primitive->value.AsConv2D()->dilateW = dilate_w; }
void Conv2D::SetDilateH(int dilate_h) { this->primitive->value.AsConv2D()->dilateH = dilate_h; }
void Conv2D::SetHasBias(bool has_bias) { this->primitive->value.AsConv2D()->hasBias = has_bias; }
void Conv2D::SetFormat(int format) { this->primitive_->value.AsConv2D()->format = (schema::Format)format; }
void Conv2D::SetGroup(int group) { this->primitive_->value.AsConv2D()->group = group; }
void Conv2D::SetChannelIn(int channel_in) { this->primitive_->value.AsConv2D()->channelIn = channel_in; }
void Conv2D::SetChannelOut(int channel_out) { this->primitive_->value.AsConv2D()->channelOut = channel_out; }
void Conv2D::SetKernelW(int kernel_w) { this->primitive_->value.AsConv2D()->kernelW = kernel_w; }
void Conv2D::SetKernelH(int kernel_h) { this->primitive_->value.AsConv2D()->kernelH = kernel_h; }
void Conv2D::SetStrideW(int stride_w) { this->primitive_->value.AsConv2D()->strideW = stride_w; }
void Conv2D::SetStrideH(int stride_h) { this->primitive_->value.AsConv2D()->strideH = stride_h; }
void Conv2D::SetPadMode(int pad_mode) { this->primitive_->value.AsConv2D()->padMode = (schema::PadMode)pad_mode; }
void Conv2D::SetPadUp(int pad_up) { this->primitive_->value.AsConv2D()->padUp = pad_up; }
void Conv2D::SetPadDown(int pad_down) { this->primitive_->value.AsConv2D()->padDown = pad_down; }
void Conv2D::SetPadLeft(int pad_left) { this->primitive_->value.AsConv2D()->padLeft = pad_left; }
void Conv2D::SetPadRight(int pad_right) { this->primitive_->value.AsConv2D()->padRight = pad_right; }
void Conv2D::SetDilateW(int dilate_w) { this->primitive_->value.AsConv2D()->dilateW = dilate_w; }
void Conv2D::SetDilateH(int dilate_h) { this->primitive_->value.AsConv2D()->dilateH = dilate_h; }
void Conv2D::SetHasBias(bool has_bias) { this->primitive_->value.AsConv2D()->hasBias = has_bias; }
void Conv2D::SetActivationType(int activation_type) {
this->primitive->value.AsConv2D()->activationType = (schema::ActivationType)activation_type;
this->primitive_->value.AsConv2D()->activationType = (schema::ActivationType)activation_type;
}

#else

int Conv2D::GetFormat() const { return this->primitive->value_as_Conv2D()->format(); }
int Conv2D::GetGroup() const { return this->primitive->value_as_Conv2D()->group(); }
int Conv2D::GetChannelIn() const { return this->primitive->value_as_Conv2D()->channelIn(); }
int Conv2D::GetChannelOut() const { return this->primitive->value_as_Conv2D()->channelOut(); }
int Conv2D::GetKernelW() const { return this->primitive->value_as_Conv2D()->kernelW(); }
int Conv2D::GetKernelH() const { return this->primitive->value_as_Conv2D()->kernelH(); }
int Conv2D::GetStrideW() const { return this->primitive->value_as_Conv2D()->strideW(); }
int Conv2D::GetStrideH() const { return this->primitive->value_as_Conv2D()->strideH(); }
int Conv2D::GetPadMode() const { return this->primitive->value_as_Conv2D()->padMode(); }
int Conv2D::GetPadUp() const { return this->primitive->value_as_Conv2D()->padUp(); }
int Conv2D::GetPadDown() const { return this->primitive->value_as_Conv2D()->padDown(); }
int Conv2D::GetPadLeft() const { return this->primitive->value_as_Conv2D()->padLeft(); }
int Conv2D::GetPadRight() const { return this->primitive->value_as_Conv2D()->padRight(); }
int Conv2D::GetDilateW() const { return this->primitive->value_as_Conv2D()->dilateW(); }
int Conv2D::GetDilateH() const { return this->primitive->value_as_Conv2D()->dilateH(); }
bool Conv2D::GetHasBias() const { return this->primitive->value_as_Conv2D()->hasBias(); }
int Conv2D::GetActivationType() const { return this->primitive->value_as_Conv2D()->activationType(); }
int Conv2D::GetFormat() const { return this->primitive_->value_as_Conv2D()->format(); }
int Conv2D::GetGroup() const { return this->primitive_->value_as_Conv2D()->group(); }
int Conv2D::GetChannelIn() const { return this->primitive_->value_as_Conv2D()->channelIn(); }
int Conv2D::GetChannelOut() const { return this->primitive_->value_as_Conv2D()->channelOut(); }
int Conv2D::GetKernelW() const { return this->primitive_->value_as_Conv2D()->kernelW(); }
int Conv2D::GetKernelH() const { return this->primitive_->value_as_Conv2D()->kernelH(); }
int Conv2D::GetStrideW() const { return this->primitive_->value_as_Conv2D()->strideW(); }
int Conv2D::GetStrideH() const { return this->primitive_->value_as_Conv2D()->strideH(); }
int Conv2D::GetPadMode() const { return this->primitive_->value_as_Conv2D()->padMode(); }
int Conv2D::GetPadUp() const { return this->primitive_->value_as_Conv2D()->padUp(); }
int Conv2D::GetPadDown() const { return this->primitive_->value_as_Conv2D()->padDown(); }
int Conv2D::GetPadLeft() const { return this->primitive_->value_as_Conv2D()->padLeft(); }
int Conv2D::GetPadRight() const { return this->primitive_->value_as_Conv2D()->padRight(); }
int Conv2D::GetDilateW() const { return this->primitive_->value_as_Conv2D()->dilateW(); }
int Conv2D::GetDilateH() const { return this->primitive_->value_as_Conv2D()->dilateH(); }
bool Conv2D::GetHasBias() const { return this->primitive_->value_as_Conv2D()->hasBias(); }
int Conv2D::GetActivationType() const { return this->primitive_->value_as_Conv2D()->activationType(); }

void Conv2D::SetFormat(int format) {}
void Conv2D::SetGroup(int group) {}
@@ -103,7 +103,7 @@ void Conv2D::SetHasBias(bool has_bias) {}
void Conv2D::SetActivationType(int activation_type) {}
#endif
void Conv2D::ConvInferShape(int input_h, int input_w, int *output_h, int *output_w) {
MS_ASSERT(this->primitive != nullptr);
MS_ASSERT(this->primitive_ != nullptr);
int kernel_w = GetKernelW();
int kernel_h = GetKernelH();
int stride_w = GetStrideW();


+ 7
- 9
mindspore/lite/src/ops/conv2d.h View File

@@ -14,25 +14,23 @@
* limitations under the License.
*/

#ifndef LITE_MINDSPORE_LITE_C_OPS_CONV2_D_H_
#define LITE_MINDSPORE_LITE_C_OPS_CONV2_D_H_

#include <vector>
#include <set>
#include <cmath>
#include "ir/dtype/type_id.h"
#include "src/ops/primitive_c.h"
#ifdef PRIMITIVE_WRITEABLE
#include "schema/inner/model_generated.h"
#else
#include "schema/model_generated.h"
#endif

#ifndef LITE_MINDSPORE_LITE_C_OPS_CONV2_D_H_
#define LITE_MINDSPORE_LITE_C_OPS_CONV2_D_H_

namespace mindspore {
namespace lite {
class Conv2D : public PrimitiveC {
public:
explicit Conv2D(OriginPrimitive *primitive) : PrimitiveC(primitive) {}
#ifdef PRIMITIVE_WRITEABLE
explicit Conv2D(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
#endif
explicit Conv2D(schema::Primitive *primitive) : PrimitiveC(primitive) {}
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
int PadUp() const;
int PadDown() const;


+ 53
- 51
mindspore/lite/src/ops/conv2d_grad_filter.cc View File

@@ -19,72 +19,74 @@
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
int Conv2DGradFilter::GetFormat() const { return this->primitive->value.AsConv2DGradFilter()->format; }
int Conv2DGradFilter::GetGroup() const { return this->primitive->value.AsConv2DGradFilter()->group; }
int Conv2DGradFilter::GetChannelIn() const { return this->primitive->value.AsConv2DGradFilter()->channelIn; }
int Conv2DGradFilter::GetChannelOut() const { return this->primitive->value.AsConv2DGradFilter()->channelOut; }
int Conv2DGradFilter::GetKernelW() const { return this->primitive->value.AsConv2DGradFilter()->kernelW; }
int Conv2DGradFilter::GetKernelH() const { return this->primitive->value.AsConv2DGradFilter()->kernelH; }
int Conv2DGradFilter::GetStrideW() const { return this->primitive->value.AsConv2DGradFilter()->strideW; }
int Conv2DGradFilter::GetStrideH() const { return this->primitive->value.AsConv2DGradFilter()->strideH; }
int Conv2DGradFilter::GetPadMode() const { return this->primitive->value.AsConv2DGradFilter()->padMode; }
int Conv2DGradFilter::GetPadUp() const { return this->primitive->value.AsConv2DGradFilter()->padUp; }
int Conv2DGradFilter::GetPadDown() const { return this->primitive->value.AsConv2DGradFilter()->padDown; }
int Conv2DGradFilter::GetPadLeft() const { return this->primitive->value.AsConv2DGradFilter()->padLeft; }
int Conv2DGradFilter::GetPadRight() const { return this->primitive->value.AsConv2DGradFilter()->padRight; }
int Conv2DGradFilter::GetDilateW() const { return this->primitive->value.AsConv2DGradFilter()->dilateW; }
int Conv2DGradFilter::GetDilateH() const { return this->primitive->value.AsConv2DGradFilter()->dilateH; }
bool Conv2DGradFilter::GetHasBias() const { return this->primitive->value.AsConv2DGradFilter()->hasBias; }
int Conv2DGradFilter::GetActivationType() const { return this->primitive->value.AsConv2DGradFilter()->activationType; }
int Conv2DGradFilter::GetFormat() const { return this->primitive_->value.AsConv2DGradFilter()->format; }
int Conv2DGradFilter::GetGroup() const { return this->primitive_->value.AsConv2DGradFilter()->group; }
int Conv2DGradFilter::GetChannelIn() const { return this->primitive_->value.AsConv2DGradFilter()->channelIn; }
int Conv2DGradFilter::GetChannelOut() const { return this->primitive_->value.AsConv2DGradFilter()->channelOut; }
int Conv2DGradFilter::GetKernelW() const { return this->primitive_->value.AsConv2DGradFilter()->kernelW; }
int Conv2DGradFilter::GetKernelH() const { return this->primitive_->value.AsConv2DGradFilter()->kernelH; }
int Conv2DGradFilter::GetStrideW() const { return this->primitive_->value.AsConv2DGradFilter()->strideW; }
int Conv2DGradFilter::GetStrideH() const { return this->primitive_->value.AsConv2DGradFilter()->strideH; }
int Conv2DGradFilter::GetPadMode() const { return this->primitive_->value.AsConv2DGradFilter()->padMode; }
int Conv2DGradFilter::GetPadUp() const { return this->primitive_->value.AsConv2DGradFilter()->padUp; }
int Conv2DGradFilter::GetPadDown() const { return this->primitive_->value.AsConv2DGradFilter()->padDown; }
int Conv2DGradFilter::GetPadLeft() const { return this->primitive_->value.AsConv2DGradFilter()->padLeft; }
int Conv2DGradFilter::GetPadRight() const { return this->primitive_->value.AsConv2DGradFilter()->padRight; }
int Conv2DGradFilter::GetDilateW() const { return this->primitive_->value.AsConv2DGradFilter()->dilateW; }
int Conv2DGradFilter::GetDilateH() const { return this->primitive_->value.AsConv2DGradFilter()->dilateH; }
bool Conv2DGradFilter::GetHasBias() const { return this->primitive_->value.AsConv2DGradFilter()->hasBias; }
int Conv2DGradFilter::GetActivationType() const { return this->primitive_->value.AsConv2DGradFilter()->activationType; }

void Conv2DGradFilter::SetFormat(int format) {
this->primitive->value.AsConv2DGradFilter()->format = (schema::Format)format;
this->primitive_->value.AsConv2DGradFilter()->format = (schema::Format)format;
}
void Conv2DGradFilter::SetGroup(int group) { this->primitive->value.AsConv2DGradFilter()->group = group; }
void Conv2DGradFilter::SetGroup(int group) { this->primitive_->value.AsConv2DGradFilter()->group = group; }
void Conv2DGradFilter::SetChannelIn(int channel_in) {
this->primitive->value.AsConv2DGradFilter()->channelIn = channel_in;
this->primitive_->value.AsConv2DGradFilter()->channelIn = channel_in;
}
void Conv2DGradFilter::SetChannelOut(int channel_out) {
this->primitive->value.AsConv2DGradFilter()->channelOut = channel_out;
this->primitive_->value.AsConv2DGradFilter()->channelOut = channel_out;
}
void Conv2DGradFilter::SetKernelW(int kernel_w) { this->primitive->value.AsConv2DGradFilter()->kernelW = kernel_w; }
void Conv2DGradFilter::SetKernelH(int kernel_h) { this->primitive->value.AsConv2DGradFilter()->kernelH = kernel_h; }
void Conv2DGradFilter::SetStrideW(int stride_w) { this->primitive->value.AsConv2DGradFilter()->strideW = stride_w; }
void Conv2DGradFilter::SetStrideH(int stride_h) { this->primitive->value.AsConv2DGradFilter()->strideH = stride_h; }
void Conv2DGradFilter::SetKernelW(int kernel_w) { this->primitive_->value.AsConv2DGradFilter()->kernelW = kernel_w; }
void Conv2DGradFilter::SetKernelH(int kernel_h) { this->primitive_->value.AsConv2DGradFilter()->kernelH = kernel_h; }
void Conv2DGradFilter::SetStrideW(int stride_w) { this->primitive_->value.AsConv2DGradFilter()->strideW = stride_w; }
void Conv2DGradFilter::SetStrideH(int stride_h) { this->primitive_->value.AsConv2DGradFilter()->strideH = stride_h; }
void Conv2DGradFilter::SetPadMode(int pad_mode) {
this->primitive->value.AsConv2DGradFilter()->padMode = (schema::PadMode)pad_mode;
this->primitive_->value.AsConv2DGradFilter()->padMode = (schema::PadMode)pad_mode;
}
void Conv2DGradFilter::SetPadUp(int pad_up) { this->primitive->value.AsConv2DGradFilter()->padUp = pad_up; }
void Conv2DGradFilter::SetPadDown(int pad_down) { this->primitive->value.AsConv2DGradFilter()->padDown = pad_down; }
void Conv2DGradFilter::SetPadLeft(int pad_left) { this->primitive->value.AsConv2DGradFilter()->padLeft = pad_left; }
void Conv2DGradFilter::SetPadRight(int pad_right) { this->primitive->value.AsConv2DGradFilter()->padRight = pad_right; }
void Conv2DGradFilter::SetDilateW(int dilate_w) { this->primitive->value.AsConv2DGradFilter()->dilateW = dilate_w; }
void Conv2DGradFilter::SetDilateH(int dilate_h) { this->primitive->value.AsConv2DGradFilter()->dilateH = dilate_h; }
void Conv2DGradFilter::SetHasBias(bool has_bias) { this->primitive->value.AsConv2DGradFilter()->hasBias = has_bias; }
void Conv2DGradFilter::SetPadUp(int pad_up) { this->primitive_->value.AsConv2DGradFilter()->padUp = pad_up; }
void Conv2DGradFilter::SetPadDown(int pad_down) { this->primitive_->value.AsConv2DGradFilter()->padDown = pad_down; }
void Conv2DGradFilter::SetPadLeft(int pad_left) { this->primitive_->value.AsConv2DGradFilter()->padLeft = pad_left; }
void Conv2DGradFilter::SetPadRight(int pad_right) {
this->primitive_->value.AsConv2DGradFilter()->padRight = pad_right;
}
void Conv2DGradFilter::SetDilateW(int dilate_w) { this->primitive_->value.AsConv2DGradFilter()->dilateW = dilate_w; }
void Conv2DGradFilter::SetDilateH(int dilate_h) { this->primitive_->value.AsConv2DGradFilter()->dilateH = dilate_h; }
void Conv2DGradFilter::SetHasBias(bool has_bias) { this->primitive_->value.AsConv2DGradFilter()->hasBias = has_bias; }
void Conv2DGradFilter::SetActivationType(int activation_type) {
this->primitive->value.AsConv2DGradFilter()->activationType = (schema::ActivationType)activation_type;
this->primitive_->value.AsConv2DGradFilter()->activationType = (schema::ActivationType)activation_type;
}

#else

int Conv2DGradFilter::GetFormat() const { return this->primitive->value_as_Conv2DGradFilter()->format(); }
int Conv2DGradFilter::GetGroup() const { return this->primitive->value_as_Conv2DGradFilter()->group(); }
int Conv2DGradFilter::GetChannelIn() const { return this->primitive->value_as_Conv2DGradFilter()->channelIn(); }
int Conv2DGradFilter::GetChannelOut() const { return this->primitive->value_as_Conv2DGradFilter()->channelOut(); }
int Conv2DGradFilter::GetKernelW() const { return this->primitive->value_as_Conv2DGradFilter()->kernelW(); }
int Conv2DGradFilter::GetKernelH() const { return this->primitive->value_as_Conv2DGradFilter()->kernelH(); }
int Conv2DGradFilter::GetStrideW() const { return this->primitive->value_as_Conv2DGradFilter()->strideW(); }
int Conv2DGradFilter::GetStrideH() const { return this->primitive->value_as_Conv2DGradFilter()->strideH(); }
int Conv2DGradFilter::GetPadMode() const { return this->primitive->value_as_Conv2DGradFilter()->padMode(); }
int Conv2DGradFilter::GetPadUp() const { return this->primitive->value_as_Conv2DGradFilter()->padUp(); }
int Conv2DGradFilter::GetPadDown() const { return this->primitive->value_as_Conv2DGradFilter()->padDown(); }
int Conv2DGradFilter::GetPadLeft() const { return this->primitive->value_as_Conv2DGradFilter()->padLeft(); }
int Conv2DGradFilter::GetPadRight() const { return this->primitive->value_as_Conv2DGradFilter()->padRight(); }
int Conv2DGradFilter::GetDilateW() const { return this->primitive->value_as_Conv2DGradFilter()->dilateW(); }
int Conv2DGradFilter::GetDilateH() const { return this->primitive->value_as_Conv2DGradFilter()->dilateH(); }
bool Conv2DGradFilter::GetHasBias() const { return this->primitive->value_as_Conv2DGradFilter()->hasBias(); }
int Conv2DGradFilter::GetFormat() const { return this->primitive_->value_as_Conv2DGradFilter()->format(); }
int Conv2DGradFilter::GetGroup() const { return this->primitive_->value_as_Conv2DGradFilter()->group(); }
int Conv2DGradFilter::GetChannelIn() const { return this->primitive_->value_as_Conv2DGradFilter()->channelIn(); }
int Conv2DGradFilter::GetChannelOut() const { return this->primitive_->value_as_Conv2DGradFilter()->channelOut(); }
int Conv2DGradFilter::GetKernelW() const { return this->primitive_->value_as_Conv2DGradFilter()->kernelW(); }
int Conv2DGradFilter::GetKernelH() const { return this->primitive_->value_as_Conv2DGradFilter()->kernelH(); }
int Conv2DGradFilter::GetStrideW() const { return this->primitive_->value_as_Conv2DGradFilter()->strideW(); }
int Conv2DGradFilter::GetStrideH() const { return this->primitive_->value_as_Conv2DGradFilter()->strideH(); }
int Conv2DGradFilter::GetPadMode() const { return this->primitive_->value_as_Conv2DGradFilter()->padMode(); }
int Conv2DGradFilter::GetPadUp() const { return this->primitive_->value_as_Conv2DGradFilter()->padUp(); }
int Conv2DGradFilter::GetPadDown() const { return this->primitive_->value_as_Conv2DGradFilter()->padDown(); }
int Conv2DGradFilter::GetPadLeft() const { return this->primitive_->value_as_Conv2DGradFilter()->padLeft(); }
int Conv2DGradFilter::GetPadRight() const { return this->primitive_->value_as_Conv2DGradFilter()->padRight(); }
int Conv2DGradFilter::GetDilateW() const { return this->primitive_->value_as_Conv2DGradFilter()->dilateW(); }
int Conv2DGradFilter::GetDilateH() const { return this->primitive_->value_as_Conv2DGradFilter()->dilateH(); }
bool Conv2DGradFilter::GetHasBias() const { return this->primitive_->value_as_Conv2DGradFilter()->hasBias(); }
int Conv2DGradFilter::GetActivationType() const {
return this->primitive->value_as_Conv2DGradFilter()->activationType();
return this->primitive_->value_as_Conv2DGradFilter()->activationType();
}

void Conv2DGradFilter::SetFormat(int format) {}


+ 7
- 9
mindspore/lite/src/ops/conv2d_grad_filter.h View File

@@ -14,25 +14,23 @@
* limitations under the License.
*/

#ifndef LITE_MINDSPORE_LITE_C_OPS_CONV2_D_GRAD_FILTER_H_
#define LITE_MINDSPORE_LITE_C_OPS_CONV2_D_GRAD_FILTER_H_

#include <vector>
#include <set>
#include <cmath>
#include "ir/dtype/type_id.h"
#include "src/ops/primitive_c.h"
#ifdef PRIMITIVE_WRITEABLE
#include "schema/inner/model_generated.h"
#else
#include "schema/model_generated.h"
#endif

#ifndef LITE_MINDSPORE_LITE_C_OPS_CONV2_D_GRAD_FILTER_H_
#define LITE_MINDSPORE_LITE_C_OPS_CONV2_D_GRAD_FILTER_H_

namespace mindspore {
namespace lite {
class Conv2DGradFilter : public PrimitiveC {
public:
explicit Conv2DGradFilter(OriginPrimitive *primitive) : PrimitiveC(primitive) {}
#ifdef PRIMITIVE_WRITEABLE
explicit Conv2DGradFilter(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
#endif
explicit Conv2DGradFilter(schema::Primitive *primitive) : PrimitiveC(primitive) {}

int GetFormat() const;
int GetGroup() const;


+ 53
- 51
mindspore/lite/src/ops/conv2d_grad_input.cc View File

@@ -19,71 +19,73 @@
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
int Conv2DGradInput::GetFormat() const { return this->primitive->value.AsConv2DGradInput()->format; }
int Conv2DGradInput::GetGroup() const { return this->primitive->value.AsConv2DGradInput()->group; }
int Conv2DGradInput::GetChannelIn() const { return this->primitive->value.AsConv2DGradInput()->channelIn; }
int Conv2DGradInput::GetChannelOut() const { return this->primitive->value.AsConv2DGradInput()->channelOut; }
int Conv2DGradInput::GetKernelW() const { return this->primitive->value.AsConv2DGradInput()->kernelW; }
int Conv2DGradInput::GetKernelH() const { return this->primitive->value.AsConv2DGradInput()->kernelH; }
int Conv2DGradInput::GetStrideW() const { return this->primitive->value.AsConv2DGradInput()->strideW; }
int Conv2DGradInput::GetStrideH() const { return this->primitive->value.AsConv2DGradInput()->strideH; }
int Conv2DGradInput::GetPadMode() const { return this->primitive->value.AsConv2DGradInput()->padMode; }
int Conv2DGradInput::GetPadUp() const { return this->primitive->value.AsConv2DGradInput()->padUp; }
int Conv2DGradInput::GetPadDown() const { return this->primitive->value.AsConv2DGradInput()->padDown; }
int Conv2DGradInput::GetPadLeft() const { return this->primitive->value.AsConv2DGradInput()->padLeft; }
int Conv2DGradInput::GetPadRight() const { return this->primitive->value.AsConv2DGradInput()->padRight; }
int Conv2DGradInput::GetDilateW() const { return this->primitive->value.AsConv2DGradInput()->dilateW; }
int Conv2DGradInput::GetDilateH() const { return this->primitive->value.AsConv2DGradInput()->dilateH; }
bool Conv2DGradInput::GetHasBias() const { return this->primitive->value.AsConv2DGradInput()->hasBias; }
int Conv2DGradInput::GetActivationType() const { return this->primitive->value.AsConv2DGradInput()->activationType; }
int Conv2DGradInput::GetFormat() const { return this->primitive_->value.AsConv2DGradInput()->format; }
int Conv2DGradInput::GetGroup() const { return this->primitive_->value.AsConv2DGradInput()->group; }
int Conv2DGradInput::GetChannelIn() const { return this->primitive_->value.AsConv2DGradInput()->channelIn; }
int Conv2DGradInput::GetChannelOut() const { return this->primitive_->value.AsConv2DGradInput()->channelOut; }
int Conv2DGradInput::GetKernelW() const { return this->primitive_->value.AsConv2DGradInput()->kernelW; }
int Conv2DGradInput::GetKernelH() const { return this->primitive_->value.AsConv2DGradInput()->kernelH; }
int Conv2DGradInput::GetStrideW() const { return this->primitive_->value.AsConv2DGradInput()->strideW; }
int Conv2DGradInput::GetStrideH() const { return this->primitive_->value.AsConv2DGradInput()->strideH; }
int Conv2DGradInput::GetPadMode() const { return this->primitive_->value.AsConv2DGradInput()->padMode; }
int Conv2DGradInput::GetPadUp() const { return this->primitive_->value.AsConv2DGradInput()->padUp; }
int Conv2DGradInput::GetPadDown() const { return this->primitive_->value.AsConv2DGradInput()->padDown; }
int Conv2DGradInput::GetPadLeft() const { return this->primitive_->value.AsConv2DGradInput()->padLeft; }
int Conv2DGradInput::GetPadRight() const { return this->primitive_->value.AsConv2DGradInput()->padRight; }
int Conv2DGradInput::GetDilateW() const { return this->primitive_->value.AsConv2DGradInput()->dilateW; }
int Conv2DGradInput::GetDilateH() const { return this->primitive_->value.AsConv2DGradInput()->dilateH; }
bool Conv2DGradInput::GetHasBias() const { return this->primitive_->value.AsConv2DGradInput()->hasBias; }
int Conv2DGradInput::GetActivationType() const { return this->primitive_->value.AsConv2DGradInput()->activationType; }

void Conv2DGradInput::SetFormat(int format) {
this->primitive->value.AsConv2DGradInput()->format = (schema::Format)format;
this->primitive_->value.AsConv2DGradInput()->format = (schema::Format)format;
}
void Conv2DGradInput::SetGroup(int group) { this->primitive->value.AsConv2DGradInput()->group = group; }
void Conv2DGradInput::SetGroup(int group) { this->primitive_->value.AsConv2DGradInput()->group = group; }
void Conv2DGradInput::SetChannelIn(int channel_in) {
this->primitive->value.AsConv2DGradInput()->channelIn = channel_in;
this->primitive_->value.AsConv2DGradInput()->channelIn = channel_in;
}
void Conv2DGradInput::SetChannelOut(int channel_out) {
this->primitive->value.AsConv2DGradInput()->channelOut = channel_out;
this->primitive_->value.AsConv2DGradInput()->channelOut = channel_out;
}
void Conv2DGradInput::SetKernelW(int kernel_w) { this->primitive->value.AsConv2DGradInput()->kernelW = kernel_w; }
void Conv2DGradInput::SetKernelH(int kernel_h) { this->primitive->value.AsConv2DGradInput()->kernelH = kernel_h; }
void Conv2DGradInput::SetStrideW(int stride_w) { this->primitive->value.AsConv2DGradInput()->strideW = stride_w; }
void Conv2DGradInput::SetStrideH(int stride_h) { this->primitive->value.AsConv2DGradInput()->strideH = stride_h; }
void Conv2DGradInput::SetKernelW(int kernel_w) { this->primitive_->value.AsConv2DGradInput()->kernelW = kernel_w; }
void Conv2DGradInput::SetKernelH(int kernel_h) { this->primitive_->value.AsConv2DGradInput()->kernelH = kernel_h; }
void Conv2DGradInput::SetStrideW(int stride_w) { this->primitive_->value.AsConv2DGradInput()->strideW = stride_w; }
void Conv2DGradInput::SetStrideH(int stride_h) { this->primitive_->value.AsConv2DGradInput()->strideH = stride_h; }
void Conv2DGradInput::SetPadMode(int pad_mode) {
this->primitive->value.AsConv2DGradInput()->padMode = (schema::PadMode)pad_mode;
this->primitive_->value.AsConv2DGradInput()->padMode = (schema::PadMode)pad_mode;
}
void Conv2DGradInput::SetPadUp(int pad_up) { this->primitive->value.AsConv2DGradInput()->padUp = pad_up; }
void Conv2DGradInput::SetPadDown(int pad_down) { this->primitive->value.AsConv2DGradInput()->padDown = pad_down; }
void Conv2DGradInput::SetPadLeft(int pad_left) { this->primitive->value.AsConv2DGradInput()->padLeft = pad_left; }
void Conv2DGradInput::SetPadRight(int pad_right) { this->primitive->value.AsConv2DGradInput()->padRight = pad_right; }
void Conv2DGradInput::SetDilateW(int dilate_w) { this->primitive->value.AsConv2DGradInput()->dilateW = dilate_w; }
void Conv2DGradInput::SetDilateH(int dilate_h) { this->primitive->value.AsConv2DGradInput()->dilateH = dilate_h; }
void Conv2DGradInput::SetHasBias(bool has_bias) { this->primitive->value.AsConv2DGradInput()->hasBias = has_bias; }
void Conv2DGradInput::SetPadUp(int pad_up) { this->primitive_->value.AsConv2DGradInput()->padUp = pad_up; }
void Conv2DGradInput::SetPadDown(int pad_down) { this->primitive_->value.AsConv2DGradInput()->padDown = pad_down; }
void Conv2DGradInput::SetPadLeft(int pad_left) { this->primitive_->value.AsConv2DGradInput()->padLeft = pad_left; }
void Conv2DGradInput::SetPadRight(int pad_right) { this->primitive_->value.AsConv2DGradInput()->padRight = pad_right; }
void Conv2DGradInput::SetDilateW(int dilate_w) { this->primitive_->value.AsConv2DGradInput()->dilateW = dilate_w; }
void Conv2DGradInput::SetDilateH(int dilate_h) { this->primitive_->value.AsConv2DGradInput()->dilateH = dilate_h; }
void Conv2DGradInput::SetHasBias(bool has_bias) { this->primitive_->value.AsConv2DGradInput()->hasBias = has_bias; }
void Conv2DGradInput::SetActivationType(int activation_type) {
this->primitive->value.AsConv2DGradInput()->activationType = (schema::ActivationType)activation_type;
this->primitive_->value.AsConv2DGradInput()->activationType = (schema::ActivationType)activation_type;
}

#else

int Conv2DGradInput::GetFormat() const { return this->primitive->value_as_Conv2DGradInput()->format(); }
int Conv2DGradInput::GetGroup() const { return this->primitive->value_as_Conv2DGradInput()->group(); }
int Conv2DGradInput::GetChannelIn() const { return this->primitive->value_as_Conv2DGradInput()->channelIn(); }
int Conv2DGradInput::GetChannelOut() const { return this->primitive->value_as_Conv2DGradInput()->channelOut(); }
int Conv2DGradInput::GetKernelW() const { return this->primitive->value_as_Conv2DGradInput()->kernelW(); }
int Conv2DGradInput::GetKernelH() const { return this->primitive->value_as_Conv2DGradInput()->kernelH(); }
int Conv2DGradInput::GetStrideW() const { return this->primitive->value_as_Conv2DGradInput()->strideW(); }
int Conv2DGradInput::GetStrideH() const { return this->primitive->value_as_Conv2DGradInput()->strideH(); }
int Conv2DGradInput::GetPadMode() const { return this->primitive->value_as_Conv2DGradInput()->padMode(); }
int Conv2DGradInput::GetPadUp() const { return this->primitive->value_as_Conv2DGradInput()->padUp(); }
int Conv2DGradInput::GetPadDown() const { return this->primitive->value_as_Conv2DGradInput()->padDown(); }
int Conv2DGradInput::GetPadLeft() const { return this->primitive->value_as_Conv2DGradInput()->padLeft(); }
int Conv2DGradInput::GetPadRight() const { return this->primitive->value_as_Conv2DGradInput()->padRight(); }
int Conv2DGradInput::GetDilateW() const { return this->primitive->value_as_Conv2DGradInput()->dilateW(); }
int Conv2DGradInput::GetDilateH() const { return this->primitive->value_as_Conv2DGradInput()->dilateH(); }
bool Conv2DGradInput::GetHasBias() const { return this->primitive->value_as_Conv2DGradInput()->hasBias(); }
int Conv2DGradInput::GetActivationType() const { return this->primitive->value_as_Conv2DGradInput()->activationType(); }
int Conv2DGradInput::GetFormat() const { return this->primitive_->value_as_Conv2DGradInput()->format(); }
int Conv2DGradInput::GetGroup() const { return this->primitive_->value_as_Conv2DGradInput()->group(); }
int Conv2DGradInput::GetChannelIn() const { return this->primitive_->value_as_Conv2DGradInput()->channelIn(); }
int Conv2DGradInput::GetChannelOut() const { return this->primitive_->value_as_Conv2DGradInput()->channelOut(); }
int Conv2DGradInput::GetKernelW() const { return this->primitive_->value_as_Conv2DGradInput()->kernelW(); }
int Conv2DGradInput::GetKernelH() const { return this->primitive_->value_as_Conv2DGradInput()->kernelH(); }
int Conv2DGradInput::GetStrideW() const { return this->primitive_->value_as_Conv2DGradInput()->strideW(); }
int Conv2DGradInput::GetStrideH() const { return this->primitive_->value_as_Conv2DGradInput()->strideH(); }
int Conv2DGradInput::GetPadMode() const { return this->primitive_->value_as_Conv2DGradInput()->padMode(); }
int Conv2DGradInput::GetPadUp() const { return this->primitive_->value_as_Conv2DGradInput()->padUp(); }
int Conv2DGradInput::GetPadDown() const { return this->primitive_->value_as_Conv2DGradInput()->padDown(); }
int Conv2DGradInput::GetPadLeft() const { return this->primitive_->value_as_Conv2DGradInput()->padLeft(); }
int Conv2DGradInput::GetPadRight() const { return this->primitive_->value_as_Conv2DGradInput()->padRight(); }
int Conv2DGradInput::GetDilateW() const { return this->primitive_->value_as_Conv2DGradInput()->dilateW(); }
int Conv2DGradInput::GetDilateH() const { return this->primitive_->value_as_Conv2DGradInput()->dilateH(); }
bool Conv2DGradInput::GetHasBias() const { return this->primitive_->value_as_Conv2DGradInput()->hasBias(); }
int Conv2DGradInput::GetActivationType() const {
return this->primitive_->value_as_Conv2DGradInput()->activationType();
}

void Conv2DGradInput::SetFormat(int format) {}
void Conv2DGradInput::SetGroup(int group) {}


+ 7
- 9
mindspore/lite/src/ops/conv2d_grad_input.h View File

@@ -14,25 +14,23 @@
* limitations under the License.
*/

#ifndef LITE_MINDSPORE_LITE_C_OPS_CONV2_D_GRAD_INPUT_H_
#define LITE_MINDSPORE_LITE_C_OPS_CONV2_D_GRAD_INPUT_H_

#include <vector>
#include <set>
#include <cmath>
#include "ir/dtype/type_id.h"
#include "src/ops/primitive_c.h"
#ifdef PRIMITIVE_WRITEABLE
#include "schema/inner/model_generated.h"
#else
#include "schema/model_generated.h"
#endif

#ifndef LITE_MINDSPORE_LITE_C_OPS_CONV2_D_GRAD_INPUT_H_
#define LITE_MINDSPORE_LITE_C_OPS_CONV2_D_GRAD_INPUT_H_

namespace mindspore {
namespace lite {
class Conv2DGradInput : public PrimitiveC {
public:
explicit Conv2DGradInput(OriginPrimitive *primitive) : PrimitiveC(primitive) {}
#ifdef PRIMITIVE_WRITEABLE
explicit Conv2DGradInput(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
#endif
explicit Conv2DGradInput(schema::Primitive *primitive) : PrimitiveC(primitive) {}

int GetFormat() const;
int GetGroup() const;


+ 8
- 10
mindspore/lite/src/ops/cos.h View File

@@ -14,25 +14,23 @@
* limitations under the License.
*/

#ifndef LITE_MINDSPORE_LITE_C_OPS_COS_H_
#define LITE_MINDSPORE_LITE_C_OPS_COS_H_

#include <vector>
#include <set>
#include <cmath>
#include "ir/dtype/type_id.h"
#include "src/ops/arithmetic_self.h"
#ifdef PRIMITIVE_WRITEABLE
#include "schema/inner/model_generated.h"
#else
#include "schema/model_generated.h"
#endif

#ifndef LITE_MINDSPORE_LITE_C_OPS_COS_H_
#define LITE_MINDSPORE_LITE_C_OPS_COS_H_
#include "src/ops/primitive_c.h"

namespace mindspore {
namespace lite {
class Cos : public ArithmeticSelf {
public:
explicit Cos(OriginPrimitive *primitive) : ArithmeticSelf(primitive) {}
#ifdef PRIMITIVE_WRITEABLE
explicit Cos(schema::PrimitiveT *primitive) : ArithmeticSelf(primitive) {}
#endif
explicit Cos(schema::Primitive *primitive) : ArithmeticSelf(primitive) {}
};
} // namespace lite
} // namespace mindspore


+ 6
- 6
mindspore/lite/src/ops/crop.cc View File

@@ -19,17 +19,17 @@
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
long Crop::GetAxis() const { return this->primitive->value.AsCrop()->axis; }
std::vector<long> Crop::GetOffsets() const { return this->primitive->value.AsCrop()->offsets; }
long Crop::GetAxis() const { return this->primitive_->value.AsCrop()->axis; }
std::vector<long> Crop::GetOffsets() const { return this->primitive_->value.AsCrop()->offsets; }

void Crop::SetAxis(long axis) { this->primitive->value.AsCrop()->axis = axis; }
void Crop::SetOffsets(const std::vector<long> &offsets) { this->primitive->value.AsCrop()->offsets = offsets; }
void Crop::SetAxis(long axis) { this->primitive_->value.AsCrop()->axis = axis; }
void Crop::SetOffsets(const std::vector<long> &offsets) { this->primitive_->value.AsCrop()->offsets = offsets; }

#else

long Crop::GetAxis() const { return this->primitive->value_as_Crop()->axis(); }
long Crop::GetAxis() const { return this->primitive_->value_as_Crop()->axis(); }
std::vector<long> Crop::GetOffsets() const {
auto fb_vector = this->primitive->value_as_Crop()->offsets();
auto fb_vector = this->primitive_->value_as_Crop()->offsets();
return std::vector<long>(fb_vector->begin(), fb_vector->end());
}



+ 7
- 9
mindspore/lite/src/ops/crop.h View File

@@ -14,25 +14,23 @@
* limitations under the License.
*/

#ifndef LITE_MINDSPORE_LITE_C_OPS_CROP_H_
#define LITE_MINDSPORE_LITE_C_OPS_CROP_H_

#include <vector>
#include <set>
#include <cmath>
#include "ir/dtype/type_id.h"
#include "src/ops/primitive_c.h"
#ifdef PRIMITIVE_WRITEABLE
#include "schema/inner/model_generated.h"
#else
#include "schema/model_generated.h"
#endif

#ifndef LITE_MINDSPORE_LITE_C_OPS_CROP_H_
#define LITE_MINDSPORE_LITE_C_OPS_CROP_H_

namespace mindspore {
namespace lite {
class Crop : public PrimitiveC {
public:
explicit Crop(OriginPrimitive *primitive) : PrimitiveC(primitive) {}
#ifdef PRIMITIVE_WRITEABLE
explicit Crop(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
#endif
explicit Crop(schema::Primitive *primitive) : PrimitiveC(primitive) {}

int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
long GetAxis() const;


+ 52
- 52
mindspore/lite/src/ops/deconv2d.cc View File

@@ -19,63 +19,63 @@
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
int DeConv2D::GetFormat() const { return this->primitive->value.AsDeConv2D()->format; }
int DeConv2D::GetGroup() const { return this->primitive->value.AsDeConv2D()->group; }
int DeConv2D::GetChannelIn() const { return this->primitive->value.AsDeConv2D()->channelIn; }
int DeConv2D::GetChannelOut() const { return this->primitive->value.AsDeConv2D()->channelOut; }
int DeConv2D::GetKernelW() const { return this->primitive->value.AsDeConv2D()->kernelW; }
int DeConv2D::GetKernelH() const { return this->primitive->value.AsDeConv2D()->kernelH; }
int DeConv2D::GetStrideW() const { return this->primitive->value.AsDeConv2D()->strideW; }
int DeConv2D::GetStrideH() const { return this->primitive->value.AsDeConv2D()->strideH; }
int DeConv2D::GetPadMode() const { return this->primitive->value.AsDeConv2D()->padMode; }
int DeConv2D::GetPadUp() const { return this->primitive->value.AsDeConv2D()->padUp; }
int DeConv2D::GetPadDown() const { return this->primitive->value.AsDeConv2D()->padDown; }
int DeConv2D::GetPadLeft() const { return this->primitive->value.AsDeConv2D()->padLeft; }
int DeConv2D::GetPadRight() const { return this->primitive->value.AsDeConv2D()->padRight; }
int DeConv2D::GetDilateW() const { return this->primitive->value.AsDeConv2D()->dilateW; }
int DeConv2D::GetDilateH() const { return this->primitive->value.AsDeConv2D()->dilateH; }
bool DeConv2D::GetHasBias() const { return this->primitive->value.AsDeConv2D()->hasBias; }
int DeConv2D::GetActivationType() const { return this->primitive->value.AsDeConv2D()->activationType; }
int DeConv2D::GetFormat() const { return this->primitive_->value.AsDeConv2D()->format; }
int DeConv2D::GetGroup() const { return this->primitive_->value.AsDeConv2D()->group; }
int DeConv2D::GetChannelIn() const { return this->primitive_->value.AsDeConv2D()->channelIn; }
int DeConv2D::GetChannelOut() const { return this->primitive_->value.AsDeConv2D()->channelOut; }
int DeConv2D::GetKernelW() const { return this->primitive_->value.AsDeConv2D()->kernelW; }
int DeConv2D::GetKernelH() const { return this->primitive_->value.AsDeConv2D()->kernelH; }
int DeConv2D::GetStrideW() const { return this->primitive_->value.AsDeConv2D()->strideW; }
int DeConv2D::GetStrideH() const { return this->primitive_->value.AsDeConv2D()->strideH; }
int DeConv2D::GetPadMode() const { return this->primitive_->value.AsDeConv2D()->padMode; }
int DeConv2D::GetPadUp() const { return this->primitive_->value.AsDeConv2D()->padUp; }
int DeConv2D::GetPadDown() const { return this->primitive_->value.AsDeConv2D()->padDown; }
int DeConv2D::GetPadLeft() const { return this->primitive_->value.AsDeConv2D()->padLeft; }
int DeConv2D::GetPadRight() const { return this->primitive_->value.AsDeConv2D()->padRight; }
int DeConv2D::GetDilateW() const { return this->primitive_->value.AsDeConv2D()->dilateW; }
int DeConv2D::GetDilateH() const { return this->primitive_->value.AsDeConv2D()->dilateH; }
bool DeConv2D::GetHasBias() const { return this->primitive_->value.AsDeConv2D()->hasBias; }
int DeConv2D::GetActivationType() const { return this->primitive_->value.AsDeConv2D()->activationType; }

void DeConv2D::SetFormat(int format) { this->primitive->value.AsDeConv2D()->format = (schema::Format)format; }
void DeConv2D::SetGroup(int group) { this->primitive->value.AsDeConv2D()->group = group; }
void DeConv2D::SetChannelIn(int channel_in) { this->primitive->value.AsDeConv2D()->channelIn = channel_in; }
void DeConv2D::SetChannelOut(int channel_out) { this->primitive->value.AsDeConv2D()->channelOut = channel_out; }
void DeConv2D::SetKernelW(int kernel_w) { this->primitive->value.AsDeConv2D()->kernelW = kernel_w; }
void DeConv2D::SetKernelH(int kernel_h) { this->primitive->value.AsDeConv2D()->kernelH = kernel_h; }
void DeConv2D::SetStrideW(int stride_w) { this->primitive->value.AsDeConv2D()->strideW = stride_w; }
void DeConv2D::SetStrideH(int stride_h) { this->primitive->value.AsDeConv2D()->strideH = stride_h; }
void DeConv2D::SetPadMode(int pad_mode) { this->primitive->value.AsDeConv2D()->padMode = (schema::PadMode)pad_mode; }
void DeConv2D::SetPadUp(int pad_up) { this->primitive->value.AsDeConv2D()->padUp = pad_up; }
void DeConv2D::SetPadDown(int pad_down) { this->primitive->value.AsDeConv2D()->padDown = pad_down; }
void DeConv2D::SetPadLeft(int pad_left) { this->primitive->value.AsDeConv2D()->padLeft = pad_left; }
void DeConv2D::SetPadRight(int pad_right) { this->primitive->value.AsDeConv2D()->padRight = pad_right; }
void DeConv2D::SetDilateW(int dilate_w) { this->primitive->value.AsDeConv2D()->dilateW = dilate_w; }
void DeConv2D::SetDilateH(int dilate_h) { this->primitive->value.AsDeConv2D()->dilateH = dilate_h; }
void DeConv2D::SetHasBias(bool has_bias) { this->primitive->value.AsDeConv2D()->hasBias = has_bias; }
void DeConv2D::SetFormat(int format) { this->primitive_->value.AsDeConv2D()->format = (schema::Format)format; }
void DeConv2D::SetGroup(int group) { this->primitive_->value.AsDeConv2D()->group = group; }
void DeConv2D::SetChannelIn(int channel_in) { this->primitive_->value.AsDeConv2D()->channelIn = channel_in; }
void DeConv2D::SetChannelOut(int channel_out) { this->primitive_->value.AsDeConv2D()->channelOut = channel_out; }
void DeConv2D::SetKernelW(int kernel_w) { this->primitive_->value.AsDeConv2D()->kernelW = kernel_w; }
void DeConv2D::SetKernelH(int kernel_h) { this->primitive_->value.AsDeConv2D()->kernelH = kernel_h; }
void DeConv2D::SetStrideW(int stride_w) { this->primitive_->value.AsDeConv2D()->strideW = stride_w; }
void DeConv2D::SetStrideH(int stride_h) { this->primitive_->value.AsDeConv2D()->strideH = stride_h; }
void DeConv2D::SetPadMode(int pad_mode) { this->primitive_->value.AsDeConv2D()->padMode = (schema::PadMode)pad_mode; }
void DeConv2D::SetPadUp(int pad_up) { this->primitive_->value.AsDeConv2D()->padUp = pad_up; }
void DeConv2D::SetPadDown(int pad_down) { this->primitive_->value.AsDeConv2D()->padDown = pad_down; }
void DeConv2D::SetPadLeft(int pad_left) { this->primitive_->value.AsDeConv2D()->padLeft = pad_left; }
void DeConv2D::SetPadRight(int pad_right) { this->primitive_->value.AsDeConv2D()->padRight = pad_right; }
void DeConv2D::SetDilateW(int dilate_w) { this->primitive_->value.AsDeConv2D()->dilateW = dilate_w; }
void DeConv2D::SetDilateH(int dilate_h) { this->primitive_->value.AsDeConv2D()->dilateH = dilate_h; }
void DeConv2D::SetHasBias(bool has_bias) { this->primitive_->value.AsDeConv2D()->hasBias = has_bias; }
void DeConv2D::SetActivationType(int activation_type) {
this->primitive->value.AsDeConv2D()->activationType = (schema::ActivationType)activation_type;
this->primitive_->value.AsDeConv2D()->activationType = (schema::ActivationType)activation_type;
}

#else

int DeConv2D::GetFormat() const { return this->primitive->value_as_DeConv2D()->format(); }
int DeConv2D::GetGroup() const { return this->primitive->value_as_DeConv2D()->group(); }
int DeConv2D::GetChannelIn() const { return this->primitive->value_as_DeConv2D()->channelIn(); }
int DeConv2D::GetChannelOut() const { return this->primitive->value_as_DeConv2D()->channelOut(); }
int DeConv2D::GetKernelW() const { return this->primitive->value_as_DeConv2D()->kernelW(); }
int DeConv2D::GetKernelH() const { return this->primitive->value_as_DeConv2D()->kernelH(); }
int DeConv2D::GetStrideW() const { return this->primitive->value_as_DeConv2D()->strideW(); }
int DeConv2D::GetStrideH() const { return this->primitive->value_as_DeConv2D()->strideH(); }
int DeConv2D::GetPadMode() const { return this->primitive->value_as_DeConv2D()->padMode(); }
int DeConv2D::GetPadUp() const { return this->primitive->value_as_DeConv2D()->padUp(); }
int DeConv2D::GetPadDown() const { return this->primitive->value_as_DeConv2D()->padDown(); }
int DeConv2D::GetPadLeft() const { return this->primitive->value_as_DeConv2D()->padLeft(); }
int DeConv2D::GetPadRight() const { return this->primitive->value_as_DeConv2D()->padRight(); }
int DeConv2D::GetDilateW() const { return this->primitive->value_as_DeConv2D()->dilateW(); }
int DeConv2D::GetDilateH() const { return this->primitive->value_as_DeConv2D()->dilateH(); }
bool DeConv2D::GetHasBias() const { return this->primitive->value_as_DeConv2D()->hasBias(); }
int DeConv2D::GetActivationType() const { return this->primitive->value_as_DeConv2D()->activationType(); }
int DeConv2D::GetFormat() const { return this->primitive_->value_as_DeConv2D()->format(); }
int DeConv2D::GetGroup() const { return this->primitive_->value_as_DeConv2D()->group(); }
int DeConv2D::GetChannelIn() const { return this->primitive_->value_as_DeConv2D()->channelIn(); }
int DeConv2D::GetChannelOut() const { return this->primitive_->value_as_DeConv2D()->channelOut(); }
int DeConv2D::GetKernelW() const { return this->primitive_->value_as_DeConv2D()->kernelW(); }
int DeConv2D::GetKernelH() const { return this->primitive_->value_as_DeConv2D()->kernelH(); }
int DeConv2D::GetStrideW() const { return this->primitive_->value_as_DeConv2D()->strideW(); }
int DeConv2D::GetStrideH() const { return this->primitive_->value_as_DeConv2D()->strideH(); }
int DeConv2D::GetPadMode() const { return this->primitive_->value_as_DeConv2D()->padMode(); }
int DeConv2D::GetPadUp() const { return this->primitive_->value_as_DeConv2D()->padUp(); }
int DeConv2D::GetPadDown() const { return this->primitive_->value_as_DeConv2D()->padDown(); }
int DeConv2D::GetPadLeft() const { return this->primitive_->value_as_DeConv2D()->padLeft(); }
int DeConv2D::GetPadRight() const { return this->primitive_->value_as_DeConv2D()->padRight(); }
int DeConv2D::GetDilateW() const { return this->primitive_->value_as_DeConv2D()->dilateW(); }
int DeConv2D::GetDilateH() const { return this->primitive_->value_as_DeConv2D()->dilateH(); }
bool DeConv2D::GetHasBias() const { return this->primitive_->value_as_DeConv2D()->hasBias(); }
int DeConv2D::GetActivationType() const { return this->primitive_->value_as_DeConv2D()->activationType(); }

void DeConv2D::SetFormat(int format) {}
void DeConv2D::SetGroup(int group) {}
@@ -96,7 +96,7 @@ void DeConv2D::SetHasBias(bool has_bias) {}
void DeConv2D::SetActivationType(int activation_type) {}
#endif
int DeConv2D::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) {
MS_ASSERT(this->primitive != nullptr);
MS_ASSERT(this->primitive_ != nullptr);
auto input = inputs_.front();
MS_ASSERT(input != nullptr);
auto weight = inputs_.at(1);


+ 7
- 9
mindspore/lite/src/ops/deconv2d.h View File

@@ -14,25 +14,23 @@
* limitations under the License.
*/

#ifndef LITE_MINDSPORE_LITE_C_OPS_DE_CONV2_D_H_
#define LITE_MINDSPORE_LITE_C_OPS_DE_CONV2_D_H_

#include <vector>
#include <set>
#include <cmath>
#include "ir/dtype/type_id.h"
#include "src/ops/primitive_c.h"
#ifdef PRIMITIVE_WRITEABLE
#include "schema/inner/model_generated.h"
#else
#include "schema/model_generated.h"
#endif

#ifndef LITE_MINDSPORE_LITE_C_OPS_DE_CONV2_D_H_
#define LITE_MINDSPORE_LITE_C_OPS_DE_CONV2_D_H_

namespace mindspore {
namespace lite {
class DeConv2D : public PrimitiveC {
public:
explicit DeConv2D(OriginPrimitive *primitive) : PrimitiveC(primitive) {}
#ifdef PRIMITIVE_WRITEABLE
explicit DeConv2D(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
#endif
explicit DeConv2D(schema::Primitive *primitive) : PrimitiveC(primitive) {}

int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
int GetFormat() const;


+ 49
- 49
mindspore/lite/src/ops/dedepthwise_conv2d.cc View File

@@ -19,77 +19,77 @@
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
int DeDepthwiseConv2D::GetFormat() const { return this->primitive->value.AsDeDepthwiseConv2D()->format; }
int DeDepthwiseConv2D::GetChannelIn() const { return this->primitive->value.AsDeDepthwiseConv2D()->channelIn; }
int DeDepthwiseConv2D::GetFormat() const { return this->primitive_->value.AsDeDepthwiseConv2D()->format; }
int DeDepthwiseConv2D::GetChannelIn() const { return this->primitive_->value.AsDeDepthwiseConv2D()->channelIn; }
int DeDepthwiseConv2D::GetChannelMultiplier() const {
return this->primitive->value.AsDeDepthwiseConv2D()->channelMultiplier;
return this->primitive_->value.AsDeDepthwiseConv2D()->channelMultiplier;
}
int DeDepthwiseConv2D::GetKernelW() const { return this->primitive->value.AsDeDepthwiseConv2D()->kernelW; }
int DeDepthwiseConv2D::GetKernelH() const { return this->primitive->value.AsDeDepthwiseConv2D()->kernelH; }
int DeDepthwiseConv2D::GetStrideW() const { return this->primitive->value.AsDeDepthwiseConv2D()->strideW; }
int DeDepthwiseConv2D::GetStrideH() const { return this->primitive->value.AsDeDepthwiseConv2D()->strideH; }
int DeDepthwiseConv2D::GetPadMode() const { return this->primitive->value.AsDeDepthwiseConv2D()->padMode; }
int DeDepthwiseConv2D::GetPadUp() const { return this->primitive->value.AsDeDepthwiseConv2D()->padUp; }
int DeDepthwiseConv2D::GetPadDown() const { return this->primitive->value.AsDeDepthwiseConv2D()->padDown; }
int DeDepthwiseConv2D::GetPadLeft() const { return this->primitive->value.AsDeDepthwiseConv2D()->padLeft; }
int DeDepthwiseConv2D::GetPadRight() const { return this->primitive->value.AsDeDepthwiseConv2D()->padRight; }
int DeDepthwiseConv2D::GetDilateW() const { return this->primitive->value.AsDeDepthwiseConv2D()->dilateW; }
int DeDepthwiseConv2D::GetDilateH() const { return this->primitive->value.AsDeDepthwiseConv2D()->dilateH; }
bool DeDepthwiseConv2D::GetHasBias() const { return this->primitive->value.AsDeDepthwiseConv2D()->hasBias; }
int DeDepthwiseConv2D::GetKernelW() const { return this->primitive_->value.AsDeDepthwiseConv2D()->kernelW; }
int DeDepthwiseConv2D::GetKernelH() const { return this->primitive_->value.AsDeDepthwiseConv2D()->kernelH; }
int DeDepthwiseConv2D::GetStrideW() const { return this->primitive_->value.AsDeDepthwiseConv2D()->strideW; }
int DeDepthwiseConv2D::GetStrideH() const { return this->primitive_->value.AsDeDepthwiseConv2D()->strideH; }
int DeDepthwiseConv2D::GetPadMode() const { return this->primitive_->value.AsDeDepthwiseConv2D()->padMode; }
int DeDepthwiseConv2D::GetPadUp() const { return this->primitive_->value.AsDeDepthwiseConv2D()->padUp; }
int DeDepthwiseConv2D::GetPadDown() const { return this->primitive_->value.AsDeDepthwiseConv2D()->padDown; }
int DeDepthwiseConv2D::GetPadLeft() const { return this->primitive_->value.AsDeDepthwiseConv2D()->padLeft; }
int DeDepthwiseConv2D::GetPadRight() const { return this->primitive_->value.AsDeDepthwiseConv2D()->padRight; }
int DeDepthwiseConv2D::GetDilateW() const { return this->primitive_->value.AsDeDepthwiseConv2D()->dilateW; }
int DeDepthwiseConv2D::GetDilateH() const { return this->primitive_->value.AsDeDepthwiseConv2D()->dilateH; }
bool DeDepthwiseConv2D::GetHasBias() const { return this->primitive_->value.AsDeDepthwiseConv2D()->hasBias; }
int DeDepthwiseConv2D::GetActivationType() const {
return this->primitive->value.AsDeDepthwiseConv2D()->activationType;
return this->primitive_->value.AsDeDepthwiseConv2D()->activationType;
}

void DeDepthwiseConv2D::SetFormat(int format) {
this->primitive->value.AsDeDepthwiseConv2D()->format = (schema::Format)format;
this->primitive_->value.AsDeDepthwiseConv2D()->format = (schema::Format)format;
}
void DeDepthwiseConv2D::SetChannelIn(int channel_in) {
this->primitive->value.AsDeDepthwiseConv2D()->channelIn = channel_in;
this->primitive_->value.AsDeDepthwiseConv2D()->channelIn = channel_in;
}
void DeDepthwiseConv2D::SetChannelMultiplier(int channel_multiplier) {
this->primitive->value.AsDeDepthwiseConv2D()->channelMultiplier = channel_multiplier;
this->primitive_->value.AsDeDepthwiseConv2D()->channelMultiplier = channel_multiplier;
}
void DeDepthwiseConv2D::SetKernelW(int kernel_w) { this->primitive->value.AsDeDepthwiseConv2D()->kernelW = kernel_w; }
void DeDepthwiseConv2D::SetKernelH(int kernel_h) { this->primitive->value.AsDeDepthwiseConv2D()->kernelH = kernel_h; }
void DeDepthwiseConv2D::SetStrideW(int stride_w) { this->primitive->value.AsDeDepthwiseConv2D()->strideW = stride_w; }
void DeDepthwiseConv2D::SetStrideH(int stride_h) { this->primitive->value.AsDeDepthwiseConv2D()->strideH = stride_h; }
void DeDepthwiseConv2D::SetKernelW(int kernel_w) { this->primitive_->value.AsDeDepthwiseConv2D()->kernelW = kernel_w; }
void DeDepthwiseConv2D::SetKernelH(int kernel_h) { this->primitive_->value.AsDeDepthwiseConv2D()->kernelH = kernel_h; }
void DeDepthwiseConv2D::SetStrideW(int stride_w) { this->primitive_->value.AsDeDepthwiseConv2D()->strideW = stride_w; }
void DeDepthwiseConv2D::SetStrideH(int stride_h) { this->primitive_->value.AsDeDepthwiseConv2D()->strideH = stride_h; }
void DeDepthwiseConv2D::SetPadMode(int pad_mode) {
this->primitive->value.AsDeDepthwiseConv2D()->padMode = (schema::PadMode)pad_mode;
this->primitive_->value.AsDeDepthwiseConv2D()->padMode = (schema::PadMode)pad_mode;
}
void DeDepthwiseConv2D::SetPadUp(int pad_up) { this->primitive->value.AsDeDepthwiseConv2D()->padUp = pad_up; }
void DeDepthwiseConv2D::SetPadDown(int pad_down) { this->primitive->value.AsDeDepthwiseConv2D()->padDown = pad_down; }
void DeDepthwiseConv2D::SetPadLeft(int pad_left) { this->primitive->value.AsDeDepthwiseConv2D()->padLeft = pad_left; }
void DeDepthwiseConv2D::SetPadUp(int pad_up) { this->primitive_->value.AsDeDepthwiseConv2D()->padUp = pad_up; }
void DeDepthwiseConv2D::SetPadDown(int pad_down) { this->primitive_->value.AsDeDepthwiseConv2D()->padDown = pad_down; }
void DeDepthwiseConv2D::SetPadLeft(int pad_left) { this->primitive_->value.AsDeDepthwiseConv2D()->padLeft = pad_left; }
void DeDepthwiseConv2D::SetPadRight(int pad_right) {
this->primitive->value.AsDeDepthwiseConv2D()->padRight = pad_right;
this->primitive_->value.AsDeDepthwiseConv2D()->padRight = pad_right;
}
void DeDepthwiseConv2D::SetDilateW(int dilate_w) { this->primitive->value.AsDeDepthwiseConv2D()->dilateW = dilate_w; }
void DeDepthwiseConv2D::SetDilateH(int dilate_h) { this->primitive->value.AsDeDepthwiseConv2D()->dilateH = dilate_h; }
void DeDepthwiseConv2D::SetHasBias(bool has_bias) { this->primitive->value.AsDeDepthwiseConv2D()->hasBias = has_bias; }
void DeDepthwiseConv2D::SetDilateW(int dilate_w) { this->primitive_->value.AsDeDepthwiseConv2D()->dilateW = dilate_w; }
void DeDepthwiseConv2D::SetDilateH(int dilate_h) { this->primitive_->value.AsDeDepthwiseConv2D()->dilateH = dilate_h; }
void DeDepthwiseConv2D::SetHasBias(bool has_bias) { this->primitive_->value.AsDeDepthwiseConv2D()->hasBias = has_bias; }
void DeDepthwiseConv2D::SetActivationType(int activation_type) {
this->primitive->value.AsDeDepthwiseConv2D()->activationType = (schema::ActivationType)activation_type;
this->primitive_->value.AsDeDepthwiseConv2D()->activationType = (schema::ActivationType)activation_type;
}

#else

int DeDepthwiseConv2D::GetFormat() const { return this->primitive->value_as_DeDepthwiseConv2D()->format(); }
int DeDepthwiseConv2D::GetChannelIn() const { return this->primitive->value_as_DeDepthwiseConv2D()->channelIn(); }
int DeDepthwiseConv2D::GetFormat() const { return this->primitive_->value_as_DeDepthwiseConv2D()->format(); }
int DeDepthwiseConv2D::GetChannelIn() const { return this->primitive_->value_as_DeDepthwiseConv2D()->channelIn(); }
int DeDepthwiseConv2D::GetChannelMultiplier() const {
return this->primitive->value_as_DeDepthwiseConv2D()->channelMultiplier();
return this->primitive_->value_as_DeDepthwiseConv2D()->channelMultiplier();
}
int DeDepthwiseConv2D::GetKernelW() const { return this->primitive->value_as_DeDepthwiseConv2D()->kernelW(); }
int DeDepthwiseConv2D::GetKernelH() const { return this->primitive->value_as_DeDepthwiseConv2D()->kernelH(); }
int DeDepthwiseConv2D::GetStrideW() const { return this->primitive->value_as_DeDepthwiseConv2D()->strideW(); }
int DeDepthwiseConv2D::GetStrideH() const { return this->primitive->value_as_DeDepthwiseConv2D()->strideH(); }
int DeDepthwiseConv2D::GetPadMode() const { return this->primitive->value_as_DeDepthwiseConv2D()->padMode(); }
int DeDepthwiseConv2D::GetPadUp() const { return this->primitive->value_as_DeDepthwiseConv2D()->padUp(); }
int DeDepthwiseConv2D::GetPadDown() const { return this->primitive->value_as_DeDepthwiseConv2D()->padDown(); }
int DeDepthwiseConv2D::GetPadLeft() const { return this->primitive->value_as_DeDepthwiseConv2D()->padLeft(); }
int DeDepthwiseConv2D::GetPadRight() const { return this->primitive->value_as_DeDepthwiseConv2D()->padRight(); }
int DeDepthwiseConv2D::GetDilateW() const { return this->primitive->value_as_DeDepthwiseConv2D()->dilateW(); }
int DeDepthwiseConv2D::GetDilateH() const { return this->primitive->value_as_DeDepthwiseConv2D()->dilateH(); }
bool DeDepthwiseConv2D::GetHasBias() const { return this->primitive->value_as_DeDepthwiseConv2D()->hasBias(); }
int DeDepthwiseConv2D::GetKernelW() const { return this->primitive_->value_as_DeDepthwiseConv2D()->kernelW(); }
int DeDepthwiseConv2D::GetKernelH() const { return this->primitive_->value_as_DeDepthwiseConv2D()->kernelH(); }
int DeDepthwiseConv2D::GetStrideW() const { return this->primitive_->value_as_DeDepthwiseConv2D()->strideW(); }
int DeDepthwiseConv2D::GetStrideH() const { return this->primitive_->value_as_DeDepthwiseConv2D()->strideH(); }
int DeDepthwiseConv2D::GetPadMode() const { return this->primitive_->value_as_DeDepthwiseConv2D()->padMode(); }
int DeDepthwiseConv2D::GetPadUp() const { return this->primitive_->value_as_DeDepthwiseConv2D()->padUp(); }
int DeDepthwiseConv2D::GetPadDown() const { return this->primitive_->value_as_DeDepthwiseConv2D()->padDown(); }
int DeDepthwiseConv2D::GetPadLeft() const { return this->primitive_->value_as_DeDepthwiseConv2D()->padLeft(); }
int DeDepthwiseConv2D::GetPadRight() const { return this->primitive_->value_as_DeDepthwiseConv2D()->padRight(); }
int DeDepthwiseConv2D::GetDilateW() const { return this->primitive_->value_as_DeDepthwiseConv2D()->dilateW(); }
int DeDepthwiseConv2D::GetDilateH() const { return this->primitive_->value_as_DeDepthwiseConv2D()->dilateH(); }
bool DeDepthwiseConv2D::GetHasBias() const { return this->primitive_->value_as_DeDepthwiseConv2D()->hasBias(); }
int DeDepthwiseConv2D::GetActivationType() const {
return this->primitive->value_as_DeDepthwiseConv2D()->activationType();
return this->primitive_->value_as_DeDepthwiseConv2D()->activationType();
}

void DeDepthwiseConv2D::SetFormat(int format) {}
@@ -119,7 +119,7 @@ int DeDepthwiseConv2D::InferShape(std::vector<lite::tensor::Tensor *> inputs_,
MS_LOG(ERROR) << "output number is invalid";
return 1;
}
MS_ASSERT(this->primitive != nullptr);
MS_ASSERT(this->primitive_ != nullptr);
auto input = inputs_.front();
MS_ASSERT(input != nullptr);
auto weight = inputs_.at(1);


+ 7
- 9
mindspore/lite/src/ops/dedepthwise_conv2d.h View File

@@ -14,25 +14,23 @@
* limitations under the License.
*/

#ifndef LITE_MINDSPORE_LITE_C_OPS_DE_DEPTHWISE_CONV2_D_H_
#define LITE_MINDSPORE_LITE_C_OPS_DE_DEPTHWISE_CONV2_D_H_

#include <vector>
#include <set>
#include <cmath>
#include "ir/dtype/type_id.h"
#include "src/ops/primitive_c.h"
#ifdef PRIMITIVE_WRITEABLE
#include "schema/inner/model_generated.h"
#else
#include "schema/model_generated.h"
#endif

#ifndef LITE_MINDSPORE_LITE_C_OPS_DE_DEPTHWISE_CONV2_D_H_
#define LITE_MINDSPORE_LITE_C_OPS_DE_DEPTHWISE_CONV2_D_H_

namespace mindspore {
namespace lite {
class DeDepthwiseConv2D : public PrimitiveC {
public:
explicit DeDepthwiseConv2D(OriginPrimitive *primitive) : PrimitiveC(primitive) {}
#ifdef PRIMITIVE_WRITEABLE
explicit DeDepthwiseConv2D(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
#endif
explicit DeDepthwiseConv2D(schema::Primitive *primitive) : PrimitiveC(primitive) {}

int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
int GetFormat() const;


+ 7
- 7
mindspore/lite/src/ops/depth_to_space.cc View File

@@ -19,16 +19,16 @@
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
int DepthToSpace::GetBlockSize() const { return this->primitive->value.AsDepthToSpace()->blockSize; }
int DepthToSpace::GetFormat() const { return this->primitive->value.AsDepthToSpace()->format; }
int DepthToSpace::GetBlockSize() const { return this->primitive_->value.AsDepthToSpace()->blockSize; }
int DepthToSpace::GetFormat() const { return this->primitive_->value.AsDepthToSpace()->format; }

void DepthToSpace::SetBlockSize(int block_size) { this->primitive->value.AsDepthToSpace()->blockSize = block_size; }
void DepthToSpace::SetFormat(int format) { this->primitive->value.AsDepthToSpace()->format = (schema::Format)format; }
void DepthToSpace::SetBlockSize(int block_size) { this->primitive_->value.AsDepthToSpace()->blockSize = block_size; }
void DepthToSpace::SetFormat(int format) { this->primitive_->value.AsDepthToSpace()->format = (schema::Format)format; }

#else

int DepthToSpace::GetBlockSize() const { return this->primitive->value_as_DepthToSpace()->blockSize(); }
int DepthToSpace::GetFormat() const { return this->primitive->value_as_DepthToSpace()->format(); }
int DepthToSpace::GetBlockSize() const { return this->primitive_->value_as_DepthToSpace()->blockSize(); }
int DepthToSpace::GetFormat() const { return this->primitive_->value_as_DepthToSpace()->format(); }

void DepthToSpace::SetBlockSize(int block_size) {}
void DepthToSpace::SetFormat(int format) {}
@@ -39,7 +39,7 @@ constexpr int kDepthToSpaceInputNum = 1;
} // namespace

int DepthToSpace::InferShape(std::vector<lite::tensor::Tensor *> inputs, std::vector<lite::tensor::Tensor *> outputs) {
MS_ASSERT(this->primitive != nullptr);
MS_ASSERT(this->primitive_ != nullptr);
if (outputs.size() != kDepthToSpaceOutputNum || inputs.size() != kDepthToSpaceInputNum) {
MS_LOG(ERROR) << "Invalid output/input size! output size: " << outputs.size() << ",input size: " << inputs.size();
return RET_PARAM_INVALID;


+ 7
- 9
mindspore/lite/src/ops/depth_to_space.h View File

@@ -14,25 +14,23 @@
* limitations under the License.
*/

#ifndef LITE_MINDSPORE_LITE_C_OPS_DEPTH_TO_SPACE_H_
#define LITE_MINDSPORE_LITE_C_OPS_DEPTH_TO_SPACE_H_

#include <vector>
#include <set>
#include <cmath>
#include "ir/dtype/type_id.h"
#include "src/ops/primitive_c.h"
#ifdef PRIMITIVE_WRITEABLE
#include "schema/inner/model_generated.h"
#else
#include "schema/model_generated.h"
#endif

#ifndef LITE_MINDSPORE_LITE_C_OPS_DEPTH_TO_SPACE_H_
#define LITE_MINDSPORE_LITE_C_OPS_DEPTH_TO_SPACE_H_

namespace mindspore {
namespace lite {
class DepthToSpace : public PrimitiveC {
public:
explicit DepthToSpace(OriginPrimitive *primitive) : PrimitiveC(primitive) {}
#ifdef PRIMITIVE_WRITEABLE
explicit DepthToSpace(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
#endif
explicit DepthToSpace(schema::Primitive *primitive) : PrimitiveC(primitive) {}

int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
int GetBlockSize() const;


+ 51
- 49
mindspore/lite/src/ops/depthwise_conv2d.cc View File

@@ -19,72 +19,74 @@
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
int DepthwiseConv2D::GetFormat() const { return this->primitive->value.AsDepthwiseConv2D()->format; }
int DepthwiseConv2D::GetChannelIn() const { return this->primitive->value.AsDepthwiseConv2D()->channelIn; }
int DepthwiseConv2D::GetFormat() const { return this->primitive_->value.AsDepthwiseConv2D()->format; }
int DepthwiseConv2D::GetChannelIn() const { return this->primitive_->value.AsDepthwiseConv2D()->channelIn; }
int DepthwiseConv2D::GetChannelMultiplier() const {
return this->primitive->value.AsDepthwiseConv2D()->channelMultiplier;
return this->primitive_->value.AsDepthwiseConv2D()->channelMultiplier;
}
int DepthwiseConv2D::GetKernelW() const { return this->primitive->value.AsDepthwiseConv2D()->kernelW; }
int DepthwiseConv2D::GetKernelH() const { return this->primitive->value.AsDepthwiseConv2D()->kernelH; }
int DepthwiseConv2D::GetStrideW() const { return this->primitive->value.AsDepthwiseConv2D()->strideW; }
int DepthwiseConv2D::GetStrideH() const { return this->primitive->value.AsDepthwiseConv2D()->strideH; }
int DepthwiseConv2D::GetPadMode() const { return this->primitive->value.AsDepthwiseConv2D()->padMode; }
int DepthwiseConv2D::GetPadUp() const { return this->primitive->value.AsDepthwiseConv2D()->padUp; }
int DepthwiseConv2D::GetPadDown() const { return this->primitive->value.AsDepthwiseConv2D()->padDown; }
int DepthwiseConv2D::GetPadLeft() const { return this->primitive->value.AsDepthwiseConv2D()->padLeft; }
int DepthwiseConv2D::GetPadRight() const { return this->primitive->value.AsDepthwiseConv2D()->padRight; }
int DepthwiseConv2D::GetDilateW() const { return this->primitive->value.AsDepthwiseConv2D()->dilateW; }
int DepthwiseConv2D::GetDilateH() const { return this->primitive->value.AsDepthwiseConv2D()->dilateH; }
bool DepthwiseConv2D::GetHasBias() const { return this->primitive->value.AsDepthwiseConv2D()->hasBias; }
int DepthwiseConv2D::GetActivationType() const { return this->primitive->value.AsDepthwiseConv2D()->activationType; }
int DepthwiseConv2D::GetKernelW() const { return this->primitive_->value.AsDepthwiseConv2D()->kernelW; }
int DepthwiseConv2D::GetKernelH() const { return this->primitive_->value.AsDepthwiseConv2D()->kernelH; }
int DepthwiseConv2D::GetStrideW() const { return this->primitive_->value.AsDepthwiseConv2D()->strideW; }
int DepthwiseConv2D::GetStrideH() const { return this->primitive_->value.AsDepthwiseConv2D()->strideH; }
int DepthwiseConv2D::GetPadMode() const { return this->primitive_->value.AsDepthwiseConv2D()->padMode; }
int DepthwiseConv2D::GetPadUp() const { return this->primitive_->value.AsDepthwiseConv2D()->padUp; }
int DepthwiseConv2D::GetPadDown() const { return this->primitive_->value.AsDepthwiseConv2D()->padDown; }
int DepthwiseConv2D::GetPadLeft() const { return this->primitive_->value.AsDepthwiseConv2D()->padLeft; }
int DepthwiseConv2D::GetPadRight() const { return this->primitive_->value.AsDepthwiseConv2D()->padRight; }
int DepthwiseConv2D::GetDilateW() const { return this->primitive_->value.AsDepthwiseConv2D()->dilateW; }
int DepthwiseConv2D::GetDilateH() const { return this->primitive_->value.AsDepthwiseConv2D()->dilateH; }
bool DepthwiseConv2D::GetHasBias() const { return this->primitive_->value.AsDepthwiseConv2D()->hasBias; }
int DepthwiseConv2D::GetActivationType() const { return this->primitive_->value.AsDepthwiseConv2D()->activationType; }

void DepthwiseConv2D::SetFormat(int format) {
this->primitive->value.AsDepthwiseConv2D()->format = (schema::Format)format;
this->primitive_->value.AsDepthwiseConv2D()->format = (schema::Format)format;
}
void DepthwiseConv2D::SetChannelIn(int channel_in) {
this->primitive->value.AsDepthwiseConv2D()->channelIn = channel_in;
this->primitive_->value.AsDepthwiseConv2D()->channelIn = channel_in;
}
void DepthwiseConv2D::SetChannelMultiplier(int channel_multiplier) {
this->primitive->value.AsDepthwiseConv2D()->channelMultiplier = channel_multiplier;
this->primitive_->value.AsDepthwiseConv2D()->channelMultiplier = channel_multiplier;
}
void DepthwiseConv2D::SetKernelW(int kernel_w) { this->primitive->value.AsDepthwiseConv2D()->kernelW = kernel_w; }
void DepthwiseConv2D::SetKernelH(int kernel_h) { this->primitive->value.AsDepthwiseConv2D()->kernelH = kernel_h; }
void DepthwiseConv2D::SetStrideW(int stride_w) { this->primitive->value.AsDepthwiseConv2D()->strideW = stride_w; }
void DepthwiseConv2D::SetStrideH(int stride_h) { this->primitive->value.AsDepthwiseConv2D()->strideH = stride_h; }
void DepthwiseConv2D::SetKernelW(int kernel_w) { this->primitive_->value.AsDepthwiseConv2D()->kernelW = kernel_w; }
void DepthwiseConv2D::SetKernelH(int kernel_h) { this->primitive_->value.AsDepthwiseConv2D()->kernelH = kernel_h; }
void DepthwiseConv2D::SetStrideW(int stride_w) { this->primitive_->value.AsDepthwiseConv2D()->strideW = stride_w; }
void DepthwiseConv2D::SetStrideH(int stride_h) { this->primitive_->value.AsDepthwiseConv2D()->strideH = stride_h; }
void DepthwiseConv2D::SetPadMode(int pad_mode) {
this->primitive->value.AsDepthwiseConv2D()->padMode = (schema::PadMode)pad_mode;
this->primitive_->value.AsDepthwiseConv2D()->padMode = (schema::PadMode)pad_mode;
}
void DepthwiseConv2D::SetPadUp(int pad_up) { this->primitive->value.AsDepthwiseConv2D()->padUp = pad_up; }
void DepthwiseConv2D::SetPadDown(int pad_down) { this->primitive->value.AsDepthwiseConv2D()->padDown = pad_down; }
void DepthwiseConv2D::SetPadLeft(int pad_left) { this->primitive->value.AsDepthwiseConv2D()->padLeft = pad_left; }
void DepthwiseConv2D::SetPadRight(int pad_right) { this->primitive->value.AsDepthwiseConv2D()->padRight = pad_right; }
void DepthwiseConv2D::SetDilateW(int dilate_w) { this->primitive->value.AsDepthwiseConv2D()->dilateW = dilate_w; }
void DepthwiseConv2D::SetDilateH(int dilate_h) { this->primitive->value.AsDepthwiseConv2D()->dilateH = dilate_h; }
void DepthwiseConv2D::SetHasBias(bool has_bias) { this->primitive->value.AsDepthwiseConv2D()->hasBias = has_bias; }
void DepthwiseConv2D::SetPadUp(int pad_up) { this->primitive_->value.AsDepthwiseConv2D()->padUp = pad_up; }
void DepthwiseConv2D::SetPadDown(int pad_down) { this->primitive_->value.AsDepthwiseConv2D()->padDown = pad_down; }
void DepthwiseConv2D::SetPadLeft(int pad_left) { this->primitive_->value.AsDepthwiseConv2D()->padLeft = pad_left; }
void DepthwiseConv2D::SetPadRight(int pad_right) { this->primitive_->value.AsDepthwiseConv2D()->padRight = pad_right; }
void DepthwiseConv2D::SetDilateW(int dilate_w) { this->primitive_->value.AsDepthwiseConv2D()->dilateW = dilate_w; }
void DepthwiseConv2D::SetDilateH(int dilate_h) { this->primitive_->value.AsDepthwiseConv2D()->dilateH = dilate_h; }
void DepthwiseConv2D::SetHasBias(bool has_bias) { this->primitive_->value.AsDepthwiseConv2D()->hasBias = has_bias; }
void DepthwiseConv2D::SetActivationType(int activation_type) {
this->primitive->value.AsDepthwiseConv2D()->activationType = (schema::ActivationType)activation_type;
this->primitive_->value.AsDepthwiseConv2D()->activationType = (schema::ActivationType)activation_type;
}

#else

int DepthwiseConv2D::GetFormat() const { return this->primitive->value_as_DepthwiseConv2D()->format(); }
int DepthwiseConv2D::GetChannelIn() const { return this->primitive->value_as_DepthwiseConv2D()->channelIn(); }
int DepthwiseConv2D::GetFormat() const { return this->primitive_->value_as_DepthwiseConv2D()->format(); }
int DepthwiseConv2D::GetChannelIn() const { return this->primitive_->value_as_DepthwiseConv2D()->channelIn(); }
int DepthwiseConv2D::GetChannelMultiplier() const {
return this->primitive->value_as_DepthwiseConv2D()->channelMultiplier();
return this->primitive_->value_as_DepthwiseConv2D()->channelMultiplier();
}
int DepthwiseConv2D::GetKernelW() const { return this->primitive_->value_as_DepthwiseConv2D()->kernelW(); }
int DepthwiseConv2D::GetKernelH() const { return this->primitive_->value_as_DepthwiseConv2D()->kernelH(); }
int DepthwiseConv2D::GetStrideW() const { return this->primitive_->value_as_DepthwiseConv2D()->strideW(); }
int DepthwiseConv2D::GetStrideH() const { return this->primitive_->value_as_DepthwiseConv2D()->strideH(); }
int DepthwiseConv2D::GetPadMode() const { return this->primitive_->value_as_DepthwiseConv2D()->padMode(); }
int DepthwiseConv2D::GetPadUp() const { return this->primitive_->value_as_DepthwiseConv2D()->padUp(); }
int DepthwiseConv2D::GetPadDown() const { return this->primitive_->value_as_DepthwiseConv2D()->padDown(); }
int DepthwiseConv2D::GetPadLeft() const { return this->primitive_->value_as_DepthwiseConv2D()->padLeft(); }
int DepthwiseConv2D::GetPadRight() const { return this->primitive_->value_as_DepthwiseConv2D()->padRight(); }
int DepthwiseConv2D::GetDilateW() const { return this->primitive_->value_as_DepthwiseConv2D()->dilateW(); }
int DepthwiseConv2D::GetDilateH() const { return this->primitive_->value_as_DepthwiseConv2D()->dilateH(); }
bool DepthwiseConv2D::GetHasBias() const { return this->primitive_->value_as_DepthwiseConv2D()->hasBias(); }
int DepthwiseConv2D::GetActivationType() const {
return this->primitive_->value_as_DepthwiseConv2D()->activationType();
}
int DepthwiseConv2D::GetKernelW() const { return this->primitive->value_as_DepthwiseConv2D()->kernelW(); }
int DepthwiseConv2D::GetKernelH() const { return this->primitive->value_as_DepthwiseConv2D()->kernelH(); }
int DepthwiseConv2D::GetStrideW() const { return this->primitive->value_as_DepthwiseConv2D()->strideW(); }
int DepthwiseConv2D::GetStrideH() const { return this->primitive->value_as_DepthwiseConv2D()->strideH(); }
int DepthwiseConv2D::GetPadMode() const { return this->primitive->value_as_DepthwiseConv2D()->padMode(); }
int DepthwiseConv2D::GetPadUp() const { return this->primitive->value_as_DepthwiseConv2D()->padUp(); }
int DepthwiseConv2D::GetPadDown() const { return this->primitive->value_as_DepthwiseConv2D()->padDown(); }
int DepthwiseConv2D::GetPadLeft() const { return this->primitive->value_as_DepthwiseConv2D()->padLeft(); }
int DepthwiseConv2D::GetPadRight() const { return this->primitive->value_as_DepthwiseConv2D()->padRight(); }
int DepthwiseConv2D::GetDilateW() const { return this->primitive->value_as_DepthwiseConv2D()->dilateW(); }
int DepthwiseConv2D::GetDilateH() const { return this->primitive->value_as_DepthwiseConv2D()->dilateH(); }
bool DepthwiseConv2D::GetHasBias() const { return this->primitive->value_as_DepthwiseConv2D()->hasBias(); }
int DepthwiseConv2D::GetActivationType() const { return this->primitive->value_as_DepthwiseConv2D()->activationType(); }

void DepthwiseConv2D::SetFormat(int format) {}
void DepthwiseConv2D::SetChannelIn(int channel_in) {}
@@ -113,7 +115,7 @@ int DepthwiseConv2D::InferShape(std::vector<lite::tensor::Tensor *> inputs_,
MS_LOG(ERROR) << "output number is invalid";
return 1;
}
MS_ASSERT(this->primitive != nullptr);
MS_ASSERT(this->primitive_ != nullptr);
auto input = inputs_.front();
MS_ASSERT(input != nullptr);
auto weight = inputs_.at(1);


+ 7
- 9
mindspore/lite/src/ops/depthwise_conv2d.h View File

@@ -14,25 +14,23 @@
* limitations under the License.
*/

#ifndef LITE_MINDSPORE_LITE_C_OPS_DEPTHWISE_CONV2_D_H_
#define LITE_MINDSPORE_LITE_C_OPS_DEPTHWISE_CONV2_D_H_

#include <vector>
#include <set>
#include <cmath>
#include "ir/dtype/type_id.h"
#include "src/ops/primitive_c.h"
#ifdef PRIMITIVE_WRITEABLE
#include "schema/inner/model_generated.h"
#else
#include "schema/model_generated.h"
#endif

#ifndef LITE_MINDSPORE_LITE_C_OPS_DEPTHWISE_CONV2_D_H_
#define LITE_MINDSPORE_LITE_C_OPS_DEPTHWISE_CONV2_D_H_

namespace mindspore {
namespace lite {
class DepthwiseConv2D : public PrimitiveC {
public:
explicit DepthwiseConv2D(OriginPrimitive *primitive) : PrimitiveC(primitive) {}
#ifdef PRIMITIVE_WRITEABLE
explicit DepthwiseConv2D(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
#endif
explicit DepthwiseConv2D(schema::Primitive *primitive) : PrimitiveC(primitive) {}

int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
int GetFormat() const;


+ 43
- 39
mindspore/lite/src/ops/detection_post_process.cc View File

@@ -19,100 +19,104 @@
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
int DetectionPostProcess::GetFormat() const { return this->primitive->value.AsDetectionPostProcess()->format; }
int DetectionPostProcess::GetInputSize() const { return this->primitive->value.AsDetectionPostProcess()->inputSize; }
float DetectionPostProcess::GetHScale() const { return this->primitive->value.AsDetectionPostProcess()->hScale; }
float DetectionPostProcess::GetWScale() const { return this->primitive->value.AsDetectionPostProcess()->wScale; }
float DetectionPostProcess::GetXScale() const { return this->primitive->value.AsDetectionPostProcess()->xScale; }
float DetectionPostProcess::GetYScale() const { return this->primitive->value.AsDetectionPostProcess()->yScale; }
int DetectionPostProcess::GetFormat() const { return this->primitive_->value.AsDetectionPostProcess()->format; }
int DetectionPostProcess::GetInputSize() const { return this->primitive_->value.AsDetectionPostProcess()->inputSize; }
float DetectionPostProcess::GetHScale() const { return this->primitive_->value.AsDetectionPostProcess()->hScale; }
float DetectionPostProcess::GetWScale() const { return this->primitive_->value.AsDetectionPostProcess()->wScale; }
float DetectionPostProcess::GetXScale() const { return this->primitive_->value.AsDetectionPostProcess()->xScale; }
float DetectionPostProcess::GetYScale() const { return this->primitive_->value.AsDetectionPostProcess()->yScale; }
float DetectionPostProcess::GetNmsIouThreshold() const {
return this->primitive->value.AsDetectionPostProcess()->NmsIouThreshold;
return this->primitive_->value.AsDetectionPostProcess()->NmsIouThreshold;
}
float DetectionPostProcess::GetNmsScoreThreshold() const {
return this->primitive->value.AsDetectionPostProcess()->NmsScoreThreshold;
return this->primitive_->value.AsDetectionPostProcess()->NmsScoreThreshold;
}
long DetectionPostProcess::GetMaxDetections() const {
return this->primitive->value.AsDetectionPostProcess()->MaxDetections;
return this->primitive_->value.AsDetectionPostProcess()->MaxDetections;
}
long DetectionPostProcess::GetDetectionsPreClass() const {
return this->primitive->value.AsDetectionPostProcess()->DetectionsPreClass;
return this->primitive_->value.AsDetectionPostProcess()->DetectionsPreClass;
}
long DetectionPostProcess::GetMaxClassesPreDetection() const {
return this->primitive->value.AsDetectionPostProcess()->MaxClassesPreDetection;
return this->primitive_->value.AsDetectionPostProcess()->MaxClassesPreDetection;
}
long DetectionPostProcess::GetNumClasses() const {
return this->primitive_->value.AsDetectionPostProcess()->NumClasses;
}
long DetectionPostProcess::GetNumClasses() const { return this->primitive->value.AsDetectionPostProcess()->NumClasses; }
bool DetectionPostProcess::GetUseRegularNms() const {
return this->primitive->value.AsDetectionPostProcess()->UseRegularNms;
return this->primitive_->value.AsDetectionPostProcess()->UseRegularNms;
}

void DetectionPostProcess::SetFormat(int format) {
this->primitive->value.AsDetectionPostProcess()->format = (schema::Format)format;
this->primitive_->value.AsDetectionPostProcess()->format = (schema::Format)format;
}
void DetectionPostProcess::SetInputSize(int input_size) {
this->primitive->value.AsDetectionPostProcess()->inputSize = input_size;
this->primitive_->value.AsDetectionPostProcess()->inputSize = input_size;
}
void DetectionPostProcess::SetHScale(float h_scale) {
this->primitive->value.AsDetectionPostProcess()->hScale = h_scale;
this->primitive_->value.AsDetectionPostProcess()->hScale = h_scale;
}
void DetectionPostProcess::SetWScale(float w_scale) {
this->primitive->value.AsDetectionPostProcess()->wScale = w_scale;
this->primitive_->value.AsDetectionPostProcess()->wScale = w_scale;
}
void DetectionPostProcess::SetXScale(float x_scale) {
this->primitive->value.AsDetectionPostProcess()->xScale = x_scale;
this->primitive_->value.AsDetectionPostProcess()->xScale = x_scale;
}
void DetectionPostProcess::SetYScale(float y_scale) {
this->primitive->value.AsDetectionPostProcess()->yScale = y_scale;
this->primitive_->value.AsDetectionPostProcess()->yScale = y_scale;
}
void DetectionPostProcess::SetNmsIouThreshold(float nms_iou_threshold) {
this->primitive->value.AsDetectionPostProcess()->NmsIouThreshold = nms_iou_threshold;
this->primitive_->value.AsDetectionPostProcess()->NmsIouThreshold = nms_iou_threshold;
}
void DetectionPostProcess::SetNmsScoreThreshold(float nms_score_threshold) {
this->primitive->value.AsDetectionPostProcess()->NmsScoreThreshold = nms_score_threshold;
this->primitive_->value.AsDetectionPostProcess()->NmsScoreThreshold = nms_score_threshold;
}
void DetectionPostProcess::SetMaxDetections(long max_detections) {
this->primitive->value.AsDetectionPostProcess()->MaxClassesPreDetection = max_detections;
this->primitive_->value.AsDetectionPostProcess()->MaxClassesPreDetection = max_detections;
}
void DetectionPostProcess::SetDetectionsPreClass(long detections_pre_class) {
this->primitive->value.AsDetectionPostProcess()->DetectionsPreClass = detections_pre_class;
this->primitive_->value.AsDetectionPostProcess()->DetectionsPreClass = detections_pre_class;
}
void DetectionPostProcess::SetMaxClassesPreDetection(long max_classes_pre_detection) {
this->primitive->value.AsDetectionPostProcess()->MaxClassesPreDetection = max_classes_pre_detection;
this->primitive_->value.AsDetectionPostProcess()->MaxClassesPreDetection = max_classes_pre_detection;
}
void DetectionPostProcess::SetNumClasses(long num_classes) {
this->primitive->value.AsDetectionPostProcess()->NumClasses = num_classes;
this->primitive_->value.AsDetectionPostProcess()->NumClasses = num_classes;
}
void DetectionPostProcess::SetUseRegularNms(bool use_regular_nms) {
this->primitive->value.AsDetectionPostProcess()->UseRegularNms = use_regular_nms;
this->primitive_->value.AsDetectionPostProcess()->UseRegularNms = use_regular_nms;
}

#else

int DetectionPostProcess::GetFormat() const { return this->primitive->value_as_DetectionPostProcess()->format(); }
int DetectionPostProcess::GetInputSize() const { return this->primitive->value_as_DetectionPostProcess()->inputSize(); }
float DetectionPostProcess::GetHScale() const { return this->primitive->value_as_DetectionPostProcess()->hScale(); }
float DetectionPostProcess::GetWScale() const { return this->primitive->value_as_DetectionPostProcess()->wScale(); }
float DetectionPostProcess::GetXScale() const { return this->primitive->value_as_DetectionPostProcess()->xScale(); }
float DetectionPostProcess::GetYScale() const { return this->primitive->value_as_DetectionPostProcess()->yScale(); }
int DetectionPostProcess::GetFormat() const { return this->primitive_->value_as_DetectionPostProcess()->format(); }
int DetectionPostProcess::GetInputSize() const {
return this->primitive_->value_as_DetectionPostProcess()->inputSize();
}
float DetectionPostProcess::GetHScale() const { return this->primitive_->value_as_DetectionPostProcess()->hScale(); }
float DetectionPostProcess::GetWScale() const { return this->primitive_->value_as_DetectionPostProcess()->wScale(); }
float DetectionPostProcess::GetXScale() const { return this->primitive_->value_as_DetectionPostProcess()->xScale(); }
float DetectionPostProcess::GetYScale() const { return this->primitive_->value_as_DetectionPostProcess()->yScale(); }
float DetectionPostProcess::GetNmsIouThreshold() const {
return this->primitive->value_as_DetectionPostProcess()->NmsIouThreshold();
return this->primitive_->value_as_DetectionPostProcess()->NmsIouThreshold();
}
float DetectionPostProcess::GetNmsScoreThreshold() const {
return this->primitive->value_as_DetectionPostProcess()->NmsScoreThreshold();
return this->primitive_->value_as_DetectionPostProcess()->NmsScoreThreshold();
}
long DetectionPostProcess::GetMaxDetections() const {
return this->primitive->value_as_DetectionPostProcess()->MaxDetections();
return this->primitive_->value_as_DetectionPostProcess()->MaxDetections();
}
long DetectionPostProcess::GetDetectionsPreClass() const {
return this->primitive->value_as_DetectionPostProcess()->DetectionsPreClass();
return this->primitive_->value_as_DetectionPostProcess()->DetectionsPreClass();
}
long DetectionPostProcess::GetMaxClassesPreDetection() const {
return this->primitive->value_as_DetectionPostProcess()->MaxClassesPreDetection();
return this->primitive_->value_as_DetectionPostProcess()->MaxClassesPreDetection();
}
long DetectionPostProcess::GetNumClasses() const {
return this->primitive->value_as_DetectionPostProcess()->NumClasses();
return this->primitive_->value_as_DetectionPostProcess()->NumClasses();
}
bool DetectionPostProcess::GetUseRegularNms() const {
return this->primitive->value_as_DetectionPostProcess()->UseRegularNms();
return this->primitive_->value_as_DetectionPostProcess()->UseRegularNms();
}

void DetectionPostProcess::SetFormat(int format) {}


+ 7
- 9
mindspore/lite/src/ops/detection_post_process.h View File

@@ -14,25 +14,23 @@
* limitations under the License.
*/

#ifndef LITE_MINDSPORE_LITE_C_OPS_DETECTION_POST_PROCESS_H_
#define LITE_MINDSPORE_LITE_C_OPS_DETECTION_POST_PROCESS_H_

#include <vector>
#include <set>
#include <cmath>
#include "ir/dtype/type_id.h"
#include "src/ops/primitive_c.h"
#ifdef PRIMITIVE_WRITEABLE
#include "schema/inner/model_generated.h"
#else
#include "schema/model_generated.h"
#endif

#ifndef LITE_MINDSPORE_LITE_C_OPS_DETECTION_POST_PROCESS_H_
#define LITE_MINDSPORE_LITE_C_OPS_DETECTION_POST_PROCESS_H_

namespace mindspore {
namespace lite {
class DetectionPostProcess : public PrimitiveC {
public:
explicit DetectionPostProcess(OriginPrimitive *primitive) : PrimitiveC(primitive) {}
#ifdef PRIMITIVE_WRITEABLE
explicit DetectionPostProcess(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
#endif
explicit DetectionPostProcess(schema::Primitive *primitive) : PrimitiveC(primitive) {}

int GetFormat() const;
int GetInputSize() const;


+ 3
- 3
mindspore/lite/src/ops/div.cc View File

@@ -19,15 +19,15 @@
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
int Div::GetActivationType() const { return this->primitive->value.AsDiv()->activationType; }
int Div::GetActivationType() const { return this->primitive_->value.AsDiv()->activationType; }

void Div::SetActivationType(int activation_type) {
this->primitive->value.AsDiv()->activationType = (schema::ActivationType)activation_type;
this->primitive_->value.AsDiv()->activationType = (schema::ActivationType)activation_type;
}

#else

int Div::GetActivationType() const { return this->primitive->value_as_Div()->activationType(); }
int Div::GetActivationType() const { return this->primitive_->value_as_Div()->activationType(); }

void Div::SetActivationType(int activation_type) {}
#endif


+ 7
- 10
mindspore/lite/src/ops/div.h View File

@@ -14,26 +14,23 @@
* limitations under the License.
*/

#ifndef LITE_MINDSPORE_LITE_C_OPS_DIV_H_
#define LITE_MINDSPORE_LITE_C_OPS_DIV_H_

#include <vector>
#include <set>
#include <cmath>
#include "ir/dtype/type_id.h"
#include "src/ops/arithmetic.h"

#ifdef PRIMITIVE_WRITEABLE
#include "schema/inner/model_generated.h"
#else
#include "schema/model_generated.h"
#endif

#ifndef LITE_MINDSPORE_LITE_C_OPS_DIV_H_
#define LITE_MINDSPORE_LITE_C_OPS_DIV_H_

namespace mindspore {
namespace lite {
class Div : public Arithmetic {
public:
explicit Div(OriginPrimitive *primitive) : Arithmetic(primitive) {}
#ifdef PRIMITIVE_WRITEABLE
explicit Div(schema::PrimitiveT *primitive) : Arithmetic(primitive) {}
#endif
explicit Div(schema::Primitive *primitive) : Arithmetic(primitive) {}

int GetActivationType() const;
void SetActivationType(int activation_type);


+ 3
- 3
mindspore/lite/src/ops/dropout.cc View File

@@ -19,13 +19,13 @@
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
float Dropout::GetRatio() const { return this->primitive->value.AsDropout()->ratio; }
float Dropout::GetRatio() const { return this->primitive_->value.AsDropout()->ratio; }

void Dropout::SetRatio(float ratio) { this->primitive->value.AsDropout()->ratio = ratio; }
void Dropout::SetRatio(float ratio) { this->primitive_->value.AsDropout()->ratio = ratio; }

#else

float Dropout::GetRatio() const { return this->primitive->value_as_Dropout()->ratio(); }
float Dropout::GetRatio() const { return this->primitive_->value_as_Dropout()->ratio(); }

void Dropout::SetRatio(float ratio) {}
#endif


+ 7
- 9
mindspore/lite/src/ops/dropout.h View File

@@ -14,25 +14,23 @@
* limitations under the License.
*/

#ifndef LITE_MINDSPORE_LITE_C_OPS_DROPOUT_H_
#define LITE_MINDSPORE_LITE_C_OPS_DROPOUT_H_

#include <vector>
#include <set>
#include <cmath>
#include "ir/dtype/type_id.h"
#include "src/ops/primitive_c.h"
#ifdef PRIMITIVE_WRITEABLE
#include "schema/inner/model_generated.h"
#else
#include "schema/model_generated.h"
#endif

#ifndef LITE_MINDSPORE_LITE_C_OPS_DROPOUT_H_
#define LITE_MINDSPORE_LITE_C_OPS_DROPOUT_H_

namespace mindspore {
namespace lite {
class Dropout : public PrimitiveC {
public:
explicit Dropout(OriginPrimitive *primitive) : PrimitiveC(primitive) {}
#ifdef PRIMITIVE_WRITEABLE
explicit Dropout(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
#endif
explicit Dropout(schema::Primitive *primitive) : PrimitiveC(primitive) {}

float GetRatio() const;
void SetRatio(float ratio);


+ 3
- 3
mindspore/lite/src/ops/eltwise.cc View File

@@ -19,13 +19,13 @@
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
int Eltwise::GetMode() const { return this->primitive->value.AsEltwise()->mode; }
int Eltwise::GetMode() const { return this->primitive_->value.AsEltwise()->mode; }

void Eltwise::SetMode(int mode) { this->primitive->value.AsEltwise()->mode = (schema::EltwiseMode)mode; }
void Eltwise::SetMode(int mode) { this->primitive_->value.AsEltwise()->mode = (schema::EltwiseMode)mode; }

#else

int Eltwise::GetMode() const { return this->primitive->value_as_Eltwise()->mode(); }
int Eltwise::GetMode() const { return this->primitive_->value_as_Eltwise()->mode(); }

void Eltwise::SetMode(int mode) {}
#endif


+ 7
- 9
mindspore/lite/src/ops/eltwise.h View File

@@ -14,25 +14,23 @@
* limitations under the License.
*/

#ifndef LITE_MINDSPORE_LITE_C_OPS_ELTWISE_H_
#define LITE_MINDSPORE_LITE_C_OPS_ELTWISE_H_

#include <vector>
#include <set>
#include <cmath>
#include "ir/dtype/type_id.h"
#include "src/ops/primitive_c.h"
#ifdef PRIMITIVE_WRITEABLE
#include "schema/inner/model_generated.h"
#else
#include "schema/model_generated.h"
#endif

#ifndef LITE_MINDSPORE_LITE_C_OPS_ELTWISE_H_
#define LITE_MINDSPORE_LITE_C_OPS_ELTWISE_H_

namespace mindspore {
namespace lite {
class Eltwise : public PrimitiveC {
public:
explicit Eltwise(OriginPrimitive *primitive) : PrimitiveC(primitive) {}
#ifdef PRIMITIVE_WRITEABLE
explicit Eltwise(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
#endif
explicit Eltwise(schema::Primitive *primitive) : PrimitiveC(primitive) {}

int GetMode() const;
void SetMode(int mode);


+ 3
- 3
mindspore/lite/src/ops/elu.cc View File

@@ -19,13 +19,13 @@
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
float Elu::GetAlpha() const { return this->primitive->value.AsElu()->alpha; }
float Elu::GetAlpha() const { return this->primitive_->value.AsElu()->alpha; }

void Elu::SetAlpha(float alpha) { this->primitive->value.AsElu()->alpha = alpha; }
void Elu::SetAlpha(float alpha) { this->primitive_->value.AsElu()->alpha = alpha; }

#else

float Elu::GetAlpha() const { return this->primitive->value_as_Elu()->alpha(); }
float Elu::GetAlpha() const { return this->primitive_->value_as_Elu()->alpha(); }

void Elu::SetAlpha(float alpha) {}
#endif


+ 7
- 9
mindspore/lite/src/ops/elu.h View File

@@ -14,25 +14,23 @@
* limitations under the License.
*/

#ifndef LITE_MINDSPORE_LITE_C_OPS_ELU_H_
#define LITE_MINDSPORE_LITE_C_OPS_ELU_H_

#include <vector>
#include <set>
#include <cmath>
#include "ir/dtype/type_id.h"
#include "src/ops/primitive_c.h"
#ifdef PRIMITIVE_WRITEABLE
#include "schema/inner/model_generated.h"
#else
#include "schema/model_generated.h"
#endif

#ifndef LITE_MINDSPORE_LITE_C_OPS_ELU_H_
#define LITE_MINDSPORE_LITE_C_OPS_ELU_H_

namespace mindspore {
namespace lite {
class Elu : public PrimitiveC {
public:
explicit Elu(OriginPrimitive *primitive) : PrimitiveC(primitive) {}
#ifdef PRIMITIVE_WRITEABLE
explicit Elu(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
#endif
explicit Elu(schema::Primitive *primitive) : PrimitiveC(primitive) {}

float GetAlpha() const;
void SetAlpha(float alpha);


+ 4
- 4
mindspore/lite/src/ops/embedding_lookup.cc View File

@@ -19,19 +19,19 @@
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
float EmbeddingLookup::GetMaxNorm() const { return this->primitive->value.AsEmbeddingLookup()->maxNorm; }
float EmbeddingLookup::GetMaxNorm() const { return this->primitive_->value.AsEmbeddingLookup()->maxNorm; }

void EmbeddingLookup::SetMaxNorm(float max_norm) { this->primitive->value.AsEmbeddingLookup()->maxNorm = max_norm; }
void EmbeddingLookup::SetMaxNorm(float max_norm) { this->primitive_->value.AsEmbeddingLookup()->maxNorm = max_norm; }

#else

float EmbeddingLookup::GetMaxNorm() const { return this->primitive->value_as_EmbeddingLookup()->maxNorm(); }
float EmbeddingLookup::GetMaxNorm() const { return this->primitive_->value_as_EmbeddingLookup()->maxNorm(); }

void EmbeddingLookup::SetMaxNorm(float max_norm) {}
#endif

int EmbeddingLookup::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::Tensor *> outputs_) {
MS_ASSERT(this->primitive != nullptr);
MS_ASSERT(this->primitive_ != nullptr);
if (inputs_.size() < kDoubleNum) {
MS_LOG(ERROR) << "Embedding Lookup should have at least two inputs";
return RET_INPUT_TENSOR_ERROR;


+ 7
- 9
mindspore/lite/src/ops/embedding_lookup.h View File

@@ -14,25 +14,23 @@
* limitations under the License.
*/

#ifndef LITE_MINDSPORE_LITE_C_OPS_EMBEDDING_LOOKUP_H_
#define LITE_MINDSPORE_LITE_C_OPS_EMBEDDING_LOOKUP_H_

#include <vector>
#include <set>
#include <cmath>
#include "ir/dtype/type_id.h"
#include "src/ops/primitive_c.h"
#ifdef PRIMITIVE_WRITEABLE
#include "schema/inner/model_generated.h"
#else
#include "schema/model_generated.h"
#endif

#ifndef LITE_MINDSPORE_LITE_C_OPS_EMBEDDING_LOOKUP_H_
#define LITE_MINDSPORE_LITE_C_OPS_EMBEDDING_LOOKUP_H_

namespace mindspore {
namespace lite {
class EmbeddingLookup : public PrimitiveC {
public:
explicit EmbeddingLookup(OriginPrimitive *primitive) : PrimitiveC(primitive) {}
#ifdef PRIMITIVE_WRITEABLE
explicit EmbeddingLookup(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
#endif
explicit EmbeddingLookup(schema::Primitive *primitive) : PrimitiveC(primitive) {}

int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
float GetMaxNorm() const;


+ 9
- 9
mindspore/lite/src/ops/embedding_lookup_sparse.cc View File

@@ -20,35 +20,35 @@ namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
std::vector<int> EmbeddingLookupSparse::GetSpIds() const {
return this->primitive->value.AsEmbeddingLookupSparse()->spIds;
return this->primitive_->value.AsEmbeddingLookupSparse()->spIds;
}
std::vector<float> EmbeddingLookupSparse::GetSpWeights() const {
return this->primitive->value.AsEmbeddingLookupSparse()->spWeights;
return this->primitive_->value.AsEmbeddingLookupSparse()->spWeights;
}
float EmbeddingLookupSparse::GetMaxNortm() const { return this->primitive->value.AsEmbeddingLookupSparse()->maxNortm; }
float EmbeddingLookupSparse::GetMaxNortm() const { return this->primitive_->value.AsEmbeddingLookupSparse()->maxNortm; }

void EmbeddingLookupSparse::SetSpIds(const std::vector<int> &sp_ids) {
this->primitive->value.AsEmbeddingLookupSparse()->spIds = sp_ids;
this->primitive_->value.AsEmbeddingLookupSparse()->spIds = sp_ids;
}
void EmbeddingLookupSparse::SetSpWeights(const std::vector<float> &sp_weights) {
this->primitive->value.AsEmbeddingLookupSparse()->spWeights = sp_weights;
this->primitive_->value.AsEmbeddingLookupSparse()->spWeights = sp_weights;
}
void EmbeddingLookupSparse::SetMaxNortm(float max_nortm) {
this->primitive->value.AsEmbeddingLookupSparse()->maxNortm = max_nortm;
this->primitive_->value.AsEmbeddingLookupSparse()->maxNortm = max_nortm;
}

#else

std::vector<int> EmbeddingLookupSparse::GetSpIds() const {
auto fb_vector = this->primitive->value_as_EmbeddingLookupSparse()->spIds();
auto fb_vector = this->primitive_->value_as_EmbeddingLookupSparse()->spIds();
return std::vector<int>(fb_vector->begin(), fb_vector->end());
}
std::vector<float> EmbeddingLookupSparse::GetSpWeights() const {
auto fb_vector = this->primitive->value_as_EmbeddingLookupSparse()->spWeights();
auto fb_vector = this->primitive_->value_as_EmbeddingLookupSparse()->spWeights();
return std::vector<float>(fb_vector->begin(), fb_vector->end());
}
float EmbeddingLookupSparse::GetMaxNortm() const {
return this->primitive->value_as_EmbeddingLookupSparse()->maxNortm();
return this->primitive_->value_as_EmbeddingLookupSparse()->maxNortm();
}

void EmbeddingLookupSparse::SetSpIds(const std::vector<int> &sp_ids) {}


+ 7
- 9
mindspore/lite/src/ops/embedding_lookup_sparse.h View File

@@ -14,25 +14,23 @@
* limitations under the License.
*/

#ifndef LITE_MINDSPORE_LITE_C_OPS_EMBEDDING_LOOKUP_SPARSE_H_
#define LITE_MINDSPORE_LITE_C_OPS_EMBEDDING_LOOKUP_SPARSE_H_

#include <vector>
#include <set>
#include <cmath>
#include "ir/dtype/type_id.h"
#include "src/ops/primitive_c.h"
#ifdef PRIMITIVE_WRITEABLE
#include "schema/inner/model_generated.h"
#else
#include "schema/model_generated.h"
#endif

#ifndef LITE_MINDSPORE_LITE_C_OPS_EMBEDDING_LOOKUP_SPARSE_H_
#define LITE_MINDSPORE_LITE_C_OPS_EMBEDDING_LOOKUP_SPARSE_H_

namespace mindspore {
namespace lite {
class EmbeddingLookupSparse : public PrimitiveC {
public:
explicit EmbeddingLookupSparse(OriginPrimitive *primitive) : PrimitiveC(primitive) {}
#ifdef PRIMITIVE_WRITEABLE
explicit EmbeddingLookupSparse(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
#endif
explicit EmbeddingLookupSparse(schema::Primitive *primitive) : PrimitiveC(primitive) {}

std::vector<int> GetSpIds() const;
std::vector<float> GetSpWeights() const;


+ 7
- 9
mindspore/lite/src/ops/equal.h View File

@@ -14,25 +14,23 @@
* limitations under the License.
*/

#ifndef LITE_MINDSPORE_LITE_C_OPS_EQUAL_H_
#define LITE_MINDSPORE_LITE_C_OPS_EQUAL_H_

#include <vector>
#include <set>
#include <cmath>
#include "ir/dtype/type_id.h"
#include "src/ops/arithmetic.h"
#ifdef PRIMITIVE_WRITEABLE
#include "schema/inner/model_generated.h"
#else
#include "schema/model_generated.h"
#endif

#ifndef LITE_MINDSPORE_LITE_C_OPS_EQUAL_H_
#define LITE_MINDSPORE_LITE_C_OPS_EQUAL_H_

namespace mindspore {
namespace lite {
class Equal : public Arithmetic {
public:
explicit Equal(OriginPrimitive *primitive) : Arithmetic(primitive) {}
#ifdef PRIMITIVE_WRITEABLE
explicit Equal(schema::PrimitiveT *primitive) : Arithmetic(primitive) {}
#endif
explicit Equal(schema::Primitive *primitive) : Arithmetic(primitive) {}
};
} // namespace lite
} // namespace mindspore


+ 7
- 9
mindspore/lite/src/ops/exp.h View File

@@ -14,25 +14,23 @@
* limitations under the License.
*/

#ifndef LITE_MINDSPORE_LITE_C_OPS_EXP_H_
#define LITE_MINDSPORE_LITE_C_OPS_EXP_H_

#include <vector>
#include <set>
#include <cmath>
#include "ir/dtype/type_id.h"
#include "src/ops/arithmetic_self.h"
#ifdef PRIMITIVE_WRITEABLE
#include "schema/inner/model_generated.h"
#else
#include "schema/model_generated.h"
#endif

#ifndef LITE_MINDSPORE_LITE_C_OPS_EXP_H_
#define LITE_MINDSPORE_LITE_C_OPS_EXP_H_

namespace mindspore {
namespace lite {
class Exp : public ArithmeticSelf {
public:
explicit Exp(OriginPrimitive *primitive) : ArithmeticSelf(primitive) {}
#ifdef PRIMITIVE_WRITEABLE
explicit Exp(schema::PrimitiveT *primitive) : ArithmeticSelf(primitive) {}
#endif
explicit Exp(schema::Primitive *primitive) : ArithmeticSelf(primitive) {}
};
} // namespace lite
} // namespace mindspore


+ 4
- 4
mindspore/lite/src/ops/expand_dims.cc View File

@@ -19,19 +19,19 @@
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
int ExpandDims::GetDim() const { return this->primitive->value.AsExpandDims()->dim; }
int ExpandDims::GetDim() const { return this->primitive_->value.AsExpandDims()->dim; }

void ExpandDims::SetDim(int dim) { this->primitive->value.AsExpandDims()->dim = dim; }
void ExpandDims::SetDim(int dim) { this->primitive_->value.AsExpandDims()->dim = dim; }

#else

int ExpandDims::GetDim() const { return this->primitive->value_as_ExpandDims()->dim(); }
int ExpandDims::GetDim() const { return this->primitive_->value_as_ExpandDims()->dim(); }

void ExpandDims::SetDim(int dim) {}
#endif

int ExpandDims::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::Tensor *> outputs_) {
MS_ASSERT(this->primitive != nullptr);
MS_ASSERT(this->primitive_ != nullptr);
auto input = inputs_.front();
MS_ASSERT(input != nullptr);
auto output = outputs_.front();


+ 7
- 9
mindspore/lite/src/ops/expand_dims.h View File

@@ -14,25 +14,23 @@
* limitations under the License.
*/

#ifndef LITE_MINDSPORE_LITE_C_OPS_EXPAND_DIMS_H_
#define LITE_MINDSPORE_LITE_C_OPS_EXPAND_DIMS_H_

#include <vector>
#include <set>
#include <cmath>
#include "ir/dtype/type_id.h"
#include "src/ops/primitive_c.h"
#ifdef PRIMITIVE_WRITEABLE
#include "schema/inner/model_generated.h"
#else
#include "schema/model_generated.h"
#endif

#ifndef LITE_MINDSPORE_LITE_C_OPS_EXPAND_DIMS_H_
#define LITE_MINDSPORE_LITE_C_OPS_EXPAND_DIMS_H_

namespace mindspore {
namespace lite {
class ExpandDims : public PrimitiveC {
public:
explicit ExpandDims(OriginPrimitive *primitive) : PrimitiveC(primitive) {}
#ifdef PRIMITIVE_WRITEABLE
explicit ExpandDims(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
#endif
explicit ExpandDims(schema::Primitive *primitive) : PrimitiveC(primitive) {}

int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
int GetDim() const;


+ 6
- 6
mindspore/lite/src/ops/fake_quant_with_min_max_vars.cc View File

@@ -20,24 +20,24 @@ namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
bool FakeQuantWithMinMaxVars::GetNarrowRange() const {
return this->primitive->value.AsFakeQuantWithMinMaxVars()->narrowRange;
return this->primitive_->value.AsFakeQuantWithMinMaxVars()->narrowRange;
}
int FakeQuantWithMinMaxVars::GetNumBits() const { return this->primitive->value.AsFakeQuantWithMinMaxVars()->numBits; }
int FakeQuantWithMinMaxVars::GetNumBits() const { return this->primitive_->value.AsFakeQuantWithMinMaxVars()->numBits; }

void FakeQuantWithMinMaxVars::SetNarrowRange(bool narrow_range) {
this->primitive->value.AsFakeQuantWithMinMaxVars()->narrowRange = narrow_range;
this->primitive_->value.AsFakeQuantWithMinMaxVars()->narrowRange = narrow_range;
}
void FakeQuantWithMinMaxVars::SetNumBits(int num_bits) {
this->primitive->value.AsFakeQuantWithMinMaxVars()->numBits = num_bits;
this->primitive_->value.AsFakeQuantWithMinMaxVars()->numBits = num_bits;
}

#else

bool FakeQuantWithMinMaxVars::GetNarrowRange() const {
return this->primitive->value_as_FakeQuantWithMinMaxVars()->narrowRange();
return this->primitive_->value_as_FakeQuantWithMinMaxVars()->narrowRange();
}
int FakeQuantWithMinMaxVars::GetNumBits() const {
return this->primitive->value_as_FakeQuantWithMinMaxVars()->numBits();
return this->primitive_->value_as_FakeQuantWithMinMaxVars()->numBits();
}

void FakeQuantWithMinMaxVars::SetNarrowRange(bool narrow_range) {}


+ 7
- 9
mindspore/lite/src/ops/fake_quant_with_min_max_vars.h View File

@@ -14,25 +14,23 @@
* limitations under the License.
*/

#ifndef LITE_MINDSPORE_LITE_C_OPS_FAKE_QUANT_WITH_MIN_MAX_VARS_H_
#define LITE_MINDSPORE_LITE_C_OPS_FAKE_QUANT_WITH_MIN_MAX_VARS_H_

#include <vector>
#include <set>
#include <cmath>
#include "ir/dtype/type_id.h"
#include "src/ops/primitive_c.h"
#ifdef PRIMITIVE_WRITEABLE
#include "schema/inner/model_generated.h"
#else
#include "schema/model_generated.h"
#endif

#ifndef LITE_MINDSPORE_LITE_C_OPS_FAKE_QUANT_WITH_MIN_MAX_VARS_H_
#define LITE_MINDSPORE_LITE_C_OPS_FAKE_QUANT_WITH_MIN_MAX_VARS_H_

namespace mindspore {
namespace lite {
class FakeQuantWithMinMaxVars : public PrimitiveC {
public:
explicit FakeQuantWithMinMaxVars(OriginPrimitive *primitive) : PrimitiveC(primitive) {}
#ifdef PRIMITIVE_WRITEABLE
explicit FakeQuantWithMinMaxVars(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
#endif
explicit FakeQuantWithMinMaxVars(schema::Primitive *primitive) : PrimitiveC(primitive) {}

bool GetNarrowRange() const;
int GetNumBits() const;


+ 4
- 4
mindspore/lite/src/ops/fill.cc View File

@@ -19,14 +19,14 @@
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
std::vector<int> Fill::GetDims() const { return this->primitive->value.AsFill()->dims; }
std::vector<int> Fill::GetDims() const { return this->primitive_->value.AsFill()->dims; }

void Fill::SetDims(const std::vector<int> &dims) { this->primitive->value.AsFill()->dims = dims; }
void Fill::SetDims(const std::vector<int> &dims) { this->primitive_->value.AsFill()->dims = dims; }

#else

std::vector<int> Fill::GetDims() const {
auto fb_vector = this->primitive->value_as_Fill()->dims();
auto fb_vector = this->primitive_->value_as_Fill()->dims();
return std::vector<int>(fb_vector->begin(), fb_vector->end());
}

@@ -34,7 +34,7 @@ void Fill::SetDims(const std::vector<int> &dims) {}
#endif

int Fill::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::Tensor *> outputs_) {
MS_ASSERT(this->primitive != nullptr);
MS_ASSERT(this->primitive_ != nullptr);
auto input = inputs_.front();
auto output = outputs_.front();
if (input == nullptr || output == nullptr) {


+ 7
- 9
mindspore/lite/src/ops/fill.h View File

@@ -14,25 +14,23 @@
* limitations under the License.
*/

#ifndef LITE_MINDSPORE_LITE_C_OPS_FILL_H_
#define LITE_MINDSPORE_LITE_C_OPS_FILL_H_

#include <vector>
#include <set>
#include <cmath>
#include "ir/dtype/type_id.h"
#include "src/ops/primitive_c.h"
#ifdef PRIMITIVE_WRITEABLE
#include "schema/inner/model_generated.h"
#else
#include "schema/model_generated.h"
#endif

#ifndef LITE_MINDSPORE_LITE_C_OPS_FILL_H_
#define LITE_MINDSPORE_LITE_C_OPS_FILL_H_

namespace mindspore {
namespace lite {
class Fill : public PrimitiveC {
public:
explicit Fill(OriginPrimitive *primitive) : PrimitiveC(primitive) {}
#ifdef PRIMITIVE_WRITEABLE
explicit Fill(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
#endif
explicit Fill(schema::Primitive *primitive) : PrimitiveC(primitive) {}

int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
std::vector<int> GetDims() const;


+ 1
- 1
mindspore/lite/src/ops/flatten.cc View File

@@ -20,7 +20,7 @@ namespace mindspore {
namespace lite {

int Flatten::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::Tensor *> outputs_) {
MS_ASSERT(this->primitive != nullptr);
MS_ASSERT(this->primitive_ != nullptr);
auto input = inputs_.front();
auto output = outputs_.front();
if (input == nullptr || output == nullptr) {


+ 7
- 9
mindspore/lite/src/ops/flatten.h View File

@@ -14,25 +14,23 @@
* limitations under the License.
*/

#ifndef LITE_MINDSPORE_LITE_C_OPS_FLATTEN_H_
#define LITE_MINDSPORE_LITE_C_OPS_FLATTEN_H_

#include <vector>
#include <set>
#include <cmath>
#include "ir/dtype/type_id.h"
#include "src/ops/primitive_c.h"
#ifdef PRIMITIVE_WRITEABLE
#include "schema/inner/model_generated.h"
#else
#include "schema/model_generated.h"
#endif

#ifndef LITE_MINDSPORE_LITE_C_OPS_FLATTEN_H_
#define LITE_MINDSPORE_LITE_C_OPS_FLATTEN_H_

namespace mindspore {
namespace lite {
class Flatten : public PrimitiveC {
public:
explicit Flatten(OriginPrimitive *primitive) : PrimitiveC(primitive) {}
#ifdef PRIMITIVE_WRITEABLE
explicit Flatten(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
#endif
explicit Flatten(schema::Primitive *primitive) : PrimitiveC(primitive) {}

int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
};


+ 8
- 10
mindspore/lite/src/ops/floor.h View File

@@ -14,25 +14,23 @@
* limitations under the License.
*/

#ifndef LITE_MINDSPORE_LITE_C_OPS_FLOOR_H_
#define LITE_MINDSPORE_LITE_C_OPS_FLOOR_H_

#include <vector>
#include <set>
#include <cmath>
#include "ir/dtype/type_id.h"
#include "src/ops/arithmetic_self.h"
#ifdef PRIMITIVE_WRITEABLE
#include "schema/inner/model_generated.h"
#else
#include "schema/model_generated.h"
#endif

#ifndef LITE_MINDSPORE_LITE_C_OPS_FLOOR_H_
#define LITE_MINDSPORE_LITE_C_OPS_FLOOR_H_
#include "src/ops/primitive_c.h"

namespace mindspore {
namespace lite {
class Floor : public ArithmeticSelf {
public:
explicit Floor(OriginPrimitive *primitive) : ArithmeticSelf(primitive) {}
#ifdef PRIMITIVE_WRITEABLE
explicit Floor(schema::PrimitiveT *primitive) : ArithmeticSelf(primitive) {}
#endif
explicit Floor(schema::Primitive *primitive) : ArithmeticSelf(primitive) {}
};
} // namespace lite
} // namespace mindspore


+ 7
- 9
mindspore/lite/src/ops/floor_div.h View File

@@ -14,25 +14,23 @@
* limitations under the License.
*/

#ifndef LITE_MINDSPORE_LITE_C_OPS_FLOOR_DIV_H_
#define LITE_MINDSPORE_LITE_C_OPS_FLOOR_DIV_H_

#include <vector>
#include <set>
#include <cmath>
#include "ir/dtype/type_id.h"
#include "src/ops/arithmetic.h"
#ifdef PRIMITIVE_WRITEABLE
#include "schema/inner/model_generated.h"
#else
#include "schema/model_generated.h"
#endif

#ifndef LITE_MINDSPORE_LITE_C_OPS_FLOOR_DIV_H_
#define LITE_MINDSPORE_LITE_C_OPS_FLOOR_DIV_H_

namespace mindspore {
namespace lite {
class FloorDiv : public Arithmetic {
public:
explicit FloorDiv(OriginPrimitive *primitive) : Arithmetic(primitive) {}
#ifdef PRIMITIVE_WRITEABLE
explicit FloorDiv(schema::PrimitiveT *primitive) : Arithmetic(primitive) {}
#endif
explicit FloorDiv(schema::Primitive *primitive) : Arithmetic(primitive) {}
};
} // namespace lite
} // namespace mindspore


+ 7
- 9
mindspore/lite/src/ops/floor_mod.h View File

@@ -14,25 +14,23 @@
* limitations under the License.
*/

#ifndef LITE_MINDSPORE_LITE_C_OPS_FLOOR_MOD_H_
#define LITE_MINDSPORE_LITE_C_OPS_FLOOR_MOD_H_

#include <vector>
#include <set>
#include <cmath>
#include "ir/dtype/type_id.h"
#include "src/ops/arithmetic.h"
#ifdef PRIMITIVE_WRITEABLE
#include "schema/inner/model_generated.h"
#else
#include "schema/model_generated.h"
#endif

#ifndef LITE_MINDSPORE_LITE_C_OPS_FLOOR_MOD_H_
#define LITE_MINDSPORE_LITE_C_OPS_FLOOR_MOD_H_

namespace mindspore {
namespace lite {
class FloorMod : public Arithmetic {
public:
explicit FloorMod(OriginPrimitive *primitive) : Arithmetic(primitive) {}
#ifdef PRIMITIVE_WRITEABLE
explicit FloorMod(schema::PrimitiveT *primitive) : Arithmetic(primitive) {}
#endif
explicit FloorMod(schema::Primitive *primitive) : Arithmetic(primitive) {}
};
} // namespace lite
} // namespace mindspore


+ 13
- 13
mindspore/lite/src/ops/full_connection.cc View File

@@ -19,23 +19,23 @@
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
bool FullConnection::GetHasBias() const { return this->primitive->value.AsFullConnection()->hasBias; }
int FullConnection::GetAxis() const { return this->primitive->value.AsFullConnection()->axis; }
bool FullConnection::GetUseAxis() const { return this->primitive->value.AsFullConnection()->useAxis; }
int FullConnection::GetActivationType() const { return this->primitive->value.AsFullConnection()->activationType; }
bool FullConnection::GetHasBias() const { return this->primitive_->value.AsFullConnection()->hasBias; }
int FullConnection::GetAxis() const { return this->primitive_->value.AsFullConnection()->axis; }
bool FullConnection::GetUseAxis() const { return this->primitive_->value.AsFullConnection()->useAxis; }
int FullConnection::GetActivationType() const { return this->primitive_->value.AsFullConnection()->activationType; }

void FullConnection::SetHasBias(bool has_bias) { this->primitive->value.AsFullConnection()->hasBias = has_bias; }
void FullConnection::SetAxis(int axis) { this->primitive->value.AsFullConnection()->axis = axis; }
void FullConnection::SetUseAxis(bool use_axis) { this->primitive->value.AsFullConnection()->useAxis = use_axis; }
void FullConnection::SetHasBias(bool has_bias) { this->primitive_->value.AsFullConnection()->hasBias = has_bias; }
void FullConnection::SetAxis(int axis) { this->primitive_->value.AsFullConnection()->axis = axis; }
void FullConnection::SetUseAxis(bool use_axis) { this->primitive_->value.AsFullConnection()->useAxis = use_axis; }
void FullConnection::SetActivationType(int activationType) {
this->primitive->value.AsFullConnection()->activationType = (schema::ActivationType)activationType;
this->primitive_->value.AsFullConnection()->activationType = (schema::ActivationType)activationType;
}
#else

bool FullConnection::GetHasBias() const { return this->primitive->value_as_FullConnection()->hasBias(); }
int FullConnection::GetAxis() const { return this->primitive->value_as_FullConnection()->axis(); }
bool FullConnection::GetUseAxis() const { return this->primitive->value_as_FullConnection()->useAxis(); }
int FullConnection::GetActivationType() const { return this->primitive->value_as_FullConnection()->activationType(); }
bool FullConnection::GetHasBias() const { return this->primitive_->value_as_FullConnection()->hasBias(); }
int FullConnection::GetAxis() const { return this->primitive_->value_as_FullConnection()->axis(); }
bool FullConnection::GetUseAxis() const { return this->primitive_->value_as_FullConnection()->useAxis(); }
int FullConnection::GetActivationType() const { return this->primitive_->value_as_FullConnection()->activationType(); }

void FullConnection::SetHasBias(bool has_bias) {}
void FullConnection::SetAxis(int axis) {}
@@ -44,7 +44,7 @@ void FullConnection::SetActivationType(int activationType) {}
#endif
int FullConnection::InferShape(std::vector<lite::tensor::Tensor *> inputs_,
std::vector<lite::tensor::Tensor *> outputs_) {
MS_ASSERT(this->primitive != nullptr);
MS_ASSERT(this->primitive_ != nullptr);
auto input0 = inputs_.front();
MS_ASSERT(input0 != nullptr);
auto input1 = inputs_[1];


+ 7
- 9
mindspore/lite/src/ops/full_connection.h View File

@@ -14,25 +14,23 @@
* limitations under the License.
*/

#ifndef LITE_MINDSPORE_LITE_C_OPS_FULL_CONNECTION_H_
#define LITE_MINDSPORE_LITE_C_OPS_FULL_CONNECTION_H_

#include <vector>
#include <set>
#include <cmath>
#include "ir/dtype/type_id.h"
#include "src/ops/primitive_c.h"
#ifdef PRIMITIVE_WRITEABLE
#include "schema/inner/model_generated.h"
#else
#include "schema/model_generated.h"
#endif

#ifndef LITE_MINDSPORE_LITE_C_OPS_FULL_CONNECTION_H_
#define LITE_MINDSPORE_LITE_C_OPS_FULL_CONNECTION_H_

namespace mindspore {
namespace lite {
class FullConnection : public PrimitiveC {
public:
explicit FullConnection(OriginPrimitive *primitive) : PrimitiveC(primitive) {}
#ifdef PRIMITIVE_WRITEABLE
explicit FullConnection(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
#endif
explicit FullConnection(schema::Primitive *primitive) : PrimitiveC(primitive) {}

int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
bool GetHasBias() const;


+ 9
- 9
mindspore/lite/src/ops/fused_batchnorm.cc View File

@@ -19,19 +19,19 @@
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
float FusedBatchNorm::GetEpsilon() const { return this->primitive->value.AsFusedBatchNorm()->epsilon; }
float FusedBatchNorm::GetMomentum() const { return this->primitive->value.AsFusedBatchNorm()->momentum; }
int FusedBatchNorm::GetSpatial() const { return this->primitive->value.AsFusedBatchNorm()->spatial; }
float FusedBatchNorm::GetEpsilon() const { return this->primitive_->value.AsFusedBatchNorm()->epsilon; }
float FusedBatchNorm::GetMomentum() const { return this->primitive_->value.AsFusedBatchNorm()->momentum; }
int FusedBatchNorm::GetSpatial() const { return this->primitive_->value.AsFusedBatchNorm()->spatial; }

void FusedBatchNorm::SetEpsilon(float epsilon) { this->primitive->value.AsFusedBatchNorm()->epsilon = epsilon; }
void FusedBatchNorm::SetMomentum(float momentum) { this->primitive->value.AsFusedBatchNorm()->momentum = momentum; }
void FusedBatchNorm::SetSpatial(int spatial) { this->primitive->value.AsFusedBatchNorm()->spatial = spatial; }
void FusedBatchNorm::SetEpsilon(float epsilon) { this->primitive_->value.AsFusedBatchNorm()->epsilon = epsilon; }
void FusedBatchNorm::SetMomentum(float momentum) { this->primitive_->value.AsFusedBatchNorm()->momentum = momentum; }
void FusedBatchNorm::SetSpatial(int spatial) { this->primitive_->value.AsFusedBatchNorm()->spatial = spatial; }

#else

float FusedBatchNorm::GetEpsilon() const { return this->primitive->value_as_FusedBatchNorm()->epsilon(); }
float FusedBatchNorm::GetMomentum() const { return this->primitive->value_as_FusedBatchNorm()->momentum(); }
int FusedBatchNorm::GetSpatial() const { return this->primitive->value_as_FusedBatchNorm()->spatial(); }
float FusedBatchNorm::GetEpsilon() const { return this->primitive_->value_as_FusedBatchNorm()->epsilon(); }
float FusedBatchNorm::GetMomentum() const { return this->primitive_->value_as_FusedBatchNorm()->momentum(); }
int FusedBatchNorm::GetSpatial() const { return this->primitive_->value_as_FusedBatchNorm()->spatial(); }

void FusedBatchNorm::SetEpsilon(float epsilon) {}
void FusedBatchNorm::SetMomentum(float momentum) {}


+ 7
- 9
mindspore/lite/src/ops/fused_batchnorm.h View File

@@ -14,25 +14,23 @@
* limitations under the License.
*/

#ifndef LITE_MINDSPORE_LITE_C_OPS_FUSED_BATCH_NORM_H_
#define LITE_MINDSPORE_LITE_C_OPS_FUSED_BATCH_NORM_H_

#include <vector>
#include <set>
#include <cmath>
#include "ir/dtype/type_id.h"
#include "src/ops/primitive_c.h"
#ifdef PRIMITIVE_WRITEABLE
#include "schema/inner/model_generated.h"
#else
#include "schema/model_generated.h"
#endif

#ifndef LITE_MINDSPORE_LITE_C_OPS_FUSED_BATCH_NORM_H_
#define LITE_MINDSPORE_LITE_C_OPS_FUSED_BATCH_NORM_H_

namespace mindspore {
namespace lite {
class FusedBatchNorm : public PrimitiveC {
public:
explicit FusedBatchNorm(OriginPrimitive *primitive) : PrimitiveC(primitive) {}
#ifdef PRIMITIVE_WRITEABLE
explicit FusedBatchNorm(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
#endif
explicit FusedBatchNorm(schema::Primitive *primitive) : PrimitiveC(primitive) {}

float GetEpsilon() const;
float GetMomentum() const;


+ 7
- 7
mindspore/lite/src/ops/gather.cc View File

@@ -22,23 +22,23 @@
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
int Gather::GetAxis() const { return this->primitive->value.AsGather()->axis; }
int Gather::GetBatchDims() const { return this->primitive->value.AsGather()->batchDims; }
int Gather::GetAxis() const { return this->primitive_->value.AsGather()->axis; }
int Gather::GetBatchDims() const { return this->primitive_->value.AsGather()->batchDims; }

void Gather::SetAxis(int axis) { this->primitive->value.AsGather()->axis = axis; }
void Gather::SetBatchDims(int batch_dims) { this->primitive->value.AsGather()->batchDims = batch_dims; }
void Gather::SetAxis(int axis) { this->primitive_->value.AsGather()->axis = axis; }
void Gather::SetBatchDims(int batch_dims) { this->primitive_->value.AsGather()->batchDims = batch_dims; }

#else

int Gather::GetAxis() const { return this->primitive->value_as_Gather()->axis(); }
int Gather::GetBatchDims() const { return this->primitive->value_as_Gather()->batchDims(); }
int Gather::GetAxis() const { return this->primitive_->value_as_Gather()->axis(); }
int Gather::GetBatchDims() const { return this->primitive_->value_as_Gather()->batchDims(); }

void Gather::SetAxis(int axis) {}
void Gather::SetBatchDims(int batch_dims) {}
#endif

int Gather::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::Tensor *> outputs_) {
MS_ASSERT(this->primitive != nullptr);
MS_ASSERT(this->primitive_ != nullptr);
if (inputs_.size() != kDoubleNum) {
MS_LOG(ERROR) << "Gather should have two inputs";
return RET_INPUT_TENSOR_ERROR;


+ 7
- 9
mindspore/lite/src/ops/gather.h View File

@@ -14,25 +14,23 @@
* limitations under the License.
*/

#ifndef LITE_MINDSPORE_LITE_C_OPS_GATHER_H_
#define LITE_MINDSPORE_LITE_C_OPS_GATHER_H_

#include <vector>
#include <set>
#include <cmath>
#include "ir/dtype/type_id.h"
#include "src/ops/primitive_c.h"
#ifdef PRIMITIVE_WRITEABLE
#include "schema/inner/model_generated.h"
#else
#include "schema/model_generated.h"
#endif

#ifndef LITE_MINDSPORE_LITE_C_OPS_GATHER_H_
#define LITE_MINDSPORE_LITE_C_OPS_GATHER_H_

namespace mindspore {
namespace lite {
class Gather : public PrimitiveC {
public:
explicit Gather(OriginPrimitive *primitive) : PrimitiveC(primitive) {}
#ifdef PRIMITIVE_WRITEABLE
explicit Gather(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
#endif
explicit Gather(schema::Primitive *primitive) : PrimitiveC(primitive) {}

int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
int GetAxis() const;


+ 4
- 4
mindspore/lite/src/ops/gather_nd.cc View File

@@ -19,19 +19,19 @@
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
int GatherNd::GetBatchDims() const { return this->primitive->value.AsGatherNd()->batchDims; }
int GatherNd::GetBatchDims() const { return this->primitive_->value.AsGatherNd()->batchDims; }

void GatherNd::SetBatchDims(int batch_dims) { this->primitive->value.AsGatherNd()->batchDims = batch_dims; }
void GatherNd::SetBatchDims(int batch_dims) { this->primitive_->value.AsGatherNd()->batchDims = batch_dims; }

#else

int GatherNd::GetBatchDims() const { return this->primitive->value_as_GatherNd()->batchDims(); }
int GatherNd::GetBatchDims() const { return this->primitive_->value_as_GatherNd()->batchDims(); }

void GatherNd::SetBatchDims(int batch_dims) {}
#endif

int GatherNd::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::Tensor *> outputs_) {
MS_ASSERT(this->primitive != nullptr);
MS_ASSERT(this->primitive_ != nullptr);
if (inputs_.size() != kDoubleNum) {
MS_LOG(ERROR) << "GatherNd should have two inputs";
return RET_INPUT_TENSOR_ERROR;


Some files were not shown because too many files changed in this diff

Loading…
Cancel
Save