Browse Source

!929 Cleanup dataset UT: test_batch, save_and_check support

Merge pull request !929 from cathwong/ckw_dataset_ut_cleanup
tags/v0.3.0-alpha
mindspore-ci-bot Gitee 5 years ago
parent
commit
cae8a921e3
10 changed files with 67 additions and 77 deletions
  1. BIN
      tests/ut/data/dataset/golden/batch_12_result.npz
  2. +32
    -19
      tests/ut/python/dataset/test_batch.py
  3. +5
    -6
      tests/ut/python/dataset/test_center_crop.py
  4. +0
    -2
      tests/ut/python/dataset/test_decode.py
  5. +1
    -1
      tests/ut/python/dataset/test_random_color_adjust.py
  6. +1
    -1
      tests/ut/python/dataset/test_random_erasing.py
  7. +0
    -2
      tests/ut/python/dataset/test_random_resize.py
  8. +5
    -5
      tests/ut/python/dataset/test_type_cast.py
  9. +6
    -12
      tests/ut/python/dataset/test_zip.py
  10. +17
    -29
      tests/ut/python/dataset/util.py

BIN
tests/ut/data/dataset/golden/batch_12_result.npz View File


+ 32
- 19
tests/ut/python/dataset/test_batch.py View File

@@ -37,6 +37,7 @@ def test_batch_01():
data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES) data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES)
data1 = data1.batch(batch_size, drop_remainder) data1 = data1.batch(batch_size, drop_remainder)


assert sum([1 for _ in data1]) == 6
filename = "batch_01_result.npz" filename = "batch_01_result.npz"
save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN) save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)


@@ -56,6 +57,7 @@ def test_batch_02():
data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES) data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES)
data1 = data1.batch(batch_size, drop_remainder=drop_remainder) data1 = data1.batch(batch_size, drop_remainder=drop_remainder)


assert sum([1 for _ in data1]) == 2
filename = "batch_02_result.npz" filename = "batch_02_result.npz"
save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN) save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)


@@ -75,6 +77,7 @@ def test_batch_03():
data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES) data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES)
data1 = data1.batch(batch_size=batch_size, drop_remainder=drop_remainder) data1 = data1.batch(batch_size=batch_size, drop_remainder=drop_remainder)


assert sum([1 for _ in data1]) == 4
filename = "batch_03_result.npz" filename = "batch_03_result.npz"
save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN) save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)


@@ -94,6 +97,7 @@ def test_batch_04():
data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES) data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES)
data1 = data1.batch(batch_size, drop_remainder) data1 = data1.batch(batch_size, drop_remainder)


assert sum([1 for _ in data1]) == 2
filename = "batch_04_result.npz" filename = "batch_04_result.npz"
save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN) save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)


@@ -111,6 +115,7 @@ def test_batch_05():
data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES) data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES)
data1 = data1.batch(batch_size) data1 = data1.batch(batch_size)


assert sum([1 for _ in data1]) == 12
filename = "batch_05_result.npz" filename = "batch_05_result.npz"
save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN) save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)


@@ -130,6 +135,7 @@ def test_batch_06():
data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES) data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES)
data1 = data1.batch(drop_remainder=drop_remainder, batch_size=batch_size) data1 = data1.batch(drop_remainder=drop_remainder, batch_size=batch_size)


assert sum([1 for _ in data1]) == 1
filename = "batch_06_result.npz" filename = "batch_06_result.npz"
save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN) save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)


@@ -152,6 +158,7 @@ def test_batch_07():
data1 = data1.batch(num_parallel_workers=num_parallel_workers, drop_remainder=drop_remainder, data1 = data1.batch(num_parallel_workers=num_parallel_workers, drop_remainder=drop_remainder,
batch_size=batch_size) batch_size=batch_size)


assert sum([1 for _ in data1]) == 3
filename = "batch_07_result.npz" filename = "batch_07_result.npz"
save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN) save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)


@@ -171,6 +178,7 @@ def test_batch_08():
data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES) data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES)
data1 = data1.batch(batch_size, num_parallel_workers=num_parallel_workers) data1 = data1.batch(batch_size, num_parallel_workers=num_parallel_workers)


assert sum([1 for _ in data1]) == 2
filename = "batch_08_result.npz" filename = "batch_08_result.npz"
save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN) save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)


