| @@ -31,12 +31,10 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| // namespace to support composite operators definition | // namespace to support composite operators definition | ||||
| namespace prim { | namespace prim { | ||||
| namespace { | |||||
| using PatternListType = std::initializer_list<BaseRef>; | |||||
| const std::map<TypeId, size_t> type_map = {{kNumberTypeBool, 1}, {kNumberTypeInt8, 2}, {kNumberTypeUInt8, 3}, | const std::map<TypeId, size_t> type_map = {{kNumberTypeBool, 1}, {kNumberTypeInt8, 2}, {kNumberTypeUInt8, 3}, | ||||
| {kNumberTypeInt16, 4}, {kNumberTypeInt32, 5}, {kNumberTypeInt64, 6}, | {kNumberTypeInt16, 4}, {kNumberTypeInt32, 5}, {kNumberTypeInt64, 6}, | ||||
| {kNumberTypeFloat16, 7}, {kNumberTypeFloat32, 8}, {kNumberTypeFloat64, 9}}; | {kNumberTypeFloat16, 7}, {kNumberTypeFloat32, 8}, {kNumberTypeFloat64, 9}}; | ||||
| namespace { | |||||
| const std::vector<Signature> &GetSignature(const ValuePtr &function) { | const std::vector<Signature> &GetSignature(const ValuePtr &function) { | ||||
| static const auto empty = std::vector<Signature>(); | static const auto empty = std::vector<Signature>(); | ||||
| if (function->isa<Primitive>() && function->cast<PrimitivePtr>()->has_signature()) { | if (function->isa<Primitive>() && function->cast<PrimitivePtr>()->has_signature()) { | ||||
| @@ -56,6 +56,8 @@ class DoSignatureMetaFuncGraph : public MetaFuncGraph { | |||||
| }; | }; | ||||
| using RWSignaturePtr = std::shared_ptr<DoSignatureMetaFuncGraph>; | using RWSignaturePtr = std::shared_ptr<DoSignatureMetaFuncGraph>; | ||||
| extern const std::map<TypeId, size_t> type_map; | |||||
| AnfNodePtr GenerateCNode(const FuncGraphPtr &func_graph, const std::string &func_name, const ValuePtr &function, | AnfNodePtr GenerateCNode(const FuncGraphPtr &func_graph, const std::string &func_name, const ValuePtr &function, | ||||
| const AbstractBasePtrList &args_spec_list, const AnfNodePtrList &old_node_inputs); | const AbstractBasePtrList &args_spec_list, const AnfNodePtrList &old_node_inputs); | ||||
| } // namespace prim | } // namespace prim | ||||
| @@ -160,36 +160,102 @@ std::map<SignatureEnumDType, std::vector<size_t>> GetTypeIndex(const std::vector | |||||
| return type_indexes; | return type_indexes; | ||||
| } | } | ||||
| std::map<SignatureEnumDType, size_t> GetDstType(const py::tuple &py_args, | |||||
| std::map<SignatureEnumDType, TypeId> GetDstType(const py::tuple &py_args, | |||||
| const std::map<SignatureEnumDType, std::vector<size_t>> &type_indexes) { | const std::map<SignatureEnumDType, std::vector<size_t>> &type_indexes) { | ||||
| std::map<SignatureEnumDType, size_t> dst_type; | |||||
| std::map<SignatureEnumDType, TypeId> dst_type; | |||||
| for (auto it = type_indexes.begin(); it != type_indexes.end(); (void)++it) { | for (auto it = type_indexes.begin(); it != type_indexes.end(); (void)++it) { | ||||
| auto type = it->first; | auto type = it->first; | ||||
| auto indexes = it->second; | auto indexes = it->second; | ||||
| if (indexes.size() < 2) { | |||||
| if (type == SignatureEnumDType::kDTypeEmptyDefaultValue || indexes.size() < 2) { | |||||
| continue; | continue; | ||||
| } | } | ||||
| size_t m_index = indexes[0]; | |||||
| for (size_t i = 1; i < indexes.size(); ++i) { | |||||
| if (py::isinstance<tensor::Tensor>(py_args[indexes[i]])) { | |||||
| m_index = indexes[i]; | |||||
| size_t priority = 0; | |||||
| TypeId max_type = TypeId::kTypeUnknown; | |||||
| bool has_float = false; | |||||
| bool has_int = false; | |||||
| for (size_t index : indexes) { | |||||
| if (!has_float && py::isinstance<py::float_>(py_args[index])) { | |||||
| has_float = true; | |||||
| } | |||||
| if (!has_int && !py::isinstance<py::bool_>(py_args[index]) && py::isinstance<py::int_>(py_args[index])) { | |||||
| has_int = true; | |||||
| } | |||||
| if (py::isinstance<tensor::Tensor>(py_args[index])) { | |||||
| auto arg = py::cast<tensor::TensorPtr>(py_args[index]); | |||||
| TypeId arg_type_id = arg->data_type(); | |||||
| auto type_priority = prim::type_map.find(arg_type_id); | |||||
| if (type_priority->second > priority) { | |||||
| max_type = type_priority->first; | |||||
| priority = type_priority->second; | |||||
| } | |||||
| } | |||||
| } | |||||
| if (max_type == TypeId::kNumberTypeBool) { | |||||
| if (has_int) { | |||||
| max_type = TypeId::kNumberTypeInt32; | |||||
| } | |||||
| if (has_float) { | |||||
| max_type = TypeId::kNumberTypeFloat32; | |||||
| } | } | ||||
| } | } | ||||
| (void)dst_type.insert(std::make_pair(type, m_index)); | |||||
| (void)dst_type.insert(std::make_pair(type, max_type)); | |||||
| } | } | ||||
| return dst_type; | return dst_type; | ||||
| } | } | ||||
| std::string TypeIdToMsTypeStr(const TypeId &type_id) { | |||||
| switch (type_id) { | |||||
| case kNumberTypeFloat16: | |||||
| return "float16"; | |||||
| case kNumberTypeFloat32: | |||||
| return "float32"; | |||||
| case kNumberTypeFloat64: | |||||
| return "float64"; | |||||
| case kNumberTypeInt8: | |||||
| return "int8"; | |||||
| case kNumberTypeInt16: | |||||
| return "int16"; | |||||
| case kNumberTypeInt32: | |||||
| return "int32"; | |||||
| case kNumberTypeInt64: | |||||
| return "int64"; | |||||
| case kNumberTypeUInt8: | |||||
| return "uint8"; | |||||
| case kNumberTypeUInt16: | |||||
| return "uint16"; | |||||
| case kNumberTypeUInt32: | |||||
| return "uint32"; | |||||
| case kNumberTypeUInt64: | |||||
| return "uint64"; | |||||
| case kNumberTypeBool: | |||||
| return "bool_"; | |||||
| default: | |||||
| MS_LOG(EXCEPTION) << "For implicit type conversion, not support the type: " << TypeIdToType(type_id); | |||||
| } | |||||
| } | |||||
| py::object DoAutoCast(const py::object arg, const TypeId &type_id) { | |||||
| py::tuple args(3); | |||||
| std::string module_name = "mindspore.ops.functional"; | |||||
| std::string op_name = "cast"; | |||||
| args[0] = parse::python_adapter::GetPyFn(module_name, op_name); | |||||
| args[1] = "Cast"; | |||||
| std::string dst_type_str = TypeIdToMsTypeStr(type_id); | |||||
| module_name = "mindspore.common.dtype"; | |||||
| py::object dst_type = parse::python_adapter::GetPyFn(module_name, dst_type_str); | |||||
| py::tuple inputs(2); | |||||
| inputs[0] = arg; | |||||
| inputs[1] = dst_type; | |||||
| args[2] = inputs; | |||||
| return RunOp(args)[0]; | |||||
| } | |||||
| py::tuple ConvertInputs(const PrimitivePyPtr &prim, const py::list &args, py::tuple *const out_args, | py::tuple ConvertInputs(const PrimitivePyPtr &prim, const py::list &args, py::tuple *const out_args, | ||||
| py::list *const out_args_list) { | py::list *const out_args_list) { | ||||
| auto &py_args = *out_args; | auto &py_args = *out_args; | ||||
| py::tuple input_mask(args.size()); | py::tuple input_mask(args.size()); | ||||
| for (size_t i = 0; i < args.size(); ++i) { | for (size_t i = 0; i < args.size(); ++i) { | ||||
| if (py::hasattr(args[i], "__parameter__")) { | |||||
| input_mask[i] = true; | |||||
| } else { | |||||
| input_mask[i] = false; | |||||
| } | |||||
| input_mask[i] = py::hasattr(args[i], "__parameter__"); | |||||
| py_args[i] = GetTupleObj(args[i]); | py_args[i] = GetTupleObj(args[i]); | ||||
| } | } | ||||
| auto signature = prim->signatures(); | auto signature = prim->signatures(); | ||||
| @@ -197,26 +263,36 @@ py::tuple ConvertInputs(const PrimitivePyPtr &prim, const py::list &args, py::tu | |||||
| (void)std::transform(signature.begin(), signature.end(), std::back_inserter(dtypes), | (void)std::transform(signature.begin(), signature.end(), std::back_inserter(dtypes), | ||||
| [](const Signature &sig) { return sig.dtype; }); | [](const Signature &sig) { return sig.dtype; }); | ||||
| int empty_dtype_count = std::count(dtypes.begin(), dtypes.end(), SignatureEnumDType::kDTypeEmptyDefaultValue); | int empty_dtype_count = std::count(dtypes.begin(), dtypes.end(), SignatureEnumDType::kDTypeEmptyDefaultValue); | ||||
| if (dtypes.size() == 0 || static_cast<int>(dtypes.size()) == empty_dtype_count) { | |||||
| if (dtypes.empty() || static_cast<int>(dtypes.size()) == empty_dtype_count) { | |||||
| return input_mask; | return input_mask; | ||||
| } | } | ||||
| auto type_indexes = GetTypeIndex(dtypes); | auto type_indexes = GetTypeIndex(dtypes); | ||||
| auto dst_type = GetDstType(py_args, type_indexes); | auto dst_type = GetDstType(py_args, type_indexes); | ||||
| for (size_t i = 0; i < py_args.size(); ++i) { | |||||
| for (size_t i = 0; i < dtypes.size(); ++i) { | |||||
| if (dtypes[i] == SignatureEnumDType::kDTypeEmptyDefaultValue) { | |||||
| continue; | |||||
| } | |||||
| auto it = dst_type.find(dtypes[i]); | auto it = dst_type.find(dtypes[i]); | ||||
| if (it != dst_type.end() && it->second != i && | |||||
| (py::isinstance<py::int_>(py_args[i]) || py::isinstance<py::float_>(py_args[i]))) { | |||||
| auto tensor_ptr = py::cast<tensor::TensorPtr>(py_args[it->second]); | |||||
| if (py::isinstance<py::int_>(py_args[i])) { | |||||
| py_args[i] = std::make_shared<tensor::Tensor>(py::cast<py::int_>(py_args[i]), tensor_ptr->Dtype()); | |||||
| (*out_args_list)[i] = py_args[i]; | |||||
| } else { | |||||
| double arg_value = py::cast<py::float_>(py_args[i]); | |||||
| py_args[i] = std::make_shared<tensor::Tensor>(arg_value, tensor_ptr->Dtype()); | |||||
| (*out_args_list)[i] = py_args[i]; | |||||
| } | |||||
| if (it == dst_type.end() || it->second == kTypeUnknown) { | |||||
| continue; | continue; | ||||
| } | } | ||||
| if (py::isinstance<tensor::Tensor>(py_args[i])) { | |||||
| auto arg = py::cast<tensor::TensorPtr>(py_args[i]); | |||||
| if (arg->data_type() == it->second) { | |||||
| continue; | |||||
| } | |||||
| if (signature[i].rw == SignatureEnumRW::kRWWrite) { | |||||
| MS_LOG(EXCEPTION) << "In op '" << prim->name() << "', \n" | |||||
| << "the type of writable argument is '" << TypeIdToMsTypeStr(arg->data_type()) << "', " | |||||
| << "but the largest type in the same SignatureEumDtype is '" << TypeIdToMsTypeStr(it->second) | |||||
| << "'. The writable arg type is not equal to the largest type, " | |||||
| << "so can not cast automatically."; | |||||
| } | |||||
| } | |||||
| py::object cast_output = DoAutoCast(py_args[i], it->second); | |||||
| (*out_args)[i] = cast_output; | |||||
| (*out_args_list)[i] = cast_output; | |||||
| } | } | ||||
| return input_mask; | return input_mask; | ||||
| } | } | ||||
| @@ -0,0 +1,81 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ test implicit conversion """ | |||||
| import numpy as np | |||||
| from mindspore import Tensor | |||||
| def test_float_tensor_and_int_add(): | |||||
| x = Tensor(np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=np.float32)) | |||||
| y = 2 | |||||
| ret_actual = x + y | |||||
| ret_expect = Tensor(np.array([[2.1, 2.2, 2.3], [2.4, 2.5, 2.6]], dtype=np.float32)) | |||||
| assert (ret_actual.asnumpy() == ret_expect.asnumpy()).all() | |||||
| def test_bool_tensor_and_float_add(): | |||||
| x = Tensor(np.array([[True, False], [False, True]], dtype=np.bool_)) | |||||
| y = 3.3 | |||||
| ret_actual = x + y | |||||
| ret_expect = Tensor(np.array([[4.3, 3.3], [3.3, 4.3]], dtype=np.float32)) | |||||
| assert (ret_actual.asnumpy() == ret_expect.asnumpy()).all() | |||||
| def test_bool_tensor_and_int_add(): | |||||
| x = Tensor(np.array([[True, False], [False, True]], dtype=np.bool_)) | |||||
| y = 3 | |||||
| ret_actual = x + y | |||||
| ret_expect = Tensor(np.array([[4, 3], [3, 4]], dtype=np.int32)) | |||||
| assert (ret_actual.asnumpy() == ret_expect.asnumpy()).all() | |||||
| def test_bool_and_int_tensor_add(): | |||||
| x = True | |||||
| y = Tensor(np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32)) | |||||
| ret_actual = x + y | |||||
| ret_expect = Tensor(np.array([[2, 3, 4], [5, 6, 7]], dtype=np.int32)) | |||||
| assert (ret_actual.asnumpy() == ret_expect.asnumpy()).all() | |||||
| def test_float_tensor_and_int_tensor_add(): | |||||
| x = Tensor(np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=np.float32)) | |||||
| y = Tensor(np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32)) | |||||
| ret_actual = x + y | |||||
| ret_expect = Tensor(np.array([[1.1, 2.2, 3.3], [4.4, 5.5, 6.6]], dtype=np.float32)) | |||||
| assert (ret_actual.asnumpy() == ret_expect.asnumpy()).all() | |||||
| def test_float_tensor_and_float_tensor_add(): | |||||
| x = Tensor(np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=np.float64)) | |||||
| y = Tensor(np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=np.float32)) | |||||
| ret_actual = x + y | |||||
| ret_expect = Tensor(np.array([[1.1, 2.2, 3.3], [4.4, 5.5, 6.6]], dtype=np.float64)) | |||||
| assert (ret_actual.asnumpy() == ret_expect.asnumpy()).all() | |||||
| def test_int_tensor_and_int_tensor_add(): | |||||
| x = Tensor(np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int16)) | |||||
| y = Tensor(np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32)) | |||||
| ret_actual = x + y | |||||
| ret_expect = Tensor(np.array([[2, 4, 6], [8, 10, 12]], dtype=np.int32)) | |||||
| assert (ret_actual.asnumpy() == ret_expect.asnumpy()).all() | |||||
| def test_float_tensor_and_bool_tensors_add(): | |||||
| x = Tensor(np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=np.float32)) | |||||
| y = Tensor(np.array([[True, True, True], [False, False, False]], dtype=np.bool_)) | |||||
| ret_actual = x + y | |||||
| ret_expect = Tensor(np.array([[1.1, 1.2, 1.3], [0.4, 0.5, 0.6]], dtype=np.float32)) | |||||
| assert (ret_actual.asnumpy() == ret_expect.asnumpy()).all() | |||||
| @@ -0,0 +1,81 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ test implicit conversion """ | |||||
| import numpy as np | |||||
| from mindspore import Tensor | |||||
| def test_float_tensor_and_int_add(): | |||||
| x = Tensor(np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=np.float32)) | |||||
| y = 2 | |||||
| ret_actual = x + y | |||||
| ret_expect = Tensor(np.array([[2.1, 2.2, 2.3], [2.4, 2.5, 2.6]], dtype=np.float32)) | |||||
| assert (ret_actual.asnumpy() == ret_expect.asnumpy()).all() | |||||
| def test_bool_tensor_and_float_add(): | |||||
| x = Tensor(np.array([[True, False], [False, True]], dtype=np.bool_)) | |||||
| y = 3.3 | |||||
| ret_actual = x + y | |||||
| ret_expect = Tensor(np.array([[4.3, 3.3], [3.3, 4.3]], dtype=np.float32)) | |||||
| assert (ret_actual.asnumpy() == ret_expect.asnumpy()).all() | |||||
| def test_bool_tensor_and_int_add(): | |||||
| x = Tensor(np.array([[True, False], [False, True]], dtype=np.bool_)) | |||||
| y = 3 | |||||
| ret_actual = x + y | |||||
| ret_expect = Tensor(np.array([[4, 3], [3, 4]], dtype=np.int32)) | |||||
| assert (ret_actual.asnumpy() == ret_expect.asnumpy()).all() | |||||
| def test_bool_and_int_tensor_add(): | |||||
| x = True | |||||
| y = Tensor(np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32)) | |||||
| ret_actual = x + y | |||||
| ret_expect = Tensor(np.array([[2, 3, 4], [5, 6, 7]], dtype=np.int32)) | |||||
| assert (ret_actual.asnumpy() == ret_expect.asnumpy()).all() | |||||
| def test_float_tensor_and_int_tensor_add(): | |||||
| x = Tensor(np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=np.float32)) | |||||
| y = Tensor(np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32)) | |||||
| ret_actual = x + y | |||||
| ret_expect = Tensor(np.array([[1.1, 2.2, 3.3], [4.4, 5.5, 6.6]], dtype=np.float32)) | |||||
| assert (ret_actual.asnumpy() == ret_expect.asnumpy()).all() | |||||
| def test_float_tensor_and_float_tensor_add(): | |||||
| x = Tensor(np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=np.float32)) | |||||
| y = Tensor(np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=np.float16)) | |||||
| ret_actual = x + y | |||||
| ret_expect = Tensor(np.array([[1.1, 2.2, 3.3], [4.4, 5.5, 6.6]], dtype=np.float32)) | |||||
| assert (ret_actual.asnumpy() == ret_expect.asnumpy()).all() | |||||
| def test_int_tensor_and_int_tensor_add(): | |||||
| x = Tensor(np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int8)) | |||||
| y = Tensor(np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32)) | |||||
| ret_actual = x + y | |||||
| ret_expect = Tensor(np.array([[2, 4, 6], [8, 10, 12]], dtype=np.int32)) | |||||
| assert (ret_actual.asnumpy() == ret_expect.asnumpy()).all() | |||||
| def test_float_tensor_and_bool_tensors_add(): | |||||
| x = Tensor(np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=np.float32)) | |||||
| y = Tensor(np.array([[True, True, True], [False, False, False]], dtype=np.bool_)) | |||||
| ret_actual = x + y | |||||
| ret_expect = Tensor(np.array([[1.1, 1.2, 1.3], [0.4, 0.5, 0.6]], dtype=np.float32)) | |||||
| assert (ret_actual.asnumpy() == ret_expect.asnumpy()).all() | |||||
| @@ -403,7 +403,7 @@ def max_pool_grad(x, dout, pool_h, pool_w, stride): | |||||
| """Grad of max pooling.""" | """Grad of max pooling.""" | ||||
| dout = dout.transpose(0, 2, 3, 1) | dout = dout.transpose(0, 2, 3, 1) | ||||
| pool_size = pool_h * pool_w | pool_size = pool_h * pool_w | ||||
| dmax = np.zeros((dout.size, pool_size)) | |||||
| dmax = np.zeros((dout.size, pool_size), dout.dtype) | |||||
| col = im2col(x, pool_h, pool_w, stride) | col = im2col(x, pool_h, pool_w, stride) | ||||
| col = col.reshape(-1, pool_h * pool_w) | col = col.reshape(-1, pool_h * pool_w) | ||||
| arg_max = np.argmax(col, axis=1) | arg_max = np.argmax(col, axis=1) | ||||
| @@ -418,7 +418,7 @@ def max_pool_grad_with_argmax(x, dout, arg_max, pool_h, pool_w, stride): | |||||
| """Grad of max pooling with argmax.""" | """Grad of max pooling with argmax.""" | ||||
| dout = dout.transpose(0, 2, 3, 1) | dout = dout.transpose(0, 2, 3, 1) | ||||
| pool_size = pool_h * pool_w | pool_size = pool_h * pool_w | ||||
| dmax = np.zeros((dout.size, pool_size)) | |||||
| dmax = np.zeros((dout.size, pool_size), dout.dtype) | |||||
| dmax[np.arange(arg_max.size), arg_max.flatten()] = dout.flatten() | dmax[np.arange(arg_max.size), arg_max.flatten()] = dout.flatten() | ||||
| dmax = dmax.reshape(dout.shape + (pool_size,)) | dmax = dmax.reshape(dout.shape + (pool_size,)) | ||||
| dcol = dmax.reshape(dmax.shape[0] * dmax.shape[1] * dmax.shape[2], -1) | dcol = dmax.reshape(dmax.shape[0] * dmax.shape[1] * dmax.shape[2], -1) | ||||