From: @cjh9368 Reviewed-by: Signed-off-by:tags/v1.2.0-rc1
| @@ -273,7 +273,9 @@ union PrimitiveType { | |||||
| RandomStandardNormal, | RandomStandardNormal, | ||||
| CropAndResize, | CropAndResize, | ||||
| Erf, | Erf, | ||||
| StridedSliceGrad | |||||
| StridedSliceGrad, | |||||
| IsFinite, | |||||
| BatchMatMul, | |||||
| } | } | ||||
| enum QuantType: int { | enum QuantType: int { | ||||
| @@ -1274,4 +1274,12 @@ table StridedSliceGrad { | |||||
| } | } | ||||
| table Erf { | table Erf { | ||||
| } | |||||
| table IsFinite { | |||||
| } | |||||
| table BatchMatMul { | |||||
| adj_x : bool = false; | |||||
| adj_y : bool = false; | |||||
| } | } | ||||
| @@ -0,0 +1,81 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "src/ops/batch_matmul.h" | |||||
| #include <memory> | |||||
| #ifndef PRIMITIVE_WRITEABLE | |||||
| #include "src/ops/ops_register.h" | |||||
| #endif | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| #ifdef PRIMITIVE_WRITEABLE | |||||
| bool BatchMatMul::GetAdjX() const { return this->primitive_->value.AsBatchMatMul()->adj_x; } | |||||
| void BatchMatMul::SetAdjX(bool adj_x) { this->primitive_->value.AsBatchMatMul()->adj_x = adj_x; } | |||||
| bool BatchMatMul::GetAdjY() const { return this->primitive_->value.AsBatchMatMul()->adj_y; } | |||||
| void BatchMatMul::SetAdjY(bool adj_y) { this->primitive_->value.AsBatchMatMul()->adj_y = adj_y; } | |||||
| int BatchMatMul::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { | |||||
| if (this->primitive_ == nullptr) { | |||||
| this->primitive_ = new (std::nothrow) schema::PrimitiveT; | |||||
| if (this->primitive_ == nullptr) { | |||||
| MS_LOG(ERROR) << "new primitiveT failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| this->primitive_->value.type = schema::PrimitiveType_BatchMatMul; | |||||
| } | |||||
| if (this->primitive_->value.type != schema::PrimitiveType_BatchMatMul) { | |||||
| MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (this->primitive_->value.value == nullptr) { | |||||
| auto attr = new (std::nothrow) schema::BatchMatMulT(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new FusedBatchMatMulT failed"; | |||||
| delete this->primitive_; | |||||
| this->primitive_ = nullptr; | |||||
| return RET_ERROR; | |||||
| } | |||||
| attr->adj_x = GetValue<bool>(prim.GetAttr("adj_x")); | |||||
| attr->adj_y = GetValue<bool>(prim.GetAttr("adj_y")); | |||||
| this->primitive_->value.value = attr; | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| #else | |||||
| int BatchMatMul::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { | |||||
| MS_ASSERT(nullptr != primitive); | |||||
| MS_ASSERT(nullptr != fbb); | |||||
| auto val_offset = schema::CreateBatchMatMul(*fbb); | |||||
| auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_BatchMatMul, val_offset.o); | |||||
| fbb->Finish(prim_offset); | |||||
| return RET_OK; | |||||
| } | |||||
| bool BatchMatMul::GetAdjX() const { return this->primitive_->value_as_BatchMatMul()->adj_x(); } | |||||
| bool BatchMatMul::GetAdjY() const { return this->primitive_->value_as_BatchMatMul()->adj_y(); } | |||||
| PrimitiveC *BatchMatMulCreator(const schema::Primitive *primitive) { | |||||
| return PrimitiveC::NewPrimitiveC<BatchMatMul>(primitive); | |||||
| } | |||||
| Registry BatchMatMulRegistry(schema::PrimitiveType_BatchMatMul, BatchMatMulCreator); | |||||
| #endif | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,46 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef LITE_MINDSPORE_LITE_C_OPS_BATCH_MATMUL_H_ | |||||
| #define LITE_MINDSPORE_LITE_C_OPS_BATCH_MATMUL_H_ | |||||
| #include <vector> | |||||
| #include <set> | |||||
| #include <cmath> | |||||
| #include "src/ops/primitive_c.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| class BatchMatMul : public PrimitiveC { | |||||
| public: | |||||
| BatchMatMul() = default; | |||||
| ~BatchMatMul() = default; | |||||
| #ifdef PRIMITIVE_WRITEABLE | |||||
| MS_DECLARE_PARENT(BatchMatMul, PrimitiveC); | |||||
| explicit BatchMatMul(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} | |||||
| int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override; | |||||
| void SetAdjX(bool adj_x); | |||||
| void SetAdjY(bool adj_y); | |||||
| #else | |||||
| int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; | |||||
| #endif | |||||
| bool GetAdjX() const; | |||||
| bool GetAdjY() const; | |||||
| }; | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| #endif // LITE_MINDSPORE_LITE_C_OPS_BATCH_MATMUL_H_ | |||||
| @@ -0,0 +1,33 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "src/ops/primitive_c.h" | |||||
| #ifndef LITE_MINDSPORE_LITE_C_OPS_IS_FINITE_H_ | |||||
| #define LITE_MINDSPORE_LITE_C_OPS_IS_FINITE_H_ | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| class IsFinite : public PrimitiveC { | |||||
| public: | |||||
| MS_DECLARE_PARENT(IsFinite, PrimitiveC); | |||||
| IsFinite() = default; | |||||
| ~IsFinite() = default; | |||||
| explicit IsFinite(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} | |||||
| }; | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| #endif // LITE_MINDSPORE_LITE_C_OPS_IS_FINITE_H_ | |||||
| @@ -170,6 +170,8 @@ | |||||
| #include "src/ops/crop_and_resize.h" | #include "src/ops/crop_and_resize.h" | ||||
| #include "src/ops/nonzero.h" | #include "src/ops/nonzero.h" | ||||
| #include "src/ops/erf.h" | #include "src/ops/erf.h" | ||||
| #include "src/ops/is_finite.h" | |||||
| #include "src/ops/batch_matmul.h" | |||||
| #ifdef SUPPORT_TRAIN | #ifdef SUPPORT_TRAIN | ||||
| #include "src/ops/neg_grad.h" | #include "src/ops/neg_grad.h" | ||||
| @@ -665,7 +667,6 @@ std::shared_ptr<PrimitiveC> PrimitiveC::Create(const Primitive &prim, const std: | |||||
| return NewPrimitiveC<ArgMax>(prim, inputs, quantType); | return NewPrimitiveC<ArgMax>(prim, inputs, quantType); | ||||
| } else if (op_type == "Gelu") { | } else if (op_type == "Gelu") { | ||||
| return NewPrimitiveC<GeLU>(prim, inputs, quantType); | return NewPrimitiveC<GeLU>(prim, inputs, quantType); | ||||
| #ifdef SUPPORT_TRAIN | #ifdef SUPPORT_TRAIN | ||||
| } else if (op_type == "SoftmaxCrossEntropyWithLogits") { | } else if (op_type == "SoftmaxCrossEntropyWithLogits") { | ||||
| return NewPrimitiveC<SoftmaxCrossEntropy>(prim, inputs, quantType); | return NewPrimitiveC<SoftmaxCrossEntropy>(prim, inputs, quantType); | ||||
| @@ -1034,6 +1035,10 @@ PrimitiveC *PrimitiveC::Create(mindspore::schema::PrimitiveT *primitive) { | |||||
| return new (std::nothrow) NonZero(primitive); | return new (std::nothrow) NonZero(primitive); | ||||
| case schema::PrimitiveType_Erf: | case schema::PrimitiveType_Erf: | ||||
| return new (std::nothrow) Erf(primitive); | return new (std::nothrow) Erf(primitive); | ||||
| case schema::PrimitiveType_IsFinite: | |||||
| return new (std::nothrow) IsFinite(primitive); | |||||
| case schema::PrimitiveType_BatchMatMul: | |||||
| return new (std::nothrow) BatchMatMul(primitive); | |||||
| #ifdef SUPPORT_TRAIN | #ifdef SUPPORT_TRAIN | ||||
| case schema::PrimitiveType_ActivationGrad: | case schema::PrimitiveType_ActivationGrad: | ||||
| return new (std::nothrow) ActivationGrad(primitive); | return new (std::nothrow) ActivationGrad(primitive); | ||||
| @@ -29,6 +29,18 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| template <typename T> | |||||
| int CreateOperator(const std::unique_ptr<schema::PrimitiveT> &primitive, schema::PrimitiveType type) { | |||||
| auto attr = std::make_unique<T>(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new attr failed"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| primitive->value.type = type; | |||||
| primitive->value.value = attr.release(); | |||||
| return RET_OK; | |||||
| } | |||||
| using STATUS = int; | using STATUS = int; | ||||
| STATUS BroadCastQuantParam(schema::MetaGraphT *graphT, const std::unique_ptr<schema::CNodeT> &node); | STATUS BroadCastQuantParam(schema::MetaGraphT *graphT, const std::unique_ptr<schema::CNodeT> &node); | ||||
| @@ -92,197 +104,252 @@ STATUS SetFilterDim(schema::TensorT *tensor, kTransFilterType type, int32_t filt | |||||
| int32_t filterW); | int32_t filterW); | ||||
| template <typename T> | template <typename T> | ||||
| static STATUS TransFilterData(schema::TensorT *tensor, kTransFilterType type, int32_t filterK, int32_t filterC, | |||||
| int32_t filterH, int32_t filterW) { | |||||
| MS_ASSERT(tensor != nullptr); | |||||
| int count = filterH * filterW * filterC * filterK; | |||||
| if (count <= 0) { | |||||
| MS_LOG(ERROR) << "Dim size invalid"; | |||||
| return RET_ERROR; | |||||
| static void TransKHWC2CHWK(int32_t filterK, int32_t filterC, int32_t filterH, int32_t filterW, T *srcData, T *dstData) { | |||||
| T *p1Buff = nullptr; | |||||
| T *p2Buff = nullptr; | |||||
| for (int k = 0; k < filterK; ++k) { | |||||
| for (int h = 0; h < filterH; ++h) { | |||||
| for (int w = 0; w < filterW; ++w) { | |||||
| for (int c = 0; c < filterC; ++c) { | |||||
| p1Buff = srcData + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c)); | |||||
| p2Buff = dstData + ((c * filterK * filterH * filterW) + (h * filterK * filterW) + (w * filterK) + (k)); | |||||
| *p2Buff = *p1Buff; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | } | ||||
| std::unique_ptr<T[]> buf(new (std::nothrow) T[count]); | |||||
| if (buf == nullptr) { | |||||
| MS_LOG(ERROR) << "new buf failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| template <typename T> | |||||
| static void TransKHWC2HWCK(int32_t filterK, int32_t filterC, int32_t filterH, int32_t filterW, T *srcData, T *dstData) { | |||||
| T *p1Buff = nullptr; | |||||
| T *p2Buff = nullptr; | |||||
| for (int k = 0; k < filterK; ++k) { | |||||
| for (int h = 0; h < filterH; ++h) { | |||||
| for (int w = 0; w < filterW; ++w) { | |||||
| for (int c = 0; c < filterC; ++c) { | |||||
| p1Buff = srcData + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c)); | |||||
| p2Buff = dstData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k)); | |||||
| *p2Buff = *p1Buff; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | } | ||||
| } | |||||
| void *originWeightDate = tensor->data.data(); | |||||
| T *weightData = static_cast<T *>(originWeightDate); | |||||
| template <typename T> | |||||
| static void TransCKHW(kTransFilterType type, int32_t filterK, int32_t filterC, int32_t filterH, int32_t filterW, | |||||
| T *srcData, T *dstData) { | |||||
| T *p1Buff = nullptr; | |||||
| T *p2Buff = nullptr; | |||||
| for (int c = 0; c < filterC; ++c) { | |||||
| for (int k = 0; k < filterK; ++k) { | |||||
| for (int h = 0; h < filterH; ++h) { | |||||
| for (int w = 0; w < filterW; ++w) { | |||||
| p1Buff = srcData + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w)); | |||||
| if (type == kCKHW2HWCK) { | |||||
| p2Buff = dstData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k)); | |||||
| } else if (type == kCKHW2KHWC) { | |||||
| p2Buff = dstData + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c)); | |||||
| } else { | |||||
| p2Buff = dstData + ((h * filterW * filterK * filterC) + (w * filterK * filterC) + (k * filterC) + (c)); | |||||
| } | |||||
| *p2Buff = *p1Buff; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| if (weightData == nullptr) { | |||||
| MS_LOG(ERROR) << "weightData is nullptr"; | |||||
| return RET_ERROR; | |||||
| template <typename T> | |||||
| static void TransKCHW(kTransFilterType type, int32_t filterK, int32_t filterC, int32_t filterH, int32_t filterW, | |||||
| T *srcData, T *dstData) { | |||||
| T *p1Buff = nullptr; | |||||
| T *p2Buff = nullptr; | |||||
| for (int k = 0; k < filterK; ++k) { | |||||
| for (int c = 0; c < filterC; ++c) { | |||||
| for (int h = 0; h < filterH; ++h) { | |||||
| for (int w = 0; w < filterW; ++w) { | |||||
| p1Buff = srcData + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w)); | |||||
| if (type == kKCHW2HWCK) { | |||||
| p2Buff = dstData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k)); | |||||
| } else if (type == kKCHW2KHWC) { | |||||
| p2Buff = dstData + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c)); | |||||
| } else if (type == kKCHW2CKHW) { | |||||
| p2Buff = dstData + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w)); | |||||
| } else { | |||||
| p2Buff = dstData + ((h * filterW * filterK * filterC) + (w * filterK * filterC) + (k * filterC) + (c)); | |||||
| } | |||||
| *p2Buff = *p1Buff; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| template <typename T> | |||||
| static void TransCHWK(kTransFilterType type, int32_t filterK, int32_t filterC, int32_t filterH, int32_t filterW, | |||||
| T *srcData, T *dstData) { | |||||
| T *p1Buff = nullptr; | |||||
| T *p2Buff = nullptr; | |||||
| for (int c = 0; c < filterC; ++c) { | |||||
| for (int h = 0; h < filterH; ++h) { | |||||
| for (int w = 0; w < filterW; ++w) { | |||||
| for (int k = 0; k < filterK; ++k) { | |||||
| p1Buff = srcData + ((c * filterH * filterW * filterK) + (h * filterW * filterK) + (w * filterK) + (k)); | |||||
| if (type == kCHWK2HWCK) { | |||||
| p2Buff = dstData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k)); | |||||
| } else { | |||||
| p2Buff = dstData + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c)); | |||||
| } | |||||
| *p2Buff = *p1Buff; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | } | ||||
| } | |||||
| template <typename T> | |||||
| static void TransHWCK(kTransFilterType type, int32_t filterK, int32_t filterC, int32_t filterH, int32_t filterW, | |||||
| T *srcData, T *dstData) { | |||||
| T *p1Buff = nullptr; | T *p1Buff = nullptr; | ||||
| T *p2Buff = nullptr; | T *p2Buff = nullptr; | ||||
| switch (type) { | |||||
| case kCHWK2HWCK: | |||||
| case kCHWK2KHWC: { | |||||
| for (int h = 0; h < filterH; ++h) { | |||||
| for (int w = 0; w < filterW; ++w) { | |||||
| for (int c = 0; c < filterC; ++c) { | for (int c = 0; c < filterC; ++c) { | ||||
| for (int h = 0; h < filterH; ++h) { | |||||
| for (int w = 0; w < filterW; ++w) { | |||||
| for (int k = 0; k < filterK; ++k) { | |||||
| p1Buff = weightData + ((c * filterH * filterW * filterK) + (h * filterW * filterK) + (w * filterK) + (k)); | |||||
| if (type == kCHWK2HWCK) { | |||||
| p2Buff = | |||||
| buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k)); | |||||
| } else { | |||||
| p2Buff = | |||||
| buf.get() + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c)); | |||||
| } | |||||
| *p2Buff = *p1Buff; | |||||
| } | |||||
| for (int k = 0; k < filterK; ++k) { | |||||
| p1Buff = srcData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k)); | |||||
| if (type == kHWCK2KCHW) { | |||||
| p2Buff = dstData + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w)); | |||||
| } else { | |||||
| p2Buff = dstData + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w)); | |||||
| } | } | ||||
| *p2Buff = *p1Buff; | |||||
| } | } | ||||
| } | } | ||||
| } break; | |||||
| case kKHWC2HWCK: { | |||||
| for (int k = 0; k < filterK; ++k) { | |||||
| for (int h = 0; h < filterH; ++h) { | |||||
| for (int w = 0; w < filterW; ++w) { | |||||
| for (int c = 0; c < filterC; ++c) { | |||||
| p1Buff = weightData + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c)); | |||||
| p2Buff = buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k)); | |||||
| *p2Buff = *p1Buff; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| template <typename T> | |||||
| static void TransHWKC(kTransFilterType type, int32_t filterK, int32_t filterC, int32_t filterH, int32_t filterW, | |||||
| T *srcData, T *dstData) { | |||||
| T *p1Buff = nullptr; | |||||
| T *p2Buff = nullptr; | |||||
| for (int h = 0; h < filterH; ++h) { | |||||
| for (int w = 0; w < filterW; ++w) { | |||||
| for (int c = 0; c < filterC; ++c) { | |||||
| for (int k = 0; k < filterK; ++k) { | |||||
| p1Buff = srcData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (k * filterC) + (c)); | |||||
| if (type == kHWKC2KCHW) { | |||||
| p2Buff = dstData + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w)); | |||||
| } else { | |||||
| p2Buff = dstData + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w)); | |||||
| } | |||||
| *p2Buff = *p1Buff; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| template <typename T> | |||||
| static void TransNHWC(kTransFilterType type, int32_t filterK, int32_t filterC, int32_t filterH, int32_t filterW, | |||||
| T *srcData, T *dstData) { | |||||
| T *p1Buff = nullptr; | |||||
| T *p2Buff = nullptr; | |||||
| for (int k = 0; k < filterK; ++k) { | |||||
| for (int h = 0; h < filterH; ++h) { | |||||
| for (int w = 0; w < filterW; ++w) { | |||||
| for (int c = 0; c < filterC; ++c) { | |||||
| p1Buff = srcData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (k * filterC) + (c)); | |||||
| if (type == kNHWC2HWCK) { | |||||
| p2Buff = dstData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k)); | |||||
| } else if (type == kNHWC2CKHW) { | |||||
| p2Buff = dstData + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w)); | |||||
| } else { | |||||
| p2Buff = dstData + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w)); | |||||
| } | } | ||||
| *p2Buff = *p1Buff; | |||||
| } | } | ||||
| } | } | ||||
| } | |||||
| } | |||||
| } | |||||
| template <typename T> | |||||
| static STATUS TransFilterData(kTransFilterType type, int32_t filterK, int32_t filterC, int32_t filterH, int32_t filterW, | |||||
| T *srcData, T *dstData) { | |||||
| switch (type) { | |||||
| case kCHWK2HWCK: | |||||
| case kCHWK2KHWC: { | |||||
| TransCHWK(type, filterK, filterC, filterH, filterW, srcData, dstData); | |||||
| } break; | |||||
| case kKHWC2HWCK: { | |||||
| TransKHWC2HWCK(filterK, filterC, filterH, filterW, srcData, dstData); | |||||
| } break; | } break; | ||||
| case kKCHW2HWCK: | case kKCHW2HWCK: | ||||
| case kKCHW2CKHW: | case kKCHW2CKHW: | ||||
| case kKCHW2KHWC: | case kKCHW2KHWC: | ||||
| case kKCHW2HWKC: { | case kKCHW2HWKC: { | ||||
| for (int k = 0; k < filterK; ++k) { | |||||
| for (int c = 0; c < filterC; ++c) { | |||||
| for (int h = 0; h < filterH; ++h) { | |||||
| for (int w = 0; w < filterW; ++w) { | |||||
| p1Buff = weightData + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w)); | |||||
| if (type == kKCHW2HWCK) { | |||||
| p2Buff = | |||||
| buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k)); | |||||
| } else if (type == kKCHW2KHWC) { | |||||
| p2Buff = | |||||
| buf.get() + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c)); | |||||
| } else if (type == kKCHW2CKHW) { | |||||
| p2Buff = | |||||
| buf.get() + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w)); | |||||
| } else { | |||||
| p2Buff = | |||||
| buf.get() + ((h * filterW * filterK * filterC) + (w * filterK * filterC) + (k * filterC) + (c)); | |||||
| } | |||||
| *p2Buff = *p1Buff; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| TransKCHW(type, filterK, filterC, filterH, filterW, srcData, dstData); | |||||
| } break; | } break; | ||||
| case kCKHW2HWCK: | case kCKHW2HWCK: | ||||
| case kCKHW2KHWC: | case kCKHW2KHWC: | ||||
| case kCKHW2HWKC: { | case kCKHW2HWKC: { | ||||
| for (int c = 0; c < filterC; ++c) { | |||||
| for (int k = 0; k < filterK; ++k) { | |||||
| for (int h = 0; h < filterH; ++h) { | |||||
| for (int w = 0; w < filterW; ++w) { | |||||
| p1Buff = weightData + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w)); | |||||
| if (type == kCKHW2HWCK) { | |||||
| p2Buff = | |||||
| buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k)); | |||||
| } else if (type == kCKHW2KHWC) { | |||||
| p2Buff = | |||||
| buf.get() + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c)); | |||||
| } else { | |||||
| p2Buff = | |||||
| buf.get() + ((h * filterW * filterK * filterC) + (w * filterK * filterC) + (k * filterC) + (c)); | |||||
| } | |||||
| *p2Buff = *p1Buff; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| TransCKHW(type, filterK, filterC, filterH, filterW, srcData, dstData); | |||||
| } break; | } break; | ||||
| case kHWCK2KCHW: | case kHWCK2KCHW: | ||||
| case kHWCK2CKHW: { | case kHWCK2CKHW: { | ||||
| for (int h = 0; h < filterH; ++h) { | |||||
| for (int w = 0; w < filterW; ++w) { | |||||
| for (int c = 0; c < filterC; ++c) { | |||||
| for (int k = 0; k < filterK; ++k) { | |||||
| p1Buff = weightData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k)); | |||||
| if (type == kHWCK2KCHW) { | |||||
| p2Buff = | |||||
| buf.get() + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w)); | |||||
| } else { | |||||
| p2Buff = | |||||
| buf.get() + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w)); | |||||
| } | |||||
| *p2Buff = *p1Buff; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| TransHWCK(type, filterK, filterC, filterH, filterW, srcData, dstData); | |||||
| } break; | } break; | ||||
| case kHWKC2KCHW: | case kHWKC2KCHW: | ||||
| case kHWKC2CKHW: { | case kHWKC2CKHW: { | ||||
| for (int h = 0; h < filterH; ++h) { | |||||
| for (int w = 0; w < filterW; ++w) { | |||||
| for (int c = 0; c < filterC; ++c) { | |||||
| for (int k = 0; k < filterK; ++k) { | |||||
| p1Buff = weightData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (k * filterC) + (c)); | |||||
| if (type == kHWKC2KCHW) { | |||||
| p2Buff = | |||||
| buf.get() + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w)); | |||||
| } else { | |||||
| p2Buff = | |||||
| buf.get() + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w)); | |||||
| } | |||||
| *p2Buff = *p1Buff; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| TransHWKC(type, filterK, filterC, filterH, filterW, srcData, dstData); | |||||
| } break; | } break; | ||||
| case kNHWC2HWCK: | case kNHWC2HWCK: | ||||
| case kNHWC2KCHW: | case kNHWC2KCHW: | ||||
| case kNHWC2CKHW: { | case kNHWC2CKHW: { | ||||
| for (int k = 0; k < filterK; ++k) { | |||||
| for (int h = 0; h < filterH; ++h) { | |||||
| for (int w = 0; w < filterW; ++w) { | |||||
| for (int c = 0; c < filterC; ++c) { | |||||
| p1Buff = weightData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (k * filterC) + (c)); | |||||
| if (type == kNHWC2HWCK) { | |||||
| p2Buff = | |||||
| buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k)); | |||||
| } else if (type == kNHWC2CKHW) { | |||||
| p2Buff = | |||||
| buf.get() + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w)); | |||||
| } else { | |||||
| p2Buff = | |||||
| buf.get() + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w)); | |||||
| } | |||||
| *p2Buff = *p1Buff; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| TransNHWC(type, filterK, filterC, filterH, filterW, srcData, dstData); | |||||
| } break; | } break; | ||||
| case kKHWC2CHWK: { | case kKHWC2CHWK: { | ||||
| for (int k = 0; k < filterK; ++k) { | |||||
| for (int h = 0; h < filterH; ++h) { | |||||
| for (int w = 0; w < filterW; ++w) { | |||||
| for (int c = 0; c < filterC; ++c) { | |||||
| p1Buff = weightData + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c)); | |||||
| p2Buff = buf.get() + ((c * filterK * filterH * filterW) + (h * filterK * filterW) + (w * filterK) + (k)); | |||||
| *p2Buff = *p1Buff; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| TransKHWC2CHWK(filterK, filterC, filterH, filterW, srcData, dstData); | |||||
| } break; | } break; | ||||
| default: { | default: { | ||||
| MS_LOG(ERROR) << "Unsupported transFilterType: " << type; | MS_LOG(ERROR) << "Unsupported transFilterType: " << type; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| } | } | ||||
| return RET_OK; | |||||
| } | |||||
| template <typename T> | |||||
| static STATUS TransFilterData(schema::TensorT *tensor, kTransFilterType type, int32_t filterK, int32_t filterC, | |||||
| int32_t filterH, int32_t filterW) { | |||||
| MS_ASSERT(tensor != nullptr); | |||||
| int count = filterH * filterW * filterC * filterK; | |||||
| if (count <= 0) { | |||||
| MS_LOG(ERROR) << "Dim size invalid"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| std::unique_ptr<T[]> buf(new (std::nothrow) T[count]); | |||||
| if (buf == nullptr) { | |||||
| MS_LOG(ERROR) << "new buf failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| void *originWeightDate = tensor->data.data(); | |||||
| T *weightData = static_cast<T *>(originWeightDate); | |||||
| if (weightData == nullptr) { | |||||
| MS_LOG(ERROR) << "weightData is nullptr"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (TransFilterData(type, filterK, filterC, filterH, filterW, weightData, buf.get()) != RET_OK) { | |||||
| MS_LOG(ERROR) << "TransFilterData failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto ret = ::memcpy_s(tensor->data.data(), count * sizeof(T), buf.get(), count * sizeof(T)); | auto ret = ::memcpy_s(tensor->data.data(), count * sizeof(T), buf.get(), count * sizeof(T)); | ||||
| if (ret != EOK) { | if (ret != EOK) { | ||||
| @@ -19,22 +19,11 @@ | |||||
| #include <map> | #include <map> | ||||
| #include <vector> | #include <vector> | ||||
| #include "tools/converter/parser/tf/tf_node_parser_registry.h" | #include "tools/converter/parser/tf/tf_node_parser_registry.h" | ||||
| #include "tools/common/node_util.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| template <typename T> | |||||
| int CreateOperator(const std::unique_ptr<schema::PrimitiveT> &primitive, schema::PrimitiveType type) { | |||||
| auto attr = std::make_unique<T>(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new attr failed"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| primitive->value.type = type; | |||||
| primitive->value.value = attr.release(); | |||||
| return RET_OK; | |||||
| } | |||||
| STATUS TFArithmeticSelfParser::Parse(const tensorflow::NodeDef &tf_op, | STATUS TFArithmeticSelfParser::Parse(const tensorflow::NodeDef &tf_op, | ||||
| const std::map<string, const tensorflow::NodeDef *> &tf_node_map, | const std::map<string, const tensorflow::NodeDef *> &tf_node_map, | ||||
| PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) { | PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) { | ||||
| @@ -61,6 +50,12 @@ STATUS TFArithmeticSelfParser::Parse(const tensorflow::NodeDef &tf_op, | |||||
| status = CreateOperator<schema::LogT>(primitive, schema::PrimitiveType_Log); | status = CreateOperator<schema::LogT>(primitive, schema::PrimitiveType_Log); | ||||
| } else if (tf_op.op() == "Sqrt") { | } else if (tf_op.op() == "Sqrt") { | ||||
| status = CreateOperator<schema::SqrtT>(primitive, schema::PrimitiveType_Sqrt); | status = CreateOperator<schema::SqrtT>(primitive, schema::PrimitiveType_Sqrt); | ||||
| } else if (tf_op.op() == "Cos") { | |||||
| status = CreateOperator<schema::CosT>(primitive, schema::PrimitiveType_Cos); | |||||
| } else if (tf_op.op() == "Sin") { | |||||
| status = CreateOperator<schema::SinT>(primitive, schema::PrimitiveType_Sin); | |||||
| } else if (tf_op.op() == "Square") { | |||||
| status = CreateOperator<schema::SquareT>(primitive, schema::PrimitiveType_Square); | |||||
| } else if (tf_op.op() == "Pow") { | } else if (tf_op.op() == "Pow") { | ||||
| status = CreateOperator<schema::PowerT>(primitive, schema::PrimitiveType_Power); | status = CreateOperator<schema::PowerT>(primitive, schema::PrimitiveType_Power); | ||||
| } | } | ||||
| @@ -81,6 +76,9 @@ STATUS TFArithmeticSelfParser::Parse(const tensorflow::NodeDef &tf_op, | |||||
| } | } | ||||
| return status; | return status; | ||||
| } | } | ||||
| TFNodeRegistrar g_tfCosParser("Cos", new TFArithmeticSelfParser()); | |||||
| TFNodeRegistrar g_tfSinParser("Sin", new TFArithmeticSelfParser()); | |||||
| TFNodeRegistrar g_tfSquareParser("Square", new TFArithmeticSelfParser()); | |||||
| TFNodeRegistrar g_tfCeilParser("Ceil", new TFArithmeticSelfParser()); | TFNodeRegistrar g_tfCeilParser("Ceil", new TFArithmeticSelfParser()); | ||||
| TFNodeRegistrar g_tfExpParser("Exp", new TFArithmeticSelfParser()); | TFNodeRegistrar g_tfExpParser("Exp", new TFArithmeticSelfParser()); | ||||
| TFNodeRegistrar g_tfFloorParser("Floor", new TFArithmeticSelfParser()); | TFNodeRegistrar g_tfFloorParser("Floor", new TFArithmeticSelfParser()); | ||||
| @@ -0,0 +1,64 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "tools/converter/parser/tf/tf_batch_matmul_parser.h" | |||||
| #include <string> | |||||
| #include <memory> | |||||
| #include <map> | |||||
| #include <vector> | |||||
| #include "tools/converter/parser/tf/tf_node_parser_registry.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| STATUS TFBatchMatmulParser::Parse(const tensorflow::NodeDef &tf_op, | |||||
| const std::map<string, const tensorflow::NodeDef *> &tf_node_map, | |||||
| PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) { | |||||
| if (primitiveC == nullptr || output_size == nullptr) { | |||||
| MS_LOG(ERROR) << "primitiveC is nullptr"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "primitive is nullptr"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| auto attr = std::make_unique<schema::BatchMatMulT>(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new op failed"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| tensorflow::AttrValue attr_value; | |||||
| TensorFlowUtils::FindAttrValue(tf_op, "adj_x", &attr_value); | |||||
| attr->adj_x = attr_value.b(); | |||||
| attr->adj_y = attr_value.b(); | |||||
| primitive->value.type = schema::PrimitiveType_BatchMatMul; | |||||
| primitive->value.value = attr.release(); | |||||
| *primitiveC = PrimitiveC::Create(primitive.release()); | |||||
| if (*primitiveC == nullptr) { | |||||
| MS_LOG(ERROR) << "primitiveC is nullptr"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| *output_size = 1; | |||||
| for (int i = 0; i < tf_op.input_size(); i++) { | |||||
| inputs->emplace_back(tf_op.input(i)); | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| TFNodeRegistrar g_tfBatchMatMulParser("BatchMatMul", new TFBatchMatmulParser()); | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,37 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_BATCH_MATMUL_PARSER_H_ | |||||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_BATCH_MATMUL_PARSER_H_ | |||||
| #include <string> | |||||
| #include <memory> | |||||
| #include <map> | |||||
| #include <vector> | |||||
| #include "tools/converter/parser/tf/tf_node_parser.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| class TFBatchMatmulParser : public TFNodeParser { | |||||
| public: | |||||
| TFBatchMatmulParser() = default; | |||||
| ~TFBatchMatmulParser() override = default; | |||||
| STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map<string, const tensorflow::NodeDef *> &tf_node_map, | |||||
| PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) override; | |||||
| }; | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_BATCH_MATMUL_PARSER_H_ | |||||
| @@ -0,0 +1,59 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "tools/converter/parser/tf/tf_is_finite_parser.h" | |||||
| #include <map> | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include "tools/converter/parser/tf/tf_node_parser_registry.h" | |||||
| #include "tools/common/node_util.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| STATUS TFIsFiniteParser::Parse(const tensorflow::NodeDef &tf_op, | |||||
| const std::map<string, const tensorflow::NodeDef *> &tf_node_map, | |||||
| PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) { | |||||
| if (primitiveC == nullptr || output_size == nullptr) { | |||||
| MS_LOG(ERROR) << "primitiveC is nullptr"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "primitive is nullptr"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| int status = CreateOperator<schema::IsFiniteT>(primitive, schema::PrimitiveType_IsFinite); | |||||
| if (status != RET_OK) { | |||||
| return status; | |||||
| } | |||||
| *primitiveC = PrimitiveC::Create(primitive.release()); | |||||
| if (*primitiveC == nullptr) { | |||||
| MS_LOG(ERROR) << "primitiveC is nullptr"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| *output_size = 1; | |||||
| for (int i = 0; i < tf_op.input_size(); i++) { | |||||
| inputs->emplace_back(tf_op.input(i)); | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| TFNodeRegistrar g_tf_is_finite_parser("IsFinite", new TFIsFiniteParser()); | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,37 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_IS_FINITE_PARSER_H_ | |||||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_IS_FINITE_PARSER_H_ | |||||
| #include <map> | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include "tools/converter/parser/tf/tf_node_parser.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| class TFIsFiniteParser : public TFNodeParser { | |||||
| public: | |||||
| TFIsFiniteParser() = default; | |||||
| ~TFIsFiniteParser() override = default; | |||||
| STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map<string, const tensorflow::NodeDef *> &tf_node_map, | |||||
| PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) override; | |||||
| }; | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_IS_FINITE_PARSER_H_ | |||||
| @@ -19,6 +19,7 @@ | |||||
| #include <string> | #include <string> | ||||
| #include <vector> | #include <vector> | ||||
| #include "tools/converter/parser/tf/tf_node_parser_registry.h" | #include "tools/converter/parser/tf/tf_node_parser_registry.h" | ||||
| #include "tools/common/node_util.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| @@ -36,37 +37,19 @@ STATUS TFLogicalParser::Parse(const tensorflow::NodeDef &tf_op, | |||||
| MS_LOG(ERROR) << "primitive is nullptr"; | MS_LOG(ERROR) << "primitive is nullptr"; | ||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| int status = RET_ERROR; | |||||
| if (tf_op.op() == "LogicalAnd") { | if (tf_op.op() == "LogicalAnd") { | ||||
| auto attr = std::make_unique<schema::LogicalAndT>(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new op failed"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| primitive->value.type = schema::PrimitiveType_LogicalAnd; | |||||
| primitive->value.value = attr.release(); | |||||
| *primitiveC = PrimitiveC::Create(primitive.release()); | |||||
| status = CreateOperator<schema::LogicalAndT>(primitive, schema::PrimitiveType_LogicalAnd); | |||||
| } else if (tf_op.op() == "LogicalOr") { | } else if (tf_op.op() == "LogicalOr") { | ||||
| auto attr = std::make_unique<schema::LogicalOrT>(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new op failed"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| primitive->value.type = schema::PrimitiveType_LogicalOr; | |||||
| primitive->value.value = attr.release(); | |||||
| *primitiveC = PrimitiveC::Create(primitive.release()); | |||||
| status = CreateOperator<schema::LogicalOrT>(primitive, schema::PrimitiveType_LogicalOr); | |||||
| } else if (tf_op.op() == "LogicalNot") { | } else if (tf_op.op() == "LogicalNot") { | ||||
| auto attr = std::make_unique<schema::LogicalNotT>(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new op failed"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| primitive->value.type = schema::PrimitiveType_LogicalNot; | |||||
| primitive->value.value = attr.release(); | |||||
| *primitiveC = PrimitiveC::Create(primitive.release()); | |||||
| } else { | |||||
| MS_LOG(ERROR) << tf_op.op() << " is not supported."; | |||||
| return RET_ERROR; | |||||
| status = CreateOperator<schema::LogicalNotT>(primitive, schema::PrimitiveType_LogicalNot); | |||||
| } | } | ||||
| if (status != RET_OK) { | |||||
| return status; | |||||
| } | |||||
| *primitiveC = PrimitiveC::Create(primitive.release()); | |||||
| if (*primitiveC == nullptr) { | if (*primitiveC == nullptr) { | ||||
| MS_LOG(ERROR) << "primitiveC is nullptr"; | MS_LOG(ERROR) << "primitiveC is nullptr"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| @@ -79,8 +62,8 @@ STATUS TFLogicalParser::Parse(const tensorflow::NodeDef &tf_op, | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| TFNodeRegistrar g_tfLogicalAndParser("LogicalAnd", new TFLogicalParser()); | |||||
| TFNodeRegistrar g_tfLogicalOrParser("LogicalOr", new TFLogicalParser()); | |||||
| TFNodeRegistrar g_tfLogicalNotParser("LogicalNot", new TFLogicalParser()); | TFNodeRegistrar g_tfLogicalNotParser("LogicalNot", new TFLogicalParser()); | ||||
| TFNodeRegistrar g_tfLogicalOrParser("LogicalOr", new TFLogicalParser()); | |||||
| TFNodeRegistrar g_tfLogicalAndParser("LogicalAnd", new TFLogicalParser()); | |||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -0,0 +1,60 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "tools/converter/parser/tf/tf_zeros_like_parser.h" | |||||
| #include <string> | |||||
| #include <memory> | |||||
| #include <map> | |||||
| #include <vector> | |||||
| #include "tools/converter/parser/tf/tf_node_parser_registry.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| STATUS TFZerosLikeParser::Parse(const tensorflow::NodeDef &tf_op, | |||||
| const std::map<string, const tensorflow::NodeDef *> &tf_node_map, | |||||
| PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) { | |||||
| if (primitiveC == nullptr || output_size == nullptr) { | |||||
| MS_LOG(ERROR) << "primitiveC is nullptr"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "primitive is nullptr"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| auto attr = std::make_unique<schema::ZerosLikeT>(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new op failed"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| primitive->value.type = schema::PrimitiveType_ZerosLike; | |||||
| primitive->value.value = attr.release(); | |||||
| *primitiveC = PrimitiveC::Create(primitive.release()); | |||||
| if (*primitiveC == nullptr) { | |||||
| MS_LOG(ERROR) << "primitiveC is nullptr"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| *output_size = tf_op.input_size(); | |||||
| for (int i = 0; i < tf_op.input_size(); i++) { | |||||
| inputs->emplace_back(tf_op.input(i)); | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| TFNodeRegistrar g_tfZerosLikeParser("ZerosLike", new TFZerosLikeParser()); | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,37 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_ZERO_LIKE_PARSER_H_ | |||||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_ZERO_LIKE_PARSER_H_ | |||||
| #include <string> | |||||
| #include <memory> | |||||
| #include <map> | |||||
| #include <vector> | |||||
| #include "tools/converter/parser/tf/tf_node_parser.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| class TFZerosLikeParser : public TFNodeParser { | |||||
| public: | |||||
| TFZerosLikeParser() = default; | |||||
| ~TFZerosLikeParser() override = default; | |||||
| STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map<string, const tensorflow::NodeDef *> &tf_node_map, | |||||
| PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) override; | |||||
| }; | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_ZERO_LIKE_PARSER_H_ | |||||