Browse Source

!8566 add dynamic for cache ops

From: @fangzehua
Reviewed-by: 
Signed-off-by:
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
efd3e6a168
16 changed files with 237 additions and 162 deletions
  1. +0
    -2
      mindspore/ccsrc/backend/kernel_compiler/cpu/assign_cpu_kernel.cc
  2. +0
    -1
      mindspore/ccsrc/backend/kernel_compiler/cpu/assign_cpu_kernel.h
  3. +0
    -2
      mindspore/ccsrc/backend/kernel_compiler/cpu/cache_swap_hashmap_cpu_kernel.cc
  4. +0
    -2
      mindspore/ccsrc/backend/kernel_compiler/cpu/cache_swap_hashmap_cpu_kernel.h
  5. +76
    -53
      mindspore/ccsrc/backend/kernel_compiler/cpu/map_cache_idx_cpu_kernel.cc
  6. +1
    -2
      mindspore/ccsrc/backend/kernel_compiler/cpu/map_cache_idx_cpu_kernel.h
  7. +0
    -1
      mindspore/ccsrc/backend/kernel_compiler/cpu/search_cache_idx_cpu_kernel.cc
  8. +0
    -2
      mindspore/ccsrc/backend/kernel_compiler/cpu/search_cache_idx_cpu_kernel.h
  9. +15
    -14
      mindspore/ccsrc/backend/kernel_compiler/cpu/update_cache_cpu_kernel.cc
  10. +1
    -1
      mindspore/ccsrc/backend/kernel_compiler/cpu/update_cache_cpu_kernel.h
  11. +6
    -1
      mindspore/core/abstract/infer_functions.h
  12. +93
    -0
      mindspore/core/abstract/prim_arrays.cc
  13. +3
    -0
      mindspore/core/abstract/primitive_infer_map.cc
  14. +3
    -0
      mindspore/core/base/core_ops.h
  15. +39
    -29
      mindspore/ops/operations/_cache_ops.py
  16. +0
    -52
      tests/st/ops/cpu/test_cache_ops.py

+ 0
- 2
mindspore/ccsrc/backend/kernel_compiler/cpu/assign_cpu_kernel.cc View File

@@ -35,7 +35,6 @@ void AssignCPUKernel::InitKernel(const CNodePtr &kernel_node) {
} }
} }
input_x_dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0); input_x_dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0);

if (input_x_dtype_ == kNumberTypeFloat32 || input_x_dtype_ == kNumberTypeInt32) { if (input_x_dtype_ == kNumberTypeFloat32 || input_x_dtype_ == kNumberTypeInt32) {
input_x_dtype_size_ = 4; input_x_dtype_size_ = 4;
} else if (input_x_dtype_ == kNumberTypeFloat64 || input_x_dtype_ == kNumberTypeInt64) { } else if (input_x_dtype_ == kNumberTypeFloat64 || input_x_dtype_ == kNumberTypeInt64) {
@@ -75,6 +74,5 @@ void AssignCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs,
MS_LOG(EXCEPTION) << "memcpy_s error, errorno" << ret; MS_LOG(EXCEPTION) << "memcpy_s error, errorno" << ret;
} }
} }

} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore

+ 0
- 1
mindspore/ccsrc/backend/kernel_compiler/cpu/assign_cpu_kernel.h View File

@@ -60,7 +60,6 @@ MS_REG_CPU_KERNEL(
Assign, Assign,
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
AssignCPUKernel); AssignCPUKernel);

} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore




+ 0
- 2
mindspore/ccsrc/backend/kernel_compiler/cpu/cache_swap_hashmap_cpu_kernel.cc View File

@@ -20,7 +20,6 @@


namespace mindspore { namespace mindspore {
namespace kernel { namespace kernel {

template <typename T> template <typename T>
void Compress(HashmapEntry<T> *entry_p, const size_t &length, T entry) { void Compress(HashmapEntry<T> *entry_p, const size_t &length, T entry) {
T i = (entry + 1) % length, off = 1; T i = (entry + 1) % length, off = 1;
@@ -107,6 +106,5 @@ void CacheSwapHashmapCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inpu
} }
} }
} }

} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore

+ 0
- 2
mindspore/ccsrc/backend/kernel_compiler/cpu/cache_swap_hashmap_cpu_kernel.h View File

@@ -25,7 +25,6 @@


