You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

op_registry.h 2.1 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef PREDICT_SRC_OP_REGISTRY_H_
  17. #define PREDICT_SRC_OP_REGISTRY_H_
  18. #include <map>
  19. #include <string>
  20. #include <unordered_map>
  21. #include "common/mslog.h"
  22. #include "common/module_registry.h"
  23. #include "src/op.h"
  24. #define MSPREDICT_API __attribute__((visibility("default")))
  25. namespace mindspore {
  26. namespace predict {
  27. class MSPREDICT_API OpRegistry {
  28. public:
  29. OpRegistry();
  30. virtual ~OpRegistry();
  31. static OpRegistry *GetInstance();
  32. virtual OpCreator GetOpCreator(const OpDesc &desc);
  33. const std::map<OpDesc, OpCreator> &GetOpCreators();
  34. void RegOp(OpDesc desc, OpCreator creator);
  35. void RegOp(OP_ARCH arch, OpT type, OpCreator creator);
  36. static bool Merge(const std::unordered_map<OpDesc, OpCreator> &newCreators);
  37. protected:
  38. std::map<OpDesc, OpCreator> creators;
  39. };
  40. template <>
  41. class Module<OpRegistry> : public ModuleBase {
  42. public:
  43. virtual OpRegistry *GetInstance() = 0;
  44. };
  45. const char MODULE_REG_NAME_OP_REGISTRY[] = "op_registry";
  46. class OpRegistrar {
  47. public:
  48. OpRegistrar(const OpDesc &desc, OpCreator creator) { OpRegistry::GetInstance()->RegOp(desc, creator); }
  49. OpRegistrar(const OP_ARCH arch, const OpT type, OpCreator creator) {
  50. MS_ASSERT(OpRegistry::GetInstance() != nullptr);
  51. OpRegistry::GetInstance()->RegOp(arch, type, creator);
  52. }
  53. };
  54. #define REG_OP(arch, type, opCreater) static OpRegistrar g_##arch##type##OpReg(arch, type, opCreater);
  55. } // namespace predict
  56. } // namespace mindspore
  57. #endif // PREDICT_SRC_OP_REGISTRY_H_