Merge pull request !2024 from Kang/mastertags/v0.5.0-beta
| @@ -286,6 +286,22 @@ AnfNodePtr MixedPrecisionCastHelper(AnfNodePtr source_node, AbstractBasePtr node | |||||
| ++idx; | ++idx; | ||||
| } | } | ||||
| target_node = func_graph->NewCNode(nodes); | target_node = func_graph->NewCNode(nodes); | ||||
| } else if (node_type->isa<AbstractDictionary>()) { | |||||
| auto x = node_type->cast<AbstractDictionaryPtr>(); | |||||
| auto &items = x->elements(); | |||||
| std::vector<AnfNodePtr> dict_key_nodes; | |||||
| std::vector<AnfNodePtr> dict_value_nodes; | |||||
| dict_key_nodes.emplace_back(NewValueNode(prim::kPrimMakeTuple)); | |||||
| dict_value_nodes.emplace_back(NewValueNode(prim::kPrimMakeTuple)); | |||||
| for (const auto &item : items) { | |||||
| AnfNodePtr dict_value_node = | |||||
| func_graph->NewCNode({NewValueNode(prim::kPrimDictGetItem), source_node, NewValueNode(item.first)}); | |||||
| AnfNodePtr node = MixedPrecisionCastHelper(dict_value_node, item.second, target_type, func_graph); | |||||
| dict_key_nodes.emplace_back(NewValueNode(item.first)); | |||||
| dict_value_nodes.emplace_back(node); | |||||
| } | |||||
| target_node = func_graph->NewCNode({NewValueNode(prim::kPrimMakeDict), func_graph->NewCNode(dict_key_nodes), | |||||
| func_graph->NewCNode(dict_value_nodes)}); | |||||
| } | } | ||||
| return target_node; | return target_node; | ||||
| } | } | ||||
| @@ -308,7 +308,7 @@ EvaluatorPtr GetPrimEvaluator(const PrimitivePtr &prim, const AnalysisEnginePtr | |||||
| evaluator = std::make_shared<UnpackGraphEvaluator>(prim); | evaluator = std::make_shared<UnpackGraphEvaluator>(prim); | ||||
| return evaluator; | return evaluator; | ||||
| } | } | ||||
| if (prim->name() == prim::kPrimMixedPrecisionCast->name()) { | |||||
| if (prim->Hash() == prim::kPrimMixedPrecisionCast->Hash() && prim->name() == prim::kPrimMixedPrecisionCast->name()) { | |||||
| evaluator = std::make_shared<MixedPrecisionCastEvaluator>(prim); | evaluator = std::make_shared<MixedPrecisionCastEvaluator>(prim); | ||||
| return evaluator; | return evaluator; | ||||
| } | } | ||||
| @@ -25,6 +25,7 @@ from mindspore.nn import Momentum | |||||
| from mindspore.nn import TrainOneStepCell, WithLossCell | from mindspore.nn import TrainOneStepCell, WithLossCell | ||||
| from mindspore.ops import composite as C | from mindspore.ops import composite as C | ||||
| from mindspore.ops import operations as P | from mindspore.ops import operations as P | ||||
| from mindspore.ops import functional as F | |||||
| from mindspore.train.parallel_utils import ParallelMode | from mindspore.train.parallel_utils import ParallelMode | ||||
| from tests.ops_common import convert | from tests.ops_common import convert | ||||
| from ....train_step_wrap import train_step_with_loss_warp | from ....train_step_wrap import train_step_with_loss_warp | ||||
| @@ -185,3 +186,36 @@ def test_grad_conv_prelu(): | |||||
| net = GetParamGrad(net) | net = GetParamGrad(net) | ||||
| net.set_train() | net.set_train() | ||||
| net(*all_inputs) | net(*all_inputs) | ||||
| def test_dict_cast(): | |||||
| class FirstNet(nn.Cell): | |||||
| def __init__(self): | |||||
| super(FirstNet, self).__init__() | |||||
| self.net = SecondNet() | |||||
| self.sub = P.Sub() | |||||
| def construct(self, tensor_a, tensor_b): | |||||
| a = F.mixed_precision_cast(mstype.float16, tensor_a) | |||||
| b = F.mixed_precision_cast(mstype.float16, tensor_b) | |||||
| c = self.sub(a, b) | |||||
| dictionary = {"key": a} | |||||
| result = self.net(c, key1=a, key2=dictionary) | |||||
| return result | |||||
| class SecondNet(nn.Cell): | |||||
| def __init__(self): | |||||
| super(SecondNet, self).__init__() | |||||
| self.add = P.TensorAdd() | |||||
| def construct(self, tensor_c, **kwargs): | |||||
| d = F.mixed_precision_cast(mstype.float16, tensor_c) | |||||
| dict_cast = F.mixed_precision_cast(mstype.float16, kwargs) | |||||
| e = self.add(d, dict_cast["key1"]) | |||||
| f = self.add(e, dict_cast["key2"]["key"]) | |||||
| return f | |||||
| x = Tensor(np.array([1, 2.5, 3.5]), mstype.float32) | |||||
| y = Tensor(np.array([4, 5.5, 6.5]), mstype.float32) | |||||
| net = FirstNet() | |||||
| net(x, y) | |||||