Skip to content

Commit

Permalink
Merge pull request #87 from TACJu/main
Browse files Browse the repository at this point in the history
Release the training code, inference code and model weights of MaskGen
  • Loading branch information
cornettoyu authored Feb 24, 2025
2 parents 0647fa1 + 60d4745 commit e85c999
Show file tree
Hide file tree
Showing 28 changed files with 3,858 additions and 42 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ This repo hosts the code and models for the following projects:
- TiTok: [An Image is Worth 32 Tokens for Reconstruction and Generation](https://yucornetto.github.io/projects/titok.html)

## Updates
- 02/24/2025: We release the training code, inference code and model weights of MaskGen.
- 01/17/2025: We release the training code, inference code and model weights of TA-TiTok.
- 01/14/2025: The [tech report](https://arxiv.org/abs/2501.07730) of TA-TiTok and MaskGen is available. TA-TiTok is an innovative text-aware transformer-based 1-dimensional tokenizer designed to handle both discrete and continuous tokens. MaskGen is a powerful and efficient text-to-image masked generative model trained exclusively on open-data. For more details, refer to the [README_MaskGen](README_MaskGen.md).
- 11/04/2024: We release the [tech report](https://arxiv.org/abs/2411.00776) and code for RAR models.
Expand Down
96 changes: 88 additions & 8 deletions README_MaskGen.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,6 @@ Building on TA-TiTok, we present MaskGen, a versatile text-to-image masked gener

#### We propose MaskGen, a family of text-to-image masked generative models built upon TA-TiTok. The MaskGen VQ and MaskGen KL variants utilize compact sequences of 128 discrete tokens and 32 continuous tokens, respectively. Trained exclusively on open data, MaskGen achieves performance comparable to models trained on proprietary datasets, while offering significantly lower training cost and substantially faster inference speed.

## TODO

- [ ] Release training code, inference code and checkpoints of MaskGen


## TA-TiTok Model Zoo
| arch | #tokens | Link | rFID | IS |
Expand All @@ -50,8 +46,8 @@ Please note that these models are only for research purposes.
## MaskGen Model Zoo
| Model | arch | Link | MJHQ-30K FID | GenEval Overall |
| ------------- | ------------- | ------------- | ------------- | ------------- |
| MaskGen-L | KL | TODO | 7.24 | 0.52 |
| MaskGen-XL | KL | TODO | 6.53 | 0.55 |
| MaskGen-L | KL | [checkpoint](https://huggingface.co/turkeyju/generator_maskgen_kl_l) | 7.24 | 0.52 |
| MaskGen-XL | KL | [checkpoint](https://huggingface.co/turkeyju/generator_maskgen_kl_xl) | 6.53 | 0.55 |

Please note that these models are only for research purposes.

Expand All @@ -60,7 +56,7 @@ Please note that these models are only for research purposes.
pip3 install -r requirements.txt
```

## Get Started
## Get Started - TA-TiTok
```python
import torch
from PIL import Image
Expand All @@ -76,7 +72,7 @@ tatitok_tokenizer.eval()
tatitok_tokenizer.requires_grad_(False)

# or alternatively, downloads from hf
# hf_hub_download(repo_id="fun-research/TA-TiTok", filename="tokenizer_tatitok_bl32_vae.bin", local_dir="./")
# hf_hub_download(repo_id="fun-research/TA-TiTok", filename="tatitok_bl32_vae.bin", local_dir="./")

# load config
# config = demo_util.get_config("configs/infer/TA-TiTok/tatitok_bl32_vae.yaml")
Expand Down Expand Up @@ -124,6 +120,78 @@ reconstructed_image = (reconstructed_image * 255.0).permute(0, 2, 3, 1).to("cpu"
reconstructed_image = Image.fromarray(reconstructed_image).save("assets/ILSVRC2012_val_00010240_recon.png")
```

## Get Started - MaskGen
```python
import torch
from PIL import Image
import numpy as np
import open_clip
import demo_util
from huggingface_hub import hf_hub_download
from modeling.tatitok import TATiTok
from modeling.maskgen import MaskGen_KL

torch.manual_seed(42)
torch.cuda.manual_seed(42)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# Tokenizer: load tokenizer tatitok_bl32_vae
tatitok_tokenizer = TATiTok.from_pretrained("turkeyju/tokenizer_tatitok_bl32_vae")
tatitok_tokenizer.eval()
tatitok_tokenizer.requires_grad_(False)

# or alternatively, downloads from hf
# hf_hub_download(repo_id="fun-research/TA-TiTok", filename="tatitok_bl32_vae.bin", local_dir="./")

# load config
# config = demo_util.get_config("configs/infer/TA-TiTok/tatitok_bl32_vae.yaml")
# tatitok_tokenizer = demo_util.get_tatitok_tokenizer(config)

# Generator: choose one from ["maskgen_kl_l", "maskgen_kl_xl"]
maskgen_kl_generator = MaskGen_KL.from_pretrained("turkeyju/generator_maskgen_kl_xl")
maskgen_kl_generator.eval()
maskgen_kl_generator.requires_grad_(False)

# or alternatively, downloads from hf
# hf_hub_download(repo_id="fun-research/TA-TiTok", filename="maskgen_kl_xl.bin", local_dir="./")

# load config
# config = demo_util.get_config("configs/infer/MaskGen/maskgen_kl_xl.yaml")
# maskgen_kl_generator = demo_util.get_maskgen_kl_generator(config)

clip_encoder, _, _ = open_clip.create_model_and_transforms('ViT-L-14-336', pretrained='openai')
del clip_encoder.visual
clip_tokenizer = open_clip.get_tokenizer('ViT-L-14-336')
clip_encoder.transformer.batch_first = False
clip_encoder.eval()
clip_encoder.requires_grad_(False)

device = "cuda"
tatitok_tokenizer = tatitok_tokenizer.to(device)
maskgen_kl_generator = maskgen_kl_generator.to(device)
clip_encoder = clip_encoder.to(device)

# generate an image
text = ["A cozy cabin in the middle of a snowy forest, surrounded by tall trees with lights glowing through the windows, a northern lights display visible in the sky."]
text_guidance = clip_tokenizer(text).to(device)
cast_dtype = clip_encoder.transformer.get_cast_dtype()
text_guidance = clip_encoder.token_embedding(text_guidance).to(cast_dtype) # [batch_size, n_ctx, d_model]
text_guidance = text_guidance + clip_encoder.positional_embedding.to(cast_dtype)
text_guidance = text_guidance.permute(1, 0, 2) # NLD -> LND
text_guidance = clip_encoder.transformer(text_guidance, attn_mask=clip_encoder.attn_mask)
text_guidance = text_guidance.permute(1, 0, 2) # LND -> NLD
text_guidance = clip_encoder.ln_final(text_guidance) # [batch_size, n_ctx, transformer.width]

generated_tokens = maskgen_kl_generator.sample_tokens(1, clip_tokenizer, clip_encoder, num_iter=32, cfg=3.0, aes_scores=6.5, captions=text)

# de-tokenization
reconstructed_image = tatitok_tokenizer.decode_tokens(generated_tokens, text_guidance)
reconstructed_image = torch.clamp(reconstructed_image, 0.0, 1.0)
reconstructed_image = (reconstructed_image * 255.0).permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy()[0]
reconstructed_image = Image.fromarray(reconstructed_image).save("assets/maskgen_kl_generator_generated.png")
```

## Training Preparation
We use [webdataset](https://github.com/webdataset/webdataset) format for data loading. To begin with, it is needed to convert the dataset into webdataset format.

Expand All @@ -141,6 +209,18 @@ WANDB_MODE=offline accelerate launch --num_machines=4 --num_processes=8 --machin
experiment.project="tatitok_bl32_vae" \
experiment.name="tatitok_bl32_vae_run1" \
experiment.output_dir="tatitok_bl32_vae_run1" \

# Training for MaskGen-{VQ/KL}-{L/XL} Stage1
WANDB_MODE=offline accelerate launch --num_machines=4 --num_processes=8 --machine_rank=0 --main_process_ip=127.0.0.1 --main_process_port=9999 --same_network scripts/train_maskgen.py config=configs/training/MaskGen/maskgen_{vq/kl}_{l/xl}_stage1.yaml \
experiment.project="maskgen_{vq/kl}_{l/xl}_stage1" \
experiment.name="maskgen_{vq/kl}_{l/xl}_stage1_run1" \
experiment.output_dir="maskgen_{vq/kl}_{l/xl}_stage1_run1" \

# Training for MaskGen-{VQ/KL}-{L/XL} Stage2
WANDB_MODE=offline accelerate launch --num_machines=4 --num_processes=8 --machine_rank=0 --main_process_ip=127.0.0.1 --main_process_port=9999 --same_network scripts/train_maskgen.py config=configs/training/MaskGen/maskgen_{vq/kl}_{l/xl}_stage2.yaml \
experiment.project="maskgen_{vq/kl}_{l/xl}_stage2" \
experiment.name="maskgen_{vq/kl}_{l/xl}_stage2_run1" \
experiment.output_dir="maskgen_{vq/kl}_{l/xl}_stage2_run1" \
```
You may remove the flag "WANDB_MODE=offline" to support online wandb logging, if you have configured it.

Expand Down
36 changes: 36 additions & 0 deletions configs/infer/MaskGen/maskgen_kl_l.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
experiment:
tokenizer_checkpoint: "tatitok_bl32_vae.bin"
generator_checkpoint: "maskgen_kl_l.bin"

model:
vq_model:
quantize_mode: vae
token_size: 16
vit_enc_model_size: base
vit_dec_model_size: large
vit_enc_patch_size: 16
vit_dec_patch_size: 16
num_latent_tokens: 32
scale_factor: 0.7525
finetune_decoder: False
is_legacy: False
maskgen:
decoder_embed_dim: 1024
decoder_depth: 16
decoder_num_heads: 16
micro_condition: true
micro_condition_embed_dim: 256
text_drop_prob: 0.1
cfg: 3.0
cfg_schedule: "linear"
num_iter: 32
temperature: 1.0
sample_aesthetic_score: 6.5

losses:
diffloss_d: 8
diffloss_w: 1024

dataset:
preprocessing:
crop_size: 256
36 changes: 36 additions & 0 deletions configs/infer/MaskGen/maskgen_kl_xl.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
experiment:
tokenizer_checkpoint: "tatitok_bl32_vae.bin"
generator_checkpoint: "maskgen_kl_xl.bin"

model:
vq_model:
quantize_mode: vae
token_size: 16
vit_enc_model_size: base
vit_dec_model_size: large
vit_enc_patch_size: 16
vit_dec_patch_size: 16
num_latent_tokens: 32
scale_factor: 0.7525
finetune_decoder: False
is_legacy: False
maskgen:
decoder_embed_dim: 1280
decoder_depth: 20
decoder_num_heads: 16
micro_condition: true
micro_condition_embed_dim: 256
text_drop_prob: 0.1
cfg: 3.0
cfg_schedule: "linear"
num_iter: 32
temperature: 1.0
sample_aesthetic_score: 6.5

losses:
diffloss_d: 8
diffloss_w: 1280

dataset:
preprocessing:
crop_size: 256
35 changes: 35 additions & 0 deletions configs/infer/MaskGen/maskgen_vq_l.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
experiment:
tokenizer_checkpoint: "tatitok_bl128_vq.bin"
generator_checkpoint: "maskgen_vq_l.bin"

model:
vq_model:
quantize_mode: vq
codebook_size: 8192
token_size: 64
use_l2_norm: false
commitment_cost: 0.25
clustering_vq: true
vit_enc_model_size: base
vit_dec_model_size: large
vit_enc_patch_size: 16
vit_dec_patch_size: 16
num_latent_tokens: 128
finetune_decoder: False
is_legacy: False
maskgen:
decoder_embed_dim: 1024
decoder_depth: 16
decoder_num_heads: 16
micro_condition: true
micro_condition_embed_dim: 256
text_drop_prob: 0.1
condition_num_classes: 1000
cfg: 12.0
num_iter: 16
temperature: 2.0
sample_aesthetic_score: 6.5

dataset:
preprocessing:
crop_size: 256
35 changes: 35 additions & 0 deletions configs/infer/MaskGen/maskgen_vq_xl.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
experiment:
tokenizer_checkpoint: "tatitok_bl128_vq.bin"
generator_checkpoint: "maskgen_vq_xl.bin"

model:
vq_model:
quantize_mode: vq
codebook_size: 8192
token_size: 64
use_l2_norm: false
commitment_cost: 0.25
clustering_vq: true
vit_enc_model_size: base
vit_dec_model_size: large
vit_enc_patch_size: 16
vit_dec_patch_size: 16
num_latent_tokens: 128
finetune_decoder: False
is_legacy: False
maskgen:
decoder_embed_dim: 1280
decoder_depth: 20
decoder_num_heads: 16
micro_condition: true
micro_condition_embed_dim: 256
text_drop_prob: 0.1
condition_num_classes: 1000
cfg: 12.0
num_iter: 16
temperature: 2.0
sample_aesthetic_score: 6.5

dataset:
preprocessing:
crop_size: 256
82 changes: 82 additions & 0 deletions configs/training/MaskGen/maskgen_kl_l_stage1.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
experiment:
project: "maskgen_kl_l_stage1"
name: "maskgen_kl_l_stage1_run1"
output_dir: "maskgen_kl_l_stage1_run1"
max_train_examples: 250_000_000
save_every: 50_000
eval_every: 50_000
generate_every: 5_000
log_every: 50
log_grad_norm_every: 1_000
resume: True

model:
vq_model:
quantize_mode: vae
token_size: 16
vit_enc_model_size: base
vit_dec_model_size: large
vit_enc_patch_size: 16
vit_dec_patch_size: 16
num_latent_tokens: 32
scale_factor: 0.7525
finetune_decoder: False
is_legacy: False
maskgen:
decoder_embed_dim: 1024
decoder_depth: 16
decoder_num_heads: 16
micro_condition: True
micro_condition_embed_dim: 256
text_drop_prob: 0.1
cfg: 3.0
cfg_schedule: "linear"
num_iter: 32
temperature: 1.0
sample_aesthetic_score: 6.0

losses:
diffloss_d: 8
diffloss_w: 1024

dataset:
params:
train_shards_path_or_url: "datacomp5+::cc12m::laion-en-aesthetic"
eval_shards_path_or_url: "coco"
pretokenization: "true"
num_workers_per_gpu: 12
dataset_with_class_label: False
dataset_with_text_label: True
preprocessing:
resize_shorter_edge: 256
crop_size: 256
random_crop: True
random_flip: True
res_ratio_filtering: True

optimizer:
name: adamw
params:
learning_rate: 1e-4
beta1: 0.9
beta2: 0.95
weight_decay: 0.02

lr_scheduler:
scheduler: "constant"
params:
learning_rate: ${optimizer.params.learning_rate}
warmup_steps: 50_000
end_lr: 1e-5

training:
gradient_accumulation_steps: 1
per_gpu_batch_size: 32
mixed_precision: "fp16"
enable_tf32: True
enable_wandb: True
use_ema: True
seed: 42
max_train_steps: 1_000_000
num_generated_images: 2
max_grad_norm: 1.0
Loading

0 comments on commit e85c999

Please sign in to comment.