@@ -190,6 +198,7 @@ def test_batch_09():
data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES) data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES)
data1 = data1.batch(batch_size, drop_remainder=drop_remainder) data1 = data1.batch(batch_size, drop_remainder=drop_remainder)


assert sum([1 for _ in data1]) == 1
filename = "batch_09_result.npz" filename = "batch_09_result.npz"
save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN) save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)


@@ -209,6 +218,7 @@ def test_batch_10():
data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES) data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES)
data1 = data1.batch(batch_size, drop_remainder=drop_remainder) data1 = data1.batch(batch_size, drop_remainder=drop_remainder)


assert sum([1 for _ in data1]) == 0
filename = "batch_10_result.npz" filename = "batch_10_result.npz"
save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN) save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)


@@ -228,10 +238,30 @@ def test_batch_11():
data1 = ds.TFRecordDataset(DATA_DIR, schema_file) data1 = ds.TFRecordDataset(DATA_DIR, schema_file)
data1 = data1.batch(batch_size) data1 = data1.batch(batch_size)


assert sum([1 for _ in data1]) == 1
filename = "batch_11_result.npz" filename = "batch_11_result.npz"
save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN) save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)




def test_batch_12():
"""
Test batch: batch_size boolean value True, treated as valid value 1
"""
logger.info("test_batch_12")
# define parameters
batch_size = True
parameters = {"params": {'batch_size': batch_size}}

# apply dataset operations
data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES)
data1 = data1.batch(batch_size=batch_size)

assert sum([1 for _ in data1]) == 12
filename = "batch_12_result.npz"
save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)



def test_batch_exception_01(): def test_batch_exception_01():
""" """
Test batch exception: num_parallel_workers=0 Test batch exception: num_parallel_workers=0
@@ -302,7 +332,7 @@ def test_batch_exception_04():


def test_batch_exception_05(): def test_batch_exception_05():
""" """
Test batch exception: batch_size wrong type, boolean value False
Test batch exception: batch_size boolean value False, treated as invalid value 0
""" """
logger.info("test_batch_exception_05") logger.info("test_batch_exception_05")


@@ -317,23 +347,6 @@ def test_batch_exception_05():
assert "batch_size" in str(e) assert "batch_size" in str(e)




def skip_test_batch_exception_06():
"""
Test batch exception: batch_size wrong type, boolean value True
"""
logger.info("test_batch_exception_06")

# apply dataset operations
data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES)
try:
data1 = data1.batch(batch_size=True)
sum([1 for _ in data1])

except BaseException as e:
logger.info("Got an exception in DE: {}".format(str(e)))
assert "batch_size" in str(e)


def test_batch_exception_07(): def test_batch_exception_07():
""" """
Test batch exception: drop_remainder wrong type Test batch exception: drop_remainder wrong type
@@ -473,12 +486,12 @@ if __name__ == '__main__':
test_batch_09() test_batch_09()
test_batch_10() test_batch_10()
test_batch_11() test_batch_11()
test_batch_12()
test_batch_exception_01() test_batch_exception_01()
test_batch_exception_02() test_batch_exception_02()
test_batch_exception_03() test_batch_exception_03()
test_batch_exception_04() test_batch_exception_04()
test_batch_exception_05() test_batch_exception_05()
skip_test_batch_exception_06()
test_batch_exception_07() test_batch_exception_07()
test_batch_exception_08() test_batch_exception_08()
test_batch_exception_09() test_batch_exception_09()


+ 5
- 6
tests/ut/python/dataset/test_center_crop.py View File

@@ -28,7 +28,7 @@ SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"


def test_center_crop_op(height=375, width=375, plot=False): def test_center_crop_op(height=375, width=375, plot=False):
""" """
Test random_vertical
Test CenterCrop
""" """
logger.info("Test CenterCrop") logger.info("Test CenterCrop")


@@ -55,7 +55,7 @@ def test_center_crop_op(height=375, width=375, plot=False):


def test_center_crop_md5(height=375, width=375): def test_center_crop_md5(height=375, width=375):
""" """
Test random_vertical
Test CenterCrop
""" """
logger.info("Test CenterCrop") logger.info("Test CenterCrop")


@@ -69,13 +69,12 @@ def test_center_crop_md5(height=375, width=375):
# expected md5 from images # expected md5 from images


