|
|
|
@@ -14,6 +14,7 @@ |
|
|
|
# ============================================================================ |
|
|
|
"""Context for parameter server training mode""" |
|
|
|
|
|
|
|
import os |
|
|
|
from mindspore._c_expression import PSContext |
|
|
|
|
|
|
|
_ps_context = None |
|
|
|
@@ -134,4 +135,10 @@ def _clone_hash_table(dest_param_name, src_param_name): |
|
|
|
ps_context().clone_hash_table(dest_param_name, src_param_name) |
|
|
|
|
|
|
|
def _set_cache_enable(cache_enable): |
|
|
|
# Environment variables are used to specify a maximum number of OpenBLAS threads: |
|
|
|
# In ubuntu(GPU) environment, numpy will use too many threads for computing, |
|
|
|
if cache_enable: |
|
|
|
os.environ['OPENBLAS_NUM_THREADS'] = '2' |
|
|
|
os.environ['GOTO_NUM_THREADS'] = '2' |
|
|
|
os.environ['OMP_NUM_THREADS'] = '2' |
|
|
|
ps_context().set_cache_enable(cache_enable) |