Merge pull request !7432 from yonibaehr/exporttags/v1.1.0
| @@ -117,6 +117,14 @@ int HSwish(const float *src, int length, float *dst) { | |||
| return NNACL_OK; | |||
| } | |||
| int HSigmoid(const float *src, int length, float *dst) { | |||
| for (int i = 0; i < length; ++i) { | |||
| float relu6 = MSMIN(MSMAX(src[i] + 3, 0), 6); | |||
| dst[i] = relu6 / 6; | |||
| } | |||
| return NNACL_OK; | |||
| } | |||
| int HardTanh(const float *src, int length, float *dst, float min_val, float max_val) { | |||
| if (max_val <= min_val) { | |||
| return NNACL_ERR; | |||
| @@ -13,8 +13,8 @@ | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_NNACL_ACTIVATION_H_ | |||
| #define MINDSPORE_LITE_NNACL_ACTIVATION_H_ | |||
| #ifndef MINDSPORE_LITE_NNACL_FP32_ACTIVATION_H_ | |||
| #define MINDSPORE_LITE_NNACL_FP32_ACTIVATION_H_ | |||
| #include <math.h> | |||
| #include "nnacl/op_base.h" | |||
| @@ -36,9 +36,10 @@ int Fp32Relu6(const float *src, int length, float *dst); | |||
| int LRelu(const float *src, int length, float *dst, float alpha); | |||
| int Sigmoid(const float *src, int length, float *dst); | |||
| int Tanh(const float *src, int length, float *dst); | |||
| int HSigmoid(const float *src, int length, float *dst); | |||
| int HSwish(const float *src, int length, float *dst); | |||
| int HardTanh(const float *src, int length, float *dst, float min_val, float max_val); | |||
| #ifdef __cplusplus | |||
| } | |||
| #endif | |||
| #endif // MINDSPORE_LITE_NNACL_ACTIVATION_H_ | |||
| #endif // MINDSPORE_LITE_NNACL_FP32_ACTIVATION_H_ | |||
| @@ -31,14 +31,14 @@ int OneHot(const int *indices, float *output, const OneHotParameter *one_hot_par | |||
| int i, j, k; | |||
| for (i = tid; i < outer_size; i += thread_num) { | |||
| float *output_ptr = output + i * depth * inner_size; | |||
| for (k = 0; k < depth; k++) { | |||
| for (j = 0; j < inner_size; j++) { | |||
| for (k = 0; k < inner_size; k++) { | |||
| int index = indices[i * inner_size + k]; | |||
| for (j = 0; j < depth; j++) { | |||
| *output_ptr = off_value; | |||
| int index = indices[i * inner_size + j]; | |||
| if (index >= depth) { | |||
| return NNACL_ERRCODE_INDEX_OUT_OF_RANGE; | |||
| } | |||
| if (index == k) { | |||
| if (index == j) { | |||
| *output_ptr = on_value; | |||
| } | |||
| output_ptr++; | |||
| @@ -15,27 +15,52 @@ | |||
| */ | |||
| #include "nnacl/fp32_grad/gemm.h" | |||
| #include <string.h> | |||
| static void gemm_not_trana_not_tranb(int M, int N, int K, float alpha, float *mat_a, int lda, float *mat_b, int ldb, | |||
| float *mat_c, int ldc) { | |||
| const int block_size = 4; | |||
| int block_mod = N % block_size; | |||
| int block_c4 = N - block_mod; | |||
| static void gemm_nn(int M, int N, int K, float alpha, float *mat_a, int lda, float *mat_B, int ldb, float *mat_c, | |||
| int ldc) { | |||
| int i, j, k; | |||
| for (i = 0; i < M; ++i) { | |||
| for (k = 0; k < K; ++k) { | |||
| float a = alpha * mat_a[i * lda + k]; | |||
| for (j = 0; j < N; ++j) { | |||
| mat_c[i * ldc + j] += a * mat_B[k * ldb + j]; | |||
| for (j = 0; j < block_c4; j += block_size) { | |||
| float *b = &mat_b[k * ldb + j]; | |||
| float *c = &mat_c[i * ldc + j]; | |||
| c[0] += a * b[0]; | |||
| c[1] += a * b[1]; | |||
| c[2] += a * b[2]; | |||
| c[3] += a * b[3]; | |||
| } | |||
| for (; j < N; ++j) { | |||
| mat_c[i * ldc + j] += a * mat_b[k * ldb + j]; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| static void gemm_nt(int M, int N, int K, float alpha, float *mat_a, int lda, float *mat_b, int ldb, float *mat_c, | |||
| int ldc) { | |||
| static void gemm_not_trana_tranb(int M, int N, int K, float alpha, float *mat_a, int lda, float *mat_b, int ldb, | |||
| float *mat_c, int ldc) { | |||
| const int block_size = 4; | |||
| int block_mod = K % block_size; | |||
| int block_c4 = K - block_mod; | |||
| int i, j, k; | |||
| for (i = 0; i < M; ++i) { | |||
| for (j = 0; j < N; ++j) { | |||
| float sum = 0; | |||
| for (k = 0; k < K; ++k) { | |||
| for (k = 0; k < block_c4; k += block_size) { | |||
| float *a = &mat_a[i * lda + k]; | |||
| float *b = &mat_b[j * ldb + k]; | |||
| sum += alpha * a[0] * b[0]; | |||
| sum += alpha * a[1] * b[1]; | |||
| sum += alpha * a[2] * b[2]; | |||
| sum += alpha * a[3] * b[3]; | |||
| } | |||
| for (; k < K; ++k) { | |||
| sum += alpha * mat_a[i * lda + k] * mat_b[j * ldb + k]; | |||
| } | |||
| mat_c[i * ldc + j] += sum; | |||
| @@ -43,23 +68,85 @@ static void gemm_nt(int M, int N, int K, float alpha, float *mat_a, int lda, flo | |||
| } | |||
| } | |||
| static void gemm_tn(int M, int N, int K, float alpha, float *mat_a, int lda, float *mat_b, int ldb, float *mat_c, | |||
| int ldc) { | |||
| static void gemm_trana_not_tranb(int M, int N, int K, float alpha, float *mat_a, int lda, float *mat_b, int ldb, | |||
| float *mat_c, int ldc) { | |||
| const int block_size = 4; | |||
| int block_mod = N % block_size; | |||
| int block_c4 = N - block_mod; | |||
| int i, j, k; | |||
| for (i = 0; i < M; ++i) { | |||
| for (k = 0; k < K; ++k) { | |||
| float a = alpha * mat_a[k * lda + i]; | |||
| for (j = 0; j < N; ++j) { | |||
| for (j = 0; j < block_c4; j += block_size) { | |||
| float *b = &mat_b[k * ldb + j]; | |||
| float *c = &mat_c[i * ldc + j]; | |||
| c[0] += a * b[0]; | |||
| c[1] += a * b[1]; | |||
| c[2] += a * b[2]; | |||
| c[3] += a * b[3]; | |||
| } | |||
| for (; j < N; ++j) { | |||
| mat_c[i * ldc + j] += a * mat_b[k * ldb + j]; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| static void gemm_tt(int M, int N, int K, float alpha, float *mat_a, int lda, float *mat_b, int ldb, float *mat_c, | |||
| int ldc) { | |||
| static void gemm_trana_tranb(int M, int N, int K, float alpha, float *mat_a, int lda, float *mat_b, int ldb, | |||
| float *mat_c, int ldc) { | |||
| int i, j, k; | |||
| for (i = 0; i < M; ++i) { | |||
| const int block_size = 4; | |||
| int k_block_mod = K % block_size; | |||
| int k_block_c4 = K - k_block_mod; | |||
| int m_block_mod = M % block_size; | |||
| int m_block_c4 = M - m_block_mod; | |||
| for (i = 0; i < m_block_c4; i += block_size) { | |||
| for (j = 0; j < N; ++j) { | |||
| float sum0 = 0; | |||
| float sum1 = 0; | |||
| float sum2 = 0; | |||
| float sum3 = 0; | |||
| for (k = 0; k < k_block_c4; k += block_size) { | |||
| float *b = &mat_b[j * ldb + k]; | |||
| sum0 += alpha * mat_a[i + k * lda] * b[0]; | |||
| sum0 += alpha * mat_a[i + (k + 1) * lda] * b[1]; | |||
| sum0 += alpha * mat_a[i + (k + 2) * lda] * b[2]; | |||
| sum0 += alpha * mat_a[i + (k + 3) * lda] * b[3]; | |||
| sum1 += alpha * mat_a[i + 1 + k * lda] * b[0]; | |||
| sum1 += alpha * mat_a[i + 1 + (k + 1) * lda] * b[1]; | |||
| sum1 += alpha * mat_a[i + 1 + (k + 2) * lda] * b[2]; | |||
| sum1 += alpha * mat_a[i + 1 + (k + 3) * lda] * b[3]; | |||
| sum2 += alpha * mat_a[i + 2 + k * lda] * b[0]; | |||
| sum2 += alpha * mat_a[i + 2 + (k + 1) * lda] * b[1]; | |||
| sum2 += alpha * mat_a[i + 2 + (k + 2) * lda] * b[2]; | |||
| sum2 += alpha * mat_a[i + 2 + (k + 3) * lda] * b[3]; | |||
| sum3 += alpha * mat_a[i + 3 + k * lda] * b[0]; | |||
| sum3 += alpha * mat_a[i + 3 + (k + 1) * lda] * b[1]; | |||
| sum3 += alpha * mat_a[i + 3 + (k + 2) * lda] * b[2]; | |||
| sum3 += alpha * mat_a[i + 3 + (k + 3) * lda] * b[3]; | |||
| } | |||
| for (; k < K; ++k) { | |||
| float *b = &mat_b[j * ldb + k]; | |||
| sum0 += alpha * mat_a[i + (k * lda)] * b[0]; | |||
| sum1 += alpha * mat_a[i + 1 + (k * lda)] * b[0]; | |||
| sum2 += alpha * mat_a[i + 2 + (k * lda)] * b[0]; | |||
| sum3 += alpha * mat_a[i + 3 + (k * lda)] * b[0]; | |||
| } | |||
| mat_c[i * ldc + j] += sum0; | |||
| mat_c[(i + 1) * ldc + j] += sum1; | |||
| mat_c[(i + 2) * ldc + j] += sum2; | |||
| mat_c[(i + 3) * ldc + j] += sum3; | |||
| } | |||
| } | |||
| // no more block of 4x4 | |||
| for (; i < M; ++i) { | |||
| for (j = 0; j < N; ++j) { | |||
| float sum = 0; | |||
| for (k = 0; k < K; ++k) { | |||
| @@ -74,34 +161,37 @@ static void gemm_tt(int M, int N, int K, float alpha, float *mat_a, int lda, flo | |||
| // M - number of rows of matrix a | |||
| // N - number of cols of matrix b | |||
| // K - number of cols of matrix a | |||
| // lda - fast dim of matrix a | |||
| // ldb - fast dim of matrix b | |||
| // ldc - fast dim of matrix c | |||
| void gemm(int transpose_a, int transpose_b, int M, int N, int K, float alpha, float *mat_a, int lda, float *mat_b, | |||
| int ldb, float beta, float *mat_c, int ldc) { | |||
| if (beta >= 0.f && beta <= 0.f) { | |||
| for (int i = 0; i < M; ++i) { | |||
| for (int j = 0; j < N; ++j) { | |||
| mat_c[i * ldc + j] = 0; | |||
| } | |||
| } | |||
| memset(mat_c, 0, M * N * sizeof(float)); | |||
| } else if (beta < 1.f || beta > 1.f) { | |||
| for (int i = 0; i < M; ++i) { | |||
| for (int j = 0; j < N; ++j) { | |||
| mat_c[i * ldc + j] *= beta; | |||
| } | |||
| const int block_size = 4; | |||
| const int size = M * N; | |||
| int block_mod = size % block_size; | |||
| int block_c4 = size - block_mod; | |||
| int i; | |||
| for (i = 0; i < block_c4; i += block_size) { | |||
| float *c = &mat_c[i]; | |||
| c[0] *= beta; | |||
| c[1] *= beta; | |||
| c[2] *= beta; | |||
| c[3] *= beta; | |||
| } | |||
| } | |||
| int t; | |||
| for (t = 0; t < M; ++t) { | |||
| if (!transpose_a && !transpose_b) { | |||
| gemm_nn(1, N, K, alpha, mat_a + t * lda, lda, mat_b, ldb, mat_c + t * ldc, ldc); | |||
| } else if (transpose_a && !transpose_b) { | |||
| gemm_tn(1, N, K, alpha, mat_a + t, lda, mat_b, ldb, mat_c + t * ldc, ldc); | |||
| } else if (!transpose_a && transpose_b) { | |||
| gemm_nt(1, N, K, alpha, mat_a + t * lda, lda, mat_b, ldb, mat_c + t * ldc, ldc); | |||
| } else { | |||
| gemm_tt(1, N, K, alpha, mat_a + t, lda, mat_b, ldb, mat_c + t * ldc, ldc); | |||
| for (; i < size; ++i) { | |||
| mat_c[i] *= beta; | |||
| } | |||
| } | |||
| if (transpose_a && transpose_b) { | |||
| gemm_trana_tranb(M, N, K, alpha, mat_a, lda, mat_b, ldb, mat_c, ldc); | |||
| } else if (!transpose_a && !transpose_b) { | |||
| gemm_not_trana_not_tranb(M, N, K, alpha, mat_a, lda, mat_b, ldb, mat_c, ldc); | |||
| } else if (!transpose_a && transpose_b) { | |||
| gemm_not_trana_tranb(M, N, K, alpha, mat_a, lda, mat_b, ldb, mat_c, ldc); | |||
| } else { | |||
| gemm_trana_not_tranb(M, N, K, alpha, mat_a, lda, mat_b, ldb, mat_c, ldc); | |||
| } | |||
| } | |||
| @@ -21,7 +21,6 @@ | |||
| typedef struct ApplyMomentumParameter { | |||
| OpParameter op_parameter_; | |||
| bool use_locking_; | |||
| bool use_nesterov_; | |||
| float grad_scale_; | |||
| } ApplyMomentumParameter; | |||
| @@ -33,4 +32,9 @@ typedef struct SgdParameter { | |||
| float weight_decay_; | |||
| } SgdParameter; | |||
| typedef struct AdamParameter { | |||
| OpParameter op_parameter_; | |||
| bool use_nesterov_; | |||
| } AdamParameter; | |||
| #endif // MINDSPORE_LITE_NNACL_FP32_GRAD_OPTIMIZER_H_ | |||
| @@ -182,7 +182,7 @@ union PrimitiveType { | |||
| Conv2DGradInput, | |||
| PoolingGrad, | |||
| BNGrad, | |||
| BNGradInput, | |||
| Assign, | |||
| ApplyMomentum, | |||
| BiasGrad, | |||
| SoftmaxCrossEntropy, | |||
| @@ -217,6 +217,8 @@ union PrimitiveType { | |||
| FftReal, | |||
| FftImag, | |||
| Sgd, | |||
| Adam, | |||
| GroupConv2DGradInput, | |||
| } | |||
| enum QuantType: int { | |||
| @@ -224,7 +224,29 @@ table Conv2DGradInput { | |||
| dilateH: int; | |||
| hasBias: bool = false; | |||
| activationType: ActivationType = 0; | |||
| }table FusedBatchNorm { | |||
| } | |||
| table GroupConv2DGradInput { | |||
| format: Format = 0; | |||
| group: int; | |||
| channelIn: int; | |||
| channelOut: int; | |||
| kernelW: int; | |||
| kernelH: int; | |||
| strideW: int; | |||
| strideH: int; | |||
| padMode: PadMode; | |||
| padUp: int; | |||
| padDown: int; | |||
| padLeft: int; | |||
| padRight: int; | |||
| dilateW: int; | |||
| dilateH: int; | |||
| hasBias: bool = false; | |||
| activationType: ActivationType = 0; | |||
| } | |||
| table FusedBatchNorm { | |||
| epsilon: float = 0.00001; // eg. epsilon=0.001 | |||
| momentum: float = 0.9; | |||
| spatial: int = 1; | |||
| @@ -901,7 +923,6 @@ table TupleGetItem { | |||
| table ApplyMomentum { | |||
| gradientScale: float; | |||
| useLocking: bool; | |||
| useNesterov: bool; | |||
| } | |||
| @@ -911,6 +932,14 @@ table Sgd { | |||
| useNesterov: bool; | |||
| } | |||
| table Adam { | |||
| useNesterov: bool; | |||
| } | |||
| table Assign { | |||
| } | |||
| table Where{ | |||
| condition: [bool]; | |||
| } | |||
| @@ -50,6 +50,10 @@ int Activation::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> | |||
| attr->type = schema::ActivationType_SIGMOID; | |||
| } else if (prim.name() == "ReLU6") { | |||
| attr->type = schema::ActivationType_RELU6; | |||
| } else if (prim.name() == "HSwish") { | |||
| attr->type = schema::ActivationType_HSWISH; | |||
| } else if (prim.name() == "HSigmoid") { | |||
| attr->type = schema::ActivationType_HSIGMOID; | |||
| } | |||
| this->primitive_->value.value = attr.release(); | |||
| if (this->primitive_->value.value == nullptr) { | |||
| @@ -43,8 +43,12 @@ int ActivationGrad::UnPackAttr(const Primitive &prim, const std::vector<AnfNodeP | |||
| attr->type = schema::ActivationType_RELU; | |||
| } else if (prim.name() == "SigmoidGrad") { | |||
| attr->type = schema::ActivationType_SIGMOID; | |||
| } else if (prim.name() == "Relu6Grad") { | |||
| } else if (prim.name() == "ReLU6Grad") { | |||
| attr->type = schema::ActivationType_RELU6; | |||
| } else if (prim.name() == "HSigmoidGrad") { | |||
| attr->type = schema::ActivationType_HSIGMOID; | |||
| } else if (prim.name() == "HSwishGrad") { | |||
| attr->type = schema::ActivationType_HSWISH; | |||
| } | |||
| attr->alpha = 0; // alpha; | |||
| this->primitive_->value.value = attr.release(); | |||
| @@ -0,0 +1,91 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "src/ops/adam.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| #ifdef PRIMITIVE_WRITEABLE | |||
| bool Adam::GetUseNesterov() const { return this->primitive_->value.AsAdam()->useNesterov; } | |||
| int Adam::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { | |||
| if (this->primitive_ == nullptr) { | |||
| this->primitive_ = new (std::nothrow) schema::PrimitiveT; | |||
| if (this->primitive_ == nullptr) { | |||
| MS_LOG(ERROR) << "new primitiveT failed"; | |||
| return RET_ERROR; | |||
| } | |||
| this->primitive_->value.type = schema::PrimitiveType_Adam; | |||
| } | |||
| if (this->primitive_->value.type != schema::PrimitiveType_Adam) { | |||
| MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type; | |||
| return RET_ERROR; | |||
| } | |||
| if (this->primitive_->value.value == nullptr) { | |||
| auto attr = std::make_unique<schema::AdamT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new primitiveT value failed"; | |||
| return RET_ERROR; | |||
| } | |||
| attr->useNesterov = GetValue<bool>(prim.GetAttr("use_nesterov")); | |||
| this->primitive_->value.value = attr.release(); | |||
| if (this->primitive_->value.value == nullptr) { | |||
| MS_LOG(ERROR) << "new primitiveT value failed"; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| #else | |||
| bool Adam::GetUseNesterov() const { return this->primitive_->value_as_Adam()->useNesterov(); } | |||
| int Adam::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { | |||
| MS_ASSERT(nullptr != primitive); | |||
| MS_ASSERT(nullptr != fbb); | |||
| auto attr = primitive->value_as_Adam(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "value_as_Adam return nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| auto val_offset = schema::CreateAdam(*fbb, attr->useNesterov()); | |||
| auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Adam, val_offset.o); | |||
| fbb->Finish(prim_offset); | |||
| return RET_OK; | |||
| } | |||
| #endif | |||
| int Adam::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lite::Tensor *> outputs) { | |||
| if (10 != inputs.size()) { | |||
| MS_LOG(ERROR) << "Adam should have at least 8 input tensors"; | |||
| return RET_ERROR; | |||
| } | |||
| if (inputs[0]->ElementsNum() != inputs[1]->ElementsNum() || inputs[0]->ElementsNum() != inputs[2]->ElementsNum() || | |||
| inputs[0]->ElementsNum() != inputs[9]->ElementsNum() || inputs[3]->ElementsNum() != 1 || | |||
| inputs[4]->ElementsNum() != 1 || inputs[5]->ElementsNum() != 1 || inputs[6]->ElementsNum() != 1 || | |||
| inputs[7]->ElementsNum() != 1 || inputs[8]->ElementsNum() != 1) { | |||
| MS_LOG(ERROR) << "error input data size!"; | |||
| return RET_ERROR; | |||
| } | |||
| if (!outputs.empty()) { | |||
| auto *out = outputs.front(); | |||
| MS_ASSERT(out != nullptr); | |||
| out->set_data_type(inputs[0]->data_type()); | |||
| out->SetFormat(inputs[0]->GetFormat()); | |||
| out->set_shape({1}); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,47 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_SRC_OPS_ADAM_H_ | |||
| #define MINDSPORE_LITE_SRC_OPS_ADAM_H_ | |||
| #include <vector> | |||
| #include <set> | |||
| #include <cmath> | |||
| #include <memory> | |||
| #include "src/ops/primitive_c.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| class Adam : public PrimitiveC { | |||
| public: | |||
| #ifdef PRIMITIVE_WRITEABLE | |||
| MS_DECLARE_PARENT(Adam, PrimitiveC); | |||
| Adam() = default; | |||
| explicit Adam(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} | |||
| int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override; | |||
| #else | |||
| Adam() = default; | |||
| int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; | |||
| #endif | |||
| int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override; | |||
| bool GetUseNesterov() const; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_LITE_SRC_OPS_ADAM_H_ | |||
| @@ -82,8 +82,11 @@ int AddN::InferShape(std::vector<Tensor *> inputs, std::vector<Tensor *> outputs | |||
| if (!GetInferFlag()) { | |||
| return RET_OK; | |||
| } | |||
| output->set_shape(input->shape()); | |||
| // make sure all elements have the same size or 1 (broadcasting) in all dimensions | |||
| for (size_t i = 1; i < inputs.size(); ++i) { | |||
| if (inputs.at(i)->shape() != inputs.at(0)->shape()) { | |||
| if (inputs.at(i)->shape().size() != inputs.at(0)->shape().size()) { | |||
| MS_LOG(ERROR) << "AddN inputs shape is not equal!"; | |||
| return RET_INPUT_TENSOR_ERROR; | |||
| } | |||
| @@ -93,7 +96,22 @@ int AddN::InferShape(std::vector<Tensor *> inputs, std::vector<Tensor *> outputs | |||
| } | |||
| } | |||
| output->set_shape(input->shape()); | |||
| for (size_t d = 0; d < input->shape().size(); ++d) { | |||
| int max_dim = input->shape().at(d); | |||
| for (size_t i = 1; i < inputs.size(); ++i) { | |||
| if (inputs.at(i)->shape().at(d) > max_dim) { | |||
| max_dim = inputs.at(i)->shape().at(d); | |||
| } | |||
| } | |||
| for (size_t i = 1; i < inputs.size(); ++i) { | |||
| if ((inputs.at(0)->shape().at(d) != max_dim) && (inputs.at(0)->shape().at(d) != 1)) { | |||
| MS_LOG(ERROR) << "AddN inputs shape is not equal!"; | |||
| return RET_INPUT_TENSOR_ERROR; | |||
| } | |||
| } | |||
| output->shape()[d] = max_dim; // set the biggest dimension in the output tensor | |||
| } | |||
| return RET_OK; | |||
| } | |||
| } // namespace lite | |||
| @@ -18,7 +18,6 @@ namespace mindspore { | |||
| namespace lite { | |||
| #ifdef PRIMITIVE_WRITEABLE | |||
| float ApplyMomentum::GetGradientScale() const { return this->primitive_->value.AsApplyMomentum()->gradientScale; } | |||
| bool ApplyMomentum::GetUseLocking() const { return this->primitive_->value.AsApplyMomentum()->useLocking; } | |||
| bool ApplyMomentum::GetUseNesterov() const { return this->primitive_->value.AsApplyMomentum()->useNesterov; } | |||
| int ApplyMomentum::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { | |||
| @@ -41,7 +40,6 @@ int ApplyMomentum::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePt | |||
| return RET_ERROR; | |||
| } | |||
| attr->gradientScale = GetValue<float>(prim.GetAttr("gradient_scale")); | |||
| attr->useLocking = GetValue<bool>(prim.GetAttr("use_locking")); | |||
| attr->useNesterov = GetValue<bool>(prim.GetAttr("use_nesterov")); | |||
| this->primitive_->value.value = attr.release(); | |||
| @@ -54,7 +52,6 @@ int ApplyMomentum::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePt | |||
| } | |||
| #else | |||
| float ApplyMomentum::GetGradientScale() const { return this->primitive_->value_as_ApplyMomentum()->gradientScale(); } | |||
| bool ApplyMomentum::GetUseLocking() const { return this->primitive_->value_as_ApplyMomentum()->useLocking(); } | |||
| bool ApplyMomentum::GetUseNesterov() const { return this->primitive_->value_as_ApplyMomentum()->useNesterov(); } | |||
| int ApplyMomentum::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { | |||
| @@ -65,7 +62,7 @@ int ApplyMomentum::UnPackToFlatBuilder(const schema::Primitive *primitive, flatb | |||
| MS_LOG(ERROR) << "value_as_ApplyMomentum return nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| auto val_offset = schema::CreateApplyMomentum(*fbb, attr->gradientScale(), attr->useLocking(), attr->useNesterov()); | |||
| auto val_offset = schema::CreateApplyMomentum(*fbb, attr->gradientScale(), attr->useNesterov()); | |||
| auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_ApplyMomentum, val_offset.o); | |||
| fbb->Finish(prim_offset); | |||
| return RET_OK; | |||
| @@ -40,7 +40,6 @@ class ApplyMomentum : public PrimitiveC { | |||
| #endif | |||
| int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override; | |||
| float GetGradientScale() const; | |||
| bool GetUseLocking() const; | |||
| bool GetUseNesterov() const; | |||
| }; | |||
| } // namespace lite | |||
| @@ -0,0 +1,82 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "src/ops/assign.h" | |||
| #include <memory> | |||
| namespace mindspore { | |||
| namespace lite { | |||
| #ifdef PRIMITIVE_WRITEABLE | |||
| int Assign::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { | |||
| if (this->primitive_ == nullptr) { | |||
| this->primitive_ = new (std::nothrow) schema::PrimitiveT; | |||
| if (this->primitive_ == nullptr) { | |||
| MS_LOG(ERROR) << "new primitiveT failed"; | |||
| return RET_ERROR; | |||
| } | |||
| this->primitive_->value.type = schema::PrimitiveType_Assign; | |||
| } | |||
| if (this->primitive_->value.type != schema::PrimitiveType_Assign) { | |||
| MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type; | |||
| return RET_ERROR; | |||
| } | |||
| if (this->primitive_->value.value == nullptr) { | |||
| this->primitive_->value.value = new (std::nothrow) schema::AssignT(); | |||
| if (this->primitive_->value.value == nullptr) { | |||
| MS_LOG(ERROR) << "new primitiveT value failed"; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| #else | |||
| int Assign::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { | |||
| MS_ASSERT(nullptr != primitive); | |||
| MS_ASSERT(nullptr != fbb); | |||
| auto attr = primitive->value_as_Assign(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "value_as_Assign return nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| auto val_offset = schema::CreateAssign(*fbb); | |||
| auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Assign, val_offset.o); | |||
| fbb->Finish(prim_offset); | |||
| return RET_OK; | |||
| } | |||
| #endif | |||
| int Assign::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lite::Tensor *> outputs) { | |||
| if (2 != inputs.size()) { | |||
| MS_LOG(ERROR) << "Assign should have at least 5 input tensors"; | |||
| return RET_ERROR; | |||
| } | |||
| if (inputs[0]->ElementsNum() != inputs[1]->ElementsNum()) { | |||
| MS_LOG(ERROR) << "error input data size!"; | |||
| return RET_ERROR; | |||
| } | |||
| if (!outputs.empty()) { | |||
| auto *out = outputs.front(); | |||
| MS_ASSERT(out != nullptr); | |||
| out->set_data_type(inputs[0]->data_type()); | |||
| out->SetFormat(inputs[0]->GetFormat()); | |||
| out->set_shape({1}); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,43 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_SRC_OPS_ASSIGN_H_ | |||
| #define MINDSPORE_LITE_SRC_OPS_ASSIGN_H_ | |||
| #include <vector> | |||
| #include <set> | |||
| #include <cmath> | |||
| #include "src/ops/primitive_c.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| class Assign : public PrimitiveC { | |||
| public: | |||
| #ifdef PRIMITIVE_WRITEABLE | |||
| MS_DECLARE_PARENT(Assign, PrimitiveC); | |||
| Assign() = default; | |||
| explicit Assign(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} | |||
| int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override; | |||
| #else | |||
| Assign() = default; | |||
| int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; | |||
| #endif | |||
| int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_LITE_SRC_OPS_ASSIGN_H_ | |||
| @@ -45,8 +45,8 @@ int BNGrad::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inp | |||
| } | |||
| attr->momentum = GetValue<float>(prim.GetAttr("momentum")); | |||
| // FusedBatchNormGrad dows not get this attribute | |||
| if (prim.GetAttr("eps") != nullptr) { | |||
| attr->eps = GetValue<float>(prim.GetAttr("eps")); | |||
| if (prim.GetAttr("epsilon") != nullptr) { | |||
| attr->eps = GetValue<float>(prim.GetAttr("epsilon")); | |||
| } | |||
| this->primitive_->value.value = attr; | |||
| if (this->primitive_->value.value == nullptr) { | |||
| @@ -15,7 +15,7 @@ | |||
| */ | |||
| #include "src/ops/conv2d_grad_input.h" | |||
| #include "src/ops/group_conv2d_grad_input.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| #ifdef PRIMITIVE_WRITEABLE | |||
| @@ -86,6 +86,9 @@ int Conv2DGradInput::UnPackAttr(const Primitive &prim, const std::vector<AnfNode | |||
| return RET_ERROR; | |||
| } | |||
| attr->group = GetValue<int>(prim.GetAttr("group")); | |||
| if (attr->group > 1) { | |||
| this->primitive_->value.type = schema::PrimitiveType_GroupConv2DGradInput; | |||
| } | |||
| auto format = GetValue<std::string>(prim.GetAttr("data_format")); | |||
| if (format == "NCHW") { | |||
| attr->format = schema::Format_NCHW; | |||
| @@ -26,6 +26,30 @@ void Exp::SetShift(float shift) { this->primitive_->value.AsExp()->shift = shift | |||
| float Exp::GetBase() const { return this->primitive_->value.AsExp()->base; } | |||
| float Exp::GetScale() const { return this->primitive_->value.AsExp()->scale; } | |||
| float Exp::GetShift() const { return this->primitive_->value.AsExp()->shift; } | |||
| int Exp::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { | |||
| if (this->primitive_ == nullptr) { | |||
| this->primitive_ = new (std::nothrow) schema::PrimitiveT; | |||
| if (this->primitive_ == nullptr) { | |||
| MS_LOG(ERROR) << "new primitiveT failed"; | |||
| return RET_ERROR; | |||
| } | |||
| this->primitive_->value.type = schema::PrimitiveType_Exp; | |||
| } | |||
| if (this->primitive_->value.type != schema::PrimitiveType_Exp) { | |||
| MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type; | |||
| return RET_ERROR; | |||
| } | |||
| if (this->primitive_->value.value == nullptr) { | |||
| this->primitive_->value.value = new (std::nothrow) schema::ExpT(); | |||
| if (this->primitive_->value.value == nullptr) { | |||
| MS_LOG(ERROR) << "new primitiveT value failed"; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| #else | |||
| int Exp::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { | |||
| @@ -33,6 +33,7 @@ class Exp : public PrimitiveC { | |||
| void SetBase(float base); | |||
| void SetShift(float shift); | |||
| void SetScale(float scale); | |||
| int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override; | |||
| #else | |||
| Exp() = default; | |||
| @@ -0,0 +1,172 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "src/ops/group_conv2d_grad_input.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| #ifdef PRIMITIVE_WRITEABLE | |||
| int GroupConv2DGradInput::GetFormat() const { return this->primitive_->value.AsGroupConv2DGradInput()->format; } | |||
| int GroupConv2DGradInput::GetGroup() const { return this->primitive_->value.AsGroupConv2DGradInput()->group; } | |||
| int GroupConv2DGradInput::GetChannelIn() const { return this->primitive_->value.AsGroupConv2DGradInput()->channelIn; } | |||
| int GroupConv2DGradInput::GetChannelOut() const { return this->primitive_->value.AsGroupConv2DGradInput()->channelOut; } | |||
| int GroupConv2DGradInput::GetKernelW() const { return this->primitive_->value.AsGroupConv2DGradInput()->kernelW; } | |||
| int GroupConv2DGradInput::GetKernelH() const { return this->primitive_->value.AsGroupConv2DGradInput()->kernelH; } | |||
| int GroupConv2DGradInput::GetStrideW() const { return this->primitive_->value.AsGroupConv2DGradInput()->strideW; } | |||
| int GroupConv2DGradInput::GetStrideH() const { return this->primitive_->value.AsGroupConv2DGradInput()->strideH; } | |||
| int GroupConv2DGradInput::GetPadMode() const { return this->primitive_->value.AsGroupConv2DGradInput()->padMode; } | |||
| int GroupConv2DGradInput::GetPadUp() const { return this->primitive_->value.AsGroupConv2DGradInput()->padUp; } | |||
| int GroupConv2DGradInput::GetPadDown() const { return this->primitive_->value.AsGroupConv2DGradInput()->padDown; } | |||
| int GroupConv2DGradInput::GetPadLeft() const { return this->primitive_->value.AsGroupConv2DGradInput()->padLeft; } | |||
| int GroupConv2DGradInput::GetPadRight() const { return this->primitive_->value.AsGroupConv2DGradInput()->padRight; } | |||
| int GroupConv2DGradInput::GetDilateW() const { return this->primitive_->value.AsGroupConv2DGradInput()->dilateW; } | |||
| int GroupConv2DGradInput::GetDilateH() const { return this->primitive_->value.AsGroupConv2DGradInput()->dilateH; } | |||
| bool GroupConv2DGradInput::GetHasBias() const { return this->primitive_->value.AsGroupConv2DGradInput()->hasBias; } | |||
| int GroupConv2DGradInput::GetActivationType() const { | |||
| return this->primitive_->value.AsGroupConv2DGradInput()->activationType; | |||
| } | |||
| void GroupConv2DGradInput::SetFormat(int format) { | |||
| this->primitive_->value.AsGroupConv2DGradInput()->format = (schema::Format)format; | |||
| } | |||
| void GroupConv2DGradInput::SetGroup(int group) { this->primitive_->value.AsGroupConv2DGradInput()->group = group; } | |||
| void GroupConv2DGradInput::SetChannelIn(int channel_in) { | |||
| this->primitive_->value.AsGroupConv2DGradInput()->channelIn = channel_in; | |||
| } | |||
| void GroupConv2DGradInput::SetChannelOut(int channel_out) { | |||
| this->primitive_->value.AsGroupConv2DGradInput()->channelOut = channel_out; | |||
| } | |||
| void GroupConv2DGradInput::SetKernelW(int kernel_w) { | |||
| this->primitive_->value.AsGroupConv2DGradInput()->kernelW = kernel_w; | |||
| } | |||
| void GroupConv2DGradInput::SetKernelH(int kernel_h) { | |||
| this->primitive_->value.AsGroupConv2DGradInput()->kernelH = kernel_h; | |||
| } | |||
| void GroupConv2DGradInput::SetStrideW(int stride_w) { | |||
| this->primitive_->value.AsGroupConv2DGradInput()->strideW = stride_w; | |||
| } | |||
| void GroupConv2DGradInput::SetStrideH(int stride_h) { | |||
| this->primitive_->value.AsGroupConv2DGradInput()->strideH = stride_h; | |||
| } | |||
| void GroupConv2DGradInput::SetPadMode(int pad_mode) { | |||
| this->primitive_->value.AsGroupConv2DGradInput()->padMode = (schema::PadMode)pad_mode; | |||
| } | |||
| void GroupConv2DGradInput::SetPadUp(int pad_up) { this->primitive_->value.AsGroupConv2DGradInput()->padUp = pad_up; } | |||
| void GroupConv2DGradInput::SetPadDown(int pad_down) { | |||
| this->primitive_->value.AsGroupConv2DGradInput()->padDown = pad_down; | |||
| } | |||
| void GroupConv2DGradInput::SetPadLeft(int pad_left) { | |||
| this->primitive_->value.AsGroupConv2DGradInput()->padLeft = pad_left; | |||
| } | |||
| void GroupConv2DGradInput::SetPadRight(int pad_right) { | |||
| this->primitive_->value.AsGroupConv2DGradInput()->padRight = pad_right; | |||
| } | |||
| void GroupConv2DGradInput::SetDilateW(int dilate_w) { | |||
| this->primitive_->value.AsGroupConv2DGradInput()->dilateW = dilate_w; | |||
| } | |||
| void GroupConv2DGradInput::SetDilateH(int dilate_h) { | |||
| this->primitive_->value.AsGroupConv2DGradInput()->dilateH = dilate_h; | |||
| } | |||
| void GroupConv2DGradInput::SetHasBias(bool has_bias) { | |||
| this->primitive_->value.AsGroupConv2DGradInput()->hasBias = has_bias; | |||
| } | |||
| void GroupConv2DGradInput::SetActivationType(int activation_type) { | |||
| this->primitive_->value.AsGroupConv2DGradInput()->activationType = (schema::ActivationType)activation_type; | |||
| } | |||
| #else | |||
| int GroupConv2DGradInput::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { | |||
| MS_ASSERT(nullptr != primitive); | |||
| MS_ASSERT(nullptr != fbb); | |||
| auto attr = primitive->value_as_GroupConv2DGradInput(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "value_as_GroupConv2DGradInput return nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| auto val_offset = schema::CreateGroupConv2DGradInput( | |||
| *fbb, attr->format(), attr->group(), attr->channelIn(), attr->channelOut(), attr->kernelW(), attr->kernelH(), | |||
| attr->strideW(), attr->strideH(), attr->padMode(), attr->padUp(), attr->padDown(), attr->padLeft(), | |||
| attr->padRight(), attr->dilateW(), attr->dilateH(), attr->hasBias(), attr->activationType()); | |||
| auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_GroupConv2DGradInput, val_offset.o); | |||
| fbb->Finish(prim_offset); | |||
| return RET_OK; | |||
| } | |||
| int GroupConv2DGradInput::GetFormat() const { return this->primitive_->value_as_GroupConv2DGradInput()->format(); } | |||
| int GroupConv2DGradInput::GetGroup() const { return this->primitive_->value_as_GroupConv2DGradInput()->group(); } | |||
| int GroupConv2DGradInput::GetChannelIn() const { | |||
| return this->primitive_->value_as_GroupConv2DGradInput()->channelIn(); | |||
| } | |||
| int GroupConv2DGradInput::GetChannelOut() const { | |||
| return this->primitive_->value_as_GroupConv2DGradInput()->channelOut(); | |||
| } | |||
| int GroupConv2DGradInput::GetKernelW() const { return this->primitive_->value_as_GroupConv2DGradInput()->kernelW(); } | |||
| int GroupConv2DGradInput::GetKernelH() const { return this->primitive_->value_as_GroupConv2DGradInput()->kernelH(); } | |||
| int GroupConv2DGradInput::GetStrideW() const { return this->primitive_->value_as_GroupConv2DGradInput()->strideW(); } | |||
| int GroupConv2DGradInput::GetStrideH() const { return this->primitive_->value_as_GroupConv2DGradInput()->strideH(); } | |||
| int GroupConv2DGradInput::GetPadMode() const { return this->primitive_->value_as_GroupConv2DGradInput()->padMode(); } | |||
| int GroupConv2DGradInput::GetPadUp() const { return this->primitive_->value_as_GroupConv2DGradInput()->padUp(); } | |||
| int GroupConv2DGradInput::GetPadDown() const { return this->primitive_->value_as_GroupConv2DGradInput()->padDown(); } | |||
| int GroupConv2DGradInput::GetPadLeft() const { return this->primitive_->value_as_GroupConv2DGradInput()->padLeft(); } | |||
| int GroupConv2DGradInput::GetPadRight() const { return this->primitive_->value_as_GroupConv2DGradInput()->padRight(); } | |||
| int GroupConv2DGradInput::GetDilateW() const { return this->primitive_->value_as_GroupConv2DGradInput()->dilateW(); } | |||
| int GroupConv2DGradInput::GetDilateH() const { return this->primitive_->value_as_GroupConv2DGradInput()->dilateH(); } | |||
| bool GroupConv2DGradInput::GetHasBias() const { return this->primitive_->value_as_GroupConv2DGradInput()->hasBias(); } | |||
| int GroupConv2DGradInput::GetActivationType() const { | |||
| return this->primitive_->value_as_GroupConv2DGradInput()->activationType(); | |||
| } | |||
| #endif | |||
| int GroupConv2DGradInput::InferShape(std::vector<Tensor *> inputs, std::vector<Tensor *> outputs) { | |||
| if (3 != inputs.size()) { | |||
| MS_LOG(ERROR) << "Conv2d Grad Input should have 3 inputs"; | |||
| return RET_ERROR; | |||
| } | |||
| if (1 != outputs.size()) { | |||
| MS_LOG(ERROR) << "Conv2d Grad input should have one output"; | |||
| return RET_ERROR; | |||
| } | |||
| auto *in0 = inputs.at(0); | |||
| auto *in = inputs.at(2); | |||
| MS_ASSERT(out != nullptr); | |||
| std::vector<int> output_shape; | |||
| int *out_shape = reinterpret_cast<int *>(in->MutableData()); | |||
| int new_size = in->ElementsNum(); | |||
| if (in0->GetFormat() == in->GetFormat()) { | |||
| for (int i = 0; i < new_size; i++) output_shape.push_back(out_shape[i]); | |||
| } else { | |||
| if ((in0->GetFormat() == schema::Format_NHWC) && (in->GetFormat() == schema::Format_NCHW)) { | |||
| output_shape.push_back(out_shape[0]); | |||
| output_shape.push_back(out_shape[2]); | |||
| output_shape.push_back(out_shape[3]); | |||
| output_shape.push_back(out_shape[1]); | |||
| } else { | |||
| MS_LOG(ERROR) << "Shape covnert is not supported"; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| auto *out = outputs.at(0); | |||
| MS_ASSERT(out != nullptr); | |||
| out->set_shape(output_shape); | |||
| out->set_data_type(in0->data_type()); | |||
| out->SetFormat(in0->GetFormat()); | |||
| return RET_OK; | |||
| } | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,79 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_SRC_OPS_GROUP_CONV2D_GRAD_INPUT_H_ | |||
| #define MINDSPORE_LITE_SRC_OPS_GROUP_CONV2D_GRAD_INPUT_H_ | |||
| #include <vector> | |||
| #include <set> | |||
| #include <cmath> | |||
| #include <memory> | |||
| #include <string> | |||
| #include "src/ops/primitive_c.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| class GroupConv2DGradInput : public PrimitiveC { | |||
| public: | |||
| #ifdef PRIMITIVE_WRITEABLE | |||
| MS_DECLARE_PARENT(GroupConv2DGradInput, PrimitiveC); | |||
| GroupConv2DGradInput() = default; | |||
| explicit GroupConv2DGradInput(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} | |||
| void SetFormat(int format); | |||
| void SetGroup(int group); | |||
| void SetChannelIn(int channel_in); | |||
| void SetChannelOut(int channel_out); | |||
| void SetKernelW(int kernel_w); | |||
| void SetKernelH(int kernel_h); | |||
| void SetStrideW(int stride_w); | |||
| void SetStrideH(int stride_h); | |||
| void SetPadMode(int pad_mode); | |||
| void SetPadUp(int pad_up); | |||
| void SetPadDown(int pad_down); | |||
| void SetPadLeft(int pad_left); | |||
| void SetPadRight(int pad_right); | |||
| void SetDilateW(int dilate_w); | |||
| void SetDilateH(int dilate_h); | |||
| void SetHasBias(bool has_bias); | |||
| void SetActivationType(int activation_type); | |||
| #else | |||
| GroupConv2DGradInput() = default; | |||
| int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; | |||
| #endif | |||
| int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override; | |||
| int GetFormat() const; | |||
| int GetGroup() const; | |||
| int GetChannelIn() const; | |||
| int GetChannelOut() const; | |||
| int GetKernelW() const; | |||
| int GetKernelH() const; | |||
| int GetStrideW() const; | |||
| int GetStrideH() const; | |||
| int GetPadMode() const; | |||
| int GetPadUp() const; | |||
| int GetPadDown() const; | |||
| int GetPadLeft() const; | |||
| int GetPadRight() const; | |||
| int GetDilateW() const; | |||
| int GetDilateH() const; | |||
| bool GetHasBias() const; | |||
| int GetActivationType() const; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_LITE_SRC_OPS_GROUP_CONV2D_GRAD_INPUT_H_ | |||
| @@ -18,7 +18,31 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| #ifndef PRIMITIVE_WRITEABLE | |||
| #ifdef PRIMITIVE_WRITEABLE | |||
| int Neg::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { | |||
| if (this->primitive_ == nullptr) { | |||
| this->primitive_ = new (std::nothrow) schema::PrimitiveT; | |||
| if (this->primitive_ == nullptr) { | |||
| MS_LOG(ERROR) << "new primitiveT failed"; | |||
| return RET_ERROR; | |||
| } | |||
| this->primitive_->value.type = schema::PrimitiveType_Neg; | |||
| } | |||
| if (this->primitive_->value.type != schema::PrimitiveType_Neg) { | |||
| MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type; | |||
| return RET_ERROR; | |||
| } | |||
| if (this->primitive_->value.value == nullptr) { | |||
| this->primitive_->value.value = new (std::nothrow) schema::NegT(); | |||
| if (this->primitive_->value.value == nullptr) { | |||
| MS_LOG(ERROR) << "new primitiveT value failed"; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| #else | |||
| int Neg::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { | |||
| MS_ASSERT(primitive != nullptr); | |||
| MS_ASSERT(fbb != nullptr); | |||
| @@ -31,6 +31,7 @@ class Neg : public ArithmeticSelf { | |||
| MS_DECLARE_PARENT(Neg, ArithmeticSelf); | |||
| Neg() = default; | |||
| explicit Neg(schema::PrimitiveT *primitive) : ArithmeticSelf(primitive) {} | |||
| int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override; | |||
| #else | |||
| Neg() = default; | |||
| @@ -23,6 +23,37 @@ int OneHot::GetAxis() const { return this->primitive_->value.AsOneHot()->axis; } | |||
| void OneHot::SetAxis(int axis) { this->primitive_->value.AsOneHot()->axis = axis; } | |||
| int OneHot::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { | |||
| if (this->primitive_ == nullptr) { | |||
| this->primitive_ = new (std::nothrow) schema::PrimitiveT; | |||
| if (this->primitive_ == nullptr) { | |||
| MS_LOG(ERROR) << "new primitiveT failed"; | |||
| return RET_ERROR; | |||
| } | |||
| this->primitive_->value.type = schema::PrimitiveType_OneHot; | |||
| } | |||
| if (this->primitive_->value.type != schema::PrimitiveType_OneHot) { | |||
| MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type; | |||
| return RET_ERROR; | |||
| } | |||
| if (this->primitive_->value.value == nullptr) { | |||
| auto attr = new (std::nothrow) schema::OneHotT(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new primitiveT value failed"; | |||
| return RET_ERROR; | |||
| } | |||
| attr->axis = -1; | |||
| if (prim.GetAttr("axis") != nullptr) { | |||
| attr->axis = GetValue<int>(prim.GetAttr("axis")); | |||
| } | |||
| this->primitive_->value.value = attr; | |||
| if (this->primitive_->value.value == nullptr) { | |||
| MS_LOG(ERROR) << "primitive value is nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| #else | |||
| int OneHot::GetAxis() const { return this->primitive_->value_as_OneHot()->axis(); } | |||
| @@ -32,7 +32,7 @@ class OneHot : public PrimitiveC { | |||
| OneHot() = default; | |||
| explicit OneHot(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} | |||
| void SetAxis(int axis); | |||
| int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override; | |||
| #else | |||
| OneHot() = default; | |||
| @@ -144,6 +144,7 @@ | |||
| #include "src/ops/pooling_grad.h" | |||
| #include "src/ops/conv2d_grad_filter.h" | |||
| #include "src/ops/conv2d_grad_input.h" | |||
| #include "src/ops/group_conv2d_grad_input.h" | |||
| #include "src/ops/power_grad.h" | |||
| #include "src/ops/softmax_cross_entropy.h" | |||
| #include "src/ops/bn_grad.h" | |||
| @@ -152,6 +153,8 @@ | |||
| #include "src/ops/flatten_grad.h" | |||
| #include "src/ops/log_grad.h" | |||
| #include "src/ops/sgd.h" | |||
| #include "src/ops/adam.h" | |||
| #include "src/ops/assign.h" | |||
| #endif | |||
| namespace mindspore { | |||
| @@ -367,7 +370,7 @@ std::shared_ptr<PrimitiveC> NewPrimitiveC(const Primitive &prim, const std::vect | |||
| std::shared_ptr<PrimitiveC> PrimitiveC::Create(const Primitive &prim, const std::vector<AnfNodePtr> &inputs, | |||
| const schema::QuantType &quantType) { | |||
| const auto &op_type = prim.name(); | |||
| if (op_type == "ReLU" || op_type == "ReLU6" || op_type == "Sigmoid") { | |||
| if (op_type == "ReLU" || op_type == "ReLU6" || op_type == "Sigmoid" || op_type == "HSwish" || op_type == "HSigmoid") { | |||
| return NewPrimitiveC<Activation>(prim, inputs, quantType); | |||
| } else if (op_type == "AddN") { | |||
| return NewPrimitiveC<AddN>(prim, inputs, quantType); | |||
| @@ -413,6 +416,10 @@ std::shared_ptr<PrimitiveC> PrimitiveC::Create(const Primitive &prim, const std: | |||
| return NewPrimitiveC<Reduce>(prim, inputs, quantType); | |||
| } else if (op_type == "Reshape") { | |||
| return NewPrimitiveC<Reshape>(prim, inputs, quantType); | |||
| } else if (op_type == "Slice") { | |||
| return NewPrimitiveC<Slice>(prim, inputs, quantType); | |||
| } else if (op_type == "Squeeze") { | |||
| return NewPrimitiveC<Squeeze>(prim, inputs, quantType); | |||
| } else if (op_type == "TensorAdd") { | |||
| return NewPrimitiveC<Add>(prim, inputs, quantType); | |||
| } else if (op_type == "Transpose") { | |||
| @@ -421,6 +428,10 @@ std::shared_ptr<PrimitiveC> PrimitiveC::Create(const Primitive &prim, const std: | |||
| return NewPrimitiveC<Elu>(prim, inputs, quantType); | |||
| } else if (op_type == "Log") { | |||
| return NewPrimitiveC<Log>(prim, inputs, quantType); | |||
| } else if (op_type == "Exp") { | |||
| return NewPrimitiveC<Exp>(prim, inputs, quantType); | |||
| } else if (op_type == "Neg") { | |||
| return NewPrimitiveC<Neg>(prim, inputs, quantType); | |||
| } else if (op_type == "DeConv2D") { | |||
| return NewPrimitiveC<DeConv2D>(prim, inputs, quantType); | |||
| } else if (op_type == "tuple_getitem") { | |||
| @@ -435,6 +446,8 @@ std::shared_ptr<PrimitiveC> PrimitiveC::Create(const Primitive &prim, const std: | |||
| return NewPrimitiveC<Maximum>(prim, inputs, quantType); | |||
| } else if (op_type == "Split") { | |||
| return NewPrimitiveC<Split>(prim, inputs, quantType); | |||
| } else if (op_type == "OneHot") { | |||
| return NewPrimitiveC<OneHot>(prim, inputs, quantType); | |||
| #ifdef SUPPORT_TRAIN | |||
| } else if (op_type == "SoftmaxCrossEntropyWithLogits") { | |||
| @@ -445,7 +458,8 @@ std::shared_ptr<PrimitiveC> PrimitiveC::Create(const Primitive &prim, const std: | |||
| return NewPrimitiveC<ApplyMomentum>(prim, inputs, quantType); | |||
| } else if (op_type == "Depend") { | |||
| return NewPrimitiveC<Depend>(prim, inputs, quantType); | |||
| } else if ((op_type == "ReluGrad" || op_type == "Relu6Grad" || op_type == "SigmoidGrad")) { | |||
| } else if ((op_type == "ReluGrad" || op_type == "ReLU6Grad" || op_type == "SigmoidGrad" || | |||
| op_type == "HSigmoidGrad" || op_type == "HSwishGrad")) { | |||
| return NewPrimitiveC<ActivationGrad>(prim, inputs, quantType); | |||
| } else if ((op_type == "MaxPoolGrad") || (op_type == "MeanPoolGrad")) { | |||
| return NewPrimitiveC<PoolingGrad>(prim, inputs, quantType); | |||
| @@ -465,6 +479,10 @@ std::shared_ptr<PrimitiveC> PrimitiveC::Create(const Primitive &prim, const std: | |||
| return NewPrimitiveC<PowerGrad>(prim, inputs, quantType); | |||
| } else if (op_type == "SGD") { | |||
| return NewPrimitiveC<Sgd>(prim, inputs, quantType); | |||
| } else if (op_type == "Adam") { | |||
| return NewPrimitiveC<Adam>(prim, inputs, quantType); | |||
| } else if (op_type == "Assign") { | |||
| return NewPrimitiveC<Assign>(prim, inputs, quantType); | |||
| #else | |||
| } else if (op_type == "Conv2DBackpropInput") { | |||
| return NewPrimitiveC<DeConv2D>(prim, inputs, quantType); | |||
| @@ -686,6 +704,8 @@ PrimitiveC *PrimitiveC::Create(mindspore::schema::PrimitiveT *primitive) { | |||
| return new Dropout(primitive); | |||
| case schema::PrimitiveType_Neg: | |||
| return new Neg(primitive); | |||
| case schema::PrimitiveType_RealDiv: | |||
| return new RealDiv(primitive); | |||
| case schema::PrimitiveType_LshProjection: | |||
| return new LshProjection(primitive); | |||
| case schema::PrimitiveType_HashtableLookup: | |||
| @@ -710,6 +730,8 @@ PrimitiveC *PrimitiveC::Create(mindspore::schema::PrimitiveT *primitive) { | |||
| return new Conv2DGradFilter(primitive); | |||
| case schema::PrimitiveType_Conv2DGradInput: | |||
| return new Conv2DGradInput(primitive); | |||
| case schema::PrimitiveType_GroupConv2DGradInput: | |||
| return new GroupConv2DGradInput(primitive); | |||
| case schema::PrimitiveType_BiasGrad: | |||
| return new BiasGrad(primitive); | |||
| case schema::PrimitiveType_ApplyMomentum: | |||
| @@ -738,8 +760,11 @@ PrimitiveC *PrimitiveC::Create(mindspore::schema::PrimitiveT *primitive) { | |||
| return new LogGrad(primitive); | |||
| case schema::PrimitiveType_Sgd: | |||
| return new Sgd(primitive); | |||
| case schema::PrimitiveType_Adam: | |||
| return new Adam(primitive); | |||
| case schema::PrimitiveType_Assign: | |||
| return new Assign(primitive); | |||
| #endif | |||
| default: | |||
| MS_LOG(ERROR) << "Unsupported primitive type in Create : " << schema::EnumNamePrimitiveType(op_type); | |||
| break; | |||
| @@ -958,6 +983,8 @@ PrimitiveC *PrimitiveC::Create(const schema::Primitive *primitive) { | |||
| return NewPrimitiveC<DetectionPostProcess>(primitive); | |||
| case schema::PrimitiveType_Dropout: | |||
| return NewPrimitiveC<Dropout>(primitive); | |||
| case schema::PrimitiveType_RealDiv: | |||
| return NewPrimitiveC<RealDiv>(primitive); | |||
| case schema::PrimitiveType_LshProjection: | |||
| return NewPrimitiveC<LshProjection>(primitive); | |||
| case schema::PrimitiveType_HashtableLookup: | |||
| @@ -982,6 +1009,8 @@ PrimitiveC *PrimitiveC::Create(const schema::Primitive *primitive) { | |||
| return NewPrimitiveC<Conv2DGradFilter>(primitive); | |||
| case schema::PrimitiveType_Conv2DGradInput: | |||
| return NewPrimitiveC<Conv2DGradInput>(primitive); | |||
| case schema::PrimitiveType_GroupConv2DGradInput: | |||
| return NewPrimitiveC<GroupConv2DGradInput>(primitive); | |||
| case schema::PrimitiveType_BiasGrad: | |||
| return NewPrimitiveC<BiasGrad>(primitive); | |||
| case schema::PrimitiveType_ApplyMomentum: | |||
| @@ -1004,6 +1033,10 @@ PrimitiveC *PrimitiveC::Create(const schema::Primitive *primitive) { | |||
| return NewPrimitiveC<LogGrad>(primitive); | |||
| case schema::PrimitiveType_Sgd: | |||
| return NewPrimitiveC<Sgd>(primitive); | |||
| case schema::PrimitiveType_Adam: | |||
| return NewPrimitiveC<Adam>(primitive); | |||
| case schema::PrimitiveType_Assign: | |||
| return NewPrimitiveC<Assign>(primitive); | |||
| #endif | |||
| default: | |||
| MS_LOG(ERROR) << "Unsupported primitive type in Create : " << schema::EnumNamePrimitiveType(op_type); | |||
| @@ -14,8 +14,8 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CORE_C_OPS_PRIMITIVE_C_H_ | |||
| #define MINDSPORE_CORE_C_OPS_PRIMITIVE_C_H_ | |||
| #ifndef MINDSPORE_LITE_SRC_OPS_PRIMITIVE_C_H_ | |||
| #define MINDSPORE_LITE_SRC_OPS_PRIMITIVE_C_H_ | |||
| #include <string> | |||
| #include <set> | |||
| #include <vector> | |||
| @@ -48,7 +48,9 @@ constexpr int kAnfPopulaterTwo = 2; | |||
| constexpr int kAnfPopulaterThree = 3; | |||
| static std::map<std::string, schema::ActivationType> kActivationTypeMap{{"ReLU", schema::ActivationType_RELU}, | |||
| {"ReLU6", schema::ActivationType_RELU6}, | |||
| {"Sigmoid", schema::ActivationType_SIGMOID}}; | |||
| {"Sigmoid", schema::ActivationType_SIGMOID}, | |||
| {"HSwish", schema::ActivationType_HSWISH}, | |||
| {"HSigmoid", schema::ActivationType_HSIGMOID}}; | |||
| class PrimitiveC : public mindspore::Primitive { | |||
| public: | |||
| // Argument primitive is deliverd into PrimitiveC and will be deleted in ~PrimitiveC(). | |||
| @@ -213,4 +215,4 @@ class PrimitiveC { | |||
| #endif | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CORE_C_OPS_PRIMITIVE_C_H_ | |||
| #endif // MINDSPORE_LITE_SRC_OPS_PRIMITIVE_C_H_ | |||
| @@ -44,7 +44,14 @@ int RealDiv::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &in | |||
| } | |||
| #else | |||
| int RealDiv::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { | |||
| MS_ASSERT(nullptr != primitive); | |||
| MS_ASSERT(nullptr != fbb); | |||
| auto val_offset = schema::CreateRank(*fbb); | |||
| auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_RealDiv, val_offset.o); | |||
| fbb->Finish(prim_offset); | |||
| return RET_OK; | |||
| } | |||
| #endif | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -36,10 +36,7 @@ class RealDiv : public Arithmetic { | |||
| #else | |||
| RealDiv() = default; | |||
| int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override { | |||
| return RET_ERROR; | |||
| } | |||
| int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; | |||
| #endif | |||
| }; | |||
| } // namespace lite | |||
| @@ -35,6 +35,66 @@ void Slice::SetFormat(int format) { this->primitive_->value.AsSlice()->format = | |||
| void Slice::SetBegin(const std::vector<int> &begin) { this->primitive_->value.AsSlice()->begin = begin; } | |||
| void Slice::SetSize(const std::vector<int> &size) { this->primitive_->value.AsSlice()->size = size; } | |||
| int Slice::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { | |||
| if (this->primitive_ == nullptr) { | |||
| this->primitive_ = new (std::nothrow) schema::PrimitiveT; | |||
| if (this->primitive_ == nullptr) { | |||
| MS_LOG(ERROR) << "new primitiveT failed"; | |||
| return RET_ERROR; | |||
| } | |||
| this->primitive_->value.type = schema::PrimitiveType_Slice; | |||
| } | |||
| if (this->primitive_->value.type != schema::PrimitiveType_Slice) { | |||
| MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type; | |||
| return RET_ERROR; | |||
| } | |||
| if (this->primitive_->value.value == nullptr) { | |||
| auto attr = new (std::nothrow) schema::SliceT(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new primitiveT value failed"; | |||
| return RET_ERROR; | |||
| } | |||
| if (inputs.size() >= kAnfPopulaterThree) { | |||
| auto beginNode = inputs[kAnfPopulaterOne]; | |||
| MS_ASSERT(beginNode != nullptr); | |||
| if (beginNode->isa<ValueNode>()) { | |||
| auto valueNode = beginNode->cast<ValueNodePtr>(); | |||
| MS_ASSERT(valueNode != nullptr); | |||
| auto value = valueNode->value(); | |||
| MS_ASSERT(value != nullptr); | |||
| if (value->isa<ValueTuple>()) { | |||
| auto valTuplPtr = dyn_cast<ValueTuple>(value); | |||
| MS_ASSERT(valTuplPtr != nullptr); | |||
| for (size_t i = 0; i < valTuplPtr->size(); i++) { | |||
| auto elem = dyn_cast<Int32Imm>((*valTuplPtr)[i]); | |||
| MS_ASSERT(elem != nullptr); | |||
| attr->begin.emplace_back(elem->value()); | |||
| } | |||
| } | |||
| } | |||
| auto sizeNode = inputs[kAnfPopulaterTwo]; | |||
| MS_ASSERT(sizeNode != nullptr); | |||
| if (sizeNode->isa<ValueNode>()) { | |||
| auto valueNode = sizeNode->cast<ValueNodePtr>(); | |||
| MS_ASSERT(valueNode != nullptr); | |||
| auto value = valueNode->value(); | |||
| MS_ASSERT(value != nullptr); | |||
| if (value->isa<ValueTuple>()) { | |||
| auto valTuplPtr = dyn_cast<ValueTuple>(value); | |||
| MS_ASSERT(valTuplPtr != nullptr); | |||
| for (size_t i = 0; i < valTuplPtr->size(); i++) { | |||
| auto elem = dyn_cast<Int32Imm>((*valTuplPtr)[i]); | |||
| MS_ASSERT(elem != nullptr); | |||
| attr->size.emplace_back(elem->value()); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| this->primitive_->value.value = attr; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| #else | |||
| int Slice::GetFormat() const { return this->primitive_->value_as_Slice()->format(); } | |||
| @@ -46,10 +106,12 @@ std::vector<int> Slice::GetSize() const { | |||
| auto fb_vector = this->primitive_->value_as_Slice()->size(); | |||
| return std::vector<int>(fb_vector->begin(), fb_vector->end()); | |||
| } | |||
| std::vector<int> Slice::GetAxes() const { | |||
| auto fb_vector = this->primitive_->value_as_Slice()->axes(); | |||
| return std::vector<int>(fb_vector->begin(), fb_vector->end()); | |||
| } | |||
| int Slice::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { | |||
| MS_ASSERT(nullptr != primitive); | |||
| MS_ASSERT(nullptr != fbb); | |||
| @@ -90,7 +152,7 @@ std::vector<int> Slice::GetPostProcessBegin() const { return this->begin; } | |||
| std::vector<int> Slice::GetPostProcessSize() const { return this->size; } | |||
| int Slice::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lite::Tensor *> outputs) { | |||
| MS_ASSERT(this->primitive_ != nullptr); | |||
| if (inputs.size() != kSliceInputNum || outputs.size() != kSliceOutputNum) { | |||
| if (inputs.size() < kSliceInputNum || outputs.size() != kSliceOutputNum) { | |||
| MS_LOG(ERROR) << "input size:" << inputs.size() << ",output size:" << outputs.size(); | |||
| return RET_PARAM_INVALID; | |||
| } | |||
| @@ -14,8 +14,8 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef LITE_MINDSPORE_LITE_C_OPS_SLICE_H_ | |||
| #define LITE_MINDSPORE_LITE_C_OPS_SLICE_H_ | |||
| #ifndef MINDSPORE_LITE_SRC_OPS_SLICE_H_ | |||
| #define MINDSPORE_LITE_SRC_OPS_SLICE_H_ | |||
| #include <vector> | |||
| #include <set> | |||
| @@ -35,6 +35,7 @@ class Slice : public PrimitiveC { | |||
| void SetFormat(int format); | |||
| void SetBegin(const std::vector<int> &begin); | |||
| void SetSize(const std::vector<int> &size); | |||
| int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override; | |||
| #else | |||
| Slice() = default; | |||
| @@ -56,4 +57,4 @@ class Slice : public PrimitiveC { | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // LITE_MINDSPORE_LITE_C_OPS_SLICE_H_ | |||
| #endif // MINDSPORE_LITE_SRC_OPS_SLICE_H_ | |||
| @@ -23,6 +23,35 @@ std::vector<int> Squeeze::GetAxis() const { return this->primitive_->value.AsSqu | |||
| void Squeeze::SetAxis(const std::vector<int> &axis) { this->primitive_->value.AsSqueeze()->axis = axis; } | |||
| int Squeeze::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { | |||
| if (this->primitive_ == nullptr) { | |||
| this->primitive_ = new (std::nothrow) schema::PrimitiveT; | |||
| if (this->primitive_ == nullptr) { | |||
| MS_LOG(ERROR) << "new primitiveT failed"; | |||
| return RET_ERROR; | |||
| } | |||
| this->primitive_->value.type = schema::PrimitiveType_Squeeze; | |||
| } | |||
| if (this->primitive_->value.type != schema::PrimitiveType_Squeeze) { | |||
| MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type; | |||
| return RET_ERROR; | |||
| } | |||
| if (this->primitive_->value.value == nullptr) { | |||
| auto attr = new (std::nothrow) schema::SqueezeT(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new primitiveT value failed"; | |||
| return RET_ERROR; | |||
| } | |||
| attr->axis = GetValue<std::vector<int>>(prim.GetAttr("axis")); | |||
| this->primitive_->value.value = attr; | |||
| if (this->primitive_->value.value == nullptr) { | |||
| MS_LOG(ERROR) << "primitive value is nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| #else | |||
| std::vector<int> Squeeze::GetAxis() const { | |||
| @@ -14,8 +14,8 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef LITE_MINDSPORE_LITE_C_OPS_SQUEEZE_H_ | |||
| #define LITE_MINDSPORE_LITE_C_OPS_SQUEEZE_H_ | |||
| #ifndef MINDSPORE_LITE_SRC_OPS_SQUEEZE_H_ | |||
| #define MINDSPORE_LITE_SRC_OPS_SQUEEZE_H_ | |||
| #include <vector> | |||
| #include <set> | |||
| @@ -33,6 +33,7 @@ class Squeeze : public PrimitiveC { | |||
| Squeeze() = default; | |||
| explicit Squeeze(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} | |||
| void SetAxis(const std::vector<int> &axis); | |||
| int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override; | |||
| #else | |||
| Squeeze() = default; | |||
| @@ -45,4 +46,4 @@ class Squeeze : public PrimitiveC { | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // LITE_MINDSPORE_LITE_C_OPS_SQUEEZE_H_ | |||
| #endif // MINDSPORE_LITE_SRC_OPS_SQUEEZE_H_ | |||
| @@ -24,7 +24,7 @@ std::vector<int> Tile::GetMultiples() const { return this->primitive_->value.AsT | |||
| void Tile::SetMultiples(const std::vector<int> &multiples) { this->primitive_->value.AsTile()->multiples = multiples; } | |||
| std::vector<int> Tile::GetDims() const { return this->primitive_->value.AsTile()->multiples; } | |||
| std::vector<int> Tile::GetDims() const { return this->primitive_->value.AsTile()->dims; } | |||
| void Tile::SetDims(const std::vector<int> &dims) { this->primitive_->value.AsTile()->dims = dims; } | |||
| @@ -42,11 +42,32 @@ int Tile::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &input | |||
| return RET_ERROR; | |||
| } | |||
| if (this->primitive_->value.value == nullptr) { | |||
| this->primitive_->value.value = new (std::nothrow) schema::TileT(); | |||
| if (this->primitive_->value.value == nullptr) { | |||
| auto attr = new (std::nothrow) schema::TileT(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new primitiveT value failed"; | |||
| return RET_ERROR; | |||
| } | |||
| if (inputs.size() == kAnfPopulaterTwo) { | |||
| auto inputNode = inputs[kAnfPopulaterOne]; | |||
| MS_ASSERT(inputNode != nullptr); | |||
| if (inputNode->isa<ValueNode>()) { | |||
| auto valueNode = inputNode->cast<ValueNodePtr>(); | |||
| MS_ASSERT(valueNode != nullptr); | |||
| auto value = valueNode->value(); | |||
| MS_ASSERT(value != nullptr); | |||
| if (value->isa<ValueTuple>()) { | |||
| auto valTuplPtr = dyn_cast<ValueTuple>(value); | |||
| MS_ASSERT(valTuplPtr != nullptr); | |||
| for (size_t i = 0; i < valTuplPtr->size(); i++) { | |||
| auto elem = dyn_cast<Int32Imm>((*valTuplPtr)[i]); | |||
| MS_ASSERT(elem != nullptr); | |||
| attr->multiples.emplace_back(elem->value()); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| this->primitive_->value.value = attr; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| @@ -103,15 +124,19 @@ int Tile::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> output | |||
| MS_ASSERT(tile_prim != nullptr); | |||
| std::vector<int> out_shape; | |||
| std::vector<int> multiples; | |||
| for (size_t i = 0; i < GetMultiples().size(); ++i) { | |||
| multiples.push_back(GetMultiples()[i]); | |||
| std::vector<int> multiples = GetMultiples(); | |||
| const size_t in_dims = input->shape().size(); | |||
| const size_t delta_dims = in_dims - multiples.size(); | |||
| size_t i = 0; | |||
| for (; i < delta_dims; ++i) { | |||
| int tmp = input->shape()[i]; | |||
| out_shape.push_back(tmp); | |||
| } | |||
| for (size_t i = 0; i < input->shape().size(); ++i) { | |||
| int tmp = input->shape()[i] * multiples[i]; | |||
| for (; i < in_dims; ++i) { | |||
| int tmp = input->shape()[i] * (multiples[i - delta_dims]); | |||
| out_shape.push_back(tmp); | |||
| } | |||
| output->set_shape(out_shape); | |||
| return RET_OK; | |||
| } | |||
| @@ -43,6 +43,7 @@ | |||
| #include "src/ops/add.h" | |||
| #include "src/ops/sub.h" | |||
| #include "src/ops/div.h" | |||
| #include "src/ops/real_div.h" | |||
| #include "src/ops/bias_add.h" | |||
| #include "src/ops/expand_dims.h" | |||
| #include "src/ops/full_connection.h" | |||
| @@ -1680,6 +1681,7 @@ PopulateParameterRegistry::PopulateParameterRegistry() { | |||
| populate_parameter_funcs_[schema::PrimitiveType_Add] = PopulateArithmetic; | |||
| populate_parameter_funcs_[schema::PrimitiveType_Sub] = PopulateArithmetic; | |||
| populate_parameter_funcs_[schema::PrimitiveType_Div] = PopulateArithmetic; | |||
| populate_parameter_funcs_[schema::PrimitiveType_RealDiv] = PopulateArithmetic; | |||
| populate_parameter_funcs_[schema::PrimitiveType_LogicalAnd] = PopulateArithmetic; | |||
| populate_parameter_funcs_[schema::PrimitiveType_LogicalOr] = PopulateArithmetic; | |||
| populate_parameter_funcs_[schema::PrimitiveType_Equal] = PopulateArithmetic; | |||
| @@ -24,6 +24,7 @@ using mindspore::kernel::KERNEL_ARCH::kCPU; | |||
| using mindspore::lite::KernelRegistrar; | |||
| using mindspore::lite::RET_ERROR; | |||
| using mindspore::lite::RET_OK; | |||
| using mindspore::schema::ActivationType_HSIGMOID; | |||
| using mindspore::schema::ActivationType_HSWISH; | |||
| using mindspore::schema::ActivationType_LEAKY_RELU; | |||
| using mindspore::schema::ActivationType_RELU; | |||
| @@ -57,6 +58,8 @@ int ActivationCPUKernel::DoActivation(int task_id) { | |||
| error_code = Tanh(input_addr + stride * task_id, count, output_addr + stride * task_id); | |||
| } else if (type_ == schema::ActivationType_HSWISH) { | |||
| error_code = HSwish(input_addr + stride * task_id, count, output_addr + stride * task_id); | |||
| } else if (type_ == schema::ActivationType_HSIGMOID) { | |||
| error_code = HSigmoid(input_addr + stride * task_id, count, output_addr + stride * task_id); | |||
| } else if (type_ == schema::ActivationType_HARD_TANH) { | |||
| error_code = HardTanh(input_addr + stride * task_id, count, output_addr + stride * task_id, min_val_, max_val_); | |||
| } else { | |||
| @@ -60,14 +60,34 @@ int AddNCPUKernel::Run() { | |||
| MS_LOG(ERROR) << "Prepare fail!ret: " << ret; | |||
| return ret; | |||
| } | |||
| elements_num_ = in_tensors_[0]->ElementsNum(); | |||
| elements_num_ = out_tensors_[0]->ElementsNum(); | |||
| auto input0_data = reinterpret_cast<float *>(in_tensors_[0]->MutableData()); | |||
| auto input1_data = reinterpret_cast<float *>(in_tensors_[1]->MutableData()); | |||
| auto output_data = reinterpret_cast<float *>(out_tensors_[0]->MutableData()); | |||
| if (static_cast<int>(elements_num_) < op_parameter_->thread_num_) { | |||
| ElementAdd(input0_data, input1_data, output_data, elements_num_); | |||
| if (in_tensors_[0]->shape() == in_tensors_[1]->shape()) { | |||
| ElementAdd(input0_data, input1_data, output_data, elements_num_); | |||
| } else { | |||
| ArithmeticParameter param; | |||
| param.in_elements_num0_ = in_tensors_[0]->ElementsNum(); | |||
| param.in_elements_num1_ = in_tensors_[1]->ElementsNum(); | |||
| param.out_elements_num_ = out_tensors_[0]->ElementsNum(); | |||
| param.broadcasting_ = true; | |||
| ElementOptAdd(input0_data, input1_data, output_data, elements_num_, ¶m); | |||
| } | |||
| for (size_t i = 2; i < in_tensors_.size(); ++i) { | |||
| ElementAdd(reinterpret_cast<float *>(in_tensors_[i]->MutableData()), output_data, output_data, elements_num_); | |||
| if (in_tensors_[i]->shape() == out_tensors_[0]->shape()) { | |||
| ElementAdd(reinterpret_cast<float *>(in_tensors_[i]->MutableData()), output_data, output_data, elements_num_); | |||
| } else { | |||
| ArithmeticParameter param; | |||
| param.in_elements_num0_ = in_tensors_[i]->ElementsNum(); | |||
| param.in_elements_num1_ = out_tensors_[0]->ElementsNum(); | |||
| param.out_elements_num_ = out_tensors_[0]->ElementsNum(); | |||
| param.broadcasting_ = true; | |||
| ElementOptAdd(reinterpret_cast<float *>(in_tensors_[i]->MutableData()), output_data, output_data, elements_num_, | |||
| ¶m); | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| @@ -104,6 +104,7 @@ int ArithmeticCPUKernel::ReSize() { | |||
| } | |||
| break; | |||
| case PrimitiveType_Div: | |||
| case PrimitiveType_RealDiv: | |||
| switch (arithmeticParameter_->activation_type_) { | |||
| case schema::ActivationType_RELU: | |||
| arithmeticParameter_->broadcasting_ = false; | |||
| @@ -304,6 +305,7 @@ REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Add, CpuArithmeticFp32KernelC | |||
| REG_KERNEL(kCPU, kNumberTypeInt, PrimitiveType_Add, CpuArithmeticFp32KernelCreator) | |||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Sub, CpuArithmeticFp32KernelCreator) | |||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Div, CpuArithmeticFp32KernelCreator) | |||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_RealDiv, CpuArithmeticFp32KernelCreator) | |||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_LogicalAnd, CpuArithmeticFp32KernelCreator) | |||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_LogicalOr, CpuArithmeticFp32KernelCreator) | |||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Maximum, CpuArithmeticFp32KernelCreator) | |||
| @@ -37,6 +37,7 @@ using mindspore::schema::PrimitiveType_Maximum; | |||
| using mindspore::schema::PrimitiveType_Minimum; | |||
| using mindspore::schema::PrimitiveType_Mul; | |||
| using mindspore::schema::PrimitiveType_NotEqual; | |||
| using mindspore::schema::PrimitiveType_RealDiv; | |||
| using mindspore::schema::PrimitiveType_SquaredDifference; | |||
| using mindspore::schema::PrimitiveType_Sub; | |||
| @@ -99,6 +100,7 @@ class ArithmeticCPUKernel : public LiteKernel { | |||
| } | |||
| break; | |||
| case PrimitiveType_Div: | |||
| case PrimitiveType_RealDiv: | |||
| switch (arithmeticParameter_->activation_type_) { | |||
| case schema::ActivationType_RELU: | |||
| arithmetic_run_ = ElementDivRelu; | |||
| @@ -48,7 +48,7 @@ int ActivationGradCPUKernel::DoActivation(int task_id) { | |||
| auto output_addr = reinterpret_cast<float *>(out_tensors_.at(0)->MutableData()); | |||
| int length = in_tensors_.at(0)->ElementsNum(); | |||
| int stride = UP_DIV(length, thread_count_); | |||
| int stride = UP_DIV(length, 1); | |||
| int count = MSMIN(stride, length - stride * task_id); | |||
| auto error_code = RET_OK; | |||
| @@ -63,8 +63,9 @@ int ActivationGradCPUKernel::DoActivation(int task_id) { | |||
| error_code = LReluGrad(yt_addr + stride * task_id, input_addr + stride * task_id, count, | |||
| output_addr + stride * task_id, param_act_grad_->alpha_); | |||
| } else if (param_act_grad_->type_ == schema::ActivationType_SIGMOID) { | |||
| // Sigmoid gets the input tensors in reverse order! | |||
| error_code = | |||
| SigmoidGrad(yt_addr + stride * task_id, input_addr + stride * task_id, count, output_addr + stride * task_id); | |||
| SigmoidGrad(input_addr + stride * task_id, yt_addr + stride * task_id, count, output_addr + stride * task_id); | |||
| } else if (param_act_grad_->type_ == schema::ActivationType_TANH) { | |||
| error_code = | |||
| TanhGrad(yt_addr + stride * task_id, input_addr + stride * task_id, count, output_addr + stride * task_id); | |||
| @@ -27,7 +27,7 @@ class ActivationGradCPUKernel : public LiteKernel { | |||
| explicit ActivationGradCPUKernel(OpParameter *param, const std::vector<lite::Tensor *> &inputs, | |||
| const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx, | |||
| const mindspore::lite::PrimitiveC *primitive) | |||
| : LiteKernel(param, inputs, outputs, ctx, primitive), thread_count_(ctx->thread_num_) { | |||
| : LiteKernel(param, inputs, outputs, ctx, primitive) { | |||
| param_act_grad_ = reinterpret_cast<ActivationParameter *>(param); | |||
| } | |||
| ~ActivationGradCPUKernel() override = default; | |||
| @@ -38,7 +38,6 @@ class ActivationGradCPUKernel : public LiteKernel { | |||
| int DoActivation(int task_id); | |||
| private: | |||
| int thread_count_; | |||
| ActivationParameter *param_act_grad_; | |||
| }; | |||
| @@ -0,0 +1,118 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "src/runtime/kernel/arm/fp32_grad/adam.h" | |||
| #include <cmath> | |||
| #include "schema/model_generated.h" | |||
| #include "src/kernel_registry.h" | |||
| #include "include/errorcode.h" | |||
| #include "src/runtime/runtime_api.h" | |||
| #include "src/runtime/kernel/arm/fp32/nchw2nhwc.h" | |||
| using mindspore::kernel::KERNEL_ARCH::kCPU; | |||
| using mindspore::lite::KernelRegistrar; | |||
| using mindspore::lite::RET_ERROR; | |||
| using mindspore::lite::RET_OK; | |||
| using mindspore::schema::PrimitiveType_Adam; | |||
| namespace mindspore::kernel { | |||
| int AdamCPUKernel::ReSize() { return RET_OK; } | |||
| int AdamCPUKernel::Execute(int task_id) { | |||
| auto weight = reinterpret_cast<float *>(in_tensors_[0]->MutableData()); | |||
| auto m = reinterpret_cast<float *>(in_tensors_[1]->MutableData()); | |||
| auto v = reinterpret_cast<float *>(in_tensors_[2]->MutableData()); | |||
| auto beta1_power = reinterpret_cast<float *>(in_tensors_[3]->MutableData())[0]; | |||
| auto beta2_power = reinterpret_cast<float *>(in_tensors_[4]->MutableData())[0]; | |||
| auto learning_rate = reinterpret_cast<float *>(in_tensors_[5]->MutableData())[0]; | |||
| auto beta1 = reinterpret_cast<float *>(in_tensors_[6]->MutableData())[0]; | |||
| auto beta2 = reinterpret_cast<float *>(in_tensors_[7]->MutableData())[0]; | |||
| auto eps = reinterpret_cast<float *>(in_tensors_[8]->MutableData())[0]; | |||
| auto gradient = reinterpret_cast<float *>(in_tensors_[9]->MutableData()); | |||
| size_t elem_num = in_tensors_[0]->ElementsNum(); | |||
| if (adam_param_->use_nesterov_) { // Nadam | |||
| for (size_t i = 0; i < elem_num; ++i) { | |||
| m[i] = (m[i] * beta1) + (gradient[i] * (1.f - beta1)); | |||
| v[i] = (v[i] * beta2) + (gradient[i] * gradient[i] * (1.f - beta2)); | |||
| auto g_hat = gradient[i] / (1 - beta1_power); | |||
| auto m_hat = m[i] / (1 - beta1_power); | |||
| auto v_hat = v[i] / (1 - beta2_power); | |||
| auto m_tag = (1.f - beta1) * g_hat + beta1 * m_hat; | |||
| weight[i] -= learning_rate * m_tag / (sqrtf(v_hat) + eps); | |||
| } | |||
| } else { | |||
| for (size_t i = 0; i < elem_num; ++i) { | |||
| m[i] = (m[i] * beta1) + (gradient[i] * (1.f - beta1)); | |||
| v[i] = (v[i] * beta2) + (gradient[i] * gradient[i] * (1.f - beta2)); | |||
| auto m_hat = m[i] / (1 - beta1_power); | |||
| auto v_hat = v[i] / (1 - beta2_power); | |||
| weight[i] -= learning_rate * m_hat / (sqrtf(v_hat) + eps); | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| int AdamRun(void *cdata, int task_id) { | |||
| auto Adam_kernel = reinterpret_cast<AdamCPUKernel *>(cdata); | |||
| auto error_code = Adam_kernel->Execute(task_id); | |||
| if (error_code != RET_OK) { | |||
| MS_LOG(ERROR) << "Adam run error task_id[" << task_id << "] error_code[" << error_code << "]"; | |||
| return RET_ERROR; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| int AdamCPUKernel::Run() { | |||
| auto prepare_ret = Prepare(); | |||
| if (prepare_ret != RET_OK) { | |||
| MS_LOG(ERROR) << "AdamCPUKernel Prepare fail!ret: " << prepare_ret; | |||
| return prepare_ret; | |||
| } | |||
| int error_code = ParallelLaunch(this->context_->thread_pool_, AdamRun, this, 1); | |||
| if (error_code != RET_OK) { | |||
| MS_LOG(ERROR) << "Adam function error error_code[" << error_code << "]"; | |||
| return RET_ERROR; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| int AdamCPUKernel::Init() { return RET_OK; } | |||
| kernel::LiteKernel *CpuAdamFp32KernelCreator(const std::vector<lite::Tensor *> &inputs, | |||
| const std::vector<lite::Tensor *> &outputs, OpParameter *opParameter, | |||
| const lite::InnerContext *ctx, const kernel::KernelKey &desc, | |||
| const lite::PrimitiveC *primitive) { | |||
| MS_ASSERT(desc.type == schema::PrimitiveType_Adam); | |||
| auto *kernel = new (std::nothrow) AdamCPUKernel(opParameter, inputs, outputs, ctx, primitive); | |||
| MS_ASSERT(kernel != nullptr); | |||
| auto ret = kernel->Init(); | |||
| if (0 != ret) { | |||
| MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " | |||
| << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); | |||
| delete kernel; | |||
| return nullptr; | |||
| } | |||
| return kernel; | |||
| } | |||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Adam, CpuAdamFp32KernelCreator) | |||
| } // namespace mindspore::kernel | |||
| @@ -0,0 +1,44 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_ADAM_H_ | |||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_ADAM_H_ | |||
| #include <vector> | |||
| #include "src/lite_kernel.h" | |||
| #include "nnacl/fp32_grad/optimizer.h" | |||
| namespace mindspore::kernel { | |||
| class AdamCPUKernel : public LiteKernel { | |||
| public: | |||
| explicit AdamCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | |||
| const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx, | |||
| const mindspore::lite::PrimitiveC *primitive) | |||
| : LiteKernel(parameter, inputs, outputs, ctx, primitive) { | |||
| adam_param_ = reinterpret_cast<AdamParameter *>(parameter); | |||
| } | |||
| ~AdamCPUKernel() override {} | |||
| int Init() override; | |||
| int ReSize() override; | |||
| int Run() override; | |||
| int Execute(int task_id); | |||
| private: | |||
| AdamParameter *adam_param_; | |||
| }; | |||
| } // namespace mindspore::kernel | |||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_ADAM_H_ | |||
| @@ -79,14 +79,7 @@ int ApplyMomentumCPUKernel::Run() { | |||
| return RET_OK; | |||
| } | |||
| int ApplyMomentumCPUKernel::Init() { | |||
| // Only for test with uninitialized Data | |||
| size_t elem_num = in_tensors_[0]->ElementsNum(); | |||
| auto accumulate = reinterpret_cast<float *>(in_tensors_[1]->MutableData()); | |||
| for (size_t i = 0; i < elem_num; i++) accumulate[i] = 0.0; | |||
| return RET_OK; | |||
| } | |||
| int ApplyMomentumCPUKernel::Init() { return RET_OK; } | |||
| kernel::LiteKernel *CpuApplyMomentumFp32KernelCreator(const std::vector<lite::Tensor *> &inputs, | |||
| const std::vector<lite::Tensor *> &outputs, | |||
| @@ -0,0 +1,91 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "src/runtime/kernel/arm/fp32_grad/assign.h" | |||
| #include "schema/model_generated.h" | |||
| #include "src/kernel_registry.h" | |||
| #include "include/errorcode.h" | |||
| #include "src/runtime/runtime_api.h" | |||
| #include "src/runtime/kernel/arm/fp32/nchw2nhwc.h" | |||
| using mindspore::kernel::KERNEL_ARCH::kCPU; | |||
| using mindspore::lite::KernelRegistrar; | |||
| using mindspore::lite::RET_ERROR; | |||
| using mindspore::lite::RET_OK; | |||
| using mindspore::schema::PrimitiveType_Assign; | |||
| namespace mindspore::kernel { | |||
| int AssignCPUKernel::ReSize() { return RET_OK; } | |||
| int AssignCPUKernel::Execute(int task_id) { | |||
| auto x = reinterpret_cast<float *>(in_tensors_[0]->MutableData()); | |||
| auto y = reinterpret_cast<float *>(in_tensors_[1]->MutableData()); | |||
| size_t size = in_tensors_[0]->Size(); | |||
| memcpy(x, y, size); | |||
| return RET_OK; | |||
| } | |||
| int AssignRun(void *cdata, int task_id) { | |||
| auto Assign_kernel = reinterpret_cast<AssignCPUKernel *>(cdata); | |||
| auto error_code = Assign_kernel->Execute(task_id); | |||
| if (error_code != RET_OK) { | |||
| MS_LOG(ERROR) << "assign run error task_id[" << task_id << "] error_code[" << error_code << "]"; | |||
| return RET_ERROR; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| int AssignCPUKernel::Run() { | |||
| auto prepare_ret = Prepare(); | |||
| if (prepare_ret != RET_OK) { | |||
| MS_LOG(ERROR) << "AssignCPUKernel Prepare fail!ret: " << prepare_ret; | |||
| return prepare_ret; | |||
| } | |||
| int error_code = ParallelLaunch(this->context_->thread_pool_, AssignRun, this, 1); | |||
| if (error_code != RET_OK) { | |||
| MS_LOG(ERROR) << "Assign function error error_code[" << error_code << "]"; | |||
| return RET_ERROR; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| int AssignCPUKernel::Init() { return RET_OK; } | |||
| kernel::LiteKernel *CpuAssignFp32KernelCreator(const std::vector<lite::Tensor *> &inputs, | |||
| const std::vector<lite::Tensor *> &outputs, OpParameter *opParameter, | |||
| const lite::InnerContext *ctx, const kernel::KernelKey &desc, | |||
| const lite::PrimitiveC *primitive) { | |||
| MS_ASSERT(desc.type == schema::PrimitiveType_Assign); | |||
| auto *kernel = new (std::nothrow) AssignCPUKernel(opParameter, inputs, outputs, ctx, primitive); | |||
| MS_ASSERT(kernel != nullptr); | |||
| auto ret = kernel->Init(); | |||
| if (0 != ret) { | |||
| MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " | |||
| << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); | |||
| delete kernel; | |||
| return nullptr; | |||
| } | |||
| return kernel; | |||
| } | |||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Assign, CpuAssignFp32KernelCreator) | |||
| } // namespace mindspore::kernel | |||
| @@ -0,0 +1,39 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_ASSIGN_H_ | |||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_ASSIGN_H_ | |||
| #include <vector> | |||
| #include "src/lite_kernel.h" | |||
| #include "nnacl/fp32_grad/optimizer.h" | |||
| namespace mindspore::kernel { | |||
| class AssignCPUKernel : public LiteKernel { | |||
| public: | |||
| explicit AssignCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | |||
| const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx, | |||
| const mindspore::lite::PrimitiveC *primitive) | |||
| : LiteKernel(parameter, inputs, outputs, ctx, primitive) {} | |||
| ~AssignCPUKernel() override {} | |||
| int Init() override; | |||
| int ReSize() override; | |||
| int Run() override; | |||
| int Execute(int task_id); | |||
| }; | |||
| } // namespace mindspore::kernel | |||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_ASSIGN_H_ | |||
| @@ -27,6 +27,7 @@ using mindspore::lite::KernelRegistrar; | |||
| using mindspore::lite::RET_ERROR; | |||
| using mindspore::lite::RET_OK; | |||
| using mindspore::schema::PrimitiveType_Conv2DGradInput; | |||
| using mindspore::schema::PrimitiveType_GroupConv2DGradInput; | |||
| namespace mindspore::kernel { | |||
| int ConvolutionGradInputCPUKernel::Init() { | |||
| @@ -134,7 +135,8 @@ kernel::LiteKernel *CpuConvGradInputFp32KernelCreator(const std::vector<lite::Te | |||
| const kernel::KernelKey &desc, | |||
| const mindspore::lite::PrimitiveC *primitive) { | |||
| MS_ASSERT(opParameter != nullptr); | |||
| MS_ASSERT(desc.type == schema::PrimitiveType_Conv2DGradInput); | |||
| MS_ASSERT(desc.type == schema::PrimitiveType_Conv2DGradInput || | |||
| desc.type == schema::PrimitiveType_GroupConv2DGradInput); | |||
| auto *kernel = new (std::nothrow) ConvolutionGradInputCPUKernel(opParameter, inputs, outputs, ctx, primitive); | |||
| if (kernel == nullptr) { | |||
| @@ -154,4 +156,6 @@ kernel::LiteKernel *CpuConvGradInputFp32KernelCreator(const std::vector<lite::Te | |||
| } | |||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Conv2DGradInput, CpuConvGradInputFp32KernelCreator) | |||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_GroupConv2DGradInput, CpuConvGradInputFp32KernelCreator) | |||
| } // namespace mindspore::kernel | |||
| @@ -154,4 +154,6 @@ kernel::LiteKernel *CpuSoftmaxCrossEntropyFp32KernelCreator(const std::vector<li | |||
| } | |||
| return kernel; | |||
| } | |||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_SoftmaxCrossEntropy, CpuSoftmaxCrossEntropyFp32KernelCreator) | |||
| } // namespace mindspore::kernel | |||
| @@ -178,5 +178,4 @@ kernel::LiteKernel *CpuSparseSoftmaxCrossEntropyFp32KernelCreator( | |||
| return kernel; | |||
| } | |||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_SoftmaxCrossEntropy, CpuSparseSoftmaxCrossEntropyFp32KernelCreator) | |||
| } // namespace mindspore::kernel | |||
| @@ -41,6 +41,7 @@ TrainModel *TrainModel::Import(const char *model_buf, size_t size) { | |||
| } | |||
| model->buf = reinterpret_cast<char *>(malloc(size)); | |||
| if (model->buf == nullptr) { | |||
| delete model; | |||
| MS_LOG(ERROR) << "new inner model buf fail!"; | |||
| return nullptr; | |||
| } | |||
| @@ -48,6 +49,8 @@ TrainModel *TrainModel::Import(const char *model_buf, size_t size) { | |||
| model->buf_size_ = size; | |||
| auto meta_graph = schema::GetMetaGraph(model->buf); | |||
| if (meta_graph == nullptr) { | |||
| delete model; | |||
| free(model->buf); | |||
| MS_LOG(ERROR) << "meta_graph is nullptr!"; | |||
| return nullptr; | |||
| } | |||
| @@ -24,6 +24,7 @@ | |||
| #include "nnacl/fp32/activation.h" | |||
| #include "src/ops/conv2d_grad_filter.h" | |||
| #include "src/ops/conv2d_grad_input.h" | |||
| #include "src/ops/group_conv2d_grad_input.h" | |||
| #include "nnacl/conv_parameter.h" | |||
| #include "src/ops/power_grad.h" | |||
| #include "nnacl/power_parameter.h" | |||
| @@ -34,6 +35,7 @@ | |||
| #include "src/ops/sgd.h" | |||
| #include "src/ops/bn_grad.h" | |||
| #include "nnacl/fp32_grad/batch_norm.h" | |||
| #include "src/ops/adam.h" | |||
| namespace mindspore::kernel { | |||
| @@ -69,12 +71,29 @@ OpParameter *PopulateApplyMomentumParameter(const mindspore::lite::PrimitiveC *p | |||
| reinterpret_cast<mindspore::lite::ApplyMomentum *>(const_cast<mindspore::lite::PrimitiveC *>(primitive)); | |||
| p->grad_scale_ = apply_momentum_primitive->GetGradientScale(); | |||
| p->use_locking_ = apply_momentum_primitive->GetUseLocking(); | |||
| p->use_nesterov_ = apply_momentum_primitive->GetUseNesterov(); | |||
| return reinterpret_cast<OpParameter *>(p); | |||
| } | |||
| OpParameter *PopulateAdamParameter(const mindspore::lite::PrimitiveC *primitive) { | |||
| if (primitive == nullptr) { | |||
| MS_LOG(ERROR) << "Primitive is nullptr when populating parameter for op."; | |||
| return nullptr; | |||
| } | |||
| AdamParameter *p = reinterpret_cast<AdamParameter *>(malloc(sizeof(AdamParameter))); | |||
| if (p == nullptr) { | |||
| MS_LOG(ERROR) << "new AdamParameter failed."; | |||
| return nullptr; | |||
| } | |||
| p->op_parameter_.type_ = primitive->Type(); | |||
| auto apply_momentum_primitive = | |||
| reinterpret_cast<mindspore::lite::Adam *>(const_cast<mindspore::lite::PrimitiveC *>(primitive)); | |||
| p->use_nesterov_ = apply_momentum_primitive->GetUseNesterov(); | |||
| return reinterpret_cast<OpParameter *>(p); | |||
| } | |||
| OpParameter *PopulateSgdParameter(const mindspore::lite::PrimitiveC *primitive) { | |||
| if (primitive == nullptr) { | |||
| MS_LOG(ERROR) << "Primitive is nullptr when populating parameter for op."; | |||
| @@ -264,6 +283,47 @@ OpParameter *PopulateConvolutionGradInputParameter(const mindspore::lite::Primit | |||
| return reinterpret_cast<OpParameter *>(param); | |||
| } | |||
| OpParameter *PopulateGroupConvolutionGradInputParameter(const mindspore::lite::PrimitiveC *primitive) { | |||
| if (primitive == nullptr) { | |||
| MS_LOG(ERROR) << "Primitive is nullptr when populating parameter for op."; | |||
| return nullptr; | |||
| } | |||
| ConvParameter *param = reinterpret_cast<ConvParameter *>(malloc(sizeof(ConvParameter))); | |||
| if (param == nullptr) { | |||
| MS_LOG(ERROR) << "new Param for conv grad filter failed."; | |||
| return nullptr; | |||
| } | |||
| param->op_parameter_.type_ = primitive->Type(); | |||
| auto convg_primitive = | |||
| reinterpret_cast<mindspore::lite::GroupConv2DGradInput *>(const_cast<mindspore::lite::PrimitiveC *>(primitive)); | |||
| param->kernel_h_ = convg_primitive->GetKernelH(); | |||
| param->kernel_w_ = convg_primitive->GetKernelW(); | |||
| param->stride_h_ = convg_primitive->GetStrideH(); | |||
| param->stride_w_ = convg_primitive->GetStrideW(); | |||
| param->dilation_h_ = convg_primitive->GetDilateH(); | |||
| param->dilation_w_ = convg_primitive->GetDilateW(); | |||
| param->pad_u_ = convg_primitive->GetPadUp(); | |||
| param->pad_d_ = convg_primitive->GetPadDown(); | |||
| param->pad_l_ = convg_primitive->GetPadLeft(); | |||
| param->pad_r_ = convg_primitive->GetPadRight(); | |||
| param->group_ = convg_primitive->GetGroup(); | |||
| param->act_type_ = ActType_No; | |||
| switch (convg_primitive->GetActivationType()) { | |||
| case schema::ActivationType_RELU: | |||
| param->act_type_ = ActType_Relu; | |||
| break; | |||
| case schema::ActivationType_RELU6: | |||
| param->act_type_ = ActType_Relu6; | |||
| break; | |||
| default: | |||
| break; | |||
| } | |||
| return reinterpret_cast<OpParameter *>(param); | |||
| } | |||
| OpParameter *PopulatePowerGradParameter(const mindspore::lite::PrimitiveC *primitive) { | |||
| if (primitive == nullptr) { | |||
| MS_LOG(ERROR) << "Primitive is nullptr when populating parameter for op."; | |||
| @@ -327,10 +387,13 @@ void PopulateTrainParameters() { | |||
| ppr->AddPopulateParameterFunc(schema::PrimitiveType_BNGrad, DefaultPopulateParameter); | |||
| ppr->AddPopulateParameterFunc(schema::PrimitiveType_Conv2DGradFilter, PopulateConvolutionGradFilterParameter); | |||
| ppr->AddPopulateParameterFunc(schema::PrimitiveType_Conv2DGradInput, PopulateConvolutionGradInputParameter); | |||
| ppr->AddPopulateParameterFunc(schema::PrimitiveType_GroupConv2DGradInput, PopulateGroupConvolutionGradInputParameter); | |||
| ppr->AddPopulateParameterFunc(schema::PrimitiveType_PoolingGrad, PopulatePoolingGradParameter); | |||
| ppr->AddPopulateParameterFunc(schema::PrimitiveType_PowerGrad, PopulatePowerGradParameter); | |||
| ppr->AddPopulateParameterFunc(schema::PrimitiveType_Sgd, PopulateSgdParameter); | |||
| ppr->AddPopulateParameterFunc(schema::PrimitiveType_BNGrad, PopulateBNGradParameter); | |||
| ppr->AddPopulateParameterFunc(schema::PrimitiveType_Adam, PopulateAdamParameter); | |||
| ppr->AddPopulateParameterFunc(schema::PrimitiveType_Assign, DefaultPopulateParameter); | |||
| } | |||
| } // namespace mindspore::kernel | |||
| @@ -104,23 +104,15 @@ int TrainSession::RunGraph(const session::KernelCallBack &before, const session: | |||
| for (auto ms_tensor : ms_tensors.second) this->outputs_.push_back((static_cast<lite::Tensor *>(ms_tensor))); | |||
| if (train_mode_) return lite::LiteSession::RunGraph(before, after); | |||
| // object is expected to run only inference part of graph | |||
| // prepare a list of kernels till the loss function -- temporary solution | |||
| std::vector<kernel::LiteKernel *> inference_kernels; | |||
| for (auto kernel : this->kernels_) { | |||
| if (IsLossKernel(kernel)) break; | |||
| inference_kernels.push_back(kernel); | |||
| } | |||
| if (this->context_ == nullptr) { | |||
| MS_LOG(ERROR) << "context is null"; | |||
| return lite::RET_NULL_PTR; | |||
| } | |||
| lite::Executor executor; | |||
| if (before == nullptr && after == nullptr) { | |||
| return executor.Run(this->inputs_, this->outputs_, inference_kernels, this->context_->allocator.get()); | |||
| return executor.Run(this->inputs_, this->outputs_, inference_kernels_, this->context_->allocator.get()); | |||
| } else { | |||
| return executor.Run(this->inputs_, this->outputs_, inference_kernels, this->context_->allocator.get(), before, | |||
| return executor.Run(this->inputs_, this->outputs_, inference_kernels_, this->context_->allocator.get(), before, | |||
| after); | |||
| } | |||
| } | |||
| @@ -173,6 +165,38 @@ void TrainSession::Eval() { | |||
| } | |||
| } | |||
| } | |||
| if (inference_kernels_.size() == 0) { | |||
| BuildInferenceKernelsMap(); | |||
| } | |||
| } | |||
| void TrainSession::BuildInferenceKernelsRecursive(kernel::LiteKernel *kernel, std::vector<kernel::LiteKernel *> *v) { | |||
| if (std::find(v->begin(), v->end(), kernel) == v->end()) { // kernel is not in vector | |||
| v->push_back(kernel); | |||
| for (auto in_node : kernel->in_kernels()) { | |||
| BuildInferenceKernelsRecursive(in_node, v); | |||
| } | |||
| } | |||
| } | |||
| void TrainSession::BuildInferenceKernelsMap() { | |||
| std::vector<kernel::LiteKernel *> req_kernels; | |||
| for (auto kernel : this->kernels_) { | |||
| if (IsLossKernel(kernel)) { // For each loss in the system add backward tree | |||
| for (auto in_node : kernel->in_kernels()) { | |||
| BuildInferenceKernelsRecursive(in_node, &req_kernels); | |||
| } | |||
| } | |||
| } | |||
| inference_kernels_.clear(); | |||
| for (auto kernel : this->kernels_) { | |||
| if (std::find(req_kernels.begin(), req_kernels.end(), kernel) != req_kernels.end()) { | |||
| inference_kernels_.push_back(kernel); | |||
| } | |||
| } | |||
| if (inference_kernels_.size() == 0) { | |||
| inference_kernels_ = this->kernels_; | |||
| } | |||
| } | |||
| bool TrainSession::IsLossKernel(kernel::LiteKernel *kernel) { | |||
| @@ -82,12 +82,16 @@ class TrainSession : virtual public session::TrainSession, virtual public lite:: | |||
| protected: | |||
| void AllocWorkSpace(); | |||
| bool IsLossKernel(kernel::LiteKernel *kernel); | |||
| virtual std::vector<CreatorOp> ReplaceOps(); | |||
| virtual void RestoreOps(const std::vector<CreatorOp> &restore); | |||
| bool IsLossKernel(kernel::LiteKernel *kernel); | |||
| virtual void BuildInferenceKernelsMap(); | |||
| virtual void BuildInferenceKernelsRecursive(kernel::LiteKernel *ker, std::vector<kernel::LiteKernel *> *req_kernels); | |||
| TrainModel *model_ = nullptr; | |||
| std::unordered_map<std::string, std::vector<mindspore::tensor::MSTensor *>> orig_output_map_; | |||
| std::unordered_map<std::string, mindspore::tensor::MSTensor *> orig_output_tensor_map_; | |||
| std::vector<kernel::LiteKernel *> inference_kernels_; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -13,8 +13,8 @@ | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef TESTS_UT_COMMON_UT_COMMON_H_ | |||
| #define TESTS_UT_COMMON_UT_COMMON_H_ | |||
| #ifndef MINDSPORE_LITE_TEST_COMMON_COMMON_TEST_H_ | |||
| #define MINDSPORE_LITE_TEST_COMMON_COMMON_TEST_H_ | |||
| #include <cmath> | |||
| #include <fstream> | |||
| @@ -37,11 +37,11 @@ class CommonTest : public testing::Test { | |||
| void PrintData(std::string name, T *output_data, int size) { | |||
| std::cout << "The " << name << " is as follows:" << std::endl; | |||
| if (typeid(output_data[0]) == typeid(uint8_t) || typeid(output_data[0]) == typeid(int8_t)) { | |||
| for (size_t i = 0; i < std::min(size, 100); i++) { | |||
| for (int i = 0; i < std::min(size, 100); i++) { | |||
| std::cout << static_cast<int>(output_data[i]) << " "; | |||
| } | |||
| } else { | |||
| for (size_t i = 0; i < std::min(size, 100); i++) { | |||
| for (int i = 0; i < std::min(size, 100); i++) { | |||
| std::cout << output_data[i] << " "; | |||
| } | |||
| } | |||
| @@ -58,7 +58,7 @@ class CommonTest : public testing::Test { | |||
| void CompareOutputInt8(int8_t *output_data, int8_t *correct_data, int size, float err_percent) { | |||
| int bias_count = 0; | |||
| for (size_t i = 0; i < size; i++) { | |||
| for (int i = 0; i < size; i++) { | |||
| int8_t diff = abs(output_data[i] - correct_data[i]); | |||
| ASSERT_LE(diff, 1); | |||
| if (diff == 1) { | |||
| @@ -88,4 +88,4 @@ class CommonTest : public testing::Test { | |||
| } | |||
| }; | |||
| } // namespace mindspore | |||
| #endif // TESTS_UT_COMMON_UT_COMMON_H_ | |||
| #endif // MINDSPORE_LITE_TEST_COMMON_COMMON_TEST_H_ | |||
| @@ -250,7 +250,7 @@ TEST_F(NetworkTest, tuning_layer) { | |||
| auto label = std::make_unique<schema::TensorT>(); | |||
| label->nodeType = schema::NodeType::NodeType_ValueNode; | |||
| label->format = schema::Format_NHWC; | |||
| label->dataType = TypeId::kNumberTypeInt32; | |||
| label->dataType = TypeId::kNumberTypeFloat32; | |||
| label->dims = {BATCH_SIZE * NUM_CLASSES}; | |||
| label->offset = -1; | |||
| meta_graph->allTensors.emplace_back(std::move(label)); | |||
| @@ -386,8 +386,10 @@ TEST_F(NetworkTest, tuning_layer) { | |||
| auto labelTensor = inputs.at(1); | |||
| ASSERT_NE(nullptr, labelTensor); | |||
| ASSERT_EQ(BATCH_SIZE * NUM_CLASSES, labelTensor->ElementsNum()); | |||
| auto labels = reinterpret_cast<int *>(labelTensor->MutableData()); | |||
| for (int i = 0; i < BATCH_SIZE; i++) labels[i] = (i * 97) % NUM_CLASSES; | |||
| auto labels = reinterpret_cast<float *>(labelTensor->MutableData()); | |||
| std::fill(labels, labels + labelTensor->ElementsNum(), 0.f); | |||
| for (int i = 0; i < BATCH_SIZE; i++) labels[i * NUM_CLASSES + (i * 97) % NUM_CLASSES] = 1.0; | |||
| ret = session->RunGraph(); | |||
| ASSERT_EQ(lite::RET_OK, ret); | |||
| @@ -576,12 +578,12 @@ TEST_F(NetworkTest, lenetnet) { | |||
| delete context; | |||
| ASSERT_EQ(res, 0); | |||
| } | |||
| #if 0 | |||
| TEST_F(NetworkTest, retina_net) { | |||
| char *buf = nullptr; | |||
| size_t net_size = 0; | |||
| std::string net = "./test_data/nets/retinaface1009.ms"; | |||
| std::string net = "./test_data/nets/retinaface1.ms"; | |||
| ReadFile(net.c_str(), &net_size, &buf); | |||
| // auto model = lite::TrainModel::Import(buf, net_size); | |||
| auto model = lite::Model::Import(buf, net_size); | |||
| @@ -598,26 +600,36 @@ TEST_F(NetworkTest, retina_net) { | |||
| ASSERT_EQ(lite::RET_OK, ret); | |||
| // session->Eval(); | |||
| std::string in = "./test_data/nets/retinaface_input.f32"; | |||
| std::string in = "./test_data/nets/test1.hwc_normalized_f32"; | |||
| std::cout << "----- Output 0 -----" << std::endl; | |||
| std::string out = "./test_data/nets/retinaface_out_0.f32"; | |||
| std::string out = "./test_data/nets/test1_loc.f32"; | |||
| int final_res = 0; | |||
| auto res = runNet(session, in, out, "448", true); | |||
| ASSERT_EQ(res, 0); | |||
| // ASSERT_EQ(res, 0); | |||
| if (res != 0) { | |||
| final_res = res; | |||
| } | |||
| std::cout << "----- Output 1 -----" << std::endl; | |||
| out = "./test_data/nets/retinaface_out_1.f32"; | |||
| out = "./test_data/nets/test1_conf.f32"; | |||
| res = runNet(session, in, out, "435", true); | |||
| ASSERT_EQ(res, 0); | |||
| // ASSERT_EQ(res, 0); | |||
| if (res != 0) { | |||
| final_res |= res; | |||
| } | |||
| std::cout << "----- Output 2 -----" << std::endl; | |||
| out = "./test_data/nets/retinaface_out_2.f32"; | |||
| out = "./test_data/nets/test1_landms.f32"; | |||
| res = runNet(session, in, out, "421", true); | |||
| ASSERT_EQ(res, 0); | |||
| if (res != 0) { | |||
| final_res |= res; | |||
| } | |||
| ASSERT_EQ(final_res, 0); | |||
| delete session; | |||
| delete context; | |||
| } | |||
| #endif | |||
| TEST_F(NetworkTest, mobileface_net) { | |||
| char *buf = nullptr; | |||
| size_t net_size = 0; | |||
| @@ -41,10 +41,11 @@ TEST_F(TestSoftmaxCrossEntropyFp32, SoftmaxCrossEntropyFp32) { | |||
| std::string label_path = "./test_data/operators/sce_fp32_1_l_6.bin"; | |||
| auto ll_labels = reinterpret_cast<int64_t *>(mindspore::lite::ReadFile(label_path.c_str(), &input_size)); | |||
| auto labels = new int[6]; | |||
| for (int i = 0; i < 6; i++) labels[i] = static_cast<int>(ll_labels[i]); | |||
| auto labels = new float[6 * 4]; | |||
| std::fill(labels, labels + 6 * 4, 0.f); | |||
| for (int i = 0; i < 6; i++) labels[i * 4 + ll_labels[i]] = 1.0; | |||
| std::vector<int> dim_l({6}); | |||
| std::vector<int> dim_l({6, 4}); | |||
| lite::Tensor l_tensor(TypeId::kNumberTypeInt32, dim_l); | |||
| l_tensor.SetData(labels); | |||
| @@ -274,9 +274,24 @@ int AnfExporter::ConvertInputCNode(const std::shared_ptr<AnfNode> input_anode, s | |||
| auto input_cnode = utils::cast<CNodePtr>(input_anode); | |||
| if (!IsPrimitiveCNode(input_cnode, schema::PrimitiveType_TupleGetItem)) { | |||
| #ifndef SUPPORT_TRAIN | |||
| if (node_id_map_.find(input_name) != node_id_map_.end()) { | |||
| output_cnode->inputIndex.emplace_back(node_id_map_[input_name]); | |||
| } | |||
| #else | |||
| bool found = false; | |||
| if (node_id_map_.find(input_name) != node_id_map_.end()) { | |||
| output_cnode->inputIndex.emplace_back(node_id_map_[input_name]); | |||
| found = true; | |||
| } | |||
| if (found == false) { | |||
| auto input_index_key = input_name + "_o:" + std::to_string(0); | |||
| if (node_id_map_.find(input_index_key) != node_id_map_.end()) { | |||
| output_cnode->inputIndex.emplace_back(node_id_map_[input_index_key]); | |||
| } | |||
| } | |||
| #endif | |||
| } else { | |||
| auto inputs = input_cnode->inputs(); | |||
| if (inputs.size() != 3) { | |||
| @@ -369,6 +384,9 @@ int AnfExporter::ConvertInputValueNode(std::shared_ptr<AnfNode> input_anode, | |||
| auto typePtr = abstractTensor->element()->GetTypeTrack(); | |||
| paramTensor->dataType = typePtr->type_id(); | |||
| paramTensor->dims = utils::cast<abstract::ShapePtr>(abstractTensor->BuildShape())->shape(); | |||
| #ifdef SUPPORT_TRAIN | |||
| if (paramTensor->dims.size() == 0) paramTensor->dims = {1}; | |||
| #endif | |||
| paramTensor->nodeType = schema::NodeType::NodeType_ValueNode; | |||
| auto data = value->cast<tensor::TensorPtr>(); | |||
| paramTensor->data.resize(data->Size()); | |||
| @@ -505,7 +523,8 @@ void AnfExporter::SetOpOutputNode(const CNodePtr &cnode, const std::unique_ptr<s | |||
| node_id_map_[name] = meta_graphT->allTensors.size(); | |||
| meta_graphT->allTensors.emplace_back(msTensor); | |||
| if (IsPrimitiveCNode(cnode, schema::PrimitiveType_Conv2D) || | |||
| IsPrimitiveCNode(cnode, schema::PrimitiveType_DepthwiseConv2D)) | |||
| IsPrimitiveCNode(cnode, schema::PrimitiveType_DepthwiseConv2D) || | |||
| IsPrimitiveCNode(cnode, schema::PrimitiveType_Adam)) | |||
| break; | |||
| #else | |||
| if (tuple->size() == 1) { | |||
| @@ -678,6 +678,13 @@ bool AnfImporterFromProtobuf::BuildReturnForFuncGraph(const FuncGraphPtr &output | |||
| outputFuncGraph->set_return(return_node); | |||
| MS_LOG(INFO) << "Construct funcgraph finined, all success."; | |||
| } else { | |||
| #ifdef SUPPORT_TRAIN | |||
| auto ret_node = outputFuncGraph->get_return(); | |||
| if (ret_node) { | |||
| ret_node->add_input(cnode_ptr); | |||
| return true; | |||
| } | |||
| #endif | |||
| const onnx::ValueInfoProto &output_node = importProto.output(0); | |||
| const onnx::TypeProto &output_typeproto = output_node.type(); | |||
| int output_type = output_typeproto.tensor_type().elem_type(); | |||
| @@ -687,7 +694,6 @@ bool AnfImporterFromProtobuf::BuildReturnForFuncGraph(const FuncGraphPtr &output | |||
| } | |||
| auto type_ptr = TypeIdToType(kDefaultValueSwitchMap[output_type]); | |||
| auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, output_shape); | |||
| inputs.clear(); | |||
| auto primReturn = std::make_unique<schema::PrimitiveT>(); | |||
| MS_ASSERT(primReturn != nullptr); | |||
| @@ -717,6 +723,7 @@ int AnfImporterFromProtobuf::ImportNodesForGraph(const FuncGraphPtr &outputFuncG | |||
| } | |||
| MS_LOG(INFO) << "The CNdoe size : " << importProto.node_size(); | |||
| CNodePtr cnode_ptr = nullptr; | |||
| CNodePtr last_cnode_ptr = nullptr; | |||
| int status = RET_OK; | |||
| NoSupportOp::GetInstance()->SetFmkType("MINDIR"); | |||
| for (int i = 0; i < importProto.node_size(); ++i) { | |||
| @@ -734,13 +741,35 @@ int AnfImporterFromProtobuf::ImportNodesForGraph(const FuncGraphPtr &outputFuncG | |||
| MS_LOG(ERROR) << "Build CNode for funcgraph fail at index: : " << i; | |||
| status = (status == RET_OK ? RET_NULL_PTR : status); | |||
| } | |||
| auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode_ptr->input(0)); | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "primitive_c is nullptr"; | |||
| status = RET_ERROR; | |||
| } | |||
| #ifdef SUPPORT_TRAIN | |||
| if (primitive_c->Type() == schema::PrimitiveType_MakeTuple) { | |||
| last_cnode_ptr = cnode_ptr; | |||
| if (!BuildReturnForFuncGraph(outputFuncGraph, importProto, cnode_ptr)) { | |||
| MS_LOG(ERROR) << "Build ReturnNode for funcgraph failed"; | |||
| status = RET_ERROR; | |||
| } | |||
| } | |||
| #endif | |||
| } | |||
| if (status != RET_OK) { | |||
| return status; | |||
| } | |||
| if (!BuildReturnForFuncGraph(outputFuncGraph, importProto, cnode_ptr)) { | |||
| MS_LOG(ERROR) << "Build ReturnNode for funcgraph failed"; | |||
| status = RET_ERROR; | |||
| #ifdef SUPPORT_TRAIN | |||
| if (last_cnode_ptr != cnode_ptr) { | |||
| #else | |||
| { | |||
| #endif | |||
| if (!BuildReturnForFuncGraph(outputFuncGraph, importProto, cnode_ptr)) { | |||
| MS_LOG(ERROR) << "Build ReturnNode for funcgraph failed"; | |||
| status = RET_ERROR; | |||
| } | |||
| } | |||
| return status; | |||
| } | |||
| @@ -28,12 +28,14 @@ static const std::vector<schema::PrimitiveType> nhwcOpList = { | |||
| #ifdef SUPPORT_TRAIN | |||
| schema::PrimitiveType_Conv2DGradFilter, | |||
| schema::PrimitiveType_Conv2DGradInput, | |||
| schema::PrimitiveType_GroupConv2DGradInput, | |||
| schema::PrimitiveType_PoolingGrad, | |||
| schema::PrimitiveType_BiasGrad, | |||
| schema::PrimitiveType_BNGrad, | |||
| schema::PrimitiveType_ActivationGrad, | |||
| schema::PrimitiveType_ApplyMomentum, | |||
| schema::PrimitiveType_Sgd, | |||
| schema::PrimitiveType_Adam, | |||
| #endif | |||
| schema::PrimitiveType_Conv2D, | |||
| schema::PrimitiveType_DeConv2D, | |||
| @@ -161,6 +161,7 @@ STATUS FormatTransPass::DoNodeInoutFormatTrans(schema::MetaGraphT *graph) { | |||
| int idx = 0; | |||
| if (GetCNodeTType(**iter) == schema::PrimitiveType_ApplyMomentum) idx = 3; | |||
| if (GetCNodeTType(**iter) == schema::PrimitiveType_Sgd) idx = 1; | |||
| if (GetCNodeTType(**iter) == schema::PrimitiveType_Adam) idx = 9; | |||
| iter = InsertFormatTransNode(graph, iter, kBefore, idx, beforeNodeType, &status); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "InsertNhwc2NchwNode after " << nodeName << "failed"; | |||
| @@ -183,7 +183,9 @@ STATUS TransOpInsertPass::ChangeOpAxis(schema::MetaGraphT *graph, const std::uni | |||
| auto origin_begin = attr->begin; | |||
| attr->begin = {origin_begin[NCHW_N], origin_begin[NCHW_H], origin_begin[NCHW_W], origin_begin[NCHW_C]}; | |||
| auto origin_end = attr->axes; | |||
| attr->axes = {origin_end[NCHW_N], origin_end[NCHW_H], origin_end[NCHW_W], origin_end[NCHW_C]}; | |||
| if (origin_end.size() >= 4) { | |||
| attr->axes = {origin_end[NCHW_N], origin_end[NCHW_H], origin_end[NCHW_W], origin_end[NCHW_C]}; | |||
| } | |||
| auto origin_stride = attr->size; | |||
| attr->size = {origin_stride[NCHW_N], origin_stride[NCHW_H], origin_stride[NCHW_W], origin_stride[NCHW_C]}; | |||
| } | |||
| @@ -122,6 +122,12 @@ lite::STATUS WeightFormatHardCodePass::HardCodeMS(const AnfNodePtr &conv_node, | |||
| param_value->set_format(schema::Format::Format_CKHW); | |||
| } else if (op_type == schema::PrimitiveType_DeConv2D) { | |||
| param_value->set_format(schema::Format::Format_KCHW); | |||
| #ifdef SUPPORT_TRAIN | |||
| } else if (op_type == schema::PrimitiveType_Conv2DGradInput) { | |||
| param_value->set_format(schema::Format::Format_KCHW); | |||
| } else if (op_type == schema::PrimitiveType_GroupConv2DGradInput) { | |||
| param_value->set_format(schema::Format::Format_CKHW); | |||
| #endif | |||
| } else { | |||
| MS_LOG(ERROR) << "Unsupported opType: " << EnumNamePrimitiveType(op_type) | |||
| << ", node: " << conv_node->fullname_with_scope(); | |||
| @@ -178,6 +184,10 @@ bool WeightFormatHardCodePass::Run(const FuncGraphPtr &graph) { | |||
| auto conv_cnode = node->cast<CNodePtr>(); | |||
| auto type = opt::GetCNodeType(node); | |||
| if (type != schema::PrimitiveType_Conv2D && type != schema::PrimitiveType_DepthwiseConv2D && | |||
| #ifdef SUPPORT_TRAIN | |||
| ((type != schema::PrimitiveType_Conv2DGradInput) || (fmk_type != FmkType_MS)) && | |||
| ((type != schema::PrimitiveType_GroupConv2DGradInput) || (fmk_type != FmkType_MS)) && | |||
| #endif | |||
| type != schema::PrimitiveType_DeConv2D && type != schema::PrimitiveType_DeDepthwiseConv2D) { | |||
| continue; | |||
| } | |||
| @@ -18,27 +18,21 @@ | |||
| #include "tools/optimizer/common/gllo_utils.h" | |||
| using mindspore::lite::converter::FmkType_CAFFE; | |||
| using mindspore::lite::converter::FmkType_TFLITE; | |||
| using mindspore::lite::converter::FmkType_ONNX; | |||
| using mindspore::lite::converter::FmkType_MS; | |||
| using mindspore::schema::QuantType_WeightQuant; | |||
| using mindspore::schema::QuantType_QUANT_NONE; | |||
| using mindspore::lite::converter::FmkType_ONNX; | |||
| using mindspore::lite::converter::FmkType_TFLITE; | |||
| using mindspore::schema::QuantType_AwareTraining; | |||
| using mindspore::schema::QuantType_PostTraining; | |||
| using mindspore::schema::QuantType_QUANT_NONE; | |||
| using mindspore::schema::QuantType_WeightQuant; | |||
| namespace mindspore::opt { | |||
| namespace { | |||
| constexpr size_t kConvWeightIndex = 2; | |||
| } // namespace | |||
| void WeightFormatTransformPass::SetQuantType(QuantType type) { | |||
| this->quant_type = type; | |||
| } | |||
| void WeightFormatTransformPass::SetFmkType(FmkType type) { | |||
| this->fmk_type = type; | |||
| } | |||
| void WeightFormatTransformPass::SetDstFormat(schema::Format format) { | |||
| this->dst_format = format; | |||
| } | |||
| void WeightFormatTransformPass::SetQuantType(QuantType type) { this->quant_type = type; } | |||
| void WeightFormatTransformPass::SetFmkType(FmkType type) { this->fmk_type = type; } | |||
| void WeightFormatTransformPass::SetDstFormat(schema::Format format) { this->dst_format = format; } | |||
| lite::STATUS WeightFormatTransformPass::ConvWeightFormatTrans(const FuncGraphPtr &graph) { | |||
| MS_ASSERT(graph != nullptr); | |||
| auto node_list = TopoSort(graph->get_return()); | |||
| @@ -48,6 +42,9 @@ lite::STATUS WeightFormatTransformPass::ConvWeightFormatTrans(const FuncGraphPtr | |||
| } | |||
| auto type = opt::GetCNodeType(node); | |||
| if (type != schema::PrimitiveType_Conv2D && type != schema::PrimitiveType_DepthwiseConv2D | |||
| #ifdef SUPPORT_TRAIN | |||
| && type != schema::PrimitiveType_Conv2DGradInput && type != schema::PrimitiveType_GroupConv2DGradInput | |||
| #endif | |||
| && type != schema::PrimitiveType_DeConv2D && type != schema::PrimitiveType_DeDepthwiseConv2D) { | |||
| continue; | |||
| } | |||
| @@ -60,8 +57,8 @@ lite::STATUS WeightFormatTransformPass::ConvWeightFormatTrans(const FuncGraphPtr | |||
| MS_LOG(ERROR) << "weight node must param value"; | |||
| return false; | |||
| } | |||
| MS_ASSERT(weight_value->tensor_type() == TypeId::kNumberTypeFloat32 | |||
| || weight_value->tensor_type() == TypeId::kNumberTypeUInt8); | |||
| MS_ASSERT(weight_value->tensor_type() == TypeId::kNumberTypeFloat32 || | |||
| weight_value->tensor_type() == TypeId::kNumberTypeUInt8); | |||
| lite::STATUS status; | |||
| schema::Format weight_dst_format = schema::Format::Format_KHWC; | |||
| if (dst_format != schema::Format::Format_NUM_OF_FORMAT) { | |||