Browse Source

!14287 fix loss scale for cache embedding

From: @fangzehua
Reviewed-by: @zhoufeng54,@jjfeing
Signed-off-by: @jjfeing
pull/14287/MERGE
mindspore-ci-bot Gitee 5 years ago
parent
commit
d346a861bc
2 changed files with 6 additions and 0 deletions
  1. +5
    -0
      mindspore/ccsrc/frontend/parallel/cache_embedding/cache_embedding.cc
  2. +1
    -0
      mindspore/core/base/core_ops.h

+ 5
- 0
mindspore/ccsrc/frontend/parallel/cache_embedding/cache_embedding.cc View File

@@ -714,6 +714,11 @@ void AddCacheEmbedding(const FuncGraphPtr &graph, bool is_pipe) {
if (!CheckHostCacheParamSize(param_cache_enable_set)) {
return;
}
for (auto &node : cnodes) {
if (IsPrimitiveCNode(node, prim::kPrimNPUAllocFloatStatus)) {
MS_LOG(EXCEPTION) << "Cache embedding haven't support loss scale yet.";
}
}
auto unique_cache_enable = FindUniqueCacheEnable(cnodes);
if (unique_cache_enable.empty()) {
MS_LOG(WARNING) << "Parameters have cache enable, but not find Unique op cache enable.";


+ 1
- 0
mindspore/core/base/core_ops.h View File

@@ -465,6 +465,7 @@ inline const PrimitivePtr kPrimPriorBox = std::make_shared<Primitive>("PriorBox"
inline const PrimitivePtr kPrimQuantDTypeCast = std::make_shared<Primitive>("QuantDTypeCast");
inline const PrimitivePtr kPrimWhile = std::make_shared<Primitive>("While");
inline const PrimitivePtr kPrimPull = std::make_shared<Primitive>("Pull");
inline const PrimitivePtr kPrimNPUAllocFloatStatus = std::make_shared<Primitive>("NPUAllocFloatStatus");

// Structures
inline const PrimitivePtr kPrimMakeList = std::make_shared<Primitive>("make_list");


Loading…
Cancel
Save