|
|
|
@@ -27,7 +27,7 @@ namespace kernel { |
|
|
|
using mindspore::device::TensorArrayMgr; |
|
|
|
using mindspore::device::TensorArrayPtr; |
|
|
|
TensorArrayStackKernel::TensorArrayStackKernel() |
|
|
|
: handle_(nullptr), value_size_(0), ele_size_(0), stream_ptr_(nullptr), type_(nullptr) { |
|
|
|
: handle_(0), value_size_(0), ele_size_(0), stream_ptr_(nullptr), type_(nullptr) { |
|
|
|
ResetResource(); |
|
|
|
} |
|
|
|
|
|
|
|
@@ -59,6 +59,7 @@ void TensorArrayStackKernel::PostExecute() { |
|
|
|
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, cudaStreamSynchronize(reinterpret_cast<cudaStream_t>(stream_ptr_)), |
|
|
|
"TensorArrayStack cudaStreamSynchronized failed"); |
|
|
|
TensorArrayPtr tensors_ = TensorArrayMgr::GetInstance().GetTensorArray(handle_); |
|
|
|
MS_EXCEPTION_IF_NULL(tensors_); |
|
|
|
size_t tensor_size = tensors_->GetValidSize(); |
|
|
|
auto shape = shapes_; |
|
|
|
shape.insert(shape.begin(), tensor_size); |
|
|
|
@@ -67,7 +68,7 @@ void TensorArrayStackKernel::PostExecute() { |
|
|
|
} |
|
|
|
|
|
|
|
void TensorArrayStackKernel::ResetResource() noexcept { |
|
|
|
handle_ = nullptr; |
|
|
|
handle_ = 0; |
|
|
|
value_size_ = 0; |
|
|
|
ele_size_ = 0; |
|
|
|
stream_ptr_ = nullptr; |
|
|
|
@@ -85,10 +86,16 @@ void TensorArrayStackKernel::InitSizeLists() { |
|
|
|
bool TensorArrayStackKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &, |
|
|
|
const std::vector<AddressPtr> &outputs, void *stream_ptr) { |
|
|
|
stream_ptr_ = stream_ptr; |
|
|
|
handle_ = GetDeviceAddress<int64_t>(inputs, 0); |
|
|
|
auto handle_addr = GetDeviceAddress<int64_t>(inputs, 0); |
|
|
|
auto out_value = GetDeviceAddress<unsigned char>(outputs, 0); |
|
|
|
MS_ERROR_IF_NULL(out_value); |
|
|
|
MS_ERROR_IF_NULL(handle_addr); |
|
|
|
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, |
|
|
|
cudaMemcpyAsync(&handle_, handle_addr, sizeof(int64_t), cudaMemcpyDeviceToHost, |
|
|
|
reinterpret_cast<cudaStream_t>(stream_ptr)), |
|
|
|
"Get handle to host failed"); |
|
|
|
TensorArrayPtr tensors_ = TensorArrayMgr::GetInstance().GetTensorArray(handle_); |
|
|
|
MS_ERROR_IF_NULL(tensors_); |
|
|
|
if (tensors_->GetValidSize() > tensors_->GetRealSize()) { |
|
|
|
MS_LOG(EXCEPTION) << "Invalid TensorArray size, maybe should Clear() TensorArray before next usage."; |
|
|
|
} |
|
|
|
|