|
|
|
@@ -1,5 +1,5 @@ |
|
|
|
/** |
|
|
|
* Copyright 2020 Huawei Technologies Co., Ltd |
|
|
|
* Copyright 2021 Huawei Technologies Co., Ltd |
|
|
|
* |
|
|
|
* Licensed under the Apache License, Version 2.0 (the "License"); |
|
|
|
* you may not use this file except in compliance with the License. |
|
|
|
@@ -15,19 +15,41 @@ |
|
|
|
*/ |
|
|
|
|
|
|
|
#include "ops/custom.h" |
|
|
|
#include "utils/check_convert_utils.h" |
|
|
|
#include "abstract/primitive_infer_map.h" |
|
|
|
#include "ops/op_utils.h" |
|
|
|
#include <memory> |
|
|
|
#include <map> |
|
|
|
|
|
|
|
namespace mindspore { |
|
|
|
namespace ops { |
|
|
|
void Custom::Init(const std::vector<int64_t> &custom) { this->set_custom(custom); } |
|
|
|
void Custom::Init(const std::string &type, const std::map<std::string, std::vector<uint8_t>> &attrs) { |
|
|
|
this->set_type(type); |
|
|
|
this->set_attr(attrs); |
|
|
|
} |
|
|
|
|
|
|
|
void Custom::set_type(const std::string &type) { this->AddAttr(kType, MakeValue(type)); } |
|
|
|
|
|
|
|
void Custom::set_custom(const std::vector<int64_t> &custom) { this->AddAttr(kCustom, MakeValue(custom)); } |
|
|
|
std::string Custom::get_type() const { |
|
|
|
auto value_ptr = this->GetAttr(kType); |
|
|
|
return GetValue<std::string>(value_ptr); |
|
|
|
} |
|
|
|
|
|
|
|
void Custom::set_attr(const std::map<std::string, std::vector<uint8_t>> &attrs) { |
|
|
|
ValuePtrList value_ptr_list; |
|
|
|
for (const auto &attr : attrs) { |
|
|
|
value_ptr_list.emplace_back(MakeValue<std::string>(attr.first)); |
|
|
|
value_ptr_list.emplace_back(MakeValue<std::vector<uint8_t>>(attr.second)); |
|
|
|
} |
|
|
|
this->AddAttr(kAttr, MakeValue(value_ptr_list)); |
|
|
|
} |
|
|
|
|
|
|
|
std::vector<int64_t> Custom::get_custom() const { |
|
|
|
auto value_ptr = this->GetAttr(kCustom); |
|
|
|
return GetValue<std::vector<int64_t>>(value_ptr); |
|
|
|
std::map<std::string, std::vector<uint8_t>> Custom::get_attr() const { |
|
|
|
std::map<std::string, std::vector<uint8_t>> attrs; |
|
|
|
auto value_ptr_list = GetValue<ValuePtrList>(this->GetAttr(kAttr)); |
|
|
|
for (size_t i = 0; i < value_ptr_list.size(); i += 2) { |
|
|
|
auto key = GetValue<std::string>(value_ptr_list[i]); |
|
|
|
auto value = GetValue<std::vector<uint8_t>>(value_ptr_list[i + 1]); |
|
|
|
attrs[key] = value; |
|
|
|
} |
|
|
|
return attrs; |
|
|
|
} |
|
|
|
REGISTER_PRIMITIVE_C(kNameCustom, Custom); |
|
|
|
} // namespace ops |
|
|
|
|