| @@ -206,7 +206,7 @@ bool CompareShape(const std::vector<ValuePtr> &x_shape, const std::vector<ValueP | |||||
| } | } | ||||
| AbstractBasePtr DoInferReduceShape(const AbstractTuplePtr &x_shape, const ValuePtr &x_shp_value, | AbstractBasePtr DoInferReduceShape(const AbstractTuplePtr &x_shape, const ValuePtr &x_shp_value, | ||||
| const ValueTuplePtr &axis_value_ptr, const PrimitivePtr &primitive) { | |||||
| const ValueSequeuePtr &axis_value_ptr, const PrimitivePtr &primitive) { | |||||
| size_t x_rank = x_shape->size(); | size_t x_rank = x_shape->size(); | ||||
| std::set<int> axis_set; | std::set<int> axis_set; | ||||
| auto axis_data = axis_value_ptr->value(); | auto axis_data = axis_value_ptr->value(); | ||||
| @@ -348,17 +348,17 @@ AbstractBasePtr InferImplReduceShape(const AnalysisEnginePtr &, const PrimitiveP | |||||
| << " evaluator shape's data field can't be anything: " << args_spec_list[1]->ToString(); | << " evaluator shape's data field can't be anything: " << args_spec_list[1]->ToString(); | ||||
| } | } | ||||
| // Axis can be scalar, tuple or None | |||||
| AbstractTuplePtr axis = nullptr; | |||||
| // Axis can be scalar, tuple or list | |||||
| AbstractSequeuePtr axis = nullptr; | |||||
| if (args_spec_list[1]->isa<AbstractScalar>()) { | if (args_spec_list[1]->isa<AbstractScalar>()) { | ||||
| MS_LOG(DEBUG) << op_name << " evaluator second parameter is scalar"; | MS_LOG(DEBUG) << op_name << " evaluator second parameter is scalar"; | ||||
| AbstractBasePtrList axis_list = {dyn_cast<AbstractScalar>(args_spec_list[1])}; | AbstractBasePtrList axis_list = {dyn_cast<AbstractScalar>(args_spec_list[1])}; | ||||
| axis = std::make_shared<AbstractTuple>(axis_list); | axis = std::make_shared<AbstractTuple>(axis_list); | ||||
| } else if (args_spec_list[1]->isa<AbstractTuple>()) { | |||||
| MS_LOG(DEBUG) << op_name << " evaluator second parameter is tuple"; | |||||
| axis = args_spec_list[1]->cast<AbstractTuplePtr>(); | |||||
| } else if (args_spec_list[1]->isa<AbstractSequeue>()) { | |||||
| MS_LOG(DEBUG) << op_name << " evaluator second parameter is sequeue"; | |||||
| axis = args_spec_list[1]->cast<AbstractSequeuePtr>(); | |||||
| } else { | } else { | ||||
| MS_LOG(EXCEPTION) << op_name << " evaluator second parameter should be a scalar or tuple, but got " | |||||
| MS_LOG(EXCEPTION) << op_name << " evaluator second parameter should be a scalar or tuple or list, but got " | |||||
| << args_spec_list[1]->ToString(); | << args_spec_list[1]->ToString(); | ||||
| } | } | ||||
| @@ -367,7 +367,7 @@ AbstractBasePtr InferImplReduceShape(const AnalysisEnginePtr &, const PrimitiveP | |||||
| MS_LOG(EXCEPTION) << op_name | MS_LOG(EXCEPTION) << op_name | ||||
| << " evaluator shape's data field can't be anything: " << args_spec_list[1]->ToString(); | << " evaluator shape's data field can't be anything: " << args_spec_list[1]->ToString(); | ||||
| } | } | ||||
| auto axis_value_ptr = axis_value->cast<ValueTuplePtr>(); | |||||
| auto axis_value_ptr = axis_value->cast<ValueSequeuePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(axis_value_ptr); | MS_EXCEPTION_IF_NULL(axis_value_ptr); | ||||
| return DoInferReduceShape(shape_x, x_shp_value, axis_value_ptr, primitive); | return DoInferReduceShape(shape_x, x_shp_value, axis_value_ptr, primitive); | ||||
| @@ -261,17 +261,19 @@ void ConvertInputs(const PrimitivePyPtr &prim, const py::list &args, const OpExe | |||||
| auto obj = out_args[i]; | auto obj = out_args[i]; | ||||
| if (py::isinstance<tensor::Tensor>(obj)) { | if (py::isinstance<tensor::Tensor>(obj)) { | ||||
| auto arg = py::cast<tensor::TensorPtr>(obj); | auto arg = py::cast<tensor::TensorPtr>(obj); | ||||
| if (arg->data_type() == it->second) { | |||||
| TypeId arg_type_id = arg->data_type(); | |||||
| if (prim::type_map.find(arg_type_id) == prim::type_map.end() || arg_type_id == it->second) { | |||||
| continue; | continue; | ||||
| } | } | ||||
| if (signature[i].rw == SignatureEnumRW::kRWWrite) { | if (signature[i].rw == SignatureEnumRW::kRWWrite) { | ||||
| prim::RaiseExceptionForConvertRefDtype(prim->name(), TypeIdToMsTypeStr(arg->data_type()), | |||||
| prim::RaiseExceptionForConvertRefDtype(prim->name(), TypeIdToMsTypeStr(arg_type_id), | |||||
| TypeIdToMsTypeStr(it->second)); | TypeIdToMsTypeStr(it->second)); | ||||
| } | } | ||||
| } | } | ||||
| if (!py::isinstance<tensor::Tensor>(obj) && !py::isinstance<py::int_>(obj) && !py::isinstance<py::float_>(obj)) { | if (!py::isinstance<tensor::Tensor>(obj) && !py::isinstance<py::int_>(obj) && !py::isinstance<py::float_>(obj)) { | ||||
| MS_EXCEPTION(TypeError) << "For '" << prim->name() << "', the " << i << "th input is a not support type: " | |||||
| MS_EXCEPTION(TypeError) << "For '" << prim->name() << "', the " << i | |||||
| << "th input is a not support implicit conversion type: " | |||||
| << py::cast<std::string>(obj.attr("__class__").attr("__name__")) << ", and the value is " | << py::cast<std::string>(obj.attr("__class__").attr("__name__")) << ", and the value is " | ||||
| << py::cast<py::str>(obj) << "."; | << py::cast<py::str>(obj) << "."; | ||||
| } | } | ||||
| @@ -239,7 +239,8 @@ class Tensor(Tensor_): | |||||
| Check all array elements along a given axis evaluate to True. | Check all array elements along a given axis evaluate to True. | ||||
| Args: | Args: | ||||
| axis (Union[None, int, tuple(int)): Dimensions of reduction. | |||||
| axis (Union[None, int, tuple(int)): Dimensions of reduction, | |||||
| when axis is None or empty tuple, reduce all dimensions. | |||||
| Default: (), reduce all dimensions. | Default: (), reduce all dimensions. | ||||
| keep_dims (bool): Whether to keep the reduced dimensions. | keep_dims (bool): Whether to keep the reduced dimensions. | ||||
| Default : False, don't keep these reduced dimensions. | Default : False, don't keep these reduced dimensions. | ||||
| @@ -257,7 +258,8 @@ class Tensor(Tensor_): | |||||
| Check any array element along a given axis evaluate to True. | Check any array element along a given axis evaluate to True. | ||||
| Args: | Args: | ||||
| axis (Union[None, int, tuple(int)): Dimensions of reduction. | |||||
| axis (Union[None, int, tuple(int)): Dimensions of reduction, | |||||
| when axis is None or empty tuple, reduce all dimensions. | |||||
| Default: (), reduce all dimensions. | Default: (), reduce all dimensions. | ||||
| keep_dims (bool): Whether to keep the reduced dimensions. | keep_dims (bool): Whether to keep the reduced dimensions. | ||||
| Default : False, don't keep these reduced dimensions. | Default : False, don't keep these reduced dimensions. | ||||
| @@ -338,21 +338,21 @@ TypePtr FunctionStrToType(const std::string &type_name) { | |||||
| TypePtr StringToType(const std::string &type_name) { | TypePtr StringToType(const std::string &type_name) { | ||||
| TypePtr type = nullptr; | TypePtr type = nullptr; | ||||
| if (type_name.compare("None") == 0) { | |||||
| if (type_name == "None") { | |||||
| type = std::make_shared<TypeNone>(); | type = std::make_shared<TypeNone>(); | ||||
| } else if (type_name.compare("Ellipsis") == 0) { | |||||
| } else if (type_name == "Ellipsis") { | |||||
| type = std::make_shared<TypeEllipsis>(); | type = std::make_shared<TypeEllipsis>(); | ||||
| } else if (type_name.compare("TypeType") == 0) { | |||||
| } else if (type_name == "TypeType") { | |||||
| type = std::make_shared<TypeType>(); | type = std::make_shared<TypeType>(); | ||||
| } else if (type_name.compare("SymbolicKeyType") == 0) { | |||||
| } else if (type_name == "SymbolicKeyType") { | |||||
| type = std::make_shared<SymbolicKeyType>(); | type = std::make_shared<SymbolicKeyType>(); | ||||
| } else if (type_name.compare("RefKeyType") == 0) { | |||||
| } else if (type_name == "RefKeyType") { | |||||
| type = std::make_shared<RefKeyType>(); | type = std::make_shared<RefKeyType>(); | ||||
| } else if (type_name.compare("EnvType") == 0) { | |||||
| } else if (type_name == "EnvType") { | |||||
| type = std::make_shared<EnvType>(); | type = std::make_shared<EnvType>(); | ||||
| } else if (type_name.compare("Number") == 0) { | |||||
| } else if (type_name == "Number") { | |||||
| type = std::make_shared<Number>(); | type = std::make_shared<Number>(); | ||||
| } else if (type_name.compare("Bool") == 0) { | |||||
| } else if (type_name == "Bool") { | |||||
| type = std::make_shared<Bool>(); | type = std::make_shared<Bool>(); | ||||
| } else if (type_name.compare(0, strlen("Int"), "Int") == 0) { | } else if (type_name.compare(0, strlen("Int"), "Int") == 0) { | ||||
| type = StringToNumberType<Int>(type_name, "Int"); | type = StringToNumberType<Int>(type_name, "Int"); | ||||
| @@ -372,16 +372,18 @@ TypePtr StringToType(const std::string &type_name) { | |||||
| type = ListStrToType(type_name); | type = ListStrToType(type_name); | ||||
| } else if (type_name.compare(0, strlen("Tuple"), "Tuple") == 0) { | } else if (type_name.compare(0, strlen("Tuple"), "Tuple") == 0) { | ||||
| type = TupleStrToType(type_name); | type = TupleStrToType(type_name); | ||||
| } else if (type_name.compare("Slice") == 0) { | |||||
| } else if (type_name == "Slice") { | |||||
| type = std::make_shared<Slice>(); | type = std::make_shared<Slice>(); | ||||
| } else if (type_name.compare("Dictionary") == 0) { | |||||
| } else if (type_name == "Dictionary") { | |||||
| type = std::make_shared<Dictionary>(); | type = std::make_shared<Dictionary>(); | ||||
| } else if (type_name.compare("String") == 0) { | |||||
| } else if (type_name == "String") { | |||||
| type = std::make_shared<String>(); | type = std::make_shared<String>(); | ||||
| } else if (type_name.compare("Problem") == 0) { | |||||
| } else if (type_name == "Problem") { | |||||
| type = std::make_shared<Problem>(); | type = std::make_shared<Problem>(); | ||||
| } else if (type_name.compare(0, strlen("Function"), "Function") == 0) { | } else if (type_name.compare(0, strlen("Function"), "Function") == 0) { | ||||
| type = FunctionStrToType(type_name); | type = FunctionStrToType(type_name); | ||||
| } else if (type_name == "mstype") { | |||||
| type = std::make_shared<TypeType>(); | |||||
| } else { | } else { | ||||
| // - unsupported to convert | // - unsupported to convert | ||||
| // Class | // Class | ||||
| @@ -389,7 +391,6 @@ TypePtr StringToType(const std::string &type_name) { | |||||
| // JTagged | // JTagged | ||||
| // Anything | // Anything | ||||
| // External | // External | ||||
| // Problem | |||||
| MS_LOG(EXCEPTION) << "Unsupported type name: " << type_name << "!"; | MS_LOG(EXCEPTION) << "Unsupported type name: " << type_name << "!"; | ||||
| } | } | ||||
| return type; | return type; | ||||
| @@ -403,10 +404,7 @@ bool IsParentOrChildrenType(TypePtr const &x, TypePtr const &base_type) { | |||||
| if (base_type->type_id() == kTypeUnknown || x->type_id() == kTypeUnknown) { | if (base_type->type_id() == kTypeUnknown || x->type_id() == kTypeUnknown) { | ||||
| return false; | return false; | ||||
| } | } | ||||
| if (base_type->type_id() == x->parent_type() || x->type_id() == base_type->parent_type()) { | |||||
| return true; | |||||
| } | |||||
| return false; | |||||
| return base_type->type_id() == x->parent_type() || x->type_id() == base_type->parent_type(); | |||||
| } | } | ||||
| bool IsIdentidityOrSubclass(TypePtr const &x, TypePtr const &base_type) { | bool IsIdentidityOrSubclass(TypePtr const &x, TypePtr const &base_type) { | ||||
| @@ -206,8 +206,8 @@ void Cloner::InlineCloneParameters(const FuncGraphPtr &func_graph, const AnfNode | |||||
| MS_EXCEPTION_IF_NULL(func_graph); | MS_EXCEPTION_IF_NULL(func_graph); | ||||
| auto &old_params = func_graph->parameters(); | auto &old_params = func_graph->parameters(); | ||||
| if (old_params.size() != params.size()) { | if (old_params.size() != params.size()) { | ||||
| MS_LOG(EXCEPTION) << "Origin params size[" << old_params.size() << "], inline params size[" << params.size() << "]"; | |||||
| return; | |||||
| MS_EXCEPTION(TypeError) << "Origin params size[" << old_params.size() << "], inline params size[" << params.size() | |||||
| << "]"; | |||||
| } | } | ||||
| for (size_t i = 0; i < old_params.size(); ++i) { | for (size_t i = 0; i < old_params.size(); ++i) { | ||||
| repl_node_[old_params[i]] = params[i]; | repl_node_[old_params[i]] = params[i]; | ||||
| @@ -762,3 +762,10 @@ def get_stride_info_from_tuple(data_shape, index_tuple): | |||||
| end_strides.append(data_shape[item]) | end_strides.append(data_shape[item]) | ||||
| step_strides.append(1) | step_strides.append(1) | ||||
| return tuple(begin_strides), tuple(end_strides), tuple(step_strides), shrink_axis | return tuple(begin_strides), tuple(end_strides), tuple(step_strides), shrink_axis | ||||
| @constexpr | |||||
| def mstype_eq(x, y): | |||||
| if x == y: | |||||
| return True | |||||
| return False | |||||
| @@ -14,7 +14,7 @@ | |||||
| # ============================================================================ | # ============================================================================ | ||||
| """equal_impl""" | """equal_impl""" | ||||
| from . import _constexpr_utils as const_utils | |||||
| from ...composite import base | from ...composite import base | ||||
| from ... import functional as F | from ... import functional as F | ||||
| @@ -32,8 +32,8 @@ def _equal_scalar(x, y): | |||||
| Determine if two numbers are equal. | Determine if two numbers are equal. | ||||
| Args: | Args: | ||||
| x (Number): x | |||||
| y (NUmber): y | |||||
| x (Number): first input number. | |||||
| y (NUmber): second input number. | |||||
| Returns: | Returns: | ||||
| bool, if x == y return true, x != y return false. | bool, if x == y return true, x != y return false. | ||||
| @@ -41,14 +41,29 @@ def _equal_scalar(x, y): | |||||
| return F.scalar_eq(x, y) | return F.scalar_eq(x, y) | ||||
| @equal.register("mstype", "mstype") | |||||
| def _equal_mstype(x, y): | |||||
| """ | |||||
| Determine if two mindspore types are equal. | |||||
| Args: | |||||
| x (mstype): first input mindspore type. | |||||
| y (mstype): second input mindspore type. | |||||
| Returns: | |||||
| bool, if x == y return true, x != y return false. | |||||
| """ | |||||
| return const_utils.mstype_eq(x, y) | |||||
| @equal.register("String", "String") | @equal.register("String", "String") | ||||
| def _equal_string(x, y): | def _equal_string(x, y): | ||||
| """ | """ | ||||
| Determine if two strings are equal. | Determine if two strings are equal. | ||||
| Args: | Args: | ||||
| x: str | |||||
| y: str | |||||
| x (str): first input string. | |||||
| y (str): second input string. | |||||
| Returns: | Returns: | ||||
| bool, if x == y return true, x != y return false. | bool, if x == y return true, x != y return false. | ||||
| @@ -62,8 +77,8 @@ def _string_equal_none(x, y): | |||||
| Determine if string equals none. | Determine if string equals none. | ||||
| Args: | Args: | ||||
| x: str. | |||||
| y: None. | |||||
| x (str): first input string. | |||||
| y (None) second input None. | |||||
| Returns: | Returns: | ||||
| bool, return false. | bool, return false. | ||||
| @@ -77,8 +92,8 @@ def _none_equal_string(x, y): | |||||
| Determine if string equals none. | Determine if string equals none. | ||||
| Args: | Args: | ||||
| x: None. | |||||
| y: str. | |||||
| x (None): first input None. | |||||
| y (str): second input string. | |||||
| Returns: | Returns: | ||||
| bool, return false. | bool, return false. | ||||
| @@ -92,8 +107,8 @@ def _none_equal_none(x, y): | |||||
| Determine if none equals none. | Determine if none equals none. | ||||
| Args: | Args: | ||||
| x: None. | |||||
| y: None. | |||||
| x (None): first input None. | |||||
| y (None): second input None. | |||||
| Returns: | Returns: | ||||
| bool, return true. | bool, return true. | ||||
| @@ -107,8 +122,8 @@ def _scalar_equal_none(x, y): | |||||
| Determine if number equals none. | Determine if number equals none. | ||||
| Args: | Args: | ||||
| x: Number. | |||||
| y: None. | |||||
| x (Number): first input number. | |||||
| y (None): second input None. | |||||
| Returns: | Returns: | ||||
| bool, return false. | bool, return false. | ||||
| @@ -122,8 +137,8 @@ def _none_equal_scalar(x, y): | |||||
| Determine if number equals none. | Determine if number equals none. | ||||
| Args: | Args: | ||||
| x: None. | |||||
| y: NUmber. | |||||
| x (None): first input None. | |||||
| y (Number): second input Number. | |||||
| Returns: | Returns: | ||||
| bool, return false. | bool, return false. | ||||
| @@ -137,8 +152,8 @@ def _euqal_tuple(x, y): | |||||
| Determine if two tuples are equal by element. | Determine if two tuples are equal by element. | ||||
| Args: | Args: | ||||
| x (tuple): x | |||||
| y (tuple): y | |||||
| x (tuple): first input tuple. | |||||
| y (tuple): second input tuple. | |||||
| Returns: | Returns: | ||||
| bool, if x and y are equal by element return true, else return false. | bool, if x and y are equal by element return true, else return false. | ||||
| @@ -152,8 +167,8 @@ def _euqal_list(x, y): | |||||
| Determine if two lists are equal by element. | Determine if two lists are equal by element. | ||||
| Args: | Args: | ||||
| x (list): x | |||||
| y (list): y | |||||
| x (list): first input list. | |||||
| y (list): second input list. | |||||
| Returns: | Returns: | ||||
| bool, if x and y are equal by element return true, else return false. | bool, if x and y are equal by element return true, else return false. | ||||
| @@ -167,8 +182,8 @@ def _tuple_euqal_none(x, y): | |||||
| Determine if tuple element equals none element. | Determine if tuple element equals none element. | ||||
| Args: | Args: | ||||
| x: Tuple. | |||||
| y: None. | |||||
| x(tuple): first input tuple. | |||||
| y (None): second input None. | |||||
| Returns: | Returns: | ||||
| bool, return false. | bool, return false. | ||||
| @@ -182,8 +197,8 @@ def _none_equal_tuple(x, y): | |||||
| Determine if tuple element equals none element. | Determine if tuple element equals none element. | ||||
| Args: | Args: | ||||
| x: None. | |||||
| y: Tuple. | |||||
| x (None): first input None. | |||||
| y (tuple): second input tuple. | |||||
| Returns: | Returns: | ||||
| bool, return false. | bool, return false. | ||||
| @@ -199,8 +214,8 @@ def _tensor_equal_tensor(x, y): | |||||
| Determine if two tensors are equal. | Determine if two tensors are equal. | ||||
| Args: | Args: | ||||
| x : Tensor. | |||||
| y : Tensor. | |||||
| x (Tensor): first input tensor. | |||||
| y (Tensor): second input tensor. | |||||
| Returns: | Returns: | ||||
| bool, if x == y return true, x != y return false. | bool, if x == y return true, x != y return false. | ||||
| @@ -214,8 +229,8 @@ def _tensor_equal_none(x, y): | |||||
| Determine if tensor equal none. | Determine if tensor equal none. | ||||
| Args: | Args: | ||||
| x : Tensor. | |||||
| y : None. | |||||
| x (Tensor): first input tensor. | |||||
| y (None): second input None. | |||||
| Returns: | Returns: | ||||
| bool, return false. | bool, return false. | ||||
| @@ -229,8 +244,8 @@ def _none_equal_tensor(x, y): | |||||
| Determine if tensor equal none. | Determine if tensor equal none. | ||||
| Args: | Args: | ||||
| x : None. | |||||
| y : Tensor. | |||||
| x (None): first input None. | |||||
| y (Tensor): second input tensor. | |||||
| Returns: | Returns: | ||||
| bool, return false. | bool, return false. | ||||
| @@ -245,7 +260,7 @@ def _list_equal_none(x, y): | |||||
| Args: | Args: | ||||
| x (list): The first input which is a list. | x (list): The first input which is a list. | ||||
| y (none): The second input which is none. | |||||
| y (None): The second input which is none. | |||||
| Returns: | Returns: | ||||
| bool, return false. | bool, return false. | ||||
| @@ -259,7 +274,7 @@ def _none_equal_list(x, y): | |||||
| Determine if none equal list. | Determine if none equal list. | ||||
| Args: | Args: | ||||
| x (none): The first input which is none. | |||||
| x (None): The first input which is none. | |||||
| y (list): The second input which is a list. | y (list): The second input which is a list. | ||||
| Returns: | Returns: | ||||
| @@ -49,6 +49,7 @@ def test_bool_and_int_tensor_add(): | |||||
| ret_expect = Tensor(np.array([[2, 3, 4], [5, 6, 7]], dtype=np.int32)) | ret_expect = Tensor(np.array([[2, 3, 4], [5, 6, 7]], dtype=np.int32)) | ||||
| assert (ret_actual.asnumpy() == ret_expect.asnumpy()).all() | assert (ret_actual.asnumpy() == ret_expect.asnumpy()).all() | ||||
| def test_float_tensor_and_int_tensor_add(): | def test_float_tensor_and_int_tensor_add(): | ||||
| x = Tensor(np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=np.float32)) | x = Tensor(np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=np.float32)) | ||||
| y = Tensor(np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32)) | y = Tensor(np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32)) | ||||
| @@ -0,0 +1,65 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ test enumerate""" | |||||
| import numpy as np | |||||
| import mindspore.nn as nn | |||||
| from mindspore import Tensor | |||||
| from mindspore import context | |||||
| from mindspore.common import dtype as mstype | |||||
| context.set_context(mode=context.GRAPH_MODE) | |||||
| def test_equal_two_const_mstype(): | |||||
| class Net(nn.Cell): | |||||
| def __init__(self): | |||||
| super(Net, self).__init__() | |||||
| self.type_base = mstype.float32 | |||||
| self.type_0 = mstype.float32 | |||||
| self.type_1 = mstype.float16 | |||||
| self.type_2 = mstype.int32 | |||||
| self.type_3 = mstype.tuple_ | |||||
| def construct(self): | |||||
| ret_0 = self.type_0 == self.type_base | |||||
| ret_1 = self.type_1 == self.type_base | |||||
| ret_2 = self.type_2 == self.type_base | |||||
| ret_3 = self.type_3 == self.type_base | |||||
| return ret_0, ret_1, ret_2, ret_3 | |||||
| net = Net() | |||||
| assert net() == (True, False, False, False) | |||||
| def test_equal_two_tensor_mstype(): | |||||
| class Net(nn.Cell): | |||||
| def __init__(self): | |||||
| super(Net, self).__init__() | |||||
| def construct(self, x, y, z): | |||||
| ret_x = x.dtype == mstype.float32 | |||||
| ret_y = y.dtype == mstype.int32 | |||||
| ret_z = z.dtype == mstype.bool_ | |||||
| ret_xy = x.dtype == y.dtype | |||||
| ret_xz = x.dtype == z.dtype | |||||
| ret_yz = y.dtype == z.dtype | |||||
| return ret_x, ret_y, ret_z, ret_xy, ret_xz, ret_yz | |||||
| net = Net() | |||||
| x = Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)), mstype.float32) | |||||
| y = Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)), mstype.int32) | |||||
| z = Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)), mstype.bool_) | |||||
| assert net(x, y, z) == (True, True, True, False, False, False) | |||||
| @@ -96,7 +96,7 @@ def test_float_tensor_and_str_add(): | |||||
| y = "ok" | y = "ok" | ||||
| with pytest.raises(TypeError) as er: | with pytest.raises(TypeError) as er: | ||||
| ret = x + y | ret = x + y | ||||
| assert "For 'TensorAdd', the 1th input is a not support type: str" in str(er.value) | |||||
| assert "For 'TensorAdd', the 1th input is a not support implicit conversion type: str" in str(er.value) | |||||
| def test_float_tensor_and_tuple_add(): | def test_float_tensor_and_tuple_add(): | ||||
| @@ -104,7 +104,7 @@ def test_float_tensor_and_tuple_add(): | |||||
| y = (1, 2, 3) | y = (1, 2, 3) | ||||
| with pytest.raises(TypeError) as er: | with pytest.raises(TypeError) as er: | ||||
| ret = x + y | ret = x + y | ||||
| assert "For 'TensorAdd', the 1th input is a not support type: tuple" in str(er.value) | |||||
| assert "For 'TensorAdd', the 1th input is a not support implicit conversion type: tuple" in str(er.value) | |||||
| def test_float_tensor_and_list_add(): | def test_float_tensor_and_list_add(): | ||||
| @@ -112,7 +112,7 @@ def test_float_tensor_and_list_add(): | |||||
| y = [1, 2, 3] | y = [1, 2, 3] | ||||
| with pytest.raises(TypeError) as er: | with pytest.raises(TypeError) as er: | ||||
| ret = x + y | ret = x + y | ||||
| assert "For 'TensorAdd', the 1th input is a not support type: list" in str(er.value) | |||||
| assert "For 'TensorAdd', the 1th input is a not support implicit conversion type: list" in str(er.value) | |||||
| def test_float_tensor_and_bool_tensors_add_grad(): | def test_float_tensor_and_bool_tensors_add_grad(): | ||||