fix: stabilize openclaw concurrent gateway sessions

This commit is contained in:
Haitao Pan 2026-05-29 10:41:46 +08:00
parent 6d31a95f70
commit de3d637296
4 changed files with 355 additions and 6 deletions

View File

@ -1233,11 +1233,13 @@ func (o *SessionOrchestrator) openClawGatewayRequestWithRetry(
func isOpenClawRetryableGatewayError(errorPayload map[string]any) bool {
code := strings.TrimSpace(strings.ToUpper(shared.StringArg(errorPayload, "code", "")))
if code == "SOCKET_CLOSED" || code == "SOCKET_FAILURE" || code == "OFFLINE" {
if code == "SOCKET_CLOSED" || code == "SOCKET_FAILURE" || code == "OFFLINE" || code == "INVALID_HANDSHAKE" {
return true
}
message := strings.TrimSpace(strings.ToLower(shared.StringArg(errorPayload, "message", "")))
return strings.Contains(message, "socket closed")
return strings.Contains(message, "socket closed") ||
strings.Contains(message, "invalid handshake") ||
strings.Contains(message, "first request must be connect")
}
func firstNonEmptyString(values map[string]any, keys ...string) string {

View File

@ -324,6 +324,119 @@ func TestHTTPHandlerGatewayOpenClawAdmissionQueuesExcessConcurrentSSE(t *testing
}
}
func TestHTTPHandlerGatewayOpenClawHandlesFiveConcurrentE2ECases(t *testing.T) {
gateway := newAcpFakeOpenClawGateway(t)
defer gateway.Close()
gateway.agentWaitDelayMs.Store(200)
t.Setenv("GATEWAY_RPC_URL", gateway.URL())
t.Setenv("BRIDGE_AUTH_TOKEN", "bridge-test-token")
t.Setenv("BRIDGE_CONFIG_PATH", filepath.Join(t.TempDir(), "missing-config.yaml"))
t.Setenv("XWORKMATE_BRIDGE_OPENCLAW_GATEWAY_MAX_ACTIVE", "5")
t.Setenv("XWORKMATE_BRIDGE_OPENCLAW_GATEWAY_MAX_QUEUED", "20")
t.Setenv("XWORKMATE_BRIDGE_OPENCLAW_GATEWAY_QUEUE_TIMEOUT", "5s")
server := NewServer()
httpServer := httptest.NewServer(server.Handler())
defer httpServer.Close()
prompts := []string{
"从单机权限 → 网络边界 → Web安全 → 云身份 → Zero Trust → AI Agent 身份 → AI模型与知识保护 演进 制作 使用codex 制作连续制作 7张的一些列图片",
"参考附件模版制作 ,围绕 从单机权限 → 网络边界 → Web安全 → 云身份 → Zero Trust → AI Agent 身份 → AI模型与知识保护 演进 连续制作 7张的一些列图片",
"拆章节 -> 每章调用 Codex -> 每章 GPT images2 生成图 -> 汇总排版 -> 输出 PDF make artifact",
"围绕 从单机权限 → 网络边界 → Web安全 → 云身份 → Zero Trust → AI Agent 身份 → AI模型与知识保护 演进 右侧是当下 测试制作视频",
"从单机权限 → 网络边界 → Web安全 → 云身份 → Zero Trust → AI Agent 身份 → AI模型与知识保护 演进 拆章节 -> 每章调用 Codex -> 每章 GPT images2 生成图 -> 汇总排版 -> 制作视频",
}
type result struct {
body string
err error
}
results := make(chan result, len(prompts))
start := make(chan struct{})
var wg sync.WaitGroup
for index, prompt := range prompts {
wg.Add(1)
go func(index int, prompt string) {
defer wg.Done()
<-start
body := fmt.Sprintf(
`{"jsonrpc":"2.0","id":"e2e-%d","method":"session.start","params":{"sessionId":"e2e-s%d","threadId":"e2e-t%d","taskPrompt":%q,"workingDirectory":%q,"routing":{"routingMode":"explicit","explicitExecutionTarget":"gateway","preferredGatewayProviderId":"openclaw"}}}`,
index,
index,
index,
prompt,
t.TempDir(),
)
request, err := http.NewRequest(
http.MethodPost,
httpServer.URL+"/acp/rpc",
strings.NewReader(body),
)
if err != nil {
results <- result{err: err}
return
}
request.Header.Set("Content-Type", "application/json")
request.Header.Set("Accept", "text/event-stream")
request.Header.Set("Authorization", "Bearer bridge-test-token")
response, err := http.DefaultClient.Do(request)
if err != nil {
results <- result{err: err}
return
}
defer func() { _ = response.Body.Close() }()
responseBody, err := io.ReadAll(response.Body)
if err != nil {
results <- result{err: err}
return
}
if response.StatusCode != http.StatusOK {
results <- result{err: fmt.Errorf("expected 200, got %d: %s", response.StatusCode, string(responseBody))}
return
}
results <- result{body: string(responseBody)}
}(index, prompt)
}
close(start)
waitForOpenClawGatewayCount(t, gateway.ChatSendCount, len(prompts))
wg.Wait()
close(results)
var finalCount int
for item := range results {
if item.err != nil {
t.Fatalf("concurrent e2e request failed: %v", item.err)
}
if strings.Contains(item.body, `"event":"queued"`) {
t.Fatalf("expected five active OpenClaw slots without queueing, got queued event: %s", item.body)
}
for _, unexpected := range []string{
"invalid handshake",
"SOCKET_CLOSED",
"ACP_HTTP_CONNECTION_CLOSED",
"GATEWAY_CONNECT_FAILED",
} {
if strings.Contains(item.body, unexpected) {
t.Fatalf("unexpected gateway stability error %q in body: %s", unexpected, item.body)
}
}
if strings.Contains(item.body, `"result"`) && strings.Contains(item.body, `data: [DONE]`) {
finalCount += 1
}
}
if finalCount != len(prompts) {
t.Fatalf("expected all five e2e requests to return final result, got %d", finalCount)
}
if got := gateway.ConnectCount(); got != 1 {
t.Fatalf("expected bridge to reuse one established OpenClaw connection, got %d connects", got)
}
if got := gateway.ChatSendCount(); got != len(prompts) {
t.Fatalf("expected five chat.send calls, got %d", got)
}
if got := gateway.AgentWaitCount(); got != len(prompts) {
t.Fatalf("expected five agent.wait calls, got %d", got)
}
}
func TestHTTPHandlerGatewayOpenClawAdmissionRejectsWhenQueueFull(t *testing.T) {
gateway := newAcpFakeOpenClawGateway(t)
defer gateway.Close()

View File

@ -138,8 +138,7 @@ func (m *Manager) Connect(
}
m.mu.Unlock()
current.configure(request, notify)
return current.connect()
return current.connect(request, notify)
}
func (m *Manager) Request(
@ -223,6 +222,7 @@ type session struct {
runtimeID string
mu sync.Mutex
connectMu sync.Mutex
writeMu sync.Mutex
notify func(map[string]any)
config ConnectRequest
@ -273,7 +273,17 @@ func (s *session) setNotify(notify func(map[string]any)) {
s.notify = notify
}
func (s *session) connect() ConnectResult {
func (s *session) connect(request ConnectRequest, notify func(map[string]any)) ConnectResult {
s.connectMu.Lock()
defer s.connectMu.Unlock()
if snapshot, ok := s.connectedSnapshotFor(request, notify); ok {
return ConnectResult{
OK: true,
Snapshot: snapshot,
}
}
s.configure(request, notify)
s.appendLog(
"info",
"connect",
@ -310,6 +320,29 @@ func (s *session) connect() ConnectResult {
}
}
func (s *session) connectedSnapshotFor(
request ConnectRequest,
notify func(map[string]any),
) (map[string]any, bool) {
s.mu.Lock()
defer s.mu.Unlock()
if s.conn == nil || s.snapshot.Status != "connected" {
return nil, false
}
if !sameConnectTarget(s.config, request) {
return nil, false
}
s.notify = notify
return s.snapshot.Map(), true
}
func sameConnectTarget(current ConnectRequest, next ConnectRequest) bool {
return strings.TrimSpace(current.Mode) == strings.TrimSpace(next.Mode) &&
strings.TrimSpace(current.Endpoint.Host) == strings.TrimSpace(next.Endpoint.Host) &&
current.Endpoint.Port == next.Endpoint.Port &&
current.Endpoint.TLS == next.Endpoint.TLS
}
func (s *session) connectAttempt() (ConnectResult, *GatewayError) {
url := fmt.Sprintf(
"%s://%s:%d",
@ -352,7 +385,7 @@ func (s *session) connectAttempt() (ConnectResult, *GatewayError) {
s.closeConn(conn)
return ConnectResult{}, gatewayErr
}
requestResult := s.requestRemote("connect", params, 12*time.Second, false)
requestResult := s.requestRemoteOnConn("connect", params, 12*time.Second, false, conn)
if !requestResult.OK {
s.closeConn(conn)
return ConnectResult{}, mapToGatewayError(requestResult.Error, "connect failed")
@ -434,6 +467,16 @@ func (s *session) requestRemote(
params map[string]any,
timeout time.Duration,
requireConnected bool,
) RequestResult {
return s.requestRemoteOnConn(method, params, timeout, requireConnected, nil)
}
func (s *session) requestRemoteOnConn(
method string,
params map[string]any,
timeout time.Duration,
requireConnected bool,
boundConn *websocket.Conn,
) RequestResult {
if timeout <= 0 {
timeout = defaultRequestTimeout
@ -441,6 +484,9 @@ func (s *session) requestRemote(
s.mu.Lock()
conn := s.conn
if boundConn != nil {
conn = boundConn
}
connected := s.snapshot.Status == "connected"
if conn == nil || (requireConnected && !connected) {
s.mu.Unlock()
@ -491,6 +537,9 @@ func (s *session) requestRemote(
s.mu.Unlock()
if !response.OK {
gatewayErr := parseRemoteError(response.Error)
if isInvalidHandshakeGatewayError(gatewayErr) {
s.resetConnAfterProtocolError(conn, gatewayErr)
}
if !shouldAutoReconnectForCodes(
gatewayErr.Code,
gatewayErr.DetailCode(),
@ -532,6 +581,60 @@ func (s *session) requestRemote(
}
}
func (s *session) resetConnAfterProtocolError(conn *websocket.Conn, err *GatewayError) {
if conn == nil {
return
}
message := "gateway protocol handshake failed"
code := "INVALID_HANDSHAKE"
if err != nil {
if strings.TrimSpace(err.Message) != "" {
message = strings.TrimSpace(err.Message)
}
if strings.TrimSpace(err.Code) != "" {
code = strings.TrimSpace(err.Code)
}
}
s.mu.Lock()
if s.conn != conn {
s.mu.Unlock()
return
}
s.conn = nil
pending := s.takePendingLocked()
s.snapshot.Status = "error"
s.snapshot.StatusText = "Disconnected"
s.snapshot.LastError = message
s.snapshot.LastErrorCode = code
s.snapshot.LastErrorDetailCode = ""
s.mu.Unlock()
for _, ch := range pending {
ch <- remoteResponse{
OK: false,
Error: (&GatewayError{
Message: "socket closed",
Code: "SOCKET_CLOSED",
}).Map(),
}
}
_ = conn.Close()
s.appendLog("warn", "socket", message)
}
func isInvalidHandshakeGatewayError(err *GatewayError) bool {
if err == nil {
return false
}
code := strings.TrimSpace(strings.ToUpper(err.Code))
if code == "INVALID_HANDSHAKE" {
return true
}
message := strings.TrimSpace(strings.ToLower(err.Message))
return strings.Contains(message, "invalid handshake") ||
strings.Contains(message, "first request must be connect")
}
func (s *session) disconnect() {
s.mu.Lock()
s.manualDisconnect = true

View File

@ -92,6 +92,102 @@ func TestManagerReconnectsAfterSocketClose(t *testing.T) {
}
}
func TestManagerSerializesConcurrentConnectReuseBeforeRequests(t *testing.T) {
server := newFakeGatewayServer(t)
server.enforceConnectFirst.Store(true)
defer server.Close()
manager := NewManager()
manager.ReconnectDelay = 20 * time.Millisecond
const workers = 5
start := make(chan struct{})
errs := make(chan string, workers)
var wg sync.WaitGroup
for index := 0; index < workers; index++ {
wg.Add(1)
go func() {
defer wg.Done()
<-start
connectResult := manager.Connect(
buildTestConnectRequest(server.Port()),
func(map[string]any) {},
)
if !connectResult.OK {
errs <- "connect failed: " + stringValue(connectResult.Error["message"])
return
}
requestResult := manager.Request(
"runtime-1",
"chat.send",
map[string]any{"message": "pong"},
2*time.Second,
func(map[string]any) {},
)
if !requestResult.OK {
errs <- "request failed: " + stringValue(requestResult.Error["message"])
return
}
}()
}
close(start)
wg.Wait()
close(errs)
for err := range errs {
t.Fatal(err)
}
if got := server.InvalidHandshakeCount(); got != 0 {
t.Fatalf("expected no invalid handshake, got %d", got)
}
if got := server.ConnectCount(); got != 1 {
t.Fatalf("expected concurrent connect calls to reuse one established gateway session, got %d", got)
}
}
func TestManagerDropsConnectionAfterInvalidHandshake(t *testing.T) {
server := newFakeGatewayServer(t)
defer server.Close()
manager := NewManager()
result := manager.Connect(buildTestConnectRequest(server.Port()), func(map[string]any) {})
if !result.OK {
t.Fatalf("expected connect success, got %#v", result.Error)
}
server.invalidNextRequest.Store(true)
failed := manager.Request(
"runtime-1",
"chat.send",
map[string]any{"message": "pong"},
2*time.Second,
func(map[string]any) {},
)
if failed.OK {
t.Fatalf("expected invalid handshake failure")
}
if got := stringValue(failed.Error["code"]); got != "INVALID_HANDSHAKE" {
t.Fatalf("expected invalid handshake code, got %#v", failed.Error)
}
reconnected := manager.Connect(buildTestConnectRequest(server.Port()), func(map[string]any) {})
if !reconnected.OK {
t.Fatalf("expected reconnect success, got %#v", reconnected.Error)
}
requestResult := manager.Request(
"runtime-1",
"chat.send",
map[string]any{"message": "pong"},
2*time.Second,
func(map[string]any) {},
)
if !requestResult.OK {
t.Fatalf("expected request after reconnect to succeed, got %#v", requestResult.Error)
}
if got := server.ConnectCount(); got != 2 {
t.Fatalf("expected reconnect after invalid handshake, got %d connects", got)
}
}
func TestManagerSuppressesReconnectForPairingRequired(t *testing.T) {
server := newFakeGatewayServer(t)
server.connectErrorCode = "NOT_PAIRED"
@ -179,7 +275,10 @@ type fakeGatewayServer struct {
server *http.Server
listener net.Listener
connectCount atomic.Int32
invalidHandshakeCount atomic.Int32
closeAfterConnect atomic.Bool
enforceConnectFirst atomic.Bool
invalidNextRequest atomic.Bool
connectErrorCode string
connectErrorDetailCode string
}
@ -208,6 +307,7 @@ func newFakeGatewayServer(t *testing.T) *fakeGatewayServer {
"nonce": "nonce-1",
},
})
connected := false
for {
_, payload, err := conn.ReadMessage()
if err != nil {
@ -222,9 +322,36 @@ func newFakeGatewayServer(t *testing.T) *fakeGatewayServer {
}
id := frame["id"]
method := stringValue(frame["method"])
if fake.enforceConnectFirst.Load() && !connected && method != "connect" {
fake.invalidHandshakeCount.Add(1)
_ = conn.WriteJSON(map[string]any{
"type": "res",
"id": id,
"ok": false,
"error": map[string]any{
"code": "INVALID_HANDSHAKE",
"message": "invalid handshake: first request must be connect",
},
})
continue
}
if fake.invalidNextRequest.Swap(false) && method != "connect" {
fake.invalidHandshakeCount.Add(1)
_ = conn.WriteJSON(map[string]any{
"type": "res",
"id": id,
"ok": false,
"error": map[string]any{
"code": "INVALID_HANDSHAKE",
"message": "invalid handshake: first request must be connect",
},
})
continue
}
switch method {
case "connect":
fake.connectCount.Add(1)
connected = true
if fake.connectErrorCode != "" {
_ = conn.WriteJSON(map[string]any{
"type": "res",
@ -296,6 +423,10 @@ func (f *fakeGatewayServer) ConnectCount() int {
return int(f.connectCount.Load())
}
func (f *fakeGatewayServer) InvalidHandshakeCount() int {
return int(f.invalidHandshakeCount.Load())
}
func (f *fakeGatewayServer) Close() {
_ = f.server.Close()
}