diff --git a/dlclivegui/processors/dlc_processor_socket.py b/dlclivegui/processors/dlc_processor_socket.py index d999690..8ded010 100644 --- a/dlclivegui/processors/dlc_processor_socket.py +++ b/dlclivegui/processors/dlc_processor_socket.py @@ -209,12 +209,6 @@ def _accept_loop(self): try: conn = self.listener.accept() - # Apply safe timeout to client socket - try: - conn._socket.settimeout(self._socket_timeout) - except Exception: - pass - logger.debug(f"Client connected from {self.listener.last_accepted}") self.conns.add(conn) @@ -238,7 +232,7 @@ def _rx_loop(self, conn): self._handle_client_message(msg) continue - if getattr(conn._socket, "_closed", False): + if conn.closed: raise EOFError except (EOFError, OSError, ConnectionError, BrokenPipeError): @@ -253,10 +247,6 @@ def _rx_loop(self, conn): def _close_conn(self, conn): """Force-close client connection.""" - try: - conn._socket.shutdown(socket.SHUT_RDWR) - except Exception: - pass try: conn.close() except Exception: @@ -397,12 +387,11 @@ def broadcast(self, payload): def process(self, pose, **kwargs): curr_time = self.timing_func() - if self.save_original: - self.original_pose.append(pose.copy()) - self.curr_step += 1 if self.recording: + if self.save_original and self.original_pose is not None: + self.original_pose.append(pose.copy()) self.time_stamp.append(curr_time) self.step.append(self.curr_step) self.frame_time.append(kwargs.get("frame_time", -1)) @@ -578,9 +567,6 @@ def _initialize_filters(self, vals): logger.debug(f"Initialized One-Euro filters with parameters: {self.filter_kwargs}") def process(self, pose, **kwargs): - if self.save_original: - self.original_pose.append(pose.copy()) - # Extract keypoints and confidence xy = pose[:, :2] conf = pose[:, 2] @@ -633,6 +619,8 @@ def process(self, pose, **kwargs): # Store processed data (only if recording) if self.recording: + if self.save_original and self.original_pose is not None: + self.original_pose.append(pose.copy()) self.center_x.append(vals[0]) self.center_y.append(vals[1]) self.heading_direction.append(vals[2]) @@ -690,7 +678,7 @@ class ExampleProcessorSocketFilterKeypoints(BaseProcessorSocket): # pragma: no }, "save_original": { "type": "bool", - "default": False, + "default": True, "description": "Save raw pose arrays for analysis", }, } @@ -702,7 +690,7 @@ def __init__( use_perf_counter=False, use_filter=False, filter_kwargs: dict | None = None, - save_original=False, + save_original=True, p_cutoff=0.4, ): super().__init__( @@ -741,9 +729,6 @@ def _initialize_filters(self, vals): logger.debug(f"Initialized One-Euro filters with parameters: {self.filter_kwargs}") def process(self, pose, **kwargs): - if self.save_original: - self.original_pose.append(pose.copy()) - # Extract keypoints and confidence xy = pose[:, :2] conf = pose[:, 2] @@ -801,6 +786,8 @@ def process(self, pose, **kwargs): # Store processed data (only if recording) if self.recording: + if self.save_original and self.original_pose is not None: + self.original_pose.append(pose.copy()) self.center_x.append(vals[0]) self.center_y.append(vals[1]) self.heading_direction.append(vals[2]) diff --git a/tests/custom_processors/test_base_processor.py b/tests/custom_processors/test_base_processor.py index da66a40..d38749b 100644 --- a/tests/custom_processors/test_base_processor.py +++ b/tests/custom_processors/test_base_processor.py @@ -120,7 +120,7 @@ def test_base_process_without_and_with_recording(socket_mod): BaseProcessorSocket.process() should: - increment curr_step always, - when recording, append time/step/frame_time/pose_time, - - when save_original=True, store copies of pose arrays. + - when save_original=True, store copies of pose arrays only while recording. """ BaseProcessorSocket = socket_mod.BaseProcessorSocket proc = BaseProcessorSocket(bind=("127.0.0.1", 0), save_original=True) @@ -136,10 +136,9 @@ def test_base_process_without_and_with_recording(socket_mod): assert len(proc.step) == 0 assert len(proc.frame_time) == 0 assert len(proc.pose_time) == 0 - # When not recording, save_original is still respected + # Raw poses must stay aligned with recorded metadata. assert proc.original_pose is not None - assert len(proc.original_pose) == 1 - np.testing.assert_allclose(proc.original_pose[0], pose) + assert len(proc.original_pose) == 0 # Start recording and push two frames proc._handle_client_message({"cmd": "start_recording"}) @@ -150,6 +149,9 @@ def test_base_process_without_and_with_recording(socket_mod): assert len(proc.step) == 2 assert len(proc.frame_time) == 2 assert len(proc.pose_time) == 2 + assert len(proc.original_pose) == 2 + np.testing.assert_allclose(proc.original_pose[0], pose) + np.testing.assert_allclose(proc.original_pose[1], pose) # Data snapshot integrity data = proc.get_data() @@ -166,6 +168,128 @@ def test_base_process_without_and_with_recording(socket_mod): proc.stop() +def test_save_ignores_pre_recording_original_pose_frames(socket_mod): + """ + save_original data must stay aligned with recorded metadata even if process() + is called before recording starts. + """ + BaseProcessorSocket = socket_mod.BaseProcessorSocket + proc = BaseProcessorSocket(bind=("127.0.0.1", 0), save_original=True) + + try: + n_keypoints = 4 + bodyparts = _mk_bodyparts(n_keypoints) + proc.set_dlc_cfg({"metadata": {"bodyparts": bodyparts}}) + + pose = _mk_pose(n_keypoints=n_keypoints) + + for _ in range(3): + proc.process(pose, frame_time=0.001, pose_time=0.002) + + assert len(proc.original_pose) == 0 + assert len(proc.frame_time) == 0 + + proc._handle_client_message({"cmd": "start_recording"}) + for _ in range(2): + proc.process(pose, frame_time=0.01, pose_time=0.02) + proc._handle_client_message({"cmd": "stop_recording"}) + + filename = "unit_test_pre_recording_frames.pkl" + ret = proc.save(filename) + assert ret == 1 + + data_dir = _module_data_dir(socket_mod) + pkl_path = data_dir / filename + h5_path = data_dir / (Path(filename).stem + "_DLC.hdf5") + + assert pkl_path.exists() + assert h5_path.exists() + + with open(pkl_path, "rb") as f: + payload = pickle.load(f) + + assert len(payload["frame_time"]) == 2 + assert len(payload["time_stamp"]) == 2 + + pytest.importorskip("tables") + df = pd.read_hdf(h5_path, key="df_with_missing") + assert df.shape[0] == 2 + assert list(df["frame_time"]) == [0.01, 0.01] + assert list(df["pose_time"]) == list(payload["time_stamp"]) + + finally: + proc.stop() + try: + pkl_path.unlink(missing_ok=True) + h5_path.unlink(missing_ok=True) + except Exception: + pass + + +@pytest.mark.parametrize( + ("class_name", "n_keypoints"), + [ + ("ExampleProcessorSocketCalculateMousePose", 27), + ("ExampleProcessorSocketFilterKeypoints", 10), + ], +) +def test_subclass_save_ignores_pre_recording_original_pose_frames(socket_mod, class_name, n_keypoints): + """ + Concrete processors must keep original_pose aligned with recorded metadata + even when process() is called before recording starts. + """ + processor_class = getattr(socket_mod, class_name) + proc = processor_class(bind=("127.0.0.1", 0), save_original=True) + + try: + bodyparts = _mk_bodyparts(n_keypoints) + proc.set_dlc_cfg({"metadata": {"bodyparts": bodyparts}}) + + pose = _mk_pose(n_keypoints=n_keypoints) + + for _ in range(4): + proc.process(pose, frame_time=0.001, pose_time=0.002) + + assert len(proc.original_pose) == 0 + assert len(proc.frame_time) == 0 + + proc._handle_client_message({"cmd": "start_recording"}) + for _ in range(3): + proc.process(pose, frame_time=0.01, pose_time=0.02) + proc._handle_client_message({"cmd": "stop_recording"}) + + filename = f"unit_test_{class_name}.pkl" + ret = proc.save(filename) + assert ret == 1 + + data_dir = _module_data_dir(socket_mod) + pkl_path = data_dir / filename + h5_path = data_dir / (Path(filename).stem + "_DLC.hdf5") + + assert pkl_path.exists() + assert h5_path.exists() + + with open(pkl_path, "rb") as f: + payload = pickle.load(f) + + assert len(payload["frame_time"]) == 3 + assert len(payload["time_stamp"]) == 3 + + pytest.importorskip("tables") + df = pd.read_hdf(h5_path, key="df_with_missing") + assert df.shape[0] == 3 + assert list(df["frame_time"]) == [0.01, 0.01, 0.01] + assert list(df["pose_time"]) == list(payload["time_stamp"]) + + finally: + proc.stop() + try: + pkl_path.unlink(missing_ok=True) + h5_path.unlink(missing_ok=True) + except Exception: + pass + + def test_base_broadcast_handles_bad_connections(socket_mod): """ broadcast() must handle failing connections gracefully and drop them. @@ -270,6 +394,8 @@ def test_save_writes_pkl_and_hdf5_with_labels(socket_mod, caplog): # frame_time & pose_time columns are present assert "frame_time" in df.columns assert "pose_time" in df.columns + assert list(df["frame_time"]) == [0.01, 0.01, 0.01] + assert list(df["pose_time"]) == list(payload["time_stamp"]) # sanity check values for first row for i, bp in enumerate(bodyparts):