Skip to content

Commit

Permalink
polish code style
Browse files Browse the repository at this point in the history
  • Loading branch information
dujiangsu committed May 18, 2022
1 parent 6136636 commit f29bfeb
Show file tree
Hide file tree
Showing 43 changed files with 1,207 additions and 1,184 deletions.
10 changes: 10 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
repos:
- repo: https://github.com/pre-commit/mirrors-yapf
rev: v0.32.0
hooks:
- id: yapf
args: ['--style=.style.yapf', '--parallel', '--in-place']
- repo: https://github.com/pre-commit/mirrors-clang-format
rev: v13.0.1
hooks:
- id: clang-format
5 changes: 5 additions & 0 deletions .style.yapf
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
[style]
based_on_style = google
spaces_before_comment = 4
split_before_logical_operator = true
column_limit = 120
4 changes: 3 additions & 1 deletion energon/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,16 @@

app = typer.Typer()


@app.callback()
def callback():
"""
Typer app, including Click subapp
"""


typer_click_object = typer.main.get_command(app)
typer_click_object.add_command(service, "service")

if __name__ == "__main__":
typer_click_object()
typer_click_object()
37 changes: 19 additions & 18 deletions energon/cli/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,7 @@ def launches(model_class=None,
f'Worker Server Host: {server_host} \n'
f'Worker Server Port: {server_port} \n'
f'Unvicorn Log Level: {log_level} \n'
f'Remove padding: {rm_padding} \n'
)
f'Remove padding: {rm_padding} \n')

