|
|
|
@@ -54,6 +54,14 @@ STATUS TfliteActivationParser::Parse(const std::unique_ptr<tflite::OperatorT> &t |
|
|
|
} else if (std::strcmp(node_name, "Logistic") == 0) { |
|
|
|
MS_LOG(DEBUG) << "parse TfliteLogisticParser"; |
|
|
|
attr->type = schema::ActivationType_SIGMOID; |
|
|
|
} else if (std::strcmp(node_name, "LeakyRelu") == 0) { |
|
|
|
const auto &option = tfliteOp->builtin_options.AsLeakyReluOptions(); |
|
|
|
if (option == nullptr) { |
|
|
|
MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; |
|
|
|
return RET_NULL_PTR; |
|
|
|
} |
|
|
|
attr->type = schema::ActivationType_LEAKY_RELU; |
|
|
|
attr->alpha = option->alpha; |
|
|
|
} else { |
|
|
|
MS_LOG(ERROR) << "wrong activation type"; |
|
|
|
return RET_ERROR; |
|
|
|
@@ -92,36 +100,6 @@ STATUS TflitePreluParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite |
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
|
|
|
|
STATUS TfliteLeakyReluParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp, |
|
|
|
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors, |
|
|
|
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, |
|
|
|
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, |
|
|
|
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { |
|
|
|
if (op == nullptr) { |
|
|
|
MS_LOG(ERROR) << "op is null"; |
|
|
|
return RET_NULL_PTR; |
|
|
|
} |
|
|
|
op->primitive = std::make_unique<schema::PrimitiveT>(); |
|
|
|
if (op->primitive == nullptr) { |
|
|
|
MS_LOG(ERROR) << "op->primitive is null"; |
|
|
|
return RET_NULL_PTR; |
|
|
|
} |
|
|
|
|
|
|
|
MS_LOG(DEBUG) << "parse TfliteLeakyReluParser"; |
|
|
|
std::unique_ptr<schema::LeakyReLUT> attr(new schema::LeakyReLUT()); |
|
|
|
|
|
|
|
const auto &tflite_attr = tfliteOp->builtin_options.AsLeakyReluOptions(); |
|
|
|
if (tflite_attr == nullptr) { |
|
|
|
MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; |
|
|
|
return RET_NULL_PTR; |
|
|
|
} |
|
|
|
attr->negativeSlope = tflite_attr->alpha; |
|
|
|
|
|
|
|
op->primitive->value.type = schema::PrimitiveType_LeakyReLU; |
|
|
|
op->primitive->value.value = attr.release(); |
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
|
|
|
|
TfliteNodeRegister g_TfliteReluParser("Relu", new TfliteReluParser()); |
|
|
|
TfliteNodeRegister g_TfliteRelu6Parser("Relu6", new TfliteRelu6Parser()); |
|
|
|
TfliteNodeRegister g_TfliteTanhParser("Tanh", new TfliteTanhParser()); |
|
|
|
|