| @@ -1 +1 @@ | |||||
| Subproject commit 43f5d24337bf785251eefae2d810c7d5684194d6 | |||||
| Subproject commit 63cb729373ae8b1b14bc14176c14dac6d18d0e4d | |||||
| @@ -320,6 +320,224 @@ class Validator: | |||||
| raise TypeError(f"{msg_prefix} `{arg_name}` must be float.") | raise TypeError(f"{msg_prefix} `{arg_name}` must be float.") | ||||
| class ParamValidator: | |||||
| """Parameter validator. NOTICE: this class will be replaced by `class Validator`""" | |||||
| @staticmethod | |||||
| def equal(arg_name, arg_value, cond_str, cond): | |||||
| """Judging valid value.""" | |||||
| if not cond: | |||||
| raise ValueError(f'The `{arg_name}` must be {cond_str}, but got {arg_value}.') | |||||
| @staticmethod | |||||
| def check(arg_name, arg_value, value_name, value, rel=Rel.EQ): | |||||
| """This method is only used for check int values, since when compare float values, | |||||
| we need consider float error.""" | |||||
| rel_fn = Rel.get_fns(rel) | |||||
| if not rel_fn(arg_value, value): | |||||
| rel_str = Rel.get_strs(rel).format(f'{value_name}: {value}') | |||||
| raise ValueError(f'The `{arg_name}` should be {rel_str}, but got {arg_value}.') | |||||
| @staticmethod | |||||
| def check_integer(arg_name, arg_value, value, rel): | |||||
| """Integer value judgment.""" | |||||
| rel_fn = Rel.get_fns(rel) | |||||
| type_mismatch = not isinstance(arg_value, int) or isinstance(arg_value, bool) | |||||
| if type_mismatch or not rel_fn(arg_value, value): | |||||
| rel_str = Rel.get_strs(rel).format(value) | |||||
| raise ValueError(f'The `{arg_name}` should be an int and must {rel_str}, but got {arg_value}.') | |||||
| return arg_value | |||||
| @staticmethod | |||||
| def check_shape_length(arg_name, arg_value, value, rel): | |||||
| """Shape length judgment.""" | |||||
| rel_fn = Rel.get_fns(rel) | |||||
| type_mismatch = not isinstance(arg_value, int) | |||||
| if type_mismatch or not rel_fn(arg_value, value): | |||||
| rel_str = Rel.get_strs(rel).format(value) | |||||
| raise ValueError(f'The length of `{arg_name}` should be an int and must {rel_str}, but got {arg_value}') | |||||
| return arg_value | |||||
| @staticmethod | |||||
| def check_int_range(arg_name, arg_value, lower_limit, upper_limit, rel): | |||||
| """This method is only used for check int values, | |||||
| since when compare float values, we need consider float error.""" | |||||
| rel_fn = Rel.get_fns(rel) | |||||
| type_mismatch = not isinstance(arg_value, int) | |||||
| if type_mismatch or not rel_fn(arg_value, lower_limit, upper_limit): | |||||
| rel_str = Rel.get_strs(rel).format(lower_limit, upper_limit) | |||||
| raise ValueError(f'The `{arg_name}` should be an int in range {rel_str}, but got {arg_value}.') | |||||
| return arg_value | |||||
| @staticmethod | |||||
| def check_isinstance(arg_name, arg_value, classes): | |||||
| """Check arg isinstance of classes""" | |||||
| if not isinstance(arg_value, classes): | |||||
| raise ValueError(f'The `{arg_name}` should be isinstance of {classes}, but got {arg_value}.') | |||||
| return arg_value | |||||
| @staticmethod | |||||
| def check_number_range(arg_name, arg_value, lower_limit, upper_limit, rel): | |||||
| """Is it necessary to consider error when comparing float values.""" | |||||
| rel_fn = Rel.get_fns(rel) | |||||
| if not rel_fn(arg_value, lower_limit, upper_limit): | |||||
| rel_str = Rel.get_strs(rel).format(lower_limit, upper_limit) | |||||
| raise ValueError(f'The `{arg_name}` should be in range {rel_str}, but got {arg_value}.') | |||||
| return arg_value | |||||
| @staticmethod | |||||
| def check_subclass(arg_name, type_, template_type, with_type_of=True): | |||||
| """Check whether some type is subclass of another type""" | |||||
| if not isinstance(template_type, Iterable): | |||||
| template_type = (template_type,) | |||||
| if not any([mstype.issubclass_(type_, x) for x in template_type]): | |||||
| type_str = (type(type_).__name__ if isinstance(type_, (tuple, list)) else "") + str(type_) | |||||
| raise TypeError(f'The {"type of" if with_type_of else ""} `{arg_name}` should be subclass' | |||||
| f' of {",".join((str(x) for x in template_type))}, but got {type_str}.') | |||||
| @staticmethod | |||||
| def check_args_tensor(args): | |||||
| """Check whether args are all tensor.""" | |||||
| if not isinstance(args, dict): | |||||
| raise TypeError("The args should be a dict.") | |||||
| for arg, value in args.items(): | |||||
| ParamValidator.check_subclass(arg, value, mstype.tensor) | |||||
| @staticmethod | |||||
| def check_bool(arg_name, arg_value): | |||||
| """Check arg isinstance of bool""" | |||||
| if not isinstance(arg_value, bool): | |||||
| raise ValueError(f'The `{arg_name}` should be isinstance of bool, but got {arg_value}.') | |||||
| return arg_value | |||||
| @staticmethod | |||||
| def check_type(arg_name, arg_value, valid_types): | |||||
| """Type checking.""" | |||||
| def raise_error_msg(): | |||||
| """func for raising error message when check failed""" | |||||
| type_names = [t.__name__ for t in valid_types] | |||||
| num_types = len(valid_types) | |||||
| raise TypeError(f'The type of `{arg_name}` should be {"one of " if num_types > 1 else ""}' | |||||
| f'{type_names if num_types > 1 else type_names[0]}, but got {type(arg_value).__name__}.') | |||||
| if isinstance(arg_value, type(mstype.tensor)): | |||||
| arg_value = arg_value.element_type() | |||||
| # Notice: bool is subclass of int, so `check_type('x', True, [int])` will check fail, and | |||||
| # `check_type('x', True, [bool, int])` will check pass | |||||
| if isinstance(arg_value, bool) and bool not in tuple(valid_types): | |||||
| raise_error_msg() | |||||
| if isinstance(arg_value, tuple(valid_types)): | |||||
| return arg_value | |||||
| raise_error_msg() | |||||
| @staticmethod | |||||
| def check_typename(arg_name, arg_type, valid_types): | |||||
| """Does it contain the _name_ attribute.""" | |||||
| def get_typename(t): | |||||
| return t.__name__ if hasattr(t, '__name__') else str(t) | |||||
| if isinstance(arg_type, type(mstype.tensor)): | |||||
| arg_type = arg_type.element_type() | |||||
| if arg_type in valid_types: | |||||
| return arg_type | |||||
| type_names = [get_typename(t) for t in valid_types] | |||||
| if len(valid_types) == 1: | |||||
| raise ValueError(f'The type of `{arg_name}` should be {type_names[0]},' | |||||
| f' but got {get_typename(arg_type)}.') | |||||
| raise ValueError(f'The type of `{arg_name}` should be one of {type_names},' | |||||
| f' but got {get_typename(arg_type)}.') | |||||
| @staticmethod | |||||
| def check_string(arg_name, arg_value, valid_values): | |||||
| """String type judgment.""" | |||||
| if isinstance(arg_value, str) and arg_value in valid_values: | |||||
| return arg_value | |||||
| if len(valid_values) == 1: | |||||
| raise ValueError(f'The `{arg_name}` should be str and must be {valid_values[0]},' | |||||
| f' but got {arg_value}.') | |||||
| raise ValueError(f'The `{arg_name}` should be str and must be one of {valid_values},' | |||||
| f' but got {arg_value}.') | |||||
| @staticmethod | |||||
| def check_type_same(args, valid_values): | |||||
| """Determine whether the types are the same.""" | |||||
| name = list(args.keys())[0] | |||||
| value = list(args.values())[0] | |||||
| if isinstance(value, type(mstype.tensor)): | |||||
| value = value.element_type() | |||||
| for arg_name, arg_value in args.items(): | |||||
| if isinstance(arg_value, type(mstype.tensor)): | |||||
| arg_value = arg_value.element_type() | |||||
| if arg_value not in valid_values: | |||||
| raise TypeError(f'The `{arg_name}` should be in {valid_values},' | |||||
| f' but `{arg_name}` is {arg_value}.') | |||||
| if arg_value != value: | |||||
| raise TypeError(f'`{arg_name}` should be same as `{name}`,' | |||||
| f' but `{arg_name}` is {arg_value}, `{name}` is {value}.') | |||||
| @staticmethod | |||||
| def check_two_types_same(arg1_name, arg1_type, arg2_name, arg2_type): | |||||
| """Determine whether the types of two variables are the same.""" | |||||
| if arg1_type != arg2_type: | |||||
| raise TypeError(f'The type of `{arg1_name}` and `{arg2_name}` should be same.') | |||||
| @staticmethod | |||||
| def check_value_on_integer(arg_name, arg_value, value, rel): | |||||
| """Judging integer type.""" | |||||
| rel_fn = Rel.get_fns(rel) | |||||
| type_match = isinstance(arg_value, int) | |||||
| if type_match and (not rel_fn(arg_value, value)): | |||||
| rel_str = Rel.get_strs(rel).format(value) | |||||
| raise ValueError(f'The `{arg_name}` should be an int and must {rel_str}, but got {arg_value}.') | |||||
| return arg_value | |||||
| @staticmethod | |||||
| def check_param_equal(param1_name, param1_value, param2_name, param2_value): | |||||
| """Judging the equality of parameters.""" | |||||
| if param1_value != param2_value: | |||||
| raise ValueError(f"`{param1_name}` must equal `{param2_name}`," | |||||
| f" but got `{param1_name}` = {param1_value}," | |||||
| f" `{param2_name}` = {param2_value}.") | |||||
| @staticmethod | |||||
| def check_const_input(arg_name, arg_value): | |||||
| """Check valid value.""" | |||||
| if arg_value is None: | |||||
| raise ValueError(f'The `{arg_name}` must be a const input, but got {arg_value}.') | |||||
| @staticmethod | |||||
| def check_float_positive(arg_name, arg_value): | |||||
| """Float type judgment.""" | |||||
| if isinstance(arg_value, float): | |||||
| if arg_value > 0: | |||||
| return arg_value | |||||
| raise ValueError(f"The `{arg_name}` must be positive, but got {arg_value}.") | |||||
| raise TypeError(f"`{arg_name}` must be float!") | |||||
| @staticmethod | |||||
| def check_pad_value_by_mode(op_name, pad_mode, padding): | |||||
| """Validate value of padding according to pad_mode""" | |||||
| if pad_mode != 'pad' and padding != 0: | |||||
| raise ValueError(f"For op '{op_name}', padding must be zero when pad_mode is '{pad_mode}'.") | |||||
| return padding | |||||
| @staticmethod | |||||
| def check_empty_shape_input(arg_name, arg_value): | |||||
| """Check zeros value.""" | |||||
| if 0 in arg_value: | |||||
| raise ValueError(f"Input `{arg_name}` cannot be empty.") | |||||
| @staticmethod | |||||
| def check_scalar_shape_input(arg_name, arg_value): | |||||
| """Check scalar shape input.""" | |||||
| if arg_value != []: | |||||
| raise ValueError(f"Input `{arg_name}` shape should be (). got {arg_value}") | |||||
| def check_int(input_param): | def check_int(input_param): | ||||
| """Int type judgment.""" | """Int type judgment.""" | ||||
| if isinstance(input_param, int) and not isinstance(input_param, bool): | if isinstance(input_param, int) and not isinstance(input_param, bool): | ||||
| @@ -201,6 +201,7 @@ void KernelRuntime::RunOpAssignOutputMemory(const AnfNodePtr &kernel) { | |||||
| if (AnfAlgo::GetCNodeName(kernel) == "ApplyMomentum") { | if (AnfAlgo::GetCNodeName(kernel) == "ApplyMomentum") { | ||||
| auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, 0); | auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, 0); | ||||
| AnfAlgo::SetOutputAddr(device_address, 0, kernel.get()); | AnfAlgo::SetOutputAddr(device_address, 0, kernel.get()); | ||||
| AnfAlgo::SetOutputAddr(device_address, 1, kernel.get()); | |||||
| return; | return; | ||||
| } | } | ||||
| @@ -27,7 +27,6 @@ namespace kernel { | |||||
| constexpr auto kInitDataSetQueue = "InitDataSetQueue"; | constexpr auto kInitDataSetQueue = "InitDataSetQueue"; | ||||
| constexpr auto kInitData = "InitData"; | constexpr auto kInitData = "InitData"; | ||||
| constexpr auto kGetNext = "GetNext"; | constexpr auto kGetNext = "GetNext"; | ||||
| constexpr auto kDropoutGenMask = "DropoutGenMask"; | |||||
| constexpr auto kPrint = "Print"; | constexpr auto kPrint = "Print"; | ||||
| constexpr auto kOutputTypes = "output_types"; | constexpr auto kOutputTypes = "output_types"; | ||||
| @@ -55,7 +55,6 @@ MS_REG_GPU_KERNEL_ONE(BatchNorm, | |||||
| .AddOutputAttr(kNumberTypeFloat32) | .AddOutputAttr(kNumberTypeFloat32) | ||||
| .AddOutputAttr(kNumberTypeFloat32) | .AddOutputAttr(kNumberTypeFloat32) | ||||
| .AddOutputAttr(kNumberTypeFloat32) | .AddOutputAttr(kNumberTypeFloat32) | ||||
| .AddOutputAttr(kNumberTypeFloat32) | |||||
| .AddOutputAttr(kNumberTypeFloat32), | .AddOutputAttr(kNumberTypeFloat32), | ||||
| FusedBatchNormGpuKernel, float) | FusedBatchNormGpuKernel, float) | ||||
| MS_REG_GPU_KERNEL_ONE(BatchNorm, | MS_REG_GPU_KERNEL_ONE(BatchNorm, | ||||
| @@ -69,7 +68,6 @@ MS_REG_GPU_KERNEL_ONE(BatchNorm, | |||||
| .AddOutputAttr(kNumberTypeFloat16) | .AddOutputAttr(kNumberTypeFloat16) | ||||
| .AddOutputAttr(kNumberTypeFloat16) | .AddOutputAttr(kNumberTypeFloat16) | ||||
| .AddOutputAttr(kNumberTypeFloat16) | .AddOutputAttr(kNumberTypeFloat16) | ||||
| .AddOutputAttr(kNumberTypeFloat16) | |||||
| .AddOutputAttr(kNumberTypeFloat16), | .AddOutputAttr(kNumberTypeFloat16), | ||||
| FusedBatchNormGpuKernel, half) | FusedBatchNormGpuKernel, half) | ||||
| } // namespace kernel | } // namespace kernel | ||||
| @@ -157,9 +157,6 @@ class FusedBatchNormGpuKernel : public GpuKernel { | |||||
| output_size_list_.push_back(para_size); // running variance | output_size_list_.push_back(para_size); // running variance | ||||
| output_size_list_.push_back(para_size); // save mean | output_size_list_.push_back(para_size); // save mean | ||||
| output_size_list_.push_back(para_size); // save variance | output_size_list_.push_back(para_size); // save variance | ||||
| if (!is_train_) { | |||||
| output_size_list_.push_back(para_size); // reserve | |||||
| } | |||||
| return; | return; | ||||
| } | } | ||||
| @@ -30,6 +30,9 @@ namespace mindspore { | |||||
| namespace kernel { | namespace kernel { | ||||
| namespace tbe { | namespace tbe { | ||||
| static std::map<string, string> tbe_func_adapter_map = { | static std::map<string, string> tbe_func_adapter_map = { | ||||
| {"softmax", "softmax_v2"}, | |||||
| {"log_softmax", "log_softmax_v2"}, | |||||
| {"apply_momentum", "apply_momentum_d"}, | |||||
| {"re_lu6", "relu6"}, | {"re_lu6", "relu6"}, | ||||
| {"re_lu6_grad", "relu6_grad"}, | {"re_lu6_grad", "relu6_grad"}, | ||||
| {"re_lu", "relu"}, | {"re_lu", "relu"}, | ||||
| @@ -344,8 +344,23 @@ bool IsNopNode(const AnfNodePtr &node) { | |||||
| return true; | return true; | ||||
| } | } | ||||
| bool IsAllNopNode(session::KernelGraph *const graph) { | |||||
| MS_EXCEPTION_IF_NULL(graph); | |||||
| auto execution_order = graph->execution_order(); | |||||
| for (auto &cnode : execution_order) { | |||||
| MS_EXCEPTION_IF_NULL(cnode); | |||||
| if (!IsNopNode(cnode)) { | |||||
| return false; | |||||
| } | |||||
| } | |||||
| return true; | |||||
| } | |||||
| void HideNopNode(session::KernelGraph *const graph) { | void HideNopNode(session::KernelGraph *const graph) { | ||||
| MS_EXCEPTION_IF_NULL(graph); | MS_EXCEPTION_IF_NULL(graph); | ||||
| if (IsAllNopNode(graph) == true) { | |||||
| return; | |||||
| } | |||||
| auto execution_order = graph->execution_order(); | auto execution_order = graph->execution_order(); | ||||
| MS_LOG(INFO) << "nop node info (Before Remove) size: " << execution_order.size(); | MS_LOG(INFO) << "nop node info (Before Remove) size: " << execution_order.size(); | ||||
| std::vector<CNodePtr> new_nodes; | std::vector<CNodePtr> new_nodes; | ||||
| @@ -361,6 +376,9 @@ void HideNopNode(session::KernelGraph *const graph) { | |||||
| void RemoveNopNode(session::KernelGraph *const graph) { | void RemoveNopNode(session::KernelGraph *const graph) { | ||||
| MS_EXCEPTION_IF_NULL(graph); | MS_EXCEPTION_IF_NULL(graph); | ||||
| if (IsAllNopNode(graph) == true) { | |||||
| return; | |||||
| } | |||||
| bool changed = true; | bool changed = true; | ||||
| while (changed) { | while (changed) { | ||||
| changed = false; | changed = false; | ||||
| @@ -177,6 +177,7 @@ const char kNameAbsGrad[] = "AbsGrad"; | |||||
| const char kNameBinaryCrossEntropy[] = "BinaryCrossEntropy"; | const char kNameBinaryCrossEntropy[] = "BinaryCrossEntropy"; | ||||
| const char kNameBinaryCrossEntropyGrad[] = "BinaryCrossEntropyGrad"; | const char kNameBinaryCrossEntropyGrad[] = "BinaryCrossEntropyGrad"; | ||||
| const char kNameSparseApplyAdagrad[] = "SparseApplyAdagrad"; | const char kNameSparseApplyAdagrad[] = "SparseApplyAdagrad"; | ||||
| const char kNameSparseApplyFtrlD[] = "SparseApplyFtrlD"; | |||||
| const char kNameAcosh[] = "Acosh"; | const char kNameAcosh[] = "Acosh"; | ||||
| const char kNameAcoshGrad[] = "AcoshGrad"; | const char kNameAcoshGrad[] = "AcoshGrad"; | ||||
| const char kNameFloorMod[] = "FloorMod"; | const char kNameFloorMod[] = "FloorMod"; | ||||
| @@ -206,7 +207,7 @@ std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_ma | |||||
| {string(kNameMaxPool), ADPT_DESC(MaxPool)}, | {string(kNameMaxPool), ADPT_DESC(MaxPool)}, | ||||
| {string(kNameAvgPool), ADPT_DESC(AvgPool)}, | {string(kNameAvgPool), ADPT_DESC(AvgPool)}, | ||||
| {string(kNameMaxPoolWithArgmax), ADPT_DESC(MaxPoolWithArgmax)}, | {string(kNameMaxPoolWithArgmax), ADPT_DESC(MaxPoolWithArgmax)}, | ||||
| {string(kNameTopK), ADPT_DESC(TopKV2)}, | |||||
| {string(kNameTopK), ADPT_DESC(TopK)}, | |||||
| {string(kNamePack), ADPT_DESC(Pack)}, | {string(kNamePack), ADPT_DESC(Pack)}, | ||||
| {string(kNameUnpack), ADPT_DESC(Unpack)}, | {string(kNameUnpack), ADPT_DESC(Unpack)}, | ||||
| {string(kNameSplitD), ADPT_DESC(SplitD)}, | {string(kNameSplitD), ADPT_DESC(SplitD)}, | ||||
| @@ -240,15 +241,15 @@ std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_ma | |||||
| {string(kNameSquare), ADPT_DESC(Square)}, | {string(kNameSquare), ADPT_DESC(Square)}, | ||||
| {prim::kPrimTanh->name(), ADPT_DESC(Tanh)}, | {prim::kPrimTanh->name(), ADPT_DESC(Tanh)}, | ||||
| {prim::kPrimTanhGrad->name(), ADPT_DESC(TanhGrad)}, | {prim::kPrimTanhGrad->name(), ADPT_DESC(TanhGrad)}, | ||||
| {string(kNameResizeNearestNeighborD), ADPT_DESC(ResizeNearestNeighborD)}, | |||||
| {string(kNameResizeNearestNeighborGrad), ADPT_DESC(ResizeNearestNeighborGrad)}, | |||||
| {string(kNameResizeNearestNeighborD), ADPT_DESC(ResizeNearestNeighborV2D)}, | |||||
| {string(kNameResizeNearestNeighborGrad), ADPT_DESC(ResizeNearestNeighborV2Grad)}, | |||||
| {string(kNameApplyAdam), ADPT_DESC(ApplyAdam)}, | {string(kNameApplyAdam), ADPT_DESC(ApplyAdam)}, | ||||
| {string(kNameReLU6), ADPT_DESC(Relu6)}, | {string(kNameReLU6), ADPT_DESC(Relu6)}, | ||||
| {string(kNameReLU6Grad), ADPT_DESC(Relu6Grad)}, | {string(kNameReLU6Grad), ADPT_DESC(Relu6Grad)}, | ||||
| {string(kNameElu), ADPT_DESC(Elu)}, | {string(kNameElu), ADPT_DESC(Elu)}, | ||||
| {string(kNameEluGrad), ADPT_DESC(EluGrad)}, | {string(kNameEluGrad), ADPT_DESC(EluGrad)}, | ||||
| {string(kNameResizeBilinearGrad), ADPT_DESC(ResizeBilinearGrad)}, | |||||
| {string(kNameResizeBilinear), ADPT_DESC(ResizeBilinearD)}, | |||||
| {string(kNameResizeBilinearGrad), ADPT_DESC(ResizeBilinearV2Grad)}, | |||||
| {string(kNameResizeBilinear), ADPT_DESC(ResizeBilinearV2D)}, | |||||
| {string(kNameZerosLike), ADPT_DESC(ZerosLike)}, | {string(kNameZerosLike), ADPT_DESC(ZerosLike)}, | ||||
| {string(kNameOnesLike), ADPT_DESC(OnesLike)}, | {string(kNameOnesLike), ADPT_DESC(OnesLike)}, | ||||
| {string(kNameScatterNdUpdate), ADPT_DESC(ScatterNdUpdate)}, | {string(kNameScatterNdUpdate), ADPT_DESC(ScatterNdUpdate)}, | ||||
| @@ -329,7 +330,7 @@ std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_ma | |||||
| {prim::kPrimMinimum->name(), ADPT_DESC(Minimum)}, | {prim::kPrimMinimum->name(), ADPT_DESC(Minimum)}, | ||||
| {prim::kPrimSelect->name(), ADPT_DESC(Select)}, | {prim::kPrimSelect->name(), ADPT_DESC(Select)}, | ||||
| {string(kNameLessEqual), ADPT_DESC(LessEqual)}, | {string(kNameLessEqual), ADPT_DESC(LessEqual)}, | ||||
| {prim::kPrimLogSoftmax->name(), ADPT_DESC(LogSoftmax)}, | |||||
| {prim::kPrimLogSoftmax->name(), ADPT_DESC(LogSoftmaxV2)}, | |||||
| {string(kNameTruncatedNormal), ADPT_DESC(TruncatedNormal)}, | {string(kNameTruncatedNormal), ADPT_DESC(TruncatedNormal)}, | ||||
| {string(kNameStridedSliceGrad), ADPT_DESC(StridedSliceGrad)}, | {string(kNameStridedSliceGrad), ADPT_DESC(StridedSliceGrad)}, | ||||
| {prim::kPrimGelu->name(), ADPT_DESC(Gelu)}, | {prim::kPrimGelu->name(), ADPT_DESC(Gelu)}, | ||||
| @@ -363,7 +364,7 @@ std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_ma | |||||
| {prim::kPrimMatMul->name(), ADPT_DESC(MatMul)}, | {prim::kPrimMatMul->name(), ADPT_DESC(MatMul)}, | ||||
| {string(kNameConst), ADPT_DESC(Constant, Const)}, | {string(kNameConst), ADPT_DESC(Constant, Const)}, | ||||
| {string(kNameSoftmax), ADPT_DESC(Softmax)}, | |||||
| {string(kNameSoftmax), ADPT_DESC(SoftmaxV2)}, | |||||
| {string(kNameSoftmaxGrad), ADPT_DESC(SoftmaxGrad)}, | {string(kNameSoftmaxGrad), ADPT_DESC(SoftmaxGrad)}, | ||||
| {string(kNameParam), ADPT_DESC(Data)}, | {string(kNameParam), ADPT_DESC(Data)}, | ||||
| {string(kNameROIAlign), ADPT_DESC(ROIAlign)}, | {string(kNameROIAlign), ADPT_DESC(ROIAlign)}, | ||||
| @@ -373,6 +374,7 @@ std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_ma | |||||
| {string(kNameBinaryCrossEntropy), ADPT_DESC(BinaryCrossEntropy)}, | {string(kNameBinaryCrossEntropy), ADPT_DESC(BinaryCrossEntropy)}, | ||||
| {string(kNameBinaryCrossEntropyGrad), ADPT_DESC(BinaryCrossEntropyGrad)}, | {string(kNameBinaryCrossEntropyGrad), ADPT_DESC(BinaryCrossEntropyGrad)}, | ||||
| {string(kNameSparseApplyAdagrad), ADPT_DESC(SparseApplyAdagradD)}, | {string(kNameSparseApplyAdagrad), ADPT_DESC(SparseApplyAdagradD)}, | ||||
| {string(kNameSparseApplyFtrlD), ADPT_DESC(SparseApplyFtrlD)}, | |||||
| {string(kNameAcosh), ADPT_DESC(Acosh)}, | {string(kNameAcosh), ADPT_DESC(Acosh)}, | ||||
| {string(kNameAcoshGrad), ADPT_DESC(AcoshGrad)}, | {string(kNameAcoshGrad), ADPT_DESC(AcoshGrad)}, | ||||
| {string(kNameFloorMod), ADPT_DESC(FloorMod)}, | {string(kNameFloorMod), ADPT_DESC(FloorMod)}, | ||||
| @@ -390,6 +392,7 @@ std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_ma | |||||
| {string(kNameApplyCenteredRMSProp), ADPT_DESC(ApplyCenteredRMSProp)}}; | {string(kNameApplyCenteredRMSProp), ADPT_DESC(ApplyCenteredRMSProp)}}; | ||||
| #ifdef ENABLE_GE | #ifdef ENABLE_GE | ||||
| adpt_map[string(kNamePrint)] = ADPT_DESC(Print); | adpt_map[string(kNamePrint)] = ADPT_DESC(Print); | ||||
| adpt_map[string(kNameApplyAdam)] = ADPT_DESC(ApplyAdamD); | |||||
| #endif | #endif | ||||
| return adpt_map; | return adpt_map; | ||||
| } | } | ||||
| @@ -1127,8 +1130,8 @@ void DfGraphConvertor::UpdateDataOpDesc(const AnfNodePtr &it, const OperatorPtr | |||||
| if (desc == nullptr) { | if (desc == nullptr) { | ||||
| MS_LOG(ERROR) << "Update data op descriptor failed! TensorDesc is null."; | MS_LOG(ERROR) << "Update data op descriptor failed! TensorDesc is null."; | ||||
| } else { | } else { | ||||
| (void)std::static_pointer_cast<Data>(op)->update_input_desc_data(*desc); | |||||
| (void)std::static_pointer_cast<Data>(op)->update_output_desc_out(*desc); | |||||
| (void)std::static_pointer_cast<Data>(op)->update_input_desc_x(*desc); | |||||
| (void)std::static_pointer_cast<Data>(op)->update_output_desc_y(*desc); | |||||
| } | } | ||||
| } | } | ||||
| @@ -736,7 +736,7 @@ class OpAdapter : public BaseOpAdapter { | |||||
| return static_cast<int64_t>(GetValue<int>(value)); | return static_cast<int64_t>(GetValue<int>(value)); | ||||
| } | } | ||||
| // specialization for int to Vector | |||||
| // specialization for int or tuple broadcast to Vector | |||||
| static std::vector<int64_t> ConvertAny(const ValuePtr &value, const std::string &name, | static std::vector<int64_t> ConvertAny(const ValuePtr &value, const std::string &name, | ||||
| const AnyTraits<std::vector<int64_t>> anyTraitsInt) { | const AnyTraits<std::vector<int64_t>> anyTraitsInt) { | ||||
| return ConvertAnyUtil(value, name, anyTraitsInt); | return ConvertAnyUtil(value, name, anyTraitsInt); | ||||
| @@ -35,15 +35,21 @@ GeTensor ConvertAnyUtil(const ValuePtr &value, const AnyTraits<mindspore::tensor | |||||
| std::vector<int64_t> ConvertAnyUtil(const ValuePtr &value, const std::string &name, | std::vector<int64_t> ConvertAnyUtil(const ValuePtr &value, const std::string &name, | ||||
| const AnyTraits<std::vector<int64_t>>) { | const AnyTraits<std::vector<int64_t>>) { | ||||
| int64_t data = GetValue<int>(value); | |||||
| MS_EXCEPTION_IF_NULL(value); | |||||
| std::vector<int64_t> list; | std::vector<int64_t> list; | ||||
| int size = 2; // 2 int in list | |||||
| if (name == "pad") { | if (name == "pad") { | ||||
| size = 4; // 4 int in list | |||||
| list = TransformUtil::ConvertIntToList(data, size); | |||||
| if (!value->isa<ValueSequeue>()) { | |||||
| MS_LOG(EXCEPTION) << "Value should be ValueTuple, but got" << value->type_name(); | |||||
| } | |||||
| auto vec = value->cast<ValueSequeuePtr>(); | |||||
| list.resize(vec->value().size() + 2); | |||||
| list[0] = 1; | list[0] = 1; | ||||
| list[1] = 1; | list[1] = 1; | ||||
| (void)std::transform(vec->value().begin(), vec->value().end(), list.begin() + 2, | |||||
| [](const ValuePtr &val) { return static_cast<int64_t>(GetValue<int>(val)); }); | |||||
| } else { | } else { | ||||
| int64_t data = GetValue<int>(value); | |||||
| int size = 2; // 2 int in list | |||||
| list = TransformUtil::ConvertIntToList(data, size); | list = TransformUtil::ConvertIntToList(data, size); | ||||
| } | } | ||||
| @@ -138,11 +138,10 @@ OUTPUT_MAP(ApplyMomentum) = {{0, OUTPUT_DESC(var)}}; | |||||
| INPUT_MAP(Summary) = {{2, INPUT_DESC(x)}}; | INPUT_MAP(Summary) = {{2, INPUT_DESC(x)}}; | ||||
| ATTR_MAP(Summary) = EMPTY_ATTR_MAP; | ATTR_MAP(Summary) = EMPTY_ATTR_MAP; | ||||
| // data | |||||
| // Data | |||||
| INPUT_MAP(Data) = EMPTY_INPUT_MAP; | INPUT_MAP(Data) = EMPTY_INPUT_MAP; | ||||
| ATTR_MAP(Data) = EMPTY_ATTR_MAP; | ATTR_MAP(Data) = EMPTY_ATTR_MAP; | ||||
| // resnet ops in ge | |||||
| // BatchNorm | // BatchNorm | ||||
| INPUT_MAP(BatchNorm) = {{1, INPUT_DESC(x)}, | INPUT_MAP(BatchNorm) = {{1, INPUT_DESC(x)}, | ||||
| {2, INPUT_DESC(scale)}, | {2, INPUT_DESC(scale)}, | ||||
| @@ -156,13 +155,14 @@ OUTPUT_MAP(BatchNorm) = {{0, OUTPUT_DESC(y)}, | |||||
| {1, OUTPUT_DESC(batch_mean)}, | {1, OUTPUT_DESC(batch_mean)}, | ||||
| {2, OUTPUT_DESC(batch_variance)}, | {2, OUTPUT_DESC(batch_variance)}, | ||||
| {3, OUTPUT_DESC(reserve_space_1)}, | {3, OUTPUT_DESC(reserve_space_1)}, | ||||
| {4, OUTPUT_DESC(reserve_space_2)}, | |||||
| {5, OUTPUT_DESC(reserve_space_3)}}; | |||||
| {4, OUTPUT_DESC(reserve_space_2)}}; | |||||
| // BatchNormGrad | // BatchNormGrad | ||||
| INPUT_MAP(BatchNormGrad) = {{1, INPUT_DESC(y_backprop)}, {2, INPUT_DESC(x)}, | |||||
| {3, INPUT_DESC(scale)}, {4, INPUT_DESC(reserve_space_1)}, | |||||
| {5, INPUT_DESC(reserve_space_2)}, {6, INPUT_DESC(reserve_space_3)}}; | |||||
| INPUT_MAP(BatchNormGrad) = {{1, INPUT_DESC(y_backprop)}, | |||||
| {2, INPUT_DESC(x)}, | |||||
| {3, INPUT_DESC(scale)}, | |||||
| {4, INPUT_DESC(reserve_space_1)}, | |||||
| {5, INPUT_DESC(reserve_space_2)}}; | |||||
| ATTR_MAP(BatchNormGrad) = {{"data_format", ATTR_DESC(data_format, AnyTraits<std::string>())}, | ATTR_MAP(BatchNormGrad) = {{"data_format", ATTR_DESC(data_format, AnyTraits<std::string>())}, | ||||
| {"epsilon", ATTR_DESC(epsilon, AnyTraits<float>())}, | {"epsilon", ATTR_DESC(epsilon, AnyTraits<float>())}, | ||||
| {"is_training", ATTR_DESC(is_training, AnyTraits<bool>())}}; | {"is_training", ATTR_DESC(is_training, AnyTraits<bool>())}}; | ||||
| @@ -193,10 +193,9 @@ ATTR_MAP(PRelu) = EMPTY_ATTR_MAP; | |||||
| OUTPUT_MAP(PRelu) = {{0, OUTPUT_DESC(y)}}; | OUTPUT_MAP(PRelu) = {{0, OUTPUT_DESC(y)}}; | ||||
| // PReluGrad | // PReluGrad | ||||
| INPUT_MAP(PReluGrad) = { | |||||
| {1, INPUT_DESC(input_gradients)}, {2, INPUT_DESC(input_features)}, {3, INPUT_DESC(input_weights)}}; | |||||
| INPUT_MAP(PReluGrad) = {{1, INPUT_DESC(grads)}, {2, INPUT_DESC(features)}, {3, INPUT_DESC(weights)}}; | |||||
| ATTR_MAP(PReluGrad) = EMPTY_ATTR_MAP; | ATTR_MAP(PReluGrad) = EMPTY_ATTR_MAP; | ||||
| OUTPUT_MAP(PReluGrad) = {{0, OUTPUT_DESC(output_backprops_dx)}, {1, OUTPUT_DESC(output_backprops_da)}}; | |||||
| OUTPUT_MAP(PReluGrad) = {{0, OUTPUT_DESC(dx)}, {1, OUTPUT_DESC(da)}}; | |||||
| // Sigmoid | // Sigmoid | ||||
| INPUT_MAP(Sigmoid) = {{1, INPUT_DESC(x)}}; | INPUT_MAP(Sigmoid) = {{1, INPUT_DESC(x)}}; | ||||
| @@ -241,12 +240,12 @@ ATTR_MAP(CumsumD) = {{"exclusive", ATTR_DESC(exclusive, AnyTraits<bool>())}, | |||||
| {"reverse", ATTR_DESC(reverse, AnyTraits<bool>())}}; | {"reverse", ATTR_DESC(reverse, AnyTraits<bool>())}}; | ||||
| OUTPUT_MAP(CumsumD) = {{0, OUTPUT_DESC(y)}}; | OUTPUT_MAP(CumsumD) = {{0, OUTPUT_DESC(y)}}; | ||||
| // softmax | |||||
| INPUT_MAP(Softmax) = {{1, INPUT_DESC(x)}}; | |||||
| ATTR_MAP(Softmax) = { | |||||
| {"axis", ATTR_DESC(axis, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())}, | |||||
| // SoftmaxV2 | |||||
| INPUT_MAP(SoftmaxV2) = {{1, INPUT_DESC(x)}}; | |||||
| ATTR_MAP(SoftmaxV2) = { | |||||
| {"axis", ATTR_DESC(axes, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())}, | |||||
| }; | }; | ||||
| OUTPUT_MAP(Softmax) = {{0, OUTPUT_DESC(y)}}; | |||||
| OUTPUT_MAP(SoftmaxV2) = {{0, OUTPUT_DESC(y)}}; | |||||
| // SoftmaxGrad | // SoftmaxGrad | ||||
| INPUT_MAP(SoftmaxGrad) = {{1, INPUT_DESC(softmax)}, {2, INPUT_DESC(grad_softmax)}}; | INPUT_MAP(SoftmaxGrad) = {{1, INPUT_DESC(softmax)}, {2, INPUT_DESC(grad_softmax)}}; | ||||
| @@ -271,14 +270,14 @@ OUTPUT_MAP(GatherV2) = {{0, OUTPUT_DESC(y)}}; | |||||
| // ReduceSumD | // ReduceSumD | ||||
| INPUT_MAP(ReduceSumD) = {{1, INPUT_DESC(x)}}; | INPUT_MAP(ReduceSumD) = {{1, INPUT_DESC(x)}}; | ||||
| INPUT_ATTR_MAP(ReduceSumD) = { | INPUT_ATTR_MAP(ReduceSumD) = { | ||||
| {2, ATTR_DESC(axis, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())}}; | |||||
| {2, ATTR_DESC(axes, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())}}; | |||||
| ATTR_MAP(ReduceSumD) = {{"keep_dims", ATTR_DESC(keep_dims, AnyTraits<bool>())}}; | ATTR_MAP(ReduceSumD) = {{"keep_dims", ATTR_DESC(keep_dims, AnyTraits<bool>())}}; | ||||
| OUTPUT_MAP(ReduceSumD) = {{0, OUTPUT_DESC(y)}}; | OUTPUT_MAP(ReduceSumD) = {{0, OUTPUT_DESC(y)}}; | ||||
| // ReduceProdD | // ReduceProdD | ||||
| INPUT_MAP(ReduceProdD) = {{1, INPUT_DESC(x)}}; | INPUT_MAP(ReduceProdD) = {{1, INPUT_DESC(x)}}; | ||||
| INPUT_ATTR_MAP(ReduceProdD) = { | INPUT_ATTR_MAP(ReduceProdD) = { | ||||
| {2, ATTR_DESC(axis, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())}}; | |||||
| {2, ATTR_DESC(axes, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())}}; | |||||
| ATTR_MAP(ReduceProdD) = {{"keep_dims", ATTR_DESC(keep_dims, AnyTraits<bool>())}}; | ATTR_MAP(ReduceProdD) = {{"keep_dims", ATTR_DESC(keep_dims, AnyTraits<bool>())}}; | ||||
| OUTPUT_MAP(ReduceProdD) = {{0, OUTPUT_DESC(y)}}; | OUTPUT_MAP(ReduceProdD) = {{0, OUTPUT_DESC(y)}}; | ||||
| @@ -289,7 +288,7 @@ ATTR_MAP(CumprodD) = {{"exclusive", ATTR_DESC(exclusive, AnyTraits<bool>())}, | |||||
| {"reverse", ATTR_DESC(reverse, AnyTraits<bool>())}}; | {"reverse", ATTR_DESC(reverse, AnyTraits<bool>())}}; | ||||
| OUTPUT_MAP(CumprodD) = {{0, OUTPUT_DESC(y)}}; | OUTPUT_MAP(CumprodD) = {{0, OUTPUT_DESC(y)}}; | ||||
| // SoftmaxCrossEntropyWithLogits/ | |||||
| // SoftmaxCrossEntropyWithLogits | |||||
| INPUT_MAP(SoftmaxCrossEntropyWithLogits) = {{1, INPUT_DESC(features)}, {2, INPUT_DESC(labels)}}; | INPUT_MAP(SoftmaxCrossEntropyWithLogits) = {{1, INPUT_DESC(features)}, {2, INPUT_DESC(labels)}}; | ||||
| ATTR_MAP(SoftmaxCrossEntropyWithLogits) = EMPTY_ATTR_MAP; | ATTR_MAP(SoftmaxCrossEntropyWithLogits) = EMPTY_ATTR_MAP; | ||||
| OUTPUT_MAP(SoftmaxCrossEntropyWithLogits) = {{0, OUTPUT_DESC(loss)}, {1, OUTPUT_DESC(backprop)}}; | OUTPUT_MAP(SoftmaxCrossEntropyWithLogits) = {{0, OUTPUT_DESC(loss)}, {1, OUTPUT_DESC(backprop)}}; | ||||
| @@ -301,7 +300,7 @@ INPUT_ATTR_MAP(MeanGrad) = {{2, ATTR_DESC(mean_grad_output_shape_value, kOpForma | |||||
| ATTR_MAP(MeanGrad) = {{"mode", ATTR_DESC(mode, AnyTraits<int64_t>())}}; | ATTR_MAP(MeanGrad) = {{"mode", ATTR_DESC(mode, AnyTraits<int64_t>())}}; | ||||
| INPUT_MAP(SliceD) = {{1, INPUT_DESC(x)}}; | INPUT_MAP(SliceD) = {{1, INPUT_DESC(x)}}; | ||||
| INPUT_ATTR_MAP(SliceD) = {{2, ATTR_DESC(begin, AnyTraits<int>(), AnyTraits<std::vector<int64_t>>())}, | |||||
| INPUT_ATTR_MAP(SliceD) = {{2, ATTR_DESC(offsets, AnyTraits<int>(), AnyTraits<std::vector<int64_t>>())}, | |||||
| {3, ATTR_DESC(size, AnyTraits<int>(), AnyTraits<std::vector<int64_t>>())}}; | {3, ATTR_DESC(size, AnyTraits<int>(), AnyTraits<std::vector<int64_t>>())}}; | ||||
| ATTR_MAP(SliceD) = EMPTY_ATTR_MAP; | ATTR_MAP(SliceD) = EMPTY_ATTR_MAP; | ||||
| OUTPUT_MAP(SliceD) = {{0, OUTPUT_DESC(y)}}; | OUTPUT_MAP(SliceD) = {{0, OUTPUT_DESC(y)}}; | ||||
| @@ -411,42 +410,10 @@ ATTR_MAP(BoundingBoxDecode) = { | |||||
| }; | }; | ||||
| OUTPUT_MAP(BoundingBoxDecode) = {{0, OUTPUT_DESC(bboxes)}}; | OUTPUT_MAP(BoundingBoxDecode) = {{0, OUTPUT_DESC(bboxes)}}; | ||||
| #ifdef VALID_CODE | |||||
| // Less | |||||
| INPUT_MAP(Less) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(y)}}; | |||||
| ATTR_MAP(Less) = EMPTY_ATTR_MAP; | |||||
| OUTPUT_MAP(Less) = {{0, OUTPUT_DESC(z)}}; | |||||
| // Cast | |||||
| INPUT_MAP(Cast) = {{1, INPUT_DESC(x)}}; | |||||
| INPUT_ATTR_MAP(Cast) = {{2, ATTR_DESC(dst_type, AnyTraits<GEType>())}}; | |||||
| ATTR_MAP(Cast) = {{"Truncate", ATTR_DESC(truncate, AnyTraits<bool>())}}; | |||||
| OUTPUT_MAP(Cast) = {{0, OUTPUT_DESC(y)}}; | |||||
| // Minimum | |||||
| INPUT_MAP(Minimum) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(y)}}; | |||||
| ATTR_MAP(Minimum) = {{"alpha", ATTR_DESC(alpha, AnyTraits<float>())}, {"beta", ATTR_DESC(beta, AnyTraits<float>())}}; | |||||
| OUTPUT_MAP(Minimum) = {{0, OUTPUT_DESC(z)}}; | |||||
| // Sub | |||||
| INPUT_MAP(Sub) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}}; | |||||
| ATTR_MAP(Sub) = {{"alpha", ATTR_DESC(alpha, AnyTraits<float>())}, {"beta", ATTR_DESC(beta, AnyTraits<float>())}}; | |||||
| #endif | |||||
| // TopKV2 | |||||
| INPUT_MAP(TopKV2) = { | |||||
| {1, INPUT_DESC(input)}, | |||||
| {2, INPUT_DESC(k)}, | |||||
| }; | |||||
| ATTR_MAP(TopKV2) = {{"T", ATTR_DESC(T, AnyTraits<GEType>())}, {"sorted", ATTR_DESC(sorted, AnyTraits<bool>())}}; | |||||
| OUTPUT_MAP(TopKV2) = { | |||||
| {0, OUTPUT_DESC(values)}, | |||||
| {1, OUTPUT_DESC(indices)}, | |||||
| }; | |||||
| // TopK | |||||
| INPUT_MAP(TopK) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(k)}}; | |||||
| ATTR_MAP(TopK) = {{"sorted", ATTR_DESC(sorted, AnyTraits<bool>())}}; | |||||
| OUTPUT_MAP(TopK) = {{0, OUTPUT_DESC(values)}, {1, OUTPUT_DESC(indices)}}; | |||||
| // Multiply | // Multiply | ||||
| INPUT_MAP(Multiply) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(y)}}; | INPUT_MAP(Multiply) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(y)}}; | ||||
| @@ -485,17 +452,17 @@ INPUT_MAP(Iou) = {{1, INPUT_DESC(bboxes)}, {2, INPUT_DESC(gtboxes)}}; | |||||
| ATTR_MAP(Iou) = {{"mode", ATTR_DESC(mode, AnyTraits<std::string>())}}; | ATTR_MAP(Iou) = {{"mode", ATTR_DESC(mode, AnyTraits<std::string>())}}; | ||||
| OUTPUT_MAP(Iou) = {{0, OUTPUT_DESC(overlap)}}; | OUTPUT_MAP(Iou) = {{0, OUTPUT_DESC(overlap)}}; | ||||
| // ResizeNearestNeighborD | |||||
| INPUT_MAP(ResizeNearestNeighborD) = {{1, INPUT_DESC(images)}}; | |||||
| ATTR_MAP(ResizeNearestNeighborD) = { | |||||
| // ResizeNearestNeighborV2D | |||||
| INPUT_MAP(ResizeNearestNeighborV2D) = {{1, INPUT_DESC(x)}}; | |||||
| ATTR_MAP(ResizeNearestNeighborV2D) = { | |||||
| {"size", ATTR_DESC(size, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())}, | {"size", ATTR_DESC(size, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())}, | ||||
| {"align_corners", ATTR_DESC(align_corners, AnyTraits<bool>())}}; | {"align_corners", ATTR_DESC(align_corners, AnyTraits<bool>())}}; | ||||
| OUTPUT_MAP(ResizeNearestNeighborD) = {{0, OUTPUT_DESC(y)}}; | |||||
| OUTPUT_MAP(ResizeNearestNeighborV2D) = {{0, OUTPUT_DESC(y)}}; | |||||
| // ResizeNearestNeighborGrad | |||||
| INPUT_MAP(ResizeNearestNeighborGrad) = {{1, INPUT_DESC(grads)}, {2, INPUT_DESC(size)}}; | |||||
| ATTR_MAP(ResizeNearestNeighborGrad) = {{"align_corners", ATTR_DESC(align_corners, AnyTraits<bool>())}}; | |||||
| OUTPUT_MAP(ResizeNearestNeighborGrad) = {{0, OUTPUT_DESC(y)}}; | |||||
| // ResizeNearestNeighborV2Grad | |||||
| INPUT_MAP(ResizeNearestNeighborV2Grad) = {{1, INPUT_DESC(grads)}, {2, INPUT_DESC(size)}}; | |||||
| ATTR_MAP(ResizeNearestNeighborV2Grad) = {{"align_corners", ATTR_DESC(align_corners, AnyTraits<bool>())}}; | |||||
| OUTPUT_MAP(ResizeNearestNeighborV2Grad) = {{0, OUTPUT_DESC(y)}}; | |||||
| // ApplyAdam | // ApplyAdam | ||||
| INPUT_MAP(ApplyAdam) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(m)}, {3, INPUT_DESC(v)}, | INPUT_MAP(ApplyAdam) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(m)}, {3, INPUT_DESC(v)}, | ||||
| @@ -504,33 +471,38 @@ INPUT_MAP(ApplyAdam) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(m)}, | |||||
| {10, INPUT_DESC(grad)}}; | {10, INPUT_DESC(grad)}}; | ||||
| ATTR_MAP(ApplyAdam) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())}, | ATTR_MAP(ApplyAdam) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())}, | ||||
| {"use_nesterov", ATTR_DESC(use_nesterov, AnyTraits<bool>())}}; | {"use_nesterov", ATTR_DESC(use_nesterov, AnyTraits<bool>())}}; | ||||
| #ifdef ENABLE_GE | |||||
| OUTPUT_MAP(ApplyAdam) = {{0, OUTPUT_DESC(var)}, {1, OUTPUT_DESC(m)}, {2, OUTPUT_DESC(v)}}; | |||||
| #else | |||||
| OUTPUT_MAP(ApplyAdam) = {{0, OUTPUT_DESC(var)}}; | OUTPUT_MAP(ApplyAdam) = {{0, OUTPUT_DESC(var)}}; | ||||
| #endif | |||||
| // ApplyAdamD | |||||
| INPUT_MAP(ApplyAdamD) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(m)}, {3, INPUT_DESC(v)}, | |||||
| {4, INPUT_DESC(beta1_power)}, {5, INPUT_DESC(beta2_power)}, {6, INPUT_DESC(lr)}, | |||||
| {7, INPUT_DESC(beta1)}, {8, INPUT_DESC(beta2)}, {9, INPUT_DESC(epsilon)}, | |||||
| {10, INPUT_DESC(grad)}}; | |||||
| ATTR_MAP(ApplyAdamD) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())}, | |||||
| {"use_nesterov", ATTR_DESC(use_nesterov, AnyTraits<bool>())}}; | |||||
| OUTPUT_MAP(ApplyAdamD) = {{0, OUTPUT_DESC(var)}, {1, OUTPUT_DESC(m)}, {2, OUTPUT_DESC(v)}}; | |||||
| // Relu6 | // Relu6 | ||||
| INPUT_MAP(Relu6) = {{1, INPUT_DESC(features)}}; | |||||
| INPUT_MAP(Relu6) = {{1, INPUT_DESC(x)}}; | |||||
| ATTR_MAP(Relu6) = EMPTY_ATTR_MAP; | ATTR_MAP(Relu6) = EMPTY_ATTR_MAP; | ||||
| OUTPUT_MAP(Relu6) = {{0, OUTPUT_DESC(activations)}}; | |||||
| OUTPUT_MAP(Relu6) = {{0, OUTPUT_DESC(y)}}; | |||||
| // Relu6Grad | // Relu6Grad | ||||
| INPUT_MAP(Relu6Grad) = {{1, INPUT_DESC(gradients)}, {2, INPUT_DESC(features)}}; | INPUT_MAP(Relu6Grad) = {{1, INPUT_DESC(gradients)}, {2, INPUT_DESC(features)}}; | ||||
| ATTR_MAP(Relu6Grad) = EMPTY_ATTR_MAP; | ATTR_MAP(Relu6Grad) = EMPTY_ATTR_MAP; | ||||
| OUTPUT_MAP(Relu6Grad) = {{0, OUTPUT_DESC(backprops)}}; | OUTPUT_MAP(Relu6Grad) = {{0, OUTPUT_DESC(backprops)}}; | ||||
| // ResizeBilinearGrad | |||||
| INPUT_MAP(ResizeBilinearGrad) = {{1, INPUT_DESC(grads)}, {2, INPUT_DESC(original_image)}}; | |||||
| ATTR_MAP(ResizeBilinearGrad) = {{"align_corners", ATTR_DESC(align_corners, AnyTraits<bool>())}}; | |||||
| OUTPUT_MAP(ResizeBilinearGrad) = {{0, OUTPUT_DESC(y)}}; | |||||
| // ResizeBilinearV2Grad | |||||
| INPUT_MAP(ResizeBilinearV2Grad) = {{1, INPUT_DESC(grads)}, {2, INPUT_DESC(original_image)}}; | |||||
| ATTR_MAP(ResizeBilinearV2Grad) = {{"align_corners", ATTR_DESC(align_corners, AnyTraits<bool>())}}; | |||||
| OUTPUT_MAP(ResizeBilinearV2Grad) = {{0, OUTPUT_DESC(y)}}; | |||||
| // ResizeBilinear | |||||
| INPUT_MAP(ResizeBilinearD) = {{1, INPUT_DESC(images)}}; | |||||
| ATTR_MAP(ResizeBilinearD) = { | |||||
| // ResizeBilinearV2D | |||||
| INPUT_MAP(ResizeBilinearV2D) = {{1, INPUT_DESC(x)}}; | |||||
| ATTR_MAP(ResizeBilinearV2D) = { | |||||
| {"size", ATTR_DESC(size, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())}, | {"size", ATTR_DESC(size, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())}, | ||||
| {"align_corners", ATTR_DESC(align_corners, AnyTraits<bool>())}}; | {"align_corners", ATTR_DESC(align_corners, AnyTraits<bool>())}}; | ||||
| OUTPUT_MAP(ResizeBilinearD) = {{0, OUTPUT_DESC(y)}}; | |||||
| OUTPUT_MAP(ResizeBilinearV2D) = {{0, OUTPUT_DESC(y)}}; | |||||
| // ZerosLike | // ZerosLike | ||||
| INPUT_MAP(ZerosLike) = {{1, INPUT_DESC(x)}}; | INPUT_MAP(ZerosLike) = {{1, INPUT_DESC(x)}}; | ||||
| @@ -549,9 +521,9 @@ OUTPUT_MAP(NMSWithMask) = { | |||||
| {0, OUTPUT_DESC(selected_boxes)}, {1, OUTPUT_DESC(selected_idx)}, {2, OUTPUT_DESC(selected_mask)}}; | {0, OUTPUT_DESC(selected_boxes)}, {1, OUTPUT_DESC(selected_idx)}, {2, OUTPUT_DESC(selected_mask)}}; | ||||
| // Unpack | // Unpack | ||||
| INPUT_MAP(Unpack) = {{1, INPUT_DESC(value)}}; | |||||
| INPUT_MAP(Unpack) = {{1, INPUT_DESC(x)}}; | |||||
| ATTR_MAP(Unpack) = {{"axis", ATTR_DESC(axis, AnyTraits<int>())}, {"num", ATTR_DESC(num, AnyTraits<int>())}}; | ATTR_MAP(Unpack) = {{"axis", ATTR_DESC(axis, AnyTraits<int>())}, {"num", ATTR_DESC(num, AnyTraits<int>())}}; | ||||
| DYN_OUTPUT_MAP(Unpack) = {{0, DYN_OUTPUT_DESC(output)}}; | |||||
| DYN_OUTPUT_MAP(Unpack) = {{0, DYN_OUTPUT_DESC(y)}}; | |||||
| // ScatterNdUpdate | // ScatterNdUpdate | ||||
| INPUT_MAP(ScatterNdUpdate) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(indices)}, {3, INPUT_DESC(updates)}}; | INPUT_MAP(ScatterNdUpdate) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(indices)}, {3, INPUT_DESC(updates)}}; | ||||
| @@ -584,8 +556,8 @@ INPUT_MAP(SigmoidCrossEntropyWithLogitsGrad) = { | |||||
| ATTR_MAP(SigmoidCrossEntropyWithLogitsGrad) = EMPTY_ATTR_MAP; | ATTR_MAP(SigmoidCrossEntropyWithLogitsGrad) = EMPTY_ATTR_MAP; | ||||
| OUTPUT_MAP(SigmoidCrossEntropyWithLogitsGrad) = {{0, OUTPUT_DESC(gradient)}}; | OUTPUT_MAP(SigmoidCrossEntropyWithLogitsGrad) = {{0, OUTPUT_DESC(gradient)}}; | ||||
| // ScatterNd | |||||
| INPUT_MAP(ScatterNdD) = {{1, INPUT_DESC(indices)}, {2, INPUT_DESC(updates)}}; | |||||
| // ScatterNdD | |||||
| INPUT_MAP(ScatterNdD) = {{1, INPUT_DESC(indices)}, {2, INPUT_DESC(x)}}; | |||||
| INPUT_ATTR_MAP(ScatterNdD) = { | INPUT_ATTR_MAP(ScatterNdD) = { | ||||
| {3, ATTR_DESC(shape, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())}}; | {3, ATTR_DESC(shape, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())}}; | ||||
| ATTR_MAP(ScatterNdD) = EMPTY_ATTR_MAP; | ATTR_MAP(ScatterNdD) = EMPTY_ATTR_MAP; | ||||
| @@ -607,13 +579,13 @@ ATTR_MAP(MirrorPadGrad) = {{"mode", ATTR_DESC(mode, AnyTraits<std::string>())}}; | |||||
| OUTPUT_MAP(MirrorPadGrad) = {{0, OUTPUT_DESC(y)}}; | OUTPUT_MAP(MirrorPadGrad) = {{0, OUTPUT_DESC(y)}}; | ||||
| // GatherNd | // GatherNd | ||||
| INPUT_MAP(GatherNd) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}}; | |||||
| INPUT_MAP(GatherNd) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(indices)}}; | |||||
| ATTR_MAP(GatherNd) = EMPTY_ATTR_MAP; | ATTR_MAP(GatherNd) = EMPTY_ATTR_MAP; | ||||
| OUTPUT_MAP(GatherNd) = {{0, OUTPUT_DESC(y)}}; | OUTPUT_MAP(GatherNd) = {{0, OUTPUT_DESC(y)}}; | ||||
| // ROIAlign | // ROIAlign | ||||
| INPUT_MAP(ROIAlign) = {{1, INPUT_DESC(features)}, {2, INPUT_DESC(rois)}}; | INPUT_MAP(ROIAlign) = {{1, INPUT_DESC(features)}, {2, INPUT_DESC(rois)}}; | ||||
| OUTPUT_MAP(ROIAlign) = {{0, OUTPUT_DESC(output)}}; | |||||
| OUTPUT_MAP(ROIAlign) = {{0, OUTPUT_DESC(y)}}; | |||||
| ATTR_MAP(ROIAlign) = {{"pooled_height", ATTR_DESC(pooled_height, AnyTraits<int>())}, | ATTR_MAP(ROIAlign) = {{"pooled_height", ATTR_DESC(pooled_height, AnyTraits<int>())}, | ||||
| {"pooled_width", ATTR_DESC(pooled_width, AnyTraits<int>())}, | {"pooled_width", ATTR_DESC(pooled_width, AnyTraits<int>())}, | ||||
| {"spatial_scale", ATTR_DESC(spatial_scale, AnyTraits<float>())}, | {"spatial_scale", ATTR_DESC(spatial_scale, AnyTraits<float>())}, | ||||
| @@ -632,13 +604,13 @@ ATTR_MAP(ROIAlignGrad) = { | |||||
| // ArgMaxD | // ArgMaxD | ||||
| INPUT_MAP(ArgMaxD) = {{1, INPUT_DESC(x)}}; | INPUT_MAP(ArgMaxD) = {{1, INPUT_DESC(x)}}; | ||||
| ATTR_MAP(ArgMaxD) = {{"axis", ATTR_DESC(dimension, AnyTraits<int>())}, | ATTR_MAP(ArgMaxD) = {{"axis", ATTR_DESC(dimension, AnyTraits<int>())}, | ||||
| {"output_type", ATTR_DESC(output_type, AnyTraits<GEType>())}}; | |||||
| {"output_type", ATTR_DESC(dtype, AnyTraits<GEType>())}}; | |||||
| OUTPUT_MAP(ArgMaxD) = {{0, OUTPUT_DESC(y)}}; | OUTPUT_MAP(ArgMaxD) = {{0, OUTPUT_DESC(y)}}; | ||||
| // ArgMinD | // ArgMinD | ||||
| INPUT_MAP(ArgMinD) = {{1, INPUT_DESC(x)}}; | INPUT_MAP(ArgMinD) = {{1, INPUT_DESC(x)}}; | ||||
| ATTR_MAP(ArgMinD) = {{"axis", ATTR_DESC(dimension, AnyTraits<int>())}, | ATTR_MAP(ArgMinD) = {{"axis", ATTR_DESC(dimension, AnyTraits<int>())}, | ||||
| {"output_type", ATTR_DESC(output_type, AnyTraits<GEType>())}}; | |||||
| {"output_type", ATTR_DESC(dtype, AnyTraits<GEType>())}}; | |||||
| OUTPUT_MAP(ArgMinD) = {{0, OUTPUT_DESC(y)}}; | OUTPUT_MAP(ArgMinD) = {{0, OUTPUT_DESC(y)}}; | ||||
| // ArgMaxWithValue | // ArgMaxWithValue | ||||
| @@ -656,14 +628,14 @@ OUTPUT_MAP(ArgMinWithValue) = {{0, OUTPUT_DESC(indice)}, {1, OUTPUT_DESC(values) | |||||
| // ReduceAllD | // ReduceAllD | ||||
| INPUT_MAP(ReduceAllD) = {{1, INPUT_DESC(x)}}; | INPUT_MAP(ReduceAllD) = {{1, INPUT_DESC(x)}}; | ||||
| INPUT_ATTR_MAP(ReduceAllD) = { | INPUT_ATTR_MAP(ReduceAllD) = { | ||||
| {2, ATTR_DESC(axis, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())}}; | |||||
| {2, ATTR_DESC(axes, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())}}; | |||||
| ATTR_MAP(ReduceAllD) = {{"keep_dims", ATTR_DESC(keep_dims, AnyTraits<bool>())}}; | ATTR_MAP(ReduceAllD) = {{"keep_dims", ATTR_DESC(keep_dims, AnyTraits<bool>())}}; | ||||
| OUTPUT_MAP(ReduceAllD) = {{0, OUTPUT_DESC(y)}}; | OUTPUT_MAP(ReduceAllD) = {{0, OUTPUT_DESC(y)}}; | ||||
| // ReduceMeanD | // ReduceMeanD | ||||
| INPUT_MAP(ReduceMeanD) = {{1, INPUT_DESC(x)}}; | INPUT_MAP(ReduceMeanD) = {{1, INPUT_DESC(x)}}; | ||||
| INPUT_ATTR_MAP(ReduceMeanD) = { | INPUT_ATTR_MAP(ReduceMeanD) = { | ||||
| {2, ATTR_DESC(axis, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())}}; | |||||
| {2, ATTR_DESC(axes, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())}}; | |||||
| ATTR_MAP(ReduceMeanD) = {{"keep_dims", ATTR_DESC(keep_dims, AnyTraits<bool>())}}; | ATTR_MAP(ReduceMeanD) = {{"keep_dims", ATTR_DESC(keep_dims, AnyTraits<bool>())}}; | ||||
| OUTPUT_MAP(ReduceMeanD) = {{0, OUTPUT_DESC(y)}}; | OUTPUT_MAP(ReduceMeanD) = {{0, OUTPUT_DESC(y)}}; | ||||
| @@ -708,11 +680,12 @@ INPUT_MAP(BiasAddGrad) = {{1, INPUT_DESC(x)}}; | |||||
| ATTR_MAP(BiasAddGrad) = {{"data_format", ATTR_DESC(data_format, AnyTraits<std::string>())}}; | ATTR_MAP(BiasAddGrad) = {{"data_format", ATTR_DESC(data_format, AnyTraits<std::string>())}}; | ||||
| OUTPUT_MAP(BiasAddGrad) = {{0, OUTPUT_DESC(y)}}; | OUTPUT_MAP(BiasAddGrad) = {{0, OUTPUT_DESC(y)}}; | ||||
| // maxpoolgrad | |||||
| // MaxPoolGrad | |||||
| INPUT_MAP(MaxPoolGrad) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}, {3, INPUT_DESC(grad)}}; | INPUT_MAP(MaxPoolGrad) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}, {3, INPUT_DESC(grad)}}; | ||||
| ATTR_MAP(MaxPoolGrad) = {{"ksize", ATTR_DESC(ksize, AnyTraits<int>(), AnyTraits<std::vector<int64_t>>())}, | ATTR_MAP(MaxPoolGrad) = {{"ksize", ATTR_DESC(ksize, AnyTraits<int>(), AnyTraits<std::vector<int64_t>>())}, | ||||
| {"strides", ATTR_DESC(strides, AnyTraits<int>(), AnyTraits<std::vector<int64_t>>())}, | {"strides", ATTR_DESC(strides, AnyTraits<int>(), AnyTraits<std::vector<int64_t>>())}, | ||||
| {"padding", ATTR_DESC(padding, AnyTraits<std::string>())}}; | |||||
| {"padding", ATTR_DESC(padding, AnyTraits<std::string>())}, | |||||
| {"data_format", ATTR_DESC(data_format, AnyTraits<std::string>())}}; | |||||
| OUTPUT_MAP(MaxPoolGrad) = {{0, OUTPUT_DESC(y)}}; | OUTPUT_MAP(MaxPoolGrad) = {{0, OUTPUT_DESC(y)}}; | ||||
| // avgpoolgrad | // avgpoolgrad | ||||
| @@ -738,7 +711,7 @@ ATTR_MAP(MaxPoolGradWithArgmax) = {{"ksize", ATTR_DESC(ksize, AnyTraits<int>(), | |||||
| OUTPUT_MAP(MaxPoolGradWithArgmax) = {{0, OUTPUT_DESC(y)}}; | OUTPUT_MAP(MaxPoolGradWithArgmax) = {{0, OUTPUT_DESC(y)}}; | ||||
| // ExtractImagePatches | // ExtractImagePatches | ||||
| INPUT_MAP(ExtractImagePatches) = {{1, INPUT_DESC(images)}}; | |||||
| INPUT_MAP(ExtractImagePatches) = {{1, INPUT_DESC(x)}}; | |||||
| ATTR_MAP(ExtractImagePatches) = {{"ksizes", ATTR_DESC(ksizes, AnyTraits<int>(), AnyTraits<std::vector<int64_t>>())}, | ATTR_MAP(ExtractImagePatches) = {{"ksizes", ATTR_DESC(ksizes, AnyTraits<int>(), AnyTraits<std::vector<int64_t>>())}, | ||||
| {"strides", ATTR_DESC(strides, AnyTraits<int>(), AnyTraits<std::vector<int64_t>>())}, | {"strides", ATTR_DESC(strides, AnyTraits<int>(), AnyTraits<std::vector<int64_t>>())}, | ||||
| {"rates", ATTR_DESC(rates, AnyTraits<int>(), AnyTraits<std::vector<int64_t>>())}, | {"rates", ATTR_DESC(rates, AnyTraits<int>(), AnyTraits<std::vector<int64_t>>())}, | ||||
| @@ -751,28 +724,34 @@ ATTR_MAP(Conv2D) = { | |||||
| {"stride", ATTR_DESC(strides, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())}, | {"stride", ATTR_DESC(strides, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())}, | ||||
| {"pad_list", ATTR_DESC(pads, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())}, | {"pad_list", ATTR_DESC(pads, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())}, | ||||
| {"dilation", ATTR_DESC(dilations, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())}, | {"dilation", ATTR_DESC(dilations, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())}, | ||||
| {"data_format", ATTR_DESC(data_format, AnyTraits<std::string>())}, | |||||
| {"group", ATTR_DESC(groups, AnyTraits<int>())}, | |||||
| }; | }; | ||||
| OUTPUT_MAP(Conv2D) = {{0, OUTPUT_DESC(y)}}; | OUTPUT_MAP(Conv2D) = {{0, OUTPUT_DESC(y)}}; | ||||
| // Conv2DBackpropInputD | // Conv2DBackpropInputD | ||||
| INPUT_MAP(Conv2DBackpropInputD) = {{1, INPUT_DESC(out_backprop)}, {2, INPUT_DESC(filters)}}; | |||||
| INPUT_MAP(Conv2DBackpropInputD) = {{1, INPUT_DESC(out_backprop)}, {2, INPUT_DESC(filter)}}; | |||||
| INPUT_ATTR_MAP(Conv2DBackpropInputD) = { | INPUT_ATTR_MAP(Conv2DBackpropInputD) = { | ||||
| {3, ATTR_DESC(input_sizes, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())}}; | |||||
| {3, ATTR_DESC(input_size, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())}}; | |||||
| ATTR_MAP(Conv2DBackpropInputD) = { | ATTR_MAP(Conv2DBackpropInputD) = { | ||||
| {"pad_list", ATTR_DESC(pads, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())}, | {"pad_list", ATTR_DESC(pads, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())}, | ||||
| {"stride", ATTR_DESC(strides, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())}, | |||||
| {"stride", ATTR_DESC(strides, "pad", AnyTraits<std::vector<int64_t>>())}, | |||||
| {"dilation", ATTR_DESC(dilations, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())}, | {"dilation", ATTR_DESC(dilations, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())}, | ||||
| {"data_format", ATTR_DESC(data_format, AnyTraits<std::string>())}, | |||||
| {"group", ATTR_DESC(groups, AnyTraits<int>())}, | |||||
| }; | }; | ||||
| OUTPUT_MAP(Conv2DBackpropInputD) = {{0, OUTPUT_DESC(y)}}; | OUTPUT_MAP(Conv2DBackpropInputD) = {{0, OUTPUT_DESC(y)}}; | ||||
| // Conv2DBackpropFilterD | // Conv2DBackpropFilterD | ||||
| INPUT_MAP(Conv2DBackpropFilterD) = {{1, INPUT_DESC(out_backprop)}, {2, INPUT_DESC(x)}}; | INPUT_MAP(Conv2DBackpropFilterD) = {{1, INPUT_DESC(out_backprop)}, {2, INPUT_DESC(x)}}; | ||||
| INPUT_ATTR_MAP(Conv2DBackpropFilterD) = { | INPUT_ATTR_MAP(Conv2DBackpropFilterD) = { | ||||
| {3, ATTR_DESC(filter_sizes, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())}}; | |||||
| {3, ATTR_DESC(filter_size, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())}}; | |||||
| ATTR_MAP(Conv2DBackpropFilterD) = { | ATTR_MAP(Conv2DBackpropFilterD) = { | ||||
| {"pad_list", ATTR_DESC(pads, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())}, | {"pad_list", ATTR_DESC(pads, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())}, | ||||
| {"stride", ATTR_DESC(strides, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())}, | |||||
| {"stride", ATTR_DESC(strides, "pad", AnyTraits<std::vector<int64_t>>())}, | |||||
| {"dilation", ATTR_DESC(dilations, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())}, | {"dilation", ATTR_DESC(dilations, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())}, | ||||
| {"data_format", ATTR_DESC(data_format, AnyTraits<std::string>())}, | |||||
| {"group", ATTR_DESC(groups, AnyTraits<int>())}, | |||||
| }; | }; | ||||
| OUTPUT_MAP(Conv2DBackpropFilterD) = {{0, OUTPUT_DESC(y)}}; | OUTPUT_MAP(Conv2DBackpropFilterD) = {{0, OUTPUT_DESC(y)}}; | ||||
| @@ -810,8 +789,8 @@ OUTPUT_MAP(DepthwiseConv2DBackpropFilterD) = {{0, OUTPUT_DESC(filter_grad)}}; | |||||
| // MatMul | // MatMul | ||||
| INPUT_MAP(MatMul) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}}; | INPUT_MAP(MatMul) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}}; | ||||
| ATTR_MAP(MatMul) = {{"transpose_a", ATTR_DESC(transpose_a, AnyTraits<bool>())}, | |||||
| {"transpose_b", ATTR_DESC(transpose_b, AnyTraits<bool>())}}; | |||||
| ATTR_MAP(MatMul) = {{"transpose_a", ATTR_DESC(transpose_x1, AnyTraits<bool>())}, | |||||
| {"transpose_b", ATTR_DESC(transpose_x2, AnyTraits<bool>())}}; | |||||
| OUTPUT_MAP(MatMul) = {{0, OUTPUT_DESC(y)}}; | OUTPUT_MAP(MatMul) = {{0, OUTPUT_DESC(y)}}; | ||||
| // Merge | // Merge | ||||
| @@ -858,10 +837,10 @@ ATTR_MAP(Sub) = EMPTY_ATTR_MAP; | |||||
| OUTPUT_MAP(Sub) = {{0, OUTPUT_DESC(y)}}; | OUTPUT_MAP(Sub) = {{0, OUTPUT_DESC(y)}}; | ||||
| // SplitD | // SplitD | ||||
| INPUT_MAP(SplitD) = {{1, INPUT_DESC(value)}}; | |||||
| INPUT_MAP(SplitD) = {{1, INPUT_DESC(x)}}; | |||||
| ATTR_MAP(SplitD) = {{"axis", ATTR_DESC(split_dim, AnyTraits<int>())}, | ATTR_MAP(SplitD) = {{"axis", ATTR_DESC(split_dim, AnyTraits<int>())}, | ||||
| {"output_num", ATTR_DESC(num_split, AnyTraits<int>())}}; | {"output_num", ATTR_DESC(num_split, AnyTraits<int>())}}; | ||||
| DYN_OUTPUT_MAP(SplitD) = {{0, DYN_OUTPUT_DESC(output)}}; | |||||
| DYN_OUTPUT_MAP(SplitD) = {{0, DYN_OUTPUT_DESC(y)}}; | |||||
| // Neg | // Neg | ||||
| INPUT_MAP(Neg) = {{1, INPUT_DESC(x)}}; | INPUT_MAP(Neg) = {{1, INPUT_DESC(x)}}; | ||||
| @@ -888,12 +867,12 @@ OUTPUT_MAP(Pack) = {{0, OUTPUT_DESC(y)}}; | |||||
| // ConcatD | // ConcatD | ||||
| INPUT_MAP(ConcatD) = EMPTY_INPUT_MAP; | INPUT_MAP(ConcatD) = EMPTY_INPUT_MAP; | ||||
| DYN_INPUT_MAP(ConcatD) = {{1, DYN_INPUT_DESC(input_values)}}; | |||||
| DYN_INPUT_MAP(ConcatD) = {{1, DYN_INPUT_DESC(x)}}; | |||||
| ATTR_MAP(ConcatD) = { | ATTR_MAP(ConcatD) = { | ||||
| {"axis", ATTR_DESC(concat_dim, AnyTraits<int>())}, | {"axis", ATTR_DESC(concat_dim, AnyTraits<int>())}, | ||||
| {"inputNums", ATTR_DESC(N, AnyTraits<int>())}, | {"inputNums", ATTR_DESC(N, AnyTraits<int>())}, | ||||
| }; | }; | ||||
| OUTPUT_MAP(ConcatD) = {{0, OUTPUT_DESC(output_data)}}; | |||||
| OUTPUT_MAP(ConcatD) = {{0, OUTPUT_DESC(y)}}; | |||||
| // Less | // Less | ||||
| INPUT_MAP(Less) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}}; | INPUT_MAP(Less) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}}; | ||||
| @@ -928,14 +907,14 @@ OUTPUT_MAP(TanhGrad) = {{0, OUTPUT_DESC(z)}}; | |||||
| // ReduceMinD | // ReduceMinD | ||||
| INPUT_MAP(ReduceMinD) = {{1, INPUT_DESC(x)}}; | INPUT_MAP(ReduceMinD) = {{1, INPUT_DESC(x)}}; | ||||
| INPUT_ATTR_MAP(ReduceMinD) = { | INPUT_ATTR_MAP(ReduceMinD) = { | ||||
| {2, ATTR_DESC(axis, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())}}; | |||||
| {2, ATTR_DESC(axes, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())}}; | |||||
| ATTR_MAP(ReduceMinD) = {{"keep_dims", ATTR_DESC(keep_dims, AnyTraits<bool>())}}; | ATTR_MAP(ReduceMinD) = {{"keep_dims", ATTR_DESC(keep_dims, AnyTraits<bool>())}}; | ||||
| OUTPUT_MAP(ReduceMinD) = {{0, OUTPUT_DESC(y)}}; | OUTPUT_MAP(ReduceMinD) = {{0, OUTPUT_DESC(y)}}; | ||||
| // ReduceMaxD | // ReduceMaxD | ||||
| INPUT_MAP(ReduceMaxD) = {{1, INPUT_DESC(x)}}; | INPUT_MAP(ReduceMaxD) = {{1, INPUT_DESC(x)}}; | ||||
| INPUT_ATTR_MAP(ReduceMaxD) = { | INPUT_ATTR_MAP(ReduceMaxD) = { | ||||
| {2, ATTR_DESC(axis, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())}}; | |||||
| {2, ATTR_DESC(axes, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())}}; | |||||
| ATTR_MAP(ReduceMaxD) = {{"keep_dims", ATTR_DESC(keep_dims, AnyTraits<bool>())}}; | ATTR_MAP(ReduceMaxD) = {{"keep_dims", ATTR_DESC(keep_dims, AnyTraits<bool>())}}; | ||||
| OUTPUT_MAP(ReduceMaxD) = {{0, OUTPUT_DESC(y)}}; | OUTPUT_MAP(ReduceMaxD) = {{0, OUTPUT_DESC(y)}}; | ||||
| @@ -1020,11 +999,11 @@ INPUT_MAP(LessEqual) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}}; | |||||
| ATTR_MAP(LessEqual) = EMPTY_ATTR_MAP; | ATTR_MAP(LessEqual) = EMPTY_ATTR_MAP; | ||||
| OUTPUT_MAP(LessEqual) = {{0, OUTPUT_DESC(y)}}; | OUTPUT_MAP(LessEqual) = {{0, OUTPUT_DESC(y)}}; | ||||
| // LogSoftmax | |||||
| INPUT_MAP(LogSoftmax) = {{1, INPUT_DESC(logits)}}; | |||||
| ATTR_MAP(LogSoftmax) = { | |||||
| {"axis", ATTR_DESC(axis, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())}}; | |||||
| OUTPUT_MAP(LogSoftmax) = {{0, OUTPUT_DESC(logsoftmax)}}; | |||||
| // LogSoftmaxV2 | |||||
| INPUT_MAP(LogSoftmaxV2) = {{1, INPUT_DESC(logits)}}; | |||||
| ATTR_MAP(LogSoftmaxV2) = { | |||||
| {"axis", ATTR_DESC(axes, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())}}; | |||||
| OUTPUT_MAP(LogSoftmaxV2) = {{0, OUTPUT_DESC(logsoftmax)}}; | |||||
| // RandomChoiceWithMask | // RandomChoiceWithMask | ||||
| INPUT_MAP(RandomChoiceWithMask) = {{1, INPUT_DESC(x)}}; | INPUT_MAP(RandomChoiceWithMask) = {{1, INPUT_DESC(x)}}; | ||||
| @@ -1106,8 +1085,8 @@ OUTPUT_MAP(LayerNormGrad) = {{0, OUTPUT_DESC(pd_x)}, {1, OUTPUT_DESC(pd_gamma)}, | |||||
| // BatchMatMul | // BatchMatMul | ||||
| INPUT_MAP(BatchMatMul) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}}; | INPUT_MAP(BatchMatMul) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}}; | ||||
| ATTR_MAP(BatchMatMul) = {{"transpose_x1", ATTR_DESC(adj_x, AnyTraits<bool>())}, | |||||
| {"transpose_x2", ATTR_DESC(adj_y, AnyTraits<bool>())}}; | |||||
| ATTR_MAP(BatchMatMul) = {{"transpose_x1", ATTR_DESC(adj_x1, AnyTraits<bool>())}, | |||||
| {"transpose_x2", ATTR_DESC(adj_x2, AnyTraits<bool>())}}; | |||||
| OUTPUT_MAP(BatchMatMul) = {{0, OUTPUT_DESC(y)}}; | OUTPUT_MAP(BatchMatMul) = {{0, OUTPUT_DESC(y)}}; | ||||
| // DropoutDoMask | // DropoutDoMask | ||||
| @@ -1156,7 +1135,20 @@ INPUT_MAP(SparseApplyAdagradD) = { | |||||
| {1, INPUT_DESC(var)}, {2, INPUT_DESC(accum)}, {3, INPUT_DESC(grad)}, {4, INPUT_DESC(indices)}}; | {1, INPUT_DESC(var)}, {2, INPUT_DESC(accum)}, {3, INPUT_DESC(grad)}, {4, INPUT_DESC(indices)}}; | ||||
| ATTR_MAP(SparseApplyAdagradD) = {{"lr", ATTR_DESC(lr, AnyTraits<float>())}, | ATTR_MAP(SparseApplyAdagradD) = {{"lr", ATTR_DESC(lr, AnyTraits<float>())}, | ||||
| {"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())}}; | {"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())}}; | ||||
| OUTPUT_MAP(SparseApplyAdagradD) = {{0, OUTPUT_DESC(var)}}; | |||||
| OUTPUT_MAP(SparseApplyAdagradD) = {{0, OUTPUT_DESC(var)}, {1, OUTPUT_DESC(accum)}}; | |||||
| // SparseApplyFtrlD | |||||
| INPUT_MAP(SparseApplyFtrlD) = {{1, INPUT_DESC(var)}, | |||||
| {2, INPUT_DESC(accum)}, | |||||
| {3, INPUT_DESC(linear)}, | |||||
| {4, INPUT_DESC(grad)}, | |||||
| {5, INPUT_DESC(indices)}}; | |||||
| ATTR_MAP(SparseApplyFtrlD) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())}, | |||||
| {"lr", ATTR_DESC(lr, AnyTraits<float>())}, | |||||
| {"l1", ATTR_DESC(l1, AnyTraits<float>())}, | |||||
| {"l2", ATTR_DESC(l2, AnyTraits<float>())}, | |||||
| {"lr_power", ATTR_DESC(lr_power, AnyTraits<float>())}}; | |||||
| OUTPUT_MAP(SparseApplyFtrlD) = {{0, OUTPUT_DESC(var)}}; | |||||
| // SpaceToDepth | // SpaceToDepth | ||||
| INPUT_MAP(SpaceToDepth) = {{1, INPUT_DESC(x)}}; | INPUT_MAP(SpaceToDepth) = {{1, INPUT_DESC(x)}}; | ||||
| @@ -114,20 +114,22 @@ DECLARE_OP_ADAPTER(Reshape) | |||||
| DECLARE_OP_USE_OUTPUT(Reshape) | DECLARE_OP_USE_OUTPUT(Reshape) | ||||
| DECLARE_OP_ADAPTER(Iou) | DECLARE_OP_ADAPTER(Iou) | ||||
| DECLARE_OP_USE_OUTPUT(Iou) | DECLARE_OP_USE_OUTPUT(Iou) | ||||
| DECLARE_OP_ADAPTER(ResizeNearestNeighborD) | |||||
| DECLARE_OP_USE_OUTPUT(ResizeNearestNeighborD) | |||||
| DECLARE_OP_ADAPTER(ResizeNearestNeighborGrad) | |||||
| DECLARE_OP_USE_OUTPUT(ResizeNearestNeighborGrad) | |||||
| DECLARE_OP_ADAPTER(ResizeNearestNeighborV2D) | |||||
| DECLARE_OP_USE_OUTPUT(ResizeNearestNeighborV2D) | |||||
| DECLARE_OP_ADAPTER(ResizeNearestNeighborV2Grad) | |||||
| DECLARE_OP_USE_OUTPUT(ResizeNearestNeighborV2Grad) | |||||
| DECLARE_OP_ADAPTER(ApplyAdam) | DECLARE_OP_ADAPTER(ApplyAdam) | ||||
| DECLARE_OP_USE_OUTPUT(ApplyAdam) | DECLARE_OP_USE_OUTPUT(ApplyAdam) | ||||
| DECLARE_OP_ADAPTER(ApplyAdamD) | |||||
| DECLARE_OP_USE_OUTPUT(ApplyAdamD) | |||||
| DECLARE_OP_ADAPTER(Relu6) | DECLARE_OP_ADAPTER(Relu6) | ||||
| DECLARE_OP_USE_OUTPUT(Relu6) | DECLARE_OP_USE_OUTPUT(Relu6) | ||||
| DECLARE_OP_ADAPTER(Relu6Grad) | DECLARE_OP_ADAPTER(Relu6Grad) | ||||
| DECLARE_OP_USE_OUTPUT(Relu6Grad) | DECLARE_OP_USE_OUTPUT(Relu6Grad) | ||||
| DECLARE_OP_ADAPTER(ResizeBilinearD) | |||||
| DECLARE_OP_USE_OUTPUT(ResizeBilinearD) | |||||
| DECLARE_OP_ADAPTER(ResizeBilinearGrad) | |||||
| DECLARE_OP_USE_OUTPUT(ResizeBilinearGrad) | |||||
| DECLARE_OP_ADAPTER(ResizeBilinearV2D) | |||||
| DECLARE_OP_USE_OUTPUT(ResizeBilinearV2D) | |||||
| DECLARE_OP_ADAPTER(ResizeBilinearV2Grad) | |||||
| DECLARE_OP_USE_OUTPUT(ResizeBilinearV2Grad) | |||||
| DECLARE_OP_ADAPTER(ZerosLike) | DECLARE_OP_ADAPTER(ZerosLike) | ||||
| DECLARE_OP_USE_OUTPUT(ZerosLike) | DECLARE_OP_USE_OUTPUT(ZerosLike) | ||||
| DECLARE_OP_ADAPTER(OnesLike) | DECLARE_OP_ADAPTER(OnesLike) | ||||
| @@ -213,8 +215,8 @@ DECLARE_OP_USE_OUTPUT(Merge) | |||||
| DECLARE_OP_ADAPTER(Switch) | DECLARE_OP_ADAPTER(Switch) | ||||
| DECLARE_OP_USE_OUTPUT(Switch) | DECLARE_OP_USE_OUTPUT(Switch) | ||||
| DECLARE_OP_ADAPTER(TopKV2) | |||||
| DECLARE_OP_USE_OUTPUT(TopKV2) | |||||
| DECLARE_OP_ADAPTER(TopK) | |||||
| DECLARE_OP_USE_OUTPUT(TopK) | |||||
| DECLARE_OP_ADAPTER(RealDiv) | DECLARE_OP_ADAPTER(RealDiv) | ||||
| DECLARE_OP_USE_OUTPUT(RealDiv) | DECLARE_OP_USE_OUTPUT(RealDiv) | ||||
| @@ -264,8 +266,8 @@ DECLARE_OP_ADAPTER(Select) | |||||
| DECLARE_OP_USE_OUTPUT(Select) | DECLARE_OP_USE_OUTPUT(Select) | ||||
| DECLARE_OP_ADAPTER(LessEqual) | DECLARE_OP_ADAPTER(LessEqual) | ||||
| DECLARE_OP_USE_OUTPUT(LessEqual) | DECLARE_OP_USE_OUTPUT(LessEqual) | ||||
| DECLARE_OP_ADAPTER(LogSoftmax) | |||||
| DECLARE_OP_USE_OUTPUT(LogSoftmax) | |||||
| DECLARE_OP_ADAPTER(LogSoftmaxV2) | |||||
| DECLARE_OP_USE_OUTPUT(LogSoftmaxV2) | |||||
| DECLARE_OP_ADAPTER(TruncatedNormal) | DECLARE_OP_ADAPTER(TruncatedNormal) | ||||
| DECLARE_OP_USE_OUTPUT(TruncatedNormal) | DECLARE_OP_USE_OUTPUT(TruncatedNormal) | ||||
| DECLARE_OP_ADAPTER(StridedSliceGrad) | DECLARE_OP_ADAPTER(StridedSliceGrad) | ||||
| @@ -400,8 +402,8 @@ DECLARE_OP_ADAPTER(Sigmoid) | |||||
| DECLARE_OP_USE_OUTPUT(Sigmoid) | DECLARE_OP_USE_OUTPUT(Sigmoid) | ||||
| DECLARE_OP_ADAPTER(SigmoidGrad) | DECLARE_OP_ADAPTER(SigmoidGrad) | ||||
| DECLARE_OP_USE_OUTPUT(SigmoidGrad) | DECLARE_OP_USE_OUTPUT(SigmoidGrad) | ||||
| DECLARE_OP_ADAPTER(Softmax) | |||||
| DECLARE_OP_USE_OUTPUT(Softmax) | |||||
| DECLARE_OP_ADAPTER(SoftmaxV2) | |||||
| DECLARE_OP_USE_OUTPUT(SoftmaxV2) | |||||
| DECLARE_OP_ADAPTER(SoftmaxGrad) | DECLARE_OP_ADAPTER(SoftmaxGrad) | ||||
| DECLARE_OP_USE_OUTPUT(SoftmaxGrad) | DECLARE_OP_USE_OUTPUT(SoftmaxGrad) | ||||
| DECLARE_OP_ADAPTER(Greater) | DECLARE_OP_ADAPTER(Greater) | ||||
| @@ -444,6 +446,8 @@ DECLARE_OP_ADAPTER(Round) | |||||
| DECLARE_OP_USE_OUTPUT(Round) | DECLARE_OP_USE_OUTPUT(Round) | ||||
| DECLARE_OP_ADAPTER(ApplyFtrl) | DECLARE_OP_ADAPTER(ApplyFtrl) | ||||
| DECLARE_OP_USE_OUTPUT(ApplyFtrl) | DECLARE_OP_USE_OUTPUT(ApplyFtrl) | ||||
| DECLARE_OP_ADAPTER(SparseApplyFtrlD) | |||||
| DECLARE_OP_USE_OUTPUT(SparseApplyFtrlD) | |||||
| DECLARE_OP_ADAPTER(Diag) | DECLARE_OP_ADAPTER(Diag) | ||||
| DECLARE_OP_USE_OUTPUT(Diag) | DECLARE_OP_USE_OUTPUT(Diag) | ||||
| DECLARE_OP_ADAPTER(DiagPart) | DECLARE_OP_ADAPTER(DiagPart) | ||||
| @@ -31,6 +31,7 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| const char kShapeSeperator[] = ","; | const char kShapeSeperator[] = ","; | ||||
| const char kShapeScalar[] = "[0]"; | |||||
| static std::map<std::string, TypeId> print_type_map = { | static std::map<std::string, TypeId> print_type_map = { | ||||
| {"int8_t", TypeId::kNumberTypeInt8}, {"uint8_t", TypeId::kNumberTypeUInt8}, | {"int8_t", TypeId::kNumberTypeInt8}, {"uint8_t", TypeId::kNumberTypeUInt8}, | ||||
| {"int16_t", TypeId::kNumberTypeInt16}, {"uint16_t", TypeId::kNumberTypeUInt16}, | {"int16_t", TypeId::kNumberTypeInt16}, {"uint16_t", TypeId::kNumberTypeUInt16}, | ||||
| @@ -81,6 +82,73 @@ bool PrintTensorToString(const char *str_data_ptr, mindspore::tensor::Tensor *co | |||||
| return true; | return true; | ||||
| } | } | ||||
| template <typename T> | |||||
| void PrintScalarToString(const char *str_data_ptr, const string &tensor_type) { | |||||
| const T *data_ptr = reinterpret_cast<const T *>(str_data_ptr); | |||||
| std::ostringstream buf_scalar; | |||||
| buf_scalar << "Tensor shape :1 " << tensor_type; | |||||
| buf_scalar << "\nval:"; | |||||
| buf_scalar << *data_ptr; | |||||
| std::cout << buf_scalar.str() << std::endl; | |||||
| } | |||||
| void PrintScalarToBoolString(const char *str_data_ptr, const string &tensor_type) { | |||||
| const bool *data_ptr = reinterpret_cast<const bool *>(str_data_ptr); | |||||
| std::ostringstream buf_scalar; | |||||
| buf_scalar << "Tensor shape :1 " << tensor_type; | |||||
| buf_scalar << "\nval:"; | |||||
| if (*data_ptr == true) { | |||||
| buf_scalar << "True"; | |||||
| } else { | |||||
| buf_scalar << "False"; | |||||
| } | |||||
| std::cout << buf_scalar.str() << std::endl; | |||||
| } | |||||
| void convertDataItem2Scalar(const char *str_data_ptr, const string &tensor_type) { | |||||
| auto type_iter = print_type_map.find(tensor_type); | |||||
| auto type_id = type_iter->second; | |||||
| if (type_id == TypeId::kNumberTypeBool) { | |||||
| PrintScalarToBoolString(str_data_ptr, tensor_type); | |||||
| } else if (type_id == TypeId::kNumberTypeInt8) { | |||||
| PrintScalarToString<int8_t>(str_data_ptr, tensor_type); | |||||
| } else if (type_id == TypeId::kNumberTypeUInt8) { | |||||
| PrintScalarToString<uint8_t>(str_data_ptr, tensor_type); | |||||
| } else if (type_id == TypeId::kNumberTypeInt16) { | |||||
| PrintScalarToString<int16_t>(str_data_ptr, tensor_type); | |||||
| } else if (type_id == TypeId::kNumberTypeUInt16) { | |||||
| PrintScalarToString<uint16_t>(str_data_ptr, tensor_type); | |||||
| } else if (type_id == TypeId::kNumberTypeInt32) { | |||||
| PrintScalarToString<int32_t>(str_data_ptr, tensor_type); | |||||
| } else if (type_id == TypeId::kNumberTypeUInt32) { | |||||
| PrintScalarToString<uint32_t>(str_data_ptr, tensor_type); | |||||
| } else if (type_id == TypeId::kNumberTypeInt64) { | |||||
| PrintScalarToString<int64_t>(str_data_ptr, tensor_type); | |||||
| } else if (type_id == TypeId::kNumberTypeUInt64) { | |||||
| PrintScalarToString<uint64_t>(str_data_ptr, tensor_type); | |||||
| } else if (type_id == TypeId::kNumberTypeFloat16) { | |||||
| PrintScalarToString<float16>(str_data_ptr, tensor_type); | |||||
| } else if (type_id == TypeId::kNumberTypeFloat32) { | |||||
| PrintScalarToString<float>(str_data_ptr, tensor_type); | |||||
| } else if (type_id == TypeId::kNumberTypeFloat64) { | |||||
| PrintScalarToString<double>(str_data_ptr, tensor_type); | |||||
| } else { | |||||
| MS_LOG(EXCEPTION) << "Cannot print scalar because of unsupport data type: " << tensor_type << "."; | |||||
| } | |||||
| } // namespace mindspore | |||||
| bool judgeLengthValid(const size_t str_len, const string &tensor_type) { | |||||
| auto type_iter = type_size_map.find(tensor_type); | |||||
| if (type_iter == type_size_map.end()) { | |||||
| MS_LOG(EXCEPTION) << "type of scalar to print is not support."; | |||||
| } | |||||
| if (str_len != type_iter->second) { | |||||
| return false; | |||||
| } | |||||
| return true; | |||||
| } | |||||
| #ifndef NO_DLIB | #ifndef NO_DLIB | ||||
| bool ConvertDataItem2Tensor(const std::vector<tdt::DataItem> &items) { | bool ConvertDataItem2Tensor(const std::vector<tdt::DataItem> &items) { | ||||
| // Acquire Python GIL | // Acquire Python GIL | ||||
| @@ -92,14 +160,22 @@ bool ConvertDataItem2Tensor(const std::vector<tdt::DataItem> &items) { | |||||
| ret_end_sequence = true; | ret_end_sequence = true; | ||||
| break; | break; | ||||
| } | } | ||||
| std::shared_ptr<std::string> str_data_ptr = std::static_pointer_cast<std::string>(item.dataPtr_); | |||||
| MS_EXCEPTION_IF_NULL(str_data_ptr); | |||||
| if (item.tensorShape_ == kShapeScalar) { | |||||
| if (!judgeLengthValid(str_data_ptr->size(), item.tensorType_)) { | |||||
| MS_LOG(EXCEPTION) << "Print op receive data length is invalid."; | |||||
| } | |||||
| convertDataItem2Scalar(str_data_ptr->data(), item.tensorType_); | |||||
| continue; | |||||
| } | |||||
| std::vector<int> tensor_shape; | std::vector<int> tensor_shape; | ||||
| size_t totaldims = 1; | size_t totaldims = 1; | ||||
| if (!ParseTensorShape(item.tensorShape_, &tensor_shape, &totaldims)) { | if (!ParseTensorShape(item.tensorShape_, &tensor_shape, &totaldims)) { | ||||
| MS_LOG(ERROR) << "Tensor print can not parse tensor shape, receive info" << item.tensorShape_; | MS_LOG(ERROR) << "Tensor print can not parse tensor shape, receive info" << item.tensorShape_; | ||||
| continue; | continue; | ||||
| } | } | ||||
| std::shared_ptr<std::string> str_data_ptr = std::static_pointer_cast<std::string>(item.dataPtr_); | |||||
| MS_EXCEPTION_IF_NULL(str_data_ptr); | |||||
| if (item.tensorType_ == "string") { | if (item.tensorType_ == "string") { | ||||
| std::string data(reinterpret_cast<const char *>(str_data_ptr->c_str()), item.dataLen_); | std::string data(reinterpret_cast<const char *>(str_data_ptr->c_str()), item.dataLen_); | ||||
| @@ -377,12 +377,10 @@ def get_bprop_batch_norm(self): | |||||
| if is_training: | if is_training: | ||||
| saved_reserve_1 = out[3] | saved_reserve_1 = out[3] | ||||
| saved_reserve_2 = out[4] | saved_reserve_2 = out[4] | ||||
| saved_reserve_3 = out[5] | |||||
| else: | else: | ||||
| saved_reserve_1 = mean | saved_reserve_1 = mean | ||||
| saved_reserve_2 = variance | saved_reserve_2 = variance | ||||
| saved_reserve_3 = variance | |||||
| out = input_grad(dout[0], x, scale, saved_reserve_1, saved_reserve_2, saved_reserve_3) | |||||
| out = input_grad(dout[0], x, scale, saved_reserve_1, saved_reserve_2) | |||||
| dx = out[0] | dx = out[0] | ||||
| dscale = out[1] | dscale = out[1] | ||||
| dbias = out[2] | dbias = out[2] | ||||
| @@ -18,3 +18,8 @@ from .dropout_genmask import _dropout_genmask_aicpu | |||||
| from .get_next import _get_next_aicpu | from .get_next import _get_next_aicpu | ||||
| from .print_tensor import _print_aicpu | from .print_tensor import _print_aicpu | ||||
| from .topk import _top_k_aicpu | from .topk import _top_k_aicpu | ||||
| from .is_finite import _is_finite_aicpu | |||||
| from .reshape import _reshape_aicpu | |||||
| from .flatten import _flatten_aicpu | |||||
| from .squeeze import _squeeze_aicpu | |||||
| from .expand_dims import _expand_dims_aicpu | |||||
| @@ -0,0 +1,52 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ExpandDims op""" | |||||
| from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType | |||||
| expand_dims_op_info = AiCPURegOp("ExpandDims") \ | |||||
| .fusion_type("OPAQUE") \ | |||||
| .input(0, "x", "required") \ | |||||
| .output(0, "y", "required") \ | |||||
| .dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \ | |||||
| .dtype_format(DataType.I8_Default, DataType.I8_Default) \ | |||||
| .dtype_format(DataType.I16_Default, DataType.I16_Default) \ | |||||
| .dtype_format(DataType.I32_Default, DataType.I32_Default) \ | |||||
| .dtype_format(DataType.I64_Default, DataType.I64_Default) \ | |||||
| .dtype_format(DataType.U8_Default, DataType.U8_Default) \ | |||||
| .dtype_format(DataType.U16_Default, DataType.U16_Default) \ | |||||
| .dtype_format(DataType.U32_Default, DataType.U32_Default) \ | |||||
| .dtype_format(DataType.U64_Default, DataType.U64_Default) \ | |||||
| .dtype_format(DataType.F16_Default, DataType.F16_Default) \ | |||||
| .dtype_format(DataType.F32_Default, DataType.F32_Default) \ | |||||
| .dtype_format(DataType.F64_Default, DataType.F64_Default) \ | |||||
| .dtype_format(DataType.BOOL_NCHW, DataType.BOOL_NCHW) \ | |||||
| .dtype_format(DataType.I8_NCHW, DataType.I8_NCHW) \ | |||||
| .dtype_format(DataType.I16_NCHW, DataType.I16_NCHW) \ | |||||
| .dtype_format(DataType.I32_NCHW, DataType.I32_NCHW) \ | |||||
| .dtype_format(DataType.I64_NCHW, DataType.I64_NCHW) \ | |||||
| .dtype_format(DataType.U8_NCHW, DataType.U8_NCHW) \ | |||||
| .dtype_format(DataType.U16_NCHW, DataType.U16_NCHW) \ | |||||
| .dtype_format(DataType.U32_NCHW, DataType.U32_NCHW) \ | |||||
| .dtype_format(DataType.U64_NCHW, DataType.U64_NCHW) \ | |||||
| .dtype_format(DataType.F16_NCHW, DataType.F16_NCHW) \ | |||||
| .dtype_format(DataType.F32_NCHW, DataType.F32_NCHW) \ | |||||
| .dtype_format(DataType.F64_NCHW, DataType.F64_NCHW) \ | |||||
| .get_op_info() | |||||
| @op_info_register(expand_dims_op_info) | |||||
| def _expand_dims_aicpu(): | |||||
| """ExpandDims AiCPU register""" | |||||
| return | |||||
| @@ -0,0 +1,48 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Flatten op""" | |||||
| from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType | |||||
| flatten_op_info = AiCPURegOp("Flatten") \ | |||||
| .fusion_type("OPAQUE") \ | |||||
| .input(0, "x", "required") \ | |||||
| .output(0, "y", "required") \ | |||||
| .dtype_format(DataType.I8_Default, DataType.I8_Default) \ | |||||
| .dtype_format(DataType.I16_Default, DataType.I16_Default) \ | |||||
| .dtype_format(DataType.I32_Default, DataType.I32_Default) \ | |||||
| .dtype_format(DataType.I64_Default, DataType.I64_Default) \ | |||||
| .dtype_format(DataType.U8_Default, DataType.U8_Default) \ | |||||
| .dtype_format(DataType.U16_Default, DataType.U16_Default) \ | |||||
| .dtype_format(DataType.U32_Default, DataType.U32_Default) \ | |||||
| .dtype_format(DataType.U64_Default, DataType.U64_Default) \ | |||||
| .dtype_format(DataType.F16_Default, DataType.F16_Default) \ | |||||
| .dtype_format(DataType.F32_Default, DataType.F32_Default) \ | |||||
| .dtype_format(DataType.I8_NCHW, DataType.I8_NCHW) \ | |||||
| .dtype_format(DataType.I16_NCHW, DataType.I16_NCHW) \ | |||||
| .dtype_format(DataType.I32_NCHW, DataType.I32_NCHW) \ | |||||
| .dtype_format(DataType.I64_NCHW, DataType.I64_NCHW) \ | |||||
| .dtype_format(DataType.U8_NCHW, DataType.U8_NCHW) \ | |||||
| .dtype_format(DataType.U16_NCHW, DataType.U16_NCHW) \ | |||||
| .dtype_format(DataType.U32_NCHW, DataType.U32_NCHW) \ | |||||
| .dtype_format(DataType.U64_NCHW, DataType.U64_NCHW) \ | |||||
| .dtype_format(DataType.F16_NCHW, DataType.F16_NCHW) \ | |||||
| .dtype_format(DataType.F32_NCHW, DataType.F32_NCHW) \ | |||||
| .get_op_info() | |||||
| @op_info_register(flatten_op_info) | |||||
| def _flatten_aicpu(): | |||||
| """Flatten AiCPU register""" | |||||
| return | |||||
| @@ -0,0 +1,52 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """IsFinite op""" | |||||
| from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType | |||||
| is_finite_op_info = AiCPURegOp("IsFinite") \ | |||||
| .fusion_type("OPAQUE") \ | |||||
| .input(0, "x", "required") \ | |||||
| .output(0, "y", "required") \ | |||||
| .dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \ | |||||
| .dtype_format(DataType.I8_Default, DataType.BOOL_Default) \ | |||||
| .dtype_format(DataType.I16_Default, DataType.BOOL_Default) \ | |||||
| .dtype_format(DataType.I32_Default, DataType.BOOL_Default) \ | |||||
| .dtype_format(DataType.I64_Default, DataType.BOOL_Default) \ | |||||
| .dtype_format(DataType.U8_Default, DataType.BOOL_Default) \ | |||||
| .dtype_format(DataType.U16_Default, DataType.BOOL_Default) \ | |||||
| .dtype_format(DataType.U32_Default, DataType.BOOL_Default) \ | |||||
| .dtype_format(DataType.U64_Default, DataType.BOOL_Default) \ | |||||
| .dtype_format(DataType.F16_Default, DataType.BOOL_Default) \ | |||||
| .dtype_format(DataType.F32_Default, DataType.BOOL_Default) \ | |||||
| .dtype_format(DataType.F64_Default, DataType.BOOL_Default) \ | |||||
| .dtype_format(DataType.BOOL_NCHW, DataType.BOOL_NCHW) \ | |||||
| .dtype_format(DataType.I8_NCHW, DataType.BOOL_NCHW) \ | |||||
| .dtype_format(DataType.I16_NCHW, DataType.BOOL_NCHW) \ | |||||
| .dtype_format(DataType.I32_NCHW, DataType.BOOL_NCHW) \ | |||||
| .dtype_format(DataType.I64_NCHW, DataType.BOOL_NCHW) \ | |||||
| .dtype_format(DataType.U8_NCHW, DataType.BOOL_NCHW) \ | |||||
| .dtype_format(DataType.U16_NCHW, DataType.BOOL_NCHW) \ | |||||
| .dtype_format(DataType.U32_NCHW, DataType.BOOL_NCHW) \ | |||||
| .dtype_format(DataType.U64_NCHW, DataType.BOOL_NCHW) \ | |||||
| .dtype_format(DataType.F16_NCHW, DataType.BOOL_NCHW) \ | |||||
| .dtype_format(DataType.F32_NCHW, DataType.BOOL_NCHW) \ | |||||
| .dtype_format(DataType.F64_NCHW, DataType.BOOL_NCHW) \ | |||||
| .get_op_info() | |||||
| @op_info_register(is_finite_op_info) | |||||
| def _is_finite_aicpu(): | |||||
| """IsFinite AiCPU register""" | |||||
| return | |||||
| @@ -0,0 +1,52 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Reshape op""" | |||||
| from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType | |||||
| reshape_op_info = AiCPURegOp("Reshape") \ | |||||
| .fusion_type("OPAQUE") \ | |||||
| .input(0, "x", "required") \ | |||||
| .output(0, "y", "required") \ | |||||
| .dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \ | |||||
| .dtype_format(DataType.I8_Default, DataType.I8_Default) \ | |||||
| .dtype_format(DataType.I16_Default, DataType.I16_Default) \ | |||||
| .dtype_format(DataType.I32_Default, DataType.I32_Default) \ | |||||
| .dtype_format(DataType.I64_Default, DataType.I64_Default) \ | |||||
| .dtype_format(DataType.U8_Default, DataType.U8_Default) \ | |||||
| .dtype_format(DataType.U16_Default, DataType.U16_Default) \ | |||||
| .dtype_format(DataType.U32_Default, DataType.U32_Default) \ | |||||
| .dtype_format(DataType.U64_Default, DataType.U64_Default) \ | |||||
| .dtype_format(DataType.F16_Default, DataType.F16_Default) \ | |||||
| .dtype_format(DataType.F32_Default, DataType.F32_Default) \ | |||||
| .dtype_format(DataType.F64_Default, DataType.F64_Default) \ | |||||
| .dtype_format(DataType.BOOL_NCHW, DataType.BOOL_NCHW) \ | |||||
| .dtype_format(DataType.I8_NCHW, DataType.I8_NCHW) \ | |||||
| .dtype_format(DataType.I16_NCHW, DataType.I16_NCHW) \ | |||||
| .dtype_format(DataType.I32_NCHW, DataType.I32_NCHW) \ | |||||
| .dtype_format(DataType.I64_NCHW, DataType.I64_NCHW) \ | |||||
| .dtype_format(DataType.U8_NCHW, DataType.U8_NCHW) \ | |||||
| .dtype_format(DataType.U16_NCHW, DataType.U16_NCHW) \ | |||||
| .dtype_format(DataType.U32_NCHW, DataType.U32_NCHW) \ | |||||
| .dtype_format(DataType.U64_NCHW, DataType.U64_NCHW) \ | |||||
| .dtype_format(DataType.F16_NCHW, DataType.F16_NCHW) \ | |||||
| .dtype_format(DataType.F32_NCHW, DataType.F32_NCHW) \ | |||||
| .dtype_format(DataType.F64_NCHW, DataType.F64_NCHW) \ | |||||
| .get_op_info() | |||||
| @op_info_register(reshape_op_info) | |||||
| def _reshape_aicpu(): | |||||
| """Rpeshape AiCPU register""" | |||||
| return | |||||
| @@ -0,0 +1,52 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Squeeze op""" | |||||
| from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType | |||||
| squeeze_op_info = AiCPURegOp("Squeeze") \ | |||||
| .fusion_type("OPAQUE") \ | |||||
| .input(0, "x", "required") \ | |||||
| .output(0, "y", "required") \ | |||||
| .dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \ | |||||
| .dtype_format(DataType.I8_Default, DataType.I8_Default) \ | |||||
| .dtype_format(DataType.I16_Default, DataType.I16_Default) \ | |||||
| .dtype_format(DataType.I32_Default, DataType.I32_Default) \ | |||||
| .dtype_format(DataType.I64_Default, DataType.I64_Default) \ | |||||
| .dtype_format(DataType.U8_Default, DataType.U8_Default) \ | |||||
| .dtype_format(DataType.U16_Default, DataType.U16_Default) \ | |||||
| .dtype_format(DataType.U32_Default, DataType.U32_Default) \ | |||||
| .dtype_format(DataType.U64_Default, DataType.U64_Default) \ | |||||
| .dtype_format(DataType.F16_Default, DataType.F16_Default) \ | |||||
| .dtype_format(DataType.F32_Default, DataType.F32_Default) \ | |||||
| .dtype_format(DataType.F64_Default, DataType.F64_Default) \ | |||||
| .dtype_format(DataType.BOOL_NCHW, DataType.BOOL_NCHW) \ | |||||
| .dtype_format(DataType.I8_NCHW, DataType.I8_NCHW) \ | |||||
| .dtype_format(DataType.I16_NCHW, DataType.I16_NCHW) \ | |||||
| .dtype_format(DataType.I32_NCHW, DataType.I32_NCHW) \ | |||||
| .dtype_format(DataType.I64_NCHW, DataType.I64_NCHW) \ | |||||
| .dtype_format(DataType.U8_NCHW, DataType.U8_NCHW) \ | |||||
| .dtype_format(DataType.U16_NCHW, DataType.U16_NCHW) \ | |||||
| .dtype_format(DataType.U32_NCHW, DataType.U32_NCHW) \ | |||||
| .dtype_format(DataType.U64_NCHW, DataType.U64_NCHW) \ | |||||
| .dtype_format(DataType.F16_NCHW, DataType.F16_NCHW) \ | |||||
| .dtype_format(DataType.F32_NCHW, DataType.F32_NCHW) \ | |||||
| .dtype_format(DataType.F64_NCHW, DataType.F64_NCHW) \ | |||||
| .get_op_info() | |||||
| @op_info_register(squeeze_op_info) | |||||
| def _squeeze_aicpu(): | |||||
| """Squeeze AiCPU register""" | |||||
| return | |||||
| @@ -61,9 +61,6 @@ from .reduce_mean_d import _reduce_mean_d_tbe | |||||
| from .scatter_nd import _scatter_nd_tbe | from .scatter_nd import _scatter_nd_tbe | ||||
| from .scatter_nd_d import _scatter_nd_d_tbe | from .scatter_nd_d import _scatter_nd_d_tbe | ||||
| from .reduce_mean import _reduce_mean_tbe | from .reduce_mean import _reduce_mean_tbe | ||||
| from .reshape import _reshape_tbe | |||||
| from .expand_dims import _expand_dims_tbe | |||||
| from .squeeze import _squeeze_tbe | |||||
| from .tile import _tile_tbe | from .tile import _tile_tbe | ||||
| from .atomic_addr_clean import _atomic_addr_clean_tbe | from .atomic_addr_clean import _atomic_addr_clean_tbe | ||||
| from .gather_v2 import _gather_v2_tbe | from .gather_v2 import _gather_v2_tbe | ||||
| @@ -30,22 +30,23 @@ apply_momentum_op_info = TBERegOp("ApplyMomentum") \ | |||||
| .input(3, "grad", False, "required", "all") \ | .input(3, "grad", False, "required", "all") \ | ||||
| .input(4, "momentum", False, "required", "all") \ | .input(4, "momentum", False, "required", "all") \ | ||||
| .output(0, "var", False, "required", "all") \ | .output(0, "var", False, "required", "all") \ | ||||
| .output(1, "accum", False, "required", "all") \ | |||||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, | .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, | ||||
| DataType.F16_Default, DataType.F16_Default) \ | |||||
| DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ | |||||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_Default, DataType.F16_5HD, | .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_Default, DataType.F16_5HD, | ||||
| DataType.F16_Default, DataType.F16_5HD) \ | |||||
| DataType.F16_Default, DataType.F16_5HD, DataType.F16_5HD) \ | |||||
| .dtype_format(DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_Default, DataType.F16_C1HWNCoC0, | .dtype_format(DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_Default, DataType.F16_C1HWNCoC0, | ||||
| DataType.F16_Default, DataType.F16_C1HWNCoC0) \ | |||||
| DataType.F16_Default, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0) \ | |||||
| .dtype_format(DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_Default, DataType.F16_FracZ, | .dtype_format(DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_Default, DataType.F16_FracZ, | ||||
| DataType.F16_Default, DataType.F16_FracZ) \ | |||||
| DataType.F16_Default, DataType.F16_FracZ, DataType.F16_FracZ) \ | |||||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | ||||
| DataType.F32_Default, DataType.F32_Default) \ | |||||
| DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ | |||||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_Default, DataType.F32_5HD, | .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_Default, DataType.F32_5HD, | ||||
| DataType.F32_Default, DataType.F32_5HD) \ | |||||
| DataType.F32_Default, DataType.F32_5HD, DataType.F32_5HD) \ | |||||
| .dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_Default, DataType.F32_C1HWNCoC0, | .dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_Default, DataType.F32_C1HWNCoC0, | ||||
| DataType.F32_Default, DataType.F32_C1HWNCoC0) \ | |||||
| DataType.F32_Default, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0) \ | |||||
| .dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_Default, DataType.F32_FracZ, | .dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_Default, DataType.F32_FracZ, | ||||
| DataType.F32_Default, DataType.F32_FracZ) \ | |||||
| DataType.F32_Default, DataType.F32_FracZ, DataType.F32_FracZ) \ | |||||
| .get_op_info() | .get_op_info() | ||||
| @@ -36,19 +36,18 @@ batch_norm_op_info = TBERegOp("BatchNorm") \ | |||||
| .output(2, "batch_variance", False, "required", "all") \ | .output(2, "batch_variance", False, "required", "all") \ | ||||
| .output(3, "reserve_space_1", False, "optional", "all") \ | .output(3, "reserve_space_1", False, "optional", "all") \ | ||||
| .output(4, "reserve_space_2", False, "optional", "all") \ | .output(4, "reserve_space_2", False, "optional", "all") \ | ||||
| .output(5, "reserve_space_3", False, "optional", "all") \ | |||||
| .dtype_format(DataType.F16_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | .dtype_format(DataType.F16_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | ||||
| DataType.F32_Default, DataType.F16_Default, DataType.F32_Default, DataType.F32_Default, | DataType.F32_Default, DataType.F16_Default, DataType.F32_Default, DataType.F32_Default, | ||||
| DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ | |||||
| DataType.F32_Default, DataType.F32_Default) \ | |||||
| .dtype_format(DataType.F16_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, | .dtype_format(DataType.F16_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, | ||||
| DataType.F32_5HD, DataType.F16_5HD, DataType.F32_5HD, DataType.F32_5HD, | DataType.F32_5HD, DataType.F16_5HD, DataType.F32_5HD, DataType.F32_5HD, | ||||
| DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \ | |||||
| DataType.F32_5HD, DataType.F32_5HD) \ | |||||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | ||||
| DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | ||||
| DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ | |||||
| DataType.F32_Default, DataType.F32_Default) \ | |||||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, | .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, | ||||
| DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, | DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, | ||||
| DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \ | |||||
| DataType.F32_5HD, DataType.F32_5HD) \ | |||||
| .get_op_info() | .get_op_info() | ||||
| @@ -599,4 +599,13 @@ class DataType: | |||||
| F32_NCHW = ("float32", "NCHW") | F32_NCHW = ("float32", "NCHW") | ||||
| F32_NHWC = ("float32", "NHWC") | F32_NHWC = ("float32", "NHWC") | ||||
| F32_HWCN = ("float32", "HWCN") | F32_HWCN = ("float32", "HWCN") | ||||
| F64_None = ("float64", "") | |||||
| F64_Default = ("float64", "DefaultFormat") | |||||
| F64_5HD = ("float64", "NC1HWC0") | |||||
| F64_FracZ = ("float64", "FracZ") | |||||
| F64_FracNZ = ("float64", "FRACTAL_NZ") | |||||
| F64_C1HWNCoC0 = ("float64", "C1HWNCoC0") | |||||
| F64_NCHW = ("float64", "NCHW") | |||||
| F64_NHWC = ("float64", "NHWC") | |||||
| F64_HWCN = ("float64", "HWCN") | |||||
| @@ -85,11 +85,11 @@ class BatchNormGrad(PrimitiveWithInfer): | |||||
| self.epsilon = validator.check_number_range('epsilon', epsilon, 0, 1, Rel.INC_RIGHT, self.name) | self.epsilon = validator.check_number_range('epsilon', epsilon, 0, 1, Rel.INC_RIGHT, self.name) | ||||
| self.add_prim_attr('data_format', "NCHW") | self.add_prim_attr('data_format', "NCHW") | ||||
| def infer_shape(self, y_backprop_shape, x_shape, scale_shape, reserve_1_shape, reserve_2_shape, reserve_3_shape): | |||||
| def infer_shape(self, y_backprop_shape, x_shape, scale_shape, reserve_1_shape, reserve_2_shape): | |||||
| validator.check("BatchNorm y_backprop_shape", y_backprop_shape, "BatchNorm x_shape", x_shape) | validator.check("BatchNorm y_backprop_shape", y_backprop_shape, "BatchNorm x_shape", x_shape) | ||||
| return (x_shape, scale_shape, scale_shape, reserve_1_shape, reserve_2_shape) | return (x_shape, scale_shape, scale_shape, reserve_1_shape, reserve_2_shape) | ||||
| def infer_dtype(self, y_backprop_type, x_type, scale_type, reserve_1_type, reserve_2_type, reserve_3_type): | |||||
| def infer_dtype(self, y_backprop_type, x_type, scale_type, reserve_1_type, reserve_2_type): | |||||
| return (x_type, scale_type, scale_type, reserve_1_type, reserve_2_type) | return (x_type, scale_type, scale_type, reserve_1_type, reserve_2_type) | ||||
| @@ -209,7 +209,7 @@ class Conv2DBackpropFilter(PrimitiveWithInfer): | |||||
| 'value': None, | 'value': None, | ||||
| 'shape': w_size_v, | 'shape': w_size_v, | ||||
| 'dtype': doutput['dtype'], | 'dtype': doutput['dtype'], | ||||
| } | |||||
| } | |||||
| return out | return out | ||||
| @@ -349,7 +349,7 @@ class FlattenGrad(PrimitiveWithInfer): | |||||
| 'value': None, | 'value': None, | ||||
| 'shape': args[1]['value'], | 'shape': args[1]['value'], | ||||
| 'dtype': args[0]['dtype'], | 'dtype': args[0]['dtype'], | ||||
| } | |||||
| } | |||||
| return out | return out | ||||
| @@ -1657,6 +1657,8 @@ class IsFinite(PrimitiveWithInfer): | |||||
| return x_shape | return x_shape | ||||
| def infer_dtype(self, x_dtype): | def infer_dtype(self, x_dtype): | ||||
| validator.check_subclass("x", x_dtype, mstype.tensor, self.name) | |||||
| validator.check_tensor_type_same({'x': x_dtype}, mstype.number_type + (mstype.bool_,), self.name) | |||||
| return mstype.bool_ | return mstype.bool_ | ||||
| class FloatStatus(PrimitiveWithInfer): | class FloatStatus(PrimitiveWithInfer): | ||||
| @@ -580,7 +580,7 @@ class BatchNorm(PrimitiveWithInfer): | |||||
| >>> mean = Tensor(np.ones([64]), mindspore.float32) | >>> mean = Tensor(np.ones([64]), mindspore.float32) | ||||
| >>> variance = Tensor(np.ones([64]), mindspore.float32) | >>> variance = Tensor(np.ones([64]), mindspore.float32) | ||||
| >>> batch_norm = P.BatchNorm() | >>> batch_norm = P.BatchNorm() | ||||
| >>> output = batch_norm(input_x, scale, bias, mean, variance) | |||||
| >>> output = batch_norm(input_x, scale, bias, mean, variance | |||||
| """ | """ | ||||
| @prim_attr_register | @prim_attr_register | ||||
| @@ -589,8 +589,7 @@ class BatchNorm(PrimitiveWithInfer): | |||||
| validator.check_number_range('epsilon', epsilon, 0, 1, Rel.INC_RIGHT, self.name) | validator.check_number_range('epsilon', epsilon, 0, 1, Rel.INC_RIGHT, self.name) | ||||
| self.add_prim_attr('data_format', "NCHW") | self.add_prim_attr('data_format', "NCHW") | ||||
| self.init_prim_io_names(inputs=['x', 'scale', 'offset', 'mean', 'variance'], | self.init_prim_io_names(inputs=['x', 'scale', 'offset', 'mean', 'variance'], | ||||
| outputs=['y', 'batch_mean', 'batch_variance', 'reserve_space_1', 'reserve_space_2', | |||||
| 'reserve_space_3']) | |||||
| outputs=['y', 'batch_mean', 'batch_variance', 'reserve_space_1', 'reserve_space_2']) | |||||
| def infer_shape(self, input_x, scale, bias, mean, variance): | def infer_shape(self, input_x, scale, bias, mean, variance): | ||||
| validator.check_integer("scale rank", len(scale), 1, Rel.EQ, self.name) | validator.check_integer("scale rank", len(scale), 1, Rel.EQ, self.name) | ||||
| @@ -600,7 +599,7 @@ class BatchNorm(PrimitiveWithInfer): | |||||
| validator.check_integer("mean rank", len(mean), 1, Rel.EQ, self.name) | validator.check_integer("mean rank", len(mean), 1, Rel.EQ, self.name) | ||||
| validator.check("mean shape", mean, "variance shape", variance, Rel.EQ, self.name) | validator.check("mean shape", mean, "variance shape", variance, Rel.EQ, self.name) | ||||
| validator.check("mean shape", mean, "scale shape", scale, Rel.EQ, self.name) | validator.check("mean shape", mean, "scale shape", scale, Rel.EQ, self.name) | ||||
| return (input_x, scale, scale, scale, scale, scale) | |||||
| return (input_x, scale, scale, scale, scale) | |||||
| def infer_dtype(self, input_x, scale, bias, mean, variance): | def infer_dtype(self, input_x, scale, bias, mean, variance): | ||||
| validator.check_tensor_type_same({"input_x": input_x}, [mstype.float16, mstype.float32], self.name) | validator.check_tensor_type_same({"input_x": input_x}, [mstype.float16, mstype.float32], self.name) | ||||
| @@ -613,7 +612,7 @@ class BatchNorm(PrimitiveWithInfer): | |||||
| else: | else: | ||||
| args_moving = {"mean": mean, "variance": variance} | args_moving = {"mean": mean, "variance": variance} | ||||
| validator.check_tensor_type_same(args_moving, [mstype.float16, mstype.float32], self.name) | validator.check_tensor_type_same(args_moving, [mstype.float16, mstype.float32], self.name) | ||||
| return (input_x, scale, bias, input_x, input_x, input_x) | |||||
| return (input_x, scale, bias, input_x, input_x) | |||||
| class Conv2D(PrimitiveWithInfer): | class Conv2D(PrimitiveWithInfer): | ||||
| @@ -1428,8 +1427,11 @@ class ApplyMomentum(PrimitiveWithInfer): | |||||
| def __init__(self, use_nesterov=False, use_locking=False, gradient_scale=1.0): | def __init__(self, use_nesterov=False, use_locking=False, gradient_scale=1.0): | ||||
| self.init_prim_io_names(inputs=['variable', 'accumulation', 'learning_rate', 'gradient', 'momentum'], | self.init_prim_io_names(inputs=['variable', 'accumulation', 'learning_rate', 'gradient', 'momentum'], | ||||
| outputs=['output']) | outputs=['output']) | ||||
| self.is_tbe = context.get_context("device_target") == "Ascend" | |||||
| def infer_shape(self, v_shape, a_shape, l_shape, g_shape, m_shape): | def infer_shape(self, v_shape, a_shape, l_shape, g_shape, m_shape): | ||||
| if self.is_tbe: | |||||
| return v_shape, v_shape | |||||
| return v_shape | return v_shape | ||||
| def infer_dtype(self, v_dtype, a_dtype, l_dtype, g_dtype, m_dtype): | def infer_dtype(self, v_dtype, a_dtype, l_dtype, g_dtype, m_dtype): | ||||
| @@ -1440,6 +1442,8 @@ class ApplyMomentum(PrimitiveWithInfer): | |||||
| validator.check_scalar_or_tensor_type_same({"l_dtype": l_dtype}, valid_types, self.name) | validator.check_scalar_or_tensor_type_same({"l_dtype": l_dtype}, valid_types, self.name) | ||||
| validator.check_scalar_or_tensor_type_same({"g_dtype": g_dtype}, valid_types, self.name) | validator.check_scalar_or_tensor_type_same({"g_dtype": g_dtype}, valid_types, self.name) | ||||
| validator.check_scalar_or_tensor_type_same({"m_dtype": m_dtype}, valid_types, self.name) | validator.check_scalar_or_tensor_type_same({"m_dtype": m_dtype}, valid_types, self.name) | ||||
| if self.is_tbe: | |||||
| return g_dtype, g_dtype | |||||
| return g_dtype | return g_dtype | ||||
| @@ -2578,13 +2582,13 @@ class SparseApplyAdagrad(PrimitiveWithInfer): | |||||
| validator.check('var_shape[1:]', var_shape[1:], 'grad_shape[1:]', grad_shape[1:], Rel.EQ, self.name) | validator.check('var_shape[1:]', var_shape[1:], 'grad_shape[1:]', grad_shape[1:], Rel.EQ, self.name) | ||||
| validator.check_integer("indices rank", len(indices_shape), 1, Rel.EQ, self.name) | validator.check_integer("indices rank", len(indices_shape), 1, Rel.EQ, self.name) | ||||
| validator.check('grad_shape[0]', grad_shape[0], 'indices_shape[0]', indices_shape[0], Rel.EQ, self.name) | validator.check('grad_shape[0]', grad_shape[0], 'indices_shape[0]', indices_shape[0], Rel.EQ, self.name) | ||||
| return var_shape | |||||
| return var_shape, accum_shape | |||||
| def infer_dtype(self, var_type, accum_type, grad_type, indices_type): | def infer_dtype(self, var_type, accum_type, grad_type, indices_type): | ||||
| args = {'var': var_type, 'accum': accum_type, 'grad': grad_type} | args = {'var': var_type, 'accum': accum_type, 'grad': grad_type} | ||||
| validator.check_tensor_type_same(args, (mstype.float32,), self.name) | validator.check_tensor_type_same(args, (mstype.float32,), self.name) | ||||
| validator.check_tensor_type_same({'indices': indices_type}, [mstype.int32], self.name) | validator.check_tensor_type_same({'indices': indices_type}, [mstype.int32], self.name) | ||||
| return var_type | |||||
| return var_type, accum_type | |||||
| class LARSUpdate(PrimitiveWithInfer): | class LARSUpdate(PrimitiveWithInfer): | ||||
| @@ -0,0 +1,114 @@ | |||||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| from mindspore import Tensor | |||||
| from mindspore.ops import operations as P | |||||
| import mindspore.nn as nn | |||||
| from mindspore.common.api import ms_function | |||||
| import numpy as np | |||||
| import mindspore.context as context | |||||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") | |||||
| class Net(nn.Cell): | |||||
| def __init__(self): | |||||
| super(Net, self).__init__() | |||||
| self.expand_dims = P.ExpandDims() | |||||
| def construct(self, tensor, dim): | |||||
| return self.expand_dims(tensor, dim) | |||||
| def test_net_bool(): | |||||
| x = np.random.randn(1, 16, 1, 1).astype(np.bool) | |||||
| net = Net() | |||||
| output = net(Tensor(x), -1) | |||||
| print(output.asnumpy()) | |||||
| assert(np.all(output.asnumpy() == np.expand_dims(x, -1))) | |||||
| def test_net_int8(): | |||||
| x = np.random.randn(1, 16, 1, 1).astype(np.int8) | |||||
| net = Net() | |||||
| output = net(Tensor(x), -1) | |||||
| print(output.asnumpy()) | |||||
| assert(np.all(output.asnumpy() == np.expand_dims(x, -1))) | |||||
| def test_net_uint8(): | |||||
| x = np.random.randn(1, 16, 1, 1).astype(np.uint8) | |||||
| net = Net() | |||||
| output = net(Tensor(x), -1) | |||||
| print(output.asnumpy()) | |||||
| assert(np.all(output.asnumpy() == np.expand_dims(x, -1))) | |||||
| def test_net_int16(): | |||||
| x = np.random.randn(1, 16, 1, 1).astype(np.int16) | |||||
| net = Net() | |||||
| output = net(Tensor(x), -1) | |||||
| print(output.asnumpy()) | |||||
| assert(np.all(output.asnumpy() == np.expand_dims(x, -1))) | |||||
| def test_net_uint16(): | |||||
| x = np.random.randn(1, 16, 1, 1).astype(np.uint16) | |||||
| net = Net() | |||||
| output = net(Tensor(x), -1) | |||||
| print(output.asnumpy()) | |||||
| assert(np.all(output.asnumpy() == np.expand_dims(x, -1))) | |||||
| def test_net_int32(): | |||||
| x = np.random.randn(1, 16, 1, 1).astype(np.int32) | |||||
| net = Net() | |||||
| output = net(Tensor(x), -1) | |||||
| print(output.asnumpy()) | |||||
| assert(np.all(output.asnumpy() == np.expand_dims(x, -1))) | |||||
| def test_net_uint32(): | |||||
| x = np.random.randn(1, 16, 1, 1).astype(np.uint32) | |||||
| net = Net() | |||||
| output = net(Tensor(x), -1) | |||||
| print(output.asnumpy()) | |||||
| assert(np.all(output.asnumpy() == np.expand_dims(x, -1))) | |||||
| def test_net_int64(): | |||||
| x = np.random.randn(1, 16, 1, 1).astype(np.int64) | |||||
| net = Net() | |||||
| output = net(Tensor(x), -1) | |||||
| print(output.asnumpy()) | |||||
| assert(np.all(output.asnumpy() == np.expand_dims(x, -1))) | |||||
| def test_net_uint64(): | |||||
| x = np.random.randn(1, 16, 1, 1).astype(np.uint64) | |||||
| net = Net() | |||||
| output = net(Tensor(x), -1) | |||||
| print(output.asnumpy()) | |||||
| assert(np.all(output.asnumpy() == np.expand_dims(x, -1))) | |||||
| def test_net_float16(): | |||||
| x = np.random.randn(1, 16, 1, 1).astype(np.float16) | |||||
| net = Net() | |||||
| output = net(Tensor(x), -1) | |||||
| print(output.asnumpy()) | |||||
| assert(np.all(output.asnumpy() == np.expand_dims(x, -1))) | |||||
| def test_net_float32(): | |||||
| x = np.random.randn(1, 16, 1, 1).astype(np.float32) | |||||
| net = Net() | |||||
| output = net(Tensor(x), -1) | |||||
| print(output.asnumpy()) | |||||
| assert(np.all(output.asnumpy() == np.expand_dims(x, -1))) | |||||
| def test_net_float64(): | |||||
| x = np.random.randn(1, 16, 1, 1).astype(np.float64) | |||||
| net = Net() | |||||
| output = net(Tensor(x), -1) | |||||
| print(output.asnumpy()) | |||||
| assert(np.all(output.asnumpy() == np.expand_dims(x, -1))) | |||||
| @@ -0,0 +1,99 @@ | |||||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| from mindspore import Tensor | |||||
| from mindspore.ops import operations as P | |||||
| import mindspore.nn as nn | |||||
| import numpy as np | |||||
| import mindspore.context as context | |||||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") | |||||
| class Net(nn.Cell): | |||||
| def __init__(self): | |||||
| super(Net, self).__init__() | |||||
| self.flatten = P.Flatten() | |||||
| def construct(self, tensor): | |||||
| return self.flatten(tensor) | |||||
| def test_net_int8(): | |||||
| x = np.random.randn(1, 16, 1, 1).astype(np.int8) | |||||
| net = Net() | |||||
| output = net(Tensor(x)) | |||||
| print(output.asnumpy()) | |||||
| assert(np.all(output.asnumpy() == x.flatten())) | |||||
| def test_net_uint8(): | |||||
| x = np.random.randn(1, 16, 1, 1).astype(np.uint8) | |||||
| net = Net() | |||||
| output = net(Tensor(x)) | |||||
| print(output.asnumpy()) | |||||
| assert(np.all(output.asnumpy() == x.flatten())) | |||||
| def test_net_int16(): | |||||
| x = np.random.randn(1, 16, 1, 1).astype(np.int16) | |||||
| net = Net() | |||||
| output = net(Tensor(x)) | |||||
| print(output.asnumpy()) | |||||
| assert(np.all(output.asnumpy() == x.flatten())) | |||||
| def test_net_uint16(): | |||||
| x = np.random.randn(1, 16, 1, 1).astype(np.uint16) | |||||
| net = Net() | |||||
| output = net(Tensor(x)) | |||||
| print(output.asnumpy()) | |||||
| assert(np.all(output.asnumpy() == x.flatten())) | |||||
| def test_net_int32(): | |||||
| x = np.random.randn(1, 16, 1, 1).astype(np.int32) | |||||
| net = Net() | |||||
| output = net(Tensor(x)) | |||||
| print(output.asnumpy()) | |||||
| assert(np.all(output.asnumpy() == x.flatten())) | |||||
| def test_net_uint32(): | |||||
| x = np.random.randn(1, 16, 1, 1).astype(np.uint32) | |||||
| net = Net() | |||||
| output = net(Tensor(x)) | |||||
| print(output.asnumpy()) | |||||
| assert(np.all(output.asnumpy() == x.flatten())) | |||||
| def test_net_int64(): | |||||
| x = np.random.randn(1, 16, 1, 1).astype(np.int64) | |||||
| net = Net() | |||||
| output = net(Tensor(x)) | |||||
| print(output.asnumpy()) | |||||
| assert(np.all(output.asnumpy() == x.flatten())) | |||||
| def test_net_uint64(): | |||||
| x = np.random.randn(1, 16, 1, 1).astype(np.uint64) | |||||
| net = Net() | |||||
| output = net(Tensor(x)) | |||||
| print(output.asnumpy()) | |||||
| assert(np.all(output.asnumpy() == x.flatten())) | |||||
| def test_net_float16(): | |||||
| x = np.random.randn(1, 16, 1, 1).astype(np.float16) | |||||
| net = Net() | |||||
| output = net(Tensor(x)) | |||||
| print(output.asnumpy()) | |||||
| assert(np.all(output.asnumpy() == x.flatten())) | |||||
| def test_net_float32(): | |||||
| x = np.random.randn(1, 16, 1, 1).astype(np.float32) | |||||
| net = Net() | |||||
| output = net(Tensor(x)) | |||||
| print(output.asnumpy()) | |||||
| assert(np.all(output.asnumpy() == x.flatten())) | |||||
| @@ -0,0 +1,114 @@ | |||||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| from mindspore import Tensor | |||||
| from mindspore.ops import operations as P | |||||
| import mindspore.nn as nn | |||||
| from mindspore.common.api import ms_function | |||||
| import numpy as np | |||||
| import mindspore.context as context | |||||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") | |||||
| class Net(nn.Cell): | |||||
| def __init__(self): | |||||
| super(Net, self).__init__() | |||||
| self.isfinite = P.IsFinite() | |||||
| def construct(self, tensor): | |||||
| return self.isfinite(tensor) | |||||
| def test_net_bool(): | |||||
| x = np.random.randn(1, 16, 1, 1).astype(np.bool) | |||||
| net = Net() | |||||
| output = net(Tensor(x)) | |||||
| print(output.asnumpy()) | |||||
| assert(np.all(output.asnumpy() == np.isfinite(x))) | |||||
| def test_net_int8(): | |||||
| x = np.random.randn(1, 16, 1, 1).astype(np.int8) | |||||
| net = Net() | |||||
| output = net(Tensor(x)) | |||||
| print(output.asnumpy()) | |||||
| assert(np.all(output.asnumpy() == np.isfinite(x))) | |||||
| def test_net_uint8(): | |||||
| x = np.random.randn(1, 16, 1, 1).astype(np.uint8) | |||||
| net = Net() | |||||
| output = net(Tensor(x)) | |||||
| print(output.asnumpy()) | |||||
| assert(np.all(output.asnumpy() == np.isfinite(x))) | |||||
| def test_net_int16(): | |||||
| x = np.random.randn(1, 16, 1, 1).astype(np.int16) | |||||
| net = Net() | |||||
| output = net(Tensor(x)) | |||||
| print(output.asnumpy()) | |||||
| assert(np.all(output.asnumpy() == np.isfinite(x))) | |||||
| def test_net_uint16(): | |||||
| x = np.random.randn(1, 16, 1, 1).astype(np.uint16) | |||||
| net = Net() | |||||
| output = net(Tensor(x)) | |||||
| print(output.asnumpy()) | |||||
| assert(np.all(output.asnumpy() == np.isfinite(x))) | |||||
| def test_net_int32(): | |||||
| x = np.random.randn(1, 16, 1, 1).astype(np.int32) | |||||
| net = Net() | |||||
| output = net(Tensor(x)) | |||||
| print(output.asnumpy()) | |||||
| assert(np.all(output.asnumpy() == np.isfinite(x))) | |||||
| def test_net_uint32(): | |||||
| x = np.random.randn(1, 16, 1, 1).astype(np.uint32) | |||||
| net = Net() | |||||
| output = net(Tensor(x)) | |||||
| print(output.asnumpy()) | |||||
| assert(np.all(output.asnumpy() == np.isfinite(x))) | |||||
| def test_net_int64(): | |||||
| x = np.random.randn(1, 16, 1, 1).astype(np.int64) | |||||
| net = Net() | |||||
| output = net(Tensor(x)) | |||||
| print(output.asnumpy()) | |||||
| assert(np.all(output.asnumpy() == np.isfinite(x))) | |||||
| def test_net_uint64(): | |||||
| x = np.random.randn(1, 16, 1, 1).astype(np.uint64) | |||||
| net = Net() | |||||
| output = net(Tensor(x)) | |||||
| print(output.asnumpy()) | |||||
| assert(np.all(output.asnumpy() == np.isfinite(x))) | |||||
| def test_net_float16(): | |||||
| x = np.random.randn(1, 16, 1, 1).astype(np.float16) | |||||
| net = Net() | |||||
| output = net(Tensor(x)) | |||||
| print(output.asnumpy()) | |||||
| assert(np.all(output.asnumpy() == np.isfinite(x))) | |||||
| def test_net_float32(): | |||||
| x = np.random.randn(1, 16, 1, 1).astype(np.float32) | |||||
| net = Net() | |||||
| output = net(Tensor(x)) | |||||
| print(output.asnumpy()) | |||||
| assert(np.all(output.asnumpy() == np.isfinite(x))) | |||||
| def test_net_float64(): | |||||
| x = np.random.randn(1, 16, 1, 1).astype(np.float64) | |||||
| net = Net() | |||||
| output = net(Tensor(x)) | |||||
| print(output.asnumpy()) | |||||
| assert(np.all(output.asnumpy() == np.isfinite(x))) | |||||
| @@ -0,0 +1,114 @@ | |||||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| from mindspore import Tensor | |||||
| from mindspore.ops import operations as P | |||||
| import mindspore.nn as nn | |||||
| from mindspore.common.api import ms_function | |||||
| import numpy as np | |||||
| import mindspore.context as context | |||||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") | |||||
| class Net(nn.Cell): | |||||
| def __init__(self): | |||||
| super(Net, self).__init__() | |||||
| self.reshape = P.Reshape() | |||||
| def construct(self, tensor): | |||||
| return self.reshape(tensor, (4,4)) | |||||
| def test_net_bool(): | |||||
| x = np.random.randn(1, 16, 1, 1).astype(np.bool) | |||||
| net = Net() | |||||
| output = net(Tensor(x)) | |||||
| print(output.asnumpy()) | |||||
| assert(np.all(output.asnumpy() == np.reshape(x, (4,4)))) | |||||
| def test_net_int8(): | |||||
| x = np.random.randn(1, 16, 1, 1).astype(np.int8) | |||||
| net = Net() | |||||
| output = net(Tensor(x)) | |||||
| print(output.asnumpy()) | |||||
| assert(np.all(output.asnumpy() == np.reshape(x, (4,4)))) | |||||
| def test_net_uint8(): | |||||
| x = np.random.randn(1, 16, 1, 1).astype(np.uint8) | |||||
| net = Net() | |||||
| output = net(Tensor(x)) | |||||
| print(output.asnumpy()) | |||||
| assert(np.all(output.asnumpy() == np.reshape(x, (4,4)))) | |||||
| def test_net_int16(): | |||||
| x = np.random.randn(1, 16, 1, 1).astype(np.int16) | |||||
| net = Net() | |||||
| output = net(Tensor(x)) | |||||
| print(output.asnumpy()) | |||||
| assert(np.all(output.asnumpy() == np.reshape(x, (4,4)))) | |||||
| def test_net_uint16(): | |||||
| x = np.random.randn(1, 16, 1, 1).astype(np.uint16) | |||||
| net = Net() | |||||
| output = net(Tensor(x)) | |||||
| print(output.asnumpy()) | |||||
| assert(np.all(output.asnumpy() == np.reshape(x, (4,4)))) | |||||
| def test_net_int32(): | |||||
| x = np.random.randn(1, 16, 1, 1).astype(np.int32) | |||||
| net = Net() | |||||
| output = net(Tensor(x)) | |||||
| print(output.asnumpy()) | |||||
| assert(np.all(output.asnumpy() == np.reshape(x, (4,4)))) | |||||
| def test_net_uint32(): | |||||
| x = np.random.randn(1, 16, 1, 1).astype(np.uint32) | |||||
| net = Net() | |||||
| output = net(Tensor(x)) | |||||
| print(output.asnumpy()) | |||||
| assert(np.all(output.asnumpy() == np.reshape(x, (4,4)))) | |||||
| def test_net_int64(): | |||||
| x = np.random.randn(1, 16, 1, 1).astype(np.int64) | |||||
| net = Net() | |||||
| output = net(Tensor(x)) | |||||
| print(output.asnumpy()) | |||||
| assert(np.all(output.asnumpy() == np.reshape(x, (4,4)))) | |||||
| def test_net_uint64(): | |||||
| x = np.random.randn(1, 16, 1, 1).astype(np.uint64) | |||||
| net = Net() | |||||
| output = net(Tensor(x)) | |||||
| print(output.asnumpy()) | |||||
| assert(np.all(output.asnumpy() == np.reshape(x, (4,4)))) | |||||
| def test_net_float16(): | |||||
| x = np.random.randn(1, 16, 1, 1).astype(np.float16) | |||||
| net = Net() | |||||
| output = net(Tensor(x)) | |||||
| print(output.asnumpy()) | |||||
| assert(np.all(output.asnumpy() == np.reshape(x, (4,4)))) | |||||
| def test_net_float32(): | |||||
| x = np.random.randn(1, 16, 1, 1).astype(np.float32) | |||||
| net = Net() | |||||
| output = net(Tensor(x)) | |||||
| print(output.asnumpy()) | |||||
| assert(np.all(output.asnumpy() == np.reshape(x, (4,4)))) | |||||
| def test_net_float64(): | |||||
| x = np.random.randn(1, 16, 1, 1).astype(np.float64) | |||||
| net = Net() | |||||
| output = net(Tensor(x)) | |||||
| print(output.asnumpy()) | |||||
| assert(np.all(output.asnumpy() == np.reshape(x, (4,4)))) | |||||
| @@ -0,0 +1,113 @@ | |||||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| from mindspore import Tensor | |||||
| from mindspore.ops import operations as P | |||||
| import mindspore.nn as nn | |||||
| import numpy as np | |||||
| import mindspore.context as context | |||||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") | |||||
| class Net(nn.Cell): | |||||
| def __init__(self): | |||||
| super(Net, self).__init__() | |||||
| self.squeeze = P.Squeeze() | |||||
| def construct(self, tensor): | |||||
| return self.squeeze(tensor) | |||||
| def test_net_bool(): | |||||
| x = np.random.randn(1, 16, 1, 1).astype(np.bool) | |||||
| net = Net() | |||||
| output = net(Tensor(x)) | |||||
| print(output.asnumpy()) | |||||
| assert(np.all(output.asnumpy() == x.squeeze())) | |||||
| def test_net_int8(): | |||||
| x = np.random.randn(1, 16, 1, 1).astype(np.int8) | |||||
| net = Net() | |||||
| output = net(Tensor(x)) | |||||
| print(output.asnumpy()) | |||||
| assert(np.all(output.asnumpy() == x.squeeze())) | |||||
| def test_net_uint8(): | |||||
| x = np.random.randn(1, 16, 1, 1).astype(np.uint8) | |||||
| net = Net() | |||||
| output = net(Tensor(x)) | |||||
| print(output.asnumpy()) | |||||
| assert(np.all(output.asnumpy() == x.squeeze())) | |||||
| def test_net_int16(): | |||||
| x = np.random.randn(1, 16, 1, 1).astype(np.int16) | |||||
| net = Net() | |||||
| output = net(Tensor(x)) | |||||
| print(output.asnumpy()) | |||||
| assert(np.all(output.asnumpy() == x.squeeze())) | |||||
| def test_net_uint16(): | |||||
| x = np.random.randn(1, 16, 1, 1).astype(np.uint16) | |||||
| net = Net() | |||||
| output = net(Tensor(x)) | |||||
| print(output.asnumpy()) | |||||
| assert(np.all(output.asnumpy() == x.squeeze())) | |||||
| def test_net_int32(): | |||||
| x = np.random.randn(1, 16, 1, 1).astype(np.int32) | |||||
| net = Net() | |||||
| output = net(Tensor(x)) | |||||
| print(output.asnumpy()) | |||||
| assert(np.all(output.asnumpy() == x.squeeze())) | |||||
| def test_net_uint32(): | |||||
| x = np.random.randn(1, 16, 1, 1).astype(np.uint32) | |||||
| net = Net() | |||||
| output = net(Tensor(x)) | |||||
| print(output.asnumpy()) | |||||
| assert(np.all(output.asnumpy() == x.squeeze())) | |||||
| def test_net_int64(): | |||||
| x = np.random.randn(1, 16, 1, 1).astype(np.int64) | |||||
| net = Net() | |||||
| output = net(Tensor(x)) | |||||
| print(output.asnumpy()) | |||||
| assert(np.all(output.asnumpy() == x.squeeze())) | |||||
| def test_net_uint64(): | |||||
| x = np.random.randn(1, 16, 1, 1).astype(np.uint64) | |||||
| net = Net() | |||||
| output = net(Tensor(x)) | |||||
| print(output.asnumpy()) | |||||
| assert(np.all(output.asnumpy() == x.squeeze())) | |||||
| def test_net_float16(): | |||||
| x = np.random.randn(1, 16, 1, 1).astype(np.float16) | |||||
| net = Net() | |||||
| output = net(Tensor(x)) | |||||
| print(output.asnumpy()) | |||||
| assert(np.all(output.asnumpy() == x.squeeze())) | |||||
| def test_net_float32(): | |||||
| x = np.random.randn(1, 16, 1, 1).astype(np.float32) | |||||
| net = Net() | |||||
| output = net(Tensor(x)) | |||||
| print(output.asnumpy()) | |||||
| assert(np.all(output.asnumpy() == x.squeeze())) | |||||
| def test_net_float64(): | |||||
| x = np.random.randn(1, 16, 1, 1).astype(np.float64) | |||||
| net = Net() | |||||
| output = net(Tensor(x)) | |||||
| print(output.asnumpy()) | |||||
| assert(np.all(output.asnumpy() == x.squeeze())) | |||||
| @@ -1,59 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "common/backend_common_test.h" | |||||
| #include "common/py_func_graph_fetcher.h" | |||||
| #include "operator/ops.h" | |||||
| #include "ir/meta_tensor.h" | |||||
| #include "debug/anf_ir_dump.h" | |||||
| #include "utils/utils.h" | |||||
| #include "pre_activate/common/optimizer.h" | |||||
| #include "pre_activate/ascend/ir_fission/batch_norm_grad_split.h" | |||||
| #include "session/anf_runtime_algorithm.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| class TestHWBatchNormGradSplit : public BackendCommon { | |||||
| public: | |||||
| TestHWBatchNormGradSplit() : get_py_fun_("gtest_input.pre_activate.batch_norm_grad_split", true) {} | |||||
| public: | |||||
| UT::PyFuncGraphFetcher get_py_fun_; | |||||
| }; | |||||
| TEST_F(TestHWBatchNormGradSplit, test_split) { | |||||
| get_py_fun_.SetDoResolve(true); | |||||
| FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_batch_norm_grad_split", "before"); | |||||
| EXPECT_NE(g, nullptr); | |||||
| std::vector<int> shp_x{1, 64, 112, 112}; | |||||
| std::vector<int> shp_b{64}; | |||||
| auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_x); | |||||
| auto b_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_b); | |||||
| AbstractBasePtrList args_spec_list{x_abstract, x_abstract, b_abstract, b_abstract, b_abstract, b_abstract}; | |||||
| auto kernel_graph = GetKernelGraph(g, args_spec_list); | |||||
| EXPECT_NE(kernel_graph, nullptr); | |||||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||||
| auto pm = std::make_shared<opt::PassManager>(); | |||||
| auto pass = std::make_shared<opt::BatchNormGradSplit>(); | |||||
| pm->AddPass(pass); | |||||
| optimizer->AddPassManager(pm); | |||||
| auto new_graph = optimizer->Optimize(kernel_graph); | |||||
| FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_batch_norm_grad_split", "after"); | |||||
| EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); | |||||
| } | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| @@ -189,7 +189,8 @@ TEST_F(TestConvert, TestConvertBatchNorm) { | |||||
| TEST_F(TestConvert, TestConvertConvBackpropInput) { | TEST_F(TestConvert, TestConvertConvBackpropInput) { | ||||
| auto prim = prim::kPrimConv2DBackpropInput; | auto prim = prim::kPrimConv2DBackpropInput; | ||||
| prim->AddAttr("stride", MakeValue(1)); | |||||
| const std::vector<int> list{1,1}; | |||||
| prim->AddAttr("stride", MakeValue(list)); | |||||
| prim->AddAttr("pad", MakeValue(0)); | prim->AddAttr("pad", MakeValue(0)); | ||||
| prim->AddAttr("pad_mode", MakeValue(std::string("pad"))); | prim->AddAttr("pad_mode", MakeValue(std::string("pad"))); | ||||
| prim->AddAttr("dilation", MakeValue(1)); | prim->AddAttr("dilation", MakeValue(1)); | ||||
| @@ -218,7 +219,8 @@ TEST_F(TestConvert, TestConvertConvBackpropInput) { | |||||
| TEST_F(TestConvert, TestConvertConvBackpropFilter) { | TEST_F(TestConvert, TestConvertConvBackpropFilter) { | ||||
| auto prim = prim::kPrimConv2DBackpropFilter; | auto prim = prim::kPrimConv2DBackpropFilter; | ||||
| prim->AddAttr("stride", MakeValue(1)); | |||||
| const std::vector<int> list{1,1}; | |||||
| prim->AddAttr("stride", MakeValue(list)); | |||||
| prim->AddAttr("pad", MakeValue(0)); | prim->AddAttr("pad", MakeValue(0)); | ||||
| prim->AddAttr("pad_mode", MakeValue(std::string("pad"))); | prim->AddAttr("pad_mode", MakeValue(std::string("pad"))); | ||||
| prim->AddAttr("dilation", MakeValue(1)); | prim->AddAttr("dilation", MakeValue(1)); | ||||
| @@ -38,7 +38,7 @@ def tensor_run_opt(opt, iters, learning_rate, momentum, | |||||
| gradient, variable, moment): | gradient, variable, moment): | ||||
| """ tensor_run_opt """ | """ tensor_run_opt """ | ||||
| success = True | success = True | ||||
| new_weight = opt(variable, moment, learning_rate, gradient, momentum) | |||||
| new_weight = opt(variable, moment, learning_rate, gradient, momentum)[0] | |||||
| success = F.depend(success, F.assign(variable, new_weight)) | success = F.depend(success, F.assign(variable, new_weight)) | ||||
| return success | return success | ||||
| @@ -670,7 +670,7 @@ test_case_nn_ops = [ | |||||
| 'skip': []}), | 'skip': []}), | ||||
| ('BatchNormGrad', { | ('BatchNormGrad', { | ||||
| 'block': G.BatchNormGrad(), | 'block': G.BatchNormGrad(), | ||||
| 'desc_inputs': [[128, 64, 32, 32], [128, 64, 32, 32], [64], [64], [64], [64]], | |||||
| 'desc_inputs': [[128, 64, 32, 32], [128, 64, 32, 32], [64], [64], [64]], | |||||
| 'desc_bprop': [[128, 64, 32, 32], [64], [64], [64], [64]], | 'desc_bprop': [[128, 64, 32, 32], [64], [64], [64], [64]], | ||||
| 'skip': ['backward']}), | 'skip': ['backward']}), | ||||
| ('TopK', { | ('TopK', { | ||||
| @@ -807,7 +807,7 @@ test_case_nn_ops = [ | |||||
| ('SparseApplyAdagrad', { | ('SparseApplyAdagrad', { | ||||
| 'block': P.SparseApplyAdagrad(0.5), | 'block': P.SparseApplyAdagrad(0.5), | ||||
| 'desc_inputs': [[3, 3], [3, 3], [3, 3], Tensor(np.ones((3,), np.int32))], | 'desc_inputs': [[3, 3], [3, 3], [3, 3], Tensor(np.ones((3,), np.int32))], | ||||
| 'desc_bprop': [3, 3], | |||||
| 'desc_bprop': [[3, 3], [3, 3]], | |||||
| 'skip': ['backward']}), | 'skip': ['backward']}), | ||||
| ('Flatten_1', { | ('Flatten_1', { | ||||
| 'block': NetForFlatten(), | 'block': NetForFlatten(), | ||||