Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 9 additions & 22 deletions dlclivegui/processors/dlc_processor_socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,12 +209,6 @@ def _accept_loop(self):
try:
conn = self.listener.accept()

# Apply safe timeout to client socket
try:
conn._socket.settimeout(self._socket_timeout)
except Exception:
pass

logger.debug(f"Client connected from {self.listener.last_accepted}")
self.conns.add(conn)

Expand All @@ -238,7 +232,7 @@ def _rx_loop(self, conn):
self._handle_client_message(msg)
continue

if getattr(conn._socket, "_closed", False):
if conn.closed:
raise EOFError

except (EOFError, OSError, ConnectionError, BrokenPipeError):
Expand All @@ -253,10 +247,6 @@ def _rx_loop(self, conn):

def _close_conn(self, conn):
"""Force-close client connection."""
try:
conn._socket.shutdown(socket.SHUT_RDWR)
except Exception:
pass
try:
conn.close()
except Exception:
Expand Down Expand Up @@ -397,12 +387,11 @@ def broadcast(self, payload):
def process(self, pose, **kwargs):
curr_time = self.timing_func()

if self.save_original:
self.original_pose.append(pose.copy())

self.curr_step += 1

if self.recording:
if self.save_original and self.original_pose is not None:
self.original_pose.append(pose.copy())
self.time_stamp.append(curr_time)
self.step.append(self.curr_step)
self.frame_time.append(kwargs.get("frame_time", -1))
Expand Down Expand Up @@ -578,9 +567,6 @@ def _initialize_filters(self, vals):
logger.debug(f"Initialized One-Euro filters with parameters: {self.filter_kwargs}")

def process(self, pose, **kwargs):
if self.save_original:
self.original_pose.append(pose.copy())

# Extract keypoints and confidence
xy = pose[:, :2]
conf = pose[:, 2]
Expand Down Expand Up @@ -633,6 +619,8 @@ def process(self, pose, **kwargs):

