Skip to content

Commit

Permalink
interposer prototype
Browse files Browse the repository at this point in the history
Signed-off-by: Vladimir Mandic <[email protected]>
  • Loading branch information
vladmandic committed Feb 13, 2025
1 parent f94196b commit d73a185
Show file tree
Hide file tree
Showing 4 changed files with 184 additions and 3 deletions.
5 changes: 3 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@

### Highlight for 2025-02-13

We're back with another update!
- Starting with massive UI update with full [localization](https://vladmandic.github.io/sdnext-docs/Locale/) for 8 languages and 100+ new [hints](https://vladmandic.github.io/sdnext-docs/Hints/)
We're back with another update with over 50 commits!
- Starting with massive UI update with full [localization](https://vladmandic.github.io/sdnext-docs/Locale/) for 8 languages
and 100+ new [hints](https://vladmandic.github.io/sdnext-docs/Hints/)
- Big update to [Docker](https://vladmandic.github.io/sdnext-docs/Docker/) containers with support for all major compute platforms
- A lot of [outpainting](https://vladmandic.github.io/sdnext-docs/Outpaint/) goodies
- Support for new models: [AlphaVLLM Lumina 2](https://github.com/Alpha-VLLM/Lumina-Image-2.0) and [Ostris Flex.1-Alpha](https://huggingface.co/ostris/Flex.1-alpha)
Expand Down
164 changes: 164 additions & 0 deletions modules/interposer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
# converted from <https://github.com/city96/SD-Latent-Interposer>

import os
import time
import torch
import torch.nn as nn
from safetensors.torch import load_file


# v1 = Stable Diffusion 1.x
# xl = Stable Diffusion Extra Large (SDXL)
# v3 = Stable Diffusion Version Three (SD3)
# fx = Black Forest Labs Flux dot One
# cc = Stable Cascade (Stage C) [not used]
# ca = Stable Cascade (Stage A/B)
config = {
"v1-to-xl": {"ch_in": 4, "ch_out": 4, "ch_mid": 64, "scale": 1.0, "blocks": 12},
"v1-to-v3": {"ch_in": 4, "ch_out":16, "ch_mid": 64, "scale": 1.0, "blocks": 12},
"xl-to-v1": {"ch_in": 4, "ch_out": 4, "ch_mid": 64, "scale": 1.0, "blocks": 12},
"xl-to-v3": {"ch_in": 4, "ch_out":16, "ch_mid": 64, "scale": 1.0, "blocks": 12},
"v3-to-v1": {"ch_in":16, "ch_out": 4, "ch_mid": 64, "scale": 1.0, "blocks": 12},
"v3-to-xl": {"ch_in":16, "ch_out": 4, "ch_mid": 64, "scale": 1.0, "blocks": 12},
"fx-to-v1": {"ch_in":16, "ch_out": 4, "ch_mid": 64, "scale": 1.0, "blocks": 12},
"fx-to-xl": {"ch_in":16, "ch_out": 4, "ch_mid": 64, "scale": 1.0, "blocks": 12},
"fx-to-v3": {"ch_in":16, "ch_out":16, "ch_mid": 64, "scale": 1.0, "blocks": 12},
"ca-to-v1": {"ch_in": 4, "ch_out": 4, "ch_mid": 64, "scale": 0.5, "blocks": 12},
"ca-to-xl": {"ch_in": 4, "ch_out": 4, "ch_mid": 64, "scale": 0.5, "blocks": 12},
"ca-to-v3": {"ch_in": 4, "ch_out":16, "ch_mid": 64, "scale": 0.5, "blocks": 12},
}


class ResBlock(nn.Module):
"""Block with residuals"""
def __init__(self, ch):
super().__init__()
self.join = nn.ReLU()
self.norm = nn.BatchNorm2d(ch)
self.long = nn.Sequential(
nn.Conv2d(ch, ch, kernel_size=3, stride=1, padding=1),
nn.SiLU(),
nn.Conv2d(ch, ch, kernel_size=3, stride=1, padding=1),
nn.SiLU(),
nn.Conv2d(ch, ch, kernel_size=3, stride=1, padding=1),
nn.Dropout(0.1)
)
def forward(self, x):
x = self.norm(x)
return self.join(self.long(x) + x)


class ExtractBlock(nn.Module):
"""Increase no. of channels by [out/in]"""
def __init__(self, ch_in, ch_out):
super().__init__()
self.join = nn.ReLU()
self.short = nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1)
self.long = nn.Sequential(
nn.Conv2d( ch_in, ch_out, kernel_size=3, stride=1, padding=1),
nn.SiLU(),
nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1),
nn.SiLU(),
nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1),
nn.Dropout(0.1)
)
def forward(self, x):
return self.join(self.long(x) + self.short(x))


class InterposerModel(nn.Module):
"""
NN layout, ported from:
https://github.com/city96/SD-Latent-Interposer/blob/main/interposer.py
"""
def __init__(self, ch_in=4, ch_out=4, ch_mid=64, scale=1.0, blocks=12):
super().__init__()
self.ch_in = ch_in
self.ch_out = ch_out
self.ch_mid = ch_mid
self.blocks = blocks
self.scale = scale

self.head = ExtractBlock(self.ch_in, self.ch_mid)
self.core = nn.Sequential(
nn.Upsample(scale_factor=self.scale, mode="nearest"),
*[ResBlock(self.ch_mid) for _ in range(blocks)],
nn.BatchNorm2d(self.ch_mid),
nn.SiLU(),
)
self.tail = nn.Conv2d(self.ch_mid, self.ch_out, kernel_size=3, stride=1, padding=1)

def forward(self, x):
y = self.head(x)
z = self.core(y)
return self.tail(z)


def map_model_name(name: str):
if name == 'sd':
return 'v1'
if name == 'sdxl':
return 'xl'
if name == 'sd3':
return 'v3'
if name == 'f1':
return 'fx'
return name


class Interposer:
def __init__(self):
self.version = 4.0 # network revision
self.loaded = None # current model name
self.model = None # current model
self.vae = None # current VAE

def convert(self, src: str, dst: str, latents: torch.Tensor):
from diffusers import AutoencoderKL
from huggingface_hub import hf_hub_download
from modules import shared, devices

src = map_model_name(src)
dst = map_model_name(dst)
if src == dst:
return None
model_name = f"{src}-to-{dst}"
if model_name not in config:
shared.log.error(f'Interposer: model="{model_name}" unknown')
return None
if (self.loaded != model_name) or (self.model is None):
model_fn = hf_hub_download(
repo_id="city96/SD-Latent-Interposer",
subfolder=f"v{self.version}",
filename=f"{model_name}_interposer-v{self.version}.safetensors",
cache_dir=shared.opts.hfcache_dir,
)
self.model = InterposerModel(**config[model_name])
self.model = self.model.to(device=devices.cpu, dtype=torch.float32)
self.model.eval()
self.model.load_state_dict(load_file(model_fn))
self.loaded = model_name
if dst == 'v1':
vae_repo = 'stable-diffusion-v1-5/stable-diffusion-v1-5'
self.vae = AutoencoderKL.from_pretrained(vae_repo, subfolder='vae', cache_dir=shared.opts.hfcache_dir, torch_dtype=devices.dtype)
elif dst == 'xl':
vae_repo = 'madebyollin/sdxl-vae-fp16-fix'
self.vae = AutoencoderKL.from_pretrained(vae_repo, cache_dir=shared.opts.hfcache_dir, torch_dtype=devices.dtype)
elif dst == 'v3':
vae_repo = 'stabilityai/stable-diffusion-3.5-large'
self.vae = AutoencoderKL.from_pretrained(vae_repo, subfolder='vae', cache_dir=shared.opts.hfcache_dir, torch_dtype=devices.dtype)
elif dst == 'fx':
vae_repo = 'black-forest-labs/FLUX.1-dev'
self.vae = AutoencoderKL.from_pretrained(vae_repo, subfolder='vae', cache_dir=shared.opts.hfcache_dir, torch_dtype=devices.dtype)

t0 = time.time()
if self.model is None or self.vae is None:
return None
with torch.no_grad():
latent = latents.clone().cpu().float() # force fp32, always run on CPU
output = self.model(latent)
output = output.to(device=latents.device, dtype=latents.dtype)
t1 = time.time()
shared.log.debug(f'Interposer: src={src}/{list(latents.shape)} dst={dst}/{list(output.shape)} model="{os.path.basename(model_fn)}" vae="{vae_repo}" time={t1-t0:.2f}')
# shared.log.debug(f'Interposer: src={latents.aminmax()} dst={output.aminmax()}')
return output
16 changes: 16 additions & 0 deletions modules/processing_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,18 @@ def full_vqgan_decode(latents, model):
return decoded


def vae_interpose(latents, target):
from modules.interposer import Interposer
interposer = Interposer()
converted = interposer.convert(src=shared.sd_model_type, dst=target, latents=latents)
if converted is None:
return None
interposer.vae = interposer.vae.to(device=devices.device, dtype=devices.dtype)
decoded = interposer.vae.decode(converted, return_dict=False)[0]
interposer.vae = interposer.vae.to(device=devices.cpu)
return decoded


def full_vae_decode(latents, model):
t0 = time.time()
if not hasattr(model, 'vae') and hasattr(model, 'pipe'):
Expand All @@ -96,6 +108,10 @@ def full_vae_decode(latents, model):
devices.torch_gc(force=True)
shared.mem_mon.reset()

# decoded = vae_interpose(latents, shared.opts.vae_interpose)
# if decoded is not None:
# return decoded

base_device = None
if shared.opts.diffusers_move_unet and not getattr(model, 'has_accelerate', False):
base_device = sd_models.move_base(model, devices.cpu)
Expand Down
2 changes: 1 addition & 1 deletion wiki
Submodule wiki updated from 3f91c3 to 068dbd

0 comments on commit d73a185

Please sign in to comment.