Browse Source

!16905 code check

From: @wilfchen
Reviewed-by: @limingqi107,@cristoval
Signed-off-by: @cristoval
tags/v1.3.0
mindspore-ci-bot Gitee 4 years ago
parent
commit
d02d78a435
14 changed files with 245 additions and 344 deletions
  1. +4
    -13
      mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/gathernd_gpu_kernel.h
  2. +4
    -13
      mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/gatherv2_gpu_kernel.h
  3. +168
    -0
      mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/strided_slice_gpu_common.h
  4. +4
    -134
      mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/strided_slice_gpu_kernel.h
  5. +4
    -133
      mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/strided_slice_grad_gpu_kernel.h
  6. +2
    -12
      mindspore/ccsrc/backend/kernel_compiler/gpu/data/dataset_init_kernel.cc
  7. +2
    -12
      mindspore/ccsrc/backend/kernel_compiler/gpu/data/dataset_iterator_kernel.cc
  8. +21
    -0
      mindspore/ccsrc/backend/kernel_compiler/gpu/data/dataset_utils.cc
  9. +1
    -0
      mindspore/ccsrc/backend/kernel_compiler/gpu/data/dataset_utils.h
  10. +3
    -13
      mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.h
  11. +3
    -13
      mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_grad_gpu_kernel.h
  12. +2
    -1
      mindspore/ccsrc/backend/optimizer/gpu/relu_v2_pass.cc
  13. +13
    -0
      mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc
  14. +14
    -0
      mindspore/ccsrc/backend/session/anf_runtime_algorithm.h

+ 4
- 13
mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/gathernd_gpu_kernel.h View File

