Browse Source

add DeviceType

pull/15438/head
zhujingxuan 5 years ago
parent
commit
71a20c7353
3 changed files with 5 additions and 0 deletions
  1. +1
    -0
      mindspore/core/ops/op_utils.h
  2. +1
    -0
      mindspore/lite/schema/model.fbs
  3. +3
    -0
      mindspore/lite/tools/anf_exporter/anf_exporter.cc

+ 1
- 0
mindspore/core/ops/op_utils.h View File

@@ -230,6 +230,7 @@ constexpr auto kSpliceContext = "context";
constexpr auto kSpliceForwardIndexes = "forward_indexes";
constexpr auto kSpliceOutputDims = "output_dim";
constexpr auto kSideEffectIO = "side_effect_io";
constexpr auto kDeviceType = "device_type";
const std::set<TypePtr> common_valid_types = {kInt8, kInt16, kInt32, kInt64, kUInt8, kUInt16,
kUInt32, kUInt64, kFloat16, kFloat32, kFloat64};



+ 1
- 0
mindspore/lite/schema/model.fbs View File

@@ -74,6 +74,7 @@ table CNode {
inputIndex: [uint];
outputIndex: [uint];
quantType: QuantType = QUANT_NONE;
deviceType: int = -1; // 1 = CPU, 2 = GPU, 3 = NPU, -1 = UNKNOWN
}

table SubGraph {


+ 3
- 0
mindspore/lite/tools/anf_exporter/anf_exporter.cc View File

@@ -26,6 +26,7 @@
#include "tools/common/tensor_util.h"
#include "abstract/abstract_value.h"
#include "mindspore/core/ir/primitive.h"
#include "mindspore/core/ops/op_utils.h"
#include "ops/fusion/partial_fusion.h"
#include "ops/depend.h"
#include "ops/make_tuple.h"
@@ -341,6 +342,8 @@ int AnfExporter::Anf2Fb(const FuncGraphPtr &func_graph, const std::unique_ptr<sc
}
node->name = cnode->fullname_with_scope();
node->primitive = std::unique_ptr<schema::PrimitiveT>(primT);
auto device_type_attr = cnode->GetAttr(mindspore::ops::kDeviceType);
node->deviceType = (device_type_attr != nullptr) ? GetValue<int32_t>(device_type_attr) : -1;
ret = SetOpInputNode(cnode, meta_graphT, node.get());
if (ret != RET_OK) {
MS_LOG(ERROR) << "SetOpInputNode failed";


Loading…
Cancel
Save