| @@ -1 +1 @@ | |||||
| Subproject commit ef1a6b06781035540023819afead4bdfbd49af81 | |||||
| Subproject commit 6d01f5e364224da0d58b8b20761b9af67587950b | |||||
| @@ -36,9 +36,6 @@ def expand_gelu(expand_info): | |||||
| # create tensor input. | # create tensor input. | ||||
| input_x = graph_builder.tensor(input_desc['shape'], input_desc['data_type'], input_desc['format']) | input_x = graph_builder.tensor(input_desc['shape'], input_desc['data_type'], input_desc['format']) | ||||
| graph_scope.set_input(input_x) | graph_scope.set_input(input_x) | ||||
| dtype = input_x.dtype | |||||
| if dtype == 'float16': | |||||
| input_x = graph_builder.emit('Cast', [input_x], attrs={'dst_type': 'float32'}) | |||||
| # cal y | # cal y | ||||
| mul_0 = graph_builder.emit('Mul', [input_x, input_x]) | mul_0 = graph_builder.emit('Mul', [input_x, input_x]) | ||||
| @@ -58,8 +55,6 @@ def expand_gelu(expand_info): | |||||
| mul_x = graph_builder.emit('Mul', [input_x, tanh_y_add_one]) | mul_x = graph_builder.emit('Mul', [input_x, tanh_y_add_one]) | ||||
| result = graph_builder.emit('Mul', [const_half, mul_x]) | result = graph_builder.emit('Mul', [const_half, mul_x]) | ||||
| if dtype == 'float16': | |||||
| result = graph_builder.emit('Cast', [result], attrs={'dst_type': 'float16'}) | |||||
| # set graph output. | # set graph output. | ||||
| graph_scope.set_output(result) | graph_scope.set_output(result) | ||||
| @@ -43,9 +43,6 @@ def expand_gelugrad(expand_info): | |||||
| input_x = graph_builder.tensor(input_desc_1['shape'], input_desc_1['data_type'], input_desc_1['format']) | input_x = graph_builder.tensor(input_desc_1['shape'], input_desc_1['data_type'], input_desc_1['format']) | ||||
| input_y = graph_builder.tensor(input_desc_2['shape'], input_desc_2['data_type'], input_desc_2['format']) | input_y = graph_builder.tensor(input_desc_2['shape'], input_desc_2['data_type'], input_desc_2['format']) | ||||
| graph_scope.set_input(input_dy, input_x, input_y) | graph_scope.set_input(input_dy, input_x, input_y) | ||||
| dtype = input_dy.dtype | |||||
| if dtype == 'float16': | |||||
| input_dy = graph_builder.emit('Cast', [input_dy], attrs={'dst_type': 'float32'}) | |||||
| # create some const var | # create some const var | ||||
| const_csvalue = graph_builder.value(input_dy.dtype, CSVALUE, input_desc_0['format']) | const_csvalue = graph_builder.value(input_dy.dtype, CSVALUE, input_desc_0['format']) | ||||
| @@ -83,8 +80,6 @@ def expand_gelugrad(expand_info): | |||||
| result_tmp = graph_builder.emit('TensorAdd', [half_mul_tanh_res_add_one, mul_final]) | result_tmp = graph_builder.emit('TensorAdd', [half_mul_tanh_res_add_one, mul_final]) | ||||
| result = graph_builder.emit('Mul', [input_dy, result_tmp]) | result = graph_builder.emit('Mul', [input_dy, result_tmp]) | ||||
| if dtype == 'float16': | |||||
| result = graph_builder.emit('Cast', [result], attrs={'dst_type': 'float16'}) | |||||
| # set graph output. | # set graph output. | ||||
| graph_scope.set_output(result) | graph_scope.set_output(result) | ||||
| @@ -149,7 +149,15 @@ class GraphBuilder: | |||||
| """Create a new Value""" | """Create a new Value""" | ||||
| if name in (None, ''): | if name in (None, ''): | ||||
| name = self._alloc_tensor_name() | name = self._alloc_tensor_name() | ||||
| return Value(name, dtype, value, data_format) | |||||
| if dtype == "float16": | |||||
| # For float16 value, it will be changed to float32 wrongly. And there is no good solution for now. | |||||
| # So instead just declare float32 value and then cast it to float16. | |||||
| v_fp32 = Value(name, "float32", value, data_format) | |||||
| v = self.emit("Cast", [v_fp32], attrs={"dst_type": "float16"}) | |||||
| else: | |||||
| v = Value(name, dtype, value, data_format) | |||||
| return v | |||||
| def op(self, prim, output, inputs, attrs=None): | def op(self, prim, output, inputs, attrs=None): | ||||
| """Insert an operator into graph""" | """Insert an operator into graph""" | ||||
| @@ -166,9 +174,9 @@ class GraphBuilder: | |||||
| """Emit a new operation""" | """Emit a new operation""" | ||||
| if attrs is None: | if attrs is None: | ||||
| attrs = {} | attrs = {} | ||||
| if isinstance(inputs, Tensor): | |||||
| if isinstance(inputs, (Tensor, Value)): | |||||
| inputs = [inputs] | inputs = [inputs] | ||||
| tensor_inputs = [t for t in inputs if isinstance(t, Tensor)] | |||||
| tensor_inputs = [t for t in inputs if isinstance(t, (Tensor, Value))] | |||||
| out_shape, out_dtype, out_format = OpInfer.infer(prim, tensor_inputs, attrs) | out_shape, out_dtype, out_format = OpInfer.infer(prim, tensor_inputs, attrs) | ||||
| output = self.tensor(out_shape, out_dtype, out_format, name) | output = self.tensor(out_shape, out_dtype, out_format, name) | ||||
| self.op(prim, output, inputs, attrs) | self.op(prim, output, inputs, attrs) | ||||