Browse Source

coverage increased

Former-commit-id: 78b9cf4419 [formerly 4fae91ba23] [formerly bc35de7163 [formerly 80ad6f0711]] [formerly 0a59553875 [formerly 77c8c1c97b] [formerly 9943fe3983 [formerly 17cc1dc965]]] [formerly b080325ca8 [formerly ed1923a851] [formerly cca3248818 [formerly 2fa8b461bc]] [formerly 9839af0699 [formerly e7837bf051] [formerly 53e8a29522 [formerly 13d072ef70]]]] [formerly cf31f3de0f [formerly 56c5dbd225] [formerly 2b3abaee39 [formerly 68b48fcea5]] [formerly b81be824cb [formerly e2bbf06c66] [formerly 6b1bc0cc30 [formerly f4255fb15a]]] [formerly e69b99a79e [formerly cd7f048a25] [formerly 6df4e52265 [formerly abfb9df55f]] [formerly 5be3a9705e [formerly c2011486ff] [formerly 877cb675bc [formerly 7752459ea9]]]]] [formerly 0cf354fa9d [formerly dc33779041] [formerly 6590bad12b [formerly 80237d3ecc]] [formerly 81b331c96a [formerly b6aab66068] [formerly d360858979 [formerly cfb6d69120]]] [formerly ecddc5d261 [formerly 5339a74764] [formerly 6260d92aa5 [formerly c5f4f148f2]] [formerly ce92c0bc46 [formerly 9aa600a48f] [formerly cee697ced4 [formerly b9492aa68a]]]] [formerly fdd7a1bb86 [formerly fe8b367725] [formerly 27d1123d41 [formerly 8996bcb188]] [formerly 01932b1972 [formerly d94a4b2510] [formerly 70fb868858 [formerly 1ab5c391a6]]] [formerly 3e4d4c453f [formerly 10d7b322f4] [formerly 00dc558701 [formerly 68b219a91f]] [formerly 02bfe31725 [formerly 10a485f409] [formerly e24105070d [formerly 9350908041]]]]]]
Former-commit-id: 227e50573a [formerly a1ab47525e] [formerly 8c7543f104 [formerly c7c848bd69]] [formerly 12db9cfd19 [formerly ea58ddf7fa] [formerly 4ab64ec79e [formerly f7c97007e4]]] [formerly 6db1fac2a7 [formerly 0604fbb649] [formerly a92e46780a [formerly 5c74e2a75c]] [formerly a5d7969fa8 [formerly c0baf4d885] [formerly 16fce895d4 [formerly e218021263]]]] [formerly ba8a6f3fc9 [formerly 00a7ca64b0] [formerly 97bad15e97 [formerly 07a27e87b8]] [formerly f2dcfc721c [formerly b59e422f6a] [formerly f6fdf887e6 [formerly da9d05a07b]]] [formerly 32a1abaebb [formerly 7cc5d3ff19] [formerly 10ee21e43c [formerly b79f54231a]] [formerly bfc66cfd34 [formerly 3174ab4a0f] [formerly e24105070d]]]]
Former-commit-id: 00cc696a06 [formerly 17a96c497b] [formerly e72458559d [formerly dcaba208be]] [formerly f488d1dbff [formerly d6326d12c7] [formerly a5be01fbe5 [formerly 791c056a80]]] [formerly a4f1d90f07 [formerly 1ee9514201] [formerly da2b160ba6 [formerly 5e3bf51141]] [formerly 777c555314 [formerly 4bccea9fbe] [formerly 601046354e [formerly c85f0f5fda]]]]
Former-commit-id: 70280691b9 [formerly a0a181c7b9] [formerly 8e44f3d49c [formerly ec721f0a55]] [formerly 0c45e05739 [formerly 62ca892350] [formerly 914d56242b [formerly c673775e04]]]
Former-commit-id: d88d95f696 [formerly ac232e5e48] [formerly e368f2d797 [formerly 4c897bd0af]]
Former-commit-id: 979bc388b8 [formerly fe95f5f9ab]
Former-commit-id: 2f6a3160dd
master
Devesh Kumar 5 years ago
parent
commit
b953002fa1
2 changed files with 131 additions and 15 deletions
  1. +15
    -15
      tods/data_processing/ConstructPredictions.py
  2. +116
    -0
      tods/tests/data_processing/test_ConstructPredictions.py

