Merge pull request !4963 from fary86/fix_switch_layer_join_bugtags/v1.0.0
| @@ -283,9 +283,99 @@ bool ConvertOtherObj(py::object obj, ValuePtr *const data) { | |||
| MS_LOG(ERROR) << "Resolve type is invalid " << ((std::string)py::str(obj)); | |||
| return false; | |||
| } | |||
| bool ConvertIntegerWithType(const int &obj, ValuePtr *const data, TypePtr dtype = nullptr) { | |||
| if (dtype == nullptr) { | |||
| *data = std::make_shared<Int32Imm>(obj); | |||
| return true; | |||
| } | |||
| auto int_dypte = dyn_cast<Int>(dtype); | |||
| if (int_dypte != nullptr) { | |||
| switch (int_dypte->nbits()) { | |||
| case 8: | |||
| *data = std::make_shared<Int8Imm>(static_cast<int8_t>(obj)); | |||
| break; | |||
| case 16: | |||
| *data = std::make_shared<Int16Imm>(obj); | |||
| break; | |||
| case 32: | |||
| *data = std::make_shared<Int32Imm>(obj); | |||
| break; | |||
| case 64: | |||
| *data = std::make_shared<Int64Imm>(obj); | |||
| break; | |||
| default: | |||
| *data = std::make_shared<Int32Imm>(obj); | |||
| } | |||
| return true; | |||
| } | |||
| auto uint_dypte = dyn_cast<UInt>(dtype); | |||
| if (int_dypte != nullptr) { | |||
| switch (uint_dypte->nbits()) { | |||
| case 8: | |||
| *data = std::make_shared<UInt8Imm>(obj); | |||
| break; | |||
| case 16: | |||
| *data = std::make_shared<UInt16Imm>(obj); | |||
| break; | |||
| case 32: | |||
| *data = std::make_shared<UInt32Imm>(obj); | |||
| break; | |||
| case 64: | |||
| *data = std::make_shared<UInt64Imm>(obj); | |||
| break; | |||
| default: | |||
| *data = std::make_shared<UInt32Imm>(obj); | |||
| } | |||
| return true; | |||
| } | |||
| auto float_dypte = dyn_cast<Float>(dtype); | |||
| if (float_dypte != nullptr) { | |||
| switch (float_dypte->nbits()) { | |||
| case 32: | |||
| *data = std::make_shared<FP32Imm>(obj); | |||
| break; | |||
| case 64: | |||
| *data = std::make_shared<FP64Imm>(obj); | |||
| break; | |||
| default: | |||
| *data = std::make_shared<FP32Imm>(obj); | |||
| } | |||
| return true; | |||
| } | |||
| return false; | |||
| } | |||
| bool ConvertFloatWithType(const float &obj, ValuePtr *const data, TypePtr dtype = nullptr) { | |||
| if (dtype == nullptr) { | |||
| *data = std::make_shared<FP32Imm>(obj); | |||
| return true; | |||
| } | |||
| auto float_dypte = dyn_cast<Float>(dtype); | |||
| if (float_dypte == nullptr) { | |||
| return false; | |||
| } | |||
| switch (float_dypte->nbits()) { | |||
| case 32: | |||
| *data = std::make_shared<FP32Imm>(obj); | |||
| break; | |||
| case 64: | |||
| *data = std::make_shared<FP64Imm>(obj); | |||
| break; | |||
| default: | |||
| *data = std::make_shared<FP32Imm>(obj); | |||
| } | |||
| return true; | |||
| } | |||
| } // namespace | |||
| bool ConvertData(const py::object &obj, ValuePtr *const data, bool use_signature) { | |||
| bool ConvertData(const py::object &obj, ValuePtr *const data, bool use_signature, TypePtr dtype) { | |||
| // check parameter valid | |||
| if (data == nullptr) { | |||
| MS_LOG(ERROR) << "Data is null pointer"; | |||
| @@ -299,9 +389,9 @@ bool ConvertData(const py::object &obj, ValuePtr *const data, bool use_signature | |||
| } else if (py::isinstance<py::bool_>(obj)) { | |||
| converted = std::make_shared<BoolImm>(py::cast<bool>(obj)); | |||
| } else if (py::isinstance<py::int_>(obj)) { | |||
| converted = std::make_shared<Int32Imm>(py::cast<int>(obj)); | |||
| ret = ConvertIntegerWithType(py::cast<int>(obj), &converted, dtype); | |||
| } else if (py::isinstance<py::float_>(obj)) { | |||
| converted = std::make_shared<FP32Imm>(py::cast<float>(obj)); | |||
| ret = ConvertFloatWithType(py::cast<float>(obj), &converted, dtype); | |||
| } else if (py::isinstance<py::str>(obj)) { | |||
| converted = std::make_shared<StringImm>(py::cast<std::string>(obj)); | |||
| } else if (py::isinstance<py::dict>(obj)) { | |||
| @@ -139,7 +139,7 @@ enum ClassInstanceTypeDef { | |||
| }; | |||
| // Convert python object to ValuePtr | |||
| bool ConvertData(const py::object &obj, ValuePtr *data, bool use_signature = false); | |||
| bool ConvertData(const py::object &obj, ValuePtr *data, bool use_signature = false, TypePtr dtype = nullptr); | |||
| // Convert python obj to graph | |||
| FuncGraphPtr ConvertToFuncGraph(const py::object &obj, | |||
| @@ -407,9 +407,9 @@ py::tuple PreparePyInputs(const PrimitivePyPtr &prim_py, const AbstractBasePtrLi | |||
| AbstractBasePtr PyInferRes2Abstract(const PrimitivePyPtr &prim_py, const py::dict &output) { | |||
| // Convert to AbstractValue based on type and shape | |||
| auto out_dtype = output["dtype"]; | |||
| if (output["value"].is_none()) { | |||
| auto out_shape = output["shape"]; | |||
| auto out_dtype = output["dtype"]; | |||
| py::object min_shape = output.contains("min_shape") ? (py::object)output["min_shape"] : (py::object)py::none(); | |||
| py::object max_shape = output.contains("max_shape") ? (py::object)output["max_shape"] : (py::object)py::none(); | |||
| @@ -417,7 +417,8 @@ AbstractBasePtr PyInferRes2Abstract(const PrimitivePyPtr &prim_py, const py::dic | |||
| } | |||
| // Convert pyobject to Value, then to AbstractValue | |||
| ValuePtr converted_ret = nullptr; | |||
| bool converted = parse::ConvertData(output["value"], &converted_ret); | |||
| TypePtr dtype = py::isinstance<Type>(out_dtype) ? out_dtype.cast<TypePtr>() : nullptr; | |||
| bool converted = parse::ConvertData(output["value"], &converted_ret, false, dtype); | |||
| if (!converted) { | |||
| MS_LOG(EXCEPTION) << "Convert data failed"; | |||
| } | |||
| @@ -45,14 +45,34 @@ py::object ValuePtrToPyData(const ValuePtr &value) { | |||
| MS_LOG(EXCEPTION) << "value is null"; | |||
| } | |||
| py::object ret; | |||
| if (value->isa<Int32Imm>()) { | |||
| MS_LOG(DEBUG) << "int"; | |||
| if (value->isa<Int8Imm>()) { | |||
| MS_LOG(DEBUG) << "int8"; | |||
| py::int_ v = value->cast<Int8ImmPtr>()->value(); | |||
| ret = v; | |||
| } else if (value->isa<Int16Imm>()) { | |||
| MS_LOG(DEBUG) << "int16"; | |||
| py::int_ v = value->cast<Int16ImmPtr>()->value(); | |||
| ret = v; | |||
| } else if (value->isa<Int32Imm>()) { | |||
| MS_LOG(DEBUG) << "int32"; | |||
| py::int_ v = value->cast<Int32ImmPtr>()->value(); | |||
| ret = v; | |||
| } else if (value->isa<Int64Imm>()) { | |||
| MS_LOG(DEBUG) << "int64"; | |||
| py::int_ v = value->cast<Int64ImmPtr>()->value(); | |||
| ret = v; | |||
| } else if (value->isa<UInt8Imm>()) { | |||
| MS_LOG(DEBUG) << "uint8"; | |||
| py::int_ v = value->cast<UInt8ImmPtr>()->value(); | |||
| ret = v; | |||
| } else if (value->isa<UInt16Imm>()) { | |||
| MS_LOG(DEBUG) << "uint16"; | |||
| py::int_ v = value->cast<UInt16ImmPtr>()->value(); | |||
| ret = v; | |||
| } else if (value->isa<UInt32Imm>()) { | |||
| MS_LOG(DEBUG) << "uint32"; | |||
| py::int_ v = value->cast<UInt32ImmPtr>()->value(); | |||
| ret = v; | |||
| } else if (value->isa<UInt64Imm>()) { | |||
| MS_LOG(DEBUG) << "uint64"; | |||
| py::int_ v = value->cast<UInt64ImmPtr>()->value(); | |||
| @@ -97,8 +97,12 @@ AbstractBasePtr AbstractScalar::Join(const AbstractBasePtr &other) { | |||
| } | |||
| auto value_self = GetValueTrack(); | |||
| MS_EXCEPTION_IF_NULL(value_self); | |||
| ValuePtr res_value = ValueJoin(value_self, other->GetValueTrack()); | |||
| TypePtr res_type = TypeJoin(GetTypeTrack(), other->GetTypeTrack()); | |||
| if (res_type == kAnyType) { | |||
| MS_EXCEPTION(TypeError) << "Type join failed, type1 = " << GetTypeTrack()->ToString() | |||
| << ", type2 = " << other->GetTypeTrack()->ToString(); | |||
| } | |||
| ValuePtr res_value = ValueJoin(value_self, other->GetValueTrack()); | |||
| if (res_value == value_self) { | |||
| return shared_from_base<AbstractBase>(); | |||
| } | |||
| @@ -50,9 +50,17 @@ ShapePtr ShapeJoin(const ShapePtr &shape1, const ShapePtr &shape2) { | |||
| if (*shape1 == *shape2) { | |||
| return shape1; | |||
| } | |||
| // lengths of two shapes are not same, join failed | |||
| if (shape1->shape().size() != shape2->shape().size()) { | |||
| MS_LOG(WARNING) << "Unsupported shape join. shape1 = " << shape1->ToString() << ", shape2 = " << shape2->ToString(); | |||
| return shape1; | |||
| // special case: shape(1), shape() -> shape(1) | |||
| if (shape1->shape().size() == 1 && shape1->shape()[0] == 1 && shape2->shape().size() == 0) { | |||
| return shape1; | |||
| } | |||
| if (shape2->shape().size() == 1 && shape2->shape()[0] == 1 && shape1->shape().size() == 0) { | |||
| return shape2; | |||
| } | |||
| MS_EXCEPTION(ValueError) << "Unsupported shape join. shape1 = " << shape1->ToString() | |||
| << ", shape2 = " << shape2->ToString(); | |||
| } | |||
| std::vector<int> dims; | |||
| bool has_dynamic_shape = false; | |||
| @@ -105,7 +105,7 @@ class Int8Imm : public IntergerImm { | |||
| std::string DumpText() const override { | |||
| std::ostringstream oss; | |||
| oss << "I8(" << v_ << ")"; | |||
| oss << "I8(" << int(v_) << ")"; | |||
| return oss.str(); | |||
| } | |||
| @@ -131,7 +131,7 @@ class Int16Imm : public IntergerImm { | |||
| std::string DumpText() const override { | |||
| std::ostringstream oss; | |||
| oss << "I16(" << v_ << ")"; | |||
| oss << "I16(" << int(v_) << ")"; | |||
| return oss.str(); | |||
| } | |||
| @@ -157,7 +157,7 @@ class Int32Imm : public IntergerImm { | |||
| std::string DumpText() const override { | |||
| std::ostringstream oss; | |||
| oss << "I32(" << v_ << ")"; | |||
| oss << "I32(" << int(v_) << ")"; | |||
| return oss.str(); | |||
| } | |||
| @@ -211,7 +211,7 @@ class UInt8Imm : public IntergerImm { | |||
| std::string DumpText() const override { | |||
| std::ostringstream oss; | |||
| oss << "U8(" << v_ << ")"; | |||
| oss << "U8(" << unsigned(v_) << ")"; | |||
| return oss.str(); | |||
| } | |||
| @@ -239,7 +239,7 @@ class UInt16Imm : public IntergerImm { | |||
| std::string DumpText() const override { | |||
| std::ostringstream oss; | |||
| oss << "U16(" << v_ << ")"; | |||
| oss << "U16(" << unsigned(v_) << ")"; | |||
| return oss.str(); | |||
| } | |||
| @@ -267,7 +267,7 @@ class UInt32Imm : public IntergerImm { | |||
| std::string DumpText() const override { | |||
| std::ostringstream oss; | |||
| oss << "U32(" << v_ << ")"; | |||
| oss << "U32(" << unsigned(v_) << ")"; | |||
| return oss.str(); | |||
| } | |||
| @@ -324,7 +324,7 @@ class ScalarGradChecker(_GradChecker): | |||
| self.input_selector = [i for i in range(self.nin)] | |||
| def get_sens(self, i): | |||
| return 1 | |||
| return 1.0 | |||
| def check_against_numeric(self, out_index): | |||
| args = list(self.args) | |||
| @@ -916,3 +916,73 @@ def test_recursive_call(): | |||
| with pytest.raises(RuntimeError): | |||
| net(input_data) | |||
| context.set_context(max_call_depth=old_max_call_depth) | |||
| def test_switch_layer_shape_join_failed(): | |||
| class AddFuncNet(nn.Cell): | |||
| def __init__(self, funcs, new_func): | |||
| super(AddFuncNet, self).__init__() | |||
| self.funcs = funcs | |||
| self.new_func = new_func | |||
| def construct(self, i, inputs): | |||
| final_funcs = self.funcs + (self.new_func,) | |||
| x = final_funcs[i](inputs) | |||
| return x | |||
| class ReLUTuple(nn.Cell): | |||
| def __init__(self): | |||
| super(ReLUTuple, self).__init__() | |||
| self.op = nn.ReLU() | |||
| def construct(self, x): | |||
| return self.op(x[0]) | |||
| func1 = nn.Softmax() | |||
| func2 = nn.ReLU() | |||
| func3 = ReLUTuple() | |||
| funcs = (func1, func2) | |||
| net = AddFuncNet(funcs, func3) | |||
| inp = Tensor(np.random.randn(2, 3, 4, 5).astype(np.float32)) | |||
| i = Tensor(1, mstype.int32) | |||
| with pytest.raises(ValueError) as err: | |||
| net(i, inp) | |||
| def test_switch_layer_dtype_join_failed(): | |||
| class Cast(nn.Cell): | |||
| def __init__(self, dtype): | |||
| super(Cast, self).__init__() | |||
| self.op = P.Cast() | |||
| self.dtype = dtype | |||
| def construct(self, x): | |||
| y = self.op(x, self.dtype) | |||
| return y + y | |||
| class SwitchNegNet(nn.Cell): | |||
| def __init__(self, funcs): | |||
| super(SwitchNegNet, self).__init__() | |||
| self.funcs = funcs | |||
| self.op = P.Neg() | |||
| def construct(self, i, inputs): | |||
| x = self.funcs[i](inputs) | |||
| x = self.op(x) | |||
| return x | |||
| func1 = nn.ReLU() | |||
| func2 = Cast(mstype.int32) | |||
| funcs = (func1, func2) | |||
| net = SwitchNegNet(funcs) | |||
| inp = Tensor(np.random.randn(2, 3, 4, 5).astype(np.float32)) | |||
| i = Tensor(0, mstype.int32) | |||
| with pytest.raises(TypeError) as err: | |||
| net(i, inp) | |||
| @@ -33,6 +33,7 @@ from ....mindspore_test_framework.pipeline.forward.compile_forward \ | |||
| pipeline_for_compile_forward_ge_graph_for_case_by_case_config_exception) | |||
| from ....mindspore_test_framework.pipeline.gradient.compile_gradient \ | |||
| import pipeline_for_compile_grad_ge_graph_for_case_by_case_config | |||
| from ....ops_common import convert | |||
| grad_all_with_sens = C.GradOperation('grad_all_with_sens', get_all=True, sens_param=True) | |||
| @@ -1703,7 +1704,7 @@ test_case_nn_ops = [ | |||
| ('ResizeBilinear', { | |||
| 'block': P.ResizeBilinear((5, 5)), | |||
| 'desc_inputs': [Tensor([[[[1, 2, 3, 4, 5], [1, 2, 3, 4, 5]]]], mstype.float16)], | |||
| 'desc_bprop': [Tensor([[[[1, 2, 3, 4, 5], [1, 2, 3, 4, 5]]]], mstype.float16)]}), | |||
| 'desc_bprop': [Tensor([[[[1, 2, 3, 4, 5], [1, 2, 3, 4, 5]]]], mstype.float32)]}), | |||
| ('ResizeBilinearGrad', { | |||
| 'block': G.ResizeBilinearGrad(), | |||
| 'desc_inputs': [Tensor([[[[1, 2, 3, 4, 5]]]], mstype.float32), Tensor([[[[1, 2, 3, 4, 5]]]], mstype.float32)], | |||
| @@ -1712,7 +1713,7 @@ test_case_nn_ops = [ | |||
| ('ROIAlign', { | |||
| 'block': P.ROIAlign(7, 7, 0.03125, 2), | |||
| 'desc_inputs': [[2, 256, 192, 320], [1024, 5]], | |||
| 'desc_bprop': [[7, 7]]}), | |||
| 'desc_bprop': [[1024, 256, 7, 7]]}), | |||
| ('ROIAlignGrad', { | |||
| 'block': G.ROIAlignGrad((1, 1, 1, 1), 2, 2, 0.5, 2), | |||
| 'desc_inputs': [[1, 1, 2, 2], [1, 5]], | |||
| @@ -2315,7 +2316,7 @@ test_case_other_ops = [ | |||
| ('IOU', { | |||
| 'block': P.IOU(), | |||
| 'desc_inputs': [Tensor(np.ones((256, 4), np.float16)), Tensor(np.ones((128, 4), np.float16))], | |||
| 'desc_bprop': [[128, 256]]}), | |||
| 'desc_bprop': [convert([128, 256], np.float16)]}), | |||
| ('Summary', { | |||
| 'block': SummaryNet(), | |||
| 'desc_inputs': [Tensor(np.array([1.1]).astype(np.float32)), | |||
| @@ -118,29 +118,29 @@ test_case_reid_ops = [ | |||
| 'desc_inputs': [[256, 8]], | |||
| 'desc_bprop': [[256, 8]]}), | |||
| ('Pow', { | |||
| 'block': P.Pow(), # 输入有标量插件产生了段错误。 | |||
| 'block': P.Pow(), | |||
| 'desc_const': [2.0], | |||
| 'desc_inputs': [[1, 512]], | |||
| 'desc_bprop': [[1, 512]]}), | |||
| ('LogicalNot', { | |||
| 'block': P.LogicalNot(), | |||
| 'desc_inputs': [convert([256], np.bool_)], | |||
| 'desc_bprop': [[256]]}), # 自定义算子 input bool没转换,gongchen提单。 | |||
| 'desc_bprop': [convert([256], np.bool_)]}), | |||
| ('Equal', { | |||
| 'block': P.Equal(), | |||
| 'desc_inputs': [convert([256], np.float16), convert([256], np.float16)], | |||
| 'desc_bprop': [[256]]}), | |||
| 'desc_bprop': [convert([256], np.bool_)]}), | |||
| ('Greater', { | |||
| 'block': P.Greater(), | |||
| 'desc_inputs': [convert([256], np.float16), convert([256], np.float16)], | |||
| 'desc_bprop': [[256]]}), | |||
| 'desc_bprop': [convert([256], np.bool_)]}), | |||
| ('Dropout', { | |||
| 'block': nn.Dropout(), | |||
| 'desc_inputs': [[1, 512, 7, 7]], | |||
| 'desc_bprop': [[1, 512, 7, 7]]}), # 输入有标量插件产生了段错误。 | |||
| 'desc_bprop': [[1, 512, 7, 7]]}), | |||
| ('MatMul', { | |||
| 'block': P.MatMul(), | |||
| 'desc_inputs': [[64, 512], [512, 64]], # fp16不行。很有问题。 | |||
| 'desc_inputs': [[64, 512], [512, 64]], | |||
| 'desc_bprop': [[64, 64]]}), | |||
| ('Maximum', { | |||
| 'block': P.Maximum(), | |||
| @@ -84,8 +84,8 @@ class Bprop(Cell): | |||
| self.grad = grad_op | |||
| self.with_sens = False | |||
| self.sens = sens | |||
| if sens: | |||
| self.sens = Tensor(sens, dtype=mstype.float32) | |||
| if not sens is None: | |||
| self.sens = sens if isinstance(sens, Tensor) else Tensor(sens, dtype=mstype.float32) | |||
| self.with_sens = True | |||
| def construct(self, *inputs): | |||
| @@ -115,7 +115,7 @@ def test_all_var_args_grad_with_sens(): | |||
| x = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32) | |||
| y = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32) | |||
| sens = Tensor(1.0, dtype=mstype.float32) | |||
| sens = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32) | |||
| net = VarNet(SecondNet()) | |||
| grad_net = GradNet(net) | |||
| _ = grad_net(x, y, sens) | |||
| @@ -167,7 +167,7 @@ def test_grad_all_var_args_with_sens(): | |||
| x = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32) | |||
| y = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32) | |||
| sens = Tensor(1.0, dtype=mstype.float32) | |||
| sens = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32) | |||
| net = VarNet(SecondNet()) | |||
| grad_net = GradNet(net) | |||
| _ = grad_net(x, y, sens) | |||
| @@ -185,7 +185,7 @@ def test_grad_var_args_with_sens(): | |||
| x = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32) | |||
| y = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32) | |||
| sens = Tensor(1.0, dtype=mstype.float32) | |||
| sens = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32) | |||
| net = VarNet(SecondNet()) | |||
| grad_net = GradNet(net) | |||
| _ = grad_net(x, y, sens) | |||
| @@ -244,7 +244,7 @@ def test_var_args_grad(): | |||
| x = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32) | |||
| y = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32) | |||
| sens = Tensor(1.0, dtype=mstype.float32) | |||
| sens = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32) | |||
| net = VarNet(SecondNet()) | |||
| grad_net = GradNet(net) | |||
| _ = grad_net(x, y, sens) | |||
| @@ -292,14 +292,14 @@ def test_grad_within_if_else(): | |||
| self.net = net | |||
| grad_op = C.GradOperation( | |||
| name='grad', get_all=False, get_by_list=True, sens_param=True) | |||
| self.grad = Bprop(self.net, True, self.weights, grad_op, 1.0) | |||
| sens = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32) | |||
| self.grad = Bprop(self.net, True, self.weights, grad_op, sens) | |||
| def construct(self, *inputs): | |||
| return self.grad(*inputs) | |||
| x = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32) | |||
| y = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32) | |||
| _ = Tensor(1.0, dtype=mstype.float32) | |||
| net = VarNet(SecondNet()) | |||
| grad_net = GradNet(net) | |||
| out = grad_net(x, y) | |||