namespace mindspore { namespace mindspore {
namespace kernel { namespace kernel {

class CacheSwapHashmapCPUKernel : public CPUKernel { class CacheSwapHashmapCPUKernel : public CPUKernel {
public: public:
CacheSwapHashmapCPUKernel() = default; CacheSwapHashmapCPUKernel() = default;
@@ -82,7 +81,6 @@ MS_REG_CPU_KERNEL(CacheSwapHashmap,
.AddOutputAttr(kNumberTypeInt32) .AddOutputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt32), .AddOutputAttr(kNumberTypeInt32),
CacheSwapHashmapCPUKernel); CacheSwapHashmapCPUKernel);

} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore




+ 76
- 53
mindspore/ccsrc/backend/kernel_compiler/cpu/map_cache_idx_cpu_kernel.cc View File

@@ -22,7 +22,6 @@


namespace mindspore { namespace mindspore {
namespace kernel { namespace kernel {

template <typename T> template <typename T>
struct HashmapEntry { struct HashmapEntry {
T key; T key;
@@ -60,8 +59,9 @@ T HashFunc(const T &key, const size_t &m) {
} }


template <typename T> template <typename T>
void Compress(HashmapEntry<T> *entry_p, const size_t &length, T entry) {
int Compress(HashmapEntry<T> *entry_p, const size_t &length, T entry) {
T i = (entry + 1) % length, off = 1; T i = (entry + 1) % length, off = 1;
int compress_count = 0;
for (; !entry_p[i].IsEmpty(); i = (i + 1) % length, off++) { for (; !entry_p[i].IsEmpty(); i = (i + 1) % length, off++) {
if (entry_p[i].tag > off) { if (entry_p[i].tag > off) {
entry_p[entry].key = entry_p[i].key; entry_p[entry].key = entry_p[i].key;
@@ -72,21 +72,20 @@ void Compress(HashmapEntry<T> *entry_p, const size_t &length, T entry) {
off = 0; off = 0;
entry = i; entry = i;
} }
compress_count++;
} }
return compress_count;
} }


void MapCacheIdxCPUKernel::InitKernel(const CNodePtr &kernel_node) { void MapCacheIdxCPUKernel::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
node_ = kernel_node;
auto hashmap_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); auto hashmap_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
auto emb_idx_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);


if (hashmap_shape.size() != 2) { if (hashmap_shape.size() != 2) {
MS_LOG(EXCEPTION) << "Dimension of HashMap must be 2, (n, 4)"; MS_LOG(EXCEPTION) << "Dimension of HashMap must be 2, (n, 4)";
} }


for (size_t i = 0; i < emb_idx_shape.size(); ++i) {
batch_size_ *= emb_idx_shape[i];
}

hashmap_length_ = hashmap_shape[0]; hashmap_length_ = hashmap_shape[0];
dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0); dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0);
} }
@@ -108,100 +107,124 @@ bool MapCacheIdxCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
template <typename T> template <typename T>
void MapCacheIdxCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, void MapCacheIdxCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &outputs) { const std::vector<kernel::AddressPtr> &outputs) {
auto emb_idx_shape = AnfAlgo::GetPrevNodeOutputInferShape(node_, 1);
batch_size_ = 1;
for (size_t i = 0; i < emb_idx_shape.size(); ++i) {
batch_size_ *= emb_idx_shape[i];
}
HashmapEntry<T> *hashmap = reinterpret_cast<HashmapEntry<T> *>(inputs[0]->addr); HashmapEntry<T> *hashmap = reinterpret_cast<HashmapEntry<T> *>(inputs[0]->addr);
auto input_indices = reinterpret_cast<T *>(inputs[1]->addr); auto input_indices = reinterpret_cast<T *>(inputs[1]->addr);
T *step_ = reinterpret_cast<T *>(inputs[2]->addr); T *step_ = reinterpret_cast<T *>(inputs[2]->addr);
T emb_max_num = *reinterpret_cast<T *>(inputs[3]->addr); T emb_max_num = *reinterpret_cast<T *>(inputs[3]->addr);
T cache_max_num = *reinterpret_cast<T *>(inputs[4]->addr);
T offset = *reinterpret_cast<T *>(inputs[4]->addr);
auto output_cache_idx = reinterpret_cast<T *>(outputs[0]->addr); auto output_cache_idx = reinterpret_cast<T *>(outputs[0]->addr);
auto output_old_emb_idx = reinterpret_cast<T *>(outputs[1]->addr); auto output_old_emb_idx = reinterpret_cast<T *>(outputs[1]->addr);
auto output_miss_emb_idx = reinterpret_cast<T *>(outputs[2]->addr); auto output_miss_emb_idx = reinterpret_cast<T *>(outputs[2]->addr);
auto output_swap_cache_idx = reinterpret_cast<T *>(outputs[3]->addr); auto output_swap_cache_idx = reinterpret_cast<T *>(outputs[3]->addr);


std::vector<T> output_miss_idx(batch_size_, -1);
std::vector<T> miss_idx;
size_t miss_count = 0;
float total_count = 0; float total_count = 0;
int count_size = 0; int count_size = 0;
float hit_count = 0; float hit_count = 0;

// search_cache_idx // search_cache_idx
for (size_t i = 0; i < batch_size_; ++i) { for (size_t i = 0; i < batch_size_; ++i) {
if (input_indices[i] == emb_max_num) {
output_miss_idx[i] = -1;
output_cache_idx[i] = cache_max_num;
output_miss_emb_idx[i] = -1;
T key = input_indices[i] - offset;
if (key >= emb_max_num || key < 0) {
output_cache_idx[i] = -1;
continue; continue;
} }


T key = input_indices[i];
T tmp_entry = HashFunc(key, hashmap_length_); T tmp_entry = HashFunc(key, hashmap_length_);


int count = 1;
size_t count = 1;
count_size += 1; count_size += 1;
while ((!hashmap[tmp_entry].IsEmpty() && !hashmap[tmp_entry].IsKey(key))) { while ((!hashmap[tmp_entry].IsEmpty() && !hashmap[tmp_entry].IsKey(key))) {
tmp_entry = (tmp_entry + 1) % hashmap_length_; tmp_entry = (tmp_entry + 1) % hashmap_length_;
if (count > hashmap_length_) {
MS_LOG(ERROR) << "Hashmap is full, search cache idx failed!";
break;
}
count += 1; count += 1;
} }


total_count += count; total_count += count;
if (hashmap[tmp_entry].IsEmpty()) { if (hashmap[tmp_entry].IsEmpty()) {
output_miss_idx[i] = i;
output_miss_emb_idx[i] = key;
miss_idx.emplace_back(i);
output_miss_emb_idx[miss_count] = key;
output_cache_idx[i] = -1; output_cache_idx[i] = -1;
miss_count++;
} else { } else {
hit_count += 1; hit_count += 1;
output_miss_idx[i] = -1;
output_cache_idx[i] = hashmap[tmp_entry].value; output_cache_idx[i] = hashmap[tmp_entry].value;
hashmap[tmp_entry].step = step_[0]; hashmap[tmp_entry].step = step_[0];
output_miss_emb_idx[i] = -1;
} }
} }
MS_LOG(INFO) << "avg search count: " << total_count / count_size;
MS_LOG(INFO) << "cache hit rate: " << hit_count / count_size;
MS_LOG(INFO) << "Miss count: " << miss_count;
MS_LOG(INFO) << "Avg search count: " << total_count / count_size;
MS_LOG(INFO) << "Cache hit rate: " << hit_count / count_size;

float total_insert_count = 0;
float total_delete_count = 0;


// swap hash map // swap hash map
for (size_t i = 0; i < batch_size_; ++i) {
if (output_miss_emb_idx[i] < 0) {
output_swap_cache_idx[i] = -1;
output_old_emb_idx[i] = -1;
} else {
T emb_idx = output_miss_emb_idx[i];
T entry = HashFunc(emb_idx, hashmap_length_);
T tag_count = 1;
while (!hashmap[entry].IsEmpty()) {
entry = (entry + 1) % hashmap_length_;
tag_count++;
for (size_t i = 0; i < miss_count; ++i) {
T emb_idx = output_miss_emb_idx[i];
T entry = HashFunc(emb_idx, hashmap_length_);
size_t tag_count = 1;
while (!hashmap[entry].IsEmpty()) {
entry = (entry + 1) % hashmap_length_;
if (tag_count > hashmap_length_) {
MS_LOG(ERROR) << "Hashmap is full, insert new key failed!";
break;
} }
tag_count++;
}


hashmap[entry].key = emb_idx;
hashmap[entry].step = step_[0];
hashmap[entry].tag = tag_count;

T tmp_entry = (entry + 1) % hashmap_length_;
hashmap[entry].key = emb_idx;
hashmap[entry].step = step_[0];
hashmap[entry].tag = tag_count;


while (hashmap[tmp_entry].IsEmpty() || hashmap[tmp_entry].IsUsing(step_[0])) {
tmp_entry = (tmp_entry + 1) % hashmap_length_;
T tmp_entry = (entry + 1) % hashmap_length_;
size_t delete_count = 1;
while (hashmap[tmp_entry].IsEmpty() || hashmap[tmp_entry].IsUsing(step_[0])) {
tmp_entry = (tmp_entry + 1) % hashmap_length_;
if (delete_count > hashmap_length_) {
MS_LOG(ERROR) << "Hashmap is full, delete old key failed!";
break;
} }

output_swap_cache_idx[i] = hashmap[tmp_entry].value;
output_old_emb_idx[i] = hashmap[tmp_entry].key;
hashmap[entry].value = output_swap_cache_idx[i];
hashmap[tmp_entry].SetEmpty();
Compress(hashmap, hashmap_length_, tmp_entry);
delete_count++;
} }

output_swap_cache_idx[i] = hashmap[tmp_entry].value;
output_old_emb_idx[i] = hashmap[tmp_entry].key;
hashmap[entry].value = output_swap_cache_idx[i];
hashmap[tmp_entry].SetEmpty();
int compress_count = Compress(hashmap, hashmap_length_, tmp_entry);
total_delete_count += (compress_count + delete_count);
total_insert_count += tag_count;
} }


MS_LOG(INFO) << "Insert count: " << total_insert_count / miss_count;
MS_LOG(INFO) << "Delete count: " << total_delete_count / miss_count;

// update step // update step
step_[0] += 1; step_[0] += 1;


// update cache idx // update cache idx
for (size_t i = 0; i < batch_size_; ++i) {
if (output_miss_idx[i] < 0 || output_miss_idx[i] >= cache_max_num) {
continue;
}
output_cache_idx[i] = output_swap_cache_idx[i];
for (size_t i = 0; i < miss_count; ++i) {
int idx = miss_idx[i];
output_cache_idx[idx] = output_swap_cache_idx[i];
} }
}


std::vector<size_t> out_shape;
out_shape.emplace_back(miss_count);
std::vector<TypeId> dtypes;
for (size_t i = 0; i < AnfAlgo::GetOutputTensorNum(node_); i++) {
dtypes.push_back(AnfAlgo::GetOutputInferDataType(node_, i));
}
AnfAlgo::SetOutputInferTypeAndShape(dtypes, {AnfAlgo::GetOutputInferShape(node_, 0), out_shape, out_shape, out_shape},
node_.get());
}
} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore

+ 1
- 2
mindspore/ccsrc/backend/kernel_compiler/cpu/map_cache_idx_cpu_kernel.h View File

@@ -27,7 +27,6 @@


namespace mindspore { namespace mindspore {
namespace kernel { namespace kernel {

class MapCacheIdxCPUKernel : public CPUKernel { class MapCacheIdxCPUKernel : public CPUKernel {
public: public:
MapCacheIdxCPUKernel() = default; MapCacheIdxCPUKernel() = default;
@@ -45,6 +44,7 @@ class MapCacheIdxCPUKernel : public CPUKernel {
size_t batch_size_{1}; size_t batch_size_{1};
size_t hashmap_length_{1}; size_t hashmap_length_{1};
TypeId dtype_{kTypeUnknown}; TypeId dtype_{kTypeUnknown};
CNodePtr node_ = nullptr;
}; };


MS_REG_CPU_KERNEL(MapCacheIdx, MS_REG_CPU_KERNEL(MapCacheIdx,
@@ -98,7 +98,6 @@ MS_REG_CPU_KERNEL(MapCacheIdx,
.AddOutputAttr(kNumberTypeInt32) .AddOutputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt32), .AddOutputAttr(kNumberTypeInt32),
MapCacheIdxCPUKernel); MapCacheIdxCPUKernel);

} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore




+ 0
- 1
mindspore/ccsrc/backend/kernel_compiler/cpu/search_cache_idx_cpu_kernel.cc View File

@@ -99,6 +99,5 @@ void SearchCacheIdxCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs
MS_LOG(INFO) << "avg search count: " << total_count / count_size; MS_LOG(INFO) << "avg search count: " << total_count / count_size;
MS_LOG(INFO) << "cache hit rate: " << hit_count / count_size; MS_LOG(INFO) << "cache hit rate: " << hit_count / count_size;
} }

} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore

+ 0
- 2
mindspore/ccsrc/backend/kernel_compiler/cpu/search_cache_idx_cpu_kernel.h View File

@@ -27,7 +27,6 @@


namespace mindspore { namespace mindspore {
namespace kernel { namespace kernel {

template <typename T> template <typename T>
struct HashmapEntry { struct HashmapEntry {
T key; T key;
@@ -133,7 +132,6 @@ MS_REG_CPU_KERNEL(SearchCacheIdx,
.AddOutputAttr(kNumberTypeInt32) .AddOutputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt32), .AddOutputAttr(kNumberTypeInt32),
SearchCacheIdxCPUKernel); SearchCacheIdxCPUKernel);

} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore




+ 15
- 14
mindspore/ccsrc/backend/kernel_compiler/cpu/update_cache_cpu_kernel.cc View File

@@ -21,20 +21,9 @@
namespace mindspore { namespace mindspore {
namespace kernel { namespace kernel {
void UpdateCacheCPUKernel::InitKernel(const CNodePtr &kernel_node) { void UpdateCacheCPUKernel::InitKernel(const CNodePtr &kernel_node) {
auto indices_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
auto update_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 2);
if (indices_shape.size() < 2) {
MS_LOG(EXCEPTION) << "indices shape less than 2";
}

for (size_t i = 0; i < indices_shape.size(); ++i) {
batch_size_ *= indices_shape[i];
}
MS_EXCEPTION_IF_NULL(kernel_node);
node_ = kernel_node;


for (size_t i = 0; i < update_shape.size(); ++i) {
update_size_ *= update_shape[i];
}
update_length_ = update_size_ / batch_size_;
input_x_dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0); input_x_dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0);
indices_dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 1); indices_dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 1);


@@ -64,6 +53,19 @@ bool UpdateCacheCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
template <typename T> template <typename T>
void UpdateCacheCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, void UpdateCacheCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &outputs) { const std::vector<kernel::AddressPtr> &outputs) {
auto indices_shape = AnfAlgo::GetPrevNodeOutputInferShape(node_, 1);
auto update_shape = AnfAlgo::GetPrevNodeOutputInferShape(node_, 2);

batch_size_ = 1;
for (size_t i = 0; i < indices_shape.size(); ++i) {
batch_size_ *= indices_shape[i];
}
MS_LOG(INFO) << "UpdateCache batch_size:" << batch_size_;
update_size_ = 1;
for (size_t i = 0; i < update_shape.size(); ++i) {
update_size_ *= update_shape[i];
}
update_length_ = update_shape[1];
char *input_x = reinterpret_cast<char *>(inputs[0]->addr); char *input_x = reinterpret_cast<char *>(inputs[0]->addr);
T *indices = reinterpret_cast<T *>(inputs[1]->addr); T *indices = reinterpret_cast<T *>(inputs[1]->addr);
char *update = reinterpret_cast<char *>(inputs[2]->addr); char *update = reinterpret_cast<char *>(inputs[2]->addr);
@@ -80,6 +82,5 @@ void UpdateCacheCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs,
} }
} }
} }

} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore

