-
Notifications
You must be signed in to change notification settings - Fork 0
/
app.py
89 lines (80 loc) · 3.21 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import gradio as gr
from PIL import Image
import torch
from segment_anything import SamPredictor, sam_model_registry
from diffusers import StableDiffusionInpaintPipeline
from groundingdino.util.inference import load_model, predict, annotate
from groundingdino.util import box_ops
import numpy as np
from torchvision import transforms
from torchvision.transforms import ToPILImage
from torchvision import transforms as T
device = "cuda"
model_type = "vit_h"
predictor = SamPredictor(sam_model_registry[model_type](checkpoint="sam_vit_h_4b8939.pth").to(device=device))
pipe = StableDiffusionInpaintPipeline.from_pretrained("stabilityai/stable-diffusion-2-inpainting", torch_dtype=torch.float16).to(device)
groundingdino_model = load_model("GroundingDINO_SwinT_OGC.py", "groundingdino_swint_ogc.pth")
def show_mask(mask, image, random_color=True):
if random_color:
color = np.concatenate([np.random.random(3), np.array([0.8])], axis=0)
else:
color = np.array([30/255, 144/255, 255/255, 0.6])
h, w = mask.shape[-2:]
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
annotated_frame_pil = Image.fromarray(image).convert("RGBA")
mask_image_pil = Image.fromarray((mask_image.cpu().numpy() * 255).astype(np.uint8)).convert("RGBA")
return np.array(Image.alpha_composite(annotated_frame_pil, mask_image_pil))
def process_boxes(boxes, src):
H, W, _ = src.shape
boxes_xyxy = box_ops.box_cxcywh_to_xyxy(boxes) * torch.Tensor([W, H, W, H])
return predictor.transform.apply_boxes_torch(boxes_xyxy, src.shape[:2]).to(device)
def edit_image(image, item, prompt, box_threshold, text_threshold):
src,img = load_image(image)
boxes, logits, phrases = predict(
model=groundingdino_model,
image=img,
caption=item,
box_threshold=box_threshold,
text_threshold=text_threshold
)
predictor.set_image(src)
new_boxes = process_boxes(boxes, src)
masks, _, _ = predictor.predict_torch(
point_coords=None,
point_labels=None,
boxes=new_boxes,
multimask_output=False,
)
edited_image =pipe(prompt=prompt,
image=image.resize((512, 512)),
mask_image=Image.fromarray(masks[0][0].cpu().numpy()).resize((512, 512))
).images[0]
return edited_image
def load_image(image):
transform = T.Compose(
[
T.Resize(800),
T.ToTensor(),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
image_transformed = transform(image)
image = np.asarray(image)
return image, image_transformed
def gradio_interface(image, item, prompt, box_threshold, text_threshold):
edited_image = edit_image(image, item, prompt, box_threshold, text_threshold)
return edited_image
iface = gr.Interface(
fn=gradio_interface,
inputs=[
gr.Image(type="pil"),
gr.Textbox(label="Item"),
gr.Textbox(label="Prompt"),
gr.Slider(minimum=0, maximum=1, value=0.5, label="Box Threshold"),
gr.Slider(minimum=0, maximum=1, value=0.2, label="Text Threshold"),
],
outputs=gr.Image(type="pil"),
title="Image Inpainting",
description="Upload an image and specify your editing criteria to see the result."
)
iface.launch()