| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -37,7 +37,7 @@ class AvgPoolingGradCPUKernel : public MKLCPUKernel { | |||
| std::vector<size_t> kernel_size_; | |||
| }; | |||
| MS_REG_CPU_KERNEL(AvgPoolGradCpu, | |||
| MS_REG_CPU_KERNEL(AvgPoolGrad, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -32,14 +32,14 @@ MS_REG_GPU_KERNEL_ONE(MaxPoolGrad, | |||
| .AddInputAttr(kNumberTypeFloat16) | |||
| .AddOutputAttr(kNumberTypeFloat16), | |||
| PoolingGradGpuKernel, half) | |||
| MS_REG_GPU_KERNEL_ONE(AvgPoolGradGpu, | |||
| MS_REG_GPU_KERNEL_ONE(AvgPoolGrad, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat32), | |||
| PoolingGradGpuKernel, float) | |||
| MS_REG_GPU_KERNEL_ONE(AvgPoolGradGpu, | |||
| MS_REG_GPU_KERNEL_ONE(AvgPoolGrad, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeFloat16) | |||
| .AddInputAttr(kNumberTypeFloat16) | |||
| @@ -254,7 +254,7 @@ class PoolingGradGpuKernel : public GpuKernel { | |||
| } | |||
| void SetPoolingMode(const CNodePtr &kernel_node) { | |||
| mode_ = AnfAlgo::GetCNodeName(kernel_node); | |||
| if (mode_ == "AvgPoolGradGpu") { | |||
| if (mode_ == "AvgPoolGrad") { | |||
| pooling_mode_ = CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING; | |||
| pad_value_ = 0.0; | |||
| } else { | |||
| @@ -0,0 +1,200 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "backend/optimizer/ascend/mindir/avg_pool_grad_unify_mindir.h" | |||
| #include <vector> | |||
| #include <memory> | |||
| #include <algorithm> | |||
| #include <functional> | |||
| #include <string> | |||
| #include "utils/utils.h" | |||
| #include "utils/ms_context.h" | |||
| #include "utils/check_convert_utils.h" | |||
| #include "backend/optimizer/common/helper.h" | |||
| #include "runtime/device/kernel_info.h" | |||
| #include "backend/session/anf_runtime_algorithm.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| namespace { | |||
| constexpr size_t kAvgPoolGradInputNum = 3; | |||
| constexpr size_t kShapeDimNum = 4; | |||
| constexpr float kKernelMatrixInitNum = 1.0; | |||
| constexpr size_t kFloat32Len = 4; // size of float32 | |||
| std::vector<int64_t> GetInputXShape(const AnfNodePtr &node) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| std::vector<int64_t> shapes; | |||
| auto shape_size_t = AnfAlgo::GetPrevNodeOutputInferShape(node, 0); | |||
| std::transform(shape_size_t.begin(), shape_size_t.end(), std::back_inserter(shapes), SizeToLong); | |||
| return shapes; | |||
| } | |||
| int64_t windowed_output_size(int64_t input_size, int64_t ksize, int64_t stride, PadMode pad_mode, int64_t *pad_before, | |||
| int64_t *pad_after) { | |||
| int64_t output = 0; | |||
| *pad_before = 0; | |||
| *pad_after = 0; | |||
| if (pad_mode == PadMode::VALID) { | |||
| output = (input_size - ksize + stride) / stride; | |||
| } else if (pad_mode == PadMode::SAME) { | |||
| output = (input_size + stride - 1) / stride; | |||
| int64_t pad_need = std::max(int64_t(0), (output - 1) * stride + ksize - input_size); | |||
| *pad_before = pad_need / 2; | |||
| *pad_after = pad_need - *pad_before; | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "The pad mode of AvgPoolGrad should be SAME or VALID."; | |||
| } | |||
| return output; | |||
| } | |||
| ValueNodePtr CreateMeanMatrixValueNode(const FuncGraphPtr &func_graph, const std::vector<int64_t> &x_shape, | |||
| const std::vector<int64_t> &k_size, const std::vector<int64_t> &stride, | |||
| const PadMode pad_mode, const TypeId x_dtype) { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| auto kernel_graph = func_graph->cast<KernelGraphPtr>(); | |||
| MS_EXCEPTION_IF_NULL(kernel_graph); | |||
| if (x_shape.size() != kShapeDimNum || k_size.size() != kShapeDimNum || stride.size() != kShapeDimNum) { | |||
| MS_LOG(EXCEPTION) << "The dim of x_shape or kernel_size or strides of AvgPoolGrad should be 4."; | |||
| } | |||
| int64_t pad_top, pad_bottom, pad_left, pad_right; | |||
| int64_t h_output = windowed_output_size(x_shape[2], k_size[2], stride[2], pad_mode, &pad_top, &pad_bottom); | |||
| int64_t w_output = windowed_output_size(x_shape[3], k_size[3], stride[3], pad_mode, &pad_left, &pad_right); | |||
| // `assist_input_matrix` is a 2d matrix with input_shape after padding, | |||
| // the value of element which is padded is 0, else are 1. | |||
| // For each element of output, it is mapped for slide window: `[h*h_stride : h*h_stride + h_ksize, | |||
| // w*w_stride : w*w_stride + w_ksize]` of `assist_input_matrix`, so the sum of slide window is the | |||
| // number of input that associate with output element. | |||
| std::vector<std::vector<float>> assist_input_matrix; | |||
| std::vector<int64_t> in_shape_after_padding_2d = {x_shape[2] + pad_top + pad_bottom, | |||
| x_shape[3] + pad_left + pad_right}; | |||
| std::vector<float> tmp_zero_vector(in_shape_after_padding_2d[1], 0.0); | |||
| std::vector<float> tmp_one_vector(in_shape_after_padding_2d[1], 1.0); | |||
| for (int64_t i = 0; i < in_shape_after_padding_2d[1]; ++i) { | |||
| if (i < pad_left || i >= (in_shape_after_padding_2d[1] - pad_right)) { | |||
| tmp_one_vector[i] = 0.0; | |||
| } | |||
| } | |||
| for (int64_t i = 0; i < in_shape_after_padding_2d[0]; ++i) { | |||
| if (i < pad_top || i >= (in_shape_after_padding_2d[0] - pad_bottom)) { | |||
| assist_input_matrix.emplace_back(tmp_zero_vector); | |||
| } else { | |||
| assist_input_matrix.emplace_back(tmp_one_vector); | |||
| } | |||
| } | |||
| // calculate output | |||
| std::vector<float> hw_output(h_output * w_output, 0.0); | |||
| for (int64_t h = 0; h < h_output; ++h) { | |||
| for (int64_t w = 0; w < w_output; ++w) { | |||
| float curr_sum = 0; | |||
| for (int64_t i = h * stride[2]; i < h * stride[2] + k_size[2]; ++i) { | |||
| for (int64_t j = w * stride[3]; j < w * stride[3] + k_size[3]; ++j) { | |||
| curr_sum += assist_input_matrix[i][j]; | |||
| } | |||
| } | |||
| if (curr_sum > 0) { | |||
| hw_output[h * w_output + w] = 1.0 / curr_sum; | |||
| } | |||
| } | |||
| } | |||
| // make output tensor | |||
| std::vector<int64_t> output_shape = {x_shape[0], x_shape[1], h_output, w_output}; | |||
| auto output_size = std::accumulate(output_shape.begin(), output_shape.end(), int64_t(1), std::multiplies<int64_t>()); | |||
| std::vector<float> output(output_size, 0.0); | |||
| for (int64_t i = 0; i < output_shape[0] * output_shape[1]; ++i) { | |||
| size_t copy_size = hw_output.size() * kFloat32Len; | |||
| (void)memcpy_s(&output[i * hw_output.size()], copy_size, &hw_output[0], copy_size); | |||
| } | |||
| auto output_tensor = std::make_shared<tensor::Tensor>(x_dtype, output_shape, &output[0], kNumberTypeFloat32); | |||
| MS_EXCEPTION_IF_NULL(output_tensor); | |||
| auto abstract = std::make_shared<abstract::AbstractTensor>(TypeIdToType(x_dtype), output_shape); | |||
| MS_EXCEPTION_IF_NULL(abstract); | |||
| auto mean_matrix_vnode = kernel_graph->NewValueNode(abstract, output_tensor); | |||
| MS_EXCEPTION_IF_NULL(mean_matrix_vnode); | |||
| kernel_graph->AddValueNodeToGraph(mean_matrix_vnode); | |||
| return mean_matrix_vnode; | |||
| } | |||
| ValueNodePtr CreateKernelMatrixValueNode(const FuncGraphPtr &func_graph, const std::vector<int64_t> &x_shape, | |||
| const std::vector<int64_t> &k_size, const TypeId x_dtype) { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| auto kernel_graph = func_graph->cast<KernelGraphPtr>(); | |||
| MS_EXCEPTION_IF_NULL(kernel_graph); | |||
| if (x_shape.size() != kShapeDimNum || k_size.size() != kShapeDimNum) { | |||
| MS_LOG(EXCEPTION) << "The dim of x_shape or kernel_size of AvgPoolGrad should be 4."; | |||
| } | |||
| std::vector<int64_t> kernel_shape = {1, x_shape[1], k_size[2], k_size[3]}; | |||
| auto data_size = std::accumulate(kernel_shape.begin(), kernel_shape.end(), int64_t(1), std::multiplies<int64_t>()); | |||
| std::vector<float> data(data_size, kKernelMatrixInitNum); | |||
| auto kernel_matrix_tensor = std::make_shared<tensor::Tensor>(x_dtype, kernel_shape, &data[0], kNumberTypeFloat32); | |||
| MS_EXCEPTION_IF_NULL(kernel_matrix_tensor); | |||
| auto abstract = std::make_shared<abstract::AbstractTensor>(TypeIdToType(x_dtype), kernel_shape); | |||
| MS_EXCEPTION_IF_NULL(abstract); | |||
| auto kernel_matrix_vnode = kernel_graph->NewValueNode(abstract, kernel_matrix_tensor); | |||
| MS_EXCEPTION_IF_NULL(kernel_matrix_vnode); | |||
| kernel_graph->AddValueNodeToGraph(kernel_matrix_vnode); | |||
| return kernel_matrix_vnode; | |||
| } | |||
| } // namespace | |||
| const BaseRef AvgPoolGradUnifyMindIR::DefinePattern() const { | |||
| VarPtr X1 = std::make_shared<Var>(); | |||
| VarPtr X2 = std::make_shared<Var>(); | |||
| VarPtr G = std::make_shared<Var>(); | |||
| VectorRef pattern({prim::kPrimAvgPoolGrad, X1, X2, G}); | |||
| return pattern; | |||
| } | |||
| const AnfNodePtr AvgPoolGradUnifyMindIR::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, | |||
| const EquivPtr &) const { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| auto avgpool_grad = CheckAnfNodeIfCNodeAndInputSize(node, kAvgPoolGradInputNum); | |||
| auto x_shape = GetInputXShape(avgpool_grad); | |||
| auto x_dtype = AnfAlgo::GetPrevNodeOutputInferDataType(avgpool_grad, 0); | |||
| auto k_size = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(avgpool_grad, kAttrKernelSize); | |||
| auto stride = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(avgpool_grad, kAttrStrides); | |||
| auto pad_mode = PadMode(AnfAlgo::GetNodeAttr<int64_t>(avgpool_grad, kAttrPadMode)); | |||
| auto x_shape_vnode = CreateShapeValueNode(graph, x_shape); | |||
| auto mean_matrix_vnode = CreateMeanMatrixValueNode(graph, x_shape, k_size, stride, pad_mode, x_dtype); | |||
| auto kernel_matrix_vnode = CreateKernelMatrixValueNode(graph, x_shape, k_size, x_dtype); | |||
| std::vector<AnfNodePtr> avgpool_grad_vm_inputs = {NewValueNode(std::make_shared<Primitive>(kAvgPoolGradVmOpName)), | |||
| x_shape_vnode, avgpool_grad->input(3), mean_matrix_vnode, | |||
| kernel_matrix_vnode}; | |||
| auto avgpool_grad_vm = graph->NewCNode(avgpool_grad_vm_inputs); | |||
| MS_EXCEPTION_IF_NULL(avgpool_grad_vm); | |||
| avgpool_grad_vm->set_scope(avgpool_grad->scope()); | |||
| avgpool_grad_vm->set_abstract(avgpool_grad->abstract()); | |||
| AnfAlgo::CopyNodeAttr(kAttrKernelSize, avgpool_grad, avgpool_grad_vm); | |||
| AnfAlgo::CopyNodeAttr(kAttrStrides, avgpool_grad, avgpool_grad_vm); | |||
| AnfAlgo::CopyNodeAttr(kAttrPadMode, avgpool_grad, avgpool_grad_vm); | |||
| AnfAlgo::CopyNodeAttr(kAttrFormat, avgpool_grad, avgpool_grad_vm); | |||
| auto input_names = std::vector<std::string>{"x_origin", "grad", "mean_matrix", "kernel_matrix"}; | |||
| AnfAlgo::SetNodeAttr(kAttrInputNames, MakeValue(input_names), avgpool_grad_vm); | |||
| auto output_names = std::vector<std::string>{"output"}; | |||
| AnfAlgo::SetNodeAttr(kAttrOutputNames, MakeValue(output_names), avgpool_grad_vm); | |||
| return avgpool_grad_vm; | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,34 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_AVG_POOL_GRAD_UNIFY_MINDIR_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_AVG_POOL_GRAD_UNIFY_MINDIR_H_ | |||
| #include <memory> | |||
| #include "backend/optimizer/common/optimizer.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| class AvgPoolGradUnifyMindIR : public PatternProcessPass { | |||
| public: | |||
| explicit AvgPoolGradUnifyMindIR(bool multigraph = true) | |||
| : PatternProcessPass("avg_pool_grad_unify_mindir", multigraph) {} | |||
| ~AvgPoolGradUnifyMindIR() override = default; | |||
| const BaseRef DefinePattern() const override; | |||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | |||
| }; | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_AVG_POOL_GRAD_UNIFY_MINDIR_H_ | |||
| @@ -104,47 +104,6 @@ ValueNodePtr CreateKeepPorbValueNode(const FuncGraphPtr &func_graph, const AnfNo | |||
| return keep_prob_value; | |||
| } | |||
| ValueNodePtr CreateShapeValueNode(const FuncGraphPtr &func_graph, const std::vector<int64_t> &shape, | |||
| bool is_pynative = false) { | |||
| MS_LOG(INFO) << "CreateShapeValueNode start."; | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| auto kernel_graph = func_graph->cast<KernelGraphPtr>(); | |||
| MS_EXCEPTION_IF_NULL(kernel_graph); | |||
| ValuePtr shape_value = nullptr; | |||
| AbstractBasePtr abstract = nullptr; | |||
| if (is_pynative) { | |||
| // pynative mode need to create tensor | |||
| int64_t shape_dim = SizeToLong(shape.size()); | |||
| std::vector<int64_t> shape_vec_shape = {shape_dim}; | |||
| auto shape_tensor = std::make_shared<tensor::Tensor>(kNumberTypeInt64, shape_vec_shape); | |||
| MS_EXCEPTION_IF_NULL(shape_tensor); | |||
| auto data_ptr = shape_tensor->data_c(); | |||
| MS_EXCEPTION_IF_NULL(data_ptr); | |||
| auto elem_num = shape.size() * kInt64Len; | |||
| auto ret_code = memcpy_s(data_ptr, static_cast<size_t>(shape_tensor->data().nbytes()), &shape[0], elem_num); | |||
| if (ret_code != 0) { | |||
| MS_LOG(EXCEPTION) << "Failed to copy data into Tensor."; | |||
| } | |||
| shape_value = shape_tensor; | |||
| abstract = std::make_shared<abstract::AbstractTensor>(kInt64, shape_vec_shape); | |||
| } else { | |||
| std::vector<ValuePtr> dim_values{}; | |||
| abstract::AbstractBasePtrList abs{}; | |||
| for (const auto &dim : shape) { | |||
| dim_values.push_back(MakeValue(dim)); | |||
| abs.push_back(std::make_shared<abstract::AbstractScalar>(dim)); | |||
| } | |||
| shape_value = std::make_shared<ValueTuple>(dim_values); | |||
| abstract = std::make_shared<abstract::AbstractTuple>(abs); | |||
| } | |||
| MS_EXCEPTION_IF_NULL(shape_value); | |||
| MS_EXCEPTION_IF_NULL(abstract); | |||
| auto shape_value_node = kernel_graph->NewValueNode(abstract, shape_value); | |||
| MS_EXCEPTION_IF_NULL(shape_value_node); | |||
| kernel_graph->AddValueNodeToGraph(shape_value_node); | |||
| return shape_value_node; | |||
| } | |||
| std::vector<int64_t> CalDropoutGenMaskOutput(const std::vector<int64_t> &shape) { | |||
| auto output_size = std::accumulate(shape.begin(), shape.end(), static_cast<int64_t>(1), std::multiplies<int64_t>()); | |||
| auto output_count = output_size / kMaskAlignNum; | |||
| @@ -35,6 +35,8 @@ | |||
| namespace mindspore { | |||
| namespace opt { | |||
| constexpr size_t kType32Len = 4; | |||
| constexpr size_t kType64Len = 8; | |||
| std::vector<int64_t> Convert2Int(const std::vector<size_t> &v) { | |||
| std::vector<int64_t> result; | |||
| (void)std::transform(v.begin(), v.end(), std::back_inserter(result), SizeToInt); | |||
| @@ -495,6 +497,46 @@ CNodePtr CreatTupleGetItemNode(const FuncGraphPtr &func_graph, const AnfNodePtr | |||
| return tuple_getitem; | |||
| } | |||
| ValueNodePtr CreateShapeValueNode(const FuncGraphPtr &func_graph, const std::vector<int64_t> &shape, bool to_tensor) { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| auto kernel_graph = func_graph->cast<KernelGraphPtr>(); | |||
| MS_EXCEPTION_IF_NULL(kernel_graph); | |||
| ValuePtr shape_value = nullptr; | |||
| AbstractBasePtr abstract = nullptr; | |||
| if (to_tensor) { | |||
| // create Tensor | |||
| int64_t shape_dim = SizeToLong(shape.size()); | |||
| std::vector<int64_t> shape_vec_shape = {shape_dim}; | |||
| auto shape_tensor = std::make_shared<tensor::Tensor>(kNumberTypeInt64, shape_vec_shape); | |||
| MS_EXCEPTION_IF_NULL(shape_tensor); | |||
| auto data_ptr = shape_tensor->data_c(); | |||
| MS_EXCEPTION_IF_NULL(data_ptr); | |||
| auto elem_num = shape.size() * kType64Len; | |||
| auto ret_code = memcpy_s(data_ptr, static_cast<size_t>(shape_tensor->data().nbytes()), &shape[0], elem_num); | |||
| if (ret_code != 0) { | |||
| MS_LOG(EXCEPTION) << "Failed to copy data into Tensor."; | |||
| } | |||
| shape_value = shape_tensor; | |||
| abstract = std::make_shared<abstract::AbstractTensor>(kInt64, shape_vec_shape); | |||
| } else { | |||
| // create ValueTuple | |||
| std::vector<ValuePtr> dim_values{}; | |||
| abstract::AbstractBasePtrList abs{}; | |||
| for (const auto &dim : shape) { | |||
| dim_values.push_back(MakeValue(dim)); | |||
| abs.push_back(std::make_shared<abstract::AbstractScalar>(dim)); | |||
| } | |||
| shape_value = std::make_shared<ValueTuple>(dim_values); | |||
| abstract = std::make_shared<abstract::AbstractTuple>(abs); | |||
| } | |||
| MS_EXCEPTION_IF_NULL(shape_value); | |||
| MS_EXCEPTION_IF_NULL(abstract); | |||
| auto shape_value_node = kernel_graph->NewValueNode(abstract, shape_value); | |||
| MS_EXCEPTION_IF_NULL(shape_value_node); | |||
| kernel_graph->AddValueNodeToGraph(shape_value_node); | |||
| return shape_value_node; | |||
| } | |||
| void ConstInputToAttr(const CNodePtr &cnode, const std::unordered_set<size_t> &input_attrs) { | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| std::vector<AnfNodePtr> new_inputs; | |||
| @@ -168,6 +168,9 @@ void RemoveNopNode(session::KernelGraph *const graph); | |||
| CNodePtr CreatTupleGetItemNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, size_t output_idx); | |||
| ValueNodePtr CreateShapeValueNode(const FuncGraphPtr &func_graph, const std::vector<int64_t> &shape, | |||
| bool to_tensor = false); | |||
| bool IsUsedByOthers(const FuncGraphPtr &graph, const AnfNodePtr &node); | |||
| std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> GetRealNodeUsedList(const FuncGraphPtr &graph, | |||
| @@ -41,6 +41,7 @@ | |||
| #include "backend/optimizer/ascend/mindir/optimizer_unify_output.h" | |||
| #include "backend/optimizer/ascend/mindir/sparse_softmax_cross_entropy_with_logits_unify_mindir.h" | |||
| #include "backend/optimizer/ascend/mindir/slice_grad_unify_mindir.h" | |||
| #include "backend/optimizer/ascend/mindir/avg_pool_grad_unify_mindir.h" | |||
| #include "runtime/device/kernel_adjust.h" | |||
| #include "runtime/device/ascend/ascend_stream_assign.h" | |||
| #include "backend/session/anf_runtime_algorithm.h" | |||
| @@ -225,6 +226,7 @@ void AscendSession::UnifyMindIR(const KernelGraphPtr &graph) { | |||
| unify_mindir_pm->AddPass(std::make_shared<opt::Conv2DBackpropInputUnifyMindIR>()); | |||
| unify_mindir_pm->AddPass(std::make_shared<opt::Conv2DBackpropFilterUnifyMindIR>()); | |||
| unify_mindir_pm->AddPass(std::make_shared<opt::SliceGradUnifyMindIR>()); | |||
| unify_mindir_pm->AddPass(std::make_shared<opt::AvgPoolGradUnifyMindIR>()); | |||
| unify_mindir_pm->AddPass(std::make_shared<opt::FtrlUnifyOutput>()); | |||
| unify_mindir_pm->AddPass(std::make_shared<opt::MomentumUnifyOutput>()); | |||
| unify_mindir_pm->AddPass(std::make_shared<opt::RMSPropUnifyOutput>()); | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -42,7 +42,7 @@ static std::map<std::string, std::pair<std::vector<size_t>, std::vector<size_t>> | |||
| {prim::kPrimMaxPool->name(), {{0}, {0}}}, | |||
| {prim::kPrimMaxPoolGrad->name(), {{0, 1, 2}, {0}}}, | |||
| {kAvgPoolOpName, {{0}, {0}}}, | |||
| {kAvgPoolGradGpuOpName, {{0, 1, 2}, {0}}}, | |||
| {kAvgPoolGradOpName, {{0, 1, 2}, {0}}}, | |||
| {kFusedBatchNormEx, {{0}, {0}}}, | |||
| {kFusedBatchNormExWithActivation, {{0}, {0}}}, | |||
| {kFusedBatchNormExWithAddAndActivation, {{0, 5}, {0}}}, | |||
| @@ -211,7 +211,8 @@ constexpr auto kEmbeddingLookupProxyOpName = "EmbeddingLookupProxy"; | |||
| constexpr auto kGatherV2OpName = "Gather"; | |||
| constexpr auto kPaddingOpName = "Padding"; | |||
| constexpr auto kAvgPoolOpName = "AvgPool"; | |||
| constexpr auto kAvgPoolGradGpuOpName = "AvgPoolGradGpu"; | |||
| constexpr auto kAvgPoolGradOpName = "AvgPoolGrad"; | |||
| constexpr auto kAvgPoolGradVmOpName = "AvgPoolGradVm"; | |||
| constexpr auto kmaxPoolGradOpName = "MaxPoolGrad"; | |||
| constexpr auto kMaxPoolWithArgmaxOpName = "MaxPoolWithArgmax"; | |||
| constexpr auto kMaxPoolGradWithArgmaxOpName = "MaxPoolGradWithArgmax"; | |||
| @@ -216,7 +216,6 @@ inline const PrimitivePtr kPrimApplyCenteredRMSProp = std::make_shared<Primitive | |||
| inline const PrimitivePtr kPrimAvgPool = std::make_shared<Primitive>("AvgPool"); | |||
| inline const PrimitivePtr kPrimAvgPoolGrad = std::make_shared<Primitive>("AvgPoolGrad"); | |||
| inline const PrimitivePtr kPrimAvgPoolGradVm = std::make_shared<Primitive>("AvgPoolGradVm"); | |||
| inline const PrimitivePtr kPrimAvgPoolGradCpu = std::make_shared<Primitive>("AvgPoolGradCpu"); | |||
| inline const PrimitivePtr kPrimFusedSparseAdam = std::make_shared<Primitive>("FusedSparseAdam"); | |||
| inline const PrimitivePtr kPrimFusedBatchNorm = std::make_shared<Primitive>("FusedBatchNorm"); | |||
| inline const PrimitivePtr kPrimFusedBatchNormEx = std::make_shared<Primitive>("FusedBatchNormEx"); | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -15,7 +15,6 @@ | |||
| """Define the grad rules of neural network related operations.""" | |||
| import os | |||
| import numpy as np | |||
| from mindspore.ops import _selected_grad_ops as SG | |||
| from mindspore.ops.primitive import constexpr | |||
| from mindspore.common.tensor import Tensor | |||
| @@ -250,149 +249,20 @@ def get_bprop_max_pool_grad(self): | |||
| return bprop | |||
| def _windowed_output_size(input_size, ksize, stride, pad_mode): | |||
| """ | |||
| helper func for AvgPoolGrad | |||
| """ | |||
| tmp_output = 0 | |||
| tmp_pad_need = 0 | |||
| tmp_pad_before = 0 | |||
| tmp_pad_after = 0 | |||
| if pad_mode == 'VALID': | |||
| tmp_output = (input_size - ksize + stride) // stride | |||
| tmp_pad_before = 0 | |||
| tmp_pad_after = 0 | |||
| elif pad_mode == 'SAME': | |||
| tmp_output = (input_size + stride - 1) // stride | |||
| tmp_pad_need = max(0, (tmp_output - 1) * stride + ksize - input_size) | |||
| tmp_pad_before = tmp_pad_need // 2 | |||
| tmp_pad_after = tmp_pad_need - tmp_pad_before | |||
| return tmp_output, tmp_pad_before, tmp_pad_after | |||
| @constexpr | |||
| def _get_mean_matrix(x_shape, ksize, stride, pad_mode, x_dtype): | |||
| """ | |||
| helper func for AvgPoolGrad. | |||
| `assist_input_matrix` is a 2d matrix with input_shape after padding, | |||
| the value of element which is padded is 0, else are 1. | |||
| For each element of output, it is mapped for slide window: `[h*h_stride : h*h_stride + h_ksize, | |||
| w*w_stride : w*w_stride + w_ksize]` of `assist_input_matrix`, so the sum of slide window is the | |||
| number of input that associate with output element. | |||
| """ | |||
| n_input, c_input, h_input, w_input = x_shape | |||
| h_ksize, w_ksize = ksize[2], ksize[3] | |||
| h_stride, w_stride = stride[2], stride[3] | |||
| n_output = n_input | |||
| c_output = c_input | |||
| h_output, w_output = 0, 0 | |||
| pad_top, pad_bottom, pad_left, pad_right = 0, 0, 0, 0 | |||
| h_output, pad_top, pad_bottom = _windowed_output_size(h_input, h_ksize, | |||
| h_stride, pad_mode) | |||
| w_output, pad_left, pad_right = _windowed_output_size(w_input, w_ksize, | |||
| w_stride, pad_mode) | |||
| output_size = n_output * c_output * h_output * w_output | |||
| output_shape = (n_output, c_output, h_output, w_output) | |||
| output = np.array([0.0] * output_size) | |||
| output = np.reshape(output, output_shape) | |||
| in_shape_after_padding_2d = (h_input + pad_top + pad_bottom, w_input + pad_left + pad_right) | |||
| assist_input_matrix = np.ones(in_shape_after_padding_2d).astype(np.float32) | |||
| if pad_top > 0: | |||
| assist_input_matrix[:pad_top, :] = 0 | |||
| if pad_bottom > 0: | |||
| assist_input_matrix[-pad_bottom:, :] = 0 | |||
| if pad_left > 0: | |||
| assist_input_matrix[:, :pad_left] = 0 | |||
| if pad_right > 0: | |||
| assist_input_matrix[:, -pad_right:] = 0 | |||
| for h in range(h_output): | |||
| for w in range(w_output): | |||
| curr_input = assist_input_matrix[h * h_stride: h * h_stride + h_ksize, w * w_stride: w * w_stride + w_ksize] | |||
| curr_sum = np.sum(curr_input) | |||
| if curr_sum > 0: | |||
| output[:, :, h, w] = 1. / curr_sum | |||
| return Tensor(output, x_dtype) | |||
| @constexpr | |||
| def _get_kernel_matrix(x_shape_nchw, kernel_matrix_shape, pad_mode, x_dtype): | |||
| kernel_matrix = np.ones(kernel_matrix_shape) | |||
| return Tensor(kernel_matrix, x_dtype) | |||
| @bprop_getters.register(P.AvgPool) | |||
| def get_bprop_avg_pool_grad(self): | |||
| """Grad definition for `AvgPool` operation.""" | |||
| avgpool_grad = G.AvgPoolGrad( | |||
| kernel_size=self.kernel_size, | |||
| strides=self.strides, | |||
| pad_mode=self.pad_mode, | |||
| data_format=self.format) | |||
| # the parameter of AvgPoolGrad in GPU and TBE/CPU is not same | |||
| if self.target == "GPU": | |||
| avgpool_grad_gpu = G.AvgPoolGradGpu( | |||
| kernel_size=self.kernel_size, | |||
| strides=self.strides, | |||
| pad_mode=self.pad_mode, | |||
| data_format=self.format) | |||
| def bprop_gpu(x, out, dout): | |||
| dx = avgpool_grad_gpu(x, out, dout) | |||
| return (dx,) | |||
| bprop_fn = bprop_gpu | |||
| elif self.target == "CPU": | |||
| avgpool_grad_cpu = G.AvgPoolGradCpu( | |||
| kernel_size=self.kernel_size, | |||
| strides=self.strides, | |||
| pad_mode=self.pad_mode, | |||
| data_format=self.format) | |||
| def bprop_cpu(x, out, dout): | |||
| dx = avgpool_grad_cpu(x, out, dout) | |||
| return (dx,) | |||
| bprop_fn = bprop_cpu | |||
| elif self.target == "GE": | |||
| avgpool_grad_ge = G.AvgPoolGrad( | |||
| kernel_size=self.kernel_size, | |||
| strides=self.strides, | |||
| pad_mode=self.pad_mode) | |||
| shape_op = P.Shape() | |||
| def bprop_ge(x, out, dout): | |||
| dx = avgpool_grad_ge(shape_op(x), dout) | |||
| return (dx,) | |||
| bprop_fn = bprop_ge | |||
| else: | |||
| avgpool_grad_vm = G.AvgPoolGradVm( | |||
| kernel_size=self.kernel_size, | |||
| strides=self.strides, | |||
| pad_mode=self.pad_mode) | |||
| k_size_nchw = avgpool_grad_vm.kernel_size | |||
| stride_nchw = avgpool_grad_vm.strides | |||
| pad_mode = self.pad_mode | |||
| def bprop_vm(x, out, dout): | |||
| x_shape_nchw = F.shape(x) | |||
| x_dtype = F.dtype(x) | |||
| kernel_matrix_shape = (1, x_shape_nchw[1], | |||
| k_size_nchw[2], | |||
| k_size_nchw[3]) | |||
| mean_matrix = _get_mean_matrix(x_shape_nchw, k_size_nchw, stride_nchw, pad_mode, x_dtype) | |||
| kernel_matrix = _get_kernel_matrix(x_shape_nchw, kernel_matrix_shape, pad_mode, x_dtype) | |||
| dx = avgpool_grad_vm(x_shape_nchw, dout, mean_matrix, kernel_matrix) | |||
| return (dx,) | |||
| bprop_fn = bprop_vm | |||
| return bprop_fn | |||
| def bprop(x, out, dout): | |||
| dx = avgpool_grad(x, out, dout) | |||
| return (dx,) | |||
| return bprop | |||
| @bprop_getters.register(P.DropoutGenMask) | |||
| @@ -892,23 +892,6 @@ class _PoolGrad(PrimitiveWithInfer): | |||
| self.add_prim_attr("strides", self.strides) | |||
| class AvgPoolGrad(_PoolGrad): | |||
| """Gradients of the avg pool operation for ge.""" | |||
| @prim_attr_register | |||
| def __init__(self, kernel_size=1, strides=1, pad_mode="VALID"): | |||
| super(AvgPoolGrad, self).__init__(kernel_size, strides, pad_mode) | |||
| def __infer__(self, origin_input, dout): | |||
| out = { | |||
| 'value': None, | |||
| 'shape': tuple(origin_input['value']), | |||
| 'dtype': dout['dtype'], | |||
| } | |||
| return out | |||
| class AvgPoolGradVm(_PoolGrad): | |||
| """Gradients of the avg pool operation for vm.""" | |||
| @@ -927,26 +910,12 @@ class AvgPoolGradVm(_PoolGrad): | |||
| return out | |||
| class AvgPoolGradGpu(_PoolGrad): | |||
| """Gradients of the avg pool operation for gpu.""" | |||
| @prim_attr_register | |||
| def __init__(self, kernel_size=1, strides=1, pad_mode="VALID", data_format="NCHW"): | |||
| super(AvgPoolGradGpu, self).__init__(kernel_size, strides, pad_mode, data_format) | |||
| def infer_shape(self, x1_shape, x2_shape, grad_shape): | |||
| return x1_shape | |||
| def infer_dtype(self, x1_dtype, x2_dtype, grad_dtype): | |||
| return x1_dtype | |||
| class AvgPoolGradCpu(_PoolGrad): | |||
| """Gradients of the avg pool operation for cpu.""" | |||
| class AvgPoolGrad(_PoolGrad): | |||
| """Gradients of the avg pool operation.""" | |||
| @prim_attr_register | |||
| def __init__(self, kernel_size=1, strides=1, pad_mode="VALID", data_format="NCHW"): | |||
| super(AvgPoolGradCpu, self).__init__(kernel_size, strides, pad_mode, data_format) | |||
| super(AvgPoolGrad, self).__init__(kernel_size, strides, pad_mode, data_format) | |||
| def infer_shape(self, x1_shape, x2_shape, grad_shape): | |||
| return x1_shape | |||
| @@ -1801,10 +1801,9 @@ class MaxPoolWithArgmax(_Pool): | |||
| [33. 34. 35.]]]] | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self, kernel_size=1, strides=1, pad_mode="valid", data_format="NCHW"): | |||
| super(MaxPoolWithArgmax, self).__init__(kernel_size, strides, pad_mode, data_format) | |||
| self.is_tbe = context.get_context("device_target") == "Ascend" | |||
| self.is_gpu = context.get_context("device_target") == "GPU" | |||
| def infer_shape(self, x_shape): | |||
| out_shape = _Pool.infer_shape(self, x_shape) | |||
| @@ -1887,14 +1886,6 @@ class AvgPool(_Pool): | |||
| @prim_attr_register | |||
| def __init__(self, kernel_size=1, strides=1, pad_mode="valid", data_format="NCHW"): | |||
| if context.get_context("device_target") == "GPU": | |||
| self.target = "GPU" | |||
| elif context.get_context("device_target") == "CPU": | |||
| self.target = "CPU" | |||
| elif context.get_context("enable_ge"): | |||
| self.target = "GE" | |||
| else: | |||
| self.target = "OTHER" | |||
| super(AvgPool, self).__init__(kernel_size, strides, pad_mode, data_format) | |||