+ 1
- 1
mindspore/ccsrc/backend/kernel_compiler/cpu/update_cache_cpu_kernel.h View File

@@ -46,6 +46,7 @@ class UpdateCacheCPUKernel : public CPUKernel {
TypeId input_x_dtype_{kTypeUnknown}; TypeId input_x_dtype_{kTypeUnknown};
TypeId indices_dtype_{kTypeUnknown}; TypeId indices_dtype_{kTypeUnknown};
size_t input_x_dtype_size_ = 4; size_t input_x_dtype_size_ = 4;
CNodePtr node_ = nullptr;
}; };


MS_REG_CPU_KERNEL(UpdateCache, MS_REG_CPU_KERNEL(UpdateCache,
@@ -101,7 +102,6 @@ MS_REG_CPU_KERNEL(UpdateCache,
.AddInputAttr(kNumberTypeInt64) .AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeInt64), .AddOutputAttr(kNumberTypeInt64),
UpdateCacheCPUKernel); UpdateCacheCPUKernel);

} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore




+ 6
- 1
mindspore/core/abstract/infer_functions.h View File

@@ -201,7 +201,12 @@ AbstractBasePtr InferImplDiv(const AnalysisEnginePtr &, const PrimitivePtr &prim
const AbstractBasePtrList &args_spec_list); const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplRealDiv(const AnalysisEnginePtr &, const PrimitivePtr &primitive, AbstractBasePtr InferImplRealDiv(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list); const AbstractBasePtrList &args_spec_list);

