Browse Source

fix bugs in instance_norm and pad operator

tags/v1.1.0
zengxianglong 5 years ago
parent
commit
86ae625dc7
7 changed files with 43 additions and 24 deletions
  1. +9
    -0
      mindspore/lite/src/ops/instance_norm.cc
  2. +1
    -2
      mindspore/lite/src/ops/pad.cc
  3. +2
    -2
      mindspore/lite/src/ops/populate/instance_norm_populate.cc
  4. +13
    -15
      mindspore/lite/src/ops/populate/pad_populate.cc
  5. +3
    -0
      mindspore/lite/src/ops/primitive_c.cc
  6. +13
    -4
      mindspore/lite/src/runtime/kernel/arm/fp32/pad.cc
  7. +2
    -1
      mindspore/lite/tools/common/node_util.cc

+ 9
- 0
mindspore/lite/src/ops/instance_norm.cc View File

@@ -16,6 +16,11 @@


#include "src/ops/instance_norm.h" #include "src/ops/instance_norm.h"
#include <memory> #include <memory>

#ifndef PRIMITIVE_WRITEABLE
#include "src/ops/ops_register.h"
#endif

namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
#ifdef PRIMITIVE_WRITEABLE #ifdef PRIMITIVE_WRITEABLE
@@ -60,6 +65,10 @@ int InstanceNorm::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbu
} }
float InstanceNorm::GetEpsilon() const { return this->primitive_->value_as_InstanceNorm()->epsilon(); } float InstanceNorm::GetEpsilon() const { return this->primitive_->value_as_InstanceNorm()->epsilon(); }


PrimitiveC *InstanceNormCreator(const schema::Primitive *primitive) {
return PrimitiveC::NewPrimitiveC<InstanceNorm>(primitive);
}
Registry InstanceNormRegistry(schema::PrimitiveType_InstanceNorm, InstanceNormCreator);
#endif #endif
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore

+ 1
- 2
mindspore/lite/src/ops/pad.cc View File

@@ -87,11 +87,10 @@ int Pad::InferShape(std::vector<Tensor *> inputs, std::vector<Tensor *> outputs)
} }


std::vector<int> paddings; std::vector<int> paddings;
if (GetPaddingMode() == static_cast<int>(schema::PaddingMode_CONSTANT)) {
if (inputs.size() == 1) {
paddings = GetPaddings(); paddings = GetPaddings();
} else { } else {
// mirror pad // mirror pad
MS_ASSERT(inputs.size() == 2);
auto paddings_tensor = inputs.at(1); auto paddings_tensor = inputs.at(1);
int rank = static_cast<int>(inputs.front()->shape().size()); int rank = static_cast<int>(inputs.front()->shape().size());
MS_ASSERT(paddings_tensor->ElementsNum() == 2 * rank); MS_ASSERT(paddings_tensor->ElementsNum() == 2 * rank);


+ 2
- 2
mindspore/lite/src/ops/populate/instance_norm_populate.cc View File

@@ -21,7 +21,7 @@


namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
OpParameter *PopulateInstanceNorm(const mindspore::lite::PrimitiveC *primitive) {
OpParameter *PopulateInstanceNormParameter(const mindspore::lite::PrimitiveC *primitive) {
const auto param = const auto param =
reinterpret_cast<mindspore::lite::InstanceNorm *>(const_cast<mindspore::lite::PrimitiveC *>(primitive)); reinterpret_cast<mindspore::lite::InstanceNorm *>(const_cast<mindspore::lite::PrimitiveC *>(primitive));
InstanceNormParameter *instance_norm_param = InstanceNormParameter *instance_norm_param =
@@ -37,6 +37,6 @@ OpParameter *PopulateInstanceNorm(const mindspore::lite::PrimitiveC *primitive)
return reinterpret_cast<OpParameter *>(instance_norm_param); return reinterpret_cast<OpParameter *>(instance_norm_param);
} }


Registry InstanceNormParameterRegistry(schema::PrimitiveType_L2Norm, PopulateInstanceNorm);
Registry InstanceNormParameterRegistry(schema::PrimitiveType_InstanceNorm, PopulateInstanceNormParameter);
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore

+ 13
- 15
mindspore/lite/src/ops/populate/pad_populate.cc View File

@@ -32,23 +32,21 @@ OpParameter *PopulatePadParameter(const mindspore::lite::PrimitiveC *primitive)
pad_param->op_parameter_.type_ = primitive->Type(); pad_param->op_parameter_.type_ = primitive->Type();
auto pad_node = reinterpret_cast<mindspore::lite::Pad *>(const_cast<mindspore::lite::PrimitiveC *>(primitive)); auto pad_node = reinterpret_cast<mindspore::lite::Pad *>(const_cast<mindspore::lite::PrimitiveC *>(primitive));
pad_param->pad_mode_ = pad_node->GetPaddingMode(); pad_param->pad_mode_ = pad_node->GetPaddingMode();
if (pad_param->pad_mode_ == static_cast<int>(schema::PaddingMode_CONSTANT)) {
pad_param->constant_value_ = pad_node->GetConstantValue();
auto size = pad_node->GetPaddings().size();
if (size > MAX_PAD_SIZE) {
MS_LOG(ERROR) << "Invalid padding size: " << size;
free(pad_param);
return nullptr;
}
pad_param->constant_value_ = pad_node->GetConstantValue();
auto size = pad_node->GetPaddings().size();
if (size > MAX_PAD_SIZE) {
MS_LOG(ERROR) << "Invalid padding size: " << size;
free(pad_param);
return nullptr;
}


for (size_t i = 0; i < MAX_PAD_SIZE - size; ++i) {
pad_param->paddings_[i] = 0;
}
for (size_t i = 0; i < size; i++) {
pad_param->paddings_[MAX_PAD_SIZE - size + i] = pad_node->GetPaddings()[i];
}
pad_param->padding_length = MAX_PAD_SIZE;
for (size_t i = 0; i < MAX_PAD_SIZE - size; ++i) {
pad_param->paddings_[i] = 0;
}
for (size_t i = 0; i < size; i++) {
pad_param->paddings_[MAX_PAD_SIZE - size + i] = pad_node->GetPaddings()[i];
} }
pad_param->padding_length = MAX_PAD_SIZE;


return reinterpret_cast<OpParameter *>(pad_param); return reinterpret_cast<OpParameter *>(pad_param);
} }


