Browse Source

update argmaxwithvalue

tags/v0.5.0-beta
VectorSL 5 years ago
parent
commit
46afb18e25
3 changed files with 17 additions and 30 deletions
  1. +13
    -26
      mindspore/ccsrc/kernel/gpu/arrays/argmaxwithvalue_gpu_kernel.h
  2. +3
    -3
      mindspore/ccsrc/kernel/gpu/cuda_impl/argmaxwithvalue_impl.cu
  3. +1
    -1
      mindspore/ccsrc/kernel/gpu/cuda_impl/argmaxwithvalue_impl.cuh

+ 13
- 26
mindspore/ccsrc/kernel/gpu/arrays/argmaxwithvalue_gpu_kernel.h View File

@@ -26,15 +26,7 @@ namespace kernel {
template <typename T, typename S> template <typename T, typename S>
class ArgmaxWithValueGpuKernel : public GpuKernel { class ArgmaxWithValueGpuKernel : public GpuKernel {
public: public:
ArgmaxWithValueGpuKernel()
: input_size_(0),
output_size_(0),
workspace_size_(0),
axis_(0),
dims_(1),
bound_(0),
outerSize_(0),
innerSize_(0) {}
ArgmaxWithValueGpuKernel() : input_size_(0), output_size_(0), bound_(0), outerSize_(0), innerSize_(0) {}
~ArgmaxWithValueGpuKernel() override = default; ~ArgmaxWithValueGpuKernel() override = default;


const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
@@ -46,37 +38,36 @@ class ArgmaxWithValueGpuKernel : public GpuKernel {
T *input = GetDeviceAddress<T>(inputs, 0); T *input = GetDeviceAddress<T>(inputs, 0);
T *output = GetDeviceAddress<T>(outputs, 1); T *output = GetDeviceAddress<T>(outputs, 1);
S *index = GetDeviceAddress<S>(outputs, 0); S *index = GetDeviceAddress<S>(outputs, 0);
CalArgmaxWithValue(input_size_ / sizeof(T), input, bound_, outerSize_, innerSize_, axis_, dims_, index, output,
CalArgmaxWithValue(input_size_ / sizeof(T), input, bound_, outerSize_, innerSize_, index, output,
reinterpret_cast<cudaStream_t>(stream_ptr)); reinterpret_cast<cudaStream_t>(stream_ptr));
return true; return true;
} }


bool Init(const CNodePtr &kernel_node) override { bool Init(const CNodePtr &kernel_node) override {
shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
std::vector<size_t> shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 1); auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 1);
dims_ = shape_.size();

axis_ = GetAttr<int>(kernel_node, "axis");
if (axis_ < 0) {
axis_ += dims_;
int dims = shape.size();
int axis = GetAttr<int>(kernel_node, "axis");
if (axis < 0) {
axis += dims;
} }
input_size_ = sizeof(T); input_size_ = sizeof(T);
for (auto x : shape_) {
for (auto x : shape) {
input_size_ *= x; input_size_ *= x;
} }
output_size_ = sizeof(S); output_size_ = sizeof(S);
for (auto x : output_shape) { for (auto x : output_shape) {
output_size_ *= x; output_size_ *= x;
} }
bound_ = shape_[axis_];
bound_ = shape[axis];
outerSize_ = 1; outerSize_ = 1;
for (int i = axis_ - 1; i >= 0; i--) {
outerSize_ *= shape_[i];
for (int i = axis - 1; i >= 0; i--) {
outerSize_ *= shape[i];
} }


innerSize_ = 1; innerSize_ = 1;
for (int i = axis_ + 1; i < dims_; i++) {
innerSize_ *= shape_[i];
for (int i = axis + 1; i < dims; i++) {
innerSize_ *= shape[i];
} }
InitSizeLists(); InitSizeLists();
return true; return true;
@@ -92,13 +83,9 @@ class ArgmaxWithValueGpuKernel : public GpuKernel {
private: private:
size_t input_size_; size_t input_size_;
size_t output_size_; size_t output_size_;
size_t workspace_size_;
std::vector<size_t> input_size_list_; std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_; std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_; std::vector<size_t> workspace_size_list_;
std::vector<size_t> shape_;
int axis_;
int dims_;
int bound_; int bound_;
int outerSize_; int outerSize_;
int innerSize_; int innerSize_;


+ 3
- 3
mindspore/ccsrc/kernel/gpu/cuda_impl/argmaxwithvalue_impl.cu View File

@@ -44,15 +44,15 @@ __global__ void ArgmaxWithValue(size_t size, const T* input, const int bound, in


template <typename T, typename S> template <typename T, typename S>
void CalArgmaxWithValue(size_t size, const T* input, const int bound_, const int outerSize_, const int innerSize_, void CalArgmaxWithValue(size_t size, const T* input, const int bound_, const int outerSize_, const int innerSize_,
int axis_, int dims_, S* index, T* output, cudaStream_t cuda_stream) {
S* index, T* output, cudaStream_t cuda_stream) {
ArgmaxWithValue<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, input, bound_, outerSize_, innerSize_, ArgmaxWithValue<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, input, bound_, outerSize_, innerSize_,
index, output); index, output);
return; return;
} }


template void CalArgmaxWithValue<float, int>(size_t size, const float* input, const int bound_, const int outerSize_, template void CalArgmaxWithValue<float, int>(size_t size, const float* input, const int bound_, const int outerSize_,
const int innerSize_, int axis_, int dims_, int* index, float* output,
const int innerSize_, int* index, float* output,
cudaStream_t cuda_stream); cudaStream_t cuda_stream);
template void CalArgmaxWithValue<half, int>(size_t size, const half* input, const int bound_, const int outerSize_, template void CalArgmaxWithValue<half, int>(size_t size, const half* input, const int bound_, const int outerSize_,
const int innerSize_, int axis_, int dims_, int* index, half* output,
const int innerSize_, int* index, half* output,
cudaStream_t cuda_stream); cudaStream_t cuda_stream);

+ 1
- 1
mindspore/ccsrc/kernel/gpu/cuda_impl/argmaxwithvalue_impl.cuh View File

@@ -18,5 +18,5 @@
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_ARGMAXWITHVALUE_H_ #define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_ARGMAXWITHVALUE_H_
template <typename T, typename S> template <typename T, typename S>
void CalArgmaxWithValue(size_t size, const T* input, const int bound_, const int outerSize_, const int innerSize_, void CalArgmaxWithValue(size_t size, const T* input, const int bound_, const int outerSize_, const int innerSize_,
int axis_, int dims_, S* index, T* output, cudaStream_t cuda_stream);
S* index, T* output, cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_ARGMAXWITHVALUE_H_ #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_ARGMAXWITHVALUE_H_

Loading…
Cancel
Save