|
|
|
@@ -342,24 +342,25 @@ void PrimitiveC::GetAttrDataFromInput(const AnfNodePtr inputNode, std::vector<in |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
schema::PrimitiveT *PrimitiveC::GetPrimitiveT() const { return this->primitive_; } |
|
|
|
schema::PrimitiveT *PrimitiveC::primitiveT() const { return this->primitive_; } |
|
|
|
|
|
|
|
void PrimitiveC::ClearPrimitiveT() { this->primitive_ = nullptr; } |
|
|
|
|
|
|
|
void PrimitiveC::SetInputQuantParams(const std::vector<std::vector<schema::QuantParamT>> &input_quant_param) { |
|
|
|
void PrimitiveC::set_input_quant_params(const std::vector<std::vector<schema::QuantParamT>> &input_quant_param) { |
|
|
|
this->input_quant_param_ = input_quant_param; |
|
|
|
} |
|
|
|
|
|
|
|
void PrimitiveC::SetInputQuantParam(const size_t &index, const std::vector<schema::QuantParamT> &input_quant_param) { |
|
|
|
void PrimitiveC::set_input_quant_param(const size_t &index, const std::vector<schema::QuantParamT> &input_quant_param) { |
|
|
|
MS_ASSERT(index < this->input_quant_param_.size()); |
|
|
|
this->input_quant_param_[index] = input_quant_param; |
|
|
|
} |
|
|
|
|
|
|
|
void PrimitiveC::SetOutputQuantParams(const std::vector<std::vector<schema::QuantParamT>> &output_quant_param) { |
|
|
|
void PrimitiveC::set_output_quant_params(const std::vector<std::vector<schema::QuantParamT>> &output_quant_param) { |
|
|
|
this->output_quant_param_ = output_quant_param; |
|
|
|
} |
|
|
|
|
|
|
|
void PrimitiveC::SetOutputQuantParam(const size_t &index, const std::vector<schema::QuantParamT> &output_quant_param) { |
|
|
|
void PrimitiveC::set_output_quant_param(const size_t &index, |
|
|
|
const std::vector<schema::QuantParamT> &output_quant_param) { |
|
|
|
MS_ASSERT(index < this->output_quant_param_.size()); |
|
|
|
this->output_quant_param_[index] = output_quant_param; |
|
|
|
} |
|
|
|
@@ -396,16 +397,16 @@ void PrimitiveC::ClearInputOutputQuantParam() { |
|
|
|
void PrimitiveC::AddInputQuantParam(std::vector<schema::QuantParamT> quant_param) { |
|
|
|
this->input_quant_param_.emplace_back(quant_param); |
|
|
|
} |
|
|
|
std::vector<std::vector<schema::QuantParamT>> PrimitiveC::GetInputQuantParams() const { return input_quant_param_; } |
|
|
|
std::vector<std::vector<schema::QuantParamT>> PrimitiveC::input_quant_params() const { return input_quant_param_; } |
|
|
|
|
|
|
|
void PrimitiveC::AddOutputQuantParam(std::vector<schema::QuantParamT> quant_param) { |
|
|
|
this->output_quant_param_.emplace_back(quant_param); |
|
|
|
} |
|
|
|
std::vector<std::vector<schema::QuantParamT>> PrimitiveC::GetOutputQuantParams() const { return output_quant_param_; } |
|
|
|
std::vector<std::vector<schema::QuantParamT>> PrimitiveC::output_quant_params() const { return output_quant_param_; } |
|
|
|
|
|
|
|
void PrimitiveC::SetQuantType(const schema::QuantType &quant_type) { this->quant_type_ = quant_type; } |
|
|
|
void PrimitiveC::set_quant_type(const schema::QuantType &quant_type) { this->quant_type_ = quant_type; } |
|
|
|
|
|
|
|
schema::QuantType PrimitiveC::GetQuantType() const { return quant_type_; } |
|
|
|
schema::QuantType PrimitiveC::quant_type() const { return quant_type_; } |
|
|
|
|
|
|
|
std::shared_ptr<PrimitiveC> GetReturnPrim() { |
|
|
|
auto return_primitiveT = new (std::nothrow) schema::PrimitiveT; |
|
|
|
@@ -463,7 +464,7 @@ std::shared_ptr<PrimitiveC> NewPrimitiveC(const Primitive &prim, const std::vect |
|
|
|
MS_LOG(ERROR) << "make_shared PrimitiveC failed"; |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
primc->SetQuantType(quantType); |
|
|
|
primc->set_quant_type(quantType); |
|
|
|
auto ret = primc->UnPackAttr(prim, inputs); |
|
|
|
if (ret != RET_OK) { |
|
|
|
MS_LOG(ERROR) << "UnPackAttr failed"; |
|
|
|
@@ -956,8 +957,8 @@ PrimitiveC *PrimitiveC::Create(mindspore::schema::PrimitiveT *primitive) { |
|
|
|
} |
|
|
|
|
|
|
|
#else |
|
|
|
void PrimitiveC::SetQuantType(schema::QuantType quant_type) { this->quant_type_ = quant_type; } |
|
|
|
schema::QuantType PrimitiveC::GetQuantType() const { return quant_type_; } |
|
|
|
void PrimitiveC::set_quant_type(schema::QuantType quant_type) { this->quant_type_ = quant_type; } |
|
|
|
schema::QuantType PrimitiveC::quant_type() const { return quant_type_; } |
|
|
|
#endif |
|
|
|
|
|
|
|
int PrimitiveC::Type() const { |
|
|
|
@@ -970,18 +971,18 @@ int PrimitiveC::Type() const { |
|
|
|
return this->primitive_->value_type(); |
|
|
|
#endif |
|
|
|
} |
|
|
|
bool PrimitiveC::GetInferFlag() const { return this->infer_flag_; } |
|
|
|
bool PrimitiveC::infer_flag() const { return this->infer_flag_; } |
|
|
|
|
|
|
|
void PrimitiveC::SetInferFlag(bool flag) { this->infer_flag_ = flag; } |
|
|
|
void PrimitiveC::set_infer_flag(bool flag) { this->infer_flag_ = flag; } |
|
|
|
|
|
|
|
int PrimitiveC::InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) { |
|
|
|
auto input = inputs_.front(); |
|
|
|
int PrimitiveC::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lite::Tensor *> outputs) { |
|
|
|
auto input = inputs.front(); |
|
|
|
MS_ASSERT(input != nullptr); |
|
|
|
auto output = outputs_.front(); |
|
|
|
auto output = outputs.front(); |
|
|
|
MS_ASSERT(output != nullptr); |
|
|
|
output->set_shape(input->shape()); |
|
|
|
output->set_data_type(input->data_type()); |
|
|
|
output->SetFormat(input->GetFormat()); |
|
|
|
output->set_format(input->format()); |
|
|
|
return 0; |
|
|
|
} |
|
|
|
|
|
|
|
|