Browse Source

!12212 add pb parser

From: @yeyunpeng2020
Reviewed-by: @HilbertDavid
Signed-off-by: @HilbertDavid
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 5 years ago
parent
commit
695052e88c
19 changed files with 616 additions and 40 deletions
  1. +2
    -0
      mindspore/lite/schema/model.fbs
  2. +10
    -2
      mindspore/lite/schema/ops.fbs
  3. +19
    -15
      mindspore/lite/src/ops/batch_matmul.cc
  4. +4
    -5
      mindspore/lite/src/ops/batch_matmul.h
  5. +53
    -0
      mindspore/lite/src/ops/lin_space.cc
  6. +42
    -0
      mindspore/lite/src/ops/lin_space.h
  7. +10
    -1
      mindspore/lite/src/ops/primitive_c.cc
  8. +101
    -0
      mindspore/lite/src/ops/uniform_real.cc
  9. +46
    -0
      mindspore/lite/src/ops/uniform_real.h
  10. +4
    -1
      mindspore/lite/tools/converter/parser/tf/tf_activation_parser.cc
  11. +6
    -0
      mindspore/lite/tools/converter/parser/tf/tf_arithmetic_self_parser.cc
  12. +20
    -12
      mindspore/lite/tools/converter/parser/tf/tf_batch_matmul_parser.cc
  13. +3
    -4
      mindspore/lite/tools/converter/parser/tf/tf_batch_matmul_parser.h
  14. +61
    -0
      mindspore/lite/tools/converter/parser/tf/tf_linspace_parser.cc
  15. +36
    -0
      mindspore/lite/tools/converter/parser/tf/tf_linspace_parser.h
  16. +59
    -0
      mindspore/lite/tools/converter/parser/tf/tf_rank_parser.cc
  17. +36
    -0
      mindspore/lite/tools/converter/parser/tf/tf_rank_parser.h
  18. +68
    -0
      mindspore/lite/tools/converter/parser/tf/tf_uniform_real_parser.cc
  19. +36
    -0
      mindspore/lite/tools/converter/parser/tf/tf_uniform_real_parser.h

+ 2
- 0
mindspore/lite/schema/model.fbs View File

@@ -276,6 +276,8 @@ union PrimitiveType {
StridedSliceGrad,
IsFinite,
BatchMatMul,
LinSpace,
UniformReal
}

enum QuantType: int {


+ 10
- 2
mindspore/lite/schema/ops.fbs View File

@@ -1281,6 +1281,14 @@ table IsFinite {
}

table BatchMatMul {
adj_x : bool = false;
adj_y : bool = false;
transpose_a :bool;
transpose_b :bool;
}

table LinSpace {
}

table UniformReal {
seed : int;
seed2 : int;
}

+ 19
- 15
mindspore/lite/src/ops/batch_matmul.cc View File

@@ -13,8 +13,8 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "src/ops/batch_matmul.h"
#include <memory>
#ifndef PRIMITIVE_WRITEABLE
#include "src/ops/ops_register.h"
#endif
@@ -22,14 +22,17 @@
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
bool BatchMatMul::GetAdjX() const { return this->primitive_->value.AsBatchMatMul()->adj_x; }

void BatchMatMul::SetAdjX(bool adj_x) { this->primitive_->value.AsBatchMatMul()->adj_x = adj_x; }
bool BatchMatMul::GetTransposeA() const { return this->primitive_->value.AsBatchMatMul()->transpose_a; }

bool BatchMatMul::GetAdjY() const { return this->primitive_->value.AsBatchMatMul()->adj_y; }
bool BatchMatMul::GetTransposeB() const { return this->primitive_->value.AsBatchMatMul()->transpose_b; }

void BatchMatMul::SetAdjY(bool adj_y) { this->primitive_->value.AsBatchMatMul()->adj_y = adj_y; }
void BatchMatMul::SetTransposeA(bool transpose_a) {
this->primitive_->value.AsBatchMatMul()->transpose_a = transpose_a;
}

void BatchMatMul::SetTransposeB(bool transpose_b) {
this->primitive_->value.AsBatchMatMul()->transpose_b = transpose_b;
}
int BatchMatMul::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
if (this->primitive_ == nullptr) {
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
@@ -51,31 +54,32 @@ int BatchMatMul::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr>
this->primitive_ = nullptr;
return RET_ERROR;
}
attr->adj_x = GetValue<bool>(prim.GetAttr("adj_x"));
attr->adj_y = GetValue<bool>(prim.GetAttr("adj_y"));
attr->transpose_a = GetValue<bool>(prim.GetAttr("transpose_a"));
attr->transpose_b = GetValue<bool>(prim.GetAttr("transpose_b"));
this->primitive_->value.value = attr;
}
return RET_OK;
}

