Browse Source

!8344 add dynamic shape of row-split operations

From: @hwjiaorui
Reviewed-by: 
Signed-off-by:
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
baf6a059cf
11 changed files with 381 additions and 4 deletions
  1. +3
    -1
      mindspore/ccsrc/backend/optimizer/pass/convert_const_input_to_attr.cc
  2. +10
    -1
      mindspore/core/abstract/infer_functions.h
  3. +127
    -0
      mindspore/core/abstract/prim_maths.cc
  4. +47
    -0
      mindspore/core/abstract/prim_others.cc
  5. +5
    -2
      mindspore/core/abstract/primitive_infer_map.cc
  6. +5
    -0
      mindspore/core/ir/dtype/type.h
  7. +4
    -0
      mindspore/ops/_op_impl/tbe/__init__.py
  8. +57
    -0
      mindspore/ops/_op_impl/tbe/cast_ds.py
  9. +42
    -0
      mindspore/ops/_op_impl/tbe/equal_ds.py
  10. +41
    -0
      mindspore/ops/_op_impl/tbe/minimum_ds.py
  11. +40
    -0
      mindspore/ops/_op_impl/tbe/sub_ds.py

+ 3
- 1
mindspore/ccsrc/backend/optimizer/pass/convert_const_input_to_attr.cc View File

@@ -43,6 +43,7 @@ const AnfNodePtr ConvertConstInputToAttr::Process(const FuncGraphPtr &, const An
todos.push_back(node);
}

