Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

single-gpu generation for integration testing #640

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from

Conversation

jaysonfrancis
Copy link

@jaysonfrancis jaysonfrancis commented Oct 22, 2024

Added simple generation tool to help validate integration of artifacts (converted checkpoints, tokenizers, etc)

Update 10/29: [WIP] working on supporting single-host TP generation

> python test/generate/test_generate.py \
  --config=./train_configs/debug_model.toml \
  --checkpoint=./outputs/checkpoint/ \
  --prompt="The meaning of life is" \
  --seed=14 \
  --batch_size=2 > out.json
> cat out.json
[
    {
        "response_idx": 0,
        "input_n_tokens": 6,
        "output_n_tokens": 14,
        "input_text": "<|begin_of_text|>The meaning of life is",
        "output_text": " as old as humankind. There have been more books about meaning than"
    },
    {
        "response_idx": 1,
        "input_n_tokens": 6,
        "output_n_tokens": 14,
        "input_text": "<|begin_of_text|>The meaning of life is",
        "output_text": " something that has puzzled philosophers over the millennia, but for lack of convincing"
    }
]

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Oct 22, 2024
@jaysonfrancis jaysonfrancis marked this pull request as ready for review October 22, 2024 23:04
@fduwjj
Copy link
Contributor

fduwjj commented Oct 25, 2024

QQ, what is the threshold of a good eval. Shall we just use common sense to tell if the output makes sense or not? @jaysonfrancis

@jaysonfrancis
Copy link
Author

@fduwjj Initially yes. Most require enabling distributed inference to be useful, but sharing my thoughts below;

  1. SW/HW Integrity; Receiving a response without errors (binary)

  2. Model Integrity; Receiving a response that makes sense (qualitative). Receiving a deterministic and expected response (e.g match cached responses after sw/hw/driver or unrelated model updates)

  3. Tokenization consistency; special tokens being parsed, in/out token counts match expectations.

  4. Checkpoints; converted, pruned and/or manipulated weights load & fwd pass w/o errors (also helpful for continued pre-training)

  5. Stress test/sizing of max seq lengths; mem utilization vs sequence length given network arch & compute budget

  6. Distributed Inference; devices/comms working for native PP/TP and showcasing perf benchmarks

  7. Evals; batch inference to support train-time eval loops for collecting early proxy metrics other then NLL

@fduwjj
Copy link
Contributor

fduwjj commented Oct 25, 2024

@jaysonfrancis thanks for your answer, do we expect the pretrained model actually to offer meaningful answers always? If so, why people still do lots of fine tune and RL learning?

@jaysonfrancis
Copy link
Author

Thanks @fduwjj -- Hmm depends on the downstream task, there may be use cases that want to infer natively from the underlying data distribution alone without any external biases from instruct/reward tuning.

Anyways, my thoughts for this we're more focused on test-driven generation and catching any issues that may go unnoticed during dev.

@fduwjj
Copy link
Contributor

fduwjj commented Oct 25, 2024

@jaysonfrancis Maybe can you kindly add documentations on the rationale and user manual on how to run this generate. So that folks know how to use it, what it is used for and how to tell the check is successful or not. WDYT?

@jaysonfrancis jaysonfrancis marked this pull request as draft October 29, 2024 01:34
Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for putting up the scripts! Looks awesome in general. I had some questions.

test/generate/test_generate.py Outdated Show resolved Hide resolved
test/generate/generation.py Outdated Show resolved Hide resolved
test/generate/test_generate.py Outdated Show resolved Hide resolved
@jaysonfrancis
Copy link
Author

Thank you for putting up the scripts! Looks awesome in general. I had some questions.

@tianyu-l @fduwjj Thank you for review. Looking to support TP, will see how far I can get 😶‍🌫️ – I'll follow up with a doc/use shortly after

@tianyu-l tianyu-l linked an issue Oct 31, 2024 that may be closed by this pull request
@jaysonfrancis
Copy link
Author

Question regarding multi-device generation (cc @tianyu-l @fduwjj)

I've got a solution working with TP, but I had to avoid sharding inputs along sequence dims (including norms) due to prompt length increasing at each step.

