| @@ -27,17 +27,34 @@ namespace dataset { | |||
| SentencePieceTokenizerOp::SentencePieceTokenizerOp(const std::shared_ptr<SentencePieceVocab> vocab, | |||
| const SPieceTokenizerLoadType load_type, | |||
| const SPieceTokenizerOutType out_type) | |||
| : vocab_(vocab), load_type_(load_type), out_type_(out_type) {} | |||
| : vocab_(vocab), load_type_(load_type), out_type_(out_type) { | |||
| auto status = processor_.LoadFromSerializedProto(vocab_.get()->model_proto()); | |||
| if (!status.ok()) { | |||
| model_status_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "parser vocab model filed."); | |||
| } else { | |||
| model_status_ = Status::OK(); | |||
| } | |||
| } | |||
| SentencePieceTokenizerOp::SentencePieceTokenizerOp(const std::string &model_path, const std::string &model_filename, | |||
| const SPieceTokenizerLoadType load_type, | |||
| const SPieceTokenizerOutType out_type) | |||
| : load_type_(load_type), out_type_(out_type) { | |||
| (void)GetModelRealPath(model_path, model_filename); | |||
| auto status = processor_.Load(file_path_); | |||
| if (!status.ok()) { | |||
| model_status_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "load vocab model filed."); | |||
| } else { | |||
| model_status_ = Status::OK(); | |||
| } | |||
| } | |||
| Status SentencePieceTokenizerOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) { | |||
| IO_CHECK(input, output); | |||
| if (!model_status_.IsOk()) { | |||
| return model_status_; | |||
| } | |||
| if (input->Rank() != 0 || input->type() != DataType::DE_STRING) { | |||
| RETURN_STATUS_UNEXPECTED("the input tensor should be scalar string tensor"); | |||
| } | |||
| @@ -45,18 +62,6 @@ Status SentencePieceTokenizerOp::Compute(const std::shared_ptr<Tensor> &input, s | |||
| std::string_view sentence_v; | |||
| RETURN_IF_NOT_OK(input->GetItemAt(&sentence_v, {})); | |||
| std::string sentence{sentence_v}; | |||
| if (load_type_ == SPieceTokenizerLoadType::kFile) { | |||
| auto status = processor_.Load(file_path_); | |||
| if (!status.ok()) { | |||
| RETURN_STATUS_UNEXPECTED("load sentence piece model failed."); | |||
| } | |||
| } else { | |||
| RETURN_UNEXPECTED_IF_NULL(vocab_); | |||
| auto status = processor_.LoadFromSerializedProto(vocab_.get()->model_proto()); | |||
| if (!status.ok()) { | |||
| RETURN_STATUS_UNEXPECTED("sentence piece load model failed."); | |||
| } | |||
| } | |||
| if (out_type_ == SPieceTokenizerOutType::kString) { | |||
| std::vector<std::string> pieces; | |||
| @@ -58,6 +58,7 @@ class SentencePieceTokenizerOp : public TensorOp { | |||
| std::string file_path_; | |||
| SPieceTokenizerLoadType load_type_; | |||
| sentencepiece::SentencePieceProcessor processor_; | |||
| Status model_status_; | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -12,7 +12,7 @@ | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| import copy | |||
| import mindspore.dataset.text as text | |||
| import mindspore.dataset as ds | |||
| from mindspore.dataset.text import SentencePieceModel, to_str, SPieceTokenizerOutType | |||
| @@ -121,6 +121,44 @@ def test_build_from_dataset(): | |||
| assert value == expect[key] | |||
| def apply_func(dataset): | |||
| input_columns = ['text'] | |||
| output_columns = ['text2'] | |||
| dataset = dataset.rename(input_columns, output_columns) | |||
| return dataset | |||
| def zip_test(dataset): | |||
| dataset_1 = copy.deepcopy(dataset) | |||
| dataset_2 = copy.deepcopy(dataset) | |||
| dataset_1 = dataset_1.apply(apply_func) | |||
| dataset_zip = ds.zip((dataset_1, dataset_2)) | |||
| expect = ['▁I', '▁sa', 'w', '▁a', '▁girl', '▁with', '▁a', '▁te', 'les', 'co', 'pe', '.'] | |||
| for i in dataset_zip.create_dict_iterator(): | |||
| ret = to_str(i["text"]) | |||
| for key, value in enumerate(ret): | |||
| assert value == expect[key] | |||
| def concat_test(dataset): | |||
| dataset_1 = copy.deepcopy(dataset) | |||
| dataset = dataset.concat(dataset_1) | |||
| expect = ['▁I', '▁sa', 'w', '▁a', '▁girl', '▁with', '▁a', '▁te', 'les', 'co', 'pe', '.'] | |||
| for i in dataset.create_dict_iterator(): | |||
| ret = to_str(i["text"]) | |||
| for key, value in enumerate(ret): | |||
| assert value == expect[key] | |||
| def test_with_zip_concat(): | |||
| data = ds.TextFileDataset(VOCAB_FILE, shuffle=False) | |||
| vocab = text.SentencePieceVocab.from_dataset(data, [""], 5000, 0.9995, SentencePieceModel.UNIGRAM, {}) | |||
| tokenizer = text.SentencePieceTokenizer(vocab, out_type=SPieceTokenizerOutType.STRING) | |||
| dataset = ds.TextFileDataset(DATA_FILE, shuffle=False) | |||
| dataset = dataset.map(operations=tokenizer, num_parallel_workers=2) | |||
| zip_test(dataset) | |||
| concat_test(dataset) | |||
| if __name__ == "__main__": | |||
| test_from_vocab_to_str_UNIGRAM() | |||
| test_from_vocab_to_str_BPE() | |||
| @@ -130,3 +168,4 @@ if __name__ == "__main__": | |||
| test_from_file_to_str() | |||
| test_from_file_to_int() | |||
| test_build_from_dataset() | |||
| test_with_zip_concat() | |||