AbstractBasePtr InferImplMapCacheIdx(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplCacheSwapTable(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplUpdateCache(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplGatherV2(const AnalysisEnginePtr &, const PrimitivePtr &primitive, AbstractBasePtr InferImplGatherV2(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list); const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive, AbstractBasePtr InferImplShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,


+ 93
- 0
mindspore/core/abstract/prim_arrays.cc View File

@@ -273,6 +273,99 @@ AbstractBasePtr InferImplScatterUpdate(const AnalysisEnginePtr &, const Primitiv
return std::make_shared<AbstractTensor>(x->element(), x->shape()); return std::make_shared<AbstractTensor>(x->element(), x->shape());
} }


AbstractBasePtr InferImplMapCacheIdx(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 5);
auto hash_map = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
MS_EXCEPTION_IF_NULL(hash_map);
MS_EXCEPTION_IF_NULL(hash_map->shape());

auto indices = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
auto indices_shp = indices->shape();
MS_EXCEPTION_IF_NULL(indices);
MS_EXCEPTION_IF_NULL(indices_shp);

ShapeVector shape;
ShapeVector min_shape;
ShapeVector max_shape;
if (!indices_shp->max_shape().empty()) {
max_shape = indices_shp->max_shape();
} else {
max_shape = indices_shp->shape();
}
for (size_t i = 0; i < max_shape.size(); i++) {
shape.emplace_back(Shape::SHP_ANY);
min_shape.emplace_back(1);
}

auto cache_idx = std::make_shared<AbstractTensor>(hash_map->element(), indices->shape());
auto old_emb_idx =
std::make_shared<AbstractTensor>(hash_map->element(), std::make_shared<Shape>(shape, min_shape, max_shape));
auto miss_emb_idx =
std::make_shared<AbstractTensor>(hash_map->element(), std::make_shared<Shape>(shape, min_shape, max_shape));
auto swap_emb_idx =
std::make_shared<AbstractTensor>(hash_map->element(), std::make_shared<Shape>(shape, min_shape, max_shape));

AbstractBasePtrList elements = {cache_idx, old_emb_idx, miss_emb_idx, swap_emb_idx};
return std::make_shared<AbstractTuple>(elements);
}

AbstractBasePtr InferImplCacheSwapTable(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 3);
auto cache_table = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
auto cache_table_shp = cache_table->shape();
MS_EXCEPTION_IF_NULL(cache_table);
MS_EXCEPTION_IF_NULL(cache_table_shp);

auto swap_cache_idx = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
auto swap_cache_idx_shp = swap_cache_idx->shape();
MS_EXCEPTION_IF_NULL(swap_cache_idx);
MS_EXCEPTION_IF_NULL(swap_cache_idx_shp);

auto cache_table_shape = cache_table_shp->shape();
auto swap_cache_idx_shape = swap_cache_idx_shp->shape();
ShapeVector shape;
shape.emplace_back(swap_cache_idx_shape[0]);
shape.emplace_back(cache_table_shape[1]);
auto swap_cache_idx_max_shape = swap_cache_idx_shp->max_shape();
ShapeVector max_shape;
ShapeVector min_shape;
if (!swap_cache_idx_max_shape.empty()) {
max_shape.emplace_back(swap_cache_idx_max_shape[0]);
max_shape.emplace_back(cache_table_shape[1]);
} else {
max_shape = shape;
}
for (size_t i = 0; i < max_shape.size(); ++i) {
min_shape.emplace_back(1);
}

AbstractTensorPtr ret =
std::make_shared<AbstractTensor>(cache_table->element(), std::make_shared<Shape>(shape, min_shape, max_shape));
return ret;
}

AbstractBasePtr InferImplUpdateCache(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
const std::string op_name = primitive->name();
auto input_x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
MS_EXCEPTION_IF_NULL(input_x);
MS_EXCEPTION_IF_NULL(input_x->shape());

auto indices = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
MS_EXCEPTION_IF_NULL(indices);
MS_EXCEPTION_IF_NULL(indices->shape());

ShapeVector shape;
shape.emplace_back(1);

AbstractTensorPtr ret = std::make_shared<AbstractTensor>(input_x->element(), std::make_shared<Shape>(shape));
return ret;
}

AbstractBasePtr InferImplDiv(const AnalysisEnginePtr &, const PrimitivePtr &primitive, AbstractBasePtr InferImplDiv(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) { const AbstractBasePtrList &args_spec_list) {
const std::string op_name = primitive->name(); const std::string op_name = primitive->name();


+ 3
- 0
mindspore/core/abstract/primitive_infer_map.cc View File

@@ -57,6 +57,9 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
{prim::kPrimUnsortedSegmentSum, {InferImplUnsortedSegmentSum, true}}, {prim::kPrimUnsortedSegmentSum, {InferImplUnsortedSegmentSum, true}},
{prim::kPrimScatterAdd, {InferImplScatterAdd, true}}, {prim::kPrimScatterAdd, {InferImplScatterAdd, true}},
{prim::kPrimScatterUpdate, {InferImplScatterUpdate, true}}, {prim::kPrimScatterUpdate, {InferImplScatterUpdate, true}},
{prim::kPrimMapCacheIdx, {InferImplMapCacheIdx, true}},
{prim::kPrimCacheSwapTable, {InferImplCacheSwapTable, true}},
{prim::kPrimUpdateCache, {InferImplUpdateCache, true}},
{prim::kPrimDiv, {InferImplDiv, true}}, {prim::kPrimDiv, {InferImplDiv, true}},
{prim::kPrimRealDiv, {InferImplRealDiv, true}}, {prim::kPrimRealDiv, {InferImplRealDiv, true}},
{prim::kPrimShape, {InferImplShape, false}}, {prim::kPrimShape, {InferImplShape, false}},


+ 3
- 0
mindspore/core/base/core_ops.h View File

@@ -98,6 +98,9 @@ inline const PrimitivePtr kPrimUnsortedSegmentSum = std::make_shared<Primitive>(
inline const PrimitivePtr kPrimUnsortedSegmentMin = std::make_shared<Primitive>("UnsortedSegmentMin"); inline const PrimitivePtr kPrimUnsortedSegmentMin = std::make_shared<Primitive>("UnsortedSegmentMin");
inline const PrimitivePtr kPrimConcatOffset = std::make_shared<Primitive>("ConcatOffset"); inline const PrimitivePtr kPrimConcatOffset = std::make_shared<Primitive>("ConcatOffset");
inline const PrimitivePtr kPrimReshape = std::make_shared<Primitive>("Reshape"); inline const PrimitivePtr kPrimReshape = std::make_shared<Primitive>("Reshape");
inline const PrimitivePtr kPrimMapCacheIdx = std::make_shared<Primitive>("MapCacheIdx");
inline const PrimitivePtr kPrimUpdateCache = std::make_shared<Primitive>("UpdateCache");
inline const PrimitivePtr kPrimCacheSwapTable = std::make_shared<Primitive>("CacheSwapTable");
inline const PrimitivePtr kPrimTile = std::make_shared<Primitive>("Tile"); inline const PrimitivePtr kPrimTile = std::make_shared<Primitive>("Tile");
inline const PrimitivePtr kPrimAddN = std::make_shared<Primitive>("AddN"); inline const PrimitivePtr kPrimAddN = std::make_shared<Primitive>("AddN");
inline const PrimitivePtr kPrimAccumulateNV2 = std::make_shared<Primitive>("AccumulateNV2"); inline const PrimitivePtr kPrimAccumulateNV2 = std::make_shared<Primitive>("AccumulateNV2");


+ 39
- 29
mindspore/ops/operations/_cache_ops.py View File

@@ -15,11 +15,11 @@
"""cache_ops""" """cache_ops"""
from ..._checkparam import Validator as validator from ..._checkparam import Validator as validator
from ...common import dtype as mstype from ...common import dtype as mstype
from ..primitive import PrimitiveWithInfer, prim_attr_register
from ..primitive import PrimitiveWithInfer, prim_attr_register, PrimitiveWithCheck
from .. import signature as sig from .. import signature as sig




class UpdateCache(PrimitiveWithInfer):
class UpdateCache(PrimitiveWithCheck):
""" """
Update the value fo input_x, similar to ScatterNdUpdate. Update the value fo input_x, similar to ScatterNdUpdate.
The diffirent is that UpdateCache will not update when indices < 0 or indices >= max_num. The diffirent is that UpdateCache will not update when indices < 0 or indices >= max_num.
@@ -47,15 +47,12 @@ class UpdateCache(PrimitiveWithInfer):
self.init_prim_io_names(inputs=['input_x', 'indices', 'update', 'max_num'], self.init_prim_io_names(inputs=['input_x', 'indices', 'update', 'max_num'],
outputs=['out']) outputs=['out'])


def infer_shape(self, input_x_shape, indices_shape, update_shape, max_num_shape):

if len(indices_shape) < 2:
raise ValueError("The dimension of 'indices' in UpdateCache must >= 2, "
"but got %d." % len(indices_shape))
def check_shape(self, input_x_shape, indices_shape, update_shape, max_num_shape):
return [1] return [1]


def infer_dtype(self, input_x_dtype, indices_dtype, update_dtype, max_num_dtype):
validator.check_tensor_dtype_valid("indices", indices_dtype, mstype.int_type, self.name)
def check_dtype(self, input_x_dtype, indices_dtype, update_dtype, max_num_dtype):
validator.check_tensor_dtype_valid(
"indices", indices_dtype, mstype.int_type, self.name)
return input_x_dtype return input_x_dtype




@@ -139,7 +136,8 @@ class SearchCacheIdx(PrimitiveWithInfer):


def infer_dtype(self, hashmap_dtype, indices_dtype, step_dtype, emb_max_num_dtype, cache_max_num_dtype): def infer_dtype(self, hashmap_dtype, indices_dtype, step_dtype, emb_max_num_dtype, cache_max_num_dtype):
args = {"hashmap": hashmap_dtype, "indices": indices_dtype} args = {"hashmap": hashmap_dtype, "indices": indices_dtype}
validator.check_tensors_dtypes_same_and_valid(args, mstype.int_type, self.name)
validator.check_tensors_dtypes_same_and_valid(
args, mstype.int_type, self.name)
out_dtype = (hashmap_dtype, hashmap_dtype, hashmap_dtype) out_dtype = (hashmap_dtype, hashmap_dtype, hashmap_dtype)
return out_dtype return out_dtype


@@ -172,7 +170,6 @@ class CacheSwapHashmap(PrimitiveWithInfer):
outputs=['swap_cache_idx', 'old_emb_idx']) outputs=['swap_cache_idx', 'old_emb_idx'])


def infer_shape(self, hashmap_shape, miss_emb_idx_shape, step_shape): def infer_shape(self, hashmap_shape, miss_emb_idx_shape, step_shape):

if len(hashmap_shape) != 2: if len(hashmap_shape) != 2:
raise ValueError("The dimension of 'hashmap' in CacheSwapHashmap must be 2, " raise ValueError("The dimension of 'hashmap' in CacheSwapHashmap must be 2, "
"but got %d." % len(hashmap_shape)) "but got %d." % len(hashmap_shape))
@@ -181,12 +178,13 @@ class CacheSwapHashmap(PrimitiveWithInfer):
return out_shape return out_shape


def infer_dtype(self, hashmap_dtype, miss_emb_idx_dtype, step_dtype): def infer_dtype(self, hashmap_dtype, miss_emb_idx_dtype, step_dtype):
validator.check_tensor_dtype_valid("miss_emb_idx", miss_emb_idx_dtype, mstype.int_type, self.name)
validator.check_tensor_dtype_valid(
"miss_emb_idx", miss_emb_idx_dtype, mstype.int_type, self.name)
out_dtype = (miss_emb_idx_dtype, miss_emb_idx_dtype) out_dtype = (miss_emb_idx_dtype, miss_emb_idx_dtype)
return out_dtype return out_dtype




class CacheSwapTable(PrimitiveWithInfer):
class CacheSwapTable(PrimitiveWithCheck):
""" """
Delete a hashmap entry,and insert a new key to hashmap, return the key and value of delete entry. Delete a hashmap entry,and insert a new key to hashmap, return the key and value of delete entry.


@@ -212,21 +210,20 @@ class CacheSwapTable(PrimitiveWithInfer):
self.init_prim_io_names(inputs=['cache_table', 'swap_cache_idx', 'miss_value'], self.init_prim_io_names(inputs=['cache_table', 'swap_cache_idx', 'miss_value'],
outputs=['old_value']) outputs=['old_value'])


def infer_shape(self, cache_table_shape, swap_cache_idx_shape, miss_value_shape):
def check_shape(self, cache_table_shape, swap_cache_idx_shape, miss_value_shape):
if len(cache_table_shape) != 2: if len(cache_table_shape) != 2:
raise ValueError( raise ValueError(
"cache table shape must be 2, but got %d" % len(cache_table_shape)) "cache table shape must be 2, but got %d" % len(cache_table_shape))
if swap_cache_idx_shape + cache_table_shape[1:] != miss_value_shape:
raise ValueError(
"swap_cache_idx_shape + cache_table_shape[1:] must equal to miss_value_shape")

return miss_value_shape return miss_value_shape


def infer_dtype(self, cache_table_dtype, swap_cache_idx_dtype, miss_value_dtype):
validator.check_tensor_dtype_valid("swap_cache_idx", swap_cache_idx_dtype, mstype.int_type, self.name)
def check_dtype(self, cache_table_dtype, swap_cache_idx_dtype, miss_value_dtype):
validator.check_tensor_dtype_valid(
"swap_cache_idx", swap_cache_idx_dtype, mstype.int_type, self.name)
return miss_value_dtype return miss_value_dtype




class MapCacheIdx(PrimitiveWithInfer):
class MapCacheIdx(PrimitiveWithCheck):
""" """
MapCacheIdx merge SearchCacheIdx, CacheSwapHashmap, UpdateCache together. MapCacheIdx merge SearchCacheIdx, CacheSwapHashmap, UpdateCache together.
When input an indices tensor, it will output the cache indices which search in hashmap. When input an indices tensor, it will output the cache indices which search in hashmap.
@@ -244,21 +241,34 @@ class MapCacheIdx(PrimitiveWithInfer):
def __init__(self): def __init__(self):
"""init MapCacheIdx""" """init MapCacheIdx"""


self.init_prim_io_names(inputs=['hashmap', 'indices', 'step', 'emb_max_num', 'cache_max_num'],
self.init_prim_io_names(inputs=['hashmap', 'indices', 'step', 'emb_max_num', 'offset'],
outputs=['cache_idx', 'old_emb_idx', 'miss_emb_idx', 'swap_cache_idx']) outputs=['cache_idx', 'old_emb_idx', 'miss_emb_idx', 'swap_cache_idx'])


def infer_shape(self, hashmap_shape, indices_shape, step_shape, emb_max_num_shape, cache_max_num_shape):
def __check__(self, hashmap, indices, step, emb_max_num, offset):
hashmap_shape = hashmap['shape']
if len(hashmap_shape) != 2: if len(hashmap_shape) != 2:
raise ValueError("The dimension of 'hashmap' in SearchCacheIdx must be 2, " raise ValueError("The dimension of 'hashmap' in SearchCacheIdx must be 2, "
"but got %d." % len(hashmap_shape)) "but got %d." % len(hashmap_shape))
out_shape = (indices_shape, indices_shape,
indices_shape, indices_shape)
return out_shape
out_shape = (indices['shape'], -1, -1, -1)


def infer_dtype(self, hashmap_dtype, indices_dtype, step_dtype, emb_max_num_dtype, cache_max_num_dtype):
hashmap_dtype = hashmap['dtype']
indices_dtype = indices['dtype']
args = {"hashmap": hashmap_dtype, "indices": indices_dtype} args = {"hashmap": hashmap_dtype, "indices": indices_dtype}
validator.check_tensors_dtypes_same_and_valid(args, mstype.int_type, self.name)
validator.check_tensor_type_same(args, mstype.int_type, self.name)
out_dtype = (hashmap_dtype, hashmap_dtype, out_dtype = (hashmap_dtype, hashmap_dtype,
hashmap_dtype, hashmap_dtype) hashmap_dtype, hashmap_dtype)
return out_dtype

out = {'shape': out_shape,
'dtype': out_dtype,
'value': None}
if 'max_shape' in indices:
out['max_shape'] = (indices['max_shape'], indices['max_shape'],
indices['max_shape'], indices['max_shape'])
else:
out['max_shape'] = (indices['shape'], indices['shape'],
indices['shape'], indices['shape'])
if 'min_shape' in indices:
out['min_shape'] = (indices['min_shape'], 0, 0, 0)
else:
out['min_shape'] = (0, 0, 0, 0)
return out

+ 0
- 52
tests/st/ops/cpu/test_cache_ops.py View File

@@ -75,19 +75,6 @@ class CacheSwapHashmapNet(nn.Cell):
return self.ops(self.net.hashmap, miss_emb_idx, self.step) return self.ops(self.net.hashmap, miss_emb_idx, self.step)




class MapCacheIdxNet(nn.Cell):
def __init__(self, hashmap_np):
super().__init__()
self.ops = P.MapCacheIdx()
self.hashmap = Parameter(Tensor(hashmap_np), name="hashmap")
self.emb_max = 25
self.cache_max = 10
self.step = 0

def construct(self, indices):
return self.ops(self.hashmap, indices, self.step, self.emb_max, self.cache_max)


class UpdateCacheNet(nn.Cell): class UpdateCacheNet(nn.Cell):
def __init__(self, x): def __init__(self, x):
super().__init__() super().__init__()
@@ -165,45 +152,6 @@ def test_cache_swap_hashmap():
np.array(hashmap_np_after_ops, np.int32)) np.array(hashmap_np_after_ops, np.int32))




