252 lines
6.7 KiB
Go
252 lines
6.7 KiB
Go
package geminiadapter
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"testing"
|
|
|
|
"github.com/gorilla/websocket"
|
|
|
|
"xworkmate-bridge/internal/shared"
|
|
)
|
|
|
|
type stubClient struct {
|
|
initResult initializeResult
|
|
initErr error
|
|
callResult map[string]any
|
|
callErr error
|
|
lastMethod string
|
|
lastParams map[string]any
|
|
}
|
|
|
|
func (s *stubClient) Initialize() (initializeResult, error) {
|
|
return s.initResult, s.initErr
|
|
}
|
|
|
|
func (s *stubClient) Call(method string, params map[string]any) (map[string]any, error) {
|
|
s.lastMethod = method
|
|
s.lastParams = params
|
|
return s.callResult, s.callErr
|
|
}
|
|
|
|
func (s *stubClient) Close() error {
|
|
return nil
|
|
}
|
|
|
|
func TestHandleCapabilitiesSynthesizesProviderResponse(t *testing.T) {
|
|
server := NewServer(&stubClient{
|
|
initResult: initializeResult{
|
|
ProtocolVersion: 1,
|
|
AuthMethods: []map[string]any{
|
|
{"id": "gemini-api-key"},
|
|
},
|
|
AgentCapabilities: map[string]any{
|
|
"mcpCapabilities": map[string]any{"http": true},
|
|
},
|
|
},
|
|
})
|
|
|
|
result := server.handleRequest(shared.RPCRequest{
|
|
Method: "acp.capabilities",
|
|
Params: map[string]any{},
|
|
})
|
|
if got := result["singleAgent"]; got != true {
|
|
t.Fatalf("expected singleAgent true, got %#v", result)
|
|
}
|
|
providers, _ := result["providers"].([]string)
|
|
if len(providers) != 1 || providers[0] != "gemini" {
|
|
t.Fatalf("expected gemini provider, got %#v", result)
|
|
}
|
|
}
|
|
|
|
func TestHandleRPCSessionStartReturnsUpstreamResult(t *testing.T) {
|
|
stub := &stubClient{
|
|
initResult: initializeResult{ProtocolVersion: 1},
|
|
callResult: map[string]any{
|
|
"result": map[string]any{
|
|
"success": true,
|
|
"output": "hello",
|
|
},
|
|
},
|
|
}
|
|
server := NewServer(stub)
|
|
server.upstreamMethod = "session.start"
|
|
|
|
body, _ := json.Marshal(shared.RPCRequest{
|
|
JSONRPC: "2.0",
|
|
ID: 1,
|
|
Method: "session.start",
|
|
Params: map[string]any{
|
|
"sessionId": "s1",
|
|
"taskPrompt": "hello",
|
|
},
|
|
})
|
|
request := httptest.NewRequest(http.MethodPost, "http://127.0.0.1/acp/rpc", bytes.NewReader(body))
|
|
recorder := httptest.NewRecorder()
|
|
|
|
server.HandleRPC(recorder, request)
|
|
|
|
if recorder.Code != http.StatusOK {
|
|
t.Fatalf("expected 200, got %d", recorder.Code)
|
|
}
|
|
var envelope map[string]any
|
|
if err := json.NewDecoder(recorder.Body).Decode(&envelope); err != nil {
|
|
t.Fatalf("decode response: %v", err)
|
|
}
|
|
result := envelope["result"].(map[string]any)
|
|
if got := result["output"]; got != "hello" {
|
|
t.Fatalf("expected output hello, got %#v", result)
|
|
}
|
|
if stub.lastMethod != "session.start" {
|
|
t.Fatalf("expected upstream method session.start, got %q", stub.lastMethod)
|
|
}
|
|
}
|
|
|
|
func TestHandleSessionStartFallsBackToPromptRunner(t *testing.T) {
|
|
stub := &stubClient{
|
|
initResult: initializeResult{ProtocolVersion: 1},
|
|
}
|
|
server := NewServer(stub)
|
|
server.sessionRunner = func(ctx context.Context, model, prompt, workingDirectory string) (string, error) {
|
|
if model != "gemini-2.5-pro" {
|
|
t.Fatalf("expected model gemini-2.5-pro, got %q", model)
|
|
}
|
|
if workingDirectory != "/tmp/demo" {
|
|
t.Fatalf("expected workingDirectory /tmp/demo, got %q", workingDirectory)
|
|
}
|
|
expectedPrompt := "## User Turn 1\nReply with exactly pong"
|
|
if prompt != expectedPrompt {
|
|
t.Fatalf("unexpected prompt %q", prompt)
|
|
}
|
|
return "pong", nil
|
|
}
|
|
|
|
result := server.handleRequest(shared.RPCRequest{
|
|
Method: "session.start",
|
|
Params: map[string]any{
|
|
"sessionId": "s1",
|
|
"taskPrompt": "Reply with exactly pong",
|
|
"model": "gemini-2.5-pro",
|
|
"workingDirectory": "/tmp/demo",
|
|
},
|
|
})
|
|
if got := result["output"]; got != "pong" {
|
|
t.Fatalf("expected output pong, got %#v", result)
|
|
}
|
|
if got := result["upstreamMethod"]; got != "prompt" {
|
|
t.Fatalf("expected prompt upstream method, got %#v", result)
|
|
}
|
|
}
|
|
|
|
func TestHandleSessionMessageReusesAdapterLocalHistory(t *testing.T) {
|
|
stub := &stubClient{
|
|
initResult: initializeResult{ProtocolVersion: 1},
|
|
}
|
|
server := NewServer(stub)
|
|
callCount := 0
|
|
server.sessionRunner = func(ctx context.Context, model, prompt, workingDirectory string) (string, error) {
|
|
callCount++
|
|
if callCount == 1 {
|
|
expected := "## User Turn 1\nFirst turn"
|
|
if prompt != expected {
|
|
t.Fatalf("unexpected first prompt %q", prompt)
|
|
}
|
|
return "first-reply", nil
|
|
}
|
|
expected := "## User Turn 1\nFirst turn\n\n## User Turn 2\nSecond turn"
|
|
if prompt != expected {
|
|
t.Fatalf("unexpected second prompt %q", prompt)
|
|
}
|
|
if workingDirectory != "/tmp/demo" {
|
|
t.Fatalf("expected inherited workingDirectory, got %q", workingDirectory)
|
|
}
|
|
if model != "gemini-2.5-flash" {
|
|
t.Fatalf("expected inherited model, got %q", model)
|
|
}
|
|
return "second-reply", nil
|
|
}
|
|
|
|
server.handleRequest(shared.RPCRequest{
|
|
Method: "session.start",
|
|
Params: map[string]any{
|
|
"sessionId": "s1",
|
|
"taskPrompt": "First turn",
|
|
"model": "gemini-2.5-flash",
|
|
"workingDirectory": "/tmp/demo",
|
|
},
|
|
})
|
|
result := server.handleRequest(shared.RPCRequest{
|
|
Method: "session.message",
|
|
Params: map[string]any{
|
|
"sessionId": "s1",
|
|
"taskPrompt": "Second turn",
|
|
},
|
|
})
|
|
if got := result["output"]; got != "second-reply" {
|
|
t.Fatalf("expected second reply, got %#v", result)
|
|
}
|
|
}
|
|
|
|
func TestSessionCloseDropsAdapterLocalState(t *testing.T) {
|
|
server := NewServer(&stubClient{initResult: initializeResult{ProtocolVersion: 1}})
|
|
server.sessionRunner = func(ctx context.Context, model, prompt, workingDirectory string) (string, error) {
|
|
return "ok", nil
|
|
}
|
|
server.handleRequest(shared.RPCRequest{
|
|
Method: "session.start",
|
|
Params: map[string]any{
|
|
"sessionId": "s1",
|
|
"taskPrompt": "hello",
|
|
},
|
|
})
|
|
result := server.handleRequest(shared.RPCRequest{
|
|
Method: "session.close",
|
|
Params: map[string]any{
|
|
"sessionId": "s1",
|
|
},
|
|
})
|
|
if got := result["closed"]; got != true {
|
|
t.Fatalf("expected closed true, got %#v", result)
|
|
}
|
|
}
|
|
|
|
func TestHandleWebSocketCapabilities(t *testing.T) {
|
|
server := NewServer(&stubClient{
|
|
initResult: initializeResult{ProtocolVersion: 1},
|
|
})
|
|
httpServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
server.HandleWebSocket(w, r)
|
|
}))
|
|
defer httpServer.Close()
|
|
|
|
wsURL := "ws" + httpServer.URL[len("http"):]
|
|
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
|
|
if err != nil {
|
|
t.Fatalf("dial websocket: %v", err)
|
|
}
|
|
defer func() {
|
|
_ = conn.Close()
|
|
}()
|
|
|
|
if err := conn.WriteJSON(shared.RPCRequest{
|
|
JSONRPC: "2.0",
|
|
ID: "cap-1",
|
|
Method: "acp.capabilities",
|
|
Params: map[string]any{},
|
|
}); err != nil {
|
|
t.Fatalf("write json: %v", err)
|
|
}
|
|
var envelope map[string]any
|
|
if err := conn.ReadJSON(&envelope); err != nil {
|
|
t.Fatalf("read json: %v", err)
|
|
}
|
|
result := envelope["result"].(map[string]any)
|
|
providers := result["providers"].([]any)
|
|
if len(providers) != 1 || providers[0] != "gemini" {
|
|
t.Fatalf("expected gemini provider over websocket, got %#v", result)
|
|
}
|
|
}
|