You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

context.h 2.6 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_INCLUDE_API_CONTEXT_H
  17. #define MINDSPORE_INCLUDE_API_CONTEXT_H
  18. #include <map>
  19. #include <any>
  20. #include <string>
  21. #include <memory>
  22. #include "include/api/types.h"
  23. namespace mindspore {
  24. constexpr auto kDeviceTypeAscend310 = "Ascend310";
  25. constexpr auto kDeviceTypeAscend910 = "Ascend910";
  26. struct MS_API Context {
  27. virtual ~Context() = default;
  28. std::map<std::string, std::any> params;
  29. };
  30. struct MS_API GlobalContext : public Context {
  31. static std::shared_ptr<Context> GetGlobalContext();
  32. static void SetGlobalDeviceTarget(const std::string &device_target);
  33. static std::string GetGlobalDeviceTarget();
  34. static void SetGlobalDeviceID(const uint32_t &device_id);
  35. static uint32_t GetGlobalDeviceID();
  36. };
  37. struct MS_API ModelContext : public Context {
  38. static void SetInsertOpConfigPath(const std::shared_ptr<Context> &context, const std::string &cfg_path);
  39. static std::string GetInsertOpConfigPath(const std::shared_ptr<Context> &context);
  40. static void SetInputFormat(const std::shared_ptr<Context> &context, const std::string &format);
  41. static std::string GetInputFormat(const std::shared_ptr<Context> &context);
  42. static void SetInputShape(const std::shared_ptr<Context> &context, const std::string &shape);
  43. static std::string GetInputShape(const std::shared_ptr<Context> &context);
  44. static void SetOutputType(const std::shared_ptr<Context> &context, enum DataType output_type);
  45. static enum DataType GetOutputType(const std::shared_ptr<Context> &context);
  46. static void SetPrecisionMode(const std::shared_ptr<Context> &context, const std::string &precision_mode);
  47. static std::string GetPrecisionMode(const std::shared_ptr<Context> &context);
  48. static void SetOpSelectImplMode(const std::shared_ptr<Context> &context, const std::string &op_select_impl_mode);
  49. static std::string GetOpSelectImplMode(const std::shared_ptr<Context> &context);
  50. };
  51. } // namespace mindspore
  52. #endif // MINDSPORE_INCLUDE_API_CONTEXT_H