| @@ -227,7 +227,7 @@ AbstractBasePtr DoInferReduceShape(const AbstractTuplePtr &x_shape, const ValueP | |||
| } | |||
| for (auto &elem : axis_data) { | |||
| int64_t e_value = CheckAxis(primitive->name(), elem, -SizeToLong(x_rank), SizeToLong(x_rank) - 1); | |||
| int64_t e_value = CheckAxis(primitive->name(), elem, -SizeToLong(x_rank), SizeToLong(x_rank)); | |||
| (void)axis_set.insert(e_value); | |||
| } | |||
| MS_EXCEPTION_IF_NULL(x_shp_value->cast<ValueTuplePtr>()); | |||
| @@ -1019,10 +1019,10 @@ bool SetMindIRGraphAction(const ResourcePtr &res) { | |||
| if (!AbstractBasePtrListDeepEqual(func_args, broaded_args)) { | |||
| MS_LOG(EXCEPTION) << "The input arguments is not compatible with the function graph which has been exported before." | |||
| << "Please check the args is same with export.\n" | |||
| << "The export input argument size : " << func_args.size() << "\n" | |||
| << "The load input argument size : " << broaded_args.size() << "\n" | |||
| << "Export input args info:" << abstract::ArgsToString(func_args) << "\n" | |||
| << "The input args info:" << abstract::ArgsToString(broaded_args); | |||
| << "The export input argument size: " << func_args.size() << "\n" | |||
| << "The load input argument size: " << broaded_args.size() << "\n" | |||
| << "Export input args info: " << abstract::ArgsToString(func_args) << "\n" | |||
| << "The input args info: " << abstract::ArgsToString(broaded_args); | |||
| } | |||
| // suppose that there is not KeywordArgument for the top graph | |||
| @@ -326,11 +326,10 @@ void AnalysisResultCacheMgr::Todo() { | |||
| std::string ArgsToString(const AbstractBasePtrList &args_spec_list) { | |||
| std::ostringstream buffer; | |||
| buffer << "("; | |||
| for (const auto &item : args_spec_list) { | |||
| buffer << item->ToString() << " # "; | |||
| buffer << item->BuildType()->ToString() << "," << item->BuildShape()->ToString() << " #" | |||
| << "\n"; | |||
| } | |||
| buffer << " )"; | |||
| return buffer.str(); | |||
| } | |||
| } // namespace abstract | |||
| @@ -445,6 +445,36 @@ void ConvertAbstractFunctionToPython(const AbstractBasePtr &abs_base, py::dict * | |||
| } | |||
| } | |||
| } | |||
| bool CheckType(const TypePtr &expected_type, const TypePtr &x) { | |||
| // As x and predicate both are mindspore type statically, here we only to judge whether | |||
| // x is predicate or is a subclass of predicate. | |||
| return IsIdentidityOrSubclass(x, expected_type); | |||
| } | |||
| // Join all types in args_type_list; | |||
| TypePtr TypeJoin(const TypePtrList &args_type_list) { | |||
| if (args_type_list.empty()) { | |||
| MS_LOG(EXCEPTION) << "args_type_list is empty"; | |||
| } | |||
| TypePtr type_tmp = args_type_list[0]; | |||
| for (std::size_t i = 1; i < args_type_list.size(); i++) { | |||
| type_tmp = abstract::TypeJoin(type_tmp, args_type_list[i]); | |||
| } | |||
| return type_tmp; | |||
| } | |||
| TypePtr CheckTypeList(const TypePtr &predicate, const TypePtrList &args_type_list) { | |||
| MS_EXCEPTION_IF_NULL(predicate); | |||
| for (const auto &arg_type : args_type_list) { | |||
| MS_EXCEPTION_IF_NULL(arg_type); | |||
| if (!CheckType(predicate, arg_type)) { | |||
| MS_LOG(EXCEPTION) << "The expected is " << predicate->ToString() << ", not " << arg_type->ToString(); | |||
| } | |||
| } | |||
| return TypeJoin(args_type_list); | |||
| } | |||
| } // end anonymous namespace | |||
| py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base) { | |||
| @@ -49,8 +49,6 @@ AbstractBasePtr InferImplFusedSparseAdam(const AnalysisEnginePtr &, const Primit | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplBatchNorm(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplConv2D(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplBiasAddGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplHSigmoid(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| @@ -45,7 +45,7 @@ TypePtr CheckType(TypePtr type, const TypePtrList &accepts, const std::string &e | |||
| if (ok) { | |||
| return type; | |||
| } else { | |||
| MS_LOG(EXCEPTION) << error_message_prefix << accepts << " but is " << type->ToString(); | |||
| MS_EXCEPTION(TypeError) << error_message_prefix << " should be " << accepts << ",but got " << type->ToString(); | |||
| } | |||
| } | |||
| @@ -79,7 +79,8 @@ TypePtr CheckTensorsDTypeSame(const AbstractTensorPtrList &tensor_list, const Ty | |||
| TypePtr sample_type = sample_elem->BuildType(); | |||
| MS_EXCEPTION_IF_NULL(sample_type); | |||
| std::ostringstream loginfoBuffer; | |||
| loginfoBuffer << "same type, got"; | |||
| loginfoBuffer << "[" << sample_tensor->BuildType()->ToString(); | |||
| bool error_flag = false; | |||
| // Check if other elements have the same type with the first element. | |||
| for (size_t index = 1; index < tensor_list.size(); ++index) { | |||
| MS_EXCEPTION_IF_NULL(tensor_list[index]); | |||
| @@ -87,12 +88,14 @@ TypePtr CheckTensorsDTypeSame(const AbstractTensorPtrList &tensor_list, const Ty | |||
| MS_EXCEPTION_IF_NULL(elem); | |||
| auto a_type = elem->BuildType(); | |||
| MS_EXCEPTION_IF_NULL(a_type); | |||
| loginfoBuffer << " " << a_type->ToString(); | |||
| loginfoBuffer << "," << tensor_list[index]->BuildType()->ToString(); | |||
| if (sample_type->type_id() != a_type->type_id()) { | |||
| MS_LOG(EXCEPTION) << "Expected type " << sample_type->ToString() << ", but got " << a_type->ToString() | |||
| << ", index " << index; | |||
| error_flag = true; | |||
| } | |||
| } | |||
| if (error_flag) { | |||
| MS_EXCEPTION(ValueError) << error_message_prefix << " must be same, but got " << loginfoBuffer.str() << "]"; | |||
| } | |||
| MS_LOG(DEBUG) << error_message_prefix << loginfoBuffer.str(); | |||
| return CheckTensorDType(sample_tensor, accepts, error_message_prefix); | |||
| } | |||
| @@ -167,15 +170,19 @@ int64_t CheckAxis(const std::string &op, const ValuePtr &axis, int64_t minimum, | |||
| } | |||
| int64_t axis_value = GetValue<int64_t>(axis); | |||
| if (axis_value > max || axis_value < minimum) { | |||
| MS_LOG(EXCEPTION) << op << " evaluator axis value should be in the range [" << minimum << ", " << max | |||
| << "], but get " << axis_value; | |||
| MS_LOG(EXCEPTION) << "The primitive[" << op << "]'s axis value should be in the range [" << minimum << ", " << max | |||
| << "], but got " << axis_value; | |||
| } | |||
| if (axis_value < 0) { | |||
| axis_value = axis_value + SizeToLong(max); | |||
| } | |||
| return axis_value; | |||
| } | |||
| void CheckArgsSize(const std::string &op, const mindspore::abstract::AbstractBasePtrList &args_spec_list, | |||
| size_t size_expect) { | |||
| if (args_spec_list.size() != size_expect) { | |||
| MS_LOG(EXCEPTION) << op << " input args size should be " << size_expect << ", but got " << args_spec_list.size(); | |||
| MS_LOG(EXCEPTION) << op << " input arguments size should be " << size_expect << ", but got " | |||
| << args_spec_list.size(); | |||
| } | |||
| for (size_t i = 0; i < size_expect; i++) { | |||
| @@ -200,65 +207,6 @@ void CheckShapeAnyAndPositive(const std::string &op, const ShapeVector &shape) { | |||
| } | |||
| } | |||
| int64_t CheckAttrPositiveInt64(const std::string &op, const ValuePtr &attr, const std::string &attr_name) { | |||
| MS_EXCEPTION_IF_NULL(attr); | |||
| auto int64_value = attr->cast<Int64ImmPtr>(); | |||
| MS_EXCEPTION_IF_NULL(int64_value); | |||
| int64_t attr_val = int64_value->value(); | |||
| if (attr_val <= 0) { | |||
| MS_LOG(EXCEPTION) << op << " invalid " << attr_name << " value: " << attr_val << ", should be greater then 0"; | |||
| } | |||
| return attr_val; | |||
| } | |||
| std::vector<int64_t> CheckAttrIntOrTuple(const std::string &op, const ValuePtr &attr, const size_t start_idx, | |||
| const size_t num_element) { | |||
| std::vector<int64_t> result; | |||
| MS_EXCEPTION_IF_NULL(attr); | |||
| if (attr->isa<ValueTuple>()) { | |||
| auto tuple_attr = attr->cast<ValueTuplePtr>(); | |||
| MS_EXCEPTION_IF_NULL(tuple_attr); | |||
| std::vector<ValuePtr> attr_vec = tuple_attr->value(); | |||
| if (start_idx > attr_vec.size() || start_idx + num_element > attr_vec.size()) { | |||
| MS_EXCEPTION(IndexError) << op << " attr index is out of range, attr size is " << attr_vec.size() | |||
| << "but start idx got" << start_idx << " num element " << num_element; | |||
| } | |||
| auto it_start = attr_vec.begin() + start_idx; | |||
| (void)std::transform(it_start, it_start + num_element, std::back_inserter(result), | |||
| [](const ValuePtr &e) -> int64_t { return GetValue<int64_t>(e); }); | |||
| } else { | |||
| auto int64_imm = attr->cast<Int64ImmPtr>(); | |||
| MS_EXCEPTION_IF_NULL(int64_imm); | |||
| int64_t attr_val = int64_imm->value(); | |||
| (void)result.insert(result.begin(), num_element, attr_val); | |||
| } | |||
| return result; | |||
| } | |||
| std::string CheckAttrStringSet(const std::string &op, const ValuePtr &attr, const std::string &attr_name, | |||
| const std::set<std::string> &val_set) { | |||
| MS_EXCEPTION_IF_NULL(attr); | |||
| auto string_attr = attr->cast<StringImmPtr>(); | |||
| MS_EXCEPTION_IF_NULL(string_attr); | |||
| std::string attr_val = string_attr->value(); | |||
| if (val_set.find(attr_val) == val_set.end()) { | |||
| std::ostringstream buffer; | |||
| bool f_begin = true; | |||
| buffer << "{"; | |||
| for (auto &x : val_set) { | |||
| if (!f_begin) { | |||
| buffer << ", "; | |||
| } else { | |||
| f_begin = false; | |||
| } | |||
| buffer << x; | |||
| } | |||
| buffer << "}"; | |||
| MS_LOG(EXCEPTION) << op << "Unsupported " << attr_name << ": " << attr_val << ". use " << buffer.str(); | |||
| } | |||
| return attr_val; | |||
| } | |||
| void CheckRequiredArgsSize(const std::string &op, const mindspore::abstract::AbstractBasePtrList &args_spec_list, | |||
| size_t size_expect) { | |||
| if (args_spec_list.size() < size_expect) { | |||
| @@ -268,6 +216,5 @@ void CheckRequiredArgsSize(const std::string &op, const mindspore::abstract::Abs | |||
| MS_EXCEPTION_IF_NULL(args_spec_list[i]); | |||
| } | |||
| } | |||
| } // namespace abstract | |||
| } // namespace mindspore | |||
| @@ -53,8 +53,6 @@ void CheckShapeAllPositive(const std::string &op, const ShapeVector &shape); | |||
| void CheckShapeAnyAndPositive(const std::string &op, const ShapeVector &shape); | |||
| int64_t CheckAttrPositiveInt64(const std::string &op, const ValuePtr &attr, const std::string &attr_name); | |||
| std::vector<int64_t> CheckAttrIntOrTuple(const std::string &op, const ValuePtr &attr, const size_t start_idx, | |||
| const size_t num_element); | |||
| @@ -149,8 +149,6 @@ AbstractBasePtr InferImplStack(const AnalysisEnginePtr &, const PrimitivePtr &pr | |||
| ValuePtr axis = primitive->GetAttr("axis"); | |||
| // Axis value should be in [-(rank_base + 1), rank_base). | |||
| int64_t axis_value = CheckAxis(op_name, axis, -(rank_base + 1), rank_base); | |||
| // If axis is negative, add offset(rank_base + 1) to turn it to positive. | |||
| axis_value = GetPositiveAxis(axis_value, LongToSize(rank_base + 1)); | |||
| for (size_t i = 1; i < tuple_len; ++i) { | |||
| AbstractTensorPtr tensor = nullptr; | |||
| @@ -950,8 +948,7 @@ AbstractBasePtr InferImplSplit(const AnalysisEnginePtr &, const PrimitivePtr &pr | |||
| int64_t rank = SizeToLong(x_shape.size()); | |||
| ValuePtr axis = primitive->GetAttr("axis"); | |||
| int64_t axis_value = CheckAxis(op_name, axis, -(rank + 1), rank); | |||
| uint64_t axis_value_pos = LongToUlong(GetPositiveAxis(axis_value, LongToSize(rank))); | |||
| int64_t axis_value_pos = CheckAxis(op_name, axis, -(rank + 1), rank); | |||
| int64_t output_num_value = GetValue<int64_t>(primitive->GetAttr("output_num")); | |||
| if ((x_shape[axis_value_pos] != Shape::SHP_ANY) && (x_shape[axis_value_pos] % output_num_value != 0)) { | |||
| MS_LOG(EXCEPTION) << "x_shape[" << axis_value_pos << "] = " << x_shape[axis_value_pos] | |||
| @@ -1097,8 +1094,6 @@ AbstractBasePtr InferImplConcat(const AnalysisEnginePtr &, const PrimitivePtr &p | |||
| ValuePtr axis = primitive->GetAttr("axis"); | |||
| // Axis value should be in [-(rank_base + 1), rank_base). | |||
| int64_t axis_value = CheckAxis(op_name, axis, -(rank_base + 1), rank_base); | |||
| // If axis is negative, add offset(rank_base) to turn it to positive. | |||
| axis_value = GetPositiveAxis(axis_value, LongToSize(rank_base)); | |||
| int64_t all_shp = shape_base[axis_value]; | |||
| int64_t min_all_shp = min_shape_base[axis_value]; | |||
| @@ -139,14 +139,14 @@ AbstractBasePtr InferImplBatchNorm(const AnalysisEnginePtr &, const PrimitivePtr | |||
| CheckMinMaxShape(x_shape, &x_min_shape, &x_max_shape); | |||
| auto input_tensor = CheckArg<AbstractTensor>(op_name, args_spec_list, 0); | |||
| (void)CheckTensorDType(input_tensor, {kFloat16, kFloat32}, "param x of BatchNorm should be"); | |||
| (void)CheckTensorDType(input_tensor, {kFloat16, kFloat32}, "Input argument[x] of BatchNorm"); | |||
| AbstractTensorPtrList tensorPtrList = std::vector<AbstractTensorPtr>(); | |||
| for (size_t i = 1; i < args_spec_list.size(); ++i) { | |||
| auto param = CheckArg<AbstractTensor>(op_name, args_spec_list, i); | |||
| tensorPtrList.push_back(param); | |||
| } | |||
| (void)CheckTensorsDTypeSame(tensorPtrList, {kFloat16, kFloat32}, | |||
| "param gamma, beta, mean, variance of Batchnorm should be"); | |||
| "Input arguments[gamma, beta, mean, variance] of BatchNorm"); | |||
| auto data_format_ptr = primitive->GetAttr("format"); | |||
| MS_EXCEPTION_IF_NULL(data_format_ptr); | |||
| @@ -240,113 +240,6 @@ void CheckShape(const std::string &op_name, const ShapeVector &w_shape, const Ab | |||
| CheckShapeAllPositive(op_name + " w_max_shape", w_max_shape); | |||
| } | |||
| AbstractBasePtr InferImplConv2D(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| constexpr auto kConv2DInputNum = 2; | |||
| const std::string op_name = primitive->name(); | |||
| CheckArgsSize(op_name, args_spec_list, kConv2DInputNum); | |||
| AbstractTensorPtr input_x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0); | |||
| MS_EXCEPTION_IF_NULL(input_x); | |||
| MS_EXCEPTION_IF_NULL(input_x->shape()); | |||
| ShapeVector x_shape = input_x->shape()->shape(); | |||
| ShapeVector x_min_shape = input_x->shape()->min_shape(); | |||
| ShapeVector x_max_shape = input_x->shape()->max_shape(); | |||
| CheckMinMaxShape(x_shape, &x_min_shape, &x_max_shape); | |||
| CheckShapeAnyAndPositive(op_name + " x_shape", x_shape); | |||
| CheckShapeAllPositive(op_name + " x_min_shape", x_min_shape); | |||
| CheckShapeAllPositive(op_name + " x_max_shape", x_max_shape); | |||
| AbstractTensorPtr input_w = CheckArg<AbstractTensor>(op_name, args_spec_list, 1); | |||
| MS_EXCEPTION_IF_NULL(input_w); | |||
| MS_EXCEPTION_IF_NULL(input_w->shape()); | |||
| ShapeVector w_shape = input_w->shape()->shape(); | |||
| CheckShape(op_name, w_shape, input_w); | |||
| const uint64_t n_axis = 0; | |||
| uint64_t c_axis = 1; | |||
| uint64_t h_axis = 2; | |||
| uint64_t w_axis = 3; | |||
| int64_t data_format = GetAndCheckFormat(primitive->GetAttr("format")); | |||
| if (data_format == Format::NHWC) { | |||
| c_axis = 3; | |||
| h_axis = 1; | |||
| w_axis = 2; | |||
| } | |||
| int64_t group = CheckAttrPositiveInt64(op_name, primitive->GetAttr("group"), "group"); | |||
| if ((x_shape[c_axis] != Shape::SHP_ANY) && (w_shape[c_axis] != Shape::SHP_ANY) && | |||
| ((x_shape[c_axis] / group) != w_shape[c_axis])) { | |||
| MS_LOG(EXCEPTION) << "x_shape[C_in] / group must be equal to w_shape[C_in]: " << w_shape[c_axis] << ", but got " | |||
| << (x_shape[c_axis] / group); | |||
| } | |||
| int64_t out_channel = CheckAttrPositiveInt64(op_name, primitive->GetAttr("out_channel"), "out_channel"); | |||
| if ((w_shape[n_axis] != Shape::SHP_ANY) && (w_shape[n_axis] != out_channel)) { | |||
| MS_LOG(EXCEPTION) << "w_shape[" << n_axis << "] = " << w_shape[n_axis] << " must be equal to " << out_channel; | |||
| } | |||
| const size_t kernel_size_num_element = 2; | |||
| std::vector<int64_t> kernel_size = | |||
| CheckAttrIntOrTuple(op_name, primitive->GetAttr("kernel_size"), 0, kernel_size_num_element); | |||
| if ((w_shape[h_axis] != Shape::SHP_ANY) && (w_shape[h_axis] != kernel_size[0])) { | |||
| MS_LOG(EXCEPTION) << "weight height: " << w_shape[h_axis] << " must be equal to " << kernel_size[0]; | |||
| } | |||
| if ((w_shape[w_axis] != Shape::SHP_ANY) && (w_shape[w_axis] != kernel_size[1])) { | |||
| MS_LOG(EXCEPTION) << "weight width: " << w_shape[w_axis] << " must be equal to " << kernel_size[1]; | |||
| } | |||
| std::vector<int64_t> stride = | |||
| CheckAttrIntOrTuple(op_name, primitive->GetAttr("stride"), stride_start_idx, stride_num_element); | |||
| std::vector<int64_t> dilation = | |||
| CheckAttrIntOrTuple(op_name, primitive->GetAttr("dilation"), dilation_start_idx, dilation_num_element); | |||
| std::vector<int64_t> padding = | |||
| CheckAttrIntOrTuple(op_name, primitive->GetAttr("pad"), padding_start_idx, padding_num_element); | |||
| int64_t pad_mode; | |||
| CheckAndConvertUtils::GetPadModEnumValue(primitive->GetAttr("pad_mode"), &pad_mode); | |||
| std::vector<int64_t> output_hw; | |||
| std::vector<int64_t> pad_list; | |||
| std::vector<int64_t> output_hw_min; | |||
| std::vector<int64_t> pad_list_min; | |||
| std::vector<int64_t> output_hw_max; | |||
| std::vector<int64_t> pad_list_max; | |||
| Conv2DPadFunction(&output_hw, &pad_list, x_shape[h_axis], x_shape[w_axis], kernel_size, stride, dilation, pad_mode, | |||
| padding); | |||
| if (x_shape[h_axis] == Shape::SHP_ANY) { | |||
| output_hw[0] = Shape::SHP_ANY; | |||
| } | |||
| if (x_shape[w_axis] == Shape::SHP_ANY) { | |||
| output_hw[1] = Shape::SHP_ANY; | |||
| } | |||
| Conv2DPadFunction(&output_hw_min, &pad_list_min, x_min_shape[h_axis], x_min_shape[w_axis], kernel_size, stride, | |||
| dilation, pad_mode, padding); | |||
| Conv2DPadFunction(&output_hw_max, &pad_list_max, x_max_shape[h_axis], x_max_shape[w_axis], kernel_size, stride, | |||
| dilation, pad_mode, padding); | |||
| std::vector<ValuePtr> pad_list_val = {MakeValue(pad_list[0]), MakeValue(pad_list[1]), MakeValue(pad_list[2]), | |||
| MakeValue(pad_list[3])}; | |||
| primitive->set_attr("pad_list", MakeValue(pad_list_val)); | |||
| ShapeVector output_shape; | |||
| ShapeVector output_shape_min; | |||
| ShapeVector output_shape_max; | |||
| if (data_format == Format::NHWC) { | |||
| output_shape = {x_shape[n_axis], output_hw[0], output_hw[1], out_channel}; | |||
| output_shape_min = {x_min_shape[n_axis], output_hw_min[0], output_hw_min[1], out_channel}; | |||
| output_shape_max = {x_max_shape[n_axis], output_hw_max[0], output_hw_max[1], out_channel}; | |||
| } else { | |||
| output_shape = {x_shape[n_axis], out_channel, output_hw[0], output_hw[1]}; | |||
| output_shape_min = {x_min_shape[n_axis], out_channel, output_hw_min[0], output_hw_min[1]}; | |||
| output_shape_max = {x_max_shape[n_axis], out_channel, output_hw_max[0], output_hw_max[1]}; | |||
| } | |||
| CheckShapeAnyAndPositive(op_name + " output_shape", output_shape); | |||
| CheckShapeAllPositive(op_name + " output_shape_min", output_shape_min); | |||
| CheckShapeAllPositive(op_name + " output_shape_max", output_shape_max); | |||
| TypePtr x_type = input_x->element()->GetTypeTrack(); | |||
| if (x_type->type_id() == TypeId::kNumberTypeInt8) { | |||
| x_type = kInt32; | |||
| } | |||
| ShapePtr output_shape_ptr = std::make_shared<Shape>(output_shape, output_shape_min, output_shape_max); | |||
| return std::make_shared<AbstractTensor>(x_type, output_shape_ptr); | |||
| } | |||
| AbstractBasePtr InferImplBiasAddGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| // Inputs: at least one tensor(y_backprop) | |||
| @@ -186,7 +186,6 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { | |||
| {prim::kPrimPooling, R{InferImplPooling, nullptr, true}}, | |||
| {prim::kPrimPoolingGrad, R{InferImplPoolingGrad, nullptr, true}}, | |||
| {prim::kPrimBatchNorm, R{InferImplBatchNorm, nullptr, true}}, | |||
| {prim::kPrimConv2D, R{InferImplConv2D, nullptr, true}}, | |||
| {prim::kPrimBpropCut, R{InferImplBpropCut, nullptr, true}}, | |||
| {prim::kPrimDropout, R{InferImplDropout, nullptr, true}}, | |||
| {prim::kPrimSparseApplyFtrl, R{InferImplSparseApplyFtrl, nullptr, true}}, | |||
| @@ -183,83 +183,6 @@ AbstractBasePtr SensitivityTransform(const AbstractBasePtr &spec) { | |||
| return spec->Clone(); | |||
| } | |||
| namespace { | |||
| // Join all types in args_type_list; | |||
| TypePtr TypeJoin(const TypePtrList &args_type_list) { | |||
| if (args_type_list.empty()) { | |||
| MS_LOG(EXCEPTION) << "args_type_list is empty"; | |||
| } | |||
| TypePtr type_tmp = args_type_list[0]; | |||
| for (std::size_t i = 1; i < args_type_list.size(); i++) { | |||
| type_tmp = abstract::TypeJoin(type_tmp, args_type_list[i]); | |||
| } | |||
| return type_tmp; | |||
| } | |||
| } // namespace | |||
| bool CheckType(const TypePtr &expected_type, const TypePtr &x) { | |||
| // As x and predicate both are mindspore type statically, here we only to judge whether | |||
| // x is predicate or is a subclass of predicate. | |||
| return IsIdentidityOrSubclass(x, expected_type); | |||
| } | |||
| TypePtr CheckTypeList(const TypePtr &predicate, const TypePtrList &args_type_list) { | |||
| MS_EXCEPTION_IF_NULL(predicate); | |||
| for (const auto &arg_type : args_type_list) { | |||
| MS_EXCEPTION_IF_NULL(arg_type); | |||
| if (!CheckType(predicate, arg_type)) { | |||
| MS_LOG(EXCEPTION) << "The expected is " << predicate->ToString() << ", not " << arg_type->ToString(); | |||
| } | |||
| } | |||
| return TypeJoin(args_type_list); | |||
| } | |||
| int64_t GetPositiveAxis(int64_t axis_value, size_t increment) { | |||
| if (axis_value < 0) { | |||
| axis_value = axis_value + SizeToLong(increment); | |||
| } | |||
| if (axis_value < 0) { | |||
| MS_LOG(EXCEPTION) << "axis_value should not still <0"; | |||
| } | |||
| return axis_value; | |||
| } | |||
| // Return if two shapes can be broadcast. | |||
| // Broadcast shape is placed in broadcast_output_shape. | |||
| ShapeVector RealBroadcast(const std::string &op, ShapeVector x_shape, ShapeVector y_shape) { | |||
| std::reverse(x_shape.begin(), x_shape.end()); | |||
| std::reverse(y_shape.begin(), y_shape.end()); | |||
| // Fill a placeholder value 1 which will be replaced later. | |||
| size_t std_len = x_shape.size() > y_shape.size() ? x_shape.size() : y_shape.size(); | |||
| y_shape.resize(std_len, 1); | |||
| x_shape.resize(std_len, 1); | |||
| ShapeVector broadcast_shape; | |||
| for (size_t i = 0; i < std_len; i++) { | |||
| int64_t x_i = x_shape[i]; // i-th dimension of x | |||
| int64_t y_i = y_shape[i]; // i-th dimension of y | |||
| int64_t output_i = 0; // i-th dimension of the output | |||
| if (x_i == y_i) { | |||
| output_i = x_i; | |||
| } else if (x_i == 1) { | |||
| output_i = y_i; | |||
| } else if (y_i == 1) { | |||
| output_i = x_i; | |||
| } else { | |||
| MS_LOG(EXCEPTION) | |||
| << op | |||
| << " evaluator the shape of first tensor and the shape of second tensor do not meet the broadcasting " | |||
| "requirements"; | |||
| } | |||
| broadcast_shape.push_back(output_i); | |||
| } | |||
| std::reverse(broadcast_shape.begin(), broadcast_shape.end()); | |||
| return broadcast_shape; | |||
| } | |||
| ShapeVector BroadcastShape(ShapeVector shpx, ShapeVector shpy) { | |||
| int dlen = SizeToInt(shpx.size()) - SizeToInt(shpy.size()); | |||
| if (dlen < 0) { | |||
| @@ -43,20 +43,11 @@ AbstractBasePtrList AbstractJoin(const AbstractBasePtrList &spec1, const Abstrac | |||
| // else self.Clone; | |||
| AbstractBasePtr SensitivityTransform(const AbstractBasePtr &spec); | |||
| TypePtr CheckTypeList(const TypePtr &predicate, const TypePtrList &args_type_list); | |||
| bool CheckType(const TypePtr &expected_type, const TypePtr &x); | |||
| int64_t GetPositiveAxis(int64_t axis_value, size_t increment); | |||
| ShapeVector BroadcastShape(ShapeVector shpx, ShapeVector shpy); | |||
| MS_CORE_API size_t TypeIdSize(const TypeId data_type); | |||
| size_t ShapeSize(const std::vector<size_t> &shape); | |||
| // Get broadcasted shape for binary element-wise operation | |||
| ShapePtr GetBroadcastShape(const std::string &op, const AbstractTensorPtr &tensor_x, const AbstractTensorPtr &tensor_y); | |||
| // Check dynamic shape routine | |||
| void CheckMinMaxShape(const ShapeVector &shape, ShapeVector *min_shape, ShapeVector *max_shape); | |||
| @@ -61,11 +61,10 @@ AbstractBasePtr LayerNormInfer(const abstract::AnalysisEnginePtr &, const Primit | |||
| // begin_norm_axis and begin_params_axis should be smaller than the size of input_x and >= -1 | |||
| ValuePtr bna_ptr = primitive->GetAttr("begin_norm_axis"); | |||
| int64_t begin_norm_axis = abstract::CheckAxis(op_name, bna_ptr, -1, SizeToLong(input_rank) - 1); | |||
| int64_t begin_norm_axis = abstract::CheckAxis(op_name, bna_ptr, -1, SizeToLong(input_rank)); | |||
| ValuePtr bpa_ptr = primitive->GetAttr("begin_params_axis"); | |||
| int64_t begin_params_axis = abstract::CheckAxis(op_name, bpa_ptr, -1, SizeToLong(input_rank) - 1); | |||
| begin_params_axis = abstract::GetPositiveAxis(begin_params_axis, input_rank); | |||
| int64_t begin_params_axis = abstract::CheckAxis(op_name, bpa_ptr, -1, SizeToLong(input_rank)); | |||
| // the beta and gama shape should be x_shape[begin_params_axis:] | |||
| auto valid_types = {kFloat16, kFloat32}; | |||
| @@ -104,7 +104,8 @@ TEST_F(TestComposite, test_TupleSlice_arg_two_numbers) { | |||
| engine_->Run(tupleSliceGraphPtr, args_spec_list); | |||
| FAIL() << "Excepted exception :Args type is wrong"; | |||
| } catch (std::runtime_error const &err) { | |||
| ASSERT_TRUE(std::string(err.what()).find("TupleSlice input args size should be 2, but got 3") != std::string::npos); | |||
| ASSERT_TRUE(std::string(err.what()).find("TupleSlice input arguments size should be 2, but got 3") != | |||
| std::string::npos); | |||
| } catch (...) { | |||
| FAIL() << "Excepted exception :Args type is wrong"; | |||
| } | |||
| @@ -250,7 +251,7 @@ TEST_F(TestComposite, test_UnpackCall_3args) { | |||
| MetaFuncGraphPtr unPackCallPtr = std::make_shared<prim::UnpackCall>("UnPackCall"); | |||
| FuncGraphPtr unPackCallGraphPtr = UTCompositeUtils::MakeFuncGraph(unPackCallPtr, 3); | |||
| auto fn_arg= std::make_shared<abstract::PrimitiveAbstractClosure>(prim::kPrimMakeTuple); | |||
| auto fn_arg = std::make_shared<abstract::PrimitiveAbstractClosure>(prim::kPrimMakeTuple); | |||
| AbstractTensorPtr tensor = UTCompositeUtils::ArrayInt32Of({2, 3, 4}); | |||
| AbstractBasePtrList eles; | |||
| for (size_t i = 0; i < 6; i++) { | |||