Browse Source

!4963 fix bug of switch layer join

Merge pull request !4963 from fary86/fix_switch_layer_join_bug
tags/v1.0.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
afd16fbf0a
12 changed files with 229 additions and 35 deletions
  1. +93
    -3
      mindspore/ccsrc/pipeline/jit/parse/data_converter.cc
  2. +1
    -1
      mindspore/ccsrc/pipeline/jit/parse/parse_base.h
  3. +3
    -2
      mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc
  4. +22
    -2
      mindspore/ccsrc/utils/convert_utils.cc
  5. +5
    -1
      mindspore/core/abstract/abstract_value.cc
  6. +10
    -2
      mindspore/core/abstract/utils.cc
  7. +6
    -6
      mindspore/core/ir/scalar.h
  8. +1
    -1
      tests/mindspore_test_framework/utils/check_gradient.py
  9. +70
    -0
      tests/ut/python/ops/test_control_ops.py
  10. +4
    -3
      tests/ut/python/ops/test_ops.py
  11. +6
    -6
      tests/ut/python/ops/test_ops_reid.py
  12. +8
    -8
      tests/ut/python/parameter_feature/test_var_grad.py

+ 93
- 3
mindspore/ccsrc/pipeline/jit/parse/data_converter.cc View File

@@ -283,9 +283,99 @@ bool ConvertOtherObj(py::object obj, ValuePtr *const data) {
MS_LOG(ERROR) << "Resolve type is invalid " << ((std::string)py::str(obj));
return false;
}

bool ConvertIntegerWithType(const int &obj, ValuePtr *const data, TypePtr dtype = nullptr) {
if (dtype == nullptr) {
*data = std::make_shared<Int32Imm>(obj);
return true;
}

auto int_dypte = dyn_cast<Int>(dtype);
if (int_dypte != nullptr) {
switch (int_dypte->nbits()) {
case 8:
*data = std::make_shared<Int8Imm>(static_cast<int8_t>(obj));
break;
case 16:
*data = std::make_shared<Int16Imm>(obj);
break;
case 32:
*data = std::make_shared<Int32Imm>(obj);
break;
case 64:
*data = std::make_shared<Int64Imm>(obj);
break;
default:
*data = std::make_shared<Int32Imm>(obj);
}
return true;
}

auto uint_dypte = dyn_cast<UInt>(dtype);
if (int_dypte != nullptr) {
switch (uint_dypte->nbits()) {
case 8:
*data = std::make_shared<UInt8Imm>(obj);
break;
case 16:
*data = std::make_shared<UInt16Imm>(obj);
break;
case 32:
*data = std::make_shared<UInt32Imm>(obj);
break;
case 64:
*data = std::make_shared<UInt64Imm>(obj);
break;
default:
*data = std::make_shared<UInt32Imm>(obj);
}
return true;
}

auto float_dypte = dyn_cast<Float>(dtype);
if (float_dypte != nullptr) {
switch (float_dypte->nbits()) {
case 32:
*data = std::make_shared<FP32Imm>(obj);
break;
case 64:
*data = std::make_shared<FP64Imm>(obj);
break;
default:
*data = std::make_shared<FP32Imm>(obj);
}
return true;
}

return false;
}

bool ConvertFloatWithType(const float &obj, ValuePtr *const data, TypePtr dtype = nullptr) {
if (dtype == nullptr) {
*data = std::make_shared<FP32Imm>(obj);
return true;
}

auto float_dypte = dyn_cast<Float>(dtype);
if (float_dypte == nullptr) {
return false;
}

switch (float_dypte->nbits()) {
case 32:
*data = std::make_shared<FP32Imm>(obj);
break;
case 64:
*data = std::make_shared<FP64Imm>(obj);
break;
default:
*data = std::make_shared<FP32Imm>(obj);
}
return true;
}
} // namespace

