Browse Source

!15161 support hbm perfermance test

From: @laiyongqiang
Reviewed-by: @jjfeing,@majorzhang
Signed-off-by: @majorzhang
pull/15161/MERGE
mindspore-ci-bot Gitee 4 years ago
parent
commit
0fe493e205
8 changed files with 104 additions and 87 deletions
  1. +4
    -4
      mindspore/ccsrc/backend/optimizer/mem_reuse/mem_dynamic_allocator.h
  2. +45
    -10
      mindspore/ccsrc/backend/optimizer/somas/somas.cc
  3. +37
    -3
      mindspore/ccsrc/runtime/device/ascend/ascend_memory_pool.cc
  4. +2
    -2
      mindspore/ccsrc/runtime/device/ascend/ascend_memory_pool.h
  5. +2
    -1
      mindspore/ccsrc/runtime/device/kernel_runtime.cc
  6. +7
    -64
      mindspore/ccsrc/runtime/device/memory_manager.cc
  7. +1
    -1
      mindspore/ccsrc/runtime/device/memory_manager.h
  8. +6
    -2
      tests/st/ops/ascend/test_tbe_ops/test_relu.py

+ 4
- 4
mindspore/ccsrc/backend/optimizer/mem_reuse/mem_dynamic_allocator.h View File

