Browse Source

support broadcast of Mul cpu op

tags/v1.1.0
wuxuejian 5 years ago
parent
commit
fe855fe911
1 changed files with 6 additions and 2 deletions
  1. +6
    -2
      mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_cpu_kernel.cc

+ 6
- 2
mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_cpu_kernel.cc View File

@@ -32,7 +32,9 @@ void ArithmeticCPUKernel::AssignAdd(T *input1, const T *input2, T *out, size_t s
template <typename T>
void ArithmeticCPUKernel::Add(const T *input1, const T *input2, T *out, size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
out[i] = input1[i] + input2[i];
std::vector<size_t> idx;
GenIndex(i, &idx);
out[i] = input1[idx[0]] + input2[idx[1]];
}
}

@@ -48,7 +50,9 @@ void ArithmeticCPUKernel::Sub(const T *input1, const T *input2, T *out, size_t s
template <typename T>
void ArithmeticCPUKernel::Mul(const T *input1, const T *input2, T *out, size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
out[i] = input1[i] * input2[i];
std::vector<size_t> idx;
GenIndex(i, &idx);
out[i] = input1[idx[0]] * input2[idx[1]];
}
}



Loading…
Cancel
Save