diff --git a/tests/ut/python/dataset/test_graphdata.py b/tests/ut/python/dataset/test_graphdata.py index 67aa42cb25..35c6d02fc7 100644 --- a/tests/ut/python/dataset/test_graphdata.py +++ b/tests/ut/python/dataset/test_graphdata.py @@ -22,7 +22,7 @@ DATASET_FILE = "../data/mindrecord/testGraphData/testdata" def test_graphdata_getfullneighbor(): g = ds.GraphData(DATASET_FILE, 2) nodes = g.get_all_nodes(1) - assert len(nodes) is 10 + assert len(nodes) == 10 nodes_list = nodes.tolist() neighbor = g.get_all_neighbors(nodes_list, 2) assert neighbor.shape == (10, 6)