#else
bool BatchMatMul::GetTransposeA() const { return this->primitive_->value_as_BatchMatMul()->transpose_a(); }
bool BatchMatMul::GetTransposeB() const { return this->primitive_->value_as_BatchMatMul()->transpose_b(); }
int BatchMatMul::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);
MS_ASSERT(nullptr != fbb);
auto val_offset = schema::CreateBatchMatMul(*fbb);
auto attr = primitive->value_as_BatchMatMul();
if (attr == nullptr) {
MS_LOG(ERROR) << "value_as_Add return nullptr";
return RET_ERROR;
}
auto val_offset = schema::CreateBatchMatMul(*fbb, attr->transpose_a(), attr->transpose_b());
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_BatchMatMul, val_offset.o);
fbb->Finish(prim_offset);
return RET_OK;
}
bool BatchMatMul::GetAdjX() const { return this->primitive_->value_as_BatchMatMul()->adj_x(); }

bool BatchMatMul::GetAdjY() const { return this->primitive_->value_as_BatchMatMul()->adj_y(); }

PrimitiveC *BatchMatMulCreator(const schema::Primitive *primitive) {
return PrimitiveC::NewPrimitiveC<BatchMatMul>(primitive);
}
Registry BatchMatMulRegistry(schema::PrimitiveType_BatchMatMul, BatchMatMulCreator);
#endif

} // namespace lite
} // namespace mindspore

+ 4
- 5
mindspore/lite/src/ops/batch_matmul.h View File

@@ -32,15 +32,14 @@ class BatchMatMul : public PrimitiveC {
MS_DECLARE_PARENT(BatchMatMul, PrimitiveC);
explicit BatchMatMul(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
void SetAdjX(bool adj_x);
void SetAdjY(bool adj_y);
void SetTransposeA(bool transpose_a);
void SetTransposeB(bool transpose_b);
#else
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
bool GetAdjX() const;
bool GetAdjY() const;
bool GetTransposeA() const;
bool GetTransposeB() const;
};
} // namespace lite
} // namespace mindspore

#endif // LITE_MINDSPORE_LITE_C_OPS_BATCH_MATMUL_H_

+ 53
- 0
mindspore/lite/src/ops/lin_space.cc View File

@@ -0,0 +1,53 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "src/ops/lin_space.h"
#ifndef PRIMITIVE_WRITEABLE
#include "src/ops/ops_register.h"
#endif

namespace mindspore {
namespace lite {
#ifndef PRIMITIVE_WRITEABLE
int LinSpace::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);
MS_ASSERT(nullptr != fbb);
auto val_offset = schema::CreateLinSpace(*fbb);
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_LinSpace, val_offset.o);
fbb->Finish(prim_offset);
return RET_OK;
}
PrimitiveC *LinSpaceCreator(const schema::Primitive *primitive) {
return PrimitiveC::NewPrimitiveC<LinSpace>(primitive);
}
Registry LinSpaceRegistry(schema::PrimitiveType_LinSpace, LinSpaceCreator);
#endif
int LinSpace::InferShape(std::vector<Tensor *> inputs, std::vector<Tensor *> outputs) {
auto input = inputs.front();
MS_ASSERT(input != nullptr);
auto output = outputs.front();
MS_ASSERT(output != nullptr);
output->set_data_type(input->data_type());
output->set_format(input->format());
auto num = inputs.at(2)->data_c();
if (num == nullptr) {
return RET_INFER_INVALID;
}
output->set_shape({reinterpret_cast<int *>(num)[0]});
return RET_OK;
}
} // namespace lite
} // namespace mindspore

