|
|
|
@@ -131,15 +131,18 @@ void *AscendPsCache::MallocMemory(size_t size) { |
|
|
|
return device::ascend::AscendMemoryPool::GetInstance().AllocTensorMem(size); |
|
|
|
} |
|
|
|
|
|
|
|
bool AscendPsCache::MallocConstantMemory(size_t constant_value) { |
|
|
|
bool AscendPsCache::MallocConstantMemory(size_t cache_vocab_size) { |
|
|
|
offset_addr_ = reinterpret_cast<int *>(device::ascend::AscendMemoryPool::GetInstance().AllocTensorMem(sizeof(int))); |
|
|
|
MS_ERROR_IF_NULL(offset_addr_); |
|
|
|
rtMemset(offset_addr_, sizeof(int), 0, sizeof(int)); |
|
|
|
cache_vocab_size_addr_ = |
|
|
|
reinterpret_cast<int *>(device::ascend::AscendMemoryPool::GetInstance().AllocTensorMem(sizeof(int))); |
|
|
|
MS_ERROR_IF_NULL(cache_vocab_size_addr_); |
|
|
|
rtMemset(cache_vocab_size_addr_, sizeof(int), constant_value, sizeof(int)); |
|
|
|
return true; |
|
|
|
int copy_value = SizeToInt(cache_vocab_size); |
|
|
|
if (!CopyHostMemToDevice(cache_vocab_size_addr_, ©_value, sizeof(int))) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
return SynchronizeStream(); |
|
|
|
} |
|
|
|
|
|
|
|
bool AscendPsCache::RecordEvent() { |
|
|
|
|