| @@ -321,6 +321,13 @@ AnfNodePtr MixedPrecisionCastHelper(AnfNodePtr source_node, AbstractBasePtr node | |||||
| } | } | ||||
| target_node = func_graph->NewCNode({NewValueNode(prim::kPrimMakeDict), func_graph->NewCNode(dict_key_nodes), | target_node = func_graph->NewCNode({NewValueNode(prim::kPrimMakeDict), func_graph->NewCNode(dict_key_nodes), | ||||
| func_graph->NewCNode(dict_value_nodes)}); | func_graph->NewCNode(dict_value_nodes)}); | ||||
| } else if (node_type->isa<AbstractKeywordArg>()) { | |||||
| auto x = node_type->cast<AbstractKeywordArgPtr>(); | |||||
| std::string kwarg_key = x->get_key(); | |||||
| AnfNodePtr kwarg_value_node = | |||||
| func_graph->NewCNode({NewValueNode(prim::kPrimExtractKeywordArg), NewValueNode(kwarg_key), source_node}); | |||||
| AnfNodePtr node = MixedPrecisionCastHelper(kwarg_value_node, x->get_arg(), target_type, func_graph); | |||||
| target_node = func_graph->NewCNode({NewValueNode(prim::kPrimMakeKeywordArg), NewValueNode(kwarg_key), node}); | |||||
| } | } | ||||
| return target_node; | return target_node; | ||||
| } | } | ||||
| @@ -219,3 +219,31 @@ def test_dict_cast(): | |||||
| y = Tensor(np.array([4, 5.5, 6.5]), mstype.float32) | y = Tensor(np.array([4, 5.5, 6.5]), mstype.float32) | ||||
| net = FirstNet() | net = FirstNet() | ||||
| net(x, y) | net(x, y) | ||||
| def test_kwarg_cast(): | |||||
| class FirstNet(nn.Cell): | |||||
| def __init__(self): | |||||
| super(FirstNet, self).__init__() | |||||
| self.net = SecondNet().add_flags_recursive(fp16=True) | |||||
| self.add = P.TensorAdd() | |||||
| def construct(self, tensor_a, tensor_b): | |||||
| tensor_c = self.add(tensor_a, tensor_b) | |||||
| dictionary = {"key": tensor_a} | |||||
| result = self.net(key1=tensor_c, key2=dictionary) | |||||
| return result | |||||
| class SecondNet(nn.Cell): | |||||
| def __init__(self): | |||||
| super(SecondNet, self).__init__() | |||||
| self.add = P.TensorAdd() | |||||
| def construct(self, key1=1, key2=2): | |||||
| tensor_d = self.add(key1, key2["key"]) | |||||
| return tensor_d | |||||
| x = Tensor(np.array([1, 2.5, 3.5]), mstype.float32) | |||||
| y = Tensor(np.array([4, 5.5, 6.5]), mstype.float32) | |||||
| net = FirstNet() | |||||
| net(x, y) | |||||