In order to avoid making code changes , I can just apply custom plan within this script. I can also utilize a boolean if we'd like.

Curious for your thoughts or if I missed something. Thanks!

@tianyu-l
Copy link
Contributor

In order to avoid making code changes , I can just apply custom plan within this script. I can also utilize a boolean if we'd like.

For that we can have a boolean config training.enable_sequence_parallel. But I didn't understand why

I had to avoid sharding inputs along sequence dims (including norms) due to prompt length increasing at each step.

Could you elaborate a bit?

cc: @kwen2501 on distributed inference in torchtitan with a minimal script.

@jaysonfrancis
Copy link
Author

jaysonfrancis commented Oct 31, 2024

Could you elaborate a bit?

Sure, as it is now, the input during generation eventually becomes odd length, so the sequence dimension (1) cannot be sharded evenly across ranks.

Below is an example workaround, where is_train set to False when running this generate script.

parallelize_module(
        model,
        tp_mesh,
        {
            "tok_embeddings": RowwiseParallel(
                input_layouts=Replicate(),
                output_layouts=Shard(1) if is_train else Replicate(), # <-- New
            ),
            "norm": SequenceParallel(),
            "output": ColwiseParallel(
                input_layouts=Shard(1),
                output_layouts=Shard(-1) if loss_parallel else Replicate(),
                use_local_output=not loss_parallel,
            ),
        },
    )

@tianyu-l
Copy link
Contributor

Sure, as it is now, the input during generation eventually becomes odd length, so the sequence dimension (1) cannot be sharded evenly across ranks.

This I understood. But why odd sequence dimension cannot be sharded? IIUC, the TP API is able to handle uneven sharding (by auto padding in case it is uneven). I somehow feel as long as loss parallel is disabled, everything should work. Can you show the error log when sequence dim is sharded? Thanks!

