Merge pull request !31531 from chenweifeng/aicpu-priority-replay-buffer-kernelr1.7
| @@ -30,6 +30,8 @@ if(EXISTS ${CMAKE_C_COMPILER} AND EXISTS ${CMAKE_CXX_COMPILER}) | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/environ/environ_set.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/environ/environ_get.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/environ/environ_destroy_all.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/replay_buffer/fifo_replay_buffer.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/replay_buffer/priority_replay_buffer.cc | |||
| ) | |||
| add_library(mindspore_aicpu_kernels SHARED | |||
| @@ -0,0 +1,93 @@ | |||
| /** | |||
| * Copyright 2022 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "replay_buffer/fifo_replay_buffer.h" | |||
| #include <vector> | |||
| #include <algorithm> | |||
| #include <memory> | |||
| #include "securec/include/securec.h" | |||
| #include "common/kernel_log.h" | |||
| namespace aicpu { | |||
| FIFOReplayBuffer::FIFOReplayBuffer(size_t capacity, const std::vector<size_t> &schema) | |||
| : capacity_(capacity), head_(-1), size_(0), schema_(schema) { | |||
| for (const auto &size : schema) { | |||
| size_t alloc_size = size * capacity; | |||
| if (alloc_size == 0) { | |||
| AICPU_LOGW("Malloc size can not be 0."); | |||
| return; | |||
| } | |||
| void *ptr = malloc(alloc_size); | |||
| AddressPtr item = std::make_shared<Address>(ptr, alloc_size); | |||
| buffer_.emplace_back(item); | |||
| } | |||
| } | |||
| FIFOReplayBuffer::~FIFOReplayBuffer() { | |||
| for (const auto &item : buffer_) { | |||
| free(item->addr); | |||
| item->addr = nullptr; | |||
| } | |||
| } | |||
| bool FIFOReplayBuffer::Push(const std::vector<AddressPtr> &inputs) { | |||
| if (inputs.size() != schema_.size()) { | |||
| AICPU_LOGE("Transition element num error. Expect %u, but got %u.", schema_.size(), inputs.size()); | |||
| } | |||
| // Head point to the latest item. | |||
| head_ = head_ >= capacity_ ? 0 : head_ + 1; | |||
| size_ = size_ >= capacity_ ? capacity_ : size_ + 1; | |||
| for (size_t i = 0; i < inputs.size(); i++) { | |||
| void *offset = reinterpret_cast<uint8_t *>(buffer_[i]->addr) + head_ * schema_[i]; | |||
| auto ret = memcpy_s(offset, buffer_[i]->size, inputs[i]->addr, inputs[i]->size); | |||
| if (ret != EOK) { | |||
| AICPU_LOGE("memcpy_s() failed. Error code: %d.", ret); | |||
| } | |||
| } | |||
| return true; | |||
| } | |||
| std::vector<AddressPtr> FIFOReplayBuffer::GetItem(size_t idx) { | |||
| if (idx >= capacity_ || idx >= size_) { | |||
| AICPU_LOGE("Idex: %u out of range %u.", idx, std::min(capacity_, size_)); | |||
| } | |||
| std::vector<AddressPtr> ret; | |||
| for (size_t i = 0; i < schema_.size(); i++) { | |||
| void *offset = reinterpret_cast<uint8_t *>(buffer_[i]->addr) + schema_[i] * idx; | |||
| ret.push_back(std::make_shared<Address>(offset, schema_[i])); | |||
| } | |||
| return ret; | |||
| } | |||
| std::vector<std::vector<AddressPtr>> FIFOReplayBuffer::GetItems(const std::vector<size_t> &indices) { | |||
| std::vector<std::vector<AddressPtr>> ret; | |||
| for (const auto &idx : indices) { | |||
| auto item = GetItem(idx); | |||
| (void)ret.emplace_back(item); | |||
| } | |||
| return ret; | |||
| } | |||
| const std::vector<AddressPtr> &FIFOReplayBuffer::GetAll() const { return buffer_; } | |||
| } // namespace aicpu | |||
| @@ -0,0 +1,68 @@ | |||
| /** | |||
| * Copyright 2022 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_FIFO_REPLAY_BUFFER_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_FIFO_REPLAY_BUFFER_H_ | |||
| #include <vector> | |||
| #include <memory> | |||
| namespace aicpu { | |||
| struct Address { | |||
| Address() : addr(nullptr), size(0) {} | |||
| Address(void *address_addr, size_t address_size) : addr(address_addr), size(address_size) {} | |||
| void *addr; | |||
| size_t size; | |||
| }; | |||
| using AddressPtr = std::shared_ptr<Address>; | |||
| // The FIFOReplayBuffer is container storing experiences. | |||
| // It lets the reinforcement learning agents remember and reuse experiences from the past. | |||
| // When the replay buffer is full, the oldest transition will be overridden. | |||
| class FIFOReplayBuffer { | |||
| public: | |||
| // Construct a fixed-length FIFO replay buffer. | |||
| FIFOReplayBuffer(size_t capacity, const std::vector<size_t> &schema); | |||
| ~FIFOReplayBuffer(); | |||
| // Push a transition to replay buffer. If the replay buffer is full, the oldest one will be overridden. | |||
| bool Push(const std::vector<AddressPtr> &inputs); | |||
| // Get a transition by the index. | |||
| std::vector<AddressPtr> GetItem(size_t idx); | |||
| // Get transitions by the indices. | |||
| std::vector<std::vector<AddressPtr>> GetItems(const std::vector<size_t> &indices); | |||
| // Get all transitions. | |||
| const std::vector<AddressPtr> &GetAll() const; | |||
| // Return the latest transition index. It returns -1 if the replay buffer is empty. | |||
| size_t head() const { return head_; } | |||
| // Return the valid transitions number. | |||
| size_t size() const { return size_; } | |||
| protected: | |||
| size_t capacity_; | |||
| size_t head_; | |||
| size_t size_; | |||
| std::vector<AddressPtr> buffer_; | |||
| std::vector<size_t> schema_; | |||
| }; | |||
| } // namespace aicpu | |||
| #endif | |||
| @@ -0,0 +1,128 @@ | |||
| /** | |||
| * Copyright 2022 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "replay_buffer/priority_replay_buffer.h" | |||
| #include <vector> | |||
| #include <tuple> | |||
| #include <memory> | |||
| #include <algorithm> | |||
| #include <limits> | |||
| #include "common/kernel_log.h" | |||
| namespace aicpu { | |||
| PriorityTree::PriorityTree(size_t capacity, const PriorityItem &init_value) | |||
| : SegmentTree<PriorityItem>(capacity, init_value) {} | |||
| PriorityItem PriorityTree::ReduceOp(const PriorityItem &lhs, const PriorityItem &rhs) { | |||
| return PriorityItem(lhs.sum_priority + rhs.sum_priority, std::min(lhs.min_priority, rhs.min_priority)); | |||
| } | |||
| size_t PriorityTree::GetPrefixSumIdx(float prefix_sum) const { | |||
| size_t idx = 1; | |||
| while (idx < capacity_) { | |||
| const auto &left_priority = buffer_[kNumSubnodes * idx].sum_priority; | |||
| if (prefix_sum <= left_priority) { | |||
| idx = kNumSubnodes * idx; | |||
| } else { | |||
| prefix_sum -= left_priority; | |||
| idx = kNumSubnodes * idx + kRightOffset; | |||
| } | |||
| } | |||
| return idx - capacity_; | |||
| } | |||
| PriorityReplayBuffer::PriorityReplayBuffer(uint32_t seed, float alpha, float beta, size_t capacity, | |||
| const std::vector<size_t> &schema) | |||
| : alpha_(alpha), beta_(beta), capacity_(capacity), max_priority_(1.0), schema_(schema) { | |||
| random_engine_.seed(seed); | |||
| fifo_replay_buffer_ = std::make_unique<FIFOReplayBuffer>(capacity, schema); | |||
| priority_tree_ = std::make_unique<PriorityTree>(capacity); | |||
| } | |||
| bool PriorityReplayBuffer::Push(const std::vector<AddressPtr> &items) { | |||
| (void)fifo_replay_buffer_->Push(items); | |||
| auto idx = fifo_replay_buffer_->head(); | |||
| // Set max priority for the newest item. | |||
| priority_tree_->Insert(idx, {max_priority_, max_priority_}); | |||
| return true; | |||
| } | |||
| bool PriorityReplayBuffer::UpdatePriorities(const std::vector<size_t> &indices, const std::vector<float> &priorities) { | |||
| if (indices.size() != priorities.size()) { | |||
| return false; | |||
| } | |||
| for (size_t i = 0; i < indices.size(); i++) { | |||
| float priority = static_cast<float>(pow(priorities[i], alpha_)); | |||
| if (priority <= 0.0f) { | |||
| AICPU_LOGW("The sum priority is %f. It may leads to converge issue."); | |||
| priority = std::numeric_limits<decltype(priority)>::epsilon(); | |||
| } | |||
| priority_tree_->Insert(indices[i], {priority, priority}); | |||
| // Record max priority of transitions | |||
| max_priority_ = std::max(max_priority_, priority); | |||
| } | |||
| return true; | |||
| } | |||
| std::tuple<std::vector<size_t>, std::vector<float>, std::vector<std::vector<AddressPtr>>> PriorityReplayBuffer::Sample( | |||
| size_t batch_size) { | |||
| if (batch_size == 0) { | |||
| AICPU_LOGD("The batch size can not be zero."); | |||
| } | |||
| const PriorityItem &root = priority_tree_->Root(); | |||
| float sum_priority = root.sum_priority; | |||
| float min_priority = root.min_priority; | |||
| size_t size = fifo_replay_buffer_->size(); | |||
| float max_weight = Weight(min_priority, sum_priority, size); | |||
| float segment_len = root.sum_priority / batch_size; | |||
| std::vector<size_t> indices; | |||
| std::vector<float> weights; | |||
| std::vector<std::vector<AddressPtr>> items; | |||
| for (size_t i = 0; i < batch_size; i++) { | |||
| float mass = (dist_(random_engine_) + i) * segment_len; | |||
| size_t idx = priority_tree_->GetPrefixSumIdx(mass); | |||
| (void)indices.emplace_back(idx); | |||
| float priority = priority_tree_->GetByIndex(idx).sum_priority; | |||
| if (max_weight <= 0.0f) { | |||
| AICPU_LOGW("The sum priority is %f. It may leads to converge issue."); | |||
| max_weight = std::numeric_limits<decltype(max_weight)>::epsilon(); | |||
| } | |||
| (void)weights.emplace_back(Weight(priority, sum_priority, size) / max_weight); | |||
| (void)items.emplace_back(fifo_replay_buffer_->GetItem(idx)); | |||
| } | |||
| return std::forward_as_tuple(indices, weights, items); | |||
| } | |||
| inline float PriorityReplayBuffer::Weight(float priority, float sum_priority, size_t size) const { | |||
| if (sum_priority <= 0.0f) { | |||
| AICPU_LOGW("The sum priority is %f. It may leads to converge issue."); | |||
| sum_priority = std::numeric_limits<decltype(sum_priority)>::epsilon(); | |||
| } | |||
| float sample_prob = priority / sum_priority; | |||
| float weight = static_cast<float>(pow(sample_prob * size, -beta_)); | |||
| return weight; | |||
| } | |||
| } // namespace aicpu | |||
| @@ -0,0 +1,82 @@ | |||
| /** | |||
| * Copyright 2022 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_PRIORITY_REPLAY_BUFFER_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_PRIORITY_REPLAY_BUFFER_H_ | |||
| #include <vector> | |||
| #include <tuple> | |||
| #include <memory> | |||
| #include <limits> | |||
| #include <random> | |||
| #include "replay_buffer/fifo_replay_buffer.h" | |||
| #include "replay_buffer/segment_tree.h" | |||
| namespace aicpu { | |||
| // Node value of PriorityTree. It contains sum and minimal priority. | |||
| struct PriorityItem { | |||
| PriorityItem() : sum_priority(0), min_priority(std::numeric_limits<float>::max()) {} | |||
| PriorityItem(float sum, float min) : sum_priority(sum), min_priority(min) {} | |||
| float sum_priority; | |||
| float min_priority; | |||
| }; | |||
| // PriorityTree is tree which the value of node contains sum and minimal priority of its subnodes. | |||
| class PriorityTree : public SegmentTree<PriorityItem> { | |||
| public: | |||
| explicit PriorityTree(size_t capacity, const PriorityItem &init_value = PriorityItem()); | |||
| // Calculate sum and minimal priority of its subnodes. | |||
| PriorityItem ReduceOp(const PriorityItem &lhs, const PriorityItem &rhs) override; | |||
| // Find the minimal index greater than prefix_sum. | |||
| size_t GetPrefixSumIdx(float prefix_sum) const; | |||
| }; | |||
| // PriorityReplayBuffer is experience container used in Deep Q-Networks. | |||
| // The algorithm is proposed in `Prioritized Experience Replay <https://arxiv.org/abs/1511.05952>`. | |||
| // Same as the normal replay buffer, it lets the reinforcement learning agents remember and reuse experiences from the | |||
| // past. Besides, it replays important transitions more frequently and improve sample effciency. | |||
| class PriorityReplayBuffer { | |||
| public: | |||
| // Construct a fixed-length priority replay buffer. | |||
| PriorityReplayBuffer(uint32_t seed, float alpha, float beta, size_t capacity, const std::vector<size_t> &schema); | |||
| // Push an experience transition to the buffer which will be given the highest priority. | |||
| bool Push(const std::vector<AddressPtr> &items); | |||
| // Sample a batch transitions with indices and bias correction weights. | |||
| std::tuple<std::vector<size_t>, std::vector<float>, std::vector<std::vector<AddressPtr>>> Sample(size_t batch_size); | |||
| // Update experience transitions priorities. | |||
| bool UpdatePriorities(const std::vector<size_t> &indices, const std::vector<float> &priorities); | |||
| private: | |||
| float Weight(float priority, float sum_priority, size_t size) const; | |||
| float alpha_; | |||
| float beta_; | |||
| size_t capacity_; | |||
| float max_priority_; | |||
| std::vector<size_t> schema_; | |||
| std::default_random_engine random_engine_; | |||
| std::uniform_real_distribution<float> dist_{0, 1}; | |||
| std::unique_ptr<FIFOReplayBuffer> fifo_replay_buffer_; | |||
| std::unique_ptr<PriorityTree> priority_tree_; | |||
| }; | |||
| } // namespace aicpu | |||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_PRIORITY_REPLAY_BUFFER_H_ | |||
| @@ -0,0 +1,86 @@ | |||
| /** | |||
| * Copyright 2022 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_SEGMENT_TREE_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_SEGMENT_TREE_H_ | |||
| #include <vector> | |||
| #include "common/kernel_log.h" | |||
| namespace aicpu { | |||
| constexpr size_t kRootIndex = 1; | |||
| constexpr size_t kNumSubnodes = 2; | |||
| constexpr size_t kRightOffset = 1; | |||
| // SegmentTree is a binary tree use for storing intervals information. | |||
| // It allows querying stored intervals by given point with complex O(logN). | |||
| // It is constructed from a fixed-lengh array for performance. | |||
| // The intervals information are templated as type T. User could override | |||
| // the ReduceOp() method for generic perpose. | |||
| template <typename T> | |||
| class SegmentTree { | |||
| public: | |||
| // Construct fixed-length segment tree. | |||
| SegmentTree(size_t capacity, const T &init_value) { | |||
| size_t capacity_pow_two = 1; | |||
| while (capacity_pow_two < capacity) { | |||
| capacity_pow_two *= kNumSubnodes; | |||
| } | |||
| capacity_ = capacity_pow_two; | |||
| buffer_.resize(capacity_ * kNumSubnodes, init_value); | |||
| } | |||
| virtual ~SegmentTree() = default; | |||
| // Insert item to the segment tree. | |||
| void Insert(size_t idx, const T &value) { | |||
| if (idx >= capacity_) { | |||
| AICPU_LOGE("The index %d out of range %d.", idx, capacity_); | |||
| } | |||
| // Update leaf node value. | |||
| idx += capacity_; | |||
| buffer_[idx] = value; | |||
| // Update non-leaf node value. | |||
| idx /= kNumSubnodes; | |||
| while (idx >= kRootIndex) { | |||
| buffer_[idx] = ReduceOp(buffer_[kNumSubnodes * idx], buffer_[kNumSubnodes * idx + kRightOffset]); | |||
| idx /= kNumSubnodes; | |||
| } | |||
| } | |||
| // Get root node information. | |||
| const T &Root() { return buffer_[kRootIndex]; } | |||
| // Get leaf node information with index. | |||
| const T &GetByIndex(size_t idx) { | |||
| if (idx >= capacity_) { | |||
| AICPU_LOGE("The index %d out of range %d.", idx, capacity_); | |||
| } | |||
| return buffer_[idx + capacity_]; | |||
| } | |||
| // Reduce method for non leaf node. Subclass should override it for general perpose. | |||
| virtual T ReduceOp(const T &lhs, const T &rhs) = 0; | |||
| protected: | |||
| size_t capacity_; | |||
| std::vector<T> buffer_; | |||
| }; | |||
| } // namespace aicpu | |||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_SEGMENT_TREE_H_ | |||