diff --git a/internal/desktop/webrtc.go b/internal/desktop/webrtc.go index 065a5d8..d057bb6 100644 --- a/internal/desktop/webrtc.go +++ b/internal/desktop/webrtc.go @@ -16,6 +16,7 @@ import ( const ( desktopReliableInputChannelLabel = "input" desktopMoveInputChannelLabel = "input-move" + desktopICEGatheringTimeout = 10 * time.Second ) type WebRTCServer struct { @@ -230,7 +231,9 @@ func (w *WebRTCServer) ProcessOffer(sdpOffer string) (string, error) { return "", fmt.Errorf("failed to set local description: %w", err) } - <-gatherComplete + if err := waitForICEGatheringComplete(gatherComplete, desktopICEGatheringTimeout); err != nil { + return "", err + } localDesc := pc.LocalDescription() if localDesc == nil { @@ -240,6 +243,15 @@ func (w *WebRTCServer) ProcessOffer(sdpOffer string) (string, error) { return localDesc.SDP, nil } +func waitForICEGatheringComplete(done <-chan struct{}, timeout time.Duration) error { + select { + case <-done: + return nil + case <-time.After(timeout): + return fmt.Errorf("timed out waiting for ICE gathering after %s", timeout) + } +} + // AddICECandidate adds a remote ICE candidate func (w *WebRTCServer) AddICECandidate(candidate webrtc.ICECandidateInit) error { w.mu.Lock() diff --git a/internal/desktop/webrtc_test.go b/internal/desktop/webrtc_test.go index 4eb8bec..712bc0d 100644 --- a/internal/desktop/webrtc_test.go +++ b/internal/desktop/webrtc_test.go @@ -1,6 +1,9 @@ package desktop -import "testing" +import ( + "testing" + "time" +) func TestIsDesktopInputDataChannelLabelAllowsReliableAndMoveChannels(t *testing.T) { if !isDesktopInputDataChannelLabel(desktopReliableInputChannelLabel) { @@ -13,3 +16,25 @@ func TestIsDesktopInputDataChannelLabelAllowsReliableAndMoveChannels(t *testing. t.Fatalf("expected unrelated data channel label to be ignored") } } + +func TestWaitForICEGatheringCompleteReturnsWhenDoneCloses(t *testing.T) { + done := make(chan struct{}) + close(done) + + if err := waitForICEGatheringComplete(done, time.Second); err != nil { + t.Fatalf("expected closed gathering channel to succeed: %v", err) + } +} + +func TestWaitForICEGatheringCompleteTimesOut(t *testing.T) { + done := make(chan struct{}) + + start := time.Now() + err := waitForICEGatheringComplete(done, 10*time.Millisecond) + if err == nil { + t.Fatalf("expected timeout error") + } + if elapsed := time.Since(start); elapsed > time.Second { + t.Fatalf("timeout helper waited too long: %s", elapsed) + } +}