# -*- coding: UTF-8 -*- """ Copyright 2021 Tianshu AI Platform. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ============================================================= """ import struct import numpy as np import json from oneflow.customized.utils.plugin_hparams_pb2 import HParamsPluginData from oneflow.customized.utils.graph_pb2 import GraphDef def get_parser(value, step, wall_time): """ :param event: :return: dict = {tag, step, wall_time, value, type} """ data = dict(step=step, wall_time=wall_time) if value.HasField('simple_value'): value = _get_scalar(value) elif value.HasField('image'): value = _get_image(value) elif value.HasField('audio'): value = _get_audio(value) elif value.HasField('histo'): value = _get_hist(value) elif value.HasField('projector'): value = _get_projector(value) elif value.HasField('metadata'): if value.metadata.plugin_data.plugin_name == 'hparams': value = _get_hparams(value) elif value.metadata.plugin_data.plugin_name == 'text': value = _get_text(value) elif value.metadata.plugin_data.plugin_name == 'featuremap': value = _get_featuremap(value) elif value.metadata.plugin_data.plugin_name == 'transformer': if 'transformertext' in value.tag: value = _get_TransformerText(value) else: value = _get_transformer(value) elif value.metadata.plugin_data.plugin_name == 'hiddenstate': value = _get_state(value) else: raise Exception(f'cannot parse {value.metadata.plugin_data.plugin_name} data.') else: raise Exception(f'cannot parse this data: {value}') data.update(value) return data def _decode_byte(tensor): # 若tensor是float类型 if tensor.dtype == 1: return struct.unpack('f', tensor.tensor_content)[0] def _decoder_tensor(tensor): # tensor 为字节流 tensor_shape = tuple([i.size for i in tensor.tensor_shape.dim]) tensor_content = np.frombuffer(tensor.tensor_content, dtype=tensor.dtype) return tensor_content.reshape(tensor_shape) def get_graph(event): graph = GraphDef() graph.ParseFromString(event.graph_def) return dict(wall_time=event.wall_time, value=graph, type='graph') def _get_scalar(value): """ Decode an scalar event :param value: A value field of an event :return: Decoded scalar """ return dict(tag=value.tag, value=value.simple_value, type='scalar') def _get_image(value): """ Decode an image event :param value: A value field of an event :return: Decoded image """ dic = { 'width': value.image.width, 'height': value.image.height, 'colorspace': value.image.colorspace, 'encoded_image_string': value.image.encoded_image_string } return dict(tag=value.tag, value=dic, type='image') def _get_text(value): """ Return text data :param value: A value field of an event :return: text data """ return dict(tag=value.tag, value=np.array([v.decode() for v in value.tensor.string_val]), type='text') def _get_audio(value): dic = {'sample_rate': value.audio.sample_rate, 'num_channels': value.audio.num_channels, 'length_frames': value.audio.length_frames, 'encoded_audio_string': value.audio.encoded_audio_string} return dict(tag=value.tag, value=dic, type='audio') def _get_hist(value): dic = { 'min': value.histo.min, 'max': value.histo.max, 'num': value.histo.num, 'sum': value.histo.sum, 'sum_squares': value.histo.sum_squares, 'bucket_limit': np.array(value.histo.bucket_limit), 'bucket': np.array(value.histo.bucket)} return dict(tag=value.tag, value=dic, type='hist') def _get_hparams(value): metadata = value.metadata plugin_data = HParamsPluginData() plugin_data.ParseFromString(metadata.plugin_data.content) return dict(tag=value.tag, value=plugin_data, type='hparams') def _get_embedding(value): projector = value.projector if projector.embedding.HasField('sample'): sample_type = {1:'audio', 2:'text', 3:'image'} sample = projector.embedding.sample data =dict(type = sample_type[sample.type], X = _decoder_tensor(sample.X)) return dict(tag = 'sample_' + value.tag, value = data, type = 'embedding' ) else: embedding = projector.embedding return dict(tag = value.tag, value = _decoder_tensor(embedding.value), label = _decoder_tensor(embedding.label) if embedding.HasField('label') else np.array([]), type = 'embedding') def _get_exception(value): return dict(tag=value.tag, value=_decoder_tensor(value.projector.exception.value), type='exception') def _get_projector(value): projector = value.projector if projector.HasField('embedding'): return _get_embedding(value) else: return _get_exception(value) def filter_graph(file): variable_names = {} graph = json.loads(file) for sub_graph in graph: cfg = sub_graph["config"] # 拷贝一份,用于循环 cfg_copy = cfg["layers"].copy() for layer in cfg_copy: if layer["class_name"] == "variable": _name = layer["name"] variable_names[_name] = layer cfg["layers"].remove(layer) # 第二遍循环,删除`variable_names`出现在`inbound_nodes`中的名字 for sub_graph in graph: cfg = sub_graph["config"] for layer in cfg["layers"]: in_nodes = layer["inbound_nodes"] in_nodes_copy = in_nodes.copy() for node in in_nodes_copy: # 在里面则删除 if node in variable_names.keys(): in_nodes.remove(node) graph_str = json.dumps(graph) return graph_str def _get_featuremap(value): return dict(tag=value.tag, value=np.array(_decoder_tensor(value.tensor)), type='featuremap') def _get_transformer(value): return dict(tag=value.tag, value=np.array(_decoder_tensor(value.tensor)), type='transformer') def _get_TransformerText(value): if "transformertext-sentence" in value.tag: return dict(tag=value.tag, value=np.array(_decoder_tensor(value.tensor)), type='transformer') else: return dict(tag=value.tag, value=_decoder_TransformerText(value.transformer), type='transformer') def _get_state(value): return dict(tag=value.tag, value=np.array(_decoder_tensor(value.tensor)), type='hiddenstate') def _decoder_TransformerText(value): data = {} for attentionItem in list(value.attentionItem): tag = attentionItem.tag attention_kid_data = {} attention_kid_data['attn'] = _decoder_tensor(attentionItem.attn) attention_kid_data['left_text'] = _decoder_tensor(attentionItem.left) attention_kid_data['right_text'] = _decoder_tensor(attentionItem.right) data[tag] = attention_kid_data data["bidirectional"] = value.bidirectional data["default_filter"] = value.default_filter data["displayMode"] = value.displayMode data["head"] = value.head data["layer"] = value.layer return data