+ 42
- 0
mindspore/lite/src/ops/lin_space.h View File

@@ -0,0 +1,42 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <vector>
#include <set>
#include <cmath>

#include "src/ops/primitive_c.h"

#ifndef LITE_MINDSPORE_LITE_C_OPS_LIN_SPACE_H_
#define LITE_MINDSPORE_LITE_C_OPS_LIN_SPACE_H_

namespace mindspore {
namespace lite {
class LinSpace : public PrimitiveC {
public:
LinSpace() = default;
~LinSpace() = default;
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(LinSpace, PrimitiveC);
explicit LinSpace(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
#else
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
int InferShape(std::vector<Tensor *> inputs, std::vector<Tensor *> outputs) override;
};
} // namespace lite
} // namespace mindspore
#endif // LITE_MINDSPORE_LITE_C_OPS_LIN_SPACE_H_

+ 10
- 1
mindspore/lite/src/ops/primitive_c.cc View File

@@ -170,8 +170,11 @@
#include "src/ops/crop_and_resize.h"
#include "src/ops/nonzero.h"
#include "src/ops/erf.h"
#include "src/ops/is_finite.h"
#include "src/ops/batch_matmul.h"
#include "src/ops/lin_space.h"
#include "src/ops/uniform_real.h"
#include "src/ops/rank.h"
#include "src/ops/is_finite.h"

#ifdef SUPPORT_TRAIN
#include "src/ops/neg_grad.h"
@@ -1047,6 +1050,12 @@ PrimitiveC *PrimitiveC::Create(mindspore::schema::PrimitiveT *primitive) {
return new (std::nothrow) IsFinite(primitive);
case schema::PrimitiveType_BatchMatMul:
return new (std::nothrow) BatchMatMul(primitive);
case schema::PrimitiveType_LinSpace:
return new (std::nothrow) LinSpace(primitive);
case schema::PrimitiveType_UniformReal:
return new (std::nothrow) UniformReal(primitive);
case schema::PrimitiveType_Rank:
return new (std::nothrow) Rank(primitive);
#ifdef SUPPORT_TRAIN
case schema::PrimitiveType_ActivationGrad:
return new (std::nothrow) ActivationGrad(primitive);


+ 101
- 0
mindspore/lite/src/ops/uniform_real.cc View File

@@ -0,0 +1,101 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "src/ops/uniform_real.h"

#ifndef PRIMITIVE_WRITEABLE
#include "src/ops/ops_register.h"
#endif

namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
int UniformReal::GetSeed() const { return this->primitive_->value.AsUniformReal()->seed; }

int UniformReal::GetSeed2() const { return this->primitive_->value.AsUniformReal()->seed2; }

int UniformReal::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
if (this->primitive_ == nullptr) {
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
if (this->primitive_ == nullptr) {
MS_LOG(ERROR) << "new primitiveT failed";
return RET_ERROR;
}
this->primitive_->value.type = schema::PrimitiveType_UniformReal;
}
if (this->primitive_->value.type != schema::PrimitiveType_UniformReal) {
MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type;
return RET_ERROR;
}
if (this->primitive_->value.value == nullptr) {
auto attr = new (std::nothrow) schema::UniformRealT();
if (attr == nullptr) {
MS_LOG(ERROR) << "new primitiveT value failed";
return RET_ERROR;
}

this->primitive_->value.value = attr;
if (this->primitive_->value.value == nullptr) {
MS_LOG(ERROR) << "primitive value is nullptr";
return RET_ERROR;
}
}
return RET_OK;
}
#else
int UniformReal::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);
MS_ASSERT(nullptr != fbb);
auto attr = primitive->value_as_UniformReal();
if (attr == nullptr) {
MS_LOG(ERROR) << "value_as_UniformReal return nullptr";
return RET_ERROR;
}
auto val_offset = schema::CreateUniformReal(*fbb, attr->seed(), attr->seed2());
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_UniformReal, val_offset.o);
fbb->Finish(prim_offset);
return RET_OK;
}

