You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

graph_data.cc 2.4 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "cxx_api/graph/graph_data.h"
  17. #include "utils/log_adapter.h"
  18. #ifdef ENABLE_ACL
  19. #include "framework/common/helper/model_helper.h"
  20. #endif
  21. namespace mindspore {
  22. Graph::GraphData::GraphData(const FuncGraphPtr &func_graph, enum ModelType model_type)
  23. : func_graph_(nullptr), om_data_(), model_type_(ModelType::kUnknownType), data_graph_({}) {
  24. if (model_type != ModelType::kMindIR) {
  25. MS_LOG(EXCEPTION) << "Invalid ModelType " << model_type;
  26. }
  27. func_graph_ = func_graph;
  28. model_type_ = model_type;
  29. }
  30. Graph::GraphData::GraphData(const Buffer &om_data, enum ModelType model_type)
  31. : func_graph_(nullptr), om_data_(om_data), model_type_(model_type), data_graph_({}) {
  32. if (model_type_ != ModelType::kOM) {
  33. MS_LOG(EXCEPTION) << "Invalid ModelType " << model_type_;
  34. }
  35. #ifdef ENABLE_ACL
  36. // check om
  37. ge::ModelHelper helper;
  38. ge::ModelData model_data;
  39. model_data.model_data = om_data_.MutableData();
  40. model_data.model_len = om_data_.DataSize();
  41. ge::Status ret = helper.LoadRootModel(model_data);
  42. if (ret != ge::SUCCESS) {
  43. MS_LOG(EXCEPTION) << "Invalid input data cannot parse to om.";
  44. }
  45. #else
  46. MS_LOG(EXCEPTION) << "Unsupported ModelType OM.";
  47. #endif
  48. }
  49. Graph::GraphData::~GraphData() {}
  50. FuncGraphPtr Graph::GraphData::GetFuncGraph() const {
  51. if (model_type_ != ModelType::kMindIR) {
  52. MS_LOG(ERROR) << "Invalid ModelType " << model_type_;
  53. return nullptr;
  54. }
  55. return func_graph_;
  56. }
  57. Buffer Graph::GraphData::GetOMData() const {
  58. if (model_type_ != ModelType::kOM) {
  59. MS_LOG(ERROR) << "Invalid ModelType " << model_type_;
  60. return Buffer();
  61. }
  62. return om_data_;
  63. }
  64. void Graph::GraphData::SetPreprocess(const std::vector<std::shared_ptr<dataset::Execute>> &data_graph) {
  65. data_graph_ = data_graph;
  66. }
  67. } // namespace mindspore