Browse Source

!22409 Set "dst_type" attribute when creating Cast op

Merge pull request !22409 from DeshiChen/0826_cast_dst_type
tags/v1.5.0-rc1
i-robot Gitee 4 years ago
parent
commit
7398162ac2
1 changed files with 1 additions and 0 deletions
  1. +1
    -0
      mindspore/ccsrc/backend/optimizer/gpu/reduce_precision_fusion.cc

+ 1
- 0
mindspore/ccsrc/backend/optimizer/gpu/reduce_precision_fusion.cc View File

@@ -36,6 +36,7 @@ void ReducePrecision(const FuncGraphPtr &graph, const AnfNodePtr &node, size_t i
std::vector<AnfNodePtr> inputs = {NewValueNode(prim), AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), i)};
auto cast = graph->NewCNode(inputs);
MS_EXCEPTION_IF_NULL(cast);
prim->AddAttr("dst_type", TypeIdToType(cast_type));
auto cast_shape = {AnfAlgo::GetInputDeviceShape(node, i)};
AnfAlgo::SetOutputInferTypeAndShape({cast_type}, cast_shape, cast.get());
FuncGraphManagerPtr manager = graph->manager();


Loading…
Cancel
Save