From: @yeyunpeng2020 Reviewed-by: @HilbertDavid Signed-off-by: @HilbertDavidtags/v1.2.0-rc1
| @@ -276,6 +276,8 @@ union PrimitiveType { | |||
| StridedSliceGrad, | |||
| IsFinite, | |||
| BatchMatMul, | |||
| LinSpace, | |||
| UniformReal | |||
| } | |||
| enum QuantType: int { | |||
| @@ -1281,6 +1281,14 @@ table IsFinite { | |||
| } | |||
| table BatchMatMul { | |||
| adj_x : bool = false; | |||
| adj_y : bool = false; | |||
| transpose_a :bool; | |||
| transpose_b :bool; | |||
| } | |||
| table LinSpace { | |||
| } | |||
| table UniformReal { | |||
| seed : int; | |||
| seed2 : int; | |||
| } | |||
| @@ -13,8 +13,8 @@ | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "src/ops/batch_matmul.h" | |||
| #include <memory> | |||
| #ifndef PRIMITIVE_WRITEABLE | |||
| #include "src/ops/ops_register.h" | |||
| #endif | |||
| @@ -22,14 +22,17 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| #ifdef PRIMITIVE_WRITEABLE | |||
| bool BatchMatMul::GetAdjX() const { return this->primitive_->value.AsBatchMatMul()->adj_x; } | |||
| void BatchMatMul::SetAdjX(bool adj_x) { this->primitive_->value.AsBatchMatMul()->adj_x = adj_x; } | |||
| bool BatchMatMul::GetTransposeA() const { return this->primitive_->value.AsBatchMatMul()->transpose_a; } | |||
| bool BatchMatMul::GetAdjY() const { return this->primitive_->value.AsBatchMatMul()->adj_y; } | |||
| bool BatchMatMul::GetTransposeB() const { return this->primitive_->value.AsBatchMatMul()->transpose_b; } | |||
| void BatchMatMul::SetAdjY(bool adj_y) { this->primitive_->value.AsBatchMatMul()->adj_y = adj_y; } | |||
| void BatchMatMul::SetTransposeA(bool transpose_a) { | |||
| this->primitive_->value.AsBatchMatMul()->transpose_a = transpose_a; | |||
| } | |||
| void BatchMatMul::SetTransposeB(bool transpose_b) { | |||
| this->primitive_->value.AsBatchMatMul()->transpose_b = transpose_b; | |||
| } | |||
| int BatchMatMul::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { | |||
| if (this->primitive_ == nullptr) { | |||
| this->primitive_ = new (std::nothrow) schema::PrimitiveT; | |||
| @@ -51,31 +54,32 @@ int BatchMatMul::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> | |||
| this->primitive_ = nullptr; | |||
| return RET_ERROR; | |||
| } | |||
| attr->adj_x = GetValue<bool>(prim.GetAttr("adj_x")); | |||
| attr->adj_y = GetValue<bool>(prim.GetAttr("adj_y")); | |||
| attr->transpose_a = GetValue<bool>(prim.GetAttr("transpose_a")); | |||
| attr->transpose_b = GetValue<bool>(prim.GetAttr("transpose_b")); | |||
| this->primitive_->value.value = attr; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| #else | |||
| bool BatchMatMul::GetTransposeA() const { return this->primitive_->value_as_BatchMatMul()->transpose_a(); } | |||
| bool BatchMatMul::GetTransposeB() const { return this->primitive_->value_as_BatchMatMul()->transpose_b(); } | |||
| int BatchMatMul::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { | |||
| MS_ASSERT(nullptr != primitive); | |||
| MS_ASSERT(nullptr != fbb); | |||
| auto val_offset = schema::CreateBatchMatMul(*fbb); | |||
| auto attr = primitive->value_as_BatchMatMul(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "value_as_Add return nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| auto val_offset = schema::CreateBatchMatMul(*fbb, attr->transpose_a(), attr->transpose_b()); | |||
| auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_BatchMatMul, val_offset.o); | |||
| fbb->Finish(prim_offset); | |||
| return RET_OK; | |||
| } | |||
| bool BatchMatMul::GetAdjX() const { return this->primitive_->value_as_BatchMatMul()->adj_x(); } | |||
| bool BatchMatMul::GetAdjY() const { return this->primitive_->value_as_BatchMatMul()->adj_y(); } | |||
| PrimitiveC *BatchMatMulCreator(const schema::Primitive *primitive) { | |||
| return PrimitiveC::NewPrimitiveC<BatchMatMul>(primitive); | |||
| } | |||
| Registry BatchMatMulRegistry(schema::PrimitiveType_BatchMatMul, BatchMatMulCreator); | |||
| #endif | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -32,15 +32,14 @@ class BatchMatMul : public PrimitiveC { | |||
| MS_DECLARE_PARENT(BatchMatMul, PrimitiveC); | |||
| explicit BatchMatMul(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} | |||
| int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override; | |||
| void SetAdjX(bool adj_x); | |||
| void SetAdjY(bool adj_y); | |||
| void SetTransposeA(bool transpose_a); | |||
| void SetTransposeB(bool transpose_b); | |||
| #else | |||
| int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; | |||
| #endif | |||
| bool GetAdjX() const; | |||
| bool GetAdjY() const; | |||
| bool GetTransposeA() const; | |||
| bool GetTransposeB() const; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // LITE_MINDSPORE_LITE_C_OPS_BATCH_MATMUL_H_ | |||
| @@ -0,0 +1,53 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "src/ops/lin_space.h" | |||
| #ifndef PRIMITIVE_WRITEABLE | |||
| #include "src/ops/ops_register.h" | |||
| #endif | |||
| namespace mindspore { | |||
| namespace lite { | |||
| #ifndef PRIMITIVE_WRITEABLE | |||
| int LinSpace::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { | |||
| MS_ASSERT(nullptr != primitive); | |||
| MS_ASSERT(nullptr != fbb); | |||
| auto val_offset = schema::CreateLinSpace(*fbb); | |||
| auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_LinSpace, val_offset.o); | |||
| fbb->Finish(prim_offset); | |||
| return RET_OK; | |||
| } | |||
| PrimitiveC *LinSpaceCreator(const schema::Primitive *primitive) { | |||
| return PrimitiveC::NewPrimitiveC<LinSpace>(primitive); | |||
| } | |||
| Registry LinSpaceRegistry(schema::PrimitiveType_LinSpace, LinSpaceCreator); | |||
| #endif | |||
| int LinSpace::InferShape(std::vector<Tensor *> inputs, std::vector<Tensor *> outputs) { | |||
| auto input = inputs.front(); | |||
| MS_ASSERT(input != nullptr); | |||
| auto output = outputs.front(); | |||
| MS_ASSERT(output != nullptr); | |||
| output->set_data_type(input->data_type()); | |||
| output->set_format(input->format()); | |||
| auto num = inputs.at(2)->data_c(); | |||
| if (num == nullptr) { | |||
| return RET_INFER_INVALID; | |||
| } | |||
| output->set_shape({reinterpret_cast<int *>(num)[0]}); | |||
| return RET_OK; | |||
| } | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,42 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <vector> | |||
| #include <set> | |||
| #include <cmath> | |||
| #include "src/ops/primitive_c.h" | |||
| #ifndef LITE_MINDSPORE_LITE_C_OPS_LIN_SPACE_H_ | |||
| #define LITE_MINDSPORE_LITE_C_OPS_LIN_SPACE_H_ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| class LinSpace : public PrimitiveC { | |||
| public: | |||
| LinSpace() = default; | |||
| ~LinSpace() = default; | |||
| #ifdef PRIMITIVE_WRITEABLE | |||
| MS_DECLARE_PARENT(LinSpace, PrimitiveC); | |||
| explicit LinSpace(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} | |||
| #else | |||
| int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; | |||
| #endif | |||
| int InferShape(std::vector<Tensor *> inputs, std::vector<Tensor *> outputs) override; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // LITE_MINDSPORE_LITE_C_OPS_LIN_SPACE_H_ | |||
| @@ -170,8 +170,11 @@ | |||
| #include "src/ops/crop_and_resize.h" | |||
| #include "src/ops/nonzero.h" | |||
| #include "src/ops/erf.h" | |||
| #include "src/ops/is_finite.h" | |||
| #include "src/ops/batch_matmul.h" | |||
| #include "src/ops/lin_space.h" | |||
| #include "src/ops/uniform_real.h" | |||
| #include "src/ops/rank.h" | |||
| #include "src/ops/is_finite.h" | |||
| #ifdef SUPPORT_TRAIN | |||
| #include "src/ops/neg_grad.h" | |||
| @@ -1047,6 +1050,12 @@ PrimitiveC *PrimitiveC::Create(mindspore::schema::PrimitiveT *primitive) { | |||
| return new (std::nothrow) IsFinite(primitive); | |||
| case schema::PrimitiveType_BatchMatMul: | |||
| return new (std::nothrow) BatchMatMul(primitive); | |||
| case schema::PrimitiveType_LinSpace: | |||
| return new (std::nothrow) LinSpace(primitive); | |||
| case schema::PrimitiveType_UniformReal: | |||
| return new (std::nothrow) UniformReal(primitive); | |||
| case schema::PrimitiveType_Rank: | |||
| return new (std::nothrow) Rank(primitive); | |||
| #ifdef SUPPORT_TRAIN | |||
| case schema::PrimitiveType_ActivationGrad: | |||
| return new (std::nothrow) ActivationGrad(primitive); | |||
| @@ -0,0 +1,101 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "src/ops/uniform_real.h" | |||
| #ifndef PRIMITIVE_WRITEABLE | |||
| #include "src/ops/ops_register.h" | |||
| #endif | |||
| namespace mindspore { | |||
| namespace lite { | |||
| #ifdef PRIMITIVE_WRITEABLE | |||
| int UniformReal::GetSeed() const { return this->primitive_->value.AsUniformReal()->seed; } | |||
| int UniformReal::GetSeed2() const { return this->primitive_->value.AsUniformReal()->seed2; } | |||
| int UniformReal::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { | |||
| if (this->primitive_ == nullptr) { | |||
| this->primitive_ = new (std::nothrow) schema::PrimitiveT; | |||
| if (this->primitive_ == nullptr) { | |||
| MS_LOG(ERROR) << "new primitiveT failed"; | |||
| return RET_ERROR; | |||
| } | |||
| this->primitive_->value.type = schema::PrimitiveType_UniformReal; | |||
| } | |||
| if (this->primitive_->value.type != schema::PrimitiveType_UniformReal) { | |||
| MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type; | |||
| return RET_ERROR; | |||
| } | |||
| if (this->primitive_->value.value == nullptr) { | |||
| auto attr = new (std::nothrow) schema::UniformRealT(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new primitiveT value failed"; | |||
| return RET_ERROR; | |||
| } | |||
| this->primitive_->value.value = attr; | |||
| if (this->primitive_->value.value == nullptr) { | |||
| MS_LOG(ERROR) << "primitive value is nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| #else | |||
| int UniformReal::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { | |||
| MS_ASSERT(nullptr != primitive); | |||
| MS_ASSERT(nullptr != fbb); | |||
| auto attr = primitive->value_as_UniformReal(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "value_as_UniformReal return nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| auto val_offset = schema::CreateUniformReal(*fbb, attr->seed(), attr->seed2()); | |||
| auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_UniformReal, val_offset.o); | |||
| fbb->Finish(prim_offset); | |||
| return RET_OK; | |||
| } | |||
| int UniformReal::GetSeed() const { return this->primitive_->value_as_UniformReal()->seed(); } | |||
| int UniformReal::GetSeed2() const { return this->primitive_->value_as_UniformReal()->seed2(); } | |||
| PrimitiveC *UniformRealCreator(const schema::Primitive *primitive) { | |||
| return PrimitiveC::NewPrimitiveC<UniformReal>(primitive); | |||
| } | |||
| Registry UniformRealRegistry(schema::PrimitiveType_UniformReal, UniformRealCreator); | |||
| #endif | |||
| int UniformReal::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { | |||
| if (!infer_flag()) { | |||
| return RET_INFER_INVALID; | |||
| } | |||
| auto input_data = static_cast<int32_t *>(inputs_[0]->data_c()); | |||
| if (input_data == nullptr) { | |||
| return RET_INFER_INVALID; | |||
| } | |||
| auto input_num = inputs_[0]->ElementsNum(); | |||
| std::vector<int> output_shape(input_num); | |||
| for (int i = 0; i < input_num; i++) { | |||
| output_shape[i] = input_data[i]; | |||
| } | |||
| outputs_[0]->set_shape(output_shape); | |||
| outputs_[0]->set_data_type(kNumberTypeFloat32); | |||
| return RET_OK; | |||
| } | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,46 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef LITE_MINDSPORE_LITE_C_OPS_UNIFORM_REAL_H_ | |||
| #define LITE_MINDSPORE_LITE_C_OPS_UNIFORM_REAL_H_ | |||
| #include <vector> | |||
| #include <set> | |||
| #include <cmath> | |||
| #include <memory> | |||
| #include "src/ops/primitive_c.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| class UniformReal : public PrimitiveC { | |||
| public: | |||
| UniformReal() = default; | |||
| ~UniformReal() = default; | |||
| #ifdef PRIMITIVE_WRITEABLE | |||
| MS_DECLARE_PARENT(UniformReal, PrimitiveC); | |||
| explicit UniformReal(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} | |||
| int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override; | |||
| #else | |||
| int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; | |||
| #endif | |||
| int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override; | |||
| int GetSeed() const; | |||
| int GetSeed2() const; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // LITE_MINDSPORE_LITE_C_OPS_UNIFORM_REAL_H_ | |||
| @@ -52,6 +52,8 @@ STATUS TFActivationParser::Parse(const tensorflow::NodeDef &tf_op, | |||
| attr->type = schema::ActivationType_TANH; | |||
| } else if (tf_op.op() == "LeakyRelu") { | |||
| attr->type = schema::ActivationType_LEAKY_RELU; | |||
| } else if (tf_op.op() == "Selu") { | |||
| attr->type = schema::ActivationType_SELU; | |||
| } else { | |||
| MS_LOG(ERROR) << "unsupported activation type:" << tf_op.op(); | |||
| return RET_ERROR; | |||
| @@ -63,7 +65,7 @@ STATUS TFActivationParser::Parse(const tensorflow::NodeDef &tf_op, | |||
| auto attr_leaky_relu = std::make_unique<schema::LeakyReLUT>(); | |||
| tensorflow::AttrValue attr_value; | |||
| if (!TensorFlowUtils::FindAttrValue(tf_op, "alpha", &attr_value)) { | |||
| MS_LOG(ERROR) << "The attribute alpha shoud be specified."; | |||
| MS_LOG(ERROR) << "The attribute alpha should be specified."; | |||
| return RET_ERROR; | |||
| } | |||
| attr_leaky_relu->negativeSlope = attr_value.f(); | |||
| @@ -85,5 +87,6 @@ TFNodeRegistrar g_tfRelu6Parser("Relu6", new TFActivationParser()); | |||
| TFNodeRegistrar g_tfSigmoidParser("Sigmoid", new TFActivationParser()); | |||
| TFNodeRegistrar g_tfTanhParser("Tanh", new TFActivationParser()); | |||
| TFNodeRegistrar g_tfLeakyReluParser("LeakyRelu", new TFActivationParser()); | |||
| TFNodeRegistrar g_tfSeLUParser("Selu", new TFActivationParser()); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -58,6 +58,11 @@ STATUS TFArithmeticSelfParser::Parse(const tensorflow::NodeDef &tf_op, | |||
| status = CreateOperator<schema::SquareT>(primitive, schema::PrimitiveType_Square); | |||
| } else if (tf_op.op() == "Pow") { | |||
| status = CreateOperator<schema::PowerT>(primitive, schema::PrimitiveType_Power); | |||
| } else if (tf_op.op() == "Abs") { | |||
| status = CreateOperator<schema::PowerT>(primitive, schema::PrimitiveType_Abs); | |||
| } else { | |||
| MS_LOG(ERROR) << "unsupported arithmetic self type:" << tf_op.op(); | |||
| return RET_ERROR; | |||
| } | |||
| if (status != RET_OK) { | |||
| return status; | |||
| @@ -85,5 +90,6 @@ TFNodeRegistrar g_tfFloorParser("Floor", new TFArithmeticSelfParser()); | |||
| TFNodeRegistrar g_tfLogParser("Log", new TFArithmeticSelfParser()); | |||
| TFNodeRegistrar g_tfSqrtParser("Sqrt", new TFArithmeticSelfParser()); | |||
| TFNodeRegistrar g_tfPowParser("Pow", new TFArithmeticSelfParser()); | |||
| TFNodeRegistrar g_tfAbsParser("Abs", new TFArithmeticSelfParser()); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -22,29 +22,35 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS TFBatchMatmulParser::Parse(const tensorflow::NodeDef &tf_op, | |||
| STATUS TFBatchMatMulParser::Parse(const tensorflow::NodeDef &tf_op, | |||
| const std::map<string, const tensorflow::NodeDef *> &tf_node_map, | |||
| PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) { | |||
| MS_LOG(DEBUG) << "TF BatchMatMulParser"; | |||
| if (primitiveC == nullptr || output_size == nullptr) { | |||
| MS_LOG(ERROR) << "primitiveC is nullptr"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (primitive == nullptr) { | |||
| MS_LOG(ERROR) << "primitive is nullptr"; | |||
| MS_LOG(ERROR) << "New PrimitiveT failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| auto attr = std::make_unique<schema::BatchMatMulT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| MS_LOG(ERROR) << "new attr failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| tensorflow::AttrValue attr_value; | |||
| TensorFlowUtils::FindAttrValue(tf_op, "adj_x", &attr_value); | |||
| attr->adj_x = attr_value.b(); | |||
| attr->adj_y = attr_value.b(); | |||
| if (!TensorFlowUtils::FindAttrValue(tf_op, "adj_x", &attr_value)) { | |||
| MS_LOG(ERROR) << "The begin_mask attr should be specified"; | |||
| return RET_ERROR; | |||
| } | |||
| attr->transpose_a = attr_value.b(); | |||
| if (!TensorFlowUtils::FindAttrValue(tf_op, "adj_y", &attr_value)) { | |||
| MS_LOG(ERROR) << "The begin_mask attr should be specified"; | |||
| return RET_ERROR; | |||
| } | |||
| attr->transpose_b = attr_value.b(); | |||
| primitive->value.type = schema::PrimitiveType_BatchMatMul; | |||
| primitive->value.value = attr.release(); | |||
| *primitiveC = PrimitiveC::Create(primitive.release()); | |||
| @@ -52,13 +58,15 @@ STATUS TFBatchMatmulParser::Parse(const tensorflow::NodeDef &tf_op, | |||
| MS_LOG(ERROR) << "primitiveC is nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| *output_size = 1; | |||
| for (int i = 0; i < tf_op.input_size(); i++) { | |||
| inputs->emplace_back(tf_op.input(i)); | |||
| for (int i = 0; i < tf_op.input_size(); ++i) { | |||
| auto status = AddOpInput(tf_op, i, inputs); | |||
| if (status != RET_OK) { | |||
| return status; | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| TFNodeRegistrar g_tfBatchMatMulParser("BatchMatMul", new TFBatchMatmulParser()); | |||
| TFNodeRegistrar g_tfBatchMatMulParser("BatchMatMul", new TFBatchMatMulParser()); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -23,15 +23,14 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| class TFBatchMatmulParser : public TFNodeParser { | |||
| class TFBatchMatMulParser : public TFNodeParser { | |||
| public: | |||
| TFBatchMatmulParser() = default; | |||
| ~TFBatchMatmulParser() override = default; | |||
| TFBatchMatMulParser() = default; | |||
| ~TFBatchMatMulParser() override = default; | |||
| STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map<string, const tensorflow::NodeDef *> &tf_node_map, | |||
| PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) override; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_BATCH_MATMUL_PARSER_H_ | |||
| @@ -0,0 +1,61 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "tools/converter/parser/tf/tf_linspace_parser.h" | |||
| #include <string> | |||
| #include <memory> | |||
| #include <map> | |||
| #include <vector> | |||
| #include "tools/converter/parser/tf/tf_node_parser_registry.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS TFLinSpaceParser::Parse(const tensorflow::NodeDef &tf_op, | |||
| const std::map<string, const tensorflow::NodeDef *> &tf_node_map, | |||
| PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) { | |||
| MS_LOG(DEBUG) << "TF LinSpaceParser"; | |||
| if (primitiveC == nullptr || output_size == nullptr) { | |||
| MS_LOG(ERROR) << "primitiveC is nullptr"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (primitive == nullptr) { | |||
| MS_LOG(ERROR) << "New PrimitiveT failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| auto attr = std::make_unique<schema::LinSpaceT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new attr failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| primitive->value.type = schema::PrimitiveType_LinSpace; | |||
| primitive->value.value = attr.release(); | |||
| *primitiveC = PrimitiveC::Create(primitive.release()); | |||
| if (*primitiveC == nullptr) { | |||
| MS_LOG(ERROR) << "primitiveC is nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| *output_size = 1; | |||
| for (int i = 0; i < tf_op.input_size(); ++i) { | |||
| auto status = AddOpInput(tf_op, i, inputs); | |||
| if (status != RET_OK) { | |||
| return status; | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| TFNodeRegistrar g_tfLinSpaceParser("LinSpace", new TFLinSpaceParser()); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,36 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_LIN_SPACE_PARSER_H_ | |||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_LIN_SPACE_PARSER_H_ | |||
| #include <string> | |||
| #include <memory> | |||
| #include <map> | |||
| #include <vector> | |||
| #include "tools/converter/parser/tf/tf_node_parser.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| class TFLinSpaceParser : public TFNodeParser { | |||
| public: | |||
| TFLinSpaceParser() = default; | |||
| ~TFLinSpaceParser() override = default; | |||
| STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map<string, const tensorflow::NodeDef *> &tf_node_map, | |||
| PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) override; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_LIN_SPACE_PARSER_H_ | |||
| @@ -0,0 +1,59 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "tools/converter/parser/tf/tf_rank_parser.h" | |||
| #include <string> | |||
| #include <memory> | |||
| #include <map> | |||
| #include <vector> | |||
| #include "tools/converter/parser/tf/tf_node_parser_registry.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS TFRankParser::Parse(const tensorflow::NodeDef &tf_op, | |||
| const std::map<string, const tensorflow::NodeDef *> &tf_node_map, PrimitiveC **primitiveC, | |||
| std::vector<std::string> *inputs, int *output_size) { | |||
| MS_LOG(DEBUG) << "TF RankParser"; | |||
| if (primitiveC == nullptr || output_size == nullptr) { | |||
| MS_LOG(ERROR) << "primitiveC is nullptr"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (primitive == nullptr) { | |||
| MS_LOG(ERROR) << "New PrimitiveT failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| auto attr = std::make_unique<schema::RankT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new attr failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| primitive->value.type = schema::PrimitiveType_Rank; | |||
| primitive->value.value = attr.release(); | |||
| *primitiveC = PrimitiveC::Create(primitive.release()); | |||
| if (*primitiveC == nullptr) { | |||
| MS_LOG(ERROR) << "primitiveC is nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| *output_size = 1; | |||
| auto status = AddOpInput(tf_op, 0, inputs); | |||
| if (status != RET_OK) { | |||
| return status; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| TFNodeRegistrar g_tfRankParser("Rank", new TFRankParser()); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,36 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_RANK_PARSER_H_ | |||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_RANK_PARSER_H_ | |||
| #include <string> | |||
| #include <memory> | |||
| #include <map> | |||
| #include <vector> | |||
| #include "tools/converter/parser/tf/tf_node_parser.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| class TFRankParser : public TFNodeParser { | |||
| public: | |||
| TFRankParser() = default; | |||
| ~TFRankParser() override = default; | |||
| STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map<string, const tensorflow::NodeDef *> &tf_node_map, | |||
| PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) override; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_RANK_PARSER_H_ | |||
| @@ -0,0 +1,68 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "tools/converter/parser/tf/tf_uniform_real_parser.h" | |||
| #include <string> | |||
| #include <memory> | |||
| #include <map> | |||
| #include <vector> | |||
| #include "tools/converter/parser/tf/tf_node_parser_registry.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS TFUniformRealParser::Parse(const tensorflow::NodeDef &tf_op, | |||
| const std::map<string, const tensorflow::NodeDef *> &tf_node_map, | |||
| PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) { | |||
| MS_LOG(DEBUG) << "TF UniformRealParser"; | |||
| if (primitiveC == nullptr || output_size == nullptr) { | |||
| MS_LOG(ERROR) << "primitiveC is nullptr"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (primitive == nullptr) { | |||
| MS_LOG(ERROR) << "New PrimitiveT failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| auto attr = std::make_unique<schema::UniformRealT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new attr failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| tensorflow::AttrValue attr_value; | |||
| if (!TensorFlowUtils::FindAttrValue(tf_op, "seed", &attr_value)) { | |||
| MS_LOG(ERROR) << "The seed attr should be specified"; | |||
| return RET_ERROR; | |||
| } | |||
| attr->seed = attr_value.i(); | |||
| if (!TensorFlowUtils::FindAttrValue(tf_op, "seed2", &attr_value)) { | |||
| MS_LOG(ERROR) << "The seed2 attr should be specified"; | |||
| return RET_ERROR; | |||
| } | |||
| attr->seed2 = attr_value.i(); | |||
| primitive->value.type = schema::PrimitiveType_UniformReal; | |||
| primitive->value.value = attr.release(); | |||
| *primitiveC = PrimitiveC::Create(primitive.release()); | |||
| if (*primitiveC == nullptr) { | |||
| MS_LOG(ERROR) << "primitiveC is nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| *output_size = 1; | |||
| auto status = AddOpInput(tf_op, 0, inputs); | |||
| return status; | |||
| } | |||
| TFNodeRegistrar g_tfRandomUniformParser("RandomUniform", new TFUniformRealParser()); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,36 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_UNIFORM_REAL_PARSER_H_ | |||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_UNIFORM_REAL_PARSER_H_ | |||
| #include <string> | |||
| #include <memory> | |||
| #include <map> | |||
| #include <vector> | |||
| #include "tools/converter/parser/tf/tf_node_parser.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| class TFUniformRealParser : public TFNodeParser { | |||
| public: | |||
| TFUniformRealParser() = default; | |||
| ~TFUniformRealParser() override = default; | |||
| STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map<string, const tensorflow::NodeDef *> &tf_node_map, | |||
| PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) override; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_UNIFORM_REAL_PARSER_H_ | |||