int UniformReal::GetSeed() const { return this->primitive_->value_as_UniformReal()->seed(); }

int UniformReal::GetSeed2() const { return this->primitive_->value_as_UniformReal()->seed2(); }

PrimitiveC *UniformRealCreator(const schema::Primitive *primitive) {
return PrimitiveC::NewPrimitiveC<UniformReal>(primitive);
}
Registry UniformRealRegistry(schema::PrimitiveType_UniformReal, UniformRealCreator);
#endif

int UniformReal::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
if (!infer_flag()) {
return RET_INFER_INVALID;
}
auto input_data = static_cast<int32_t *>(inputs_[0]->data_c());
if (input_data == nullptr) {
return RET_INFER_INVALID;
}
auto input_num = inputs_[0]->ElementsNum();
std::vector<int> output_shape(input_num);
for (int i = 0; i < input_num; i++) {
output_shape[i] = input_data[i];
}
outputs_[0]->set_shape(output_shape);
outputs_[0]->set_data_type(kNumberTypeFloat32);
return RET_OK;
}
} // namespace lite
} // namespace mindspore

+ 46
- 0
mindspore/lite/src/ops/uniform_real.h View File

@@ -0,0 +1,46 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef LITE_MINDSPORE_LITE_C_OPS_UNIFORM_REAL_H_
#define LITE_MINDSPORE_LITE_C_OPS_UNIFORM_REAL_H_

#include <vector>
#include <set>
#include <cmath>
#include <memory>
#include "src/ops/primitive_c.h"

