-
Notifications
You must be signed in to change notification settings - Fork 150
/
convert_ckpt.py
44 lines (26 loc) · 1.3 KB
/
convert_ckpt.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import torch
import argparse
def add_additional_channels(state_dict, num_additional_channels):
"state_dict should be just from unet model, not the entire SD or GLIGEN"
if num_additional_channels != 0:
new_conv_weight = torch.zeros(320, 4+num_additional_channels, 3, 3 )
for key,value in state_dict.items():
if key == "input_blocks.0.0.weight":
old_conv_weight = value
new_conv_weight[:,0:4,:,:] = old_conv_weight
state_dict[key] = new_conv_weight
if __name__ == "__main__":
# The following code will add additional 5 channels (for inpainting) to a GLIGEN ckpt
parser = argparse.ArgumentParser()
parser.add_argument("--ckpt_path", type=str, default=None, help="")
parser.add_argument("--new_ckpt_path", type=str, default=None, help="")
args = parser.parse_args()
new_conv_weight = torch.zeros(320, 4+4+1, 3, 3 )
ckpt = torch.load(args.ckpt_path, map_location="cpu")
for key,value in ckpt["model"].items():
if key == "input_blocks.0.0.weight":
old_conv_weight = value
new_conv_weight[:,0:4,:,:] = old_conv_weight
ckpt["model"]["input_blocks.0.0.weight"] = new_conv_weight
save = {"model":ckpt["model"]}
torch.save(save, args.new_ckpt_path)