+ 15
- 15
tods/data_processing/ConstructPredictions.py View File

@@ -96,14 +96,14 @@ class ConstructPredictionsPrimitive(transformer.TransformerPrimitiveBase[Inputs,
return base.CallResult(outputs)

def _filter_index_columns(self, inputs_metadata: metadata_base.DataMetadata, index_columns: typing.Sequence[int]) -> typing.Sequence[int]:
if self.hyperparams['use_columns']:
if self.hyperparams['use_columns']: # pragma: no cover
index_columns = [index_column_index for index_column_index in index_columns if index_column_index in self.hyperparams['use_columns']]
if not index_columns:
raise ValueError("No index columns listed in \"use_columns\" hyper-parameter, but index columns are required.")

else:
index_columns = [index_column_index for index_column_index in index_columns if index_column_index not in self.hyperparams['exclude_columns']]
if not index_columns:
if not index_columns: # pragma: no cover
raise ValueError("All index columns listed in \"exclude_columns\" hyper-parameter, but index columns are required.")

names = []
@@ -113,11 +113,11 @@ class ConstructPredictionsPrimitive(transformer.TransformerPrimitiveBase[Inputs,
if index_metadata.get('name', None):
names.append(index_metadata['name'])

if 'd3mIndex' not in names:
if 'd3mIndex' not in names: # pragma: no cover
raise ValueError("\"d3mIndex\" index column is missing.")

names_set = set(names)
if len(names) != len(names_set):
if len(names) != len(names_set): # pragma: no cover
duplicate_names = names
for name in names_set:
# Removes just the first occurrence.
@@ -135,14 +135,14 @@ class ConstructPredictionsPrimitive(transformer.TransformerPrimitiveBase[Inputs,

index_columns = self._filter_index_columns(inputs_metadata, index_columns)

if self.hyperparams['use_columns']:
if self.hyperparams['use_columns']: # pragma: no cover
target_columns = [target_column_index for target_column_index in target_columns if target_column_index in self.hyperparams['use_columns']]
if not target_columns:
raise ValueError("No target columns listed in \"use_columns\" hyper-parameter, but target columns are required.")

else:
target_columns = [target_column_index for target_column_index in target_columns if target_column_index not in self.hyperparams['exclude_columns']]
if not target_columns:
if not target_columns: # pragma: no cover
raise ValueError("All target columns listed in \"exclude_columns\" hyper-parameter, but target columns are required.")

assert index_columns
@@ -153,7 +153,7 @@ class ConstructPredictionsPrimitive(transformer.TransformerPrimitiveBase[Inputs,
def _get_confidence_columns(self, inputs_metadata: metadata_base.DataMetadata) -> typing.List[int]:
confidence_columns = inputs_metadata.list_columns_with_semantic_types(('https://metadata.datadrivendiscovery.org/types/Confidence',))

if self.hyperparams['use_columns']:
if self.hyperparams['use_columns']:# pragma: no cover
confidence_columns = [confidence_column_index for confidence_column_index in confidence_columns if confidence_column_index in self.hyperparams['use_columns']]
else:
confidence_columns = [confidence_column_index for confidence_column_index in confidence_columns if confidence_column_index not in self.hyperparams['exclude_columns']]
@@ -176,7 +176,7 @@ class ConstructPredictionsPrimitive(transformer.TransformerPrimitiveBase[Inputs,

return outputs

def _update_confidence_columns(self, inputs_metadata: metadata_base.DataMetadata, confidence_columns: typing.Sequence[int]) -> metadata_base.DataMetadata:
def _update_confidence_columns(self, inputs_metadata: metadata_base.DataMetadata, confidence_columns: typing.Sequence[int]) -> metadata_base.DataMetadata: # pragma: no cover
output_columns_length = inputs_metadata.query((metadata_base.ALL_ELEMENTS,))['dimension']['length']

outputs_metadata = inputs_metadata
@@ -193,17 +193,17 @@ class ConstructPredictionsPrimitive(transformer.TransformerPrimitiveBase[Inputs,
if not index_columns:
reference_index_columns = reference.metadata.get_index_columns()

if not reference_index_columns:
if not reference_index_columns: # pragma: no cover
raise ValueError("Cannot find an index column in reference data, but index column is required.")

filtered_index_columns = self._filter_index_columns(reference.metadata, reference_index_columns)
index = reference.select_columns(filtered_index_columns)
else:
else: # pragma: no cover
filtered_index_columns = self._filter_index_columns(inputs.metadata, index_columns)
index = inputs.select_columns(filtered_index_columns)

if not target_columns:
if index_columns:
if index_columns: # pragma: no cover
raise ValueError("No target columns in input data, but index column(s) present.")

# We assume all inputs are targets.
@@ -220,10 +220,10 @@ class ConstructPredictionsPrimitive(transformer.TransformerPrimitiveBase[Inputs,

return index.append_columns(targets)

def multi_produce(self, *, produce_methods: typing.Sequence[str], inputs: Inputs, reference: Inputs, timeout: float = None, iterations: int = None) -> base.MultiCallResult: # type: ignore
def multi_produce(self, *, produce_methods: typing.Sequence[str], inputs: Inputs, reference: Inputs, timeout: float = None, iterations: int = None) -> base.MultiCallResult: # pragma: no cover
return self._multi_produce(produce_methods=produce_methods, timeout=timeout, iterations=iterations, inputs=inputs, reference=reference)

def fit_multi_produce(self, *, produce_methods: typing.Sequence[str], inputs: Inputs, reference: Inputs, timeout: float = None, iterations: int = None) -> base.MultiCallResult: # type: ignore
def fit_multi_produce(self, *, produce_methods: typing.Sequence[str], inputs: Inputs, reference: Inputs, timeout: float = None, iterations: int = None) -> base.MultiCallResult: # pragma: no cover
return self._fit_multi_produce(produce_methods=produce_methods, timeout=timeout, iterations=iterations, inputs=inputs, reference=reference)

def _get_target_names(self, metadata: metadata_base.DataMetadata) -> typing.List[typing.Union[str, None]]:
@@ -239,7 +239,7 @@ class ConstructPredictionsPrimitive(transformer.TransformerPrimitiveBase[Inputs,
def _update_targets_metadata(self, metadata: metadata_base.DataMetadata, target_names: typing.Sequence[typing.Union[str, None]]) -> metadata_base.DataMetadata:
targets_length = metadata.query((metadata_base.ALL_ELEMENTS,))['dimension']['length']

if targets_length != len(target_names):
if targets_length != len(target_names): # pragma: no cover
raise ValueError("Not an expected number of target columns to apply names for. Expected {target_names}, provided {targets_length}.".format(
target_names=len(target_names),
targets_length=targets_length,
@@ -250,7 +250,7 @@ class ConstructPredictionsPrimitive(transformer.TransformerPrimitiveBase[Inputs,
metadata = metadata.add_semantic_type((metadata_base.ALL_ELEMENTS, column_index), 'https://metadata.datadrivendiscovery.org/types/PredictedTarget')

# We do not have it, let's skip it and hope for the best.
if target_name is None:
if target_name is None: # pragma: no cover
continue

metadata = metadata.update_column(column_index, {


+ 116
- 0
tods/tests/data_processing/test_ConstructPredictions.py View File

@@ -71,6 +71,7 @@ class ConstructPredictionsPrimitiveTestCase(unittest.TestCase):




def _test_metadata(self, metadata, no_metadata=False):
self.maxDiff = None

@@ -125,6 +126,121 @@ class ConstructPredictionsPrimitiveTestCase(unittest.TestCase):
],
})

def test_all_columns(self):
dataframe = self._get_yahoo_dataframe()

# We use all columns. Output has to be just index and targets.
targets = copy.copy(dataframe)

# We pretend these are our predictions.
targets.metadata = targets.metadata.remove_semantic_type((metadata_base.ALL_ELEMENTS, 5),
'https://metadata.datadrivendiscovery.org/types/TrueTarget')
targets.metadata = targets.metadata.add_semantic_type((metadata_base.ALL_ELEMENTS, 5),
'https://metadata.datadrivendiscovery.org/types/PredictedTarget')

hyperparams_class = ConstructPredictions.ConstructPredictionsPrimitive.metadata.get_hyperparams()

construct_primitive = ConstructPredictions.ConstructPredictionsPrimitive(
hyperparams=hyperparams_class.defaults())

call_metadata = construct_primitive.produce(inputs=targets, reference=dataframe)

dataframe = call_metadata.value

self.assertEqual(list(dataframe.columns), ['d3mIndex', 'value_3'])

self._test_metadata(dataframe.metadata)

def test_missing_index(self):
dataframe = self._get_yahoo_dataframe()

# We just use all columns.
targets = copy.copy(dataframe)

# We pretend these are our predictions.
targets.metadata = targets.metadata.remove_semantic_type((metadata_base.ALL_ELEMENTS, 5),
'https://metadata.datadrivendiscovery.org/types/TrueTarget')
targets.metadata = targets.metadata.add_semantic_type((metadata_base.ALL_ELEMENTS, 5),
'https://metadata.datadrivendiscovery.org/types/PredictedTarget')

# Remove primary index. This one has to be reconstructed.
targets = targets.remove_columns([0])

hyperparams_class = ConstructPredictions.ConstructPredictionsPrimitive.metadata.get_hyperparams()

construct_primitive = ConstructPredictions.ConstructPredictionsPrimitive(
hyperparams=hyperparams_class.defaults())

call_metadata = construct_primitive.produce(inputs=targets, reference=dataframe)

dataframe = call_metadata.value

self.assertEqual(list(dataframe.columns), ['d3mIndex', 'value_3'])

self._test_metadata(dataframe.metadata)

def test_just_targets_no_metadata(self):
dataframe = self._get_yahoo_dataframe()

hyperparams_class = ExtractColumnsBySemanticTypes.ExtractColumnsBySemanticTypesPrimitive.metadata.get_hyperparams()

# We extract just targets.
primitive = ExtractColumnsBySemanticTypes.ExtractColumnsBySemanticTypesPrimitive(
hyperparams=hyperparams_class.defaults().replace(
{'semantic_types': ('https://metadata.datadrivendiscovery.org/types/Target',)}))

call_metadata = primitive.produce(inputs=dataframe)

targets = call_metadata.value

# Remove all metadata.
targets.metadata = metadata_base.DataMetadata().generate(targets)

hyperparams_class = ConstructPredictions.ConstructPredictionsPrimitive.metadata.get_hyperparams()

construct_primitive = ConstructPredictions.ConstructPredictionsPrimitive(
hyperparams=hyperparams_class.defaults())

call_metadata = construct_primitive.produce(inputs=targets, reference=dataframe)

dataframe = call_metadata.value

self.assertEqual(list(dataframe.columns), ['d3mIndex', 'value_3'])

self._test_metadata(dataframe.metadata, True)

def test_float_vector(self):
dataframe = container.DataFrame({
'd3mIndex': [0],
'target': [container.ndarray(numpy.array([3, 5, 9, 10]))],
}, generate_metadata=True)

# Update metadata.
dataframe.metadata = dataframe.metadata.add_semantic_type((metadata_base.ALL_ELEMENTS, 0),
'https://metadata.datadrivendiscovery.org/types/PrimaryKey')
dataframe.metadata = dataframe.metadata.add_semantic_type((metadata_base.ALL_ELEMENTS, 1),
'https://metadata.datadrivendiscovery.org/types/PredictedTarget')

hyperparams_class = ConstructPredictions.ConstructPredictionsPrimitive.metadata.get_hyperparams()

construct_primitive = ConstructPredictions.ConstructPredictionsPrimitive(
hyperparams=hyperparams_class.defaults())

dataframe = construct_primitive.produce(inputs=dataframe, reference=dataframe).value

self.assertEqual(list(dataframe.columns), ['d3mIndex', 'target'])

self.assertEqual(dataframe.values.tolist(), [
[0, '3,5,9,10'],
])

self.assertEqual(dataframe.metadata.query_column(1), {
'structural_type': str,
'name': 'target',
'semantic_types': (
'https://metadata.datadrivendiscovery.org/types/PredictedTarget',
),
})


if __name__ == '__main__':


Loading…
Cancel
Save