Browse Source

!12005 [MS][LITE] add tf parsers

From: @cjh9368
Reviewed-by: 
Signed-off-by:
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 4 years ago
parent
commit
027152b6ac
15 changed files with 710 additions and 193 deletions
  1. +3
    -1
      mindspore/lite/schema/model.fbs
  2. +8
    -0
      mindspore/lite/schema/ops.fbs
  3. +81
    -0
      mindspore/lite/src/ops/batch_matmul.cc
  4. +46
    -0
      mindspore/lite/src/ops/batch_matmul.h
  5. +33
    -0
      mindspore/lite/src/ops/is_finite.h
  6. +6
    -1
      mindspore/lite/src/ops/primitive_c.cc
  7. +217
    -150
      mindspore/lite/tools/common/node_util.h
  8. +10
    -12
      mindspore/lite/tools/converter/parser/tf/tf_arithmetic_self_parser.cc
  9. +64
    -0
      mindspore/lite/tools/converter/parser/tf/tf_batch_matmul_parser.cc
  10. +37
    -0
      mindspore/lite/tools/converter/parser/tf/tf_batch_matmul_parser.h
  11. +59
    -0
      mindspore/lite/tools/converter/parser/tf/tf_is_finite_parser.cc
  12. +37
    -0
      mindspore/lite/tools/converter/parser/tf/tf_is_finite_parser.h
  13. +12
    -29
      mindspore/lite/tools/converter/parser/tf/tf_logical_parser.cc
  14. +60
    -0
      mindspore/lite/tools/converter/parser/tf/tf_zeros_like_parser.cc
  15. +37
    -0
      mindspore/lite/tools/converter/parser/tf/tf_zeros_like_parser.h

+ 3
- 1
mindspore/lite/schema/model.fbs View File

@@ -273,7 +273,9 @@ union PrimitiveType {
RandomStandardNormal, RandomStandardNormal,
CropAndResize, CropAndResize,
Erf, Erf,
StridedSliceGrad
StridedSliceGrad,
IsFinite,
BatchMatMul,
} }


enum QuantType: int { enum QuantType: int {


+ 8
- 0
mindspore/lite/schema/ops.fbs View File

@@ -1274,4 +1274,12 @@ table StridedSliceGrad {
} }


table Erf { table Erf {
}

table IsFinite {
}

table BatchMatMul {
adj_x : bool = false;
adj_y : bool = false;
} }

+ 81
- 0
mindspore/lite/src/ops/batch_matmul.cc View File

@@ -0,0 +1,81 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/ops/batch_matmul.h"
#include <memory>
#ifndef PRIMITIVE_WRITEABLE
#include "src/ops/ops_register.h"
#endif

namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
bool BatchMatMul::GetAdjX() const { return this->primitive_->value.AsBatchMatMul()->adj_x; }

void BatchMatMul::SetAdjX(bool adj_x) { this->primitive_->value.AsBatchMatMul()->adj_x = adj_x; }

bool BatchMatMul::GetAdjY() const { return this->primitive_->value.AsBatchMatMul()->adj_y; }

void BatchMatMul::SetAdjY(bool adj_y) { this->primitive_->value.AsBatchMatMul()->adj_y = adj_y; }

int BatchMatMul::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
if (this->primitive_ == nullptr) {
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
if (this->primitive_ == nullptr) {
MS_LOG(ERROR) << "new primitiveT failed";
return RET_ERROR;
}
this->primitive_->value.type = schema::PrimitiveType_BatchMatMul;
}
if (this->primitive_->value.type != schema::PrimitiveType_BatchMatMul) {
MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type;
return RET_ERROR;
}
if (this->primitive_->value.value == nullptr) {
auto attr = new (std::nothrow) schema::BatchMatMulT();
if (attr == nullptr) {
MS_LOG(ERROR) << "new FusedBatchMatMulT failed";
delete this->primitive_;
this->primitive_ = nullptr;
return RET_ERROR;
}
attr->adj_x = GetValue<bool>(prim.GetAttr("adj_x"));
attr->adj_y = GetValue<bool>(prim.GetAttr("adj_y"));
this->primitive_->value.value = attr;
}
return RET_OK;
}

#else
int BatchMatMul::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);
MS_ASSERT(nullptr != fbb);
auto val_offset = schema::CreateBatchMatMul(*fbb);
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_BatchMatMul, val_offset.o);
fbb->Finish(prim_offset);
return RET_OK;
}
bool BatchMatMul::GetAdjX() const { return this->primitive_->value_as_BatchMatMul()->adj_x(); }

bool BatchMatMul::GetAdjY() const { return this->primitive_->value_as_BatchMatMul()->adj_y(); }

