You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

context.h 2.6 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_INCLUDE_API_CONTEXT_H
  17. #define MINDSPORE_INCLUDE_API_CONTEXT_H
  18. #include <map>
  19. #include <any>
  20. #include <string>
  21. #include <memory>
  22. #include "include/api/types.h"
  23. namespace mindspore {
  24. constexpr auto kDeviceTypeAscend310 = "Ascend310";
  25. constexpr auto kDeviceTypeAscend910 = "Ascend910";
  26. constexpr auto kDeviceTypeGPU = "GPU";
  27. struct MS_API Context {
  28. virtual ~Context() = default;
  29. std::map<std::string, std::any> params;
  30. };
  31. struct MS_API GlobalContext : public Context {
  32. static std::shared_ptr<Context> GetGlobalContext();
  33. static void SetGlobalDeviceTarget(const std::string &device_target);
  34. static std::string GetGlobalDeviceTarget();
  35. static void SetGlobalDeviceID(const uint32_t &device_id);
  36. static uint32_t GetGlobalDeviceID();
  37. };
  38. struct MS_API ModelContext : public Context {
  39. static void SetInsertOpConfigPath(const std::shared_ptr<Context> &context, const std::string &cfg_path);
  40. static std::string GetInsertOpConfigPath(const std::shared_ptr<Context> &context);
  41. static void SetInputFormat(const std::shared_ptr<Context> &context, const std::string &format);
  42. static std::string GetInputFormat(const std::shared_ptr<Context> &context);
  43. static void SetInputShape(const std::shared_ptr<Context> &context, const std::string &shape);
  44. static std::string GetInputShape(const std::shared_ptr<Context> &context);
  45. static void SetOutputType(const std::shared_ptr<Context> &context, enum DataType output_type);
  46. static enum DataType GetOutputType(const std::shared_ptr<Context> &context);
  47. static void SetPrecisionMode(const std::shared_ptr<Context> &context, const std::string &precision_mode);
  48. static std::string GetPrecisionMode(const std::shared_ptr<Context> &context);
  49. static void SetOpSelectImplMode(const std::shared_ptr<Context> &context, const std::string &op_select_impl_mode);
  50. static std::string GetOpSelectImplMode(const std::shared_ptr<Context> &context);
  51. };
  52. } // namespace mindspore
  53. #endif // MINDSPORE_INCLUDE_API_CONTEXT_H