You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

graph_impl_stub.cc 2.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "stub/graph_impl_stub.h"
  17. namespace mindspore {
  18. GraphImplStubAdd::GraphImplStubAdd() { Init({2, 2}); }
  19. GraphImplStubAdd::GraphImplStubAdd(const std::vector<int64_t> &add_shape) { Init(add_shape); }
  20. GraphImplStubAdd::~GraphImplStubAdd() {}
  21. void GraphImplStubAdd::Init(const std::vector<int64_t> &add_shape) {
  22. auto element_cnt = [add_shape]() -> size_t {
  23. size_t element_num = 1;
  24. for (auto dim : add_shape) {
  25. if (dim <= 0) {
  26. return 0;
  27. }
  28. element_num *= dim;
  29. }
  30. return element_num;
  31. };
  32. auto ele_size = element_cnt() * sizeof(float);
  33. MSTensor tensor_x1 = MSTensor("x1", mindspore::DataType::kNumberTypeFloat32, add_shape, nullptr, ele_size);
  34. MSTensor tensor_x2 = MSTensor("x2", mindspore::DataType::kNumberTypeFloat32, add_shape, nullptr, ele_size);
  35. MSTensor tensor_y = MSTensor("y", mindspore::DataType::kNumberTypeFloat32, add_shape, nullptr, ele_size);
  36. inputs_.push_back(tensor_x1);
  37. inputs_.push_back(tensor_x2);
  38. outputs_.push_back(tensor_y);
  39. }
  40. Status GraphImplStubAdd::Run(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) {
  41. if (inputs.size() != inputs_.size()) {
  42. return mindspore::kCoreFailed;
  43. }
  44. for (size_t i = 0; i < inputs.size(); i++) {
  45. if (inputs[i].DataSize() != inputs_[i].DataSize()) {
  46. return mindspore::kCoreFailed;
  47. }
  48. if (inputs_[i].DataSize() != 0 && inputs[i].Data() == nullptr) {
  49. return mindspore::kCoreFailed;
  50. }
  51. }
  52. auto x1 = reinterpret_cast<const float *>(inputs[0].Data().get());
  53. auto x2 = reinterpret_cast<const float *>(inputs[1].Data().get());
  54. MSTensor output = outputs_[0].Clone();
  55. auto y = reinterpret_cast<float *>(output.MutableData());
  56. for (size_t i = 0; i < outputs_[0].DataSize() / sizeof(float); i++) {
  57. y[i] = x1[i] + x2[i];
  58. }
  59. outputs->push_back(output);
  60. return mindspore::kSuccess;
  61. }
  62. Status GraphImplStubAdd::Load() { return kSuccess; }
  63. std::vector<MSTensor> GraphImplStubAdd::GetInputs() { return inputs_; }
  64. std::vector<MSTensor> GraphImplStubAdd::GetOutputs() { return outputs_; }
  65. } // namespace mindspore

A lightweight and high-performance service module that helps MindSpore developers efficiently deploy online inference services in the production environment.