fix: route xworkmate-bridge ACP providers with cwd fallback

This commit is contained in:
Haitao Pan 2026-04-09 13:27:30 +08:00
parent 3f41409363
commit 7cd0bdfbc4
6 changed files with 214 additions and 13 deletions

View File

@ -3,6 +3,8 @@ package acp
import (
"sort"
"strings"
"xworkmate-bridge/internal/shared"
)
type syncedProvider struct {
@ -72,6 +74,9 @@ func (s *Server) availableProviders() []string {
providers[provider.ProviderID] = struct{}{}
}
s.mu.Unlock()
for _, providerID := range shared.DetectACPProviders() {
providers[providerID] = struct{}{}
}
ordered := make([]string, 0, len(providers))
for providerID := range providers {
ordered = append(ordered, providerID)

View File

@ -6,6 +6,8 @@ import (
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strings"
"testing"
"xworkmate-bridge/internal/shared"
@ -17,6 +19,9 @@ func TestCapabilitiesIgnoreLocalProviderAutodetectUntilSync(t *testing.T) {
t.Fatalf("write fake provider: %v", err)
}
t.Setenv("ACP_CLAUDE_BIN", fakeProvider)
t.Setenv("ACP_CODEX_BIN", "")
t.Setenv("ACP_GEMINI_BIN", "")
t.Setenv("ACP_OPENCODE_BIN", "")
server := NewServer()
result, rpcErr := server.handleRequest(shared.RPCRequest{
@ -28,8 +33,15 @@ func TestCapabilitiesIgnoreLocalProviderAutodetectUntilSync(t *testing.T) {
}
providers, _ := result["providers"].([]string)
if len(providers) != 0 {
t.Fatalf("expected no providers before sync, got %#v", providers)
found := false
for _, provider := range providers {
if provider == "claude" {
found = true
break
}
}
if !found {
t.Fatalf("expected autodetected local provider before sync, got %#v", providers)
}
}
@ -207,3 +219,80 @@ func TestRunSingleAgentUsesFrozenExternalProviderParams(t *testing.T) {
t.Fatalf("expected frozen provider output, got %#v", result.response)
}
}
func TestRunSingleAgentFallsBackWorkingDirectoryToHome(t *testing.T) {
home := t.TempDir()
t.Setenv("HOME", home)
fakeOpencode := filepath.Join(t.TempDir(), "opencode")
if err := os.WriteFile(fakeOpencode, []byte("#!/bin/sh\necho local-ok\n"), 0o755); err != nil {
t.Fatalf("write fake opencode: %v", err)
}
t.Setenv("ACP_OPENCODE_BIN", fakeOpencode)
server := NewServer()
session := server.getOrCreateSession("session-local", "thread-local")
result := server.runSingleAgent(
context.Background(),
"session.start",
session,
map[string]any{
"provider": "opencode",
"taskPrompt": "hello",
"workingDirectory": filepath.Join(t.TempDir(), "missing"),
},
"turn-local",
func(map[string]any) {},
)
if result.err != nil {
t.Fatalf("expected success, got rpc error: %v", result.err)
}
if got := result.response["output"]; got != "local-ok" {
t.Fatalf("expected local provider output, got %#v", result.response)
}
if got := result.response["effectiveWorkingDirectory"]; got != home {
t.Fatalf("expected effectiveWorkingDirectory %q, got %#v", home, got)
}
}
func TestHandleRPCForwardsInboundBearerToExternalProvider(t *testing.T) {
externalServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if got := r.Header.Get("Authorization"); got != "Bearer bridge-token" {
t.Fatalf("expected forwarded bearer header, got %q", got)
}
_ = json.NewEncoder(w).Encode(map[string]any{
"jsonrpc": "2.0",
"id": "run-auth",
"result": map[string]any{
"success": true,
"output": "forwarded-auth-ok",
},
})
}))
defer externalServer.Close()
server := NewServer()
server.syncProviders([]syncedProvider{{
ProviderID: "codex",
Label: "Codex",
Endpoint: externalServer.URL,
Enabled: true,
}})
recorder := httptest.NewRecorder()
request := httptest.NewRequest(
http.MethodPost,
"http://127.0.0.1/acp/rpc",
strings.NewReader(`{"jsonrpc":"2.0","id":"run-auth","method":"session.start","params":{"sessionId":"s1","threadId":"t1","taskPrompt":"hello","workingDirectory":"`+t.TempDir()+`","routing":{"routingMode":"explicit","explicitExecutionTarget":"singleAgent","explicitProviderId":"codex"}}}`),
)
request.Header.Set("Content-Type", "application/json")
request.Header.Set("Authorization", "Bearer bridge-token")
server.HandleRPC(recorder, request)
if recorder.Code != http.StatusOK {
t.Fatalf("expected 200, got %d", recorder.Code)
}
if !strings.Contains(recorder.Body.String(), "forwarded-auth-ok") {
t.Fatalf("expected forwarded provider response, got %q", recorder.Body.String())
}
}

View File

@ -297,7 +297,16 @@ func TestExecuteSessionTaskExplicitRoutingDoesNotRecordProjectMemory(t *testing.
}
}
func TestExecuteSessionTaskExplicitProviderRequiresSyncedCatalog(t *testing.T) {
func TestExecuteSessionTaskExplicitProviderUsesAutodetectedLocalProvider(t *testing.T) {
fakeClaude := filepath.Join(t.TempDir(), "claude")
if err := os.WriteFile(fakeClaude, []byte("#!/bin/sh\nprintf 'autodetected-provider-ok\\n'"), 0o755); err != nil {
t.Fatalf("write fake claude: %v", err)
}
t.Setenv("ACP_CLAUDE_BIN", fakeClaude)
t.Setenv("ACP_CODEX_BIN", "")
t.Setenv("ACP_GEMINI_BIN", "")
t.Setenv("ACP_OPENCODE_BIN", "")
server := NewServer()
response, rpcErr := server.executeSessionTask(task{
req: shared.RPCRequest{
@ -315,13 +324,16 @@ func TestExecuteSessionTaskExplicitProviderRequiresSyncedCatalog(t *testing.T) {
},
})
if rpcErr != nil {
t.Fatalf("expected structured unavailable response, got rpc error: %v", rpcErr)
t.Fatalf("expected structured response, got rpc error: %v", rpcErr)
}
if got := response["unavailable"]; got != true {
t.Fatalf("expected unavailable response, got %#v", response)
if success, _ := response["success"].(bool); !success {
t.Fatalf("expected success response, got %#v", response)
}
if got := response["unavailableCode"]; got != "PROVIDER_UNAVAILABLE" {
t.Fatalf("expected PROVIDER_UNAVAILABLE, got %#v", response)
if got := response["provider"]; got != "claude" {
t.Fatalf("expected claude provider, got %#v", response)
}
if got := response["output"]; got != "autodetected-provider-ok" {
t.Fatalf("expected autodetected provider output, got %#v", response)
}
}

View File

@ -217,6 +217,10 @@ func (s *Server) HandleRPC(w http.ResponseWriter, r *http.Request) {
s.writeJSONError(w, nil, http.StatusBadRequest, -32700, err.Error())
return
}
request.Params = injectInboundAuthorizationHeader(
request.Params,
r.Header.Get("Authorization"),
)
accept := strings.ToLower(r.Header.Get("Accept"))
stream := strings.Contains(accept, "text/event-stream")
@ -269,6 +273,13 @@ func (s *Server) HandleRPC(w http.ResponseWriter, r *http.Request) {
_ = json.NewEncoder(w).Encode(shared.ResultEnvelope(request.ID, response))
}
func (s *Server) authorized(r *http.Request) bool {
if s == nil || s.authService == nil {
return false
}
return s.authService.ValidateAuthorizationHeader(r.Header.Get("Authorization"))
}
func (s *Server) handleRequest(
request shared.RPCRequest,
notify func(map[string]any),
@ -689,6 +700,10 @@ func (s *Server) runSingleAgent(
workingDirectory := strings.TrimSpace(
shared.StringArg(params, "workingDirectory", ""),
)
workingDirectory, effectiveWorkingDirectory := shared.NormalizeProviderWorkingDirectory(
provider,
workingDirectory,
)
model := strings.TrimSpace(shared.StringArg(params, "model", ""))
prompt := strings.TrimSpace(shared.StringArg(params, "taskPrompt", ""))
prompt = shared.AugmentPromptWithAttachments(prompt, params)
@ -715,6 +730,9 @@ func (s *Server) runSingleAgent(
if _, exists := result["turnId"]; !exists {
result["turnId"] = turnID
}
if _, exists := result["effectiveWorkingDirectory"]; !exists && effectiveWorkingDirectory != "" {
result["effectiveWorkingDirectory"] = effectiveWorkingDirectory
}
return taskResult{response: result}
}
s.emitSessionUpdate(session, notify, turnID, map[string]any{
@ -757,6 +775,9 @@ func (s *Server) runSingleAgent(
if _, exists := result["turnId"]; !exists {
result["turnId"] = turnID
}
if _, exists := result["effectiveWorkingDirectory"]; !exists && effectiveWorkingDirectory != "" {
result["effectiveWorkingDirectory"] = effectiveWorkingDirectory
}
return taskResult{response: result}
}
s.emitSessionUpdate(session, notify, turnID, map[string]any{
@ -827,11 +848,12 @@ func (s *Server) runSingleAgent(
return taskResult{
response: map[string]any{
"success": true,
"output": output,
"turnId": turnID,
"mode": "single-agent",
"provider": provider,
"success": true,
"output": output,
"turnId": turnID,
"mode": "single-agent",
"provider": provider,
"effectiveWorkingDirectory": effectiveWorkingDirectory,
},
}
}

View File

@ -8,6 +8,7 @@ import (
"fmt"
"io"
"net/http"
"os"
"os/exec"
"sort"
"strings"
@ -84,6 +85,40 @@ func RunProviderCommand(
return output, nil
}
func NormalizeProviderWorkingDirectory(provider, requested string) (string, string) {
requested = strings.TrimSpace(requested)
if requested == "" {
return "", ""
}
switch strings.TrimSpace(strings.ToLower(provider)) {
case "codex", "opencode":
default:
return requested, requested
}
if canAccessWorkingDirectory(requested) {
return requested, requested
}
home, err := os.UserHomeDir()
if err != nil {
return requested, requested
}
home = strings.TrimSpace(home)
if home == "" {
return requested, requested
}
return home, home
}
func canAccessWorkingDirectory(path string) bool {
info, err := os.Stat(path)
if err != nil || !info.IsDir() {
return false
}
cmd := exec.Command("pwd")
cmd.Dir = path
return cmd.Run() == nil
}
func ResolveProviderCommand(
provider,
model,

View File

@ -0,0 +1,38 @@
package shared
import (
"path/filepath"
"testing"
)
func TestNormalizeProviderWorkingDirectoryPreservesAccessibleDir(t *testing.T) {
accessible := t.TempDir()
got, effective := NormalizeProviderWorkingDirectory("opencode", accessible)
if got != accessible || effective != accessible {
t.Fatalf("expected accessible dir preserved, got %q %q", got, effective)
}
}
func TestNormalizeProviderWorkingDirectoryFallsBackToHome(t *testing.T) {
home := t.TempDir()
t.Setenv("HOME", home)
missing := filepath.Join(t.TempDir(), "missing")
got, effective := NormalizeProviderWorkingDirectory("codex", missing)
if got != home || effective != home {
t.Fatalf("expected fallback to home %q, got %q %q", home, got, effective)
}
}
func TestNormalizeProviderWorkingDirectorySkipsUnknownProvider(t *testing.T) {
dir := t.TempDir()
got, effective := NormalizeProviderWorkingDirectory("claude", dir)
if got != dir || effective != dir {
t.Fatalf("expected unknown provider to keep dir, got %q %q", got, effective)
}
}