Merge pull request !1766 from vlne-v1/I1J0M0-amp-do-auto-cast-failedtags/v0.5.0-beta
| @@ -65,6 +65,7 @@ test_temp_summary_event_file/ | |||
| *.ckpt | |||
| *.shp | |||
| *.pkl | |||
| *.pb | |||
| .clangd | |||
| mindspore/version.py | |||
| mindspore/default_config.py | |||
| @@ -253,7 +253,7 @@ std::string Dtype2String(const std::string &dtypes) { | |||
| std::string TypeId2String(TypeId type_id) { | |||
| auto iter = type_id_str_map.find(type_id); | |||
| if (iter == type_id_str_map.end()) { | |||
| MS_EXCEPTION(ArgumentError) << "Illegal input dtype." << TypeIdLabel(type_id); | |||
| return std::string(TypeIdLabel(type_id)); | |||
| } | |||
| return iter->second; | |||
| } | |||
| @@ -47,16 +47,6 @@ const std::vector<Signature> &GetSignature(const ValuePtr &function) { | |||
| return empty; | |||
| } | |||
| const std::string GetOpName(const ValuePtr &function) { | |||
| std::string name = ""; | |||
| if (function->isa<Primitive>()) { | |||
| name = function->cast<PrimitivePyPtr>()->name(); | |||
| } else if (function->isa<MetaFuncGraph>()) { | |||
| name = function->cast<MetaFuncGraphPtr>()->name(); | |||
| } | |||
| return name; | |||
| } | |||
| void ProcessDefault(const std::string &func_name, const AbstractBasePtrList &args_spec_list, | |||
| const std::vector<Signature> &signature, bool has_var, std::vector<AnfNodePtr> *const op_inputs) { | |||
| std::size_t sig_size = signature.size(); | |||
| @@ -93,7 +83,8 @@ void setMaxType(TypeId *max_type_id, TypeId *max_type, size_t *max_type_number, | |||
| *max_type_number = type_number; | |||
| } | |||
| TypeId GetMaxTypeId(const abstract::AbstractBasePtrList &args_spec_list, std::vector<size_t> indexs) { | |||
| TypeId GetMaxTypeId(const abstract::AbstractBasePtrList &args_spec_list, std::vector<size_t> indexs, | |||
| const std::set<size_t> &write_indexs) { | |||
| TypeId max_type_id = kTypeUnknown; | |||
| TypeId max_type = kTypeUnknown; | |||
| size_t max_type_number = 0; | |||
| @@ -103,7 +94,12 @@ TypeId GetMaxTypeId(const abstract::AbstractBasePtrList &args_spec_list, std::ve | |||
| TypeId arg_type = kTypeUnknown; | |||
| AbstractBasePtr arg_value = args_spec_list[index]; | |||
| if (arg_value->isa<abstract::AbstractRef>()) { | |||
| arg_value = arg_value->cast<abstract::AbstractRefPtr>()->ref(); | |||
| auto is_write = (write_indexs.find(index) != write_indexs.end()); | |||
| if (is_write) { | |||
| arg_value = arg_value->cast<abstract::AbstractRefPtr>()->ref_origin(); | |||
| } else { | |||
| arg_value = arg_value->cast<abstract::AbstractRefPtr>()->ref(); | |||
| } | |||
| } | |||
| if (arg_value->isa<abstract::AbstractTensor>()) { | |||
| auto tensor = arg_value->cast<abstract::AbstractTensorPtr>(); | |||
| @@ -157,7 +153,8 @@ TypeId GetMaxTypeId(const abstract::AbstractBasePtrList &args_spec_list, std::ve | |||
| // Get the largest type of index in the same SignatureEnumDType of arguments. | |||
| std::map<SignatureEnumDType, TypeId> GetMaxDtype(const std::vector<SignatureEnumDType> &dtypes, | |||
| const abstract::AbstractBasePtrList &args_spec_list) { | |||
| const abstract::AbstractBasePtrList &args_spec_list, | |||
| const std::set<size_t> &write_indexs) { | |||
| // record index for signature.dtypes of the same type | |||
| // eg. [T, T1, T, T2, T, T1, T3] -> {{T:(0,2,4)}, {T1:(1,5)}, {T2:(3)}, {T3:(6)}} | |||
| std::map<SignatureEnumDType, std::vector<size_t>> type_indexs; | |||
| @@ -192,7 +189,7 @@ std::map<SignatureEnumDType, TypeId> GetMaxDtype(const std::vector<SignatureEnum | |||
| (void)dst_type.insert(std::make_pair(type, kTypeUnknown)); | |||
| continue; | |||
| } | |||
| (void)dst_type.insert(std::make_pair(type, GetMaxTypeId(args_spec_list, indexs))); | |||
| (void)dst_type.insert(std::make_pair(type, GetMaxTypeId(args_spec_list, indexs, write_indexs))); | |||
| } | |||
| return dst_type; | |||
| } | |||
| @@ -205,9 +202,9 @@ AnfNodePtr DoCast(const AnfNodePtr ¶m, const TypeId &type_id, const FuncGrap | |||
| return NewCNode({cast_node, param, dtype_node}, graph); | |||
| } | |||
| void DoAutoCast(const std::vector<Signature> &signature, const abstract::AbstractBasePtrList &args_spec_list, | |||
| const FuncGraphPtr &graph, std::vector<AnfNodePtr> *const op_inputs, | |||
| const std::set<size_t> &write_indexs) { | |||
| void DoAutoCast(const std::string &func_name, const std::vector<Signature> &signature, | |||
| const abstract::AbstractBasePtrList &args_spec_list, const FuncGraphPtr &graph, | |||
| std::vector<AnfNodePtr> *const op_inputs, const std::set<size_t> &write_indexs) { | |||
| std::vector<SignatureEnumDType> dtypes; | |||
| (void)std::transform(signature.begin(), signature.end(), std::back_inserter(dtypes), | |||
| [](const Signature &sig) { return sig.dtype; }); | |||
| @@ -216,16 +213,23 @@ void DoAutoCast(const std::vector<Signature> &signature, const abstract::Abstrac | |||
| return; | |||
| } | |||
| // Stat the index of the arguments with the largest type in the same SignatureEnumDType. | |||
| std::map<SignatureEnumDType, TypeId> dst_type = GetMaxDtype(dtypes, args_spec_list); | |||
| std::map<SignatureEnumDType, TypeId> dst_type = GetMaxDtype(dtypes, args_spec_list, write_indexs); | |||
| // Identify which arg requires auto cast | |||
| for (size_t i = 0; i < args_spec_list.size(); ++i) { | |||
| auto it = dst_type.find(dtypes[i]); | |||
| if (it == dst_type.end() || it->second == kTypeUnknown) { | |||
| continue; | |||
| } | |||
| auto rw_it = write_indexs.find(i); | |||
| auto is_write = (rw_it != write_indexs.end()); | |||
| AbstractBasePtr arg_value = args_spec_list[i]; | |||
| if (arg_value->isa<abstract::AbstractRef>()) { | |||
| arg_value = arg_value->cast<abstract::AbstractRefPtr>()->ref(); | |||
| if (is_write) { | |||
| arg_value = arg_value->cast<abstract::AbstractRefPtr>()->ref_origin(); | |||
| } else { | |||
| arg_value = arg_value->cast<abstract::AbstractRefPtr>()->ref(); | |||
| } | |||
| } | |||
| TypeId arg_type_id = kTypeUnknown; | |||
| if (arg_value->isa<abstract::AbstractTensor>()) { | |||
| @@ -243,10 +247,9 @@ void DoAutoCast(const std::vector<Signature> &signature, const abstract::Abstrac | |||
| if (it_map == type_map.end()) { | |||
| continue; | |||
| } | |||
| auto rw_it = write_indexs.find(i); | |||
| if (rw_it != write_indexs.end()) { | |||
| if (is_write) { | |||
| if (arg_type_id != it->second) { | |||
| MS_LOG(EXCEPTION) << "In op '" << GetOpName(graph) << "', argument '" << args_spec_list[i] | |||
| MS_LOG(EXCEPTION) << "In op '" << func_name << "', argument '" << args_spec_list[i] | |||
| << "' can not cast type from '" << TypeIdLabel(arg_type_id) << "' to '" | |||
| << TypeIdLabel(it->second) << "' automatically."; | |||
| } | |||
| @@ -299,8 +302,8 @@ AnfNodePtr BuildNewCNode(const FuncGraphPtr &func_graph, const std::string &func | |||
| if (sig == SignatureEnumRW::kRWRead) { | |||
| param = func_graph->NewCNode({NewValueNode(prim::kPrimGetRefValue), param}); | |||
| } else if (sig == SignatureEnumRW::kRWWrite) { | |||
| param = func_graph->NewCNode({NewValueNode(prim::kPrimGetRefOrigin), param}); | |||
| write_indexs.insert(i); | |||
| param = func_graph->NewCNode({NewValueNode(prim::kPrimGetRefKey), param}); | |||
| } | |||
| // If sig is SignatureEnumRW::kRWRef, not do anything. | |||
| } else if (sig == SignatureEnumRW::kRWWrite && type->type_id() != kObjectTypeRefKey) { | |||
| @@ -310,7 +313,7 @@ AnfNodePtr BuildNewCNode(const FuncGraphPtr &func_graph, const std::string &func | |||
| } | |||
| // process default | |||
| ProcessDefault(func_name, args_spec_list, signature, has_var, &op_inputs); | |||
| DoAutoCast(signature, args_spec_list, func_graph, &op_inputs, write_indexs); | |||
| DoAutoCast(func_name, signature, args_spec_list, func_graph, &op_inputs, write_indexs); | |||
| return func_graph->NewCNode(op_inputs); | |||
| } | |||
| } // namespace | |||
| @@ -160,7 +160,7 @@ AbstractBasePtr InferImplGetRefOrigin(const AnalysisEnginePtr &, const Primitive | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| // arguments: value | |||
| if (args_spec_list.size() != 1) { | |||
| MS_LOG(EXCEPTION) << "get_ref_value requires 1 parameters, while the input size is " << args_spec_list.size() | |||
| MS_LOG(EXCEPTION) << "get_ref_origin requires 1 parameters, while the input size is " << args_spec_list.size() | |||
| << "."; | |||
| } | |||
| TypePtr type = args_spec_list[0]->GetTypeTrack(); | |||
| @@ -81,8 +81,8 @@ OptimizeIRPassLib::OptimizeIRPassLib() { | |||
| // Ref eliminate | |||
| make_ref_eliminate_ = MakeSubstitution(MakeRefEliminater(), "make_ref_eliminate", prim::kPrimMakeRef); | |||
| get_make_ref_eliminate_ = | |||
| MakeSubstitution(GetMakeRefEliminater(), "get_make_ref_eliminate", {prim::kPrimGetRefKey, prim::kPrimGetRefValue}); | |||
| get_make_ref_eliminate_ = MakeSubstitution(GetMakeRefEliminater(), "get_make_ref_eliminate", | |||
| {prim::kPrimGetRefKey, prim::kPrimGetRefValue, prim::kPrimGetRefOrigin}); | |||
| replace_refkey_by_param_ = | |||
| MakeSubstitution(ReplaceRefkeyByParam(), "replace_refkey_by_param", IsValueNode<RefKey>, opt::FORCE_RENORM); | |||
| @@ -48,6 +48,7 @@ class MakeRefEliminater : public AnfVisitor { | |||
| // {prim::kPrimGetRefKey, {prim::kPrimMakeRef, X, Y, Z}} -> X | |||
| // {prim::kPrimGetRefValue, {prim::kPrimMakeRef, X, Y, Z}} -> Y | |||
| // {prim::kPrimGetRefOrigin, {prim::kPrimMakeRef, X, Y, Z}} -> Z | |||
| class GetMakeRefEliminater : public AnfVisitor { | |||
| public: | |||
| AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { | |||
| @@ -71,6 +72,10 @@ class GetMakeRefEliminater : public AnfVisitor { | |||
| return ref->input(2); | |||
| } | |||
| if (cnode->IsApply(prim::kPrimGetRefOrigin)) { | |||
| return ref->input(3); | |||
| } | |||
| return nullptr; | |||
| } | |||
| }; | |||
| @@ -315,7 +315,7 @@ void FunctionBlock::InsertDependItemsBeforeReturn() { | |||
| ValueNodePtr make_tuple_op = NewValueNode(prim::kPrimMakeTuple); | |||
| ValueNodePtr depend_op = NewValueNode(prim::kPrimDepend); | |||
| ValueNodePtr get_refkey_op = NewValueNode(prim::kPrimGetRefKey); | |||
| ValueNodePtr get_ref_origin_op = NewValueNode(prim::kPrimGetRefOrigin); | |||
| ValueNodePtr stop_gradient_op = NewValueNode(prim::kPrimStopGradient); | |||
| const std::string primitive_name("assign"); | |||
| const std::string module_name("mindspore.ops.functional"); | |||
| @@ -329,8 +329,8 @@ void FunctionBlock::InsertDependItemsBeforeReturn() { | |||
| vec_states.emplace_back(make_tuple_op); | |||
| for (auto &item : state_assign_) { | |||
| auto source = ReadVariable(item.second); | |||
| auto refkey = func_graph()->NewCNode({get_refkey_op, item.first}); | |||
| auto assign = func_graph()->NewCNode({assign_op, refkey, source}); | |||
| auto origin = func_graph()->NewCNode({get_ref_origin_op, item.first}); | |||
| auto assign = func_graph()->NewCNode({assign_op, origin, source}); | |||
| MS_LOG(INFO) << "SetState read " << item.first->ToString() << ", " << item.second; | |||
| vec_states.emplace_back(assign); | |||
| } | |||
| @@ -801,8 +801,8 @@ bool AbstractRef::operator==(const AbstractBase &other) const { | |||
| std::string AbstractRef::ToString() const { | |||
| std::ostringstream buffer; | |||
| buffer << type_name() << "(" | |||
| << "key: " << ref_key_->ToString() << "ref_value: " << ref_->ToString() | |||
| << "origin_value: " << ref_origin_->ToString(); | |||
| << "key: " << ref_key_->ToString() << " ref_value: " << ref_->ToString() | |||
| << " origin_value: " << ref_origin_->ToString(); | |||
| auto value = GetValueTrack(); | |||
| if (value) { | |||
| buffer << ", value: " << value->ToString(); | |||
| @@ -783,7 +783,7 @@ class RefToEmbedEvaluator : public SymbolicPrimEvaluator { | |||
| AbstractBasePtr abs = node_conf->GetEvaluatedValue()->abstract(); | |||
| AbstractRefPtr ref_abs = abs->cast<AbstractRefPtr>(); | |||
| if (ref_abs == nullptr) { | |||
| MS_LOG(ERROR) << "The first parameter of RefToEmbed should be Ref."; | |||
| MS_LOG(ERROR) << "The first parameter of RefToEmbed should be Ref, but " << abs->ToString(); | |||
| return nullptr; | |||
| } | |||
| auto key_abs = ref_abs->ref_key(); | |||
| @@ -170,7 +170,7 @@ def get_py_obj_dtype(obj): | |||
| Type of MindSpore type. | |||
| """ | |||
| # Tensor | |||
| if hasattr(obj, 'dtype'): | |||
| if hasattr(obj, 'dtype') and callable(obj.dtype) and isinstance(obj.dtype(), typing.Type): | |||
| return tensor_type(obj.dtype()) | |||
| if hasattr(obj, '__primitive_flag__') or hasattr(obj, 'construct'): | |||
| return function | |||
| @@ -31,7 +31,9 @@ from ...common.tensor import Tensor | |||
| from ..operations.math_ops import _infer_shape_reduce | |||
| from .._utils import get_concat_offset | |||
| from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register | |||
| from ..._c_expression import signature_rw as sig_rw | |||
| from ..._c_expression import signature_kind as sig_kind | |||
| from ..._c_expression import signature_dtype as sig_dtype | |||
| def _check_infer_attr_reduce(axis, keep_dims, prim_name): | |||
| validator.check_value_type('keep_dims', keep_dims, [bool], prim_name) | |||
| @@ -2156,13 +2158,17 @@ class ScatterUpdate(PrimitiveWithInfer): | |||
| >>> input_x = mindspore.Parameter(Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]), mindspore.float32)) | |||
| >>> indices = Tensor(np.array([[0, 0], [1, 1]]), mindspore.int32) | |||
| >>> update = Tensor(np.array([1.0, 2.2]), mindspore.float32) | |||
| >>> op = P.ScatterNdUpdate() | |||
| >>> op = P.ScatterUpdate() | |||
| >>> output = op(input_x, indices, update) | |||
| """ | |||
| __mindspore_signature__ = ( | |||
| ('x', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), | |||
| ('indices', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T1), | |||
| ('value', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T) | |||
| ) | |||
| @prim_attr_register | |||
| def __init__(self, use_locking=True): | |||
| """Init ScatterNdUpdate""" | |||
| """Init ScatterUpdate""" | |||
| self.init_prim_io_names(inputs=['x', 'indices', 'value'], outputs=['y']) | |||
| def infer_shape(self, x_shape, indices_shape, value_shape): | |||
| @@ -2201,7 +2207,11 @@ class ScatterNdUpdate(PrimitiveWithInfer): | |||
| >>> op = P.ScatterNdUpdate() | |||
| >>> output = op(input_x, indices, update) | |||
| """ | |||
| __mindspore_signature__ = ( | |||
| ('x', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), | |||
| ('indices', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T1), | |||
| ('value', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T) | |||
| ) | |||
| @prim_attr_register | |||
| def __init__(self, use_locking=True): | |||
| """Init ScatterNdUpdate""" | |||
| @@ -179,7 +179,7 @@ class AssignAdd(PrimitiveWithInfer): | |||
| return value | |||
| def infer_dtype(self, variable, value): | |||
| args = {"value": value} | |||
| args = {"variable": variable, "value": value} | |||
| validator.check_scalar_or_tensor_type_same(args, mstype.number_type, self.name) | |||
| return value | |||
| @@ -222,7 +222,7 @@ class AssignSub(PrimitiveWithInfer): | |||
| return value | |||
| def infer_dtype(self, variable, value): | |||
| args = {"value": value} | |||
| args = {"variable": variable, "value": value} | |||
| validator.check_scalar_or_tensor_type_same(args, mstype.number_type, self.name) | |||
| return value | |||
| @@ -58,6 +58,8 @@ class Assign(PrimitiveWithInfer): | |||
| return variable | |||
| def infer_dtype(self, variable, value): | |||
| args = {"variable": variable, "value": value} | |||
| validator.check_tensor_type_same(args, (mstype.bool_,) + mstype.number_type, self.name) | |||
| return variable | |||
| @@ -1,3 +1,18 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """test layer switch""" | |||
| import numpy as np | |||
| import mindspore | |||
| @@ -345,19 +345,6 @@ class Conv2dNativeNet(nn.Cell): | |||
| return self.flatten(self.conv(input_x, self.weight)) | |||
| class MakeRefKeyNet(nn.Cell): | |||
| """ MakeRefKeyNet definition """ | |||
| def __init__(self): | |||
| super(MakeRefKeyNet, self).__init__() | |||
| self.y = Parameter(Tensor([1.0], mindspore.float32), name="y") | |||
| def construct(self, x): | |||
| key = P.MakeRefKey("y")() | |||
| P.Assign()(key, x) | |||
| return x | |||
| class StateNet(nn.Cell): | |||
| """ StateTestTensor definition """ | |||
| @@ -538,10 +525,6 @@ test_cases = [ | |||
| 'block': Grad(NetWithLossClass(Conv2dNativeNet())), | |||
| 'desc_inputs': [Tensor(np.ones([1, 3, 16, 16], np.float32)), Tensor(np.zeros([1, 1764], np.float32))], | |||
| }), | |||
| ('MakeRefKey', { | |||
| 'block': MakeRefKeyNet(), | |||
| 'desc_inputs': [Tensor([2.0], mindspore.float32)], | |||
| }), | |||
| ('StateTest', { | |||
| 'block': StateNet(), | |||
| 'desc_inputs': [Tensor(np.ones([2, 1, 2, 2]).astype(np.float32))], | |||
| @@ -0,0 +1,75 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ | |||
| test assign sub | |||
| """ | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore.context as context | |||
| import mindspore.nn as nn | |||
| import mindspore.ops.operations as P | |||
| from mindspore import Tensor | |||
| from mindspore.common.initializer import initializer | |||
| from mindspore.common.parameter import Parameter | |||
| import mindspore as ms | |||
| class AssignW(nn.Cell): | |||
| def __init__(self): | |||
| super(AssignW, self).__init__() | |||
| self.assign = P.Assign() | |||
| def construct(self, x, w): | |||
| self.assign(x, w) | |||
| return x | |||
| class Net(nn.Cell): | |||
| def __init__(self): | |||
| super(Net, self).__init__() | |||
| self.b = Parameter(initializer('ones', [5]), name='b') | |||
| self.assign = AssignW() | |||
| def construct(self, value): | |||
| return self.assign(self.b, value) | |||
| def test_assign_through_cell(): | |||
| context.set_context(mode=context.GRAPH_MODE, save_graphs=True) | |||
| net = Net() | |||
| net.to_float(ms.float16) | |||
| net.add_flags_recursive(fp16=False) | |||
| input_data = Tensor(np.ones([5]).astype(np.float32)) | |||
| net(input_data) | |||
| with pytest.raises(TypeError): | |||
| net(None) | |||
| class NetScatterNdUpdate(nn.Cell): | |||
| def __init__(self): | |||
| super(NetScatterNdUpdate, self).__init__() | |||
| self.b = Parameter(initializer('ones', [5, 5]), name='b') | |||
| self.scatter = P.ScatterNdUpdate() | |||
| def construct(self, idx, x): | |||
| return self.scatter(self.b, idx, x) | |||
| def test_scatter_nd_update(): | |||
| context.set_context(mode=context.GRAPH_MODE) | |||
| net = NetScatterNdUpdate() | |||
| x = Tensor(np.ones([5]).astype(np.float16)) | |||
| idx = Tensor(np.ones([1]).astype(np.int32)) | |||
| net(idx, x) | |||