| @@ -19,55 +19,20 @@ | |||
| #include <memory> | |||
| #include <vector> | |||
| #include "runtime/device/cpu/cpu_device_address.h" | |||
| #include "utils/cache_embedding_hashmap_struct.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| template <typename T> | |||
| struct HashmapEntry { | |||
| T key; | |||
| T value; | |||
| T step; | |||
| T tag; | |||
| bool IsEmpty() { | |||
| if (this->tag == NULLTAG) | |||
| return true; | |||
| else | |||
| return false; | |||
| } | |||
| bool IsUsing(const T &train_step) { | |||
| if (this->step >= (train_step - 1)) | |||
| return true; | |||
| else | |||
| return false; | |||
| } | |||
| bool IsKey(const T &emb_idx) { | |||
| if (this->key == emb_idx) | |||
| return true; | |||
| else | |||
| return false; | |||
| } | |||
| void SetEmpty() { this->tag = NULLTAG; } | |||
| }; | |||
| template <typename T> | |||
| T HashFunc(const T &key, const size_t &m) { | |||
| return (T)(((0.6180339 * key) - floor(0.6180339 * key)) * m); | |||
| } | |||
| template <typename T> | |||
| int Compress(HashmapEntry<T> *entry_p, const size_t &length, T entry) { | |||
| T i = (entry + 1) % length, off = 1; | |||
| int compress_count = 0; | |||
| for (; !entry_p[i].IsEmpty(); i = (i + 1) % length, off++) { | |||
| if (entry_p[i].tag > off) { | |||
| entry_p[entry].key = entry_p[i].key; | |||
| entry_p[entry].value = entry_p[i].value; | |||
| entry_p[entry].step = entry_p[i].step; | |||
| entry_p[entry].tag = entry_p[i].tag - off; | |||
| if (entry_p[i].tag_ > off) { | |||
| entry_p[entry].key_ = entry_p[i].key_; | |||
| entry_p[entry].value_ = entry_p[i].value_; | |||
| entry_p[entry].step_ = entry_p[i].step_; | |||
| entry_p[entry].tag_ = entry_p[i].tag_ - off; | |||
| entry_p[i].SetEmpty(); | |||
| off = 0; | |||
| entry = i; | |||
| @@ -127,6 +92,7 @@ void MapCacheIdxCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, | |||
| float total_count = 0; | |||
| int count_size = 0; | |||
| float hit_count = 0; | |||
| // search_cache_idx | |||
| for (size_t i = 0; i < batch_size_; ++i) { | |||
| T key = input_indices[i] - offset; | |||
| @@ -140,7 +106,7 @@ void MapCacheIdxCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, | |||
| while ((!hashmap[tmp_entry].IsEmpty() && !hashmap[tmp_entry].IsKey(key))) { | |||
| tmp_entry = (tmp_entry + 1) % hashmap_length_; | |||
| if (count > hashmap_length_) { | |||
| MS_LOG(ERROR) << "Hashmap is full, search cache idx failed!"; | |||
| MS_LOG(EXCEPTION) << "Hashmap is full, search cache idx failed, please set a larger vocab_cache_size!"; | |||
| break; | |||
| } | |||
| count += 1; | |||
| @@ -153,8 +119,8 @@ void MapCacheIdxCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, | |||
| miss_count++; | |||
| } else { | |||
| hit_count += 1; | |||
| output_cache_idx[i] = hashmap[tmp_entry].value; | |||
| hashmap[tmp_entry].step = step_[0]; | |||
| output_cache_idx[i] = hashmap[tmp_entry].value_; | |||
| hashmap[tmp_entry].step_ = step_[0]; | |||
| } | |||
| } | |||
| if (miss_count != 0) { | |||
| @@ -175,27 +141,27 @@ void MapCacheIdxCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, | |||
| while (!hashmap[entry].IsEmpty()) { | |||
| entry = (entry + 1) % hashmap_length_; | |||
| if (tag_count > hashmap_length_) { | |||
| MS_LOG(ERROR) << "Hashmap is full, insert new key failed!"; | |||
| MS_LOG(EXCEPTION) << "Hashmap is full, insert new key failed, please set a larger vocab_cache_size!"; | |||
| break; | |||
| } | |||
| tag_count++; | |||
| } | |||
| hashmap[entry].key = emb_idx; | |||
| hashmap[entry].step = step_[0]; | |||
| hashmap[entry].tag = tag_count; | |||
| hashmap[entry].key_ = emb_idx; | |||
| hashmap[entry].step_ = step_[0]; | |||
| hashmap[entry].tag_ = tag_count; | |||
| T tmp_entry = (entry + 1) % hashmap_length_; | |||
| size_t delete_count = 1; | |||
| while (hashmap[tmp_entry].IsEmpty() || hashmap[tmp_entry].IsUsing(step_[0])) { | |||
| tmp_entry = (tmp_entry + 1) % hashmap_length_; | |||
| if (delete_count > hashmap_length_) { | |||
| MS_LOG(ERROR) << "Hashmap is full, delete old key failed!"; | |||
| MS_LOG(EXCEPTION) << "Hashmap is full, delete old key failed, please set a larger vocab_cache_size!"; | |||
| break; | |||
| } | |||
| delete_count++; | |||
| } | |||
| output_swap_cache_idx[i] = hashmap[tmp_entry].value; | |||
| output_old_emb_idx[i] = hashmap[tmp_entry].key; | |||
| hashmap[entry].value = output_swap_cache_idx[i]; | |||
| output_swap_cache_idx[i] = hashmap[tmp_entry].value_; | |||
| output_old_emb_idx[i] = hashmap[tmp_entry].key_; | |||
| hashmap[entry].value_ = output_swap_cache_idx[i]; | |||
| hashmap[tmp_entry].SetEmpty(); | |||
| int compress_count = Compress(hashmap, hashmap_length_, tmp_entry); | |||
| total_delete_count += (compress_count + delete_count); | |||
| @@ -23,8 +23,6 @@ | |||
| #include "backend/kernel_compiler/cpu/cpu_kernel.h" | |||
| #include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" | |||
| #define NULLTAG 0 | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| class MapCacheIdxCPUKernel : public CPUKernel { | |||
| @@ -188,12 +188,18 @@ void ReplaceOldNode(std::unordered_map<int64_t, BufferFusionInfo_t> *buffer_fusi | |||
| MS_EXCEPTION_IF_NULL(manager); | |||
| auto buffer_fusion_info = (*buffer_fusion_infos)[fusion_id]; | |||
| if (buffer_fusion_info.outputs_list.size() == 1) { // single output | |||
| if (kernel_graph != nullptr) { | |||
| kernel_graph->FrontBackendlMapUpdate(buffer_fusion_info.outputs_list[0], buffer_fusion_kernel); | |||
| } | |||
| (void)manager->Replace(buffer_fusion_info.outputs_list[0], buffer_fusion_kernel); | |||
| ReplaceInputNodeInOtherFusionScope(buffer_fusion_infos, fusion_id, buffer_fusion_info.outputs_list[0], | |||
| buffer_fusion_kernel); | |||
| } else { // multiple output | |||
| for (size_t index = 0; index < buffer_fusion_info.outputs_list.size(); ++index) { | |||
| auto tuple_item = CreateTupleGetItem(buffer_fusion_kernel, kernel_graph, index); | |||
| if (kernel_graph != nullptr) { | |||
| kernel_graph->FrontBackendlMapUpdate(buffer_fusion_info.outputs_list[index], tuple_item); | |||
| } | |||
| (void)manager->Replace(buffer_fusion_info.outputs_list[index], tuple_item); | |||
| ReplaceInputNodeInOtherFusionScope(buffer_fusion_infos, fusion_id, buffer_fusion_info.outputs_list[index], | |||
| tuple_item); | |||
| @@ -274,6 +274,10 @@ tensor::TensorPtr CreateTupleTensor(const ValueTuplePtr &value_tuple) { | |||
| bool IsNopNode(const AnfNodePtr &node) { | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||
| auto target = GetCNodeTarget(node); | |||
| if (target == kCPUDevice) { | |||
| return false; | |||
| } | |||
| if (context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET) != kAscendDevice && | |||
| context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET) != kGPUDevice) { | |||
| return false; | |||
| @@ -0,0 +1,535 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "frontend/parallel/cache_embedding/cache_embedding.h" | |||
| #include <random> | |||
| #include <vector> | |||
| #include <list> | |||
| #include <utility> | |||
| #include <memory> | |||
| #include <unordered_map> | |||
| #include <unordered_set> | |||
| #include <string> | |||
| #include <algorithm> | |||
| #include "backend/optimizer/common/helper.h" | |||
| #include "frontend/optimizer/optimizer.h" | |||
| #include "ir/func_graph.h" | |||
| #include "utils/cache_embedding_hashmap_struct.h" | |||
| namespace mindspore { | |||
| namespace parallel { | |||
| using ParamMap = std::unordered_map<ParameterPtr, ParameterPtr>; | |||
| using ParamSet = std::unordered_set<ParameterPtr>; | |||
| using NodePairList = std::vector<std::pair<AnfNodePtr, AnfNodePtr>>; | |||
| ParamMap AddCacheParameters(const FuncGraphPtr &graph, const ParamSet ¶meter_cache_enable_set) { | |||
| ParamMap cache_host_params_map; | |||
| for (auto ¶m : parameter_cache_enable_set) { | |||
| auto param_info = param->param_info(); | |||
| if (param_info && param_info->cache_enable()) { | |||
| auto data_type = param->Type(); | |||
| auto data_element_type = data_type->cast<mindspore::TensorTypePtr>()->element(); | |||
| auto type_id = data_element_type->type_id(); | |||
| auto cache_shape = param_info->cache_shape(); | |||
| auto ori_param_name = param->name(); | |||
| auto new_tensor = std::make_shared<tensor::Tensor>(type_id, cache_shape); | |||
| ParamInfoPtr new_param_info = std::make_shared<ParamInfo>(); | |||
| auto cache_name = ori_param_name + "_cache"; | |||
| new_param_info->set_name(cache_name); | |||
| new_tensor->set_param_info(new_param_info); | |||
| auto cache_param = graph->AddWeightParameter(cache_name); | |||
| cache_param->set_default_param(MakeValue(new_tensor)); | |||
| cache_param->set_abstract(new_tensor->ToAbstract()); | |||
| cache_host_params_map[cache_param] = param; | |||
| } | |||
| } | |||
| return cache_host_params_map; | |||
| } | |||
| bool CheckHostCacheParamSize(const ParamSet ¶meter_cache_enable_set) { | |||
| int64_t host_size = 0; | |||
| int64_t cache_size = 0; | |||
| for (auto &host_param : parameter_cache_enable_set) { | |||
| auto tmp_host_size = host_param->abstract()->GetShapeTrack()->cast<abstract::ShapePtr>()->shape()[0]; | |||
| auto host_param_info = host_param->param_info(); | |||
| auto cache_shape = host_param_info->cache_shape(); | |||
| if (cache_shape.empty()) { | |||
| MS_LOG(EXCEPTION) << "The value of cache_shape is empty."; | |||
| } | |||
| auto tmp_cache_size = cache_shape[0]; | |||
| if ((host_size != 0 && tmp_host_size != host_size) || (cache_size != 0 && tmp_cache_size != cache_size)) { | |||
| MS_LOG(EXCEPTION) | |||
| << "If EmbeddingLookup are cache enable, vocab_size and vocab_cache_size of different cells must be the same."; | |||
| } | |||
| cache_size = tmp_cache_size; | |||
| host_size = tmp_host_size; | |||
| } | |||
| if (cache_size >= host_size) { | |||
| MS_LOG(WARNING) << "vocab_cache_size >= vocab_size, there is no need use cache."; | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| void ReplaceCacheParams(const FuncGraphPtr &graph, const ParamMap &map) { | |||
| auto manager = graph->manager(); | |||
| MS_EXCEPTION_IF_NULL(manager); | |||
| for (auto &ele : map) { | |||
| if (!manager->Replace(ele.second, ele.first)) { | |||
| MS_LOG(EXCEPTION) << "host param: " << ele.second->name() << ", replace node failed."; | |||
| } | |||
| } | |||
| } | |||
| ParamSet MapKeysToSet(const ParamMap &map) { | |||
| ParamSet set; | |||
| for (auto &ele : map) { | |||
| set.insert(ele.first); | |||
| } | |||
| return set; | |||
| } | |||
| ParamSet FindParamCacheEnable(const FuncGraphPtr &graph) { | |||
| ParamSet parameter_cache_enable_set; | |||
| auto parameters = graph->parameters(); | |||
| auto params_size = parameters.size(); | |||
| for (size_t i = 0; i < params_size; ++i) { | |||
| auto param = parameters[i]->cast<ParameterPtr>(); | |||
| auto param_info = param->param_info(); | |||
| if (param_info && param_info->cache_enable()) { | |||
| parameter_cache_enable_set.insert(param); | |||
| } | |||
| } | |||
| return parameter_cache_enable_set; | |||
| } | |||
| CNodePtrList FindUniqueCacheEnable(const CNodePtrList &cnodes) { | |||
| size_t cnodes_size = cnodes.size(); | |||
| CNodePtrList unique_cache_enable; | |||
| for (size_t i = 0; i < cnodes_size; ++i) { | |||
| if (IsPrimitiveCNode(cnodes[i], prim::kPrimUnique)) { | |||
| auto unique_node = cnodes[i]; | |||
| auto unique_prim = GetCNodePrimitive(unique_node); | |||
| MS_EXCEPTION_IF_NULL(unique_prim); | |||
| auto attr_value = unique_prim->GetAttr(kAttrCacheEnable); | |||
| if (attr_value != nullptr && GetValue<bool>(attr_value)) { | |||
| unique_cache_enable.emplace_back(unique_node); | |||
| } | |||
| } | |||
| } | |||
| if (unique_cache_enable.size() > 1) { | |||
| MS_LOG(EXCEPTION) << "Support only one of Unique op cache enable, but got " << unique_cache_enable.size(); | |||
| } | |||
| return unique_cache_enable; | |||
| } | |||
| void BindAndInitCacheTensor(const ParamMap ¶m_pair_list, const ParameterPtr &hashmap) { | |||
| auto hashmap_tensor_value = hashmap->default_param(); | |||
| auto hashmap_tensor = hashmap_tensor_value->cast<std::shared_ptr<tensor::Tensor>>(); | |||
| for (auto &ele : param_pair_list) { | |||
| auto host_tensor_value = ele.second->default_param(); | |||
| auto host_tensor = host_tensor_value->cast<std::shared_ptr<tensor::Tensor>>(); | |||
| auto cache_tensor_value = ele.first->default_param(); | |||
| auto cache_tensor = cache_tensor_value->cast<std::shared_ptr<tensor::Tensor>>(); | |||
| // bind host, cache, hashmap | |||
| host_tensor->set_cache_enable(true); | |||
| host_tensor->set_hashmap_tensor_ptr(hashmap_tensor); | |||
| host_tensor->set_cache_tensor_ptr(cache_tensor); | |||
| // init cache tensor data | |||
| auto cache_byte_size = cache_tensor->Size(); | |||
| int ret = memcpy_s(cache_tensor->data_c(), cache_byte_size, host_tensor->data_c(), cache_byte_size); | |||
| if (ret != 0) { | |||
| MS_LOG(EXCEPTION) << "Memcpy failed."; | |||
| } | |||
| } | |||
| } | |||
| template <typename T> | |||
| void InitHashMapData(void *data, const int64_t host_size, const int64_t cache_size, const size_t hashmap_size, | |||
| const size_t byte_size) { | |||
| MS_LOG(INFO) << "Start init hashmap data."; | |||
| MS_EXCEPTION_IF_NULL(data); | |||
| HashmapEntry<T> *hashmap_data = static_cast<HashmapEntry<T> *>(data); | |||
| MS_EXCEPTION_IF_NULL(hashmap_data); | |||
| int ret = memset_s(hashmap_data, byte_size, 0, byte_size); | |||
| if (ret != 0) { | |||
| MS_LOG(EXCEPTION) << "Memset failed."; | |||
| } | |||
| std::vector<T> host_range; | |||
| host_range.reserve(host_size); | |||
| for (int64_t i = 0; i < host_size; ++i) { | |||
| host_range.emplace_back(i); | |||
| } | |||
| std::random_shuffle(host_range.begin(), host_range.end()); | |||
| size_t size = cache_size; | |||
| size_t hashmap_count = 0; | |||
| for (size_t i = 0; i < size; ++i) { | |||
| auto random_key = host_range[i]; | |||
| auto entry = HashFunc(random_key, hashmap_size); | |||
| size_t count = 1; | |||
| while (!hashmap_data[entry].IsEmpty() && !hashmap_data[entry].IsKey(random_key)) { | |||
| count += 1; | |||
| entry = (entry + 1) % hashmap_size; | |||
| } | |||
| if (hashmap_data[entry].IsEmpty()) { | |||
| hashmap_count++; | |||
| hashmap_data[entry].key_ = random_key; | |||
| hashmap_data[entry].value_ = i; | |||
| hashmap_data[entry].step_ = kInitStep; | |||
| hashmap_data[entry].tag_ = count; | |||
| } | |||
| } | |||
| MS_LOG(INFO) << "Hashmap init success, with " << hashmap_count << " / " << hashmap_size; | |||
| } | |||
| AnfNodePtr InitHashMap(const FuncGraphPtr &func_graph, const int64_t host_size, const int64_t cache_size, | |||
| TypeId type_id) { | |||
| // init new tensor | |||
| size_t hashmap_size = cache_size * kEmptyRate; | |||
| std::vector<int64_t> host_shape{static_cast<int64_t>(hashmap_size), 4}; | |||
| auto new_tensor = std::make_shared<tensor::Tensor>(type_id, host_shape); | |||
| size_t byte_size = new_tensor->Size(); | |||
| if (type_id == TypeId::kNumberTypeInt64) { | |||
| InitHashMapData<int64_t>(new_tensor->data_c(), host_size, cache_size, hashmap_size, byte_size); | |||
| } else { | |||
| InitHashMapData<int32_t>(new_tensor->data_c(), host_size, cache_size, hashmap_size, byte_size); | |||
| } | |||
| ParamInfoPtr new_param_info = std::make_shared<ParamInfo>(); | |||
| std::string hashmap_name = "cache_hashmap"; | |||
| new_param_info->set_name(hashmap_name); | |||
| new_tensor->set_param_info(new_param_info); | |||
| auto hashmap = func_graph->AddWeightParameter(hashmap_name); | |||
| hashmap->set_default_param(MakeValue(new_tensor)); | |||
| hashmap->set_abstract(new_tensor->ToAbstract()); | |||
| return hashmap; | |||
| } | |||
| AnfNodePtr InitStep(const FuncGraphPtr &func_graph, TypeId type_id) { | |||
| std::vector<int64_t> host_shape{1}; | |||
| auto new_tensor = std::make_shared<tensor::Tensor>(type_id, host_shape); | |||
| auto step_data = static_cast<int64_t *>(new_tensor->data_c()); | |||
| step_data[0] = 0; | |||
| ParamInfoPtr new_param_info = std::make_shared<ParamInfo>(); | |||
| std::string step_name = "cache_step"; | |||
| new_param_info->set_name(step_name); | |||
| new_tensor->set_param_info(new_param_info); | |||
| auto step = func_graph->AddWeightParameter(step_name); | |||
| step->set_default_param(MakeValue(new_tensor)); | |||
| step->set_abstract(new_tensor->ToAbstract()); | |||
| return step; | |||
| } | |||
| AnfNodePtr CreateMapCacheIdx(const FuncGraphPtr &func_graph, const AnfNodePtr &indices, | |||
| const ParamMap &cache_host_params_map) { | |||
| auto iter = cache_host_params_map.begin(); | |||
| int64_t cache_size = iter->first->abstract()->GetShapeTrack()->cast<abstract::ShapePtr>()->shape()[0]; | |||
| int64_t host_size = iter->second->abstract()->GetShapeTrack()->cast<abstract::ShapePtr>()->shape()[0]; | |||
| auto indices_type = indices->Type(); | |||
| auto indices_element_type = indices_type->cast<mindspore::TensorTypePtr>()->element(); | |||
| auto indices_type_id = indices_element_type->type_id(); | |||
| auto hashmap = InitHashMap(func_graph, host_size, cache_size, indices_type_id); | |||
| auto step = InitStep(func_graph, indices_type_id); | |||
| auto max_num = NewValueNode(MakeValue(host_size)); | |||
| auto hashmap_param = hashmap->cast<ParameterPtr>(); | |||
| BindAndInitCacheTensor(cache_host_params_map, hashmap_param); | |||
| // add rank_id | |||
| int64_t offset_value = 0; | |||
| std::string rank_id_str = common::GetEnv("RANK_ID"); | |||
| if (!rank_id_str.empty()) { | |||
| int64_t rank_id = atoi(rank_id_str.c_str()); | |||
| offset_value = rank_id * host_size; | |||
| } | |||
| auto offset = NewValueNode(MakeValue(offset_value)); | |||
| auto max_num_imm = std::make_shared<Int64Imm>(SizeToLong(host_size)); | |||
| auto max_num_abstract_scalar = std::make_shared<abstract::AbstractScalar>(max_num_imm); | |||
| max_num->set_abstract(max_num_abstract_scalar); | |||
| auto offset_imm = std::make_shared<Int64Imm>(SizeToLong(offset_value)); | |||
| auto offset_abstract_scalar = std::make_shared<abstract::AbstractScalar>(offset_imm); | |||
| offset->set_abstract(offset_abstract_scalar); | |||
| PrimitivePtr map_cache_primitive = prim::kPrimMapCacheIdx; | |||
| map_cache_primitive->set_attr(kAttrPrimitiveTarget, MakeValue("CPU")); | |||
| std::vector<AnfNodePtr> map_cache_nodes{NewValueNode(map_cache_primitive), hashmap, indices, step, max_num, offset}; | |||
| auto map_cache_idx = func_graph->NewCNode(map_cache_nodes); | |||
| auto indices_ori_shp = indices->Shape(); | |||
| auto indices_shp = indices_ori_shp->cast<abstract::ShapePtr>(); | |||
| ShapeVector shape; | |||
| ShapeVector min_shape; | |||
| ShapeVector max_shape; | |||
| if (!indices_shp->max_shape().empty()) { | |||
| max_shape = indices_shp->max_shape(); | |||
| } else { | |||
| max_shape = indices_shp->shape(); | |||
| } | |||
| for (size_t i = 0; i < max_shape.size(); i++) { | |||
| shape.emplace_back(-1); | |||
| min_shape.emplace_back(1); | |||
| } | |||
| auto cache_idx = std::make_shared<abstract::AbstractTensor>(indices_element_type, indices_shp); | |||
| auto old_emb_idx = std::make_shared<abstract::AbstractTensor>( | |||
| indices_element_type, std::make_shared<abstract::Shape>(shape, min_shape, max_shape)); | |||
| auto miss_emb_idx = std::make_shared<abstract::AbstractTensor>( | |||
| indices_element_type, std::make_shared<abstract::Shape>(shape, min_shape, max_shape)); | |||
| auto swap_emb_idx = std::make_shared<abstract::AbstractTensor>( | |||
| indices_element_type, std::make_shared<abstract::Shape>(shape, min_shape, max_shape)); | |||
| std::vector<std::shared_ptr<abstract::AbstractBase>> elements = {cache_idx, old_emb_idx, miss_emb_idx, swap_emb_idx}; | |||
| auto abstract = std::make_shared<abstract::AbstractTuple>(elements); | |||
| map_cache_idx->set_abstract(abstract); | |||
| return map_cache_idx; | |||
| } | |||
| AnfNodePtr CreateTupleGetItem(const FuncGraphPtr &func_graph, const AnfNodePtr &input, size_t index) { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| auto idx = NewValueNode(SizeToLong(index)); | |||
| MS_EXCEPTION_IF_NULL(idx); | |||
| auto imm = std::make_shared<Int64Imm>(SizeToLong(index)); | |||
| auto abstract_scalar = std::make_shared<abstract::AbstractScalar>(imm); | |||
| idx->set_abstract(abstract_scalar); | |||
| auto tuple_getitem = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), input, idx}); | |||
| auto input_abstract_tuple = dyn_cast<abstract::AbstractTuple>(input->abstract()); | |||
| auto tuple_getitem_abstract = input_abstract_tuple->elements()[index]; | |||
| tuple_getitem->set_abstract(tuple_getitem_abstract); | |||
| return tuple_getitem; | |||
| } | |||
| void CreateTupleGetItems(const FuncGraphPtr &func_graph, const AnfNodePtr &input, std::vector<AnfNodePtr> *outputs) { | |||
| auto input_abstract_tuple = dyn_cast<abstract::AbstractTuple>(input->abstract()); | |||
| auto size = input_abstract_tuple->elements().size(); | |||
| for (size_t i = 0; i < size; ++i) { | |||
| (*outputs).emplace_back(CreateTupleGetItem(func_graph, input, i)); | |||
| } | |||
| MS_EXCEPTION_IF_NULL(outputs); | |||
| } | |||
| AnfNodePtr CreateEmbeddingLookup(const FuncGraphPtr &graph, ParameterPtr params, AnfNodePtr indices) { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| PrimitivePtr emb_lookup_primitive = prim::kPrimEmbeddingLookup; | |||
| emb_lookup_primitive->set_attr(kAttrPrimitiveTarget, MakeValue("CPU")); | |||
| emb_lookup_primitive->set_attr(kAttrOffset, MakeValue<int64_t>(0)); | |||
| std::vector<AnfNodePtr> emb_lookup_nodes{NewValueNode(emb_lookup_primitive), params, indices}; | |||
| auto emb_lookup = graph->NewCNode(emb_lookup_nodes); | |||
| return emb_lookup; | |||
| } | |||
| AnfNodePtr CreateCacheSwapTable(const FuncGraphPtr &graph, ParameterPtr cache_table, AnfNodePtr swap_cache_idx, | |||
| AnfNodePtr miss_value) { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| PrimitivePtr cache_swap_table_primitive = prim::kPrimCacheSwapTable; | |||
| std::vector<AnfNodePtr> cache_swap_table_nodes{NewValueNode(cache_swap_table_primitive), cache_table, swap_cache_idx, | |||
| miss_value}; | |||
| auto cache_swap_table = graph->NewCNode(cache_swap_table_nodes); | |||
| return cache_swap_table; | |||
| } | |||
| AnfNodePtr CreateUpdateCache(const FuncGraphPtr &graph, ParameterPtr params, AnfNodePtr old_emb_idx, | |||
| AnfNodePtr old_value) { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| PrimitivePtr update_cache_primitive = prim::kPrimUpdateCache; | |||
| update_cache_primitive->set_attr(kAttrPrimitiveTarget, MakeValue("CPU")); | |||
| auto params_ori_shp = params->Shape(); | |||
| MS_EXCEPTION_IF_NULL(params_ori_shp); | |||
| auto params_shp = params_ori_shp->cast<abstract::ShapePtr>(); | |||
| MS_EXCEPTION_IF_NULL(params_shp); | |||
| auto params_shape = params_shp->shape(); | |||
| auto max_size = params_shape[0]; | |||
| auto max_size_node = NewValueNode(MakeValue(max_size)); | |||
| auto max_num_imm = std::make_shared<Int64Imm>(SizeToLong(max_size)); | |||
| auto max_num_abstract_scalar = std::make_shared<abstract::AbstractScalar>(max_num_imm); | |||
| max_size_node->set_abstract(max_num_abstract_scalar); | |||
| std::vector<AnfNodePtr> update_cache_nodes{NewValueNode(update_cache_primitive), params, old_emb_idx, old_value, | |||
| max_size_node}; | |||
| auto update_cache = graph->NewCNode(update_cache_nodes); | |||
| return update_cache; | |||
| } | |||
| NodePairList CreateEmbSwapUpdate(const FuncGraphPtr &graph, ParamMap param_pair_list, | |||
| const AnfNodePtrList &map_cache_idx_node_outputs) { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| NodePairList node_pair_list; | |||
| for (auto &ele : param_pair_list) { | |||
| auto emb_lookup = CreateEmbeddingLookup(graph, ele.second, map_cache_idx_node_outputs[2]); | |||
| auto cache_swap_table = CreateCacheSwapTable(graph, ele.first, map_cache_idx_node_outputs[3], emb_lookup); | |||
| auto update_cache = CreateUpdateCache(graph, ele.second, map_cache_idx_node_outputs[1], cache_swap_table); | |||
| node_pair_list.emplace_back(std::make_pair(cache_swap_table, update_cache)); | |||
| } | |||
| return node_pair_list; | |||
| } | |||
| AnfNodePtr CreateControlDepend(const FuncGraphPtr &main_graph, const AnfNodePtr &prior_node, | |||
| const AnfNodePtr &behind_node) { | |||
| // Create control depend | |||
| MS_EXCEPTION_IF_NULL(main_graph); | |||
| AnfNodePtrList cd_inputs = {NewValueNode(prim::kPrimControlDepend), prior_node, behind_node}; | |||
| auto control_depend_cnode = main_graph->NewCNode(cd_inputs); | |||
| return control_depend_cnode; | |||
| } | |||
| AnfNodePtr CreateDepend(const FuncGraphPtr &graph, const std::vector<AnfNodePtr> &invalid_nodes, | |||
| const AnfNodePtr &patron_node) { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| std::vector<AnfNodePtr> make_tuple_list{NewValueNode(prim::kPrimMakeTuple)}; | |||
| std::copy(invalid_nodes.begin(), invalid_nodes.end(), std::back_inserter(make_tuple_list)); | |||
| auto make_tuple = graph->NewCNode(make_tuple_list); | |||
| std::vector<AnfNodePtr> depend_list{NewValueNode(prim::kPrimDepend), patron_node, make_tuple}; | |||
| auto depend_cnode = graph->NewCNode(depend_list); | |||
| depend_cnode->set_abstract(patron_node->abstract()); | |||
| return depend_cnode; | |||
| } | |||
| CNodePtrList FindSparseGatherV2WithCache(const CNodePtrList &cnodes, const ParamSet ¶m_set) { | |||
| size_t cnodes_size = cnodes.size(); | |||
| CNodePtrList sparse_gather_v2_with_cache; | |||
| for (size_t i = 0; i < cnodes_size; ++i) { | |||
| if (IsPrimitiveCNode(cnodes[i], prim::kPrimSparseGatherV2)) { | |||
| auto param_node = cnodes[i]->input(1)->cast<ParameterPtr>(); | |||
| if (param_set.find(param_node) != param_set.end()) { | |||
| sparse_gather_v2_with_cache.push_back(cnodes[i]); | |||
| } | |||
| } | |||
| } | |||
| if (sparse_gather_v2_with_cache.empty()) { | |||
| MS_LOG(EXCEPTION) << "Can not find SparseGatherV2 with cache param."; | |||
| } | |||
| auto indices = sparse_gather_v2_with_cache[0]->input(2); | |||
| for (auto &ele : sparse_gather_v2_with_cache) { | |||
| if (ele->input(2) != indices) { | |||
| MS_LOG(EXCEPTION) << "SparseGatherV2 which with cache param have different indices!."; | |||
| } | |||
| } | |||
| return sparse_gather_v2_with_cache; | |||
| } | |||
| AnfNodePtr FindGatherV2FromSparseGatherV2(const FuncGraphPtr &graph, const AnfNodePtr &node) { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| AnfNodePtrList gatherv2_nodes; | |||
| auto user_set = graph->manager()->node_users()[node]; | |||
| for (auto &ele : user_set) { | |||
| if (IsPrimitiveCNode(ele.first, prim::kPrimGatherV2)) { | |||
| gatherv2_nodes.emplace_back(ele.first); | |||
| } | |||
| } | |||
| if (gatherv2_nodes.size() != 1) { | |||
| MS_LOG(EXCEPTION) << "SparseGatherV2 with cache can only used by one of gatherv2, but got " | |||
| << gatherv2_nodes.size(); | |||
| } | |||
| return gatherv2_nodes[0]; | |||
| } | |||
| void AddCacheEmbedding(const FuncGraphPtr &graph) { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| std::list<CNodePtr> orders = graph->GetOrderedCnodes(); | |||
| CNodePtrList cnodes(orders.begin(), orders.end()); | |||
| size_t cnodes_size = cnodes.size(); | |||
| auto manager = graph->manager(); | |||
| MS_EXCEPTION_IF_NULL(manager); | |||
| bool training = graph->has_flag("training"); | |||
| auto param_cache_enable_set = FindParamCacheEnable(graph); | |||
| if (param_cache_enable_set.empty()) { | |||
| MS_LOG(INFO) << "Parameters are all not cache enable."; | |||
| return; | |||
| } else { | |||
| MS_LOG(INFO) << "Parameters have cache enable."; | |||
| } | |||
| if (!CheckHostCacheParamSize(param_cache_enable_set)) { | |||
| return; | |||
| } | |||
| if (training) { | |||
| // If training, create cache parameters corresponding to the host params with is cache_enable. | |||
| // Replace the host params. Create hashmap then insert MapCacheIdx op after Unique with has 'cache_enable' attr. | |||
| // Bind hashmap tensor ptr and cache tensor ptr to host tensor, so that we can flush values | |||
| // from cache to host in each epoch end. | |||
| // Create EmbeddingLookup(CPU), CacheSwapTable(Ascend), UpdateCache(CPU) for each pair of params, in order to | |||
| // flush miss values to cache params and write back old values to host params. | |||
| // If no use pipe in training, EmbeddingLookup and CacheSwapTable must execute before SparseGatherV2, so add | |||
| // ControlDepend between them. And add Depend for UpdateCache op and ControlDepnd op to add nodes into graph. | |||
| auto unique_cache_enable = FindUniqueCacheEnable(cnodes); | |||
| if (unique_cache_enable.empty()) { | |||
| MS_LOG(WARNING) << "Parameters have cache enable, but not find Unique op cache enable."; | |||
| return; | |||
| } | |||
| auto cache_host_params_map = AddCacheParameters(graph, param_cache_enable_set); | |||
| auto param_set = MapKeysToSet(cache_host_params_map); | |||
| ReplaceCacheParams(graph, cache_host_params_map); | |||
| graph->set_flag(GRAPH_FLAG_CACHE_ENABLE, true); | |||
| auto unique_node = unique_cache_enable[0]; | |||
| CNodePtrList sparse_gatherv2_with_cache = FindSparseGatherV2WithCache(cnodes, param_set); | |||
| auto unique_node_output_0 = CreateTupleGetItem(graph, unique_node, 0); | |||
| auto map_cache_idx = CreateMapCacheIdx(graph, unique_node_output_0, cache_host_params_map); | |||
| AnfNodePtrList map_cache_idx_node_outputs; | |||
| CreateTupleGetItems(graph, map_cache_idx, &map_cache_idx_node_outputs); | |||
| if (!manager->Replace(sparse_gatherv2_with_cache[0]->input(2), map_cache_idx_node_outputs[0])) { | |||
| MS_LOG(EXCEPTION) << "MapCacheIdx output[0] replace node failed"; | |||
| } | |||
| auto node_pair_list = CreateEmbSwapUpdate(graph, cache_host_params_map, map_cache_idx_node_outputs); | |||
| AnfNodePtr last_node = cnodes[cnodes_size - 1]; | |||
| CNodePtr return_node; | |||
| if (last_node->isa<CNode>()) { | |||
| return_node = last_node->cast<CNodePtr>(); | |||
| } | |||
| MS_EXCEPTION_IF_NULL(return_node); | |||
| if (!IsPrimitiveCNode(return_node, prim::kPrimReturn)) { | |||
| MS_LOG(EXCEPTION) << "The last cnode after sorting, not return cnode."; | |||
| } | |||
| if (return_node->inputs().size() < 2) { | |||
| MS_LOG(EXCEPTION) << "Number of return node inputs should be great than or equal to 2."; | |||
| } | |||
| AnfNodePtrList invalid_nodes; | |||
| for (auto &ele : node_pair_list) { | |||
| std::transform(sparse_gatherv2_with_cache.begin(), sparse_gatherv2_with_cache.end(), | |||
| std::back_inserter(invalid_nodes), [&graph, &ele](const AnfNodePtr &sparse_gatherv2) { | |||
| return CreateControlDepend(graph, ele.first, sparse_gatherv2); | |||
| }); | |||
| invalid_nodes.emplace_back(ele.second); | |||
| } | |||
| auto depend_node = CreateDepend(graph, invalid_nodes, return_node->input(1)); | |||
| if (!manager->Replace(return_node->input(1), depend_node)) { | |||
| MS_LOG(EXCEPTION) << "Depend replace node failed"; | |||
| } | |||
| } else { | |||
| // If eval, Use EmbeddingLookup(CPU) op to replace GatherV2. | |||
| // The network is the same as Host-Device mode. | |||
| auto unique_cache_enable = FindUniqueCacheEnable(cnodes); | |||
| if (unique_cache_enable.empty()) { | |||
| MS_LOG(WARNING) << "Parameters have cache enable, but not find Unique op cache enable."; | |||
| return; | |||
| } | |||
| graph->set_flag(GRAPH_FLAG_CACHE_ENABLE, true); | |||
| // replace GatherV2 to EmbeddingLookupCPU | |||
| auto indices = unique_cache_enable[0]->input(1); | |||
| auto sparse_gatherv2_with_cache = FindSparseGatherV2WithCache(cnodes, param_cache_enable_set); | |||
| for (auto &ele : sparse_gatherv2_with_cache) { | |||
| auto anf_ele = ele->cast<AnfNodePtr>(); | |||
| auto gatherv2 = FindGatherV2FromSparseGatherV2(graph, anf_ele); | |||
| auto param = ele->input(1)->cast<ParameterPtr>(); | |||
| auto embedding_lookup = CreateEmbeddingLookup(graph, param, indices); | |||
| if (!manager->Replace(gatherv2, embedding_lookup)) { | |||
| MS_LOG(EXCEPTION) << "Depend replace node failed"; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } // namespace parallel | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,28 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_CACHE_EMBEDDING_CACHE_EMBEDDING_H_ | |||
| #define MINDSPORE_CCSRC_FRONTEND_PARALLEL_CACHE_EMBEDDING_CACHE_EMBEDDING_H_ | |||
| #include "ir/anf.h" | |||
| namespace mindspore { | |||
| namespace parallel { | |||
| // Automatically adding control depend based on effect order and side effect analysis. | |||
| void AddCacheEmbedding(const FuncGraphPtr &graph); | |||
| } // namespace parallel | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_CACHE_EMBEDDING_CACHE_EMBEDDING_H_ | |||
| @@ -36,11 +36,14 @@ | |||
| #include "frontend/optimizer/graph_transform.h" | |||
| #include "frontend/parallel/step_parallel.h" | |||
| #include "frontend/parallel/step_auto_parallel.h" | |||
| #include "frontend/parallel/cache_embedding/cache_embedding.h" | |||
| #include "frontend/parallel/allreduce_fusion/step_allreduce_fusion.h" | |||
| #include "frontend/optimizer/recompute.h" | |||
| #include "utils/log_adapter.h" | |||
| #include "pipeline/jit/pipeline_split.h" | |||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||
| #include "ps/util.h" | |||
| #endif | |||
| namespace mindspore { | |||
| namespace pipeline { | |||
| using OptPassGroupMap = opt::OptPassGroupMap; | |||
| @@ -391,6 +394,26 @@ bool AddRecomputationPass(const ResourcePtr &res) { | |||
| return true; | |||
| } | |||
| bool AddCacheEmbeddingPass(const ResourcePtr &res) { | |||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||
| if (ps::Util::IsParamServerMode()) { | |||
| return true; | |||
| } | |||
| #endif | |||
| FuncGraphPtr func_graph = res->func_graph(); | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| parallel::AddCacheEmbedding(func_graph); | |||
| if (func_graph->has_flag(GRAPH_FLAG_CACHE_ENABLE)) { | |||
| auto params = func_graph->parameters(); | |||
| AbstractBasePtrList args_spec_list; | |||
| std::for_each(params.begin(), params.end(), | |||
| [&args_spec_list](const AnfNodePtr &node) { args_spec_list.push_back(node->abstract()); }); | |||
| func_graph = pipeline::Renormalize(res, func_graph, args_spec_list); | |||
| } | |||
| return true; | |||
| } | |||
| bool MergeDupGraphPass(const ResourcePtr &res) { | |||
| FuncGraphPtr func_graph = res->func_graph(); | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| @@ -500,6 +523,7 @@ std::vector<PassItem> kVmPasses = {{"simplify_data_structures", SimplifyDataStru | |||
| {"tuple_transform", OptPassTransformGraphGroup}, | |||
| {"opt_graph_kernel_a", OptPassGraphKernelGroupA}, | |||
| {"opt_graph_kernel_b", OptPassGraphKernelGroupB}, | |||
| {"add_cache_embedding", AddCacheEmbeddingPass}, | |||
| {"add_control_depend", AddControlDependPass}, | |||
| {"add_recomputation", AddRecomputationPass}}; | |||
| @@ -37,6 +37,7 @@ bool PipelineSplitPass(const ResourcePtr &res); | |||
| bool ValidatePass(const ResourcePtr &res); | |||
| bool ConvertPrepareAdapt(const ResourcePtr &res); | |||
| bool AddControlDependPass(const ResourcePtr &res); | |||
| bool AddCacheEmbeddingPass(const ResourcePtr &res); | |||
| bool InferenceOptPreparePass(const ResourcePtr &res); | |||
| void ReclaimOptimizer(); | |||
| bool PynativeOptPass(const ResourcePtr &res); | |||
| @@ -32,6 +32,8 @@ REGISTER_PYBIND_DEFINE(ParamInfo, ([](const py::module *m) { | |||
| .def_property("parallel_optimizer", &ParamInfo::parallel_optimizer, | |||
| &ParamInfo::set_parallel_optimizer) | |||
| .def_property("comm_fusion", &ParamInfo::comm_fusion, &ParamInfo::set_comm_fusion) | |||
| .def_property("cache_enable", &ParamInfo::cache_enable, &ParamInfo::set_cache_enable) | |||
| .def_property("cache_shape", &ParamInfo::cache_shape, &ParamInfo::set_cache_shape) | |||
| .def(py::pickle( | |||
| [](const ParamInfo &p) { // __getstate__ | |||
| return py::make_tuple(p.name(), p.requires_grad(), p.layerwise_parallel()); | |||
| @@ -24,6 +24,7 @@ | |||
| #include "pybind_api/api_register.h" | |||
| #include "abstract/abstract_value.h" | |||
| #include "utils/shape_utils.h" | |||
| #include "utils/cache_embedding_hashmap_struct.h" | |||
| namespace mindspore { | |||
| namespace tensor { | |||
| @@ -272,6 +273,68 @@ py::int_ TensorPy::GetPyItemSize(const Tensor &tensor) { return tensor.data().it | |||
| py::int_ TensorPy::GetPyNBytes(const Tensor &tensor) { return tensor.data().nbytes(); } | |||
| template <typename T> | |||
| void MemCopyFromCacheToHost(void *hashmap_addr, void *host_addr, void *cache_addr, size_t host_max, size_t cache_max, | |||
| size_t hashmap_size, size_t col_size) { | |||
| auto host_data = static_cast<char *>(host_addr); | |||
| auto cache_data = static_cast<char *>(cache_addr); | |||
| auto hashmap_data = static_cast<HashmapEntry<T> *>(hashmap_addr); | |||
| // default param type float | |||
| size_t param_type_size = 4; | |||
| size_t single_col_bytes = param_type_size * col_size; | |||
| for (size_t i = 0; i < hashmap_size; ++i) { | |||
| if (!hashmap_data[i].IsEmpty()) { | |||
| size_t host_offset = single_col_bytes * hashmap_data[i].key_; | |||
| size_t cache_offset = single_col_bytes * hashmap_data[i].value_; | |||
| if (cache_offset + single_col_bytes <= cache_max) { | |||
| auto ret = | |||
| memcpy_s(host_data + host_offset, host_max - host_offset, cache_data + cache_offset, single_col_bytes); | |||
| if (ret != 0) { | |||
| MS_LOG(EXCEPTION) << "Memcpy failed."; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| MS_LOG(INFO) << "Memcpy from cache to host success!"; | |||
| } | |||
| void TensorPy::FlushFromCache(const Tensor &tensor) { | |||
| py::gil_scoped_release gil_release; | |||
| if (tensor.NeedWait()) { | |||
| tensor.Wait(); | |||
| } | |||
| tensor.data_sync(); | |||
| if (tensor.cache_enable()) { | |||
| MS_LOG(INFO) << tensor.ToString() << " is cache enable."; | |||
| auto hashmap_tensor_ptr = tensor.hashmap_tensor_ptr(); | |||
| auto cache_tensor_ptr = tensor.cache_tensor_ptr(); | |||
| if (hashmap_tensor_ptr != nullptr && cache_tensor_ptr != nullptr) { | |||
| hashmap_tensor_ptr->data_sync(); | |||
| cache_tensor_ptr->data_sync(); | |||
| auto hashmap_size = hashmap_tensor_ptr->shape_c()[0]; | |||
| auto host_shape = tensor.shape_c(); | |||
| auto cache_shape = cache_tensor_ptr->shape_c(); | |||
| if (host_shape.size() != 2 && host_shape.size() != 2 && host_shape[1] != cache_shape[1]) { | |||
| MS_LOG(EXCEPTION) << "Got host shape and cache shape invalid." | |||
| << "host shape:" << host_shape << ", cache shape:" << cache_shape; | |||
| } | |||
| auto host_data_max_size = tensor.Size(); | |||
| auto cache_data_max_size = cache_tensor_ptr->Size(); | |||
| auto hashmap_data_type = hashmap_tensor_ptr->data_type(); | |||
| if (hashmap_data_type == TypeId::kNumberTypeInt32) { | |||
| MemCopyFromCacheToHost<int32_t>(hashmap_tensor_ptr->data_c(), tensor.data_c(), cache_tensor_ptr->data_c(), | |||
| host_data_max_size, cache_data_max_size, hashmap_size, host_shape[1]); | |||
| } else if (hashmap_data_type == TypeId::kNumberTypeInt64) { | |||
| MemCopyFromCacheToHost<int32_t>(hashmap_tensor_ptr->data_c(), tensor.data_c(), cache_tensor_ptr->data_c(), | |||
| host_data_max_size, cache_data_max_size, hashmap_size, host_shape[1]); | |||
| } else { | |||
| MS_LOG(ERROR) << "Hashmap dtype only suppotr int32, in64."; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| py::array TensorPy::SyncAsNumpy(const Tensor &tensor) { | |||
| { | |||
| py::gil_scoped_release gil_release; | |||
| @@ -457,6 +520,16 @@ REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) { | |||
| array([[1., 1., 1.], | |||
| [1., 1., 1.]]) | |||
| )mydelimiter") | |||
| .def("_flush_from_cache", TensorPy::FlushFromCache, R"mydelimiter( | |||
| Flush Cache data to Host if tensor is cache enable. | |||
| Returns: | |||
| None. | |||
| Examples: | |||
| >>> data = mindspore.Tensor(np.ones((2, 3))) | |||
| >>> data._flush_from_cache() | |||
| )mydelimiter") | |||
| .def("is_init", &Tensor::is_init, R"mydelimiter( | |||
| Get tensor init_flag. | |||
| @@ -115,6 +115,8 @@ class TensorPy { | |||
| static py::int_ GetPyItemSize(const Tensor &tensor); | |||
| static py::int_ GetPyNBytes(const Tensor &tensor); | |||
| static void FlushFromCache(const Tensor &tensor); | |||
| }; | |||
| } // namespace tensor | |||
| } // namespace mindspore | |||
| @@ -268,7 +268,7 @@ void CPUKernelRuntime::CreateOutputTensors(session::KernelGraph *kernel_graph, | |||
| bound_addresses_.clear(); | |||
| auto output_nodes = kernel_graph->outputs(); | |||
| for (const auto &item : output_nodes) { | |||
| auto item_with_index = AnfAlgo::VisitKernelWithReturnType(item, 0, true); | |||
| auto item_with_index = AnfAlgo::VisitKernelWithReturnType(item, 0, false); | |||
| auto out = CreatTensorForOutput(kernel_graph, item_with_index, tensor_to_node); | |||
| outputs->push_back(std::move(out)); | |||
| } | |||
| @@ -0,0 +1,51 @@ | |||
| /** | |||
| * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). | |||
| * | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_UTIL_CACHE_EMBBEDDING_HASHMAP_STRUCT_H_ | |||
| #define MINDSPORE_CCSRC_UTIL_CACHE_EMBBEDDING_HASHMAP_STRUCT_H_ | |||
| #include <math.h> | |||
| namespace mindspore { | |||
| const int64_t kNullTag = 0; | |||
| const int64_t kInitStep = -5; | |||
| const int64_t kEmptyRate = 4; | |||
| const double kGoldenRatio = 0.6180339; | |||
| template <typename T> | |||
| struct HashmapEntry { | |||
| T key_; | |||
| T value_; | |||
| T step_; | |||
| T tag_; | |||
| bool IsEmpty() { return tag_ == kNullTag; } | |||
| bool IsUsing(const T train_step) { return step_ >= (train_step - 1); } | |||
| bool IsKey(const T emb_idx) { return key_ == emb_idx; } | |||
| void SetEmpty() { tag_ = kNullTag; } | |||
| }; | |||
| template <typename T> | |||
| T HashFunc(const T key, const size_t m) { | |||
| return (T)(((kGoldenRatio * key) - floor(kGoldenRatio * key)) * m); | |||
| } | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_UTIL_CACHE_EMBBEDDING_HASHMAP_STRUCT_H_ | |||
| @@ -350,6 +350,7 @@ constexpr auto kAttrPrimitiveTarget = "primitive_target"; | |||
| constexpr auto kAttrUseLocking = "use_locking"; | |||
| constexpr auto kAttrReduceScatterFlag = "reduce_scatter_flag"; | |||
| constexpr auto kAttrOffset = "offset"; | |||
| constexpr auto kAttrCacheEnable = "cache_enable"; | |||
| constexpr auto kAttrPsKey = "ps_key"; | |||
| constexpr auto kAttrOptimizerType = "optim_type"; | |||
| constexpr auto kAttrChildGraph = "child_graph"; | |||
| @@ -131,7 +131,7 @@ class Parameter(Tensor_): | |||
| if self.init_mode is not None: | |||
| data = self.init_mode | |||
| else: | |||
| # cast to break deep infinit loop while deepcopy | |||
| # cast to break deep infinite loop while deepcopy | |||
| data = Tensor(self) | |||
| return ( | |||
| Parameter, (data, self.name, self.requires_grad, self.layerwise_parallel)) | |||
| @@ -348,6 +348,8 @@ class Parameter(Tensor_): | |||
| x.is_param_ps = self.is_param_ps | |||
| x.init_in_server = self.init_in_server | |||
| x.cache_enable = self.cache_enable | |||
| if self.cache_shape: | |||
| x.cache_shape = self.cache_shape | |||
| if init != 'same': | |||
| shape = self.shape | |||
| dtype = self.dtype | |||
| @@ -375,6 +377,28 @@ class Parameter(Tensor_): | |||
| raise TypeError("`parallel_optimizer` parameter must be bool type") | |||
| self._param_info.parallel_optimizer = value | |||
| @property | |||
| def cache_enable(self): | |||
| """Return whether the parameter is cache enable.""" | |||
| return self._param_info.cache_enable | |||
| @cache_enable.setter | |||
| def cache_enable(self, value=True): | |||
| if not isinstance(value, bool): | |||
| raise TypeError("`cache_enable` parameter must be bool type") | |||
| self._param_info.cache_enable = value | |||
| @property | |||
| def cache_shape(self): | |||
| """Return the cache shape corresponding to the parameter if use cache.""" | |||
| return self._param_info.cache_shape | |||
| @cache_shape.setter | |||
| def cache_shape(self, value): | |||
| if not isinstance(value, (tuple, list)): | |||
| raise TypeError("`cache_shape` parameter must be tuple or list type") | |||
| self._param_info.cache_shape = value | |||
| @property | |||
| def requires_grad(self): | |||
| """Return whether the parameter requires gradient.""" | |||
| @@ -308,6 +308,10 @@ class Tensor(Tensor_): | |||
| """Convert tensor to numpy array.""" | |||
| return Tensor_.asnumpy(self) | |||
| def _flush_from_cache(self): | |||
| """Flush cache data to host if tensor is cache enable.""" | |||
| Tensor_._flush_from_cache(self) | |||
| def all(self, axis=(), keep_dims=False): | |||
| """ | |||
| Check all array elements along a given axis evaluate to True. | |||
| @@ -60,6 +60,7 @@ using ValueNodePtr = std::shared_ptr<ValueNode>; | |||
| class CNode; | |||
| using CNodePtr = std::shared_ptr<CNode>; | |||
| using CNodePtrList = std::vector<CNodePtr>; | |||
| class FuncGraph; | |||
| using FuncGraphSet = OrderedSet<FuncGraphPtr>; | |||
| @@ -88,7 +89,7 @@ using ParamInfoPtr = std::shared_ptr<ParamInfo>; | |||
| // intermediate_abstract: return the cached inferring abstract value. | |||
| // Type/Shape: return the related info of this AnfNode. When this AnfNode is an | |||
| // input of other CNodes, you can get the related info by this method. | |||
| // debug_info: return the information retrived from parser. Set it using set_debug_info. | |||
| // debug_info: return the information retrieved from parser. Set it using set_debug_info. | |||
| // fullname_with_scope: return the detailed debug info. | |||
| class AnfNode : public Base { | |||
| public: | |||
| @@ -167,7 +167,6 @@ class MetaTensor : public Value { | |||
| // Get tensor's param_info info. | |||
| ParamInfoPtr param_info() const { return param_info_; } | |||
| bool is_parameter() const { return is_parameter_; } | |||
| // Set tensor's param_info info. | |||
| void set_param_info(const ParamInfoPtr ¶m_info) { | |||
| is_parameter_ = true; | |||
| @@ -81,6 +81,12 @@ class ParamInfo { | |||
| bool parallel_optimizer() const { return parallel_optimizer_; } | |||
| void set_parallel_optimizer(bool parallel_optimizer) { parallel_optimizer_ = parallel_optimizer; } | |||
| bool cache_enable() const { return cache_enable_; } | |||
| void set_cache_enable(bool cache_enable) { cache_enable_ = cache_enable; } | |||
| std::vector<int64_t> cache_shape() const { return cache_shape_; } | |||
| void set_cache_shape(const std::vector<int64_t> &cache_shape) { cache_shape_ = cache_shape; } | |||
| private: | |||
| std::string name_{"Parameter"}; | |||
| bool requires_grad_{true}; | |||
| @@ -92,6 +98,8 @@ class ParamInfo { | |||
| int32_t cloned_index_{0}; | |||
| int32_t fusion_type_{1}; | |||
| bool parallel_optimizer_{true}; | |||
| bool cache_enable_{false}; | |||
| std::vector<int64_t> cache_shape_; | |||
| }; | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CORE_IR_PARAM_INFO_H_ | |||
| @@ -449,6 +449,9 @@ Tensor::Tensor(const Tensor &tensor) | |||
| event_(tensor.event_), | |||
| sync_status_(tensor.sync_status_), | |||
| device_sync_(tensor.device_sync_), | |||
| cache_enable_(tensor.cache_enable_), | |||
| cache_tensor_ptr_(tensor.cache_tensor_ptr_), | |||
| hashmap_tensor_ptr_(tensor.hashmap_tensor_ptr_), | |||
| padding_type_(tensor.padding_type()) {} | |||
| Tensor::Tensor(const Tensor &tensor, TypeId data_type) | |||
| @@ -459,6 +462,9 @@ Tensor::Tensor(const Tensor &tensor, TypeId data_type) | |||
| event_(tensor.event_), | |||
| sync_status_(tensor.sync_status_), | |||
| device_sync_(tensor.device_sync_), | |||
| cache_enable_(tensor.cache_enable_), | |||
| cache_tensor_ptr_(tensor.cache_tensor_ptr_), | |||
| hashmap_tensor_ptr_(tensor.hashmap_tensor_ptr_), | |||
| padding_type_(tensor.padding_type()) {} | |||
| Tensor::Tensor(TypeId data_type, const ShapeVector &shape, TensorDataPtr data) | |||
| @@ -511,7 +517,7 @@ bool Tensor::ValueEqual(const Tensor &tensor) const { | |||
| return (&tensor == this || (MetaTensor::operator==(tensor) && data_->equals(*tensor.data_))); | |||
| } | |||
| // assgin value to this tensor | |||
| // assign value to this tensor | |||
| Tensor &Tensor::AssignValue(const Tensor &tensor) { | |||
| if (this != &tensor) { | |||
| MetaTensor::operator=(tensor); | |||
| @@ -206,7 +206,7 @@ class Tensor : public MetaTensor { | |||
| // it do real value comparison. | |||
| bool ValueEqual(const Tensor &tensor) const; | |||
| // assgin value to this tensor | |||
| // assign value to this tensor | |||
| Tensor &AssignValue(const Tensor &tensor); | |||
| bool operator==(const Value &other) const override { | |||
| @@ -291,6 +291,18 @@ class Tensor : public MetaTensor { | |||
| TypePtr cast_dtype() { return cast_dtype_; } | |||
| void set_cast_dtype(TypePtr dtype = nullptr) { cast_dtype_ = dtype; } | |||
| // used if cache_enable, in order to update tensor from cache to host | |||
| bool cache_enable() const { return cache_enable_; } | |||
| void set_cache_enable(bool cache_enable = true) { cache_enable_ = cache_enable; } | |||
| std::shared_ptr<Tensor> hashmap_tensor_ptr() const { return hashmap_tensor_ptr_; } | |||
| void set_hashmap_tensor_ptr(std::shared_ptr<Tensor> hashmap_tensor_ptr = nullptr) { | |||
| hashmap_tensor_ptr_ = hashmap_tensor_ptr; | |||
| } | |||
| std::shared_ptr<Tensor> cache_tensor_ptr() const { return cache_tensor_ptr_; } | |||
| void set_cache_tensor_ptr(std::shared_ptr<Tensor> cache_tensor_ptr = nullptr) { | |||
| cache_tensor_ptr_ = cache_tensor_ptr; | |||
| } | |||
| void SetNeedWait(bool need_wait) { | |||
| if (event_ != nullptr) { | |||
| event_->set_need_wait(need_wait); | |||
| @@ -335,6 +347,9 @@ class Tensor : public MetaTensor { | |||
| mutable TensorSyncStatus sync_status_{kNeedSyncHostToDevice}; | |||
| bool graph_output_{false}; | |||
| DeviceSyncPtr device_sync_{nullptr}; | |||
| bool cache_enable_{false}; | |||
| std::shared_ptr<Tensor> cache_tensor_ptr_{nullptr}; | |||
| std::shared_ptr<Tensor> hashmap_tensor_ptr_{nullptr}; | |||
| std::vector<Axis> padding_type_; | |||
| TypePtr cast_dtype_{nullptr}; | |||
| }; | |||
| @@ -21,6 +21,7 @@ const char GRAPH_FLAG_MIX_PRECISION_FP16[] = "fp16"; | |||
| const char GRAPH_FLAG_MIX_PRECISION_FP32[] = "fp32"; | |||
| const char GRAPH_FLAG_HAS_EFFECT[] = "has_effect"; | |||
| const char GRAPH_FLAG_EFFECT_PATIAL_ORDER[] = "_effect_patial_order"; | |||
| const char GRAPH_FLAG_CACHE_ENABLE[] = "cache_enable"; | |||
| const char GRAPH_FLAG_RANDOM_EFFECT[] = "_random_effect"; | |||
| const char GRAPH_FLAG_SIDE_EFFECT[] = "_side_effect"; | |||
| @@ -21,6 +21,7 @@ extern const char GRAPH_FLAG_MIX_PRECISION_FP16[]; | |||
| extern const char GRAPH_FLAG_MIX_PRECISION_FP32[]; | |||
| extern const char GRAPH_FLAG_HAS_EFFECT[]; | |||
| extern const char GRAPH_FLAG_EFFECT_PATIAL_ORDER[]; | |||
| extern const char GRAPH_FLAG_CACHE_ENABLE[]; | |||
| extern const char GRAPH_FLAG_RANDOM_EFFECT[]; | |||
| extern const char GRAPH_FLAG_SIDE_EFFECT[]; | |||
| @@ -172,8 +172,8 @@ class EmbeddingLookup(Cell): | |||
| or None. Default: None | |||
| sparse (bool): Using sparse mode. When 'target' is set to 'CPU', 'sparse' has to be true. Default: True. | |||
| vocab_cache_size (int): Cache size of the dictionary of embeddings. Default: 0. It is valid only in | |||
| parameter server trainning mode and 'DEVICE' target. And the moment parameter of corresponding | |||
| optimizer will also be set to the cache size. In addition, it should be noted that it will cost the 'DEVICE' | |||
| 'DEVICE' target. And the moment parameter of corresponding optimizer will also be set to the cache size. | |||
| In addition, it should be noted that it will cost the 'DEVICE' | |||
| memory, so suggests setting a reasonable value to avoid insufficient memory. | |||
| Inputs: | |||
| @@ -205,7 +205,12 @@ class EmbeddingLookup(Cell): | |||
| max_norm=None, sparse=True, vocab_cache_size=0): | |||
| super(EmbeddingLookup, self).__init__() | |||
| validator.check_value_type('sparse', sparse, [bool], self.cls_name) | |||
| self.vocab_size = validator.check_positive_int(vocab_size, 'vocab_size') | |||
| self.vocab_cache_size = validator.check_non_negative_int(vocab_cache_size, 'vocab_cache_size') | |||
| self.target = target | |||
| self.sparse = sparse | |||
| self.cache_enable = self.vocab_cache_size > 0 | |||
| self.forward_unique = False | |||
| if target not in ('CPU', 'DEVICE'): | |||
| raise ValueError('Attr \'target\' of \'EmbeddingLookup\' Op passed ' | |||
| + str(target) + ', should be one of values in \'CPU\', \'DEVICE\'.') | |||
| @@ -216,21 +221,23 @@ class EmbeddingLookup(Cell): | |||
| else: | |||
| self.gatherv2 = P.GatherV2() | |||
| self.embeddinglookup = P.EmbeddingLookup().add_prim_attr('primitive_target', 'CPU') | |||
| self.vocab_size = validator.check_positive_int(vocab_size, 'vocab_size') | |||
| self.vocab_cache_size = validator.check_non_negative_int(vocab_cache_size, 'vocab_cache_size') | |||
| self._process_vocab_cache(slice_mode) | |||
| enable_ps = _get_ps_context("enable_ps") | |||
| if enable_ps: | |||
| self._process_vocab_cache(slice_mode) | |||
| self.embedding_size = validator.check_positive_int(embedding_size, 'embedding_size') | |||
| self.embedding_table = Parameter(initializer(param_init, [self.vocab_size, self.embedding_size]), | |||
| name='embedding_table') | |||
| if self.cache_enable: | |||
| self._set_voacb_cache_enable(vocab_cache_size, embedding_size, vocab_size) | |||
| if self.cache_enable and enable_ps: | |||
| self._set_voacb_cache_enable_for_ps(vocab_cache_size, embedding_size, vocab_size) | |||
| parallel_mode = _get_parallel_mode() | |||
| is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL) | |||
| self.forward_unique = False | |||
| self.gather_revert = P.GatherV2() | |||
| self.unique = P.Unique().shard(((1,),)) | |||
| self.reshape_first = P.Reshape() | |||
| self.reshape = P.Reshape() | |||
| self.unique = P.Unique() | |||
| self.shape = P.Shape() | |||
| if is_auto_parallel: | |||
| self.unique = P.Unique().shard(((1,),)) | |||
| indices_shape_size = 2 | |||
| if slice_mode == "field_slice" and is_auto_parallel: | |||
| if not manual_shapes: | |||
| @@ -270,12 +277,34 @@ class EmbeddingLookup(Cell): | |||
| if is_auto_parallel: | |||
| raise ValueError("slice_mode should support mode in nn.EmbeddingLookup, but get " | |||
| + str(slice_mode)) | |||
| if self.cache_enable and not enable_ps: | |||
| if is_auto_parallel: | |||
| raise ValueError("parallel mode haven't supported cache enable yet.") | |||
| self._set_cache_enable() | |||
| self.embedding_table.unique = self.forward_unique | |||
| self.max_norm = max_norm | |||
| if self.max_norm is not None: | |||
| self.max_norm = validator.check_positive_float(self.max_norm, 'max_norm', self.cls_name) | |||
| self.max_norm = Tensor(self.max_norm, dtype=mstype.float32) | |||
| def _set_cache_enable(self): | |||
| """EmbeddingLookup cache check for not ps env.""" | |||
| if self.target != 'DEVICE': | |||
| logger.warning("The configuration of 'vocab_cache_size' is valid only in 'DEVICE' target, " | |||
| "so it will be ignored.") | |||
| return | |||
| if not self.sparse: | |||
| logger.warning("The configuration of 'vocab_cache_size' is valid only 'sparse' is true, " | |||
| "so it will be ignored.") | |||
| return | |||
| logger.info("EmbeddingLookup cache enable takes effect.") | |||
| self.forward_unique = True | |||
| self.unique = P.Unique().add_prim_attr('primitive_target', 'CPU') | |||
| self.unique.add_prim_attr('cache_enable', True) | |||
| self.embedding_table.cache_enable = self.cache_enable | |||
| self.embedding_table.cache_shape = (self.vocab_cache_size, self.embedding_size) | |||
| self.reshape_first = P.Reshape().add_prim_attr('primitive_target', 'CPU') | |||
| def _process_vocab_cache(self, slice_mode): | |||
| """PS embeddingLookup cache check and process.""" | |||
| self.cache_enable = False | |||
| @@ -302,7 +331,7 @@ class EmbeddingLookup(Cell): | |||
| if _is_role_worker(): | |||
| self.vocab_size = self.vocab_cache_size | |||
| def _set_voacb_cache_enable(self, vocab_cache_size, embedding_size, vocab_size): | |||
| def _set_voacb_cache_enable_for_ps(self, vocab_cache_size, embedding_size, vocab_size): | |||
| """PS embeddingLookup cache enable set.""" | |||
| self.embedding_table.cache_enable = True | |||
| self.embedding_table.is_param_ps = True | |||
| @@ -316,7 +345,7 @@ class EmbeddingLookup(Cell): | |||
| else: | |||
| if self.forward_unique: | |||
| shp = self.shape(indices) + (self.embedding_size,) | |||
| indices_flatten = self.reshape(indices, (-1,)) | |||
| indices_flatten = self.reshape_first(indices, (-1,)) | |||
| unique_id, unique_idx = self.unique(indices_flatten) | |||
| weight_unique = self.gatherv2(self.embedding_table, unique_id, 0) | |||
| weight_flatten = self.gather_revert(weight_unique, unique_idx, 0) | |||
| @@ -156,8 +156,8 @@ class Optimizer(Cell): | |||
| break | |||
| ps_filter = lambda x: x.is_param_ps | |||
| self.ps_parameters = tuple(ps_filter(x) for x in self.parameters) | |||
| ps_cache_filter = lambda x: x.cache_enable | |||
| self.cache_enable = tuple(ps_cache_filter(x) for x in self.parameters) | |||
| cache_filter = lambda x: x.cache_enable | |||
| self.cache_enable = tuple(cache_filter(x) for x in self.parameters) | |||
| self.reciprocal_scale = Tensor(1.0 / loss_scale, mstype.float32) | |||
| self.need_scale = loss_scale != 1.0 | |||
| self.global_step_increase_tensor = Tensor(1, mstype.int32) | |||
| @@ -526,6 +526,9 @@ class Model: | |||
| train_dataset.reset() | |||
| # if param is cache enable, flush data from cache to host before epoch end | |||
| self._flush_from_cache(cb_params) | |||
| list_callback.epoch_end(run_context) | |||
| should_stop = should_stop or run_context.get_stop_requested() | |||
| if should_stop: | |||
| @@ -784,5 +787,11 @@ class Model: | |||
| predict_net.compile(*predict_data) | |||
| return predict_net.parameter_layout_dict | |||
| def _flush_from_cache(self, cb_params): | |||
| """Flush cache data to host if tensor is cache enable.""" | |||
| params = cb_params.train_network.get_parameters() | |||
| for param in params: | |||
| if param.cache_enable: | |||
| Tensor(param)._flush_from_cache() | |||
| __all__ = ["Model"] | |||
| @@ -53,8 +53,8 @@ def init_var_dict(init_args, in_vars): | |||
| ''' | |||
| var_map = {} | |||
| _, _max_val = init_args | |||
| for _, iterm in enumerate(in_vars): | |||
| key, shape, method = iterm | |||
| for _, item in enumerate(in_vars): | |||
| key, shape, method = item | |||
| if key not in var_map.keys(): | |||
| if method in ['random', 'uniform']: | |||
| var_map[key] = Parameter(initializer( | |||
| @@ -257,9 +257,11 @@ class WideDeepModel(nn.Cell): | |||
| self.wide_embeddinglookup.embedding_table.set_param_ps() | |||
| else: | |||
| self.deep_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, self.emb_dim, | |||
| target='DEVICE', sparse=sparse) | |||
| target='DEVICE', sparse=sparse, | |||
| vocab_cache_size=self.vocab_cache_size) | |||
| self.wide_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, 1, | |||
| target='DEVICE', sparse=sparse) | |||
| target='DEVICE', sparse=sparse, | |||
| vocab_cache_size=self.vocab_cache_size) | |||
| self.embedding_table = self.deep_embeddinglookup.embedding_table | |||
| def construct(self, id_hldr, wt_hldr): | |||
| @@ -57,8 +57,8 @@ def init_var_dict(init_args, in_vars): | |||
| """ | |||
| var_map = {} | |||
| _, _max_val = init_args | |||
| for _, iterm in enumerate(in_vars): | |||
| key, shape, method = iterm | |||
| for _, item in enumerate(in_vars): | |||
| key, shape, method = item | |||
| if key not in var_map.keys(): | |||
| if method in ['random', 'uniform']: | |||
| var_map[key] = Parameter(initializer(Uniform(_max_val), shape, | |||