if half:
dtype = torch.half
Expand All @@ -51,7 +50,7 @@ def launches(model_class=None,

engine_port = server_port
worker_port = server_port + 1
worker_rank = 1 # start from 1
worker_rank = 1 # start from 1

process_list = []
for i in range(num_worker):
Expand All @@ -63,21 +62,23 @@ def launches(model_class=None,

sig_server = inspect.signature(engine_server)
parameters_server = sig_server.parameters

cfg = {'model_class' : model_class,
'model_type' : model_type,
'max_batch_size' : max_batch_size,
'tp_init_size' : tp_init_size,
'pp_init_size' : pp_init_size,
'host' : host,
'port' : port,
'dtype' : dtype,
'checkpoint' : checkpoint,
'tokenizer_path' : tokenizer_path,
'server_host' : server_host,
'server_port' : engine_port,
'log_level' : log_level,
'rm_padding' : rm_padding}

cfg = {
'model_class': model_class,
'model_type': model_type,
'max_batch_size': max_batch_size,
'tp_init_size': tp_init_size,
'pp_init_size': pp_init_size,
'host': host,
'port': port,
'dtype': dtype,
'checkpoint': checkpoint,
'tokenizer_path': tokenizer_path,
'server_host': server_host,
'server_port': engine_port,
'log_level': log_level,
'rm_padding': rm_padding
}

argv = dict()
for name, _ in parameters_server.items():
Expand Down
31 changes: 20 additions & 11 deletions energon/communication/__init__.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,26 @@
from .collective import all_gather, reduce_scatter, all_reduce, broadcast, reduce
from .p2p import (send_forward, send_forward_recv_forward,
send_backward_recv_forward, send_backward,
send_backward_recv_backward, send_forward_recv_backward,
send_forward_backward_recv_forward_backward, recv_forward,
recv_backward)
from .p2p import (send_forward, send_forward_recv_forward, send_backward_recv_forward, send_backward,
send_backward_recv_backward, send_forward_recv_backward, send_forward_backward_recv_forward_backward,
recv_forward, recv_backward)
from .ring import ring_forward
from .utils import send_tensor_meta, recv_tensor_meta

__all__ = [
'all_gather', 'reduce_scatter', 'all_reduce', 'broadcast', 'reduce',
'send_forward', 'send_forward_recv_forward',
'send_forward_backward_recv_forward_backward', 'send_backward',
'send_backward_recv_backward', 'send_backward_recv_forward',
'send_forward_recv_backward', 'recv_backward', 'recv_forward',
'ring_forward', 'send_tensor_meta', 'recv_tensor_meta',
'all_gather',
'reduce_scatter',
'all_reduce',
'broadcast',
'reduce',
'send_forward',
'send_forward_recv_forward',
'send_forward_backward_recv_forward_backward',
'send_backward',
'send_backward_recv_backward',
'send_backward_recv_forward',
'send_forward_recv_backward',
'recv_backward',
'recv_forward',
'ring_forward',
'send_tensor_meta',
'recv_tensor_meta',
]
2 changes: 1 addition & 1 deletion energon/communication/collective.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def scatter_object_list(scatter_object_output_list, scatter_object_input_list, s
# set tensor device to cuda if backend is nccl
device = torch.cuda.current_device() if dist.get_backend(group) == 'nccl' else torch.device("cpu")

my_rank = dist.get_rank() # use global rank
my_rank = dist.get_rank() # use global rank
if my_rank == src:
tensor_list, tensor_sizes = zip(
*[dist.distributed_c10d._object_to_tensor(obj) for obj in scatter_object_input_list])
Expand Down
32 changes: 13 additions & 19 deletions energon/communication/p2p.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import operator
from .utils import split_tensor_into_1d_equal_chunks, gather_split_1d_tensor


TensorShape = Union[torch.Size, List[int], Tuple[int]]


Expand Down Expand Up @@ -88,13 +87,11 @@ def _communicate(tensor_send_next=None,

if tensor_send_prev is not None or recv_prev:
if prev_rank is None:
prev_rank = gpc.get_prev_global_rank(
ParallelMode.PIPELINE)
prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE)

if tensor_send_next is not None or recv_next:
if next_rank is None:
next_rank = gpc.get_next_global_rank(
ParallelMode.PIPELINE)
next_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE)

if tensor_send_prev is not None:
send_prev_split = _get_tensor_shape(tensor_send_prev.shape, scatter_gather_tensors)[1]
Expand Down Expand Up @@ -184,9 +181,7 @@ def send_forward(output_tensor, next_rank=None, scatter_gather_tensors=False):
:type next_rank: int, optional
"""
if not gpc.is_pipeline_last_stage():
_communicate(tensor_send_next=output_tensor,
next_rank=next_rank,
scatter_gather_tensors=scatter_gather_tensors)
_communicate(tensor_send_next=output_tensor, next_rank=next_rank, scatter_gather_tensors=scatter_gather_tensors)


def send_backward(input_tensor_grad, prev_rank=None, scatter_gather_tensors=False):
Expand Down Expand Up @@ -342,15 +337,14 @@ def send_forward_backward_recv_forward_backward(output_tensor,
:return: (the input tensor in forward step, the grad of output tensor in forward step)
:rtype: (Tensor, Tensor)
"""
input_tensor, output_tensor_grad = _communicate(
tensor_send_next=output_tensor,
tensor_send_prev=input_tensor_grad,
recv_prev=recv_prev,
recv_next=recv_next,
recv_prev_shape=input_tensor_shape,
recv_next_shape=output_grad_shape,
prev_rank=prev_rank,
next_rank=next_rank,
dtype=dtype,
scatter_gather_tensors=scatter_gather_tensors)
input_tensor, output_tensor_grad = _communicate(tensor_send_next=output_tensor,
tensor_send_prev=input_tensor_grad,
recv_prev=recv_prev,
recv_next=recv_next,
recv_prev_shape=input_tensor_shape,
recv_next_shape=output_grad_shape,
prev_rank=prev_rank,
next_rank=next_rank,
dtype=dtype,
scatter_gather_tensors=scatter_gather_tensors)
return input_tensor, output_tensor_grad
10 changes: 4 additions & 6 deletions energon/communication/ring.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,13 @@ def ring_forward(tensor_send_next: torch.Tensor, parallel_mode: ParallelMode):
dtype=tensor_send_next.dtype)

# send to next rank
send_next_op = torch.distributed.P2POp(
torch.distributed.isend, tensor_send_next,
gpc.get_next_global_rank(parallel_mode))
send_next_op = torch.distributed.P2POp(torch.distributed.isend, tensor_send_next,
gpc.get_next_global_rank(parallel_mode))
ops.append(send_next_op)

# receive from prev rank
recv_prev_op = torch.distributed.P2POp(
torch.distributed.irecv, tensor_recv_prev,
gpc.get_prev_global_rank(parallel_mode))
recv_prev_op = torch.distributed.P2POp(torch.distributed.irecv, tensor_recv_prev,
gpc.get_prev_global_rank(parallel_mode))
ops.append(recv_prev_op)

if current_rank % 2 == 0:
Expand Down
10 changes: 3 additions & 7 deletions energon/communication/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,7 @@ def split_tensor_into_1d_equal_chunks(tensor, new_buffer=False):
start_index = partition_size * gpc.get_local_rank(ParallelMode.PARALLEL_1D)
end_index = start_index + partition_size
if new_buffer:
data = torch.empty(partition_size, dtype=tensor.dtype,
device=torch.cuda.current_device(),
requires_grad=False)
data = torch.empty(partition_size, dtype=tensor.dtype, device=torch.cuda.current_device(), requires_grad=False)
data.copy_(tensor.view(-1)[start_index:end_index])
else:
data = tensor.view(-1)[start_index:end_index]
Expand All @@ -101,9 +99,7 @@ def gather_split_1d_tensor(tensor):
world_size = gpc.get_world_size(ParallelMode.PARALLEL_1D)
numel = torch.numel(tensor)
numel_gathered = world_size * numel
gathered = torch.empty(numel_gathered, dtype=tensor.dtype,
device=torch.cuda.current_device(),
requires_grad=False)
chunks = [gathered[i*numel:(i+1)*numel] for i in range(world_size)]
gathered = torch.empty(numel_gathered, dtype=tensor.dtype, device=torch.cuda.current_device(), requires_grad=False)
chunks = [gathered[i * numel:(i + 1) * numel] for i in range(world_size)]
dist.all_gather(chunks, tensor, group=gpc.get_group(ParallelMode.PARALLEL_1D))
return gathered
41 changes: 14 additions & 27 deletions energon/context/parallel_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def add_local_rank(self, parallel_mode: ParallelMode, rank: int):
"""
self._check_parallel_mode(parallel_mode)
self._local_ranks[parallel_mode] = rank

def rm_local_rank(self, parallel_mode: ParallelMode):
"""Removes the local rank of the current device for `parallel_mode` to the context.
"""
Expand Down Expand Up @@ -327,20 +327,14 @@ def add_ranks_in_group(self, parallel_mode: ParallelMode, ranks: list):
"""
self._check_parallel_mode(parallel_mode)
self._ranks_in_group[parallel_mode] = ranks

def rm_ranks_in_group(self, parallel_mode: ParallelMode):
"""Removes the ranks of the current device for `parallel_mode` in the group.
"""
self._check_parallel_mode(parallel_mode)
self._ranks_in_group.pop(parallel_mode)

def init_global_dist(self,
rank: int,
world_size: int,
backend: str,
host: str,
port: int
):
def init_global_dist(self, rank: int, world_size: int, backend: str, host: str, port: int):
"""Initializes the global distributed environment
:param rank: rank for the default process group
:type rank: int
Expand All @@ -355,10 +349,7 @@ def init_global_dist(self,
"""
# initialize the default process group
init_method = f'tcp://{host}:{port}'
dist.init_process_group(rank=rank,
world_size=world_size,
backend=backend,
init_method=init_method)
dist.init_process_group(rank=rank, world_size=world_size, backend=backend, init_method=init_method)
ranks = list(range(world_size))
# None will give the default global process group for pytorch dist operations
cpu_group = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else None
Expand Down Expand Up @@ -399,8 +390,7 @@ def _set_parallel_size_from_config(self, config: dict, key: str, attr_name: str)
setattr(self, attr_name, ele['size'])
else:
raise NotImplementedError(
f"Parallel configuration does not support this kind of argument, please use int or dict"
)
f"Parallel configuration does not support this kind of argument, please use int or dict")