tokenizer.encode(prompt, bos=True, eos=False), dtype=torch.long
)
.view(1, -1)
.repeat(batch_size, 1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question: Why we would run batch_size number of same inference tasks? I thought batch_size can be useful when we have multiple distinct prompts. In our case maybe it's not necessary?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking supporting batch might be useful for profiling bw/memory on fwd pass alone. Currently defaults to 1, but I can remove

@jaysonfrancis
Copy link
Author

jaysonfrancis commented Oct 31, 2024

This I understood. But why odd sequence dimension cannot be sharded? IIUC, the TP API is able to handle uneven sharding

Okay good to know, probably an implementation issue somewhere else in my script. I will troubleshoot.

To clarify, Shard(1) on tok_embedding output is what initially triggers the timeout. (Leaving everything else unchanged) (tested on both Ampere & Hopper).

"tok_embeddings": RowwiseParallel(
      input_layouts=Replicate(),
      output_layouts=Replicate(), # Shard(1) 
 ), 
Traceback

Traceback (most recent call last):
  File "/opt/jf/envs/p311/lib/python3.11/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 355, in wrapper
    return f(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^
  File "/home/jf/oss/torchtitan/test/generate/test_generate.py", line 139, in example_generate
    responses, _ = generate(
                   ^^^^^^^^^
  File "/opt/jf/envs/p311/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/jf/oss/torchtitan/test/generate/generation.py", line 65, in generate
    next_token, logits = generate_next_token(
                         ^^^^^^^^^^^^^^^^^^^^
  File "/home/jf/oss/torchtitan/test/generate/generation.py", line 42, in generate_next_token
    logits = model(x)  # (B, T, vocab_size)
             ^^^^^^^^
  File "/opt/jf/envs/p311/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/jf/envs/p311/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jf/oss/torchtitan/torchtitan/models/llama/model.py", line 443, in forward
    h = layer(h, self.freqs_cis)
        ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/jf/envs/p311/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/jf/envs/p311/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jf/oss/torchtitan/torchtitan/models/llama/model.py", line 324, in forward
    h = x + self.attention(self.attention_norm(x), freqs_cis)
        ~~^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  File "/opt/jf/envs/p311/lib/python3.11/site-packages/torch/distributed/_functional_collectives.py", line 645, in __torch_dispatch__
    unwrapped_args = tree_map_only(AsyncCollectiveTensor, unwrap, args)
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/jf/envs/p311/lib/python3.11/site-packages/torch/utils/_pytree.py", line 1136, in tree_map_only
    return tree_map(map_only(__type_or_types_or_pred)(func), tree, is_leaf=is_leaf)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/jf/envs/p311/lib/python3.11/site-packages/torch/utils/_pytree.py", line 964, in tree_map
    return treespec.unflatten(map(func, *flat_args))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/jf/envs/p311/lib/python3.11/site-packages/torch/utils/_pytree.py", line 803, in unflatten
    leaves = list(leaves)
             ^^^^^^^^^^^^
  File "/opt/jf/envs/p311/lib/python3.11/site-packages/torch/utils/_pytree.py", line 1082, in wrapped
    return func(x)
           ^^^^^^^
  File "/opt/jf/envs/p311/lib/python3.11/site-packages/torch/distributed/_functional_collectives.py", line 636, in unwrap
    return e.trigger_wait()
           ^^^^^^^^^^^^^^^^
  File "/opt/jf/envs/p311/lib/python3.11/site-packages/torch/distributed/_functional_collectives.py", line 609, in trigger_wait
    out = wait_tensor(self.elem)
          ^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/jf/envs/p311/lib/python3.11/site-packages/torch/distributed/_functional_collectives.py", line 141, in wait_tensor
    return torch.ops._c10d_functional.wait_tensor(tensor)  # type: ignore[attr-defined]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/jf/envs/p311/lib/python3.11/site-packages/torch/_ops.py", line 1123, in __call__
    return self._op(*args, **(kwargs or {}))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
torch.distributed.DistBackendError: [Rank 1] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=12, OpType=COALESCED, NumelIn=49152, NumelOut=24576, Timeout(ms)=300000) ran for 300000 milliseconds before timing out.
[rank1]:[E1031 22:42:30.764730470 ProcessGroupNCCL.cpp:1484] [PG ID 0 PG GUID 0(default_pg) Rank 1] Received a dump signal due to a collective timeout from this local rank and we will try our best to dump the debug info. Last enqueued NCCL work: 13, last completed NCCL work: 11.This is most likely caused by incorrect usages of collectives, e.g., wrong sizes used across ranks, the order of collectives is not same for all ranks or the scheduled collective, for some reason, didn't run. Additionally, this can be caused by GIL deadlock or other reasons such as network errors or bugs in the communications library (e.g. NCCL), etc.

@tianyu-l
Copy link
Contributor

tianyu-l commented Nov 1, 2024

Shard(1) on tok_embedding output is what initially triggers the timeout.

Hmm sounds like there might be a bug in dealing with uneven sharding somewhere. I see that the timeout happens on

h = x + self.attention(self.attention_norm(x), freqs_cis)

It would be great if you could provide a minimal repro on this issue, so we can take it from there.
cc: @wz337

@jaysonfrancis
Copy link
Author

jaysonfrancis commented Nov 1, 2024

1) Single device run (non dist) ✅

NGPU=1 CONFIG_FILE=./train_configs/llama3_8b.toml \
  CHECKPOINT_DIR=./outputs/L318B_DCP/ \
  PROMPT="What is the meaning of life? That is a great question. It is" \
  ./test/generate/run_llama_pred.sh --max_new_tokens=32 --temperature=0.8 --seed=3
Output

{
    "metadata": {
        "generated_n_tokens": 32,
        "input_n_tokens": 16,
        "generation_time_sec": 1.3203406818211079,
        "tokens_per_sec": 36.35425361111723,
        "batch_size": 1,
        "seed": 3,
        "timestamp": "2024-11-02T02:48:12",
        "memory/max_active(GiB)": 29.974102020263672,
        "memory/max_active(%)": 37.83867588127289,
        "memory/max_reserved(GiB)": 33.85546875,
        "memory/max_reserved(%)": 42.73843159584148,
        "memory/num_alloc_retries": 0,
        "memory/num_ooms": 0,
        "world_size": 1,
        "torch_version": "2.6.0.dev20241101+cu121"
    },
    "responses": [
        {
            "response_idx": 0,
            "input_text": "<|begin_of_text|>What is the meaning of life? That is a great question. It is",
            "output_text": " also a topic that has been debated for centuries and has been a fundamental topic of philosophy. The answer can be found in many places and in many different ways."
        }
    ]
}

2) Multiple device run (activating dist) ❌

NGPU=2 CONFIG_FILE=./train_configs/llama3_8b.toml \
  CHECKPOINT_DIR=./outputs/L318B_DCP/ \
  PROMPT="What is the meaning of life? That is a great question. It is" \
  ./test/generate/run_llama_pred.sh --max_new_tokens=32 --temperature=0.8 --seed=3

Causes timeout, same Traceback in my reply above. Below is compressed trace from flight-recorder

Collective stack traces

Collective 4 at entry 3 errors  for group 0:default_pg collective nccl:gather  [[48262]] [[0]] 2 completed  
Found errors: Culprit rank 0; Error type: COLLECTIVE_DTYPE_MISMATCH, Expected dtypes: '['Byte']/['Byte']' does not match found dtype: '['UNKNOWN_SCALAR']/['Byte', 'Byte']'.
 
Collective stack traces: 
Collective 6 at entry 5 errors  for group 0:default_pg collective nccl:scatter  [[48262]] [[0]] 2 completed  
Found errors: Culprit rank 0; Error type: COLLECTIVE_DTYPE_MISMATCH, Expected dtypes: '['Byte']/['Byte']' does not match found dtype: '['UNKNOWN_SCALAR']/['Byte', 'Byte']'.
 
Collective stack traces: 
Collective 7 at entry 6 errors  for group 0:default_pg collective nccl:scatter  [[1]] [[0]] 2 completed  
Found errors: Culprit rank 0; Error type: COLLECTIVE_DTYPE_MISMATCH, Expected dtypes: '['Long']/['Long']' does not match found dtype: '['UNKNOWN_SCALAR']/['Long', 'Long']'.
 
Collective stack traces: 
Collective 142 at entry 141 errors  for group 0:default_pg collective nccl:all_gather_into_tensor_coalesced  [[1, 8, 4096]] [[2, 8, 4096]] 2 completed  
Found errors: Culprit rank 0; Error type: COLLECTIVE_STATE_MISMATCH, Expected state: 'completed' does not match found state: 'scheduled'.
 
Collective stack traces: 
Collective 143 at entry 142 errors  for group 0:default_pg collective nccl:reduce_scatter_tensor_coalesced  [[2, 8, 4096]] [[1, 8, 4096]] 2 scheduled  
Found errors: Culprit rank 0; Error type: SIZE_OR_SYNTAX_MISMATCH, Expected input sizes: '[[2, 8, 4096]]' does not match found input sizes: '[[2, 9, 4096]]'.
 
Collective stack traces: 
Collective 144 at entry 143 errors  for group 0:default_pg collective nccl:all_gather_into_tensor_coalesced  [[1, 8, 4096]] [[2, 8, 4096]] 2 scheduled  
Found errors: Culprit rank 0; Error type: SIZE_OR_SYNTAX_MISMATCH, Expected input sizes: '[[1, 8, 4096]]' does not match found input sizes: '[[1, 9, 4096]]'.
 
Collective stack traces: 
Collective 145 at entry 144 errors  for group 0:default_pg collective nccl:reduce_scatter_tensor_coalesced  [[2, 8, 4096]] [[1, 8, 4096]] 2 scheduled  
Found errors: Culprit rank 0; Error type: SIZE_OR_SYNTAX_MISMATCH, Expected input sizes: '[[2, 8, 4096]]' does not match found input sizes: '[[2, 9, 4096]]'.
 
Collective stack traces: 
Collective 146 at entry 145 errors  for group 0:default_pg collective nccl:all_gather_into_tensor_coalesced  [[1, 8, 4096]] [[2, 8, 4096]] 2 scheduled  
Found errors: Culprit rank 0; Error type: SIZE_OR_SYNTAX_MISMATCH, Expected input sizes: '[[1, 8, 4096]]' does not match found input sizes: '[[1, 9, 4096]]'.
 
Collective stack traces: 
Collective 148 at entry 147 errors  for group 0:default_pg collective nccl:all_gather_into_tensor_coalesced  [[1, 8, 4096]] [[2, 8, 4096]] 2 scheduled  
Found errors: Culprit rank 0; Error type: SIZE_OR_SYNTAX_MISMATCH, Expected input sizes: '[[1, 8, 4096]]' does not match found input sizes: '[[1, 9, 4096]]'.
 
Collective stack traces: 
Collective 149 at entry 148 errors  for group 0:default_pg collective nccl:reduce_scatter_tensor_coalesced  [[2, 8, 4096]] [[1, 8, 4096]] 2 scheduled  
Found errors: Culprit rank 0; Error type: SIZE_OR_SYNTAX_MISMATCH, Expected input sizes: '[[2, 8, 4096]]' does not match found input sizes: '[[2, 9, 4096]]'.
 
Collective stack traces: 
Too many mismatches for process_group 09000494312501845833:default_pg, aborting

3) Multiple device run (matching response) ✅