namespace mindspore {
namespace lite {
class UniformReal : public PrimitiveC {
public:
UniformReal() = default;
~UniformReal() = default;
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(UniformReal, PrimitiveC);
explicit UniformReal(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
#else
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
int GetSeed() const;
int GetSeed2() const;
};
} // namespace lite
} // namespace mindspore

#endif // LITE_MINDSPORE_LITE_C_OPS_UNIFORM_REAL_H_

+ 4
- 1
mindspore/lite/tools/converter/parser/tf/tf_activation_parser.cc View File

@@ -52,6 +52,8 @@ STATUS TFActivationParser::Parse(const tensorflow::NodeDef &tf_op,
attr->type = schema::ActivationType_TANH;
} else if (tf_op.op() == "LeakyRelu") {
attr->type = schema::ActivationType_LEAKY_RELU;
} else if (tf_op.op() == "Selu") {
attr->type = schema::ActivationType_SELU;
} else {
MS_LOG(ERROR) << "unsupported activation type:" << tf_op.op();
return RET_ERROR;
@@ -63,7 +65,7 @@ STATUS TFActivationParser::Parse(const tensorflow::NodeDef &tf_op,
auto attr_leaky_relu = std::make_unique<schema::LeakyReLUT>();
tensorflow::AttrValue attr_value;
if (!TensorFlowUtils::FindAttrValue(tf_op, "alpha", &attr_value)) {
MS_LOG(ERROR) << "The attribute alpha shoud be specified.";
MS_LOG(ERROR) << "The attribute alpha should be specified.";
return RET_ERROR;
}
attr_leaky_relu->negativeSlope = attr_value.f();
@@ -85,5 +87,6 @@ TFNodeRegistrar g_tfRelu6Parser("Relu6", new TFActivationParser());
TFNodeRegistrar g_tfSigmoidParser("Sigmoid", new TFActivationParser());
TFNodeRegistrar g_tfTanhParser("Tanh", new TFActivationParser());
TFNodeRegistrar g_tfLeakyReluParser("LeakyRelu", new TFActivationParser());
TFNodeRegistrar g_tfSeLUParser("Selu", new TFActivationParser());
} // namespace lite
} // namespace mindspore

+ 6
- 0
mindspore/lite/tools/converter/parser/tf/tf_arithmetic_self_parser.cc View File

@@ -58,6 +58,11 @@ STATUS TFArithmeticSelfParser::Parse(const tensorflow::NodeDef &tf_op,
status = CreateOperator<schema::SquareT>(primitive, schema::PrimitiveType_Square);
} else if (tf_op.op() == "Pow") {
status = CreateOperator<schema::PowerT>(primitive, schema::PrimitiveType_Power);
} else if (tf_op.op() == "Abs") {
status = CreateOperator<schema::PowerT>(primitive, schema::PrimitiveType_Abs);
} else {
MS_LOG(ERROR) << "unsupported arithmetic self type:" << tf_op.op();
return RET_ERROR;
}
if (status != RET_OK) {
return status;
@@ -85,5 +90,6 @@ TFNodeRegistrar g_tfFloorParser("Floor", new TFArithmeticSelfParser());
TFNodeRegistrar g_tfLogParser("Log", new TFArithmeticSelfParser());
TFNodeRegistrar g_tfSqrtParser("Sqrt", new TFArithmeticSelfParser());
TFNodeRegistrar g_tfPowParser("Pow", new TFArithmeticSelfParser());
TFNodeRegistrar g_tfAbsParser("Abs", new TFArithmeticSelfParser());
} // namespace lite
} // namespace mindspore

+ 20
- 12
mindspore/lite/tools/converter/parser/tf/tf_batch_matmul_parser.cc View File

@@ -22,29 +22,35 @@

namespace mindspore {
namespace lite {
STATUS TFBatchMatmulParser::Parse(const tensorflow::NodeDef &tf_op,
STATUS TFBatchMatMulParser::Parse(const tensorflow::NodeDef &tf_op,
const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) {
MS_LOG(DEBUG) << "TF BatchMatMulParser";
if (primitiveC == nullptr || output_size == nullptr) {
MS_LOG(ERROR) << "primitiveC is nullptr";
return RET_NULL_PTR;
}

auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "primitive is nullptr";
MS_LOG(ERROR) << "New PrimitiveT failed";
return RET_NULL_PTR;
}
auto attr = std::make_unique<schema::BatchMatMulT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
MS_LOG(ERROR) << "new attr failed";
return RET_NULL_PTR;
}
tensorflow::AttrValue attr_value;
TensorFlowUtils::FindAttrValue(tf_op, "adj_x", &attr_value);
attr->adj_x = attr_value.b();
attr->adj_y = attr_value.b();

if (!TensorFlowUtils::FindAttrValue(tf_op, "adj_x", &attr_value)) {
MS_LOG(ERROR) << "The begin_mask attr should be specified";
return RET_ERROR;
}
attr->transpose_a = attr_value.b();
if (!TensorFlowUtils::FindAttrValue(tf_op, "adj_y", &attr_value)) {
MS_LOG(ERROR) << "The begin_mask attr should be specified";
return RET_ERROR;
}
attr->transpose_b = attr_value.b();
primitive->value.type = schema::PrimitiveType_BatchMatMul;
primitive->value.value = attr.release();
*primitiveC = PrimitiveC::Create(primitive.release());
@@ -52,13 +58,15 @@ STATUS TFBatchMatmulParser::Parse(const tensorflow::NodeDef &tf_op,
MS_LOG(ERROR) << "primitiveC is nullptr";
return RET_ERROR;
}

*output_size = 1;
for (int i = 0; i < tf_op.input_size(); i++) {
inputs->emplace_back(tf_op.input(i));
for (int i = 0; i < tf_op.input_size(); ++i) {
auto status = AddOpInput(tf_op, i, inputs);
if (status != RET_OK) {
return status;
}
}
return RET_OK;
}
TFNodeRegistrar g_tfBatchMatMulParser("BatchMatMul", new TFBatchMatmulParser());
TFNodeRegistrar g_tfBatchMatMulParser("BatchMatMul", new TFBatchMatMulParser());
} // namespace lite
} // namespace mindspore