@@ -21,6 +21,7 @@
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
#include "backend/kernel_compiler/gpu/cuda_impl/gathernd.cuh"
#include "backend/session/anf_runtime_algorithm.h"
namespace mindspore {
namespace kernel {
@@ -113,13 +114,13 @@ class GatherNdGpuFwdKernel : public GpuKernel {
protected:
void InitSizeLists() override {
size_t size = GetSize(input_shapes_);
size_t size = AnfAlgo::TensorSizeInByte<T>(input_shapes_);
input_size_list_.push_back(size);
size = GetSize(indices_shapes_);
size = AnfAlgo::TensorSizeInByte<T>(indices_shapes_);
input_size_list_.push_back(size);
size = GetSize(output_shapes_);
size = AnfAlgo::TensorSizeInByte<T>(output_shapes_);
output_size_list_.push_back(size);
}
@@ -140,16 +141,6 @@ class GatherNdGpuFwdKernel : public GpuKernel {
dims_.emplace_back(dim_indices_last);
return;
}
size_t GetSize(const std::vector<size_t> &shape) const {
if (shape.size() == 0) {
return 0;
}
size_t result = sizeof(T);
for (size_t i = 0; i < shape.size(); i++) {
result *= shape[i];
}
return result;
}
std::vector<size_t> input_shapes_;
std::vector<size_t> indices_shapes_;


+ 4
- 13
mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/gatherv2_gpu_kernel.h View File

@@ -22,6 +22,7 @@
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
#include "backend/kernel_compiler/gpu/cuda_impl/gatherv2.cuh"
#include "backend/session/anf_runtime_algorithm.h"

namespace mindspore {
namespace kernel {
@@ -92,14 +93,14 @@ class GatherV2GpuFwdKernel : public GpuKernel {

protected:
void InitSizeLists() override {
size_t size = GetSize(input_shapes_);
size_t size = AnfAlgo::TensorSizeInByte<T>(input_shapes_);
input_size_list_.push_back(size);
size = GetSize(indices_shapes_);
size = AnfAlgo::TensorSizeInByte<T>(indices_shapes_);
input_size_list_.push_back(size);
if (is_dynamic_shape_) {
input_size_list_.push_back(sizeof(int64_t));
}
size = GetSize(output_shapes_);
size = AnfAlgo::TensorSizeInByte<T>(output_shapes_);
output_size_list_.push_back(size);
}

@@ -125,16 +126,6 @@ class GatherV2GpuFwdKernel : public GpuKernel {
dims_[2] = dim_after_indices;
return;
}
size_t GetSize(const std::vector<size_t> &shape) const {
if (shape.size() == 0) {
return 0;
}
size_t result = sizeof(T);
for (size_t i = 0; i < shape.size(); i++) {
result *= shape[i];
}
return result;
}

std::vector<size_t> input_shapes_;
std::vector<size_t> indices_shapes_;


+ 168
- 0
mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/strided_slice_gpu_common.h View File

@@ -0,0 +1,168 @@
/**
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_STRIDED_SLICE_GPU_COMMON_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_STRIDED_SLICE_GPU_COMMON_H_

#include <vector>
#include <bitset>
#include <algorithm>
#include "backend/session/anf_runtime_algorithm.h"

namespace mindspore {
namespace kernel {
constexpr size_t MAX_DIMS = 8;
class StridedSliceGpuCommon {
public:
StridedSliceGpuCommon() : null_output_(false) {}

void CollectInfo(const CNodePtr &kernel_node) {
FillEmptyDims(kernel_node);
ParseMasks(kernel_node);
FillOutputDim();
null_output_ = IsNullOutput();
}

protected:
void FillEmptyDims(const CNodePtr &kernel_node) {
begin_ = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, "begin");
end_ = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, "end");
strides_ = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, "strides");

for (size_t i = 0; i < MAX_DIMS; i++) {
if (i < begin_.size()) {
int64_t dim = input_shape_[i];
begin_[i] = std::min(begin_[i] < 0 ? std::max(begin_[i] + dim, static_cast<int64_t>(0)) : begin_[i], dim - 1);
} else {
begin_.push_back(0);
}

if (i < end_.size()) {
int64_t dim = input_shape_[i];
end_[i] = std::max(end_[i] < 0 ? end_[i] + dim : std::min(end_[i], dim), static_cast<int64_t>(-1));
} else {
end_.push_back(i < input_shape_.size() ? input_shape_[i] : 1);
}

if (i >= strides_.size()) {
strides_.push_back(1);
}

if (i >= input_shape_.size()) {
input_shape_.push_back(1);
}
}
}

void ParseMasks(const CNodePtr &kernel_node) {
auto begin_mask_int = AnfAlgo::GetNodeAttr<int64_t>(kernel_node, "begin_mask");
auto begin_mask = Dec2Bin(begin_mask_int);
for (size_t i = 0; i < begin_mask.size(); i++) {
if (begin_mask[i]) {
begin_[i] = 0;
}
}

auto end_mask_int = AnfAlgo::GetNodeAttr<int64_t>(kernel_node, "end_mask");
auto end_mask = Dec2Bin(end_mask_int);
for (size_t j = 0; j < end_mask.size(); j++) {
if (end_mask[j]) {
end_[j] = input_shape_[j];
}
}

auto ellipsis_mask_int = AnfAlgo::GetNodeAttr<int64_t>(kernel_node, "ellipsis_mask");
auto ellipsis_mask = Dec2Bin(ellipsis_mask_int);
for (size_t k = 0; k < ellipsis_mask.size(); k++) {
if (ellipsis_mask[k]) {
begin_[k] = 0;
end_[k] = input_shape_[k];
strides_[k] = 1;
}
}

auto new_axis_mask_int = AnfAlgo::GetNodeAttr<int64_t>(kernel_node, "new_axis_mask");
auto new_axis_mask = Dec2Bin(new_axis_mask_int);
for (size_t l = 0; l < new_axis_mask.size(); l++) {
if (new_axis_mask[l]) {
begin_[l] = 0;
end_[l] = input_shape_[l];
strides_[l] = 1;
}
}

auto shrink_axis_mask_int = AnfAlgo::GetNodeAttr<int64_t>(kernel_node, "shrink_axis_mask");
auto shrink_axis_mask = Dec2Bin(shrink_axis_mask_int);
for (size_t m = 0; m < shrink_axis_mask.size(); m++) {
if (shrink_axis_mask[m]) {
end_[m] = end_[m] > begin_[m] ? begin_[m] + 1 : begin_[m] - 1;
strides_[m] = end_[m] > begin_[m] ? 1 : -1;
}
}
}

std::vector<bool> Dec2Bin(const int64_t &mask) {
auto mask_str = std::bitset<MAX_DIMS>(mask).to_string();
int64_t dim_idx = 0;
std::vector<bool> result = {false, false, false, false};
for (int64_t i = mask_str.size() - 1; i >= 0; i--) {
if (mask_str[i] == '1') {
result[dim_idx] = true;
}
dim_idx++;
}
return result;
}

void FillOutputDim() {
for (size_t i = 0; i < MAX_DIMS; i++) {
if (begin_[i] <= end_[i] && strides_[i] > 0) {
output_shape_.push_back((end_[i] - 1 - begin_[i]) / strides_[i] + 1);
} else if (begin_[i] > end_[i] && strides_[i] < 0) {
output_shape_.push_back((end_[i] - begin_[i] + 1) / strides_[i] + 1);
} else {
output_shape_.push_back(0);
}
}
}

bool IsNullOutput() {
for (size_t i = 0; i < MAX_DIMS; i++) {
if (begin_[i] >= end_[i] && strides_[i] > 0) {
return true;
}
if (begin_[i] < end_[i] && strides_[i] < 0) {
return true;
}
}
return false;
}

std::vector<int64_t> begin_;
std::vector<int64_t> end_;
std::vector<int64_t> strides_;
std::vector<size_t> input_shape_;
std::vector<size_t> output_shape_;
bool null_output_;

std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
};
} // namespace kernel
} // namespace mindspore

#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_STRIDED_SLICE_GPU_COMMON_H_

+ 4
- 134
mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/strided_slice_gpu_kernel.h View File

@@ -22,15 +22,15 @@
#include <algorithm>
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
#include "backend/kernel_compiler/gpu/arrays/strided_slice_gpu_common.h"
#include "backend/kernel_compiler/gpu/cuda_impl/slice_impl.cuh"

namespace mindspore {
namespace kernel {
constexpr size_t MAX_DIMS = 8;
template <typename T>
class StridedSliceGpuKernel : public GpuKernel {
class StridedSliceGpuKernel : public GpuKernel, public StridedSliceGpuCommon {
public:
StridedSliceGpuKernel() : null_output_(false) {}
StridedSliceGpuKernel() = default;
~StridedSliceGpuKernel() override = default;
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
@@ -57,11 +57,7 @@ class StridedSliceGpuKernel : public GpuKernel {
return false;
}

FillEmptyDims(kernel_node);
ParseMasks(kernel_node);
FillOutputDim();
null_output_ = IsNullOutput();

CollectInfo(kernel_node);
InitSizeLists();
return true;
}
@@ -80,132 +76,6 @@ class StridedSliceGpuKernel : public GpuKernel {
}
output_size_list_.push_back(size1);
}

private:
void FillEmptyDims(const CNodePtr &kernel_node) {
begin_ = GetAttr<std::vector<int64_t>>(kernel_node, "begin");
end_ = GetAttr<std::vector<int64_t>>(kernel_node, "end");
strides_ = GetAttr<std::vector<int64_t>>(kernel_node, "strides");

for (size_t i = 0; i < MAX_DIMS; i++) {
if (i < begin_.size()) {
int64_t dim = input_shape_[i];
begin_[i] = std::min(begin_[i] < 0 ? std::max(begin_[i] + dim, static_cast<int64_t>(0)) : begin_[i], dim - 1);
} else {
begin_.push_back(0);
}

if (i < end_.size()) {
int64_t dim = input_shape_[i];
end_[i] = std::max(end_[i] < 0 ? end_[i] + dim : std::min(end_[i], dim), static_cast<int64_t>(-1));
} else {
end_.push_back(i < input_shape_.size() ? input_shape_[i] : 1);
}

if (i >= strides_.size()) {
strides_.push_back(1);
}

if (i >= input_shape_.size()) {
input_shape_.push_back(1);
}
}
}

void ParseMasks(const CNodePtr &kernel_node) {
auto begin_mask_int = static_cast<int64_t>(GetAttr<int64_t>(kernel_node, "begin_mask"));
auto begin_mask = Dec2Bin(begin_mask_int);
for (size_t i = 0; i < begin_mask.size(); i++) {
if (begin_mask[i]) {
begin_[i] = 0;
}
}

auto end_mask_int = static_cast<int64_t>(GetAttr<int64_t>(kernel_node, "end_mask"));
auto end_mask = Dec2Bin(end_mask_int);
for (size_t j = 0; j < end_mask.size(); j++) {
if (end_mask[j]) {
end_[j] = input_shape_[j];
}
}

auto ellipsis_mask_int = static_cast<int64_t>(GetAttr<int64_t>(kernel_node, "ellipsis_mask"));
auto ellipsis_mask = Dec2Bin(ellipsis_mask_int);
for (size_t k = 0; k < ellipsis_mask.size(); k++) {
if (ellipsis_mask[k]) {
begin_[k] = 0;
end_[k] = input_shape_[k];
strides_[k] = 1;
}
}

auto new_axis_mask_int = static_cast<int64_t>(GetAttr<int64_t>(kernel_node, "new_axis_mask"));
auto new_axis_mask = Dec2Bin(new_axis_mask_int);
for (size_t l = 0; l < new_axis_mask.size(); l++) {
if (new_axis_mask[l]) {
begin_[l] = 0;
end_[l] = input_shape_[l];
strides_[l] = 1;
}
}

auto shrink_axis_mask_int = static_cast<int64_t>(GetAttr<int64_t>(kernel_node, "shrink_axis_mask"));
auto shrink_axis_mask = Dec2Bin(shrink_axis_mask_int);
for (size_t m = 0; m < shrink_axis_mask.size(); m++) {
if (shrink_axis_mask[m]) {
end_[m] = end_[m] > begin_[m] ? begin_[m] + 1 : begin_[m] - 1;
strides_[m] = end_[m] > begin_[m] ? 1 : -1;
}
}
}

std::vector<bool> Dec2Bin(const int64_t &mask) {
auto mask_str = std::bitset<MAX_DIMS>(mask).to_string();
int64_t dim_idx = 0;
std::vector<bool> result = {false, false, false, false};
for (int64_t i = mask_str.size() - 1; i >= 0; i--) {
if (mask_str[i] == '1') {
result[dim_idx] = true;
}
dim_idx++;
}
return result;
}

void FillOutputDim() {
for (size_t i = 0; i < MAX_DIMS; i++) {
if (begin_[i] <= end_[i] && strides_[i] > 0) {
output_shape_.push_back((end_[i] - 1 - begin_[i]) / strides_[i] + 1);
} else if (begin_[i] > end_[i] && strides_[i] < 0) {
output_shape_.push_back((end_[i] - begin_[i] + 1) / strides_[i] + 1);
} else {
output_shape_.push_back(0);
}
}
}

bool IsNullOutput() {
for (size_t i = 0; i < MAX_DIMS; i++) {
if (begin_[i] >= end_[i] && strides_[i] > 0) {
return true;
}
if (begin_[i] < end_[i] && strides_[i] < 0) {
return true;
}
}
return false;
}

std::vector<int64_t> begin_;
std::vector<int64_t> end_;
std::vector<int64_t> strides_;
std::vector<size_t> input_shape_;
std::vector<size_t> output_shape_;
bool null_output_;

std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
};
} // namespace kernel
} // namespace mindspore


+ 4
- 133
mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/strided_slice_grad_gpu_kernel.h View File

@@ -22,15 +22,15 @@
#include <algorithm>
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
#include "backend/kernel_compiler/gpu/arrays/strided_slice_gpu_common.h"
#include "backend/kernel_compiler/gpu/cuda_impl/slice_impl.cuh"

namespace mindspore {
namespace kernel {
constexpr size_t MAX_DIMS = 7;
template <typename T>
class StridedSliceGradGpuKernel : public GpuKernel {
class StridedSliceGradGpuKernel : public GpuKernel, public StridedSliceGpuCommon {
public:
StridedSliceGradGpuKernel() : null_output_(false) {}
StridedSliceGradGpuKernel() = default;
~StridedSliceGradGpuKernel() override = default;
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
@@ -59,10 +59,7 @@ class StridedSliceGradGpuKernel : public GpuKernel {
return false;
}

FillEmptyDims(kernel_node);
ParseMasks(kernel_node);
FillOutputDim();
null_output_ = IsNullOutput();
CollectInfo(kernel_node);
InitSizeLists();
return true;
}
@@ -81,132 +78,6 @@ class StridedSliceGradGpuKernel : public GpuKernel {
}
output_size_list_.push_back(size1);
}

private:
void FillEmptyDims(const CNodePtr &kernel_node) {
begin_ = GetAttr<std::vector<int64_t>>(kernel_node, "begin");
end_ = GetAttr<std::vector<int64_t>>(kernel_node, "end");
strides_ = GetAttr<std::vector<int64_t>>(kernel_node, "strides");

for (size_t i = 0; i < MAX_DIMS; i++) {
if (i < begin_.size()) {
int64_t dim = input_shape_[i];
begin_[i] = std::min(begin_[i] < 0 ? std::max(begin_[i] + dim, static_cast<int64_t>(0)) : begin_[i], dim - 1);
} else {
begin_.push_back(0);
}

if (i < end_.size()) {
int64_t dim = input_shape_[i];
end_[i] = std::max(end_[i] < 0 ? end_[i] + dim : std::min(end_[i], dim), static_cast<int64_t>(-1));
} else {
end_.push_back(i < input_shape_.size() ? input_shape_[i] : 1);
}

if (i >= strides_.size()) {
strides_.push_back(1);
}

if (i >= input_shape_.size()) {
input_shape_.push_back(1);
}
}
}

void ParseMasks(const CNodePtr &kernel_node) {
auto begin_mask_int = static_cast<int64_t>(GetAttr<int64_t>(kernel_node, "begin_mask"));
auto begin_mask = Dec2Bin(begin_mask_int);
for (size_t i = 0; i < begin_mask.size(); i++) {
if (begin_mask[i]) {
begin_[i] = 0;
}
}

auto end_mask_int = static_cast<int64_t>(GetAttr<int64_t>(kernel_node, "end_mask"));
auto end_mask = Dec2Bin(end_mask_int);
for (size_t j = 0; j < end_mask.size(); j++) {
if (end_mask[j]) {
end_[j] = input_shape_[j];
}
}

auto ellipsis_mask_int = static_cast<int64_t>(GetAttr<int64_t>(kernel_node, "ellipsis_mask"));
auto ellipsis_mask = Dec2Bin(ellipsis_mask_int);
for (size_t k = 0; k < ellipsis_mask.size(); k++) {
if (ellipsis_mask[k]) {
begin_[k] = 0;
end_[k] = input_shape_[k];
strides_[k] = 1;
}
}

auto new_axis_mask_int = static_cast<int64_t>(GetAttr<int64_t>(kernel_node, "new_axis_mask"));
auto new_axis_mask = Dec2Bin(new_axis_mask_int);
for (size_t l = 0; l < new_axis_mask.size(); l++) {
if (new_axis_mask[l]) {
begin_[l] = 0;
end_[l] = input_shape_[l];
strides_[l] = 1;
}
}

auto shrink_axis_mask_int = static_cast<int64_t>(GetAttr<int64_t>(kernel_node, "shrink_axis_mask"));
auto shrink_axis_mask = Dec2Bin(shrink_axis_mask_int);
for (size_t m = 0; m < shrink_axis_mask.size(); m++) {
if (shrink_axis_mask[m]) {
end_[m] = end_[m] > begin_[m] ? begin_[m] + 1 : begin_[m] - 1;
strides_[m] = end_[m] > begin_[m] ? 1 : -1;
}
}
}

std::vector<bool> Dec2Bin(const int64_t &mask) {
auto mask_str = std::bitset<MAX_DIMS>(mask).to_string();
int64_t dim_idx = 0;
std::vector<bool> result = {false, false, false, false};
for (int64_t i = mask_str.size() - 1; i >= 0; i--) {
if (mask_str[i] == '1') {
result[dim_idx] = true;
}
dim_idx++;
}
return result;
}

void FillOutputDim() {
for (size_t i = 0; i < MAX_DIMS; i++) {
if (begin_[i] <= end_[i] && strides_[i] > 0) {
output_shape_.push_back((end_[i] - 1 - begin_[i]) / strides_[i] + 1);
} else if (begin_[i] > end_[i] && strides_[i] < 0) {
output_shape_.push_back((end_[i] - begin_[i] + 1) / strides_[i] + 1);
} else {
output_shape_.push_back(0);
}
}
}

bool IsNullOutput() {
for (size_t i = 0; i < MAX_DIMS; i++) {
if (begin_[i] >= end_[i] && strides_[i] > 0) {
return true;
}
if (begin_[i] < end_[i] && strides_[i] < 0) {
return true;
}
}
return false;
}

std::vector<int64_t> begin_;
std::vector<int64_t> end_;
std::vector<int64_t> strides_;
std::vector<size_t> input_shape_;
std::vector<size_t> output_shape_;
bool null_output_;

std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
};
} // namespace kernel
} // namespace mindspore


+ 2
- 12
mindspore/ccsrc/backend/kernel_compiler/gpu/data/dataset_init_kernel.cc View File

@@ -36,18 +36,8 @@ const std::vector<size_t> &DatasetInitKernel::GetWorkspaceSizeList() const { ret
bool DatasetInitKernel::Init(const CNodePtr &kernel_node) {
queue_name_ = GetAttr<std::string>(kernel_node, "queue_name");
std::vector<std::vector<int>> shapes;
std::vector<std::vector<int64_t>> shapes_me = GetAttr<const std::vector<std::vector<int64_t>>>(kernel_node, "shapes");
(void)std::transform(shapes_me.begin(), shapes_me.end(), std::back_inserter(shapes),
[](const std::vector<int64_t> &values) {
std::vector<int> shape;
(void)std::transform(values.begin(), values.end(), std::back_inserter(shape),
[](const int64_t &value) { return static_cast<int>(value); });
return shape;
});
auto types = GetAttr<const std::vector<TypePtr>>(kernel_node, "types");
if (shapes.size() != types.size()) {
MS_LOG(EXCEPTION) << "Invalid shapes: " << shapes << ", types: " << types;
}
std::vector<TypePtr> types;
GetShapeAndType(kernel_node, &shapes, &types);

for (size_t i = 0; i < shapes.size(); i++) {
int unit = UnitSizeInBytes(types[i]->type_id());


+ 2
- 12
mindspore/ccsrc/backend/kernel_compiler/gpu/data/dataset_iterator_kernel.cc View File

@@ -47,18 +47,8 @@ bool DatasetIteratorKernel::Init(const CNodePtr &kernel_node) {
kernel_node_ = kernel_node;
queue_name_ = GetAttr<std::string>(kernel_node, "shared_name");
std::vector<std::vector<int>> shapes;
std::vector<std::vector<int64_t>> shapes_me = GetAttr<const std::vector<std::vector<int64_t>>>(kernel_node, "shapes");
(void)std::transform(shapes_me.begin(), shapes_me.end(), std::back_inserter(shapes),
[](const std::vector<int64_t> &values) {
std::vector<int> shape;
(void)std::transform(values.begin(), values.end(), std::back_inserter(shape),
[](const int64_t &value) { return static_cast<int>(value); });
return shape;
});
auto types = GetAttr<const std::vector<TypePtr>>(kernel_node, "types");
if (shapes.size() != types.size()) {
MS_LOG(EXCEPTION) << "Invalid shapes: " << shapes << ", types: " << types;
}
std::vector<TypePtr> types;
GetShapeAndType(kernel_node, &shapes, &types);

for (size_t i = 0; i < shapes.size(); i++) {
int unit = UnitSizeInBytes(types[i]->type_id());


+ 21
- 0
mindspore/ccsrc/backend/kernel_compiler/gpu/data/dataset_utils.cc View File

@@ -15,6 +15,8 @@
*/

#include "backend/kernel_compiler/gpu/data/dataset_utils.h"
#include <algorithm>
#include "backend/session/anf_runtime_algorithm.h"

namespace mindspore {
namespace kernel {
@@ -64,5 +66,24 @@ int ElementNums(const std::vector<int> &shape) {

return nums;
}

void GetShapeAndType(const CNodePtr &kernel_node, std::vector<std::vector<int>> *shapes, std::vector<TypePtr> *types) {
MS_EXCEPTION_IF_NULL(shapes);
MS_EXCEPTION_IF_NULL(types);
std::vector<std::vector<int64_t>> shapes_me =
AnfAlgo::GetNodeAttr<std::vector<std::vector<int64_t>>>(kernel_node, "shapes");
(void)std::transform(shapes_me.begin(), shapes_me.end(), std::back_inserter(*shapes),
[](const std::vector<int64_t> &values) {
std::vector<int> shape;
(void)std::transform(values.begin(), values.end(), std::back_inserter(shape),
[](const int64_t &value) { return static_cast<int>(value); });
return shape;
});

*types = AnfAlgo::GetNodeAttr<std::vector<TypePtr>>(kernel_node, "types");
if (shapes->size() != types->size()) {
MS_LOG(EXCEPTION) << "Invalid shapes: " << *shapes << ", types: " << *types;
}
}
} // namespace kernel
} // namespace mindspore

+ 1
- 0
mindspore/ccsrc/backend/kernel_compiler/gpu/data/dataset_utils.h View File

@@ -23,6 +23,7 @@ namespace mindspore {
namespace kernel {
size_t UnitSizeInBytes(const mindspore::TypeId &t);
int ElementNums(const std::vector<int> &shape);
void GetShapeAndType(const CNodePtr &kernel_node, std::vector<std::vector<int>> *shapes, std::vector<TypePtr> *types);
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_DATASET_UTILS_KERNEL_H_

+ 3
- 13
mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.h View File

@@ -25,6 +25,8 @@
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
#include "backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cuh"
#include "backend/kernel_compiler/gpu/kernel_constants.h"
#include "backend/session/anf_runtime_algorithm.h"

namespace mindspore {
namespace kernel {
constexpr int MAX_DIMS = 7;
@@ -68,7 +70,7 @@ class BroadcastOpGpuKernel : public GpuKernel {
auto shape1 = AnfAlgo::GetInputRealDeviceShapeIfExist(kernel_node, 0);
auto shape2 = AnfAlgo::GetInputRealDeviceShapeIfExist(kernel_node, 1);
auto shape3 = AnfAlgo::GetOutputRealDeviceShapeIfExist(kernel_node, 0);
need_broadcast_ = IsBroadcast(shape1, shape2);
need_broadcast_ = AnfAlgo::IsTensorBroadcast(shape1, shape2);
if (need_broadcast_ && shape1.size() > 7) {
MS_LOG(EXCEPTION) << "Broadcast operation not support dim greater than 7";
}
@@ -165,18 +167,6 @@ class BroadcastOpGpuKernel : public GpuKernel {
MS_LOG(EXCEPTION) << "operation " << kernel_name << " is not supported.";
}

bool IsBroadcast(const std::vector<size_t> &lhs, const std::vector<size_t> &rhs) {
if (lhs.size() != rhs.size()) {
return true;
}
for (size_t i = 0; i < lhs.size(); i++) {
if (lhs[i] != rhs[i]) {
return true;
}
}
return false;
}

BroadcastOpType op_type_;
bool need_broadcast_;
bool is_comp_op_;


+ 3
- 13
mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_grad_gpu_kernel.h View File

@@ -25,6 +25,8 @@
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
#include "backend/kernel_compiler/gpu/cuda_impl/broadcast_grad_impl.cuh"
#include "backend/kernel_compiler/gpu/kernel_constants.h"
#include "backend/session/anf_runtime_algorithm.h"

namespace mindspore {
namespace kernel {
template <typename T>
@@ -68,7 +70,7 @@ class BroadcastOpGradGpuKernel : public GpuKernel {
auto shape1 = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
auto shape2 = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
auto shape3 = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 2);
need_broadcast_ = IsBroadcast(shape1, shape2);
need_broadcast_ = AnfAlgo::IsTensorBroadcast(shape1, shape2);
if (need_broadcast_ && shape1.size() > 4) {
MS_LOG(EXCEPTION) << "Broadcast operation not support dim greater than 4";
}
@@ -144,18 +146,6 @@ class BroadcastOpGradGpuKernel : public GpuKernel {
}
}

bool IsBroadcast(const std::vector<size_t> &lhs, const std::vector<size_t> &rhs) {
if (lhs.size() != rhs.size()) {
return true;
}
for (size_t i = 0; i < lhs.size(); i++) {
if (lhs[i] != rhs[i]) {
return true;
}
}
return false;
}

BroadcastGradOpType op_type_;
bool need_broadcast_;
size_t input1_num_;


+ 2
- 1
mindspore/ccsrc/backend/optimizer/gpu/relu_v2_pass.cc View File

@@ -29,6 +29,7 @@ namespace mindspore {
namespace opt {
namespace {
const size_t kReluV2OutputNum = 2;
const size_t kBitPerUInt = 32;

CNodePtr GetRelu(const CNodePtr &relu_grad) {
MS_EXCEPTION_IF_NULL(relu_grad);
@@ -80,7 +81,7 @@ CNodePtr CreateReluV2(const FuncGraphPtr &graph, const CNodePtr &relu) {
auto element_num =
std::accumulate(output_shape.begin(), output_shape.end(), static_cast<size_t>(1), std::multiplies<size_t>());

std::vector<size_t> mask_shape = {(element_num + 31) / 32};
std::vector<size_t> mask_shape = {(element_num + kBitPerUInt - 1) / kBitPerUInt};
auto shapes = {AnfAlgo::GetOutputInferShape(relu, 0), mask_shape};
auto types = {AnfAlgo::GetOutputInferDataType(relu, 0), kNumberTypeUInt32};
AnfAlgo::SetOutputInferTypeAndShape(types, shapes, new_node.get());


+ 13
- 0
mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc View File

@@ -2074,5 +2074,18 @@ void AnfRuntimeAlgorithm::GetRealInputs(const AnfNodePtr &node, std::vector<sess
GetRealOutputRecursively(input_node, 0, inputs);
}
}

bool AnfRuntimeAlgorithm::IsTensorBroadcast(const std::vector<size_t> &lhs, const std::vector<size_t> &rhs) {
if (lhs.size() != rhs.size()) {
return true;
}
for (size_t i = 0; i < lhs.size(); i++) {
if (lhs[i] != rhs[i]) {
return true;
}
}
return false;
}

} // namespace session
} // namespace mindspore

+ 14
- 0
mindspore/ccsrc/backend/session/anf_runtime_algorithm.h View File

@@ -279,6 +279,20 @@ class AnfRuntimeAlgorithm {
static AnfNodeIndexSet GetUpdateStateUsers(const FuncGraphManagerPtr &manager, const AnfNodePtr &node);
// Get node real inputs, skip `MakeTuple`, `TupleGetItem`, `Depend`, `Load`, `UpdateState` etc.
static void GetRealInputs(const AnfNodePtr &anf_node, std::vector<session::KernelWithIndex> *inputs);
// Check whether tensors need broadcast or not.
static bool IsTensorBroadcast(const std::vector<size_t> &lhs, const std::vector<size_t> &rhs);
// Calc tensor size in byte.
template <typename T>
static size_t TensorSizeInByte(const std::vector<size_t> &shape) {
if (shape.size() == 0) {
return 0;
}
size_t result = sizeof(T);
for (size_t i = 0; i < shape.size(); i++) {
result *= shape[i];
}
return result;
}
};
} // namespace session
using AnfAlgo = session::AnfRuntimeAlgorithm;


Loading…
Cancel
Save