Source code for cpu.distributed

"""The code of this module is modified from:

- https://github.com/facebookresearch/detectron2/blob/main/detectron2/utils/comm.py
- https://github.com/pytorch/vision/blob/main/references/detection/utils.py
"""
import functools
import logging
import os
import socket
from typing import Any, Dict, List, Optional, Tuple

import torch
import torch.distributed as dist
from torch import Tensor
from torch._C._distributed_c10d import ProcessGroup

__all__ = [
    "all_gather", "gather", "reduce_dict", "setup_print_for_distributed", "get_world_size",
    "get_rank", "is_main_process", "init_distributed"
]

logger = logging.getLogger(__name__)


@functools.lru_cache()
def _get_global_gloo_group() -> ProcessGroup:
    """Return a process group based on gloo backend, containing all ranks.
    The result is cached.
    """
    if dist.get_backend() == "nccl":
        return dist.new_group(backend="gloo")
    else:
        return dist.group.WORLD


[docs] def all_gather(data: Any, group: Optional[ProcessGroup] = None) -> List[Any]: """Run :meth:`all_gather` on arbitrary picklable data (not necessarily tensors). Args: data: Any picklable object. group (ProcessGroup): A torch process group. By default, will use a group which contains all ranks on gloo backend. Returns: list[data]: List of data gathered from each rank. """ if get_world_size() == 1: return [data] if group is None: group = _get_global_gloo_group() # use CPU group by default, to reduce GPU RAM usage. world_size = dist.get_world_size(group) if world_size == 1: return [data] output = [None for _ in range(world_size)] dist.all_gather_object(output, data, group=group) return output
[docs] def gather(data: Any, dst: int = 0, group: Optional[ProcessGroup] = None) -> List[Any]: """Run :meth:`gather` on arbitrary picklable data (not necessarily tensors). Args: data: Any picklable object. dst (int): Destination rank. group (ProcessGroup): A torch process group. By default, will use a group which contains all ranks on ``gloo`` backend. Returns: list[data]: On ``dst``, a list of data gathered from each rank. Otherwise, an empty list. """ if get_world_size() == 1: return [data] if group is None: group = _get_global_gloo_group() world_size = dist.get_world_size(group) if world_size == 1: return [data] if dist.get_rank(group) == dst: output = [None for _ in range(world_size)] dist.gather_object(data, output, dst=dst, group=group) return output else: dist.gather_object(data, None, dst=dst, group=group) return []
[docs] def reduce_dict(input_dict: Dict[str, Tensor], average: bool = True) -> Dict[str, Tensor]: """Reduce the values in the dictionary from all processes so that all processes have the averaged results. Args: input_dict (dict): All the values will be reduced. average (bool): Whether to do average or sum. Returns: dict: A dict with the same fields as input_dict, after reduction. """ world_size = get_world_size() if world_size < 2: return input_dict with torch.no_grad(): names = [] values = [] # sort the keys so that they are consistent across processes for k in sorted(input_dict.keys()): names.append(k) values.append(input_dict[k]) values = torch.stack(values, dim=0) dist.all_reduce(values) if average: values /= world_size reduced_dict = {k: v for k, v in zip(names, values)} return reduced_dict
[docs] def setup_print_for_distributed(is_master: bool) -> None: """This function disables printing when not in master process. Args: is_master (bool): If the current process is the master process or not. """ import builtins builtin_print = builtins.print def print(*args, **kwargs): force = kwargs.pop("force", False) if is_master or force: builtin_print(*args, **kwargs) builtins.print = print
[docs] def get_world_size() -> int: """Return the number of processes in the current process group.""" if not dist.is_available() or not dist.is_initialized(): return 1 return dist.get_world_size()
[docs] def get_rank() -> int: """Return the rank of the current process in the current process group.""" if not dist.is_available() or not dist.is_initialized(): return 0 return dist.get_rank()
[docs] def is_main_process() -> bool: """Return if the current process is the master process or not.""" return get_rank() == 0
def _is_free_port(port: int) -> bool: ips = socket.gethostbyname_ex(socket.gethostname())[-1] ips.append("localhost") with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: return all(s.connect_ex((ip, port)) != 0 for ip in ips) def _find_free_port() -> int: sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) # Binding to port 0 will cause the OS to find an available port for us sock.bind(("", 0)) port = sock.getsockname()[1] sock.close() # NOTE: there is still a chance the port could be taken by other processes. return port
[docs] def init_distributed(auto: bool = False) -> Tuple[int]: """Initialize the distributed mode as follows: - Initialize the process group, with ``backend="nccl"`` and ``init_method="env://"``. - Set correct cuda device. - Disable printing when not in master process. Args: auto (bool): If True, when MASTER_PORT is not free, automatically find a free one. Defaults to False. Returns: tuple: (``rank``, ``local_rank``, ``world_size``) """ if "RANK" in os.environ and "WORLD_SIZE" in os.environ: # launched by `torch.distributed.launch` rank = int(os.environ["RANK"]) world_size = int(os.environ["WORLD_SIZE"]) local_rank = int(os.environ["LOCAL_RANK"]) elif "SLURM_PROCID" in os.environ: # launched by slurm rank = int(os.environ["SLURM_PROCID"]) world_size = int(os.environ["SLURM_NTASKS"]) local_rank = rank % torch.cuda.device_count() else: print("Not using distributed mode.") return 0, 0, 1 assert "MASTER_ADDR" in os.environ and "MASTER_PORT" in os.environ, ( "init_method='env://' requires the two environment variables: " "MASTER_ADDR and MASTER_PORT.") if auto: assert os.environ["MASTER_ADDR"] == "127.0.0.1", ( "`auto` is not supported in multi-machine jobs.") port = os.environ["MASTER_PORT"] if not _is_free_port(port): new_port = _find_free_port() print(f"Port {port} is not free, use port {new_port} instead.") os.environ["MASTER_PORT"] = new_port print(f"| distributed init (rank {rank})", flush=True) dist.init_process_group(backend="nccl", init_method="env://", rank=rank, world_size=world_size) dist.barrier() torch.cuda.set_device(local_rank) setup_print_for_distributed(rank == 0) return rank, local_rank, world_size