filename = "test_center_crop_01_result.npz" filename = "test_center_crop_01_result.npz"
parameters = {"params": {}}
save_and_check_md5(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
save_and_check_md5(data1, filename, generate_golden=GENERATE_GOLDEN)




def test_center_crop_comp(height=375, width=375, plot=False): def test_center_crop_comp(height=375, width=375, plot=False):
""" """
Test random_vertical between python and c image augmentation
Test CenterCrop between python and c image augmentation
""" """
logger.info("Test CenterCrop") logger.info("Test CenterCrop")


@@ -114,4 +113,4 @@ if __name__ == "__main__":
test_center_crop_op(300, 600) test_center_crop_op(300, 600)
test_center_crop_op(600, 300) test_center_crop_op(600, 300)
test_center_crop_md5(600, 600) test_center_crop_md5(600, 600)
test_center_crop_comp()
test_center_crop_comp()

+ 0
- 2
tests/ut/python/dataset/test_decode.py View File

@@ -18,9 +18,7 @@ Testing Decode op in DE
import cv2 import cv2
import mindspore.dataset.transforms.vision.c_transforms as vision import mindspore.dataset.transforms.vision.c_transforms as vision
import numpy as np import numpy as np

import mindspore.dataset as ds import mindspore.dataset as ds
import mindspore.dataset.transforms.vision.c_transforms as vision
from mindspore import log as logger from mindspore import log as logger


DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"] DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]


+ 1
- 1
tests/ut/python/dataset/test_random_color_adjust.py View File

@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
""" """
Testing RandomRotation op in DE
Testing RandomColorAdjust op in DE
""" """
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import mindspore.dataset.transforms.vision.c_transforms as c_vision import mindspore.dataset.transforms.vision.c_transforms as c_vision


+ 1
- 1
tests/ut/python/dataset/test_random_erasing.py View File

@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
""" """
Testing RandomRotation op in DE
Testing RandomErasing op in DE
""" """
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np


+ 0
- 2
tests/ut/python/dataset/test_random_resize.py View File

@@ -17,9 +17,7 @@ Testing the resize op in DE
""" """
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import mindspore.dataset.transforms.vision.c_transforms as vision import mindspore.dataset.transforms.vision.c_transforms as vision

import mindspore.dataset as ds import mindspore.dataset as ds
import mindspore.dataset.transforms.vision.c_transforms as vision
from mindspore import log as logger from mindspore import log as logger


DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"] DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]


+ 5
- 5
tests/ut/python/dataset/test_type_cast.py View File

@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
""" """
Testing RandomRotation op in DE
Testing TypeCast op in DE
""" """
import mindspore.dataset.transforms.vision.c_transforms as c_vision import mindspore.dataset.transforms.vision.c_transforms as c_vision
import mindspore.dataset.transforms.vision.py_transforms as py_vision import mindspore.dataset.transforms.vision.py_transforms as py_vision
@@ -31,9 +31,9 @@ SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"


def test_type_cast(): def test_type_cast():
""" """
Test type_cast_op
Test TypeCast op
""" """
logger.info("test_type_cast_op")
logger.info("test_type_cast")


# First dataset # First dataset
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
@@ -71,9 +71,9 @@ def test_type_cast():


def test_type_cast_string(): def test_type_cast_string():
""" """
Test type_cast_op
Test TypeCast op
""" """
logger.info("test_type_cast_op")
logger.info("test_type_cast_string")


# First dataset # First dataset
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)


+ 6
- 12
tests/ut/python/dataset/test_zip.py View File

@@ -44,8 +44,7 @@ def test_zip_01():
dataz = ds.zip((data1, data2)) dataz = ds.zip((data1, data2))
# Note: zipped dataset has 5 rows and 7 columns # Note: zipped dataset has 5 rows and 7 columns
filename = "zip_01_result.npz" filename = "zip_01_result.npz"
parameters = {"params": {}}
save_and_check_dict(dataz, parameters, filename, generate_golden=GENERATE_GOLDEN)
save_and_check_dict(dataz, filename, generate_golden=GENERATE_GOLDEN)




