Browse Source

fix get num classes of concat

tags/v1.1.0
YangLuo 5 years ago
parent
commit
df9f4f41e8
3 changed files with 35 additions and 0 deletions
  1. +15
    -0
      mindspore/ccsrc/minddata/dataset/engine/datasetops/concat_op.cc
  2. +5
    -0
      mindspore/ccsrc/minddata/dataset/engine/datasetops/concat_op.h
  3. +15
    -0
      tests/ut/python/dataset/test_datasets_manifestop.py

+ 15
- 0
mindspore/ccsrc/minddata/dataset/engine/datasetops/concat_op.cc View File

@@ -196,5 +196,20 @@ Status ConcatOp::PreAccept(NodePass *p, bool *modified) {
return p->PreRunOnNode(shared_from_base<ConcatOp>(), modified);
}

// Gets the number of classes
Status ConcatOp::GetNumClasses(int64_t *num_classes) {
int64_t max_num_classes = -1;
for (const auto &child : child_) {
// Choose a dataset which can get valid num_classes
int64_t tmp_num_classes = -1;
child->GetNumClasses(&tmp_num_classes);
if (tmp_num_classes > max_num_classes) {
max_num_classes = tmp_num_classes;
}
}
*num_classes = max_num_classes;
return Status::OK();
}

} // namespace dataset
} // namespace mindspore

+ 5
- 0
mindspore/ccsrc/minddata/dataset/engine/datasetops/concat_op.h View File

@@ -111,6 +111,11 @@ class ConcatOp : public PipelineOp {
/// \return Status of the node visit
Status PreAccept(NodePass *p, bool *modified) override;

/// \brief Gets the number of classes
/// \param[out] num_classes the number of classes
/// \return Status - The status code return
Status GetNumClasses(int64_t *num_classes) override;

private:
Status Verify(int32_t id, const std::unique_ptr<DataBuffer> &buf);



+ 15
- 0
tests/ut/python/dataset/test_datasets_manifestop.py View File

@@ -113,6 +113,20 @@ def test_manifest_dataset_multi_label_onehot():
count = count + 1


def test_manifest_dataset_get_num_class():
data = ds.ManifestDataset(DATA_FILE, decode=True, shuffle=False)
assert data.num_classes() == 3

padded_samples = [{'image': np.zeros(1, np.uint8), 'label': np.array(1, np.int32)}]
padded_ds = ds.PaddedDataset(padded_samples)

data = data.repeat(2)
padded_ds = padded_ds.repeat(2)

data1 = data + padded_ds
assert data1.num_classes() == 3


if __name__ == '__main__':
test_manifest_dataset_train()
test_manifest_dataset_eval()
@@ -120,3 +134,4 @@ if __name__ == '__main__':
test_manifest_dataset_get_class_index()
test_manifest_dataset_multi_label()
test_manifest_dataset_multi_label_onehot()
test_manifest_dataset_get_num_class()

Loading…
Cancel
Save