You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

graph_tests.cc 4.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <gtest/gtest.h>
  17. #include <cstdio>
  18. #include <string>
  19. #include "schema/inner/ms_generated.h"
  20. #include "src/graph.h"
  21. #include "common/file_utils.h"
  22. #include "test/test_context.h"
  23. #include "include/session.h"
  24. namespace mindspore {
  25. namespace predict {
  26. class GraphTest : public ::testing::Test {
  27. protected:
  28. void SetUp() {}
  29. void TearDown() {}
  30. std::string root;
  31. };
  32. void InitMsGraphAllTensor(SubGraphDefT *msSubgraph) {
  33. ASSERT_NE(msSubgraph, nullptr);
  34. std::unique_ptr<TensorDefT> tensor (new (std::nothrow) TensorDefT);
  35. ASSERT_NE(tensor, nullptr);
  36. tensor->refCount = MSConst_WEIGHT_REFCOUNT;
  37. tensor->format = Format_NCHW;
  38. tensor->dataType = DataType_DT_FLOAT;
  39. tensor->dims = {1, 1, 1, 2};
  40. tensor->offset = -1;
  41. tensor->data.resize(0);
  42. msSubgraph->allTensors.emplace_back(std::move(tensor));
  43. std::unique_ptr<TensorDefT> tensor2(new (std::nothrow) TensorDefT);
  44. ASSERT_NE(tensor2, nullptr);
  45. tensor2->refCount = MSConst_WEIGHT_REFCOUNT;
  46. tensor2->format = Format_NCHW;
  47. tensor2->dataType = DataType_DT_FLOAT;
  48. tensor2->dims = {1, 1, 1, 2};
  49. tensor2->offset = -1;
  50. tensor2->data.resize(0);
  51. msSubgraph->allTensors.emplace_back(std::move(tensor2));
  52. std::unique_ptr<TensorDefT> tensor3(new (std::nothrow) TensorDefT);
  53. ASSERT_NE(tensor3, nullptr);
  54. tensor3->refCount = 0;
  55. tensor3->format = Format_NCHW;
  56. tensor3->dataType = DataType_DT_FLOAT;
  57. tensor3->dims = {1, 1, 1, 2};
  58. tensor3->offset = -1;
  59. tensor3->data.resize(0);
  60. msSubgraph->allTensors.emplace_back(std::move(tensor3));
  61. }
  62. void FreeOutputs(std::map<std::string, std::vector<Tensor *>> *outputs) {
  63. for (auto &output : (*outputs)) {
  64. for (auto &outputTensor : output.second) {
  65. delete outputTensor;
  66. }
  67. }
  68. outputs->clear();
  69. }
  70. void FreeInputs(std::vector<Tensor *> *inputs) {
  71. for (auto &input : *inputs) {
  72. input->SetData(nullptr);
  73. delete input;
  74. }
  75. inputs->clear();
  76. return;
  77. }
  78. TEST_F(GraphTest, CreateFromFileAdd) {
  79. auto msGraph = std::unique_ptr<GraphDefT>(new (std::nothrow) GraphDefT());
  80. ASSERT_NE(msGraph, nullptr);
  81. msGraph->name = "test1";
  82. auto msSubgraph = std::unique_ptr<SubGraphDefT>(new (std::nothrow) SubGraphDefT());
  83. ASSERT_NE(msSubgraph, nullptr);
  84. msSubgraph->name = msGraph->name + "_1";
  85. msSubgraph->inputIndex = {0, 1};
  86. msSubgraph->outputIndex = {2};
  87. std::unique_ptr<NodeDefT> node(new (std::nothrow) NodeDefT);
  88. ASSERT_NE(node, nullptr);
  89. std::unique_ptr<OpDefT> opDef(new (std::nothrow) OpDefT);
  90. ASSERT_NE(opDef, nullptr);
  91. node->opDef = std::move(opDef);
  92. node->opDef->isLastConv = false;
  93. node->opDef->inputIndex = {static_cast<unsigned int>(0), 1};
  94. node->opDef->outputIndex = {static_cast<unsigned int>(2)};
  95. node->opDef->name = msSubgraph->name + std::to_string(0);
  96. node->fmkType = FmkType_CAFFE;
  97. auto attr = std::unique_ptr<AddT>(new (std::nothrow) AddT());
  98. ASSERT_NE(attr, nullptr);
  99. attr->format = DataFormatType_NCHW;
  100. node->opDef->attr.type = OpT_Add;
  101. node->opDef->attr.value = attr.release();
  102. msSubgraph->nodes.emplace_back(std::move(node));
  103. InitMsGraphAllTensor(msSubgraph.get());
  104. msGraph->subgraphs.emplace_back(std::move(msSubgraph));
  105. flatbuffers::FlatBufferBuilder builder(1024);
  106. auto offset = mindspore::predict::GraphDef::Pack(builder, msGraph.get());
  107. builder.Finish(offset);
  108. int size = builder.GetSize();
  109. void *content = builder.GetBufferPointer();
  110. Context ctx;
  111. auto session = CreateSession(static_cast<char *>(content), size, ctx);
  112. std::vector<float> tmpT = {1, 2};
  113. void *in1Data = tmpT.data();
  114. std::vector<float> tmpT2 = {3, 5};
  115. void *in2Data = tmpT2.data();
  116. auto inputs = session->GetInput();
  117. inputs[0]->SetData(in1Data);
  118. inputs[1]->SetData(in2Data);
  119. auto ret = session->Run(inputs);
  120. EXPECT_EQ(0, ret);
  121. auto outputs = session->GetAllOutput();
  122. EXPECT_EQ(4, reinterpret_cast<float *>(outputs.begin()->second.front()->GetData())[0]);
  123. EXPECT_EQ(7, reinterpret_cast<float *>(outputs.begin()->second.front()->GetData())[1]);
  124. FreeOutputs(&outputs);
  125. FreeInputs(&inputs);
  126. }
  127. } // namespace predict
  128. } // namespace mindspore