GitOrigin-RevId: 686fff4f73
tags/v1.3.1
| @@ -309,12 +309,17 @@ def test_broadcast(): | |||||
| output2_shape = (20, 10, 20) | output2_shape = (20, 10, 20) | ||||
| data2 = np.random.random(input2_shape).astype(np.float32) | data2 = np.random.random(input2_shape).astype(np.float32) | ||||
| input3_shape = (10, 10) | |||||
| output3_shape = (10, 10) | |||||
| data3 = np.random.random(input3_shape).astype(np.float32) | |||||
| def compare_fn(x, y): | def compare_fn(x, y): | ||||
| assert x.shape[0] == y | assert x.shape[0] == y | ||||
| cases = [ | cases = [ | ||||
| {"input": [data1, output1_shape], "output": output1_shape}, | {"input": [data1, output1_shape], "output": output1_shape}, | ||||
| {"input": [data2, output2_shape], "output": output2_shape}, | {"input": [data2, output2_shape], "output": output2_shape}, | ||||
| {"input": [data3, output3_shape], "output": output3_shape}, | |||||
| ] | ] | ||||
| opr_test(cases, F.broadcast_to, compare_fn=compare_fn) | opr_test(cases, F.broadcast_to, compare_fn=compare_fn) | ||||
| @@ -24,14 +24,14 @@ std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) { | |||||
| return Broadcast::make(); | return Broadcast::make(); | ||||
| } | } | ||||
| cg::OperatorNodeBase* apply_on_var_node( | |||||
| auto apply_on_var_node( | |||||
| const OpDef& def, | const OpDef& def, | ||||
| const VarNodeArray& inputs) { | const VarNodeArray& inputs) { | ||||
| auto&& op = def.cast_final_safe<Broadcast>(); | auto&& op = def.cast_final_safe<Broadcast>(); | ||||
| size_t nr_inp = inputs.size(); | size_t nr_inp = inputs.size(); | ||||
| mgb_assert(nr_inp == 2, "Broadcast expects 2 inputs; got %lu actually", nr_inp); | mgb_assert(nr_inp == 2, "Broadcast expects 2 inputs; got %lu actually", nr_inp); | ||||
| OperatorNodeConfig config{op.make_name()}; | OperatorNodeConfig config{op.make_name()}; | ||||
| return opr::Broadcast::make(inputs[0], inputs[1], config).node()->owner_opr(); | |||||
| return opr::Broadcast::make(inputs[0], inputs[1], config); | |||||
| } | } | ||||
| bool valid_broadcast(const TensorShape& src_shape, | bool valid_broadcast(const TensorShape& src_shape, | ||||