Browse Source

!27216 fix tensorarray

Merge pull request !27216 from VectorSL/fix-tensorarray
tags/v1.6.0
i-robot Gitee 4 years ago
parent
commit
017a2231cd
17 changed files with 76 additions and 38 deletions
  1. +1
    -1
      mindspore/ccsrc/backend/kernel_compiler/cpu/rl/tensor_array_clear_kernel.cc
  2. +2
    -2
      mindspore/ccsrc/backend/kernel_compiler/cpu/rl/tensor_array_close_kernel.cc
  3. +1
    -3
      mindspore/ccsrc/backend/kernel_compiler/cpu/rl/tensor_array_create_kernel.cc
  4. +1
    -1
      mindspore/ccsrc/backend/kernel_compiler/cpu/rl/tensor_array_read_kernel.cc
  5. +1
    -1
      mindspore/ccsrc/backend/kernel_compiler/cpu/rl/tensor_array_size_kernel.cc
  6. +5
    -5
      mindspore/ccsrc/backend/kernel_compiler/cpu/rl/tensor_array_stack_kernel.cc
  7. +1
    -1
      mindspore/ccsrc/backend/kernel_compiler/cpu/rl/tensor_array_stack_kernel.h
  8. +1
    -1
      mindspore/ccsrc/backend/kernel_compiler/cpu/rl/tensor_array_write_kernel.cc
  9. +9
    -2
      mindspore/ccsrc/backend/kernel_compiler/gpu/rl/tensor_array_clear_kernel.cc
  10. +9
    -3
      mindspore/ccsrc/backend/kernel_compiler/gpu/rl/tensor_array_close_kernel.cc
  11. +2
    -3
      mindspore/ccsrc/backend/kernel_compiler/gpu/rl/tensor_array_create_kernel.cc
  12. +9
    -2
      mindspore/ccsrc/backend/kernel_compiler/gpu/rl/tensor_array_read_kernel.cc
  13. +8
    -1
      mindspore/ccsrc/backend/kernel_compiler/gpu/rl/tensor_array_size_kernel.cc
  14. +10
    -3
      mindspore/ccsrc/backend/kernel_compiler/gpu/rl/tensor_array_stack_kernel.cc
  15. +1
    -1
      mindspore/ccsrc/backend/kernel_compiler/gpu/rl/tensor_array_stack_kernel.h
  16. +9
    -2
      mindspore/ccsrc/backend/kernel_compiler/gpu/rl/tensor_array_write_kernel.cc
  17. +6
    -6
      mindspore/ccsrc/runtime/device/tensor_array_manager.h

+ 1
- 1
mindspore/ccsrc/backend/kernel_compiler/cpu/rl/tensor_array_clear_kernel.cc View File