+ 3
- 0
mindspore/lite/src/ops/primitive_c.cc View File

@@ -143,6 +143,7 @@
#include "src/ops/audio_spectrogram.h" #include "src/ops/audio_spectrogram.h"
#include "src/ops/mfcc.h" #include "src/ops/mfcc.h"
#include "src/ops/identity.h" #include "src/ops/identity.h"
#include "src/ops/instance_norm.h"


#ifdef SUPPORT_TRAIN #ifdef SUPPORT_TRAIN
#include "src/ops/neg_grad.h" #include "src/ops/neg_grad.h"
@@ -790,6 +791,8 @@ PrimitiveC *PrimitiveC::Create(mindspore::schema::PrimitiveT *primitive) {
return new AudioSpectrogram(primitive); return new AudioSpectrogram(primitive);
case schema::PrimitiveType_Mfcc: case schema::PrimitiveType_Mfcc:
return new Mfcc(primitive); return new Mfcc(primitive);
case schema::PrimitiveType_InstanceNorm:
return new InstanceNorm(primitive);


#ifdef SUPPORT_TRAIN #ifdef SUPPORT_TRAIN
case schema::PrimitiveType_ActivationGrad: case schema::PrimitiveType_ActivationGrad:


+ 13
- 4
mindspore/lite/src/runtime/kernel/arm/fp32/pad.cc View File

@@ -213,11 +213,20 @@ void PadCPUKernel::CalculateStrides() {
} }


int PadCPUKernel::HandleMirrorPad() { int PadCPUKernel::HandleMirrorPad() {
auto ret = CopyPaddingFromInput();
if (ret != RET_OK) {
return ret;
if (in_tensors_.size() == 1) {
auto input_shape = in_tensors_.at(0)->shape();
int rank = static_cast<int>(input_shape.size());
auto ret = ExtendShape(in_, DEFAULT_PAD_NDIMS, input_shape.data(), rank);
if (ret != RET_OK) {
return ret;
}
} else {
auto ret = CopyPaddingFromInput();
if (ret != RET_OK) {
return ret;
}
} }
ret = CheckPaddings(pad_param_->paddings_, DEFAULT_PAD_NDIMS, in_, pad_param_->pad_mode_);
auto ret = CheckPaddings(pad_param_->paddings_, DEFAULT_PAD_NDIMS, in_, pad_param_->pad_mode_);
if (ret != RET_OK) { if (ret != RET_OK) {
return ret; return ret;
} }


+ 2
- 1
mindspore/lite/tools/common/node_util.cc View File

@@ -46,7 +46,8 @@ static const std::vector<schema::PrimitiveType> nhwcOpList = {
schema::PrimitiveType_BatchNorm, schema::PrimitiveType_BatchNorm,
schema::PrimitiveType_FusedBatchNorm, schema::PrimitiveType_FusedBatchNorm,
schema::PrimitiveType_PReLU, schema::PrimitiveType_PReLU,
schema::PrimitiveType_BiasAdd};
schema::PrimitiveType_BiasAdd,
schema::PrimitiveType_InstanceNorm};


static const std::vector<schema::PrimitiveType> nhwcOpDualInputList = { static const std::vector<schema::PrimitiveType> nhwcOpDualInputList = {
#ifdef SUPPORT_TRAIN #ifdef SUPPORT_TRAIN


Loading…
Cancel
Save