def init_parallel_groups(self):
"""Initializes the parallel groups.
Expand Down Expand Up @@ -463,12 +453,10 @@ def init_parallel_groups(self):
for initializer_cfg in pg_init:
cfg = initializer_cfg.copy()
initializer_type = cfg.pop('type')
initializer = DIST_GROUP_INITIALIZER.get_module(initializer_type)(
rank, world_size, self.config,
self.data_parallel_size,
self.pipeline_parallel_size,
self.tensor_parallel_size,
**cfg)
initializer = DIST_GROUP_INITIALIZER.get_module(initializer_type)(rank, world_size, self.config,
self.data_parallel_size,
self.pipeline_parallel_size,
self.tensor_parallel_size, **cfg)
parallel_setting = initializer.init_dist_group()
if isinstance(parallel_setting, list):
for args in parallel_setting:
Expand Down Expand Up @@ -504,7 +492,7 @@ def destroy_vice_groups(self):
if mode is not ParallelMode.GLOBAL:
modes.append(mode)
dist.destroy_process_group(group)

for mode in modes:
self._deregister_dist(mode)

Expand Down Expand Up @@ -555,10 +543,9 @@ def set_seed(self, seed: int):
seed_str = ', '.join([f'{k}: {v}' for k, v in seeds.items()])

if self._verbose:
self._logger.info(
f"initialized seed on rank {global_rank}, "
f"numpy: {seed}, python random: {seed}, {seed_str},"
f"the default parallel seed is {ParallelMode.DATA}.")
self._logger.info(f"initialized seed on rank {global_rank}, "
f"numpy: {seed}, python random: {seed}, {seed_str},"
f"the default parallel seed is {ParallelMode.DATA}.")
else:
if self._verbose:
self._logger.info(
Expand All @@ -573,4 +560,4 @@ def set_virtual_pipeline_parallel_size(self, size):
self.virtual_pipeline_parallel_size = size

def set_virtual_pipeline_parallel_rank(self, rank):
self.virtual_pipeline_parallel_rank = rank
self.virtual_pipeline_parallel_rank = rank
2 changes: 1 addition & 1 deletion energon/engine/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .engine import InferenceEngine
from .engine import InferenceEngine
Loading

0 comments on commit f29bfeb

Please sign in to comment.