+ 3
- 4
mindspore/lite/tools/converter/parser/tf/tf_batch_matmul_parser.h View File

@@ -23,15 +23,14 @@

namespace mindspore {
namespace lite {
class TFBatchMatmulParser : public TFNodeParser {
class TFBatchMatMulParser : public TFNodeParser {
public:
TFBatchMatmulParser() = default;
~TFBatchMatmulParser() override = default;
TFBatchMatMulParser() = default;
~TFBatchMatMulParser() override = default;

STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) override;
};
} // namespace lite
} // namespace mindspore

#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_BATCH_MATMUL_PARSER_H_

+ 61
- 0
mindspore/lite/tools/converter/parser/tf/tf_linspace_parser.cc View File

@@ -0,0 +1,61 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/parser/tf/tf_linspace_parser.h"
#include <string>
#include <memory>
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser_registry.h"

namespace mindspore {
namespace lite {
STATUS TFLinSpaceParser::Parse(const tensorflow::NodeDef &tf_op,
const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) {
MS_LOG(DEBUG) << "TF LinSpaceParser";
if (primitiveC == nullptr || output_size == nullptr) {
MS_LOG(ERROR) << "primitiveC is nullptr";
return RET_NULL_PTR;
}
auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "New PrimitiveT failed";
return RET_NULL_PTR;
}
auto attr = std::make_unique<schema::LinSpaceT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new attr failed";
return RET_NULL_PTR;
}
primitive->value.type = schema::PrimitiveType_LinSpace;
primitive->value.value = attr.release();
*primitiveC = PrimitiveC::Create(primitive.release());
if (*primitiveC == nullptr) {
MS_LOG(ERROR) << "primitiveC is nullptr";
return RET_ERROR;
}
*output_size = 1;
for (int i = 0; i < tf_op.input_size(); ++i) {
auto status = AddOpInput(tf_op, i, inputs);
if (status != RET_OK) {
return status;
}
}
return RET_OK;
}
TFNodeRegistrar g_tfLinSpaceParser("LinSpace", new TFLinSpaceParser());
} // namespace lite
} // namespace mindspore

+ 36
- 0
mindspore/lite/tools/converter/parser/tf/tf_linspace_parser.h View File

@@ -0,0 +1,36 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_LIN_SPACE_PARSER_H_
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_LIN_SPACE_PARSER_H_
#include <string>
#include <memory>
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser.h"