PrimitiveC *BatchMatMulCreator(const schema::Primitive *primitive) {
return PrimitiveC::NewPrimitiveC<BatchMatMul>(primitive);
}
Registry BatchMatMulRegistry(schema::PrimitiveType_BatchMatMul, BatchMatMulCreator);
#endif

} // namespace lite
} // namespace mindspore

+ 46
- 0
mindspore/lite/src/ops/batch_matmul.h View File

@@ -0,0 +1,46 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef LITE_MINDSPORE_LITE_C_OPS_BATCH_MATMUL_H_
#define LITE_MINDSPORE_LITE_C_OPS_BATCH_MATMUL_H_

#include <vector>
#include <set>
#include <cmath>
#include "src/ops/primitive_c.h"

namespace mindspore {
namespace lite {
class BatchMatMul : public PrimitiveC {
public:
BatchMatMul() = default;
~BatchMatMul() = default;
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(BatchMatMul, PrimitiveC);
explicit BatchMatMul(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
void SetAdjX(bool adj_x);
void SetAdjY(bool adj_y);
#else
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
bool GetAdjX() const;
bool GetAdjY() const;
};
} // namespace lite
} // namespace mindspore

#endif // LITE_MINDSPORE_LITE_C_OPS_BATCH_MATMUL_H_

+ 33
- 0
mindspore/lite/src/ops/is_finite.h View File

@@ -0,0 +1,33 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "src/ops/primitive_c.h"

#ifndef LITE_MINDSPORE_LITE_C_OPS_IS_FINITE_H_
#define LITE_MINDSPORE_LITE_C_OPS_IS_FINITE_H_

namespace mindspore {
namespace lite {
class IsFinite : public PrimitiveC {
public:
MS_DECLARE_PARENT(IsFinite, PrimitiveC);
IsFinite() = default;
~IsFinite() = default;
explicit IsFinite(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
};
} // namespace lite
} // namespace mindspore
#endif // LITE_MINDSPORE_LITE_C_OPS_IS_FINITE_H_

+ 6
- 1
mindspore/lite/src/ops/primitive_c.cc View File

@@ -170,6 +170,8 @@
#include "src/ops/crop_and_resize.h" #include "src/ops/crop_and_resize.h"
#include "src/ops/nonzero.h" #include "src/ops/nonzero.h"
#include "src/ops/erf.h" #include "src/ops/erf.h"
#include "src/ops/is_finite.h"
#include "src/ops/batch_matmul.h"


#ifdef SUPPORT_TRAIN #ifdef SUPPORT_TRAIN
#include "src/ops/neg_grad.h" #include "src/ops/neg_grad.h"
@@ -665,7 +667,6 @@ std::shared_ptr<PrimitiveC> PrimitiveC::Create(const Primitive &prim, const std:
return NewPrimitiveC<ArgMax>(prim, inputs, quantType); return NewPrimitiveC<ArgMax>(prim, inputs, quantType);
} else if (op_type == "Gelu") { } else if (op_type == "Gelu") {
return NewPrimitiveC<GeLU>(prim, inputs, quantType); return NewPrimitiveC<GeLU>(prim, inputs, quantType);

#ifdef SUPPORT_TRAIN #ifdef SUPPORT_TRAIN
} else if (op_type == "SoftmaxCrossEntropyWithLogits") { } else if (op_type == "SoftmaxCrossEntropyWithLogits") {
return NewPrimitiveC<SoftmaxCrossEntropy>(prim, inputs, quantType); return NewPrimitiveC<SoftmaxCrossEntropy>(prim, inputs, quantType);
@@ -1034,6 +1035,10 @@ PrimitiveC *PrimitiveC::Create(mindspore::schema::PrimitiveT *primitive) {
return new (std::nothrow) NonZero(primitive); return new (std::nothrow) NonZero(primitive);
case schema::PrimitiveType_Erf: case schema::PrimitiveType_Erf:
return new (std::nothrow) Erf(primitive); return new (std::nothrow) Erf(primitive);
case schema::PrimitiveType_IsFinite:
return new (std::nothrow) IsFinite(primitive);
case schema::PrimitiveType_BatchMatMul:
return new (std::nothrow) BatchMatMul(primitive);
#ifdef SUPPORT_TRAIN #ifdef SUPPORT_TRAIN
case schema::PrimitiveType_ActivationGrad: case schema::PrimitiveType_ActivationGrad:
return new (std::nothrow) ActivationGrad(primitive); return new (std::nothrow) ActivationGrad(primitive);


+ 217
- 150
mindspore/lite/tools/common/node_util.h View File

@@ -29,6 +29,18 @@


namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
template <typename T>
int CreateOperator(const std::unique_ptr<schema::PrimitiveT> &primitive, schema::PrimitiveType type) {
auto attr = std::make_unique<T>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new attr failed";
return RET_NULL_PTR;
}
primitive->value.type = type;
primitive->value.value = attr.release();
return RET_OK;
}

using STATUS = int; using STATUS = int;
STATUS BroadCastQuantParam(schema::MetaGraphT *graphT, const std::unique_ptr<schema::CNodeT> &node); STATUS BroadCastQuantParam(schema::MetaGraphT *graphT, const std::unique_ptr<schema::CNodeT> &node);


@@ -92,197 +104,252 @@ STATUS SetFilterDim(schema::TensorT *tensor, kTransFilterType type, int32_t filt
int32_t filterW); int32_t filterW);


template <typename T> template <typename T>
static STATUS TransFilterData(schema::TensorT *tensor, kTransFilterType type, int32_t filterK, int32_t filterC,
int32_t filterH, int32_t filterW) {
MS_ASSERT(tensor != nullptr);
int count = filterH * filterW * filterC * filterK;
if (count <= 0) {
MS_LOG(ERROR) << "Dim size invalid";
return RET_ERROR;
static void TransKHWC2CHWK(int32_t filterK, int32_t filterC, int32_t filterH, int32_t filterW, T *srcData, T *dstData) {
T *p1Buff = nullptr;
T *p2Buff = nullptr;
for (int k = 0; k < filterK; ++k) {
for (int h = 0; h < filterH; ++h) {
for (int w = 0; w < filterW; ++w) {
for (int c = 0; c < filterC; ++c) {
p1Buff = srcData + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c));
p2Buff = dstData + ((c * filterK * filterH * filterW) + (h * filterK * filterW) + (w * filterK) + (k));
*p2Buff = *p1Buff;
}
}
}
} }
std::unique_ptr<T[]> buf(new (std::nothrow) T[count]);
if (buf == nullptr) {
MS_LOG(ERROR) << "new buf failed";
return RET_ERROR;
}

template <typename T>
static void TransKHWC2HWCK(int32_t filterK, int32_t filterC, int32_t filterH, int32_t filterW, T *srcData, T *dstData) {
T *p1Buff = nullptr;
T *p2Buff = nullptr;
for (int k = 0; k < filterK; ++k) {
for (int h = 0; h < filterH; ++h) {
for (int w = 0; w < filterW; ++w) {
for (int c = 0; c < filterC; ++c) {
p1Buff = srcData + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c));
p2Buff = dstData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k));
*p2Buff = *p1Buff;
}
}
}
} }
}