Temporarily added tp_plan from torchchat, flag here:

use_torchchat_tp = False

NGPU=2 CONFIG_FILE=./train_configs/llama3_8b.toml \
  CHECKPOINT_DIR=./outputs/L318B_DCP/ \
  PROMPT="What is the meaning of life? That is a great question. It is" \
  ./test/generate/run_llama_pred.sh --max_new_tokens=32 --temperature=0.8 --seed=3
Output

{
    "metadata": {
        "generated_n_tokens": 32,
        "input_n_tokens": 16,
        "generation_time_sec": 2.821475793607533,
        "tokens_per_sec": 17.012373492181304,
        "batch_size": 1,
        "seed": 3,
        "timestamp": "2024-11-02T03:07:33",
        "memory/max_active(GiB)": 15.05003833770752,
        "memory/max_active(%)": 18.998851818021414,
        "memory/max_reserved(GiB)": 15.263671875,
        "memory/max_reserved(%)": 19.268538301690388,
        "memory/num_alloc_retries": 0,
        "memory/num_ooms": 0,
        "world_size": 2,
        "torch_version": "2.6.0.dev20241101+cu121"
    },
    "responses": [
        {
            "response_idx": 0,
            "input_text": "<|begin_of_text|>What is the meaning of life? That is a great question. It is",
            "output_text": " also a topic that has been debated for centuries and has been a fundamental topic of philosophy. The answer can be found in many places and in many different ways."
        }
    ]
}