@@ -41,7 +41,7 @@ bool TensorArrayCPUClearKernel::Launch(const std::vector<AddressPtr> &inputs, co
const std::vector<AddressPtr> &) {
auto handle_addr = GetDeviceAddress<int64_t>(inputs, 0);
MS_EXCEPTION_IF_NULL(handle_addr);
TensorArrayPtr tensors_ = TensorArrayMgr::GetInstance().GetTensorArray(handle_addr);
TensorArrayPtr tensors_ = TensorArrayMgr::GetInstance().GetTensorArray(handle_addr[0]);
MS_ERROR_IF_NULL(tensors_);
// Clear TensorArray valid size, but keep the memory.
tensors_->Clear();


+ 2
- 2
mindspore/ccsrc/backend/kernel_compiler/cpu/rl/tensor_array_close_kernel.cc View File

@@ -43,12 +43,12 @@ bool TensorArrayCPUCloseKernel::Launch(const std::vector<AddressPtr> &inputs, co
auto handle_addr = GetDeviceAddress<int64_t>(inputs, 0);
MS_EXCEPTION_IF_NULL(handle_addr);
CPUTensorArrayPtr tensors_ =
std::dynamic_pointer_cast<CPUTensorArray>(TensorArrayMgr::GetInstance().GetTensorArray(handle_addr));
std::dynamic_pointer_cast<CPUTensorArray>(TensorArrayMgr::GetInstance().GetTensorArray(handle_addr[0]));
MS_ERROR_IF_NULL(tensors_);
// Free device mem
tensors_->Free();
// Erase tensorarray
if (!TensorArrayMgr::GetInstance().EraseTensorArray(handle_addr)) {
if (!TensorArrayMgr::GetInstance().EraseTensorArray(handle_addr[0])) {
MS_LOG(EXCEPTION) << "Free tensorarray failed";
}
return true;


+ 1
- 3
mindspore/ccsrc/backend/kernel_compiler/cpu/rl/tensor_array_create_kernel.cc View File

@@ -59,9 +59,7 @@ bool TensorArrayCPUCreateKernel::Launch(const std::vector<AddressPtr> &, const s
out_addr[0] = tensor_array_handle;
MS_LOG(DEBUG) << "Create handle id " << tensor_array_handle;
// Put tensorarray to a saved map : map<handle, tensorarray> in tensorarray manager.
// Only put the device addr as the key to avoid a copy from device to host.
// The output handle address will kept and won't be reused.
TensorArrayMgr::GetInstance().AddTensorArray(out_addr, tensor_array);
TensorArrayMgr::GetInstance().AddTensorArray(tensor_array_handle, tensor_array);
return true;
}
} // namespace kernel


+ 1
- 1
mindspore/ccsrc/backend/kernel_compiler/cpu/rl/tensor_array_read_kernel.cc View File

@@ -53,7 +53,7 @@ bool TensorArrayCPUReadKernel::Launch(const std::vector<AddressPtr> &inputs, con
MS_EXCEPTION_IF_NULL(index);
MS_EXCEPTION_IF_NULL(out_value);
int64_t index_host = index[0];
TensorArrayPtr tensors_ = TensorArrayMgr::GetInstance().GetTensorArray(handle_addr);
TensorArrayPtr tensors_ = TensorArrayMgr::GetInstance().GetTensorArray(handle_addr[0]);
MS_ERROR_IF_NULL(tensors_);
if (!tensors_->CheckReadIndexLogical(index_host)) {
MS_LOG(EXCEPTION) << "Invalid index " << index_host << " for read.";


+ 1
- 1
mindspore/ccsrc/backend/kernel_compiler/cpu/rl/tensor_array_size_kernel.cc View File

@@ -42,7 +42,7 @@ bool TensorArrayCPUSizeKernel::Launch(const std::vector<AddressPtr> &inputs, con
auto out_addr = GetDeviceAddress<int64_t>(outputs, 0);
MS_EXCEPTION_IF_NULL(handle_addr);
MS_EXCEPTION_IF_NULL(out_addr);
auto tensors_ = TensorArrayMgr::GetInstance().GetTensorArray(handle_addr);
auto tensors_ = TensorArrayMgr::GetInstance().GetTensorArray(handle_addr[0]);
MS_ERROR_IF_NULL(tensors_);
int64_t valid_size = SizeToLong(tensors_->GetValidSize());
out_addr[0] = valid_size;


+ 5
- 5
mindspore/ccsrc/backend/kernel_compiler/cpu/rl/tensor_array_stack_kernel.cc View File

@@ -24,8 +24,7 @@ namespace mindspore {
namespace kernel {
using mindspore::device::TensorArrayMgr;
using mindspore::device::TensorArrayPtr;
TensorArrayCPUStackKernel::TensorArrayCPUStackKernel()
: handle_(nullptr), value_size_(0), ele_size_(0), type_(nullptr) {
TensorArrayCPUStackKernel::TensorArrayCPUStackKernel() : handle_(0), value_size_(0), ele_size_(0), type_(nullptr) {
ResetResource();
}

@@ -64,7 +63,7 @@ void TensorArrayCPUStackKernel::PostExecute() {
}

void TensorArrayCPUStackKernel::ResetResource() noexcept {
handle_ = nullptr;
handle_ = 0;
value_size_ = 0;
ele_size_ = 0;
shapes_.clear();
@@ -75,10 +74,11 @@ void TensorArrayCPUStackKernel::ResetResource() noexcept {

bool TensorArrayCPUStackKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs) {
handle_ = GetDeviceAddress<int64_t>(inputs, 0);
auto handle_addr = GetDeviceAddress<int64_t>(inputs, 0);
auto out_value = GetDeviceAddress<unsigned char>(outputs, 0);
MS_EXCEPTION_IF_NULL(out_value);
MS_EXCEPTION_IF_NULL(handle_);
MS_EXCEPTION_IF_NULL(handle_addr);
handle_ = handle_addr[0];
TensorArrayPtr tensors_ = TensorArrayMgr::GetInstance().GetTensorArray(handle_);
MS_EXCEPTION_IF_NULL(tensors_);
if (tensors_->GetValidSize() > tensors_->GetRealSize()) {


+ 1
- 1
mindspore/ccsrc/backend/kernel_compiler/cpu/rl/tensor_array_stack_kernel.h View File

@@ -43,7 +43,7 @@ class TensorArrayCPUStackKernel : public CPUKernel {

private:
CNodeWeakPtr kernel_node_;
int64_t *handle_;
int64_t handle_;
int64_t value_size_;
int64_t ele_size_;
std::vector<size_t> shapes_;


+ 1
- 1
mindspore/ccsrc/backend/kernel_compiler/cpu/rl/tensor_array_write_kernel.cc View File

@@ -56,7 +56,7 @@ bool TensorArrayCPUWriteKernel::Launch(const std::vector<AddressPtr> &inputs, co
MS_EXCEPTION_IF_NULL(value);
int64_t index_host = index[0];
CPUTensorArrayPtr tensors_ =
std::dynamic_pointer_cast<CPUTensorArray>(TensorArrayMgr::GetInstance().GetTensorArray(handle_addr));
std::dynamic_pointer_cast<CPUTensorArray>(TensorArrayMgr::GetInstance().GetTensorArray(handle_addr[0]));
MS_EXCEPTION_IF_NULL(tensors_);
if (!tensors_->CheckValue(type_, shapes_)) {
MS_LOG(EXCEPTION) << "Invalid input data for tensor array write op.";


+ 9
- 2
mindspore/ccsrc/backend/kernel_compiler/gpu/rl/tensor_array_clear_kernel.cc View File

@@ -33,6 +33,7 @@ const std::vector<size_t> &TensorArrayClearKernel::GetWorkspaceSizeList() const

bool TensorArrayClearKernel::Init(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
kernel_node_ = kernel_node;
InitSizeLists();
return true;
}
@@ -43,9 +44,15 @@ void TensorArrayClearKernel::InitSizeLists() {
}

bool TensorArrayClearKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &, void *) {
const std::vector<AddressPtr> &, void *stream) {
auto handle_addr = GetDeviceAddress<int64_t>(inputs, 0);
TensorArrayPtr tensors_ = TensorArrayMgr::GetInstance().GetTensorArray(handle_addr);
MS_ERROR_IF_NULL(handle_addr);
int64_t handle = 0;
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
cudaMemcpyAsync(&handle, handle_addr, sizeof(int64_t), cudaMemcpyDeviceToHost,
reinterpret_cast<cudaStream_t>(stream)),
"Get handle to host failed");
TensorArrayPtr tensors_ = TensorArrayMgr::GetInstance().GetTensorArray(handle);
MS_ERROR_IF_NULL(tensors_);
// Clear TensorArray valid size, but keep the memory.
tensors_->Clear();


+ 9
- 3
mindspore/ccsrc/backend/kernel_compiler/gpu/rl/tensor_array_close_kernel.cc View File

@@ -44,15 +44,21 @@ void TensorArrayCloseKernel::InitSizeLists() {
}

bool TensorArrayCloseKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &, void *) {
const std::vector<AddressPtr> &, void *stream) {
auto handle_addr = GetDeviceAddress<int64_t>(inputs, 0);
MS_ERROR_IF_NULL(handle_addr);
int64_t handle = 0;
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
cudaMemcpyAsync(&handle, handle_addr, sizeof(int64_t), cudaMemcpyDeviceToHost,
reinterpret_cast<cudaStream_t>(stream)),
"Get handle to host failed");
GPUTensorArrayPtr tensors_ =
std::dynamic_pointer_cast<GPUTensorArray>(TensorArrayMgr::GetInstance().GetTensorArray(handle_addr));
std::dynamic_pointer_cast<GPUTensorArray>(TensorArrayMgr::GetInstance().GetTensorArray(handle));
MS_ERROR_IF_NULL(tensors_);
// Free device mem
tensors_->Free();
// Erase tensorarray
if (!TensorArrayMgr::GetInstance().EraseTensorArray(handle_addr)) {
if (!TensorArrayMgr::GetInstance().EraseTensorArray(handle)) {
MS_LOG(EXCEPTION) << "Free tensorarray failed";
}
return true;


+ 2
- 3
mindspore/ccsrc/backend/kernel_compiler/gpu/rl/tensor_array_create_kernel.cc View File

@@ -64,9 +64,8 @@ bool TensorArrayCreateKernel::Launch(const std::vector<AddressPtr> &, const std:
"Create TensorArray failed");
MS_LOG(DEBUG) << "Create handle id " << tensor_array_handle;
// Put tensorarray to a saved map : map<handle, tensorarray> in tensorarray manager.
// Only put the device addr as the key to avoid a copy from device to host.
// The output handle address will kept and won't be reused.
TensorArrayMgr::GetInstance().AddTensorArray(out_addr, tensor_array);
// And increase the handle count automatically in AddTensorArray function.
TensorArrayMgr::GetInstance().AddTensorArray(tensor_array_handle, tensor_array);
return true;
}
} // namespace kernel


+ 9
- 2
mindspore/ccsrc/backend/kernel_compiler/gpu/rl/tensor_array_read_kernel.cc View File

@@ -52,15 +52,22 @@ void TensorArrayReadKernel::InitSizeLists() {
bool TensorArrayReadKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs, void *stream) {
auto handle_addr = GetDeviceAddress<int64_t>(inputs, 0);
MS_ERROR_IF_NULL(handle_addr);
auto index = GetDeviceAddress<int64_t>(inputs, 1);
MS_ERROR_IF_NULL(index);
auto out_value = GetDeviceAddress<unsigned char>(outputs, 0);
MS_ERROR_IF_NULL(out_value);
int64_t index_host = 0;
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
cudaMemcpyAsync(&index_host, index, sizeof(int64_t), cudaMemcpyDeviceToHost,
reinterpret_cast<cudaStream_t>(stream)),
"Get index failed");
TensorArrayPtr tensors_ = TensorArrayMgr::GetInstance().GetTensorArray(handle_addr);
"Get index to host failed");
int64_t handle = 0;
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
cudaMemcpyAsync(&handle, handle_addr, sizeof(int64_t), cudaMemcpyDeviceToHost,
reinterpret_cast<cudaStream_t>(stream)),
"Get handle to host failed");
TensorArrayPtr tensors_ = TensorArrayMgr::GetInstance().GetTensorArray(handle);
MS_ERROR_IF_NULL(tensors_);
if (!tensors_->CheckReadIndexLogical(index_host)) {
MS_LOG(EXCEPTION) << "Invalid index " << index_host << " for read.";


+ 8
- 1
mindspore/ccsrc/backend/kernel_compiler/gpu/rl/tensor_array_size_kernel.cc View File

@@ -45,7 +45,14 @@ bool TensorArraySizeKernel::Launch(const std::vector<AddressPtr> &inputs, const
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
auto handle_addr = GetDeviceAddress<int64_t>(inputs, 0);
auto out_addr = GetDeviceAddress<int64_t>(outputs, 0);
auto tensors_ = TensorArrayMgr::GetInstance().GetTensorArray(handle_addr);
MS_ERROR_IF_NULL(out_addr);
MS_ERROR_IF_NULL(handle_addr);
int64_t handle = 0;
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
cudaMemcpyAsync(&handle, handle_addr, sizeof(int64_t), cudaMemcpyDeviceToHost,
reinterpret_cast<cudaStream_t>(stream_ptr)),
"Get handle to host failed");
auto tensors_ = TensorArrayMgr::GetInstance().GetTensorArray(handle);
MS_ERROR_IF_NULL(tensors_);
int64_t valid_size = SizeToLong(tensors_->GetValidSize());
MS_LOG(DEBUG) << "Launch TensorArraySize, valid size is " << valid_size;


+ 10
- 3
mindspore/ccsrc/backend/kernel_compiler/gpu/rl/tensor_array_stack_kernel.cc View File

@@ -27,7 +27,7 @@ namespace kernel {
using mindspore::device::TensorArrayMgr;
using mindspore::device::TensorArrayPtr;
TensorArrayStackKernel::TensorArrayStackKernel()
: handle_(nullptr), value_size_(0), ele_size_(0), stream_ptr_(nullptr), type_(nullptr) {
: handle_(0), value_size_(0), ele_size_(0), stream_ptr_(nullptr), type_(nullptr) {
ResetResource();
}

@@ -59,6 +59,7 @@ void TensorArrayStackKernel::PostExecute() {
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, cudaStreamSynchronize(reinterpret_cast<cudaStream_t>(stream_ptr_)),
"TensorArrayStack cudaStreamSynchronized failed");
TensorArrayPtr tensors_ = TensorArrayMgr::GetInstance().GetTensorArray(handle_);
MS_EXCEPTION_IF_NULL(tensors_);
size_t tensor_size = tensors_->GetValidSize();
auto shape = shapes_;
shape.insert(shape.begin(), tensor_size);
@@ -67,7 +68,7 @@ void TensorArrayStackKernel::PostExecute() {
}

void TensorArrayStackKernel::ResetResource() noexcept {
handle_ = nullptr;
handle_ = 0;
value_size_ = 0;
ele_size_ = 0;
stream_ptr_ = nullptr;
@@ -85,10 +86,16 @@ void TensorArrayStackKernel::InitSizeLists() {
bool TensorArrayStackKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
stream_ptr_ = stream_ptr;
handle_ = GetDeviceAddress<int64_t>(inputs, 0);
auto handle_addr = GetDeviceAddress<int64_t>(inputs, 0);
auto out_value = GetDeviceAddress<unsigned char>(outputs, 0);
MS_ERROR_IF_NULL(out_value);
MS_ERROR_IF_NULL(handle_addr);
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
cudaMemcpyAsync(&handle_, handle_addr, sizeof(int64_t), cudaMemcpyDeviceToHost,
reinterpret_cast<cudaStream_t>(stream_ptr)),
"Get handle to host failed");
TensorArrayPtr tensors_ = TensorArrayMgr::GetInstance().GetTensorArray(handle_);
MS_ERROR_IF_NULL(tensors_);
if (tensors_->GetValidSize() > tensors_->GetRealSize()) {
MS_LOG(EXCEPTION) << "Invalid TensorArray size, maybe should Clear() TensorArray before next usage.";
}


+ 1
- 1
mindspore/ccsrc/backend/kernel_compiler/gpu/rl/tensor_array_stack_kernel.h View File

@@ -43,7 +43,7 @@ class TensorArrayStackKernel : public GpuKernel {
void InitSizeLists() override;

private:
int64_t *handle_;
int64_t handle_;
int64_t value_size_;
int64_t ele_size_;
void *stream_ptr_;


+ 9
- 2
mindspore/ccsrc/backend/kernel_compiler/gpu/rl/tensor_array_write_kernel.cc View File

@@ -56,14 +56,21 @@ bool TensorArrayWriteKernel::Launch(const std::vector<AddressPtr> &inputs, const
auto handle_addr = GetDeviceAddress<int64_t>(inputs, 0);
auto index = GetDeviceAddress<int64_t>(inputs, 1);
auto value = GetDeviceAddress<unsigned char>(inputs, 2);

MS_ERROR_IF_NULL(handle_addr);
MS_ERROR_IF_NULL(index);
MS_ERROR_IF_NULL(value);
int64_t index_host = 0;
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
cudaMemcpyAsync(&index_host, index, sizeof(int64_t), cudaMemcpyDeviceToHost,
reinterpret_cast<cudaStream_t>(stream)),
"Get indexd failed");
int64_t handle = 0;
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
cudaMemcpyAsync(&handle, handle_addr, sizeof(int64_t), cudaMemcpyDeviceToHost,
reinterpret_cast<cudaStream_t>(stream)),
"Get handle to host failed");
GPUTensorArrayPtr tensors_ =
std::dynamic_pointer_cast<GPUTensorArray>(TensorArrayMgr::GetInstance().GetTensorArray(handle_addr));
std::dynamic_pointer_cast<GPUTensorArray>(TensorArrayMgr::GetInstance().GetTensorArray(handle));
MS_EXCEPTION_IF_NULL(tensors_);
if (!tensors_->CheckValue(type_, shapes_)) {
MS_LOG(EXCEPTION) << "Invalid input data for tensor array write op.";


+ 6
- 6
mindspore/ccsrc/runtime/device/tensor_array_manager.h View File

@@ -42,14 +42,14 @@ class TensorArrayMgr {
TensorArrayMgr(const TensorArrayMgr &) = delete;
TensorArrayMgr(const TensorArrayMgr &&) = delete;

void AddTensorArray(const int64_t *handle, const TensorArrayPtr &ta) {
MS_LOG(DEBUG) << "Add a TensorArray to map, handle addr is " << handle;
void AddTensorArray(const int64_t handle, const TensorArrayPtr &ta) {
MS_LOG(DEBUG) << "Add a TensorArray to map, handle is " << handle;
tensors_map_.emplace(std::make_pair(handle, ta));
// Increase handle count when added a TensorArray.
tensor_array_handle_count += 1;
}

TensorArrayPtr GetTensorArray(const int64_t *handle) {
TensorArrayPtr GetTensorArray(const int64_t handle) {
if (!tensors_map_.count(handle)) {
MS_LOG(EXCEPTION) << "Error handle [" << handle << "] to get tensorarray";
} else {
@@ -58,7 +58,7 @@ class TensorArrayMgr {
}
}

bool EraseTensorArray(const int64_t *handle) {
bool EraseTensorArray(const int64_t handle) {
if (tensors_map_.count(handle)) {
MS_LOG(DEBUG) << "Erase tensorarray from map, handle number is " << handle;
tensors_map_.erase(handle);
@@ -72,8 +72,8 @@ class TensorArrayMgr {
int64_t GetHandleCount() const { return tensor_array_handle_count; }

private:
// Store the TensorArrays in a map, as pair(handle_addr, TensorArrayPtr).
std::map<const int64_t *, TensorArrayPtr> tensors_map_;
// Store the TensorArrays in a map, as pair(handle, TensorArrayPtr).
std::map<const int64_t, TensorArrayPtr> tensors_map_;
// Used as an unique handle number for each TensorArray.
std::atomic<int64_t> tensor_array_handle_count{0};
};


Loading…
Cancel
Save