# Copyright 2020 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """bprop primitives""" from .. import functional as F from ..composite import multitype_ops as C from .grad_base import bprops # Unused parameters are placeholders. @bprops.register("scalar_add") def bprop_scalar_add(x, y, out, dout): """Backpropagator for primitive `scalar_add`.""" return dout, dout @bprops.register("scalar_mul") def bprop_scalar_mul(x, y, out, dout): """Backpropagator for primitive `scalar_mul`.""" return dout*y, dout*x @bprops.register("scalar_sub") def bprop_scalar_sub(x, y, out, dout): """Backpropagator for primitive `scalar_sub`.""" return dout, -dout @bprops.register("scalar_div") def bprop_scalar_div(x, y, out, dout): """Backpropagator for primitive `scalar_div`.""" return dout/y, (-dout) * (out/y) @bprops.register("scalar_pow") def bprop_scalar_pow(x, y, out, dout): """Backpropagator for primitive `scalar_pow`.""" return dout * (y * (x ** (y-1))), dout * (F.scalar_log(x) * out) @bprops.register("scalar_exp") def bprop_scalar_exp(x, out, dout): """Backpropagator for primitive `scalar_exp`.""" return (dout * out,) @bprops.register("scalar_uadd") def bprop_scalar_uadd(x, out, dout): """Backpropagator for primitive `scalar_uadd`.""" return (dout,) @bprops.register("scalar_usub") def bprop_scalar_usub(x, out, dout): """Backpropagator for primitive `scalar_usub`.""" return (-dout,) @bprops.register("scalar_gt") def bprop_scalar_gt(x, y, out, dout): """Backpropagator for primitive `scalar_gt`.""" return C.zeros_like(x), C.zeros_like(y) @bprops.register("scalar_lt") def bprop_scalar_lt(x, y, out, dout): """Backpropagator for primitive `scalar_lt`.""" return C.zeros_like(x), C.zeros_like(y) @bprops.register("scalar_ge") def bprop_scalar_ge(x, y, out, dout): """Backpropagator for primitive `scalar_ge`.""" return C.zeros_like(x), C.zeros_like(y) @bprops.register("scalar_le") def bprop_scalar_le(x, y, out, dout): """Backpropagator for primitive `scalar_le`.""" return C.zeros_like(x), C.zeros_like(y) @bprops.register("scalar_eq") def bprop_scalar_eq(x, y, out, dout): """Backpropagator for primitive `scalar_eq`.""" return C.zeros_like(x), C.zeros_like(y) @bprops.register("scalar_ne") def bprop_scalar_ne(x, y, out, dout): """Backpropagator for primitive `scalar_eq`.""" return C.zeros_like(x), C.zeros_like(y) @bprops.register("scalar_cast") def bprop_scalar_cast(x, t, out, dout): """Backpropagator for primitive `scalar_cast`.""" return F.scalar_cast(dout, F.typeof(x)), t @bprops.register("tuple_getitem") def bprop_tuple_getitem(data, idx, out, dout): """Backpropagator for primitive `tuple_getitem`.""" return F.tuple_setitem(C.zeros_like(data), idx, dout), C.zeros_like(idx) @bprops.register("list_getitem") def bprop_list_getitem(data, idx, out, dout): """Backpropagator for primitive `list_getitem`.""" return F.list_setitem(C.zeros_like(data), idx, dout), C.zeros_like(idx) @bprops.register("identity") def bprop_identity(x, out, dout): """Backpropagator for primitive `identity`.""" return (dout,) @bprops.register("make_ref") def bprop_make_ref(key, x, y, out, dout): """Backpropagator for primitive `make_ref`.""" return (C.zeros_like(key), dout, C.zeros_like(y)) @bprops.register("get_ref_value") def bprop_get_ref_value(x, out, dout): """Backpropagator for primitive `get_ref_value`.""" return (dout,) @bprops.register("get_ref_key") def bprop_get_ref_key(x, out, dout): """Backpropagator for primitive `get_ref_key`.""" return (C.zeros_like(x),) @bprops.register("scalar_to_array") def bprop_scalar_to_array(x, out, dout): """Backpropagator for primitive `scalar_to_array`.""" return (F.array_to_scalar(dout),) @bprops.register("array_to_scalar") def bprop_array_to_scalar(x, out, dout): """Backpropagator for primitive `array_to_scalar`.""" return (F.scalar_to_array(dout),) @bprops.register("dot") def bprop_dot(x, y, out, dout): """Backpropagator for primitive `dot`.""" return F.dot(dout, F.transpose(y, (1, 0))), F.dot(F.transpose(x, (1, 0)), dout) @bprops.register("reshape") def bprop_reshape(xs, shp, out, dout): """Backpropagator for primitive `reshape`.""" return F.reshape(dout, F.shape(xs)), C.zeros_like(shp) @bprops.register("distribute") def bprop_distribute(arr, shp, out, dout): """Backpropagator for primitive `distribute`.""" return F.array_reduce(F.scalar_add, dout, F.shape(arr)), C.zeros_like(shp) @bprops.register("shape") def bprop_shape(arr, out, dout): """Backpropagator for primitive `shape`.""" return (C.zeros_like(arr),) @bprops.register("broadcast_shape") def bprop_broadcast_shape(shp1, shp2, out, dout): """Backpropagator for primitive `broadcast_shape`.""" return C.zeros_like(shp1), C.zeros_like(shp2) @bprops.register("J") def bprop_j(x, out, dout): """Backpropagator for primitive `J`.""" return (F.jinv(dout),) @bprops.register("array_reduce") def bprop_array_reduce(fn, x, shp, out, dout): """Backpropagator for primitive `array_reduce`.""" return F.distribute(dout, F.shape(x)), C.zeros_like(shp) @bprops.register("Depend") def bprop_depend(x, y, out, dout): """Backpropagator for primitive `depend`.""" return dout, C.zeros_like(y) @bprops.register("embed") def bprop_embed(x, out, dout): """Backpropagator for primitive `embed`.""" return (C.zeros_like(x),) @bprops.register("bool_not") def bprop_bool_not(x, out, dout): """Backpropagator for primitive `bool_not`.""" return (C.zeros_like(x),) @bprops.register("bool_or") def bprop_bool_or(x, y, out, dout): """Backpropagator for primitive `bool_or`.""" return C.zeros_like(x), C.zeros_like(y) @bprops.register("stop_gradient") def bprop_stop_gradient(x, out, dout): """Backpropagator for primitive `stop_gradient`.""" return (C.zeros_like(x),) @bprops.register("bool_and") def bprop_bool_and(x, y, out, dout): """Backpropagator for primitive `bool_and`.""" return C.zeros_like(x), C.zeros_like(y) @bprops.register("ControlDepend") def bprop_control_depend(x, y, out, dout): """Backpropagator for primitive `Control_depend`.""" return C.zeros_like(x), C.zeros_like(y) @bprops.register("switch") def bprop_switch(cond, tb, fb, out, dout): """Backpropagator for primitive `switch`.""" return C.zeros_like(cond), F.switch(cond, dout, C.zeros_like(tb)), \ F.switch(cond, C.zeros_like(fb), dout) def _fprop_switch_layer(index, layers): """Backpropagator for primitive `switch_layer`.""" def _bprop_switch_layer(dout): return dout, C.zeros_like(index), () return F.switch_layer(index, layers), _bprop_switch_layer