@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_map_cache_idx():
hashmap_np = init_hashmap(10)
indices_np = np.array([10, 2, 20, 5, 3], np.int32)
map_cache_idx = MapCacheIdxNet(hashmap_np)
indices = Tensor(indices_np)
cache_idx, old_emb_idx, miss_emb_idx, swap_cache_idx = map_cache_idx(
indices)

expect_cache_idx = [5, 1, 9, 7, 3]
expect_old_emb_idx = [-1, -1, 21, 15, -1]
expect_miss_emb_idx = [-1, -1, 20, 5, -1]
expect_swap_cache_idx = [-1, -1, 9, 7, -1]

hashmap_np_after_ops = [[5, 7, 0, 1],
[10, 5, 0, 1],
[2, 1, 0, 1],
[20, 9, 0, 1],
[20, 9, 0, 0],
[0, 0, 0, 0],
[0, 0, 0, 0],
[0, 0, 0, 0],
[3, 3, 0, 1],
[21, 9, -5, 0]]

assert np.allclose(cache_idx.asnumpy(),
np.array(expect_cache_idx, np.int32))
assert np.allclose(old_emb_idx.asnumpy(),
np.array(expect_old_emb_idx, np.int32))
assert np.allclose(miss_emb_idx.asnumpy(),
np.array(expect_miss_emb_idx, np.int32))
assert np.allclose(swap_cache_idx.asnumpy(),
np.array(expect_swap_cache_idx, np.int32))
assert np.allclose(map_cache_idx.hashmap.data.asnumpy(),
np.array(hashmap_np_after_ops, np.int32))


@pytest.mark.level0 @pytest.mark.level0
@pytest.mark.platform_x86_cpu @pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard @pytest.mark.env_onecard


Loading…
Cancel
Save