|
|
|
@@ -62,8 +62,8 @@ def init(backend_name=None): |
|
|
|
""" |
|
|
|
if _is_role_pserver() or _is_role_sched(): |
|
|
|
return |
|
|
|
device_target = context.get_context("device_target") |
|
|
|
if backend_name is None: |
|
|
|
device_target = context.get_context("device_target") |
|
|
|
if device_target == "Ascend": |
|
|
|
backend_name = "hccl" |
|
|
|
elif device_target == "GPU": |
|
|
|
@@ -74,6 +74,8 @@ def init(backend_name=None): |
|
|
|
raise TypeError("Backend name must be a string, but got {}".format(type(backend_name))) |
|
|
|
|
|
|
|
if backend_name == "hccl": |
|
|
|
if device_target != "Ascend": |
|
|
|
raise RuntimeError("Device target should be 'Ascend' to init hccl, but got {}".format(device_target)) |
|
|
|
init_hccl() |
|
|
|
GlobalComm.BACKEND = Backend("hccl") |
|
|
|
GlobalComm.WORLD_COMM_GROUP = HCCL_WORLD_COMM_GROUP |
|
|
|
|