| @@ -0,0 +1,206 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "backend/kernel_compiler/host/dynamic_broadcast_gradient_args_kernel.h" | |||
| #include "backend/session/anf_runtime_algorithm.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| namespace { | |||
| const int kInputNum = 2; | |||
| std::vector<std::vector<int64_t>> GetGradientIndices(const std::vector<std::vector<int64_t>> &reverse_shape, | |||
| const size_t largest_rank) { | |||
| std::vector<std::vector<int64_t>> grad_reduce_idx(kInputNum); | |||
| // indices of j-th component of each input. | |||
| bool prev_is_one[kInputNum]; | |||
| bool current_is_one[kInputNum]; | |||
| for (int i = 0; i < kInputNum; ++i) { | |||
| prev_is_one[i] = false; | |||
| current_is_one[i] = false; | |||
| } | |||
| bool set_one = false; | |||
| for (size_t j = 0; j < largest_rank; ++j) { | |||
| int output_dim = -1; | |||
| bool output_dim_set = false; | |||
| bool none_is_one = true; | |||
| // Find which indices are 1. | |||
| for (int i = 0; i < kInputNum; ++i) { | |||
| if (reverse_shape[i][j] == 1) { | |||
| current_is_one[i] = true; | |||
| none_is_one = false; | |||
| } else { | |||
| current_is_one[i] = false; | |||
| if (!output_dim_set || reverse_shape[i][j] == static_cast<int64_t>(output_dim)) { | |||
| output_dim = reverse_shape[i][j]; | |||
| output_dim_set = true; | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "Input[0] and input[1] Cannot broadcast!"; | |||
| } | |||
| } | |||
| } | |||
| // All dimensions are 1. | |||
| if (!output_dim_set) { | |||
| for (int i = 0; i < kInputNum; ++i) { | |||
| grad_reduce_idx[i].push_back(largest_rank - 1 - j); | |||
| } | |||
| continue; | |||
| } else if (std::equal(current_is_one, current_is_one + kInputNum, prev_is_one) && set_one) { | |||
| for (int i = 0; i < kInputNum; ++i) { | |||
| if (current_is_one[i] && !none_is_one) { | |||
| grad_reduce_idx[i].push_back(largest_rank - 1 - j); | |||
| } | |||
| } | |||
| } else { | |||
| for (int i = 0; i < kInputNum; ++i) { | |||
| if (current_is_one[i] && !none_is_one) { | |||
| grad_reduce_idx[i].push_back(largest_rank - 1 - j); | |||
| } | |||
| } | |||
| } | |||
| set_one = true; | |||
| for (int i = 0; i < kInputNum; ++i) { | |||
| prev_is_one[i] = current_is_one[i]; | |||
| } | |||
| } | |||
| return grad_reduce_idx; | |||
| } | |||
| std::vector<std::vector<int64_t>> CalculateOutput(const std::vector<std::vector<int64_t>> &x) { | |||
| std::vector<std::vector<int64_t>> grad_reduce_idx(kInputNum); | |||
| bool all_equal = true; | |||
| size_t largest_rank = 0; | |||
| for (int i = 0; i < kInputNum; ++i) { | |||
| if (x[i] != x[0]) { | |||
| all_equal = false; | |||
| } | |||
| if (x[i].size() > largest_rank) { | |||
| largest_rank = x[i].size(); | |||
| } | |||
| } | |||
| if (all_equal) { | |||
| return grad_reduce_idx; | |||
| } | |||
| // Reverse input the shapes | |||
| std::vector<std::vector<int64_t>> reverse_shape(kInputNum); | |||
| for (int i = 0; i < kInputNum; ++i) { | |||
| reverse_shape[i] = x[i]; | |||
| std::reverse(reverse_shape[i].begin(), reverse_shape[i].end()); | |||
| } | |||
| // 1-extend and align all vectors. | |||
| for (int i = 0; i < kInputNum; ++i) { | |||
| if (reverse_shape[i].size() < largest_rank) { | |||
| reverse_shape[i].resize(largest_rank, 1); | |||
| } | |||
| } | |||
| grad_reduce_idx = GetGradientIndices(reverse_shape, largest_rank); | |||
| return grad_reduce_idx; | |||
| } | |||
| std::vector<int64_t> GetInputShape(const CNodePtr &cnode, size_t index) { | |||
| auto address_x = AnfAlgo::GetPrevNodeMutableOutputAddr(cnode, index); | |||
| auto shape_x = AnfAlgo::GetPrevNodeOutputInferShape(cnode, index); | |||
| auto type_x = AnfAlgo::GetOutputInferDataType(cnode, index); | |||
| if (type_x != TypeId::kNumberTypeInt64) { | |||
| MS_LOG(EXCEPTION) << "Input x type must be int64, but :" << type_x; | |||
| } | |||
| if (shape_x.size() != 1) { | |||
| MS_LOG(EXCEPTION) << "Input" << index << " must be [1-D], but " << shape_x.size() << "-D."; | |||
| } | |||
| size_t x_num = shape_x[0]; | |||
| std::vector<int64_t> x{SizeToLong(x_num)}; | |||
| auto x_shape_value = std::make_shared<tensor::Tensor>(type_x, x); | |||
| x_shape_value->set_device_address(address_x); | |||
| x_shape_value->data_sync(); | |||
| auto x_value = reinterpret_cast<int64_t *>(x_shape_value->data_c()); | |||
| MS_EXCEPTION_IF_NULL(x_value); | |||
| std::vector<int64_t> input_shape = {x_value, x_value + x_num}; | |||
| return input_shape; | |||
| } | |||
| size_t SetOutputValue(const CNodePtr &cnode, const std::vector<std::vector<int64_t>> &grad_reduce_idx, size_t index, | |||
| size_t input_num) { | |||
| std::vector<int64_t> output; | |||
| size_t idx_num = grad_reduce_idx[index].size(); | |||
| for (size_t k = 0; k < idx_num; ++k) { | |||
| output.push_back(grad_reduce_idx[index][idx_num - 1 - k]); | |||
| } | |||
| auto out_addr = AnfAlgo::GetOutputAddr(cnode, index); | |||
| MS_EXCEPTION_IF_NULL(out_addr); | |||
| size_t out_size = idx_num; | |||
| if (idx_num == 0) { | |||
| out_size = input_num; | |||
| for (size_t k = 0; k < input_num; ++k) { | |||
| output.push_back(k); | |||
| } | |||
| } | |||
| std::vector<int64_t> out_shape{SizeToLong(out_size)}; | |||
| auto output_type = TypeId::kNumberTypeInt64; | |||
| auto tensor_for_sync = std::make_shared<tensor::Tensor>(output_type, out_shape); | |||
| auto data_ptr = static_cast<int64_t *>(tensor_for_sync->data_c()); | |||
| for (size_t i = 0; i < out_size; ++i) { | |||
| MS_LOG(DEBUG) << "DEBUG r" << index << "_output_shape[" << i << "]:" << output[i]; | |||
| *(data_ptr + i) = output[i]; | |||
| } | |||
| out_addr->SyncHostToDevice(out_shape, LongToSize(tensor_for_sync->data().nbytes()), tensor_for_sync->data_type(), | |||
| tensor_for_sync->data_c(), tensor_for_sync->device_info().host_format_); | |||
| return out_size; | |||
| } | |||
| } // namespace | |||
| void DynamicBroadcastGradientArgsKernel::Execute() { | |||
| MS_LOG(INFO) << "Execute DynamicBroadcastGradientArgsKernel Start"; | |||
| auto cnode = cnode_ptr_.lock(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| auto input_num = AnfAlgo::GetInputTensorNum(cnode); | |||
| if (input_num != 2) { | |||
| MS_LOG(EXCEPTION) << "Invalid Input Num:" << input_num; | |||
| } | |||
| std::vector<std::vector<int64_t>> input_shapes(kInputNum); | |||
| input_shapes[0] = GetInputShape(cnode, 0); | |||
| input_shapes[1] = GetInputShape(cnode, 1); | |||
| auto grad_reduce_idx = CalculateOutput(input_shapes); | |||
| auto r0_size = SetOutputValue(cnode, grad_reduce_idx, 0, input_shapes[0].size()); | |||
| auto r1_size = SetOutputValue(cnode, grad_reduce_idx, 1, input_shapes[1].size()); | |||
| std::vector<size_t> r0_shp{r0_size}; | |||
| std::vector<size_t> r1_shp{r1_size}; | |||
| auto output_type = TypeId::kNumberTypeInt64; | |||
| AnfAlgo::SetOutputInferTypeAndShape({output_type, output_type}, {r0_shp, r1_shp}, cnode.get()); | |||
| MS_LOG(INFO) << "Execute DynamicBroadcastGradientArgsKernel End"; | |||
| } | |||
| device::DynamicKernelPtr DynamicBroadcastGradientArgsKernelMod::GenDynamicKernel(const CNodePtr &cnode_ptr, | |||
| void *stream_ptr) { | |||
| return std::make_shared<DynamicBroadcastGradientArgsKernel>(stream_ptr, cnode_ptr); | |||
| } | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,44 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_HOST_DYNAMIC_BROADCAST_GRADIENT_ARGS_KERNEL_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_HOST_DYNAMIC_BROADCAST_GRADIENT_ARGS_KERNEL_H_ | |||
| #include <vector> | |||
| #include <memory> | |||
| #include <string> | |||
| #include "runtime/device/ascend/executor/host_dynamic_kernel.h" | |||
| #include "backend/kernel_compiler/host/host_kernel_mod.h" | |||
| using HostDynamicKernel = mindspore::device::ascend::HostDynamicKernel; | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| class DynamicBroadcastGradientArgsKernel : public HostDynamicKernel { | |||
| public: | |||
| DynamicBroadcastGradientArgsKernel(void *stream, const CNodePtr &cnode_ptr) : HostDynamicKernel(stream, cnode_ptr) {} | |||
| ~DynamicBroadcastGradientArgsKernel() override = default; | |||
| void Execute() override; | |||
| }; | |||
| class DynamicBroadcastGradientArgsKernelMod : public HostKernelMod { | |||
| public: | |||
| DynamicBroadcastGradientArgsKernelMod() = default; | |||
| ~DynamicBroadcastGradientArgsKernelMod() override = default; | |||
| device::DynamicKernelPtr GenDynamicKernel(const CNodePtr &cnode_ptr, void *stream_ptr) override; | |||
| }; | |||
| MS_HOST_REG_KERNEL(DynamicBroadcastGradientArgs, DynamicBroadcastGradientArgsKernelMod); | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_HOST_DYNAMIC_BROADCAST_GRADIENT_ARGS_KERNEL_H_ | |||
| @@ -29,11 +29,6 @@ void HostMetadataInfo(const CNodePtr &kernel_node, std::vector<std::shared_ptr<K | |||
| MS_LOG(INFO) << "HostMetadataInfo."; | |||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||
| MS_EXCEPTION_IF_NULL(kernel_info_list); | |||
| std::string op_name = AnfAlgo::GetCNodeName(kernel_node); | |||
| if (op_name != kDynamicShape) { | |||
| MS_LOG(DEBUG) << "Host does not have op [" << op_name << "]"; | |||
| return; | |||
| } | |||
| std::vector<std::string> inputs_format{}; | |||
| std::vector<TypeId> inputs_type{}; | |||
| @@ -74,6 +74,7 @@ constexpr auto kFastGeLU = "FastGeLU"; | |||
| constexpr auto kFastGeLUGrad = "FastGeLUGrad"; | |||
| constexpr auto kZerosLike = "ZerosLike"; | |||
| constexpr auto kOnesLike = "OnesLike"; | |||
| constexpr auto kDynamicBroadcastGradientArgs = "DynamicBroadcastGradientArgs"; | |||
| // NN | |||
| constexpr auto kCTCLoss = "CTCLoss"; | |||
| @@ -632,6 +633,8 @@ inline const PrimitivePtr kPrimStringConcat = std::make_shared<Primitive>("strin | |||
| inline const PrimitivePtr kPrimDictLen = std::make_shared<Primitive>("dict_len"); | |||
| inline const PrimitivePtr kPrimFakeBprop = std::make_shared<Primitive>("fake_bprop"); | |||
| inline const PrimitivePtr kPrimBroadcastGradientArgs = std::make_shared<Primitive>("BroadcastGradientArgs"); | |||
| inline const PrimitivePtr kPrimDynamicBroadcastGradientArgs = | |||
| std::make_shared<Primitive>(kDynamicBroadcastGradientArgs); | |||
| class DoSignaturePrimitive : public Primitive { | |||
| public: | |||
| @@ -0,0 +1,80 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "ops/dynamic_broadcast_gradient_args.h" | |||
| #include <set> | |||
| #include <string> | |||
| #include <vector> | |||
| #include <map> | |||
| #include <memory> | |||
| #include "ops/op_utils.h" | |||
| namespace mindspore { | |||
| namespace ops { | |||
| namespace { | |||
| size_t CheckInputsAndGetShape(const AbstractBasePtr &input_arg, const string &prim_name) { | |||
| MS_EXCEPTION_IF_NULL(input_arg); | |||
| if (input_arg->isa<abstract::AbstractTensor>()) { | |||
| auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_arg->BuildShape())[kShape]; | |||
| auto input_size = input_shape.size(); | |||
| if (input_size != 1) { | |||
| MS_EXCEPTION(TypeError) << prim_name << " input must be 1-D, but dims is " << input_size; | |||
| } | |||
| if (input_shape[0] == abstract::Shape::SHP_ANY) { | |||
| auto max_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_arg->BuildShape())[kMaxShape]; | |||
| if (max_shape.empty()) { | |||
| MS_LOG(EXCEPTION) << prim_name << " input shape is dynamic, but max shape is empty."; | |||
| } | |||
| return max_shape[0]; | |||
| } | |||
| return input_shape[0]; | |||
| } else if (input_arg->isa<abstract::AbstractTuple>()) { | |||
| auto x_shape = dyn_cast<abstract::AbstractTuple>(input_arg); | |||
| auto x_shape_data = x_shape->elements(); | |||
| return x_shape_data.size(); | |||
| } else { | |||
| MS_EXCEPTION(TypeError) << prim_name << " input must be a tuple or Tensor."; | |||
| } | |||
| } | |||
| abstract::TupleShapePtr Infer(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | |||
| MS_EXCEPTION_IF_NULL(primitive); | |||
| auto prim_name = primitive->name(); | |||
| (void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, 2, prim_name); | |||
| auto x_shape = CheckInputsAndGetShape(input_args[0], prim_name); | |||
| auto y_shape = CheckInputsAndGetShape(input_args[1], prim_name); | |||
| ShapeVector shape{abstract::Shape::SHP_ANY}; | |||
| ShapeVector min_shape{1L}; | |||
| size_t max_size = x_shape > y_shape ? x_shape : y_shape; | |||
| ShapeVector max_shape{SizeToLong(max_size)}; | |||
| auto out_shape = std::make_shared<abstract::Shape>(shape, min_shape, max_shape); | |||
| return std::make_shared<abstract::TupleShape>(std::vector<abstract::BaseShapePtr>{out_shape, out_shape}); | |||
| } | |||
| } // namespace | |||
| AbstractBasePtr DynamicBroadcastGradientArgsInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const std::vector<AbstractBasePtr> &input_args) { | |||
| auto types = std::vector<TypePtr>{kInt64, kInt64}; | |||
| auto output_type = std::make_shared<Tuple>(types); | |||
| return abstract::MakeAbstract(Infer(primitive, input_args), output_type); | |||
| } | |||
| REGISTER_PRIMITIVE_EVAL_IMPL(DynamicBroadcastGradientArgs, prim::kPrimDynamicBroadcastGradientArgs, | |||
| DynamicBroadcastGradientArgsInfer, nullptr, true); | |||
| } // namespace ops | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,39 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CORE_OPS_DYNAMIC_BROADCAST_GRADIENT_ARGS_H_ | |||
| #define MINDSPORE_CORE_OPS_DYNAMIC_BROADCAST_GRADIENT_ARGS_H_ | |||
| #include <vector> | |||
| #include <memory> | |||
| #include "ops/primitive_c.h" | |||
| #include "utils/check_convert_utils.h" | |||
| namespace mindspore { | |||
| namespace ops { | |||
| class DynamicBroadcastGradientArgs : public PrimitiveC { | |||
| public: | |||
| DynamicBroadcastGradientArgs() : PrimitiveC(prim::kPrimDynamicBroadcastGradientArgs->name()) {} | |||
| ~DynamicBroadcastGradientArgs() = default; | |||
| MS_DECLARE_PARENT(DynamicBroadcastGradientArgs, PrimitiveC); | |||
| void Init() {} | |||
| }; | |||
| AbstractBasePtr DynamicBroadcastGradientArgsInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const std::vector<AbstractBasePtr> &input_args); | |||
| using PrimDynamicBroadcastGradientArgsPtr = std::shared_ptr<DynamicBroadcastGradientArgs>; | |||
| } // namespace ops | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CORE_OPS_DYNAMIC_BROADCAST_GRADIENT_ARGS_H_ | |||
| @@ -1,40 +0,0 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """DynamicShape op""" | |||
| from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType | |||
| dynamic_shape_op_info = AiCPURegOp("DynamicShape") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .input(0, "x", "required") \ | |||
| .output(0, "y", "required") \ | |||
| .dtype_format(DataType.I8_Default, DataType.I32_Default) \ | |||
| .dtype_format(DataType.I16_Default, DataType.I32_Default) \ | |||
| .dtype_format(DataType.I32_Default, DataType.I32_Default) \ | |||
| .dtype_format(DataType.I64_Default, DataType.I32_Default) \ | |||
| .dtype_format(DataType.U8_Default, DataType.I32_Default) \ | |||
| .dtype_format(DataType.U16_Default, DataType.I32_Default) \ | |||
| .dtype_format(DataType.U32_Default, DataType.I32_Default) \ | |||
| .dtype_format(DataType.U64_Default, DataType.I32_Default) \ | |||
| .dtype_format(DataType.F16_Default, DataType.I32_Default) \ | |||
| .dtype_format(DataType.F32_Default, DataType.I32_Default) \ | |||
| .dtype_format(DataType.F64_Default, DataType.I32_Default) \ | |||
| .dtype_format(DataType.BOOL_Default, DataType.I32_Default) \ | |||
| .get_op_info() | |||
| @op_info_register(dynamic_shape_op_info) | |||
| def _dynamic_shape_aicpu(): | |||
| """Unique AiCPU register""" | |||
| return | |||
| @@ -21,7 +21,7 @@ from ..._checkparam import Rel | |||
| from ..._checkparam import Validator as validator | |||
| from ... import context | |||
| from ...common import dtype as mstype | |||
| from ..primitive import PrimitiveWithCheck, PrimitiveWithInfer, prim_attr_register | |||
| from ..primitive import PrimitiveWithCheck, PrimitiveWithInfer, prim_attr_register, Primitive | |||
| from ..operations.math_ops import _infer_shape_reduce | |||
| from ...communication.management import GlobalComm | |||
| from .. import signature as sig | |||
| @@ -1111,3 +1111,45 @@ class DynamicStitch(PrimitiveWithCheck): | |||
| mstype.number_type + (mstype.bool_,), self.name) | |||
| validator.check(f"type of data[{i}]", data_type[i], f"type of data[0]", data_type[0], Rel.EQ, self.name) | |||
| return data_type[0] | |||
| class DynamicBroadcastGradientArgs(Primitive): | |||
| """ | |||
| Broadcast the two input shapes, return the dimensions that each need to be broadcast. | |||
| Input shape `s0` and shape `s1` can be broadcast to a common shape if for each dimension pair they are either equal | |||
| or input is one or the target dimension is -1. In case of -1 in target shape, it will be replaced by the input | |||
| shape's value in that dimension. | |||
| Inputs: | |||
| - **s0** (Tensor) - A `1-D` tensor. The data type should be one of the following types: int32, int64, | |||
| uint32, uint64. | |||
| - **s1** (Tensor) - A `1-D` tensor with the same type as `s0`. | |||
| Outputs: | |||
| Tuple(Tensor), tuple of 2 tensors, r0 and r1. The first one is the index tensor and the other one is the mask | |||
| tensor. | |||
| - **r0** (Tensor) - The output shape is 1-D with the same type as s0. | |||
| - **r1** (Tensor) - The output shape is 1-D with the same type as s0. | |||
| Raises: | |||
| ValueError: if the `s0` and `s1` are incompatible, or if a - 1 in the target shape is in an invalid | |||
| location. | |||
| Supported Platforms: | |||
| ``Ascend`` | |||
| Examples: | |||
| >>> shape0 = (4, 2, 1) | |||
| >>> shape1 = (2, 7) | |||
| >>> from mindspore.ops.operations import _inner_ops | |||
| >>> args = _inner_ops.DynamicBroadcastGradientArgs() | |||
| >>> r0, r1 = args(Tensor(shape0), Tensor(shape1)) | |||
| >>> print(r0, r1) | |||
| [2], [0] | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self): | |||
| """Init BroadcastGradientArgs""" | |||
| @@ -0,0 +1,44 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| import mindspore.nn as nn | |||
| import mindspore.context as context | |||
| from mindspore.ops.operations import _inner_ops | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||
| class Net(nn.Cell): | |||
| def __init__(self): | |||
| super(Net, self).__init__() | |||
| self.args = _inner_ops.BroadcastGradientArgs() | |||
| def construct(self, s0, s1): | |||
| return self.args(s0, s1) | |||
| def test_net(): | |||
| shape0 = (4, 2, 1) | |||
| shape1 = (2, 7) | |||
| net = Net() | |||
| r0, r1 = net(shape0, shape1) | |||
| print(r0, r1) | |||
| r0_expected = [2] | |||
| r1_expected = [0] | |||
| assert np.array_equal(r0_expected, r0.asnumpy()) | |||
| assert np.array_equal(r1_expected, r1.asnumpy()) | |||