@@ -109,16 +109,16 @@ class DynamicMemPoolBestFit {
protected: protected:
// The real size by memory alloc aligned. // The real size by memory alloc aligned.
virtual size_t AlignMemorySize(size_t size) const; virtual size_t AlignMemorySize(size_t size) const;
// Get the minimum memory unit size using for dynamic extend.
virtual size_t mem_alloc_unit_size() const { return DYNAMIC_MEM_ALLOC_UNIT_SIZE; }
// Calculate memory block required alloc size when adding the memory block.
virtual size_t CalMemBlockAllocSize(size_t size);
private: private:
// Get the minimum memory unit size using for dynamic extend.
size_t mem_alloc_unit_size() const { return DYNAMIC_MEM_ALLOC_UNIT_SIZE; }
// Find the idle memory buf by aligned size when memory alloc. // Find the idle memory buf by aligned size when memory alloc.
DeviceMemPtr FindIdleMemBuf(size_t size); DeviceMemPtr FindIdleMemBuf(size_t size);
// Add the memory block and memory buf when memory alloc not find the idle memory buf. // Add the memory block and memory buf when memory alloc not find the idle memory buf.
DeviceMemPtr AddMemBlockAndMemBuf(size_t size); DeviceMemPtr AddMemBlockAndMemBuf(size_t size);
// Calculate memory block required alloc size when adding the memory block.
size_t CalMemBlockAllocSize(size_t size);
// Judge whether need divide the memory buf by alloc size and memory buf size. // Judge whether need divide the memory buf by alloc size and memory buf size.
bool IsDivide(size_t tensor_size, size_t mem_buf_size) const; bool IsDivide(size_t tensor_size, size_t mem_buf_size) const;
// Divide the memory buf by alloc size. // Divide the memory buf by alloc size.


+ 45
- 10
mindspore/ccsrc/backend/optimizer/somas/somas.cc View File

@@ -447,6 +447,7 @@ void Somas::InitSomasOutputAndWorkspaceTensors(const session::KernelGraph *graph
auto kernel_mod = AnfAlgo::GetKernelMod(kernel); auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
MS_EXCEPTION_IF_NULL(kernel_mod); MS_EXCEPTION_IF_NULL(kernel_mod);
auto output_sizes = kernel_mod->GetOutputSizeList(); auto output_sizes = kernel_mod->GetOutputSizeList();
auto index = 0;
for (const auto &size : output_sizes) { for (const auto &size : output_sizes) {
auto output_tensor_index = tensor_index; auto output_tensor_index = tensor_index;
tensor_index++; tensor_index++;
@@ -455,15 +456,21 @@ void Somas::InitSomasOutputAndWorkspaceTensors(const session::KernelGraph *graph
tensor->lifetime_.start_ = node->GetId(); tensor->lifetime_.start_ = node->GetId();
tensor->lifetime_.end_ = node->GetId(); tensor->lifetime_.end_ = node->GetId();
tensor->type_ = kOutputOnly; tensor->type_ = kOutputOnly;
if (AnfAlgo::OutputAddrExist(kernel, index)) {
tensor->aligned_size_ = 0;
}

tensors_list_.push_back(tensor); tensors_list_.push_back(tensor);
tensors_map_[output_tensor_index] = tensor; tensors_map_[output_tensor_index] = tensor;
stream->tensors_.push_back(tensor); stream->tensors_.push_back(tensor);
node->tensors_.insert(tensor); node->tensors_.insert(tensor);
node->output_tensors_.push_back(tensor); node->output_tensors_.push_back(tensor);
index++;
} }


// WorkSpace Tensor // WorkSpace Tensor
auto workspace_sizes = kernel_mod->GetWorkspaceSizeList(); auto workspace_sizes = kernel_mod->GetWorkspaceSizeList();
index = 0;
for (const auto &size : workspace_sizes) { for (const auto &size : workspace_sizes) {
auto workspace_tensor_index = tensor_index; auto workspace_tensor_index = tensor_index;
tensor_index++; tensor_index++;
@@ -471,11 +478,15 @@ void Somas::InitSomasOutputAndWorkspaceTensors(const session::KernelGraph *graph
tensor->type_ = kWorkspace; tensor->type_ = kWorkspace;
tensor->lifetime_.start_ = node->GetId(); tensor->lifetime_.start_ = node->GetId();
tensor->lifetime_.end_ = node->GetId(); tensor->lifetime_.end_ = node->GetId();
if (AnfAlgo::WorkspaceAddrExist(kernel, index)) {
tensor->aligned_size_ = 0;
}
tensors_list_.push_back(tensor); tensors_list_.push_back(tensor);
tensors_map_[workspace_tensor_index] = tensor; tensors_map_[workspace_tensor_index] = tensor;
stream->tensors_.push_back(tensor); stream->tensors_.push_back(tensor);
node->tensors_.insert(tensor); node->tensors_.insert(tensor);
node->workspace_tensors_.push_back(tensor); node->workspace_tensors_.push_back(tensor);
index++;
} }
} }
} }
@@ -874,8 +885,12 @@ void Somas::GenContiguousList(const session::KernelGraph *graph) {


// Contiguous input // Contiguous input
if ((!node->input_tensors_.empty()) && (!node->input_tensors_[0]->contiguous_)) { if ((!node->input_tensors_.empty()) && (!node->input_tensors_[0]->contiguous_)) {
node->input_tensors_[0]->aligned_size_ += kGapSize;
node->input_tensors_[node->input_tensors_.size() - 1]->aligned_size_ += kGapSize;
if (node->input_tensors_[0]->aligned_size_) {
node->input_tensors_[0]->aligned_size_ += kGapSize;
}
if (node->input_tensors_[node->input_tensors_.size() - 1]->aligned_size_) {
node->input_tensors_[node->input_tensors_.size() - 1]->aligned_size_ += kGapSize;
}
std::vector<size_t> inputs; std::vector<size_t> inputs;
for (const auto &input_tensor : node->input_tensors_) { for (const auto &input_tensor : node->input_tensors_) {
comm_input_total_size_ += input_tensor->aligned_size_; comm_input_total_size_ += input_tensor->aligned_size_;
@@ -887,8 +902,12 @@ void Somas::GenContiguousList(const session::KernelGraph *graph) {


// Contiguous output // Contiguous output
if ((!node->output_tensors_.empty()) && (!node->output_tensors_[0]->contiguous_)) { if ((!node->output_tensors_.empty()) && (!node->output_tensors_[0]->contiguous_)) {
node->output_tensors_[0]->aligned_size_ += kGapSize;
node->output_tensors_[node->output_tensors_.size() - 1]->aligned_size_ += kGapSize;
if (node->output_tensors_[0]->aligned_size_) {
node->output_tensors_[0]->aligned_size_ += kGapSize;
}
if (node->output_tensors_[node->output_tensors_.size() - 1]->aligned_size_) {
node->output_tensors_[node->output_tensors_.size() - 1]->aligned_size_ += kGapSize;
}
std::vector<size_t> outputs; std::vector<size_t> outputs;
for (const auto &output_tensor : node->output_tensors_) { for (const auto &output_tensor : node->output_tensors_) {
comm_output_total_size_ += output_tensor->aligned_size_; comm_output_total_size_ += output_tensor->aligned_size_;
@@ -1097,17 +1116,33 @@ bool Somas::Assign(const session::KernelGraph *graph) {
// Ref Node Preprocessing // Ref Node Preprocessing
UpdateRefTensorsConflict(); UpdateRefTensorsConflict();
std::map<size_t, size_t> contiguous_list_with_ref_index_map = GetContiguousListContainRefTensor(); std::map<size_t, size_t> contiguous_list_with_ref_index_map = GetContiguousListContainRefTensor();
vector<vector<size_t>> contiguous_tensors_list_removed_ref = contiguous_tensors_list_;
vector<vector<size_t>> contiguous_tensors_list_removed = contiguous_tensors_list_;
std::set<vector<size_t>> contiguous_tensors_list_to_remove; std::set<vector<size_t>> contiguous_tensors_list_to_remove;
for (auto ref_list_pair : contiguous_list_with_ref_index_map) { for (auto ref_list_pair : contiguous_list_with_ref_index_map) {
contiguous_tensors_list_to_remove.insert(contiguous_tensors_list_[ref_list_pair.second]); contiguous_tensors_list_to_remove.insert(contiguous_tensors_list_[ref_list_pair.second]);
} }


// remove the contiguous list which all tensors' align size is 0
for (auto contiguous_list : contiguous_tensors_list_) {
bool all_outputs = true;
for (auto tensor_id : contiguous_list) {
auto tensor = tensors_list_[tensor_id];
if (tensor->aligned_size_ != 0) {
all_outputs = false;
break;
}
}

if (all_outputs) {
contiguous_tensors_list_to_remove.insert(contiguous_list);
}
}

for (auto contiguous_list : contiguous_tensors_list_to_remove) { for (auto contiguous_list : contiguous_tensors_list_to_remove) {
auto iterator = std::find(contiguous_tensors_list_removed_ref.begin(), contiguous_tensors_list_removed_ref.end(),
contiguous_list);
if (iterator != contiguous_tensors_list_removed_ref.end()) {
contiguous_tensors_list_removed_ref.erase(iterator);
auto iterator =
std::find(contiguous_tensors_list_removed.begin(), contiguous_tensors_list_removed.end(), contiguous_list);
if (iterator != contiguous_tensors_list_removed.end()) {
contiguous_tensors_list_removed.erase(iterator);
} else { } else {
MS_LOG(WARNING) << "Could not find contiguous list to remove for ref"; MS_LOG(WARNING) << "Could not find contiguous list to remove for ref";
} }
@@ -1142,7 +1177,7 @@ bool Somas::Assign(const session::KernelGraph *graph) {


somas_solver_ = std::make_shared<SomasSolverPre>(); somas_solver_ = std::make_shared<SomasSolverPre>();
auto status = auto status =
somas_solver_->Solving(graph, &solver_tensor_desc_map_, &reuse_matrix_, contiguous_tensors_list_removed_ref, false);
somas_solver_->Solving(graph, &solver_tensor_desc_map_, &reuse_matrix_, contiguous_tensors_list_removed, false);
MS_LOG(INFO) << "End Solving"; MS_LOG(INFO) << "End Solving";
if (status != SUCCESS) { if (status != SUCCESS) {
GenGraphStatisticInfo(); GenGraphStatisticInfo();


+ 37
- 3
mindspore/ccsrc/runtime/device/ascend/ascend_memory_pool.cc View File

@@ -14,6 +14,7 @@
* limitations under the License. * limitations under the License.
*/ */


#include <algorithm>
#include "runtime/device/ascend/ascend_memory_pool.h" #include "runtime/device/ascend/ascend_memory_pool.h"
#include "runtime/device/ascend/ascend_kernel_runtime.h" #include "runtime/device/ascend/ascend_kernel_runtime.h"
#include "utils/log_adapter.h" #include "utils/log_adapter.h"
@@ -21,6 +22,9 @@
namespace mindspore { namespace mindspore {
namespace device { namespace device {
namespace ascend { namespace ascend {
// The minimum unit size (256MB) of memory block used for dynamic extend.
static const size_t ASCEND_DYNAMIC_MEM_ALLOC_UNIT_SIZE = 256 << 20;

void AscendMemoryPool::Init(uint8_t *device_mem_base, uint64_t device_mem_size, uint64_t dynamic_mem_offset) { void AscendMemoryPool::Init(uint8_t *device_mem_base, uint64_t device_mem_size, uint64_t dynamic_mem_offset) {
static bool initialized = false; static bool initialized = false;
if (initialized) { if (initialized) {
@@ -40,11 +44,43 @@ void AscendMemoryPool::Init(uint8_t *device_mem_base, uint64_t device_mem_size,
initialized = true; initialized = true;
} }


size_t AscendMemoryPool::CalMemBlockAllocSize(size_t size) {
auto device_free_mem_size = free_mem_size();
if (device_free_mem_size < size) {
MS_LOG(EXCEPTION) << "Memory not enough: current free memory size[" << device_free_mem_size
<< "] is smaller than required size[" << size << "], dynamic offset ["
<< graph_dynamic_mem_offset_ << "] memory pool offset["
<< device_mem_size_ - device_mem_pool_offset_ << "])";
return 0;
}
auto alloc_mem_size = ASCEND_DYNAMIC_MEM_ALLOC_UNIT_SIZE;
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
const bool pynative_mode = (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode);
if (pynative_mode) {
// Growing at twice of alloc size
while (alloc_mem_size < size) {
alloc_mem_size = alloc_mem_size * 2;
}
} else {
while (alloc_mem_size < size) {
alloc_mem_size = alloc_mem_size + ASCEND_DYNAMIC_MEM_ALLOC_UNIT_SIZE;
}
}
alloc_mem_size = std::min(alloc_mem_size, device_free_mem_size);
return alloc_mem_size;
}

size_t AscendMemoryPool::AllocDeviceMem(size_t size, DeviceMemPtr *addr) { size_t AscendMemoryPool::AllocDeviceMem(size_t size, DeviceMemPtr *addr) {
MS_LOG(INFO) << "Malloc Memory: Pool, total[" << device_mem_size_ << "] (dynamic[" << graph_dynamic_mem_offset_
<< "] memory pool[" << device_mem_size_ - device_mem_pool_offset_ << "])"
<< " malloc [" << size << "]";

if (size == 0) { if (size == 0) {
MS_LOG(EXCEPTION) << "Failed to alloc memory pool resource, the size is zero!"; MS_LOG(EXCEPTION) << "Failed to alloc memory pool resource, the size is zero!";
} }
if (device_mem_pool_offset_ - size <= graph_dynamic_mem_offset_) {

if (device_mem_pool_offset_ - size < graph_dynamic_mem_offset_) {
MS_LOG(EXCEPTION) << "Failed to alloc memory pool memory, the current device_mem_pool_offset_ [" MS_LOG(EXCEPTION) << "Failed to alloc memory pool memory, the current device_mem_pool_offset_ ["
<< device_mem_pool_offset_ << "], current graph_dynamic_mem_offset_ " << graph_dynamic_mem_offset_ << device_mem_pool_offset_ << "], current graph_dynamic_mem_offset_ " << graph_dynamic_mem_offset_
<< "], need memory size [" << size << "]"; << "], need memory size [" << size << "]";
@@ -76,8 +112,6 @@ size_t AscendMemoryPool::AlignMemorySize(size_t size) const {
return size; return size;
} }


size_t AscendMemoryPool::mem_alloc_unit_size() const { return DYNAMIC_MEM_ALLOC_UNIT_SIZE / 4; }

void AscendMemoryPool::set_device_mem_pool_base(uint8_t *device_mem_pool_base) { void AscendMemoryPool::set_device_mem_pool_base(uint8_t *device_mem_pool_base) {
MS_EXCEPTION_IF_NULL(device_mem_pool_base); MS_EXCEPTION_IF_NULL(device_mem_pool_base);
device_mem_pool_base_ = device_mem_pool_base; device_mem_pool_base_ = device_mem_pool_base;


+ 2
- 2
mindspore/ccsrc/runtime/device/ascend/ascend_memory_pool.h View File

@@ -50,8 +50,8 @@ class AscendMemoryPool : public DynamicMemPoolBestFit {
protected: protected:
// The real size by memory alloc aligned. // The real size by memory alloc aligned.
size_t AlignMemorySize(size_t size) const override; size_t AlignMemorySize(size_t size) const override;
// Get the minimum memory unit size using for dynamic extend.
size_t mem_alloc_unit_size() const override;
// Calculate memory block required alloc size when adding the memory block.
size_t CalMemBlockAllocSize(size_t size) override;


private: private:
AscendMemoryPool() = default; AscendMemoryPool() = default;


+ 2
- 1
mindspore/ccsrc/runtime/device/kernel_runtime.cc View File

@@ -340,7 +340,8 @@ void KernelRuntime::AssignStaticMemoryInput(const session::KernelGraph *graph) {
#endif #endif
auto tensor_size = AnfAlgo::GetOutputTensorMemSize(item, index); auto tensor_size = AnfAlgo::GetOutputTensorMemSize(item, index);
device_address = CreateDeviceAddress(nullptr, tensor_size, AnfAlgo::GetOutputFormat(item, index), output_type_id); device_address = CreateDeviceAddress(nullptr, tensor_size, AnfAlgo::GetOutputFormat(item, index), output_type_id);
MS_LOG(DEBUG) << "Malloc static memory for " << item->fullname_with_scope();
MS_LOG(INFO) << "Malloc Input for graph " << graph->graph_id() << ", node: " << item->fullname_with_scope()
<< " index: " << index << " size: " << tensor_size;
if (mem_manager_->MallocMem(kStaticMem, tensor_size, device_address, graph->graph_id()) == nullptr) { if (mem_manager_->MallocMem(kStaticMem, tensor_size, device_address, graph->graph_id()) == nullptr) {
MS_LOG(EXCEPTION) << "Cannot alloc address when flag is: " << kStaticMem << ", tensor size is: " << tensor_size; MS_LOG(EXCEPTION) << "Cannot alloc address when flag is: " << kStaticMem << ", tensor size is: " << tensor_size;
} }


+ 7
- 64
mindspore/ccsrc/runtime/device/memory_manager.cc View File

@@ -65,10 +65,12 @@ void MemoryManager::MallocSomasDynamicMem(const session::KernelGraph *graph) {


size_t total_allocated_size = somas_reuse_util_ptr->GetTotalMemSize(); size_t total_allocated_size = somas_reuse_util_ptr->GetTotalMemSize();
MS_LOG(INFO) << "Graph " << graph->graph_id() << ": TotalSomasReuseDynamicSize [" << total_allocated_size << "]"; MS_LOG(INFO) << "Graph " << graph->graph_id() << ": TotalSomasReuseDynamicSize [" << total_allocated_size << "]";
auto base_ptr = MallocDynamicMem(total_allocated_size, false);
MS_LOG(INFO) << "Somas Reuse Memory Base Address [" << static_cast<void *>(base_ptr) << "], End Address ["
<< static_cast<void *>(base_ptr + total_allocated_size) << "]";
somas_reuse_util_ptr->set_mem_base_addr(base_ptr);
if (total_allocated_size > 0) {
auto base_ptr = MallocDynamicMem(total_allocated_size, false);
MS_LOG(INFO) << "Somas Reuse Memory Base Address [" << static_cast<void *>(base_ptr) << "], End Address ["
<< static_cast<void *>(base_ptr + total_allocated_size) << "]";
somas_reuse_util_ptr->set_mem_base_addr(base_ptr);
}


auto context_ptr = MsContext::GetInstance(); auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr); MS_EXCEPTION_IF_NULL(context_ptr);
@@ -166,66 +168,7 @@ uint8_t *MemoryManager::MallocMem(MemType type, size_t size, const DeviceAddress
return ptr; return ptr;
} }


uint8_t *MemoryManager::MallocStaticMem(size_t size, bool communication_mem, uint32_t graph_id) {
size_t align_size = 0;
if (communication_mem) {
align_size = GetCommunicationAlignSize(size);
} else {
align_size = GetCommonAlignSize(size);
}

MS_LOG(INFO) << "Malloc Memory for Static: total[" << device_mem_size_ << "](dynamic[" << total_dynamic_size_
<< "] static[" << total_static_size_ << "])"
<< " malloc [" << align_size << "] communication_mem: " << communication_mem;

if (static_mem_offset_ < align_size) {
MS_LOG(EXCEPTION) << "Out of memory!!! total[" << device_mem_size_ << "](dynamic[" << total_dynamic_size_
<< "] static[" << total_static_size_ << "])"
<< " malloc [" << align_size << "] failed!";
}
total_static_size_ += align_size;
auto offset = static_mem_offset_ - align_size;
if (dynamic_mem_offset_ > offset) {
MS_LOG(EXCEPTION) << "Out of memory!!! total[" << device_mem_size_ << "](dynamic[" << total_dynamic_size_
<< "] static[" << total_static_size_ << "])"
<< " malloc [" << align_size << "] failed!";
}
static_mem_offset_ = offset;
if (communication_mem) {
return device_mem_base_ + offset + kMemAlignSize;
} else {
return device_mem_base_ + offset;
}
}

uint8_t *MemoryManager::MallocDynamicMem(size_t size, bool communication_mem) {
size_t align_size = 0;
if (communication_mem) {
align_size = GetCommunicationAlignSize(size);
} else {
align_size = GetCommonAlignSize(size);
}

MS_LOG(INFO) << "Malloc Memory for Dynamic: total[" << device_mem_size_ << "](dynamic[" << total_dynamic_size_
<< "] static[" << total_static_size_ << "])"
<< " malloc [" << align_size << "] communication_mem: " << communication_mem;

uint64_t offset = dynamic_mem_offset_;
auto new_offset = dynamic_mem_offset_ + align_size;
if (new_offset > static_mem_offset_) {
MS_LOG(EXCEPTION) << "Out of memory!!! total[" << device_mem_size_ << "](dynamic[" << total_dynamic_size_
<< "] static[" << total_static_size_ << "])"
<< " malloc [" << align_size << "] failed!";
}
total_dynamic_size_ += align_size;
dynamic_mem_offset_ = new_offset;

if (communication_mem) {
return device_mem_base_ + offset + kMemAlignSize;
} else {
return device_mem_base_ + offset;
}
}
uint8_t *MemoryManager::MallocDynamicMem(size_t size, bool communication_mem) { return nullptr; }


bool MemoryManager::MallocMemFromMemPool(const DeviceAddressPtr address, size_t size) { bool MemoryManager::MallocMemFromMemPool(const DeviceAddressPtr address, size_t size) {
auto device_ptr = MallocMemFromMemPool(size); auto device_ptr = MallocMemFromMemPool(size);


+ 1
- 1
mindspore/ccsrc/runtime/device/memory_manager.h View File

@@ -64,7 +64,7 @@ class MemoryManager {
size_t GetCommunicationAlignSize(size_t input_size) const; size_t GetCommunicationAlignSize(size_t input_size) const;


protected: protected:
virtual uint8_t *MallocStaticMem(size_t size, bool communication_mem, uint32_t graph_id = kInvalidGraphId);
virtual uint8_t *MallocStaticMem(size_t size, bool communication_mem, uint32_t graph_id = kInvalidGraphId) = 0;
virtual uint8_t *MallocDynamicMem(size_t size, bool communication_mem); virtual uint8_t *MallocDynamicMem(size_t size, bool communication_mem);
uint8_t *device_mem_base_{nullptr}; uint8_t *device_mem_base_{nullptr};
uint64_t device_mem_size_{0}; uint64_t device_mem_size_{0};


+ 6
- 2
tests/st/ops/ascend/test_tbe_ops/test_relu.py View File

@@ -20,7 +20,7 @@ from mindspore import Tensor
from mindspore.common.api import ms_function from mindspore.common.api import ms_function
from mindspore.ops import operations as P from mindspore.ops import operations as P


context.set_context(device_target="Ascend")
context.set_context(device_target="Ascend", mode=context.GRAPH_MODE, variable_memory_max_size="31GB")




class Net(nn.Cell): class Net(nn.Cell):
@@ -34,8 +34,12 @@ class Net(nn.Cell):




def test_net(): def test_net():
x = np.random.randn(2, 3, 3, 4).astype(np.float32)
# size (31GB/2/-512)s/ize(float32) 4160749440
x = np.random.randn(16, 120, 2167057).astype(np.float32)
relu = Net() relu = Net()
output = relu(Tensor(x)) output = relu(Tensor(x))
expect = 1 * (x > 0) * x
print(x) print(x)
print(output.asnumpy()) print(output.asnumpy())
print(expect)
assert (output.asnumpy() == expect).all()

Loading…
Cancel
Save