void *originWeightDate = tensor->data.data();
T *weightData = static_cast<T *>(originWeightDate);
template <typename T>
static void TransCKHW(kTransFilterType type, int32_t filterK, int32_t filterC, int32_t filterH, int32_t filterW,
T *srcData, T *dstData) {
T *p1Buff = nullptr;
T *p2Buff = nullptr;
for (int c = 0; c < filterC; ++c) {
for (int k = 0; k < filterK; ++k) {
for (int h = 0; h < filterH; ++h) {
for (int w = 0; w < filterW; ++w) {
p1Buff = srcData + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w));
if (type == kCKHW2HWCK) {
p2Buff = dstData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k));
} else if (type == kCKHW2KHWC) {
p2Buff = dstData + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c));
} else {
p2Buff = dstData + ((h * filterW * filterK * filterC) + (w * filterK * filterC) + (k * filterC) + (c));
}
*p2Buff = *p1Buff;
}
}
}
}
}


if (weightData == nullptr) {
MS_LOG(ERROR) << "weightData is nullptr";
return RET_ERROR;
template <typename T>
static void TransKCHW(kTransFilterType type, int32_t filterK, int32_t filterC, int32_t filterH, int32_t filterW,
T *srcData, T *dstData) {
T *p1Buff = nullptr;
T *p2Buff = nullptr;
for (int k = 0; k < filterK; ++k) {
for (int c = 0; c < filterC; ++c) {
for (int h = 0; h < filterH; ++h) {
for (int w = 0; w < filterW; ++w) {
p1Buff = srcData + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w));
if (type == kKCHW2HWCK) {
p2Buff = dstData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k));
} else if (type == kKCHW2KHWC) {
p2Buff = dstData + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c));
} else if (type == kKCHW2CKHW) {
p2Buff = dstData + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w));
} else {
p2Buff = dstData + ((h * filterW * filterK * filterC) + (w * filterK * filterC) + (k * filterC) + (c));
}
*p2Buff = *p1Buff;
}
}
}
}
}

