Fix ACP provider endpoint routing
This commit is contained in:
parent
17c0fa6f16
commit
cce9833689
@ -1,6 +1,10 @@
|
||||
package acp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"testing"
|
||||
)
|
||||
@ -52,13 +56,37 @@ func TestResolveSingleAgentForwardEndpointManual(t *testing.T) {
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "preserves upstream endpoint",
|
||||
name: "preserves http rpc endpoint",
|
||||
provider: syncedProvider{
|
||||
ProviderID: "custom",
|
||||
Endpoint: "https://upstream-provider.example.com/acp/rpc",
|
||||
},
|
||||
want: "https://upstream-provider.example.com/acp/rpc",
|
||||
},
|
||||
{
|
||||
name: "normalizes http acp endpoint to rpc endpoint",
|
||||
provider: syncedProvider{
|
||||
ProviderID: "opencode",
|
||||
Endpoint: "http://127.0.0.1:39992/acp",
|
||||
},
|
||||
want: "http://127.0.0.1:39992/acp/rpc",
|
||||
},
|
||||
{
|
||||
name: "normalizes websocket opencode endpoint to http rpc endpoint",
|
||||
provider: syncedProvider{
|
||||
ProviderID: "opencode",
|
||||
Endpoint: "ws://127.0.0.1:39992/acp",
|
||||
},
|
||||
want: "http://127.0.0.1:39992/acp/rpc",
|
||||
},
|
||||
{
|
||||
name: "does not duplicate nested acp path",
|
||||
provider: syncedProvider{
|
||||
ProviderID: "opencode",
|
||||
Endpoint: "http://127.0.0.1:39992/acp",
|
||||
},
|
||||
want: "http://127.0.0.1:39992/acp/rpc",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
@ -92,6 +120,64 @@ func TestNormalizeAuthorizationHeader(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestCodexCompatTranslatesSessionLifecycleToThreadAndTurnRPC(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var methods []string
|
||||
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
defer func() {
|
||||
_ = r.Body.Close()
|
||||
}()
|
||||
var request map[string]any
|
||||
if err := json.NewDecoder(r.Body).Decode(&request); err != nil {
|
||||
t.Fatalf("decode request: %v", err)
|
||||
}
|
||||
method := stringValue(request["method"])
|
||||
methods = append(methods, method)
|
||||
result := map[string]any{}
|
||||
switch method {
|
||||
case "thread/start":
|
||||
result["id"] = "codex-thread-1"
|
||||
case "turn/start":
|
||||
result["output"] = "pong"
|
||||
default:
|
||||
t.Fatalf("unexpected codex upstream method %q", method)
|
||||
}
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"jsonrpc": "2.0",
|
||||
"id": request["id"],
|
||||
"result": result,
|
||||
})
|
||||
}))
|
||||
defer upstream.Close()
|
||||
|
||||
compat := newProviderCompat(syncedProvider{
|
||||
ProviderID: "codex",
|
||||
Label: "Codex",
|
||||
Endpoint: upstream.URL,
|
||||
Enabled: true,
|
||||
})
|
||||
result, err := compat.StartSession(
|
||||
context.Background(),
|
||||
"session-1",
|
||||
"thread-1",
|
||||
map[string]any{
|
||||
"taskPrompt": "Reply with exactly pong",
|
||||
"workingDirectory": t.TempDir(),
|
||||
},
|
||||
nil,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("StartSession failed: %v", err)
|
||||
}
|
||||
if got := result["output"]; got != "pong" {
|
||||
t.Fatalf("expected pong output, got %#v", result)
|
||||
}
|
||||
if len(methods) != 2 || methods[0] != "thread/start" || methods[1] != "turn/start" {
|
||||
t.Fatalf("expected thread/start then turn/start, got %#v", methods)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExternalACPNotificationCollectorExtractsNestedSessionUpdateText(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
@ -31,12 +31,25 @@ func resolveSingleAgentForwardEndpoint(provider syncedProvider) string {
|
||||
isWS := strings.HasPrefix(parsed.Scheme, "ws")
|
||||
isHTTP := strings.HasPrefix(parsed.Scheme, "http")
|
||||
|
||||
path := strings.TrimRight(parsed.Path, "/")
|
||||
path := strings.TrimRight(parsed.EscapedPath(), "/")
|
||||
if path == "" {
|
||||
path = strings.TrimRight(parsed.Path, "/")
|
||||
}
|
||||
|
||||
if isWS && !strings.Contains(path, "/acp") {
|
||||
parsed.Path = path + "/acp"
|
||||
} else if isHTTP && !strings.Contains(path, "/acp/rpc") {
|
||||
parsed.Path = path + "/acp/rpc"
|
||||
if isWS {
|
||||
if path == "/acp" || strings.HasSuffix(path, "/acp") {
|
||||
parsed.Path = path
|
||||
} else {
|
||||
parsed.Path = path + "/acp"
|
||||
}
|
||||
} else if isHTTP {
|
||||
if path == "/acp/rpc" || strings.HasSuffix(path, "/acp/rpc") {
|
||||
parsed.Path = path
|
||||
} else if path == "/acp" || strings.HasSuffix(path, "/acp") {
|
||||
parsed.Path = strings.TrimSuffix(path, "/acp") + "/acp/rpc"
|
||||
} else {
|
||||
parsed.Path = path + "/acp/rpc"
|
||||
}
|
||||
}
|
||||
|
||||
return parsed.String()
|
||||
@ -61,7 +74,7 @@ type externalACPNotificationCollector struct {
|
||||
|
||||
func (c *externalACPNotificationCollector) observe(notification map[string]any) {
|
||||
method := strings.TrimSpace(stringValue(notification["method"]))
|
||||
if method != "session.update" && method != "acp.session.update" && method != "session/update" {
|
||||
if method != "session.update" && method != "acp.session.update" && method != "session/update" && !strings.HasPrefix(method, "item/") && !strings.HasPrefix(method, "turn/") {
|
||||
return
|
||||
}
|
||||
params := asMap(notification["params"])
|
||||
@ -150,6 +163,9 @@ func extractExternalACPNotificationText(notification map[string]any) string {
|
||||
if text := extractExternalACPTextValue(update); text != "" {
|
||||
return text
|
||||
}
|
||||
if text := extractExternalACPTextValue(asMap(payload["item"])); text != "" {
|
||||
return text
|
||||
}
|
||||
if text := extractExternalACPTextValue(payload); text != "" {
|
||||
return text
|
||||
}
|
||||
@ -174,7 +190,7 @@ func extractExternalACPTextValue(value any) string {
|
||||
return strings.TrimSpace(builder.String())
|
||||
}
|
||||
for key, child := range v {
|
||||
if key == "text" || key == "message" || key == "content" || key == "delta" || key == "value" || key == "sessionId" || key == "session_id" || key == "sessionUpdate" || key == "session_update" {
|
||||
if key == "text" || key == "message" || key == "content" || key == "delta" || key == "value" || key == "sessionId" || key == "session_id" || key == "sessionUpdate" || key == "session_update" || key == "threadId" || key == "turnId" || key == "itemId" {
|
||||
continue
|
||||
}
|
||||
if text := extractExternalACPTextValue(child); text != "" {
|
||||
|
||||
@ -9,6 +9,7 @@ import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
@ -25,7 +26,11 @@ type externalACPCompat struct {
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
type codexCompat struct{ *externalACPCompat }
|
||||
type codexCompat struct {
|
||||
*externalACPCompat
|
||||
mu sync.Mutex
|
||||
threads map[string]string
|
||||
}
|
||||
type opencodeCompat struct{ *externalACPCompat }
|
||||
type geminiCompat struct{ *externalACPCompat }
|
||||
type hermesCompat struct{ *externalACPCompat }
|
||||
@ -49,7 +54,10 @@ func newProviderCompat(provider syncedProvider) ProviderCompat {
|
||||
case "hermes":
|
||||
return &hermesCompat{externalACPCompat: base}
|
||||
default:
|
||||
return &codexCompat{externalACPCompat: base}
|
||||
return &codexCompat{
|
||||
externalACPCompat: base,
|
||||
threads: make(map[string]string),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -81,6 +89,249 @@ func (c *externalACPCompat) Probe(ctx context.Context) ProviderProbeResult {
|
||||
return ProviderProbeResult{Available: true, Status: "ok"}
|
||||
}
|
||||
|
||||
func (c *codexCompat) Probe(ctx context.Context) ProviderProbeResult {
|
||||
_, err := c.codexCall(ctx, "initialize", codexInitializeParams(), nil)
|
||||
if err != nil {
|
||||
if strings.Contains(strings.ToLower(err.Error()), "already initialized") {
|
||||
return ProviderProbeResult{Available: true, Status: "ok"}
|
||||
}
|
||||
return ProviderProbeResult{Available: false, Status: err.Error()}
|
||||
}
|
||||
return ProviderProbeResult{Available: true, Status: "ok"}
|
||||
}
|
||||
|
||||
func (c *codexCompat) StartSession(ctx context.Context, sessionID string, threadID string, params map[string]any, sink SessionNotificationSink) (map[string]any, error) {
|
||||
thread, err := c.codexCall(ctx, "thread/start", codexThreadStartParams(params), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
codexThreadID := codexThreadIDFromResult(thread)
|
||||
if codexThreadID == "" {
|
||||
return nil, fmt.Errorf("codex thread/start response missing thread id")
|
||||
}
|
||||
c.rememberThread(sessionID, threadID, codexThreadID)
|
||||
return c.startTurn(ctx, codexThreadID, params, sink)
|
||||
}
|
||||
|
||||
func (c *codexCompat) SendMessage(ctx context.Context, sessionID string, threadID string, params map[string]any, sink SessionNotificationSink) (map[string]any, error) {
|
||||
codexThreadID := c.lookupThread(sessionID, threadID)
|
||||
if codexThreadID == "" {
|
||||
codexThreadID = strings.TrimSpace(threadID)
|
||||
}
|
||||
if codexThreadID != "" {
|
||||
thread, err := c.codexCall(ctx, "thread/resume", map[string]any{"threadId": codexThreadID}, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resolved := codexThreadIDFromResult(thread); resolved != "" {
|
||||
codexThreadID = resolved
|
||||
c.rememberThread(sessionID, threadID, codexThreadID)
|
||||
}
|
||||
}
|
||||
if codexThreadID == "" {
|
||||
return c.StartSession(ctx, sessionID, threadID, params, sink)
|
||||
}
|
||||
return c.startTurn(ctx, codexThreadID, params, sink)
|
||||
}
|
||||
|
||||
func (c *codexCompat) CloseSession(ctx context.Context, sessionID string) error {
|
||||
c.mu.Lock()
|
||||
delete(c.threads, sessionID)
|
||||
c.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *codexCompat) CancelSession(ctx context.Context, sessionID string) error {
|
||||
codexThreadID := c.lookupThread(sessionID, "")
|
||||
if codexThreadID == "" {
|
||||
return nil
|
||||
}
|
||||
_, err := c.codexCall(ctx, "turn/interrupt", map[string]any{"threadId": codexThreadID}, nil)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *codexCompat) startTurn(ctx context.Context, codexThreadID string, params map[string]any, sink SessionNotificationSink) (map[string]any, error) {
|
||||
result, err := c.codexCall(
|
||||
ctx,
|
||||
"turn/start",
|
||||
map[string]any{
|
||||
"threadId": codexThreadID,
|
||||
"input": codexUserInput(params),
|
||||
},
|
||||
sink,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if _, ok := result["output"]; !ok {
|
||||
if summary := strings.TrimSpace(shared.StringArg(result, "summary", "")); summary != "" {
|
||||
result["output"] = summary
|
||||
}
|
||||
}
|
||||
result["providerThreadId"] = codexThreadID
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (c *codexCompat) codexCall(ctx context.Context, method string, params map[string]any, sink SessionNotificationSink) (map[string]any, error) {
|
||||
if c.transport() != "ws" {
|
||||
return c.rpcCall(ctx, method, params, sink)
|
||||
}
|
||||
return c.callWSRPCWithInitialize(ctx, method, params, sink)
|
||||
}
|
||||
|
||||
func (c *codexCompat) callWSRPCWithInitialize(ctx context.Context, method string, params map[string]any, sink SessionNotificationSink) (map[string]any, error) {
|
||||
headers := http.Header{}
|
||||
if c.authHeader != "" {
|
||||
headers.Set("Authorization", c.authHeader)
|
||||
}
|
||||
conn, _, err := websocket.DefaultDialer.DialContext(ctx, c.endpoint, headers)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = conn.Close() }()
|
||||
|
||||
if _, err := c.writeAndReadWSRPC(ctx, conn, "initialize", codexInitializeParams(), nil); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return c.writeAndReadWSRPC(ctx, conn, method, params, sink)
|
||||
}
|
||||
|
||||
func (c *codexCompat) writeAndReadWSRPC(ctx context.Context, conn *websocket.Conn, method string, params map[string]any, sink SessionNotificationSink) (map[string]any, error) {
|
||||
requestID := fmt.Sprintf("req-%d", time.Now().UnixNano())
|
||||
request := map[string]any{
|
||||
"jsonrpc": "2.0",
|
||||
"id": requestID,
|
||||
"method": method,
|
||||
"params": params,
|
||||
}
|
||||
if err := conn.WriteJSON(request); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
collector := &externalACPNotificationCollector{}
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
default:
|
||||
}
|
||||
_, payload, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var decoded map[string]any
|
||||
if err := json.Unmarshal(payload, &decoded); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode websocket rpc response: %w", err)
|
||||
}
|
||||
|
||||
methodName := strings.TrimSpace(shared.StringArg(decoded, "method", ""))
|
||||
if methodName != "" {
|
||||
collector.observe(decoded)
|
||||
if isExternalSessionUpdateMethod(methodName) && sink != nil {
|
||||
update := shared.AsMap(decoded["params"])
|
||||
if len(update) > 0 {
|
||||
sink(update)
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if fmt.Sprintf("%v", decoded["id"]) != requestID {
|
||||
continue
|
||||
}
|
||||
|
||||
result, err := parseExternalRPCResult(decoded)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return collector.apply(result), nil
|
||||
}
|
||||
}
|
||||
|
||||
func (c *codexCompat) rememberThread(sessionID string, threadID string, codexThreadID string) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if sessionID != "" {
|
||||
c.threads[sessionID] = codexThreadID
|
||||
}
|
||||
if threadID != "" {
|
||||
c.threads[threadID] = codexThreadID
|
||||
}
|
||||
}
|
||||
|
||||
func (c *codexCompat) lookupThread(sessionID string, threadID string) string {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if sessionID != "" {
|
||||
if value := strings.TrimSpace(c.threads[sessionID]); value != "" {
|
||||
return value
|
||||
}
|
||||
}
|
||||
if threadID != "" {
|
||||
return strings.TrimSpace(c.threads[threadID])
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func codexThreadStartParams(params map[string]any) map[string]any {
|
||||
result := map[string]any{}
|
||||
if cwd := strings.TrimSpace(shared.StringArg(params, "workingDirectory", "")); cwd != "" {
|
||||
result["cwd"] = cwd
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func codexInitializeParams() map[string]any {
|
||||
return map[string]any{
|
||||
"clientInfo": map[string]any{
|
||||
"name": "xworkmate-bridge",
|
||||
"version": "1.1.0",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func codexUserInput(params map[string]any) []any {
|
||||
input := map[string]any{
|
||||
"type": "text",
|
||||
"text": shared.StringArg(params, "taskPrompt", ""),
|
||||
}
|
||||
if attachments := anyList(params["attachments"]); len(attachments) > 0 {
|
||||
input["attachments"] = attachments
|
||||
}
|
||||
return []any{input}
|
||||
}
|
||||
|
||||
func codexThreadIDFromResult(result map[string]any) string {
|
||||
for _, key := range []string{"threadId", "id"} {
|
||||
if value := strings.TrimSpace(shared.StringArg(result, key, "")); value != "" {
|
||||
return value
|
||||
}
|
||||
}
|
||||
thread := shared.AsMap(result["thread"])
|
||||
for _, key := range []string{"id", "threadId"} {
|
||||
if value := strings.TrimSpace(shared.StringArg(thread, key, "")); value != "" {
|
||||
return value
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func anyList(value any) []any {
|
||||
switch typed := value.(type) {
|
||||
case []any:
|
||||
return typed
|
||||
case []map[string]any:
|
||||
result := make([]any, 0, len(typed))
|
||||
for _, item := range typed {
|
||||
result = append(result, item)
|
||||
}
|
||||
return result
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (c *externalACPCompat) StartSession(ctx context.Context, sessionID string, threadID string, params map[string]any, sink SessionNotificationSink) (map[string]any, error) {
|
||||
return c.rpcCall(ctx, "session.start", params, sink)
|
||||
}
|
||||
@ -221,7 +472,7 @@ func isExternalSessionUpdateMethod(method string) bool {
|
||||
case "session.update", "acp.session.update", "session/update":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
return strings.HasPrefix(method, "item/") || strings.HasPrefix(method, "turn/")
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -47,7 +47,6 @@ func (s *Server) executeSessionTask(t task) (map[string]any, *shared.RPCError) {
|
||||
return s.handleRequest(t.req, t.notify)
|
||||
}
|
||||
|
||||
|
||||
func newExternalSingleAgentProvider(
|
||||
t *testing.T,
|
||||
providerID string,
|
||||
@ -66,16 +65,24 @@ func newExternalSingleAgentProvider(
|
||||
if err := json.NewDecoder(r.Body).Decode(&request); err != nil {
|
||||
t.Fatalf("decode request: %v", err)
|
||||
}
|
||||
method := strings.TrimSpace(shared.StringArg(request, "method", ""))
|
||||
result := map[string]any{
|
||||
"success": true,
|
||||
"output": output,
|
||||
"turnId": "turn-" + providerID,
|
||||
"provider": providerID,
|
||||
"mode": "single-agent",
|
||||
}
|
||||
switch method {
|
||||
case "thread/start", "thread/resume":
|
||||
result = map[string]any{"id": "provider-thread-" + providerID}
|
||||
case "turn/start":
|
||||
result["summary"] = output
|
||||
}
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"jsonrpc": "2.0",
|
||||
"id": request["id"],
|
||||
"result": map[string]any{
|
||||
"success": true,
|
||||
"output": output,
|
||||
"turnId": "turn-" + providerID,
|
||||
"provider": providerID,
|
||||
"mode": "single-agent",
|
||||
},
|
||||
"result": result,
|
||||
})
|
||||
}))
|
||||
}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user