Browse Source

fix CopyPrimitive and leaky_relu_parser

tags/v0.7.0-beta
sunsuodong 5 years ago
parent
commit
5a7863de32
2 changed files with 10 additions and 3 deletions
  1. +6
    -0
      mindspore/lite/src/model_impl.cc
  2. +4
    -3
      mindspore/lite/tools/converter/parser/tflite/tflite_leaky_relu_parser.cc

+ 6
- 0
mindspore/lite/src/model_impl.cc View File

@@ -166,6 +166,8 @@ lite::Primitive *ModelImpl::CopyPrimitive(const schema::Primitive *srcPrim) {
return new lite::Exp(const_cast<schema::Primitive *>(srcPrim));
case schema::PrimitiveType_Gather:
return new lite::Gather(const_cast<schema::Primitive *>(srcPrim));
case schema::PrimitiveType_GatherNd:
return new lite::GatherNd(const_cast<schema::Primitive *>(srcPrim));
case schema::PrimitiveType_LocalResponseNormalization:
return new lite::LocalResponseNormalization(const_cast<schema::Primitive *>(srcPrim));
case schema::PrimitiveType_Maximum:
@@ -180,6 +182,8 @@ lite::Primitive *ModelImpl::CopyPrimitive(const schema::Primitive *srcPrim) {
return new lite::Prelu(const_cast<schema::Primitive *>(srcPrim));
case schema::PrimitiveType_Round:
return new lite::Round(const_cast<schema::Primitive *>(srcPrim));
case schema::PrimitiveType_Reverse:
return new lite::Reverse(const_cast<schema::Primitive *>(srcPrim));
case schema::PrimitiveType_ReverseSequence:
return new lite::ReverseSequence(const_cast<schema::Primitive *>(srcPrim));
case schema::PrimitiveType_LogicalAnd:
@@ -212,6 +216,8 @@ lite::Primitive *ModelImpl::CopyPrimitive(const schema::Primitive *srcPrim) {
return new lite::Split(const_cast<schema::Primitive *>(srcPrim));
case schema::PrimitiveType_OneHot:
return new lite::OneHot(const_cast<schema::Primitive *>(srcPrim));
case schema::PrimitiveType_Resize:
return new lite::Resize(const_cast<schema::Primitive *>(srcPrim));
case schema::PrimitiveType_MatMul:
return new lite::MatMul(const_cast<schema::Primitive *>(srcPrim));
case schema::PrimitiveType_QuantDTypeCast:


+ 4
- 3
mindspore/lite/tools/converter/parser/tflite/tflite_leaky_relu_parser.cc View File

@@ -26,18 +26,19 @@ STATUS TfliteLeakyReluParser::Parse(const std::unique_ptr<tflite::OperatorT> &tf
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) {
MS_LOG(DEBUG) << "parse TfliteLeakyReluParser";
std::unique_ptr<schema::LeakyReLUT> attr(new schema::LeakyReLUT());
std::unique_ptr<schema::ActivationT> attr(new schema::ActivationT());

const auto &tflite_attr = tfliteOp->builtin_options.AsLeakyReluOptions();
if (tflite_attr == nullptr) {
MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed";
return RET_NULL_PTR;
}
attr->negativeSlope = tflite_attr->alpha;
attr->type = schema::ActivationType_LEAKY_RELU;
attr->alpha = tflite_attr->alpha;

if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_LeakyReLU;
op->primitive->value.type = schema::PrimitiveType_Activation;
op->primitive->value.value = attr.release();
}
return RET_OK;


Loading…
Cancel
Save