template <typename T>
static void TransCHWK(kTransFilterType type, int32_t filterK, int32_t filterC, int32_t filterH, int32_t filterW,
T *srcData, T *dstData) {
T *p1Buff = nullptr;
T *p2Buff = nullptr;
for (int c = 0; c < filterC; ++c) {
for (int h = 0; h < filterH; ++h) {
for (int w = 0; w < filterW; ++w) {
for (int k = 0; k < filterK; ++k) {
p1Buff = srcData + ((c * filterH * filterW * filterK) + (h * filterW * filterK) + (w * filterK) + (k));
if (type == kCHWK2HWCK) {
p2Buff = dstData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k));
} else {
p2Buff = dstData + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c));
}
*p2Buff = *p1Buff;
}
}
}
} }
}

template <typename T>
static void TransHWCK(kTransFilterType type, int32_t filterK, int32_t filterC, int32_t filterH, int32_t filterW,
T *srcData, T *dstData) {
T *p1Buff = nullptr; T *p1Buff = nullptr;
T *p2Buff = nullptr; T *p2Buff = nullptr;
switch (type) {
case kCHWK2HWCK:
case kCHWK2KHWC: {
for (int h = 0; h < filterH; ++h) {
for (int w = 0; w < filterW; ++w) {
for (int c = 0; c < filterC; ++c) { for (int c = 0; c < filterC; ++c) {
for (int h = 0; h < filterH; ++h) {
for (int w = 0; w < filterW; ++w) {
for (int k = 0; k < filterK; ++k) {
p1Buff = weightData + ((c * filterH * filterW * filterK) + (h * filterW * filterK) + (w * filterK) + (k));
if (type == kCHWK2HWCK) {
p2Buff =
buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k));
} else {
p2Buff =
buf.get() + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c));
}
*p2Buff = *p1Buff;
}
for (int k = 0; k < filterK; ++k) {
p1Buff = srcData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k));
if (type == kHWCK2KCHW) {
p2Buff = dstData + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w));
} else {
p2Buff = dstData + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w));
} }
*p2Buff = *p1Buff;
} }
} }
} break;
case kKHWC2HWCK: {
for (int k = 0; k < filterK; ++k) {
for (int h = 0; h < filterH; ++h) {
for (int w = 0; w < filterW; ++w) {
for (int c = 0; c < filterC; ++c) {
p1Buff = weightData + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c));
p2Buff = buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k));
*p2Buff = *p1Buff;
}
}
}
}

template <typename T>
static void TransHWKC(kTransFilterType type, int32_t filterK, int32_t filterC, int32_t filterH, int32_t filterW,
T *srcData, T *dstData) {
T *p1Buff = nullptr;
T *p2Buff = nullptr;
for (int h = 0; h < filterH; ++h) {
for (int w = 0; w < filterW; ++w) {
for (int c = 0; c < filterC; ++c) {
for (int k = 0; k < filterK; ++k) {
p1Buff = srcData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (k * filterC) + (c));
if (type == kHWKC2KCHW) {
p2Buff = dstData + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w));
} else {
p2Buff = dstData + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w));
}
*p2Buff = *p1Buff;
}
}
}
}
}

template <typename T>
static void TransNHWC(kTransFilterType type, int32_t filterK, int32_t filterC, int32_t filterH, int32_t filterW,
T *srcData, T *dstData) {
T *p1Buff = nullptr;
T *p2Buff = nullptr;
for (int k = 0; k < filterK; ++k) {
for (int h = 0; h < filterH; ++h) {
for (int w = 0; w < filterW; ++w) {
for (int c = 0; c < filterC; ++c) {
p1Buff = srcData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (k * filterC) + (c));
if (type == kNHWC2HWCK) {
p2Buff = dstData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k));
} else if (type == kNHWC2CKHW) {
p2Buff = dstData + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w));
} else {
p2Buff = dstData + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w));
} }
*p2Buff = *p1Buff;
} }
} }
}
}
}