def test_zip_02(): def test_zip_02():
@@ -59,8 +58,7 @@ def test_zip_02():
dataz = ds.zip((data1, data2)) dataz = ds.zip((data1, data2))
# Note: zipped dataset has 3 rows and 4 columns # Note: zipped dataset has 3 rows and 4 columns
filename = "zip_02_result.npz" filename = "zip_02_result.npz"
parameters = {"params": {}}
save_and_check_dict(dataz, parameters, filename, generate_golden=GENERATE_GOLDEN)
save_and_check_dict(dataz, filename, generate_golden=GENERATE_GOLDEN)




def test_zip_03(): def test_zip_03():
@@ -74,8 +72,7 @@ def test_zip_03():
dataz = ds.zip((data1, data2)) dataz = ds.zip((data1, data2))
# Note: zipped dataset has 3 rows and 7 columns # Note: zipped dataset has 3 rows and 7 columns
filename = "zip_03_result.npz" filename = "zip_03_result.npz"
parameters = {"params": {}}
save_and_check_dict(dataz, parameters, filename, generate_golden=GENERATE_GOLDEN)
save_and_check_dict(dataz, filename, generate_golden=GENERATE_GOLDEN)




def test_zip_04(): def test_zip_04():
@@ -90,8 +87,7 @@ def test_zip_04():
dataz = ds.zip((data1, data2, data3)) dataz = ds.zip((data1, data2, data3))
# Note: zipped dataset has 3 rows and 9 columns # Note: zipped dataset has 3 rows and 9 columns
filename = "zip_04_result.npz" filename = "zip_04_result.npz"
parameters = {"params": {}}
save_and_check_dict(dataz, parameters, filename, generate_golden=GENERATE_GOLDEN)
save_and_check_dict(dataz, filename, generate_golden=GENERATE_GOLDEN)




def test_zip_05(): def test_zip_05():
@@ -109,8 +105,7 @@ def test_zip_05():
dataz = ds.zip((data1, data2)) dataz = ds.zip((data1, data2))
# Note: zipped dataset has 5 rows and 9 columns # Note: zipped dataset has 5 rows and 9 columns
filename = "zip_05_result.npz" filename = "zip_05_result.npz"
parameters = {"params": {}}
save_and_check_dict(dataz, parameters, filename, generate_golden=GENERATE_GOLDEN)
save_and_check_dict(dataz, filename, generate_golden=GENERATE_GOLDEN)




def test_zip_06(): def test_zip_06():
@@ -129,8 +124,7 @@ def test_zip_06():
dataz = dataz.repeat(2) dataz = dataz.repeat(2)
# Note: resultant dataset has 10 rows and 9 columns # Note: resultant dataset has 10 rows and 9 columns
filename = "zip_06_result.npz" filename = "zip_06_result.npz"
parameters = {"params": {}}
save_and_check_dict(dataz, parameters, filename, generate_golden=GENERATE_GOLDEN)
save_and_check_dict(dataz, filename, generate_golden=GENERATE_GOLDEN)




def test_zip_exception_01(): def test_zip_exception_01():


+ 17
- 29
tests/ut/python/dataset/util.py View File

@@ -15,15 +15,15 @@


import json import json
import os import os
import hashlib
import numpy as np import numpy as np
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import hashlib

#import jsbeautifier #import jsbeautifier
from mindspore import log as logger from mindspore import log as logger


COLUMNS = ["col_1d", "col_2d", "col_3d", "col_binary", "col_float", COLUMNS = ["col_1d", "col_2d", "col_3d", "col_binary", "col_float",
"col_sint16", "col_sint32", "col_sint64"] "col_sint16", "col_sint32", "col_sint64"]
SAVE_JSON = False




def save_golden(cur_dir, golden_ref_dir, result_dict): def save_golden(cur_dir, golden_ref_dir, result_dict):
@@ -44,15 +44,6 @@ def save_golden_dict(cur_dir, golden_ref_dir, result_dict):
np.savez(golden_ref_dir, np.array(list(result_dict.items()))) np.savez(golden_ref_dir, np.array(list(result_dict.items())))




def save_golden_md5(cur_dir, golden_ref_dir, result_dict):
"""
Save the dictionary (both keys and values) as the golden result in .npz file
"""
logger.info("cur_dir is {}".format(cur_dir))
logger.info("golden_ref_dir is {}".format(golden_ref_dir))
np.savez(golden_ref_dir, np.array(list(result_dict.items())))


