xworkmate-bridge/internal/gatewayruntime/runtime_test.go
Haitao Pan d7cf863fd5 fix(test): add appThreadKey to validate-openclaw-session.sh to pass plugin validation
Since the OpenClaw plugins now enforce appThreadKey to prevent disconnected task maps, the smoke test must supply this key as well.
2026-06-06 06:09:28 +08:00

601 lines
16 KiB
Go

package gatewayruntime
import (
"encoding/json"
"fmt"
"net"
"net/http"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/gorilla/websocket"
)
func TestManagerConnectAndRequest(t *testing.T) {
server := newFakeGatewayServer(t)
defer server.Close()
manager := NewManager()
manager.ReconnectDelay = 20 * time.Millisecond
notifications := make([]map[string]any, 0, 8)
var mu sync.Mutex
notify := func(message map[string]any) {
mu.Lock()
defer mu.Unlock()
notifications = append(notifications, message)
}
result := manager.Connect(buildTestConnectRequest(server.Port()), notify)
if !result.OK {
t.Fatalf("expected connect success, got %#v", result.Error)
}
if result.ReturnedDeviceToken != "device-token-1" {
t.Fatalf("expected returned device token, got %#v", result.ReturnedDeviceToken)
}
requestResult := manager.Request(
"runtime-1",
"health",
map[string]any{},
2*time.Second,
notify,
)
if !requestResult.OK {
t.Fatalf("expected health success, got %#v", requestResult.Error)
}
payload, ok := requestResult.Payload.(map[string]any)
if !ok || payload["status"] != "ok" {
t.Fatalf("unexpected health payload %#v", requestResult.Payload)
}
mu.Lock()
defer mu.Unlock()
if len(notifications) == 0 {
t.Fatalf("expected notifications during connect")
}
}
func TestManagerConnectAdvertisesCurrentOpenClawProtocol(t *testing.T) {
server := newFakeGatewayServer(t)
server.expectedProtocol = defaultProtocolVersion
defer server.Close()
manager := NewManager()
result := manager.Connect(buildTestConnectRequest(server.Port()), func(map[string]any) {})
if !result.OK {
t.Fatalf("expected connect success, got %#v", result.Error)
}
params := server.LastConnectParams()
if params["minProtocol"] != float64(defaultProtocolVersion) {
t.Fatalf("expected minProtocol %d, got %#v", defaultProtocolVersion, params["minProtocol"])
}
if params["maxProtocol"] != float64(defaultProtocolVersion) {
t.Fatalf("expected maxProtocol %d, got %#v", defaultProtocolVersion, params["maxProtocol"])
}
}
func TestGatewayFakeRejectsProtocol3AndAcceptsCurrentProtocol(t *testing.T) {
server := newFakeGatewayServer(t)
server.expectedProtocol = defaultProtocolVersion
defer server.Close()
manager := NewManager()
result := manager.Connect(buildTestConnectRequest(server.Port()), func(map[string]any) {})
if !result.OK {
t.Fatalf("expected current bridge protocol to connect, got %#v", result.Error)
}
conn, _, err := websocket.DefaultDialer.Dial(
fmt.Sprintf("ws://127.0.0.1:%d", server.Port()),
nil,
)
if err != nil {
t.Fatalf("dial fake gateway: %v", err)
}
defer func() {
_ = conn.Close()
}()
var challenge map[string]any
if err := conn.ReadJSON(&challenge); err != nil {
t.Fatalf("read challenge: %v", err)
}
if challenge["event"] != "connect.challenge" {
t.Fatalf("expected connect challenge, got %#v", challenge)
}
if err := conn.WriteJSON(map[string]any{
"type": "req",
"id": "legacy-connect",
"method": "connect",
"params": map[string]any{
"minProtocol": float64(3),
"maxProtocol": float64(3),
},
}); err != nil {
t.Fatalf("write legacy connect: %v", err)
}
var response map[string]any
if err := conn.ReadJSON(&response); err != nil {
t.Fatalf("read legacy connect response: %v", err)
}
if response["ok"] != false {
t.Fatalf("expected protocol 3 rejection, got %#v", response)
}
errorPayload := asMap(response["error"])
if stringValue(errorPayload["message"]) != "protocol mismatch" {
t.Fatalf("expected protocol mismatch error, got %#v", response)
}
}
func TestManagerConnectPreservesEndpointPath(t *testing.T) {
server := newFakeGatewayServer(t)
server.expectedPath = "/gateway/openclaw"
defer server.Close()
manager := NewManager()
request := buildTestConnectRequest(server.Port())
request.Endpoint.Path = "/gateway/openclaw"
result := manager.Connect(request, func(map[string]any) {})
if !result.OK {
t.Fatalf("expected connect success through path-scoped endpoint, got %#v", result.Error)
}
}
func TestManagerReconnectsAfterSocketClose(t *testing.T) {
server := newFakeGatewayServer(t)
server.closeAfterConnect.Store(true)
defer server.Close()
manager := NewManager()
manager.ReconnectDelay = 25 * time.Millisecond
reconnected := make(chan struct{}, 1)
notify := func(message map[string]any) {
params := asMap(message["params"])
if strings.TrimSpace(stringValue(message["method"])) != "xworkmate.gateway.snapshot" {
return
}
snapshot := asMap(params["snapshot"])
if snapshot["status"] == "connected" && server.ConnectCount() >= 2 {
select {
case reconnected <- struct{}{}:
default:
}
}
}
result := manager.Connect(buildTestConnectRequest(server.Port()), notify)
if !result.OK {
t.Fatalf("expected connect success, got %#v", result.Error)
}
select {
case <-reconnected:
case <-time.After(3 * time.Second):
t.Fatalf("expected reconnect to complete; connect count=%d", server.ConnectCount())
}
}
func TestManagerSerializesConcurrentConnectReuseBeforeRequests(t *testing.T) {
server := newFakeGatewayServer(t)
server.enforceConnectFirst.Store(true)
defer server.Close()
manager := NewManager()
manager.ReconnectDelay = 20 * time.Millisecond
const workers = 5
start := make(chan struct{})
errs := make(chan string, workers)
var wg sync.WaitGroup
for index := 0; index < workers; index++ {
wg.Add(1)
go func() {
defer wg.Done()
<-start
connectResult := manager.Connect(
buildTestConnectRequest(server.Port()),
func(map[string]any) {},
)
if !connectResult.OK {
errs <- "connect failed: " + stringValue(connectResult.Error["message"])
return
}
requestResult := manager.Request(
"runtime-1",
"chat.send",
map[string]any{"message": "pong"},
2*time.Second,
func(map[string]any) {},
)
if !requestResult.OK {
errs <- "request failed: " + stringValue(requestResult.Error["message"])
return
}
}()
}
close(start)
wg.Wait()
close(errs)
for err := range errs {
t.Fatal(err)
}
if got := server.InvalidHandshakeCount(); got != 0 {
t.Fatalf("expected no invalid handshake, got %d", got)
}
if got := server.ConnectCount(); got != 1 {
t.Fatalf("expected concurrent connect calls to reuse one established gateway session, got %d", got)
}
}
func TestManagerDropsConnectionAfterInvalidHandshake(t *testing.T) {
server := newFakeGatewayServer(t)
defer server.Close()
manager := NewManager()
result := manager.Connect(buildTestConnectRequest(server.Port()), func(map[string]any) {})
if !result.OK {
t.Fatalf("expected connect success, got %#v", result.Error)
}
server.invalidNextRequest.Store(true)
failed := manager.Request(
"runtime-1",
"chat.send",
map[string]any{"message": "pong"},
2*time.Second,
func(map[string]any) {},
)
if failed.OK {
t.Fatalf("expected invalid handshake failure")
}
if got := stringValue(failed.Error["code"]); got != "INVALID_HANDSHAKE" {
t.Fatalf("expected invalid handshake code, got %#v", failed.Error)
}
reconnected := manager.Connect(buildTestConnectRequest(server.Port()), func(map[string]any) {})
if !reconnected.OK {
t.Fatalf("expected reconnect success, got %#v", reconnected.Error)
}
requestResult := manager.Request(
"runtime-1",
"chat.send",
map[string]any{"message": "pong"},
2*time.Second,
func(map[string]any) {},
)
if !requestResult.OK {
t.Fatalf("expected request after reconnect to succeed, got %#v", requestResult.Error)
}
if got := server.ConnectCount(); got != 2 {
t.Fatalf("expected reconnect after invalid handshake, got %d connects", got)
}
}
func TestManagerSuppressesReconnectForPairingRequired(t *testing.T) {
server := newFakeGatewayServer(t)
server.connectErrorCode = "NOT_PAIRED"
server.connectErrorDetailCode = "PAIRING_REQUIRED"
defer server.Close()
manager := NewManager()
manager.ReconnectDelay = 20 * time.Millisecond
result := manager.Connect(buildTestConnectRequest(server.Port()), func(map[string]any) {})
if result.OK {
t.Fatalf("expected connect failure")
}
time.Sleep(120 * time.Millisecond)
if server.ConnectCount() != 1 {
t.Fatalf("expected reconnect suppression, got %d connect attempts", server.ConnectCount())
}
}
func TestSessionEmitsNormalizedChatRunPushEvents(t *testing.T) {
manager := NewManager()
session := newSession(manager, "runtime-1")
notifications := make([]map[string]any, 0, 8)
session.setNotify(func(message map[string]any) {
notifications = append(notifications, message)
})
session.handleEvent(
"chat",
map[string]any{"seq": 7},
map[string]any{
"runId": "run-1",
"sessionKey": "main",
"state": "final",
"message": map[string]any{
"role": "assistant",
"content": []any{
map[string]any{"type": "text", "text": "XWORKMATE_OK"},
},
},
},
)
session.handleEvent(
"agent",
map[string]any{"seq": 8},
map[string]any{
"runId": "run-1",
"stream": "assistant",
"data": map[string]any{
"text": "DELTA_TEXT",
},
},
)
normalized := make([]map[string]any, 0, 2)
for _, notification := range notifications {
if strings.TrimSpace(stringValue(notification["method"])) != "xworkmate.gateway.push" {
continue
}
params := asMap(notification["params"])
event := asMap(params["event"])
if strings.TrimSpace(stringValue(event["event"])) != "chat.run" {
continue
}
normalized = append(normalized, asMap(event["payload"]))
}
if len(normalized) != 2 {
t.Fatalf("expected 2 normalized chat.run notifications, got %#v", normalized)
}
if normalized[0]["runId"] != "run-1" || normalized[0]["state"] != "final" {
t.Fatalf("unexpected normalized chat payload %#v", normalized[0])
}
if normalized[0]["assistantText"] != "XWORKMATE_OK" {
t.Fatalf("expected final assistant text, got %#v", normalized[0])
}
if normalized[0]["terminal"] != true {
t.Fatalf("expected terminal final chat.run, got %#v", normalized[0])
}
if normalized[1]["assistantText"] != "DELTA_TEXT" || normalized[1]["state"] != "delta" {
t.Fatalf("unexpected normalized agent payload %#v", normalized[1])
}
}
type fakeGatewayServer struct {
server *http.Server
listener net.Listener
connectCount atomic.Int32
invalidHandshakeCount atomic.Int32
closeAfterConnect atomic.Bool
enforceConnectFirst atomic.Bool
invalidNextRequest atomic.Bool
connectErrorCode string
connectErrorDetailCode string
expectedProtocol int
expectedPath string
lastConnectParams atomic.Value
}
func newFakeGatewayServer(t *testing.T) *fakeGatewayServer {
t.Helper()
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("listen: %v", err)
}
fake := &fakeGatewayServer{listener: listener}
upgrader := websocket.Upgrader{CheckOrigin: func(*http.Request) bool { return true }}
mux := http.NewServeMux()
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
if fake.expectedPath != "" && r.URL.Path != fake.expectedPath {
http.NotFound(w, r)
return
}
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
return
}
defer func() {
_ = conn.Close()
}()
_ = conn.WriteJSON(map[string]any{
"type": "event",
"event": "connect.challenge",
"payload": map[string]any{
"nonce": "nonce-1",
},
})
connected := false
for {
_, payload, err := conn.ReadMessage()
if err != nil {
return
}
var frame map[string]any
if err := json.Unmarshal(payload, &frame); err != nil {
continue
}
if frame["type"] != "req" {
continue
}
id := frame["id"]
method := stringValue(frame["method"])
if fake.enforceConnectFirst.Load() && !connected && method != "connect" {
fake.invalidHandshakeCount.Add(1)
_ = conn.WriteJSON(map[string]any{
"type": "res",
"id": id,
"ok": false,
"error": map[string]any{
"code": "INVALID_HANDSHAKE",
"message": "invalid handshake: first request must be connect",
},
})
continue
}
if fake.invalidNextRequest.Swap(false) && method != "connect" {
fake.invalidHandshakeCount.Add(1)
_ = conn.WriteJSON(map[string]any{
"type": "res",
"id": id,
"ok": false,
"error": map[string]any{
"code": "INVALID_HANDSHAKE",
"message": "invalid handshake: first request must be connect",
},
})
continue
}
switch method {
case "connect":
fake.connectCount.Add(1)
connected = true
params := asMap(frame["params"])
fake.lastConnectParams.Store(params)
if fake.expectedProtocol > 0 &&
(params["minProtocol"] != float64(fake.expectedProtocol) ||
params["maxProtocol"] != float64(fake.expectedProtocol)) {
_ = conn.WriteJSON(map[string]any{
"type": "res",
"id": id,
"ok": false,
"error": map[string]any{
"code": "INVALID_REQUEST",
"message": "protocol mismatch",
"details": map[string]any{
"code": "PROTOCOL_MISMATCH",
"clientMinProtocol": params["minProtocol"],
"clientMaxProtocol": params["maxProtocol"],
"expectedProtocol": fake.expectedProtocol,
"minimumProbeProtocol": fake.expectedProtocol,
},
},
})
continue
}
if fake.connectErrorCode != "" {
_ = conn.WriteJSON(map[string]any{
"type": "res",
"id": id,
"ok": false,
"error": map[string]any{
"code": fake.connectErrorCode,
"message": "connect failed",
"details": map[string]any{
"code": fake.connectErrorDetailCode,
},
},
})
continue
}
_ = conn.WriteJSON(map[string]any{
"type": "res",
"id": id,
"ok": true,
"payload": map[string]any{
"server": map[string]any{"host": "127.0.0.1"},
"snapshot": map[string]any{
"sessionDefaults": map[string]any{"mainSessionKey": "main"},
},
"auth": map[string]any{
"role": "operator",
"scopes": defaultOperatorScopes,
"deviceToken": "device-token-1",
},
},
})
if fake.closeAfterConnect.Load() && fake.connectCount.Load() == 1 {
go func() {
time.Sleep(20 * time.Millisecond)
_ = conn.Close()
}()
}
case "health":
_ = conn.WriteJSON(map[string]any{
"type": "res",
"id": id,
"ok": true,
"payload": map[string]any{
"status": "ok",
},
})
default:
_ = conn.WriteJSON(map[string]any{
"type": "res",
"id": id,
"ok": true,
"payload": map[string]any{},
})
}
}
})
fake.server = &http.Server{Handler: mux}
go func() {
_ = fake.server.Serve(listener)
}()
return fake
}
func (f *fakeGatewayServer) Port() int {
return f.listener.Addr().(*net.TCPAddr).Port
}
func (f *fakeGatewayServer) LastConnectParams() map[string]any {
value := f.lastConnectParams.Load()
if value == nil {
return map[string]any{}
}
params, _ := value.(map[string]any)
return params
}
func (f *fakeGatewayServer) ConnectCount() int {
return int(f.connectCount.Load())
}
func (f *fakeGatewayServer) InvalidHandshakeCount() int {
return int(f.invalidHandshakeCount.Load())
}
func (f *fakeGatewayServer) Close() {
_ = f.server.Close()
}
func buildTestConnectRequest(port int) ConnectRequest {
return ConnectRequest{
RuntimeID: "runtime-1",
Mode: "openclaw",
ClientID: "openclaw-macos",
Locale: "en_US",
UserAgent: "XWorkmate/1.0.0",
Endpoint: Endpoint{
Host: "127.0.0.1",
Port: port,
TLS: false,
},
ConnectAuthMode: "shared-token",
ConnectAuthFields: []string{"token"},
ConnectAuthSources: []string{"shared:form"},
HasSharedAuth: true,
HasDeviceToken: false,
PackageInfo: PackageInfo{
AppName: "XWorkmate",
Version: "1.0.0",
},
DeviceInfo: DeviceInfo{
Platform: "macos",
PlatformVersion: "14.0",
DeviceFamily: "Mac",
ModelIdentifier: "Mac14,5",
},
Identity: DeviceIdentity{
DeviceID: "device-1",
PublicKeyBase64URL: "tl4fnKW7VLD0Cl4lQTu2CEgHPs4PWAX7eVgWfWQWk2Q",
PrivateKeyBase64URL: "dr7GfMKoO-lJBtgA0dE5m6f_X4kEFsxChDc7mW8mkXu2Xh-cpbsUsPQKXiVBO7YISAc-zg9YBft5WBZ9ZBaTZA",
},
Auth: AuthConfig{
Token: "shared-token",
},
}
}