You've already forked OpenShot-ComfyUI
mirror of
https://github.com/OpenShot/OpenShot-ComfyUI.git
synced 2026-03-02 08:56:22 -08:00
2855 lines
114 KiB
Python
2855 lines
114 KiB
Python
import json
|
|
import os
|
|
import hashlib
|
|
import subprocess
|
|
import shutil
|
|
import time
|
|
import tempfile
|
|
from contextlib import nullcontext
|
|
from urllib.parse import urlparse
|
|
from fractions import Fraction
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from torch.hub import download_url_to_file
|
|
from PIL import Image
|
|
|
|
import comfy.model_management as mm
|
|
from comfy.utils import ProgressBar, common_upscale
|
|
import folder_paths
|
|
from hydra import initialize_config_dir
|
|
from hydra.core.global_hydra import GlobalHydra
|
|
|
|
try:
|
|
import sam2.build_sam as sam2_build
|
|
from sam2.build_sam import build_sam2
|
|
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
|
except Exception as ex: # pragma: no cover - runtime env specific
|
|
sam2_build = None
|
|
build_sam2 = None
|
|
SAM2ImagePredictor = None
|
|
_sam2_import_error = ex
|
|
else:
|
|
_sam2_import_error = None
|
|
|
|
try:
|
|
from transformers import AutoModelForZeroShotObjectDetection, AutoProcessor
|
|
except Exception as ex: # pragma: no cover - runtime env specific
|
|
AutoModelForZeroShotObjectDetection = None
|
|
AutoProcessor = None
|
|
_groundingdino_import_error = ex
|
|
else:
|
|
_groundingdino_import_error = None
|
|
|
|
try:
|
|
from transnetv2_pytorch import TransNetV2 as _TransNetV2
|
|
except Exception as ex: # pragma: no cover - runtime env specific
|
|
_TransNetV2 = None
|
|
_transnet_import_error = ex
|
|
else:
|
|
_transnet_import_error = None
|
|
|
|
|
|
SAM2_MODEL_DIR = "sam2"
|
|
OPENSHOT_NODEPACK_VERSION = "v1.1.2-track-object-keyframes"
|
|
GROUNDING_DINO_MODEL_IDS = (
|
|
"IDEA-Research/grounding-dino-tiny",
|
|
"IDEA-Research/grounding-dino-base",
|
|
)
|
|
GROUNDING_DINO_CACHE = {}
|
|
def _sam2_debug_enabled():
|
|
# Temporary: always-on debug while we diagnose chunk/carry drift.
|
|
return True
|
|
|
|
|
|
def _sam2_debug(*parts):
|
|
if _sam2_debug_enabled():
|
|
try:
|
|
print("[OpenShot-SAM2-DEBUG]", *parts)
|
|
except Exception:
|
|
pass
|
|
|
|
|
|
SAM2_MODELS = {
|
|
"sam2.1_hiera_tiny.safetensors": {
|
|
"url": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_tiny.pt",
|
|
"config": "sam2.1_hiera_t.yaml",
|
|
},
|
|
"sam2.1_hiera_small.safetensors": {
|
|
"url": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_small.pt",
|
|
"config": "sam2.1_hiera_s.yaml",
|
|
},
|
|
"sam2.1_hiera_base_plus.safetensors": {
|
|
"url": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_base_plus.pt",
|
|
"config": "sam2.1_hiera_b+.yaml",
|
|
},
|
|
"sam2.1_hiera_large.safetensors": {
|
|
"url": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt",
|
|
"config": "sam2.1_hiera_l.yaml",
|
|
},
|
|
}
|
|
|
|
|
|
def _require_sam2():
|
|
if build_sam2 is None or SAM2ImagePredictor is None:
|
|
raise RuntimeError(
|
|
"SAM2 imports failed. Ensure `sam2` is available in Comfy runtime. Error: {}".format(_sam2_import_error)
|
|
)
|
|
|
|
|
|
def _require_groundingdino():
|
|
if AutoModelForZeroShotObjectDetection is None or AutoProcessor is None:
|
|
raise RuntimeError(
|
|
"GroundingDINO imports failed. Install requirements and restart ComfyUI. Error: {}".format(
|
|
_groundingdino_import_error
|
|
)
|
|
)
|
|
|
|
|
|
def _require_transnet():
|
|
if _TransNetV2 is None:
|
|
raise RuntimeError(
|
|
"TransNetV2 imports failed. Install `transnetv2-pytorch` and restart ComfyUI. Error: {}".format(
|
|
_transnet_import_error
|
|
)
|
|
)
|
|
|
|
|
|
def _model_storage_dir():
|
|
path = os.path.join(folder_paths.models_dir, SAM2_MODEL_DIR)
|
|
os.makedirs(path, exist_ok=True)
|
|
return path
|
|
|
|
|
|
def _safe_get_filename_list(model_dir_name):
|
|
try:
|
|
return list(folder_paths.get_filename_list(model_dir_name) or [])
|
|
except Exception:
|
|
# Folder key may not be registered in some Comfy installs.
|
|
path = os.path.join(folder_paths.models_dir, model_dir_name)
|
|
if not os.path.isdir(path):
|
|
return []
|
|
return sorted(
|
|
name
|
|
for name in os.listdir(path)
|
|
if os.path.isfile(os.path.join(path, name))
|
|
)
|
|
|
|
|
|
def _safe_get_full_path(model_dir_name, name):
|
|
try:
|
|
full = folder_paths.get_full_path(model_dir_name, name)
|
|
if full:
|
|
return full
|
|
except Exception:
|
|
pass
|
|
fallback = os.path.join(folder_paths.models_dir, model_dir_name, name)
|
|
if os.path.exists(fallback):
|
|
return fallback
|
|
return ""
|
|
|
|
|
|
def _model_options():
|
|
available = set(_safe_get_filename_list(SAM2_MODEL_DIR))
|
|
merged = list(SAM2_MODELS.keys())
|
|
for name in sorted(available):
|
|
if name not in merged:
|
|
merged.append(name)
|
|
return merged
|
|
|
|
|
|
def _download_if_needed(model_name):
|
|
model_name = str(model_name or "").strip()
|
|
if not model_name:
|
|
raise ValueError("Model name is required")
|
|
|
|
full_path = _safe_get_full_path(SAM2_MODEL_DIR, model_name)
|
|
if full_path and os.path.exists(full_path):
|
|
return full_path
|
|
|
|
if model_name not in SAM2_MODELS:
|
|
raise ValueError("Model not found locally and no download mapping for '{}'".format(model_name))
|
|
|
|
url = SAM2_MODELS[model_name]["url"]
|
|
parsed = urlparse(url)
|
|
src_name = os.path.basename(parsed.path)
|
|
target = os.path.join(_model_storage_dir(), src_name)
|
|
if not os.path.exists(target):
|
|
download_url_to_file(url, target)
|
|
return target
|
|
|
|
|
|
def _resolve_config_candidates(model_name, checkpoint_path):
|
|
candidates = []
|
|
|
|
info = SAM2_MODELS.get(model_name)
|
|
if info and info.get("config"):
|
|
candidates.append(str(info["config"]))
|
|
|
|
base = os.path.basename(checkpoint_path).replace(".pt", "")
|
|
variants = {
|
|
base,
|
|
base.replace("2.1", "2_1"),
|
|
base.replace("2.1", "2"),
|
|
base.replace("sam2.1", "sam2"),
|
|
base.replace("sam2_1", "sam2"),
|
|
}
|
|
for variant in sorted(variants):
|
|
candidates.append("{}.yaml".format(variant))
|
|
|
|
# De-duplicate while preserving order.
|
|
seen = set()
|
|
ordered = []
|
|
for name in candidates:
|
|
if name in seen:
|
|
continue
|
|
seen.add(name)
|
|
ordered.append(name)
|
|
return ordered
|
|
|
|
|
|
def _pack_config_dir():
|
|
return os.path.join(os.path.dirname(os.path.abspath(__file__)), "sam2_configs")
|
|
|
|
|
|
def _init_hydra_for_local_configs():
|
|
cfg_dir = _pack_config_dir()
|
|
if not os.path.isdir(cfg_dir):
|
|
raise RuntimeError("OpenShot SAM2 config directory not found: {}".format(cfg_dir))
|
|
if GlobalHydra.instance().is_initialized():
|
|
GlobalHydra.instance().clear()
|
|
initialize_config_dir(config_dir=cfg_dir, version_base=None)
|
|
|
|
|
|
def _to_device_dtype(device_name, precision):
|
|
device_name = str(device_name or "").strip().lower()
|
|
if device_name in ("", "auto"):
|
|
device = mm.get_torch_device()
|
|
elif device_name == "cpu":
|
|
device = torch.device("cpu")
|
|
elif device_name == "cuda":
|
|
device = torch.device("cuda")
|
|
elif device_name == "mps":
|
|
device = torch.device("mps")
|
|
else:
|
|
device = mm.get_torch_device()
|
|
|
|
precision = str(precision or "fp16").strip().lower()
|
|
if precision == "bf16":
|
|
dtype = torch.bfloat16
|
|
elif precision == "fp32":
|
|
dtype = torch.float32
|
|
else:
|
|
dtype = torch.float16
|
|
return device, dtype
|
|
|
|
|
|
def _parse_points(text):
|
|
text = str(text or "").strip()
|
|
if not text:
|
|
return []
|
|
try:
|
|
parsed = json.loads(text.replace("'", '"'))
|
|
except Exception:
|
|
return []
|
|
if not isinstance(parsed, list):
|
|
return []
|
|
pts = []
|
|
for item in parsed:
|
|
if not isinstance(item, dict):
|
|
continue
|
|
try:
|
|
pts.append((float(item["x"]), float(item["y"])))
|
|
except Exception:
|
|
continue
|
|
return pts
|
|
|
|
|
|
def _parse_rects(text):
|
|
text = str(text or "").strip()
|
|
if not text:
|
|
return []
|
|
try:
|
|
parsed = json.loads(text.replace("'", '"'))
|
|
except Exception:
|
|
return []
|
|
if not isinstance(parsed, list):
|
|
return []
|
|
out = []
|
|
for item in parsed:
|
|
if not isinstance(item, dict):
|
|
continue
|
|
if all(k in item for k in ("x1", "y1", "x2", "y2")):
|
|
try:
|
|
x1 = float(item["x1"])
|
|
y1 = float(item["y1"])
|
|
x2 = float(item["x2"])
|
|
y2 = float(item["y2"])
|
|
except Exception:
|
|
continue
|
|
elif all(k in item for k in ("x", "y", "w", "h")):
|
|
try:
|
|
x1 = float(item["x"])
|
|
y1 = float(item["y"])
|
|
x2 = x1 + float(item["w"])
|
|
y2 = y1 + float(item["h"])
|
|
except Exception:
|
|
continue
|
|
else:
|
|
continue
|
|
out.append((x1, y1, x2, y2))
|
|
return out
|
|
|
|
|
|
def _parse_tracking_selection(text):
|
|
text = str(text or "").strip()
|
|
if not text:
|
|
return {"seed_frame_idx": 0, "schedule": {}}
|
|
try:
|
|
parsed = json.loads(text.replace("'", '"'))
|
|
except Exception:
|
|
return {"seed_frame_idx": 0, "schedule": {}}
|
|
if not isinstance(parsed, dict):
|
|
return {"seed_frame_idx": 0, "schedule": {}}
|
|
|
|
try:
|
|
seed_frame_idx = max(0, int(parsed.get("seed_frame", 1)) - 1)
|
|
except Exception:
|
|
seed_frame_idx = 0
|
|
|
|
frames = parsed.get("frames", {})
|
|
if not isinstance(frames, dict):
|
|
frames = {}
|
|
|
|
schedule = {}
|
|
for frame_key, frame_data in frames.items():
|
|
if not isinstance(frame_data, dict):
|
|
continue
|
|
try:
|
|
frame_idx = int(frame_key)
|
|
except Exception:
|
|
continue
|
|
frame_idx = max(0, frame_idx - 1)
|
|
|
|
pos = []
|
|
neg = []
|
|
for item in frame_data.get("positive_points", []) or []:
|
|
if not isinstance(item, dict):
|
|
continue
|
|
try:
|
|
pos.append((float(item["x"]), float(item["y"])))
|
|
except Exception:
|
|
continue
|
|
for item in frame_data.get("negative_points", []) or []:
|
|
if not isinstance(item, dict):
|
|
continue
|
|
try:
|
|
neg.append((float(item["x"]), float(item["y"])))
|
|
except Exception:
|
|
continue
|
|
|
|
pos_rects = []
|
|
neg_rects = []
|
|
for item in frame_data.get("positive_rects", []) or []:
|
|
if not isinstance(item, dict):
|
|
continue
|
|
try:
|
|
pos_rects.append(
|
|
(
|
|
float(item["x1"]),
|
|
float(item["y1"]),
|
|
float(item["x2"]),
|
|
float(item["y2"]),
|
|
)
|
|
)
|
|
except Exception:
|
|
continue
|
|
for item in frame_data.get("negative_rects", []) or []:
|
|
if not isinstance(item, dict):
|
|
continue
|
|
try:
|
|
neg_rects.append(
|
|
(
|
|
float(item["x1"]),
|
|
float(item["y1"]),
|
|
float(item["x2"]),
|
|
float(item["y2"]),
|
|
)
|
|
)
|
|
except Exception:
|
|
continue
|
|
|
|
points = []
|
|
labels = []
|
|
object_prompts = []
|
|
for idx, (x, y) in enumerate(pos):
|
|
obj_id = int(idx)
|
|
points.append((x, y))
|
|
labels.append(1)
|
|
object_prompts.append(
|
|
{
|
|
"obj_id": obj_id,
|
|
"points": [(x, y)],
|
|
"labels": [1],
|
|
"positive_rects": [],
|
|
}
|
|
)
|
|
for x, y in neg:
|
|
points.append((x, y))
|
|
labels.append(0)
|
|
for extra_idx, rect in enumerate(pos_rects):
|
|
obj_id = int(len(object_prompts) + extra_idx)
|
|
object_prompts.append(
|
|
{
|
|
"obj_id": obj_id,
|
|
"points": [],
|
|
"labels": [],
|
|
"positive_rects": [rect],
|
|
}
|
|
)
|
|
|
|
if points or pos_rects or neg_rects:
|
|
schedule[int(frame_idx)] = {
|
|
"points": points,
|
|
"labels": labels,
|
|
"positive_rects": pos_rects,
|
|
"negative_rects": neg_rects,
|
|
"object_prompts": object_prompts,
|
|
}
|
|
|
|
return {"seed_frame_idx": int(seed_frame_idx), "schedule": schedule}
|
|
|
|
|
|
def _clip_rect(rect, width, height):
|
|
x1, y1, x2, y2 = [float(v) for v in rect]
|
|
left = max(0, min(int(np.floor(min(x1, x2))), int(width)))
|
|
top = max(0, min(int(np.floor(min(y1, y2))), int(height)))
|
|
right = max(0, min(int(np.ceil(max(x1, x2))), int(width)))
|
|
bottom = max(0, min(int(np.ceil(max(y1, y2))), int(height)))
|
|
if right <= left or bottom <= top:
|
|
return None
|
|
return (left, top, right, bottom)
|
|
|
|
|
|
def _rect_center_points(rects):
|
|
out = []
|
|
for x1, y1, x2, y2 in rects:
|
|
out.append(((float(x1) + float(x2)) * 0.5, (float(y1) + float(y2)) * 0.5))
|
|
return out
|
|
|
|
|
|
def _mask_stack_like(base_mask, image):
|
|
if base_mask is None:
|
|
return None
|
|
mask = base_mask.float()
|
|
if mask.ndim == 2:
|
|
mask = mask.unsqueeze(0)
|
|
if mask.ndim == 4:
|
|
mask = mask.squeeze(-1)
|
|
if mask.ndim != 3:
|
|
return None
|
|
b = int(image.shape[0])
|
|
h = int(image.shape[1])
|
|
w = int(image.shape[2])
|
|
if int(mask.shape[0]) == 1 and b > 1:
|
|
mask = mask.repeat(b, 1, 1)
|
|
if int(mask.shape[0]) != b:
|
|
return None
|
|
if int(mask.shape[1]) != h or int(mask.shape[2]) != w:
|
|
mask = F.interpolate(mask.unsqueeze(1), size=(h, w), mode="nearest").squeeze(1)
|
|
return torch.clamp(mask, 0.0, 1.0)
|
|
|
|
|
|
def _apply_negative_rects(mask_tensor, negative_rects):
|
|
if mask_tensor is None or not negative_rects:
|
|
return mask_tensor
|
|
if mask_tensor.ndim != 3:
|
|
return mask_tensor
|
|
h = int(mask_tensor.shape[1])
|
|
w = int(mask_tensor.shape[2])
|
|
out = mask_tensor.clone()
|
|
for rect in negative_rects:
|
|
clipped = _clip_rect(rect, w, h)
|
|
if not clipped:
|
|
continue
|
|
left, top, right, bottom = clipped
|
|
out[:, top:bottom, left:right] = 0.0
|
|
return out
|
|
|
|
|
|
def _tensor_to_pil_image(img):
|
|
arr = torch.clamp(img, 0.0, 1.0).mul(255.0).byte().cpu().numpy()
|
|
return Image.fromarray(arr)
|
|
|
|
|
|
def _resolve_dino_device(device_name):
|
|
device_name = str(device_name or "auto").strip().lower()
|
|
if device_name == "auto":
|
|
return mm.get_torch_device()
|
|
return torch.device(device_name)
|
|
|
|
|
|
def _get_groundingdino_model_and_processor(model_id, device):
|
|
key = "{}::{}".format(str(model_id), str(device))
|
|
if key in GROUNDING_DINO_CACHE:
|
|
return GROUNDING_DINO_CACHE[key]
|
|
processor = AutoProcessor.from_pretrained(model_id)
|
|
model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id)
|
|
model.to(device)
|
|
model.eval()
|
|
GROUNDING_DINO_CACHE[key] = (processor, model)
|
|
return processor, model
|
|
|
|
|
|
def _detect_groundingdino_boxes(image_tensor, prompt, model_id, box_threshold, text_threshold, device_name):
|
|
prompt = str(prompt or "").strip()
|
|
if not prompt:
|
|
return []
|
|
_require_groundingdino()
|
|
if not prompt.endswith("."):
|
|
prompt = "{}.".format(prompt)
|
|
if image_tensor is None or int(image_tensor.shape[0]) <= 0:
|
|
return []
|
|
|
|
device = _resolve_dino_device(device_name)
|
|
processor, model = _get_groundingdino_model_and_processor(model_id, device)
|
|
pil = _tensor_to_pil_image(image_tensor[0])
|
|
h = int(image_tensor.shape[1])
|
|
w = int(image_tensor.shape[2])
|
|
with torch.inference_mode():
|
|
inputs = processor(images=pil, text=prompt, return_tensors="pt")
|
|
inputs = {k: v.to(device) for k, v in inputs.items()}
|
|
outputs = model(**inputs)
|
|
post_kwargs = {
|
|
"target_sizes": [(h, w)],
|
|
"text_threshold": float(text_threshold),
|
|
}
|
|
try:
|
|
result = processor.post_process_grounded_object_detection(
|
|
outputs,
|
|
inputs["input_ids"],
|
|
box_threshold=float(box_threshold),
|
|
**post_kwargs,
|
|
)[0]
|
|
except TypeError:
|
|
try:
|
|
result = processor.post_process_grounded_object_detection(
|
|
outputs,
|
|
inputs["input_ids"],
|
|
threshold=float(box_threshold),
|
|
**post_kwargs,
|
|
)[0]
|
|
except TypeError:
|
|
result = processor.post_process_grounded_object_detection(
|
|
outputs,
|
|
inputs["input_ids"],
|
|
threshold=float(box_threshold),
|
|
target_sizes=[(h, w)],
|
|
)[0]
|
|
boxes = result.get("boxes")
|
|
labels = result.get("labels")
|
|
scores = result.get("scores")
|
|
if boxes is None or boxes.numel() == 0:
|
|
_sam2_debug("dino-detect", "prompt=", prompt, "detections=0")
|
|
return []
|
|
|
|
boxes_cpu = boxes.detach().cpu()
|
|
out_boxes = [tuple(float(v) for v in boxes_cpu[i].tolist()) for i in range(int(boxes_cpu.shape[0]))]
|
|
|
|
# Detailed detection diagnostics for prompt-quality debugging.
|
|
details = []
|
|
for i in range(int(boxes_cpu.shape[0])):
|
|
try:
|
|
lbl = str(labels[i]) if labels is not None else ""
|
|
except Exception:
|
|
lbl = ""
|
|
try:
|
|
score = float(scores[i].item()) if scores is not None else 0.0
|
|
except Exception:
|
|
score = 0.0
|
|
b = out_boxes[i]
|
|
details.append({
|
|
"i": i,
|
|
"label": lbl,
|
|
"score": round(score, 4),
|
|
"box": [round(float(b[0]), 1), round(float(b[1]), 1), round(float(b[2]), 1), round(float(b[3]), 1)],
|
|
})
|
|
_sam2_debug(
|
|
"dino-detect",
|
|
"prompt=", prompt,
|
|
"detections=", len(out_boxes),
|
|
"details=", json.dumps(details[:12]),
|
|
)
|
|
|
|
return out_boxes
|
|
|
|
|
|
def _sam2_add_prompts(model, state, frame_idx, obj_id, coords, labels, positive_rects):
|
|
errors = []
|
|
if coords is not None and labels is not None and len(coords) > 0 and len(labels) > 0:
|
|
for call in (
|
|
lambda: model.add_new_points(
|
|
inference_state=state,
|
|
frame_idx=int(frame_idx),
|
|
obj_id=int(obj_id),
|
|
points=coords,
|
|
labels=labels,
|
|
),
|
|
lambda: model.add_new_points_or_box(
|
|
inference_state=state,
|
|
frame_idx=int(frame_idx),
|
|
obj_id=int(obj_id),
|
|
points=coords,
|
|
labels=labels,
|
|
),
|
|
):
|
|
try:
|
|
call()
|
|
break
|
|
except Exception as ex:
|
|
errors.append(str(ex))
|
|
else:
|
|
raise RuntimeError("Failed SAM2 add points across API variants: {}".format(errors))
|
|
|
|
for rect in positive_rects or []:
|
|
box = np.array([float(rect[0]), float(rect[1]), float(rect[2]), float(rect[3])], dtype=np.float32)
|
|
rect_errors = []
|
|
for call in (
|
|
lambda: model.add_new_points_or_box(
|
|
inference_state=state,
|
|
frame_idx=int(frame_idx),
|
|
obj_id=int(obj_id),
|
|
box=box,
|
|
),
|
|
lambda: model.add_new_points_or_box(
|
|
inference_state=state,
|
|
frame_idx=int(frame_idx),
|
|
obj_id=int(obj_id),
|
|
points=np.empty((0, 2), dtype=np.float32),
|
|
labels=np.empty((0,), dtype=np.int32),
|
|
box=box,
|
|
),
|
|
):
|
|
try:
|
|
call()
|
|
rect_errors = []
|
|
break
|
|
except Exception as ex:
|
|
rect_errors.append(str(ex))
|
|
if rect_errors:
|
|
errors.extend(rect_errors)
|
|
return errors
|
|
|
|
|
|
def _resolve_video_path_for_sam2(path_text):
|
|
"""Resolve Comfy-style path text to an absolute local file path for SAM2 video predictor."""
|
|
path_text = str(path_text or "").strip()
|
|
if not path_text:
|
|
return ""
|
|
# Strip Comfy annotation suffixes if present.
|
|
if path_text.endswith("]") and " [" in path_text:
|
|
path_text = path_text.rsplit(" [", 1)[0].strip()
|
|
|
|
if os.path.isabs(path_text) and os.path.exists(path_text):
|
|
return path_text
|
|
|
|
# Handles plain names and annotated names like "clip.mp4 [input]".
|
|
try:
|
|
resolved = folder_paths.get_annotated_filepath(path_text)
|
|
if resolved and os.path.exists(resolved):
|
|
return resolved
|
|
except Exception:
|
|
pass
|
|
|
|
# Fallback to Comfy input directory.
|
|
try:
|
|
candidate = os.path.join(folder_paths.get_input_directory(), path_text)
|
|
if os.path.exists(candidate):
|
|
return candidate
|
|
# fallback to basename if caller passed nested/odd relative path tokens
|
|
candidate2 = os.path.join(folder_paths.get_input_directory(), os.path.basename(path_text))
|
|
if os.path.exists(candidate2):
|
|
return candidate2
|
|
except Exception:
|
|
pass
|
|
|
|
return path_text
|
|
|
|
|
|
def _ensure_mp4_for_sam2(video_path):
|
|
"""Convert non-MP4 input videos to MP4 for SAM2VideoPredictor compatibility."""
|
|
video_path = str(video_path or "").strip()
|
|
if not video_path:
|
|
return video_path
|
|
ext = os.path.splitext(video_path)[1].lower()
|
|
if ext == ".mp4":
|
|
return video_path
|
|
if not os.path.isfile(video_path):
|
|
return video_path
|
|
|
|
cache_dir = os.path.join(folder_paths.get_temp_directory(), "openshot_sam2_mp4_cache")
|
|
os.makedirs(cache_dir, exist_ok=True)
|
|
|
|
st = os.stat(video_path)
|
|
key = "{}|{}|{}".format(video_path, int(st.st_mtime_ns), int(st.st_size))
|
|
digest = hashlib.sha256(key.encode("utf-8")).hexdigest()[:16]
|
|
out_path = os.path.join(cache_dir, "{}.mp4".format(digest))
|
|
if os.path.exists(out_path):
|
|
return out_path
|
|
|
|
cmd = [
|
|
"ffmpeg",
|
|
"-y",
|
|
"-i",
|
|
video_path,
|
|
"-an",
|
|
"-c:v",
|
|
"libx264",
|
|
"-pix_fmt",
|
|
"yuv420p",
|
|
"-crf",
|
|
"18",
|
|
out_path,
|
|
]
|
|
try:
|
|
subprocess.run(cmd, check=True, stdout=subprocess.DEVNULL, stderr=subprocess.PIPE, text=True)
|
|
except FileNotFoundError:
|
|
raise RuntimeError("ffmpeg not found; required to convert '{}' to MP4".format(video_path))
|
|
except subprocess.CalledProcessError as ex:
|
|
err = (ex.stderr or "").strip()
|
|
if len(err) > 500:
|
|
err = err[:500] + "...(truncated)"
|
|
raise RuntimeError("ffmpeg conversion to MP4 failed: {}".format(err))
|
|
return out_path
|
|
|
|
|
|
def _load_video_frame_tensor_for_dino(video_path, frame_index=0):
|
|
"""Load one RGB frame from video as IMAGE tensor shape [1,H,W,C] in 0..1."""
|
|
vp = _resolve_video_path_for_sam2(video_path)
|
|
vp = _ensure_mp4_for_sam2(vp)
|
|
if not vp or (not os.path.isfile(vp)):
|
|
return None
|
|
|
|
try:
|
|
frame_index = int(max(0, frame_index))
|
|
except Exception:
|
|
frame_index = 0
|
|
|
|
tmp_dir = tempfile.mkdtemp(prefix="openshot_dino_frame_", dir=folder_paths.get_temp_directory())
|
|
out_png = os.path.join(tmp_dir, "seed.png")
|
|
filter_expr = r"select=eq(n\,{})".format(frame_index)
|
|
cmd = [
|
|
"ffmpeg",
|
|
"-y",
|
|
"-i",
|
|
vp,
|
|
"-vf",
|
|
filter_expr,
|
|
"-vframes",
|
|
"1",
|
|
out_png,
|
|
]
|
|
try:
|
|
subprocess.run(cmd, check=True, stdout=subprocess.DEVNULL, stderr=subprocess.PIPE, text=True)
|
|
if not os.path.isfile(out_png):
|
|
return None
|
|
pil = Image.open(out_png).convert("RGB")
|
|
arr = np.asarray(pil, dtype=np.float32) / 255.0
|
|
return torch.from_numpy(arr).unsqueeze(0)
|
|
except Exception:
|
|
return None
|
|
finally:
|
|
shutil.rmtree(tmp_dir, ignore_errors=True)
|
|
|
|
|
|
def _build_sam2_video_predictor(config_name, checkpoint, torch_device):
|
|
"""Build a SAM2 video predictor across package variants."""
|
|
if sam2_build is None:
|
|
raise RuntimeError("sam2.build_sam module unavailable")
|
|
|
|
candidate_names = (
|
|
"build_sam2_video_predictor",
|
|
"build_video_predictor",
|
|
"build_sam_video_predictor",
|
|
)
|
|
found = []
|
|
last_error = None
|
|
for name in candidate_names:
|
|
fn = getattr(sam2_build, name, None)
|
|
if not callable(fn):
|
|
continue
|
|
found.append(name)
|
|
for kwargs in (
|
|
{"device": torch_device},
|
|
{},
|
|
):
|
|
try:
|
|
return fn(config_name, checkpoint, **kwargs)
|
|
except TypeError:
|
|
continue
|
|
except Exception as ex:
|
|
last_error = ex
|
|
continue
|
|
raise RuntimeError(
|
|
"Could not build SAM2 video predictor. Found builders={} last_error={}".format(found, last_error)
|
|
)
|
|
|
|
|
|
class OpenShotTransNetSceneDetect:
|
|
@classmethod
|
|
def IS_CHANGED(cls, **kwargs):
|
|
return ""
|
|
|
|
@classmethod
|
|
def INPUT_TYPES(cls):
|
|
return {
|
|
"required": {
|
|
"source_video_path": ("STRING", {"default": ""}),
|
|
"threshold": ("FLOAT", {"default": 0.50, "min": 0.01, "max": 0.99, "step": 0.01}),
|
|
"min_scene_length_frames": ("INT", {"default": 30, "min": 1, "max": 10000}),
|
|
"device": (["auto", "cuda", "cpu", "mps"], {"default": "auto"}),
|
|
},
|
|
}
|
|
|
|
RETURN_TYPES = ("STRING",)
|
|
RETURN_NAMES = ("scene_ranges_json",)
|
|
FUNCTION = "detect"
|
|
CATEGORY = "OpenShot/Video"
|
|
|
|
def _resolve_device_name(self, device_name):
|
|
value = str(device_name or "auto").strip().lower()
|
|
if value != "auto":
|
|
return value
|
|
if torch.cuda.is_available():
|
|
return "cuda"
|
|
if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
|
|
return "mps"
|
|
return "cpu"
|
|
|
|
def _build_model(self, device_name):
|
|
errors = []
|
|
for kwargs in (
|
|
{"device": device_name},
|
|
{},
|
|
):
|
|
try:
|
|
return _TransNetV2(**kwargs)
|
|
except Exception as ex:
|
|
errors.append(str(ex))
|
|
raise RuntimeError("Failed to initialize TransNetV2 model: {}".format(errors[:2]))
|
|
|
|
def _extract_scenes(self, raw):
|
|
fps = None
|
|
scenes = None
|
|
|
|
if isinstance(raw, dict):
|
|
scenes = raw.get("scenes")
|
|
fps_value = raw.get("fps")
|
|
try:
|
|
if fps_value is not None:
|
|
fps = float(fps_value)
|
|
except Exception:
|
|
fps = None
|
|
else:
|
|
scenes = raw
|
|
|
|
normalized = []
|
|
if isinstance(scenes, np.ndarray):
|
|
scenes = scenes.tolist()
|
|
|
|
if isinstance(scenes, list):
|
|
for entry in scenes:
|
|
start = end = None
|
|
if isinstance(entry, dict):
|
|
start = entry.get("start_seconds", entry.get("start_time", entry.get("start")))
|
|
end = entry.get("end_seconds", entry.get("end_time", entry.get("end")))
|
|
elif isinstance(entry, (list, tuple)) and len(entry) >= 2:
|
|
start, end = entry[0], entry[1]
|
|
try:
|
|
start_f = float(start)
|
|
end_f = float(end)
|
|
except Exception:
|
|
continue
|
|
if end_f <= start_f:
|
|
continue
|
|
normalized.append((start_f, end_f))
|
|
return normalized, fps
|
|
|
|
def _run_inference(self, model, video_path, threshold):
|
|
errors = []
|
|
for fn_name in ("detect_scenes", "analyze_video", "predict_video"):
|
|
fn = getattr(model, fn_name, None)
|
|
if not callable(fn):
|
|
continue
|
|
for kwargs in (
|
|
{"threshold": float(threshold)},
|
|
{},
|
|
):
|
|
try:
|
|
return fn(video_path, **kwargs)
|
|
except TypeError:
|
|
continue
|
|
except Exception as ex:
|
|
errors.append("{}: {}".format(fn_name, ex))
|
|
break
|
|
raise RuntimeError("TransNetV2 inference failed: {}".format(errors[:2]))
|
|
|
|
def _apply_min_scene_length(self, scenes, fps, min_scene_length_frames):
|
|
if not scenes:
|
|
return []
|
|
if not fps or fps <= 0:
|
|
return scenes
|
|
min_seconds = float(min_scene_length_frames) / float(fps)
|
|
if min_seconds <= 0:
|
|
return scenes
|
|
|
|
out = []
|
|
for start_sec, end_sec in scenes:
|
|
if not out:
|
|
out.append([start_sec, end_sec])
|
|
continue
|
|
duration = end_sec - start_sec
|
|
if duration < min_seconds:
|
|
out[-1][1] = max(out[-1][1], end_sec)
|
|
continue
|
|
out.append([start_sec, end_sec])
|
|
return [(float(s), float(e)) for s, e in out if e > s]
|
|
|
|
def detect(self, source_video_path, threshold, min_scene_length_frames, device):
|
|
_require_transnet()
|
|
video_path = _resolve_video_path_for_sam2(source_video_path)
|
|
if not video_path or not os.path.exists(video_path):
|
|
raise ValueError("Video path not found: {}".format(source_video_path))
|
|
|
|
device_name = self._resolve_device_name(device)
|
|
model = self._build_model(device_name)
|
|
raw = self._run_inference(model, video_path, threshold)
|
|
scenes, fps = self._extract_scenes(raw)
|
|
scenes = sorted(scenes, key=lambda item: (item[0], item[1]))
|
|
scenes = self._apply_min_scene_length(scenes, fps, int(min_scene_length_frames))
|
|
|
|
payload = {
|
|
"version": 1,
|
|
"detector": "openshot-transnetv2",
|
|
"source_video_path": str(video_path),
|
|
"fps": float(fps) if fps else None,
|
|
"segments": [
|
|
{
|
|
"index": idx,
|
|
"start_seconds": round(float(start_sec), 6),
|
|
"end_seconds": round(float(end_sec), 6),
|
|
}
|
|
for idx, (start_sec, end_sec) in enumerate(scenes, start=1)
|
|
],
|
|
}
|
|
return (json.dumps(payload),)
|
|
|
|
|
|
def _probe_video_info(path_text):
|
|
"""Probe basic video metadata via ffprobe."""
|
|
path_text = str(path_text or "").strip()
|
|
if not path_text:
|
|
return {}
|
|
cmd = [
|
|
"ffprobe",
|
|
"-v",
|
|
"error",
|
|
"-select_streams",
|
|
"v:0",
|
|
"-show_entries",
|
|
"stream=avg_frame_rate,r_frame_rate:format=duration",
|
|
"-of",
|
|
"json",
|
|
path_text,
|
|
]
|
|
try:
|
|
result = subprocess.run(cmd, check=True, capture_output=True, text=True)
|
|
except Exception:
|
|
return {}
|
|
try:
|
|
payload = json.loads(result.stdout or "{}")
|
|
except Exception:
|
|
return {}
|
|
|
|
stream = {}
|
|
streams = payload.get("streams")
|
|
if isinstance(streams, list) and streams:
|
|
stream = streams[0] if isinstance(streams[0], dict) else {}
|
|
fmt = payload.get("format") if isinstance(payload.get("format"), dict) else {}
|
|
|
|
def _parse_rate(text_value):
|
|
text_value = str(text_value or "").strip()
|
|
if not text_value or text_value in ("0/0", "N/A"):
|
|
return None
|
|
if "/" in text_value:
|
|
try:
|
|
frac = Fraction(text_value)
|
|
if frac > 0:
|
|
return frac
|
|
except Exception:
|
|
return None
|
|
try:
|
|
value = float(text_value)
|
|
if value > 0:
|
|
return Fraction(value).limit_denominator(1000000)
|
|
except Exception:
|
|
return None
|
|
return None
|
|
|
|
fps = _parse_rate(stream.get("avg_frame_rate")) or _parse_rate(stream.get("r_frame_rate"))
|
|
duration = None
|
|
try:
|
|
duration = float(fmt.get("duration"))
|
|
except Exception:
|
|
duration = None
|
|
|
|
return {
|
|
"fps": fps,
|
|
"duration": duration,
|
|
}
|
|
|
|
|
|
class OpenShotSceneRangesFromSegments:
|
|
@classmethod
|
|
def IS_CHANGED(cls, **kwargs):
|
|
return ""
|
|
|
|
@classmethod
|
|
def INPUT_TYPES(cls):
|
|
return {
|
|
"required": {
|
|
"segment_paths": ("*",),
|
|
"source_video_path": ("STRING", {"default": ""}),
|
|
},
|
|
"optional": {
|
|
"fallback_fps": ("FLOAT", {"default": 30.0, "min": 1.0, "max": 240.0, "step": 0.001}),
|
|
},
|
|
}
|
|
|
|
RETURN_TYPES = ("STRING",)
|
|
RETURN_NAMES = ("scene_ranges_json",)
|
|
FUNCTION = "build"
|
|
CATEGORY = "OpenShot/Video"
|
|
|
|
def _as_path_list(self, segment_paths):
|
|
if isinstance(segment_paths, (list, tuple)):
|
|
return [str(p).strip() for p in segment_paths if str(p or "").strip()]
|
|
if isinstance(segment_paths, str):
|
|
text = segment_paths.strip()
|
|
if not text:
|
|
return []
|
|
try:
|
|
parsed = json.loads(text)
|
|
if isinstance(parsed, list):
|
|
return [str(p).strip() for p in parsed if str(p or "").strip()]
|
|
except Exception:
|
|
pass
|
|
return [text]
|
|
return []
|
|
|
|
def _timecode(self, seconds_value, fps_fraction):
|
|
fps_fraction = fps_fraction if isinstance(fps_fraction, Fraction) and fps_fraction > 0 else Fraction(30, 1)
|
|
fps_float = float(fps_fraction)
|
|
total_seconds = max(0.0, float(seconds_value or 0.0))
|
|
hours = int(total_seconds // 3600)
|
|
minutes = int((total_seconds % 3600) // 60)
|
|
secs = int(total_seconds % 60)
|
|
frames = int(round((total_seconds - int(total_seconds)) * fps_float))
|
|
fps_ceiling = int(round(fps_float)) or 1
|
|
if frames >= fps_ceiling:
|
|
frames = 0
|
|
secs += 1
|
|
if secs >= 60:
|
|
secs = 0
|
|
minutes += 1
|
|
if minutes >= 60:
|
|
minutes = 0
|
|
hours += 1
|
|
if hours > 0:
|
|
return "{:02d}:{:02d}:{:02d};{:02d}".format(hours, minutes, secs, frames)
|
|
if minutes > 0:
|
|
return "{:02d}:{:02d};{:02d}".format(minutes, secs, frames)
|
|
return "{:02d};{:02d}".format(secs, frames)
|
|
|
|
def build(self, segment_paths, source_video_path, fallback_fps=30.0):
|
|
paths = self._as_path_list(segment_paths)
|
|
if not paths:
|
|
return (json.dumps({"segments": []}),)
|
|
|
|
source_info = _probe_video_info(source_video_path)
|
|
fps_fraction = source_info.get("fps")
|
|
if fps_fraction is None or fps_fraction <= 0:
|
|
try:
|
|
fps_fraction = Fraction(float(fallback_fps)).limit_denominator(1000000)
|
|
except Exception:
|
|
fps_fraction = Fraction(30, 1)
|
|
fps_float = float(fps_fraction)
|
|
|
|
source_duration = source_info.get("duration")
|
|
running_start = 0.0
|
|
segments = []
|
|
|
|
for idx, segment_path in enumerate(paths, start=1):
|
|
info = _probe_video_info(segment_path)
|
|
duration = info.get("duration")
|
|
if duration is None:
|
|
continue
|
|
duration = max(0.0, float(duration))
|
|
if duration <= 0.0:
|
|
continue
|
|
start_seconds = running_start
|
|
end_seconds = running_start + duration
|
|
if source_duration is not None:
|
|
end_seconds = min(end_seconds, float(source_duration))
|
|
if end_seconds <= start_seconds:
|
|
continue
|
|
|
|
start_frame = int(round(start_seconds * fps_float)) + 1
|
|
end_frame = int(round(end_seconds * fps_float))
|
|
if end_frame < start_frame:
|
|
end_frame = start_frame
|
|
|
|
segments.append(
|
|
{
|
|
"index": idx,
|
|
"path": str(segment_path),
|
|
"start_seconds": round(start_seconds, 6),
|
|
"end_seconds": round(end_seconds, 6),
|
|
"duration_seconds": round(end_seconds - start_seconds, 6),
|
|
"start_frame": int(start_frame),
|
|
"end_frame": int(end_frame),
|
|
"start_timecode": self._timecode(start_seconds, fps_fraction),
|
|
"end_timecode": self._timecode(end_seconds, fps_fraction),
|
|
}
|
|
)
|
|
running_start = end_seconds
|
|
|
|
payload = {
|
|
"version": 1,
|
|
"source_video_path": str(source_video_path or ""),
|
|
"fps": {
|
|
"num": int(fps_fraction.numerator),
|
|
"den": int(fps_fraction.denominator),
|
|
"float": fps_float,
|
|
},
|
|
"segments": segments,
|
|
}
|
|
return (json.dumps(payload),)
|
|
|
|
|
|
class OpenShotDownloadAndLoadSAM2Model:
|
|
@classmethod
|
|
def INPUT_TYPES(cls):
|
|
return {
|
|
"required": {
|
|
"model": (_model_options(),),
|
|
"segmentor": (["video", "single_image"], {"default": "video"}),
|
|
"device": (["auto", "cuda", "cpu", "mps"], {"default": "auto"}),
|
|
"precision": (["fp16", "bf16", "fp32"], {"default": "fp16"}),
|
|
}
|
|
}
|
|
|
|
RETURN_TYPES = ("SAM2MODEL",)
|
|
RETURN_NAMES = ("sam2_model",)
|
|
FUNCTION = "load"
|
|
CATEGORY = "OpenShot/SAM2"
|
|
|
|
def load(self, model, segmentor, device, precision):
|
|
_require_sam2()
|
|
|
|
checkpoint = _download_if_needed(model)
|
|
config_candidates = _resolve_config_candidates(model, checkpoint)
|
|
torch_device, dtype = _to_device_dtype(device, precision)
|
|
|
|
_init_hydra_for_local_configs()
|
|
print(
|
|
"[OpenShot-ComfyUI:{}] Loading SAM2 model='{}' checkpoint='{}' configs={}".format(
|
|
OPENSHOT_NODEPACK_VERSION, model, checkpoint, config_candidates
|
|
)
|
|
)
|
|
|
|
sam_model = None
|
|
last_error = None
|
|
for config_name in config_candidates:
|
|
try:
|
|
if str(segmentor or "video") == "video":
|
|
sam_model = _build_sam2_video_predictor(config_name, checkpoint, torch_device)
|
|
else:
|
|
sam_model = build_sam2(config_name, checkpoint, device=torch_device)
|
|
break
|
|
except Exception as ex:
|
|
last_error = ex
|
|
# Missing config names are expected across SAM2 package variants.
|
|
if "Cannot find primary config" in str(ex):
|
|
continue
|
|
raise
|
|
if sam_model is None:
|
|
raise RuntimeError(
|
|
"Failed loading SAM2 model. Tried configs {}. Last error: {}".format(config_candidates, last_error)
|
|
)
|
|
return ({
|
|
"model": sam_model,
|
|
"device": torch_device,
|
|
"dtype": dtype,
|
|
"segmentor": str(segmentor or "video"),
|
|
"model_name": str(model),
|
|
"checkpoint": str(checkpoint),
|
|
},)
|
|
|
|
|
|
class OpenShotSam2Segmentation:
|
|
@classmethod
|
|
def IS_CHANGED(cls, **kwargs):
|
|
return ""
|
|
|
|
@classmethod
|
|
def INPUT_TYPES(cls):
|
|
return {
|
|
"required": {
|
|
"sam2_model": ("SAM2MODEL",),
|
|
"image": ("IMAGE",),
|
|
"auto_mode": ("BOOLEAN", {"default": False}),
|
|
"keep_model_loaded": ("BOOLEAN", {"default": False}),
|
|
},
|
|
"optional": {
|
|
"positive_points_json": ("STRING", {"default": ""}),
|
|
"negative_points_json": ("STRING", {"default": ""}),
|
|
"positive_rects_json": ("STRING", {"default": ""}),
|
|
"negative_rects_json": ("STRING", {"default": ""}),
|
|
"dino_prompt": ("STRING", {"default": ""}),
|
|
"dino_model_id": (GROUNDING_DINO_MODEL_IDS,),
|
|
"dino_box_threshold": ("FLOAT", {"default": 0.35, "min": 0.0, "max": 1.0, "step": 0.01}),
|
|
"dino_text_threshold": ("FLOAT", {"default": 0.25, "min": 0.0, "max": 1.0, "step": 0.01}),
|
|
"dino_device": (("auto", "cpu", "cuda", "mps"),),
|
|
"base_mask": ("MASK",),
|
|
"meta_batch": ("VHS_BatchManager",),
|
|
},
|
|
}
|
|
|
|
RETURN_TYPES = ("MASK",)
|
|
RETURN_NAMES = ("mask",)
|
|
FUNCTION = "segment"
|
|
CATEGORY = "OpenShot/SAM2"
|
|
|
|
def segment(
|
|
self,
|
|
sam2_model,
|
|
image,
|
|
auto_mode,
|
|
keep_model_loaded,
|
|
positive_points_json="",
|
|
negative_points_json="",
|
|
positive_rects_json="",
|
|
negative_rects_json="",
|
|
dino_prompt="",
|
|
dino_model_id="IDEA-Research/grounding-dino-tiny",
|
|
dino_box_threshold=0.35,
|
|
dino_text_threshold=0.25,
|
|
dino_device="auto",
|
|
base_mask=None,
|
|
):
|
|
_require_sam2()
|
|
|
|
model = sam2_model["model"]
|
|
device = sam2_model["device"]
|
|
dtype = sam2_model["dtype"]
|
|
|
|
positive = _parse_points(positive_points_json)
|
|
negative = _parse_points(negative_points_json)
|
|
positive_rects = _parse_rects(positive_rects_json)
|
|
negative_rects = _parse_rects(negative_rects_json)
|
|
|
|
predictor = SAM2ImagePredictor(model)
|
|
base_mask_stack = _mask_stack_like(base_mask, image)
|
|
|
|
out_masks = []
|
|
autocast_device = mm.get_autocast_device(device)
|
|
autocast_ok = not mm.is_device_mps(device)
|
|
with torch.autocast(autocast_device, dtype=dtype) if autocast_ok else nullcontext():
|
|
for frame_idx, frame in enumerate(image):
|
|
frame_np = np.clip((frame.cpu().numpy() * 255.0), 0, 255).astype(np.uint8)
|
|
predictor.set_image(frame_np[..., :3])
|
|
h, w = frame_np.shape[0], frame_np.shape[1]
|
|
|
|
final_mask = torch.zeros((h, w), dtype=torch.float32)
|
|
if base_mask_stack is not None:
|
|
final_mask = torch.maximum(final_mask, (base_mask_stack[frame_idx].cpu() > 0.5).float())
|
|
|
|
seed_points = list(positive)
|
|
if bool(auto_mode) and not seed_points and not positive_rects and base_mask_stack is None:
|
|
seed_points = [(float(w) * 0.5, float(h) * 0.5)]
|
|
|
|
if seed_points or negative:
|
|
pos_arr = np.array(seed_points, dtype=np.float32) if seed_points else np.empty((0, 2), dtype=np.float32)
|
|
neg_arr = np.array(negative, dtype=np.float32) if negative else np.empty((0, 2), dtype=np.float32)
|
|
coords = np.concatenate((pos_arr, neg_arr), axis=0)
|
|
labels = np.concatenate(
|
|
(
|
|
np.ones((len(pos_arr),), dtype=np.int32),
|
|
np.zeros((len(neg_arr),), dtype=np.int32),
|
|
),
|
|
axis=0,
|
|
)
|
|
masks, _scores, _logits = predictor.predict(
|
|
point_coords=coords,
|
|
point_labels=labels,
|
|
multimask_output=False,
|
|
)
|
|
final_mask = torch.maximum(final_mask, torch.from_numpy(masks[0]).float())
|
|
|
|
frame_positive_rects = list(positive_rects)
|
|
dino_prompt_text = str(dino_prompt or "").strip()
|
|
if dino_prompt_text:
|
|
dino_boxes = _detect_groundingdino_boxes(
|
|
image[frame_idx:frame_idx + 1],
|
|
dino_prompt_text,
|
|
dino_model_id,
|
|
float(dino_box_threshold),
|
|
float(dino_text_threshold),
|
|
dino_device,
|
|
)
|
|
if dino_boxes:
|
|
frame_positive_rects.extend([tuple(box) for box in dino_boxes])
|
|
|
|
for rect in frame_positive_rects:
|
|
clipped = _clip_rect(rect, w, h)
|
|
if not clipped:
|
|
continue
|
|
left, top, right, bottom = clipped
|
|
box = np.array([float(left), float(top), float(right), float(bottom)], dtype=np.float32)
|
|
try:
|
|
masks, _scores, _logits = predictor.predict(box=box, multimask_output=False)
|
|
except TypeError:
|
|
masks, _scores, _logits = predictor.predict(box=box, point_coords=None, point_labels=None, multimask_output=False)
|
|
final_mask = torch.maximum(final_mask, torch.from_numpy(masks[0]).float())
|
|
|
|
final_mask = _apply_negative_rects(final_mask.unsqueeze(0), negative_rects).squeeze(0)
|
|
out_masks.append(torch.clamp(final_mask, 0.0, 1.0))
|
|
|
|
if not keep_model_loaded:
|
|
model.to(mm.unet_offload_device())
|
|
mm.soft_empty_cache()
|
|
|
|
return (torch.stack(out_masks, dim=0),)
|
|
|
|
|
|
class OpenShotSam2VideoSegmentationAddPoints:
|
|
@classmethod
|
|
def IS_CHANGED(cls, **kwargs):
|
|
return ""
|
|
|
|
@classmethod
|
|
def INPUT_TYPES(cls):
|
|
return {
|
|
"required": {
|
|
"sam2_model": ("SAM2MODEL",),
|
|
"frame_index": ("INT", {"default": 0, "min": 0}),
|
|
"object_index": ("INT", {"default": 0, "min": 0}),
|
|
"windowed_mode": ("BOOLEAN", {"default": True}),
|
|
"offload_video_to_cpu": ("BOOLEAN", {"default": False}),
|
|
"offload_state_to_cpu": ("BOOLEAN", {"default": False}),
|
|
"auto_mode": ("BOOLEAN", {"default": False}),
|
|
},
|
|
"optional": {
|
|
"image": ("IMAGE",),
|
|
"video_path": ("STRING", {"default": ""}),
|
|
"positive_points_json": ("STRING", {"default": ""}),
|
|
"negative_points_json": ("STRING", {"default": ""}),
|
|
"positive_rects_json": ("STRING", {"default": ""}),
|
|
"negative_rects_json": ("STRING", {"default": ""}),
|
|
"tracking_selection_json": ("STRING", {"default": "{}"}),
|
|
"dino_prompt": ("STRING", {"default": ""}),
|
|
"dino_model_id": (GROUNDING_DINO_MODEL_IDS,),
|
|
"dino_box_threshold": ("FLOAT", {"default": 0.35, "min": 0.0, "max": 1.0, "step": 0.01}),
|
|
"dino_text_threshold": ("FLOAT", {"default": 0.25, "min": 0.0, "max": 1.0, "step": 0.01}),
|
|
"dino_device": (("auto", "cpu", "cuda", "mps"),),
|
|
"prev_inference_state": ("SAM2INFERENCESTATE",),
|
|
"base_mask": ("MASK",),
|
|
},
|
|
}
|
|
|
|
RETURN_TYPES = ("SAM2MODEL", "SAM2INFERENCESTATE")
|
|
RETURN_NAMES = ("sam2_model", "inference_state")
|
|
FUNCTION = "add_points"
|
|
CATEGORY = "OpenShot/SAM2"
|
|
|
|
def add_points(
|
|
self,
|
|
sam2_model,
|
|
frame_index,
|
|
object_index,
|
|
windowed_mode,
|
|
offload_video_to_cpu,
|
|
offload_state_to_cpu,
|
|
auto_mode,
|
|
image=None,
|
|
video_path="",
|
|
positive_points_json="",
|
|
negative_points_json="",
|
|
positive_rects_json="",
|
|
negative_rects_json="",
|
|
tracking_selection_json="{}",
|
|
dino_prompt="",
|
|
dino_model_id="IDEA-Research/grounding-dino-tiny",
|
|
dino_box_threshold=0.35,
|
|
dino_text_threshold=0.25,
|
|
dino_device="auto",
|
|
prev_inference_state=None,
|
|
base_mask=None,
|
|
meta_batch=None,
|
|
):
|
|
model = sam2_model["model"]
|
|
device = sam2_model["device"]
|
|
dtype = sam2_model["dtype"]
|
|
segmentor = sam2_model.get("segmentor", "video")
|
|
if segmentor != "video":
|
|
raise ValueError("Loaded SAM2 model is not configured for video")
|
|
|
|
pos = _parse_points(positive_points_json)
|
|
neg = _parse_points(negative_points_json)
|
|
pos_rects = _parse_rects(positive_rects_json)
|
|
neg_rects = _parse_rects(negative_rects_json)
|
|
tracking_selection = _parse_tracking_selection(tracking_selection_json)
|
|
prompt_schedule = dict(tracking_selection.get("schedule") or {})
|
|
seed_frame_idx = int(max(0, tracking_selection.get("seed_frame_idx", int(max(0, frame_index)))))
|
|
|
|
# Build a stable run key so cached meta-batch state is reused only for
|
|
# the same source/prompt/seed inputs within one generation.
|
|
run_key_payload = {
|
|
"video_path": str(video_path or ""),
|
|
"frame_index": int(max(0, frame_index)),
|
|
"seed_frame_idx": int(seed_frame_idx),
|
|
"object_index": int(max(0, object_index)),
|
|
"auto_mode": bool(auto_mode),
|
|
"dino_prompt": str(dino_prompt or "").strip(),
|
|
"dino_model_id": str(dino_model_id or ""),
|
|
"dino_box_threshold": float(dino_box_threshold),
|
|
"dino_text_threshold": float(dino_text_threshold),
|
|
"positive_points_json": str(positive_points_json or ""),
|
|
"negative_points_json": str(negative_points_json or ""),
|
|
"positive_rects_json": str(positive_rects_json or ""),
|
|
"negative_rects_json": str(negative_rects_json or ""),
|
|
"tracking_selection_json": str(tracking_selection_json or "{}"),
|
|
}
|
|
meta_run_key = hashlib.sha256(json.dumps(run_key_payload, sort_keys=True).encode("utf-8")).hexdigest()
|
|
|
|
# In windowed meta-batch mode, reuse evolving state between chunks only
|
|
# when the request matches this exact run key.
|
|
if bool(windowed_mode) and prev_inference_state is None and meta_batch is not None:
|
|
try:
|
|
cached_state = getattr(meta_batch, "_openshot_sam2_window_state", None)
|
|
if isinstance(cached_state, dict) and cached_state.get("windowed_mode", False):
|
|
cached_key = str(cached_state.get("_meta_run_key", ""))
|
|
if cached_key and cached_key == meta_run_key:
|
|
_sam2_debug("meta-cache", "reuse=1", "run_key=", cached_key[:10])
|
|
return (sam2_model, cached_state)
|
|
# Different run: drop stale cache so prompts are rebuilt.
|
|
_sam2_debug("meta-cache", "reuse=0", "reason=run_key_mismatch")
|
|
try:
|
|
setattr(meta_batch, "_openshot_sam2_window_state", None)
|
|
except Exception:
|
|
pass
|
|
except Exception:
|
|
pass
|
|
|
|
if base_mask is not None:
|
|
mask_stack = _mask_stack_like(base_mask, image) if image is not None else None
|
|
if mask_stack is not None and int(mask_stack.shape[0]) > 0:
|
|
ys, xs = torch.where(mask_stack[0] > 0.5)
|
|
if xs.numel() > 0:
|
|
pos.append((float(xs.float().mean().item()), float(ys.float().mean().item())))
|
|
|
|
if bool(auto_mode) and (not pos) and (not pos_rects):
|
|
if image is not None:
|
|
h = int(image.shape[1])
|
|
w = int(image.shape[2])
|
|
pos = [(float(w) * 0.5, float(h) * 0.5)]
|
|
|
|
dino_prompt = str(dino_prompt or "").strip()
|
|
if dino_prompt:
|
|
dino_image = image
|
|
if dino_image is None and str(video_path or "").strip():
|
|
dino_image = _load_video_frame_tensor_for_dino(video_path, seed_frame_idx)
|
|
|
|
if dino_image is not None:
|
|
try:
|
|
dino_boxes = _detect_groundingdino_boxes(
|
|
dino_image,
|
|
dino_prompt,
|
|
dino_model_id,
|
|
float(dino_box_threshold),
|
|
float(dino_text_threshold),
|
|
dino_device,
|
|
)
|
|
print(
|
|
"[OpenShot-ComfyUI:{}] DINO prompt='{}' boxes={} model='{}' box_th={} text_th={}".format(
|
|
OPENSHOT_NODEPACK_VERSION,
|
|
dino_prompt,
|
|
len(dino_boxes),
|
|
dino_model_id,
|
|
float(dino_box_threshold),
|
|
float(dino_text_threshold),
|
|
)
|
|
)
|
|
except Exception as ex:
|
|
raise RuntimeError(
|
|
"GroundingDINO detection failed for prompt '{}': {}".format(dino_prompt, ex)
|
|
)
|
|
else:
|
|
dino_boxes = []
|
|
print(
|
|
"[OpenShot-ComfyUI:{}] DINO prompt='{}' skipped (no image or video_path frame available)".format(
|
|
OPENSHOT_NODEPACK_VERSION,
|
|
dino_prompt,
|
|
)
|
|
)
|
|
|
|
if dino_boxes:
|
|
_sam2_debug(
|
|
"dino-seed",
|
|
"prompt=", dino_prompt,
|
|
"seed_frame_idx=", int(seed_frame_idx),
|
|
"boxes=", len(dino_boxes),
|
|
)
|
|
seed_entry = dict(prompt_schedule.get(seed_frame_idx, {}) or {})
|
|
seed_object_prompts = list(seed_entry.get("object_prompts") or [])
|
|
next_obj_id = 0
|
|
for op in seed_object_prompts:
|
|
try:
|
|
next_obj_id = max(next_obj_id, int(op.get("obj_id", 0)) + 1)
|
|
except Exception:
|
|
continue
|
|
for box in dino_boxes:
|
|
seed_object_prompts.append(
|
|
{
|
|
"obj_id": int(next_obj_id),
|
|
"points": [],
|
|
"labels": [],
|
|
"positive_rects": [tuple(box)],
|
|
}
|
|
)
|
|
next_obj_id += 1
|
|
seed_entry["object_prompts"] = seed_object_prompts
|
|
seed_rects = list(seed_entry.get("positive_rects") or [])
|
|
seed_rects.extend([tuple(b) for b in dino_boxes])
|
|
seed_entry["positive_rects"] = seed_rects
|
|
prompt_schedule[int(seed_frame_idx)] = seed_entry
|
|
_sam2_debug(
|
|
"dino-seed-boxes",
|
|
"seed_frame_idx=", int(seed_frame_idx),
|
|
"boxes=", json.dumps([[round(float(v),1) for v in b] for b in dino_boxes[:12]]),
|
|
)
|
|
|
|
# Backward-compatible seed injection if no explicit keyframed payload exists.
|
|
if seed_frame_idx not in prompt_schedule and (pos or neg or pos_rects or neg_rects):
|
|
points = []
|
|
labels = []
|
|
for x, y in pos:
|
|
points.append((float(x), float(y)))
|
|
labels.append(1)
|
|
for x, y in neg:
|
|
points.append((float(x), float(y)))
|
|
labels.append(0)
|
|
prompt_schedule[int(seed_frame_idx)] = {
|
|
"points": points,
|
|
"labels": labels,
|
|
"positive_rects": list(pos_rects),
|
|
"negative_rects": list(neg_rects),
|
|
}
|
|
|
|
has_any_positive = False
|
|
for entry in prompt_schedule.values():
|
|
labels = list(entry.get("labels") or [])
|
|
object_prompts = list(entry.get("object_prompts") or [])
|
|
has_object_positive = any(bool((op or {}).get("positive_rects") or []) for op in object_prompts if isinstance(op, dict))
|
|
if any(int(v) == 1 for v in labels) or bool(entry.get("positive_rects") or []) or has_object_positive:
|
|
has_any_positive = True
|
|
break
|
|
allow_empty_schedule = bool(dino_prompt) or bool(auto_mode)
|
|
if not has_any_positive and not allow_empty_schedule:
|
|
raise ValueError("No positive points/rectangles provided")
|
|
|
|
_sam2_debug(
|
|
"add_points",
|
|
"seed_frame_idx=", int(seed_frame_idx),
|
|
"schedule_frames=", sorted([int(k) for k in prompt_schedule.keys()]),
|
|
"windowed=", bool(windowed_mode),
|
|
"has_prev_state=", bool(prev_inference_state is not None),
|
|
)
|
|
|
|
serial_schedule = []
|
|
for fidx in sorted(prompt_schedule.keys()):
|
|
entry = prompt_schedule.get(fidx, {}) or {}
|
|
serial_schedule.append(
|
|
{
|
|
"frame_idx": int(fidx),
|
|
"points": [[float(p[0]), float(p[1])] for p in (entry.get("points") or [])],
|
|
"labels": [int(v) for v in (entry.get("labels") or [])],
|
|
"positive_rects": [[float(r[0]), float(r[1]), float(r[2]), float(r[3])] for r in (entry.get("positive_rects") or [])],
|
|
"negative_rects": [[float(r[0]), float(r[1]), float(r[2]), float(r[3])] for r in (entry.get("negative_rects") or [])],
|
|
"object_prompts": [
|
|
{
|
|
"obj_id": int(op.get("obj_id", 0)),
|
|
"points": [[float(p[0]), float(p[1])] for p in (op.get("points") or [])],
|
|
"labels": [int(v) for v in (op.get("labels") or [])],
|
|
"positive_rects": [
|
|
[float(r[0]), float(r[1]), float(r[2]), float(r[3])]
|
|
for r in (op.get("positive_rects") or [])
|
|
],
|
|
}
|
|
for op in (entry.get("object_prompts") or [])
|
|
if isinstance(op, dict)
|
|
],
|
|
}
|
|
)
|
|
|
|
# Keep these for backward compatibility / fallback behavior.
|
|
if prompt_schedule:
|
|
first_frame = int(sorted(prompt_schedule.keys())[0])
|
|
first_entry = prompt_schedule.get(first_frame, {}) or {}
|
|
else:
|
|
first_frame = int(seed_frame_idx)
|
|
first_entry = {}
|
|
pos_seed = [tuple(p) for p, lbl in zip(first_entry.get("points") or [], first_entry.get("labels") or []) if int(lbl) == 1]
|
|
neg_seed = [tuple(p) for p, lbl in zip(first_entry.get("points") or [], first_entry.get("labels") or []) if int(lbl) == 0]
|
|
pos_arr = np.atleast_2d(np.array(pos_seed, dtype=np.float32)) if pos_seed else np.empty((0, 2), dtype=np.float32)
|
|
neg_arr = np.atleast_2d(np.array(neg_seed, dtype=np.float32)) if neg_seed else np.empty((0, 2), dtype=np.float32)
|
|
coords = np.concatenate((pos_arr, neg_arr), axis=0) if (len(pos_arr) or len(neg_arr)) else np.empty((0, 2), dtype=np.float32)
|
|
labels = np.concatenate((np.ones((len(pos_arr),), dtype=np.int32), np.zeros((len(neg_arr),), dtype=np.int32)), axis=0) if (len(pos_arr) or len(neg_arr)) else np.empty((0,), dtype=np.int32)
|
|
first_pos_rects = [tuple(r) for r in (first_entry.get("positive_rects") or [])]
|
|
first_neg_rects = [tuple(r) for r in (first_entry.get("negative_rects") or [])]
|
|
|
|
# Windowed mode does not hold full-video SAM2 state in memory.
|
|
if bool(windowed_mode):
|
|
state = dict(prev_inference_state or {})
|
|
state["windowed_mode"] = True
|
|
state["seed_points"] = coords.tolist()
|
|
state["seed_labels"] = labels.tolist()
|
|
state["last_points"] = coords.tolist()
|
|
state["last_labels"] = labels.tolist()
|
|
state["seed_rects"] = [[float(a), float(b), float(c), float(d)] for (a, b, c, d) in first_pos_rects]
|
|
state["negative_rects"] = [[float(a), float(b), float(c), float(d)] for (a, b, c, d) in first_neg_rects]
|
|
state["active_negative_rects"] = [[float(a), float(b), float(c), float(d)] for (a, b, c, d) in first_neg_rects]
|
|
state["prompt_schedule"] = serial_schedule
|
|
state["object_index"] = int(object_index)
|
|
state["next_frame_idx"] = int(max(0, state.get("next_frame_idx", 0) or 0))
|
|
state["num_frames"] = int(state.get("num_frames", 0) or 0)
|
|
state["offload_video_to_cpu"] = bool(offload_video_to_cpu)
|
|
state["offload_state_to_cpu"] = bool(offload_state_to_cpu)
|
|
state["object_carries"] = dict(state.get("object_carries", {}) or {})
|
|
state["prompt_frames_applied"] = list(state.get("prompt_frames_applied", []) or [])
|
|
state["boundary_reseed_frames"] = int(max(1, state.get("boundary_reseed_frames", 4) or 4))
|
|
state["_meta_run_key"] = str(meta_run_key)
|
|
if meta_batch is not None:
|
|
try:
|
|
setattr(meta_batch, "_openshot_sam2_window_state", state)
|
|
except Exception:
|
|
pass
|
|
return (sam2_model, state)
|
|
|
|
if (image is None and not str(video_path or "").strip()) and prev_inference_state is None:
|
|
raise ValueError("Image or video_path input is required for initial inference state")
|
|
|
|
model.to(device)
|
|
if prev_inference_state is None:
|
|
# Support SAM2 API variants for init_state signature.
|
|
init_errors = []
|
|
state = None
|
|
num_frames = 0
|
|
|
|
# Preferred path for newer SAM2 video predictors: initialize from source video path.
|
|
if str(video_path or "").strip():
|
|
vp = _resolve_video_path_for_sam2(video_path)
|
|
vp = _ensure_mp4_for_sam2(vp)
|
|
print(
|
|
"[OpenShot-ComfyUI:{}] SAM2 init_state path='{}' exists={} ext='{}'".format(
|
|
OPENSHOT_NODEPACK_VERSION,
|
|
vp,
|
|
os.path.exists(vp),
|
|
os.path.splitext(vp)[1].lower(),
|
|
)
|
|
)
|
|
# Prefer CPU-offloaded inference state to avoid huge VRAM spikes on long videos.
|
|
for call in (
|
|
lambda: model.init_state(
|
|
vp,
|
|
offload_video_to_cpu=bool(offload_video_to_cpu),
|
|
offload_state_to_cpu=bool(offload_state_to_cpu),
|
|
),
|
|
lambda: model.init_state(vp, offload_video_to_cpu=bool(offload_video_to_cpu)),
|
|
lambda: model.init_state(vp, offload_state_to_cpu=bool(offload_state_to_cpu)),
|
|
lambda: model.init_state(vp),
|
|
lambda: model.init_state(vp, device=device),
|
|
):
|
|
try:
|
|
state = call()
|
|
break
|
|
except Exception as ex:
|
|
init_errors.append(str(ex))
|
|
|
|
# Fallback for tensor-accepting SAM2 variants.
|
|
if state is None and image is not None:
|
|
b, h, w, _c = image.shape
|
|
if hasattr(model, "image_size"):
|
|
size = int(model.image_size)
|
|
image = common_upscale(image.movedim(-1, 1), size, size, "bilinear", "disabled").movedim(1, -1)
|
|
video_tensor = image.permute(0, 3, 1, 2).contiguous()
|
|
for call in (
|
|
lambda: model.init_state(video_tensor, h, w, device=device),
|
|
lambda: model.init_state(video_tensor, h, w),
|
|
lambda: model.init_state(video_tensor, device=device),
|
|
lambda: model.init_state(video_tensor),
|
|
):
|
|
try:
|
|
state = call()
|
|
num_frames = int(b)
|
|
break
|
|
except Exception as ex:
|
|
init_errors.append(str(ex))
|
|
if state is None:
|
|
short_errors = init_errors[:2]
|
|
raise RuntimeError(
|
|
"SAM2 init_state failed; path='{}' exists={} ext='{}' errors={}".format(
|
|
vp if str(video_path or "").strip() else "",
|
|
(os.path.exists(vp) if str(video_path or "").strip() else False),
|
|
(os.path.splitext(vp)[1].lower() if str(video_path or "").strip() else ""),
|
|
short_errors,
|
|
)
|
|
)
|
|
else:
|
|
state = prev_inference_state["inference_state"]
|
|
num_frames = int(prev_inference_state.get("num_frames", 0) or 0)
|
|
|
|
autocast_device = mm.get_autocast_device(device)
|
|
autocast_ok = not mm.is_device_mps(device)
|
|
with torch.inference_mode():
|
|
with torch.autocast(autocast_device, dtype=dtype) if autocast_ok else nullcontext():
|
|
add_errors = []
|
|
if len(coords) or len(first_pos_rects):
|
|
add_errors = _sam2_add_prompts(
|
|
model,
|
|
state,
|
|
int(first_frame),
|
|
int(object_index),
|
|
coords,
|
|
labels,
|
|
first_pos_rects,
|
|
)
|
|
if add_errors:
|
|
raise RuntimeError("Failed applying one or more SAM2 rectangle prompts: {}".format(add_errors[:3]))
|
|
|
|
if num_frames <= 0:
|
|
try:
|
|
num_frames = int(state.get("num_frames", 0) or 0)
|
|
except Exception:
|
|
try:
|
|
num_frames = int(getattr(state, "num_frames", 0) or 0)
|
|
except Exception:
|
|
num_frames = 0
|
|
|
|
return (
|
|
sam2_model,
|
|
{
|
|
"inference_state": state,
|
|
"num_frames": num_frames,
|
|
"next_frame_idx": 0,
|
|
"negative_rects": [[float(a), float(b), float(c), float(d)] for (a, b, c, d) in first_neg_rects],
|
|
"active_negative_rects": [[float(a), float(b), float(c), float(d)] for (a, b, c, d) in first_neg_rects],
|
|
"seed_rects": [[float(a), float(b), float(c), float(d)] for (a, b, c, d) in first_pos_rects],
|
|
"prompt_schedule": serial_schedule,
|
|
"prompt_frames_applied": [int(first_frame)] if (len(coords) or len(first_pos_rects)) else [],
|
|
"object_carries": {},
|
|
},
|
|
)
|
|
|
|
|
|
class OpenShotSam2VideoSegmentationChunked:
|
|
@classmethod
|
|
def IS_CHANGED(cls, **kwargs):
|
|
return ""
|
|
|
|
@classmethod
|
|
def INPUT_TYPES(cls):
|
|
return {
|
|
"required": {
|
|
"sam2_model": ("SAM2MODEL",),
|
|
"inference_state": ("SAM2INFERENCESTATE",),
|
|
"image": ("IMAGE",),
|
|
"start_frame": ("INT", {"default": 0, "min": 0}),
|
|
"chunk_size_frames": ("INT", {"default": 32, "min": 1, "max": 4096}),
|
|
"keep_model_loaded": ("BOOLEAN", {"default": False}),
|
|
},
|
|
"optional": {
|
|
"meta_batch": ("VHS_BatchManager",),
|
|
},
|
|
}
|
|
|
|
RETURN_TYPES = ("MASK",)
|
|
RETURN_NAMES = ("mask",)
|
|
FUNCTION = "segment_chunk"
|
|
CATEGORY = "OpenShot/SAM2"
|
|
|
|
def _get_frames_per_batch(self, meta_batch, fallback):
|
|
if meta_batch is None:
|
|
return int(fallback)
|
|
if isinstance(meta_batch, dict):
|
|
for key in ("frames_per_batch", "batch_size", "frames"):
|
|
try:
|
|
if key in meta_batch and int(meta_batch[key]) > 0:
|
|
return int(meta_batch[key])
|
|
except Exception:
|
|
pass
|
|
for name in ("frames_per_batch", "batch_size", "frames"):
|
|
try:
|
|
value = getattr(meta_batch, name)
|
|
value = int(value)
|
|
if value > 0:
|
|
return value
|
|
except Exception:
|
|
pass
|
|
return int(fallback)
|
|
|
|
def _write_window_jpegs(self, image):
|
|
image_np = np.clip((image.detach().cpu().numpy() * 255.0), 0, 255).astype(np.uint8)
|
|
root = os.path.join(folder_paths.get_temp_directory(), "openshot_sam2_windows")
|
|
os.makedirs(root, exist_ok=True)
|
|
name = "w{}_{}".format(int(time.time() * 1000), hashlib.sha256(os.urandom(16)).hexdigest()[:8])
|
|
window_dir = os.path.join(root, name)
|
|
os.makedirs(window_dir, exist_ok=True)
|
|
for i, frame in enumerate(image_np):
|
|
Image.fromarray(frame[..., :3], mode="RGB").save(
|
|
os.path.join(window_dir, "{:05d}.jpg".format(i)),
|
|
format="JPEG",
|
|
quality=95,
|
|
)
|
|
return window_dir, int(image_np.shape[0]), int(image_np.shape[1]), int(image_np.shape[2])
|
|
|
|
def _init_window_state(self, model, window_dir, device, inference_state):
|
|
errs = []
|
|
offload_video_to_cpu = bool(inference_state.get("offload_video_to_cpu", False))
|
|
offload_state_to_cpu = bool(inference_state.get("offload_state_to_cpu", False))
|
|
for call in (
|
|
lambda: model.init_state(
|
|
window_dir,
|
|
offload_video_to_cpu=offload_video_to_cpu,
|
|
offload_state_to_cpu=offload_state_to_cpu,
|
|
),
|
|
lambda: model.init_state(window_dir, offload_video_to_cpu=offload_video_to_cpu),
|
|
lambda: model.init_state(window_dir, offload_state_to_cpu=offload_state_to_cpu),
|
|
lambda: model.init_state(window_dir),
|
|
lambda: model.init_state(window_dir, device=device),
|
|
):
|
|
try:
|
|
return call()
|
|
except Exception as ex:
|
|
errs.append(str(ex))
|
|
raise RuntimeError("SAM2 window init_state failed: {}".format(errs[:3]))
|
|
|
|
def _prompt_schedule(self, inference_state):
|
|
raw = inference_state.get("prompt_schedule") or []
|
|
out = []
|
|
for item in raw:
|
|
if not isinstance(item, dict):
|
|
continue
|
|
try:
|
|
frame_idx = int(item.get("frame_idx", 0))
|
|
except Exception:
|
|
frame_idx = 0
|
|
points = []
|
|
for p in (item.get("points") or []):
|
|
if not isinstance(p, (list, tuple)) or len(p) != 2:
|
|
continue
|
|
try:
|
|
points.append((float(p[0]), float(p[1])))
|
|
except Exception:
|
|
continue
|
|
labels = []
|
|
for v in (item.get("labels") or []):
|
|
try:
|
|
labels.append(int(v))
|
|
except Exception:
|
|
labels.append(0)
|
|
pos_rects = []
|
|
for r in (item.get("positive_rects") or []):
|
|
if not isinstance(r, (list, tuple)) or len(r) != 4:
|
|
continue
|
|
try:
|
|
pos_rects.append((float(r[0]), float(r[1]), float(r[2]), float(r[3])))
|
|
except Exception:
|
|
continue
|
|
neg_rects = []
|
|
for r in (item.get("negative_rects") or []):
|
|
if not isinstance(r, (list, tuple)) or len(r) != 4:
|
|
continue
|
|
try:
|
|
neg_rects.append((float(r[0]), float(r[1]), float(r[2]), float(r[3])))
|
|
except Exception:
|
|
continue
|
|
object_prompts = []
|
|
for op in (item.get("object_prompts") or []):
|
|
if not isinstance(op, dict):
|
|
continue
|
|
try:
|
|
obj_id = int(op.get("obj_id", 0))
|
|
except Exception:
|
|
obj_id = 0
|
|
op_points = []
|
|
for p in (op.get("points") or []):
|
|
if not isinstance(p, (list, tuple)) or len(p) != 2:
|
|
continue
|
|
try:
|
|
op_points.append((float(p[0]), float(p[1])))
|
|
except Exception:
|
|
continue
|
|
op_labels = []
|
|
for v in (op.get("labels") or []):
|
|
try:
|
|
op_labels.append(int(v))
|
|
except Exception:
|
|
op_labels.append(0)
|
|
op_rects = []
|
|
for r in (op.get("positive_rects") or []):
|
|
if not isinstance(r, (list, tuple)) or len(r) != 4:
|
|
continue
|
|
try:
|
|
op_rects.append((float(r[0]), float(r[1]), float(r[2]), float(r[3])))
|
|
except Exception:
|
|
continue
|
|
object_prompts.append(
|
|
{
|
|
"obj_id": int(max(0, obj_id)),
|
|
"points": op_points,
|
|
"labels": op_labels,
|
|
"positive_rects": op_rects,
|
|
}
|
|
)
|
|
|
|
out.append(
|
|
{
|
|
"frame_idx": int(max(0, frame_idx)),
|
|
"points": points,
|
|
"labels": labels,
|
|
"positive_rects": pos_rects,
|
|
"negative_rects": neg_rects,
|
|
"object_prompts": object_prompts,
|
|
}
|
|
)
|
|
out.sort(key=lambda x: int(x.get("frame_idx", 0)))
|
|
return out
|
|
|
|
def _apply_prompt_entry(self, model, state_obj, inference_state, frame_idx, entry):
|
|
points_list = list(entry.get("points") or [])
|
|
labels_list = [int(v) for v in (entry.get("labels") or [])]
|
|
rects = [tuple(r) for r in (entry.get("positive_rects") or [])]
|
|
neg_rects = [tuple(r) for r in (entry.get("negative_rects") or [])]
|
|
global_negative_points = [p for p, lbl in zip(points_list, labels_list) if int(lbl) == 0]
|
|
|
|
points = np.array(points_list, dtype=np.float32) if points_list else np.empty((0, 2), dtype=np.float32)
|
|
labels = np.array(labels_list, dtype=np.int32) if labels_list else np.empty((0,), dtype=np.int32)
|
|
if points.ndim == 1 and points.size > 0:
|
|
points = points.reshape(1, 2)
|
|
if labels.ndim == 0 and labels.size > 0:
|
|
labels = labels.reshape(1)
|
|
if (points.size == 0 or labels.size == 0) and rects:
|
|
centers = _rect_center_points(rects)
|
|
points = np.array(centers, dtype=np.float32)
|
|
labels = np.ones((len(centers),), dtype=np.int32)
|
|
|
|
object_prompts = list(entry.get("object_prompts") or [])
|
|
_sam2_debug(
|
|
"apply_prompt_entry",
|
|
"frame_idx=", int(frame_idx),
|
|
"points=", len(points_list),
|
|
"rects=", len(rects),
|
|
"object_prompts=", len(object_prompts),
|
|
"neg_points=", len(global_negative_points),
|
|
)
|
|
if object_prompts:
|
|
for op in object_prompts:
|
|
if not isinstance(op, dict):
|
|
continue
|
|
obj_id = int(max(0, int(op.get("obj_id", 0))))
|
|
op_points_list = list(op.get("points") or [])
|
|
op_labels_list = [int(v) for v in (op.get("labels") or [])]
|
|
op_rects = [tuple(r) for r in (op.get("positive_rects") or [])]
|
|
if global_negative_points:
|
|
op_points_list = list(op_points_list) + list(global_negative_points)
|
|
op_labels_list = list(op_labels_list) + [0 for _ in global_negative_points]
|
|
op_points = np.array(op_points_list, dtype=np.float32) if op_points_list else np.empty((0, 2), dtype=np.float32)
|
|
op_labels = np.array(op_labels_list, dtype=np.int32) if op_labels_list else np.empty((0,), dtype=np.int32)
|
|
if op_points.ndim == 1 and op_points.size > 0:
|
|
op_points = op_points.reshape(1, 2)
|
|
if op_labels.ndim == 0 and op_labels.size > 0:
|
|
op_labels = op_labels.reshape(1)
|
|
if (op_points.size == 0 or op_labels.size == 0) and op_rects:
|
|
centers = _rect_center_points(op_rects)
|
|
op_points = np.array(centers, dtype=np.float32)
|
|
op_labels = np.ones((len(centers),), dtype=np.int32)
|
|
_sam2_add_prompts(model, state_obj, int(frame_idx), obj_id, op_points, op_labels, op_rects)
|
|
if op_points.size > 0:
|
|
carries = dict(inference_state.get("object_carries", {}) or {})
|
|
px = float(op_points[0][0])
|
|
py = float(op_points[0][1])
|
|
if op_rects:
|
|
b = tuple(op_rects[0])
|
|
bx = [float(b[0]), float(b[1]), float(b[2]), float(b[3])]
|
|
else:
|
|
# Do not invent a tiny bbox for point-only prompts.
|
|
# Boundary replay should use the point itself unless we
|
|
# have a real object bbox from SAM2 propagation.
|
|
bx = None
|
|
carries[str(int(obj_id))] = {"point": [px, py], "bbox": bx}
|
|
inference_state["object_carries"] = carries
|
|
else:
|
|
obj_id = int(inference_state.get("object_index", 0))
|
|
_sam2_add_prompts(model, state_obj, int(frame_idx), obj_id, points, labels, rects)
|
|
inference_state["active_negative_rects"] = [[float(a), float(b), float(c), float(d)] for (a, b, c, d) in neg_rects]
|
|
if points.size > 0:
|
|
inference_state["last_points"] = points.tolist()
|
|
inference_state["last_labels"] = labels.tolist()
|
|
|
|
def _seed_window_prompt(self, model, local_state, inference_state):
|
|
carries = dict(inference_state.get("object_carries", {}) or {})
|
|
_sam2_debug(
|
|
"seed_window_prompt",
|
|
"next_frame_idx=", int(max(0, inference_state.get("next_frame_idx", 0) or 0)),
|
|
"carry_count=", len(carries),
|
|
"seed_rects=", len(list(inference_state.get("seed_rects") or [])),
|
|
)
|
|
if carries:
|
|
for raw_obj_id, payload in carries.items():
|
|
try:
|
|
obj_id = int(raw_obj_id)
|
|
except Exception:
|
|
continue
|
|
|
|
point = payload
|
|
bbox = None
|
|
if isinstance(payload, dict):
|
|
point = payload.get("point")
|
|
bbox = payload.get("bbox")
|
|
|
|
if not isinstance(point, (list, tuple)) or len(point) != 2:
|
|
continue
|
|
try:
|
|
x = float(point[0])
|
|
y = float(point[1])
|
|
except Exception:
|
|
continue
|
|
|
|
rects = []
|
|
if isinstance(bbox, (list, tuple)) and len(bbox) == 4:
|
|
try:
|
|
rects = [(
|
|
float(bbox[0]),
|
|
float(bbox[1]),
|
|
float(bbox[2]),
|
|
float(bbox[3]),
|
|
)]
|
|
except Exception:
|
|
rects = []
|
|
|
|
pts = np.array([[x, y]], dtype=np.float32)
|
|
lbs = np.array([1], dtype=np.int32)
|
|
_sam2_add_prompts(model, local_state, 0, obj_id, pts, lbs, rects)
|
|
return
|
|
|
|
# Only apply original seed prompts on the very first chunk.
|
|
# Re-using frame-1 seeds on later chunks can cause background drift
|
|
# after objects leave frame.
|
|
next_frame_idx = int(max(0, inference_state.get("next_frame_idx", 0) or 0))
|
|
use_initial_seed = (next_frame_idx <= 0)
|
|
|
|
# If an explicit frame-0 prompt exists in prompt_schedule (e.g. DINO boxes),
|
|
# do not also inject fallback seed prompts here, to avoid duplicate/competing seeds.
|
|
has_initial_prompt = False
|
|
if use_initial_seed:
|
|
for entry in list(inference_state.get("prompt_schedule") or []):
|
|
if not isinstance(entry, dict):
|
|
continue
|
|
try:
|
|
ef = int(entry.get("frame_idx", 0))
|
|
except Exception:
|
|
ef = 0
|
|
if ef != 0:
|
|
continue
|
|
labels0 = [int(v) for v in (entry.get("labels") or [])]
|
|
if any(v == 1 for v in labels0) or bool(entry.get("positive_rects") or []):
|
|
has_initial_prompt = True
|
|
break
|
|
for op in list(entry.get("object_prompts") or []):
|
|
if not isinstance(op, dict):
|
|
continue
|
|
op_labels0 = [int(v) for v in (op.get("labels") or [])]
|
|
if any(v == 1 for v in op_labels0) or bool(op.get("positive_rects") or []):
|
|
has_initial_prompt = True
|
|
break
|
|
if has_initial_prompt:
|
|
break
|
|
|
|
seed_points = inference_state.get("seed_points") if (use_initial_seed and not has_initial_prompt) else []
|
|
seed_labels = inference_state.get("seed_labels") if (use_initial_seed and not has_initial_prompt) else []
|
|
seed_rects = inference_state.get("seed_rects") if (use_initial_seed and not has_initial_prompt) else []
|
|
if use_initial_seed and has_initial_prompt:
|
|
_sam2_debug("seed_window_prompt-skip-fallback", "reason=frame0_prompt_schedule")
|
|
|
|
points = np.array(inference_state.get("last_points") or seed_points or [], dtype=np.float32)
|
|
labels = np.array(inference_state.get("last_labels") or seed_labels or [], dtype=np.int32)
|
|
rects = [tuple(r) for r in (seed_rects or []) if isinstance(r, (list, tuple)) and len(r) == 4]
|
|
if points.ndim == 1 and points.size > 0:
|
|
points = points.reshape(1, 2)
|
|
if labels.ndim == 0 and labels.size > 0:
|
|
labels = labels.reshape(1)
|
|
if (points.size == 0 or labels.size == 0) and rects:
|
|
centers = _rect_center_points(rects)
|
|
points = np.array(centers, dtype=np.float32)
|
|
labels = np.ones((len(centers),), dtype=np.int32)
|
|
if points.size == 0 and not rects:
|
|
return
|
|
obj_id = int(inference_state.get("object_index", 0))
|
|
_sam2_add_prompts(model, local_state, 0, obj_id, points, labels, rects)
|
|
_sam2_debug(
|
|
"seed_window_prompt-applied",
|
|
"obj_id=", int(obj_id),
|
|
"points=", int(points.shape[0]) if hasattr(points, "shape") else 0,
|
|
"rects=", len(rects),
|
|
)
|
|
|
|
def _collect_range_masks(self, model, state_obj, frame_start, frame_count):
|
|
frame_start = int(max(0, frame_start))
|
|
frame_count = int(max(0, frame_count))
|
|
if frame_count <= 0:
|
|
return []
|
|
try:
|
|
iterator = model.propagate_in_video(
|
|
state_obj,
|
|
start_frame_idx=frame_start,
|
|
max_frame_num_to_track=frame_count,
|
|
)
|
|
except TypeError:
|
|
iterator = model.propagate_in_video(state_obj)
|
|
|
|
by_idx = {}
|
|
frame_end = frame_start + frame_count
|
|
for out_frame_idx, out_obj_ids, out_mask_logits in iterator:
|
|
idx = int(out_frame_idx)
|
|
if idx < frame_start:
|
|
continue
|
|
if idx >= frame_end:
|
|
break
|
|
combined = None
|
|
for i, _obj_id in enumerate(out_obj_ids):
|
|
current = out_mask_logits[i, 0] > 0.0
|
|
combined = current if combined is None else torch.logical_or(combined, current)
|
|
if combined is None:
|
|
_n, _c, h, w = out_mask_logits.shape
|
|
combined = torch.zeros((h, w), dtype=torch.bool, device=out_mask_logits.device)
|
|
by_idx[idx] = combined.float().cpu()
|
|
del out_mask_logits
|
|
if not by_idx:
|
|
return []
|
|
h = int(next(iter(by_idx.values())).shape[0])
|
|
w = int(next(iter(by_idx.values())).shape[1])
|
|
return [by_idx.get(i, torch.zeros((h, w), dtype=torch.float32)) for i in range(frame_start, frame_end)]
|
|
|
|
def _update_prompt_from_last_mask(self, inference_state, masks):
|
|
last = None
|
|
for m in reversed(masks):
|
|
if torch.any(m > 0.0):
|
|
last = m
|
|
break
|
|
if last is None:
|
|
return
|
|
ys, xs = torch.where(last > 0.0)
|
|
if xs.numel() == 0:
|
|
return
|
|
cx = float(xs.float().mean().item())
|
|
cy = float(ys.float().mean().item())
|
|
inference_state["last_points"] = [[cx, cy]]
|
|
inference_state["last_labels"] = [1]
|
|
|
|
def _segment_windowed(self, sam2_model, inference_state, image, keep_model_loaded, meta_batch=None):
|
|
model = sam2_model["model"]
|
|
device = sam2_model["device"]
|
|
dtype = sam2_model["dtype"]
|
|
model.to(device)
|
|
autocast_device = mm.get_autocast_device(device)
|
|
autocast_ok = not mm.is_device_mps(device)
|
|
|
|
window_dir = None
|
|
local_state = None
|
|
out_chunks = []
|
|
try:
|
|
window_dir, bsz, h, w = self._write_window_jpegs(image)
|
|
progress = ProgressBar(bsz)
|
|
with torch.inference_mode():
|
|
with torch.autocast(autocast_device, dtype=dtype) if autocast_ok else nullcontext():
|
|
local_state = self._init_window_state(model, window_dir, device, inference_state)
|
|
# Seed from carried prompt so chunk-to-chunk tracking continues.
|
|
self._seed_window_prompt(model, local_state, inference_state)
|
|
|
|
global_start = int(max(0, inference_state.get("next_frame_idx", 0) or 0))
|
|
applied_frames = set(int(v) for v in (inference_state.get("prompt_frames_applied") or []))
|
|
_sam2_debug(
|
|
"segment_windowed-start",
|
|
"global_start=", int(global_start),
|
|
"bsz=", int(bsz),
|
|
"applied_frames=", sorted(list(applied_frames))[:8],
|
|
"carry_count=", len(dict(inference_state.get("object_carries", {}) or {})),
|
|
)
|
|
for entry in self._prompt_schedule(inference_state):
|
|
gidx = int(entry.get("frame_idx", 0))
|
|
if gidx < global_start or gidx >= (global_start + bsz):
|
|
continue
|
|
if gidx in applied_frames:
|
|
continue
|
|
self._apply_prompt_entry(model, local_state, inference_state, int(gidx - global_start), entry)
|
|
_sam2_debug("segment_windowed-prompt-applied", "global_frame=", int(gidx), "local_frame=", int(gidx - global_start))
|
|
applied_frames.add(gidx)
|
|
inference_state["prompt_frames_applied"] = sorted(list(applied_frames))
|
|
|
|
# Boundary replay: reinforce carried prompts for first N local frames.
|
|
boundary_reseed_frames = int(max(1, inference_state.get("boundary_reseed_frames", 4) or 4))
|
|
carries_for_reseed = dict(inference_state.get("object_carries", {}) or {})
|
|
if carries_for_reseed and boundary_reseed_frames > 1:
|
|
max_local = int(min(bsz, boundary_reseed_frames))
|
|
for local_f in range(1, max_local):
|
|
for raw_obj_id, payload in carries_for_reseed.items():
|
|
try:
|
|
obj_id = int(raw_obj_id)
|
|
except Exception:
|
|
continue
|
|
point = payload
|
|
bbox = None
|
|
if isinstance(payload, dict):
|
|
point = payload.get("point")
|
|
bbox = payload.get("bbox")
|
|
if not isinstance(point, (list, tuple)) or len(point) != 2:
|
|
continue
|
|
try:
|
|
x = float(point[0])
|
|
y = float(point[1])
|
|
except Exception:
|
|
continue
|
|
rects = []
|
|
if isinstance(bbox, (list, tuple)) and len(bbox) == 4:
|
|
try:
|
|
rects = [(
|
|
float(bbox[0]), float(bbox[1]), float(bbox[2]), float(bbox[3])
|
|
)]
|
|
except Exception:
|
|
rects = []
|
|
pts = np.array([[x, y]], dtype=np.float32)
|
|
lbs = np.array([1], dtype=np.int32)
|
|
_sam2_add_prompts(model, local_state, int(local_f), obj_id, pts, lbs, rects)
|
|
_sam2_debug("boundary_reseed", "frames=", int(max_local), "objects=", len(carries_for_reseed))
|
|
|
|
def _entry_has_positive_seed(entry):
|
|
labels = list((entry or {}).get("labels") or [])
|
|
if any(int(v) == 1 for v in labels):
|
|
return True
|
|
if bool((entry or {}).get("positive_rects") or []):
|
|
return True
|
|
for op in list((entry or {}).get("object_prompts") or []):
|
|
if not isinstance(op, dict):
|
|
continue
|
|
op_labels = list(op.get("labels") or [])
|
|
if any(int(v) == 1 for v in op_labels) or bool(op.get("positive_rects") or []):
|
|
return True
|
|
return False
|
|
|
|
by_idx = {}
|
|
carries = dict(inference_state.get("object_carries", {}) or {})
|
|
has_window_prompt = False
|
|
for entry in self._prompt_schedule(inference_state):
|
|
gidx = int(entry.get("frame_idx", 0))
|
|
if gidx < global_start or gidx >= (global_start + bsz):
|
|
continue
|
|
if _entry_has_positive_seed(entry):
|
|
has_window_prompt = True
|
|
break
|
|
|
|
# Gracefully handle prompt-only runs when detector finds no boxes.
|
|
# If nothing is currently seeded, emit empty masks for this chunk.
|
|
if (not carries) and (not has_window_prompt):
|
|
for _ in range(bsz):
|
|
progress.update(1)
|
|
else:
|
|
try:
|
|
iterator = model.propagate_in_video(
|
|
local_state,
|
|
start_frame_idx=0,
|
|
max_frame_num_to_track=bsz,
|
|
)
|
|
except TypeError:
|
|
iterator = model.propagate_in_video(local_state)
|
|
|
|
seen_obj_ids = set()
|
|
for out_frame_idx, out_obj_ids, out_mask_logits in iterator:
|
|
idx = int(out_frame_idx)
|
|
if idx < 0 or idx >= bsz:
|
|
continue
|
|
combined = None
|
|
for i, _obj_id in enumerate(out_obj_ids):
|
|
current = out_mask_logits[i, 0] > 0.0
|
|
combined = current if combined is None else torch.logical_or(combined, current)
|
|
if torch.any(current):
|
|
ys, xs = torch.where(current)
|
|
if xs.numel() > 0:
|
|
obj_id_int = int(_obj_id)
|
|
min_x = int(xs.min().item())
|
|
max_x = int(xs.max().item())
|
|
min_y = int(ys.min().item())
|
|
max_y = int(ys.max().item())
|
|
seen_obj_ids.add(obj_id_int)
|
|
carries[str(obj_id_int)] = {
|
|
"point": [
|
|
float(xs.float().mean().item()),
|
|
float(ys.float().mean().item()),
|
|
],
|
|
"bbox": [
|
|
float(min_x), float(min_y), float(max_x), float(max_y),
|
|
],
|
|
}
|
|
if combined is None:
|
|
combined = torch.zeros((h, w), dtype=torch.bool, device=out_mask_logits.device)
|
|
by_idx[idx] = combined.float().cpu()
|
|
progress.update(1)
|
|
del out_mask_logits
|
|
inference_state["object_carries"] = {
|
|
str(obj_id): carries.get(str(obj_id))
|
|
for obj_id in sorted(seen_obj_ids)
|
|
if str(obj_id) in carries
|
|
}
|
|
_sam2_debug(
|
|
"segment_windowed-end",
|
|
"kept_carries=", sorted([int(v) for v in seen_obj_ids]),
|
|
"next_frame_idx=", int(inference_state.get("next_frame_idx", 0) or 0),
|
|
)
|
|
|
|
for i in range(bsz):
|
|
out_chunks.append(by_idx.get(i, torch.zeros((h, w), dtype=torch.float32)))
|
|
# Do not overwrite user/keyframe prompts with a single centroid carry point.
|
|
# For multi-target prompts (e.g. several cars), centroid carry causes drift/fizzle.
|
|
inference_state["next_frame_idx"] = int(inference_state.get("next_frame_idx", 0) or 0) + bsz
|
|
inference_state["num_frames"] = int(inference_state.get("num_frames", 0) or 0) + bsz
|
|
if "_meta_run_key" not in inference_state:
|
|
inference_state["_meta_run_key"] = ""
|
|
if meta_batch is not None:
|
|
try:
|
|
setattr(meta_batch, "_openshot_sam2_window_state", inference_state)
|
|
except Exception:
|
|
pass
|
|
finally:
|
|
if local_state is not None and hasattr(model, "reset_state"):
|
|
try:
|
|
model.reset_state(local_state)
|
|
except Exception:
|
|
pass
|
|
if window_dir and os.path.isdir(window_dir):
|
|
shutil.rmtree(window_dir, ignore_errors=True)
|
|
if not keep_model_loaded:
|
|
model.to(mm.unet_offload_device())
|
|
mm.soft_empty_cache()
|
|
|
|
stacked = torch.stack(out_chunks, dim=0)
|
|
return (stacked,)
|
|
|
|
def segment_chunk(self, sam2_model, inference_state, image, start_frame, chunk_size_frames, keep_model_loaded, meta_batch=None):
|
|
model = sam2_model["model"]
|
|
device = sam2_model["device"]
|
|
dtype = sam2_model["dtype"]
|
|
segmentor = sam2_model.get("segmentor", "video")
|
|
if segmentor != "video":
|
|
raise ValueError("Loaded SAM2 model is not configured for video")
|
|
|
|
if bool(inference_state.get("windowed_mode", False)):
|
|
return self._segment_windowed(sam2_model, inference_state, image, keep_model_loaded, meta_batch=meta_batch)
|
|
|
|
state = inference_state["inference_state"]
|
|
chunk_size_frames = int(max(1, chunk_size_frames))
|
|
effective_chunk = self._get_frames_per_batch(meta_batch, chunk_size_frames)
|
|
# Force this node to track VHS chunking cadence exactly.
|
|
try:
|
|
effective_chunk = min(effective_chunk, int(image.shape[0]))
|
|
except Exception:
|
|
pass
|
|
|
|
# Persist frame cursor inside the shared inference_state object so each
|
|
# meta-batch call continues from the prior chunk without recomputing frame 0.
|
|
if "next_frame_idx" not in inference_state:
|
|
inference_state["next_frame_idx"] = int(max(0, start_frame))
|
|
current_start = int(max(0, inference_state.get("next_frame_idx", start_frame)))
|
|
|
|
total_frames = int(inference_state.get("num_frames", 0) or 0)
|
|
if total_frames > 0:
|
|
remaining = max(0, total_frames - current_start)
|
|
effective_chunk = min(effective_chunk, remaining) if remaining > 0 else 0
|
|
|
|
if effective_chunk <= 0:
|
|
raise RuntimeError("No remaining SAM2 frames to process (cursor at end of video)")
|
|
|
|
model.to(device)
|
|
autocast_device = mm.get_autocast_device(device)
|
|
autocast_ok = not mm.is_device_mps(device)
|
|
|
|
out_chunks = []
|
|
progress = ProgressBar(effective_chunk)
|
|
with torch.inference_mode():
|
|
with torch.autocast(autocast_device, dtype=dtype) if autocast_ok else nullcontext():
|
|
end_frame = current_start + effective_chunk
|
|
schedule_by_frame = {
|
|
int(entry.get("frame_idx", 0)): entry
|
|
for entry in self._prompt_schedule(inference_state)
|
|
}
|
|
applied_frames = set(int(v) for v in (inference_state.get("prompt_frames_applied") or []))
|
|
for frame_idx in sorted(schedule_by_frame.keys()):
|
|
if frame_idx < current_start or frame_idx >= end_frame:
|
|
continue
|
|
if frame_idx in applied_frames:
|
|
continue
|
|
self._apply_prompt_entry(model, state, inference_state, frame_idx, schedule_by_frame[frame_idx])
|
|
applied_frames.add(frame_idx)
|
|
inference_state["prompt_frames_applied"] = sorted(list(applied_frames))
|
|
|
|
try:
|
|
iterator = model.propagate_in_video(
|
|
state,
|
|
start_frame_idx=current_start,
|
|
max_frame_num_to_track=effective_chunk,
|
|
)
|
|
except TypeError:
|
|
iterator = model.propagate_in_video(state)
|
|
|
|
carries = dict(inference_state.get("object_carries", {}) or {})
|
|
seen_obj_ids = set()
|
|
for out_frame_idx, out_obj_ids, out_mask_logits in iterator:
|
|
idx = int(out_frame_idx)
|
|
if idx < current_start:
|
|
continue
|
|
if idx >= end_frame:
|
|
break
|
|
|
|
combined = None
|
|
for i, _obj_id in enumerate(out_obj_ids):
|
|
current = out_mask_logits[i, 0] > 0.0
|
|
combined = current if combined is None else torch.logical_or(combined, current)
|
|
if torch.any(current):
|
|
ys, xs = torch.where(current)
|
|
if xs.numel() > 0:
|
|
h_cur = int(current.shape[0])
|
|
w_cur = int(current.shape[1])
|
|
obj_id_int = int(_obj_id)
|
|
area = int(xs.numel())
|
|
area_ratio = float(area) / float(max(1, h_cur * w_cur))
|
|
min_x = int(xs.min().item())
|
|
max_x = int(xs.max().item())
|
|
min_y = int(ys.min().item())
|
|
max_y = int(ys.max().item())
|
|
bbox_w = int(max_x - min_x + 1)
|
|
bbox_h = int(max_y - min_y + 1)
|
|
bbox_area = max(1, bbox_w * bbox_h)
|
|
fill_ratio = float(area) / float(bbox_area)
|
|
touches_edge = (min_x <= 1) or (min_y <= 1) or (max_x >= (w_cur - 2)) or (max_y >= (h_cur - 2))
|
|
|
|
seen_obj_ids.add(obj_id_int)
|
|
carries[str(obj_id_int)] = {
|
|
"point": [
|
|
float(xs.float().mean().item()),
|
|
float(ys.float().mean().item()),
|
|
],
|
|
"bbox": [
|
|
float(min_x), float(min_y), float(max_x), float(max_y),
|
|
],
|
|
}
|
|
|
|
if combined is None:
|
|
_n, _c, h, w = out_mask_logits.shape
|
|
combined = torch.zeros((h, w), dtype=torch.bool, device=out_mask_logits.device)
|
|
|
|
out_chunks.append(combined.float().cpu())
|
|
progress.update(1)
|
|
del out_mask_logits
|
|
inference_state["object_carries"] = {
|
|
str(obj_id): carries.get(str(obj_id))
|
|
for obj_id in sorted(seen_obj_ids)
|
|
if str(obj_id) in carries
|
|
}
|
|
|
|
if not out_chunks:
|
|
raise RuntimeError(
|
|
"SAM2 chunk produced no frames. Check cursor/chunk size and inference state. "
|
|
"cursor={} chunk={} total={}".format(current_start, effective_chunk, total_frames)
|
|
)
|
|
|
|
inference_state["next_frame_idx"] = current_start + effective_chunk
|
|
if total_frames > 0 and inference_state["next_frame_idx"] >= total_frames:
|
|
if hasattr(model, "reset_state"):
|
|
try:
|
|
model.reset_state(state)
|
|
except Exception:
|
|
pass
|
|
|
|
if not keep_model_loaded:
|
|
model.to(mm.unet_offload_device())
|
|
mm.soft_empty_cache()
|
|
|
|
stacked = torch.stack(out_chunks, dim=0)
|
|
return (stacked,)
|
|
|
|
|
|
def _gaussian_kernel(kernel_size, sigma, device, dtype):
|
|
axis = torch.linspace(-1, 1, kernel_size, device=device, dtype=dtype)
|
|
x, y = torch.meshgrid(axis, axis, indexing="ij")
|
|
d = torch.sqrt(x * x + y * y)
|
|
g = torch.exp(-(d * d) / (2.0 * sigma * sigma))
|
|
return g / g.sum()
|
|
|
|
|
|
def _parse_color_rgba(color_text, default=(1.0, 1.0, 0.0, 1.0)):
|
|
text = str(color_text or "").strip().lower()
|
|
if not text:
|
|
return default
|
|
if text == "transparent":
|
|
return (0.0, 0.0, 0.0, 0.0)
|
|
if text.startswith("#"):
|
|
raw = text[1:]
|
|
try:
|
|
if len(raw) == 6:
|
|
r = int(raw[0:2], 16) / 255.0
|
|
g = int(raw[2:4], 16) / 255.0
|
|
b = int(raw[4:6], 16) / 255.0
|
|
return (r, g, b, 1.0)
|
|
if len(raw) == 8:
|
|
r = int(raw[0:2], 16) / 255.0
|
|
g = int(raw[2:4], 16) / 255.0
|
|
b = int(raw[4:6], 16) / 255.0
|
|
a = int(raw[6:8], 16) / 255.0
|
|
return (r, g, b, a)
|
|
except Exception:
|
|
return default
|
|
return default
|
|
|
|
|
|
class OpenShotImageBlurMasked:
|
|
@classmethod
|
|
def IS_CHANGED(cls, **kwargs):
|
|
return ""
|
|
|
|
@classmethod
|
|
def INPUT_TYPES(cls):
|
|
return {
|
|
"required": {
|
|
"image": ("IMAGE",),
|
|
"mask": ("MASK",),
|
|
"blur_radius": ("INT", {"default": 12, "min": 0, "max": 64, "step": 1}),
|
|
"sigma": ("FLOAT", {"default": 4.0, "min": 0.1, "max": 20.0, "step": 0.1}),
|
|
},
|
|
}
|
|
|
|
RETURN_TYPES = ("IMAGE",)
|
|
RETURN_NAMES = ("image",)
|
|
FUNCTION = "blur_masked"
|
|
CATEGORY = "OpenShot/Video"
|
|
|
|
def blur_masked(self, image, mask, blur_radius, sigma):
|
|
blur_radius = int(max(0, blur_radius))
|
|
if blur_radius == 0:
|
|
return (image,)
|
|
|
|
device = mm.get_torch_device()
|
|
img = image.to(device)
|
|
m = mask.to(device).float()
|
|
if m.ndim == 3:
|
|
m = m.unsqueeze(-1)
|
|
m = torch.clamp(m, 0.0, 1.0)
|
|
|
|
has_mask = (m.view(m.shape[0], -1).max(dim=1).values > 0)
|
|
if not bool(has_mask.any()):
|
|
return (image,)
|
|
|
|
out = img.clone()
|
|
idx = torch.nonzero(has_mask, as_tuple=False).squeeze(1)
|
|
work = img[idx]
|
|
work_mask = m[idx]
|
|
|
|
kernel_size = blur_radius * 2 + 1
|
|
kernel = _gaussian_kernel(kernel_size, float(sigma), device=work.device, dtype=work.dtype)
|
|
kernel = kernel.repeat(work.shape[-1], 1, 1).unsqueeze(1)
|
|
|
|
work_nchw = work.permute(0, 3, 1, 2)
|
|
padded = F.pad(work_nchw, (blur_radius, blur_radius, blur_radius, blur_radius), "reflect")
|
|
blurred = F.conv2d(padded, kernel, padding=kernel_size // 2, groups=work.shape[-1])[
|
|
:, :, blur_radius:-blur_radius, blur_radius:-blur_radius
|
|
]
|
|
blurred = blurred.permute(0, 2, 3, 1)
|
|
|
|
composited = work * (1.0 - work_mask) + blurred * work_mask
|
|
out[idx] = composited
|
|
return (out.to(mm.intermediate_device()),)
|
|
|
|
|
|
class OpenShotImageHighlightMasked:
|
|
@classmethod
|
|
def IS_CHANGED(cls, **kwargs):
|
|
return ""
|
|
|
|
@classmethod
|
|
def INPUT_TYPES(cls):
|
|
return {
|
|
"required": {
|
|
"image": ("IMAGE",),
|
|
"mask": ("MASK",),
|
|
"highlight_color": ("STRING", {"default": "#F5D742"}),
|
|
"highlight_opacity": ("FLOAT", {"default": 0.35, "min": 0.0, "max": 1.0, "step": 0.01}),
|
|
"border_color": ("STRING", {"default": "transparent"}),
|
|
"border_width": ("INT", {"default": 0, "min": 0, "max": 64, "step": 1}),
|
|
"mask_brightness": ("FLOAT", {"default": 1.15, "min": 0.0, "max": 3.0, "step": 0.01}),
|
|
"background_brightness": ("FLOAT", {"default": 0.75, "min": 0.0, "max": 3.0, "step": 0.01}),
|
|
},
|
|
}
|
|
|
|
RETURN_TYPES = ("IMAGE",)
|
|
RETURN_NAMES = ("image",)
|
|
FUNCTION = "highlight_masked"
|
|
CATEGORY = "OpenShot/Video"
|
|
|
|
def highlight_masked(
|
|
self,
|
|
image,
|
|
mask,
|
|
highlight_color,
|
|
highlight_opacity,
|
|
border_color,
|
|
border_width,
|
|
mask_brightness,
|
|
background_brightness,
|
|
):
|
|
hi_r, hi_g, hi_b, hi_a = _parse_color_rgba(highlight_color, default=(0.96, 0.84, 0.26, 1.0))
|
|
bo_r, bo_g, bo_b, bo_a = _parse_color_rgba(border_color, default=(0.0, 0.0, 0.0, 0.0))
|
|
hi_alpha = float(max(0.0, min(1.0, float(highlight_opacity)))) * float(hi_a)
|
|
border_width = int(max(0, border_width))
|
|
mask_brightness = float(max(0.0, min(3.0, float(mask_brightness))))
|
|
background_brightness = float(max(0.0, min(3.0, float(background_brightness))))
|
|
|
|
if hi_alpha <= 0.0 and (border_width <= 0 or bo_a <= 0.0):
|
|
return (image,)
|
|
|
|
device = mm.get_torch_device()
|
|
img = image.to(device)
|
|
m = mask.to(device).float()
|
|
if m.ndim == 2:
|
|
m = m.unsqueeze(0)
|
|
if m.ndim == 4:
|
|
m = m.squeeze(-1)
|
|
if m.ndim != 3:
|
|
return (image,)
|
|
m = torch.clamp(m, 0.0, 1.0)
|
|
if int(m.shape[0]) == 1 and int(img.shape[0]) > 1:
|
|
m = m.repeat(int(img.shape[0]), 1, 1)
|
|
if int(m.shape[0]) != int(img.shape[0]):
|
|
return (image,)
|
|
|
|
has_mask = (m.view(m.shape[0], -1).max(dim=1).values > 0)
|
|
if not bool(has_mask.any()):
|
|
return (image,)
|
|
|
|
out = img.clone()
|
|
idx = torch.nonzero(has_mask, as_tuple=False).squeeze(1)
|
|
work = img[idx]
|
|
work_mask = m[idx].unsqueeze(-1)
|
|
work_bg = torch.clamp(work * background_brightness, 0.0, 1.0)
|
|
work_fg = torch.clamp(work * mask_brightness, 0.0, 1.0)
|
|
work = work_bg * (1.0 - work_mask) + work_fg * work_mask
|
|
|
|
if hi_alpha > 0.0:
|
|
hi_color = torch.tensor([hi_r, hi_g, hi_b], device=work.device, dtype=work.dtype).view(1, 1, 1, 3)
|
|
fill_alpha = torch.clamp(work_mask * hi_alpha, 0.0, 1.0)
|
|
work = work * (1.0 - fill_alpha) + hi_color * fill_alpha
|
|
|
|
if border_width > 0 and bo_a > 0.0:
|
|
k = border_width * 2 + 1
|
|
base = work_mask.permute(0, 3, 1, 2)
|
|
dilated = F.max_pool2d(base, kernel_size=k, stride=1, padding=border_width)
|
|
border = torch.clamp(dilated - base, 0.0, 1.0).permute(0, 2, 3, 1)
|
|
if torch.any(border > 0.0):
|
|
bo_color = torch.tensor([bo_r, bo_g, bo_b], device=work.device, dtype=work.dtype).view(1, 1, 1, 3)
|
|
border_alpha = torch.clamp(border * bo_a, 0.0, 1.0)
|
|
work = work * (1.0 - border_alpha) + bo_color * border_alpha
|
|
|
|
out[idx] = work
|
|
return (out.to(mm.intermediate_device()),)
|
|
|
|
|
|
class OpenShotGroundingDinoDetect:
|
|
_model_cache = {}
|
|
|
|
@classmethod
|
|
def IS_CHANGED(cls, **kwargs):
|
|
return ""
|
|
|
|
@classmethod
|
|
def INPUT_TYPES(cls):
|
|
return {
|
|
"required": {
|
|
"image": ("IMAGE",),
|
|
"prompt": ("STRING", {"default": "person.", "multiline": False}),
|
|
"model_id": (GROUNDING_DINO_MODEL_IDS,),
|
|
"box_threshold": ("FLOAT", {"default": 0.35, "min": 0.0, "max": 1.0, "step": 0.01}),
|
|
"text_threshold": ("FLOAT", {"default": 0.25, "min": 0.0, "max": 1.0, "step": 0.01}),
|
|
"device": (("auto", "cpu", "cuda", "mps"),),
|
|
"keep_model_loaded": ("BOOLEAN", {"default": True}),
|
|
},
|
|
}
|
|
|
|
RETURN_TYPES = ("MASK", "STRING")
|
|
RETURN_NAMES = ("mask", "detections_json")
|
|
FUNCTION = "detect"
|
|
CATEGORY = "OpenShot/GroundingDINO"
|
|
|
|
def _resolve_device(self, device_name):
|
|
device_name = str(device_name or "auto").strip().lower()
|
|
if device_name == "auto":
|
|
return mm.get_torch_device()
|
|
return torch.device(device_name)
|
|
|
|
def _cache_key(self, model_id, device):
|
|
return "{}::{}".format(model_id, str(device))
|
|
|
|
def _get_model_and_processor(self, model_id, device):
|
|
key = self._cache_key(model_id, device)
|
|
if key in self._model_cache:
|
|
return self._model_cache[key]
|
|
|
|
processor = AutoProcessor.from_pretrained(model_id)
|
|
model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id)
|
|
model.to(device)
|
|
model.eval()
|
|
self._model_cache[key] = (processor, model)
|
|
return processor, model
|
|
|
|
def _tensor_to_pil(self, img):
|
|
arr = torch.clamp(img, 0.0, 1.0).mul(255.0).byte().cpu().numpy()
|
|
return Image.fromarray(arr)
|
|
|
|
def _boxes_to_mask(self, boxes, height, width):
|
|
frame_mask = torch.zeros((height, width), dtype=torch.float32)
|
|
for box in boxes:
|
|
x0, y0, x1, y1 = [float(v) for v in box]
|
|
left = int(max(0, min(width, np.floor(x0))))
|
|
top = int(max(0, min(height, np.floor(y0))))
|
|
right = int(max(0, min(width, np.ceil(x1))))
|
|
bottom = int(max(0, min(height, np.ceil(y1))))
|
|
if right <= left or bottom <= top:
|
|
continue
|
|
frame_mask[top:bottom, left:right] = 1.0
|
|
return frame_mask
|
|
|
|
def detect(self, image, prompt, model_id, box_threshold, text_threshold, device, keep_model_loaded):
|
|
_require_groundingdino()
|
|
|
|
prompt = str(prompt or "").strip()
|
|
if not prompt:
|
|
raise ValueError("GroundingDINO prompt must not be empty")
|
|
if not prompt.endswith("."):
|
|
prompt = "{}.".format(prompt)
|
|
|
|
device = self._resolve_device(device)
|
|
processor, model = self._get_model_and_processor(model_id, device)
|
|
model.to(device)
|
|
|
|
batch = int(image.shape[0])
|
|
height = int(image.shape[1])
|
|
width = int(image.shape[2])
|
|
all_masks = []
|
|
all_detections = []
|
|
|
|
with torch.inference_mode():
|
|
for i in range(batch):
|
|
pil = self._tensor_to_pil(image[i])
|
|
inputs = processor(images=pil, text=prompt, return_tensors="pt")
|
|
inputs = {k: v.to(device) for k, v in inputs.items()}
|
|
outputs = model(**inputs)
|
|
post_kwargs = {
|
|
"target_sizes": [(height, width)],
|
|
"text_threshold": float(text_threshold),
|
|
}
|
|
try:
|
|
result = processor.post_process_grounded_object_detection(
|
|
outputs,
|
|
inputs["input_ids"],
|
|
box_threshold=float(box_threshold),
|
|
**post_kwargs,
|
|
)[0]
|
|
except TypeError:
|
|
try:
|
|
result = processor.post_process_grounded_object_detection(
|
|
outputs,
|
|
inputs["input_ids"],
|
|
threshold=float(box_threshold),
|
|
**post_kwargs,
|
|
)[0]
|
|
except TypeError:
|
|
result = processor.post_process_grounded_object_detection(
|
|
outputs,
|
|
inputs["input_ids"],
|
|
threshold=float(box_threshold),
|
|
target_sizes=[(height, width)],
|
|
)[0]
|
|
|
|
boxes = result.get("boxes")
|
|
labels = result.get("labels")
|
|
scores = result.get("scores")
|
|
if boxes is None or boxes.numel() == 0:
|
|
all_masks.append(torch.zeros((height, width), dtype=torch.float32))
|
|
all_detections.append({"frame_index": i, "detections": []})
|
|
continue
|
|
|
|
boxes_cpu = boxes.detach().cpu()
|
|
mask = self._boxes_to_mask(boxes_cpu, height, width)
|
|
all_masks.append(mask)
|
|
|
|
frame_items = []
|
|
for idx in range(boxes_cpu.shape[0]):
|
|
frame_items.append(
|
|
{
|
|
"label": str(labels[idx]),
|
|
"score": float(scores[idx].item()),
|
|
"box_xyxy": [float(v) for v in boxes_cpu[idx].tolist()],
|
|
}
|
|
)
|
|
all_detections.append({"frame_index": i, "detections": frame_items})
|
|
|
|
if not keep_model_loaded:
|
|
model.to(mm.unet_offload_device())
|
|
mm.soft_empty_cache()
|
|
|
|
mask_tensor = torch.stack(all_masks, dim=0).to(mm.intermediate_device())
|
|
return (mask_tensor, json.dumps(all_detections))
|
|
|
|
|
|
NODE_CLASS_MAPPINGS = {
|
|
"OpenShotTransNetSceneDetect": OpenShotTransNetSceneDetect,
|
|
"OpenShotDownloadAndLoadSAM2Model": OpenShotDownloadAndLoadSAM2Model,
|
|
"OpenShotSam2Segmentation": OpenShotSam2Segmentation,
|
|
"OpenShotSam2VideoSegmentationAddPoints": OpenShotSam2VideoSegmentationAddPoints,
|
|
"OpenShotSam2VideoSegmentationChunked": OpenShotSam2VideoSegmentationChunked,
|
|
"OpenShotImageBlurMasked": OpenShotImageBlurMasked,
|
|
"OpenShotImageHighlightMasked": OpenShotImageHighlightMasked,
|
|
"OpenShotGroundingDinoDetect": OpenShotGroundingDinoDetect,
|
|
"OpenShotSceneRangesFromSegments": OpenShotSceneRangesFromSegments,
|
|
}
|
|
|
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
|
"OpenShotTransNetSceneDetect": "OpenShot TransNet Scene Detect",
|
|
"OpenShotDownloadAndLoadSAM2Model": "OpenShot Download+Load SAM2",
|
|
"OpenShotSam2Segmentation": "OpenShot SAM2 Segmentation (Image)",
|
|
"OpenShotSam2VideoSegmentationAddPoints": "OpenShot SAM2 Add Video Points",
|
|
"OpenShotSam2VideoSegmentationChunked": "OpenShot SAM2 Video Segmentation (Chunked)",
|
|
"OpenShotImageBlurMasked": "OpenShot Blur Masked (Skip Empty)",
|
|
"OpenShotImageHighlightMasked": "OpenShot Highlight Masked",
|
|
"OpenShotGroundingDinoDetect": "OpenShot GroundingDINO Detect",
|
|
"OpenShotSceneRangesFromSegments": "OpenShot Scene Ranges From Segments",
|
|
}
|