/** * Copyright 2019 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "parallel/tensor_layout/map.h" #include #include #include #include "common/utils.h" #include "parallel/status.h" #include "parallel/tensor_layout/shape_util.h" #include "utils/convert_utils.h" #include "utils/log_adapter.h" namespace mindspore { namespace parallel { Status Map::Init(const std::vector &array) { Status status = Array::Init(array); if (status != Status::SUCCESS) { return Status::FAILED; } if (!IsValidMap()) { MS_LOG(ERROR) << "invalid map " << this->ToString(); return Status::FAILED; } return Status::SUCCESS; } bool Map::IsValidMap() { if (std::any_of(array_.begin(), array_.end(), [](int32_t value) { return ((value < 0) && (value != MAP_NONE)); })) { return false; } // check that all none -1 value in array_ is different std::vector sorted_array = array_; std::sort(sorted_array.begin(), sorted_array.end()); int32_t value = MAP_NONE; for (auto &element : sorted_array) { if (element == MAP_NONE) { continue; } if (element == value) { return false; } value = element; } return true; } int32_t Map::GetMaxItem() const { if (!array_.empty()) { return *std::max_element(array_.begin(), array_.end()); } else { return MAP_NONE; } } int32_t Map::GetIndexByValue(int32_t value) const { auto iter = find(array_.begin(), array_.end(), value); if (iter != array_.end()) { return static_cast(std::distance(array_.begin(), iter)); } else { return MAP_NONE; } } /* * expand.size() should be equal to array_.size() */ std::shared_ptr Map::ExpandMapByNone(const Arrangement &expand_num_list) const { if (expand_num_list.GetDimSize() != GetDimSize()) { return nullptr; } std::vector new_shape; for (uint32_t i = 0; i != GetDimSize(); i++) { if (GetDimByIdx(i) == MAP_NONE) { for (int32_t j = 0; j < expand_num_list.GetDimByIdx(i); j++) { new_shape.push_back(MAP_NONE); } } else { new_shape.push_back(GetDimByIdx(i)); int32_t j = 1; while (j < expand_num_list.GetDimByIdx(i)) { new_shape.push_back(MAP_NONE); j++; } } } auto map_new = std::make_shared(); (void)map_new->Init(new_shape); return map_new; } /* * expand.size() should be equal to array_.size() */ std::shared_ptr Map::ExpandMapByDecreaseNumber(const Arrangement &expand_num_list) const { if (GetMaxItem() >= static_cast(expand_num_list.GetDimSize())) { return nullptr; } std::vector new_shape; for (uint32_t i = 0; i < GetDimSize(); i++) { if (GetDimByIdx(i) == MAP_NONE) { new_shape.push_back(MAP_NONE); } else { int32_t start_map = expand_num_list.ComputeReverseAccumulateSumInReverseOrder()[static_cast(GetDimByIdx(i))]; for (int32_t k = expand_num_list.GetDimByReverseIdx(static_cast(GetDimByIdx(i))) - 1; k >= 0; k--) { new_shape.push_back(k + start_map); } } } auto map_new = std::make_shared(); (void)map_new->Init(new_shape); return map_new; } std::shared_ptr> Map::ReMapVector(const std::vector &input_vector) const { if (GetMaxItem() >= static_cast(input_vector.size())) { return nullptr; } std::vector out; Arrangement empty_arrangement; for (uint32_t i = 0; i < GetDimSize(); i++) { if (GetDimByIdx(i) == MAP_NONE) { out.push_back(empty_arrangement); } else { out.push_back(input_vector[IntToUint(SizeToInt(input_vector.size()) - 1 - GetDimByIdx(i))]); } } return std::make_shared>(out); } bool Map::CheckNoneByIdxList(std::vector idx_list) const { for (auto &value : idx_list) { if (GetDimByIdx(SizeToUint(value)) != MAP_NONE) { return false; } } return true; } Map Map::SqueezeMapByIdxList(std::vector idx_list) const { std::vector out_shape; for (size_t i = 0; i < GetDimSize(); i++) { auto it = std::find(idx_list.begin(), idx_list.end(), i); if (it == idx_list.end()) { out_shape.push_back(GetDimByIdx(SizeToUint(i))); } } if (out_shape.empty()) { MS_LOG(ERROR) << "out_shape size is 0, this may not happen under current situation"; out_shape.push_back(MAP_NONE); } Map out; (void)out.Init(out_shape); return out; } } // namespace parallel } // namespace mindspore