namespace mindspore {
namespace lite {
class TFLinSpaceParser : public TFNodeParser {
public:
TFLinSpaceParser() = default;
~TFLinSpaceParser() override = default;

STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) override;
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_LIN_SPACE_PARSER_H_

+ 59
- 0
mindspore/lite/tools/converter/parser/tf/tf_rank_parser.cc View File

@@ -0,0 +1,59 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/parser/tf/tf_rank_parser.h"
#include <string>
#include <memory>
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser_registry.h"

namespace mindspore {
namespace lite {
STATUS TFRankParser::Parse(const tensorflow::NodeDef &tf_op,
const std::map<string, const tensorflow::NodeDef *> &tf_node_map, PrimitiveC **primitiveC,
std::vector<std::string> *inputs, int *output_size) {
MS_LOG(DEBUG) << "TF RankParser";
if (primitiveC == nullptr || output_size == nullptr) {
MS_LOG(ERROR) << "primitiveC is nullptr";
return RET_NULL_PTR;
}
auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "New PrimitiveT failed";
return RET_NULL_PTR;
}
auto attr = std::make_unique<schema::RankT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new attr failed";
return RET_NULL_PTR;
}
primitive->value.type = schema::PrimitiveType_Rank;
primitive->value.value = attr.release();
*primitiveC = PrimitiveC::Create(primitive.release());
if (*primitiveC == nullptr) {
MS_LOG(ERROR) << "primitiveC is nullptr";
return RET_ERROR;
}
*output_size = 1;
auto status = AddOpInput(tf_op, 0, inputs);
if (status != RET_OK) {
return status;
}
return RET_OK;
}
TFNodeRegistrar g_tfRankParser("Rank", new TFRankParser());
} // namespace lite
} // namespace mindspore

+ 36
- 0
mindspore/lite/tools/converter/parser/tf/tf_rank_parser.h View File

@@ -0,0 +1,36 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_RANK_PARSER_H_
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_RANK_PARSER_H_
#include <string>
#include <memory>
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser.h"

namespace mindspore {
namespace lite {
class TFRankParser : public TFNodeParser {
public:
TFRankParser() = default;
~TFRankParser() override = default;

STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) override;
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_RANK_PARSER_H_

+ 68
- 0
mindspore/lite/tools/converter/parser/tf/tf_uniform_real_parser.cc View File

@@ -0,0 +1,68 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/parser/tf/tf_uniform_real_parser.h"
#include <string>
#include <memory>
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser_registry.h"

namespace mindspore {
namespace lite {
STATUS TFUniformRealParser::Parse(const tensorflow::NodeDef &tf_op,
const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) {
MS_LOG(DEBUG) << "TF UniformRealParser";
if (primitiveC == nullptr || output_size == nullptr) {
MS_LOG(ERROR) << "primitiveC is nullptr";
return RET_NULL_PTR;
}

auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "New PrimitiveT failed";
return RET_NULL_PTR;
}
auto attr = std::make_unique<schema::UniformRealT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new attr failed";
return RET_NULL_PTR;
}
tensorflow::AttrValue attr_value;
if (!TensorFlowUtils::FindAttrValue(tf_op, "seed", &attr_value)) {
MS_LOG(ERROR) << "The seed attr should be specified";
return RET_ERROR;
}
attr->seed = attr_value.i();
if (!TensorFlowUtils::FindAttrValue(tf_op, "seed2", &attr_value)) {
MS_LOG(ERROR) << "The seed2 attr should be specified";
return RET_ERROR;
}
attr->seed2 = attr_value.i();
primitive->value.type = schema::PrimitiveType_UniformReal;
primitive->value.value = attr.release();
*primitiveC = PrimitiveC::Create(primitive.release());
if (*primitiveC == nullptr) {
MS_LOG(ERROR) << "primitiveC is nullptr";
return RET_ERROR;
}
*output_size = 1;
auto status = AddOpInput(tf_op, 0, inputs);
return status;
}
TFNodeRegistrar g_tfRandomUniformParser("RandomUniform", new TFUniformRealParser());
} // namespace lite
} // namespace mindspore

+ 36
- 0
mindspore/lite/tools/converter/parser/tf/tf_uniform_real_parser.h View File

@@ -0,0 +1,36 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_UNIFORM_REAL_PARSER_H_
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_UNIFORM_REAL_PARSER_H_
#include <string>
#include <memory>
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser.h"

namespace mindspore {
namespace lite {
class TFUniformRealParser : public TFNodeParser {
public:
TFUniformRealParser() = default;
~TFUniformRealParser() override = default;

STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) override;
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_UNIFORM_REAL_PARSER_H_

Loading…
Cancel
Save