From fc36ce04ccca6da648e71d90d0964b4b7355aee2 Mon Sep 17 00:00:00 2001 From: Jonathan Thomas Date: Sat, 21 Feb 2026 16:10:01 -0600 Subject: [PATCH] Debugging chunk boundaries with SAM2 --- nodes.py | 68 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 68 insertions(+) diff --git a/nodes.py b/nodes.py index a039245..daeb086 100644 --- a/nodes.py +++ b/nodes.py @@ -58,6 +58,19 @@ GROUNDING_DINO_MODEL_IDS = ( "IDEA-Research/grounding-dino-base", ) GROUNDING_DINO_CACHE = {} +def _sam2_debug_enabled(): + v = str(os.environ.get("OPENSHOT_SAM2_DEBUG", "")).strip().lower() + return v in ("1", "true", "yes", "on", "debug") + + +def _sam2_debug(*parts): + if _sam2_debug_enabled(): + try: + print("[OpenShot-SAM2-DEBUG]", *parts) + except Exception: + pass + + SAM2_MODELS = { "sam2.1_hiera_tiny.safetensors": { "url": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_tiny.pt", @@ -1409,6 +1422,12 @@ class OpenShotSam2VideoSegmentationAddPoints: ) if dino_boxes: + _sam2_debug( + "dino-seed", + "prompt=", dino_prompt, + "seed_frame_idx=", int(seed_frame_idx), + "boxes=", len(dino_boxes), + ) seed_entry = dict(prompt_schedule.get(seed_frame_idx, {}) or {}) seed_object_prompts = list(seed_entry.get("object_prompts") or []) next_obj_id = 0 @@ -1462,6 +1481,14 @@ class OpenShotSam2VideoSegmentationAddPoints: if not has_any_positive and not allow_empty_schedule: raise ValueError("No positive points/rectangles provided") + _sam2_debug( + "add_points", + "seed_frame_idx=", int(seed_frame_idx), + "schedule_frames=", sorted([int(k) for k in prompt_schedule.keys()]), + "windowed=", bool(windowed_mode), + "has_prev_state=", bool(prev_inference_state is not None), + ) + serial_schedule = [] for fidx in sorted(prompt_schedule.keys()): entry = prompt_schedule.get(fidx, {}) or {} @@ -1839,6 +1866,14 @@ class OpenShotSam2VideoSegmentationChunked: labels = np.ones((len(centers),), dtype=np.int32) object_prompts = list(entry.get("object_prompts") or []) + _sam2_debug( + "apply_prompt_entry", + "frame_idx=", int(frame_idx), + "points=", len(points_list), + "rects=", len(rects), + "object_prompts=", len(object_prompts), + "neg_points=", len(global_negative_points), + ) if object_prompts: for op in object_prompts: if not isinstance(op, dict): @@ -1875,6 +1910,12 @@ class OpenShotSam2VideoSegmentationChunked: def _seed_window_prompt(self, model, local_state, inference_state): carries = dict(inference_state.get("object_carries", {}) or {}) + _sam2_debug( + "seed_window_prompt", + "next_frame_idx=", int(max(0, inference_state.get("next_frame_idx", 0) or 0)), + "carry_count=", len(carries), + "seed_rects=", len(list(inference_state.get("seed_rects") or [])), + ) if carries: for raw_obj_id, point in carries.items(): try: @@ -1918,6 +1959,12 @@ class OpenShotSam2VideoSegmentationChunked: return obj_id = int(inference_state.get("object_index", 0)) _sam2_add_prompts(model, local_state, 0, obj_id, points, labels, rects) + _sam2_debug( + "seed_window_prompt-applied", + "obj_id=", int(obj_id), + "points=", int(points.shape[0]) if hasattr(points, "shape") else 0, + "rects=", len(rects), + ) def _collect_range_masks(self, model, state_obj, frame_start, frame_count): frame_start = int(max(0, frame_start)) @@ -1994,6 +2041,13 @@ class OpenShotSam2VideoSegmentationChunked: global_start = int(max(0, inference_state.get("next_frame_idx", 0) or 0)) applied_frames = set(int(v) for v in (inference_state.get("prompt_frames_applied") or [])) + _sam2_debug( + "segment_windowed-start", + "global_start=", int(global_start), + "bsz=", int(bsz), + "applied_frames=", sorted(list(applied_frames))[:8], + "carry_count=", len(dict(inference_state.get("object_carries", {}) or {})), + ) for entry in self._prompt_schedule(inference_state): gidx = int(entry.get("frame_idx", 0)) if gidx < global_start or gidx >= (global_start + bsz): @@ -2001,6 +2055,7 @@ class OpenShotSam2VideoSegmentationChunked: if gidx in applied_frames: continue self._apply_prompt_entry(model, local_state, inference_state, int(gidx - global_start), entry) + _sam2_debug("segment_windowed-prompt-applied", "global_frame=", int(gidx), "local_frame=", int(gidx - global_start)) applied_frames.add(gidx) inference_state["prompt_frames_applied"] = sorted(list(applied_frames)) @@ -2085,6 +2140,14 @@ class OpenShotSam2VideoSegmentationChunked: float(xs.float().mean().item()), float(ys.float().mean().item()), ] + else: + _sam2_debug( + "carry-drop-windowed", + "obj_id=", int(obj_id_int), + "area_ratio=", round(float(area_ratio), 6), + "fill_ratio=", round(float(fill_ratio), 6), + "touches_edge=", bool(touches_edge), + ) if combined is None: combined = torch.zeros((h, w), dtype=torch.bool, device=out_mask_logits.device) by_idx[idx] = combined.float().cpu() @@ -2095,6 +2158,11 @@ class OpenShotSam2VideoSegmentationChunked: for obj_id in sorted(seen_obj_ids) if str(obj_id) in carries } + _sam2_debug( + "segment_windowed-end", + "kept_carries=", sorted([int(v) for v in seen_obj_ids]), + "next_frame_idx=", int(inference_state.get("next_frame_idx", 0) or 0), + ) for i in range(bsz): out_chunks.append(by_idx.get(i, torch.zeros((h, w), dtype=torch.float32)))