template <typename T>
static STATUS TransFilterData(kTransFilterType type, int32_t filterK, int32_t filterC, int32_t filterH, int32_t filterW,
T *srcData, T *dstData) {
switch (type) {
case kCHWK2HWCK:
case kCHWK2KHWC: {
TransCHWK(type, filterK, filterC, filterH, filterW, srcData, dstData);
} break;
case kKHWC2HWCK: {
TransKHWC2HWCK(filterK, filterC, filterH, filterW, srcData, dstData);
} break; } break;
case kKCHW2HWCK: case kKCHW2HWCK:
case kKCHW2CKHW: case kKCHW2CKHW:
case kKCHW2KHWC: case kKCHW2KHWC:
case kKCHW2HWKC: { case kKCHW2HWKC: {
for (int k = 0; k < filterK; ++k) {
for (int c = 0; c < filterC; ++c) {
for (int h = 0; h < filterH; ++h) {
for (int w = 0; w < filterW; ++w) {
p1Buff = weightData + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w));
if (type == kKCHW2HWCK) {
p2Buff =
buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k));
} else if (type == kKCHW2KHWC) {
p2Buff =
buf.get() + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c));
} else if (type == kKCHW2CKHW) {
p2Buff =
buf.get() + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w));
} else {
p2Buff =
buf.get() + ((h * filterW * filterK * filterC) + (w * filterK * filterC) + (k * filterC) + (c));
}
*p2Buff = *p1Buff;
}
}
}
}
TransKCHW(type, filterK, filterC, filterH, filterW, srcData, dstData);
} break; } break;
case kCKHW2HWCK: case kCKHW2HWCK:
case kCKHW2KHWC: case kCKHW2KHWC:
case kCKHW2HWKC: { case kCKHW2HWKC: {
for (int c = 0; c < filterC; ++c) {
for (int k = 0; k < filterK; ++k) {
for (int h = 0; h < filterH; ++h) {
for (int w = 0; w < filterW; ++w) {
p1Buff = weightData + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w));
if (type == kCKHW2HWCK) {
p2Buff =
buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k));
} else if (type == kCKHW2KHWC) {
p2Buff =
buf.get() + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c));
} else {
p2Buff =
buf.get() + ((h * filterW * filterK * filterC) + (w * filterK * filterC) + (k * filterC) + (c));
}
*p2Buff = *p1Buff;
}
}
}
}
TransCKHW(type, filterK, filterC, filterH, filterW, srcData, dstData);
} break; } break;
case kHWCK2KCHW: case kHWCK2KCHW:
case kHWCK2CKHW: { case kHWCK2CKHW: {
for (int h = 0; h < filterH; ++h) {
for (int w = 0; w < filterW; ++w) {
for (int c = 0; c < filterC; ++c) {
for (int k = 0; k < filterK; ++k) {
p1Buff = weightData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k));
if (type == kHWCK2KCHW) {
p2Buff =
buf.get() + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w));
} else {
p2Buff =
buf.get() + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w));
}
*p2Buff = *p1Buff;
}
}
}
}
TransHWCK(type, filterK, filterC, filterH, filterW, srcData, dstData);
} break; } break;
case kHWKC2KCHW: case kHWKC2KCHW:
case kHWKC2CKHW: { case kHWKC2CKHW: {
for (int h = 0; h < filterH; ++h) {
for (int w = 0; w < filterW; ++w) {
for (int c = 0; c < filterC; ++c) {
for (int k = 0; k < filterK; ++k) {
p1Buff = weightData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (k * filterC) + (c));
if (type == kHWKC2KCHW) {
p2Buff =
buf.get() + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w));
} else {
p2Buff =
buf.get() + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w));
}
*p2Buff = *p1Buff;
}
}
}
}
TransHWKC(type, filterK, filterC, filterH, filterW, srcData, dstData);
} break; } break;
case kNHWC2HWCK: case kNHWC2HWCK:
case kNHWC2KCHW: case kNHWC2KCHW:
case kNHWC2CKHW: { case kNHWC2CKHW: {
for (int k = 0; k < filterK; ++k) {
for (int h = 0; h < filterH; ++h) {
for (int w = 0; w < filterW; ++w) {
for (int c = 0; c < filterC; ++c) {
p1Buff = weightData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (k * filterC) + (c));
if (type == kNHWC2HWCK) {
p2Buff =
buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k));
} else if (type == kNHWC2CKHW) {
p2Buff =
buf.get() + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w));
} else {
p2Buff =
buf.get() + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w));
}
*p2Buff = *p1Buff;
}
}
}
}
TransNHWC(type, filterK, filterC, filterH, filterW, srcData, dstData);
} break; } break;
case kKHWC2CHWK: { case kKHWC2CHWK: {
for (int k = 0; k < filterK; ++k) {
for (int h = 0; h < filterH; ++h) {
for (int w = 0; w < filterW; ++w) {
for (int c = 0; c < filterC; ++c) {
p1Buff = weightData + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c));
p2Buff = buf.get() + ((c * filterK * filterH * filterW) + (h * filterK * filterW) + (w * filterK) + (k));
*p2Buff = *p1Buff;
}
}
}
}
TransKHWC2CHWK(filterK, filterC, filterH, filterW, srcData, dstData);
} break; } break;
default: { default: {
MS_LOG(ERROR) << "Unsupported transFilterType: " << type; MS_LOG(ERROR) << "Unsupported transFilterType: " << type;
return RET_ERROR; return RET_ERROR;
} }
} }
return RET_OK;
}

