|
|
|
@@ -29,6 +29,37 @@ void Split::SetSizeSplits(const std::vector<int> &size_splits) { |
|
|
|
} |
|
|
|
void Split::SetSplitDim(int split_dim) { this->primitive_->value.AsSplit()->splitDim = split_dim; } |
|
|
|
|
|
|
|
int Split::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { |
|
|
|
if (this->primitive_ == nullptr) { |
|
|
|
this->primitive_ = new (std::nothrow) schema::PrimitiveT; |
|
|
|
if (this->primitive_ == nullptr) { |
|
|
|
MS_LOG(ERROR) << "new primitiveT failed"; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
this->primitive_->value.type = schema::PrimitiveType_Split; |
|
|
|
} |
|
|
|
if (this->primitive_->value.type != schema::PrimitiveType_Split) { |
|
|
|
MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
if (this->primitive_->value.value == nullptr) { |
|
|
|
auto attr = new (std::nothrow) schema::SplitT(); |
|
|
|
if (attr == nullptr) { |
|
|
|
MS_LOG(ERROR) << "new primitiveT value failed"; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
attr->splitDim = GetValue<int32_t>(prim.GetAttr("axis")); |
|
|
|
attr->numberSplit = GetValue<int32_t>(prim.GetAttr("output_num")); |
|
|
|
this->primitive_->value.value = attr; |
|
|
|
if (this->primitive_->value.value == nullptr) { |
|
|
|
MS_LOG(ERROR) << "primitive value is nullptr"; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
|
|
|
|
#else |
|
|
|
|
|
|
|
int Split::GetNumberSplit() const { return this->primitive_->value_as_Split()->numberSplit(); } |
|
|
|
@@ -99,12 +130,14 @@ int Split::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outpu |
|
|
|
output_shape.insert(output_shape.begin(), input_shape.begin(), input_shape.end()); |
|
|
|
int split_dim_i = input_shape[split_dim]; |
|
|
|
// support split size is -1 in the end. |
|
|
|
if (i == number_split - 1 && size_split[i] == -1) { |
|
|
|
if (size_split.empty()) { |
|
|
|
split_dim_i = input_shape[split_dim] / number_split; |
|
|
|
} else if (i == number_split - 1 && size_split[i] == -1) { |
|
|
|
for (size_t j = 0; j < size_split.size() - 1; ++j) { |
|
|
|
split_dim_i -= size_split[j]; |
|
|
|
} |
|
|
|
} else { |
|
|
|
split_dim_i = size_split.empty() ? input_shape[split_dim] / number_split : size_split[i]; |
|
|
|
split_dim_i = size_split[i]; |
|
|
|
} |
|
|
|
output_shape[split_dim] = split_dim_i; |
|
|
|
outputs_[i]->set_shape(output_shape); |
|
|
|
|