-
Notifications
You must be signed in to change notification settings - Fork 499
/
generate.py
executable file
·123 lines (96 loc) · 5.13 KB
/
generate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
"""Generate images using pretrained network pickle."""
import argparse
import os
import pickle
import re
import numpy as np
import PIL.Image
import dnnlib
import dnnlib.tflib as tflib
#----------------------------------------------------------------------------
def generate_images(network_pkl, seeds, truncation_psi, outdir, class_idx, dlatents_npz):
tflib.init_tf()
print('Loading networks from "%s"...' % network_pkl)
with dnnlib.util.open_url(network_pkl) as fp:
_G, _D, Gs = pickle.load(fp)
os.makedirs(outdir, exist_ok=True)
# Render images for a given dlatent vector.
if dlatents_npz is not None:
print(f'Generating images from dlatents file "{dlatents_npz}"')
dlatents = np.load(dlatents_npz)['dlatents']
assert dlatents.shape[1:] == (18, 512) # [N, 18, 512]
imgs = Gs.components.synthesis.run(dlatents, output_transform=dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True))
for i, img in enumerate(imgs):
fname = f'{outdir}/dlatent{i:02d}.png'
print (f'Saved {fname}')
PIL.Image.fromarray(img, 'RGB').save(fname)
return
# Render images for dlatents initialized from random seeds.
Gs_kwargs = {
'output_transform': dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True),
'randomize_noise': False
}
if truncation_psi is not None:
Gs_kwargs['truncation_psi'] = truncation_psi
noise_vars = [var for name, var in Gs.components.synthesis.vars.items() if name.startswith('noise')]
label = np.zeros([1] + Gs.input_shapes[1][1:])
if class_idx is not None:
label[:, class_idx] = 1
for seed_idx, seed in enumerate(seeds):
print('Generating image for seed %d (%d/%d) ...' % (seed, seed_idx, len(seeds)))
rnd = np.random.RandomState(seed)
z = rnd.randn(1, *Gs.input_shape[1:]) # [minibatch, component]
tflib.set_vars({var: rnd.randn(*var.shape.as_list()) for var in noise_vars}) # [height, width]
images = Gs.run(z, label, **Gs_kwargs) # [minibatch, height, width, channel]
PIL.Image.fromarray(images[0], 'RGB').save(f'{outdir}/seed{seed:04d}.png')
#----------------------------------------------------------------------------
def _parse_num_range(s):
'''Accept either a comma separated list of numbers 'a,b,c' or a range 'a-c' and return as a list of ints.'''
range_re = re.compile(r'^(\d+)-(\d+)$')
m = range_re.match(s)
if m:
return list(range(int(m.group(1)), int(m.group(2))+1))
vals = s.split(',')
return [int(x) for x in vals]
#----------------------------------------------------------------------------
_examples = '''examples:
# Generate curated MetFaces images without truncation (Fig.10 left)
python %(prog)s --outdir=out --trunc=1 --seeds=85,265,297,849 \\
--network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/metfaces.pkl
# Generate uncurated MetFaces images with truncation (Fig.12 upper left)
python %(prog)s --outdir=out --trunc=0.7 --seeds=600-605 \\
--network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/metfaces.pkl
# Generate class conditional CIFAR-10 images (Fig.17 left, Car)
python %(prog)s --outdir=out --trunc=1 --seeds=0-35 --class=1 \\
--network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/cifar10.pkl
# Render image from projected latent vector
python %(prog)s --outdir=out --dlatents=out/dlatents.npz \\
--network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/ffhq.pkl
'''
#----------------------------------------------------------------------------
def main():
parser = argparse.ArgumentParser(
description='Generate images using pretrained network pickle.',
epilog=_examples,
formatter_class=argparse.RawDescriptionHelpFormatter
)
parser.add_argument('--network', help='Network pickle filename', dest='network_pkl', required=True)
g = parser.add_mutually_exclusive_group(required=True)
g.add_argument('--seeds', type=_parse_num_range, help='List of random seeds')
g.add_argument('--dlatents', dest='dlatents_npz', help='Generate images for saved dlatents')
parser.add_argument('--trunc', dest='truncation_psi', type=float, help='Truncation psi (default: %(default)s)', default=0.5)
parser.add_argument('--class', dest='class_idx', type=int, help='Class label (default: unconditional)')
parser.add_argument('--outdir', help='Where to save the output images', required=True, metavar='DIR')
args = parser.parse_args()
generate_images(**vars(args))
#----------------------------------------------------------------------------
if __name__ == "__main__":
main()
#----------------------------------------------------------------------------