|
|
|
@@ -16,6 +16,8 @@ |
|
|
|
import math |
|
|
|
|
|
|
|
from mindinsight.datavisual.proto_files.mindinsight_summary_pb2 import Summary |
|
|
|
from mindinsight.utils.exceptions import ParamValueError |
|
|
|
from mindinsight.datavisual.utils.utils import calc_histogram_bins |
|
|
|
|
|
|
|
|
|
|
|
def _mask_invalid_number(num): |
|
|
|
@@ -26,6 +28,49 @@ def _mask_invalid_number(num): |
|
|
|
return num |
|
|
|
|
|
|
|
|
|
|
|
class Bucket: |
|
|
|
""" |
|
|
|
Bucket data class. |
|
|
|
|
|
|
|
Args: |
|
|
|
left (double): Left edge of the histogram bucket. |
|
|
|
width (double): Width of the histogram bucket. |
|
|
|
count (int): Count of numbers fallen in the histogram bucket. |
|
|
|
""" |
|
|
|
def __init__(self, left, width, count): |
|
|
|
self._left = left |
|
|
|
self._width = width |
|
|
|
self._count = count |
|
|
|
|
|
|
|
@property |
|
|
|
def left(self): |
|
|
|
"""Gets left edge of the histogram bucket.""" |
|
|
|
return self._left |
|
|
|
|
|
|
|
@property |
|
|
|
def count(self): |
|
|
|
"""Gets count of numbers fallen in the histogram bucket.""" |
|
|
|
return self._count |
|
|
|
|
|
|
|
@property |
|
|
|
def width(self): |
|
|
|
"""Gets width of the histogram bucket.""" |
|
|
|
return self._width |
|
|
|
|
|
|
|
@property |
|
|
|
def right(self): |
|
|
|
"""Gets right edge of the histogram bucket.""" |
|
|
|
return self._left + self._width |
|
|
|
|
|
|
|
def as_tuple(self): |
|
|
|
"""Gets the bucket as tuple.""" |
|
|
|
return self._left, self._width, self._count |
|
|
|
|
|
|
|
def __repr__(self): |
|
|
|
"""Returns repr(self).""" |
|
|
|
return "Bucket(left={}, width={}, count={})".format(self._left, self._width, self._count) |
|
|
|
|
|
|
|
|
|
|
|
class HistogramContainer: |
|
|
|
""" |
|
|
|
Histogram data container. |
|
|
|
@@ -35,16 +80,19 @@ class HistogramContainer: |
|
|
|
""" |
|
|
|
def __init__(self, histogram_message: Summary.Histogram): |
|
|
|
self._msg = histogram_message |
|
|
|
self._original_buckets = tuple((bucket.left, bucket.width, bucket.count) for bucket in self._msg.buckets) |
|
|
|
original_buckets = [Bucket(bucket.left, bucket.width, bucket.count) for bucket in self._msg.buckets] |
|
|
|
# Ensure buckets are sorted from min to max. |
|
|
|
original_buckets.sort(key=lambda bucket: bucket.left) |
|
|
|
self._original_buckets = tuple(original_buckets) |
|
|
|
self._count = sum(bucket.count for bucket in self._original_buckets) |
|
|
|
self._max = _mask_invalid_number(histogram_message.max) |
|
|
|
self._min = _mask_invalid_number(histogram_message.min) |
|
|
|
self._visual_max = self._max |
|
|
|
self._visual_min = self._min |
|
|
|
# default bin number |
|
|
|
self._visual_bins = 10 |
|
|
|
self._count = sum(bucket[2] for bucket in self._original_buckets) |
|
|
|
self._visual_bins = calc_histogram_bins(self._count) |
|
|
|
# Note that tuple is immutable, so sharing tuple is often safe. |
|
|
|
self._re_sampled_buckets = self._original_buckets |
|
|
|
self._re_sampled_buckets = () |
|
|
|
|
|
|
|
@property |
|
|
|
def max(self): |
|
|
|
@@ -63,7 +111,7 @@ class HistogramContainer: |
|
|
|
|
|
|
|
@property |
|
|
|
def original_msg(self): |
|
|
|
"""Get original proto message""" |
|
|
|
"""Gets original proto message.""" |
|
|
|
return self._msg |
|
|
|
|
|
|
|
def set_visual_range(self, max_val: float, min_val: float, bins: int) -> None: |
|
|
|
@@ -77,6 +125,13 @@ class HistogramContainer: |
|
|
|
min_val (float): Min value for visual histogram. |
|
|
|
bins (int): Bins number for visual histogram. |
|
|
|
""" |
|
|
|
if max_val < min_val: |
|
|
|
raise ParamValueError( |
|
|
|
"Invalid input. max_val({}) is less or equal than min_val({}).".format(max_val, min_val)) |
|
|
|
|
|
|
|
if bins < 1: |
|
|
|
raise ParamValueError("Invalid input bins({}). Must be greater than 0.".format(bins)) |
|
|
|
|
|
|
|
self._visual_max = max_val |
|
|
|
self._visual_min = min_val |
|
|
|
self._visual_bins = bins |
|
|
|
@@ -84,15 +139,104 @@ class HistogramContainer: |
|
|
|
# mark _re_sampled_buckets to empty |
|
|
|
self._re_sampled_buckets = () |
|
|
|
|
|
|
|
def _re_sample_buckets(self): |
|
|
|
# Will call re-sample logic in later PR. |
|
|
|
self._re_sampled_buckets = self._original_buckets |
|
|
|
def _calc_intersection_len(self, max1, min1, max2, min2): |
|
|
|
"""Calculates intersection length of [min1, max1] and [min2, max2].""" |
|
|
|
if max1 < min1: |
|
|
|
raise ParamValueError( |
|
|
|
"Invalid input. max1({}) is less than min1({}).".format(max1, min1)) |
|
|
|
|
|
|
|
if max2 < min2: |
|
|
|
raise ParamValueError( |
|
|
|
"Invalid input. max2({}) is less than min2({}).".format(max2, min2)) |
|
|
|
|
|
|
|
if min1 <= min2: |
|
|
|
if max1 <= min2: |
|
|
|
# return value must be calculated by max1.__sub__ |
|
|
|
return max1 - max1 |
|
|
|
if max1 <= max2: |
|
|
|
return max1 - min2 |
|
|
|
# max1 > max2 |
|
|
|
return max2 - min2 |
|
|
|
|
|
|
|
# min1 > min2 |
|
|
|
if max2 <= min1: |
|
|
|
return max2 - max2 |
|
|
|
if max2 <= max1: |
|
|
|
return max2 - min1 |
|
|
|
return max1 - min1 |
|
|
|
|
|
|
|
def buckets(self): |
|
|
|
def _re_sample_buckets(self): |
|
|
|
"""Re-samples buckets according to visual_max, visual_min and visual_bins.""" |
|
|
|
if self._visual_max == self._visual_min: |
|
|
|
# Adjust visual range if max equals min. |
|
|
|
self._visual_max += 0.5 |
|
|
|
self._visual_min -= 0.5 |
|
|
|
|
|
|
|
width = (self._visual_max - self._visual_min) / self._visual_bins |
|
|
|
|
|
|
|
if not self.count: |
|
|
|
self._re_sampled_buckets = tuple( |
|
|
|
Bucket(self._visual_min + width * i, width, 0) |
|
|
|
for i in range(self._visual_bins)) |
|
|
|
return |
|
|
|
|
|
|
|
re_sampled = [] |
|
|
|
original_pos = 0 |
|
|
|
original_bucket = self._original_buckets[original_pos] |
|
|
|
for i in range(self._visual_bins): |
|
|
|
cur_left = self._visual_min + width * i |
|
|
|
cur_right = cur_left + width |
|
|
|
cur_estimated_count = 0.0 |
|
|
|
|
|
|
|
# Skip no bucket range. |
|
|
|
if cur_right <= original_bucket.left: |
|
|
|
re_sampled.append(Bucket(cur_left, width, math.ceil(cur_estimated_count))) |
|
|
|
continue |
|
|
|
|
|
|
|
# Skip no intersect range. |
|
|
|
while cur_left >= original_bucket.right: |
|
|
|
original_pos += 1 |
|
|
|
if original_pos >= len(self._original_buckets): |
|
|
|
break |
|
|
|
original_bucket = self._original_buckets[original_pos] |
|
|
|
|
|
|
|
# entering with this condition: cur_right > original_bucket.left and cur_left < original_bucket.right |
|
|
|
while True: |
|
|
|
if original_pos >= len(self._original_buckets): |
|
|
|
break |
|
|
|
original_bucket = self._original_buckets[original_pos] |
|
|
|
|
|
|
|
intersection = self._calc_intersection_len( |
|
|
|
min1=cur_left, max1=cur_right, |
|
|
|
min2=original_bucket.left, max2=original_bucket.right) |
|
|
|
estimated_count = (intersection / original_bucket.width) * original_bucket.count |
|
|
|
|
|
|
|
cur_estimated_count += estimated_count |
|
|
|
if cur_right > original_bucket.right: |
|
|
|
# Need to sample next original bucket to this visual bucket. |
|
|
|
original_pos += 1 |
|
|
|
else: |
|
|
|
# Current visual bucket has taken all intersect buckets into account. |
|
|
|
break |
|
|
|
|
|
|
|
re_sampled.append(Bucket(cur_left, width, math.ceil(cur_estimated_count))) |
|
|
|
|
|
|
|
self._re_sampled_buckets = tuple(re_sampled) |
|
|
|
|
|
|
|
def buckets(self, convert_to_tuple=True): |
|
|
|
""" |
|
|
|
Get visual buckets instead of original buckets. |
|
|
|
|
|
|
|
Args: |
|
|
|
convert_to_tuple (bool): Whether convert bucket object to tuple. |
|
|
|
|
|
|
|
Returns: |
|
|
|
tuple, contains buckets. |
|
|
|
""" |
|
|
|
if not self._re_sampled_buckets: |
|
|
|
self._re_sample_buckets() |
|
|
|
|
|
|
|
return self._re_sampled_buckets |
|
|
|
if not convert_to_tuple: |
|
|
|
return self._re_sampled_buckets |
|
|
|
|
|
|
|
return tuple(bucket.as_tuple() for bucket in self._re_sampled_buckets) |