Browse Source

code clean

r1.7
wilfChen 4 years ago
parent
commit
5da148b523
8 changed files with 30 additions and 31 deletions
  1. +5
    -5
      mindspore/ccsrc/plugin/device/cpu/kernel/rl/fifo_replay_buffer.cc
  2. +3
    -3
      mindspore/ccsrc/plugin/device/cpu/kernel/rl/fifo_replay_buffer.h
  3. +10
    -10
      mindspore/ccsrc/plugin/device/cpu/kernel/rl/priority_replay_buffer.cc
  4. +3
    -3
      mindspore/ccsrc/plugin/device/cpu/kernel/rl/priority_replay_buffer.h
  5. +7
    -7
      mindspore/ccsrc/plugin/device/cpu/kernel/rl/priority_replay_buffer_cpu_kernel.cc
  6. +1
    -1
      mindspore/ccsrc/plugin/device/cpu/kernel/rl/priority_replay_buffer_cpu_kernel.h
  7. +1
    -1
      mindspore/ccsrc/plugin/device/cpu/kernel/rl/replay_buffer_factory.h
  8. +0
    -1
      mindspore/ccsrc/plugin/device/cpu/kernel/rl/segment_tree.h

+ 5
- 5
mindspore/ccsrc/plugin/device/cpu/kernel/rl/fifo_replay_buffer.cc View File

@@ -36,7 +36,7 @@ FIFOReplayBuffer::FIFOReplayBuffer(size_t capacity, const std::vector<size_t> &s

void *ptr = device::cpu::CPUMemoryPool::GetInstance().AllocTensorMem(alloc_size);
AddressPtr item = std::make_shared<Address>(ptr, alloc_size);
buffer_.emplace_back(item);
(void)buffer_.emplace_back(item);
}
}

