You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tensorflow_parser_register.h 5.0 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136
  1. /**
  2. * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved.
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. // Copyright (c) <2018>, <Huawei Technologies Co., Ltd>
  17. #ifndef PARSER_TENSORFLOW_TENSORFLOW_PARSER_REGISTER_H_
  18. #define PARSER_TENSORFLOW_TENSORFLOW_PARSER_REGISTER_H_
  19. #include <functional>
  20. #include <memory>
  21. #include <string>
  22. #include "common/util.h"
  23. #include "framework/omg/parser/op_parser.h"
  24. #include "parser/common/op_def/ir_pb_converter.h"
  25. #include "parser/common/op_def/operator.h"
  26. #include "parser/common/acl_graph_parser_util.h"
  27. #include "parser/common/op_parser_factory.h"
  28. #include "parser/tensorflow/tensorflow_op_parser.h"
  29. #include "proto/tensorflow/node_def.pb.h"
  30. #include "register/register_utils.h"
  31. namespace ge {
  32. class PARSER_FUNC_VISIBILITY TensorflowFinalizeable {
  33. public:
  34. virtual bool Finalize() = 0;
  35. virtual ~TensorflowFinalizeable() {}
  36. };
  37. class PARSER_FUNC_VISIBILITY TensorflowReceiver {
  38. public:
  39. TensorflowReceiver(TensorflowFinalizeable &f) noexcept { f.Finalize(); }
  40. ~TensorflowReceiver() {}
  41. };
  42. namespace tensorflow_parser {
  43. template <typename Param>
  44. class TensorflowParserBuilder;
  45. class PARSER_FUNC_VISIBILITY TensorflowWeightParserBuilder : public TensorflowFinalizeable {
  46. public:
  47. ~TensorflowWeightParserBuilder() override {}
  48. };
  49. template <typename Param>
  50. class TensorflowOpParserAdapter;
  51. template <typename Param>
  52. class PARSER_FUNC_VISIBILITY TensorflowParserBuilder : public TensorflowWeightParserBuilder {
  53. public:
  54. using ParseParamsFn = std::function<domi::Status(const domi::tensorflow::NodeDef *, Param *)>;
  55. explicit TensorflowParserBuilder(const std::string &davinci_optype) : davinci_optype_(davinci_optype) {}
  56. ~TensorflowParserBuilder() override {}
  57. TensorflowParserBuilder &SetParseParamsFn(ParseParamsFn parse_params_fn) {
  58. parse_params_fn_ = parse_params_fn;
  59. return *this;
  60. }
  61. bool Finalize() override {
  62. auto op_parser_adapter = ge::parser::MakeShared<TensorflowOpParserAdapter<Param>>(*this);
  63. if (op_parser_adapter == nullptr) {
  64. GELOGE(FAILED, "Op parser adapter is null.");
  65. }
  66. // register to OpParserFactory
  67. OpParserRegisterar registerar __attribute__((unused)) = OpParserRegisterar(
  68. domi::TENSORFLOW, davinci_optype_, [op_parser_adapter] { return std::shared_ptr<OpParser>(op_parser_adapter); });
  69. return true;
  70. }
  71. private:
  72. std::string davinci_optype_; // op type in davinci model
  73. ParseParamsFn parse_params_fn_;
  74. friend class TensorflowOpParserAdapter<Param>;
  75. };
  76. template <typename Param>
  77. class PARSER_FUNC_VISIBILITY TensorflowOpParserAdapter : public TensorFlowOpParser {
  78. using ParseParamsFn = std::function<domi::Status(const domi::tensorflow::NodeDef *, Param *)>;
  79. public:
  80. explicit TensorflowOpParserAdapter(TensorflowParserBuilder<Param> builder) {
  81. parse_params_fn_ = builder.parse_params_fn_; }
  82. ~TensorflowOpParserAdapter() override {}
  83. Status ParseParams(const Message *op_src, ge::OpDescPtr &op_dest) override {
  84. const domi::tensorflow::NodeDef *node = static_cast<const domi::tensorflow::NodeDef *>(op_src);
  85. GE_CHECK_NOTNULL(node);
  86. std::shared_ptr<Param> param = ge::parser::MakeShared<Param>();
  87. if (param == nullptr) {
  88. GELOGE(domi::FAILED, "Param is null");
  89. return domi::FAILED;
  90. }
  91. ge::Operator op = ge::OpDescUtils::CreateOperatorFromOpDesc(op_dest);
  92. GE_CHK_STATUS_RET(domi::OperatorAutoMapping(op_src, op),
  93. "[Call][AutoMapping] failed.");
  94. op.BreakConnect();
  95. GE_RETURN_IF_ERROR(parse_params_fn_(node, param.get()));
  96. param.get()->Name(node->name());
  97. std::shared_ptr<ParserOperator> op_param = std::static_pointer_cast<ParserOperator>(param);
  98. ConvertToOpDesc(*op_param, op_dest);
  99. return domi::SUCCESS;
  100. }
  101. private:
  102. ParseParamsFn parse_params_fn_;
  103. };
  104. } // namespace tensorflow_parser
  105. #define DOMI_REGISTER_TENSORFLOW_PARSER(name, param_clazz) \
  106. DOMI_REGISTER_TENSORFLOW_PARSER_UNIQ_HELPER(__COUNTER__, name, param_clazz)
  107. #define DOMI_REGISTER_TENSORFLOW_PARSER_UNIQ_HELPER(ctr, name, param_clazz) \
  108. DOMI_REGISTER_TENSORFLOW_PARSER_UNIQ(ctr, name, param_clazz)
  109. #define DOMI_REGISTER_TENSORFLOW_PARSER_UNIQ(ctr, name, param_clazz) \
  110. static TensorflowReceiver register_tensorflow_parser##ctr __attribute__((unused)) = \
  111. tensorflow_parser::TensorflowParserBuilder<param_clazz>(name)
  112. } // namespace ge
  113. #endif // PARSER_TENSORFLOW_TENSORFLOW_PARSER_REGISTER_H_