# Store processed data (only if recording)
if self.recording:
if self.save_original and self.original_pose is not None:
self.original_pose.append(pose.copy())
self.center_x.append(vals[0])
self.center_y.append(vals[1])
self.heading_direction.append(vals[2])
Expand Down Expand Up @@ -690,7 +678,7 @@ class ExampleProcessorSocketFilterKeypoints(BaseProcessorSocket): # pragma: no
},
"save_original": {
"type": "bool",
"default": False,
"default": True,
"description": "Save raw pose arrays for analysis",
},
}
Expand All @@ -702,7 +690,7 @@ def __init__(
use_perf_counter=False,
use_filter=False,
filter_kwargs: dict | None = None,
save_original=False,
save_original=True,
p_cutoff=0.4,
):
super().__init__(
Expand Down Expand Up @@ -741,9 +729,6 @@ def _initialize_filters(self, vals):
logger.debug(f"Initialized One-Euro filters with parameters: {self.filter_kwargs}")

def process(self, pose, **kwargs):
if self.save_original:
self.original_pose.append(pose.copy())

# Extract keypoints and confidence
xy = pose[:, :2]
conf = pose[:, 2]
Expand Down Expand Up @@ -801,6 +786,8 @@ def process(self, pose, **kwargs):

# Store processed data (only if recording)
if self.recording:
if self.save_original and self.original_pose is not None:
self.original_pose.append(pose.copy())
self.center_x.append(vals[0])
self.center_y.append(vals[1])
self.heading_direction.append(vals[2])
Expand Down
134 changes: 130 additions & 4 deletions tests/custom_processors/test_base_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def test_base_process_without_and_with_recording(socket_mod):
BaseProcessorSocket.process() should:
- increment curr_step always,
- when recording, append time/step/frame_time/pose_time,
- when save_original=True, store copies of pose arrays.
- when save_original=True, store copies of pose arrays only while recording.
"""
BaseProcessorSocket = socket_mod.BaseProcessorSocket
proc = BaseProcessorSocket(bind=("127.0.0.1", 0), save_original=True)
Expand All @@ -136,10 +136,9 @@ def test_base_process_without_and_with_recording(socket_mod):
assert len(proc.step) == 0
assert len(proc.frame_time) == 0
assert len(proc.pose_time) == 0
# When not recording, save_original is still respected
# Raw poses must stay aligned with recorded metadata.
assert proc.original_pose is not None
assert len(proc.original_pose) == 1
np.testing.assert_allclose(proc.original_pose[0], pose)
assert len(proc.original_pose) == 0

# Start recording and push two frames
proc._handle_client_message({"cmd": "start_recording"})
Expand All @@ -150,6 +149,9 @@ def test_base_process_without_and_with_recording(socket_mod):
assert len(proc.step) == 2
assert len(proc.frame_time) == 2
assert len(proc.pose_time) == 2
assert len(proc.original_pose) == 2
np.testing.assert_allclose(proc.original_pose[0], pose)
np.testing.assert_allclose(proc.original_pose[1], pose)

# Data snapshot integrity
data = proc.get_data()
Expand All @@ -166,6 +168,128 @@ def test_base_process_without_and_with_recording(socket_mod):
proc.stop()


def test_save_ignores_pre_recording_original_pose_frames(socket_mod):
"""
save_original data must stay aligned with recorded metadata even if process()
is called before recording starts.
"""
BaseProcessorSocket = socket_mod.BaseProcessorSocket
proc = BaseProcessorSocket(bind=("127.0.0.1", 0), save_original=True)

try:
n_keypoints = 4
bodyparts = _mk_bodyparts(n_keypoints)
proc.set_dlc_cfg({"metadata": {"bodyparts": bodyparts}})

pose = _mk_pose(n_keypoints=n_keypoints)

for _ in range(3):
proc.process(pose, frame_time=0.001, pose_time=0.002)

assert len(proc.original_pose) == 0
assert len(proc.frame_time) == 0

proc._handle_client_message({"cmd": "start_recording"})
for _ in range(2):
proc.process(pose, frame_time=0.01, pose_time=0.02)
proc._handle_client_message({"cmd": "stop_recording"})

filename = "unit_test_pre_recording_frames.pkl"
ret = proc.save(filename)
assert ret == 1

data_dir = _module_data_dir(socket_mod)
pkl_path = data_dir / filename
h5_path = data_dir / (Path(filename).stem + "_DLC.hdf5")

assert pkl_path.exists()
assert h5_path.exists()

with open(pkl_path, "rb") as f:
payload = pickle.load(f)

assert len(payload["frame_time"]) == 2
assert len(payload["time_stamp"]) == 2

pytest.importorskip("tables")
df = pd.read_hdf(h5_path, key="df_with_missing")
assert df.shape[0] == 2
assert list(df["frame_time"]) == [0.01, 0.01]
assert list(df["pose_time"]) == list(payload["time_stamp"])

finally:
proc.stop()
try:
pkl_path.unlink(missing_ok=True)
h5_path.unlink(missing_ok=True)
except Exception:
pass


@pytest.mark.parametrize(
("class_name", "n_keypoints"),
[
("ExampleProcessorSocketCalculateMousePose", 27),
("ExampleProcessorSocketFilterKeypoints", 10),
],
)
def test_subclass_save_ignores_pre_recording_original_pose_frames(socket_mod, class_name, n_keypoints):
"""
Concrete processors must keep original_pose aligned with recorded metadata
even when process() is called before recording starts.
"""
processor_class = getattr(socket_mod, class_name)
proc = processor_class(bind=("127.0.0.1", 0), save_original=True)

try:
bodyparts = _mk_bodyparts(n_keypoints)
proc.set_dlc_cfg({"metadata": {"bodyparts": bodyparts}})

pose = _mk_pose(n_keypoints=n_keypoints)

for _ in range(4):
proc.process(pose, frame_time=0.001, pose_time=0.002)

assert len(proc.original_pose) == 0
assert len(proc.frame_time) == 0

proc._handle_client_message({"cmd": "start_recording"})
for _ in range(3):
proc.process(pose, frame_time=0.01, pose_time=0.02)
proc._handle_client_message({"cmd": "stop_recording"})

filename = f"unit_test_{class_name}.pkl"
ret = proc.save(filename)
assert ret == 1

data_dir = _module_data_dir(socket_mod)
pkl_path = data_dir / filename
h5_path = data_dir / (Path(filename).stem + "_DLC.hdf5")

assert pkl_path.exists()
assert h5_path.exists()

with open(pkl_path, "rb") as f:
payload = pickle.load(f)

assert len(payload["frame_time"]) == 3
assert len(payload["time_stamp"]) == 3

pytest.importorskip("tables")
df = pd.read_hdf(h5_path, key="df_with_missing")
assert df.shape[0] == 3
assert list(df["frame_time"]) == [0.01, 0.01, 0.01]
assert list(df["pose_time"]) == list(payload["time_stamp"])

finally:
proc.stop()
try:
pkl_path.unlink(missing_ok=True)
h5_path.unlink(missing_ok=True)
except Exception:
pass


def test_base_broadcast_handles_bad_connections(socket_mod):
"""
broadcast() must handle failing connections gracefully and drop them.
Expand Down Expand Up @@ -270,6 +394,8 @@ def test_save_writes_pkl_and_hdf5_with_labels(socket_mod, caplog):
# frame_time & pose_time columns are present
assert "frame_time" in df.columns
assert "pose_time" in df.columns
assert list(df["frame_time"]) == [0.01, 0.01, 0.01]
assert list(df["pose_time"]) == list(payload["time_stamp"])

# sanity check values for first row
for i, bp in enumerate(bodyparts):
Expand Down
Loading