Fix ACP provider endpoint routing

This commit is contained in:
Haitao Pan 2026-04-24 11:51:28 +08:00
parent 17c0fa6f16
commit cce9833689
4 changed files with 379 additions and 19 deletions

View File

@ -1,6 +1,10 @@
package acp
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"os"
"testing"
)
@ -52,13 +56,37 @@ func TestResolveSingleAgentForwardEndpointManual(t *testing.T) {
want string
}{
{
name: "preserves upstream endpoint",
name: "preserves http rpc endpoint",
provider: syncedProvider{
ProviderID: "custom",
Endpoint: "https://upstream-provider.example.com/acp/rpc",
},
want: "https://upstream-provider.example.com/acp/rpc",
},
{
name: "normalizes http acp endpoint to rpc endpoint",
provider: syncedProvider{
ProviderID: "opencode",
Endpoint: "http://127.0.0.1:39992/acp",
},
want: "http://127.0.0.1:39992/acp/rpc",
},
{
name: "normalizes websocket opencode endpoint to http rpc endpoint",
provider: syncedProvider{
ProviderID: "opencode",
Endpoint: "ws://127.0.0.1:39992/acp",
},
want: "http://127.0.0.1:39992/acp/rpc",
},
{
name: "does not duplicate nested acp path",
provider: syncedProvider{
ProviderID: "opencode",
Endpoint: "http://127.0.0.1:39992/acp",
},
want: "http://127.0.0.1:39992/acp/rpc",
},
}
for _, tc := range cases {
@ -92,6 +120,64 @@ func TestNormalizeAuthorizationHeader(t *testing.T) {
}
}
func TestCodexCompatTranslatesSessionLifecycleToThreadAndTurnRPC(t *testing.T) {
t.Parallel()
var methods []string
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
defer func() {
_ = r.Body.Close()
}()
var request map[string]any
if err := json.NewDecoder(r.Body).Decode(&request); err != nil {
t.Fatalf("decode request: %v", err)
}
method := stringValue(request["method"])
methods = append(methods, method)
result := map[string]any{}
switch method {
case "thread/start":
result["id"] = "codex-thread-1"
case "turn/start":
result["output"] = "pong"
default:
t.Fatalf("unexpected codex upstream method %q", method)
}
_ = json.NewEncoder(w).Encode(map[string]any{
"jsonrpc": "2.0",
"id": request["id"],
"result": result,
})
}))
defer upstream.Close()
compat := newProviderCompat(syncedProvider{
ProviderID: "codex",
Label: "Codex",
Endpoint: upstream.URL,
Enabled: true,
})
result, err := compat.StartSession(
context.Background(),
"session-1",
"thread-1",
map[string]any{
"taskPrompt": "Reply with exactly pong",
"workingDirectory": t.TempDir(),
},
nil,
)
if err != nil {
t.Fatalf("StartSession failed: %v", err)
}
if got := result["output"]; got != "pong" {
t.Fatalf("expected pong output, got %#v", result)
}
if len(methods) != 2 || methods[0] != "thread/start" || methods[1] != "turn/start" {
t.Fatalf("expected thread/start then turn/start, got %#v", methods)
}
}
func TestExternalACPNotificationCollectorExtractsNestedSessionUpdateText(t *testing.T) {
t.Parallel()

View File

@ -31,12 +31,25 @@ func resolveSingleAgentForwardEndpoint(provider syncedProvider) string {
isWS := strings.HasPrefix(parsed.Scheme, "ws")
isHTTP := strings.HasPrefix(parsed.Scheme, "http")
path := strings.TrimRight(parsed.Path, "/")
path := strings.TrimRight(parsed.EscapedPath(), "/")
if path == "" {
path = strings.TrimRight(parsed.Path, "/")
}
if isWS && !strings.Contains(path, "/acp") {
parsed.Path = path + "/acp"
} else if isHTTP && !strings.Contains(path, "/acp/rpc") {
parsed.Path = path + "/acp/rpc"
if isWS {
if path == "/acp" || strings.HasSuffix(path, "/acp") {
parsed.Path = path
} else {
parsed.Path = path + "/acp"
}
} else if isHTTP {
if path == "/acp/rpc" || strings.HasSuffix(path, "/acp/rpc") {
parsed.Path = path
} else if path == "/acp" || strings.HasSuffix(path, "/acp") {
parsed.Path = strings.TrimSuffix(path, "/acp") + "/acp/rpc"
} else {
parsed.Path = path + "/acp/rpc"
}
}
return parsed.String()
@ -61,7 +74,7 @@ type externalACPNotificationCollector struct {
func (c *externalACPNotificationCollector) observe(notification map[string]any) {
method := strings.TrimSpace(stringValue(notification["method"]))
if method != "session.update" && method != "acp.session.update" && method != "session/update" {
if method != "session.update" && method != "acp.session.update" && method != "session/update" && !strings.HasPrefix(method, "item/") && !strings.HasPrefix(method, "turn/") {
return
}
params := asMap(notification["params"])
@ -150,6 +163,9 @@ func extractExternalACPNotificationText(notification map[string]any) string {
if text := extractExternalACPTextValue(update); text != "" {
return text
}
if text := extractExternalACPTextValue(asMap(payload["item"])); text != "" {
return text
}
if text := extractExternalACPTextValue(payload); text != "" {
return text
}
@ -174,7 +190,7 @@ func extractExternalACPTextValue(value any) string {
return strings.TrimSpace(builder.String())
}
for key, child := range v {
if key == "text" || key == "message" || key == "content" || key == "delta" || key == "value" || key == "sessionId" || key == "session_id" || key == "sessionUpdate" || key == "session_update" {
if key == "text" || key == "message" || key == "content" || key == "delta" || key == "value" || key == "sessionId" || key == "session_id" || key == "sessionUpdate" || key == "session_update" || key == "threadId" || key == "turnId" || key == "itemId" {
continue
}
if text := extractExternalACPTextValue(child); text != "" {

View File

@ -9,6 +9,7 @@ import (
"net/http"
"net/url"
"strings"
"sync"
"time"
"github.com/gorilla/websocket"
@ -25,7 +26,11 @@ type externalACPCompat struct {
client *http.Client
}
type codexCompat struct{ *externalACPCompat }
type codexCompat struct {
*externalACPCompat
mu sync.Mutex
threads map[string]string
}
type opencodeCompat struct{ *externalACPCompat }
type geminiCompat struct{ *externalACPCompat }
type hermesCompat struct{ *externalACPCompat }
@ -49,7 +54,10 @@ func newProviderCompat(provider syncedProvider) ProviderCompat {
case "hermes":
return &hermesCompat{externalACPCompat: base}
default:
return &codexCompat{externalACPCompat: base}
return &codexCompat{
externalACPCompat: base,
threads: make(map[string]string),
}
}
}
@ -81,6 +89,249 @@ func (c *externalACPCompat) Probe(ctx context.Context) ProviderProbeResult {
return ProviderProbeResult{Available: true, Status: "ok"}
}
func (c *codexCompat) Probe(ctx context.Context) ProviderProbeResult {
_, err := c.codexCall(ctx, "initialize", codexInitializeParams(), nil)
if err != nil {
if strings.Contains(strings.ToLower(err.Error()), "already initialized") {
return ProviderProbeResult{Available: true, Status: "ok"}
}
return ProviderProbeResult{Available: false, Status: err.Error()}
}
return ProviderProbeResult{Available: true, Status: "ok"}
}
func (c *codexCompat) StartSession(ctx context.Context, sessionID string, threadID string, params map[string]any, sink SessionNotificationSink) (map[string]any, error) {
thread, err := c.codexCall(ctx, "thread/start", codexThreadStartParams(params), nil)
if err != nil {
return nil, err
}
codexThreadID := codexThreadIDFromResult(thread)
if codexThreadID == "" {
return nil, fmt.Errorf("codex thread/start response missing thread id")
}
c.rememberThread(sessionID, threadID, codexThreadID)
return c.startTurn(ctx, codexThreadID, params, sink)
}
func (c *codexCompat) SendMessage(ctx context.Context, sessionID string, threadID string, params map[string]any, sink SessionNotificationSink) (map[string]any, error) {
codexThreadID := c.lookupThread(sessionID, threadID)
if codexThreadID == "" {
codexThreadID = strings.TrimSpace(threadID)
}
if codexThreadID != "" {
thread, err := c.codexCall(ctx, "thread/resume", map[string]any{"threadId": codexThreadID}, nil)
if err != nil {
return nil, err
}
if resolved := codexThreadIDFromResult(thread); resolved != "" {
codexThreadID = resolved
c.rememberThread(sessionID, threadID, codexThreadID)
}
}
if codexThreadID == "" {
return c.StartSession(ctx, sessionID, threadID, params, sink)
}
return c.startTurn(ctx, codexThreadID, params, sink)
}
func (c *codexCompat) CloseSession(ctx context.Context, sessionID string) error {
c.mu.Lock()
delete(c.threads, sessionID)
c.mu.Unlock()
return nil
}
func (c *codexCompat) CancelSession(ctx context.Context, sessionID string) error {
codexThreadID := c.lookupThread(sessionID, "")
if codexThreadID == "" {
return nil
}
_, err := c.codexCall(ctx, "turn/interrupt", map[string]any{"threadId": codexThreadID}, nil)
return err
}
func (c *codexCompat) startTurn(ctx context.Context, codexThreadID string, params map[string]any, sink SessionNotificationSink) (map[string]any, error) {
result, err := c.codexCall(
ctx,
"turn/start",
map[string]any{
"threadId": codexThreadID,
"input": codexUserInput(params),
},
sink,
)
if err != nil {
return nil, err
}
if _, ok := result["output"]; !ok {
if summary := strings.TrimSpace(shared.StringArg(result, "summary", "")); summary != "" {
result["output"] = summary
}
}
result["providerThreadId"] = codexThreadID
return result, nil
}
func (c *codexCompat) codexCall(ctx context.Context, method string, params map[string]any, sink SessionNotificationSink) (map[string]any, error) {
if c.transport() != "ws" {
return c.rpcCall(ctx, method, params, sink)
}
return c.callWSRPCWithInitialize(ctx, method, params, sink)
}
func (c *codexCompat) callWSRPCWithInitialize(ctx context.Context, method string, params map[string]any, sink SessionNotificationSink) (map[string]any, error) {
headers := http.Header{}
if c.authHeader != "" {
headers.Set("Authorization", c.authHeader)
}
conn, _, err := websocket.DefaultDialer.DialContext(ctx, c.endpoint, headers)
if err != nil {
return nil, err
}
defer func() { _ = conn.Close() }()
if _, err := c.writeAndReadWSRPC(ctx, conn, "initialize", codexInitializeParams(), nil); err != nil {
return nil, err
}
return c.writeAndReadWSRPC(ctx, conn, method, params, sink)
}
func (c *codexCompat) writeAndReadWSRPC(ctx context.Context, conn *websocket.Conn, method string, params map[string]any, sink SessionNotificationSink) (map[string]any, error) {
requestID := fmt.Sprintf("req-%d", time.Now().UnixNano())
request := map[string]any{
"jsonrpc": "2.0",
"id": requestID,
"method": method,
"params": params,
}
if err := conn.WriteJSON(request); err != nil {
return nil, err
}
collector := &externalACPNotificationCollector{}
for {
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
}
_, payload, err := conn.ReadMessage()
if err != nil {
return nil, err
}
var decoded map[string]any
if err := json.Unmarshal(payload, &decoded); err != nil {
return nil, fmt.Errorf("failed to decode websocket rpc response: %w", err)
}
methodName := strings.TrimSpace(shared.StringArg(decoded, "method", ""))
if methodName != "" {
collector.observe(decoded)
if isExternalSessionUpdateMethod(methodName) && sink != nil {
update := shared.AsMap(decoded["params"])
if len(update) > 0 {
sink(update)
}
}
continue
}
if fmt.Sprintf("%v", decoded["id"]) != requestID {
continue
}
result, err := parseExternalRPCResult(decoded)
if err != nil {
return nil, err
}
return collector.apply(result), nil
}
}
func (c *codexCompat) rememberThread(sessionID string, threadID string, codexThreadID string) {
c.mu.Lock()
defer c.mu.Unlock()
if sessionID != "" {
c.threads[sessionID] = codexThreadID
}
if threadID != "" {
c.threads[threadID] = codexThreadID
}
}
func (c *codexCompat) lookupThread(sessionID string, threadID string) string {
c.mu.Lock()
defer c.mu.Unlock()
if sessionID != "" {
if value := strings.TrimSpace(c.threads[sessionID]); value != "" {
return value
}
}
if threadID != "" {
return strings.TrimSpace(c.threads[threadID])
}
return ""
}
func codexThreadStartParams(params map[string]any) map[string]any {
result := map[string]any{}
if cwd := strings.TrimSpace(shared.StringArg(params, "workingDirectory", "")); cwd != "" {
result["cwd"] = cwd
}
return result
}
func codexInitializeParams() map[string]any {
return map[string]any{
"clientInfo": map[string]any{
"name": "xworkmate-bridge",
"version": "1.1.0",
},
}
}
func codexUserInput(params map[string]any) []any {
input := map[string]any{
"type": "text",
"text": shared.StringArg(params, "taskPrompt", ""),
}
if attachments := anyList(params["attachments"]); len(attachments) > 0 {
input["attachments"] = attachments
}
return []any{input}
}
func codexThreadIDFromResult(result map[string]any) string {
for _, key := range []string{"threadId", "id"} {
if value := strings.TrimSpace(shared.StringArg(result, key, "")); value != "" {
return value
}
}
thread := shared.AsMap(result["thread"])
for _, key := range []string{"id", "threadId"} {
if value := strings.TrimSpace(shared.StringArg(thread, key, "")); value != "" {
return value
}
}
return ""
}
func anyList(value any) []any {
switch typed := value.(type) {
case []any:
return typed
case []map[string]any:
result := make([]any, 0, len(typed))
for _, item := range typed {
result = append(result, item)
}
return result
default:
return nil
}
}
func (c *externalACPCompat) StartSession(ctx context.Context, sessionID string, threadID string, params map[string]any, sink SessionNotificationSink) (map[string]any, error) {
return c.rpcCall(ctx, "session.start", params, sink)
}
@ -221,7 +472,7 @@ func isExternalSessionUpdateMethod(method string) bool {
case "session.update", "acp.session.update", "session/update":
return true
default:
return false
return strings.HasPrefix(method, "item/") || strings.HasPrefix(method, "turn/")
}
}

View File

@ -47,7 +47,6 @@ func (s *Server) executeSessionTask(t task) (map[string]any, *shared.RPCError) {
return s.handleRequest(t.req, t.notify)
}
func newExternalSingleAgentProvider(
t *testing.T,
providerID string,
@ -66,16 +65,24 @@ func newExternalSingleAgentProvider(
if err := json.NewDecoder(r.Body).Decode(&request); err != nil {
t.Fatalf("decode request: %v", err)
}
method := strings.TrimSpace(shared.StringArg(request, "method", ""))
result := map[string]any{
"success": true,
"output": output,
"turnId": "turn-" + providerID,
"provider": providerID,
"mode": "single-agent",
}
switch method {
case "thread/start", "thread/resume":
result = map[string]any{"id": "provider-thread-" + providerID}
case "turn/start":
result["summary"] = output
}
_ = json.NewEncoder(w).Encode(map[string]any{
"jsonrpc": "2.0",
"id": request["id"],
"result": map[string]any{
"success": true,
"output": output,
"turnId": "turn-" + providerID,
"provider": providerID,
"mode": "single-agent",
},
"result": result,
})
}))
}