Other details

  • The checkpoint I am using is via scripts/convert_llama_to_dcp.py Llama 3.1-8B/original --> DCP
  • Also tested via building pytorch from source, with and without P2P
  • If I force the input to be fixed length, it runs, but is not a valid solution. Below is a example.
def generate_hack()

@torch.no_grad()
def generate_hack(
    model,
    input_ids: torch.Tensor,
    *,
    max_new_tokens: int,
    temperature: float = 1.0,
    top_k: Optional[int] = None,
) -> torch.Tensor:
    """
    FOR DEMO.
    Constant input shape using prefill padding and sliding window.
    This seems to work with with NGPU>1 without code changes to TP strategy
    """

    # ensure batch dimension (T,) --> (B, T)
    if input_ids.ndim == 1:
        input_ids = input_ids.unsqueeze(0)

    # Pre-fill input with 0 to maintain fixed seqlen
    B, T = input_ids.shape
    seqlen = T + max_new_tokens
    padded_input = torch.full((B, seqlen), 0, dtype=input_ids.dtype, device=input_ids.device)
    padded_input[:, -T:] = input_ids

    generated_tokens = padded_input.clone()

    for _ in range(max_new_tokens):

        next_token = generate_next_token(
            model,
            x=generated_tokens,
            temperature=temperature,
            top_k=top_k,
        )

        # shift left and append next_token
        generated_tokens = torch.cat([generated_tokens[:, 1:], next_token], dim=1)

    return generated_tokens

@casper-hansen
Copy link

@jaysonfrancis Nice work on generation! It would be valuable to have evaluation in the training loop, but I see there are challenges with enabling it due to TP and PP.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Inference with the checkpoint
5 participants