def compare_to_golden(golden_ref_dir, result_dict): def compare_to_golden(golden_ref_dir, result_dict):
""" """
Compare as numpy arrays the test result to the golden result Compare as numpy arrays the test result to the golden result
@@ -67,7 +58,7 @@ def compare_to_golden_dict(golden_ref_dir, result_dict):
Compare as dictionaries the test result to the golden result Compare as dictionaries the test result to the golden result
""" """
golden_array = np.load(golden_ref_dir, allow_pickle=True)['arr_0'] golden_array = np.load(golden_ref_dir, allow_pickle=True)['arr_0']
np.testing.assert_equal (result_dict, dict(golden_array))
np.testing.assert_equal(result_dict, dict(golden_array))
# assert result_dict == dict(golden_array) # assert result_dict == dict(golden_array)




@@ -83,7 +74,6 @@ def save_json(filename, parameters, result_dict):
fout.write(jsbeautifier.beautify(json.dumps(out_dict), options)) fout.write(jsbeautifier.beautify(json.dumps(out_dict), options))





def save_and_check(data, parameters, filename, generate_golden=False): def save_and_check(data, parameters, filename, generate_golden=False):
""" """
Save the dataset dictionary and compare (as numpy array) with golden file. Save the dataset dictionary and compare (as numpy array) with golden file.
@@ -111,11 +101,12 @@ def save_and_check(data, parameters, filename, generate_golden=False):


compare_to_golden(golden_ref_dir, result_dict) compare_to_golden(golden_ref_dir, result_dict)


# Save to a json file for inspection
# save_json(filename, parameters, result_dict)
if SAVE_JSON:
# Save result to a json file for inspection
save_json(filename, parameters, result_dict)




def save_and_check_dict(data, parameters, filename, generate_golden=False):
def save_and_check_dict(data, filename, generate_golden=False):
""" """
Save the dataset dictionary and compare (as dictionary) with golden file. Save the dataset dictionary and compare (as dictionary) with golden file.
Use create_dict_iterator to access the dataset. Use create_dict_iterator to access the dataset.
@@ -140,11 +131,13 @@ def save_and_check_dict(data, parameters, filename, generate_golden=False):


compare_to_golden_dict(golden_ref_dir, result_dict) compare_to_golden_dict(golden_ref_dir, result_dict)


# Save to a json file for inspection
# save_json(filename, parameters, result_dict)
if SAVE_JSON:
# Save result to a json file for inspection
parameters = {"params": {}}
save_json(filename, parameters, result_dict)




def save_and_check_md5(data, parameters, filename, generate_golden=False):
def save_and_check_md5(data, filename, generate_golden=False):
""" """
Save the dataset dictionary and compare (as dictionary) with golden file (md5). Save the dataset dictionary and compare (as dictionary) with golden file (md5).
Use create_dict_iterator to access the dataset. Use create_dict_iterator to access the dataset.
@@ -197,8 +190,9 @@ def ordered_save_and_check(data, parameters, filename, generate_golden=False):


compare_to_golden(golden_ref_dir, result_dict) compare_to_golden(golden_ref_dir, result_dict)


# Save to a json file for inspection
# save_json(filename, parameters, result_dict)
if SAVE_JSON:
# Save result to a json file for inspection
save_json(filename, parameters, result_dict)




def diff_mse(in1, in2): def diff_mse(in1, in2):
@@ -211,24 +205,18 @@ def diff_me(in1, in2):
return mse / 255 * 100 return mse / 255 * 100




def diff_ssim(in1, in2):
from skimage.measure import compare_ssim as ssim
val = ssim(in1, in2, multichannel=True)
return (1 - val) * 100


def visualize(image_original, image_transformed): def visualize(image_original, image_transformed):
""" """
visualizes the image using DE op and Numpy op visualizes the image using DE op and Numpy op
""" """
num = len(image_cropped)
num = len(image_transformed)
for i in range(num): for i in range(num):
plt.subplot(2, num, i + 1) plt.subplot(2, num, i + 1)
plt.imshow(image_original[i]) plt.imshow(image_original[i])
plt.title("Original image") plt.title("Original image")


plt.subplot(2, num, i + num + 1) plt.subplot(2, num, i + num + 1)
plt.imshow(image_cropped[i])
plt.imshow(image_transformed[i])
plt.title("Transformed image") plt.title("Transformed image")


plt.show() plt.show()

Loading…
Cancel
Save