@@ -57,7 +57,7 @@ bool FIFOReplayBuffer::Push(const std::vector<AddressPtr> &inputs) {
size_ = size_ >= capacity_ ? capacity_ : size_ + 1;

for (size_t i = 0; i < inputs.size(); i++) {
void *offset = reinterpret_cast<char *>(buffer_[i]->addr) + head_ * schema_[i];
void *offset = reinterpret_cast<uint8_t *>(buffer_[i]->addr) + head_ * schema_[i];
auto ret = memcpy_s(offset, buffer_[i]->size, inputs[i]->addr, inputs[i]->size);
if (ret != EOK) {
MS_LOG(EXCEPTION) << "memcpy_s() failed. Error code: " << ret;
@@ -74,7 +74,7 @@ std::vector<AddressPtr> FIFOReplayBuffer::GetItem(size_t idx) {

std::vector<AddressPtr> ret;
for (size_t i = 0; i < schema_.size(); i++) {
void *offset = reinterpret_cast<char *>(buffer_[i]->addr) + schema_[i] * idx;
void *offset = reinterpret_cast<uint8_t *>(buffer_[i]->addr) + schema_[i] * idx;
ret.push_back(std::make_shared<Address>(offset, schema_[i]));
}

@@ -85,12 +85,12 @@ std::vector<std::vector<AddressPtr>> FIFOReplayBuffer::GetItems(const std::vecto
std::vector<std::vector<AddressPtr>> ret;
for (const auto &idx : indices) {
auto item = GetItem(idx);
ret.emplace_back(item);
(void)ret.emplace_back(item);
}

return ret;
}

const std::vector<AddressPtr> &FIFOReplayBuffer::GetAll() { return buffer_; }
const std::vector<AddressPtr> &FIFOReplayBuffer::GetAll() const { return buffer_; }
} // namespace kernel
} // namespace mindspore

+ 3
- 3
mindspore/ccsrc/plugin/device/cpu/kernel/rl/fifo_replay_buffer.h View File

@@ -42,13 +42,13 @@ class FIFOReplayBuffer {
std::vector<std::vector<AddressPtr>> GetItems(const std::vector<size_t> &indices);

// Get all transitions.
const std::vector<AddressPtr> &GetAll();
const std::vector<AddressPtr> &GetAll() const;

// Return the latest transition index. It returns -1 if the replay buffer is empty.
size_t head() { return head_; }
size_t head() const { return head_; }

// Return the valid transitions number.
size_t size() { return size_; }
size_t size() const { return size_; }

protected:
size_t capacity_;


+ 10
- 10
mindspore/ccsrc/plugin/device/cpu/kernel/rl/priority_replay_buffer.cc View File

@@ -33,7 +33,7 @@ PriorityItem PriorityTree::ReduceOp(const PriorityItem &lhs, const PriorityItem
return PriorityItem(lhs.sum_priority + rhs.sum_priority, std::min(lhs.min_priority, rhs.min_priority));
}

size_t PriorityTree::GetPrefixSumIdx(float prefix_sum) {
size_t PriorityTree::GetPrefixSumIdx(float prefix_sum) const {
size_t idx = 1;
while (idx < capacity_) {
const auto &left_priority = buffer_[kNumSubnodes * idx].sum_priority;
@@ -48,7 +48,7 @@ size_t PriorityTree::GetPrefixSumIdx(float prefix_sum) {
return idx - capacity_;
}

PriorityReplayBuffer::PriorityReplayBuffer(int seed, float alpha, float beta, size_t capacity,
PriorityReplayBuffer::PriorityReplayBuffer(uint32_t seed, float alpha, float beta, size_t capacity,
const std::vector<size_t> &schema)
: alpha_(alpha), beta_(beta), capacity_(capacity), max_priority_(1.0), schema_(schema) {
random_engine_.seed(seed);
@@ -57,7 +57,7 @@ PriorityReplayBuffer::PriorityReplayBuffer(int seed, float alpha, float beta, si
}

bool PriorityReplayBuffer::Push(const std::vector<AddressPtr> &items) {
fifo_replay_buffer_->Push(items);
(void)fifo_replay_buffer_->Push(items);
auto idx = fifo_replay_buffer_->head();

// Set max priority for the newest item.
@@ -71,7 +71,7 @@ bool PriorityReplayBuffer::UpdatePriorities(const std::vector<size_t> &indices,
}

for (size_t i = 0; i < indices.size(); i++) {
float priority = pow(priorities[i], alpha_);
float priority = static_cast<float>(pow(priorities[i], alpha_));
if (priority <= 0.0f) {
MS_LOG(WARNING) << "The priority is " << priority << ". It may lead to converge issue.";
priority = kMinPriority;
@@ -91,7 +91,7 @@ std::tuple<std::vector<size_t>, std::vector<float>, std::vector<std::vector<Addr
const PriorityItem &root = priority_tree_->Root();
float sum_priority = root.sum_priority;
float min_priority = root.min_priority;
float size = fifo_replay_buffer_->size();
size_t size = fifo_replay_buffer_->size();
float max_weight = Weight(min_priority, sum_priority, size);
float segment_len = root.sum_priority / batch_size;

@@ -102,27 +102,27 @@ std::tuple<std::vector<size_t>, std::vector<float>, std::vector<std::vector<Addr
float mass = (dist_(random_engine_) + i) * segment_len;
size_t idx = priority_tree_->GetPrefixSumIdx(mass);

indices.emplace_back(idx);
(void)indices.emplace_back(idx);
float priority = priority_tree_->GetByIndex(idx).sum_priority;

if (max_weight <= 0.0f) {
MS_LOG(WARNING) << "The max priority is " << max_weight << ". It may leads to converge issue.";
max_weight = kMinPriority;
}
weights.emplace_back(Weight(priority, sum_priority, size) / max_weight);
items.emplace_back(fifo_replay_buffer_->GetItem(idx));
(void)weights.emplace_back(Weight(priority, sum_priority, size) / max_weight);
(void)items.emplace_back(fifo_replay_buffer_->GetItem(idx));
}

return std::forward_as_tuple(indices, weights, items);
}

float PriorityReplayBuffer::Weight(float priority, float sum_priority, size_t size) {
inline float PriorityReplayBuffer::Weight(float priority, float sum_priority, size_t size) const {
if (sum_priority <= 0.0f) {
MS_LOG(WARNING) << "The sum priority is " << sum_priority << ". It may leads to converge issue.";
sum_priority = kMinPriority;
}
float sample_prob = priority / sum_priority;
float weight = pow(sample_prob * size, -beta_);
float weight = static_cast<float>(pow(sample_prob * size, -beta_));
return weight;
}
} // namespace kernel


+ 3
- 3
mindspore/ccsrc/plugin/device/cpu/kernel/rl/priority_replay_buffer.h View File

@@ -47,7 +47,7 @@ class PriorityTree : public SegmentTree<PriorityItem> {
PriorityItem ReduceOp(const PriorityItem &lhs, const PriorityItem &rhs) override;

// Find the minimal index greater than prefix_sum.
size_t GetPrefixSumIdx(float prefix_sum);
size_t GetPrefixSumIdx(float prefix_sum) const;
};

// PriorityReplayBuffer is experience container used in Deep Q-Networks.
@@ -57,7 +57,7 @@ class PriorityTree : public SegmentTree<PriorityItem> {
class PriorityReplayBuffer {
public:
// Construct a fixed-length priority replay buffer.
PriorityReplayBuffer(int seed, float alpha, float beta, size_t capacity, const std::vector<size_t> &schema);
PriorityReplayBuffer(uint32_t seed, float alpha, float beta, size_t capacity, const std::vector<size_t> &schema);

// Push an experience transition to the buffer which will be given the highest priority.
bool Push(const std::vector<AddressPtr> &items);
@@ -69,7 +69,7 @@ class PriorityReplayBuffer {
bool UpdatePriorities(const std::vector<size_t> &indices, const std::vector<float> &priorities);

private:
inline float Weight(float priority, float sum_priority, size_t size);
float Weight(float priority, float sum_priority, size_t size) const;

float alpha_;
float beta_;


+ 7
- 7
mindspore/ccsrc/plugin/device/cpu/kernel/rl/priority_replay_buffer_cpu_kernel.cc View File

@@ -42,7 +42,7 @@ void PriorityReplayBufferCreateCpuKernel::InitKernel(const CNodePtr &kernel_node
MS_EXCEPTION_IF_CHECK_FAIL(dtypes.size() == shapes.size(), "The dtype and shapes should be same.");
std::vector<size_t> schema;
for (size_t i = 0; i < shapes.size(); i++) {
size_t num_element = std::accumulate(shapes[i].begin(), shapes[i].end(), 1, std::multiplies<size_t>());
size_t num_element = std::accumulate(shapes[i].begin(), shapes[i].end(), 1ULL, std::multiplies<size_t>());
size_t type_size = GetTypeByte(dtypes[i]);
schema.push_back(num_element * type_size);
}
@@ -50,9 +50,9 @@ void PriorityReplayBufferCreateCpuKernel::InitKernel(const CNodePtr &kernel_node
unsigned int seed = 0;
std::random_device rd;
if (seed1 != 0) {
seed = IntToUint(seed1);
seed = static_cast<unsigned int>(seed1);
} else if (seed0 != 0) {
seed = IntToUint(seed0);
seed = static_cast<unsigned int>(seed0);
} else {
seed = rd();
}
@@ -77,7 +77,7 @@ void PriorityReplayBufferPushCpuKernel::InitKernel(const CNodePtr &kernel_node)

bool PriorityReplayBufferPushCpuKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs) {
prioriory_replay_buffer_->Push(inputs);
(void)prioriory_replay_buffer_->Push(inputs);

// Return a placeholder in case of dead code eliminate optimization.
auto handle = GetDeviceAddress<int64_t>(outputs, 0);
@@ -87,14 +87,14 @@ bool PriorityReplayBufferPushCpuKernel::Launch(const std::vector<AddressPtr> &in

void PriorityReplayBufferSampleCpuKernel::InitKernel(const CNodePtr &kernel_node) {
handle_ = common::AnfAlgo::GetNodeAttr<int64_t>(kernel_node, "handle");
batch_size_ = common::AnfAlgo::GetNodeAttr<int64_t>(kernel_node, "batch_size");
batch_size_ = LongToSize(common::AnfAlgo::GetNodeAttr<int64_t>(kernel_node, "batch_size"));
const auto &dtypes = common::AnfAlgo::GetNodeAttr<std::vector<TypePtr>>(kernel_node, "dtypes");
const auto &shapes = common::AnfAlgo::GetNodeAttr<std::vector<std::vector<int64_t>>>(kernel_node, "shapes");
prioriory_replay_buffer_ = PriorityReplayBufferFactory::GetInstance().GetByHandle(handle_);
MS_EXCEPTION_IF_NULL(prioriory_replay_buffer_);

for (size_t i = 0; i < shapes.size(); i++) {
size_t num_element = std::accumulate(shapes[i].begin(), shapes[i].end(), 1, std::multiplies<size_t>());
size_t num_element = std::accumulate(shapes[i].begin(), shapes[i].end(), 1ULL, std::multiplies<size_t>());
size_t type_size = GetTypeByte(dtypes[i]);
schema_.push_back(num_element * type_size);
}
@@ -150,7 +150,7 @@ bool PriorityReplayBufferUpdateCpuKernel::Launch(const std::vector<AddressPtr> &
"memcpy_s() failed.");
MS_EXCEPTION_IF_CHECK_FAIL(memcpy_s(priorities.data(), inputs[1]->size, inputs[1]->addr, inputs[1]->size) == EOK,
"memcpy_s() failed.");
prioriory_replay_buffer_->UpdatePriorities(indices, priorities);
(void)prioriory_replay_buffer_->UpdatePriorities(indices, priorities);

// Return a placeholder in case of dead code eliminate optimization.
auto handle = GetDeviceAddress<int64_t>(outputs, 0);


+ 1
- 1
mindspore/ccsrc/plugin/device/cpu/kernel/rl/priority_replay_buffer_cpu_kernel.h View File

@@ -82,7 +82,7 @@ class PriorityReplayBufferSampleCpuKernel : public NativeCpuKernelMod {

private:
int64_t handle_{-1};
int64_t batch_size_{0};
size_t batch_size_{0};
std::vector<size_t> schema_;
std::shared_ptr<PriorityReplayBuffer> prioriory_replay_buffer_{nullptr};
};


+ 1
- 1
mindspore/ccsrc/plugin/device/cpu/kernel/rl/replay_buffer_factory.h View File

@@ -56,7 +56,7 @@ class ReplayBufferFactory {
template <typename... _Args>
std::tuple<int, std::shared_ptr<T>> Create(_Args... args) {
auto instance = std::make_shared<T>(args...);
map_handle_to_instances_.insert(std::make_pair(++handle_, instance));
(void)map_handle_to_instances_.insert(std::make_pair(++handle_, instance));
return std::make_tuple(handle_, instance);
}



+ 0
- 1
mindspore/ccsrc/plugin/device/cpu/kernel/rl/segment_tree.h View File

@@ -80,7 +80,6 @@ class SegmentTree {

protected:
size_t capacity_;
size_t size_;
std::vector<T> buffer_;
};
} // namespace kernel


Loading…
Cancel
Save