-
Notifications
You must be signed in to change notification settings - Fork 205
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
single-gpu generation for integration testing #640
base: main
Are you sure you want to change the base?
single-gpu generation for integration testing #640
Conversation
4b2bfed
to
9af29a7
Compare
044542c
to
6d94ac6
Compare
QQ, what is the threshold of a good eval. Shall we just use common sense to tell if the output makes sense or not? @jaysonfrancis |
@fduwjj Initially yes. Most require enabling distributed inference to be useful, but sharing my thoughts below;
|
@jaysonfrancis thanks for your answer, do we expect the pretrained model actually to offer meaningful answers always? If so, why people still do lots of fine tune and RL learning? |
Thanks @fduwjj -- Hmm depends on the downstream task, there may be use cases that want to infer natively from the underlying data distribution alone without any external biases from instruct/reward tuning. Anyways, my thoughts for this we're more focused on test-driven generation and catching any issues that may go unnoticed during dev. |
@jaysonfrancis Maybe can you kindly add documentations on the rationale and user manual on how to run this generate. So that folks know how to use it, what it is used for and how to tell the check is successful or not. WDYT? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for putting up the scripts! Looks awesome in general. I had some questions.
682ea2f
to
b29cbb2
Compare
4ebac97
to
e5e55fe
Compare
Question regarding multi-device generation (cc @tianyu-l @fduwjj) I've got a solution working with TP, but I had to avoid sharding inputs along sequence dims (including norms) due to prompt length increasing at each step. In order to avoid making code changes , I can just apply custom plan within this script. I can also utilize a boolean if we'd like. Curious for your thoughts or if I missed something. Thanks! |
For that we can have a boolean config
Could you elaborate a bit? cc: @kwen2501 on distributed inference in torchtitan with a minimal script. |
Sure, as it is now, the input during generation eventually becomes odd length, so the sequence dimension (1) cannot be sharded evenly across ranks. Below is an example workaround, where parallelize_module(
model,
tp_mesh,
{
"tok_embeddings": RowwiseParallel(
input_layouts=Replicate(),
output_layouts=Shard(1) if is_train else Replicate(), # <-- New
),
"norm": SequenceParallel(),
"output": ColwiseParallel(
input_layouts=Shard(1),
output_layouts=Shard(-1) if loss_parallel else Replicate(),
use_local_output=not loss_parallel,
),
},
) |
This I understood. But why odd sequence dimension cannot be sharded? IIUC, the TP API is able to handle uneven sharding (by auto padding in case it is uneven). I somehow feel as long as loss parallel is disabled, everything should work. Can you show the error log when sequence dim is sharded? Thanks! |
tokenizer.encode(prompt, bos=True, eos=False), dtype=torch.long | ||
) | ||
.view(1, -1) | ||
.repeat(batch_size, 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Question: Why we would run batch_size
number of same inference tasks? I thought batch_size
can be useful when we have multiple distinct prompts. In our case maybe it's not necessary?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was thinking supporting batch might be useful for profiling bw/memory on fwd pass alone. Currently defaults to 1, but I can remove
Okay good to know, probably an implementation issue somewhere else in my script. I will troubleshoot. To clarify, "tok_embeddings": RowwiseParallel(
input_layouts=Replicate(),
output_layouts=Replicate(), # Shard(1)
), Traceback
Traceback (most recent call last):
File "/opt/jf/envs/p311/lib/python3.11/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 355, in wrapper
return f(*args, **kwargs)
^^^^^^^^^^^^^^^^^^
File "/home/jf/oss/torchtitan/test/generate/test_generate.py", line 139, in example_generate
responses, _ = generate(
^^^^^^^^^
File "/opt/jf/envs/p311/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/home/jf/oss/torchtitan/test/generate/generation.py", line 65, in generate
next_token, logits = generate_next_token(
^^^^^^^^^^^^^^^^^^^^
File "/home/jf/oss/torchtitan/test/generate/generation.py", line 42, in generate_next_token
logits = model(x) # (B, T, vocab_size)
^^^^^^^^
File "/opt/jf/envs/p311/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/jf/envs/p311/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jf/oss/torchtitan/torchtitan/models/llama/model.py", line 443, in forward
h = layer(h, self.freqs_cis)
^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/jf/envs/p311/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/jf/envs/p311/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jf/oss/torchtitan/torchtitan/models/llama/model.py", line 324, in forward
h = x + self.attention(self.attention_norm(x), freqs_cis)
~~^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
File "/opt/jf/envs/p311/lib/python3.11/site-packages/torch/distributed/_functional_collectives.py", line 645, in __torch_dispatch__
unwrapped_args = tree_map_only(AsyncCollectiveTensor, unwrap, args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/jf/envs/p311/lib/python3.11/site-packages/torch/utils/_pytree.py", line 1136, in tree_map_only
return tree_map(map_only(__type_or_types_or_pred)(func), tree, is_leaf=is_leaf)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/jf/envs/p311/lib/python3.11/site-packages/torch/utils/_pytree.py", line 964, in tree_map
return treespec.unflatten(map(func, *flat_args))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/jf/envs/p311/lib/python3.11/site-packages/torch/utils/_pytree.py", line 803, in unflatten
leaves = list(leaves)
^^^^^^^^^^^^
File "/opt/jf/envs/p311/lib/python3.11/site-packages/torch/utils/_pytree.py", line 1082, in wrapped
return func(x)
^^^^^^^
File "/opt/jf/envs/p311/lib/python3.11/site-packages/torch/distributed/_functional_collectives.py", line 636, in unwrap
return e.trigger_wait()
^^^^^^^^^^^^^^^^
File "/opt/jf/envs/p311/lib/python3.11/site-packages/torch/distributed/_functional_collectives.py", line 609, in trigger_wait
out = wait_tensor(self.elem)
^^^^^^^^^^^^^^^^^^^^^^
File "/opt/jf/envs/p311/lib/python3.11/site-packages/torch/distributed/_functional_collectives.py", line 141, in wait_tensor
return torch.ops._c10d_functional.wait_tensor(tensor) # type: ignore[attr-defined]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/jf/envs/p311/lib/python3.11/site-packages/torch/_ops.py", line 1123, in __call__
return self._op(*args, **(kwargs or {}))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
torch.distributed.DistBackendError: [Rank 1] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=12, OpType=COALESCED, NumelIn=49152, NumelOut=24576, Timeout(ms)=300000) ran for 300000 milliseconds before timing out.
|
Hmm sounds like there might be a bug in dealing with uneven sharding somewhere. I see that the timeout happens on
It would be great if you could provide a minimal repro on this issue, so we can take it from there. |
1) Single device run (non dist) ✅NGPU=1 CONFIG_FILE=./train_configs/llama3_8b.toml \
CHECKPOINT_DIR=./outputs/L318B_DCP/ \
PROMPT="What is the meaning of life? That is a great question. It is" \
./test/generate/run_llama_pred.sh --max_new_tokens=32 --temperature=0.8 --seed=3 Output
2) Multiple device run (activating dist) ❌NGPU=2 CONFIG_FILE=./train_configs/llama3_8b.toml \
CHECKPOINT_DIR=./outputs/L318B_DCP/ \
PROMPT="What is the meaning of life? That is a great question. It is" \
./test/generate/run_llama_pred.sh --max_new_tokens=32 --temperature=0.8 --seed=3 Causes timeout, same Collective stack traces
Collective 4 at entry 3 errors for group 0:default_pg collective nccl:gather [[48262]] [[0]] 2 completed
Found errors: Culprit rank 0; Error type: COLLECTIVE_DTYPE_MISMATCH, Expected dtypes: '['Byte']/['Byte']' does not match found dtype: '['UNKNOWN_SCALAR']/['Byte', 'Byte']'.
Collective stack traces:
Collective 6 at entry 5 errors for group 0:default_pg collective nccl:scatter [[48262]] [[0]] 2 completed
Found errors: Culprit rank 0; Error type: COLLECTIVE_DTYPE_MISMATCH, Expected dtypes: '['Byte']/['Byte']' does not match found dtype: '['UNKNOWN_SCALAR']/['Byte', 'Byte']'.
Collective stack traces:
Collective 7 at entry 6 errors for group 0:default_pg collective nccl:scatter [[1]] [[0]] 2 completed
Found errors: Culprit rank 0; Error type: COLLECTIVE_DTYPE_MISMATCH, Expected dtypes: '['Long']/['Long']' does not match found dtype: '['UNKNOWN_SCALAR']/['Long', 'Long']'.
Collective stack traces:
Collective 142 at entry 141 errors for group 0:default_pg collective nccl:all_gather_into_tensor_coalesced [[1, 8, 4096]] [[2, 8, 4096]] 2 completed
Found errors: Culprit rank 0; Error type: COLLECTIVE_STATE_MISMATCH, Expected state: 'completed' does not match found state: 'scheduled'.
Collective stack traces:
Collective 143 at entry 142 errors for group 0:default_pg collective nccl:reduce_scatter_tensor_coalesced [[2, 8, 4096]] [[1, 8, 4096]] 2 scheduled
Found errors: Culprit rank 0; Error type: SIZE_OR_SYNTAX_MISMATCH, Expected input sizes: '[[2, 8, 4096]]' does not match found input sizes: '[[2, 9, 4096]]'.
Collective stack traces:
Collective 144 at entry 143 errors for group 0:default_pg collective nccl:all_gather_into_tensor_coalesced [[1, 8, 4096]] [[2, 8, 4096]] 2 scheduled
Found errors: Culprit rank 0; Error type: SIZE_OR_SYNTAX_MISMATCH, Expected input sizes: '[[1, 8, 4096]]' does not match found input sizes: '[[1, 9, 4096]]'.
Collective stack traces:
Collective 145 at entry 144 errors for group 0:default_pg collective nccl:reduce_scatter_tensor_coalesced [[2, 8, 4096]] [[1, 8, 4096]] 2 scheduled
Found errors: Culprit rank 0; Error type: SIZE_OR_SYNTAX_MISMATCH, Expected input sizes: '[[2, 8, 4096]]' does not match found input sizes: '[[2, 9, 4096]]'.
Collective stack traces:
Collective 146 at entry 145 errors for group 0:default_pg collective nccl:all_gather_into_tensor_coalesced [[1, 8, 4096]] [[2, 8, 4096]] 2 scheduled
Found errors: Culprit rank 0; Error type: SIZE_OR_SYNTAX_MISMATCH, Expected input sizes: '[[1, 8, 4096]]' does not match found input sizes: '[[1, 9, 4096]]'.
Collective stack traces:
Collective 148 at entry 147 errors for group 0:default_pg collective nccl:all_gather_into_tensor_coalesced [[1, 8, 4096]] [[2, 8, 4096]] 2 scheduled
Found errors: Culprit rank 0; Error type: SIZE_OR_SYNTAX_MISMATCH, Expected input sizes: '[[1, 8, 4096]]' does not match found input sizes: '[[1, 9, 4096]]'.
Collective stack traces:
Collective 149 at entry 148 errors for group 0:default_pg collective nccl:reduce_scatter_tensor_coalesced [[2, 8, 4096]] [[1, 8, 4096]] 2 scheduled
Found errors: Culprit rank 0; Error type: SIZE_OR_SYNTAX_MISMATCH, Expected input sizes: '[[2, 8, 4096]]' does not match found input sizes: '[[2, 9, 4096]]'.
Collective stack traces:
Too many mismatches for process_group 09000494312501845833:default_pg, aborting 3) Multiple device run (matching response) ✅Temporarily added tp_plan from torchchat, flag here: torchtitan/test/generate/test_generate.py Line 107 in febda82
NGPU=2 CONFIG_FILE=./train_configs/llama3_8b.toml \
CHECKPOINT_DIR=./outputs/L318B_DCP/ \
PROMPT="What is the meaning of life? That is a great question. It is" \
./test/generate/run_llama_pred.sh --max_new_tokens=32 --temperature=0.8 --seed=3 Output
Other details
def generate_hack()
@torch.no_grad()
def generate_hack(
model,
input_ids: torch.Tensor,
*,
max_new_tokens: int,
temperature: float = 1.0,
top_k: Optional[int] = None,
) -> torch.Tensor:
"""
FOR DEMO.
Constant input shape using prefill padding and sliding window.
This seems to work with with NGPU>1 without code changes to TP strategy
"""
# ensure batch dimension (T,) --> (B, T)
if input_ids.ndim == 1:
input_ids = input_ids.unsqueeze(0)
# Pre-fill input with 0 to maintain fixed seqlen
B, T = input_ids.shape
seqlen = T + max_new_tokens
padded_input = torch.full((B, seqlen), 0, dtype=input_ids.dtype, device=input_ids.device)
padded_input[:, -T:] = input_ids
generated_tokens = padded_input.clone()
for _ in range(max_new_tokens):
next_token = generate_next_token(
model,
x=generated_tokens,
temperature=temperature,
top_k=top_k,
)
# shift left and append next_token
generated_tokens = torch.cat([generated_tokens[:, 1:], next_token], dim=1)
return generated_tokens |
@jaysonfrancis Nice work on generation! It would be valuable to have evaluation in the training loop, but I see there are challenges with enabling it due to TP and PP. |
2936bfa
to
d89e0ea
Compare
Added simple generation tool to help validate integration of artifacts (converted checkpoints, tokenizers, etc)
Update 10/29: [WIP] working on supporting single-host TP generation