|
|
|
@@ -12,10 +12,10 @@ |
|
|
|
# See the License for the specific language governing permissions and |
|
|
|
# limitations under the License. |
|
|
|
# ============================================================================== |
|
|
|
from util import save_and_check |
|
|
|
|
|
|
|
import mindspore.dataset as ds |
|
|
|
from mindspore import log as logger |
|
|
|
from util import save_and_check |
|
|
|
|
|
|
|
|
|
|
|
DATA_DIR = ["../data/dataset/testTFTestAllTypes/test.data"] |
|
|
|
SCHEMA_DIR = "../data/dataset/testTFTestAllTypes/datasetSchema.json" |
|
|
|
@@ -24,7 +24,7 @@ COLUMNS = ["col_1d", "col_2d", "col_3d", "col_binary", "col_float", |
|
|
|
GENERATE_GOLDEN = False |
|
|
|
|
|
|
|
|
|
|
|
def skip_test_case_0(): |
|
|
|
def test_2ops_repeat_shuffle(): |
|
|
|
""" |
|
|
|
Test Repeat then Shuffle |
|
|
|
""" |
|
|
|
@@ -43,11 +43,11 @@ def skip_test_case_0(): |
|
|
|
ds.config.set_seed(seed) |
|
|
|
data1 = data1.shuffle(buffer_size=buffer_size) |
|
|
|
|
|
|
|
filename = "test_case_0_result.npz" |
|
|
|
filename = "test_2ops_repeat_shuffle.npz" |
|
|
|
save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN) |
|
|
|
|
|
|
|
|
|
|
|
def skip_test_case_0_reverse(): |
|
|
|
def skip_test_2ops_shuffle_repeat(): |
|
|
|
""" |
|
|
|
Test Shuffle then Repeat |
|
|
|
""" |
|
|
|
@@ -67,11 +67,11 @@ def skip_test_case_0_reverse(): |
|
|
|
data1 = data1.shuffle(buffer_size=buffer_size) |
|
|
|
data1 = data1.repeat(repeat_count) |
|
|
|
|
|
|
|
filename = "test_case_0_reverse_result.npz" |
|
|
|
filename = "test_2ops_shuffle_repeat.npz" |
|
|
|
save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN) |
|
|
|
|
|
|
|
|
|
|
|
def test_case_1(): |
|
|
|
def test_2ops_repeat_batch(): |
|
|
|
""" |
|
|
|
Test Repeat then Batch |
|
|
|
""" |
|
|
|
@@ -87,11 +87,11 @@ def test_case_1(): |
|
|
|
data1 = data1.repeat(repeat_count) |
|
|
|
data1 = data1.batch(batch_size, drop_remainder=True) |
|
|
|
|
|
|
|
filename = "test_case_1_result.npz" |
|
|
|
filename = "test_2ops_repeat_batch.npz" |
|
|
|
save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN) |
|
|
|
|
|
|
|
|
|
|
|
def test_case_1_reverse(): |
|
|
|
def test_2ops_batch_repeat(): |
|
|
|
""" |
|
|
|
Test Batch then Repeat |
|
|
|
""" |
|
|
|
@@ -107,11 +107,11 @@ def test_case_1_reverse(): |
|
|
|
data1 = data1.batch(batch_size, drop_remainder=True) |
|
|
|
data1 = data1.repeat(repeat_count) |
|
|
|
|
|
|
|
filename = "test_case_1_reverse_result.npz" |
|
|
|
filename = "test_2ops_batch_repeat.npz" |
|
|
|
save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN) |
|
|
|
|
|
|
|
|
|
|
|
def test_case_2(): |
|
|
|
def test_2ops_batch_shuffle(): |
|
|
|
""" |
|
|
|
Test Batch then Shuffle |
|
|
|
""" |
|
|
|
@@ -130,11 +130,11 @@ def test_case_2(): |
|
|
|
ds.config.set_seed(seed) |
|
|
|
data1 = data1.shuffle(buffer_size=buffer_size) |
|
|
|
|
|
|
|
filename = "test_case_2_result.npz" |
|
|
|
filename = "test_2ops_batch_shuffle.npz" |
|
|
|
save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN) |
|
|
|
|
|
|
|
|
|
|
|
def test_case_2_reverse(): |
|
|
|
def test_2ops_shuffle_batch(): |
|
|
|
""" |
|
|
|
Test Shuffle then Batch |
|
|
|
""" |
|
|
|
@@ -153,5 +153,14 @@ def test_case_2_reverse(): |
|
|
|
data1 = data1.shuffle(buffer_size=buffer_size) |
|
|
|
data1 = data1.batch(batch_size, drop_remainder=True) |
|
|
|
|
|
|
|
filename = "test_case_2_reverse_result.npz" |
|
|
|
filename = "test_2ops_shuffle_batch.npz" |
|
|
|
save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN) |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
test_2ops_repeat_shuffle() |
|
|
|
#test_2ops_shuffle_repeat() |
|
|
|
test_2ops_repeat_batch() |
|
|
|
test_2ops_batch_repeat() |
|
|
|
test_2ops_batch_shuffle() |
|
|
|
test_2ops_shuffle_batch() |