diff --git a/README.md b/README.md index 0bc7173a..9f8122e5 100644 --- a/README.md +++ b/README.md @@ -171,49 +171,49 @@ models we currently offer, along with their foundational information. + + - - + + - - - - + + - - + + - - - - + + + + + - + - @@ -221,38 +221,37 @@ models we currently offer, along with their foundational information. - + + - - - + + - - + + + - - + - - +
Model NameCogVideoX1.5-5B (Latest)CogVideoX1.5-5B-I2V (Latest) CogVideoX-2B CogVideoX-5B CogVideoX-5B-I2VCogVideoX1.5-5BCogVideoX1.5-5B-I2V
Release DateNovember 8, 2024November 8, 2024 August 6, 2024 August 27, 2024 September 19, 2024November 8, 2024November 8, 2024
Video Resolution720 * 480 1360 * 768256 <= W <=1360
256 <= H <=768
W,H % 16 == 0
256 <= W <=1360
256 <= H <=768
W,H % 16 == 0
720 * 480
Inference PrecisionFP16*(recommended), BF16, FP32, FP8*, INT8, not supported: INT4BF16(recommended), FP16, FP32, FP8*, INT8, not supported: INT4 BF16FP16*(Recommended), BF16, FP32, FP8*, INT8, Not supported: INT4BF16 (Recommended), FP16, FP32, FP8*, INT8, Not supported: INT4
Single GPU Memory UsageSAT FP16: 18GB
diffusers FP16: from 4GB*
diffusers INT8(torchao): from 3.6GB*
SAT BF16: 26GB
diffusers BF16 : from 5GB*
diffusers INT8(torchao): from 4.4GB*
SAT BF16: 66GB
Single GPU Memory Usage
SAT BF16: 66GB
SAT FP16: 18GB
diffusers FP16: 4GB minimum*
diffusers INT8 (torchao): 3.6GB minimum*
SAT BF16: 26GB
diffusers BF16 : 5GB minimum*
diffusers INT8 (torchao): 4.4GB minimum*
Multi-GPU Memory UsageNot Supported
FP16: 10GB* using diffusers
BF16: 15GB* using diffusers
Not supported
Inference Speed
(Step = 50, FP/BF16)
Single A100: ~1000 seconds (5-second video)
Single H100: ~550 seconds (5-second video)
Single A100: ~90 seconds
Single H100: ~45 seconds
Single A100: ~180 seconds
Single H100: ~90 seconds
Single A100: ~1000 seconds (5-second video)
Single H100: ~550 seconds (5-second video)
Prompt Language
Prompt Token Limit226 Tokens 224 Tokens226 Tokens
Video Length5 seconds or 10 seconds 6 seconds5 or 10 seconds
Frame Rate8 frames / second16 frames / second16 frames / second 8 frames / second
Positional Encoding3d_sincos_pos_embedPosition Encoding3d_rope_pos_embed3d_sincos_pos_embed 3d_rope_pos_embed 3d_rope_pos_embed + learnable_pos_embed3d_rope_pos_embed3d_rope_pos_embed
Download Link (Diffusers) Coming Soon 🤗 HuggingFace
🤖 ModelScope
🟣 WiseModel
🤗 HuggingFace
🤖 ModelScope
🟣 WiseModel
🤗 HuggingFace
🤖 ModelScope
🟣 WiseModel
Coming Soon
Download Link (SAT)SAT 🤗 HuggingFace
🤖 ModelScope
🟣 WiseModel
SAT
diff --git a/README_ja.md b/README_ja.md index 1bc3d137..9962d1b9 100644 --- a/README_ja.md +++ b/README_ja.md @@ -163,88 +163,87 @@ CogVideoXは、[清影](https://chatglm.cn/video?fr=osm_cogvideox) と同源の + + - - - + + + - - - - + + - - + + - - - - + + + + - - - - + + + + - - - - + + + + - - + + - - + + - - + + - + + - - + - - +
モデル名CogVideoX1.5-5B (最新)CogVideoX1.5-5B-I2V (最新) CogVideoX-2B CogVideoX-5B CogVideoX-5B-I2VCogVideoX1.5-5BCogVideoX1.5-5B-I2V
リリース日公開日2024年11月8日2024年11月8日 2024年8月6日 2024年8月27日 2024年9月19日2024年11月8日2024年11月8日
ビデオ解像度720 * 480 1360 * 768256 <= W <=1360
256 <= H <=768
W,H % 16 == 0
256 <= W <=1360
256 <= H <=768
W,H % 16 == 0
720 * 480
推論精度FP16*(推奨), BF16, FP32, FP8*, INT8, INT4は非対応BF16(推奨), FP16, FP32, FP8*, INT8, INT4は非対応 BF16FP16*(推奨), BF16, FP32,FP8*,INT8,INT4非対応BF16(推奨), FP16, FP32,FP8*,INT8,INT4非対応
シングルGPUメモリ消費SAT FP16: 18GB
diffusers FP16: 4GBから*
diffusers INT8(torchao): 3.6GBから*
SAT BF16: 26GB
diffusers BF16: 5GBから*
diffusers INT8(torchao): 4.4GBから*
SAT BF16: 66GB
単一GPUメモリ消費量
SAT BF16: 66GB
SAT FP16: 18GB
diffusers FP16: 4GB以上*
diffusers INT8(torchao): 3.6GB以上*
SAT BF16: 26GB
diffusers BF16 : 5GB以上*
diffusers INT8(torchao): 4.4GB以上*
マルチGPUメモリ消費FP16: 10GB* using diffusers
BF16: 15GB* using diffusers
サポートなし
複数GPU推論メモリ消費量非対応
FP16: 10GB* diffusers使用
BF16: 15GB* diffusers使用
推論速度
(ステップ数 = 50, FP/BF16)
単一A100: 約90秒
単一H100: 約45秒
単一A100: 約180秒
単一H100: 約90秒
単一A100: 約1000秒(5秒動画)
単一H100: 約550秒(5秒動画)
推論速度
(Step = 50, FP/BF16)
シングルA100: ~1000秒(5秒ビデオ)
シングルH100: ~550秒(5秒ビデオ)
シングルA100: ~90秒
シングルH100: ~45秒
シングルA100: ~180秒
シングルH100: ~90秒
プロンプト言語 英語*
プロンプトトークン制限226トークンプロンプト長さの上限 224トークン226トークン
ビデオの長さ6秒ビデオ長さ 5秒または10秒6秒
フレームレート8 フレーム / 秒16 フレーム / 秒16フレーム/秒8フレーム/秒
位置エンコーディング3d_sincos_pos_embed3d_rope_pos_embed3d_sincos_pos_embed 3d_rope_pos_embed 3d_rope_pos_embed + learnable_pos_embed3d_rope_pos_embed3d_rope_pos_embed
ダウンロードリンク (Diffusers) 近日公開 🤗 HuggingFace
🤖 ModelScope
🟣 WiseModel
🤗 HuggingFace
🤖 ModelScope
🟣 WiseModel
🤗 HuggingFace
🤖 ModelScope
🟣 WiseModel
近日公開
ダウンロードリンク (SAT)SAT 🤗 HuggingFace
🤖 ModelScope
🟣 WiseModel
SAT
diff --git a/README_zh.md b/README_zh.md index a88cc369..c66fc855 100644 --- a/README_zh.md +++ b/README_zh.md @@ -154,49 +154,49 @@ CogVideoX是 [清影](https://chatglm.cn/video?fr=osm_cogvideox) 同源的开源 + + - - + + - - - - + + + - + - + - + - @@ -204,39 +204,37 @@ CogVideoX是 [清影](https://chatglm.cn/video?fr=osm_cogvideox) 同源的开源 - + - + - + - + + - - + - - - +
模型名CogVideoX1.5-5B (最新)CogVideoX1.5-5B-I2V (最新) CogVideoX-2B CogVideoX-5B CogVideoX-5B-I2V CogVideoX1.5-5BCogVideoX1.5-5B-I2V
发布时间2024年11月8日2024年11月8日 2024年8月6日 2024年8月27日 2024年9月19日2024年11月8日2024年11月8日
视频分辨率720 * 480 1360 * 768 256 <= W <=1360
256 <= H <=768
W,H % 16 == 0
720 * 480
推理精度BF16 FP16*(推荐), BF16, FP32,FP8*,INT8,不支持INT4 BF16(推荐), FP16, FP32,FP8*,INT8,不支持INT4BF16
单GPU显存消耗
SAT BF16: 66GB
SAT FP16: 18GB
diffusers FP16: 4GB起*
diffusers INT8(torchao): 3.6G起*
SAT BF16: 26GB
diffusers BF16 : 5GB起*
diffusers INT8(torchao): 4.4G起*
SAT BF16: 66GB
多GPU推理显存消耗不支持
FP16: 10GB* using diffusers
BF16: 15GB* using diffusers
Not support
推理速度
(Step = 50, FP/BF16)
单卡A100: ~1000秒(5秒视频)
单卡H100: ~550秒(5秒视频)
单卡A100: ~90秒
单卡H100: ~45秒
单卡A100: ~180秒
单卡H100: ~90秒
单卡A100: ~1000秒(5秒视频)
单卡H100: ~550秒(5秒视频)
提示词语言
提示词长度上限226 Tokens 224 Tokens226 Tokens
视频长度6 秒 5 秒 或 10 秒6 秒
帧率8 帧 / 秒 16 帧 / 秒 8 帧 / 秒
位置编码3d_sincos_pos_embed3d_rope_pos_embed3d_sincos_pos_embed 3d_rope_pos_embed 3d_rope_pos_embed + learnable_pos_embed3d_rope_pos_embed3d_rope_pos_embed
下载链接 (Diffusers) 即将推出 🤗 HuggingFace
🤖 ModelScope
🟣 WiseModel
🤗 HuggingFace
🤖 ModelScope
🟣 WiseModel
🤗 HuggingFace
🤖 ModelScope
🟣 WiseModel
即将推出
下载链接 (SAT)SAT 🤗 HuggingFace
🤖 ModelScope
🟣 WiseModel
SAT
diff --git a/sat/configs/cogvideox1.5_5b.yaml b/sat/configs/cogvideox1.5_5b.yaml index 0000ec21..62d46be6 100644 --- a/sat/configs/cogvideox1.5_5b.yaml +++ b/sat/configs/cogvideox1.5_5b.yaml @@ -23,7 +23,7 @@ model: params: time_embed_dim: 512 elementwise_affine: True - num_frames: 81 + num_frames: 81 # for 5 seconds and 161 for 10 seconds time_compressed_rate: 4 latent_width: 300 latent_height: 300 diff --git a/sat/configs/cogvideox1.5_5b_i2v.yaml b/sat/configs/cogvideox1.5_5b_i2v.yaml index c65f0b7a..65d90f9e 100644 --- a/sat/configs/cogvideox1.5_5b_i2v.yaml +++ b/sat/configs/cogvideox1.5_5b_i2v.yaml @@ -25,11 +25,10 @@ model: network_config: target: dit_video_concat.DiffusionTransformer params: -# space_interpolation: 1.875 ofs_embed_dim: 512 time_embed_dim: 512 elementwise_affine: True - num_frames: 81 + num_frames: 81 # for 5 seconds and 161 for 10 seconds time_compressed_rate: 4 latent_width: 300 latent_height: 300 diff --git a/sat/configs/inference.yaml b/sat/configs/inference.yaml index a93bb997..644e71ab 100644 --- a/sat/configs/inference.yaml +++ b/sat/configs/inference.yaml @@ -1,16 +1,14 @@ args: - image2video: False # True for image2video, False for text2video +# image2video: True # True for image2video, False for text2video latent_channels: 16 mode: inference load: "{your CogVideoX SAT folder}/transformer" # This is for Full model without lora adapter - # load: "{your lora folder} such as zRzRzRzRzRzRzR/lora-disney-08-20-13-28" # This is for Full model without lora adapter batch_size: 1 input_type: txt input_file: configs/test.txt - sampling_image_size: [480, 720] - sampling_num_frames: 13 # Must be 13, 11 or 9 - sampling_fps: 8 -# fp16: True # For CogVideoX-2B - bf16: True # For CogVideoX-5B and CoGVideoX-5B-I2V - output_dir: outputs/ + sampling_image_size: [768, 1360] # remove this for I2V + sampling_num_frames: 22 # 42 for 10 seconds and 22 for 5 seconds + sampling_fps: 16 + bf16: True + output_dir: outputs force_inference: True \ No newline at end of file diff --git a/sat/diffusion_video.py b/sat/diffusion_video.py index 10635b4d..226ed6e5 100644 --- a/sat/diffusion_video.py +++ b/sat/diffusion_video.py @@ -192,13 +192,13 @@ def decode_first_stage(self, z): for i in range(fake_cp_size): end_frame = start_frame + latent_time // fake_cp_size + (1 if i < latent_time % fake_cp_size else 0) - fake_cp_rank0 = True if i == 0 else False + use_cp = True if i == 0 else False clear_fake_cp_cache = True if i == fake_cp_size - 1 else False with torch.no_grad(): recon = self.first_stage_model.decode( z_now[:, :, start_frame:end_frame].contiguous(), clear_fake_cp_cache=clear_fake_cp_cache, - fake_cp_rank0=fake_cp_rank0, + use_cp=use_cp, ) recons.append(recon) start_frame = end_frame diff --git a/sat/dit_video_concat.py b/sat/dit_video_concat.py index b55a3f18..22c3821f 100644 --- a/sat/dit_video_concat.py +++ b/sat/dit_video_concat.py @@ -7,7 +7,6 @@ import torch from torch import nn import torch.nn.functional as F - from sat.model.base_model import BaseModel, non_conflict from sat.model.mixins import BaseMixin from sat.transformer_defaults import HOOKS_DEFAULT, attention_fn_default diff --git a/sat/inference.sh b/sat/inference.sh index a22ef872..9904433f 100755 --- a/sat/inference.sh +++ b/sat/inference.sh @@ -4,7 +4,7 @@ echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES" environs="WORLD_SIZE=1 RANK=0 LOCAL_RANK=0 LOCAL_WORLD_SIZE=1" -run_cmd="$environs python sample_video.py --base configs/cogvideox1.5_5b.yaml configs/test_inference.yaml --seed $RANDOM" +run_cmd="$environs python sample_video.py --base configs/cogvideox1.5_5b.yaml configs/inference.yaml --seed $RANDOM" echo ${run_cmd} eval ${run_cmd} diff --git a/sat/vae_modules/autoencoder.py b/sat/vae_modules/autoencoder.py index 9642fb44..0adad855 100644 --- a/sat/vae_modules/autoencoder.py +++ b/sat/vae_modules/autoencoder.py @@ -1,17 +1,13 @@ import logging import math import re -import random from abc import abstractmethod from contextlib import contextmanager from typing import Any, Dict, List, Optional, Tuple, Union -import numpy as np import pytorch_lightning as pl import torch import torch.distributed -import torch.nn as nn -from einops import rearrange from packaging import version from vae_modules.ema import LitEma @@ -56,34 +52,16 @@ def __init__( if version.parse(torch.__version__) >= version.parse("2.0.0"): self.automatic_optimization = False - # def apply_ckpt(self, ckpt: Union[None, str, dict]): - # if ckpt is None: - # return - # if isinstance(ckpt, str): - # ckpt = { - # "target": "sgm.modules.checkpoint.CheckpointEngine", - # "params": {"ckpt_path": ckpt}, - # } - # engine = instantiate_from_config(ckpt) - # engine(self) - def apply_ckpt(self, ckpt: Union[None, str, dict]): if ckpt is None: return - self.init_from_ckpt(ckpt) - - def init_from_ckpt(self, path, ignore_keys=list()): - sd = torch.load(path, map_location="cpu")["state_dict"] - keys = list(sd.keys()) - for k in keys: - for ik in ignore_keys: - if k.startswith(ik): - print("Deleting key {} from state_dict.".format(k)) - del sd[k] - missing_keys, unexpected_keys = self.load_state_dict(sd, strict=False) - print("Missing keys: ", missing_keys) - print("Unexpected keys: ", unexpected_keys) - print(f"Restored from {path}") + if isinstance(ckpt, str): + ckpt = { + "target": "sgm.modules.checkpoint.CheckpointEngine", + "params": {"ckpt_path": ckpt}, + } + engine = instantiate_from_config(ckpt) + engine(self) @abstractmethod def get_input(self, batch) -> Any: @@ -119,7 +97,9 @@ def decode(self, *args, **kwargs) -> torch.Tensor: def instantiate_optimizer_from_config(self, params, lr, cfg): logpy.info(f"loading >>> {cfg['target']} <<< optimizer from config") - return get_obj_from_str(cfg["target"])(params, lr=lr, **cfg.get("params", dict())) + return get_obj_from_str(cfg["target"])( + params, lr=lr, **cfg.get("params", dict()) + ) def configure_optimizers(self) -> Any: raise NotImplementedError() @@ -216,12 +196,13 @@ def get_last_layer(self): return self.decoder.get_last_layer() def encode( - self, - x: torch.Tensor, - return_reg_log: bool = False, - unregularized: bool = False, + self, + x: torch.Tensor, + return_reg_log: bool = False, + unregularized: bool = False, + **kwargs, ) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]: - z = self.encoder(x) + z = self.encoder(x, **kwargs) if unregularized: return z, dict() z, reg_log = self.regularization(z) diff --git a/sat/vae_modules/cp_enc_dec.py b/sat/vae_modules/cp_enc_dec.py index 1d9c34f5..4d773240 100644 --- a/sat/vae_modules/cp_enc_dec.py +++ b/sat/vae_modules/cp_enc_dec.py @@ -101,8 +101,6 @@ def _gather(input_, dim): group = get_context_parallel_group() cp_rank = get_context_parallel_rank() - # print('in _gather, cp_rank:', cp_rank, 'input_size:', input_.shape) - input_first_frame_ = input_.transpose(0, dim)[:1].transpose(0, dim).contiguous() if cp_rank == 0: input_ = input_.transpose(0, dim)[1:].transpose(0, dim).contiguous() @@ -127,12 +125,9 @@ def _gather(input_, dim): def _conv_split(input_, dim, kernel_size): cp_world_size = get_context_parallel_world_size() - # Bypass the function if context parallel is 1 if cp_world_size == 1: return input_ - # print('in _conv_split, cp_rank:', cp_rank, 'input_size:', input_.shape) - cp_rank = get_context_parallel_rank() dim_size = (input_.size()[dim] - kernel_size) // cp_world_size @@ -140,14 +135,11 @@ def _conv_split(input_, dim, kernel_size): if cp_rank == 0: output = input_.transpose(dim, 0)[: dim_size + kernel_size].transpose(dim, 0) else: - # output = input_.transpose(dim, 0)[cp_rank * dim_size + 1:(cp_rank + 1) * dim_size + kernel_size].transpose(dim, 0) output = input_.transpose(dim, 0)[ cp_rank * dim_size + kernel_size : (cp_rank + 1) * dim_size + kernel_size ].transpose(dim, 0) output = output.contiguous() - # print('out _conv_split, cp_rank:', cp_rank, 'input_size:', output.shape) - return output @@ -160,9 +152,6 @@ def _conv_gather(input_, dim, kernel_size): group = get_context_parallel_group() cp_rank = get_context_parallel_rank() - - # print('in _conv_gather, cp_rank:', cp_rank, 'input_size:', input_.shape) - input_first_kernel_ = input_.transpose(0, dim)[:kernel_size].transpose(0, dim).contiguous() if cp_rank == 0: input_ = input_.transpose(0, dim)[kernel_size:].transpose(0, dim).contiguous() @@ -255,17 +244,12 @@ def _fake_cp_pass_from_previous_rank(input_, dim, kernel_size, cache_padding=Non if recv_rank % cp_world_size == cp_world_size - 1: recv_rank += cp_world_size - # req_send = torch.distributed.isend(input_[-kernel_size + 1:].contiguous(), send_rank, group=group) - # recv_buffer = torch.empty_like(input_[-kernel_size + 1:]).contiguous() - # req_recv = torch.distributed.recv(recv_buffer, recv_rank, group=group) - # req_recv.wait() recv_buffer = torch.empty_like(input_[-kernel_size + 1 :]).contiguous() if cp_rank < cp_world_size - 1: req_send = torch.distributed.isend(input_[-kernel_size + 1 :].contiguous(), send_rank, group=group) if cp_rank > 0: req_recv = torch.distributed.irecv(recv_buffer, recv_rank, group=group) - # req_send = torch.distributed.isend(input_[-kernel_size + 1:].contiguous(), send_rank, group=group) - # req_recv = torch.distributed.irecv(recv_buffer, recv_rank, group=group) + if cp_rank == 0: if cache_padding is not None: @@ -421,7 +405,6 @@ def forward(self, input_): def Normalize(in_channels, gather=False, **kwargs): - # same for 3D and 2D if gather: return ContextParallelGroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) else: @@ -468,8 +451,8 @@ def __init__( kernel_size=1, ) - def forward(self, f, zq, clear_fake_cp_cache=True, fake_cp_rank0=True): - if f.shape[2] > 1 and get_context_parallel_rank() == 0 and fake_cp_rank0: + def forward(self, f, zq, clear_fake_cp_cache=True, fake_cp=True): + if f.shape[2] > 1 and get_context_parallel_rank() == 0 and fake_cp: f_first, f_rest = f[:, :, :1], f[:, :, 1:] f_first_size, f_rest_size = f_first.shape[-3:], f_rest.shape[-3:] zq_first, zq_rest = zq[:, :, :1], zq[:, :, 1:] @@ -531,13 +514,11 @@ def __init__( self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) self.compress_time = compress_time - def forward(self, x, fake_cp_rank0=True): + def forward(self, x, fake_cp=True): if self.compress_time and x.shape[2] > 1: - if get_context_parallel_rank() == 0 and fake_cp_rank0: - # print(x.shape) + if get_context_parallel_rank() == 0 and fake_cp: # split first frame x_first, x_rest = x[:, :, 0], x[:, :, 1:] - x_first = torch.nn.functional.interpolate(x_first, scale_factor=2.0, mode="nearest") splits = torch.split(x_rest, 32, dim=1) @@ -545,8 +526,6 @@ def forward(self, x, fake_cp_rank0=True): torch.nn.functional.interpolate(split, scale_factor=2.0, mode="nearest") for split in splits ] x_rest = torch.cat(interpolated_splits, dim=1) - - # x_rest = torch.nn.functional.interpolate(x_rest, scale_factor=2.0, mode="nearest") x = torch.cat([x_first[:, :, None, :, :], x_rest], dim=2) else: splits = torch.split(x, 32, dim=1) @@ -555,13 +534,10 @@ def forward(self, x, fake_cp_rank0=True): ] x = torch.cat(interpolated_splits, dim=1) - # x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") - else: # only interpolate 2D t = x.shape[2] x = rearrange(x, "b c t h w -> (b t) c h w") - # x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") splits = torch.split(x, 32, dim=1) interpolated_splits = [ @@ -590,12 +566,12 @@ def __init__(self, in_channels, with_conv, compress_time=False, out_channels=Non self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=0) self.compress_time = compress_time - def forward(self, x, fake_cp_rank0=True): + def forward(self, x, fake_cp=True): if self.compress_time and x.shape[2] > 1: h, w = x.shape[-2:] x = rearrange(x, "b c t h w -> (b h w) c t") - if get_context_parallel_rank() == 0 and fake_cp_rank0: + if get_context_parallel_rank() == 0 and fake_cp: # split first frame x_first, x_rest = x[..., 0], x[..., 1:] @@ -693,17 +669,13 @@ def __init__( padding=0, ) - def forward(self, x, temb, zq=None, clear_fake_cp_cache=True, fake_cp_rank0=True): + def forward(self, x, temb, zq=None, clear_fake_cp_cache=True, fake_cp=True): h = x - # if isinstance(self.norm1, torch.nn.GroupNorm): - # h = conv_gather_from_context_parallel_region(h, dim=2, kernel_size=1) if zq is not None: - h = self.norm1(h, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp_rank0=fake_cp_rank0) + h = self.norm1(h, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp=fake_cp) else: h = self.norm1(h) - # if isinstance(self.norm1, torch.nn.GroupNorm): - # h = conv_scatter_to_context_parallel_region(h, dim=2, kernel_size=1) h = nonlinearity(h) h = self.conv1(h, clear_cache=clear_fake_cp_cache) @@ -711,14 +683,10 @@ def forward(self, x, temb, zq=None, clear_fake_cp_cache=True, fake_cp_rank0=True if temb is not None: h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None, None] - # if isinstance(self.norm2, torch.nn.GroupNorm): - # h = conv_gather_from_context_parallel_region(h, dim=2, kernel_size=1) if zq is not None: - h = self.norm2(h, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp_rank0=fake_cp_rank0) + h = self.norm2(h, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp=fake_cp) else: h = self.norm2(h) - # if isinstance(self.norm2, torch.nn.GroupNorm): - # h = conv_scatter_to_context_parallel_region(h, dim=2, kernel_size=1) h = nonlinearity(h) h = self.dropout(h) @@ -827,32 +795,33 @@ def __init__( kernel_size=3, ) - def forward(self, x, clear_fake_cp_cache=True, fake_cp_rank0=True): + def forward(self, x, use_cp=True): + global _USE_CP + _USE_CP = use_cp + # timestep embedding temb = None # downsampling - h = self.conv_in(x, clear_cache=clear_fake_cp_cache) + hs = [self.conv_in(x)] for i_level in range(self.num_resolutions): for i_block in range(self.num_res_blocks): - h = self.down[i_level].block[i_block](h, temb, clear_fake_cp_cache=clear_fake_cp_cache) + h = self.down[i_level].block[i_block](hs[-1], temb) if len(self.down[i_level].attn) > 0: - print("Attention not implemented") h = self.down[i_level].attn[i_block](h) + hs.append(h) if i_level != self.num_resolutions - 1: - h = self.down[i_level].downsample(h, fake_cp_rank0=fake_cp_rank0) + hs.append(self.down[i_level].downsample(hs[-1])) # middle - h = self.mid.block_1(h, temb, clear_fake_cp_cache=clear_fake_cp_cache) - h = self.mid.block_2(h, temb, clear_fake_cp_cache=clear_fake_cp_cache) + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.block_2(h, temb) # end - # h = conv_gather_from_context_parallel_region(h, dim=2, kernel_size=1) h = self.norm_out(h) - # h = conv_scatter_to_context_parallel_region(h, dim=2, kernel_size=1) - h = nonlinearity(h) - h = self.conv_out(h, clear_cache=clear_fake_cp_cache) + h = self.conv_out(h) return h @@ -895,11 +864,9 @@ def __init__( zq_ch = z_channels # compute in_ch_mult, block_in and curr_res at lowest res - in_ch_mult = (1,) + tuple(ch_mult) block_in = ch * ch_mult[self.num_resolutions - 1] curr_res = resolution // 2 ** (self.num_resolutions - 1) self.z_shape = (1, z_channels, curr_res, curr_res) - print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape))) self.conv_in = ContextParallelCausalConv3d( chan_in=z_channels, @@ -955,11 +922,6 @@ def __init__( up.block = block up.attn = attn if i_level != 0: - # # Symmetrical enc-dec - if i_level <= self.temporal_compress_level: - up.upsample = Upsample3D(block_in, with_conv=resamp_with_conv, compress_time=True) - else: - up.upsample = Upsample3D(block_in, with_conv=resamp_with_conv, compress_time=False) if i_level < self.num_resolutions - self.temporal_compress_level: up.upsample = Upsample3D(block_in, with_conv=resamp_with_conv, compress_time=False) else: @@ -974,7 +936,9 @@ def __init__( kernel_size=3, ) - def forward(self, z, clear_fake_cp_cache=True, fake_cp_rank0=True): + def forward(self, z, clear_fake_cp_cache=True, use_cp=True): + global _USE_CP + _USE_CP = use_cp self.last_z_shape = z.shape # timestep embedding @@ -987,25 +951,25 @@ def forward(self, z, clear_fake_cp_cache=True, fake_cp_rank0=True): h = self.conv_in(z, clear_cache=clear_fake_cp_cache) # middle - h = self.mid.block_1(h, temb, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp_rank0=fake_cp_rank0) - h = self.mid.block_2(h, temb, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp_rank0=fake_cp_rank0) + h = self.mid.block_1(h, temb, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp=use_cp) + h = self.mid.block_2(h, temb, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp=use_cp) # upsampling for i_level in reversed(range(self.num_resolutions)): for i_block in range(self.num_res_blocks + 1): h = self.up[i_level].block[i_block]( - h, temb, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp_rank0=fake_cp_rank0 + h, temb, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp=use_cp ) if len(self.up[i_level].attn) > 0: h = self.up[i_level].attn[i_block](h, zq) if i_level != 0: - h = self.up[i_level].upsample(h, fake_cp_rank0=fake_cp_rank0) + h = self.up[i_level].upsample(h, fake_cp=use_cp) # end if self.give_pre_end: return h - h = self.norm_out(h, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp_rank0=fake_cp_rank0) + h = self.norm_out(h, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp=use_cp) h = nonlinearity(h) h = self.conv_out(h, clear_cache=clear_fake_cp_cache)