Debugging chunk boundaries with SAM2

This commit is contained in:
Jonathan Thomas
2026-02-21 16:10:01 -06:00
parent bc0899ebb9
commit fc36ce04cc

View File

@@ -58,6 +58,19 @@ GROUNDING_DINO_MODEL_IDS = (
"IDEA-Research/grounding-dino-base",
)
GROUNDING_DINO_CACHE = {}
def _sam2_debug_enabled():
v = str(os.environ.get("OPENSHOT_SAM2_DEBUG", "")).strip().lower()
return v in ("1", "true", "yes", "on", "debug")
def _sam2_debug(*parts):
if _sam2_debug_enabled():
try:
print("[OpenShot-SAM2-DEBUG]", *parts)
except Exception:
pass
SAM2_MODELS = {
"sam2.1_hiera_tiny.safetensors": {
"url": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_tiny.pt",
@@ -1409,6 +1422,12 @@ class OpenShotSam2VideoSegmentationAddPoints:
)
if dino_boxes:
_sam2_debug(
"dino-seed",
"prompt=", dino_prompt,
"seed_frame_idx=", int(seed_frame_idx),
"boxes=", len(dino_boxes),
)
seed_entry = dict(prompt_schedule.get(seed_frame_idx, {}) or {})
seed_object_prompts = list(seed_entry.get("object_prompts") or [])
next_obj_id = 0
@@ -1462,6 +1481,14 @@ class OpenShotSam2VideoSegmentationAddPoints:
if not has_any_positive and not allow_empty_schedule:
raise ValueError("No positive points/rectangles provided")
_sam2_debug(
"add_points",
"seed_frame_idx=", int(seed_frame_idx),
"schedule_frames=", sorted([int(k) for k in prompt_schedule.keys()]),
"windowed=", bool(windowed_mode),
"has_prev_state=", bool(prev_inference_state is not None),
)
serial_schedule = []
for fidx in sorted(prompt_schedule.keys()):
entry = prompt_schedule.get(fidx, {}) or {}
@@ -1839,6 +1866,14 @@ class OpenShotSam2VideoSegmentationChunked:
labels = np.ones((len(centers),), dtype=np.int32)
object_prompts = list(entry.get("object_prompts") or [])
_sam2_debug(
"apply_prompt_entry",
"frame_idx=", int(frame_idx),
"points=", len(points_list),
"rects=", len(rects),
"object_prompts=", len(object_prompts),
"neg_points=", len(global_negative_points),
)
if object_prompts:
for op in object_prompts:
if not isinstance(op, dict):
@@ -1875,6 +1910,12 @@ class OpenShotSam2VideoSegmentationChunked:
def _seed_window_prompt(self, model, local_state, inference_state):
carries = dict(inference_state.get("object_carries", {}) or {})
_sam2_debug(
"seed_window_prompt",
"next_frame_idx=", int(max(0, inference_state.get("next_frame_idx", 0) or 0)),
"carry_count=", len(carries),
"seed_rects=", len(list(inference_state.get("seed_rects") or [])),
)
if carries:
for raw_obj_id, point in carries.items():
try:
@@ -1918,6 +1959,12 @@ class OpenShotSam2VideoSegmentationChunked:
return
obj_id = int(inference_state.get("object_index", 0))
_sam2_add_prompts(model, local_state, 0, obj_id, points, labels, rects)
_sam2_debug(
"seed_window_prompt-applied",
"obj_id=", int(obj_id),
"points=", int(points.shape[0]) if hasattr(points, "shape") else 0,
"rects=", len(rects),
)
def _collect_range_masks(self, model, state_obj, frame_start, frame_count):
frame_start = int(max(0, frame_start))
@@ -1994,6 +2041,13 @@ class OpenShotSam2VideoSegmentationChunked:
global_start = int(max(0, inference_state.get("next_frame_idx", 0) or 0))
applied_frames = set(int(v) for v in (inference_state.get("prompt_frames_applied") or []))
_sam2_debug(
"segment_windowed-start",
"global_start=", int(global_start),
"bsz=", int(bsz),
"applied_frames=", sorted(list(applied_frames))[:8],
"carry_count=", len(dict(inference_state.get("object_carries", {}) or {})),
)
for entry in self._prompt_schedule(inference_state):
gidx = int(entry.get("frame_idx", 0))
if gidx < global_start or gidx >= (global_start + bsz):
@@ -2001,6 +2055,7 @@ class OpenShotSam2VideoSegmentationChunked:
if gidx in applied_frames:
continue
self._apply_prompt_entry(model, local_state, inference_state, int(gidx - global_start), entry)
_sam2_debug("segment_windowed-prompt-applied", "global_frame=", int(gidx), "local_frame=", int(gidx - global_start))
applied_frames.add(gidx)
inference_state["prompt_frames_applied"] = sorted(list(applied_frames))
@@ -2085,6 +2140,14 @@ class OpenShotSam2VideoSegmentationChunked:
float(xs.float().mean().item()),
float(ys.float().mean().item()),
]
else:
_sam2_debug(
"carry-drop-windowed",
"obj_id=", int(obj_id_int),
"area_ratio=", round(float(area_ratio), 6),
"fill_ratio=", round(float(fill_ratio), 6),
"touches_edge=", bool(touches_edge),
)
if combined is None:
combined = torch.zeros((h, w), dtype=torch.bool, device=out_mask_logits.device)
by_idx[idx] = combined.float().cpu()
@@ -2095,6 +2158,11 @@ class OpenShotSam2VideoSegmentationChunked:
for obj_id in sorted(seen_obj_ids)
if str(obj_id) in carries
}
_sam2_debug(
"segment_windowed-end",
"kept_carries=", sorted([int(v) for v in seen_obj_ids]),
"next_frame_idx=", int(inference_state.get("next_frame_idx", 0) or 0),
)
for i in range(bsz):
out_chunks.append(by_idx.get(i, torch.zeros((h, w), dtype=torch.float32)))