std::set<string> DynamicShapeConstInputToAttr = {kCastOpName, kExpandDimsOpName};
for (auto &t : todos) {
CNodePtr cnode = t->cast<CNodePtr>();
ConstInputToAttrInfoRegister reg;
@@ -61,7 +62,8 @@ const AnfNodePtr ConvertConstInputToAttr::Process(const FuncGraphPtr &, const An
continue;
}
}
if (AnfAlgo::IsDynamicShape(cnode)) {
if (AnfAlgo::IsDynamicShape(cnode) &&
DynamicShapeConstInputToAttr.find(AnfAlgo::GetCNodeName(cnode)) == DynamicShapeConstInputToAttr.end()) {
MS_LOG(INFO) << "current node is dynamic shape " << cnode->fullname_with_scope();
continue;
}


+ 10
- 1
mindspore/core/abstract/infer_functions.h View File

@@ -223,7 +223,16 @@ AbstractBasePtr InferImplReduceScatter(const AnalysisEnginePtr &, const Primitiv

AbstractBasePtr InferImplMemCpyAsync(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);

AbstractBasePtr InferImplSub(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplEqual(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplCast(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplMinimum(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplExpandDims(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
template <typename T>
AbstractBasePtr InferTupleOrListOrDictLen(const std::string &op_name, const AbstractBasePtrList &args_spec_list) {
// Inputs: a tuple or list or dict.


+ 127
- 0
mindspore/core/abstract/prim_maths.cc View File

@@ -96,5 +96,132 @@ AbstractBasePtr InferImplSquare(const AnalysisEnginePtr &, const PrimitivePtr &p
CheckArgsSize(op_name, args_spec_list, 1);
return args_spec_list[0]->Broaden();
}

AbstractBasePtr InferImplSub(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 2);
auto input_x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
MS_EXCEPTION_IF_NULL(input_x);
MS_EXCEPTION_IF_NULL(input_x->shape());

auto input_y = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
MS_EXCEPTION_IF_NULL(input_y);
MS_EXCEPTION_IF_NULL(input_y->shape());

auto x_shape = input_x->shape()->shape();
auto y_shape = input_y->shape()->shape();
auto output_shape = BroadcastShape(x_shape, y_shape);
if (output_shape.empty()) {
MS_LOG(EXCEPTION) << "BroadcastShape fail: " << args_spec_list[0]->ToString() << ","
<< args_spec_list[1]->ToString();
}

auto x_type = input_x->BuildType();
MS_EXCEPTION_IF_NULL(x_type);
MS_EXCEPTION_IF_NULL(x_type->cast<TensorTypePtr>());
auto y_type = input_y->BuildType();
MS_EXCEPTION_IF_NULL(y_type);
MS_EXCEPTION_IF_NULL(y_type->cast<TensorTypePtr>());

auto x_element = x_type->cast<TensorTypePtr>()->element();
MS_EXCEPTION_IF_NULL(x_element);
auto y_element = y_type->cast<TensorTypePtr>()->element();
MS_EXCEPTION_IF_NULL(y_element);

auto x_element_type = x_element->number_type();
auto y_element_type = y_element->number_type();

auto x_priority = type_priority_map.find(x_element_type);
if (x_priority == type_priority_map.end()) {
MS_LOG(EXCEPTION) << "input_x type is " << x_element_type << ", it's not number type.";
}
auto y_priority = type_priority_map.find(y_element_type);
if (y_priority == type_priority_map.end()) {
MS_LOG(EXCEPTION) << "input_y type is " << y_element_type << ", it's not number type.";
}

if (x_priority->second >= y_priority->second) {
return std::make_shared<AbstractTensor>(input_x->element(), std::make_shared<Shape>(output_shape));
} else {
return std::make_shared<AbstractTensor>(input_y->element(), std::make_shared<Shape>(output_shape));
}
}

AbstractBasePtr InferImplEqual(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 2);
auto input_x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
MS_EXCEPTION_IF_NULL(input_x);
MS_EXCEPTION_IF_NULL(input_x->shape());

auto input_y = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
MS_EXCEPTION_IF_NULL(input_y);
MS_EXCEPTION_IF_NULL(input_y->shape());

auto x_shape = input_x->shape()->shape();
auto y_shape = input_y->shape()->shape();
auto out_shape = BroadcastShape(x_shape, y_shape);
if (out_shape.empty()) {
MS_LOG(EXCEPTION) << "BroadcastShape fail: " << args_spec_list[0]->ToString() << ","
<< args_spec_list[1]->ToString();
}

auto output_type = std::make_shared<Bool>();
auto ret = std::make_shared<AbstractTensor>(output_type, out_shape);
return ret;
}

AbstractBasePtr InferImplMinimum(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 2);
auto input_x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
MS_EXCEPTION_IF_NULL(input_x);
MS_EXCEPTION_IF_NULL(input_x->shape());

auto input_y = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
MS_EXCEPTION_IF_NULL(input_y);
MS_EXCEPTION_IF_NULL(input_y->shape());

auto x_shape = input_x->shape()->shape();
auto y_shape = input_y->shape()->shape();
auto output_shape = BroadcastShape(x_shape, y_shape);
if (output_shape.empty()) {
MS_LOG(EXCEPTION) << "BroadcastShape fail: " << args_spec_list[0]->ToString() << ","
<< args_spec_list[1]->ToString();
}

auto x_type = input_x->BuildType();
MS_EXCEPTION_IF_NULL(x_type);
MS_EXCEPTION_IF_NULL(x_type->cast<TensorTypePtr>());
auto y_type = input_y->BuildType();
MS_EXCEPTION_IF_NULL(y_type);
MS_EXCEPTION_IF_NULL(y_type->cast<TensorTypePtr>());

auto x_element = x_type->cast<TensorTypePtr>()->element();
MS_EXCEPTION_IF_NULL(x_element);
auto y_element = y_type->cast<TensorTypePtr>()->element();
MS_EXCEPTION_IF_NULL(y_element);

auto x_element_type = x_element->number_type();
auto y_element_type = y_element->number_type();

auto x_priority = type_priority_map.find(x_element_type);
if (x_priority == type_priority_map.end()) {
MS_LOG(EXCEPTION) << "input_x type is " << x_element_type << ", it's not number type.";
}
auto y_priority = type_priority_map.find(y_element_type);
if (y_priority == type_priority_map.end()) {
MS_LOG(EXCEPTION) << "input_y type is " << y_element_type << ", it's not number type.";
}

if (x_priority->second >= y_priority->second) {
return std::make_shared<AbstractTensor>(input_x->element(), std::make_shared<Shape>(output_shape));
} else {
return std::make_shared<AbstractTensor>(input_y->element(), std::make_shared<Shape>(output_shape));
}
}
} // namespace abstract
} // namespace mindspore

+ 47
- 0
mindspore/core/abstract/prim_others.cc View File

@@ -440,5 +440,52 @@ AbstractBasePtr InferImplMemCpyAsync(const AnalysisEnginePtr &, const PrimitiveP
MS_EXCEPTION_IF_NULL(x->shape());
return std::make_shared<AbstractTensor>(x->element(), std::make_shared<Shape>(x->shape()->shape()));
}

AbstractBasePtr InferImplCast(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 1);
auto input_x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
MS_EXCEPTION_IF_NULL(input_x);
MS_EXCEPTION_IF_NULL(input_x->shape());
auto input_type = primitive->GetAttr("dst_type")->cast<TypePtr>();
auto ret = std::make_shared<AbstractTensor>(input_type, input_x->shape()->shape());
return ret;
}

AbstractBasePtr InferImplExpandDims(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 2);
auto x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
MS_EXCEPTION_IF_NULL(x);
MS_EXCEPTION_IF_NULL(x->shape());

auto axis = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
MS_EXCEPTION_IF_NULL(axis);

std::vector<int64_t> shape;
std::vector<int64_t> x_shape = x->shape()->shape();
shape.insert(shape.end(), x_shape.begin(), x_shape.end());

auto axis_value = axis->BuildValue();
if (!axis_value->isa<tensor::Tensor>()) {
MS_LOG(EXCEPTION) << axis_value << " axis_value should be tensor, but got " << axis_value->type_name();
}
auto axis_tensor = axis_value->cast<tensor::TensorPtr>();
int value = *(static_cast<int *>(axis_tensor->data_c()));
if (value < -(SizeToInt(x_shape.size()) + 1) || value > SizeToInt(x_shape.size())) {
MS_LOG(EXCEPTION) << " axis value shoud be in range [-intput_x.dim-1,input_x.dim], but axis value is" << value
<< " and input_x.dim is" << x_shape.size();
}
if (value < 0) {
value = value + SizeToInt(x_shape.size()) + 1;
}
shape.insert(shape.begin() + value, 1);

auto ret = std::make_shared<AbstractTensor>(x->element(), std::make_shared<Shape>(shape));
return ret;
}

} // namespace abstract
} // namespace mindspore

