Skip to content

Commit

Permalink
Merge branch 'master' into olruwase/ds_5241
Browse files Browse the repository at this point in the history
  • Loading branch information
tjruwase authored Jan 30, 2025
2 parents c1b87ea + 065ca8a commit 5da6cd0
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 5 deletions.
19 changes: 19 additions & 0 deletions deepspeed/launcher/multinode_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,12 +134,31 @@ def name(self):

def validate_args(self):
super().validate_args()

# Validate and set MPI environment variables
self._setup_mpi_environment()

#TODO: Allow for include/exclude at node-level but not gpu-level
if self.args.include != "" or self.args.exclude != "":
raise ValueError(f"{self.name} backend does not support worker include/exclusion")
if self.args.num_nodes != -1 or self.args.num_gpus != -1:
raise ValueError(f"{self.name} backend does not support limiting num nodes/gpus")

def _setup_mpi_environment(self):
"""Sets up MPI-related environment variables or raises an error if they're missing."""

required_vars = ['OMPI_COMM_WORLD_LOCAL_RANK', 'OMPI_COMM_WORLD_RANK', 'OMPI_COMM_WORLD_SIZE']

# Check if all these are present
if not all(var in os.environ for var in required_vars):
raise EnvironmentError("MPI environment variables are not set. "
"Ensure you are running the script with an MPI-compatible launcher.")

# Now safe to read all
os.environ['LOCAL_RANK'] = os.environ['OMPI_COMM_WORLD_LOCAL_RANK']
os.environ['RANK'] = os.environ['OMPI_COMM_WORLD_RANK']
os.environ['WORLD_SIZE'] = os.environ['OMPI_COMM_WORLD_SIZE']

def get_cmd(self, environment, active_resources):
total_process_count = sum(self.resource_pool.values())

Expand Down
71 changes: 66 additions & 5 deletions tests/unit/launcher/test_multinode_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,14 @@ def runner_info():
return env, hosts, world_info, args


@pytest.fixture
def mock_mpi_env(monkeypatch):
# Provide the 3 required MPI variables:
monkeypatch.setenv('OMPI_COMM_WORLD_LOCAL_RANK', '0')
monkeypatch.setenv('OMPI_COMM_WORLD_RANK', '0')
monkeypatch.setenv('OMPI_COMM_WORLD_SIZE', '1')


def test_pdsh_runner(runner_info):
env, resource_pool, world_info, args = runner_info
runner = mnrunner.PDSHRunner(args, world_info)
Expand All @@ -27,34 +35,87 @@ def test_pdsh_runner(runner_info):
assert env['PDSH_RCMD_TYPE'] == 'ssh'


def test_openmpi_runner(runner_info):
def test_openmpi_runner(runner_info, mock_mpi_env):
env, resource_pool, world_info, args = runner_info
runner = mnrunner.OpenMPIRunner(args, world_info, resource_pool)
cmd = runner.get_cmd(env, resource_pool)
assert cmd[0] == 'mpirun'
assert 'eth0' in cmd


def test_btl_nic_openmpi_runner(runner_info):
def test_btl_nic_openmpi_runner(runner_info, mock_mpi_env):
env, resource_pool, world_info, _ = runner_info
args = parse_args(['--launcher_arg', '-mca btl_tcp_if_include eth1', 'test_launcher.py'])

runner = mnrunner.OpenMPIRunner(args, world_info, resource_pool)
cmd = runner.get_cmd(env, resource_pool)
assert 'eth0' not in cmd
assert 'eth1' in cmd


def test_btl_nic_two_dashes_openmpi_runner(runner_info):
def test_btl_nic_two_dashes_openmpi_runner(runner_info, mock_mpi_env):
env, resource_pool, world_info, _ = runner_info
args = parse_args(['--launcher_arg', '--mca btl_tcp_if_include eth1', 'test_launcher.py'])

runner = mnrunner.OpenMPIRunner(args, world_info, resource_pool)
cmd = runner.get_cmd(env, resource_pool)
assert 'eth0' not in cmd
assert 'eth1' in cmd


def test_setup_mpi_environment_success():
"""Test that _setup_mpi_environment correctly sets environment variables when MPI variables exist."""
os.environ['OMPI_COMM_WORLD_LOCAL_RANK'] = '0'
os.environ['OMPI_COMM_WORLD_RANK'] = '1'
os.environ['OMPI_COMM_WORLD_SIZE'] = '2'

args = parse_args(['--launcher_arg', '--mca btl_tcp_if_include eth1', 'test_launcher.py'])

runner = mnrunner.OpenMPIRunner(args, None, None)
# Set up the MPI environment
runner._setup_mpi_environment()

assert os.environ['LOCAL_RANK'] == '0'
assert os.environ['RANK'] == '1'
assert os.environ['WORLD_SIZE'] == '2'

# Clean up environment
del os.environ['OMPI_COMM_WORLD_LOCAL_RANK']
del os.environ['OMPI_COMM_WORLD_RANK']
del os.environ['OMPI_COMM_WORLD_SIZE']
del os.environ['LOCAL_RANK']
del os.environ['RANK']
del os.environ['WORLD_SIZE']


def test_setup_mpi_environment_missing_variables():
"""Test that _setup_mpi_environment raises an EnvironmentError when MPI variables are missing."""

# Clear relevant environment variables
os.environ.pop('OMPI_COMM_WORLD_LOCAL_RANK', None)
os.environ.pop('OMPI_COMM_WORLD_RANK', None)
os.environ.pop('OMPI_COMM_WORLD_SIZE', None)

args = parse_args(['--launcher_arg', '--mca btl_tcp_if_include eth1', 'test_launcher.py'])

with pytest.raises(EnvironmentError, match="MPI environment variables are not set"):
mnrunner.OpenMPIRunner(args, None, None)


def test_setup_mpi_environment_fail():
"""Test that _setup_mpi_environment fails if only partial MPI variables are provided."""
os.environ['OMPI_COMM_WORLD_LOCAL_RANK'] = '0'
os.environ.pop('OMPI_COMM_WORLD_RANK', None) # missing variable
os.environ['OMPI_COMM_WORLD_SIZE'] = '2'

args = parse_args(['--launcher_arg', '--mca btl_tcp_if_include eth1', 'test_launcher.py'])

with pytest.raises(EnvironmentError, match="MPI environment variables are not set"):
runner = mnrunner.OpenMPIRunner(args, None, None)

# Clean up environment
del os.environ['OMPI_COMM_WORLD_LOCAL_RANK']
del os.environ['OMPI_COMM_WORLD_SIZE']


def test_mpich_runner(runner_info):
env, resource_pool, world_info, args = runner_info
runner = mnrunner.MPICHRunner(args, world_info, resource_pool)
Expand Down

0 comments on commit 5da6cd0

Please sign in to comment.