diff --git a/README.md b/README.md
index 0bc7173a..9f8122e5 100644
--- a/README.md
+++ b/README.md
@@ -171,49 +171,49 @@ models we currently offer, along with their foundational information.
Model Name |
+ CogVideoX1.5-5B (Latest) |
+ CogVideoX1.5-5B-I2V (Latest) |
CogVideoX-2B |
CogVideoX-5B |
CogVideoX-5B-I2V |
- CogVideoX1.5-5B |
- CogVideoX1.5-5B-I2V |
Release Date |
+ November 8, 2024 |
+ November 8, 2024 |
August 6, 2024 |
August 27, 2024 |
September 19, 2024 |
- November 8, 2024 |
- November 8, 2024 |
Video Resolution |
- 720 * 480 |
1360 * 768 |
- 256 <= W <=1360 256 <= H <=768 W,H % 16 == 0 |
+ 256 <= W <=1360 256 <= H <=768 W,H % 16 == 0 |
+ 720 * 480 |
Inference Precision |
- FP16*(recommended), BF16, FP32, FP8*, INT8, not supported: INT4 |
- BF16(recommended), FP16, FP32, FP8*, INT8, not supported: INT4 |
BF16 |
+ FP16*(Recommended), BF16, FP32, FP8*, INT8, Not supported: INT4 |
+ BF16 (Recommended), FP16, FP32, FP8*, INT8, Not supported: INT4 |
- Single GPU Memory Usage |
- SAT FP16: 18GB diffusers FP16: from 4GB* diffusers INT8(torchao): from 3.6GB* |
- SAT BF16: 26GB diffusers BF16 : from 5GB* diffusers INT8(torchao): from 4.4GB* |
- SAT BF16: 66GB
|
+ Single GPU Memory Usage
|
+ SAT BF16: 66GB
|
+ SAT FP16: 18GB diffusers FP16: 4GB minimum* diffusers INT8 (torchao): 3.6GB minimum* |
+ SAT BF16: 26GB diffusers BF16 : 5GB minimum* diffusers INT8 (torchao): 4.4GB minimum* |
Multi-GPU Memory Usage |
+ Not Supported
|
FP16: 10GB* using diffusers
|
BF16: 15GB* using diffusers
|
- Not supported
|
Inference Speed (Step = 50, FP/BF16) |
+ Single A100: ~1000 seconds (5-second video) Single H100: ~550 seconds (5-second video) |
Single A100: ~90 seconds Single H100: ~45 seconds |
Single A100: ~180 seconds Single H100: ~90 seconds |
- Single A100: ~1000 seconds (5-second video) Single H100: ~550 seconds (5-second video) |
Prompt Language |
@@ -221,38 +221,37 @@ models we currently offer, along with their foundational information.
Prompt Token Limit |
- 226 Tokens |
224 Tokens |
+ 226 Tokens |
Video Length |
+ 5 seconds or 10 seconds |
6 seconds |
- 5 or 10 seconds |
Frame Rate |
- 8 frames / second |
- 16 frames / second |
+ 16 frames / second |
+ 8 frames / second |
- Positional Encoding |
- 3d_sincos_pos_embed |
+ Position Encoding |
+ 3d_rope_pos_embed |
+ 3d_sincos_pos_embed |
3d_rope_pos_embed |
3d_rope_pos_embed + learnable_pos_embed |
- 3d_rope_pos_embed |
- 3d_rope_pos_embed |
Download Link (Diffusers) |
+ Coming Soon |
🤗 HuggingFace 🤖 ModelScope 🟣 WiseModel |
🤗 HuggingFace 🤖 ModelScope 🟣 WiseModel |
🤗 HuggingFace 🤖 ModelScope 🟣 WiseModel |
- Coming Soon |
Download Link (SAT) |
- SAT |
🤗 HuggingFace 🤖 ModelScope 🟣 WiseModel |
+ SAT |
diff --git a/README_ja.md b/README_ja.md
index 1bc3d137..9962d1b9 100644
--- a/README_ja.md
+++ b/README_ja.md
@@ -163,88 +163,87 @@ CogVideoXは、[清影](https://chatglm.cn/video?fr=osm_cogvideox) と同源の
モデル名 |
+ CogVideoX1.5-5B (最新) |
+ CogVideoX1.5-5B-I2V (最新) |
CogVideoX-2B |
CogVideoX-5B |
CogVideoX-5B-I2V |
- CogVideoX1.5-5B |
- CogVideoX1.5-5B-I2V |
- リリース日 |
+ 公開日 |
+ 2024年11月8日 |
+ 2024年11月8日 |
2024年8月6日 |
2024年8月27日 |
2024年9月19日 |
- 2024年11月8日 |
- 2024年11月8日 |
ビデオ解像度 |
- 720 * 480 |
1360 * 768 |
- 256 <= W <=1360 256 <= H <=768 W,H % 16 == 0 |
+ 256 <= W <=1360 256 <= H <=768 W,H % 16 == 0 |
+ 720 * 480 |
推論精度 |
- FP16*(推奨), BF16, FP32, FP8*, INT8, INT4は非対応 |
- BF16(推奨), FP16, FP32, FP8*, INT8, INT4は非対応 |
BF16 |
+ FP16*(推奨), BF16, FP32,FP8*,INT8,INT4非対応 |
+ BF16(推奨), FP16, FP32,FP8*,INT8,INT4非対応 |
- シングルGPUメモリ消費 |
- SAT FP16: 18GB diffusers FP16: 4GBから* diffusers INT8(torchao): 3.6GBから* |
- SAT BF16: 26GB diffusers BF16: 5GBから* diffusers INT8(torchao): 4.4GBから* |
- SAT BF16: 66GB
|
+ 単一GPUメモリ消費量
|
+ SAT BF16: 66GB
|
+ SAT FP16: 18GB diffusers FP16: 4GB以上* diffusers INT8(torchao): 3.6GB以上* |
+ SAT BF16: 26GB diffusers BF16 : 5GB以上* diffusers INT8(torchao): 4.4GB以上* |
- マルチGPUメモリ消費 |
- FP16: 10GB* using diffusers
|
- BF16: 15GB* using diffusers
|
- サポートなし
|
+ 複数GPU推論メモリ消費量 |
+ 非対応
|
+ FP16: 10GB* diffusers使用
|
+ BF16: 15GB* diffusers使用
|
- 推論速度 (ステップ数 = 50, FP/BF16) |
- 単一A100: 約90秒 単一H100: 約45秒 |
- 単一A100: 約180秒 単一H100: 約90秒 |
- 単一A100: 約1000秒(5秒動画) 単一H100: 約550秒(5秒動画) |
+ 推論速度 (Step = 50, FP/BF16) |
+ シングルA100: ~1000秒(5秒ビデオ) シングルH100: ~550秒(5秒ビデオ) |
+ シングルA100: ~90秒 シングルH100: ~45秒 |
+ シングルA100: ~180秒 シングルH100: ~90秒 |
プロンプト言語 |
英語* |
- プロンプトトークン制限 |
- 226トークン |
+ プロンプト長さの上限 |
224トークン |
+ 226トークン |
- ビデオの長さ |
- 6秒 |
+ ビデオ長さ |
5秒または10秒 |
+ 6秒 |
フレームレート |
- 8 フレーム / 秒 |
- 16 フレーム / 秒 |
+ 16フレーム/秒 |
+ 8フレーム/秒 |
位置エンコーディング |
- 3d_sincos_pos_embed |
+ 3d_rope_pos_embed |
+ 3d_sincos_pos_embed |
3d_rope_pos_embed |
3d_rope_pos_embed + learnable_pos_embed |
- 3d_rope_pos_embed |
- 3d_rope_pos_embed |
ダウンロードリンク (Diffusers) |
+ 近日公開 |
🤗 HuggingFace 🤖 ModelScope 🟣 WiseModel |
🤗 HuggingFace 🤖 ModelScope 🟣 WiseModel |
🤗 HuggingFace 🤖 ModelScope 🟣 WiseModel |
- 近日公開 |
ダウンロードリンク (SAT) |
- SAT |
🤗 HuggingFace 🤖 ModelScope 🟣 WiseModel |
+ SAT |
diff --git a/README_zh.md b/README_zh.md
index a88cc369..c66fc855 100644
--- a/README_zh.md
+++ b/README_zh.md
@@ -154,49 +154,49 @@ CogVideoX是 [清影](https://chatglm.cn/video?fr=osm_cogvideox) 同源的开源
模型名 |
+ CogVideoX1.5-5B (最新) |
+ CogVideoX1.5-5B-I2V (最新) |
CogVideoX-2B |
CogVideoX-5B |
CogVideoX-5B-I2V |
- CogVideoX1.5-5B |
- CogVideoX1.5-5B-I2V |
发布时间 |
+ 2024年11月8日 |
+ 2024年11月8日 |
2024年8月6日 |
2024年8月27日 |
2024年9月19日 |
- 2024年11月8日 |
- 2024年11月8日 |
视频分辨率 |
- 720 * 480 |
1360 * 768 |
256 <= W <=1360 256 <= H <=768 W,H % 16 == 0 |
-
+ 720 * 480 |
+
推理精度 |
+ BF16 |
FP16*(推荐), BF16, FP32,FP8*,INT8,不支持INT4 |
BF16(推荐), FP16, FP32,FP8*,INT8,不支持INT4 |
- BF16 |
单GPU显存消耗
|
+ SAT BF16: 66GB
|
SAT FP16: 18GB diffusers FP16: 4GB起* diffusers INT8(torchao): 3.6G起* |
SAT BF16: 26GB diffusers BF16 : 5GB起* diffusers INT8(torchao): 4.4G起* |
- SAT BF16: 66GB
|
多GPU推理显存消耗 |
+ 不支持
|
FP16: 10GB* using diffusers
|
BF16: 15GB* using diffusers
|
- Not support
|
推理速度 (Step = 50, FP/BF16) |
+ 单卡A100: ~1000秒(5秒视频) 单卡H100: ~550秒(5秒视频) |
单卡A100: ~90秒 单卡H100: ~45秒 |
单卡A100: ~180秒 单卡H100: ~90秒 |
- 单卡A100: ~1000秒(5秒视频) 单卡H100: ~550秒(5秒视频) |
提示词语言 |
@@ -204,39 +204,37 @@ CogVideoX是 [清影](https://chatglm.cn/video?fr=osm_cogvideox) 同源的开源
提示词长度上限 |
- 226 Tokens |
224 Tokens |
+ 226 Tokens |
视频长度 |
- 6 秒 |
5 秒 或 10 秒 |
+ 6 秒 |
帧率 |
- 8 帧 / 秒 |
16 帧 / 秒 |
+ 8 帧 / 秒 |
位置编码 |
- 3d_sincos_pos_embed |
+ 3d_rope_pos_embed |
+ 3d_sincos_pos_embed |
3d_rope_pos_embed |
3d_rope_pos_embed + learnable_pos_embed |
- 3d_rope_pos_embed |
- 3d_rope_pos_embed |
下载链接 (Diffusers) |
+ 即将推出 |
🤗 HuggingFace 🤖 ModelScope 🟣 WiseModel |
🤗 HuggingFace 🤖 ModelScope 🟣 WiseModel |
🤗 HuggingFace 🤖 ModelScope 🟣 WiseModel |
- 即将推出 |
下载链接 (SAT) |
- SAT |
🤗 HuggingFace 🤖 ModelScope 🟣 WiseModel |
-
+ SAT |
diff --git a/sat/configs/cogvideox1.5_5b.yaml b/sat/configs/cogvideox1.5_5b.yaml
index 0000ec21..62d46be6 100644
--- a/sat/configs/cogvideox1.5_5b.yaml
+++ b/sat/configs/cogvideox1.5_5b.yaml
@@ -23,7 +23,7 @@ model:
params:
time_embed_dim: 512
elementwise_affine: True
- num_frames: 81
+ num_frames: 81 # for 5 seconds and 161 for 10 seconds
time_compressed_rate: 4
latent_width: 300
latent_height: 300
diff --git a/sat/configs/cogvideox1.5_5b_i2v.yaml b/sat/configs/cogvideox1.5_5b_i2v.yaml
index c65f0b7a..65d90f9e 100644
--- a/sat/configs/cogvideox1.5_5b_i2v.yaml
+++ b/sat/configs/cogvideox1.5_5b_i2v.yaml
@@ -25,11 +25,10 @@ model:
network_config:
target: dit_video_concat.DiffusionTransformer
params:
-# space_interpolation: 1.875
ofs_embed_dim: 512
time_embed_dim: 512
elementwise_affine: True
- num_frames: 81
+ num_frames: 81 # for 5 seconds and 161 for 10 seconds
time_compressed_rate: 4
latent_width: 300
latent_height: 300
diff --git a/sat/configs/inference.yaml b/sat/configs/inference.yaml
index a93bb997..644e71ab 100644
--- a/sat/configs/inference.yaml
+++ b/sat/configs/inference.yaml
@@ -1,16 +1,14 @@
args:
- image2video: False # True for image2video, False for text2video
+# image2video: True # True for image2video, False for text2video
latent_channels: 16
mode: inference
load: "{your CogVideoX SAT folder}/transformer" # This is for Full model without lora adapter
- # load: "{your lora folder} such as zRzRzRzRzRzRzR/lora-disney-08-20-13-28" # This is for Full model without lora adapter
batch_size: 1
input_type: txt
input_file: configs/test.txt
- sampling_image_size: [480, 720]
- sampling_num_frames: 13 # Must be 13, 11 or 9
- sampling_fps: 8
-# fp16: True # For CogVideoX-2B
- bf16: True # For CogVideoX-5B and CoGVideoX-5B-I2V
- output_dir: outputs/
+ sampling_image_size: [768, 1360] # remove this for I2V
+ sampling_num_frames: 22 # 42 for 10 seconds and 22 for 5 seconds
+ sampling_fps: 16
+ bf16: True
+ output_dir: outputs
force_inference: True
\ No newline at end of file
diff --git a/sat/diffusion_video.py b/sat/diffusion_video.py
index 10635b4d..226ed6e5 100644
--- a/sat/diffusion_video.py
+++ b/sat/diffusion_video.py
@@ -192,13 +192,13 @@ def decode_first_stage(self, z):
for i in range(fake_cp_size):
end_frame = start_frame + latent_time // fake_cp_size + (1 if i < latent_time % fake_cp_size else 0)
- fake_cp_rank0 = True if i == 0 else False
+ use_cp = True if i == 0 else False
clear_fake_cp_cache = True if i == fake_cp_size - 1 else False
with torch.no_grad():
recon = self.first_stage_model.decode(
z_now[:, :, start_frame:end_frame].contiguous(),
clear_fake_cp_cache=clear_fake_cp_cache,
- fake_cp_rank0=fake_cp_rank0,
+ use_cp=use_cp,
)
recons.append(recon)
start_frame = end_frame
diff --git a/sat/dit_video_concat.py b/sat/dit_video_concat.py
index b55a3f18..22c3821f 100644
--- a/sat/dit_video_concat.py
+++ b/sat/dit_video_concat.py
@@ -7,7 +7,6 @@
import torch
from torch import nn
import torch.nn.functional as F
-
from sat.model.base_model import BaseModel, non_conflict
from sat.model.mixins import BaseMixin
from sat.transformer_defaults import HOOKS_DEFAULT, attention_fn_default
diff --git a/sat/inference.sh b/sat/inference.sh
index a22ef872..9904433f 100755
--- a/sat/inference.sh
+++ b/sat/inference.sh
@@ -4,7 +4,7 @@ echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
environs="WORLD_SIZE=1 RANK=0 LOCAL_RANK=0 LOCAL_WORLD_SIZE=1"
-run_cmd="$environs python sample_video.py --base configs/cogvideox1.5_5b.yaml configs/test_inference.yaml --seed $RANDOM"
+run_cmd="$environs python sample_video.py --base configs/cogvideox1.5_5b.yaml configs/inference.yaml --seed $RANDOM"
echo ${run_cmd}
eval ${run_cmd}
diff --git a/sat/vae_modules/autoencoder.py b/sat/vae_modules/autoencoder.py
index 9642fb44..0adad855 100644
--- a/sat/vae_modules/autoencoder.py
+++ b/sat/vae_modules/autoencoder.py
@@ -1,17 +1,13 @@
import logging
import math
import re
-import random
from abc import abstractmethod
from contextlib import contextmanager
from typing import Any, Dict, List, Optional, Tuple, Union
-import numpy as np
import pytorch_lightning as pl
import torch
import torch.distributed
-import torch.nn as nn
-from einops import rearrange
from packaging import version
from vae_modules.ema import LitEma
@@ -56,34 +52,16 @@ def __init__(
if version.parse(torch.__version__) >= version.parse("2.0.0"):
self.automatic_optimization = False
- # def apply_ckpt(self, ckpt: Union[None, str, dict]):
- # if ckpt is None:
- # return
- # if isinstance(ckpt, str):
- # ckpt = {
- # "target": "sgm.modules.checkpoint.CheckpointEngine",
- # "params": {"ckpt_path": ckpt},
- # }
- # engine = instantiate_from_config(ckpt)
- # engine(self)
-
def apply_ckpt(self, ckpt: Union[None, str, dict]):
if ckpt is None:
return
- self.init_from_ckpt(ckpt)
-
- def init_from_ckpt(self, path, ignore_keys=list()):
- sd = torch.load(path, map_location="cpu")["state_dict"]
- keys = list(sd.keys())
- for k in keys:
- for ik in ignore_keys:
- if k.startswith(ik):
- print("Deleting key {} from state_dict.".format(k))
- del sd[k]
- missing_keys, unexpected_keys = self.load_state_dict(sd, strict=False)
- print("Missing keys: ", missing_keys)
- print("Unexpected keys: ", unexpected_keys)
- print(f"Restored from {path}")
+ if isinstance(ckpt, str):
+ ckpt = {
+ "target": "sgm.modules.checkpoint.CheckpointEngine",
+ "params": {"ckpt_path": ckpt},
+ }
+ engine = instantiate_from_config(ckpt)
+ engine(self)
@abstractmethod
def get_input(self, batch) -> Any:
@@ -119,7 +97,9 @@ def decode(self, *args, **kwargs) -> torch.Tensor:
def instantiate_optimizer_from_config(self, params, lr, cfg):
logpy.info(f"loading >>> {cfg['target']} <<< optimizer from config")
- return get_obj_from_str(cfg["target"])(params, lr=lr, **cfg.get("params", dict()))
+ return get_obj_from_str(cfg["target"])(
+ params, lr=lr, **cfg.get("params", dict())
+ )
def configure_optimizers(self) -> Any:
raise NotImplementedError()
@@ -216,12 +196,13 @@ def get_last_layer(self):
return self.decoder.get_last_layer()
def encode(
- self,
- x: torch.Tensor,
- return_reg_log: bool = False,
- unregularized: bool = False,
+ self,
+ x: torch.Tensor,
+ return_reg_log: bool = False,
+ unregularized: bool = False,
+ **kwargs,
) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
- z = self.encoder(x)
+ z = self.encoder(x, **kwargs)
if unregularized:
return z, dict()
z, reg_log = self.regularization(z)
diff --git a/sat/vae_modules/cp_enc_dec.py b/sat/vae_modules/cp_enc_dec.py
index 1d9c34f5..4d773240 100644
--- a/sat/vae_modules/cp_enc_dec.py
+++ b/sat/vae_modules/cp_enc_dec.py
@@ -101,8 +101,6 @@ def _gather(input_, dim):
group = get_context_parallel_group()
cp_rank = get_context_parallel_rank()
- # print('in _gather, cp_rank:', cp_rank, 'input_size:', input_.shape)
-
input_first_frame_ = input_.transpose(0, dim)[:1].transpose(0, dim).contiguous()
if cp_rank == 0:
input_ = input_.transpose(0, dim)[1:].transpose(0, dim).contiguous()
@@ -127,12 +125,9 @@ def _gather(input_, dim):
def _conv_split(input_, dim, kernel_size):
cp_world_size = get_context_parallel_world_size()
- # Bypass the function if context parallel is 1
if cp_world_size == 1:
return input_
- # print('in _conv_split, cp_rank:', cp_rank, 'input_size:', input_.shape)
-
cp_rank = get_context_parallel_rank()
dim_size = (input_.size()[dim] - kernel_size) // cp_world_size
@@ -140,14 +135,11 @@ def _conv_split(input_, dim, kernel_size):
if cp_rank == 0:
output = input_.transpose(dim, 0)[: dim_size + kernel_size].transpose(dim, 0)
else:
- # output = input_.transpose(dim, 0)[cp_rank * dim_size + 1:(cp_rank + 1) * dim_size + kernel_size].transpose(dim, 0)
output = input_.transpose(dim, 0)[
cp_rank * dim_size + kernel_size : (cp_rank + 1) * dim_size + kernel_size
].transpose(dim, 0)
output = output.contiguous()
- # print('out _conv_split, cp_rank:', cp_rank, 'input_size:', output.shape)
-
return output
@@ -160,9 +152,6 @@ def _conv_gather(input_, dim, kernel_size):
group = get_context_parallel_group()
cp_rank = get_context_parallel_rank()
-
- # print('in _conv_gather, cp_rank:', cp_rank, 'input_size:', input_.shape)
-
input_first_kernel_ = input_.transpose(0, dim)[:kernel_size].transpose(0, dim).contiguous()
if cp_rank == 0:
input_ = input_.transpose(0, dim)[kernel_size:].transpose(0, dim).contiguous()
@@ -255,17 +244,12 @@ def _fake_cp_pass_from_previous_rank(input_, dim, kernel_size, cache_padding=Non
if recv_rank % cp_world_size == cp_world_size - 1:
recv_rank += cp_world_size
- # req_send = torch.distributed.isend(input_[-kernel_size + 1:].contiguous(), send_rank, group=group)
- # recv_buffer = torch.empty_like(input_[-kernel_size + 1:]).contiguous()
- # req_recv = torch.distributed.recv(recv_buffer, recv_rank, group=group)
- # req_recv.wait()
recv_buffer = torch.empty_like(input_[-kernel_size + 1 :]).contiguous()
if cp_rank < cp_world_size - 1:
req_send = torch.distributed.isend(input_[-kernel_size + 1 :].contiguous(), send_rank, group=group)
if cp_rank > 0:
req_recv = torch.distributed.irecv(recv_buffer, recv_rank, group=group)
- # req_send = torch.distributed.isend(input_[-kernel_size + 1:].contiguous(), send_rank, group=group)
- # req_recv = torch.distributed.irecv(recv_buffer, recv_rank, group=group)
+
if cp_rank == 0:
if cache_padding is not None:
@@ -421,7 +405,6 @@ def forward(self, input_):
def Normalize(in_channels, gather=False, **kwargs):
- # same for 3D and 2D
if gather:
return ContextParallelGroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
else:
@@ -468,8 +451,8 @@ def __init__(
kernel_size=1,
)
- def forward(self, f, zq, clear_fake_cp_cache=True, fake_cp_rank0=True):
- if f.shape[2] > 1 and get_context_parallel_rank() == 0 and fake_cp_rank0:
+ def forward(self, f, zq, clear_fake_cp_cache=True, fake_cp=True):
+ if f.shape[2] > 1 and get_context_parallel_rank() == 0 and fake_cp:
f_first, f_rest = f[:, :, :1], f[:, :, 1:]
f_first_size, f_rest_size = f_first.shape[-3:], f_rest.shape[-3:]
zq_first, zq_rest = zq[:, :, :1], zq[:, :, 1:]
@@ -531,13 +514,11 @@ def __init__(
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
self.compress_time = compress_time
- def forward(self, x, fake_cp_rank0=True):
+ def forward(self, x, fake_cp=True):
if self.compress_time and x.shape[2] > 1:
- if get_context_parallel_rank() == 0 and fake_cp_rank0:
- # print(x.shape)
+ if get_context_parallel_rank() == 0 and fake_cp:
# split first frame
x_first, x_rest = x[:, :, 0], x[:, :, 1:]
-
x_first = torch.nn.functional.interpolate(x_first, scale_factor=2.0, mode="nearest")
splits = torch.split(x_rest, 32, dim=1)
@@ -545,8 +526,6 @@ def forward(self, x, fake_cp_rank0=True):
torch.nn.functional.interpolate(split, scale_factor=2.0, mode="nearest") for split in splits
]
x_rest = torch.cat(interpolated_splits, dim=1)
-
- # x_rest = torch.nn.functional.interpolate(x_rest, scale_factor=2.0, mode="nearest")
x = torch.cat([x_first[:, :, None, :, :], x_rest], dim=2)
else:
splits = torch.split(x, 32, dim=1)
@@ -555,13 +534,10 @@ def forward(self, x, fake_cp_rank0=True):
]
x = torch.cat(interpolated_splits, dim=1)
- # x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
-
else:
# only interpolate 2D
t = x.shape[2]
x = rearrange(x, "b c t h w -> (b t) c h w")
- # x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
splits = torch.split(x, 32, dim=1)
interpolated_splits = [
@@ -590,12 +566,12 @@ def __init__(self, in_channels, with_conv, compress_time=False, out_channels=Non
self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=0)
self.compress_time = compress_time
- def forward(self, x, fake_cp_rank0=True):
+ def forward(self, x, fake_cp=True):
if self.compress_time and x.shape[2] > 1:
h, w = x.shape[-2:]
x = rearrange(x, "b c t h w -> (b h w) c t")
- if get_context_parallel_rank() == 0 and fake_cp_rank0:
+ if get_context_parallel_rank() == 0 and fake_cp:
# split first frame
x_first, x_rest = x[..., 0], x[..., 1:]
@@ -693,17 +669,13 @@ def __init__(
padding=0,
)
- def forward(self, x, temb, zq=None, clear_fake_cp_cache=True, fake_cp_rank0=True):
+ def forward(self, x, temb, zq=None, clear_fake_cp_cache=True, fake_cp=True):
h = x
- # if isinstance(self.norm1, torch.nn.GroupNorm):
- # h = conv_gather_from_context_parallel_region(h, dim=2, kernel_size=1)
if zq is not None:
- h = self.norm1(h, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp_rank0=fake_cp_rank0)
+ h = self.norm1(h, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp=fake_cp)
else:
h = self.norm1(h)
- # if isinstance(self.norm1, torch.nn.GroupNorm):
- # h = conv_scatter_to_context_parallel_region(h, dim=2, kernel_size=1)
h = nonlinearity(h)
h = self.conv1(h, clear_cache=clear_fake_cp_cache)
@@ -711,14 +683,10 @@ def forward(self, x, temb, zq=None, clear_fake_cp_cache=True, fake_cp_rank0=True
if temb is not None:
h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None, None]
- # if isinstance(self.norm2, torch.nn.GroupNorm):
- # h = conv_gather_from_context_parallel_region(h, dim=2, kernel_size=1)
if zq is not None:
- h = self.norm2(h, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp_rank0=fake_cp_rank0)
+ h = self.norm2(h, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp=fake_cp)
else:
h = self.norm2(h)
- # if isinstance(self.norm2, torch.nn.GroupNorm):
- # h = conv_scatter_to_context_parallel_region(h, dim=2, kernel_size=1)
h = nonlinearity(h)
h = self.dropout(h)
@@ -827,32 +795,33 @@ def __init__(
kernel_size=3,
)
- def forward(self, x, clear_fake_cp_cache=True, fake_cp_rank0=True):
+ def forward(self, x, use_cp=True):
+ global _USE_CP
+ _USE_CP = use_cp
+
# timestep embedding
temb = None
# downsampling
- h = self.conv_in(x, clear_cache=clear_fake_cp_cache)
+ hs = [self.conv_in(x)]
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks):
- h = self.down[i_level].block[i_block](h, temb, clear_fake_cp_cache=clear_fake_cp_cache)
+ h = self.down[i_level].block[i_block](hs[-1], temb)
if len(self.down[i_level].attn) > 0:
- print("Attention not implemented")
h = self.down[i_level].attn[i_block](h)
+ hs.append(h)
if i_level != self.num_resolutions - 1:
- h = self.down[i_level].downsample(h, fake_cp_rank0=fake_cp_rank0)
+ hs.append(self.down[i_level].downsample(hs[-1]))
# middle
- h = self.mid.block_1(h, temb, clear_fake_cp_cache=clear_fake_cp_cache)
- h = self.mid.block_2(h, temb, clear_fake_cp_cache=clear_fake_cp_cache)
+ h = hs[-1]
+ h = self.mid.block_1(h, temb)
+ h = self.mid.block_2(h, temb)
# end
- # h = conv_gather_from_context_parallel_region(h, dim=2, kernel_size=1)
h = self.norm_out(h)
- # h = conv_scatter_to_context_parallel_region(h, dim=2, kernel_size=1)
-
h = nonlinearity(h)
- h = self.conv_out(h, clear_cache=clear_fake_cp_cache)
+ h = self.conv_out(h)
return h
@@ -895,11 +864,9 @@ def __init__(
zq_ch = z_channels
# compute in_ch_mult, block_in and curr_res at lowest res
- in_ch_mult = (1,) + tuple(ch_mult)
block_in = ch * ch_mult[self.num_resolutions - 1]
curr_res = resolution // 2 ** (self.num_resolutions - 1)
self.z_shape = (1, z_channels, curr_res, curr_res)
- print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape)))
self.conv_in = ContextParallelCausalConv3d(
chan_in=z_channels,
@@ -955,11 +922,6 @@ def __init__(
up.block = block
up.attn = attn
if i_level != 0:
- # # Symmetrical enc-dec
- if i_level <= self.temporal_compress_level:
- up.upsample = Upsample3D(block_in, with_conv=resamp_with_conv, compress_time=True)
- else:
- up.upsample = Upsample3D(block_in, with_conv=resamp_with_conv, compress_time=False)
if i_level < self.num_resolutions - self.temporal_compress_level:
up.upsample = Upsample3D(block_in, with_conv=resamp_with_conv, compress_time=False)
else:
@@ -974,7 +936,9 @@ def __init__(
kernel_size=3,
)
- def forward(self, z, clear_fake_cp_cache=True, fake_cp_rank0=True):
+ def forward(self, z, clear_fake_cp_cache=True, use_cp=True):
+ global _USE_CP
+ _USE_CP = use_cp
self.last_z_shape = z.shape
# timestep embedding
@@ -987,25 +951,25 @@ def forward(self, z, clear_fake_cp_cache=True, fake_cp_rank0=True):
h = self.conv_in(z, clear_cache=clear_fake_cp_cache)
# middle
- h = self.mid.block_1(h, temb, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp_rank0=fake_cp_rank0)
- h = self.mid.block_2(h, temb, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp_rank0=fake_cp_rank0)
+ h = self.mid.block_1(h, temb, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp=use_cp)
+ h = self.mid.block_2(h, temb, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp=use_cp)
# upsampling
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks + 1):
h = self.up[i_level].block[i_block](
- h, temb, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp_rank0=fake_cp_rank0
+ h, temb, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp=use_cp
)
if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h, zq)
if i_level != 0:
- h = self.up[i_level].upsample(h, fake_cp_rank0=fake_cp_rank0)
+ h = self.up[i_level].upsample(h, fake_cp=use_cp)
# end
if self.give_pre_end:
return h
- h = self.norm_out(h, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp_rank0=fake_cp_rank0)
+ h = self.norm_out(h, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp=use_cp)
h = nonlinearity(h)
h = self.conv_out(h, clear_cache=clear_fake_cp_cache)