| @@ -20,14 +20,30 @@ | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| namespace { | |||
| const size_t kInputNum = 2; | |||
| const int kInputNum = 2; | |||
| const size_t one = 1; | |||
| void UpdatePreIsOne(std::vector<bool> *prev_is_one, std::vector<bool> current_is_one) { | |||
| for (size_t i = 0; i < kInputNum; ++i) { | |||
| (*prev_is_one)[i] = current_is_one[i]; | |||
| } | |||
| } | |||
| void AddElementToGradReduceIdx(std::vector<std::vector<int64_t>> *grad_reduce_idx, std::vector<bool> current_is_one, | |||
| bool none_is_one, const size_t largest_rank, size_t j) { | |||
| MS_EXCEPTION_IF_NULL(grad_reduce_idx); | |||
| for (size_t i = 0; i < kInputNum; ++i) { | |||
| if (current_is_one[i] && !none_is_one) { | |||
| (void)(*grad_reduce_idx)[i].emplace_back(SizeToLong(largest_rank - one - j)); | |||
| } | |||
| } | |||
| } | |||
| std::vector<std::vector<int64_t>> GetGradientIndices(const std::vector<std::vector<int64_t>> &reverse_shape, | |||
| const size_t largest_rank) { | |||
| std::vector<std::vector<int64_t>> grad_reduce_idx(kInputNum); | |||
| // indices of j-th component of each input. | |||
| bool prev_is_one[kInputNum]; | |||
| bool current_is_one[kInputNum]; | |||
| std::vector<bool> prev_is_one(kInputNum); | |||
| std::vector<bool> current_is_one(kInputNum); | |||
| for (size_t i = 0; i < kInputNum; ++i) { | |||
| prev_is_one[i] = false; | |||
| current_is_one[i] = false; | |||
| @@ -46,37 +62,26 @@ std::vector<std::vector<int64_t>> GetGradientIndices(const std::vector<std::vect | |||
| } else { | |||
| current_is_one[i] = false; | |||
| if (!output_dim_set || reverse_shape[i][j] == static_cast<int64_t>(output_dim)) { | |||
| output_dim = static_cast<int>(reverse_shape[i][j]); | |||
| output_dim = LongToInt(reverse_shape[i][j]); | |||
| output_dim_set = true; | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "Input[0] and input[1] Cannot broadcast!"; | |||
| } | |||
| } | |||
| } | |||
| // All dimensions are 1. | |||
| if (!output_dim_set) { | |||
| for (size_t i = 0; i < kInputNum; ++i) { | |||
| (void)grad_reduce_idx[i].emplace_back(largest_rank - 1 - j); | |||
| (void)grad_reduce_idx[i].emplace_back(SizeToLong(largest_rank - one - j)); | |||
| } | |||
| continue; | |||
| } else if (std::equal(current_is_one, current_is_one + kInputNum, prev_is_one) && set_one) { | |||
| for (size_t i = 0; i < kInputNum; ++i) { | |||
| if (current_is_one[i] && !none_is_one) { | |||
| (void)grad_reduce_idx[i].emplace_back(largest_rank - 1 - j); | |||
| } | |||
| } | |||
| } else if (std::equal(current_is_one.begin(), current_is_one.end(), prev_is_one.begin()) && set_one) { | |||
| AddElementToGradReduceIdx(&grad_reduce_idx, current_is_one, none_is_one, largest_rank, j); | |||
| } else { | |||
| for (size_t i = 0; i < kInputNum; ++i) { | |||
| if (current_is_one[i] && !none_is_one) { | |||
| (void)grad_reduce_idx[i].emplace_back(largest_rank - 1 - j); | |||
| } | |||
| } | |||
| AddElementToGradReduceIdx(&grad_reduce_idx, current_is_one, none_is_one, largest_rank, j); | |||
| } | |||
| set_one = true; | |||
| for (size_t i = 0; i < kInputNum; ++i) { | |||
| prev_is_one[i] = current_is_one[i]; | |||
| } | |||
| UpdatePreIsOne(&prev_is_one, current_is_one); | |||
| } | |||
| return grad_reduce_idx; | |||
| } | |||
| @@ -172,9 +177,10 @@ size_t SetOutputValue(const CNodePtr &cnode, const std::vector<std::vector<int64 | |||
| *(data_ptr + i) = output[i]; | |||
| } | |||
| (void)out_addr->SyncHostToDevice(out_shape, LongToSize(tensor_for_sync->data().nbytes()), | |||
| tensor_for_sync->data_type(), tensor_for_sync->data_c(), | |||
| tensor_for_sync->device_info().host_format_); | |||
| if (!out_addr->SyncHostToDevice(out_shape, LongToSize(tensor_for_sync->data().nbytes()), tensor_for_sync->data_type(), | |||
| tensor_for_sync->data_c(), tensor_for_sync->device_info().host_format_)) { | |||
| MS_LOG(EXCEPTION) << "Output Value SyncHostToDevice failed."; | |||
| } | |||
| return out_size; | |||
| } | |||
| } // namespace | |||
| @@ -30,7 +30,7 @@ namespace kernel { | |||
| void GetRealInputSize(const nlohmann::json &input_json, std::vector<size_t> *input_size_list, size_t *size_i) { | |||
| if (input_json[kJShape].size() == 1 && input_json[kJShape][0] == -2) { | |||
| auto input_max_shape = input_json[kJShape]; | |||
| auto input_max_shape = input_json[kJRange]; | |||
| for (auto &max_shape : input_max_shape) { | |||
| (*size_i) *= LongToSize(max_shape[1]); | |||
| } | |||
| @@ -77,7 +77,7 @@ void GetInputSizeList(const nlohmann::json &input_json, std::vector<size_t> *inp | |||
| void GetRealOutputSize(const nlohmann::json &output_json, std::vector<size_t> *output_size_list, size_t *size_i) { | |||
| if (output_json[kJShape].size() == 1 && output_json[kJShape][0] == -2) { | |||
| auto output_max_shape = output_json[kJShape]; | |||
| auto output_max_shape = output_json[kJRange]; | |||
| for (auto &max_shape : output_max_shape) { | |||
| (*size_i) *= LongToSize(max_shape[1]); | |||
| } | |||
| @@ -122,10 +122,9 @@ bool CheckIndexOutput(const CNodePtr &node, const std::shared_ptr<kernel::Kernel | |||
| void ChangeNodeInferInfo(const CNodePtr &cnode, const CNodePtr &cast, const size_t cast_index) { | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| MS_EXCEPTION_IF_NULL(cast); | |||
| using Shape = std::vector<size_t>; | |||
| auto cast_dtype = AnfAlgo::GetOutputInferDataType(cast, 0); | |||
| auto cast_shape = AnfAlgo::GetOutputInferShape(cast, 0); | |||
| std::vector<Shape> shapes; | |||
| auto cast_shape = AnfAlgo::GetOutputDetailShape(cast, 0); | |||
| std::vector<abstract::BaseShapePtr> shapes; | |||
| std::vector<TypeId> types; | |||
| size_t output_num = AnfAlgo::GetOutputTensorNum(cnode); | |||
| for (size_t index = 0; index < output_num; ++index) { | |||
| @@ -134,10 +133,10 @@ void ChangeNodeInferInfo(const CNodePtr &cnode, const CNodePtr &cast, const size | |||
| (void)types.emplace_back(cast_dtype); | |||
| continue; | |||
| } | |||
| (void)shapes.emplace_back(AnfAlgo::GetOutputInferShape(cnode, index)); | |||
| (void)shapes.emplace_back(AnfAlgo::GetOutputDetailShape(cnode, index)); | |||
| (void)types.emplace_back(AnfAlgo::GetOutputInferDataType(cnode, index)); | |||
| } | |||
| AnfAlgo::SetOutputInferTypeAndShape(types, shapes, cnode.get()); | |||
| AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, cnode.get()); | |||
| auto prim_op = AnfAlgo::GetCNodePrimitive(cnode); | |||
| if (prim_op != nullptr) { | |||
| (void)prim_op->AddAttr("cast_type", TypeIdToType(cast_dtype)); | |||
| @@ -48,6 +48,7 @@ std::string OpTilingCalculateAdapter::GetRealOpType(const std::string &op_type) | |||
| {"Softmax", "SoftmaxV2"}, | |||
| {"DropoutDoMask", "DropOutDoMask"}, | |||
| {"IOU", "Iou"}, | |||
| {"DynamicBroadcastTo", "BroadcastTo"}, | |||
| }; | |||
| auto iter = kOpTypeMap.find(op_type); | |||
| if (iter == kOpTypeMap.end()) { | |||
| @@ -110,18 +110,34 @@ TypePtr CheckScalarType(const AbstractScalarPtr &scalar, const TypePtrList &acce | |||
| return CheckType(type, accepts, error_message_prefix); | |||
| } | |||
| ShapePtr CheckShapeSame(const std::string &op, const AbstractTensorPtr &tensor_base, const AbstractTensorPtr &tensor) { | |||
| void CheckShapeSame(const std::string &op, const AbstractTensorPtr &tensor_base, const AbstractTensorPtr &tensor) { | |||
| MS_EXCEPTION_IF_NULL(tensor_base); | |||
| ShapePtr shape_base = tensor_base->shape(); | |||
| MS_EXCEPTION_IF_NULL(shape_base); | |||
| MS_EXCEPTION_IF_NULL(tensor); | |||
| ShapePtr shape = tensor->shape(); | |||
| MS_EXCEPTION_IF_NULL(shape); | |||
| if (*shape != *shape_base) { | |||
| if (shape_base->IsDimUnknown() || shape->IsDimUnknown()) { | |||
| return; | |||
| } | |||
| auto shape_vector = shape->shape(); | |||
| auto shape_base_vector = shape_base->shape(); | |||
| if (shape_vector.size() != shape_base_vector.size()) { | |||
| MS_LOG(EXCEPTION) << op << " evaluator first arg shape " << shape->ToString() | |||
| << " are not consistent with second arg shape " << shape_base->ToString(); | |||
| } | |||
| return shape_base; | |||
| for (size_t i = 0; i < shape_vector.size(); i++) { | |||
| if (shape_vector[i] == Shape::SHP_ANY || shape_base_vector[i] == Shape::SHP_ANY) { | |||
| continue; | |||
| } | |||
| if (shape_vector[i] != shape_base_vector[i]) { | |||
| MS_LOG(EXCEPTION) << op << " evaluator first arg shape " << shape->ToString() | |||
| << " are not consistent with second arg shape " << shape_base->ToString(); | |||
| } | |||
| } | |||
| return; | |||
| } | |||
| TypePtr CheckDtypeSame(const std::string &op, const AbstractTensorPtr &tensor_base, const AbstractTensorPtr &tensor) { | |||
| @@ -41,7 +41,7 @@ TypePtr CheckTensorsDTypeSame(const AbstractTensorPtrList &tensor_list, const Ty | |||
| TypePtr CheckScalarType(const AbstractScalarPtr &scalar, const TypePtrList &accepts, | |||
| const std::string &error_message_prefix); | |||
| ShapePtr CheckShapeSame(const std::string &op, const AbstractTensorPtr &tensor_base, const AbstractTensorPtr &tensor); | |||
| void CheckShapeSame(const std::string &op, const AbstractTensorPtr &tensor_base, const AbstractTensorPtr &tensor); | |||
| TypePtr CheckDtypeSame(const std::string &op, const AbstractTensorPtr &tensor_base, const AbstractTensorPtr &tensor); | |||
| @@ -114,7 +114,7 @@ AbstractBasePtr InferImplStack(const AnalysisEnginePtr &, const PrimitivePtr &pr | |||
| for (size_t i = 1; i < tuple_len; ++i) { | |||
| AbstractTensorPtr tensor = CheckArg<AbstractTensor>(op_name, arg->elements(), i); | |||
| (void)CheckDtypeSame(op_name, tensor_base, tensor); | |||
| (void)CheckShapeSame(op_name, tensor_base, tensor); | |||
| CheckShapeSame(op_name, tensor_base, tensor); | |||
| } | |||
| auto element = tensor_base->element(); | |||
| MS_EXCEPTION_IF_NULL(element); | |||
| @@ -1241,6 +1241,7 @@ AbstractBasePtr InferImplMaskedSelect(const AnalysisEnginePtr &, const Primitive | |||
| AbstractBasePtr InferImplDynamicStitch(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| MS_EXCEPTION_IF_NULL(primitive); | |||
| bool output_shape_unknow = false; | |||
| auto prim_name = primitive->name(); | |||
| constexpr int64_t args_size = 2; | |||
| (void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(args_spec_list.size()), kEqual, args_size, | |||
| @@ -1253,6 +1254,22 @@ AbstractBasePtr InferImplDynamicStitch(const AnalysisEnginePtr &, const Primitiv | |||
| auto input_tuple = args_spec_list[0]->cast<abstract::AbstractSequeuePtr>(); | |||
| MS_EXCEPTION_IF_NULL(input_tuple); | |||
| auto indices = input_tuple->elements(); | |||
| auto input_indice_size = input_tuple->size(); | |||
| int64_t first_dim_size = 0; | |||
| for (size_t i = 0; i < input_indice_size; i++) { | |||
| auto indicei = indices[i]->cast<abstract::AbstractTensorPtr>(); | |||
| MS_EXCEPTION_IF_NULL(indicei); | |||
| auto valuei = indicei->BuildValue(); | |||
| MS_EXCEPTION_IF_NULL(valuei); | |||
| if (!valuei->isa<tensor::Tensor>()) { | |||
| output_shape_unknow = true; | |||
| continue; | |||
| } | |||
| auto indicei_value = CheckAndConvertUtils::CheckTensorIntValue("indices", valuei, prim_name); | |||
| auto indicei_max = std::max_element(indicei_value.begin(), indicei_value.end()); | |||
| first_dim_size = *indicei_max > first_dim_size ? *indicei_max : first_dim_size; | |||
| } | |||
| auto indices0 = indices[0]->cast<abstract::AbstractTensorPtr>(); | |||
| MS_EXCEPTION_IF_NULL(indices0); | |||
| auto indices0_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(indices0->BuildShape())[kShape]; | |||
| @@ -1282,7 +1299,12 @@ AbstractBasePtr InferImplDynamicStitch(const AnalysisEnginePtr &, const Primitiv | |||
| std::set<TypePtr> valid_types = ops::common_valid_types; | |||
| auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim_name); | |||
| ShapeVector out_shape = {abstract::Shape::SHP_ANY}; | |||
| ShapeVector out_shape; | |||
| if (output_shape_unknow) { | |||
| out_shape.push_back(abstract::Shape::SHP_ANY); | |||
| } else { | |||
| out_shape.push_back(first_dim_size + 1); | |||
| } | |||
| for (size_t i = indices0_shape.size(); i < data0_shape.size(); ++i) { | |||
| out_shape.push_back(data0_shape[i]); | |||
| } | |||
| @@ -60,7 +60,7 @@ AbstractBasePtr InferImplSqrtGrad(const AnalysisEnginePtr &, const PrimitivePtr | |||
| auto out = CheckArg<AbstractTensor>(op_name, args_spec_list, 0); | |||
| auto dout = CheckArg<AbstractTensor>(op_name, args_spec_list, 1); | |||
| (void)CheckDtypeSame(op_name, out, dout); | |||
| (void)CheckShapeSame(op_name, out, dout); | |||
| CheckShapeSame(op_name, out, dout); | |||
| return out->Broaden(); | |||
| } | |||
| @@ -47,6 +47,7 @@ std::vector<int64_t> GetDependsFormMap(const CNodePtr &cnode) { | |||
| const auto kStridedSlice = prim::kPrimStridedSlice->name(); | |||
| const auto kStridedSliceGrad = prim::kPrimStridedSliceGrad->name(); | |||
| const auto kReduceSum = prim::kPrimReduceSum->name(); | |||
| const auto kDynamicBroadcastTo = prim::kPrimDynamicBroadcastTo->name(); | |||
| const auto kUnsortedSegmentSum = prim::kPrimUnsortedSegmentSum->name(); | |||
| const auto kUnsortedSegmentMin = prim::kPrimUnsortedSegmentMin->name(); | |||
| const auto kUnsortedSegmentMax = prim::kPrimUnsortedSegmentMax->name(); | |||
| @@ -78,7 +79,8 @@ std::vector<int64_t> GetDependsFormMap(const CNodePtr &cnode) { | |||
| {kTile, {1}}, | |||
| {kReshape, {1}}, | |||
| {kSlice, {1, 2}}, | |||
| {kSliceGrad, {2, 3}}}; | |||
| {kSliceGrad, {2, 3}}, | |||
| {kDynamicBroadcastTo, {1}}}; | |||
| auto ms_context = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(ms_context); | |||
| @@ -95,6 +95,7 @@ constexpr auto kDiagPart = "DiagPart"; | |||
| constexpr auto kDynamicBroadcastGradientArgs = "DynamicBroadcastGradientArgs"; | |||
| constexpr auto kTranspose = "Transpose"; | |||
| constexpr auto kSplitV = "SplitV"; | |||
| constexpr auto kDynamicBroadcastTo = "DynamicBroadcastTo"; | |||
| // NN | |||
| constexpr auto kCTCLoss = "CTCLoss"; | |||
| @@ -170,6 +171,7 @@ inline const PrimitivePtr kPrimStackPush = std::make_shared<Primitive>("StackPus | |||
| inline const PrimitivePtr kPrimStackPop = std::make_shared<Primitive>("StackPop"); | |||
| // Arrays | |||
| inline const PrimitivePtr kPrimDynamicBroadcastTo = std::make_shared<Primitive>(kDynamicBroadcastTo); | |||
| inline const PrimitivePtr kPrimBroadcastTo = std::make_shared<Primitive>("BroadcastTo"); | |||
| inline const PrimitivePtr kPrimScalarToArray = std::make_shared<Primitive>("scalar_to_array"); | |||
| inline const PrimitivePtr kPrimTopK = std::make_shared<Primitive>("TopK"); | |||
| @@ -73,8 +73,8 @@ abstract::ShapePtr AudioSpectrogramInferShape(const PrimitivePtr &primitive, | |||
| } | |||
| TypePtr AudioSpectrogramInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) { | |||
| const int64_t x_index = 0; | |||
| return CheckAndConvertUtils::GetInputTensorType(input_args, x_index, prim->name()); | |||
| const size_t x_index = 0; | |||
| return CheckAndConvertUtils::GetTensorInputType(prim->name(), input_args, x_index); | |||
| } | |||
| } // namespace | |||
| @@ -115,7 +115,20 @@ TypePtr BatchMatmulInferType(const PrimitivePtr &prim, const std::vector<Abstrac | |||
| std::map<std::string, TypePtr> types; | |||
| (void)types.emplace("x", input_args[0]->BuildType()); | |||
| (void)types.emplace("w", input_args[1]->BuildType()); | |||
| return CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name()); | |||
| (void)CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name()); | |||
| TypePtr x_type = input_args[0]->BuildType(); | |||
| if (x_type->type_id() == TypeId::kNumberTypeInt8) { | |||
| x_type = kInt32; | |||
| } | |||
| if (prim->HasAttr("cast_type")) { | |||
| auto out_type = prim->GetAttr("cast_type"); | |||
| MS_EXCEPTION_IF_NULL(out_type); | |||
| if (!out_type->isa<Type>()) { | |||
| MS_EXCEPTION(ValueError) << "MatMul cast_type must be a `Type`"; | |||
| } | |||
| x_type = out_type->cast<TypePtr>(); | |||
| } | |||
| return x_type; | |||
| } | |||
| } // namespace | |||
| @@ -48,17 +48,19 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A | |||
| auto x_shape_vector = x_shape->shape(); | |||
| auto mask_shape_vector = mask_shape->shape(); | |||
| int64_t x_size = 1; | |||
| for (size_t i = 0; i < x_shape_vector.size(); i++) { | |||
| x_size *= x_shape_vector[i]; | |||
| } | |||
| if (mask_shape_vector.size() != 1) { | |||
| MS_EXCEPTION(ValueError) << "DropoutDoMask input mask must be 1-dimension."; | |||
| } | |||
| auto mask_size = mask_shape_vector[0] * 8; | |||
| if (x_size > mask_size) { | |||
| MS_EXCEPTION(ValueError) << "DropoutDoMask input mask do not match input, input_x shape: " << x_shape->ToString() | |||
| << ", mask shape: " << mask_shape->ToString(); | |||
| if (!x_shape->IsDynamic() && !mask_shape->IsDynamic()) { | |||
| int64_t x_size = 1; | |||
| for (size_t i = 0; i < x_shape_vector.size(); i++) { | |||
| x_size *= x_shape_vector[i]; | |||
| } | |||
| if (mask_shape_vector.size() != 1) { | |||
| MS_EXCEPTION(ValueError) << "DropoutDoMask input mask must be 1-dimension."; | |||
| } | |||
| auto mask_size = mask_shape_vector[0] * 8; | |||
| if (x_size > mask_size) { | |||
| MS_EXCEPTION(ValueError) << "DropoutDoMask input mask do not match input, input_x shape: " << x_shape->ToString() | |||
| << ", mask shape: " << mask_shape->ToString(); | |||
| } | |||
| } | |||
| auto keep_prop = input_args[kInputIndex2]; | |||
| if (keep_prop->isa<abstract::AbstractTensor>()) { | |||
| @@ -86,7 +86,6 @@ ShapeVector CalOutputShape(const AbstractBasePtrList shape_list) { | |||
| } | |||
| count = count * value; | |||
| } | |||
| // convert to bytes(8 bits) mask, using round up | |||
| int64_t n128s = count / mask_convert_len; | |||
| if ((count % mask_convert_len) != 0) { | |||
| @@ -106,7 +105,18 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A | |||
| AbstractBasePtr shape_args = input_args[0]; | |||
| MS_EXCEPTION_IF_NULL(shape_args); | |||
| ShapeVector out_shape; | |||
| if (shape_args->isa<abstract::AbstractTensor>()) { | |||
| auto shape_value = shape_args->BuildValue(); | |||
| MS_EXCEPTION_IF_NULL(shape_value); | |||
| if (shape_value->isa<tensor::Tensor>()) { | |||
| auto mask_shape = CheckAndConvertUtils::CheckTensorIntValue("shape", shape_value, op_name); | |||
| std::vector<ValuePtr> value_elements; | |||
| std::transform(mask_shape.begin(), mask_shape.end(), std::back_inserter(value_elements), | |||
| [](int64_t elem) { return MakeValue(elem); }); | |||
| out_shape = CalDynamicOutputShape(value_elements); | |||
| return std::make_shared<abstract::Shape>(out_shape); | |||
| } | |||
| auto shape_abstract = dyn_cast<abstract::AbstractTensor>(shape_args); | |||
| MS_EXCEPTION_IF_NULL(shape_abstract); | |||
| auto shape_base = shape_abstract->BuildShape(); | |||
| @@ -139,7 +149,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A | |||
| auto x_shape = dyn_cast<abstract::AbstractTuple>(shape_args); | |||
| auto x_shape_data = x_shape->elements(); | |||
| ShapeVector out_shape = CalOutputShape(x_shape_data); | |||
| out_shape = CalOutputShape(x_shape_data); | |||
| return std::make_shared<abstract::Shape>(out_shape); | |||
| } | |||
| TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | |||
| @@ -0,0 +1,91 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "ops/dynamic_broadcast_to.h" | |||
| #include <set> | |||
| #include "utils/check_convert_utils.h" | |||
| namespace mindspore { | |||
| namespace ops { | |||
| namespace { | |||
| abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | |||
| MS_EXCEPTION_IF_NULL(primitive); | |||
| auto prim_name = primitive->name(); | |||
| CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kEqual, 2, prim_name); | |||
| auto input_y = input_args[1]; | |||
| MS_EXCEPTION_IF_NULL(input_y); | |||
| abstract::ShapePtr y_shape; | |||
| auto y_value = input_y->BuildValue(); | |||
| MS_EXCEPTION_IF_NULL(y_value); | |||
| if (input_y->isa<abstract::AbstractTensor>()) { | |||
| if (y_value->isa<tensor::Tensor>()) { | |||
| auto shape_value = CheckAndConvertUtils::CheckTensorIntValue("shape", y_value, prim_name); | |||
| return std::make_shared<abstract::Shape>(shape_value); | |||
| } | |||
| y_shape = CheckAndConvertUtils::GetTensorInputShape(prim_name, input_args, 1); | |||
| auto shape_value = y_shape->shape(); | |||
| if (shape_value.size() != 1) { | |||
| MS_EXCEPTION(TypeError) << "shape size error: " << shape_value.size(); | |||
| } | |||
| std::vector<int64_t> output_shape; | |||
| std::vector<int64_t> max_shape; | |||
| std::vector<int64_t> min_shape; | |||
| if (y_shape->IsDynamic()) { | |||
| // max shape unknown | |||
| output_shape.push_back(-2); | |||
| } else { | |||
| auto out_dims = LongToSize(y_shape->shape()[0]); | |||
| for (size_t i = 0; i < out_dims; i++) { | |||
| output_shape.push_back(-1); | |||
| } | |||
| auto min_value = input_y->cast<abstract::AbstractTensorPtr>()->get_min_value(); | |||
| auto max_value = input_y->cast<abstract::AbstractTensorPtr>()->get_max_value(); | |||
| if (!min_value || !max_value) { | |||
| MS_EXCEPTION(ValueError) << "For BroadcastTo, inputs['shape'] min or max value is empty."; | |||
| } | |||
| min_shape = GetValue<std::vector<int64_t>>(min_value); | |||
| max_shape = GetValue<std::vector<int64_t>>(max_value); | |||
| if (min_shape.size() != out_dims || max_shape.size() != out_dims) { | |||
| MS_EXCEPTION(ValueError) << "For BroadcastTo, inputs['shape'] min or max value not match with out dims."; | |||
| } | |||
| } | |||
| return std::make_shared<abstract::Shape>(output_shape, min_shape, max_shape); | |||
| } else if (input_y->isa<abstract::AbstractTuple>()) { | |||
| auto out_shape = GetValue<std::vector<int64_t>>(y_value); | |||
| return std::make_shared<abstract::Shape>(out_shape); | |||
| } | |||
| MS_EXCEPTION(TypeError) << "For BroadcastTo, input args must be tensor or tuple."; | |||
| } | |||
| TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) { | |||
| for (const auto &item : input_args) { | |||
| MS_EXCEPTION_IF_NULL(item); | |||
| } | |||
| auto x_dtype = input_args[0]->BuildType()->cast<TensorTypePtr>(); | |||
| std::set<TypePtr> template_types = {kTensorType}; | |||
| CheckAndConvertUtils::CheckSubClass("x_dtype", x_dtype, template_types, prim->name()); | |||
| return x_dtype->element(); | |||
| } | |||
| } // namespace | |||
| AbstractBasePtr DynamicBroadcastToInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const std::vector<AbstractBasePtr> &input_args) { | |||
| return abstract::MakeAbstract(InferShape(primitive, input_args), InferType(primitive, input_args)); | |||
| } | |||
| REGISTER_PRIMITIVE_EVAL_IMPL(DynamicBroadcastTo, prim::kPrimDynamicBroadcastTo, DynamicBroadcastToInfer, nullptr, true); | |||
| } // namespace ops | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,43 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CORE_OPS_DYNAMIC_BROADCAST_TO_H_ | |||
| #define MINDSPORE_CORE_OPS_DYNAMIC_BROADCAST_TO_H_ | |||
| #include <map> | |||
| #include <vector> | |||
| #include <string> | |||
| #include <memory> | |||
| #include "ops/op_utils.h" | |||
| #include "ops/primitive_c.h" | |||
| #include "abstract/abstract_value.h" | |||
| namespace mindspore { | |||
| namespace ops { | |||
| class DynamicBroadcastTo : public PrimitiveC { | |||
| public: | |||
| DynamicBroadcastTo() : PrimitiveC(prim::kPrimDynamicBroadcastTo->name()) { InitIOName({"x", "shape"}, {"y"}); } | |||
| ~DynamicBroadcastTo() = default; | |||
| MS_DECLARE_PARENT(DynamicBroadcastTo, PrimitiveC); | |||
| void Init() {} | |||
| }; | |||
| AbstractBasePtr DynamicBroadcastToInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const std::vector<AbstractBasePtr> &input_args); | |||
| using PrimDynamicBroadcastToPtr = std::shared_ptr<DynamicBroadcastTo>; | |||
| } // namespace ops | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CORE_OPS_DYNAMIC_BROADCAST_TO_H_ | |||
| @@ -48,8 +48,8 @@ AbstractBasePtr ExpandDimsInfer(const abstract::AnalysisEnginePtr &, const Primi | |||
| (void)out_shape.insert(out_shape.begin() + dim_val, 1, 1); | |||
| // Infer type | |||
| const int64_t x_index = 0; | |||
| auto x_type = CheckAndConvertUtils::GetInputTensorType(input_args, x_index, prim_name); | |||
| const size_t x_index = 0; | |||
| auto x_type = CheckAndConvertUtils::GetTensorInputType(prim_name, input_args, x_index); | |||
| std::set<TypePtr> valid_x_type = {kTensorType}; | |||
| (void)CheckAndConvertUtils::CheckSubClass("x_type", x_type, valid_x_type, prim_name); | |||
| return std::make_shared<abstract::AbstractTensor>(x_type, out_shape); | |||
| @@ -24,9 +24,9 @@ | |||
| namespace mindspore { | |||
| namespace ops { | |||
| namespace { | |||
| constexpr int64_t kDoutIndex = 0; | |||
| constexpr int64_t kInputIndex = 1; | |||
| constexpr int64_t kFilterSizeIdex = 2; | |||
| constexpr size_t kDoutIndex = 0; | |||
| constexpr size_t kInputIndex = 1; | |||
| constexpr size_t kFilterSizeIdex = 2; | |||
| constexpr size_t kStride2dSize = 2; | |||
| constexpr size_t kStride4dSize = 4; | |||
| @@ -27,7 +27,7 @@ namespace ops { | |||
| namespace { | |||
| constexpr size_t kDoutIndex = 0; | |||
| constexpr size_t kInputIndex = 1; | |||
| constexpr int64_t kSizeIndex = 2; | |||
| constexpr size_t kSizeIndex = 2; | |||
| void SetPadList(const PrimitivePtr &primitive, const std::vector<int64_t> &dout_shape_norm, | |||
| const std::vector<int64_t> &x_size_v) { | |||
| @@ -44,8 +44,8 @@ AbstractBasePtr DropoutGradInfer(const abstract::AnalysisEnginePtr &, const Prim | |||
| const int64_t input_num = 2; | |||
| CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, op_name); | |||
| const int64_t dy_index = 0; | |||
| const int64_t mask_index = 1; | |||
| const size_t dy_index = 0; | |||
| const size_t mask_index = 1; | |||
| auto dy_type = input_args[dy_index]->BuildType(); | |||
| auto mask_type = input_args[mask_index]->BuildType(); | |||
| @@ -37,7 +37,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A | |||
| } | |||
| auto dout = CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, 0); | |||
| auto out = CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, 1); | |||
| (void)abstract::CheckShapeSame(prim_name, out, dout); | |||
| abstract::CheckShapeSame(prim_name, out, dout); | |||
| auto x = input_args[0]->BuildShape(); | |||
| MS_EXCEPTION_IF_NULL(x); | |||
| auto shape_element = x->cast<abstract::ShapePtr>(); | |||
| @@ -124,7 +124,7 @@ AbstractBasePtr StridedSliceGradInfer(const abstract::AnalysisEnginePtr &, const | |||
| StridedSliceGradInferType(primitive, input_args)); | |||
| } | |||
| void StridedSliceGrad::set_begin_mask(const int64_t begin_mask) { | |||
| void StridedSliceGrad::set_begin_mask(int64_t begin_mask) { | |||
| (void)CheckAndConvertUtils::CheckInteger(kBeginMask, begin_mask, kGreaterEqual, 0, this->name()); | |||
| (void)this->AddAttr(kBeginMask, MakeValue(begin_mask)); | |||
| } | |||
| @@ -133,7 +133,7 @@ int64_t StridedSliceGrad::get_begin_mask() const { | |||
| MS_EXCEPTION_IF_NULL(value_ptr); | |||
| return GetValue<int64_t>(value_ptr); | |||
| } | |||
| void StridedSliceGrad::set_end_mask(const int64_t end_mask) { | |||
| void StridedSliceGrad::set_end_mask(int64_t end_mask) { | |||
| (void)CheckAndConvertUtils::CheckInteger(kEndMask, end_mask, kGreaterEqual, 0, this->name()); | |||
| (void)this->AddAttr(kEndMask, MakeValue(end_mask)); | |||
| } | |||
| @@ -141,7 +141,7 @@ int64_t StridedSliceGrad::get_end_mask() const { | |||
| auto value_ptr = GetAttr(kEndMask); | |||
| return GetValue<int64_t>(value_ptr); | |||
| } | |||
| void StridedSliceGrad::set_ellipsis_mask(const int64_t ellipsis_mask) { | |||
| void StridedSliceGrad::set_ellipsis_mask(int64_t ellipsis_mask) { | |||
| (void)CheckAndConvertUtils::CheckInteger(kEllipsisMask, ellipsis_mask, kGreaterEqual, 0, this->name()); | |||
| std::bitset<sizeof(int64_t) * 8> bs(ellipsis_mask); | |||
| std::ostringstream buffer; | |||
| @@ -155,7 +155,7 @@ int64_t StridedSliceGrad::get_ellipsis_mask() const { | |||
| auto value_ptr = GetAttr(kEllipsisMask); | |||
| return GetValue<int64_t>(value_ptr); | |||
| } | |||
| void StridedSliceGrad::set_new_axis_mask(const int64_t new_axis_mask) { | |||
| void StridedSliceGrad::set_new_axis_mask(int64_t new_axis_mask) { | |||
| (void)CheckAndConvertUtils::CheckInteger(kNewAxisMask, new_axis_mask, kGreaterEqual, 0, this->name()); | |||
| (void)this->AddAttr(kNewAxisMask, MakeValue(new_axis_mask)); | |||
| } | |||
| @@ -163,7 +163,7 @@ int64_t StridedSliceGrad::get_new_axis_mask() const { | |||
| auto value_ptr = GetAttr(kNewAxisMask); | |||
| return GetValue<int64_t>(value_ptr); | |||
| } | |||
| void StridedSliceGrad::set_shrink_axis_mask(const int64_t shrink_axis_mask) { | |||
| void StridedSliceGrad::set_shrink_axis_mask(int64_t shrink_axis_mask) { | |||
| (void)CheckAndConvertUtils::CheckInteger(kShrinkAxisMask, shrink_axis_mask, kGreaterEqual, 0, this->name()); | |||
| (void)this->AddAttr(kShrinkAxisMask, MakeValue(shrink_axis_mask)); | |||
| } | |||
| @@ -171,8 +171,8 @@ int64_t StridedSliceGrad::get_shrink_axis_mask() const { | |||
| auto value_ptr = GetAttr(kShrinkAxisMask); | |||
| return GetValue<int64_t>(value_ptr); | |||
| } | |||
| void StridedSliceGrad::Init(const int64_t begin_mask, const int64_t end_mask, const int64_t ellipsis_mask, | |||
| const int64_t new_axis_mask, const int64_t shrink_axis_mask) { | |||
| void StridedSliceGrad::Init(int64_t begin_mask, int64_t end_mask, int64_t ellipsis_mask, int64_t new_axis_mask, | |||
| int64_t shrink_axis_mask) { | |||
| this->set_begin_mask(begin_mask); | |||
| this->set_end_mask(end_mask); | |||
| this->set_ellipsis_mask(ellipsis_mask); | |||
| @@ -34,13 +34,13 @@ class MS_CORE_API StridedSliceGrad : public PrimitiveC { | |||
| ~StridedSliceGrad() = default; | |||
| MS_DECLARE_PARENT(StridedSliceGrad, PrimitiveC); | |||
| void Init(const int64_t begin_mask = 0, const int64_t end_mask = 0, const int64_t ellipsis_mask = 0, | |||
| const int64_t new_axis_mask = 0, const int64_t shrink_axis_mask = 0); | |||
| void set_begin_mask(const int64_t begin_mask); | |||
| void set_end_mask(const int64_t end_mask); | |||
| void set_ellipsis_mask(const int64_t ellipsis_mask); | |||
| void set_new_axis_mask(const int64_t new_axis_mask); | |||
| void set_shrink_axis_mask(const int64_t shrink_axis_mask); | |||
| void Init(int64_t begin_mask = 0, int64_t end_mask = 0, int64_t ellipsis_mask = 0, int64_t new_axis_mask = 0, | |||
| int64_t shrink_axis_mask = 0); | |||
| void set_begin_mask(int64_t begin_mask); | |||
| void set_end_mask(int64_t end_mask); | |||
| void set_ellipsis_mask(int64_t ellipsis_mask); | |||
| void set_new_axis_mask(int64_t new_axis_mask); | |||
| void set_shrink_axis_mask(int64_t shrink_axis_mask); | |||
| int64_t get_begin_mask() const; | |||
| int64_t get_end_mask() const; | |||
| int64_t get_ellipsis_mask() const; | |||
| @@ -29,7 +29,7 @@ namespace ops { | |||
| namespace { | |||
| abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | |||
| auto op_name = primitive->name(); | |||
| (void)CheckAndConvertUtils::CheckInteger("infer_shape", int64_t(input_args.size()), kGreaterEqual, 1, op_name); | |||
| (void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kGreaterEqual, 1, op_name); | |||
| return CheckAndConvertUtils::GetTensorInputShape(op_name, input_args, 0); | |||
| } | |||
| @@ -33,9 +33,9 @@ AbstractBasePtr QuantDTypeCastInfer(const abstract::AnalysisEnginePtr &, const P | |||
| const std::vector<AbstractBasePtr> &input_args) { | |||
| MS_EXCEPTION_IF_NULL(primitive); | |||
| const int64_t input_num = 1; | |||
| const int64_t x_index = 0; | |||
| const size_t x_index = 0; | |||
| CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, input_num, primitive->name()); | |||
| auto input_type = CheckAndConvertUtils::GetInputTensorType(input_args, x_index, primitive->name()); | |||
| auto input_type = CheckAndConvertUtils::GetTensorInputType(primitive->name(), input_args, x_index); | |||
| auto dst_type = TypeIdToType(TypeId(GetValue<int64_t>(primitive->GetAttr(kDstT)))); | |||
| MS_EXCEPTION_IF_NULL(dst_type); | |||
| if (input_type != dst_type) { | |||
| @@ -48,7 +48,7 @@ void InferImplReduceFuncCalShape(ShapeVector *shape, const ShapeVector &x_shape, | |||
| ValuePtrList::iterator it; | |||
| if (keep_dims_value) { | |||
| for (it = axis_items.begin(); it != axis_items.end(); ++it) { | |||
| auto axis_value = GetValue<int64_t>(*it); | |||
| auto axis_value = InferImplReduceFuncCheckAxis(GetValue<int64_t>(*it), x_shape.size()); | |||
| shape->at(LongToSize(axis_value)) = 1; | |||
| } | |||
| } else { | |||
| @@ -108,10 +108,8 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A | |||
| auto axis_shape = axis_tensor->shape()->shape(); | |||
| if (axis_shape.size() == 1 && axis_shape[0] == -1 && !keep_dims) { | |||
| out_shape.push_back(-2); | |||
| for (size_t i = 0; i < input_shape.size(); ++i) { | |||
| out_min_shape.push_back(1); | |||
| out_max_shape.push_back(max_v); | |||
| } | |||
| out_min_shape = input_min_shape; | |||
| out_max_shape = input_max_shape; | |||
| } else if (!keep_dims) { | |||
| for (size_t i = 0; i < input_shape.size() - axis_shape.size(); ++i) { | |||
| out_shape.push_back(-1); | |||
| @@ -136,7 +134,6 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A | |||
| } | |||
| MS_EXCEPTION_IF_NULL(axis_ptr); | |||
| if (axis_ptr->isa<tensor::Tensor>()) { | |||
| MS_LOG(ERROR) << "Tensor with value"; | |||
| auto axis_type = input_args[1]->BuildType(); | |||
| MS_EXCEPTION_IF_NULL(axis_type); | |||
| auto axis_type_id = axis_type->cast<TensorTypePtr>(); | |||
| @@ -178,8 +175,9 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A | |||
| TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) { | |||
| MS_EXCEPTION_IF_NULL(prim); | |||
| return CheckAndConvertUtils::CheckTensorTypeValid("x dtype", input_args[0]->BuildType(), common_valid_types, | |||
| "ReduceSum"); | |||
| auto x_type = input_args[0]->BuildType(); | |||
| (void)CheckAndConvertUtils::CheckTensorTypeValid("x dtype", x_type, common_valid_types, prim->name()); | |||
| return x_type; | |||
| } | |||
| } // namespace | |||
| @@ -321,7 +321,7 @@ abstract::ShapePtr StridedSliceInferShape(const PrimitivePtr &primitive, | |||
| const std::vector<AbstractBasePtr> &input_args) { | |||
| MS_EXCEPTION_IF_NULL(primitive); | |||
| auto prim_name = primitive->name(); | |||
| const int64_t x_index = 0; | |||
| const size_t x_index = 0; | |||
| auto x_shape = CheckAndConvertUtils::GetTensorInputShape(prim_name, input_args, x_index); | |||
| if (x_shape->IsDynamic()) { | |||
| MS_EXCEPTION(ValueError) << "input x dynamic shape is currently not supported."; | |||
| @@ -363,12 +363,12 @@ abstract::ShapePtr StridedSliceInferShape(const PrimitivePtr &primitive, | |||
| } | |||
| TypePtr StridedSliceInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | |||
| const int64_t x_index = 0; | |||
| return CheckAndConvertUtils::GetInputTensorType(input_args, x_index, primitive->name()); | |||
| const size_t x_index = 0; | |||
| return CheckAndConvertUtils::GetTensorInputType(primitive->name(), input_args, x_index); | |||
| } | |||
| } // namespace | |||
| void StridedSlice::set_begin_mask(const int64_t begin_mask) { | |||
| void StridedSlice::set_begin_mask(int64_t begin_mask) { | |||
| (void)CheckAndConvertUtils::CheckInteger(kBeginMask, begin_mask, kGreaterEqual, 0, this->name()); | |||
| (void)this->AddAttr(kBeginMask, MakeValue(begin_mask)); | |||
| } | |||
| @@ -376,7 +376,7 @@ int64_t StridedSlice::get_begin_mask() const { | |||
| auto value_ptr = GetAttr(kBeginMask); | |||
| return GetValue<int64_t>(value_ptr); | |||
| } | |||
| void StridedSlice::set_end_mask(const int64_t end_mask) { | |||
| void StridedSlice::set_end_mask(int64_t end_mask) { | |||
| (void)CheckAndConvertUtils::CheckInteger(kEndMask, end_mask, kGreaterEqual, 0, this->name()); | |||
| (void)this->AddAttr(kEndMask, MakeValue(end_mask)); | |||
| } | |||
| @@ -384,7 +384,7 @@ int64_t StridedSlice::get_end_mask() const { | |||
| auto value_ptr = GetAttr(kEndMask); | |||
| return GetValue<int64_t>(value_ptr); | |||
| } | |||
| void StridedSlice::set_ellipsis_mask(const int64_t ellipsis_mask) { | |||
| void StridedSlice::set_ellipsis_mask(int64_t ellipsis_mask) { | |||
| (void)CheckAndConvertUtils::CheckInteger(kEllipsisMask, ellipsis_mask, kGreaterEqual, 0, this->name()); | |||
| std::bitset<sizeof(int64_t) * 8> bs(ellipsis_mask); | |||
| std::ostringstream buffer; | |||
| @@ -398,7 +398,7 @@ int64_t StridedSlice::get_ellipsis_mask() const { | |||
| auto value_ptr = GetAttr(kEllipsisMask); | |||
| return GetValue<int64_t>(value_ptr); | |||
| } | |||
| void StridedSlice::set_new_axis_mask(const int64_t new_axis_mask) { | |||
| void StridedSlice::set_new_axis_mask(int64_t new_axis_mask) { | |||
| (void)CheckAndConvertUtils::CheckInteger(kNewAxisMask, new_axis_mask, kGreaterEqual, 0, this->name()); | |||
| (void)this->AddAttr(kNewAxisMask, MakeValue(new_axis_mask)); | |||
| } | |||
| @@ -406,7 +406,7 @@ int64_t StridedSlice::get_new_axis_mask() const { | |||
| auto value_ptr = GetAttr(kNewAxisMask); | |||
| return GetValue<int64_t>(value_ptr); | |||
| } | |||
| void StridedSlice::set_shrink_axis_mask(const int64_t shrink_axis_mask) { | |||
| void StridedSlice::set_shrink_axis_mask(int64_t shrink_axis_mask) { | |||
| (void)CheckAndConvertUtils::CheckInteger(kShrinkAxisMask, shrink_axis_mask, kGreaterEqual, 0, this->name()); | |||
| (void)this->AddAttr(kShrinkAxisMask, MakeValue(shrink_axis_mask)); | |||
| } | |||
| @@ -414,8 +414,8 @@ int64_t StridedSlice::get_shrink_axis_mask() const { | |||
| auto value_ptr = GetAttr(kShrinkAxisMask); | |||
| return GetValue<int64_t>(value_ptr); | |||
| } | |||
| void StridedSlice::Init(const int64_t begin_mask, const int64_t end_mask, const int64_t ellipsis_mask, | |||
| const int64_t new_axis_mask, const int64_t shrink_axis_mask) { | |||
| void StridedSlice::Init(int64_t begin_mask, int64_t end_mask, int64_t ellipsis_mask, int64_t new_axis_mask, | |||
| int64_t shrink_axis_mask) { | |||
| this->set_begin_mask(begin_mask); | |||
| this->set_end_mask(end_mask); | |||
| this->set_ellipsis_mask(ellipsis_mask); | |||
| @@ -38,18 +38,18 @@ class MS_CORE_API StridedSlice : public PrimitiveC { | |||
| ~StridedSlice() = default; | |||
| MS_DECLARE_PARENT(StridedSlice, PrimitiveC); | |||
| /// \brief Init. Refer to the parameters of python API @ref mindspore.ops.StridedSlice for the inputs. | |||
| void Init(const int64_t begin_mask = 0, const int64_t end_mask = 0, const int64_t ellipsis_mask = 0, | |||
| const int64_t new_axis_mask = 0, const int64_t shrink_axis_mask = 0); | |||
| void Init(int64_t begin_mask = 0, int64_t end_mask = 0, int64_t ellipsis_mask = 0, int64_t new_axis_mask = 0, | |||
| int64_t shrink_axis_mask = 0); | |||
| /// \brief Set begin_mask. | |||
| void set_begin_mask(const int64_t begin_mask); | |||
| void set_begin_mask(int64_t begin_mask); | |||
| /// \brief Set end_mask. | |||
| void set_end_mask(const int64_t end_mask); | |||
| void set_end_mask(int64_t end_mask); | |||
| /// \brief Set ellipsis_mask. | |||
| void set_ellipsis_mask(const int64_t ellipsis_mask); | |||
| void set_ellipsis_mask(int64_t ellipsis_mask); | |||
| /// \brief Set new_axis_mask. | |||
| void set_new_axis_mask(const int64_t new_axis_mask); | |||
| void set_new_axis_mask(int64_t new_axis_mask); | |||
| /// \brief Set shrink_axis_mask. | |||
| void set_shrink_axis_mask(const int64_t shrink_axis_mask); | |||
| void set_shrink_axis_mask(int64_t shrink_axis_mask); | |||
| /// \brief Get begin_mask. | |||
| /// | |||
| /// \return begin_mask. | |||
| @@ -378,28 +378,6 @@ void CheckAndConvertUtils::CheckInputArgs(const std::vector<AbstractBasePtr> &in | |||
| } | |||
| } | |||
| TypePtr CheckAndConvertUtils::GetInputTensorType(const std::vector<AbstractBasePtr> &input_args, const size_t index, | |||
| const std::string &prim_name) { | |||
| if (input_args.size() <= index) { | |||
| MS_EXCEPTION(ValueError) << "The primitive[" << prim_name << "]'s input index[" << index | |||
| << "] is out of the input number " << input_args.size(); | |||
| } | |||
| auto input_arg = input_args[index]; | |||
| if (input_arg == nullptr) { | |||
| MS_EXCEPTION(ValueError) << "The primitive[" << prim_name << "]'s input index[" << index << "] is nullptr."; | |||
| } | |||
| auto base_type = input_arg->BuildType(); | |||
| MS_EXCEPTION_IF_NULL(base_type); | |||
| if (!base_type->isa<TensorType>()) { | |||
| MS_EXCEPTION(ValueError) << "The primitive[" << prim_name << "]'s input index[" << index << "] is not a tensor."; | |||
| } | |||
| auto tensor_type = base_type->cast<TensorTypePtr>(); | |||
| MS_EXCEPTION_IF_NULL(tensor_type); | |||
| auto type = tensor_type->element(); | |||
| MS_EXCEPTION_IF_NULL(type); | |||
| return type; | |||
| } | |||
| ShapeMap CheckAndConvertUtils::ConvertShapePtrToShapeMap(const BaseShapePtr &shape) { | |||
| MS_EXCEPTION_IF_NULL(shape); | |||
| if (!shape->isa<abstract::Shape>()) { | |||
| @@ -416,8 +394,8 @@ ShapeMap CheckAndConvertUtils::ConvertShapePtrToShapeMap(const BaseShapePtr &sha | |||
| abstract::ShapePtr CheckAndConvertUtils::GetTensorInputShape(const std::string &prim_name, | |||
| const std::vector<AbstractBasePtr> &input_args, | |||
| int64_t index) { | |||
| auto abstract = CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, LongToSize(index)); | |||
| size_t index) { | |||
| auto abstract = CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, index); | |||
| MS_EXCEPTION_IF_NULL(abstract); | |||
| auto base_shape = abstract->BuildShape(); | |||
| MS_EXCEPTION_IF_NULL(base_shape); | |||
| @@ -429,6 +407,28 @@ abstract::ShapePtr CheckAndConvertUtils::GetTensorInputShape(const std::string & | |||
| return shape; | |||
| } | |||
| TypePtr CheckAndConvertUtils::GetTensorInputType(const std::string &prim_name, | |||
| const std::vector<AbstractBasePtr> &input_args, size_t index) { | |||
| if (input_args.size() <= index) { | |||
| MS_EXCEPTION(ValueError) << "For " << prim_name << ", the index " << index << " is out of the input number " | |||
| << input_args.size(); | |||
| } | |||
| auto input_arg = input_args[index]; | |||
| if (input_arg == nullptr) { | |||
| MS_EXCEPTION(ValueError) << "The " << index << "'s input of " << prim_name << " is nullptr."; | |||
| } | |||
| auto base_type = input_arg->BuildType(); | |||
| MS_EXCEPTION_IF_NULL(base_type); | |||
| if (!base_type->isa<TensorType>()) { | |||
| MS_EXCEPTION(ValueError) << "The " << index << "'s input type of " << prim_name << " is not Tensor."; | |||
| } | |||
| auto tensor_type = base_type->cast<TensorTypePtr>(); | |||
| MS_EXCEPTION_IF_NULL(tensor_type); | |||
| auto type = tensor_type->element(); | |||
| MS_EXCEPTION_IF_NULL(type); | |||
| return type; | |||
| } | |||
| void CheckAndConvertUtils::Check(const string &arg_name, int64_t arg_value, CompareEnum compare_type, const string &, | |||
| int64_t value, const string &prim_name, ExceptionType) { | |||
| auto iter = kCompareMap<float>.find(compare_type); | |||
| @@ -219,7 +219,9 @@ class CheckAndConvertUtils { | |||
| static ShapeMap ConvertShapePtrToShapeMap(const BaseShapePtr &shape); | |||
| static abstract::ShapePtr GetTensorInputShape(const std::string &prim_name, | |||
| const std::vector<AbstractBasePtr> &input_args, int64_t index); | |||
| const std::vector<AbstractBasePtr> &input_args, size_t index); | |||
| static TypePtr GetTensorInputType(const std::string &prim_name, const std::vector<AbstractBasePtr> &input_args, | |||
| size_t index); | |||
| static void Check(const std::string &arg_name, int64_t arg_value, CompareEnum compare_type, | |||
| const std::string &value_name, int64_t value, const std::string &prim_name = "", | |||
| ExceptionType exception_type = ValueError); | |||
| @@ -313,8 +315,6 @@ class CheckAndConvertUtils { | |||
| static size_t GetRemoveMonadAbsNum(const AbstractBasePtrList &abs_list); | |||
| static void CheckInputArgs(const std::vector<AbstractBasePtr> &input_args, const CompareEnum compare_operator, | |||
| const int64_t match_value, const std::string &prim_name); | |||
| static TypePtr GetInputTensorType(const std::vector<AbstractBasePtr> &input_args, const size_t index, | |||
| const std::string &prim_name); | |||
| static bool HasDynamicShapeInput(const AbstractBasePtrList &abs_list); | |||
| private: | |||
| @@ -55,7 +55,9 @@ from .batch_matmul_ds import _batch_matmul_ds_tbe | |||
| from .batchnorm import _batch_norm_tbe | |||
| from .batchnorm_grad import _batch_norm_grad_tbe | |||
| from .bias_add import _bias_add_tbe | |||
| from .bias_add_ds import _bias_add_ds_tbe | |||
| from .bias_add_grad import _bias_add_grad_tbe | |||
| from .bias_add_grad_ds import _bias_add_grad_ds_tbe | |||
| from .cast import _cast_tbe | |||
| from .cast_ds import _cast_ds_tbe | |||
| from .conv2d import _conv2d_tbe | |||
| @@ -113,6 +115,7 @@ from .scatter_nd_sub import _scatter_nd_sub_tbe | |||
| from .scatter_non_aliasing_add import _scatter_non_aliasing_add_tbe | |||
| from .reduce_mean import _reduce_mean_tbe | |||
| from .tile import _tile_tbe | |||
| from .tile_ds import _tile_ds_tbe | |||
| from .atomic_addr_clean import _atomic_addr_clean_tbe | |||
| from .gather_v2 import _gather_v2_tbe | |||
| from .gather_v2_ds import _gather_v2_ds_tbe | |||
| @@ -185,6 +188,7 @@ from .sparse_apply_proximal_adagrad_ds import _sparse_apply_proximal_adagrad_ds | |||
| from .apply_proximal_adagrad import _apply_proximal_adagrad | |||
| from .transpose import _transpose_tbe | |||
| from .transpose_d import _transpose_d_tbe | |||
| from .transpose_ds import _transpose_ds_tbe | |||
| from .truncate_div import _truncate_div_tbe | |||
| from .truncate_mod import _truncate_mod_tbe | |||
| from .unsorted_segment_sum import _unsorted_segment_sum_tbe | |||
| @@ -350,6 +354,7 @@ from .gru_v2_hidden_grad_cell import _gru_v2_hidden_grad_cell_tbe | |||
| from .lstm_input_grad import _lstm_input_grad_tbe | |||
| from .confusion_matrix import _confusion_matrix_tbe | |||
| from .broadcast_to import _broadcast_to_tbe | |||
| from .broadcast_to_ds import _broadcast_to_ds_tbe | |||
| from .strided_read import _strided_read_tbe | |||
| from .strided_write import _strided_write_tbe | |||
| from .range import _range_tbe | |||
| @@ -0,0 +1,39 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """BiasAdd op""" | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| bias_add_grad_op_info = TBERegOp("BiasAdd") \ | |||
| .fusion_type("COMMREDUCE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("bias_add.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("bias_add") \ | |||
| .partial_flag(True) \ | |||
| .dynamic_shape(True) \ | |||
| .attr("format", "required", "str", "all") \ | |||
| .input(0, "x", False, "required", "all") \ | |||
| .input(1, "bias", False, "required", "all") \ | |||
| .output(0, "y", False, "required", "all") \ | |||
| .is_dynamic_format(True) \ | |||
| .dtype_format(DataType.None_None, DataType.None_None, DataType.None_None) \ | |||
| .get_op_info() | |||
| @op_info_register(bias_add_grad_op_info) | |||
| def _bias_add_ds_tbe(): | |||
| """BiasAdd TBE register""" | |||
| return | |||
| @@ -0,0 +1,52 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """BiasAddGrad op""" | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| bias_add_grad_op_info = TBERegOp("BiasAddGrad") \ | |||
| .fusion_type("COMMREDUCE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("bias_add_grad.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("bias_add_grad") \ | |||
| .partial_flag(True) \ | |||
| .dynamic_shape(True) \ | |||
| .attr("format", "required", "str", "all") \ | |||
| .input(0, "output_backprop", False, "required", "all") \ | |||
| .output(0, "output", False, "required", "all") \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F16_FracNZ, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default) \ | |||
| .dtype_format(DataType.F32_FracNZ, DataType.F32_Default) \ | |||
| .dtype_format(DataType.F16_FracNZ, DataType.F16_NHWC) \ | |||
| .dtype_format(DataType.F32_FracNZ, DataType.F32_NHWC) \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_NHWC) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_NHWC) \ | |||
| .dtype_format(DataType.F16_NDC1HWC0, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F32_NDC1HWC0, DataType.F32_Default) \ | |||
| .dtype_format(DataType.F16_NDC1HWC0, DataType.F16_NHWC) \ | |||
| .dtype_format(DataType.F32_NDC1HWC0, DataType.F32_NHWC) \ | |||
| .dtype_format(DataType.F16_FRACTAL_Z_3D, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F32_FRACTAL_Z_3D, DataType.F32_Default) \ | |||
| .dtype_format(DataType.F16_FRACTAL_Z_3D, DataType.F16_NHWC) \ | |||
| .dtype_format(DataType.F32_FRACTAL_Z_3D, DataType.F32_NHWC) \ | |||
| .get_op_info() | |||
| @op_info_register(bias_add_grad_op_info) | |||
| def _bias_add_grad_ds_tbe(): | |||
| """BiasAddGrad TBE register""" | |||
| return | |||
| @@ -0,0 +1,42 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """BroadcastTo op""" | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| broadcast_to_op_info = TBERegOp("DynamicBroadcastTo") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("broadcast_to.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("broadcast_to") \ | |||
| .partial_flag(True) \ | |||
| .dynamic_shape(True) \ | |||
| .input(0, "x", False, "required", "all") \ | |||
| .input(1, "shape", False, "required", "all") \ | |||
| .output(0, "y", False, "required", "all") \ | |||
| .dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default) \ | |||
| .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ | |||
| .dtype_format(DataType.F16_Default, DataType.I64_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F32_Default, DataType.I64_Default, DataType.F32_Default) \ | |||
| .dtype_format(DataType.I32_Default, DataType.I64_Default, DataType.I32_Default) \ | |||
| .get_op_info() | |||
| @op_info_register(broadcast_to_op_info) | |||
| def _broadcast_to_ds_tbe(): | |||
| """BroadcastTo TBE register""" | |||
| return | |||
| @@ -34,12 +34,8 @@ matmul_op_info = TBERegOp("MatMul") \ | |||
| .output(0, "y", False, "required", "all") \ | |||
| .dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_Default, DataType.I8_Default, | |||
| DataType.F16_FracNZ) \ | |||
| .dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_Default, DataType.I8_Default, | |||
| DataType.F32_FracNZ) \ | |||
| .dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F32_Default, DataType.I8_Default, | |||
| DataType.F16_FracNZ) \ | |||
| .dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F32_Default, DataType.I8_Default, | |||
| DataType.F32_FracNZ) \ | |||
| .get_op_info() | |||
| @@ -0,0 +1,42 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Dynamic Tile op""" | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| tile_op_info = TBERegOp("Tile") \ | |||
| .fusion_type("ELEMWISE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("tile.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("tile") \ | |||
| .partial_flag(True) \ | |||
| .dynamic_shape(True) \ | |||
| .input(0, "x1", False, "required", "all") \ | |||
| .input(1, "multiples", False, "required", "all") \ | |||
| .output(0, "y", False, "required", "all") \ | |||
| .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ | |||
| .dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default) \ | |||
| .dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.I32_Default, DataType.I64_Default, DataType.I32_Default) \ | |||
| .dtype_format(DataType.F32_Default, DataType.I64_Default, DataType.F32_Default) \ | |||
| .dtype_format(DataType.F16_Default, DataType.I64_Default, DataType.F16_Default) \ | |||
| .get_op_info() | |||
| @op_info_register(tile_op_info) | |||
| def _tile_ds_tbe(): | |||
| """Tile TBE register""" | |||
| return | |||
| @@ -1403,3 +1403,29 @@ class SliceGetItem(Primitive): | |||
| if value == "step": | |||
| return slice_value.step | |||
| raise AttributeError("\'slice\' object has no attribute {}".format(value)) | |||
| class DynamicBroadcastTo(Primitive): | |||
| """ | |||
| Broadcasts input tensor to a given shape. | |||
| Inputs: | |||
| - **input_x** (Tensor) - The input tensor. The data type should be one of the following types: | |||
| float16, float32, int32, int8, uint8. | |||
| The shape is :math:`(N,*)` where :math:`*` means,any number of additional dimensions. | |||
| - **shape** (Tensor): The target shape to broadcast. | |||
| Outputs: | |||
| Tensor, with the given `shape` and the same data type as `input_x`. | |||
| Raises: | |||
| ValueError: if the target and input shapes are incompatible. | |||
| Supported Platforms: | |||
| ``Ascend`` | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self): | |||
| """Initialize DynamicBroadcastTo""" | |||
| self.init_prim_io_names(inputs=['x', 'shape'], outputs=['y']) | |||
| @@ -518,7 +518,7 @@ class Reshape(PrimitiveWithInfer): | |||
| neg_index = i | |||
| else: | |||
| dim_prod *= shp_i | |||
| arr_prod = np.prod(x_shp) | |||
| if -1 in x_shp: | |||
| if 'max_shape' in x: | |||
| x_max_shape = x['max_shape'] | |||
| @@ -542,6 +542,7 @@ class Reshape(PrimitiveWithInfer): | |||
| 'max_shape': tuple(max_shape), | |||
| 'min_shape': tuple(min_shape)} | |||
| else: | |||
| arr_prod = np.prod(x_shp) | |||
| if dim_prod <= 0: | |||
| raise ValueError(f"For '{self.name}', the shape of 'input_x' is {x_shp}, " | |||
| f"the value of 'input_shape' is {shape_v}. " | |||
| @@ -451,9 +451,8 @@ class _Reduce(PrimitiveWithInfer): | |||
| axis_shape = axis_shape_list[0] | |||
| if axis_shape == -1 and not self.keep_dims: | |||
| out_shape = np.array([-2]).tolist() | |||
| output_min_shape = np.ones_like(input_shp).tolist() | |||
| output_max_shape = max_v * np.ones_like(input_shp) | |||
| output_max_shape = output_max_shape.tolist() | |||
| output_min_shape = input_x['min_shape'] | |||
| output_max_shape = input_x['max_shape'] | |||
| elif not self.keep_dims: | |||
| out_shape = -1 * np.ones_like(input_shp[:-axis_shape]) | |||
| out_shape = out_shape.tolist() | |||
| @@ -467,12 +466,12 @@ class _Reduce(PrimitiveWithInfer): | |||
| output_max_shape = max_v * np.ones_like(input_shp) | |||
| output_max_shape = output_max_shape.tolist() | |||
| else: | |||
| out_shape = _infer_shape_reduce(input_shp, axis_v, self.keep_dims, self.name) | |||
| output_max_shape = _infer_shape_reduce(input_x['max_shape'], axis_v, self.keep_dims, self.name) | |||
| output_min_shape = _infer_shape_reduce(input_x['min_shape'], axis_v, self.keep_dims, self.name) | |||
| out_shape = _infer_shape_reduce(input_shp, axis_v, self.keep_dims, self.name) | |||
| else: | |||
| if axis_v is None: | |||
| raise ValueError(f"For {self.name}, the 'axis' cannot be None.") | |||
| raise ValueError(f"For {self.name}, axis must be const, its value cannot be None.") | |||
| out_shape = _infer_shape_reduce(input_shp, axis_v, self.keep_dims, self.name) | |||
| output_max_shape = out_shape | |||
| output_min_shape = out_shape | |||
| @@ -160,8 +160,8 @@ def build_train_network(network, optimizer, loss_fn=None, level='O0', boost_leve | |||
| validator.check_value_type('network', network, nn.Cell) | |||
| validator.check_value_type('optimizer', optimizer, (nn.Optimizer, boost.FreezeOpt)) | |||
| if not isinstance(level, str): | |||
| raise TypeError("The argument `level` must be a string in ['O0', 'O2', 'O3', 'auto'], \ | |||
| but got type {}.".format(type(level))) | |||
| raise TypeError(f"The argument `level` must be a string in ['O0', 'O2', 'O3', 'auto'], " | |||
| f"but got type {str(type(level))}.") | |||
| validator.check('level', level, "", ['O0', 'O2', 'O3', 'auto'], Rel.IN) | |||
| validator.check('boost_level', boost_level, "", ['O0', 'O1', 'O2'], Rel.IN) | |||