From: @cjh9368 Reviewed-by: Signed-off-by:tags/v1.2.0-rc1
| @@ -273,7 +273,9 @@ union PrimitiveType { | |||
| RandomStandardNormal, | |||
| CropAndResize, | |||
| Erf, | |||
| StridedSliceGrad | |||
| StridedSliceGrad, | |||
| IsFinite, | |||
| BatchMatMul, | |||
| } | |||
| enum QuantType: int { | |||
| @@ -1274,4 +1274,12 @@ table StridedSliceGrad { | |||
| } | |||
| table Erf { | |||
| } | |||
| table IsFinite { | |||
| } | |||
| table BatchMatMul { | |||
| adj_x : bool = false; | |||
| adj_y : bool = false; | |||
| } | |||
| @@ -0,0 +1,81 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "src/ops/batch_matmul.h" | |||
| #include <memory> | |||
| #ifndef PRIMITIVE_WRITEABLE | |||
| #include "src/ops/ops_register.h" | |||
| #endif | |||
| namespace mindspore { | |||
| namespace lite { | |||
| #ifdef PRIMITIVE_WRITEABLE | |||
| bool BatchMatMul::GetAdjX() const { return this->primitive_->value.AsBatchMatMul()->adj_x; } | |||
| void BatchMatMul::SetAdjX(bool adj_x) { this->primitive_->value.AsBatchMatMul()->adj_x = adj_x; } | |||
| bool BatchMatMul::GetAdjY() const { return this->primitive_->value.AsBatchMatMul()->adj_y; } | |||
| void BatchMatMul::SetAdjY(bool adj_y) { this->primitive_->value.AsBatchMatMul()->adj_y = adj_y; } | |||
| int BatchMatMul::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { | |||
| if (this->primitive_ == nullptr) { | |||
| this->primitive_ = new (std::nothrow) schema::PrimitiveT; | |||
| if (this->primitive_ == nullptr) { | |||
| MS_LOG(ERROR) << "new primitiveT failed"; | |||
| return RET_ERROR; | |||
| } | |||
| this->primitive_->value.type = schema::PrimitiveType_BatchMatMul; | |||
| } | |||
| if (this->primitive_->value.type != schema::PrimitiveType_BatchMatMul) { | |||
| MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type; | |||
| return RET_ERROR; | |||
| } | |||
| if (this->primitive_->value.value == nullptr) { | |||
| auto attr = new (std::nothrow) schema::BatchMatMulT(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new FusedBatchMatMulT failed"; | |||
| delete this->primitive_; | |||
| this->primitive_ = nullptr; | |||
| return RET_ERROR; | |||
| } | |||
| attr->adj_x = GetValue<bool>(prim.GetAttr("adj_x")); | |||
| attr->adj_y = GetValue<bool>(prim.GetAttr("adj_y")); | |||
| this->primitive_->value.value = attr; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| #else | |||
| int BatchMatMul::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { | |||
| MS_ASSERT(nullptr != primitive); | |||
| MS_ASSERT(nullptr != fbb); | |||
| auto val_offset = schema::CreateBatchMatMul(*fbb); | |||
| auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_BatchMatMul, val_offset.o); | |||
| fbb->Finish(prim_offset); | |||
| return RET_OK; | |||
| } | |||
| bool BatchMatMul::GetAdjX() const { return this->primitive_->value_as_BatchMatMul()->adj_x(); } | |||
| bool BatchMatMul::GetAdjY() const { return this->primitive_->value_as_BatchMatMul()->adj_y(); } | |||
| PrimitiveC *BatchMatMulCreator(const schema::Primitive *primitive) { | |||
| return PrimitiveC::NewPrimitiveC<BatchMatMul>(primitive); | |||
| } | |||
| Registry BatchMatMulRegistry(schema::PrimitiveType_BatchMatMul, BatchMatMulCreator); | |||
| #endif | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,46 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef LITE_MINDSPORE_LITE_C_OPS_BATCH_MATMUL_H_ | |||
| #define LITE_MINDSPORE_LITE_C_OPS_BATCH_MATMUL_H_ | |||
| #include <vector> | |||
| #include <set> | |||
| #include <cmath> | |||
| #include "src/ops/primitive_c.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| class BatchMatMul : public PrimitiveC { | |||
| public: | |||
| BatchMatMul() = default; | |||
| ~BatchMatMul() = default; | |||
| #ifdef PRIMITIVE_WRITEABLE | |||
| MS_DECLARE_PARENT(BatchMatMul, PrimitiveC); | |||
| explicit BatchMatMul(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} | |||
| int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override; | |||
| void SetAdjX(bool adj_x); | |||
| void SetAdjY(bool adj_y); | |||
| #else | |||
| int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; | |||
| #endif | |||
| bool GetAdjX() const; | |||
| bool GetAdjY() const; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // LITE_MINDSPORE_LITE_C_OPS_BATCH_MATMUL_H_ | |||
| @@ -0,0 +1,33 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "src/ops/primitive_c.h" | |||
| #ifndef LITE_MINDSPORE_LITE_C_OPS_IS_FINITE_H_ | |||
| #define LITE_MINDSPORE_LITE_C_OPS_IS_FINITE_H_ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| class IsFinite : public PrimitiveC { | |||
| public: | |||
| MS_DECLARE_PARENT(IsFinite, PrimitiveC); | |||
| IsFinite() = default; | |||
| ~IsFinite() = default; | |||
| explicit IsFinite(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // LITE_MINDSPORE_LITE_C_OPS_IS_FINITE_H_ | |||
| @@ -170,6 +170,8 @@ | |||
| #include "src/ops/crop_and_resize.h" | |||
| #include "src/ops/nonzero.h" | |||
| #include "src/ops/erf.h" | |||
| #include "src/ops/is_finite.h" | |||
| #include "src/ops/batch_matmul.h" | |||
| #ifdef SUPPORT_TRAIN | |||
| #include "src/ops/neg_grad.h" | |||
| @@ -665,7 +667,6 @@ std::shared_ptr<PrimitiveC> PrimitiveC::Create(const Primitive &prim, const std: | |||
| return NewPrimitiveC<ArgMax>(prim, inputs, quantType); | |||
| } else if (op_type == "Gelu") { | |||
| return NewPrimitiveC<GeLU>(prim, inputs, quantType); | |||
| #ifdef SUPPORT_TRAIN | |||
| } else if (op_type == "SoftmaxCrossEntropyWithLogits") { | |||
| return NewPrimitiveC<SoftmaxCrossEntropy>(prim, inputs, quantType); | |||
| @@ -1034,6 +1035,10 @@ PrimitiveC *PrimitiveC::Create(mindspore::schema::PrimitiveT *primitive) { | |||
| return new (std::nothrow) NonZero(primitive); | |||
| case schema::PrimitiveType_Erf: | |||
| return new (std::nothrow) Erf(primitive); | |||
| case schema::PrimitiveType_IsFinite: | |||
| return new (std::nothrow) IsFinite(primitive); | |||
| case schema::PrimitiveType_BatchMatMul: | |||
| return new (std::nothrow) BatchMatMul(primitive); | |||
| #ifdef SUPPORT_TRAIN | |||
| case schema::PrimitiveType_ActivationGrad: | |||
| return new (std::nothrow) ActivationGrad(primitive); | |||
| @@ -29,6 +29,18 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| template <typename T> | |||
| int CreateOperator(const std::unique_ptr<schema::PrimitiveT> &primitive, schema::PrimitiveType type) { | |||
| auto attr = std::make_unique<T>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new attr failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| primitive->value.type = type; | |||
| primitive->value.value = attr.release(); | |||
| return RET_OK; | |||
| } | |||
| using STATUS = int; | |||
| STATUS BroadCastQuantParam(schema::MetaGraphT *graphT, const std::unique_ptr<schema::CNodeT> &node); | |||
| @@ -92,197 +104,252 @@ STATUS SetFilterDim(schema::TensorT *tensor, kTransFilterType type, int32_t filt | |||
| int32_t filterW); | |||
| template <typename T> | |||
| static STATUS TransFilterData(schema::TensorT *tensor, kTransFilterType type, int32_t filterK, int32_t filterC, | |||
| int32_t filterH, int32_t filterW) { | |||
| MS_ASSERT(tensor != nullptr); | |||
| int count = filterH * filterW * filterC * filterK; | |||
| if (count <= 0) { | |||
| MS_LOG(ERROR) << "Dim size invalid"; | |||
| return RET_ERROR; | |||
| static void TransKHWC2CHWK(int32_t filterK, int32_t filterC, int32_t filterH, int32_t filterW, T *srcData, T *dstData) { | |||
| T *p1Buff = nullptr; | |||
| T *p2Buff = nullptr; | |||
| for (int k = 0; k < filterK; ++k) { | |||
| for (int h = 0; h < filterH; ++h) { | |||
| for (int w = 0; w < filterW; ++w) { | |||
| for (int c = 0; c < filterC; ++c) { | |||
| p1Buff = srcData + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c)); | |||
| p2Buff = dstData + ((c * filterK * filterH * filterW) + (h * filterK * filterW) + (w * filterK) + (k)); | |||
| *p2Buff = *p1Buff; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| std::unique_ptr<T[]> buf(new (std::nothrow) T[count]); | |||
| if (buf == nullptr) { | |||
| MS_LOG(ERROR) << "new buf failed"; | |||
| return RET_ERROR; | |||
| } | |||
| template <typename T> | |||
| static void TransKHWC2HWCK(int32_t filterK, int32_t filterC, int32_t filterH, int32_t filterW, T *srcData, T *dstData) { | |||
| T *p1Buff = nullptr; | |||
| T *p2Buff = nullptr; | |||
| for (int k = 0; k < filterK; ++k) { | |||
| for (int h = 0; h < filterH; ++h) { | |||
| for (int w = 0; w < filterW; ++w) { | |||
| for (int c = 0; c < filterC; ++c) { | |||
| p1Buff = srcData + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c)); | |||
| p2Buff = dstData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k)); | |||
| *p2Buff = *p1Buff; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| void *originWeightDate = tensor->data.data(); | |||
| T *weightData = static_cast<T *>(originWeightDate); | |||
| template <typename T> | |||
| static void TransCKHW(kTransFilterType type, int32_t filterK, int32_t filterC, int32_t filterH, int32_t filterW, | |||
| T *srcData, T *dstData) { | |||
| T *p1Buff = nullptr; | |||
| T *p2Buff = nullptr; | |||
| for (int c = 0; c < filterC; ++c) { | |||
| for (int k = 0; k < filterK; ++k) { | |||
| for (int h = 0; h < filterH; ++h) { | |||
| for (int w = 0; w < filterW; ++w) { | |||
| p1Buff = srcData + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w)); | |||
| if (type == kCKHW2HWCK) { | |||
| p2Buff = dstData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k)); | |||
| } else if (type == kCKHW2KHWC) { | |||
| p2Buff = dstData + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c)); | |||
| } else { | |||
| p2Buff = dstData + ((h * filterW * filterK * filterC) + (w * filterK * filterC) + (k * filterC) + (c)); | |||
| } | |||
| *p2Buff = *p1Buff; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| if (weightData == nullptr) { | |||
| MS_LOG(ERROR) << "weightData is nullptr"; | |||
| return RET_ERROR; | |||
| template <typename T> | |||
| static void TransKCHW(kTransFilterType type, int32_t filterK, int32_t filterC, int32_t filterH, int32_t filterW, | |||
| T *srcData, T *dstData) { | |||
| T *p1Buff = nullptr; | |||
| T *p2Buff = nullptr; | |||
| for (int k = 0; k < filterK; ++k) { | |||
| for (int c = 0; c < filterC; ++c) { | |||
| for (int h = 0; h < filterH; ++h) { | |||
| for (int w = 0; w < filterW; ++w) { | |||
| p1Buff = srcData + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w)); | |||
| if (type == kKCHW2HWCK) { | |||
| p2Buff = dstData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k)); | |||
| } else if (type == kKCHW2KHWC) { | |||
| p2Buff = dstData + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c)); | |||
| } else if (type == kKCHW2CKHW) { | |||
| p2Buff = dstData + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w)); | |||
| } else { | |||
| p2Buff = dstData + ((h * filterW * filterK * filterC) + (w * filterK * filterC) + (k * filterC) + (c)); | |||
| } | |||
| *p2Buff = *p1Buff; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| template <typename T> | |||
| static void TransCHWK(kTransFilterType type, int32_t filterK, int32_t filterC, int32_t filterH, int32_t filterW, | |||
| T *srcData, T *dstData) { | |||
| T *p1Buff = nullptr; | |||
| T *p2Buff = nullptr; | |||
| for (int c = 0; c < filterC; ++c) { | |||
| for (int h = 0; h < filterH; ++h) { | |||
| for (int w = 0; w < filterW; ++w) { | |||
| for (int k = 0; k < filterK; ++k) { | |||
| p1Buff = srcData + ((c * filterH * filterW * filterK) + (h * filterW * filterK) + (w * filterK) + (k)); | |||
| if (type == kCHWK2HWCK) { | |||
| p2Buff = dstData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k)); | |||
| } else { | |||
| p2Buff = dstData + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c)); | |||
| } | |||
| *p2Buff = *p1Buff; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| template <typename T> | |||
| static void TransHWCK(kTransFilterType type, int32_t filterK, int32_t filterC, int32_t filterH, int32_t filterW, | |||
| T *srcData, T *dstData) { | |||
| T *p1Buff = nullptr; | |||
| T *p2Buff = nullptr; | |||
| switch (type) { | |||
| case kCHWK2HWCK: | |||
| case kCHWK2KHWC: { | |||
| for (int h = 0; h < filterH; ++h) { | |||
| for (int w = 0; w < filterW; ++w) { | |||
| for (int c = 0; c < filterC; ++c) { | |||
| for (int h = 0; h < filterH; ++h) { | |||
| for (int w = 0; w < filterW; ++w) { | |||
| for (int k = 0; k < filterK; ++k) { | |||
| p1Buff = weightData + ((c * filterH * filterW * filterK) + (h * filterW * filterK) + (w * filterK) + (k)); | |||
| if (type == kCHWK2HWCK) { | |||
| p2Buff = | |||
| buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k)); | |||
| } else { | |||
| p2Buff = | |||
| buf.get() + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c)); | |||
| } | |||
| *p2Buff = *p1Buff; | |||
| } | |||
| for (int k = 0; k < filterK; ++k) { | |||
| p1Buff = srcData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k)); | |||
| if (type == kHWCK2KCHW) { | |||
| p2Buff = dstData + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w)); | |||
| } else { | |||
| p2Buff = dstData + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w)); | |||
| } | |||
| *p2Buff = *p1Buff; | |||
| } | |||
| } | |||
| } break; | |||
| case kKHWC2HWCK: { | |||
| for (int k = 0; k < filterK; ++k) { | |||
| for (int h = 0; h < filterH; ++h) { | |||
| for (int w = 0; w < filterW; ++w) { | |||
| for (int c = 0; c < filterC; ++c) { | |||
| p1Buff = weightData + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c)); | |||
| p2Buff = buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k)); | |||
| *p2Buff = *p1Buff; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| template <typename T> | |||
| static void TransHWKC(kTransFilterType type, int32_t filterK, int32_t filterC, int32_t filterH, int32_t filterW, | |||
| T *srcData, T *dstData) { | |||
| T *p1Buff = nullptr; | |||
| T *p2Buff = nullptr; | |||
| for (int h = 0; h < filterH; ++h) { | |||
| for (int w = 0; w < filterW; ++w) { | |||
| for (int c = 0; c < filterC; ++c) { | |||
| for (int k = 0; k < filterK; ++k) { | |||
| p1Buff = srcData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (k * filterC) + (c)); | |||
| if (type == kHWKC2KCHW) { | |||
| p2Buff = dstData + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w)); | |||
| } else { | |||
| p2Buff = dstData + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w)); | |||
| } | |||
| *p2Buff = *p1Buff; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| template <typename T> | |||
| static void TransNHWC(kTransFilterType type, int32_t filterK, int32_t filterC, int32_t filterH, int32_t filterW, | |||
| T *srcData, T *dstData) { | |||
| T *p1Buff = nullptr; | |||
| T *p2Buff = nullptr; | |||
| for (int k = 0; k < filterK; ++k) { | |||
| for (int h = 0; h < filterH; ++h) { | |||
| for (int w = 0; w < filterW; ++w) { | |||
| for (int c = 0; c < filterC; ++c) { | |||
| p1Buff = srcData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (k * filterC) + (c)); | |||
| if (type == kNHWC2HWCK) { | |||
| p2Buff = dstData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k)); | |||
| } else if (type == kNHWC2CKHW) { | |||
| p2Buff = dstData + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w)); | |||
| } else { | |||
| p2Buff = dstData + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w)); | |||
| } | |||
| *p2Buff = *p1Buff; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| template <typename T> | |||
| static STATUS TransFilterData(kTransFilterType type, int32_t filterK, int32_t filterC, int32_t filterH, int32_t filterW, | |||
| T *srcData, T *dstData) { | |||
| switch (type) { | |||
| case kCHWK2HWCK: | |||
| case kCHWK2KHWC: { | |||
| TransCHWK(type, filterK, filterC, filterH, filterW, srcData, dstData); | |||
| } break; | |||
| case kKHWC2HWCK: { | |||
| TransKHWC2HWCK(filterK, filterC, filterH, filterW, srcData, dstData); | |||
| } break; | |||
| case kKCHW2HWCK: | |||
| case kKCHW2CKHW: | |||
| case kKCHW2KHWC: | |||
| case kKCHW2HWKC: { | |||
| for (int k = 0; k < filterK; ++k) { | |||
| for (int c = 0; c < filterC; ++c) { | |||
| for (int h = 0; h < filterH; ++h) { | |||
| for (int w = 0; w < filterW; ++w) { | |||
| p1Buff = weightData + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w)); | |||
| if (type == kKCHW2HWCK) { | |||
| p2Buff = | |||
| buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k)); | |||
| } else if (type == kKCHW2KHWC) { | |||
| p2Buff = | |||
| buf.get() + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c)); | |||
| } else if (type == kKCHW2CKHW) { | |||
| p2Buff = | |||
| buf.get() + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w)); | |||
| } else { | |||
| p2Buff = | |||
| buf.get() + ((h * filterW * filterK * filterC) + (w * filterK * filterC) + (k * filterC) + (c)); | |||
| } | |||
| *p2Buff = *p1Buff; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| TransKCHW(type, filterK, filterC, filterH, filterW, srcData, dstData); | |||
| } break; | |||
| case kCKHW2HWCK: | |||
| case kCKHW2KHWC: | |||
| case kCKHW2HWKC: { | |||
| for (int c = 0; c < filterC; ++c) { | |||
| for (int k = 0; k < filterK; ++k) { | |||
| for (int h = 0; h < filterH; ++h) { | |||
| for (int w = 0; w < filterW; ++w) { | |||
| p1Buff = weightData + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w)); | |||
| if (type == kCKHW2HWCK) { | |||
| p2Buff = | |||
| buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k)); | |||
| } else if (type == kCKHW2KHWC) { | |||
| p2Buff = | |||
| buf.get() + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c)); | |||
| } else { | |||
| p2Buff = | |||
| buf.get() + ((h * filterW * filterK * filterC) + (w * filterK * filterC) + (k * filterC) + (c)); | |||
| } | |||
| *p2Buff = *p1Buff; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| TransCKHW(type, filterK, filterC, filterH, filterW, srcData, dstData); | |||
| } break; | |||
| case kHWCK2KCHW: | |||
| case kHWCK2CKHW: { | |||
| for (int h = 0; h < filterH; ++h) { | |||
| for (int w = 0; w < filterW; ++w) { | |||
| for (int c = 0; c < filterC; ++c) { | |||
| for (int k = 0; k < filterK; ++k) { | |||
| p1Buff = weightData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k)); | |||
| if (type == kHWCK2KCHW) { | |||
| p2Buff = | |||
| buf.get() + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w)); | |||
| } else { | |||
| p2Buff = | |||
| buf.get() + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w)); | |||
| } | |||
| *p2Buff = *p1Buff; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| TransHWCK(type, filterK, filterC, filterH, filterW, srcData, dstData); | |||
| } break; | |||
| case kHWKC2KCHW: | |||
| case kHWKC2CKHW: { | |||
| for (int h = 0; h < filterH; ++h) { | |||
| for (int w = 0; w < filterW; ++w) { | |||
| for (int c = 0; c < filterC; ++c) { | |||
| for (int k = 0; k < filterK; ++k) { | |||
| p1Buff = weightData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (k * filterC) + (c)); | |||
| if (type == kHWKC2KCHW) { | |||
| p2Buff = | |||
| buf.get() + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w)); | |||
| } else { | |||
| p2Buff = | |||
| buf.get() + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w)); | |||
| } | |||
| *p2Buff = *p1Buff; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| TransHWKC(type, filterK, filterC, filterH, filterW, srcData, dstData); | |||
| } break; | |||
| case kNHWC2HWCK: | |||
| case kNHWC2KCHW: | |||
| case kNHWC2CKHW: { | |||
| for (int k = 0; k < filterK; ++k) { | |||
| for (int h = 0; h < filterH; ++h) { | |||
| for (int w = 0; w < filterW; ++w) { | |||
| for (int c = 0; c < filterC; ++c) { | |||
| p1Buff = weightData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (k * filterC) + (c)); | |||
| if (type == kNHWC2HWCK) { | |||
| p2Buff = | |||
| buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k)); | |||
| } else if (type == kNHWC2CKHW) { | |||
| p2Buff = | |||
| buf.get() + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w)); | |||
| } else { | |||
| p2Buff = | |||
| buf.get() + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w)); | |||
| } | |||
| *p2Buff = *p1Buff; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| TransNHWC(type, filterK, filterC, filterH, filterW, srcData, dstData); | |||
| } break; | |||
| case kKHWC2CHWK: { | |||
| for (int k = 0; k < filterK; ++k) { | |||
| for (int h = 0; h < filterH; ++h) { | |||
| for (int w = 0; w < filterW; ++w) { | |||
| for (int c = 0; c < filterC; ++c) { | |||
| p1Buff = weightData + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c)); | |||
| p2Buff = buf.get() + ((c * filterK * filterH * filterW) + (h * filterK * filterW) + (w * filterK) + (k)); | |||
| *p2Buff = *p1Buff; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| TransKHWC2CHWK(filterK, filterC, filterH, filterW, srcData, dstData); | |||
| } break; | |||
| default: { | |||
| MS_LOG(ERROR) << "Unsupported transFilterType: " << type; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| template <typename T> | |||
| static STATUS TransFilterData(schema::TensorT *tensor, kTransFilterType type, int32_t filterK, int32_t filterC, | |||
| int32_t filterH, int32_t filterW) { | |||
| MS_ASSERT(tensor != nullptr); | |||
| int count = filterH * filterW * filterC * filterK; | |||
| if (count <= 0) { | |||
| MS_LOG(ERROR) << "Dim size invalid"; | |||
| return RET_ERROR; | |||
| } | |||
| std::unique_ptr<T[]> buf(new (std::nothrow) T[count]); | |||
| if (buf == nullptr) { | |||
| MS_LOG(ERROR) << "new buf failed"; | |||
| return RET_ERROR; | |||
| } | |||
| void *originWeightDate = tensor->data.data(); | |||
| T *weightData = static_cast<T *>(originWeightDate); | |||
| if (weightData == nullptr) { | |||
| MS_LOG(ERROR) << "weightData is nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| if (TransFilterData(type, filterK, filterC, filterH, filterW, weightData, buf.get()) != RET_OK) { | |||
| MS_LOG(ERROR) << "TransFilterData failed"; | |||
| return RET_ERROR; | |||
| } | |||
| auto ret = ::memcpy_s(tensor->data.data(), count * sizeof(T), buf.get(), count * sizeof(T)); | |||
| if (ret != EOK) { | |||
| @@ -19,22 +19,11 @@ | |||
| #include <map> | |||
| #include <vector> | |||
| #include "tools/converter/parser/tf/tf_node_parser_registry.h" | |||
| #include "tools/common/node_util.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| template <typename T> | |||
| int CreateOperator(const std::unique_ptr<schema::PrimitiveT> &primitive, schema::PrimitiveType type) { | |||
| auto attr = std::make_unique<T>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new attr failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| primitive->value.type = type; | |||
| primitive->value.value = attr.release(); | |||
| return RET_OK; | |||
| } | |||
| STATUS TFArithmeticSelfParser::Parse(const tensorflow::NodeDef &tf_op, | |||
| const std::map<string, const tensorflow::NodeDef *> &tf_node_map, | |||
| PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) { | |||
| @@ -61,6 +50,12 @@ STATUS TFArithmeticSelfParser::Parse(const tensorflow::NodeDef &tf_op, | |||
| status = CreateOperator<schema::LogT>(primitive, schema::PrimitiveType_Log); | |||
| } else if (tf_op.op() == "Sqrt") { | |||
| status = CreateOperator<schema::SqrtT>(primitive, schema::PrimitiveType_Sqrt); | |||
| } else if (tf_op.op() == "Cos") { | |||
| status = CreateOperator<schema::CosT>(primitive, schema::PrimitiveType_Cos); | |||
| } else if (tf_op.op() == "Sin") { | |||
| status = CreateOperator<schema::SinT>(primitive, schema::PrimitiveType_Sin); | |||
| } else if (tf_op.op() == "Square") { | |||
| status = CreateOperator<schema::SquareT>(primitive, schema::PrimitiveType_Square); | |||
| } else if (tf_op.op() == "Pow") { | |||
| status = CreateOperator<schema::PowerT>(primitive, schema::PrimitiveType_Power); | |||
| } | |||
| @@ -81,6 +76,9 @@ STATUS TFArithmeticSelfParser::Parse(const tensorflow::NodeDef &tf_op, | |||
| } | |||
| return status; | |||
| } | |||
| TFNodeRegistrar g_tfCosParser("Cos", new TFArithmeticSelfParser()); | |||
| TFNodeRegistrar g_tfSinParser("Sin", new TFArithmeticSelfParser()); | |||
| TFNodeRegistrar g_tfSquareParser("Square", new TFArithmeticSelfParser()); | |||
| TFNodeRegistrar g_tfCeilParser("Ceil", new TFArithmeticSelfParser()); | |||
| TFNodeRegistrar g_tfExpParser("Exp", new TFArithmeticSelfParser()); | |||
| TFNodeRegistrar g_tfFloorParser("Floor", new TFArithmeticSelfParser()); | |||
| @@ -0,0 +1,64 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "tools/converter/parser/tf/tf_batch_matmul_parser.h" | |||
| #include <string> | |||
| #include <memory> | |||
| #include <map> | |||
| #include <vector> | |||
| #include "tools/converter/parser/tf/tf_node_parser_registry.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS TFBatchMatmulParser::Parse(const tensorflow::NodeDef &tf_op, | |||
| const std::map<string, const tensorflow::NodeDef *> &tf_node_map, | |||
| PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) { | |||
| if (primitiveC == nullptr || output_size == nullptr) { | |||
| MS_LOG(ERROR) << "primitiveC is nullptr"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (primitive == nullptr) { | |||
| MS_LOG(ERROR) << "primitive is nullptr"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| auto attr = std::make_unique<schema::BatchMatMulT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| tensorflow::AttrValue attr_value; | |||
| TensorFlowUtils::FindAttrValue(tf_op, "adj_x", &attr_value); | |||
| attr->adj_x = attr_value.b(); | |||
| attr->adj_y = attr_value.b(); | |||
| primitive->value.type = schema::PrimitiveType_BatchMatMul; | |||
| primitive->value.value = attr.release(); | |||
| *primitiveC = PrimitiveC::Create(primitive.release()); | |||
| if (*primitiveC == nullptr) { | |||
| MS_LOG(ERROR) << "primitiveC is nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| *output_size = 1; | |||
| for (int i = 0; i < tf_op.input_size(); i++) { | |||
| inputs->emplace_back(tf_op.input(i)); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| TFNodeRegistrar g_tfBatchMatMulParser("BatchMatMul", new TFBatchMatmulParser()); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,37 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_BATCH_MATMUL_PARSER_H_ | |||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_BATCH_MATMUL_PARSER_H_ | |||
| #include <string> | |||
| #include <memory> | |||
| #include <map> | |||
| #include <vector> | |||
| #include "tools/converter/parser/tf/tf_node_parser.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| class TFBatchMatmulParser : public TFNodeParser { | |||
| public: | |||
| TFBatchMatmulParser() = default; | |||
| ~TFBatchMatmulParser() override = default; | |||
| STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map<string, const tensorflow::NodeDef *> &tf_node_map, | |||
| PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) override; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_BATCH_MATMUL_PARSER_H_ | |||
| @@ -0,0 +1,59 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "tools/converter/parser/tf/tf_is_finite_parser.h" | |||
| #include <map> | |||
| #include <memory> | |||
| #include <string> | |||
| #include <vector> | |||
| #include "tools/converter/parser/tf/tf_node_parser_registry.h" | |||
| #include "tools/common/node_util.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS TFIsFiniteParser::Parse(const tensorflow::NodeDef &tf_op, | |||
| const std::map<string, const tensorflow::NodeDef *> &tf_node_map, | |||
| PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) { | |||
| if (primitiveC == nullptr || output_size == nullptr) { | |||
| MS_LOG(ERROR) << "primitiveC is nullptr"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (primitive == nullptr) { | |||
| MS_LOG(ERROR) << "primitive is nullptr"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| int status = CreateOperator<schema::IsFiniteT>(primitive, schema::PrimitiveType_IsFinite); | |||
| if (status != RET_OK) { | |||
| return status; | |||
| } | |||
| *primitiveC = PrimitiveC::Create(primitive.release()); | |||
| if (*primitiveC == nullptr) { | |||
| MS_LOG(ERROR) << "primitiveC is nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| *output_size = 1; | |||
| for (int i = 0; i < tf_op.input_size(); i++) { | |||
| inputs->emplace_back(tf_op.input(i)); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| TFNodeRegistrar g_tf_is_finite_parser("IsFinite", new TFIsFiniteParser()); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,37 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_IS_FINITE_PARSER_H_ | |||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_IS_FINITE_PARSER_H_ | |||
| #include <map> | |||
| #include <memory> | |||
| #include <string> | |||
| #include <vector> | |||
| #include "tools/converter/parser/tf/tf_node_parser.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| class TFIsFiniteParser : public TFNodeParser { | |||
| public: | |||
| TFIsFiniteParser() = default; | |||
| ~TFIsFiniteParser() override = default; | |||
| STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map<string, const tensorflow::NodeDef *> &tf_node_map, | |||
| PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) override; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_IS_FINITE_PARSER_H_ | |||
| @@ -19,6 +19,7 @@ | |||
| #include <string> | |||
| #include <vector> | |||
| #include "tools/converter/parser/tf/tf_node_parser_registry.h" | |||
| #include "tools/common/node_util.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| @@ -36,37 +37,19 @@ STATUS TFLogicalParser::Parse(const tensorflow::NodeDef &tf_op, | |||
| MS_LOG(ERROR) << "primitive is nullptr"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| int status = RET_ERROR; | |||
| if (tf_op.op() == "LogicalAnd") { | |||
| auto attr = std::make_unique<schema::LogicalAndT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| primitive->value.type = schema::PrimitiveType_LogicalAnd; | |||
| primitive->value.value = attr.release(); | |||
| *primitiveC = PrimitiveC::Create(primitive.release()); | |||
| status = CreateOperator<schema::LogicalAndT>(primitive, schema::PrimitiveType_LogicalAnd); | |||
| } else if (tf_op.op() == "LogicalOr") { | |||
| auto attr = std::make_unique<schema::LogicalOrT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| primitive->value.type = schema::PrimitiveType_LogicalOr; | |||
| primitive->value.value = attr.release(); | |||
| *primitiveC = PrimitiveC::Create(primitive.release()); | |||
| status = CreateOperator<schema::LogicalOrT>(primitive, schema::PrimitiveType_LogicalOr); | |||
| } else if (tf_op.op() == "LogicalNot") { | |||
| auto attr = std::make_unique<schema::LogicalNotT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| primitive->value.type = schema::PrimitiveType_LogicalNot; | |||
| primitive->value.value = attr.release(); | |||
| *primitiveC = PrimitiveC::Create(primitive.release()); | |||
| } else { | |||
| MS_LOG(ERROR) << tf_op.op() << " is not supported."; | |||
| return RET_ERROR; | |||
| status = CreateOperator<schema::LogicalNotT>(primitive, schema::PrimitiveType_LogicalNot); | |||
| } | |||
| if (status != RET_OK) { | |||
| return status; | |||
| } | |||
| *primitiveC = PrimitiveC::Create(primitive.release()); | |||
| if (*primitiveC == nullptr) { | |||
| MS_LOG(ERROR) << "primitiveC is nullptr"; | |||
| return RET_ERROR; | |||
| @@ -79,8 +62,8 @@ STATUS TFLogicalParser::Parse(const tensorflow::NodeDef &tf_op, | |||
| return RET_OK; | |||
| } | |||
| TFNodeRegistrar g_tfLogicalAndParser("LogicalAnd", new TFLogicalParser()); | |||
| TFNodeRegistrar g_tfLogicalOrParser("LogicalOr", new TFLogicalParser()); | |||
| TFNodeRegistrar g_tfLogicalNotParser("LogicalNot", new TFLogicalParser()); | |||
| TFNodeRegistrar g_tfLogicalOrParser("LogicalOr", new TFLogicalParser()); | |||
| TFNodeRegistrar g_tfLogicalAndParser("LogicalAnd", new TFLogicalParser()); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,60 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "tools/converter/parser/tf/tf_zeros_like_parser.h" | |||
| #include <string> | |||
| #include <memory> | |||
| #include <map> | |||
| #include <vector> | |||
| #include "tools/converter/parser/tf/tf_node_parser_registry.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS TFZerosLikeParser::Parse(const tensorflow::NodeDef &tf_op, | |||
| const std::map<string, const tensorflow::NodeDef *> &tf_node_map, | |||
| PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) { | |||
| if (primitiveC == nullptr || output_size == nullptr) { | |||
| MS_LOG(ERROR) << "primitiveC is nullptr"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (primitive == nullptr) { | |||
| MS_LOG(ERROR) << "primitive is nullptr"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| auto attr = std::make_unique<schema::ZerosLikeT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| primitive->value.type = schema::PrimitiveType_ZerosLike; | |||
| primitive->value.value = attr.release(); | |||
| *primitiveC = PrimitiveC::Create(primitive.release()); | |||
| if (*primitiveC == nullptr) { | |||
| MS_LOG(ERROR) << "primitiveC is nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| *output_size = tf_op.input_size(); | |||
| for (int i = 0; i < tf_op.input_size(); i++) { | |||
| inputs->emplace_back(tf_op.input(i)); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| TFNodeRegistrar g_tfZerosLikeParser("ZerosLike", new TFZerosLikeParser()); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,37 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_ZERO_LIKE_PARSER_H_ | |||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_ZERO_LIKE_PARSER_H_ | |||
| #include <string> | |||
| #include <memory> | |||
| #include <map> | |||
| #include <vector> | |||
| #include "tools/converter/parser/tf/tf_node_parser.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| class TFZerosLikeParser : public TFNodeParser { | |||
| public: | |||
| TFZerosLikeParser() = default; | |||
| ~TFZerosLikeParser() override = default; | |||
| STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map<string, const tensorflow::NodeDef *> &tf_node_map, | |||
| PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) override; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_ZERO_LIKE_PARSER_H_ | |||