fix: stabilize openclaw concurrent gateway sessions
This commit is contained in:
parent
6d31a95f70
commit
de3d637296
@ -1233,11 +1233,13 @@ func (o *SessionOrchestrator) openClawGatewayRequestWithRetry(
|
||||
|
||||
func isOpenClawRetryableGatewayError(errorPayload map[string]any) bool {
|
||||
code := strings.TrimSpace(strings.ToUpper(shared.StringArg(errorPayload, "code", "")))
|
||||
if code == "SOCKET_CLOSED" || code == "SOCKET_FAILURE" || code == "OFFLINE" {
|
||||
if code == "SOCKET_CLOSED" || code == "SOCKET_FAILURE" || code == "OFFLINE" || code == "INVALID_HANDSHAKE" {
|
||||
return true
|
||||
}
|
||||
message := strings.TrimSpace(strings.ToLower(shared.StringArg(errorPayload, "message", "")))
|
||||
return strings.Contains(message, "socket closed")
|
||||
return strings.Contains(message, "socket closed") ||
|
||||
strings.Contains(message, "invalid handshake") ||
|
||||
strings.Contains(message, "first request must be connect")
|
||||
}
|
||||
|
||||
func firstNonEmptyString(values map[string]any, keys ...string) string {
|
||||
|
||||
@ -324,6 +324,119 @@ func TestHTTPHandlerGatewayOpenClawAdmissionQueuesExcessConcurrentSSE(t *testing
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTPHandlerGatewayOpenClawHandlesFiveConcurrentE2ECases(t *testing.T) {
|
||||
gateway := newAcpFakeOpenClawGateway(t)
|
||||
defer gateway.Close()
|
||||
gateway.agentWaitDelayMs.Store(200)
|
||||
|
||||
t.Setenv("GATEWAY_RPC_URL", gateway.URL())
|
||||
t.Setenv("BRIDGE_AUTH_TOKEN", "bridge-test-token")
|
||||
t.Setenv("BRIDGE_CONFIG_PATH", filepath.Join(t.TempDir(), "missing-config.yaml"))
|
||||
t.Setenv("XWORKMATE_BRIDGE_OPENCLAW_GATEWAY_MAX_ACTIVE", "5")
|
||||
t.Setenv("XWORKMATE_BRIDGE_OPENCLAW_GATEWAY_MAX_QUEUED", "20")
|
||||
t.Setenv("XWORKMATE_BRIDGE_OPENCLAW_GATEWAY_QUEUE_TIMEOUT", "5s")
|
||||
server := NewServer()
|
||||
httpServer := httptest.NewServer(server.Handler())
|
||||
defer httpServer.Close()
|
||||
|
||||
prompts := []string{
|
||||
"从单机权限 → 网络边界 → Web安全 → 云身份 → Zero Trust → AI Agent 身份 → AI模型与知识保护 演进 制作 使用codex 制作连续制作 7张的一些列图片",
|
||||
"参考附件模版制作 ,围绕 从单机权限 → 网络边界 → Web安全 → 云身份 → Zero Trust → AI Agent 身份 → AI模型与知识保护 演进 连续制作 7张的一些列图片",
|
||||
"拆章节 -> 每章调用 Codex -> 每章 GPT images2 生成图 -> 汇总排版 -> 输出 PDF make artifact",
|
||||
"围绕 从单机权限 → 网络边界 → Web安全 → 云身份 → Zero Trust → AI Agent 身份 → AI模型与知识保护 演进 右侧是当下 测试制作视频",
|
||||
"从单机权限 → 网络边界 → Web安全 → 云身份 → Zero Trust → AI Agent 身份 → AI模型与知识保护 演进 拆章节 -> 每章调用 Codex -> 每章 GPT images2 生成图 -> 汇总排版 -> 制作视频",
|
||||
}
|
||||
type result struct {
|
||||
body string
|
||||
err error
|
||||
}
|
||||
results := make(chan result, len(prompts))
|
||||
start := make(chan struct{})
|
||||
var wg sync.WaitGroup
|
||||
for index, prompt := range prompts {
|
||||
wg.Add(1)
|
||||
go func(index int, prompt string) {
|
||||
defer wg.Done()
|
||||
<-start
|
||||
body := fmt.Sprintf(
|
||||
`{"jsonrpc":"2.0","id":"e2e-%d","method":"session.start","params":{"sessionId":"e2e-s%d","threadId":"e2e-t%d","taskPrompt":%q,"workingDirectory":%q,"routing":{"routingMode":"explicit","explicitExecutionTarget":"gateway","preferredGatewayProviderId":"openclaw"}}}`,
|
||||
index,
|
||||
index,
|
||||
index,
|
||||
prompt,
|
||||
t.TempDir(),
|
||||
)
|
||||
request, err := http.NewRequest(
|
||||
http.MethodPost,
|
||||
httpServer.URL+"/acp/rpc",
|
||||
strings.NewReader(body),
|
||||
)
|
||||
if err != nil {
|
||||
results <- result{err: err}
|
||||
return
|
||||
}
|
||||
request.Header.Set("Content-Type", "application/json")
|
||||
request.Header.Set("Accept", "text/event-stream")
|
||||
request.Header.Set("Authorization", "Bearer bridge-test-token")
|
||||
response, err := http.DefaultClient.Do(request)
|
||||
if err != nil {
|
||||
results <- result{err: err}
|
||||
return
|
||||
}
|
||||
defer func() { _ = response.Body.Close() }()
|
||||
responseBody, err := io.ReadAll(response.Body)
|
||||
if err != nil {
|
||||
results <- result{err: err}
|
||||
return
|
||||
}
|
||||
if response.StatusCode != http.StatusOK {
|
||||
results <- result{err: fmt.Errorf("expected 200, got %d: %s", response.StatusCode, string(responseBody))}
|
||||
return
|
||||
}
|
||||
results <- result{body: string(responseBody)}
|
||||
}(index, prompt)
|
||||
}
|
||||
close(start)
|
||||
waitForOpenClawGatewayCount(t, gateway.ChatSendCount, len(prompts))
|
||||
wg.Wait()
|
||||
close(results)
|
||||
|
||||
var finalCount int
|
||||
for item := range results {
|
||||
if item.err != nil {
|
||||
t.Fatalf("concurrent e2e request failed: %v", item.err)
|
||||
}
|
||||
if strings.Contains(item.body, `"event":"queued"`) {
|
||||
t.Fatalf("expected five active OpenClaw slots without queueing, got queued event: %s", item.body)
|
||||
}
|
||||
for _, unexpected := range []string{
|
||||
"invalid handshake",
|
||||
"SOCKET_CLOSED",
|
||||
"ACP_HTTP_CONNECTION_CLOSED",
|
||||
"GATEWAY_CONNECT_FAILED",
|
||||
} {
|
||||
if strings.Contains(item.body, unexpected) {
|
||||
t.Fatalf("unexpected gateway stability error %q in body: %s", unexpected, item.body)
|
||||
}
|
||||
}
|
||||
if strings.Contains(item.body, `"result"`) && strings.Contains(item.body, `data: [DONE]`) {
|
||||
finalCount += 1
|
||||
}
|
||||
}
|
||||
if finalCount != len(prompts) {
|
||||
t.Fatalf("expected all five e2e requests to return final result, got %d", finalCount)
|
||||
}
|
||||
if got := gateway.ConnectCount(); got != 1 {
|
||||
t.Fatalf("expected bridge to reuse one established OpenClaw connection, got %d connects", got)
|
||||
}
|
||||
if got := gateway.ChatSendCount(); got != len(prompts) {
|
||||
t.Fatalf("expected five chat.send calls, got %d", got)
|
||||
}
|
||||
if got := gateway.AgentWaitCount(); got != len(prompts) {
|
||||
t.Fatalf("expected five agent.wait calls, got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTPHandlerGatewayOpenClawAdmissionRejectsWhenQueueFull(t *testing.T) {
|
||||
gateway := newAcpFakeOpenClawGateway(t)
|
||||
defer gateway.Close()
|
||||
|
||||
@ -138,8 +138,7 @@ func (m *Manager) Connect(
|
||||
}
|
||||
m.mu.Unlock()
|
||||
|
||||
current.configure(request, notify)
|
||||
return current.connect()
|
||||
return current.connect(request, notify)
|
||||
}
|
||||
|
||||
func (m *Manager) Request(
|
||||
@ -223,6 +222,7 @@ type session struct {
|
||||
runtimeID string
|
||||
|
||||
mu sync.Mutex
|
||||
connectMu sync.Mutex
|
||||
writeMu sync.Mutex
|
||||
notify func(map[string]any)
|
||||
config ConnectRequest
|
||||
@ -273,7 +273,17 @@ func (s *session) setNotify(notify func(map[string]any)) {
|
||||
s.notify = notify
|
||||
}
|
||||
|
||||
func (s *session) connect() ConnectResult {
|
||||
func (s *session) connect(request ConnectRequest, notify func(map[string]any)) ConnectResult {
|
||||
s.connectMu.Lock()
|
||||
defer s.connectMu.Unlock()
|
||||
|
||||
if snapshot, ok := s.connectedSnapshotFor(request, notify); ok {
|
||||
return ConnectResult{
|
||||
OK: true,
|
||||
Snapshot: snapshot,
|
||||
}
|
||||
}
|
||||
s.configure(request, notify)
|
||||
s.appendLog(
|
||||
"info",
|
||||
"connect",
|
||||
@ -310,6 +320,29 @@ func (s *session) connect() ConnectResult {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *session) connectedSnapshotFor(
|
||||
request ConnectRequest,
|
||||
notify func(map[string]any),
|
||||
) (map[string]any, bool) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.conn == nil || s.snapshot.Status != "connected" {
|
||||
return nil, false
|
||||
}
|
||||
if !sameConnectTarget(s.config, request) {
|
||||
return nil, false
|
||||
}
|
||||
s.notify = notify
|
||||
return s.snapshot.Map(), true
|
||||
}
|
||||
|
||||
func sameConnectTarget(current ConnectRequest, next ConnectRequest) bool {
|
||||
return strings.TrimSpace(current.Mode) == strings.TrimSpace(next.Mode) &&
|
||||
strings.TrimSpace(current.Endpoint.Host) == strings.TrimSpace(next.Endpoint.Host) &&
|
||||
current.Endpoint.Port == next.Endpoint.Port &&
|
||||
current.Endpoint.TLS == next.Endpoint.TLS
|
||||
}
|
||||
|
||||
func (s *session) connectAttempt() (ConnectResult, *GatewayError) {
|
||||
url := fmt.Sprintf(
|
||||
"%s://%s:%d",
|
||||
@ -352,7 +385,7 @@ func (s *session) connectAttempt() (ConnectResult, *GatewayError) {
|
||||
s.closeConn(conn)
|
||||
return ConnectResult{}, gatewayErr
|
||||
}
|
||||
requestResult := s.requestRemote("connect", params, 12*time.Second, false)
|
||||
requestResult := s.requestRemoteOnConn("connect", params, 12*time.Second, false, conn)
|
||||
if !requestResult.OK {
|
||||
s.closeConn(conn)
|
||||
return ConnectResult{}, mapToGatewayError(requestResult.Error, "connect failed")
|
||||
@ -434,6 +467,16 @@ func (s *session) requestRemote(
|
||||
params map[string]any,
|
||||
timeout time.Duration,
|
||||
requireConnected bool,
|
||||
) RequestResult {
|
||||
return s.requestRemoteOnConn(method, params, timeout, requireConnected, nil)
|
||||
}
|
||||
|
||||
func (s *session) requestRemoteOnConn(
|
||||
method string,
|
||||
params map[string]any,
|
||||
timeout time.Duration,
|
||||
requireConnected bool,
|
||||
boundConn *websocket.Conn,
|
||||
) RequestResult {
|
||||
if timeout <= 0 {
|
||||
timeout = defaultRequestTimeout
|
||||
@ -441,6 +484,9 @@ func (s *session) requestRemote(
|
||||
|
||||
s.mu.Lock()
|
||||
conn := s.conn
|
||||
if boundConn != nil {
|
||||
conn = boundConn
|
||||
}
|
||||
connected := s.snapshot.Status == "connected"
|
||||
if conn == nil || (requireConnected && !connected) {
|
||||
s.mu.Unlock()
|
||||
@ -491,6 +537,9 @@ func (s *session) requestRemote(
|
||||
s.mu.Unlock()
|
||||
if !response.OK {
|
||||
gatewayErr := parseRemoteError(response.Error)
|
||||
if isInvalidHandshakeGatewayError(gatewayErr) {
|
||||
s.resetConnAfterProtocolError(conn, gatewayErr)
|
||||
}
|
||||
if !shouldAutoReconnectForCodes(
|
||||
gatewayErr.Code,
|
||||
gatewayErr.DetailCode(),
|
||||
@ -532,6 +581,60 @@ func (s *session) requestRemote(
|
||||
}
|
||||
}
|
||||
|
||||
func (s *session) resetConnAfterProtocolError(conn *websocket.Conn, err *GatewayError) {
|
||||
if conn == nil {
|
||||
return
|
||||
}
|
||||
message := "gateway protocol handshake failed"
|
||||
code := "INVALID_HANDSHAKE"
|
||||
if err != nil {
|
||||
if strings.TrimSpace(err.Message) != "" {
|
||||
message = strings.TrimSpace(err.Message)
|
||||
}
|
||||
if strings.TrimSpace(err.Code) != "" {
|
||||
code = strings.TrimSpace(err.Code)
|
||||
}
|
||||
}
|
||||
s.mu.Lock()
|
||||
if s.conn != conn {
|
||||
s.mu.Unlock()
|
||||
return
|
||||
}
|
||||
s.conn = nil
|
||||
pending := s.takePendingLocked()
|
||||
s.snapshot.Status = "error"
|
||||
s.snapshot.StatusText = "Disconnected"
|
||||
s.snapshot.LastError = message
|
||||
s.snapshot.LastErrorCode = code
|
||||
s.snapshot.LastErrorDetailCode = ""
|
||||
s.mu.Unlock()
|
||||
|
||||
for _, ch := range pending {
|
||||
ch <- remoteResponse{
|
||||
OK: false,
|
||||
Error: (&GatewayError{
|
||||
Message: "socket closed",
|
||||
Code: "SOCKET_CLOSED",
|
||||
}).Map(),
|
||||
}
|
||||
}
|
||||
_ = conn.Close()
|
||||
s.appendLog("warn", "socket", message)
|
||||
}
|
||||
|
||||
func isInvalidHandshakeGatewayError(err *GatewayError) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
code := strings.TrimSpace(strings.ToUpper(err.Code))
|
||||
if code == "INVALID_HANDSHAKE" {
|
||||
return true
|
||||
}
|
||||
message := strings.TrimSpace(strings.ToLower(err.Message))
|
||||
return strings.Contains(message, "invalid handshake") ||
|
||||
strings.Contains(message, "first request must be connect")
|
||||
}
|
||||
|
||||
func (s *session) disconnect() {
|
||||
s.mu.Lock()
|
||||
s.manualDisconnect = true
|
||||
|
||||
@ -92,6 +92,102 @@ func TestManagerReconnectsAfterSocketClose(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagerSerializesConcurrentConnectReuseBeforeRequests(t *testing.T) {
|
||||
server := newFakeGatewayServer(t)
|
||||
server.enforceConnectFirst.Store(true)
|
||||
defer server.Close()
|
||||
|
||||
manager := NewManager()
|
||||
manager.ReconnectDelay = 20 * time.Millisecond
|
||||
const workers = 5
|
||||
start := make(chan struct{})
|
||||
errs := make(chan string, workers)
|
||||
var wg sync.WaitGroup
|
||||
for index := 0; index < workers; index++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
<-start
|
||||
connectResult := manager.Connect(
|
||||
buildTestConnectRequest(server.Port()),
|
||||
func(map[string]any) {},
|
||||
)
|
||||
if !connectResult.OK {
|
||||
errs <- "connect failed: " + stringValue(connectResult.Error["message"])
|
||||
return
|
||||
}
|
||||
requestResult := manager.Request(
|
||||
"runtime-1",
|
||||
"chat.send",
|
||||
map[string]any{"message": "pong"},
|
||||
2*time.Second,
|
||||
func(map[string]any) {},
|
||||
)
|
||||
if !requestResult.OK {
|
||||
errs <- "request failed: " + stringValue(requestResult.Error["message"])
|
||||
return
|
||||
}
|
||||
}()
|
||||
}
|
||||
close(start)
|
||||
wg.Wait()
|
||||
close(errs)
|
||||
|
||||
for err := range errs {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if got := server.InvalidHandshakeCount(); got != 0 {
|
||||
t.Fatalf("expected no invalid handshake, got %d", got)
|
||||
}
|
||||
if got := server.ConnectCount(); got != 1 {
|
||||
t.Fatalf("expected concurrent connect calls to reuse one established gateway session, got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagerDropsConnectionAfterInvalidHandshake(t *testing.T) {
|
||||
server := newFakeGatewayServer(t)
|
||||
defer server.Close()
|
||||
|
||||
manager := NewManager()
|
||||
result := manager.Connect(buildTestConnectRequest(server.Port()), func(map[string]any) {})
|
||||
if !result.OK {
|
||||
t.Fatalf("expected connect success, got %#v", result.Error)
|
||||
}
|
||||
|
||||
server.invalidNextRequest.Store(true)
|
||||
failed := manager.Request(
|
||||
"runtime-1",
|
||||
"chat.send",
|
||||
map[string]any{"message": "pong"},
|
||||
2*time.Second,
|
||||
func(map[string]any) {},
|
||||
)
|
||||
if failed.OK {
|
||||
t.Fatalf("expected invalid handshake failure")
|
||||
}
|
||||
if got := stringValue(failed.Error["code"]); got != "INVALID_HANDSHAKE" {
|
||||
t.Fatalf("expected invalid handshake code, got %#v", failed.Error)
|
||||
}
|
||||
|
||||
reconnected := manager.Connect(buildTestConnectRequest(server.Port()), func(map[string]any) {})
|
||||
if !reconnected.OK {
|
||||
t.Fatalf("expected reconnect success, got %#v", reconnected.Error)
|
||||
}
|
||||
requestResult := manager.Request(
|
||||
"runtime-1",
|
||||
"chat.send",
|
||||
map[string]any{"message": "pong"},
|
||||
2*time.Second,
|
||||
func(map[string]any) {},
|
||||
)
|
||||
if !requestResult.OK {
|
||||
t.Fatalf("expected request after reconnect to succeed, got %#v", requestResult.Error)
|
||||
}
|
||||
if got := server.ConnectCount(); got != 2 {
|
||||
t.Fatalf("expected reconnect after invalid handshake, got %d connects", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagerSuppressesReconnectForPairingRequired(t *testing.T) {
|
||||
server := newFakeGatewayServer(t)
|
||||
server.connectErrorCode = "NOT_PAIRED"
|
||||
@ -179,7 +275,10 @@ type fakeGatewayServer struct {
|
||||
server *http.Server
|
||||
listener net.Listener
|
||||
connectCount atomic.Int32
|
||||
invalidHandshakeCount atomic.Int32
|
||||
closeAfterConnect atomic.Bool
|
||||
enforceConnectFirst atomic.Bool
|
||||
invalidNextRequest atomic.Bool
|
||||
connectErrorCode string
|
||||
connectErrorDetailCode string
|
||||
}
|
||||
@ -208,6 +307,7 @@ func newFakeGatewayServer(t *testing.T) *fakeGatewayServer {
|
||||
"nonce": "nonce-1",
|
||||
},
|
||||
})
|
||||
connected := false
|
||||
for {
|
||||
_, payload, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
@ -222,9 +322,36 @@ func newFakeGatewayServer(t *testing.T) *fakeGatewayServer {
|
||||
}
|
||||
id := frame["id"]
|
||||
method := stringValue(frame["method"])
|
||||
if fake.enforceConnectFirst.Load() && !connected && method != "connect" {
|
||||
fake.invalidHandshakeCount.Add(1)
|
||||
_ = conn.WriteJSON(map[string]any{
|
||||
"type": "res",
|
||||
"id": id,
|
||||
"ok": false,
|
||||
"error": map[string]any{
|
||||
"code": "INVALID_HANDSHAKE",
|
||||
"message": "invalid handshake: first request must be connect",
|
||||
},
|
||||
})
|
||||
continue
|
||||
}
|
||||
if fake.invalidNextRequest.Swap(false) && method != "connect" {
|
||||
fake.invalidHandshakeCount.Add(1)
|
||||
_ = conn.WriteJSON(map[string]any{
|
||||
"type": "res",
|
||||
"id": id,
|
||||
"ok": false,
|
||||
"error": map[string]any{
|
||||
"code": "INVALID_HANDSHAKE",
|
||||
"message": "invalid handshake: first request must be connect",
|
||||
},
|
||||
})
|
||||
continue
|
||||
}
|
||||
switch method {
|
||||
case "connect":
|
||||
fake.connectCount.Add(1)
|
||||
connected = true
|
||||
if fake.connectErrorCode != "" {
|
||||
_ = conn.WriteJSON(map[string]any{
|
||||
"type": "res",
|
||||
@ -296,6 +423,10 @@ func (f *fakeGatewayServer) ConnectCount() int {
|
||||
return int(f.connectCount.Load())
|
||||
}
|
||||
|
||||
func (f *fakeGatewayServer) InvalidHandshakeCount() int {
|
||||
return int(f.invalidHandshakeCount.Load())
|
||||
}
|
||||
|
||||
func (f *fakeGatewayServer) Close() {
|
||||
_ = f.server.Close()
|
||||
}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user