bool ConvertData(const py::object &obj, ValuePtr *const data, bool use_signature) {
bool ConvertData(const py::object &obj, ValuePtr *const data, bool use_signature, TypePtr dtype) {
// check parameter valid
if (data == nullptr) {
MS_LOG(ERROR) << "Data is null pointer";
@@ -299,9 +389,9 @@ bool ConvertData(const py::object &obj, ValuePtr *const data, bool use_signature
} else if (py::isinstance<py::bool_>(obj)) {
converted = std::make_shared<BoolImm>(py::cast<bool>(obj));
} else if (py::isinstance<py::int_>(obj)) {
converted = std::make_shared<Int32Imm>(py::cast<int>(obj));
ret = ConvertIntegerWithType(py::cast<int>(obj), &converted, dtype);
} else if (py::isinstance<py::float_>(obj)) {
converted = std::make_shared<FP32Imm>(py::cast<float>(obj));
ret = ConvertFloatWithType(py::cast<float>(obj), &converted, dtype);
} else if (py::isinstance<py::str>(obj)) {
converted = std::make_shared<StringImm>(py::cast<std::string>(obj));
} else if (py::isinstance<py::dict>(obj)) {


+ 1
- 1
mindspore/ccsrc/pipeline/jit/parse/parse_base.h View File

@@ -139,7 +139,7 @@ enum ClassInstanceTypeDef {
};

// Convert python object to ValuePtr
bool ConvertData(const py::object &obj, ValuePtr *data, bool use_signature = false);
bool ConvertData(const py::object &obj, ValuePtr *data, bool use_signature = false, TypePtr dtype = nullptr);

// Convert python obj to graph
FuncGraphPtr ConvertToFuncGraph(const py::object &obj,


+ 3
- 2
mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc View File

@@ -407,9 +407,9 @@ py::tuple PreparePyInputs(const PrimitivePyPtr &prim_py, const AbstractBasePtrLi

AbstractBasePtr PyInferRes2Abstract(const PrimitivePyPtr &prim_py, const py::dict &output) {
// Convert to AbstractValue based on type and shape
auto out_dtype = output["dtype"];
if (output["value"].is_none()) {
auto out_shape = output["shape"];
auto out_dtype = output["dtype"];
py::object min_shape = output.contains("min_shape") ? (py::object)output["min_shape"] : (py::object)py::none();
py::object max_shape = output.contains("max_shape") ? (py::object)output["max_shape"] : (py::object)py::none();

@@ -417,7 +417,8 @@ AbstractBasePtr PyInferRes2Abstract(const PrimitivePyPtr &prim_py, const py::dic
}
// Convert pyobject to Value, then to AbstractValue
ValuePtr converted_ret = nullptr;
bool converted = parse::ConvertData(output["value"], &converted_ret);
TypePtr dtype = py::isinstance<Type>(out_dtype) ? out_dtype.cast<TypePtr>() : nullptr;
bool converted = parse::ConvertData(output["value"], &converted_ret, false, dtype);
if (!converted) {
MS_LOG(EXCEPTION) << "Convert data failed";
}


+ 22
- 2
mindspore/ccsrc/utils/convert_utils.cc View File

@@ -45,14 +45,34 @@ py::object ValuePtrToPyData(const ValuePtr &value) {
MS_LOG(EXCEPTION) << "value is null";
}
py::object ret;
if (value->isa<Int32Imm>()) {
MS_LOG(DEBUG) << "int";
if (value->isa<Int8Imm>()) {
MS_LOG(DEBUG) << "int8";
py::int_ v = value->cast<Int8ImmPtr>()->value();
ret = v;
} else if (value->isa<Int16Imm>()) {
MS_LOG(DEBUG) << "int16";
py::int_ v = value->cast<Int16ImmPtr>()->value();
ret = v;
} else if (value->isa<Int32Imm>()) {
MS_LOG(DEBUG) << "int32";
py::int_ v = value->cast<Int32ImmPtr>()->value();
ret = v;
} else if (value->isa<Int64Imm>()) {
MS_LOG(DEBUG) << "int64";
py::int_ v = value->cast<Int64ImmPtr>()->value();
ret = v;
} else if (value->isa<UInt8Imm>()) {
MS_LOG(DEBUG) << "uint8";
py::int_ v = value->cast<UInt8ImmPtr>()->value();
ret = v;
} else if (value->isa<UInt16Imm>()) {
MS_LOG(DEBUG) << "uint16";
py::int_ v = value->cast<UInt16ImmPtr>()->value();
ret = v;
} else if (value->isa<UInt32Imm>()) {
MS_LOG(DEBUG) << "uint32";
py::int_ v = value->cast<UInt32ImmPtr>()->value();
ret = v;
} else if (value->isa<UInt64Imm>()) {
MS_LOG(DEBUG) << "uint64";
py::int_ v = value->cast<UInt64ImmPtr>()->value();


+ 5
- 1
mindspore/core/abstract/abstract_value.cc View File

@@ -97,8 +97,12 @@ AbstractBasePtr AbstractScalar::Join(const AbstractBasePtr &other) {
}
auto value_self = GetValueTrack();
MS_EXCEPTION_IF_NULL(value_self);
ValuePtr res_value = ValueJoin(value_self, other->GetValueTrack());
TypePtr res_type = TypeJoin(GetTypeTrack(), other->GetTypeTrack());
if (res_type == kAnyType) {
MS_EXCEPTION(TypeError) << "Type join failed, type1 = " << GetTypeTrack()->ToString()
<< ", type2 = " << other->GetTypeTrack()->ToString();
}
ValuePtr res_value = ValueJoin(value_self, other->GetValueTrack());
if (res_value == value_self) {
return shared_from_base<AbstractBase>();
}


+ 10
- 2
mindspore/core/abstract/utils.cc View File

@@ -50,9 +50,17 @@ ShapePtr ShapeJoin(const ShapePtr &shape1, const ShapePtr &shape2) {
if (*shape1 == *shape2) {
return shape1;
}
// lengths of two shapes are not same, join failed
if (shape1->shape().size() != shape2->shape().size()) {
MS_LOG(WARNING) << "Unsupported shape join. shape1 = " << shape1->ToString() << ", shape2 = " << shape2->ToString();
return shape1;
// special case: shape(1), shape() -> shape(1)
if (shape1->shape().size() == 1 && shape1->shape()[0] == 1 && shape2->shape().size() == 0) {
return shape1;
}
if (shape2->shape().size() == 1 && shape2->shape()[0] == 1 && shape1->shape().size() == 0) {
return shape2;
}
MS_EXCEPTION(ValueError) << "Unsupported shape join. shape1 = " << shape1->ToString()
<< ", shape2 = " << shape2->ToString();
}
std::vector<int> dims;
bool has_dynamic_shape = false;


+ 6
- 6
mindspore/core/ir/scalar.h View File

@@ -105,7 +105,7 @@ class Int8Imm : public IntergerImm {

std::string DumpText() const override {
std::ostringstream oss;
oss << "I8(" << v_ << ")";
oss << "I8(" << int(v_) << ")";
return oss.str();
}

@@ -131,7 +131,7 @@ class Int16Imm : public IntergerImm {

std::string DumpText() const override {
std::ostringstream oss;
oss << "I16(" << v_ << ")";
oss << "I16(" << int(v_) << ")";
return oss.str();
}

@@ -157,7 +157,7 @@ class Int32Imm : public IntergerImm {

std::string DumpText() const override {
std::ostringstream oss;
oss << "I32(" << v_ << ")";
oss << "I32(" << int(v_) << ")";
return oss.str();
}

@@ -211,7 +211,7 @@ class UInt8Imm : public IntergerImm {

std::string DumpText() const override {
std::ostringstream oss;
oss << "U8(" << v_ << ")";
oss << "U8(" << unsigned(v_) << ")";
return oss.str();
}

@@ -239,7 +239,7 @@ class UInt16Imm : public IntergerImm {

std::string DumpText() const override {
std::ostringstream oss;
oss << "U16(" << v_ << ")";
oss << "U16(" << unsigned(v_) << ")";
return oss.str();
}

@@ -267,7 +267,7 @@ class UInt32Imm : public IntergerImm {

std::string DumpText() const override {
std::ostringstream oss;
oss << "U32(" << v_ << ")";
oss << "U32(" << unsigned(v_) << ")";
return oss.str();
}



+ 1
- 1
tests/mindspore_test_framework/utils/check_gradient.py View File

@@ -324,7 +324,7 @@ class ScalarGradChecker(_GradChecker):
self.input_selector = [i for i in range(self.nin)]

def get_sens(self, i):
return 1
return 1.0

def check_against_numeric(self, out_index):
args = list(self.args)


+ 70
- 0
tests/ut/python/ops/test_control_ops.py View File

@@ -916,3 +916,73 @@ def test_recursive_call():
with pytest.raises(RuntimeError):
net(input_data)
context.set_context(max_call_depth=old_max_call_depth)


def test_switch_layer_shape_join_failed():
class AddFuncNet(nn.Cell):
def __init__(self, funcs, new_func):
super(AddFuncNet, self).__init__()
self.funcs = funcs
self.new_func = new_func

def construct(self, i, inputs):
final_funcs = self.funcs + (self.new_func,)
x = final_funcs[i](inputs)
return x

class ReLUTuple(nn.Cell):
def __init__(self):
super(ReLUTuple, self).__init__()
self.op = nn.ReLU()

def construct(self, x):
return self.op(x[0])

func1 = nn.Softmax()
func2 = nn.ReLU()
func3 = ReLUTuple()

funcs = (func1, func2)


net = AddFuncNet(funcs, func3)

inp = Tensor(np.random.randn(2, 3, 4, 5).astype(np.float32))
i = Tensor(1, mstype.int32)
with pytest.raises(ValueError) as err:
net(i, inp)


def test_switch_layer_dtype_join_failed():
class Cast(nn.Cell):
def __init__(self, dtype):
super(Cast, self).__init__()
self.op = P.Cast()
self.dtype = dtype

def construct(self, x):
y = self.op(x, self.dtype)
return y + y

class SwitchNegNet(nn.Cell):
def __init__(self, funcs):
super(SwitchNegNet, self).__init__()
self.funcs = funcs
self.op = P.Neg()

def construct(self, i, inputs):
x = self.funcs[i](inputs)
x = self.op(x)
return x


func1 = nn.ReLU()
func2 = Cast(mstype.int32)
funcs = (func1, func2)
net = SwitchNegNet(funcs)

inp = Tensor(np.random.randn(2, 3, 4, 5).astype(np.float32))
i = Tensor(0, mstype.int32)

with pytest.raises(TypeError) as err:
net(i, inp)

+ 4
- 3
tests/ut/python/ops/test_ops.py View File

@@ -33,6 +33,7 @@ from ....mindspore_test_framework.pipeline.forward.compile_forward \
pipeline_for_compile_forward_ge_graph_for_case_by_case_config_exception)
from ....mindspore_test_framework.pipeline.gradient.compile_gradient \
import pipeline_for_compile_grad_ge_graph_for_case_by_case_config
from ....ops_common import convert


grad_all_with_sens = C.GradOperation('grad_all_with_sens', get_all=True, sens_param=True)
@@ -1703,7 +1704,7 @@ test_case_nn_ops = [
('ResizeBilinear', {
'block': P.ResizeBilinear((5, 5)),
'desc_inputs': [Tensor([[[[1, 2, 3, 4, 5], [1, 2, 3, 4, 5]]]], mstype.float16)],
'desc_bprop': [Tensor([[[[1, 2, 3, 4, 5], [1, 2, 3, 4, 5]]]], mstype.float16)]}),
'desc_bprop': [Tensor([[[[1, 2, 3, 4, 5], [1, 2, 3, 4, 5]]]], mstype.float32)]}),
('ResizeBilinearGrad', {
'block': G.ResizeBilinearGrad(),
'desc_inputs': [Tensor([[[[1, 2, 3, 4, 5]]]], mstype.float32), Tensor([[[[1, 2, 3, 4, 5]]]], mstype.float32)],
@@ -1712,7 +1713,7 @@ test_case_nn_ops = [
('ROIAlign', {
'block': P.ROIAlign(7, 7, 0.03125, 2),
'desc_inputs': [[2, 256, 192, 320], [1024, 5]],
'desc_bprop': [[7, 7]]}),
'desc_bprop': [[1024, 256, 7, 7]]}),
('ROIAlignGrad', {
'block': G.ROIAlignGrad((1, 1, 1, 1), 2, 2, 0.5, 2),
'desc_inputs': [[1, 1, 2, 2], [1, 5]],
@@ -2315,7 +2316,7 @@ test_case_other_ops = [
('IOU', {
'block': P.IOU(),
'desc_inputs': [Tensor(np.ones((256, 4), np.float16)), Tensor(np.ones((128, 4), np.float16))],
'desc_bprop': [[128, 256]]}),
'desc_bprop': [convert([128, 256], np.float16)]}),
('Summary', {
'block': SummaryNet(),
'desc_inputs': [Tensor(np.array([1.1]).astype(np.float32)),


+ 6
- 6
tests/ut/python/ops/test_ops_reid.py View File

@@ -118,29 +118,29 @@ test_case_reid_ops = [
'desc_inputs': [[256, 8]],
'desc_bprop': [[256, 8]]}),
('Pow', {
'block': P.Pow(), # 输入有标量插件产生了段错误。
'block': P.Pow(),
'desc_const': [2.0],
'desc_inputs': [[1, 512]],
'desc_bprop': [[1, 512]]}),
('LogicalNot', {
'block': P.LogicalNot(),
'desc_inputs': [convert([256], np.bool_)],
'desc_bprop': [[256]]}), # 自定义算子 input bool没转换,gongchen提单。
'desc_bprop': [convert([256], np.bool_)]}),
('Equal', {
'block': P.Equal(),
'desc_inputs': [convert([256], np.float16), convert([256], np.float16)],
'desc_bprop': [[256]]}),
'desc_bprop': [convert([256], np.bool_)]}),
('Greater', {
'block': P.Greater(),
'desc_inputs': [convert([256], np.float16), convert([256], np.float16)],
'desc_bprop': [[256]]}),
'desc_bprop': [convert([256], np.bool_)]}),
('Dropout', {
'block': nn.Dropout(),
'desc_inputs': [[1, 512, 7, 7]],
'desc_bprop': [[1, 512, 7, 7]]}), # 输入有标量插件产生了段错误。
'desc_bprop': [[1, 512, 7, 7]]}),
('MatMul', {
'block': P.MatMul(),
'desc_inputs': [[64, 512], [512, 64]], # fp16不行。很有问题。
'desc_inputs': [[64, 512], [512, 64]],
'desc_bprop': [[64, 64]]}),
('Maximum', {
'block': P.Maximum(),


+ 8
- 8
tests/ut/python/parameter_feature/test_var_grad.py View File

@@ -84,8 +84,8 @@ class Bprop(Cell):
self.grad = grad_op
self.with_sens = False
self.sens = sens
if sens:
self.sens = Tensor(sens, dtype=mstype.float32)
if not sens is None:
self.sens = sens if isinstance(sens, Tensor) else Tensor(sens, dtype=mstype.float32)
self.with_sens = True

def construct(self, *inputs):
@@ -115,7 +115,7 @@ def test_all_var_args_grad_with_sens():

x = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
y = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
sens = Tensor(1.0, dtype=mstype.float32)
sens = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
net = VarNet(SecondNet())
grad_net = GradNet(net)
_ = grad_net(x, y, sens)
@@ -167,7 +167,7 @@ def test_grad_all_var_args_with_sens():

x = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
y = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
sens = Tensor(1.0, dtype=mstype.float32)
sens = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
net = VarNet(SecondNet())
grad_net = GradNet(net)
_ = grad_net(x, y, sens)
@@ -185,7 +185,7 @@ def test_grad_var_args_with_sens():

x = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
y = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
sens = Tensor(1.0, dtype=mstype.float32)
sens = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
net = VarNet(SecondNet())
grad_net = GradNet(net)
_ = grad_net(x, y, sens)
@@ -244,7 +244,7 @@ def test_var_args_grad():

x = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
y = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
sens = Tensor(1.0, dtype=mstype.float32)
sens = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
net = VarNet(SecondNet())
grad_net = GradNet(net)
_ = grad_net(x, y, sens)
@@ -292,14 +292,14 @@ def test_grad_within_if_else():
self.net = net
grad_op = C.GradOperation(
name='grad', get_all=False, get_by_list=True, sens_param=True)
self.grad = Bprop(self.net, True, self.weights, grad_op, 1.0)
sens = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
self.grad = Bprop(self.net, True, self.weights, grad_op, sens)

def construct(self, *inputs):
return self.grad(*inputs)

x = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
y = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
_ = Tensor(1.0, dtype=mstype.float32)
net = VarNet(SecondNet())
grad_net = GradNet(net)
out = grad_net(x, y)


Loading…
Cancel
Save