| @@ -19,55 +19,20 @@ | |||||
| #include <memory> | #include <memory> | ||||
| #include <vector> | #include <vector> | ||||
| #include "runtime/device/cpu/cpu_device_address.h" | #include "runtime/device/cpu/cpu_device_address.h" | ||||
| #include "utils/cache_embedding_hashmap_struct.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace kernel { | namespace kernel { | ||||
| template <typename T> | |||||
| struct HashmapEntry { | |||||
| T key; | |||||
| T value; | |||||
| T step; | |||||
| T tag; | |||||
| bool IsEmpty() { | |||||
| if (this->tag == NULLTAG) | |||||
| return true; | |||||
| else | |||||
| return false; | |||||
| } | |||||
| bool IsUsing(const T &train_step) { | |||||
| if (this->step >= (train_step - 1)) | |||||
| return true; | |||||
| else | |||||
| return false; | |||||
| } | |||||
| bool IsKey(const T &emb_idx) { | |||||
| if (this->key == emb_idx) | |||||
| return true; | |||||
| else | |||||
| return false; | |||||
| } | |||||
| void SetEmpty() { this->tag = NULLTAG; } | |||||
| }; | |||||
| template <typename T> | |||||
| T HashFunc(const T &key, const size_t &m) { | |||||
| return (T)(((0.6180339 * key) - floor(0.6180339 * key)) * m); | |||||
| } | |||||
| template <typename T> | template <typename T> | ||||
| int Compress(HashmapEntry<T> *entry_p, const size_t &length, T entry) { | int Compress(HashmapEntry<T> *entry_p, const size_t &length, T entry) { | ||||
| T i = (entry + 1) % length, off = 1; | T i = (entry + 1) % length, off = 1; | ||||
| int compress_count = 0; | int compress_count = 0; | ||||
| for (; !entry_p[i].IsEmpty(); i = (i + 1) % length, off++) { | for (; !entry_p[i].IsEmpty(); i = (i + 1) % length, off++) { | ||||
| if (entry_p[i].tag > off) { | |||||
| entry_p[entry].key = entry_p[i].key; | |||||
| entry_p[entry].value = entry_p[i].value; | |||||
| entry_p[entry].step = entry_p[i].step; | |||||
| entry_p[entry].tag = entry_p[i].tag - off; | |||||
| if (entry_p[i].tag_ > off) { | |||||
| entry_p[entry].key_ = entry_p[i].key_; | |||||
| entry_p[entry].value_ = entry_p[i].value_; | |||||
| entry_p[entry].step_ = entry_p[i].step_; | |||||
| entry_p[entry].tag_ = entry_p[i].tag_ - off; | |||||
| entry_p[i].SetEmpty(); | entry_p[i].SetEmpty(); | ||||
| off = 0; | off = 0; | ||||
| entry = i; | entry = i; | ||||
| @@ -127,6 +92,7 @@ void MapCacheIdxCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, | |||||
| float total_count = 0; | float total_count = 0; | ||||
| int count_size = 0; | int count_size = 0; | ||||
| float hit_count = 0; | float hit_count = 0; | ||||
| // search_cache_idx | // search_cache_idx | ||||
| for (size_t i = 0; i < batch_size_; ++i) { | for (size_t i = 0; i < batch_size_; ++i) { | ||||
| T key = input_indices[i] - offset; | T key = input_indices[i] - offset; | ||||
| @@ -140,7 +106,7 @@ void MapCacheIdxCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, | |||||
| while ((!hashmap[tmp_entry].IsEmpty() && !hashmap[tmp_entry].IsKey(key))) { | while ((!hashmap[tmp_entry].IsEmpty() && !hashmap[tmp_entry].IsKey(key))) { | ||||
| tmp_entry = (tmp_entry + 1) % hashmap_length_; | tmp_entry = (tmp_entry + 1) % hashmap_length_; | ||||
| if (count > hashmap_length_) { | if (count > hashmap_length_) { | ||||
| MS_LOG(ERROR) << "Hashmap is full, search cache idx failed!"; | |||||
| MS_LOG(EXCEPTION) << "Hashmap is full, search cache idx failed, please set a larger vocab_cache_size!"; | |||||
| break; | break; | ||||
| } | } | ||||
| count += 1; | count += 1; | ||||
| @@ -153,8 +119,8 @@ void MapCacheIdxCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, | |||||
| miss_count++; | miss_count++; | ||||
| } else { | } else { | ||||
| hit_count += 1; | hit_count += 1; | ||||
| output_cache_idx[i] = hashmap[tmp_entry].value; | |||||
| hashmap[tmp_entry].step = step_[0]; | |||||
| output_cache_idx[i] = hashmap[tmp_entry].value_; | |||||
| hashmap[tmp_entry].step_ = step_[0]; | |||||
| } | } | ||||
| } | } | ||||
| if (miss_count != 0) { | if (miss_count != 0) { | ||||
| @@ -175,27 +141,27 @@ void MapCacheIdxCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, | |||||
| while (!hashmap[entry].IsEmpty()) { | while (!hashmap[entry].IsEmpty()) { | ||||
| entry = (entry + 1) % hashmap_length_; | entry = (entry + 1) % hashmap_length_; | ||||
| if (tag_count > hashmap_length_) { | if (tag_count > hashmap_length_) { | ||||
| MS_LOG(ERROR) << "Hashmap is full, insert new key failed!"; | |||||
| MS_LOG(EXCEPTION) << "Hashmap is full, insert new key failed, please set a larger vocab_cache_size!"; | |||||
| break; | break; | ||||
| } | } | ||||
| tag_count++; | tag_count++; | ||||
| } | } | ||||
| hashmap[entry].key = emb_idx; | |||||
| hashmap[entry].step = step_[0]; | |||||
| hashmap[entry].tag = tag_count; | |||||
| hashmap[entry].key_ = emb_idx; | |||||
| hashmap[entry].step_ = step_[0]; | |||||
| hashmap[entry].tag_ = tag_count; | |||||
| T tmp_entry = (entry + 1) % hashmap_length_; | T tmp_entry = (entry + 1) % hashmap_length_; | ||||
| size_t delete_count = 1; | size_t delete_count = 1; | ||||
| while (hashmap[tmp_entry].IsEmpty() || hashmap[tmp_entry].IsUsing(step_[0])) { | while (hashmap[tmp_entry].IsEmpty() || hashmap[tmp_entry].IsUsing(step_[0])) { | ||||
| tmp_entry = (tmp_entry + 1) % hashmap_length_; | tmp_entry = (tmp_entry + 1) % hashmap_length_; | ||||
| if (delete_count > hashmap_length_) { | if (delete_count > hashmap_length_) { | ||||
| MS_LOG(ERROR) << "Hashmap is full, delete old key failed!"; | |||||
| MS_LOG(EXCEPTION) << "Hashmap is full, delete old key failed, please set a larger vocab_cache_size!"; | |||||
| break; | break; | ||||
| } | } | ||||
| delete_count++; | delete_count++; | ||||
| } | } | ||||
| output_swap_cache_idx[i] = hashmap[tmp_entry].value; | |||||
| output_old_emb_idx[i] = hashmap[tmp_entry].key; | |||||
| hashmap[entry].value = output_swap_cache_idx[i]; | |||||
| output_swap_cache_idx[i] = hashmap[tmp_entry].value_; | |||||
| output_old_emb_idx[i] = hashmap[tmp_entry].key_; | |||||
| hashmap[entry].value_ = output_swap_cache_idx[i]; | |||||
| hashmap[tmp_entry].SetEmpty(); | hashmap[tmp_entry].SetEmpty(); | ||||
| int compress_count = Compress(hashmap, hashmap_length_, tmp_entry); | int compress_count = Compress(hashmap, hashmap_length_, tmp_entry); | ||||
| total_delete_count += (compress_count + delete_count); | total_delete_count += (compress_count + delete_count); | ||||
| @@ -23,8 +23,6 @@ | |||||
| #include "backend/kernel_compiler/cpu/cpu_kernel.h" | #include "backend/kernel_compiler/cpu/cpu_kernel.h" | ||||
| #include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" | #include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" | ||||
| #define NULLTAG 0 | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace kernel { | namespace kernel { | ||||
| class MapCacheIdxCPUKernel : public CPUKernel { | class MapCacheIdxCPUKernel : public CPUKernel { | ||||
| @@ -188,12 +188,18 @@ void ReplaceOldNode(std::unordered_map<int64_t, BufferFusionInfo_t> *buffer_fusi | |||||
| MS_EXCEPTION_IF_NULL(manager); | MS_EXCEPTION_IF_NULL(manager); | ||||
| auto buffer_fusion_info = (*buffer_fusion_infos)[fusion_id]; | auto buffer_fusion_info = (*buffer_fusion_infos)[fusion_id]; | ||||
| if (buffer_fusion_info.outputs_list.size() == 1) { // single output | if (buffer_fusion_info.outputs_list.size() == 1) { // single output | ||||
| if (kernel_graph != nullptr) { | |||||
| kernel_graph->FrontBackendlMapUpdate(buffer_fusion_info.outputs_list[0], buffer_fusion_kernel); | |||||
| } | |||||
| (void)manager->Replace(buffer_fusion_info.outputs_list[0], buffer_fusion_kernel); | (void)manager->Replace(buffer_fusion_info.outputs_list[0], buffer_fusion_kernel); | ||||
| ReplaceInputNodeInOtherFusionScope(buffer_fusion_infos, fusion_id, buffer_fusion_info.outputs_list[0], | ReplaceInputNodeInOtherFusionScope(buffer_fusion_infos, fusion_id, buffer_fusion_info.outputs_list[0], | ||||
| buffer_fusion_kernel); | buffer_fusion_kernel); | ||||
| } else { // multiple output | } else { // multiple output | ||||
| for (size_t index = 0; index < buffer_fusion_info.outputs_list.size(); ++index) { | for (size_t index = 0; index < buffer_fusion_info.outputs_list.size(); ++index) { | ||||
| auto tuple_item = CreateTupleGetItem(buffer_fusion_kernel, kernel_graph, index); | auto tuple_item = CreateTupleGetItem(buffer_fusion_kernel, kernel_graph, index); | ||||
| if (kernel_graph != nullptr) { | |||||
| kernel_graph->FrontBackendlMapUpdate(buffer_fusion_info.outputs_list[index], tuple_item); | |||||
| } | |||||
| (void)manager->Replace(buffer_fusion_info.outputs_list[index], tuple_item); | (void)manager->Replace(buffer_fusion_info.outputs_list[index], tuple_item); | ||||
| ReplaceInputNodeInOtherFusionScope(buffer_fusion_infos, fusion_id, buffer_fusion_info.outputs_list[index], | ReplaceInputNodeInOtherFusionScope(buffer_fusion_infos, fusion_id, buffer_fusion_info.outputs_list[index], | ||||
| tuple_item); | tuple_item); | ||||
| @@ -274,6 +274,10 @@ tensor::TensorPtr CreateTupleTensor(const ValueTuplePtr &value_tuple) { | |||||
| bool IsNopNode(const AnfNodePtr &node) { | bool IsNopNode(const AnfNodePtr &node) { | ||||
| auto context_ptr = MsContext::GetInstance(); | auto context_ptr = MsContext::GetInstance(); | ||||
| MS_EXCEPTION_IF_NULL(context_ptr); | MS_EXCEPTION_IF_NULL(context_ptr); | ||||
| auto target = GetCNodeTarget(node); | |||||
| if (target == kCPUDevice) { | |||||
| return false; | |||||
| } | |||||
| if (context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET) != kAscendDevice && | if (context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET) != kAscendDevice && | ||||
| context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET) != kGPUDevice) { | context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET) != kGPUDevice) { | ||||
| return false; | return false; | ||||
| @@ -0,0 +1,535 @@ | |||||
| /** | |||||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "frontend/parallel/cache_embedding/cache_embedding.h" | |||||
| #include <random> | |||||
| #include <vector> | |||||
| #include <list> | |||||
| #include <utility> | |||||
| #include <memory> | |||||
| #include <unordered_map> | |||||
| #include <unordered_set> | |||||
| #include <string> | |||||
| #include <algorithm> | |||||
| #include "backend/optimizer/common/helper.h" | |||||
| #include "frontend/optimizer/optimizer.h" | |||||
| #include "ir/func_graph.h" | |||||
| #include "utils/cache_embedding_hashmap_struct.h" | |||||
| namespace mindspore { | |||||
| namespace parallel { | |||||
| using ParamMap = std::unordered_map<ParameterPtr, ParameterPtr>; | |||||
| using ParamSet = std::unordered_set<ParameterPtr>; | |||||
| using NodePairList = std::vector<std::pair<AnfNodePtr, AnfNodePtr>>; | |||||
| ParamMap AddCacheParameters(const FuncGraphPtr &graph, const ParamSet ¶meter_cache_enable_set) { | |||||
| ParamMap cache_host_params_map; | |||||
| for (auto ¶m : parameter_cache_enable_set) { | |||||
| auto param_info = param->param_info(); | |||||
| if (param_info && param_info->cache_enable()) { | |||||
| auto data_type = param->Type(); | |||||
| auto data_element_type = data_type->cast<mindspore::TensorTypePtr>()->element(); | |||||
| auto type_id = data_element_type->type_id(); | |||||
| auto cache_shape = param_info->cache_shape(); | |||||
| auto ori_param_name = param->name(); | |||||
| auto new_tensor = std::make_shared<tensor::Tensor>(type_id, cache_shape); | |||||
| ParamInfoPtr new_param_info = std::make_shared<ParamInfo>(); | |||||
| auto cache_name = ori_param_name + "_cache"; | |||||
| new_param_info->set_name(cache_name); | |||||
| new_tensor->set_param_info(new_param_info); | |||||
| auto cache_param = graph->AddWeightParameter(cache_name); | |||||
| cache_param->set_default_param(MakeValue(new_tensor)); | |||||
| cache_param->set_abstract(new_tensor->ToAbstract()); | |||||
| cache_host_params_map[cache_param] = param; | |||||
| } | |||||
| } | |||||
| return cache_host_params_map; | |||||
| } | |||||
| bool CheckHostCacheParamSize(const ParamSet ¶meter_cache_enable_set) { | |||||
| int64_t host_size = 0; | |||||
| int64_t cache_size = 0; | |||||
| for (auto &host_param : parameter_cache_enable_set) { | |||||
| auto tmp_host_size = host_param->abstract()->GetShapeTrack()->cast<abstract::ShapePtr>()->shape()[0]; | |||||
| auto host_param_info = host_param->param_info(); | |||||
| auto cache_shape = host_param_info->cache_shape(); | |||||
| if (cache_shape.empty()) { | |||||
| MS_LOG(EXCEPTION) << "The value of cache_shape is empty."; | |||||
| } | |||||
| auto tmp_cache_size = cache_shape[0]; | |||||
| if ((host_size != 0 && tmp_host_size != host_size) || (cache_size != 0 && tmp_cache_size != cache_size)) { | |||||
| MS_LOG(EXCEPTION) | |||||
| << "If EmbeddingLookup are cache enable, vocab_size and vocab_cache_size of different cells must be the same."; | |||||
| } | |||||
| cache_size = tmp_cache_size; | |||||
| host_size = tmp_host_size; | |||||
| } | |||||
| if (cache_size >= host_size) { | |||||
| MS_LOG(WARNING) << "vocab_cache_size >= vocab_size, there is no need use cache."; | |||||
| return false; | |||||
| } | |||||
| return true; | |||||
| } | |||||
| void ReplaceCacheParams(const FuncGraphPtr &graph, const ParamMap &map) { | |||||
| auto manager = graph->manager(); | |||||
| MS_EXCEPTION_IF_NULL(manager); | |||||
| for (auto &ele : map) { | |||||
| if (!manager->Replace(ele.second, ele.first)) { | |||||
| MS_LOG(EXCEPTION) << "host param: " << ele.second->name() << ", replace node failed."; | |||||
| } | |||||
| } | |||||
| } | |||||
| ParamSet MapKeysToSet(const ParamMap &map) { | |||||
| ParamSet set; | |||||
| for (auto &ele : map) { | |||||
| set.insert(ele.first); | |||||
| } | |||||
| return set; | |||||
| } | |||||
| ParamSet FindParamCacheEnable(const FuncGraphPtr &graph) { | |||||
| ParamSet parameter_cache_enable_set; | |||||
| auto parameters = graph->parameters(); | |||||
| auto params_size = parameters.size(); | |||||
| for (size_t i = 0; i < params_size; ++i) { | |||||
| auto param = parameters[i]->cast<ParameterPtr>(); | |||||
| auto param_info = param->param_info(); | |||||
| if (param_info && param_info->cache_enable()) { | |||||
| parameter_cache_enable_set.insert(param); | |||||
| } | |||||
| } | |||||
| return parameter_cache_enable_set; | |||||
| } | |||||
| CNodePtrList FindUniqueCacheEnable(const CNodePtrList &cnodes) { | |||||
| size_t cnodes_size = cnodes.size(); | |||||
| CNodePtrList unique_cache_enable; | |||||
| for (size_t i = 0; i < cnodes_size; ++i) { | |||||
| if (IsPrimitiveCNode(cnodes[i], prim::kPrimUnique)) { | |||||
| auto unique_node = cnodes[i]; | |||||
| auto unique_prim = GetCNodePrimitive(unique_node); | |||||
| MS_EXCEPTION_IF_NULL(unique_prim); | |||||
| auto attr_value = unique_prim->GetAttr(kAttrCacheEnable); | |||||
| if (attr_value != nullptr && GetValue<bool>(attr_value)) { | |||||
| unique_cache_enable.emplace_back(unique_node); | |||||
| } | |||||
| } | |||||
| } | |||||
| if (unique_cache_enable.size() > 1) { | |||||
| MS_LOG(EXCEPTION) << "Support only one of Unique op cache enable, but got " << unique_cache_enable.size(); | |||||
| } | |||||
| return unique_cache_enable; | |||||
| } | |||||
| void BindAndInitCacheTensor(const ParamMap ¶m_pair_list, const ParameterPtr &hashmap) { | |||||
| auto hashmap_tensor_value = hashmap->default_param(); | |||||
| auto hashmap_tensor = hashmap_tensor_value->cast<std::shared_ptr<tensor::Tensor>>(); | |||||
| for (auto &ele : param_pair_list) { | |||||
| auto host_tensor_value = ele.second->default_param(); | |||||
| auto host_tensor = host_tensor_value->cast<std::shared_ptr<tensor::Tensor>>(); | |||||
| auto cache_tensor_value = ele.first->default_param(); | |||||
| auto cache_tensor = cache_tensor_value->cast<std::shared_ptr<tensor::Tensor>>(); | |||||
| // bind host, cache, hashmap | |||||
| host_tensor->set_cache_enable(true); | |||||
| host_tensor->set_hashmap_tensor_ptr(hashmap_tensor); | |||||
| host_tensor->set_cache_tensor_ptr(cache_tensor); | |||||
| // init cache tensor data | |||||
| auto cache_byte_size = cache_tensor->Size(); | |||||
| int ret = memcpy_s(cache_tensor->data_c(), cache_byte_size, host_tensor->data_c(), cache_byte_size); | |||||
| if (ret != 0) { | |||||
| MS_LOG(EXCEPTION) << "Memcpy failed."; | |||||
| } | |||||
| } | |||||
| } | |||||
| template <typename T> | |||||
| void InitHashMapData(void *data, const int64_t host_size, const int64_t cache_size, const size_t hashmap_size, | |||||
| const size_t byte_size) { | |||||
| MS_LOG(INFO) << "Start init hashmap data."; | |||||
| MS_EXCEPTION_IF_NULL(data); | |||||
| HashmapEntry<T> *hashmap_data = static_cast<HashmapEntry<T> *>(data); | |||||
| MS_EXCEPTION_IF_NULL(hashmap_data); | |||||
| int ret = memset_s(hashmap_data, byte_size, 0, byte_size); | |||||
| if (ret != 0) { | |||||
| MS_LOG(EXCEPTION) << "Memset failed."; | |||||
| } | |||||
| std::vector<T> host_range; | |||||
| host_range.reserve(host_size); | |||||
| for (int64_t i = 0; i < host_size; ++i) { | |||||
| host_range.emplace_back(i); | |||||
| } | |||||
| std::random_shuffle(host_range.begin(), host_range.end()); | |||||
| size_t size = cache_size; | |||||
| size_t hashmap_count = 0; | |||||
| for (size_t i = 0; i < size; ++i) { | |||||
| auto random_key = host_range[i]; | |||||
| auto entry = HashFunc(random_key, hashmap_size); | |||||
| size_t count = 1; | |||||
| while (!hashmap_data[entry].IsEmpty() && !hashmap_data[entry].IsKey(random_key)) { | |||||
| count += 1; | |||||
| entry = (entry + 1) % hashmap_size; | |||||
| } | |||||
| if (hashmap_data[entry].IsEmpty()) { | |||||
| hashmap_count++; | |||||
| hashmap_data[entry].key_ = random_key; | |||||
| hashmap_data[entry].value_ = i; | |||||
| hashmap_data[entry].step_ = kInitStep; | |||||
| hashmap_data[entry].tag_ = count; | |||||
| } | |||||
| } | |||||
| MS_LOG(INFO) << "Hashmap init success, with " << hashmap_count << " / " << hashmap_size; | |||||
| } | |||||
| AnfNodePtr InitHashMap(const FuncGraphPtr &func_graph, const int64_t host_size, const int64_t cache_size, | |||||
| TypeId type_id) { | |||||
| // init new tensor | |||||
| size_t hashmap_size = cache_size * kEmptyRate; | |||||
| std::vector<int64_t> host_shape{static_cast<int64_t>(hashmap_size), 4}; | |||||
| auto new_tensor = std::make_shared<tensor::Tensor>(type_id, host_shape); | |||||
| size_t byte_size = new_tensor->Size(); | |||||
| if (type_id == TypeId::kNumberTypeInt64) { | |||||
| InitHashMapData<int64_t>(new_tensor->data_c(), host_size, cache_size, hashmap_size, byte_size); | |||||
| } else { | |||||
| InitHashMapData<int32_t>(new_tensor->data_c(), host_size, cache_size, hashmap_size, byte_size); | |||||
| } | |||||
| ParamInfoPtr new_param_info = std::make_shared<ParamInfo>(); | |||||
| std::string hashmap_name = "cache_hashmap"; | |||||
| new_param_info->set_name(hashmap_name); | |||||
| new_tensor->set_param_info(new_param_info); | |||||
| auto hashmap = func_graph->AddWeightParameter(hashmap_name); | |||||
| hashmap->set_default_param(MakeValue(new_tensor)); | |||||
| hashmap->set_abstract(new_tensor->ToAbstract()); | |||||
| return hashmap; | |||||
| } | |||||
| AnfNodePtr InitStep(const FuncGraphPtr &func_graph, TypeId type_id) { | |||||
| std::vector<int64_t> host_shape{1}; | |||||
| auto new_tensor = std::make_shared<tensor::Tensor>(type_id, host_shape); | |||||
| auto step_data = static_cast<int64_t *>(new_tensor->data_c()); | |||||
| step_data[0] = 0; | |||||
| ParamInfoPtr new_param_info = std::make_shared<ParamInfo>(); | |||||
| std::string step_name = "cache_step"; | |||||
| new_param_info->set_name(step_name); | |||||
| new_tensor->set_param_info(new_param_info); | |||||
| auto step = func_graph->AddWeightParameter(step_name); | |||||
| step->set_default_param(MakeValue(new_tensor)); | |||||
| step->set_abstract(new_tensor->ToAbstract()); | |||||
| return step; | |||||
| } | |||||
| AnfNodePtr CreateMapCacheIdx(const FuncGraphPtr &func_graph, const AnfNodePtr &indices, | |||||
| const ParamMap &cache_host_params_map) { | |||||
| auto iter = cache_host_params_map.begin(); | |||||
| int64_t cache_size = iter->first->abstract()->GetShapeTrack()->cast<abstract::ShapePtr>()->shape()[0]; | |||||
| int64_t host_size = iter->second->abstract()->GetShapeTrack()->cast<abstract::ShapePtr>()->shape()[0]; | |||||
| auto indices_type = indices->Type(); | |||||
| auto indices_element_type = indices_type->cast<mindspore::TensorTypePtr>()->element(); | |||||
| auto indices_type_id = indices_element_type->type_id(); | |||||
| auto hashmap = InitHashMap(func_graph, host_size, cache_size, indices_type_id); | |||||
| auto step = InitStep(func_graph, indices_type_id); | |||||
| auto max_num = NewValueNode(MakeValue(host_size)); | |||||
| auto hashmap_param = hashmap->cast<ParameterPtr>(); | |||||
| BindAndInitCacheTensor(cache_host_params_map, hashmap_param); | |||||
| // add rank_id | |||||
| int64_t offset_value = 0; | |||||
| std::string rank_id_str = common::GetEnv("RANK_ID"); | |||||
| if (!rank_id_str.empty()) { | |||||
| int64_t rank_id = atoi(rank_id_str.c_str()); | |||||
| offset_value = rank_id * host_size; | |||||
| } | |||||
| auto offset = NewValueNode(MakeValue(offset_value)); | |||||
| auto max_num_imm = std::make_shared<Int64Imm>(SizeToLong(host_size)); | |||||
| auto max_num_abstract_scalar = std::make_shared<abstract::AbstractScalar>(max_num_imm); | |||||
| max_num->set_abstract(max_num_abstract_scalar); | |||||
| auto offset_imm = std::make_shared<Int64Imm>(SizeToLong(offset_value)); | |||||
| auto offset_abstract_scalar = std::make_shared<abstract::AbstractScalar>(offset_imm); | |||||
| offset->set_abstract(offset_abstract_scalar); | |||||
| PrimitivePtr map_cache_primitive = prim::kPrimMapCacheIdx; | |||||
| map_cache_primitive->set_attr(kAttrPrimitiveTarget, MakeValue("CPU")); | |||||
| std::vector<AnfNodePtr> map_cache_nodes{NewValueNode(map_cache_primitive), hashmap, indices, step, max_num, offset}; | |||||
| auto map_cache_idx = func_graph->NewCNode(map_cache_nodes); | |||||
| auto indices_ori_shp = indices->Shape(); | |||||
| auto indices_shp = indices_ori_shp->cast<abstract::ShapePtr>(); | |||||
| ShapeVector shape; | |||||
| ShapeVector min_shape; | |||||
| ShapeVector max_shape; | |||||
| if (!indices_shp->max_shape().empty()) { | |||||
| max_shape = indices_shp->max_shape(); | |||||
| } else { | |||||
| max_shape = indices_shp->shape(); | |||||
| } | |||||
| for (size_t i = 0; i < max_shape.size(); i++) { | |||||
| shape.emplace_back(-1); | |||||
| min_shape.emplace_back(1); | |||||
| } | |||||
| auto cache_idx = std::make_shared<abstract::AbstractTensor>(indices_element_type, indices_shp); | |||||
| auto old_emb_idx = std::make_shared<abstract::AbstractTensor>( | |||||
| indices_element_type, std::make_shared<abstract::Shape>(shape, min_shape, max_shape)); | |||||
| auto miss_emb_idx = std::make_shared<abstract::AbstractTensor>( | |||||
| indices_element_type, std::make_shared<abstract::Shape>(shape, min_shape, max_shape)); | |||||
| auto swap_emb_idx = std::make_shared<abstract::AbstractTensor>( | |||||
| indices_element_type, std::make_shared<abstract::Shape>(shape, min_shape, max_shape)); | |||||
| std::vector<std::shared_ptr<abstract::AbstractBase>> elements = {cache_idx, old_emb_idx, miss_emb_idx, swap_emb_idx}; | |||||
| auto abstract = std::make_shared<abstract::AbstractTuple>(elements); | |||||
| map_cache_idx->set_abstract(abstract); | |||||
| return map_cache_idx; | |||||
| } | |||||
| AnfNodePtr CreateTupleGetItem(const FuncGraphPtr &func_graph, const AnfNodePtr &input, size_t index) { | |||||
| MS_EXCEPTION_IF_NULL(func_graph); | |||||
| auto idx = NewValueNode(SizeToLong(index)); | |||||
| MS_EXCEPTION_IF_NULL(idx); | |||||
| auto imm = std::make_shared<Int64Imm>(SizeToLong(index)); | |||||
| auto abstract_scalar = std::make_shared<abstract::AbstractScalar>(imm); | |||||
| idx->set_abstract(abstract_scalar); | |||||
| auto tuple_getitem = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), input, idx}); | |||||
| auto input_abstract_tuple = dyn_cast<abstract::AbstractTuple>(input->abstract()); | |||||
| auto tuple_getitem_abstract = input_abstract_tuple->elements()[index]; | |||||
| tuple_getitem->set_abstract(tuple_getitem_abstract); | |||||
| return tuple_getitem; | |||||
| } | |||||
| void CreateTupleGetItems(const FuncGraphPtr &func_graph, const AnfNodePtr &input, std::vector<AnfNodePtr> *outputs) { | |||||
| auto input_abstract_tuple = dyn_cast<abstract::AbstractTuple>(input->abstract()); | |||||
| auto size = input_abstract_tuple->elements().size(); | |||||
| for (size_t i = 0; i < size; ++i) { | |||||
| (*outputs).emplace_back(CreateTupleGetItem(func_graph, input, i)); | |||||
| } | |||||
| MS_EXCEPTION_IF_NULL(outputs); | |||||
| } | |||||
| AnfNodePtr CreateEmbeddingLookup(const FuncGraphPtr &graph, ParameterPtr params, AnfNodePtr indices) { | |||||
| MS_EXCEPTION_IF_NULL(graph); | |||||
| PrimitivePtr emb_lookup_primitive = prim::kPrimEmbeddingLookup; | |||||
| emb_lookup_primitive->set_attr(kAttrPrimitiveTarget, MakeValue("CPU")); | |||||
| emb_lookup_primitive->set_attr(kAttrOffset, MakeValue<int64_t>(0)); | |||||
| std::vector<AnfNodePtr> emb_lookup_nodes{NewValueNode(emb_lookup_primitive), params, indices}; | |||||
| auto emb_lookup = graph->NewCNode(emb_lookup_nodes); | |||||
| return emb_lookup; | |||||
| } | |||||
| AnfNodePtr CreateCacheSwapTable(const FuncGraphPtr &graph, ParameterPtr cache_table, AnfNodePtr swap_cache_idx, | |||||
| AnfNodePtr miss_value) { | |||||
| MS_EXCEPTION_IF_NULL(graph); | |||||
| PrimitivePtr cache_swap_table_primitive = prim::kPrimCacheSwapTable; | |||||
| std::vector<AnfNodePtr> cache_swap_table_nodes{NewValueNode(cache_swap_table_primitive), cache_table, swap_cache_idx, | |||||
| miss_value}; | |||||
| auto cache_swap_table = graph->NewCNode(cache_swap_table_nodes); | |||||
| return cache_swap_table; | |||||
| } | |||||
| AnfNodePtr CreateUpdateCache(const FuncGraphPtr &graph, ParameterPtr params, AnfNodePtr old_emb_idx, | |||||
| AnfNodePtr old_value) { | |||||
| MS_EXCEPTION_IF_NULL(graph); | |||||
| PrimitivePtr update_cache_primitive = prim::kPrimUpdateCache; | |||||
| update_cache_primitive->set_attr(kAttrPrimitiveTarget, MakeValue("CPU")); | |||||
| auto params_ori_shp = params->Shape(); | |||||
| MS_EXCEPTION_IF_NULL(params_ori_shp); | |||||
| auto params_shp = params_ori_shp->cast<abstract::ShapePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(params_shp); | |||||
| auto params_shape = params_shp->shape(); | |||||
| auto max_size = params_shape[0]; | |||||
| auto max_size_node = NewValueNode(MakeValue(max_size)); | |||||
| auto max_num_imm = std::make_shared<Int64Imm>(SizeToLong(max_size)); | |||||
| auto max_num_abstract_scalar = std::make_shared<abstract::AbstractScalar>(max_num_imm); | |||||
| max_size_node->set_abstract(max_num_abstract_scalar); | |||||
| std::vector<AnfNodePtr> update_cache_nodes{NewValueNode(update_cache_primitive), params, old_emb_idx, old_value, | |||||
| max_size_node}; | |||||
| auto update_cache = graph->NewCNode(update_cache_nodes); | |||||
| return update_cache; | |||||
| } | |||||
| NodePairList CreateEmbSwapUpdate(const FuncGraphPtr &graph, ParamMap param_pair_list, | |||||
| const AnfNodePtrList &map_cache_idx_node_outputs) { | |||||
| MS_EXCEPTION_IF_NULL(graph); | |||||
| NodePairList node_pair_list; | |||||
| for (auto &ele : param_pair_list) { | |||||
| auto emb_lookup = CreateEmbeddingLookup(graph, ele.second, map_cache_idx_node_outputs[2]); | |||||
| auto cache_swap_table = CreateCacheSwapTable(graph, ele.first, map_cache_idx_node_outputs[3], emb_lookup); | |||||
| auto update_cache = CreateUpdateCache(graph, ele.second, map_cache_idx_node_outputs[1], cache_swap_table); | |||||
| node_pair_list.emplace_back(std::make_pair(cache_swap_table, update_cache)); | |||||
| } | |||||
| return node_pair_list; | |||||
| } | |||||
| AnfNodePtr CreateControlDepend(const FuncGraphPtr &main_graph, const AnfNodePtr &prior_node, | |||||
| const AnfNodePtr &behind_node) { | |||||
| // Create control depend | |||||
| MS_EXCEPTION_IF_NULL(main_graph); | |||||
| AnfNodePtrList cd_inputs = {NewValueNode(prim::kPrimControlDepend), prior_node, behind_node}; | |||||
| auto control_depend_cnode = main_graph->NewCNode(cd_inputs); | |||||
| return control_depend_cnode; | |||||
| } | |||||
| AnfNodePtr CreateDepend(const FuncGraphPtr &graph, const std::vector<AnfNodePtr> &invalid_nodes, | |||||
| const AnfNodePtr &patron_node) { | |||||
| MS_EXCEPTION_IF_NULL(graph); | |||||
| std::vector<AnfNodePtr> make_tuple_list{NewValueNode(prim::kPrimMakeTuple)}; | |||||
| std::copy(invalid_nodes.begin(), invalid_nodes.end(), std::back_inserter(make_tuple_list)); | |||||
| auto make_tuple = graph->NewCNode(make_tuple_list); | |||||
| std::vector<AnfNodePtr> depend_list{NewValueNode(prim::kPrimDepend), patron_node, make_tuple}; | |||||
| auto depend_cnode = graph->NewCNode(depend_list); | |||||
| depend_cnode->set_abstract(patron_node->abstract()); | |||||
| return depend_cnode; | |||||
| } | |||||
| CNodePtrList FindSparseGatherV2WithCache(const CNodePtrList &cnodes, const ParamSet ¶m_set) { | |||||
| size_t cnodes_size = cnodes.size(); | |||||
| CNodePtrList sparse_gather_v2_with_cache; | |||||
| for (size_t i = 0; i < cnodes_size; ++i) { | |||||
| if (IsPrimitiveCNode(cnodes[i], prim::kPrimSparseGatherV2)) { | |||||
| auto param_node = cnodes[i]->input(1)->cast<ParameterPtr>(); | |||||
| if (param_set.find(param_node) != param_set.end()) { | |||||
| sparse_gather_v2_with_cache.push_back(cnodes[i]); | |||||
| } | |||||
| } | |||||
| } | |||||
| if (sparse_gather_v2_with_cache.empty()) { | |||||
| MS_LOG(EXCEPTION) << "Can not find SparseGatherV2 with cache param."; | |||||
| } | |||||
| auto indices = sparse_gather_v2_with_cache[0]->input(2); | |||||
| for (auto &ele : sparse_gather_v2_with_cache) { | |||||
| if (ele->input(2) != indices) { | |||||
| MS_LOG(EXCEPTION) << "SparseGatherV2 which with cache param have different indices!."; | |||||
| } | |||||
| } | |||||
| return sparse_gather_v2_with_cache; | |||||
| } | |||||
| AnfNodePtr FindGatherV2FromSparseGatherV2(const FuncGraphPtr &graph, const AnfNodePtr &node) { | |||||
| MS_EXCEPTION_IF_NULL(graph); | |||||
| AnfNodePtrList gatherv2_nodes; | |||||
| auto user_set = graph->manager()->node_users()[node]; | |||||
| for (auto &ele : user_set) { | |||||
| if (IsPrimitiveCNode(ele.first, prim::kPrimGatherV2)) { | |||||
| gatherv2_nodes.emplace_back(ele.first); | |||||
| } | |||||
| } | |||||
| if (gatherv2_nodes.size() != 1) { | |||||
| MS_LOG(EXCEPTION) << "SparseGatherV2 with cache can only used by one of gatherv2, but got " | |||||
| << gatherv2_nodes.size(); | |||||
| } | |||||
| return gatherv2_nodes[0]; | |||||
| } | |||||
| void AddCacheEmbedding(const FuncGraphPtr &graph) { | |||||
| MS_EXCEPTION_IF_NULL(graph); | |||||
| std::list<CNodePtr> orders = graph->GetOrderedCnodes(); | |||||
| CNodePtrList cnodes(orders.begin(), orders.end()); | |||||
| size_t cnodes_size = cnodes.size(); | |||||
| auto manager = graph->manager(); | |||||
| MS_EXCEPTION_IF_NULL(manager); | |||||
| bool training = graph->has_flag("training"); | |||||
| auto param_cache_enable_set = FindParamCacheEnable(graph); | |||||
| if (param_cache_enable_set.empty()) { | |||||
| MS_LOG(INFO) << "Parameters are all not cache enable."; | |||||
| return; | |||||
| } else { | |||||
| MS_LOG(INFO) << "Parameters have cache enable."; | |||||
| } | |||||
| if (!CheckHostCacheParamSize(param_cache_enable_set)) { | |||||
| return; | |||||
| } | |||||
| if (training) { | |||||
| // If training, create cache parameters corresponding to the host params with is cache_enable. | |||||
| // Replace the host params. Create hashmap then insert MapCacheIdx op after Unique with has 'cache_enable' attr. | |||||
| // Bind hashmap tensor ptr and cache tensor ptr to host tensor, so that we can flush values | |||||
| // from cache to host in each epoch end. | |||||
| // Create EmbeddingLookup(CPU), CacheSwapTable(Ascend), UpdateCache(CPU) for each pair of params, in order to | |||||
| // flush miss values to cache params and write back old values to host params. | |||||
| // If no use pipe in training, EmbeddingLookup and CacheSwapTable must execute before SparseGatherV2, so add | |||||
| // ControlDepend between them. And add Depend for UpdateCache op and ControlDepnd op to add nodes into graph. | |||||
| auto unique_cache_enable = FindUniqueCacheEnable(cnodes); | |||||
| if (unique_cache_enable.empty()) { | |||||
| MS_LOG(WARNING) << "Parameters have cache enable, but not find Unique op cache enable."; | |||||
| return; | |||||
| } | |||||
| auto cache_host_params_map = AddCacheParameters(graph, param_cache_enable_set); | |||||
| auto param_set = MapKeysToSet(cache_host_params_map); | |||||
| ReplaceCacheParams(graph, cache_host_params_map); | |||||
| graph->set_flag(GRAPH_FLAG_CACHE_ENABLE, true); | |||||
| auto unique_node = unique_cache_enable[0]; | |||||
| CNodePtrList sparse_gatherv2_with_cache = FindSparseGatherV2WithCache(cnodes, param_set); | |||||
| auto unique_node_output_0 = CreateTupleGetItem(graph, unique_node, 0); | |||||
| auto map_cache_idx = CreateMapCacheIdx(graph, unique_node_output_0, cache_host_params_map); | |||||
| AnfNodePtrList map_cache_idx_node_outputs; | |||||
| CreateTupleGetItems(graph, map_cache_idx, &map_cache_idx_node_outputs); | |||||
| if (!manager->Replace(sparse_gatherv2_with_cache[0]->input(2), map_cache_idx_node_outputs[0])) { | |||||
| MS_LOG(EXCEPTION) << "MapCacheIdx output[0] replace node failed"; | |||||
| } | |||||
| auto node_pair_list = CreateEmbSwapUpdate(graph, cache_host_params_map, map_cache_idx_node_outputs); | |||||
| AnfNodePtr last_node = cnodes[cnodes_size - 1]; | |||||
| CNodePtr return_node; | |||||
| if (last_node->isa<CNode>()) { | |||||
| return_node = last_node->cast<CNodePtr>(); | |||||
| } | |||||
| MS_EXCEPTION_IF_NULL(return_node); | |||||
| if (!IsPrimitiveCNode(return_node, prim::kPrimReturn)) { | |||||
| MS_LOG(EXCEPTION) << "The last cnode after sorting, not return cnode."; | |||||
| } | |||||
| if (return_node->inputs().size() < 2) { | |||||
| MS_LOG(EXCEPTION) << "Number of return node inputs should be great than or equal to 2."; | |||||
| } | |||||
| AnfNodePtrList invalid_nodes; | |||||
| for (auto &ele : node_pair_list) { | |||||
| std::transform(sparse_gatherv2_with_cache.begin(), sparse_gatherv2_with_cache.end(), | |||||
| std::back_inserter(invalid_nodes), [&graph, &ele](const AnfNodePtr &sparse_gatherv2) { | |||||
| return CreateControlDepend(graph, ele.first, sparse_gatherv2); | |||||
| }); | |||||
| invalid_nodes.emplace_back(ele.second); | |||||
| } | |||||
| auto depend_node = CreateDepend(graph, invalid_nodes, return_node->input(1)); | |||||
| if (!manager->Replace(return_node->input(1), depend_node)) { | |||||
| MS_LOG(EXCEPTION) << "Depend replace node failed"; | |||||
| } | |||||
| } else { | |||||
| // If eval, Use EmbeddingLookup(CPU) op to replace GatherV2. | |||||
| // The network is the same as Host-Device mode. | |||||
| auto unique_cache_enable = FindUniqueCacheEnable(cnodes); | |||||
| if (unique_cache_enable.empty()) { | |||||
| MS_LOG(WARNING) << "Parameters have cache enable, but not find Unique op cache enable."; | |||||
| return; | |||||
| } | |||||
| graph->set_flag(GRAPH_FLAG_CACHE_ENABLE, true); | |||||
| // replace GatherV2 to EmbeddingLookupCPU | |||||
| auto indices = unique_cache_enable[0]->input(1); | |||||
| auto sparse_gatherv2_with_cache = FindSparseGatherV2WithCache(cnodes, param_cache_enable_set); | |||||
| for (auto &ele : sparse_gatherv2_with_cache) { | |||||
| auto anf_ele = ele->cast<AnfNodePtr>(); | |||||
| auto gatherv2 = FindGatherV2FromSparseGatherV2(graph, anf_ele); | |||||
| auto param = ele->input(1)->cast<ParameterPtr>(); | |||||
| auto embedding_lookup = CreateEmbeddingLookup(graph, param, indices); | |||||
| if (!manager->Replace(gatherv2, embedding_lookup)) { | |||||
| MS_LOG(EXCEPTION) << "Depend replace node failed"; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } // namespace parallel | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,28 @@ | |||||
| /** | |||||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_CACHE_EMBEDDING_CACHE_EMBEDDING_H_ | |||||
| #define MINDSPORE_CCSRC_FRONTEND_PARALLEL_CACHE_EMBEDDING_CACHE_EMBEDDING_H_ | |||||
| #include "ir/anf.h" | |||||
| namespace mindspore { | |||||
| namespace parallel { | |||||
| // Automatically adding control depend based on effect order and side effect analysis. | |||||
| void AddCacheEmbedding(const FuncGraphPtr &graph); | |||||
| } // namespace parallel | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_CACHE_EMBEDDING_CACHE_EMBEDDING_H_ | |||||
| @@ -36,11 +36,14 @@ | |||||
| #include "frontend/optimizer/graph_transform.h" | #include "frontend/optimizer/graph_transform.h" | ||||
| #include "frontend/parallel/step_parallel.h" | #include "frontend/parallel/step_parallel.h" | ||||
| #include "frontend/parallel/step_auto_parallel.h" | #include "frontend/parallel/step_auto_parallel.h" | ||||
| #include "frontend/parallel/cache_embedding/cache_embedding.h" | |||||
| #include "frontend/parallel/allreduce_fusion/step_allreduce_fusion.h" | #include "frontend/parallel/allreduce_fusion/step_allreduce_fusion.h" | ||||
| #include "frontend/optimizer/recompute.h" | #include "frontend/optimizer/recompute.h" | ||||
| #include "utils/log_adapter.h" | #include "utils/log_adapter.h" | ||||
| #include "pipeline/jit/pipeline_split.h" | #include "pipeline/jit/pipeline_split.h" | ||||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||||
| #include "ps/util.h" | |||||
| #endif | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace pipeline { | namespace pipeline { | ||||
| using OptPassGroupMap = opt::OptPassGroupMap; | using OptPassGroupMap = opt::OptPassGroupMap; | ||||
| @@ -391,6 +394,26 @@ bool AddRecomputationPass(const ResourcePtr &res) { | |||||
| return true; | return true; | ||||
| } | } | ||||
| bool AddCacheEmbeddingPass(const ResourcePtr &res) { | |||||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||||
| if (ps::Util::IsParamServerMode()) { | |||||
| return true; | |||||
| } | |||||
| #endif | |||||
| FuncGraphPtr func_graph = res->func_graph(); | |||||
| MS_EXCEPTION_IF_NULL(func_graph); | |||||
| parallel::AddCacheEmbedding(func_graph); | |||||
| if (func_graph->has_flag(GRAPH_FLAG_CACHE_ENABLE)) { | |||||
| auto params = func_graph->parameters(); | |||||
| AbstractBasePtrList args_spec_list; | |||||
| std::for_each(params.begin(), params.end(), | |||||
| [&args_spec_list](const AnfNodePtr &node) { args_spec_list.push_back(node->abstract()); }); | |||||
| func_graph = pipeline::Renormalize(res, func_graph, args_spec_list); | |||||
| } | |||||
| return true; | |||||
| } | |||||
| bool MergeDupGraphPass(const ResourcePtr &res) { | bool MergeDupGraphPass(const ResourcePtr &res) { | ||||
| FuncGraphPtr func_graph = res->func_graph(); | FuncGraphPtr func_graph = res->func_graph(); | ||||
| MS_EXCEPTION_IF_NULL(func_graph); | MS_EXCEPTION_IF_NULL(func_graph); | ||||
| @@ -500,6 +523,7 @@ std::vector<PassItem> kVmPasses = {{"simplify_data_structures", SimplifyDataStru | |||||
| {"tuple_transform", OptPassTransformGraphGroup}, | {"tuple_transform", OptPassTransformGraphGroup}, | ||||
| {"opt_graph_kernel_a", OptPassGraphKernelGroupA}, | {"opt_graph_kernel_a", OptPassGraphKernelGroupA}, | ||||
| {"opt_graph_kernel_b", OptPassGraphKernelGroupB}, | {"opt_graph_kernel_b", OptPassGraphKernelGroupB}, | ||||
| {"add_cache_embedding", AddCacheEmbeddingPass}, | |||||
| {"add_control_depend", AddControlDependPass}, | {"add_control_depend", AddControlDependPass}, | ||||
| {"add_recomputation", AddRecomputationPass}}; | {"add_recomputation", AddRecomputationPass}}; | ||||
| @@ -37,6 +37,7 @@ bool PipelineSplitPass(const ResourcePtr &res); | |||||
| bool ValidatePass(const ResourcePtr &res); | bool ValidatePass(const ResourcePtr &res); | ||||
| bool ConvertPrepareAdapt(const ResourcePtr &res); | bool ConvertPrepareAdapt(const ResourcePtr &res); | ||||
| bool AddControlDependPass(const ResourcePtr &res); | bool AddControlDependPass(const ResourcePtr &res); | ||||
| bool AddCacheEmbeddingPass(const ResourcePtr &res); | |||||
| bool InferenceOptPreparePass(const ResourcePtr &res); | bool InferenceOptPreparePass(const ResourcePtr &res); | ||||
| void ReclaimOptimizer(); | void ReclaimOptimizer(); | ||||
| bool PynativeOptPass(const ResourcePtr &res); | bool PynativeOptPass(const ResourcePtr &res); | ||||
| @@ -32,6 +32,8 @@ REGISTER_PYBIND_DEFINE(ParamInfo, ([](const py::module *m) { | |||||
| .def_property("parallel_optimizer", &ParamInfo::parallel_optimizer, | .def_property("parallel_optimizer", &ParamInfo::parallel_optimizer, | ||||
| &ParamInfo::set_parallel_optimizer) | &ParamInfo::set_parallel_optimizer) | ||||
| .def_property("comm_fusion", &ParamInfo::comm_fusion, &ParamInfo::set_comm_fusion) | .def_property("comm_fusion", &ParamInfo::comm_fusion, &ParamInfo::set_comm_fusion) | ||||
| .def_property("cache_enable", &ParamInfo::cache_enable, &ParamInfo::set_cache_enable) | |||||
| .def_property("cache_shape", &ParamInfo::cache_shape, &ParamInfo::set_cache_shape) | |||||
| .def(py::pickle( | .def(py::pickle( | ||||
| [](const ParamInfo &p) { // __getstate__ | [](const ParamInfo &p) { // __getstate__ | ||||
| return py::make_tuple(p.name(), p.requires_grad(), p.layerwise_parallel()); | return py::make_tuple(p.name(), p.requires_grad(), p.layerwise_parallel()); | ||||
| @@ -24,6 +24,7 @@ | |||||
| #include "pybind_api/api_register.h" | #include "pybind_api/api_register.h" | ||||
| #include "abstract/abstract_value.h" | #include "abstract/abstract_value.h" | ||||
| #include "utils/shape_utils.h" | #include "utils/shape_utils.h" | ||||
| #include "utils/cache_embedding_hashmap_struct.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace tensor { | namespace tensor { | ||||
| @@ -272,6 +273,68 @@ py::int_ TensorPy::GetPyItemSize(const Tensor &tensor) { return tensor.data().it | |||||
| py::int_ TensorPy::GetPyNBytes(const Tensor &tensor) { return tensor.data().nbytes(); } | py::int_ TensorPy::GetPyNBytes(const Tensor &tensor) { return tensor.data().nbytes(); } | ||||
| template <typename T> | |||||
| void MemCopyFromCacheToHost(void *hashmap_addr, void *host_addr, void *cache_addr, size_t host_max, size_t cache_max, | |||||
| size_t hashmap_size, size_t col_size) { | |||||
| auto host_data = static_cast<char *>(host_addr); | |||||
| auto cache_data = static_cast<char *>(cache_addr); | |||||
| auto hashmap_data = static_cast<HashmapEntry<T> *>(hashmap_addr); | |||||
| // default param type float | |||||
| size_t param_type_size = 4; | |||||
| size_t single_col_bytes = param_type_size * col_size; | |||||
| for (size_t i = 0; i < hashmap_size; ++i) { | |||||
| if (!hashmap_data[i].IsEmpty()) { | |||||
| size_t host_offset = single_col_bytes * hashmap_data[i].key_; | |||||
| size_t cache_offset = single_col_bytes * hashmap_data[i].value_; | |||||
| if (cache_offset + single_col_bytes <= cache_max) { | |||||
| auto ret = | |||||
| memcpy_s(host_data + host_offset, host_max - host_offset, cache_data + cache_offset, single_col_bytes); | |||||
| if (ret != 0) { | |||||
| MS_LOG(EXCEPTION) << "Memcpy failed."; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| MS_LOG(INFO) << "Memcpy from cache to host success!"; | |||||
| } | |||||
| void TensorPy::FlushFromCache(const Tensor &tensor) { | |||||
| py::gil_scoped_release gil_release; | |||||
| if (tensor.NeedWait()) { | |||||
| tensor.Wait(); | |||||
| } | |||||
| tensor.data_sync(); | |||||
| if (tensor.cache_enable()) { | |||||
| MS_LOG(INFO) << tensor.ToString() << " is cache enable."; | |||||
| auto hashmap_tensor_ptr = tensor.hashmap_tensor_ptr(); | |||||
| auto cache_tensor_ptr = tensor.cache_tensor_ptr(); | |||||
| if (hashmap_tensor_ptr != nullptr && cache_tensor_ptr != nullptr) { | |||||
| hashmap_tensor_ptr->data_sync(); | |||||
| cache_tensor_ptr->data_sync(); | |||||
| auto hashmap_size = hashmap_tensor_ptr->shape_c()[0]; | |||||
| auto host_shape = tensor.shape_c(); | |||||
| auto cache_shape = cache_tensor_ptr->shape_c(); | |||||
| if (host_shape.size() != 2 && host_shape.size() != 2 && host_shape[1] != cache_shape[1]) { | |||||
| MS_LOG(EXCEPTION) << "Got host shape and cache shape invalid." | |||||
| << "host shape:" << host_shape << ", cache shape:" << cache_shape; | |||||
| } | |||||
| auto host_data_max_size = tensor.Size(); | |||||
| auto cache_data_max_size = cache_tensor_ptr->Size(); | |||||
| auto hashmap_data_type = hashmap_tensor_ptr->data_type(); | |||||
| if (hashmap_data_type == TypeId::kNumberTypeInt32) { | |||||
| MemCopyFromCacheToHost<int32_t>(hashmap_tensor_ptr->data_c(), tensor.data_c(), cache_tensor_ptr->data_c(), | |||||
| host_data_max_size, cache_data_max_size, hashmap_size, host_shape[1]); | |||||
| } else if (hashmap_data_type == TypeId::kNumberTypeInt64) { | |||||
| MemCopyFromCacheToHost<int32_t>(hashmap_tensor_ptr->data_c(), tensor.data_c(), cache_tensor_ptr->data_c(), | |||||
| host_data_max_size, cache_data_max_size, hashmap_size, host_shape[1]); | |||||
| } else { | |||||
| MS_LOG(ERROR) << "Hashmap dtype only suppotr int32, in64."; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| py::array TensorPy::SyncAsNumpy(const Tensor &tensor) { | py::array TensorPy::SyncAsNumpy(const Tensor &tensor) { | ||||
| { | { | ||||
| py::gil_scoped_release gil_release; | py::gil_scoped_release gil_release; | ||||
| @@ -457,6 +520,16 @@ REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) { | |||||
| array([[1., 1., 1.], | array([[1., 1., 1.], | ||||
| [1., 1., 1.]]) | [1., 1., 1.]]) | ||||
| )mydelimiter") | )mydelimiter") | ||||
| .def("_flush_from_cache", TensorPy::FlushFromCache, R"mydelimiter( | |||||
| Flush Cache data to Host if tensor is cache enable. | |||||
| Returns: | |||||
| None. | |||||
| Examples: | |||||
| >>> data = mindspore.Tensor(np.ones((2, 3))) | |||||
| >>> data._flush_from_cache() | |||||
| )mydelimiter") | |||||
| .def("is_init", &Tensor::is_init, R"mydelimiter( | .def("is_init", &Tensor::is_init, R"mydelimiter( | ||||
| Get tensor init_flag. | Get tensor init_flag. | ||||
| @@ -115,6 +115,8 @@ class TensorPy { | |||||
| static py::int_ GetPyItemSize(const Tensor &tensor); | static py::int_ GetPyItemSize(const Tensor &tensor); | ||||
| static py::int_ GetPyNBytes(const Tensor &tensor); | static py::int_ GetPyNBytes(const Tensor &tensor); | ||||
| static void FlushFromCache(const Tensor &tensor); | |||||
| }; | }; | ||||
| } // namespace tensor | } // namespace tensor | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -268,7 +268,7 @@ void CPUKernelRuntime::CreateOutputTensors(session::KernelGraph *kernel_graph, | |||||
| bound_addresses_.clear(); | bound_addresses_.clear(); | ||||
| auto output_nodes = kernel_graph->outputs(); | auto output_nodes = kernel_graph->outputs(); | ||||
| for (const auto &item : output_nodes) { | for (const auto &item : output_nodes) { | ||||
| auto item_with_index = AnfAlgo::VisitKernelWithReturnType(item, 0, true); | |||||
| auto item_with_index = AnfAlgo::VisitKernelWithReturnType(item, 0, false); | |||||
| auto out = CreatTensorForOutput(kernel_graph, item_with_index, tensor_to_node); | auto out = CreatTensorForOutput(kernel_graph, item_with_index, tensor_to_node); | ||||
| outputs->push_back(std::move(out)); | outputs->push_back(std::move(out)); | ||||
| } | } | ||||
| @@ -0,0 +1,51 @@ | |||||
| /** | |||||
| * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). | |||||
| * | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_UTIL_CACHE_EMBBEDDING_HASHMAP_STRUCT_H_ | |||||
| #define MINDSPORE_CCSRC_UTIL_CACHE_EMBBEDDING_HASHMAP_STRUCT_H_ | |||||
| #include <math.h> | |||||
| namespace mindspore { | |||||
| const int64_t kNullTag = 0; | |||||
| const int64_t kInitStep = -5; | |||||
| const int64_t kEmptyRate = 4; | |||||
| const double kGoldenRatio = 0.6180339; | |||||
| template <typename T> | |||||
| struct HashmapEntry { | |||||
| T key_; | |||||
| T value_; | |||||
| T step_; | |||||
| T tag_; | |||||
| bool IsEmpty() { return tag_ == kNullTag; } | |||||
| bool IsUsing(const T train_step) { return step_ >= (train_step - 1); } | |||||
| bool IsKey(const T emb_idx) { return key_ == emb_idx; } | |||||
| void SetEmpty() { tag_ = kNullTag; } | |||||
| }; | |||||
| template <typename T> | |||||
| T HashFunc(const T key, const size_t m) { | |||||
| return (T)(((kGoldenRatio * key) - floor(kGoldenRatio * key)) * m); | |||||
| } | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_UTIL_CACHE_EMBBEDDING_HASHMAP_STRUCT_H_ | |||||
| @@ -350,6 +350,7 @@ constexpr auto kAttrPrimitiveTarget = "primitive_target"; | |||||
| constexpr auto kAttrUseLocking = "use_locking"; | constexpr auto kAttrUseLocking = "use_locking"; | ||||
| constexpr auto kAttrReduceScatterFlag = "reduce_scatter_flag"; | constexpr auto kAttrReduceScatterFlag = "reduce_scatter_flag"; | ||||
| constexpr auto kAttrOffset = "offset"; | constexpr auto kAttrOffset = "offset"; | ||||
| constexpr auto kAttrCacheEnable = "cache_enable"; | |||||
| constexpr auto kAttrPsKey = "ps_key"; | constexpr auto kAttrPsKey = "ps_key"; | ||||
| constexpr auto kAttrOptimizerType = "optim_type"; | constexpr auto kAttrOptimizerType = "optim_type"; | ||||
| constexpr auto kAttrChildGraph = "child_graph"; | constexpr auto kAttrChildGraph = "child_graph"; | ||||
| @@ -131,7 +131,7 @@ class Parameter(Tensor_): | |||||
| if self.init_mode is not None: | if self.init_mode is not None: | ||||
| data = self.init_mode | data = self.init_mode | ||||
| else: | else: | ||||
| # cast to break deep infinit loop while deepcopy | |||||
| # cast to break deep infinite loop while deepcopy | |||||
| data = Tensor(self) | data = Tensor(self) | ||||
| return ( | return ( | ||||
| Parameter, (data, self.name, self.requires_grad, self.layerwise_parallel)) | Parameter, (data, self.name, self.requires_grad, self.layerwise_parallel)) | ||||
| @@ -348,6 +348,8 @@ class Parameter(Tensor_): | |||||
| x.is_param_ps = self.is_param_ps | x.is_param_ps = self.is_param_ps | ||||
| x.init_in_server = self.init_in_server | x.init_in_server = self.init_in_server | ||||
| x.cache_enable = self.cache_enable | x.cache_enable = self.cache_enable | ||||
| if self.cache_shape: | |||||
| x.cache_shape = self.cache_shape | |||||
| if init != 'same': | if init != 'same': | ||||
| shape = self.shape | shape = self.shape | ||||
| dtype = self.dtype | dtype = self.dtype | ||||
| @@ -375,6 +377,28 @@ class Parameter(Tensor_): | |||||
| raise TypeError("`parallel_optimizer` parameter must be bool type") | raise TypeError("`parallel_optimizer` parameter must be bool type") | ||||
| self._param_info.parallel_optimizer = value | self._param_info.parallel_optimizer = value | ||||
| @property | |||||
| def cache_enable(self): | |||||
| """Return whether the parameter is cache enable.""" | |||||
| return self._param_info.cache_enable | |||||
| @cache_enable.setter | |||||
| def cache_enable(self, value=True): | |||||
| if not isinstance(value, bool): | |||||
| raise TypeError("`cache_enable` parameter must be bool type") | |||||
| self._param_info.cache_enable = value | |||||
| @property | |||||
| def cache_shape(self): | |||||
| """Return the cache shape corresponding to the parameter if use cache.""" | |||||
| return self._param_info.cache_shape | |||||
| @cache_shape.setter | |||||
| def cache_shape(self, value): | |||||
| if not isinstance(value, (tuple, list)): | |||||
| raise TypeError("`cache_shape` parameter must be tuple or list type") | |||||
| self._param_info.cache_shape = value | |||||
| @property | @property | ||||
| def requires_grad(self): | def requires_grad(self): | ||||
| """Return whether the parameter requires gradient.""" | """Return whether the parameter requires gradient.""" | ||||
| @@ -308,6 +308,10 @@ class Tensor(Tensor_): | |||||
| """Convert tensor to numpy array.""" | """Convert tensor to numpy array.""" | ||||
| return Tensor_.asnumpy(self) | return Tensor_.asnumpy(self) | ||||
| def _flush_from_cache(self): | |||||
| """Flush cache data to host if tensor is cache enable.""" | |||||
| Tensor_._flush_from_cache(self) | |||||
| def all(self, axis=(), keep_dims=False): | def all(self, axis=(), keep_dims=False): | ||||
| """ | """ | ||||
| Check all array elements along a given axis evaluate to True. | Check all array elements along a given axis evaluate to True. | ||||
| @@ -60,6 +60,7 @@ using ValueNodePtr = std::shared_ptr<ValueNode>; | |||||
| class CNode; | class CNode; | ||||
| using CNodePtr = std::shared_ptr<CNode>; | using CNodePtr = std::shared_ptr<CNode>; | ||||
| using CNodePtrList = std::vector<CNodePtr>; | |||||
| class FuncGraph; | class FuncGraph; | ||||
| using FuncGraphSet = OrderedSet<FuncGraphPtr>; | using FuncGraphSet = OrderedSet<FuncGraphPtr>; | ||||
| @@ -88,7 +89,7 @@ using ParamInfoPtr = std::shared_ptr<ParamInfo>; | |||||
| // intermediate_abstract: return the cached inferring abstract value. | // intermediate_abstract: return the cached inferring abstract value. | ||||
| // Type/Shape: return the related info of this AnfNode. When this AnfNode is an | // Type/Shape: return the related info of this AnfNode. When this AnfNode is an | ||||
| // input of other CNodes, you can get the related info by this method. | // input of other CNodes, you can get the related info by this method. | ||||
| // debug_info: return the information retrived from parser. Set it using set_debug_info. | |||||
| // debug_info: return the information retrieved from parser. Set it using set_debug_info. | |||||
| // fullname_with_scope: return the detailed debug info. | // fullname_with_scope: return the detailed debug info. | ||||
| class AnfNode : public Base { | class AnfNode : public Base { | ||||
| public: | public: | ||||
| @@ -167,7 +167,6 @@ class MetaTensor : public Value { | |||||
| // Get tensor's param_info info. | // Get tensor's param_info info. | ||||
| ParamInfoPtr param_info() const { return param_info_; } | ParamInfoPtr param_info() const { return param_info_; } | ||||
| bool is_parameter() const { return is_parameter_; } | bool is_parameter() const { return is_parameter_; } | ||||
| // Set tensor's param_info info. | // Set tensor's param_info info. | ||||
| void set_param_info(const ParamInfoPtr ¶m_info) { | void set_param_info(const ParamInfoPtr ¶m_info) { | ||||
| is_parameter_ = true; | is_parameter_ = true; | ||||
| @@ -81,6 +81,12 @@ class ParamInfo { | |||||
| bool parallel_optimizer() const { return parallel_optimizer_; } | bool parallel_optimizer() const { return parallel_optimizer_; } | ||||
| void set_parallel_optimizer(bool parallel_optimizer) { parallel_optimizer_ = parallel_optimizer; } | void set_parallel_optimizer(bool parallel_optimizer) { parallel_optimizer_ = parallel_optimizer; } | ||||
| bool cache_enable() const { return cache_enable_; } | |||||
| void set_cache_enable(bool cache_enable) { cache_enable_ = cache_enable; } | |||||
| std::vector<int64_t> cache_shape() const { return cache_shape_; } | |||||
| void set_cache_shape(const std::vector<int64_t> &cache_shape) { cache_shape_ = cache_shape; } | |||||
| private: | private: | ||||
| std::string name_{"Parameter"}; | std::string name_{"Parameter"}; | ||||
| bool requires_grad_{true}; | bool requires_grad_{true}; | ||||
| @@ -92,6 +98,8 @@ class ParamInfo { | |||||
| int32_t cloned_index_{0}; | int32_t cloned_index_{0}; | ||||
| int32_t fusion_type_{1}; | int32_t fusion_type_{1}; | ||||
| bool parallel_optimizer_{true}; | bool parallel_optimizer_{true}; | ||||
| bool cache_enable_{false}; | |||||
| std::vector<int64_t> cache_shape_; | |||||
| }; | }; | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CORE_IR_PARAM_INFO_H_ | #endif // MINDSPORE_CORE_IR_PARAM_INFO_H_ | ||||
| @@ -449,6 +449,9 @@ Tensor::Tensor(const Tensor &tensor) | |||||
| event_(tensor.event_), | event_(tensor.event_), | ||||
| sync_status_(tensor.sync_status_), | sync_status_(tensor.sync_status_), | ||||
| device_sync_(tensor.device_sync_), | device_sync_(tensor.device_sync_), | ||||
| cache_enable_(tensor.cache_enable_), | |||||
| cache_tensor_ptr_(tensor.cache_tensor_ptr_), | |||||
| hashmap_tensor_ptr_(tensor.hashmap_tensor_ptr_), | |||||
| padding_type_(tensor.padding_type()) {} | padding_type_(tensor.padding_type()) {} | ||||
| Tensor::Tensor(const Tensor &tensor, TypeId data_type) | Tensor::Tensor(const Tensor &tensor, TypeId data_type) | ||||
| @@ -459,6 +462,9 @@ Tensor::Tensor(const Tensor &tensor, TypeId data_type) | |||||
| event_(tensor.event_), | event_(tensor.event_), | ||||
| sync_status_(tensor.sync_status_), | sync_status_(tensor.sync_status_), | ||||
| device_sync_(tensor.device_sync_), | device_sync_(tensor.device_sync_), | ||||
| cache_enable_(tensor.cache_enable_), | |||||
| cache_tensor_ptr_(tensor.cache_tensor_ptr_), | |||||
| hashmap_tensor_ptr_(tensor.hashmap_tensor_ptr_), | |||||
| padding_type_(tensor.padding_type()) {} | padding_type_(tensor.padding_type()) {} | ||||
| Tensor::Tensor(TypeId data_type, const ShapeVector &shape, TensorDataPtr data) | Tensor::Tensor(TypeId data_type, const ShapeVector &shape, TensorDataPtr data) | ||||
| @@ -511,7 +517,7 @@ bool Tensor::ValueEqual(const Tensor &tensor) const { | |||||
| return (&tensor == this || (MetaTensor::operator==(tensor) && data_->equals(*tensor.data_))); | return (&tensor == this || (MetaTensor::operator==(tensor) && data_->equals(*tensor.data_))); | ||||
| } | } | ||||
| // assgin value to this tensor | |||||
| // assign value to this tensor | |||||
| Tensor &Tensor::AssignValue(const Tensor &tensor) { | Tensor &Tensor::AssignValue(const Tensor &tensor) { | ||||
| if (this != &tensor) { | if (this != &tensor) { | ||||
| MetaTensor::operator=(tensor); | MetaTensor::operator=(tensor); | ||||
| @@ -206,7 +206,7 @@ class Tensor : public MetaTensor { | |||||
| // it do real value comparison. | // it do real value comparison. | ||||
| bool ValueEqual(const Tensor &tensor) const; | bool ValueEqual(const Tensor &tensor) const; | ||||
| // assgin value to this tensor | |||||
| // assign value to this tensor | |||||
| Tensor &AssignValue(const Tensor &tensor); | Tensor &AssignValue(const Tensor &tensor); | ||||
| bool operator==(const Value &other) const override { | bool operator==(const Value &other) const override { | ||||
| @@ -291,6 +291,18 @@ class Tensor : public MetaTensor { | |||||
| TypePtr cast_dtype() { return cast_dtype_; } | TypePtr cast_dtype() { return cast_dtype_; } | ||||
| void set_cast_dtype(TypePtr dtype = nullptr) { cast_dtype_ = dtype; } | void set_cast_dtype(TypePtr dtype = nullptr) { cast_dtype_ = dtype; } | ||||
| // used if cache_enable, in order to update tensor from cache to host | |||||
| bool cache_enable() const { return cache_enable_; } | |||||
| void set_cache_enable(bool cache_enable = true) { cache_enable_ = cache_enable; } | |||||
| std::shared_ptr<Tensor> hashmap_tensor_ptr() const { return hashmap_tensor_ptr_; } | |||||
| void set_hashmap_tensor_ptr(std::shared_ptr<Tensor> hashmap_tensor_ptr = nullptr) { | |||||
| hashmap_tensor_ptr_ = hashmap_tensor_ptr; | |||||
| } | |||||
| std::shared_ptr<Tensor> cache_tensor_ptr() const { return cache_tensor_ptr_; } | |||||
| void set_cache_tensor_ptr(std::shared_ptr<Tensor> cache_tensor_ptr = nullptr) { | |||||
| cache_tensor_ptr_ = cache_tensor_ptr; | |||||
| } | |||||
| void SetNeedWait(bool need_wait) { | void SetNeedWait(bool need_wait) { | ||||
| if (event_ != nullptr) { | if (event_ != nullptr) { | ||||
| event_->set_need_wait(need_wait); | event_->set_need_wait(need_wait); | ||||
| @@ -335,6 +347,9 @@ class Tensor : public MetaTensor { | |||||
| mutable TensorSyncStatus sync_status_{kNeedSyncHostToDevice}; | mutable TensorSyncStatus sync_status_{kNeedSyncHostToDevice}; | ||||
| bool graph_output_{false}; | bool graph_output_{false}; | ||||
| DeviceSyncPtr device_sync_{nullptr}; | DeviceSyncPtr device_sync_{nullptr}; | ||||
| bool cache_enable_{false}; | |||||
| std::shared_ptr<Tensor> cache_tensor_ptr_{nullptr}; | |||||
| std::shared_ptr<Tensor> hashmap_tensor_ptr_{nullptr}; | |||||
| std::vector<Axis> padding_type_; | std::vector<Axis> padding_type_; | ||||
| TypePtr cast_dtype_{nullptr}; | TypePtr cast_dtype_{nullptr}; | ||||
| }; | }; | ||||
| @@ -21,6 +21,7 @@ const char GRAPH_FLAG_MIX_PRECISION_FP16[] = "fp16"; | |||||
| const char GRAPH_FLAG_MIX_PRECISION_FP32[] = "fp32"; | const char GRAPH_FLAG_MIX_PRECISION_FP32[] = "fp32"; | ||||
| const char GRAPH_FLAG_HAS_EFFECT[] = "has_effect"; | const char GRAPH_FLAG_HAS_EFFECT[] = "has_effect"; | ||||
| const char GRAPH_FLAG_EFFECT_PATIAL_ORDER[] = "_effect_patial_order"; | const char GRAPH_FLAG_EFFECT_PATIAL_ORDER[] = "_effect_patial_order"; | ||||
| const char GRAPH_FLAG_CACHE_ENABLE[] = "cache_enable"; | |||||
| const char GRAPH_FLAG_RANDOM_EFFECT[] = "_random_effect"; | const char GRAPH_FLAG_RANDOM_EFFECT[] = "_random_effect"; | ||||
| const char GRAPH_FLAG_SIDE_EFFECT[] = "_side_effect"; | const char GRAPH_FLAG_SIDE_EFFECT[] = "_side_effect"; | ||||
| @@ -21,6 +21,7 @@ extern const char GRAPH_FLAG_MIX_PRECISION_FP16[]; | |||||
| extern const char GRAPH_FLAG_MIX_PRECISION_FP32[]; | extern const char GRAPH_FLAG_MIX_PRECISION_FP32[]; | ||||
| extern const char GRAPH_FLAG_HAS_EFFECT[]; | extern const char GRAPH_FLAG_HAS_EFFECT[]; | ||||
| extern const char GRAPH_FLAG_EFFECT_PATIAL_ORDER[]; | extern const char GRAPH_FLAG_EFFECT_PATIAL_ORDER[]; | ||||
| extern const char GRAPH_FLAG_CACHE_ENABLE[]; | |||||
| extern const char GRAPH_FLAG_RANDOM_EFFECT[]; | extern const char GRAPH_FLAG_RANDOM_EFFECT[]; | ||||
| extern const char GRAPH_FLAG_SIDE_EFFECT[]; | extern const char GRAPH_FLAG_SIDE_EFFECT[]; | ||||
| @@ -172,8 +172,8 @@ class EmbeddingLookup(Cell): | |||||
| or None. Default: None | or None. Default: None | ||||
| sparse (bool): Using sparse mode. When 'target' is set to 'CPU', 'sparse' has to be true. Default: True. | sparse (bool): Using sparse mode. When 'target' is set to 'CPU', 'sparse' has to be true. Default: True. | ||||
| vocab_cache_size (int): Cache size of the dictionary of embeddings. Default: 0. It is valid only in | vocab_cache_size (int): Cache size of the dictionary of embeddings. Default: 0. It is valid only in | ||||
| parameter server trainning mode and 'DEVICE' target. And the moment parameter of corresponding | |||||
| optimizer will also be set to the cache size. In addition, it should be noted that it will cost the 'DEVICE' | |||||
| 'DEVICE' target. And the moment parameter of corresponding optimizer will also be set to the cache size. | |||||
| In addition, it should be noted that it will cost the 'DEVICE' | |||||
| memory, so suggests setting a reasonable value to avoid insufficient memory. | memory, so suggests setting a reasonable value to avoid insufficient memory. | ||||
| Inputs: | Inputs: | ||||
| @@ -205,7 +205,12 @@ class EmbeddingLookup(Cell): | |||||
| max_norm=None, sparse=True, vocab_cache_size=0): | max_norm=None, sparse=True, vocab_cache_size=0): | ||||
| super(EmbeddingLookup, self).__init__() | super(EmbeddingLookup, self).__init__() | ||||
| validator.check_value_type('sparse', sparse, [bool], self.cls_name) | validator.check_value_type('sparse', sparse, [bool], self.cls_name) | ||||
| self.vocab_size = validator.check_positive_int(vocab_size, 'vocab_size') | |||||
| self.vocab_cache_size = validator.check_non_negative_int(vocab_cache_size, 'vocab_cache_size') | |||||
| self.target = target | self.target = target | ||||
| self.sparse = sparse | |||||
| self.cache_enable = self.vocab_cache_size > 0 | |||||
| self.forward_unique = False | |||||
| if target not in ('CPU', 'DEVICE'): | if target not in ('CPU', 'DEVICE'): | ||||
| raise ValueError('Attr \'target\' of \'EmbeddingLookup\' Op passed ' | raise ValueError('Attr \'target\' of \'EmbeddingLookup\' Op passed ' | ||||
| + str(target) + ', should be one of values in \'CPU\', \'DEVICE\'.') | + str(target) + ', should be one of values in \'CPU\', \'DEVICE\'.') | ||||
| @@ -216,21 +221,23 @@ class EmbeddingLookup(Cell): | |||||
| else: | else: | ||||
| self.gatherv2 = P.GatherV2() | self.gatherv2 = P.GatherV2() | ||||
| self.embeddinglookup = P.EmbeddingLookup().add_prim_attr('primitive_target', 'CPU') | self.embeddinglookup = P.EmbeddingLookup().add_prim_attr('primitive_target', 'CPU') | ||||
| self.vocab_size = validator.check_positive_int(vocab_size, 'vocab_size') | |||||
| self.vocab_cache_size = validator.check_non_negative_int(vocab_cache_size, 'vocab_cache_size') | |||||
| self._process_vocab_cache(slice_mode) | |||||
| enable_ps = _get_ps_context("enable_ps") | |||||
| if enable_ps: | |||||
| self._process_vocab_cache(slice_mode) | |||||
| self.embedding_size = validator.check_positive_int(embedding_size, 'embedding_size') | self.embedding_size = validator.check_positive_int(embedding_size, 'embedding_size') | ||||
| self.embedding_table = Parameter(initializer(param_init, [self.vocab_size, self.embedding_size]), | self.embedding_table = Parameter(initializer(param_init, [self.vocab_size, self.embedding_size]), | ||||
| name='embedding_table') | name='embedding_table') | ||||
| if self.cache_enable: | |||||
| self._set_voacb_cache_enable(vocab_cache_size, embedding_size, vocab_size) | |||||
| if self.cache_enable and enable_ps: | |||||
| self._set_voacb_cache_enable_for_ps(vocab_cache_size, embedding_size, vocab_size) | |||||
| parallel_mode = _get_parallel_mode() | parallel_mode = _get_parallel_mode() | ||||
| is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL) | is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL) | ||||
| self.forward_unique = False | |||||
| self.gather_revert = P.GatherV2() | self.gather_revert = P.GatherV2() | ||||
| self.unique = P.Unique().shard(((1,),)) | |||||
| self.reshape_first = P.Reshape() | |||||
| self.reshape = P.Reshape() | self.reshape = P.Reshape() | ||||
| self.unique = P.Unique() | |||||
| self.shape = P.Shape() | self.shape = P.Shape() | ||||
| if is_auto_parallel: | |||||
| self.unique = P.Unique().shard(((1,),)) | |||||
| indices_shape_size = 2 | indices_shape_size = 2 | ||||
| if slice_mode == "field_slice" and is_auto_parallel: | if slice_mode == "field_slice" and is_auto_parallel: | ||||
| if not manual_shapes: | if not manual_shapes: | ||||
| @@ -270,12 +277,34 @@ class EmbeddingLookup(Cell): | |||||
| if is_auto_parallel: | if is_auto_parallel: | ||||
| raise ValueError("slice_mode should support mode in nn.EmbeddingLookup, but get " | raise ValueError("slice_mode should support mode in nn.EmbeddingLookup, but get " | ||||
| + str(slice_mode)) | + str(slice_mode)) | ||||
| if self.cache_enable and not enable_ps: | |||||
| if is_auto_parallel: | |||||
| raise ValueError("parallel mode haven't supported cache enable yet.") | |||||
| self._set_cache_enable() | |||||
| self.embedding_table.unique = self.forward_unique | self.embedding_table.unique = self.forward_unique | ||||
| self.max_norm = max_norm | self.max_norm = max_norm | ||||
| if self.max_norm is not None: | if self.max_norm is not None: | ||||
| self.max_norm = validator.check_positive_float(self.max_norm, 'max_norm', self.cls_name) | self.max_norm = validator.check_positive_float(self.max_norm, 'max_norm', self.cls_name) | ||||
| self.max_norm = Tensor(self.max_norm, dtype=mstype.float32) | self.max_norm = Tensor(self.max_norm, dtype=mstype.float32) | ||||
| def _set_cache_enable(self): | |||||
| """EmbeddingLookup cache check for not ps env.""" | |||||
| if self.target != 'DEVICE': | |||||
| logger.warning("The configuration of 'vocab_cache_size' is valid only in 'DEVICE' target, " | |||||
| "so it will be ignored.") | |||||
| return | |||||
| if not self.sparse: | |||||
| logger.warning("The configuration of 'vocab_cache_size' is valid only 'sparse' is true, " | |||||
| "so it will be ignored.") | |||||
| return | |||||
| logger.info("EmbeddingLookup cache enable takes effect.") | |||||
| self.forward_unique = True | |||||
| self.unique = P.Unique().add_prim_attr('primitive_target', 'CPU') | |||||
| self.unique.add_prim_attr('cache_enable', True) | |||||
| self.embedding_table.cache_enable = self.cache_enable | |||||
| self.embedding_table.cache_shape = (self.vocab_cache_size, self.embedding_size) | |||||
| self.reshape_first = P.Reshape().add_prim_attr('primitive_target', 'CPU') | |||||
| def _process_vocab_cache(self, slice_mode): | def _process_vocab_cache(self, slice_mode): | ||||
| """PS embeddingLookup cache check and process.""" | """PS embeddingLookup cache check and process.""" | ||||
| self.cache_enable = False | self.cache_enable = False | ||||
| @@ -302,7 +331,7 @@ class EmbeddingLookup(Cell): | |||||
| if _is_role_worker(): | if _is_role_worker(): | ||||
| self.vocab_size = self.vocab_cache_size | self.vocab_size = self.vocab_cache_size | ||||
| def _set_voacb_cache_enable(self, vocab_cache_size, embedding_size, vocab_size): | |||||
| def _set_voacb_cache_enable_for_ps(self, vocab_cache_size, embedding_size, vocab_size): | |||||
| """PS embeddingLookup cache enable set.""" | """PS embeddingLookup cache enable set.""" | ||||
| self.embedding_table.cache_enable = True | self.embedding_table.cache_enable = True | ||||
| self.embedding_table.is_param_ps = True | self.embedding_table.is_param_ps = True | ||||
| @@ -316,7 +345,7 @@ class EmbeddingLookup(Cell): | |||||
| else: | else: | ||||
| if self.forward_unique: | if self.forward_unique: | ||||
| shp = self.shape(indices) + (self.embedding_size,) | shp = self.shape(indices) + (self.embedding_size,) | ||||
| indices_flatten = self.reshape(indices, (-1,)) | |||||
| indices_flatten = self.reshape_first(indices, (-1,)) | |||||
| unique_id, unique_idx = self.unique(indices_flatten) | unique_id, unique_idx = self.unique(indices_flatten) | ||||
| weight_unique = self.gatherv2(self.embedding_table, unique_id, 0) | weight_unique = self.gatherv2(self.embedding_table, unique_id, 0) | ||||
| weight_flatten = self.gather_revert(weight_unique, unique_idx, 0) | weight_flatten = self.gather_revert(weight_unique, unique_idx, 0) | ||||
| @@ -156,8 +156,8 @@ class Optimizer(Cell): | |||||
| break | break | ||||
| ps_filter = lambda x: x.is_param_ps | ps_filter = lambda x: x.is_param_ps | ||||
| self.ps_parameters = tuple(ps_filter(x) for x in self.parameters) | self.ps_parameters = tuple(ps_filter(x) for x in self.parameters) | ||||
| ps_cache_filter = lambda x: x.cache_enable | |||||
| self.cache_enable = tuple(ps_cache_filter(x) for x in self.parameters) | |||||
| cache_filter = lambda x: x.cache_enable | |||||
| self.cache_enable = tuple(cache_filter(x) for x in self.parameters) | |||||
| self.reciprocal_scale = Tensor(1.0 / loss_scale, mstype.float32) | self.reciprocal_scale = Tensor(1.0 / loss_scale, mstype.float32) | ||||
| self.need_scale = loss_scale != 1.0 | self.need_scale = loss_scale != 1.0 | ||||
| self.global_step_increase_tensor = Tensor(1, mstype.int32) | self.global_step_increase_tensor = Tensor(1, mstype.int32) | ||||
| @@ -526,6 +526,9 @@ class Model: | |||||
| train_dataset.reset() | train_dataset.reset() | ||||
| # if param is cache enable, flush data from cache to host before epoch end | |||||
| self._flush_from_cache(cb_params) | |||||
| list_callback.epoch_end(run_context) | list_callback.epoch_end(run_context) | ||||
| should_stop = should_stop or run_context.get_stop_requested() | should_stop = should_stop or run_context.get_stop_requested() | ||||
| if should_stop: | if should_stop: | ||||
| @@ -784,5 +787,11 @@ class Model: | |||||
| predict_net.compile(*predict_data) | predict_net.compile(*predict_data) | ||||
| return predict_net.parameter_layout_dict | return predict_net.parameter_layout_dict | ||||
| def _flush_from_cache(self, cb_params): | |||||
| """Flush cache data to host if tensor is cache enable.""" | |||||
| params = cb_params.train_network.get_parameters() | |||||
| for param in params: | |||||
| if param.cache_enable: | |||||
| Tensor(param)._flush_from_cache() | |||||
| __all__ = ["Model"] | __all__ = ["Model"] | ||||
| @@ -53,8 +53,8 @@ def init_var_dict(init_args, in_vars): | |||||
| ''' | ''' | ||||
| var_map = {} | var_map = {} | ||||
| _, _max_val = init_args | _, _max_val = init_args | ||||
| for _, iterm in enumerate(in_vars): | |||||
| key, shape, method = iterm | |||||
| for _, item in enumerate(in_vars): | |||||
| key, shape, method = item | |||||
| if key not in var_map.keys(): | if key not in var_map.keys(): | ||||
| if method in ['random', 'uniform']: | if method in ['random', 'uniform']: | ||||
| var_map[key] = Parameter(initializer( | var_map[key] = Parameter(initializer( | ||||
| @@ -257,9 +257,11 @@ class WideDeepModel(nn.Cell): | |||||
| self.wide_embeddinglookup.embedding_table.set_param_ps() | self.wide_embeddinglookup.embedding_table.set_param_ps() | ||||
| else: | else: | ||||
| self.deep_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, self.emb_dim, | self.deep_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, self.emb_dim, | ||||
| target='DEVICE', sparse=sparse) | |||||
| target='DEVICE', sparse=sparse, | |||||
| vocab_cache_size=self.vocab_cache_size) | |||||
| self.wide_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, 1, | self.wide_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, 1, | ||||
| target='DEVICE', sparse=sparse) | |||||
| target='DEVICE', sparse=sparse, | |||||
| vocab_cache_size=self.vocab_cache_size) | |||||
| self.embedding_table = self.deep_embeddinglookup.embedding_table | self.embedding_table = self.deep_embeddinglookup.embedding_table | ||||
| def construct(self, id_hldr, wt_hldr): | def construct(self, id_hldr, wt_hldr): | ||||
| @@ -57,8 +57,8 @@ def init_var_dict(init_args, in_vars): | |||||
| """ | """ | ||||
| var_map = {} | var_map = {} | ||||
| _, _max_val = init_args | _, _max_val = init_args | ||||
| for _, iterm in enumerate(in_vars): | |||||
| key, shape, method = iterm | |||||
| for _, item in enumerate(in_vars): | |||||
| key, shape, method = item | |||||
| if key not in var_map.keys(): | if key not in var_map.keys(): | ||||
| if method in ['random', 'uniform']: | if method in ['random', 'uniform']: | ||||
| var_map[key] = Parameter(initializer(Uniform(_max_val), shape, | var_map[key] = Parameter(initializer(Uniform(_max_val), shape, | ||||