From de3d6372960fd2dda66f4697196a80e7effbd2c4 Mon Sep 17 00:00:00 2001 From: Haitao Pan Date: Fri, 29 May 2026 10:41:46 +0800 Subject: [PATCH] fix: stabilize openclaw concurrent gateway sessions --- internal/acp/orchestrator.go | 6 +- internal/acp/web_contract_test.go | 113 ++++++++++++++++++++ internal/gatewayruntime/runtime.go | 111 +++++++++++++++++++- internal/gatewayruntime/runtime_test.go | 131 ++++++++++++++++++++++++ 4 files changed, 355 insertions(+), 6 deletions(-) diff --git a/internal/acp/orchestrator.go b/internal/acp/orchestrator.go index ad130dc..c4ffbf2 100644 --- a/internal/acp/orchestrator.go +++ b/internal/acp/orchestrator.go @@ -1233,11 +1233,13 @@ func (o *SessionOrchestrator) openClawGatewayRequestWithRetry( func isOpenClawRetryableGatewayError(errorPayload map[string]any) bool { code := strings.TrimSpace(strings.ToUpper(shared.StringArg(errorPayload, "code", ""))) - if code == "SOCKET_CLOSED" || code == "SOCKET_FAILURE" || code == "OFFLINE" { + if code == "SOCKET_CLOSED" || code == "SOCKET_FAILURE" || code == "OFFLINE" || code == "INVALID_HANDSHAKE" { return true } message := strings.TrimSpace(strings.ToLower(shared.StringArg(errorPayload, "message", ""))) - return strings.Contains(message, "socket closed") + return strings.Contains(message, "socket closed") || + strings.Contains(message, "invalid handshake") || + strings.Contains(message, "first request must be connect") } func firstNonEmptyString(values map[string]any, keys ...string) string { diff --git a/internal/acp/web_contract_test.go b/internal/acp/web_contract_test.go index 0b88a11..d1f8ed9 100644 --- a/internal/acp/web_contract_test.go +++ b/internal/acp/web_contract_test.go @@ -324,6 +324,119 @@ func TestHTTPHandlerGatewayOpenClawAdmissionQueuesExcessConcurrentSSE(t *testing } } +func TestHTTPHandlerGatewayOpenClawHandlesFiveConcurrentE2ECases(t *testing.T) { + gateway := newAcpFakeOpenClawGateway(t) + defer gateway.Close() + gateway.agentWaitDelayMs.Store(200) + + t.Setenv("GATEWAY_RPC_URL", gateway.URL()) + t.Setenv("BRIDGE_AUTH_TOKEN", "bridge-test-token") + t.Setenv("BRIDGE_CONFIG_PATH", filepath.Join(t.TempDir(), "missing-config.yaml")) + t.Setenv("XWORKMATE_BRIDGE_OPENCLAW_GATEWAY_MAX_ACTIVE", "5") + t.Setenv("XWORKMATE_BRIDGE_OPENCLAW_GATEWAY_MAX_QUEUED", "20") + t.Setenv("XWORKMATE_BRIDGE_OPENCLAW_GATEWAY_QUEUE_TIMEOUT", "5s") + server := NewServer() + httpServer := httptest.NewServer(server.Handler()) + defer httpServer.Close() + + prompts := []string{ + "从单机权限 → 网络边界 → Web安全 → 云身份 → Zero Trust → AI Agent 身份 → AI模型与知识保护 演进 制作 使用codex 制作连续制作 7张的一些列图片", + "参考附件模版制作 ,围绕 从单机权限 → 网络边界 → Web安全 → 云身份 → Zero Trust → AI Agent 身份 → AI模型与知识保护 演进 连续制作 7张的一些列图片", + "拆章节 -> 每章调用 Codex -> 每章 GPT images2 生成图 -> 汇总排版 -> 输出 PDF make artifact", + "围绕 从单机权限 → 网络边界 → Web安全 → 云身份 → Zero Trust → AI Agent 身份 → AI模型与知识保护 演进 右侧是当下 测试制作视频", + "从单机权限 → 网络边界 → Web安全 → 云身份 → Zero Trust → AI Agent 身份 → AI模型与知识保护 演进 拆章节 -> 每章调用 Codex -> 每章 GPT images2 生成图 -> 汇总排版 -> 制作视频", + } + type result struct { + body string + err error + } + results := make(chan result, len(prompts)) + start := make(chan struct{}) + var wg sync.WaitGroup + for index, prompt := range prompts { + wg.Add(1) + go func(index int, prompt string) { + defer wg.Done() + <-start + body := fmt.Sprintf( + `{"jsonrpc":"2.0","id":"e2e-%d","method":"session.start","params":{"sessionId":"e2e-s%d","threadId":"e2e-t%d","taskPrompt":%q,"workingDirectory":%q,"routing":{"routingMode":"explicit","explicitExecutionTarget":"gateway","preferredGatewayProviderId":"openclaw"}}}`, + index, + index, + index, + prompt, + t.TempDir(), + ) + request, err := http.NewRequest( + http.MethodPost, + httpServer.URL+"/acp/rpc", + strings.NewReader(body), + ) + if err != nil { + results <- result{err: err} + return + } + request.Header.Set("Content-Type", "application/json") + request.Header.Set("Accept", "text/event-stream") + request.Header.Set("Authorization", "Bearer bridge-test-token") + response, err := http.DefaultClient.Do(request) + if err != nil { + results <- result{err: err} + return + } + defer func() { _ = response.Body.Close() }() + responseBody, err := io.ReadAll(response.Body) + if err != nil { + results <- result{err: err} + return + } + if response.StatusCode != http.StatusOK { + results <- result{err: fmt.Errorf("expected 200, got %d: %s", response.StatusCode, string(responseBody))} + return + } + results <- result{body: string(responseBody)} + }(index, prompt) + } + close(start) + waitForOpenClawGatewayCount(t, gateway.ChatSendCount, len(prompts)) + wg.Wait() + close(results) + + var finalCount int + for item := range results { + if item.err != nil { + t.Fatalf("concurrent e2e request failed: %v", item.err) + } + if strings.Contains(item.body, `"event":"queued"`) { + t.Fatalf("expected five active OpenClaw slots without queueing, got queued event: %s", item.body) + } + for _, unexpected := range []string{ + "invalid handshake", + "SOCKET_CLOSED", + "ACP_HTTP_CONNECTION_CLOSED", + "GATEWAY_CONNECT_FAILED", + } { + if strings.Contains(item.body, unexpected) { + t.Fatalf("unexpected gateway stability error %q in body: %s", unexpected, item.body) + } + } + if strings.Contains(item.body, `"result"`) && strings.Contains(item.body, `data: [DONE]`) { + finalCount += 1 + } + } + if finalCount != len(prompts) { + t.Fatalf("expected all five e2e requests to return final result, got %d", finalCount) + } + if got := gateway.ConnectCount(); got != 1 { + t.Fatalf("expected bridge to reuse one established OpenClaw connection, got %d connects", got) + } + if got := gateway.ChatSendCount(); got != len(prompts) { + t.Fatalf("expected five chat.send calls, got %d", got) + } + if got := gateway.AgentWaitCount(); got != len(prompts) { + t.Fatalf("expected five agent.wait calls, got %d", got) + } +} + func TestHTTPHandlerGatewayOpenClawAdmissionRejectsWhenQueueFull(t *testing.T) { gateway := newAcpFakeOpenClawGateway(t) defer gateway.Close() diff --git a/internal/gatewayruntime/runtime.go b/internal/gatewayruntime/runtime.go index 50fa361..c93c4ac 100644 --- a/internal/gatewayruntime/runtime.go +++ b/internal/gatewayruntime/runtime.go @@ -138,8 +138,7 @@ func (m *Manager) Connect( } m.mu.Unlock() - current.configure(request, notify) - return current.connect() + return current.connect(request, notify) } func (m *Manager) Request( @@ -223,6 +222,7 @@ type session struct { runtimeID string mu sync.Mutex + connectMu sync.Mutex writeMu sync.Mutex notify func(map[string]any) config ConnectRequest @@ -273,7 +273,17 @@ func (s *session) setNotify(notify func(map[string]any)) { s.notify = notify } -func (s *session) connect() ConnectResult { +func (s *session) connect(request ConnectRequest, notify func(map[string]any)) ConnectResult { + s.connectMu.Lock() + defer s.connectMu.Unlock() + + if snapshot, ok := s.connectedSnapshotFor(request, notify); ok { + return ConnectResult{ + OK: true, + Snapshot: snapshot, + } + } + s.configure(request, notify) s.appendLog( "info", "connect", @@ -310,6 +320,29 @@ func (s *session) connect() ConnectResult { } } +func (s *session) connectedSnapshotFor( + request ConnectRequest, + notify func(map[string]any), +) (map[string]any, bool) { + s.mu.Lock() + defer s.mu.Unlock() + if s.conn == nil || s.snapshot.Status != "connected" { + return nil, false + } + if !sameConnectTarget(s.config, request) { + return nil, false + } + s.notify = notify + return s.snapshot.Map(), true +} + +func sameConnectTarget(current ConnectRequest, next ConnectRequest) bool { + return strings.TrimSpace(current.Mode) == strings.TrimSpace(next.Mode) && + strings.TrimSpace(current.Endpoint.Host) == strings.TrimSpace(next.Endpoint.Host) && + current.Endpoint.Port == next.Endpoint.Port && + current.Endpoint.TLS == next.Endpoint.TLS +} + func (s *session) connectAttempt() (ConnectResult, *GatewayError) { url := fmt.Sprintf( "%s://%s:%d", @@ -352,7 +385,7 @@ func (s *session) connectAttempt() (ConnectResult, *GatewayError) { s.closeConn(conn) return ConnectResult{}, gatewayErr } - requestResult := s.requestRemote("connect", params, 12*time.Second, false) + requestResult := s.requestRemoteOnConn("connect", params, 12*time.Second, false, conn) if !requestResult.OK { s.closeConn(conn) return ConnectResult{}, mapToGatewayError(requestResult.Error, "connect failed") @@ -434,6 +467,16 @@ func (s *session) requestRemote( params map[string]any, timeout time.Duration, requireConnected bool, +) RequestResult { + return s.requestRemoteOnConn(method, params, timeout, requireConnected, nil) +} + +func (s *session) requestRemoteOnConn( + method string, + params map[string]any, + timeout time.Duration, + requireConnected bool, + boundConn *websocket.Conn, ) RequestResult { if timeout <= 0 { timeout = defaultRequestTimeout @@ -441,6 +484,9 @@ func (s *session) requestRemote( s.mu.Lock() conn := s.conn + if boundConn != nil { + conn = boundConn + } connected := s.snapshot.Status == "connected" if conn == nil || (requireConnected && !connected) { s.mu.Unlock() @@ -491,6 +537,9 @@ func (s *session) requestRemote( s.mu.Unlock() if !response.OK { gatewayErr := parseRemoteError(response.Error) + if isInvalidHandshakeGatewayError(gatewayErr) { + s.resetConnAfterProtocolError(conn, gatewayErr) + } if !shouldAutoReconnectForCodes( gatewayErr.Code, gatewayErr.DetailCode(), @@ -532,6 +581,60 @@ func (s *session) requestRemote( } } +func (s *session) resetConnAfterProtocolError(conn *websocket.Conn, err *GatewayError) { + if conn == nil { + return + } + message := "gateway protocol handshake failed" + code := "INVALID_HANDSHAKE" + if err != nil { + if strings.TrimSpace(err.Message) != "" { + message = strings.TrimSpace(err.Message) + } + if strings.TrimSpace(err.Code) != "" { + code = strings.TrimSpace(err.Code) + } + } + s.mu.Lock() + if s.conn != conn { + s.mu.Unlock() + return + } + s.conn = nil + pending := s.takePendingLocked() + s.snapshot.Status = "error" + s.snapshot.StatusText = "Disconnected" + s.snapshot.LastError = message + s.snapshot.LastErrorCode = code + s.snapshot.LastErrorDetailCode = "" + s.mu.Unlock() + + for _, ch := range pending { + ch <- remoteResponse{ + OK: false, + Error: (&GatewayError{ + Message: "socket closed", + Code: "SOCKET_CLOSED", + }).Map(), + } + } + _ = conn.Close() + s.appendLog("warn", "socket", message) +} + +func isInvalidHandshakeGatewayError(err *GatewayError) bool { + if err == nil { + return false + } + code := strings.TrimSpace(strings.ToUpper(err.Code)) + if code == "INVALID_HANDSHAKE" { + return true + } + message := strings.TrimSpace(strings.ToLower(err.Message)) + return strings.Contains(message, "invalid handshake") || + strings.Contains(message, "first request must be connect") +} + func (s *session) disconnect() { s.mu.Lock() s.manualDisconnect = true diff --git a/internal/gatewayruntime/runtime_test.go b/internal/gatewayruntime/runtime_test.go index 2d14caa..dc76686 100644 --- a/internal/gatewayruntime/runtime_test.go +++ b/internal/gatewayruntime/runtime_test.go @@ -92,6 +92,102 @@ func TestManagerReconnectsAfterSocketClose(t *testing.T) { } } +func TestManagerSerializesConcurrentConnectReuseBeforeRequests(t *testing.T) { + server := newFakeGatewayServer(t) + server.enforceConnectFirst.Store(true) + defer server.Close() + + manager := NewManager() + manager.ReconnectDelay = 20 * time.Millisecond + const workers = 5 + start := make(chan struct{}) + errs := make(chan string, workers) + var wg sync.WaitGroup + for index := 0; index < workers; index++ { + wg.Add(1) + go func() { + defer wg.Done() + <-start + connectResult := manager.Connect( + buildTestConnectRequest(server.Port()), + func(map[string]any) {}, + ) + if !connectResult.OK { + errs <- "connect failed: " + stringValue(connectResult.Error["message"]) + return + } + requestResult := manager.Request( + "runtime-1", + "chat.send", + map[string]any{"message": "pong"}, + 2*time.Second, + func(map[string]any) {}, + ) + if !requestResult.OK { + errs <- "request failed: " + stringValue(requestResult.Error["message"]) + return + } + }() + } + close(start) + wg.Wait() + close(errs) + + for err := range errs { + t.Fatal(err) + } + if got := server.InvalidHandshakeCount(); got != 0 { + t.Fatalf("expected no invalid handshake, got %d", got) + } + if got := server.ConnectCount(); got != 1 { + t.Fatalf("expected concurrent connect calls to reuse one established gateway session, got %d", got) + } +} + +func TestManagerDropsConnectionAfterInvalidHandshake(t *testing.T) { + server := newFakeGatewayServer(t) + defer server.Close() + + manager := NewManager() + result := manager.Connect(buildTestConnectRequest(server.Port()), func(map[string]any) {}) + if !result.OK { + t.Fatalf("expected connect success, got %#v", result.Error) + } + + server.invalidNextRequest.Store(true) + failed := manager.Request( + "runtime-1", + "chat.send", + map[string]any{"message": "pong"}, + 2*time.Second, + func(map[string]any) {}, + ) + if failed.OK { + t.Fatalf("expected invalid handshake failure") + } + if got := stringValue(failed.Error["code"]); got != "INVALID_HANDSHAKE" { + t.Fatalf("expected invalid handshake code, got %#v", failed.Error) + } + + reconnected := manager.Connect(buildTestConnectRequest(server.Port()), func(map[string]any) {}) + if !reconnected.OK { + t.Fatalf("expected reconnect success, got %#v", reconnected.Error) + } + requestResult := manager.Request( + "runtime-1", + "chat.send", + map[string]any{"message": "pong"}, + 2*time.Second, + func(map[string]any) {}, + ) + if !requestResult.OK { + t.Fatalf("expected request after reconnect to succeed, got %#v", requestResult.Error) + } + if got := server.ConnectCount(); got != 2 { + t.Fatalf("expected reconnect after invalid handshake, got %d connects", got) + } +} + func TestManagerSuppressesReconnectForPairingRequired(t *testing.T) { server := newFakeGatewayServer(t) server.connectErrorCode = "NOT_PAIRED" @@ -179,7 +275,10 @@ type fakeGatewayServer struct { server *http.Server listener net.Listener connectCount atomic.Int32 + invalidHandshakeCount atomic.Int32 closeAfterConnect atomic.Bool + enforceConnectFirst atomic.Bool + invalidNextRequest atomic.Bool connectErrorCode string connectErrorDetailCode string } @@ -208,6 +307,7 @@ func newFakeGatewayServer(t *testing.T) *fakeGatewayServer { "nonce": "nonce-1", }, }) + connected := false for { _, payload, err := conn.ReadMessage() if err != nil { @@ -222,9 +322,36 @@ func newFakeGatewayServer(t *testing.T) *fakeGatewayServer { } id := frame["id"] method := stringValue(frame["method"]) + if fake.enforceConnectFirst.Load() && !connected && method != "connect" { + fake.invalidHandshakeCount.Add(1) + _ = conn.WriteJSON(map[string]any{ + "type": "res", + "id": id, + "ok": false, + "error": map[string]any{ + "code": "INVALID_HANDSHAKE", + "message": "invalid handshake: first request must be connect", + }, + }) + continue + } + if fake.invalidNextRequest.Swap(false) && method != "connect" { + fake.invalidHandshakeCount.Add(1) + _ = conn.WriteJSON(map[string]any{ + "type": "res", + "id": id, + "ok": false, + "error": map[string]any{ + "code": "INVALID_HANDSHAKE", + "message": "invalid handshake: first request must be connect", + }, + }) + continue + } switch method { case "connect": fake.connectCount.Add(1) + connected = true if fake.connectErrorCode != "" { _ = conn.WriteJSON(map[string]any{ "type": "res", @@ -296,6 +423,10 @@ func (f *fakeGatewayServer) ConnectCount() int { return int(f.connectCount.Load()) } +func (f *fakeGatewayServer) InvalidHandshakeCount() int { + return int(f.invalidHandshakeCount.Load()) +} + func (f *fakeGatewayServer) Close() { _ = f.server.Close() }