template <typename T>
static STATUS TransFilterData(schema::TensorT *tensor, kTransFilterType type, int32_t filterK, int32_t filterC,
int32_t filterH, int32_t filterW) {
MS_ASSERT(tensor != nullptr);
int count = filterH * filterW * filterC * filterK;
if (count <= 0) {
MS_LOG(ERROR) << "Dim size invalid";
return RET_ERROR;
}
std::unique_ptr<T[]> buf(new (std::nothrow) T[count]);
if (buf == nullptr) {
MS_LOG(ERROR) << "new buf failed";
return RET_ERROR;
}

void *originWeightDate = tensor->data.data();
T *weightData = static_cast<T *>(originWeightDate);

if (weightData == nullptr) {
MS_LOG(ERROR) << "weightData is nullptr";
return RET_ERROR;
}

if (TransFilterData(type, filterK, filterC, filterH, filterW, weightData, buf.get()) != RET_OK) {
MS_LOG(ERROR) << "TransFilterData failed";
return RET_ERROR;
}


auto ret = ::memcpy_s(tensor->data.data(), count * sizeof(T), buf.get(), count * sizeof(T)); auto ret = ::memcpy_s(tensor->data.data(), count * sizeof(T), buf.get(), count * sizeof(T));
if (ret != EOK) { if (ret != EOK) {


+ 10
- 12
mindspore/lite/tools/converter/parser/tf/tf_arithmetic_self_parser.cc View File

@@ -19,22 +19,11 @@
#include <map> #include <map>
#include <vector> #include <vector>
#include "tools/converter/parser/tf/tf_node_parser_registry.h" #include "tools/converter/parser/tf/tf_node_parser_registry.h"
#include "tools/common/node_util.h"


namespace mindspore { namespace mindspore {
namespace lite { namespace lite {


template <typename T>
int CreateOperator(const std::unique_ptr<schema::PrimitiveT> &primitive, schema::PrimitiveType type) {
auto attr = std::make_unique<T>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new attr failed";
return RET_NULL_PTR;
}
primitive->value.type = type;
primitive->value.value = attr.release();
return RET_OK;
}

STATUS TFArithmeticSelfParser::Parse(const tensorflow::NodeDef &tf_op, STATUS TFArithmeticSelfParser::Parse(const tensorflow::NodeDef &tf_op,
const std::map<string, const tensorflow::NodeDef *> &tf_node_map, const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) { PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) {
@@ -61,6 +50,12 @@ STATUS TFArithmeticSelfParser::Parse(const tensorflow::NodeDef &tf_op,
status = CreateOperator<schema::LogT>(primitive, schema::PrimitiveType_Log); status = CreateOperator<schema::LogT>(primitive, schema::PrimitiveType_Log);
} else if (tf_op.op() == "Sqrt") { } else if (tf_op.op() == "Sqrt") {
status = CreateOperator<schema::SqrtT>(primitive, schema::PrimitiveType_Sqrt); status = CreateOperator<schema::SqrtT>(primitive, schema::PrimitiveType_Sqrt);
} else if (tf_op.op() == "Cos") {
status = CreateOperator<schema::CosT>(primitive, schema::PrimitiveType_Cos);
} else if (tf_op.op() == "Sin") {
status = CreateOperator<schema::SinT>(primitive, schema::PrimitiveType_Sin);
} else if (tf_op.op() == "Square") {
status = CreateOperator<schema::SquareT>(primitive, schema::PrimitiveType_Square);
} else if (tf_op.op() == "Pow") { } else if (tf_op.op() == "Pow") {
status = CreateOperator<schema::PowerT>(primitive, schema::PrimitiveType_Power); status = CreateOperator<schema::PowerT>(primitive, schema::PrimitiveType_Power);
} }
@@ -81,6 +76,9 @@ STATUS TFArithmeticSelfParser::Parse(const tensorflow::NodeDef &tf_op,
} }
return status; return status;
} }
TFNodeRegistrar g_tfCosParser("Cos", new TFArithmeticSelfParser());
TFNodeRegistrar g_tfSinParser("Sin", new TFArithmeticSelfParser());
TFNodeRegistrar g_tfSquareParser("Square", new TFArithmeticSelfParser());
TFNodeRegistrar g_tfCeilParser("Ceil", new TFArithmeticSelfParser()); TFNodeRegistrar g_tfCeilParser("Ceil", new TFArithmeticSelfParser());
TFNodeRegistrar g_tfExpParser("Exp", new TFArithmeticSelfParser()); TFNodeRegistrar g_tfExpParser("Exp", new TFArithmeticSelfParser());
TFNodeRegistrar g_tfFloorParser("Floor", new TFArithmeticSelfParser()); TFNodeRegistrar g_tfFloorParser("Floor", new TFArithmeticSelfParser());


+ 64
- 0
mindspore/lite/tools/converter/parser/tf/tf_batch_matmul_parser.cc View File

@@ -0,0 +1,64 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/parser/tf/tf_batch_matmul_parser.h"
#include <string>
#include <memory>
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser_registry.h"

namespace mindspore {
namespace lite {
STATUS TFBatchMatmulParser::Parse(const tensorflow::NodeDef &tf_op,
const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) {
if (primitiveC == nullptr || output_size == nullptr) {
MS_LOG(ERROR) << "primitiveC is nullptr";
return RET_NULL_PTR;
}

auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "primitive is nullptr";
return RET_NULL_PTR;
}
auto attr = std::make_unique<schema::BatchMatMulT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
tensorflow::AttrValue attr_value;
TensorFlowUtils::FindAttrValue(tf_op, "adj_x", &attr_value);
attr->adj_x = attr_value.b();
attr->adj_y = attr_value.b();

primitive->value.type = schema::PrimitiveType_BatchMatMul;
primitive->value.value = attr.release();
*primitiveC = PrimitiveC::Create(primitive.release());
if (*primitiveC == nullptr) {
MS_LOG(ERROR) << "primitiveC is nullptr";
return RET_ERROR;
}

*output_size = 1;
for (int i = 0; i < tf_op.input_size(); i++) {
inputs->emplace_back(tf_op.input(i));
}
return RET_OK;
}
TFNodeRegistrar g_tfBatchMatMulParser("BatchMatMul", new TFBatchMatmulParser());
} // namespace lite
} // namespace mindspore

+ 37
- 0
mindspore/lite/tools/converter/parser/tf/tf_batch_matmul_parser.h View File

@@ -0,0 +1,37 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_BATCH_MATMUL_PARSER_H_
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_BATCH_MATMUL_PARSER_H_
#include <string>
#include <memory>
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser.h"

namespace mindspore {
namespace lite {
class TFBatchMatmulParser : public TFNodeParser {
public:
TFBatchMatmulParser() = default;
~TFBatchMatmulParser() override = default;

STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) override;
};
} // namespace lite
} // namespace mindspore

#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_BATCH_MATMUL_PARSER_H_

+ 59
- 0
mindspore/lite/tools/converter/parser/tf/tf_is_finite_parser.cc View File

@@ -0,0 +1,59 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/parser/tf/tf_is_finite_parser.h"
#include <map>
#include <memory>
#include <string>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser_registry.h"
#include "tools/common/node_util.h"

namespace mindspore {
namespace lite {
STATUS TFIsFiniteParser::Parse(const tensorflow::NodeDef &tf_op,
const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) {
if (primitiveC == nullptr || output_size == nullptr) {
MS_LOG(ERROR) << "primitiveC is nullptr";
return RET_NULL_PTR;
}

auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "primitive is nullptr";
return RET_NULL_PTR;
}

int status = CreateOperator<schema::IsFiniteT>(primitive, schema::PrimitiveType_IsFinite);
if (status != RET_OK) {
return status;
}
*primitiveC = PrimitiveC::Create(primitive.release());
if (*primitiveC == nullptr) {
MS_LOG(ERROR) << "primitiveC is nullptr";
return RET_ERROR;
}

*output_size = 1;
for (int i = 0; i < tf_op.input_size(); i++) {
inputs->emplace_back(tf_op.input(i));
}

return RET_OK;
}
TFNodeRegistrar g_tf_is_finite_parser("IsFinite", new TFIsFiniteParser());
} // namespace lite
} // namespace mindspore

+ 37
- 0
mindspore/lite/tools/converter/parser/tf/tf_is_finite_parser.h View File

@@ -0,0 +1,37 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_IS_FINITE_PARSER_H_
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_IS_FINITE_PARSER_H_

#include <map>
#include <memory>
#include <string>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser.h"

namespace mindspore {
namespace lite {
class TFIsFiniteParser : public TFNodeParser {
public:
TFIsFiniteParser() = default;
~TFIsFiniteParser() override = default;

STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) override;
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_IS_FINITE_PARSER_H_

+ 12
- 29
mindspore/lite/tools/converter/parser/tf/tf_logical_parser.cc View File

@@ -19,6 +19,7 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include "tools/converter/parser/tf/tf_node_parser_registry.h" #include "tools/converter/parser/tf/tf_node_parser_registry.h"
#include "tools/common/node_util.h"


namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
@@ -36,37 +37,19 @@ STATUS TFLogicalParser::Parse(const tensorflow::NodeDef &tf_op,
MS_LOG(ERROR) << "primitive is nullptr"; MS_LOG(ERROR) << "primitive is nullptr";
return RET_NULL_PTR; return RET_NULL_PTR;
} }

int status = RET_ERROR;
if (tf_op.op() == "LogicalAnd") { if (tf_op.op() == "LogicalAnd") {
auto attr = std::make_unique<schema::LogicalAndT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
primitive->value.type = schema::PrimitiveType_LogicalAnd;
primitive->value.value = attr.release();
*primitiveC = PrimitiveC::Create(primitive.release());
status = CreateOperator<schema::LogicalAndT>(primitive, schema::PrimitiveType_LogicalAnd);
} else if (tf_op.op() == "LogicalOr") { } else if (tf_op.op() == "LogicalOr") {
auto attr = std::make_unique<schema::LogicalOrT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
primitive->value.type = schema::PrimitiveType_LogicalOr;
primitive->value.value = attr.release();
*primitiveC = PrimitiveC::Create(primitive.release());
status = CreateOperator<schema::LogicalOrT>(primitive, schema::PrimitiveType_LogicalOr);
} else if (tf_op.op() == "LogicalNot") { } else if (tf_op.op() == "LogicalNot") {
auto attr = std::make_unique<schema::LogicalNotT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
primitive->value.type = schema::PrimitiveType_LogicalNot;
primitive->value.value = attr.release();
*primitiveC = PrimitiveC::Create(primitive.release());
} else {
MS_LOG(ERROR) << tf_op.op() << " is not supported.";
return RET_ERROR;
status = CreateOperator<schema::LogicalNotT>(primitive, schema::PrimitiveType_LogicalNot);
} }
if (status != RET_OK) {
return status;
}
*primitiveC = PrimitiveC::Create(primitive.release());
if (*primitiveC == nullptr) { if (*primitiveC == nullptr) {
MS_LOG(ERROR) << "primitiveC is nullptr"; MS_LOG(ERROR) << "primitiveC is nullptr";
return RET_ERROR; return RET_ERROR;
@@ -79,8 +62,8 @@ STATUS TFLogicalParser::Parse(const tensorflow::NodeDef &tf_op,


return RET_OK; return RET_OK;
} }
TFNodeRegistrar g_tfLogicalAndParser("LogicalAnd", new TFLogicalParser());
TFNodeRegistrar g_tfLogicalOrParser("LogicalOr", new TFLogicalParser());
TFNodeRegistrar g_tfLogicalNotParser("LogicalNot", new TFLogicalParser()); TFNodeRegistrar g_tfLogicalNotParser("LogicalNot", new TFLogicalParser());
TFNodeRegistrar g_tfLogicalOrParser("LogicalOr", new TFLogicalParser());
TFNodeRegistrar g_tfLogicalAndParser("LogicalAnd", new TFLogicalParser());
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore

+ 60
- 0
mindspore/lite/tools/converter/parser/tf/tf_zeros_like_parser.cc View File

@@ -0,0 +1,60 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/parser/tf/tf_zeros_like_parser.h"
#include <string>
#include <memory>
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser_registry.h"

namespace mindspore {
namespace lite {
STATUS TFZerosLikeParser::Parse(const tensorflow::NodeDef &tf_op,
const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) {
if (primitiveC == nullptr || output_size == nullptr) {
MS_LOG(ERROR) << "primitiveC is nullptr";
return RET_NULL_PTR;
}

auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "primitive is nullptr";
return RET_NULL_PTR;
}
auto attr = std::make_unique<schema::ZerosLikeT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}

primitive->value.type = schema::PrimitiveType_ZerosLike;
primitive->value.value = attr.release();
*primitiveC = PrimitiveC::Create(primitive.release());
if (*primitiveC == nullptr) {
MS_LOG(ERROR) << "primitiveC is nullptr";
return RET_ERROR;
}

*output_size = tf_op.input_size();
for (int i = 0; i < tf_op.input_size(); i++) {
inputs->emplace_back(tf_op.input(i));
}
return RET_OK;
}
TFNodeRegistrar g_tfZerosLikeParser("ZerosLike", new TFZerosLikeParser());
} // namespace lite
} // namespace mindspore

+ 37
- 0
mindspore/lite/tools/converter/parser/tf/tf_zeros_like_parser.h View File

@@ -0,0 +1,37 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_ZERO_LIKE_PARSER_H_
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_ZERO_LIKE_PARSER_H_
#include <string>
#include <memory>
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser.h"

namespace mindspore {
namespace lite {
class TFZerosLikeParser : public TFNodeParser {
public:
TFZerosLikeParser() = default;
~TFZerosLikeParser() override = default;

STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) override;
};
} // namespace lite
} // namespace mindspore

#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_ZERO_LIKE_PARSER_H_

Loading…
Cancel
Save