+ 5
- 2
mindspore/core/abstract/primitive_infer_map.cc View File

@@ -37,12 +37,13 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
// Maths
{prim::kPrimMaximumGrad, {InferImplMinOrMaxGrad, true}},
{prim::kPrimMinimumGrad, {InferImplMinOrMaxGrad, true}},

{prim::kPrimMul, {InferImplMul, true}},
{prim::kPrimTensorAdd, {InferImplTensorAdd, true}},
{prim::kPrimSquare, {InferImplSquare, true}},

{prim::kPrimSqrt, {InferImplSqrt, true}},
{prim::kPrimSub, {InferImplSub, true}},
{prim::kPrimEqual, {InferImplEqual, true}},
{prim::kPrimMinimum, {InferImplMinimum, true}},
// Array
{prim::kPrimScalarToArray, {InferImplScalarToArray, true}},
{prim::kPrimArrayToScalar, {InferImplArrayToScalar, true}},
@@ -128,6 +129,8 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
{prim::kPrimAllGather, {InferImplAllGather, true}},
{prim::kPrimReduceScatter, {InferImplReduceScatter, true}},
{prim::kPrimMemCpyAsync, {InferImplMemCpyAsync, true}},
{prim::kPrimCast, {InferImplCast, true}},
{prim::kPrimExpandDims, {InferImplExpandDims, true}},
};
return prim_eval_implement_map;
}


