You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

function_base.h 1.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354
  1. /**
  2. * \file src/function_base.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #pragma once
  12. #include <unordered_map>
  13. #include "misc.h"
  14. #include "type_info.h"
  15. // template <typename tensor_type, typename ...Arg>
  16. namespace lite {
  17. class TensorImplDft;
  18. class NetworkImplDft;
  19. namespace {
  20. template <typename class_type>
  21. struct class_type_name {
  22. std::string operator()() { return ""; }
  23. };
  24. #define ADD_STATEMENT(class_name, backend_name) \
  25. template <> \
  26. struct class_type_name<class_name> { \
  27. std::string operator()() { return #backend_name; } \
  28. }
  29. ADD_STATEMENT(TensorImplDft, Dft);
  30. ADD_STATEMENT(NetworkImplDft, Dft);
  31. #undef ADD_STATEMENT
  32. } // namespace
  33. // if it can't find the function, ignore
  34. template <typename type, typename ret_type, typename... Args>
  35. ret_type try_call_func(std::string func_name, Args... args) {
  36. mark_used_variable(func_name);
  37. mark_used_variable(args...);
  38. return nullptr;
  39. }
  40. // if it can't find the function, throw error
  41. template <typename type, typename ret_type, typename... Args>
  42. ret_type call_func(std::string func_name, Args... args) {
  43. mark_used_variable(args...);
  44. auto backend_name = class_type_name<type>()();
  45. auto msg_info = func_name + " is not aviliable in " + backend_name + " backend.";
  46. LITE_THROW(msg_info.c_str());
  47. }
  48. } // namespace lite
  49. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}