|
|
|
@@ -36,6 +36,7 @@ void ReducePrecision(const FuncGraphPtr &graph, const AnfNodePtr &node, size_t i |
|
|
|
std::vector<AnfNodePtr> inputs = {NewValueNode(prim), AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), i)}; |
|
|
|
auto cast = graph->NewCNode(inputs); |
|
|
|
MS_EXCEPTION_IF_NULL(cast); |
|
|
|
prim->AddAttr("dst_type", TypeIdToType(cast_type)); |
|
|
|
auto cast_shape = {AnfAlgo::GetInputDeviceShape(node, i)}; |
|
|
|
AnfAlgo::SetOutputInferTypeAndShape({cast_type}, cast_shape, cast.get()); |
|
|
|
FuncGraphManagerPtr manager = graph->manager(); |
|
|
|
|