+ 5
- 0
mindspore/core/ir/dtype/type.h View File

@@ -127,6 +127,11 @@ const std::unordered_map<TypeId, std::string> type_name_map = {
{kNumberTypeInt16, "int16"}, {kNumberTypeInt32, "int32"}, {kNumberTypeInt64, "int64"},
{kNumberTypeFloat16, "float16"}, {kNumberTypeFloat32, "float32"}, {kNumberTypeFloat64, "float64"}};

const std::unordered_map<TypeId, int> type_priority_map = {
{kNumberTypeBool, 0}, {kNumberTypeUInt8, 1}, {kNumberTypeInt8, 2},
{kNumberTypeInt16, 3}, {kNumberTypeInt32, 4}, {kNumberTypeInt64, 5},
{kNumberTypeFloat16, 6}, {kNumberTypeFloat32, 7}, {kNumberTypeFloat64, 8}};

std::ostream &operator<<(std::ostream &os, const TypePtrList &types);
} // namespace mindspore



+ 4
- 0
mindspore/ops/_op_impl/tbe/__init__.py View File

@@ -50,6 +50,7 @@ from .batchnorm_grad import _batch_norm_grad_tbe
from .bias_add import _bias_add_tbe
from .bias_add_grad import _bias_add_grad_tbe
from .cast import _cast_tbe
from .cast_ds import _cast_ds_tbe
from .conv2d import _conv2d_tbe
from .conv2d_backprop_filter import _conv2d_backprop_filter_tbe
from .conv2d_backprop_input import _conv2d_backprop_input_tbe
@@ -80,6 +81,7 @@ from .trans_data import _trans_data_tbe
from .top_k import _top_k_tbe
from .matmul import _matmul_tbe
from .sub import _sub_tbe
from .sub_ds import _sub_ds_tbe
from .scatter_nd import _scatter_nd_tbe
from .scatter_nd_d import _scatter_nd_d_tbe
from .scatter_nd_add import _scatter_nd_add_tbe
@@ -121,6 +123,7 @@ from .npu_get_float_status import _npu_get_float_status_tbe
from .npu_alloc_float_status import _npu_alloc_float_status_tbe
from .one_hot import _one_hot_tbe
from .equal import _equal_tbe
from .equal_ds import _equal_ds_tbe
from .less import _less_tbe
from .less_equal import _less_equal_tbe
from .logical_and import _logical_and_tbe
@@ -159,6 +162,7 @@ from .select import _select_tbe
from .pow import _pow_tbe
from .maximum import _maximum_tbe
from .minimum import _minimum_tbe
from .minimum_ds import _minimum_ds_tbe
from .minimum_grad import _minimum_grad_tbe
from .maximum_grad import _maximum_grad_tbe
from .concat import _concat_tbe


+ 57
- 0
mindspore/ops/_op_impl/tbe/cast_ds.py View File

@@ -0,0 +1,57 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

