diff --git a/mindspore/lite/nnacl/fp32/activation.c b/mindspore/lite/nnacl/fp32/activation.c index a33c8368f4..1124fd67cc 100644 --- a/mindspore/lite/nnacl/fp32/activation.c +++ b/mindspore/lite/nnacl/fp32/activation.c @@ -117,6 +117,14 @@ int HSwish(const float *src, int length, float *dst) { return NNACL_OK; } +int HSigmoid(const float *src, int length, float *dst) { + for (int i = 0; i < length; ++i) { + float relu6 = MSMIN(MSMAX(src[i] + 3, 0), 6); + dst[i] = relu6 / 6; + } + return NNACL_OK; +} + int HardTanh(const float *src, int length, float *dst, float min_val, float max_val) { if (max_val <= min_val) { return NNACL_ERR; diff --git a/mindspore/lite/nnacl/fp32/activation.h b/mindspore/lite/nnacl/fp32/activation.h index a387d94c44..bd85832342 100644 --- a/mindspore/lite/nnacl/fp32/activation.h +++ b/mindspore/lite/nnacl/fp32/activation.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_LITE_NNACL_ACTIVATION_H_ -#define MINDSPORE_LITE_NNACL_ACTIVATION_H_ +#ifndef MINDSPORE_LITE_NNACL_FP32_ACTIVATION_H_ +#define MINDSPORE_LITE_NNACL_FP32_ACTIVATION_H_ #include #include "nnacl/op_base.h" @@ -36,9 +36,10 @@ int Fp32Relu6(const float *src, int length, float *dst); int LRelu(const float *src, int length, float *dst, float alpha); int Sigmoid(const float *src, int length, float *dst); int Tanh(const float *src, int length, float *dst); +int HSigmoid(const float *src, int length, float *dst); int HSwish(const float *src, int length, float *dst); int HardTanh(const float *src, int length, float *dst, float min_val, float max_val); #ifdef __cplusplus } #endif -#endif // MINDSPORE_LITE_NNACL_ACTIVATION_H_ +#endif // MINDSPORE_LITE_NNACL_FP32_ACTIVATION_H_ diff --git a/mindspore/lite/nnacl/fp32/one_hot.c b/mindspore/lite/nnacl/fp32/one_hot.c index 74444d7e0e..913f375754 100644 --- a/mindspore/lite/nnacl/fp32/one_hot.c +++ b/mindspore/lite/nnacl/fp32/one_hot.c @@ -31,14 +31,14 @@ int OneHot(const int *indices, float *output, const OneHotParameter *one_hot_par int i, j, k; for (i = tid; i < outer_size; i += thread_num) { float *output_ptr = output + i * depth * inner_size; - for (k = 0; k < depth; k++) { - for (j = 0; j < inner_size; j++) { + for (k = 0; k < inner_size; k++) { + int index = indices[i * inner_size + k]; + for (j = 0; j < depth; j++) { *output_ptr = off_value; - int index = indices[i * inner_size + j]; if (index >= depth) { return NNACL_ERRCODE_INDEX_OUT_OF_RANGE; } - if (index == k) { + if (index == j) { *output_ptr = on_value; } output_ptr++; diff --git a/mindspore/lite/nnacl/fp32_grad/gemm.c b/mindspore/lite/nnacl/fp32_grad/gemm.c index 4791ef04ba..0ec25c4141 100644 --- a/mindspore/lite/nnacl/fp32_grad/gemm.c +++ b/mindspore/lite/nnacl/fp32_grad/gemm.c @@ -15,27 +15,52 @@ */ #include "nnacl/fp32_grad/gemm.h" +#include + +static void gemm_not_trana_not_tranb(int M, int N, int K, float alpha, float *mat_a, int lda, float *mat_b, int ldb, + float *mat_c, int ldc) { + const int block_size = 4; + int block_mod = N % block_size; + int block_c4 = N - block_mod; -static void gemm_nn(int M, int N, int K, float alpha, float *mat_a, int lda, float *mat_B, int ldb, float *mat_c, - int ldc) { int i, j, k; for (i = 0; i < M; ++i) { for (k = 0; k < K; ++k) { float a = alpha * mat_a[i * lda + k]; - for (j = 0; j < N; ++j) { - mat_c[i * ldc + j] += a * mat_B[k * ldb + j]; + for (j = 0; j < block_c4; j += block_size) { + float *b = &mat_b[k * ldb + j]; + float *c = &mat_c[i * ldc + j]; + c[0] += a * b[0]; + c[1] += a * b[1]; + c[2] += a * b[2]; + c[3] += a * b[3]; + } + for (; j < N; ++j) { + mat_c[i * ldc + j] += a * mat_b[k * ldb + j]; } } } } -static void gemm_nt(int M, int N, int K, float alpha, float *mat_a, int lda, float *mat_b, int ldb, float *mat_c, - int ldc) { +static void gemm_not_trana_tranb(int M, int N, int K, float alpha, float *mat_a, int lda, float *mat_b, int ldb, + float *mat_c, int ldc) { + const int block_size = 4; + int block_mod = K % block_size; + int block_c4 = K - block_mod; + int i, j, k; for (i = 0; i < M; ++i) { for (j = 0; j < N; ++j) { float sum = 0; - for (k = 0; k < K; ++k) { + for (k = 0; k < block_c4; k += block_size) { + float *a = &mat_a[i * lda + k]; + float *b = &mat_b[j * ldb + k]; + sum += alpha * a[0] * b[0]; + sum += alpha * a[1] * b[1]; + sum += alpha * a[2] * b[2]; + sum += alpha * a[3] * b[3]; + } + for (; k < K; ++k) { sum += alpha * mat_a[i * lda + k] * mat_b[j * ldb + k]; } mat_c[i * ldc + j] += sum; @@ -43,23 +68,85 @@ static void gemm_nt(int M, int N, int K, float alpha, float *mat_a, int lda, flo } } -static void gemm_tn(int M, int N, int K, float alpha, float *mat_a, int lda, float *mat_b, int ldb, float *mat_c, - int ldc) { +static void gemm_trana_not_tranb(int M, int N, int K, float alpha, float *mat_a, int lda, float *mat_b, int ldb, + float *mat_c, int ldc) { + const int block_size = 4; + int block_mod = N % block_size; + int block_c4 = N - block_mod; + int i, j, k; for (i = 0; i < M; ++i) { for (k = 0; k < K; ++k) { float a = alpha * mat_a[k * lda + i]; - for (j = 0; j < N; ++j) { + for (j = 0; j < block_c4; j += block_size) { + float *b = &mat_b[k * ldb + j]; + float *c = &mat_c[i * ldc + j]; + c[0] += a * b[0]; + c[1] += a * b[1]; + c[2] += a * b[2]; + c[3] += a * b[3]; + } + for (; j < N; ++j) { mat_c[i * ldc + j] += a * mat_b[k * ldb + j]; } } } } -static void gemm_tt(int M, int N, int K, float alpha, float *mat_a, int lda, float *mat_b, int ldb, float *mat_c, - int ldc) { +static void gemm_trana_tranb(int M, int N, int K, float alpha, float *mat_a, int lda, float *mat_b, int ldb, + float *mat_c, int ldc) { int i, j, k; - for (i = 0; i < M; ++i) { + const int block_size = 4; + int k_block_mod = K % block_size; + int k_block_c4 = K - k_block_mod; + + int m_block_mod = M % block_size; + int m_block_c4 = M - m_block_mod; + + for (i = 0; i < m_block_c4; i += block_size) { + for (j = 0; j < N; ++j) { + float sum0 = 0; + float sum1 = 0; + float sum2 = 0; + float sum3 = 0; + + for (k = 0; k < k_block_c4; k += block_size) { + float *b = &mat_b[j * ldb + k]; + sum0 += alpha * mat_a[i + k * lda] * b[0]; + sum0 += alpha * mat_a[i + (k + 1) * lda] * b[1]; + sum0 += alpha * mat_a[i + (k + 2) * lda] * b[2]; + sum0 += alpha * mat_a[i + (k + 3) * lda] * b[3]; + + sum1 += alpha * mat_a[i + 1 + k * lda] * b[0]; + sum1 += alpha * mat_a[i + 1 + (k + 1) * lda] * b[1]; + sum1 += alpha * mat_a[i + 1 + (k + 2) * lda] * b[2]; + sum1 += alpha * mat_a[i + 1 + (k + 3) * lda] * b[3]; + + sum2 += alpha * mat_a[i + 2 + k * lda] * b[0]; + sum2 += alpha * mat_a[i + 2 + (k + 1) * lda] * b[1]; + sum2 += alpha * mat_a[i + 2 + (k + 2) * lda] * b[2]; + sum2 += alpha * mat_a[i + 2 + (k + 3) * lda] * b[3]; + + sum3 += alpha * mat_a[i + 3 + k * lda] * b[0]; + sum3 += alpha * mat_a[i + 3 + (k + 1) * lda] * b[1]; + sum3 += alpha * mat_a[i + 3 + (k + 2) * lda] * b[2]; + sum3 += alpha * mat_a[i + 3 + (k + 3) * lda] * b[3]; + } + for (; k < K; ++k) { + float *b = &mat_b[j * ldb + k]; + sum0 += alpha * mat_a[i + (k * lda)] * b[0]; + sum1 += alpha * mat_a[i + 1 + (k * lda)] * b[0]; + sum2 += alpha * mat_a[i + 2 + (k * lda)] * b[0]; + sum3 += alpha * mat_a[i + 3 + (k * lda)] * b[0]; + } + mat_c[i * ldc + j] += sum0; + mat_c[(i + 1) * ldc + j] += sum1; + mat_c[(i + 2) * ldc + j] += sum2; + mat_c[(i + 3) * ldc + j] += sum3; + } + } + // no more block of 4x4 + for (; i < M; ++i) { for (j = 0; j < N; ++j) { float sum = 0; for (k = 0; k < K; ++k) { @@ -74,34 +161,37 @@ static void gemm_tt(int M, int N, int K, float alpha, float *mat_a, int lda, flo // M - number of rows of matrix a // N - number of cols of matrix b // K - number of cols of matrix a - +// lda - fast dim of matrix a +// ldb - fast dim of matrix b +// ldc - fast dim of matrix c void gemm(int transpose_a, int transpose_b, int M, int N, int K, float alpha, float *mat_a, int lda, float *mat_b, int ldb, float beta, float *mat_c, int ldc) { if (beta >= 0.f && beta <= 0.f) { - for (int i = 0; i < M; ++i) { - for (int j = 0; j < N; ++j) { - mat_c[i * ldc + j] = 0; - } - } + memset(mat_c, 0, M * N * sizeof(float)); } else if (beta < 1.f || beta > 1.f) { - for (int i = 0; i < M; ++i) { - for (int j = 0; j < N; ++j) { - mat_c[i * ldc + j] *= beta; - } + const int block_size = 4; + const int size = M * N; + int block_mod = size % block_size; + int block_c4 = size - block_mod; + int i; + for (i = 0; i < block_c4; i += block_size) { + float *c = &mat_c[i]; + c[0] *= beta; + c[1] *= beta; + c[2] *= beta; + c[3] *= beta; } - } - - int t; - - for (t = 0; t < M; ++t) { - if (!transpose_a && !transpose_b) { - gemm_nn(1, N, K, alpha, mat_a + t * lda, lda, mat_b, ldb, mat_c + t * ldc, ldc); - } else if (transpose_a && !transpose_b) { - gemm_tn(1, N, K, alpha, mat_a + t, lda, mat_b, ldb, mat_c + t * ldc, ldc); - } else if (!transpose_a && transpose_b) { - gemm_nt(1, N, K, alpha, mat_a + t * lda, lda, mat_b, ldb, mat_c + t * ldc, ldc); - } else { - gemm_tt(1, N, K, alpha, mat_a + t, lda, mat_b, ldb, mat_c + t * ldc, ldc); + for (; i < size; ++i) { + mat_c[i] *= beta; } } + if (transpose_a && transpose_b) { + gemm_trana_tranb(M, N, K, alpha, mat_a, lda, mat_b, ldb, mat_c, ldc); + } else if (!transpose_a && !transpose_b) { + gemm_not_trana_not_tranb(M, N, K, alpha, mat_a, lda, mat_b, ldb, mat_c, ldc); + } else if (!transpose_a && transpose_b) { + gemm_not_trana_tranb(M, N, K, alpha, mat_a, lda, mat_b, ldb, mat_c, ldc); + } else { + gemm_trana_not_tranb(M, N, K, alpha, mat_a, lda, mat_b, ldb, mat_c, ldc); + } } diff --git a/mindspore/lite/nnacl/fp32_grad/optimizer.h b/mindspore/lite/nnacl/fp32_grad/optimizer.h index 9d03977a8c..0cc52ced60 100644 --- a/mindspore/lite/nnacl/fp32_grad/optimizer.h +++ b/mindspore/lite/nnacl/fp32_grad/optimizer.h @@ -21,7 +21,6 @@ typedef struct ApplyMomentumParameter { OpParameter op_parameter_; - bool use_locking_; bool use_nesterov_; float grad_scale_; } ApplyMomentumParameter; @@ -33,4 +32,9 @@ typedef struct SgdParameter { float weight_decay_; } SgdParameter; +typedef struct AdamParameter { + OpParameter op_parameter_; + bool use_nesterov_; +} AdamParameter; + #endif // MINDSPORE_LITE_NNACL_FP32_GRAD_OPTIMIZER_H_ diff --git a/mindspore/lite/schema/model.fbs b/mindspore/lite/schema/model.fbs index eb47373c9d..4b9245a3ef 100644 --- a/mindspore/lite/schema/model.fbs +++ b/mindspore/lite/schema/model.fbs @@ -182,7 +182,7 @@ union PrimitiveType { Conv2DGradInput, PoolingGrad, BNGrad, - BNGradInput, + Assign, ApplyMomentum, BiasGrad, SoftmaxCrossEntropy, @@ -217,6 +217,8 @@ union PrimitiveType { FftReal, FftImag, Sgd, + Adam, + GroupConv2DGradInput, } enum QuantType: int { diff --git a/mindspore/lite/schema/ops.fbs b/mindspore/lite/schema/ops.fbs index db7440a0da..55f7876374 100644 --- a/mindspore/lite/schema/ops.fbs +++ b/mindspore/lite/schema/ops.fbs @@ -224,7 +224,29 @@ table Conv2DGradInput { dilateH: int; hasBias: bool = false; activationType: ActivationType = 0; -}table FusedBatchNorm { +} + +table GroupConv2DGradInput { + format: Format = 0; + group: int; + channelIn: int; + channelOut: int; + kernelW: int; + kernelH: int; + strideW: int; + strideH: int; + padMode: PadMode; + padUp: int; + padDown: int; + padLeft: int; + padRight: int; + dilateW: int; + dilateH: int; + hasBias: bool = false; + activationType: ActivationType = 0; +} + +table FusedBatchNorm { epsilon: float = 0.00001; // eg. epsilon=0.001 momentum: float = 0.9; spatial: int = 1; @@ -901,7 +923,6 @@ table TupleGetItem { table ApplyMomentum { gradientScale: float; - useLocking: bool; useNesterov: bool; } @@ -911,6 +932,14 @@ table Sgd { useNesterov: bool; } +table Adam { + useNesterov: bool; +} + +table Assign { +} + + table Where{ condition: [bool]; } diff --git a/mindspore/lite/src/ops/activation.cc b/mindspore/lite/src/ops/activation.cc index b31b52e5ff..185645ec84 100644 --- a/mindspore/lite/src/ops/activation.cc +++ b/mindspore/lite/src/ops/activation.cc @@ -50,6 +50,10 @@ int Activation::UnPackAttr(const Primitive &prim, const std::vector attr->type = schema::ActivationType_SIGMOID; } else if (prim.name() == "ReLU6") { attr->type = schema::ActivationType_RELU6; + } else if (prim.name() == "HSwish") { + attr->type = schema::ActivationType_HSWISH; + } else if (prim.name() == "HSigmoid") { + attr->type = schema::ActivationType_HSIGMOID; } this->primitive_->value.value = attr.release(); if (this->primitive_->value.value == nullptr) { diff --git a/mindspore/lite/src/ops/activation_grad.cc b/mindspore/lite/src/ops/activation_grad.cc index 4e643d9d8b..f9c45a9bea 100644 --- a/mindspore/lite/src/ops/activation_grad.cc +++ b/mindspore/lite/src/ops/activation_grad.cc @@ -43,8 +43,12 @@ int ActivationGrad::UnPackAttr(const Primitive &prim, const std::vectortype = schema::ActivationType_RELU; } else if (prim.name() == "SigmoidGrad") { attr->type = schema::ActivationType_SIGMOID; - } else if (prim.name() == "Relu6Grad") { + } else if (prim.name() == "ReLU6Grad") { attr->type = schema::ActivationType_RELU6; + } else if (prim.name() == "HSigmoidGrad") { + attr->type = schema::ActivationType_HSIGMOID; + } else if (prim.name() == "HSwishGrad") { + attr->type = schema::ActivationType_HSWISH; } attr->alpha = 0; // alpha; this->primitive_->value.value = attr.release(); diff --git a/mindspore/lite/src/ops/adam.cc b/mindspore/lite/src/ops/adam.cc new file mode 100644 index 0000000000..51eaf4648f --- /dev/null +++ b/mindspore/lite/src/ops/adam.cc @@ -0,0 +1,91 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/ops/adam.h" +namespace mindspore { +namespace lite { +#ifdef PRIMITIVE_WRITEABLE +bool Adam::GetUseNesterov() const { return this->primitive_->value.AsAdam()->useNesterov; } +int Adam::UnPackAttr(const Primitive &prim, const std::vector &inputs) { + if (this->primitive_ == nullptr) { + this->primitive_ = new (std::nothrow) schema::PrimitiveT; + if (this->primitive_ == nullptr) { + MS_LOG(ERROR) << "new primitiveT failed"; + return RET_ERROR; + } + this->primitive_->value.type = schema::PrimitiveType_Adam; + } + if (this->primitive_->value.type != schema::PrimitiveType_Adam) { + MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type; + return RET_ERROR; + } + if (this->primitive_->value.value == nullptr) { + auto attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new primitiveT value failed"; + return RET_ERROR; + } + attr->useNesterov = GetValue(prim.GetAttr("use_nesterov")); + + this->primitive_->value.value = attr.release(); + if (this->primitive_->value.value == nullptr) { + MS_LOG(ERROR) << "new primitiveT value failed"; + return RET_ERROR; + } + } + return RET_OK; +} +#else +bool Adam::GetUseNesterov() const { return this->primitive_->value_as_Adam()->useNesterov(); } +int Adam::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto attr = primitive->value_as_Adam(); + if (attr == nullptr) { + MS_LOG(ERROR) << "value_as_Adam return nullptr"; + return RET_ERROR; + } + auto val_offset = schema::CreateAdam(*fbb, attr->useNesterov()); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Adam, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} +#endif + +int Adam::InferShape(std::vector inputs, std::vector outputs) { + if (10 != inputs.size()) { + MS_LOG(ERROR) << "Adam should have at least 8 input tensors"; + return RET_ERROR; + } + + if (inputs[0]->ElementsNum() != inputs[1]->ElementsNum() || inputs[0]->ElementsNum() != inputs[2]->ElementsNum() || + inputs[0]->ElementsNum() != inputs[9]->ElementsNum() || inputs[3]->ElementsNum() != 1 || + inputs[4]->ElementsNum() != 1 || inputs[5]->ElementsNum() != 1 || inputs[6]->ElementsNum() != 1 || + inputs[7]->ElementsNum() != 1 || inputs[8]->ElementsNum() != 1) { + MS_LOG(ERROR) << "error input data size!"; + return RET_ERROR; + } + if (!outputs.empty()) { + auto *out = outputs.front(); + MS_ASSERT(out != nullptr); + out->set_data_type(inputs[0]->data_type()); + out->SetFormat(inputs[0]->GetFormat()); + out->set_shape({1}); + } + + return RET_OK; +} +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/src/ops/adam.h b/mindspore/lite/src/ops/adam.h new file mode 100644 index 0000000000..6ffee993ad --- /dev/null +++ b/mindspore/lite/src/ops/adam.h @@ -0,0 +1,47 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_OPS_ADAM_H_ +#define MINDSPORE_LITE_SRC_OPS_ADAM_H_ + +#include +#include +#include +#include + +#include "src/ops/primitive_c.h" + +namespace mindspore { +namespace lite { +class Adam : public PrimitiveC { + public: +#ifdef PRIMITIVE_WRITEABLE + MS_DECLARE_PARENT(Adam, PrimitiveC); + Adam() = default; + explicit Adam(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} + int UnPackAttr(const Primitive &prim, const std::vector &inputs) override; +#else + Adam() = default; + + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; +#endif + int InferShape(std::vector inputs_, std::vector outputs_) override; + bool GetUseNesterov() const; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_LITE_SRC_OPS_ADAM_H_ diff --git a/mindspore/lite/src/ops/addn.cc b/mindspore/lite/src/ops/addn.cc index 1d9bfda341..65564bb7aa 100644 --- a/mindspore/lite/src/ops/addn.cc +++ b/mindspore/lite/src/ops/addn.cc @@ -82,8 +82,11 @@ int AddN::InferShape(std::vector inputs, std::vector outputs if (!GetInferFlag()) { return RET_OK; } + output->set_shape(input->shape()); + + // make sure all elements have the same size or 1 (broadcasting) in all dimensions for (size_t i = 1; i < inputs.size(); ++i) { - if (inputs.at(i)->shape() != inputs.at(0)->shape()) { + if (inputs.at(i)->shape().size() != inputs.at(0)->shape().size()) { MS_LOG(ERROR) << "AddN inputs shape is not equal!"; return RET_INPUT_TENSOR_ERROR; } @@ -93,7 +96,22 @@ int AddN::InferShape(std::vector inputs, std::vector outputs } } - output->set_shape(input->shape()); + for (size_t d = 0; d < input->shape().size(); ++d) { + int max_dim = input->shape().at(d); + for (size_t i = 1; i < inputs.size(); ++i) { + if (inputs.at(i)->shape().at(d) > max_dim) { + max_dim = inputs.at(i)->shape().at(d); + } + } + for (size_t i = 1; i < inputs.size(); ++i) { + if ((inputs.at(0)->shape().at(d) != max_dim) && (inputs.at(0)->shape().at(d) != 1)) { + MS_LOG(ERROR) << "AddN inputs shape is not equal!"; + return RET_INPUT_TENSOR_ERROR; + } + } + output->shape()[d] = max_dim; // set the biggest dimension in the output tensor + } + return RET_OK; } } // namespace lite diff --git a/mindspore/lite/src/ops/apply_momentum.cc b/mindspore/lite/src/ops/apply_momentum.cc index 14918d9699..fc625ad85c 100644 --- a/mindspore/lite/src/ops/apply_momentum.cc +++ b/mindspore/lite/src/ops/apply_momentum.cc @@ -18,7 +18,6 @@ namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE float ApplyMomentum::GetGradientScale() const { return this->primitive_->value.AsApplyMomentum()->gradientScale; } -bool ApplyMomentum::GetUseLocking() const { return this->primitive_->value.AsApplyMomentum()->useLocking; } bool ApplyMomentum::GetUseNesterov() const { return this->primitive_->value.AsApplyMomentum()->useNesterov; } int ApplyMomentum::UnPackAttr(const Primitive &prim, const std::vector &inputs) { @@ -41,7 +40,6 @@ int ApplyMomentum::UnPackAttr(const Primitive &prim, const std::vectorgradientScale = GetValue(prim.GetAttr("gradient_scale")); - attr->useLocking = GetValue(prim.GetAttr("use_locking")); attr->useNesterov = GetValue(prim.GetAttr("use_nesterov")); this->primitive_->value.value = attr.release(); @@ -54,7 +52,6 @@ int ApplyMomentum::UnPackAttr(const Primitive &prim, const std::vectorprimitive_->value_as_ApplyMomentum()->gradientScale(); } -bool ApplyMomentum::GetUseLocking() const { return this->primitive_->value_as_ApplyMomentum()->useLocking(); } bool ApplyMomentum::GetUseNesterov() const { return this->primitive_->value_as_ApplyMomentum()->useNesterov(); } int ApplyMomentum::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { @@ -65,7 +62,7 @@ int ApplyMomentum::UnPackToFlatBuilder(const schema::Primitive *primitive, flatb MS_LOG(ERROR) << "value_as_ApplyMomentum return nullptr"; return RET_ERROR; } - auto val_offset = schema::CreateApplyMomentum(*fbb, attr->gradientScale(), attr->useLocking(), attr->useNesterov()); + auto val_offset = schema::CreateApplyMomentum(*fbb, attr->gradientScale(), attr->useNesterov()); auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_ApplyMomentum, val_offset.o); fbb->Finish(prim_offset); return RET_OK; diff --git a/mindspore/lite/src/ops/apply_momentum.h b/mindspore/lite/src/ops/apply_momentum.h index 4f3d96aef3..12cd392525 100644 --- a/mindspore/lite/src/ops/apply_momentum.h +++ b/mindspore/lite/src/ops/apply_momentum.h @@ -40,7 +40,6 @@ class ApplyMomentum : public PrimitiveC { #endif int InferShape(std::vector inputs_, std::vector outputs_) override; float GetGradientScale() const; - bool GetUseLocking() const; bool GetUseNesterov() const; }; } // namespace lite diff --git a/mindspore/lite/src/ops/assign.cc b/mindspore/lite/src/ops/assign.cc new file mode 100644 index 0000000000..ca59337651 --- /dev/null +++ b/mindspore/lite/src/ops/assign.cc @@ -0,0 +1,82 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/assign.h" +#include + +namespace mindspore { +namespace lite { +#ifdef PRIMITIVE_WRITEABLE +int Assign::UnPackAttr(const Primitive &prim, const std::vector &inputs) { + if (this->primitive_ == nullptr) { + this->primitive_ = new (std::nothrow) schema::PrimitiveT; + if (this->primitive_ == nullptr) { + MS_LOG(ERROR) << "new primitiveT failed"; + return RET_ERROR; + } + this->primitive_->value.type = schema::PrimitiveType_Assign; + } + if (this->primitive_->value.type != schema::PrimitiveType_Assign) { + MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type; + return RET_ERROR; + } + if (this->primitive_->value.value == nullptr) { + this->primitive_->value.value = new (std::nothrow) schema::AssignT(); + if (this->primitive_->value.value == nullptr) { + MS_LOG(ERROR) << "new primitiveT value failed"; + return RET_ERROR; + } + } + return RET_OK; +} +#else +int Assign::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto attr = primitive->value_as_Assign(); + if (attr == nullptr) { + MS_LOG(ERROR) << "value_as_Assign return nullptr"; + return RET_ERROR; + } + auto val_offset = schema::CreateAssign(*fbb); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Assign, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} +#endif + +int Assign::InferShape(std::vector inputs, std::vector outputs) { + if (2 != inputs.size()) { + MS_LOG(ERROR) << "Assign should have at least 5 input tensors"; + return RET_ERROR; + } + + if (inputs[0]->ElementsNum() != inputs[1]->ElementsNum()) { + MS_LOG(ERROR) << "error input data size!"; + return RET_ERROR; + } + + if (!outputs.empty()) { + auto *out = outputs.front(); + MS_ASSERT(out != nullptr); + out->set_data_type(inputs[0]->data_type()); + out->SetFormat(inputs[0]->GetFormat()); + out->set_shape({1}); + } + return RET_OK; +} +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/src/ops/assign.h b/mindspore/lite/src/ops/assign.h new file mode 100644 index 0000000000..0316031b9a --- /dev/null +++ b/mindspore/lite/src/ops/assign.h @@ -0,0 +1,43 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_OPS_ASSIGN_H_ +#define MINDSPORE_LITE_SRC_OPS_ASSIGN_H_ + +#include +#include +#include +#include "src/ops/primitive_c.h" + +namespace mindspore { +namespace lite { +class Assign : public PrimitiveC { + public: +#ifdef PRIMITIVE_WRITEABLE + MS_DECLARE_PARENT(Assign, PrimitiveC); + Assign() = default; + explicit Assign(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} + int UnPackAttr(const Primitive &prim, const std::vector &inputs) override; +#else + Assign() = default; + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; +#endif + int InferShape(std::vector inputs_, std::vector outputs_) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_LITE_SRC_OPS_ASSIGN_H_ diff --git a/mindspore/lite/src/ops/bn_grad.cc b/mindspore/lite/src/ops/bn_grad.cc index 8b6ebb321b..d623117327 100644 --- a/mindspore/lite/src/ops/bn_grad.cc +++ b/mindspore/lite/src/ops/bn_grad.cc @@ -45,8 +45,8 @@ int BNGrad::UnPackAttr(const Primitive &prim, const std::vector &inp } attr->momentum = GetValue(prim.GetAttr("momentum")); // FusedBatchNormGrad dows not get this attribute - if (prim.GetAttr("eps") != nullptr) { - attr->eps = GetValue(prim.GetAttr("eps")); + if (prim.GetAttr("epsilon") != nullptr) { + attr->eps = GetValue(prim.GetAttr("epsilon")); } this->primitive_->value.value = attr; if (this->primitive_->value.value == nullptr) { diff --git a/mindspore/lite/src/ops/conv2d_grad_input.cc b/mindspore/lite/src/ops/conv2d_grad_input.cc index e68d9c29ac..a49f62a3e0 100644 --- a/mindspore/lite/src/ops/conv2d_grad_input.cc +++ b/mindspore/lite/src/ops/conv2d_grad_input.cc @@ -15,7 +15,7 @@ */ #include "src/ops/conv2d_grad_input.h" - +#include "src/ops/group_conv2d_grad_input.h" namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE @@ -86,6 +86,9 @@ int Conv2DGradInput::UnPackAttr(const Primitive &prim, const std::vectorgroup = GetValue(prim.GetAttr("group")); + if (attr->group > 1) { + this->primitive_->value.type = schema::PrimitiveType_GroupConv2DGradInput; + } auto format = GetValue(prim.GetAttr("data_format")); if (format == "NCHW") { attr->format = schema::Format_NCHW; diff --git a/mindspore/lite/src/ops/exp.cc b/mindspore/lite/src/ops/exp.cc index 4870da0cda..0a1f88ffc1 100644 --- a/mindspore/lite/src/ops/exp.cc +++ b/mindspore/lite/src/ops/exp.cc @@ -26,6 +26,30 @@ void Exp::SetShift(float shift) { this->primitive_->value.AsExp()->shift = shift float Exp::GetBase() const { return this->primitive_->value.AsExp()->base; } float Exp::GetScale() const { return this->primitive_->value.AsExp()->scale; } float Exp::GetShift() const { return this->primitive_->value.AsExp()->shift; } + +int Exp::UnPackAttr(const Primitive &prim, const std::vector &inputs) { + if (this->primitive_ == nullptr) { + this->primitive_ = new (std::nothrow) schema::PrimitiveT; + if (this->primitive_ == nullptr) { + MS_LOG(ERROR) << "new primitiveT failed"; + return RET_ERROR; + } + this->primitive_->value.type = schema::PrimitiveType_Exp; + } + if (this->primitive_->value.type != schema::PrimitiveType_Exp) { + MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type; + return RET_ERROR; + } + if (this->primitive_->value.value == nullptr) { + this->primitive_->value.value = new (std::nothrow) schema::ExpT(); + if (this->primitive_->value.value == nullptr) { + MS_LOG(ERROR) << "new primitiveT value failed"; + return RET_ERROR; + } + } + return RET_OK; +} + #else int Exp::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { diff --git a/mindspore/lite/src/ops/exp.h b/mindspore/lite/src/ops/exp.h index e0159a1991..b72a5a1c9b 100644 --- a/mindspore/lite/src/ops/exp.h +++ b/mindspore/lite/src/ops/exp.h @@ -33,6 +33,7 @@ class Exp : public PrimitiveC { void SetBase(float base); void SetShift(float shift); void SetScale(float scale); + int UnPackAttr(const Primitive &prim, const std::vector &inputs) override; #else Exp() = default; diff --git a/mindspore/lite/src/ops/group_conv2d_grad_input.cc b/mindspore/lite/src/ops/group_conv2d_grad_input.cc new file mode 100644 index 0000000000..82c1606b49 --- /dev/null +++ b/mindspore/lite/src/ops/group_conv2d_grad_input.cc @@ -0,0 +1,172 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/group_conv2d_grad_input.h" + +namespace mindspore { +namespace lite { +#ifdef PRIMITIVE_WRITEABLE +int GroupConv2DGradInput::GetFormat() const { return this->primitive_->value.AsGroupConv2DGradInput()->format; } +int GroupConv2DGradInput::GetGroup() const { return this->primitive_->value.AsGroupConv2DGradInput()->group; } +int GroupConv2DGradInput::GetChannelIn() const { return this->primitive_->value.AsGroupConv2DGradInput()->channelIn; } +int GroupConv2DGradInput::GetChannelOut() const { return this->primitive_->value.AsGroupConv2DGradInput()->channelOut; } +int GroupConv2DGradInput::GetKernelW() const { return this->primitive_->value.AsGroupConv2DGradInput()->kernelW; } +int GroupConv2DGradInput::GetKernelH() const { return this->primitive_->value.AsGroupConv2DGradInput()->kernelH; } +int GroupConv2DGradInput::GetStrideW() const { return this->primitive_->value.AsGroupConv2DGradInput()->strideW; } +int GroupConv2DGradInput::GetStrideH() const { return this->primitive_->value.AsGroupConv2DGradInput()->strideH; } +int GroupConv2DGradInput::GetPadMode() const { return this->primitive_->value.AsGroupConv2DGradInput()->padMode; } +int GroupConv2DGradInput::GetPadUp() const { return this->primitive_->value.AsGroupConv2DGradInput()->padUp; } +int GroupConv2DGradInput::GetPadDown() const { return this->primitive_->value.AsGroupConv2DGradInput()->padDown; } +int GroupConv2DGradInput::GetPadLeft() const { return this->primitive_->value.AsGroupConv2DGradInput()->padLeft; } +int GroupConv2DGradInput::GetPadRight() const { return this->primitive_->value.AsGroupConv2DGradInput()->padRight; } +int GroupConv2DGradInput::GetDilateW() const { return this->primitive_->value.AsGroupConv2DGradInput()->dilateW; } +int GroupConv2DGradInput::GetDilateH() const { return this->primitive_->value.AsGroupConv2DGradInput()->dilateH; } +bool GroupConv2DGradInput::GetHasBias() const { return this->primitive_->value.AsGroupConv2DGradInput()->hasBias; } +int GroupConv2DGradInput::GetActivationType() const { + return this->primitive_->value.AsGroupConv2DGradInput()->activationType; +} + +void GroupConv2DGradInput::SetFormat(int format) { + this->primitive_->value.AsGroupConv2DGradInput()->format = (schema::Format)format; +} +void GroupConv2DGradInput::SetGroup(int group) { this->primitive_->value.AsGroupConv2DGradInput()->group = group; } +void GroupConv2DGradInput::SetChannelIn(int channel_in) { + this->primitive_->value.AsGroupConv2DGradInput()->channelIn = channel_in; +} +void GroupConv2DGradInput::SetChannelOut(int channel_out) { + this->primitive_->value.AsGroupConv2DGradInput()->channelOut = channel_out; +} +void GroupConv2DGradInput::SetKernelW(int kernel_w) { + this->primitive_->value.AsGroupConv2DGradInput()->kernelW = kernel_w; +} +void GroupConv2DGradInput::SetKernelH(int kernel_h) { + this->primitive_->value.AsGroupConv2DGradInput()->kernelH = kernel_h; +} +void GroupConv2DGradInput::SetStrideW(int stride_w) { + this->primitive_->value.AsGroupConv2DGradInput()->strideW = stride_w; +} +void GroupConv2DGradInput::SetStrideH(int stride_h) { + this->primitive_->value.AsGroupConv2DGradInput()->strideH = stride_h; +} +void GroupConv2DGradInput::SetPadMode(int pad_mode) { + this->primitive_->value.AsGroupConv2DGradInput()->padMode = (schema::PadMode)pad_mode; +} +void GroupConv2DGradInput::SetPadUp(int pad_up) { this->primitive_->value.AsGroupConv2DGradInput()->padUp = pad_up; } +void GroupConv2DGradInput::SetPadDown(int pad_down) { + this->primitive_->value.AsGroupConv2DGradInput()->padDown = pad_down; +} +void GroupConv2DGradInput::SetPadLeft(int pad_left) { + this->primitive_->value.AsGroupConv2DGradInput()->padLeft = pad_left; +} +void GroupConv2DGradInput::SetPadRight(int pad_right) { + this->primitive_->value.AsGroupConv2DGradInput()->padRight = pad_right; +} +void GroupConv2DGradInput::SetDilateW(int dilate_w) { + this->primitive_->value.AsGroupConv2DGradInput()->dilateW = dilate_w; +} +void GroupConv2DGradInput::SetDilateH(int dilate_h) { + this->primitive_->value.AsGroupConv2DGradInput()->dilateH = dilate_h; +} +void GroupConv2DGradInput::SetHasBias(bool has_bias) { + this->primitive_->value.AsGroupConv2DGradInput()->hasBias = has_bias; +} +void GroupConv2DGradInput::SetActivationType(int activation_type) { + this->primitive_->value.AsGroupConv2DGradInput()->activationType = (schema::ActivationType)activation_type; +} +#else +int GroupConv2DGradInput::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto attr = primitive->value_as_GroupConv2DGradInput(); + if (attr == nullptr) { + MS_LOG(ERROR) << "value_as_GroupConv2DGradInput return nullptr"; + return RET_ERROR; + } + auto val_offset = schema::CreateGroupConv2DGradInput( + *fbb, attr->format(), attr->group(), attr->channelIn(), attr->channelOut(), attr->kernelW(), attr->kernelH(), + attr->strideW(), attr->strideH(), attr->padMode(), attr->padUp(), attr->padDown(), attr->padLeft(), + attr->padRight(), attr->dilateW(), attr->dilateH(), attr->hasBias(), attr->activationType()); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_GroupConv2DGradInput, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} +int GroupConv2DGradInput::GetFormat() const { return this->primitive_->value_as_GroupConv2DGradInput()->format(); } +int GroupConv2DGradInput::GetGroup() const { return this->primitive_->value_as_GroupConv2DGradInput()->group(); } +int GroupConv2DGradInput::GetChannelIn() const { + return this->primitive_->value_as_GroupConv2DGradInput()->channelIn(); +} +int GroupConv2DGradInput::GetChannelOut() const { + return this->primitive_->value_as_GroupConv2DGradInput()->channelOut(); +} +int GroupConv2DGradInput::GetKernelW() const { return this->primitive_->value_as_GroupConv2DGradInput()->kernelW(); } +int GroupConv2DGradInput::GetKernelH() const { return this->primitive_->value_as_GroupConv2DGradInput()->kernelH(); } +int GroupConv2DGradInput::GetStrideW() const { return this->primitive_->value_as_GroupConv2DGradInput()->strideW(); } +int GroupConv2DGradInput::GetStrideH() const { return this->primitive_->value_as_GroupConv2DGradInput()->strideH(); } +int GroupConv2DGradInput::GetPadMode() const { return this->primitive_->value_as_GroupConv2DGradInput()->padMode(); } +int GroupConv2DGradInput::GetPadUp() const { return this->primitive_->value_as_GroupConv2DGradInput()->padUp(); } +int GroupConv2DGradInput::GetPadDown() const { return this->primitive_->value_as_GroupConv2DGradInput()->padDown(); } +int GroupConv2DGradInput::GetPadLeft() const { return this->primitive_->value_as_GroupConv2DGradInput()->padLeft(); } +int GroupConv2DGradInput::GetPadRight() const { return this->primitive_->value_as_GroupConv2DGradInput()->padRight(); } +int GroupConv2DGradInput::GetDilateW() const { return this->primitive_->value_as_GroupConv2DGradInput()->dilateW(); } +int GroupConv2DGradInput::GetDilateH() const { return this->primitive_->value_as_GroupConv2DGradInput()->dilateH(); } +bool GroupConv2DGradInput::GetHasBias() const { return this->primitive_->value_as_GroupConv2DGradInput()->hasBias(); } +int GroupConv2DGradInput::GetActivationType() const { + return this->primitive_->value_as_GroupConv2DGradInput()->activationType(); +} + +#endif + +int GroupConv2DGradInput::InferShape(std::vector inputs, std::vector outputs) { + if (3 != inputs.size()) { + MS_LOG(ERROR) << "Conv2d Grad Input should have 3 inputs"; + return RET_ERROR; + } + if (1 != outputs.size()) { + MS_LOG(ERROR) << "Conv2d Grad input should have one output"; + return RET_ERROR; + } + + auto *in0 = inputs.at(0); + auto *in = inputs.at(2); + MS_ASSERT(out != nullptr); + + std::vector output_shape; + int *out_shape = reinterpret_cast(in->MutableData()); + int new_size = in->ElementsNum(); + if (in0->GetFormat() == in->GetFormat()) { + for (int i = 0; i < new_size; i++) output_shape.push_back(out_shape[i]); + } else { + if ((in0->GetFormat() == schema::Format_NHWC) && (in->GetFormat() == schema::Format_NCHW)) { + output_shape.push_back(out_shape[0]); + output_shape.push_back(out_shape[2]); + output_shape.push_back(out_shape[3]); + output_shape.push_back(out_shape[1]); + } else { + MS_LOG(ERROR) << "Shape covnert is not supported"; + return RET_ERROR; + } + } + + auto *out = outputs.at(0); + MS_ASSERT(out != nullptr); + out->set_shape(output_shape); + out->set_data_type(in0->data_type()); + out->SetFormat(in0->GetFormat()); + + return RET_OK; +} +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/src/ops/group_conv2d_grad_input.h b/mindspore/lite/src/ops/group_conv2d_grad_input.h new file mode 100644 index 0000000000..2026a524ef --- /dev/null +++ b/mindspore/lite/src/ops/group_conv2d_grad_input.h @@ -0,0 +1,79 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_OPS_GROUP_CONV2D_GRAD_INPUT_H_ +#define MINDSPORE_LITE_SRC_OPS_GROUP_CONV2D_GRAD_INPUT_H_ + +#include +#include +#include +#include +#include +#include "src/ops/primitive_c.h" + +namespace mindspore { +namespace lite { +class GroupConv2DGradInput : public PrimitiveC { + public: +#ifdef PRIMITIVE_WRITEABLE + MS_DECLARE_PARENT(GroupConv2DGradInput, PrimitiveC); + GroupConv2DGradInput() = default; + explicit GroupConv2DGradInput(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} + void SetFormat(int format); + void SetGroup(int group); + void SetChannelIn(int channel_in); + void SetChannelOut(int channel_out); + void SetKernelW(int kernel_w); + void SetKernelH(int kernel_h); + void SetStrideW(int stride_w); + void SetStrideH(int stride_h); + void SetPadMode(int pad_mode); + void SetPadUp(int pad_up); + void SetPadDown(int pad_down); + void SetPadLeft(int pad_left); + void SetPadRight(int pad_right); + void SetDilateW(int dilate_w); + void SetDilateH(int dilate_h); + void SetHasBias(bool has_bias); + void SetActivationType(int activation_type); +#else + GroupConv2DGradInput() = default; + + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; +#endif + int InferShape(std::vector inputs_, std::vector outputs_) override; + int GetFormat() const; + int GetGroup() const; + int GetChannelIn() const; + int GetChannelOut() const; + int GetKernelW() const; + int GetKernelH() const; + int GetStrideW() const; + int GetStrideH() const; + int GetPadMode() const; + int GetPadUp() const; + int GetPadDown() const; + int GetPadLeft() const; + int GetPadRight() const; + int GetDilateW() const; + int GetDilateH() const; + bool GetHasBias() const; + int GetActivationType() const; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_LITE_SRC_OPS_GROUP_CONV2D_GRAD_INPUT_H_ diff --git a/mindspore/lite/src/ops/neg.cc b/mindspore/lite/src/ops/neg.cc index 2645927a7d..90fce187c4 100644 --- a/mindspore/lite/src/ops/neg.cc +++ b/mindspore/lite/src/ops/neg.cc @@ -18,7 +18,31 @@ namespace mindspore { namespace lite { -#ifndef PRIMITIVE_WRITEABLE +#ifdef PRIMITIVE_WRITEABLE +int Neg::UnPackAttr(const Primitive &prim, const std::vector &inputs) { + if (this->primitive_ == nullptr) { + this->primitive_ = new (std::nothrow) schema::PrimitiveT; + if (this->primitive_ == nullptr) { + MS_LOG(ERROR) << "new primitiveT failed"; + return RET_ERROR; + } + this->primitive_->value.type = schema::PrimitiveType_Neg; + } + if (this->primitive_->value.type != schema::PrimitiveType_Neg) { + MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type; + return RET_ERROR; + } + if (this->primitive_->value.value == nullptr) { + this->primitive_->value.value = new (std::nothrow) schema::NegT(); + if (this->primitive_->value.value == nullptr) { + MS_LOG(ERROR) << "new primitiveT value failed"; + return RET_ERROR; + } + } + return RET_OK; +} + +#else int Neg::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { MS_ASSERT(primitive != nullptr); MS_ASSERT(fbb != nullptr); diff --git a/mindspore/lite/src/ops/neg.h b/mindspore/lite/src/ops/neg.h index 9e461ef10e..e68b5c0c12 100644 --- a/mindspore/lite/src/ops/neg.h +++ b/mindspore/lite/src/ops/neg.h @@ -31,6 +31,7 @@ class Neg : public ArithmeticSelf { MS_DECLARE_PARENT(Neg, ArithmeticSelf); Neg() = default; explicit Neg(schema::PrimitiveT *primitive) : ArithmeticSelf(primitive) {} + int UnPackAttr(const Primitive &prim, const std::vector &inputs) override; #else Neg() = default; diff --git a/mindspore/lite/src/ops/one_hot.cc b/mindspore/lite/src/ops/one_hot.cc index 5b47480054..f5235b7c19 100644 --- a/mindspore/lite/src/ops/one_hot.cc +++ b/mindspore/lite/src/ops/one_hot.cc @@ -23,6 +23,37 @@ int OneHot::GetAxis() const { return this->primitive_->value.AsOneHot()->axis; } void OneHot::SetAxis(int axis) { this->primitive_->value.AsOneHot()->axis = axis; } +int OneHot::UnPackAttr(const Primitive &prim, const std::vector &inputs) { + if (this->primitive_ == nullptr) { + this->primitive_ = new (std::nothrow) schema::PrimitiveT; + if (this->primitive_ == nullptr) { + MS_LOG(ERROR) << "new primitiveT failed"; + return RET_ERROR; + } + this->primitive_->value.type = schema::PrimitiveType_OneHot; + } + if (this->primitive_->value.type != schema::PrimitiveType_OneHot) { + MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type; + return RET_ERROR; + } + if (this->primitive_->value.value == nullptr) { + auto attr = new (std::nothrow) schema::OneHotT(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new primitiveT value failed"; + return RET_ERROR; + } + attr->axis = -1; + if (prim.GetAttr("axis") != nullptr) { + attr->axis = GetValue(prim.GetAttr("axis")); + } + this->primitive_->value.value = attr; + if (this->primitive_->value.value == nullptr) { + MS_LOG(ERROR) << "primitive value is nullptr"; + return RET_ERROR; + } + } + return RET_OK; +} #else int OneHot::GetAxis() const { return this->primitive_->value_as_OneHot()->axis(); } diff --git a/mindspore/lite/src/ops/one_hot.h b/mindspore/lite/src/ops/one_hot.h index 67d9d12bb1..b92dc49767 100644 --- a/mindspore/lite/src/ops/one_hot.h +++ b/mindspore/lite/src/ops/one_hot.h @@ -32,7 +32,7 @@ class OneHot : public PrimitiveC { OneHot() = default; explicit OneHot(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} void SetAxis(int axis); - + int UnPackAttr(const Primitive &prim, const std::vector &inputs) override; #else OneHot() = default; diff --git a/mindspore/lite/src/ops/primitive_c.cc b/mindspore/lite/src/ops/primitive_c.cc index 87662121f4..04dca9b8db 100644 --- a/mindspore/lite/src/ops/primitive_c.cc +++ b/mindspore/lite/src/ops/primitive_c.cc @@ -144,6 +144,7 @@ #include "src/ops/pooling_grad.h" #include "src/ops/conv2d_grad_filter.h" #include "src/ops/conv2d_grad_input.h" +#include "src/ops/group_conv2d_grad_input.h" #include "src/ops/power_grad.h" #include "src/ops/softmax_cross_entropy.h" #include "src/ops/bn_grad.h" @@ -152,6 +153,8 @@ #include "src/ops/flatten_grad.h" #include "src/ops/log_grad.h" #include "src/ops/sgd.h" +#include "src/ops/adam.h" +#include "src/ops/assign.h" #endif namespace mindspore { @@ -367,7 +370,7 @@ std::shared_ptr NewPrimitiveC(const Primitive &prim, const std::vect std::shared_ptr PrimitiveC::Create(const Primitive &prim, const std::vector &inputs, const schema::QuantType &quantType) { const auto &op_type = prim.name(); - if (op_type == "ReLU" || op_type == "ReLU6" || op_type == "Sigmoid") { + if (op_type == "ReLU" || op_type == "ReLU6" || op_type == "Sigmoid" || op_type == "HSwish" || op_type == "HSigmoid") { return NewPrimitiveC(prim, inputs, quantType); } else if (op_type == "AddN") { return NewPrimitiveC(prim, inputs, quantType); @@ -413,6 +416,10 @@ std::shared_ptr PrimitiveC::Create(const Primitive &prim, const std: return NewPrimitiveC(prim, inputs, quantType); } else if (op_type == "Reshape") { return NewPrimitiveC(prim, inputs, quantType); + } else if (op_type == "Slice") { + return NewPrimitiveC(prim, inputs, quantType); + } else if (op_type == "Squeeze") { + return NewPrimitiveC(prim, inputs, quantType); } else if (op_type == "TensorAdd") { return NewPrimitiveC(prim, inputs, quantType); } else if (op_type == "Transpose") { @@ -421,6 +428,10 @@ std::shared_ptr PrimitiveC::Create(const Primitive &prim, const std: return NewPrimitiveC(prim, inputs, quantType); } else if (op_type == "Log") { return NewPrimitiveC(prim, inputs, quantType); + } else if (op_type == "Exp") { + return NewPrimitiveC(prim, inputs, quantType); + } else if (op_type == "Neg") { + return NewPrimitiveC(prim, inputs, quantType); } else if (op_type == "DeConv2D") { return NewPrimitiveC(prim, inputs, quantType); } else if (op_type == "tuple_getitem") { @@ -435,6 +446,8 @@ std::shared_ptr PrimitiveC::Create(const Primitive &prim, const std: return NewPrimitiveC(prim, inputs, quantType); } else if (op_type == "Split") { return NewPrimitiveC(prim, inputs, quantType); + } else if (op_type == "OneHot") { + return NewPrimitiveC(prim, inputs, quantType); #ifdef SUPPORT_TRAIN } else if (op_type == "SoftmaxCrossEntropyWithLogits") { @@ -445,7 +458,8 @@ std::shared_ptr PrimitiveC::Create(const Primitive &prim, const std: return NewPrimitiveC(prim, inputs, quantType); } else if (op_type == "Depend") { return NewPrimitiveC(prim, inputs, quantType); - } else if ((op_type == "ReluGrad" || op_type == "Relu6Grad" || op_type == "SigmoidGrad")) { + } else if ((op_type == "ReluGrad" || op_type == "ReLU6Grad" || op_type == "SigmoidGrad" || + op_type == "HSigmoidGrad" || op_type == "HSwishGrad")) { return NewPrimitiveC(prim, inputs, quantType); } else if ((op_type == "MaxPoolGrad") || (op_type == "MeanPoolGrad")) { return NewPrimitiveC(prim, inputs, quantType); @@ -465,6 +479,10 @@ std::shared_ptr PrimitiveC::Create(const Primitive &prim, const std: return NewPrimitiveC(prim, inputs, quantType); } else if (op_type == "SGD") { return NewPrimitiveC(prim, inputs, quantType); + } else if (op_type == "Adam") { + return NewPrimitiveC(prim, inputs, quantType); + } else if (op_type == "Assign") { + return NewPrimitiveC(prim, inputs, quantType); #else } else if (op_type == "Conv2DBackpropInput") { return NewPrimitiveC(prim, inputs, quantType); @@ -686,6 +704,8 @@ PrimitiveC *PrimitiveC::Create(mindspore::schema::PrimitiveT *primitive) { return new Dropout(primitive); case schema::PrimitiveType_Neg: return new Neg(primitive); + case schema::PrimitiveType_RealDiv: + return new RealDiv(primitive); case schema::PrimitiveType_LshProjection: return new LshProjection(primitive); case schema::PrimitiveType_HashtableLookup: @@ -710,6 +730,8 @@ PrimitiveC *PrimitiveC::Create(mindspore::schema::PrimitiveT *primitive) { return new Conv2DGradFilter(primitive); case schema::PrimitiveType_Conv2DGradInput: return new Conv2DGradInput(primitive); + case schema::PrimitiveType_GroupConv2DGradInput: + return new GroupConv2DGradInput(primitive); case schema::PrimitiveType_BiasGrad: return new BiasGrad(primitive); case schema::PrimitiveType_ApplyMomentum: @@ -738,8 +760,11 @@ PrimitiveC *PrimitiveC::Create(mindspore::schema::PrimitiveT *primitive) { return new LogGrad(primitive); case schema::PrimitiveType_Sgd: return new Sgd(primitive); + case schema::PrimitiveType_Adam: + return new Adam(primitive); + case schema::PrimitiveType_Assign: + return new Assign(primitive); #endif - default: MS_LOG(ERROR) << "Unsupported primitive type in Create : " << schema::EnumNamePrimitiveType(op_type); break; @@ -958,6 +983,8 @@ PrimitiveC *PrimitiveC::Create(const schema::Primitive *primitive) { return NewPrimitiveC(primitive); case schema::PrimitiveType_Dropout: return NewPrimitiveC(primitive); + case schema::PrimitiveType_RealDiv: + return NewPrimitiveC(primitive); case schema::PrimitiveType_LshProjection: return NewPrimitiveC(primitive); case schema::PrimitiveType_HashtableLookup: @@ -982,6 +1009,8 @@ PrimitiveC *PrimitiveC::Create(const schema::Primitive *primitive) { return NewPrimitiveC(primitive); case schema::PrimitiveType_Conv2DGradInput: return NewPrimitiveC(primitive); + case schema::PrimitiveType_GroupConv2DGradInput: + return NewPrimitiveC(primitive); case schema::PrimitiveType_BiasGrad: return NewPrimitiveC(primitive); case schema::PrimitiveType_ApplyMomentum: @@ -1004,6 +1033,10 @@ PrimitiveC *PrimitiveC::Create(const schema::Primitive *primitive) { return NewPrimitiveC(primitive); case schema::PrimitiveType_Sgd: return NewPrimitiveC(primitive); + case schema::PrimitiveType_Adam: + return NewPrimitiveC(primitive); + case schema::PrimitiveType_Assign: + return NewPrimitiveC(primitive); #endif default: MS_LOG(ERROR) << "Unsupported primitive type in Create : " << schema::EnumNamePrimitiveType(op_type); diff --git a/mindspore/lite/src/ops/primitive_c.h b/mindspore/lite/src/ops/primitive_c.h index f73442ba67..13532a91ef 100644 --- a/mindspore/lite/src/ops/primitive_c.h +++ b/mindspore/lite/src/ops/primitive_c.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CORE_C_OPS_PRIMITIVE_C_H_ -#define MINDSPORE_CORE_C_OPS_PRIMITIVE_C_H_ +#ifndef MINDSPORE_LITE_SRC_OPS_PRIMITIVE_C_H_ +#define MINDSPORE_LITE_SRC_OPS_PRIMITIVE_C_H_ #include #include #include @@ -48,7 +48,9 @@ constexpr int kAnfPopulaterTwo = 2; constexpr int kAnfPopulaterThree = 3; static std::map kActivationTypeMap{{"ReLU", schema::ActivationType_RELU}, {"ReLU6", schema::ActivationType_RELU6}, - {"Sigmoid", schema::ActivationType_SIGMOID}}; + {"Sigmoid", schema::ActivationType_SIGMOID}, + {"HSwish", schema::ActivationType_HSWISH}, + {"HSigmoid", schema::ActivationType_HSIGMOID}}; class PrimitiveC : public mindspore::Primitive { public: // Argument primitive is deliverd into PrimitiveC and will be deleted in ~PrimitiveC(). @@ -213,4 +215,4 @@ class PrimitiveC { #endif } // namespace lite } // namespace mindspore -#endif // MINDSPORE_CORE_C_OPS_PRIMITIVE_C_H_ +#endif // MINDSPORE_LITE_SRC_OPS_PRIMITIVE_C_H_ diff --git a/mindspore/lite/src/ops/real_div.cc b/mindspore/lite/src/ops/real_div.cc index 792996d573..0de4bc3427 100644 --- a/mindspore/lite/src/ops/real_div.cc +++ b/mindspore/lite/src/ops/real_div.cc @@ -44,7 +44,14 @@ int RealDiv::UnPackAttr(const Primitive &prim, const std::vector &in } #else - +int RealDiv::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto val_offset = schema::CreateRank(*fbb); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_RealDiv, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} #endif } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/real_div.h b/mindspore/lite/src/ops/real_div.h index 47a67753fb..d24647633b 100644 --- a/mindspore/lite/src/ops/real_div.h +++ b/mindspore/lite/src/ops/real_div.h @@ -36,10 +36,7 @@ class RealDiv : public Arithmetic { #else RealDiv() = default; - int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override { - return RET_ERROR; - } - + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif }; } // namespace lite diff --git a/mindspore/lite/src/ops/slice.cc b/mindspore/lite/src/ops/slice.cc index d42faae5b7..6441504a2b 100644 --- a/mindspore/lite/src/ops/slice.cc +++ b/mindspore/lite/src/ops/slice.cc @@ -35,6 +35,66 @@ void Slice::SetFormat(int format) { this->primitive_->value.AsSlice()->format = void Slice::SetBegin(const std::vector &begin) { this->primitive_->value.AsSlice()->begin = begin; } void Slice::SetSize(const std::vector &size) { this->primitive_->value.AsSlice()->size = size; } +int Slice::UnPackAttr(const Primitive &prim, const std::vector &inputs) { + if (this->primitive_ == nullptr) { + this->primitive_ = new (std::nothrow) schema::PrimitiveT; + if (this->primitive_ == nullptr) { + MS_LOG(ERROR) << "new primitiveT failed"; + return RET_ERROR; + } + this->primitive_->value.type = schema::PrimitiveType_Slice; + } + if (this->primitive_->value.type != schema::PrimitiveType_Slice) { + MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type; + return RET_ERROR; + } + if (this->primitive_->value.value == nullptr) { + auto attr = new (std::nothrow) schema::SliceT(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new primitiveT value failed"; + return RET_ERROR; + } + if (inputs.size() >= kAnfPopulaterThree) { + auto beginNode = inputs[kAnfPopulaterOne]; + MS_ASSERT(beginNode != nullptr); + if (beginNode->isa()) { + auto valueNode = beginNode->cast(); + MS_ASSERT(valueNode != nullptr); + auto value = valueNode->value(); + MS_ASSERT(value != nullptr); + if (value->isa()) { + auto valTuplPtr = dyn_cast(value); + MS_ASSERT(valTuplPtr != nullptr); + for (size_t i = 0; i < valTuplPtr->size(); i++) { + auto elem = dyn_cast((*valTuplPtr)[i]); + MS_ASSERT(elem != nullptr); + attr->begin.emplace_back(elem->value()); + } + } + } + auto sizeNode = inputs[kAnfPopulaterTwo]; + MS_ASSERT(sizeNode != nullptr); + if (sizeNode->isa()) { + auto valueNode = sizeNode->cast(); + MS_ASSERT(valueNode != nullptr); + auto value = valueNode->value(); + MS_ASSERT(value != nullptr); + if (value->isa()) { + auto valTuplPtr = dyn_cast(value); + MS_ASSERT(valTuplPtr != nullptr); + for (size_t i = 0; i < valTuplPtr->size(); i++) { + auto elem = dyn_cast((*valTuplPtr)[i]); + MS_ASSERT(elem != nullptr); + attr->size.emplace_back(elem->value()); + } + } + } + } + this->primitive_->value.value = attr; + } + return RET_OK; +} + #else int Slice::GetFormat() const { return this->primitive_->value_as_Slice()->format(); } @@ -46,10 +106,12 @@ std::vector Slice::GetSize() const { auto fb_vector = this->primitive_->value_as_Slice()->size(); return std::vector(fb_vector->begin(), fb_vector->end()); } + std::vector Slice::GetAxes() const { auto fb_vector = this->primitive_->value_as_Slice()->axes(); return std::vector(fb_vector->begin(), fb_vector->end()); } + int Slice::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { MS_ASSERT(nullptr != primitive); MS_ASSERT(nullptr != fbb); @@ -90,7 +152,7 @@ std::vector Slice::GetPostProcessBegin() const { return this->begin; } std::vector Slice::GetPostProcessSize() const { return this->size; } int Slice::InferShape(std::vector inputs, std::vector outputs) { MS_ASSERT(this->primitive_ != nullptr); - if (inputs.size() != kSliceInputNum || outputs.size() != kSliceOutputNum) { + if (inputs.size() < kSliceInputNum || outputs.size() != kSliceOutputNum) { MS_LOG(ERROR) << "input size:" << inputs.size() << ",output size:" << outputs.size(); return RET_PARAM_INVALID; } diff --git a/mindspore/lite/src/ops/slice.h b/mindspore/lite/src/ops/slice.h index 5ad9877b82..eb520be8db 100644 --- a/mindspore/lite/src/ops/slice.h +++ b/mindspore/lite/src/ops/slice.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef LITE_MINDSPORE_LITE_C_OPS_SLICE_H_ -#define LITE_MINDSPORE_LITE_C_OPS_SLICE_H_ +#ifndef MINDSPORE_LITE_SRC_OPS_SLICE_H_ +#define MINDSPORE_LITE_SRC_OPS_SLICE_H_ #include #include @@ -35,6 +35,7 @@ class Slice : public PrimitiveC { void SetFormat(int format); void SetBegin(const std::vector &begin); void SetSize(const std::vector &size); + int UnPackAttr(const Primitive &prim, const std::vector &inputs) override; #else Slice() = default; @@ -56,4 +57,4 @@ class Slice : public PrimitiveC { }; } // namespace lite } // namespace mindspore -#endif // LITE_MINDSPORE_LITE_C_OPS_SLICE_H_ +#endif // MINDSPORE_LITE_SRC_OPS_SLICE_H_ diff --git a/mindspore/lite/src/ops/squeeze.cc b/mindspore/lite/src/ops/squeeze.cc index 37b044a005..522e00206d 100644 --- a/mindspore/lite/src/ops/squeeze.cc +++ b/mindspore/lite/src/ops/squeeze.cc @@ -23,6 +23,35 @@ std::vector Squeeze::GetAxis() const { return this->primitive_->value.AsSqu void Squeeze::SetAxis(const std::vector &axis) { this->primitive_->value.AsSqueeze()->axis = axis; } +int Squeeze::UnPackAttr(const Primitive &prim, const std::vector &inputs) { + if (this->primitive_ == nullptr) { + this->primitive_ = new (std::nothrow) schema::PrimitiveT; + if (this->primitive_ == nullptr) { + MS_LOG(ERROR) << "new primitiveT failed"; + return RET_ERROR; + } + this->primitive_->value.type = schema::PrimitiveType_Squeeze; + } + if (this->primitive_->value.type != schema::PrimitiveType_Squeeze) { + MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type; + return RET_ERROR; + } + if (this->primitive_->value.value == nullptr) { + auto attr = new (std::nothrow) schema::SqueezeT(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new primitiveT value failed"; + return RET_ERROR; + } + attr->axis = GetValue>(prim.GetAttr("axis")); + this->primitive_->value.value = attr; + if (this->primitive_->value.value == nullptr) { + MS_LOG(ERROR) << "primitive value is nullptr"; + return RET_ERROR; + } + } + return RET_OK; +} + #else std::vector Squeeze::GetAxis() const { diff --git a/mindspore/lite/src/ops/squeeze.h b/mindspore/lite/src/ops/squeeze.h index 31ceb5ce6f..6929e5035d 100644 --- a/mindspore/lite/src/ops/squeeze.h +++ b/mindspore/lite/src/ops/squeeze.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef LITE_MINDSPORE_LITE_C_OPS_SQUEEZE_H_ -#define LITE_MINDSPORE_LITE_C_OPS_SQUEEZE_H_ +#ifndef MINDSPORE_LITE_SRC_OPS_SQUEEZE_H_ +#define MINDSPORE_LITE_SRC_OPS_SQUEEZE_H_ #include #include @@ -33,6 +33,7 @@ class Squeeze : public PrimitiveC { Squeeze() = default; explicit Squeeze(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} void SetAxis(const std::vector &axis); + int UnPackAttr(const Primitive &prim, const std::vector &inputs) override; #else Squeeze() = default; @@ -45,4 +46,4 @@ class Squeeze : public PrimitiveC { } // namespace lite } // namespace mindspore -#endif // LITE_MINDSPORE_LITE_C_OPS_SQUEEZE_H_ +#endif // MINDSPORE_LITE_SRC_OPS_SQUEEZE_H_ diff --git a/mindspore/lite/src/ops/tile.cc b/mindspore/lite/src/ops/tile.cc index af3b4af4f2..4779384d6f 100644 --- a/mindspore/lite/src/ops/tile.cc +++ b/mindspore/lite/src/ops/tile.cc @@ -24,7 +24,7 @@ std::vector Tile::GetMultiples() const { return this->primitive_->value.AsT void Tile::SetMultiples(const std::vector &multiples) { this->primitive_->value.AsTile()->multiples = multiples; } -std::vector Tile::GetDims() const { return this->primitive_->value.AsTile()->multiples; } +std::vector Tile::GetDims() const { return this->primitive_->value.AsTile()->dims; } void Tile::SetDims(const std::vector &dims) { this->primitive_->value.AsTile()->dims = dims; } @@ -42,11 +42,32 @@ int Tile::UnPackAttr(const Primitive &prim, const std::vector &input return RET_ERROR; } if (this->primitive_->value.value == nullptr) { - this->primitive_->value.value = new (std::nothrow) schema::TileT(); - if (this->primitive_->value.value == nullptr) { + auto attr = new (std::nothrow) schema::TileT(); + + if (attr == nullptr) { MS_LOG(ERROR) << "new primitiveT value failed"; return RET_ERROR; } + if (inputs.size() == kAnfPopulaterTwo) { + auto inputNode = inputs[kAnfPopulaterOne]; + MS_ASSERT(inputNode != nullptr); + if (inputNode->isa()) { + auto valueNode = inputNode->cast(); + MS_ASSERT(valueNode != nullptr); + auto value = valueNode->value(); + MS_ASSERT(value != nullptr); + if (value->isa()) { + auto valTuplPtr = dyn_cast(value); + MS_ASSERT(valTuplPtr != nullptr); + for (size_t i = 0; i < valTuplPtr->size(); i++) { + auto elem = dyn_cast((*valTuplPtr)[i]); + MS_ASSERT(elem != nullptr); + attr->multiples.emplace_back(elem->value()); + } + } + } + } + this->primitive_->value.value = attr; } return RET_OK; } @@ -103,15 +124,19 @@ int Tile::InferShape(std::vector inputs_, std::vector output MS_ASSERT(tile_prim != nullptr); std::vector out_shape; - std::vector multiples; - for (size_t i = 0; i < GetMultiples().size(); ++i) { - multiples.push_back(GetMultiples()[i]); + std::vector multiples = GetMultiples(); + const size_t in_dims = input->shape().size(); + const size_t delta_dims = in_dims - multiples.size(); + + size_t i = 0; + for (; i < delta_dims; ++i) { + int tmp = input->shape()[i]; + out_shape.push_back(tmp); } - for (size_t i = 0; i < input->shape().size(); ++i) { - int tmp = input->shape()[i] * multiples[i]; + for (; i < in_dims; ++i) { + int tmp = input->shape()[i] * (multiples[i - delta_dims]); out_shape.push_back(tmp); } - output->set_shape(out_shape); return RET_OK; } diff --git a/mindspore/lite/src/populate_parameter.cc b/mindspore/lite/src/populate_parameter.cc index fd7fb83ba1..7509f1ebdb 100644 --- a/mindspore/lite/src/populate_parameter.cc +++ b/mindspore/lite/src/populate_parameter.cc @@ -43,6 +43,7 @@ #include "src/ops/add.h" #include "src/ops/sub.h" #include "src/ops/div.h" +#include "src/ops/real_div.h" #include "src/ops/bias_add.h" #include "src/ops/expand_dims.h" #include "src/ops/full_connection.h" @@ -1680,6 +1681,7 @@ PopulateParameterRegistry::PopulateParameterRegistry() { populate_parameter_funcs_[schema::PrimitiveType_Add] = PopulateArithmetic; populate_parameter_funcs_[schema::PrimitiveType_Sub] = PopulateArithmetic; populate_parameter_funcs_[schema::PrimitiveType_Div] = PopulateArithmetic; + populate_parameter_funcs_[schema::PrimitiveType_RealDiv] = PopulateArithmetic; populate_parameter_funcs_[schema::PrimitiveType_LogicalAnd] = PopulateArithmetic; populate_parameter_funcs_[schema::PrimitiveType_LogicalOr] = PopulateArithmetic; populate_parameter_funcs_[schema::PrimitiveType_Equal] = PopulateArithmetic; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/activation.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/activation.cc index 44a888b9fb..71ecc40501 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/activation.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/activation.cc @@ -24,6 +24,7 @@ using mindspore::kernel::KERNEL_ARCH::kCPU; using mindspore::lite::KernelRegistrar; using mindspore::lite::RET_ERROR; using mindspore::lite::RET_OK; +using mindspore::schema::ActivationType_HSIGMOID; using mindspore::schema::ActivationType_HSWISH; using mindspore::schema::ActivationType_LEAKY_RELU; using mindspore::schema::ActivationType_RELU; @@ -57,6 +58,8 @@ int ActivationCPUKernel::DoActivation(int task_id) { error_code = Tanh(input_addr + stride * task_id, count, output_addr + stride * task_id); } else if (type_ == schema::ActivationType_HSWISH) { error_code = HSwish(input_addr + stride * task_id, count, output_addr + stride * task_id); + } else if (type_ == schema::ActivationType_HSIGMOID) { + error_code = HSigmoid(input_addr + stride * task_id, count, output_addr + stride * task_id); } else if (type_ == schema::ActivationType_HARD_TANH) { error_code = HardTanh(input_addr + stride * task_id, count, output_addr + stride * task_id, min_val_, max_val_); } else { diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/addn.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/addn.cc index 4aa4febfa2..d895f05167 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/addn.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/addn.cc @@ -60,14 +60,34 @@ int AddNCPUKernel::Run() { MS_LOG(ERROR) << "Prepare fail!ret: " << ret; return ret; } - elements_num_ = in_tensors_[0]->ElementsNum(); + elements_num_ = out_tensors_[0]->ElementsNum(); auto input0_data = reinterpret_cast(in_tensors_[0]->MutableData()); auto input1_data = reinterpret_cast(in_tensors_[1]->MutableData()); auto output_data = reinterpret_cast(out_tensors_[0]->MutableData()); if (static_cast(elements_num_) < op_parameter_->thread_num_) { - ElementAdd(input0_data, input1_data, output_data, elements_num_); + if (in_tensors_[0]->shape() == in_tensors_[1]->shape()) { + ElementAdd(input0_data, input1_data, output_data, elements_num_); + } else { + ArithmeticParameter param; + param.in_elements_num0_ = in_tensors_[0]->ElementsNum(); + param.in_elements_num1_ = in_tensors_[1]->ElementsNum(); + param.out_elements_num_ = out_tensors_[0]->ElementsNum(); + param.broadcasting_ = true; + ElementOptAdd(input0_data, input1_data, output_data, elements_num_, ¶m); + } + for (size_t i = 2; i < in_tensors_.size(); ++i) { - ElementAdd(reinterpret_cast(in_tensors_[i]->MutableData()), output_data, output_data, elements_num_); + if (in_tensors_[i]->shape() == out_tensors_[0]->shape()) { + ElementAdd(reinterpret_cast(in_tensors_[i]->MutableData()), output_data, output_data, elements_num_); + } else { + ArithmeticParameter param; + param.in_elements_num0_ = in_tensors_[i]->ElementsNum(); + param.in_elements_num1_ = out_tensors_[0]->ElementsNum(); + param.out_elements_num_ = out_tensors_[0]->ElementsNum(); + param.broadcasting_ = true; + ElementOptAdd(reinterpret_cast(in_tensors_[i]->MutableData()), output_data, output_data, elements_num_, + ¶m); + } } return RET_OK; } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic.cc index 5e91a446d2..042b2e39e3 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic.cc @@ -104,6 +104,7 @@ int ArithmeticCPUKernel::ReSize() { } break; case PrimitiveType_Div: + case PrimitiveType_RealDiv: switch (arithmeticParameter_->activation_type_) { case schema::ActivationType_RELU: arithmeticParameter_->broadcasting_ = false; @@ -304,6 +305,7 @@ REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Add, CpuArithmeticFp32KernelC REG_KERNEL(kCPU, kNumberTypeInt, PrimitiveType_Add, CpuArithmeticFp32KernelCreator) REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Sub, CpuArithmeticFp32KernelCreator) REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Div, CpuArithmeticFp32KernelCreator) +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_RealDiv, CpuArithmeticFp32KernelCreator) REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_LogicalAnd, CpuArithmeticFp32KernelCreator) REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_LogicalOr, CpuArithmeticFp32KernelCreator) REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Maximum, CpuArithmeticFp32KernelCreator) diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic.h b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic.h index 2c4ed55516..d11d13d2b1 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic.h @@ -37,6 +37,7 @@ using mindspore::schema::PrimitiveType_Maximum; using mindspore::schema::PrimitiveType_Minimum; using mindspore::schema::PrimitiveType_Mul; using mindspore::schema::PrimitiveType_NotEqual; +using mindspore::schema::PrimitiveType_RealDiv; using mindspore::schema::PrimitiveType_SquaredDifference; using mindspore::schema::PrimitiveType_Sub; @@ -99,6 +100,7 @@ class ArithmeticCPUKernel : public LiteKernel { } break; case PrimitiveType_Div: + case PrimitiveType_RealDiv: switch (arithmeticParameter_->activation_type_) { case schema::ActivationType_RELU: arithmetic_run_ = ElementDivRelu; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/activation_grad.cc b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/activation_grad.cc index 0d14a25b3c..cb94fab7a0 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/activation_grad.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/activation_grad.cc @@ -48,7 +48,7 @@ int ActivationGradCPUKernel::DoActivation(int task_id) { auto output_addr = reinterpret_cast(out_tensors_.at(0)->MutableData()); int length = in_tensors_.at(0)->ElementsNum(); - int stride = UP_DIV(length, thread_count_); + int stride = UP_DIV(length, 1); int count = MSMIN(stride, length - stride * task_id); auto error_code = RET_OK; @@ -63,8 +63,9 @@ int ActivationGradCPUKernel::DoActivation(int task_id) { error_code = LReluGrad(yt_addr + stride * task_id, input_addr + stride * task_id, count, output_addr + stride * task_id, param_act_grad_->alpha_); } else if (param_act_grad_->type_ == schema::ActivationType_SIGMOID) { + // Sigmoid gets the input tensors in reverse order! error_code = - SigmoidGrad(yt_addr + stride * task_id, input_addr + stride * task_id, count, output_addr + stride * task_id); + SigmoidGrad(input_addr + stride * task_id, yt_addr + stride * task_id, count, output_addr + stride * task_id); } else if (param_act_grad_->type_ == schema::ActivationType_TANH) { error_code = TanhGrad(yt_addr + stride * task_id, input_addr + stride * task_id, count, output_addr + stride * task_id); diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/activation_grad.h b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/activation_grad.h index e19b526ff1..1d0083231b 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/activation_grad.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/activation_grad.h @@ -27,7 +27,7 @@ class ActivationGradCPUKernel : public LiteKernel { explicit ActivationGradCPUKernel(OpParameter *param, const std::vector &inputs, const std::vector &outputs, const lite::InnerContext *ctx, const mindspore::lite::PrimitiveC *primitive) - : LiteKernel(param, inputs, outputs, ctx, primitive), thread_count_(ctx->thread_num_) { + : LiteKernel(param, inputs, outputs, ctx, primitive) { param_act_grad_ = reinterpret_cast(param); } ~ActivationGradCPUKernel() override = default; @@ -38,7 +38,6 @@ class ActivationGradCPUKernel : public LiteKernel { int DoActivation(int task_id); private: - int thread_count_; ActivationParameter *param_act_grad_; }; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/adam.cc b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/adam.cc new file mode 100644 index 0000000000..7956e571cc --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/adam.cc @@ -0,0 +1,118 @@ + +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/fp32_grad/adam.h" +#include +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "include/errorcode.h" +#include "src/runtime/runtime_api.h" +#include "src/runtime/kernel/arm/fp32/nchw2nhwc.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_Adam; + +namespace mindspore::kernel { + +int AdamCPUKernel::ReSize() { return RET_OK; } + +int AdamCPUKernel::Execute(int task_id) { + auto weight = reinterpret_cast(in_tensors_[0]->MutableData()); + auto m = reinterpret_cast(in_tensors_[1]->MutableData()); + auto v = reinterpret_cast(in_tensors_[2]->MutableData()); + auto beta1_power = reinterpret_cast(in_tensors_[3]->MutableData())[0]; + auto beta2_power = reinterpret_cast(in_tensors_[4]->MutableData())[0]; + auto learning_rate = reinterpret_cast(in_tensors_[5]->MutableData())[0]; + auto beta1 = reinterpret_cast(in_tensors_[6]->MutableData())[0]; + auto beta2 = reinterpret_cast(in_tensors_[7]->MutableData())[0]; + auto eps = reinterpret_cast(in_tensors_[8]->MutableData())[0]; + auto gradient = reinterpret_cast(in_tensors_[9]->MutableData()); + size_t elem_num = in_tensors_[0]->ElementsNum(); + + if (adam_param_->use_nesterov_) { // Nadam + for (size_t i = 0; i < elem_num; ++i) { + m[i] = (m[i] * beta1) + (gradient[i] * (1.f - beta1)); + v[i] = (v[i] * beta2) + (gradient[i] * gradient[i] * (1.f - beta2)); + auto g_hat = gradient[i] / (1 - beta1_power); + auto m_hat = m[i] / (1 - beta1_power); + auto v_hat = v[i] / (1 - beta2_power); + auto m_tag = (1.f - beta1) * g_hat + beta1 * m_hat; + weight[i] -= learning_rate * m_tag / (sqrtf(v_hat) + eps); + } + } else { + for (size_t i = 0; i < elem_num; ++i) { + m[i] = (m[i] * beta1) + (gradient[i] * (1.f - beta1)); + v[i] = (v[i] * beta2) + (gradient[i] * gradient[i] * (1.f - beta2)); + auto m_hat = m[i] / (1 - beta1_power); + auto v_hat = v[i] / (1 - beta2_power); + weight[i] -= learning_rate * m_hat / (sqrtf(v_hat) + eps); + } + } + return RET_OK; +} + +int AdamRun(void *cdata, int task_id) { + auto Adam_kernel = reinterpret_cast(cdata); + auto error_code = Adam_kernel->Execute(task_id); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "Adam run error task_id[" << task_id << "] error_code[" << error_code << "]"; + return RET_ERROR; + } + return RET_OK; +} + +int AdamCPUKernel::Run() { + auto prepare_ret = Prepare(); + if (prepare_ret != RET_OK) { + MS_LOG(ERROR) << "AdamCPUKernel Prepare fail!ret: " << prepare_ret; + return prepare_ret; + } + + int error_code = ParallelLaunch(this->context_->thread_pool_, AdamRun, this, 1); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "Adam function error error_code[" << error_code << "]"; + return RET_ERROR; + } + return RET_OK; +} + +int AdamCPUKernel::Init() { return RET_OK; } + +kernel::LiteKernel *CpuAdamFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, OpParameter *opParameter, + const lite::InnerContext *ctx, const kernel::KernelKey &desc, + const lite::PrimitiveC *primitive) { + MS_ASSERT(desc.type == schema::PrimitiveType_Adam); + auto *kernel = new (std::nothrow) AdamCPUKernel(opParameter, inputs, outputs, ctx, primitive); + MS_ASSERT(kernel != nullptr); + + auto ret = kernel->Init(); + if (0 != ret) { + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + delete kernel; + return nullptr; + } + + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Adam, CpuAdamFp32KernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/adam.h b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/adam.h new file mode 100644 index 0000000000..66a387c5cf --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/adam.h @@ -0,0 +1,44 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_ADAM_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_ADAM_H_ + +#include +#include "src/lite_kernel.h" +#include "nnacl/fp32_grad/optimizer.h" + +namespace mindspore::kernel { +class AdamCPUKernel : public LiteKernel { + public: + explicit AdamCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const lite::InnerContext *ctx, + const mindspore::lite::PrimitiveC *primitive) + : LiteKernel(parameter, inputs, outputs, ctx, primitive) { + adam_param_ = reinterpret_cast(parameter); + } + ~AdamCPUKernel() override {} + int Init() override; + int ReSize() override; + int Run() override; + int Execute(int task_id); + + private: + AdamParameter *adam_param_; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_ADAM_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/apply_momentum.cc b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/apply_momentum.cc index 527b260fd7..9cc9e85d5a 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/apply_momentum.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/apply_momentum.cc @@ -79,14 +79,7 @@ int ApplyMomentumCPUKernel::Run() { return RET_OK; } -int ApplyMomentumCPUKernel::Init() { - // Only for test with uninitialized Data - size_t elem_num = in_tensors_[0]->ElementsNum(); - auto accumulate = reinterpret_cast(in_tensors_[1]->MutableData()); - for (size_t i = 0; i < elem_num; i++) accumulate[i] = 0.0; - - return RET_OK; -} +int ApplyMomentumCPUKernel::Init() { return RET_OK; } kernel::LiteKernel *CpuApplyMomentumFp32KernelCreator(const std::vector &inputs, const std::vector &outputs, diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/assign.cc b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/assign.cc new file mode 100644 index 0000000000..862c203b43 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/assign.cc @@ -0,0 +1,91 @@ + +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/fp32_grad/assign.h" +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "include/errorcode.h" +#include "src/runtime/runtime_api.h" +#include "src/runtime/kernel/arm/fp32/nchw2nhwc.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_Assign; + +namespace mindspore::kernel { + +int AssignCPUKernel::ReSize() { return RET_OK; } + +int AssignCPUKernel::Execute(int task_id) { + auto x = reinterpret_cast(in_tensors_[0]->MutableData()); + auto y = reinterpret_cast(in_tensors_[1]->MutableData()); + size_t size = in_tensors_[0]->Size(); + + memcpy(x, y, size); + return RET_OK; +} + +int AssignRun(void *cdata, int task_id) { + auto Assign_kernel = reinterpret_cast(cdata); + auto error_code = Assign_kernel->Execute(task_id); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "assign run error task_id[" << task_id << "] error_code[" << error_code << "]"; + return RET_ERROR; + } + return RET_OK; +} + +int AssignCPUKernel::Run() { + auto prepare_ret = Prepare(); + if (prepare_ret != RET_OK) { + MS_LOG(ERROR) << "AssignCPUKernel Prepare fail!ret: " << prepare_ret; + return prepare_ret; + } + + int error_code = ParallelLaunch(this->context_->thread_pool_, AssignRun, this, 1); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "Assign function error error_code[" << error_code << "]"; + return RET_ERROR; + } + return RET_OK; +} + +int AssignCPUKernel::Init() { return RET_OK; } + +kernel::LiteKernel *CpuAssignFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, OpParameter *opParameter, + const lite::InnerContext *ctx, const kernel::KernelKey &desc, + const lite::PrimitiveC *primitive) { + MS_ASSERT(desc.type == schema::PrimitiveType_Assign); + auto *kernel = new (std::nothrow) AssignCPUKernel(opParameter, inputs, outputs, ctx, primitive); + MS_ASSERT(kernel != nullptr); + + auto ret = kernel->Init(); + if (0 != ret) { + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + delete kernel; + return nullptr; + } + + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Assign, CpuAssignFp32KernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/assign.h b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/assign.h new file mode 100644 index 0000000000..dd2575e62a --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/assign.h @@ -0,0 +1,39 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_ASSIGN_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_ASSIGN_H_ + +#include +#include "src/lite_kernel.h" +#include "nnacl/fp32_grad/optimizer.h" + +namespace mindspore::kernel { +class AssignCPUKernel : public LiteKernel { + public: + explicit AssignCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const lite::InnerContext *ctx, + const mindspore::lite::PrimitiveC *primitive) + : LiteKernel(parameter, inputs, outputs, ctx, primitive) {} + ~AssignCPUKernel() override {} + int Init() override; + int ReSize() override; + int Run() override; + int Execute(int task_id); +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_ASSIGN_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/convolution_grad_input.cc b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/convolution_grad_input.cc index 27ae5958b0..f7b53662a6 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/convolution_grad_input.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/convolution_grad_input.cc @@ -27,6 +27,7 @@ using mindspore::lite::KernelRegistrar; using mindspore::lite::RET_ERROR; using mindspore::lite::RET_OK; using mindspore::schema::PrimitiveType_Conv2DGradInput; +using mindspore::schema::PrimitiveType_GroupConv2DGradInput; namespace mindspore::kernel { int ConvolutionGradInputCPUKernel::Init() { @@ -134,7 +135,8 @@ kernel::LiteKernel *CpuConvGradInputFp32KernelCreator(const std::vectorbuf = reinterpret_cast(malloc(size)); if (model->buf == nullptr) { + delete model; MS_LOG(ERROR) << "new inner model buf fail!"; return nullptr; } @@ -48,6 +49,8 @@ TrainModel *TrainModel::Import(const char *model_buf, size_t size) { model->buf_size_ = size; auto meta_graph = schema::GetMetaGraph(model->buf); if (meta_graph == nullptr) { + delete model; + free(model->buf); MS_LOG(ERROR) << "meta_graph is nullptr!"; return nullptr; } diff --git a/mindspore/lite/src/train/train_populate_parameter.cc b/mindspore/lite/src/train/train_populate_parameter.cc index b2f8e318f3..f856c406aa 100644 --- a/mindspore/lite/src/train/train_populate_parameter.cc +++ b/mindspore/lite/src/train/train_populate_parameter.cc @@ -24,6 +24,7 @@ #include "nnacl/fp32/activation.h" #include "src/ops/conv2d_grad_filter.h" #include "src/ops/conv2d_grad_input.h" +#include "src/ops/group_conv2d_grad_input.h" #include "nnacl/conv_parameter.h" #include "src/ops/power_grad.h" #include "nnacl/power_parameter.h" @@ -34,6 +35,7 @@ #include "src/ops/sgd.h" #include "src/ops/bn_grad.h" #include "nnacl/fp32_grad/batch_norm.h" +#include "src/ops/adam.h" namespace mindspore::kernel { @@ -69,12 +71,29 @@ OpParameter *PopulateApplyMomentumParameter(const mindspore::lite::PrimitiveC *p reinterpret_cast(const_cast(primitive)); p->grad_scale_ = apply_momentum_primitive->GetGradientScale(); - p->use_locking_ = apply_momentum_primitive->GetUseLocking(); p->use_nesterov_ = apply_momentum_primitive->GetUseNesterov(); return reinterpret_cast(p); } +OpParameter *PopulateAdamParameter(const mindspore::lite::PrimitiveC *primitive) { + if (primitive == nullptr) { + MS_LOG(ERROR) << "Primitive is nullptr when populating parameter for op."; + return nullptr; + } + AdamParameter *p = reinterpret_cast(malloc(sizeof(AdamParameter))); + if (p == nullptr) { + MS_LOG(ERROR) << "new AdamParameter failed."; + return nullptr; + } + p->op_parameter_.type_ = primitive->Type(); + + auto apply_momentum_primitive = + reinterpret_cast(const_cast(primitive)); + p->use_nesterov_ = apply_momentum_primitive->GetUseNesterov(); + return reinterpret_cast(p); +} + OpParameter *PopulateSgdParameter(const mindspore::lite::PrimitiveC *primitive) { if (primitive == nullptr) { MS_LOG(ERROR) << "Primitive is nullptr when populating parameter for op."; @@ -264,6 +283,47 @@ OpParameter *PopulateConvolutionGradInputParameter(const mindspore::lite::Primit return reinterpret_cast(param); } +OpParameter *PopulateGroupConvolutionGradInputParameter(const mindspore::lite::PrimitiveC *primitive) { + if (primitive == nullptr) { + MS_LOG(ERROR) << "Primitive is nullptr when populating parameter for op."; + return nullptr; + } + + ConvParameter *param = reinterpret_cast(malloc(sizeof(ConvParameter))); + if (param == nullptr) { + MS_LOG(ERROR) << "new Param for conv grad filter failed."; + return nullptr; + } + param->op_parameter_.type_ = primitive->Type(); + + auto convg_primitive = + reinterpret_cast(const_cast(primitive)); + param->kernel_h_ = convg_primitive->GetKernelH(); + param->kernel_w_ = convg_primitive->GetKernelW(); + param->stride_h_ = convg_primitive->GetStrideH(); + param->stride_w_ = convg_primitive->GetStrideW(); + param->dilation_h_ = convg_primitive->GetDilateH(); + param->dilation_w_ = convg_primitive->GetDilateW(); + param->pad_u_ = convg_primitive->GetPadUp(); + param->pad_d_ = convg_primitive->GetPadDown(); + param->pad_l_ = convg_primitive->GetPadLeft(); + param->pad_r_ = convg_primitive->GetPadRight(); + param->group_ = convg_primitive->GetGroup(); + param->act_type_ = ActType_No; + switch (convg_primitive->GetActivationType()) { + case schema::ActivationType_RELU: + param->act_type_ = ActType_Relu; + break; + case schema::ActivationType_RELU6: + param->act_type_ = ActType_Relu6; + break; + default: + break; + } + + return reinterpret_cast(param); +} + OpParameter *PopulatePowerGradParameter(const mindspore::lite::PrimitiveC *primitive) { if (primitive == nullptr) { MS_LOG(ERROR) << "Primitive is nullptr when populating parameter for op."; @@ -327,10 +387,13 @@ void PopulateTrainParameters() { ppr->AddPopulateParameterFunc(schema::PrimitiveType_BNGrad, DefaultPopulateParameter); ppr->AddPopulateParameterFunc(schema::PrimitiveType_Conv2DGradFilter, PopulateConvolutionGradFilterParameter); ppr->AddPopulateParameterFunc(schema::PrimitiveType_Conv2DGradInput, PopulateConvolutionGradInputParameter); + ppr->AddPopulateParameterFunc(schema::PrimitiveType_GroupConv2DGradInput, PopulateGroupConvolutionGradInputParameter); ppr->AddPopulateParameterFunc(schema::PrimitiveType_PoolingGrad, PopulatePoolingGradParameter); ppr->AddPopulateParameterFunc(schema::PrimitiveType_PowerGrad, PopulatePowerGradParameter); ppr->AddPopulateParameterFunc(schema::PrimitiveType_Sgd, PopulateSgdParameter); ppr->AddPopulateParameterFunc(schema::PrimitiveType_BNGrad, PopulateBNGradParameter); + ppr->AddPopulateParameterFunc(schema::PrimitiveType_Adam, PopulateAdamParameter); + ppr->AddPopulateParameterFunc(schema::PrimitiveType_Assign, DefaultPopulateParameter); } } // namespace mindspore::kernel diff --git a/mindspore/lite/src/train/train_session.cc b/mindspore/lite/src/train/train_session.cc index 5486d2f221..de4cb9200c 100644 --- a/mindspore/lite/src/train/train_session.cc +++ b/mindspore/lite/src/train/train_session.cc @@ -104,23 +104,15 @@ int TrainSession::RunGraph(const session::KernelCallBack &before, const session: for (auto ms_tensor : ms_tensors.second) this->outputs_.push_back((static_cast(ms_tensor))); if (train_mode_) return lite::LiteSession::RunGraph(before, after); - // object is expected to run only inference part of graph - // prepare a list of kernels till the loss function -- temporary solution - std::vector inference_kernels; - for (auto kernel : this->kernels_) { - if (IsLossKernel(kernel)) break; - inference_kernels.push_back(kernel); - } - if (this->context_ == nullptr) { MS_LOG(ERROR) << "context is null"; return lite::RET_NULL_PTR; } lite::Executor executor; if (before == nullptr && after == nullptr) { - return executor.Run(this->inputs_, this->outputs_, inference_kernels, this->context_->allocator.get()); + return executor.Run(this->inputs_, this->outputs_, inference_kernels_, this->context_->allocator.get()); } else { - return executor.Run(this->inputs_, this->outputs_, inference_kernels, this->context_->allocator.get(), before, + return executor.Run(this->inputs_, this->outputs_, inference_kernels_, this->context_->allocator.get(), before, after); } } @@ -173,6 +165,38 @@ void TrainSession::Eval() { } } } + if (inference_kernels_.size() == 0) { + BuildInferenceKernelsMap(); + } +} + +void TrainSession::BuildInferenceKernelsRecursive(kernel::LiteKernel *kernel, std::vector *v) { + if (std::find(v->begin(), v->end(), kernel) == v->end()) { // kernel is not in vector + v->push_back(kernel); + for (auto in_node : kernel->in_kernels()) { + BuildInferenceKernelsRecursive(in_node, v); + } + } +} + +void TrainSession::BuildInferenceKernelsMap() { + std::vector req_kernels; + for (auto kernel : this->kernels_) { + if (IsLossKernel(kernel)) { // For each loss in the system add backward tree + for (auto in_node : kernel->in_kernels()) { + BuildInferenceKernelsRecursive(in_node, &req_kernels); + } + } + } + inference_kernels_.clear(); + for (auto kernel : this->kernels_) { + if (std::find(req_kernels.begin(), req_kernels.end(), kernel) != req_kernels.end()) { + inference_kernels_.push_back(kernel); + } + } + if (inference_kernels_.size() == 0) { + inference_kernels_ = this->kernels_; + } } bool TrainSession::IsLossKernel(kernel::LiteKernel *kernel) { diff --git a/mindspore/lite/src/train/train_session.h b/mindspore/lite/src/train/train_session.h index 226497aa71..6cd4c68d53 100644 --- a/mindspore/lite/src/train/train_session.h +++ b/mindspore/lite/src/train/train_session.h @@ -82,12 +82,16 @@ class TrainSession : virtual public session::TrainSession, virtual public lite:: protected: void AllocWorkSpace(); + bool IsLossKernel(kernel::LiteKernel *kernel); virtual std::vector ReplaceOps(); virtual void RestoreOps(const std::vector &restore); - bool IsLossKernel(kernel::LiteKernel *kernel); + virtual void BuildInferenceKernelsMap(); + virtual void BuildInferenceKernelsRecursive(kernel::LiteKernel *ker, std::vector *req_kernels); + TrainModel *model_ = nullptr; std::unordered_map> orig_output_map_; std::unordered_map orig_output_tensor_map_; + std::vector inference_kernels_; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/test/common/common_test.h b/mindspore/lite/test/common/common_test.h index 1c848ed724..c8bde4d69d 100644 --- a/mindspore/lite/test/common/common_test.h +++ b/mindspore/lite/test/common/common_test.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef TESTS_UT_COMMON_UT_COMMON_H_ -#define TESTS_UT_COMMON_UT_COMMON_H_ +#ifndef MINDSPORE_LITE_TEST_COMMON_COMMON_TEST_H_ +#define MINDSPORE_LITE_TEST_COMMON_COMMON_TEST_H_ #include #include @@ -37,11 +37,11 @@ class CommonTest : public testing::Test { void PrintData(std::string name, T *output_data, int size) { std::cout << "The " << name << " is as follows:" << std::endl; if (typeid(output_data[0]) == typeid(uint8_t) || typeid(output_data[0]) == typeid(int8_t)) { - for (size_t i = 0; i < std::min(size, 100); i++) { + for (int i = 0; i < std::min(size, 100); i++) { std::cout << static_cast(output_data[i]) << " "; } } else { - for (size_t i = 0; i < std::min(size, 100); i++) { + for (int i = 0; i < std::min(size, 100); i++) { std::cout << output_data[i] << " "; } } @@ -58,7 +58,7 @@ class CommonTest : public testing::Test { void CompareOutputInt8(int8_t *output_data, int8_t *correct_data, int size, float err_percent) { int bias_count = 0; - for (size_t i = 0; i < size; i++) { + for (int i = 0; i < size; i++) { int8_t diff = abs(output_data[i] - correct_data[i]); ASSERT_LE(diff, 1); if (diff == 1) { @@ -88,4 +88,4 @@ class CommonTest : public testing::Test { } }; } // namespace mindspore -#endif // TESTS_UT_COMMON_UT_COMMON_H_ +#endif // MINDSPORE_LITE_TEST_COMMON_COMMON_TEST_H_ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/network_test.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/network_test.cc index 47823dac44..0c4b6367a9 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/network_test.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/network_test.cc @@ -250,7 +250,7 @@ TEST_F(NetworkTest, tuning_layer) { auto label = std::make_unique(); label->nodeType = schema::NodeType::NodeType_ValueNode; label->format = schema::Format_NHWC; - label->dataType = TypeId::kNumberTypeInt32; + label->dataType = TypeId::kNumberTypeFloat32; label->dims = {BATCH_SIZE * NUM_CLASSES}; label->offset = -1; meta_graph->allTensors.emplace_back(std::move(label)); @@ -386,8 +386,10 @@ TEST_F(NetworkTest, tuning_layer) { auto labelTensor = inputs.at(1); ASSERT_NE(nullptr, labelTensor); ASSERT_EQ(BATCH_SIZE * NUM_CLASSES, labelTensor->ElementsNum()); - auto labels = reinterpret_cast(labelTensor->MutableData()); - for (int i = 0; i < BATCH_SIZE; i++) labels[i] = (i * 97) % NUM_CLASSES; + + auto labels = reinterpret_cast(labelTensor->MutableData()); + std::fill(labels, labels + labelTensor->ElementsNum(), 0.f); + for (int i = 0; i < BATCH_SIZE; i++) labels[i * NUM_CLASSES + (i * 97) % NUM_CLASSES] = 1.0; ret = session->RunGraph(); ASSERT_EQ(lite::RET_OK, ret); @@ -576,12 +578,12 @@ TEST_F(NetworkTest, lenetnet) { delete context; ASSERT_EQ(res, 0); } -#if 0 + TEST_F(NetworkTest, retina_net) { char *buf = nullptr; size_t net_size = 0; - std::string net = "./test_data/nets/retinaface1009.ms"; + std::string net = "./test_data/nets/retinaface1.ms"; ReadFile(net.c_str(), &net_size, &buf); // auto model = lite::TrainModel::Import(buf, net_size); auto model = lite::Model::Import(buf, net_size); @@ -598,26 +600,36 @@ TEST_F(NetworkTest, retina_net) { ASSERT_EQ(lite::RET_OK, ret); // session->Eval(); - std::string in = "./test_data/nets/retinaface_input.f32"; + std::string in = "./test_data/nets/test1.hwc_normalized_f32"; std::cout << "----- Output 0 -----" << std::endl; - std::string out = "./test_data/nets/retinaface_out_0.f32"; + std::string out = "./test_data/nets/test1_loc.f32"; + int final_res = 0; auto res = runNet(session, in, out, "448", true); - ASSERT_EQ(res, 0); + // ASSERT_EQ(res, 0); + if (res != 0) { + final_res = res; + } std::cout << "----- Output 1 -----" << std::endl; - out = "./test_data/nets/retinaface_out_1.f32"; + out = "./test_data/nets/test1_conf.f32"; res = runNet(session, in, out, "435", true); - ASSERT_EQ(res, 0); - + // ASSERT_EQ(res, 0); + if (res != 0) { + final_res |= res; + } std::cout << "----- Output 2 -----" << std::endl; - out = "./test_data/nets/retinaface_out_2.f32"; + out = "./test_data/nets/test1_landms.f32"; res = runNet(session, in, out, "421", true); - ASSERT_EQ(res, 0); + if (res != 0) { + final_res |= res; + } + + ASSERT_EQ(final_res, 0); delete session; delete context; } -#endif + TEST_F(NetworkTest, mobileface_net) { char *buf = nullptr; size_t net_size = 0; diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/softmax_crossentropy_fp32_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/softmax_crossentropy_fp32_tests.cc index 650d5be587..14c5777f58 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/softmax_crossentropy_fp32_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/softmax_crossentropy_fp32_tests.cc @@ -41,10 +41,11 @@ TEST_F(TestSoftmaxCrossEntropyFp32, SoftmaxCrossEntropyFp32) { std::string label_path = "./test_data/operators/sce_fp32_1_l_6.bin"; auto ll_labels = reinterpret_cast(mindspore::lite::ReadFile(label_path.c_str(), &input_size)); - auto labels = new int[6]; - for (int i = 0; i < 6; i++) labels[i] = static_cast(ll_labels[i]); + auto labels = new float[6 * 4]; + std::fill(labels, labels + 6 * 4, 0.f); + for (int i = 0; i < 6; i++) labels[i * 4 + ll_labels[i]] = 1.0; - std::vector dim_l({6}); + std::vector dim_l({6, 4}); lite::Tensor l_tensor(TypeId::kNumberTypeInt32, dim_l); l_tensor.SetData(labels); diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/nets/efficientnet_b0_f.ms b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/nets/efficientnet_b0_f.ms deleted file mode 100644 index a00bbee9fb..0000000000 Binary files a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/nets/efficientnet_b0_f.ms and /dev/null differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/nets/effnetb0_fwd_fuse.ms b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/nets/effnetb0_fwd_fuse.ms deleted file mode 100644 index 3a0d1f2ac4..0000000000 Binary files a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/nets/effnetb0_fwd_fuse.ms and /dev/null differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/nets/retinaface1.ms b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/nets/retinaface1.ms new file mode 100644 index 0000000000..7f175baac1 Binary files /dev/null and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/nets/retinaface1.ms differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/nets/test1.hwc_normalized_f32 b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/nets/test1.hwc_normalized_f32 new file mode 100644 index 0000000000..c13726a96b Binary files /dev/null and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/nets/test1.hwc_normalized_f32 differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/nets/test1_conf.f32 b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/nets/test1_conf.f32 new file mode 100644 index 0000000000..416b27f36d Binary files /dev/null and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/nets/test1_conf.f32 differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/nets/test1_landms.f32 b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/nets/test1_landms.f32 new file mode 100644 index 0000000000..52a901f95a Binary files /dev/null and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/nets/test1_landms.f32 differ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/nets/test1_loc.f32 b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/nets/test1_loc.f32 new file mode 100644 index 0000000000..19dd6a6dc1 Binary files /dev/null and b/mindspore/lite/test/ut/src/runtime/kernel/arm/test_data/nets/test1_loc.f32 differ diff --git a/mindspore/lite/tools/anf_exporter/anf_exporter.cc b/mindspore/lite/tools/anf_exporter/anf_exporter.cc index 405dfddbbb..c4a0c7ddec 100644 --- a/mindspore/lite/tools/anf_exporter/anf_exporter.cc +++ b/mindspore/lite/tools/anf_exporter/anf_exporter.cc @@ -274,9 +274,24 @@ int AnfExporter::ConvertInputCNode(const std::shared_ptr input_anode, s auto input_cnode = utils::cast(input_anode); if (!IsPrimitiveCNode(input_cnode, schema::PrimitiveType_TupleGetItem)) { +#ifndef SUPPORT_TRAIN + if (node_id_map_.find(input_name) != node_id_map_.end()) { + output_cnode->inputIndex.emplace_back(node_id_map_[input_name]); + } +#else + bool found = false; if (node_id_map_.find(input_name) != node_id_map_.end()) { output_cnode->inputIndex.emplace_back(node_id_map_[input_name]); + found = true; + } + + if (found == false) { + auto input_index_key = input_name + "_o:" + std::to_string(0); + if (node_id_map_.find(input_index_key) != node_id_map_.end()) { + output_cnode->inputIndex.emplace_back(node_id_map_[input_index_key]); + } } +#endif } else { auto inputs = input_cnode->inputs(); if (inputs.size() != 3) { @@ -369,6 +384,9 @@ int AnfExporter::ConvertInputValueNode(std::shared_ptr input_anode, auto typePtr = abstractTensor->element()->GetTypeTrack(); paramTensor->dataType = typePtr->type_id(); paramTensor->dims = utils::cast(abstractTensor->BuildShape())->shape(); +#ifdef SUPPORT_TRAIN + if (paramTensor->dims.size() == 0) paramTensor->dims = {1}; +#endif paramTensor->nodeType = schema::NodeType::NodeType_ValueNode; auto data = value->cast(); paramTensor->data.resize(data->Size()); @@ -505,7 +523,8 @@ void AnfExporter::SetOpOutputNode(const CNodePtr &cnode, const std::unique_ptrallTensors.size(); meta_graphT->allTensors.emplace_back(msTensor); if (IsPrimitiveCNode(cnode, schema::PrimitiveType_Conv2D) || - IsPrimitiveCNode(cnode, schema::PrimitiveType_DepthwiseConv2D)) + IsPrimitiveCNode(cnode, schema::PrimitiveType_DepthwiseConv2D) || + IsPrimitiveCNode(cnode, schema::PrimitiveType_Adam)) break; #else if (tuple->size() == 1) { diff --git a/mindspore/lite/tools/anf_importer/import_from_protobuf.cc b/mindspore/lite/tools/anf_importer/import_from_protobuf.cc index c1344c2a6d..ddb0abea8b 100644 --- a/mindspore/lite/tools/anf_importer/import_from_protobuf.cc +++ b/mindspore/lite/tools/anf_importer/import_from_protobuf.cc @@ -678,6 +678,13 @@ bool AnfImporterFromProtobuf::BuildReturnForFuncGraph(const FuncGraphPtr &output outputFuncGraph->set_return(return_node); MS_LOG(INFO) << "Construct funcgraph finined, all success."; } else { +#ifdef SUPPORT_TRAIN + auto ret_node = outputFuncGraph->get_return(); + if (ret_node) { + ret_node->add_input(cnode_ptr); + return true; + } +#endif const onnx::ValueInfoProto &output_node = importProto.output(0); const onnx::TypeProto &output_typeproto = output_node.type(); int output_type = output_typeproto.tensor_type().elem_type(); @@ -687,7 +694,6 @@ bool AnfImporterFromProtobuf::BuildReturnForFuncGraph(const FuncGraphPtr &output } auto type_ptr = TypeIdToType(kDefaultValueSwitchMap[output_type]); auto abstract_tensor = std::make_shared(type_ptr, output_shape); - inputs.clear(); auto primReturn = std::make_unique(); MS_ASSERT(primReturn != nullptr); @@ -717,6 +723,7 @@ int AnfImporterFromProtobuf::ImportNodesForGraph(const FuncGraphPtr &outputFuncG } MS_LOG(INFO) << "The CNdoe size : " << importProto.node_size(); CNodePtr cnode_ptr = nullptr; + CNodePtr last_cnode_ptr = nullptr; int status = RET_OK; NoSupportOp::GetInstance()->SetFmkType("MINDIR"); for (int i = 0; i < importProto.node_size(); ++i) { @@ -734,13 +741,35 @@ int AnfImporterFromProtobuf::ImportNodesForGraph(const FuncGraphPtr &outputFuncG MS_LOG(ERROR) << "Build CNode for funcgraph fail at index: : " << i; status = (status == RET_OK ? RET_NULL_PTR : status); } + + auto primitive_c = GetValueNode>(cnode_ptr->input(0)); + if (primitive_c == nullptr) { + MS_LOG(ERROR) << "primitive_c is nullptr"; + status = RET_ERROR; + } + +#ifdef SUPPORT_TRAIN + if (primitive_c->Type() == schema::PrimitiveType_MakeTuple) { + last_cnode_ptr = cnode_ptr; + if (!BuildReturnForFuncGraph(outputFuncGraph, importProto, cnode_ptr)) { + MS_LOG(ERROR) << "Build ReturnNode for funcgraph failed"; + status = RET_ERROR; + } + } +#endif } if (status != RET_OK) { return status; } - if (!BuildReturnForFuncGraph(outputFuncGraph, importProto, cnode_ptr)) { - MS_LOG(ERROR) << "Build ReturnNode for funcgraph failed"; - status = RET_ERROR; +#ifdef SUPPORT_TRAIN + if (last_cnode_ptr != cnode_ptr) { +#else + { +#endif + if (!BuildReturnForFuncGraph(outputFuncGraph, importProto, cnode_ptr)) { + MS_LOG(ERROR) << "Build ReturnNode for funcgraph failed"; + status = RET_ERROR; + } } return status; } diff --git a/mindspore/lite/tools/common/node_util.cc b/mindspore/lite/tools/common/node_util.cc index 02b9a4d2a7..81c77aee73 100644 --- a/mindspore/lite/tools/common/node_util.cc +++ b/mindspore/lite/tools/common/node_util.cc @@ -28,12 +28,14 @@ static const std::vector nhwcOpList = { #ifdef SUPPORT_TRAIN schema::PrimitiveType_Conv2DGradFilter, schema::PrimitiveType_Conv2DGradInput, + schema::PrimitiveType_GroupConv2DGradInput, schema::PrimitiveType_PoolingGrad, schema::PrimitiveType_BiasGrad, schema::PrimitiveType_BNGrad, schema::PrimitiveType_ActivationGrad, schema::PrimitiveType_ApplyMomentum, schema::PrimitiveType_Sgd, + schema::PrimitiveType_Adam, #endif schema::PrimitiveType_Conv2D, schema::PrimitiveType_DeConv2D, diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/format_trans_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/graph/format_trans_pass.cc index e67b257eee..15be667c70 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/graph/format_trans_pass.cc +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/format_trans_pass.cc @@ -161,6 +161,7 @@ STATUS FormatTransPass::DoNodeInoutFormatTrans(schema::MetaGraphT *graph) { int idx = 0; if (GetCNodeTType(**iter) == schema::PrimitiveType_ApplyMomentum) idx = 3; if (GetCNodeTType(**iter) == schema::PrimitiveType_Sgd) idx = 1; + if (GetCNodeTType(**iter) == schema::PrimitiveType_Adam) idx = 9; iter = InsertFormatTransNode(graph, iter, kBefore, idx, beforeNodeType, &status); if (status != RET_OK) { MS_LOG(ERROR) << "InsertNhwc2NchwNode after " << nodeName << "failed"; diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/trans_format_insert_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/graph/trans_format_insert_pass.cc index 57d36cc54d..94d61e0376 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/graph/trans_format_insert_pass.cc +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/trans_format_insert_pass.cc @@ -183,7 +183,9 @@ STATUS TransOpInsertPass::ChangeOpAxis(schema::MetaGraphT *graph, const std::uni auto origin_begin = attr->begin; attr->begin = {origin_begin[NCHW_N], origin_begin[NCHW_H], origin_begin[NCHW_W], origin_begin[NCHW_C]}; auto origin_end = attr->axes; - attr->axes = {origin_end[NCHW_N], origin_end[NCHW_H], origin_end[NCHW_W], origin_end[NCHW_C]}; + if (origin_end.size() >= 4) { + attr->axes = {origin_end[NCHW_N], origin_end[NCHW_H], origin_end[NCHW_W], origin_end[NCHW_C]}; + } auto origin_stride = attr->size; attr->size = {origin_stride[NCHW_N], origin_stride[NCHW_H], origin_stride[NCHW_W], origin_stride[NCHW_C]}; } diff --git a/mindspore/lite/tools/optimizer/graph/weight_format_hardcode_pass.cc b/mindspore/lite/tools/optimizer/graph/weight_format_hardcode_pass.cc index 8083585a3c..f46fa88bc0 100644 --- a/mindspore/lite/tools/optimizer/graph/weight_format_hardcode_pass.cc +++ b/mindspore/lite/tools/optimizer/graph/weight_format_hardcode_pass.cc @@ -122,6 +122,12 @@ lite::STATUS WeightFormatHardCodePass::HardCodeMS(const AnfNodePtr &conv_node, param_value->set_format(schema::Format::Format_CKHW); } else if (op_type == schema::PrimitiveType_DeConv2D) { param_value->set_format(schema::Format::Format_KCHW); +#ifdef SUPPORT_TRAIN + } else if (op_type == schema::PrimitiveType_Conv2DGradInput) { + param_value->set_format(schema::Format::Format_KCHW); + } else if (op_type == schema::PrimitiveType_GroupConv2DGradInput) { + param_value->set_format(schema::Format::Format_CKHW); +#endif } else { MS_LOG(ERROR) << "Unsupported opType: " << EnumNamePrimitiveType(op_type) << ", node: " << conv_node->fullname_with_scope(); @@ -178,6 +184,10 @@ bool WeightFormatHardCodePass::Run(const FuncGraphPtr &graph) { auto conv_cnode = node->cast(); auto type = opt::GetCNodeType(node); if (type != schema::PrimitiveType_Conv2D && type != schema::PrimitiveType_DepthwiseConv2D && +#ifdef SUPPORT_TRAIN + ((type != schema::PrimitiveType_Conv2DGradInput) || (fmk_type != FmkType_MS)) && + ((type != schema::PrimitiveType_GroupConv2DGradInput) || (fmk_type != FmkType_MS)) && +#endif type != schema::PrimitiveType_DeConv2D && type != schema::PrimitiveType_DeDepthwiseConv2D) { continue; } diff --git a/mindspore/lite/tools/optimizer/graph/weight_format_transform_pass.cc b/mindspore/lite/tools/optimizer/graph/weight_format_transform_pass.cc index 1f7df2ae4f..7c7b2d4c62 100644 --- a/mindspore/lite/tools/optimizer/graph/weight_format_transform_pass.cc +++ b/mindspore/lite/tools/optimizer/graph/weight_format_transform_pass.cc @@ -18,27 +18,21 @@ #include "tools/optimizer/common/gllo_utils.h" using mindspore::lite::converter::FmkType_CAFFE; -using mindspore::lite::converter::FmkType_TFLITE; -using mindspore::lite::converter::FmkType_ONNX; using mindspore::lite::converter::FmkType_MS; -using mindspore::schema::QuantType_WeightQuant; -using mindspore::schema::QuantType_QUANT_NONE; +using mindspore::lite::converter::FmkType_ONNX; +using mindspore::lite::converter::FmkType_TFLITE; using mindspore::schema::QuantType_AwareTraining; using mindspore::schema::QuantType_PostTraining; +using mindspore::schema::QuantType_QUANT_NONE; +using mindspore::schema::QuantType_WeightQuant; namespace mindspore::opt { namespace { constexpr size_t kConvWeightIndex = 2; } // namespace -void WeightFormatTransformPass::SetQuantType(QuantType type) { - this->quant_type = type; -} -void WeightFormatTransformPass::SetFmkType(FmkType type) { - this->fmk_type = type; -} -void WeightFormatTransformPass::SetDstFormat(schema::Format format) { - this->dst_format = format; -} +void WeightFormatTransformPass::SetQuantType(QuantType type) { this->quant_type = type; } +void WeightFormatTransformPass::SetFmkType(FmkType type) { this->fmk_type = type; } +void WeightFormatTransformPass::SetDstFormat(schema::Format format) { this->dst_format = format; } lite::STATUS WeightFormatTransformPass::ConvWeightFormatTrans(const FuncGraphPtr &graph) { MS_ASSERT(graph != nullptr); auto node_list = TopoSort(graph->get_return()); @@ -48,6 +42,9 @@ lite::STATUS WeightFormatTransformPass::ConvWeightFormatTrans(const FuncGraphPtr } auto type = opt::GetCNodeType(node); if (type != schema::PrimitiveType_Conv2D && type != schema::PrimitiveType_DepthwiseConv2D +#ifdef SUPPORT_TRAIN + && type != schema::PrimitiveType_Conv2DGradInput && type != schema::PrimitiveType_GroupConv2DGradInput +#endif && type != schema::PrimitiveType_DeConv2D && type != schema::PrimitiveType_DeDepthwiseConv2D) { continue; } @@ -60,8 +57,8 @@ lite::STATUS WeightFormatTransformPass::ConvWeightFormatTrans(const FuncGraphPtr MS_LOG(ERROR) << "weight node must param value"; return false; } - MS_ASSERT(weight_value->tensor_type() == TypeId::kNumberTypeFloat32 - || weight_value->tensor_type() == TypeId::kNumberTypeUInt8); + MS_ASSERT(weight_value->tensor_type() == TypeId::kNumberTypeFloat32 || + weight_value->tensor_type() == TypeId::kNumberTypeUInt8); lite::STATUS status; schema::Format weight_dst_format = schema::Format::Format_KHWC; if (dst_format != schema::Format::Format_NUM_OF_FORMAT) {