xworkmate-bridge/internal/geminiadapter/server_test.go
2026-04-10 16:17:32 +08:00

252 lines
6.7 KiB
Go

package geminiadapter
import (
"bytes"
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/gorilla/websocket"
"xworkmate-bridge/internal/shared"
)
type stubClient struct {
initResult initializeResult
initErr error
callResult map[string]any
callErr error
lastMethod string
lastParams map[string]any
}
func (s *stubClient) Initialize() (initializeResult, error) {
return s.initResult, s.initErr
}
func (s *stubClient) Call(method string, params map[string]any) (map[string]any, error) {
s.lastMethod = method
s.lastParams = params
return s.callResult, s.callErr
}
func (s *stubClient) Close() error {
return nil
}
func TestHandleCapabilitiesSynthesizesProviderResponse(t *testing.T) {
server := NewServer(&stubClient{
initResult: initializeResult{
ProtocolVersion: 1,
AuthMethods: []map[string]any{
{"id": "gemini-api-key"},
},
AgentCapabilities: map[string]any{
"mcpCapabilities": map[string]any{"http": true},
},
},
})
result := server.handleRequest(shared.RPCRequest{
Method: "acp.capabilities",
Params: map[string]any{},
})
if got := result["singleAgent"]; got != true {
t.Fatalf("expected singleAgent true, got %#v", result)
}
providers, _ := result["providers"].([]string)
if len(providers) != 1 || providers[0] != "gemini" {
t.Fatalf("expected gemini provider, got %#v", result)
}
}
func TestHandleRPCSessionStartReturnsUpstreamResult(t *testing.T) {
stub := &stubClient{
initResult: initializeResult{ProtocolVersion: 1},
callResult: map[string]any{
"result": map[string]any{
"success": true,
"output": "hello",
},
},
}
server := NewServer(stub)
server.upstreamMethod = "session.start"
body, _ := json.Marshal(shared.RPCRequest{
JSONRPC: "2.0",
ID: 1,
Method: "session.start",
Params: map[string]any{
"sessionId": "s1",
"taskPrompt": "hello",
},
})
request := httptest.NewRequest(http.MethodPost, "http://127.0.0.1/acp/rpc", bytes.NewReader(body))
recorder := httptest.NewRecorder()
server.HandleRPC(recorder, request)
if recorder.Code != http.StatusOK {
t.Fatalf("expected 200, got %d", recorder.Code)
}
var envelope map[string]any
if err := json.NewDecoder(recorder.Body).Decode(&envelope); err != nil {
t.Fatalf("decode response: %v", err)
}
result := envelope["result"].(map[string]any)
if got := result["output"]; got != "hello" {
t.Fatalf("expected output hello, got %#v", result)
}
if stub.lastMethod != "session.start" {
t.Fatalf("expected upstream method session.start, got %q", stub.lastMethod)
}
}
func TestHandleSessionStartFallsBackToPromptRunner(t *testing.T) {
stub := &stubClient{
initResult: initializeResult{ProtocolVersion: 1},
}
server := NewServer(stub)
server.sessionRunner = func(ctx context.Context, model, prompt, workingDirectory string) (string, error) {
if model != "gemini-2.5-pro" {
t.Fatalf("expected model gemini-2.5-pro, got %q", model)
}
if workingDirectory != "/tmp/demo" {
t.Fatalf("expected workingDirectory /tmp/demo, got %q", workingDirectory)
}
expectedPrompt := "## User Turn 1\nReply with exactly pong"
if prompt != expectedPrompt {
t.Fatalf("unexpected prompt %q", prompt)
}
return "pong", nil
}
result := server.handleRequest(shared.RPCRequest{
Method: "session.start",
Params: map[string]any{
"sessionId": "s1",
"taskPrompt": "Reply with exactly pong",
"model": "gemini-2.5-pro",
"workingDirectory": "/tmp/demo",
},
})
if got := result["output"]; got != "pong" {
t.Fatalf("expected output pong, got %#v", result)
}
if got := result["upstreamMethod"]; got != "prompt" {
t.Fatalf("expected prompt upstream method, got %#v", result)
}
}
func TestHandleSessionMessageReusesAdapterLocalHistory(t *testing.T) {
stub := &stubClient{
initResult: initializeResult{ProtocolVersion: 1},
}
server := NewServer(stub)
callCount := 0
server.sessionRunner = func(ctx context.Context, model, prompt, workingDirectory string) (string, error) {
callCount++
if callCount == 1 {
expected := "## User Turn 1\nFirst turn"
if prompt != expected {
t.Fatalf("unexpected first prompt %q", prompt)
}
return "first-reply", nil
}
expected := "## User Turn 1\nFirst turn\n\n## User Turn 2\nSecond turn"
if prompt != expected {
t.Fatalf("unexpected second prompt %q", prompt)
}
if workingDirectory != "/tmp/demo" {
t.Fatalf("expected inherited workingDirectory, got %q", workingDirectory)
}
if model != "gemini-2.5-flash" {
t.Fatalf("expected inherited model, got %q", model)
}
return "second-reply", nil
}
server.handleRequest(shared.RPCRequest{
Method: "session.start",
Params: map[string]any{
"sessionId": "s1",
"taskPrompt": "First turn",
"model": "gemini-2.5-flash",
"workingDirectory": "/tmp/demo",
},
})
result := server.handleRequest(shared.RPCRequest{
Method: "session.message",
Params: map[string]any{
"sessionId": "s1",
"taskPrompt": "Second turn",
},
})
if got := result["output"]; got != "second-reply" {
t.Fatalf("expected second reply, got %#v", result)
}
}
func TestSessionCloseDropsAdapterLocalState(t *testing.T) {
server := NewServer(&stubClient{initResult: initializeResult{ProtocolVersion: 1}})
server.sessionRunner = func(ctx context.Context, model, prompt, workingDirectory string) (string, error) {
return "ok", nil
}
server.handleRequest(shared.RPCRequest{
Method: "session.start",
Params: map[string]any{
"sessionId": "s1",
"taskPrompt": "hello",
},
})
result := server.handleRequest(shared.RPCRequest{
Method: "session.close",
Params: map[string]any{
"sessionId": "s1",
},
})
if got := result["closed"]; got != true {
t.Fatalf("expected closed true, got %#v", result)
}
}
func TestHandleWebSocketCapabilities(t *testing.T) {
server := NewServer(&stubClient{
initResult: initializeResult{ProtocolVersion: 1},
})
httpServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
server.HandleWebSocket(w, r)
}))
defer httpServer.Close()
wsURL := "ws" + httpServer.URL[len("http"):]
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
if err != nil {
t.Fatalf("dial websocket: %v", err)
}
defer func() {
_ = conn.Close()
}()
if err := conn.WriteJSON(shared.RPCRequest{
JSONRPC: "2.0",
ID: "cap-1",
Method: "acp.capabilities",
Params: map[string]any{},
}); err != nil {
t.Fatalf("write json: %v", err)
}
var envelope map[string]any
if err := conn.ReadJSON(&envelope); err != nil {
t.Fatalf("read json: %v", err)
}
result := envelope["result"].(map[string]any)
providers := result["providers"].([]any)
if len(providers) != 1 || providers[0] != "gemini" {
t.Fatalf("expected gemini provider over websocket, got %#v", result)
}
}