"""Cast op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

cast_ds_op_info = TBERegOp("Cast") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("cast.so") \
.compute_cost(10) \
.kernel_name("cast") \
.partial_flag(True) \
.attr("dst_type", "required", "int", "all") \
.dynamic_shape(True)\
.input(0, "x", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.op_pattern("formatAgnostic") \
.dtype_format(DataType.BOOL_None, DataType.F16_None) \
.dtype_format(DataType.BOOL_None, DataType.U8_None) \
.dtype_format(DataType.BOOL_None, DataType.F32_None) \
.dtype_format(DataType.BOOL_None, DataType.I32_None) \
.dtype_format(DataType.I8_None, DataType.F16_None) \
.dtype_format(DataType.I8_None, DataType.F32_None) \
.dtype_format(DataType.I8_None, DataType.I32_None) \
.dtype_format(DataType.U8_None, DataType.F16_None) \
.dtype_format(DataType.U8_None, DataType.F32_None) \
.dtype_format(DataType.U8_None, DataType.I32_None) \
.dtype_format(DataType.I32_None, DataType.BOOL_None) \
.dtype_format(DataType.I32_None, DataType.F16_None) \
.dtype_format(DataType.I32_None, DataType.F32_None) \
.dtype_format(DataType.I32_None, DataType.I8_None) \
.dtype_format(DataType.I32_None, DataType.U8_None) \
.dtype_format(DataType.F16_None, DataType.U8_None) \
.dtype_format(DataType.F16_None, DataType.F32_None) \
.dtype_format(DataType.F16_None, DataType.I32_None) \
.dtype_format(DataType.F32_None, DataType.F16_None) \
.dtype_format(DataType.F32_None, DataType.I32_None) \
.get_op_info()


@op_info_register(cast_ds_op_info)
def _cast_ds_tbe():
"""Cast TBE register"""
return

+ 42
- 0
mindspore/ops/_op_impl/tbe/equal_ds.py View File

@@ -0,0 +1,42 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

"""Equal op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

equal_ds_op_info = TBERegOp("Equal") \
.fusion_type("ELEMWISE") \
.async_flag(False) \
.binfile_name("equal.so") \
.compute_cost(10) \
.kernel_name("equal") \
.partial_flag(True) \
.dynamic_shape(True)\
.input(0, "x1", False, "required", "all") \
.input(1, "x2", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.op_pattern("broadcast") \
.dtype_format(DataType.I8_Default, DataType.I8_Default, DataType.BOOL_Default) \
.dtype_format(DataType.U8_Default, DataType.U8_Default, DataType.BOOL_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.BOOL_Default) \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.BOOL_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.BOOL_Default) \
.get_op_info()


@op_info_register(equal_ds_op_info)
def _equal_ds_tbe():
"""Equal TBE register"""
return

+ 41
- 0
mindspore/ops/_op_impl/tbe/minimum_ds.py View File

@@ -0,0 +1,41 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================


"""Minimum op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

minimum_ds_op_info = TBERegOp("Minimum") \
.fusion_type("ELEMWISE") \
.async_flag(False) \
.binfile_name("minimum.so") \
.compute_cost(10) \
.kernel_name("minimum") \
.partial_flag(True) \
.dynamic_shape(True)\
.input(0, "x1", False, "required", "all") \
.input(1, "x2", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.op_pattern("broadcast") \
.dtype_format(DataType.I32_None, DataType.I32_None, DataType.I32_None) \
.dtype_format(DataType.F16_None, DataType.F16_None, DataType.F16_None) \
.dtype_format(DataType.F32_None, DataType.F32_None, DataType.F32_None) \
.get_op_info()


@op_info_register(minimum_ds_op_info)
def _minimum_ds_tbe():
"""Minimum TBE register"""
return

+ 40
- 0
mindspore/ops/_op_impl/tbe/sub_ds.py View File

@@ -0,0 +1,40 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

"""Sub op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

sub_ds_op_info = TBERegOp("Sub") \
.fusion_type("ELEMWISE") \
.async_flag(False) \
.binfile_name("sub.so") \
.compute_cost(10) \
.kernel_name("sub") \
.partial_flag(True) \
.dynamic_shape(True)\
.input(0, "x1", False, "required", "all") \
.input(1, "x2", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.op_pattern("broadcast") \
.dtype_format(DataType.I32_None, DataType.I32_None, DataType.I32_None) \
.dtype_format(DataType.F16_None, DataType.F16_None, DataType.F16_None) \
.dtype_format(DataType.F32_None, DataType.F32_None, DataType.F32_None) \
.get_op_info()


@op_info_register(sub_ds_op_info)
def _sub_ds_tbe():
"""Add TBE register"""
return

Loading…
Cancel
Save