fix: route xworkmate-bridge ACP providers with cwd fallback
This commit is contained in:
parent
3f41409363
commit
7cd0bdfbc4
@ -3,6 +3,8 @@ package acp
|
||||
import (
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"xworkmate-bridge/internal/shared"
|
||||
)
|
||||
|
||||
type syncedProvider struct {
|
||||
@ -72,6 +74,9 @@ func (s *Server) availableProviders() []string {
|
||||
providers[provider.ProviderID] = struct{}{}
|
||||
}
|
||||
s.mu.Unlock()
|
||||
for _, providerID := range shared.DetectACPProviders() {
|
||||
providers[providerID] = struct{}{}
|
||||
}
|
||||
ordered := make([]string, 0, len(providers))
|
||||
for providerID := range providers {
|
||||
ordered = append(ordered, providerID)
|
||||
|
||||
@ -6,6 +6,8 @@ import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"xworkmate-bridge/internal/shared"
|
||||
@ -17,6 +19,9 @@ func TestCapabilitiesIgnoreLocalProviderAutodetectUntilSync(t *testing.T) {
|
||||
t.Fatalf("write fake provider: %v", err)
|
||||
}
|
||||
t.Setenv("ACP_CLAUDE_BIN", fakeProvider)
|
||||
t.Setenv("ACP_CODEX_BIN", "")
|
||||
t.Setenv("ACP_GEMINI_BIN", "")
|
||||
t.Setenv("ACP_OPENCODE_BIN", "")
|
||||
|
||||
server := NewServer()
|
||||
result, rpcErr := server.handleRequest(shared.RPCRequest{
|
||||
@ -28,8 +33,15 @@ func TestCapabilitiesIgnoreLocalProviderAutodetectUntilSync(t *testing.T) {
|
||||
}
|
||||
|
||||
providers, _ := result["providers"].([]string)
|
||||
if len(providers) != 0 {
|
||||
t.Fatalf("expected no providers before sync, got %#v", providers)
|
||||
found := false
|
||||
for _, provider := range providers {
|
||||
if provider == "claude" {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Fatalf("expected autodetected local provider before sync, got %#v", providers)
|
||||
}
|
||||
}
|
||||
|
||||
@ -207,3 +219,80 @@ func TestRunSingleAgentUsesFrozenExternalProviderParams(t *testing.T) {
|
||||
t.Fatalf("expected frozen provider output, got %#v", result.response)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunSingleAgentFallsBackWorkingDirectoryToHome(t *testing.T) {
|
||||
home := t.TempDir()
|
||||
t.Setenv("HOME", home)
|
||||
fakeOpencode := filepath.Join(t.TempDir(), "opencode")
|
||||
if err := os.WriteFile(fakeOpencode, []byte("#!/bin/sh\necho local-ok\n"), 0o755); err != nil {
|
||||
t.Fatalf("write fake opencode: %v", err)
|
||||
}
|
||||
t.Setenv("ACP_OPENCODE_BIN", fakeOpencode)
|
||||
|
||||
server := NewServer()
|
||||
session := server.getOrCreateSession("session-local", "thread-local")
|
||||
result := server.runSingleAgent(
|
||||
context.Background(),
|
||||
"session.start",
|
||||
session,
|
||||
map[string]any{
|
||||
"provider": "opencode",
|
||||
"taskPrompt": "hello",
|
||||
"workingDirectory": filepath.Join(t.TempDir(), "missing"),
|
||||
},
|
||||
"turn-local",
|
||||
func(map[string]any) {},
|
||||
)
|
||||
if result.err != nil {
|
||||
t.Fatalf("expected success, got rpc error: %v", result.err)
|
||||
}
|
||||
if got := result.response["output"]; got != "local-ok" {
|
||||
t.Fatalf("expected local provider output, got %#v", result.response)
|
||||
}
|
||||
if got := result.response["effectiveWorkingDirectory"]; got != home {
|
||||
t.Fatalf("expected effectiveWorkingDirectory %q, got %#v", home, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleRPCForwardsInboundBearerToExternalProvider(t *testing.T) {
|
||||
externalServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if got := r.Header.Get("Authorization"); got != "Bearer bridge-token" {
|
||||
t.Fatalf("expected forwarded bearer header, got %q", got)
|
||||
}
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"jsonrpc": "2.0",
|
||||
"id": "run-auth",
|
||||
"result": map[string]any{
|
||||
"success": true,
|
||||
"output": "forwarded-auth-ok",
|
||||
},
|
||||
})
|
||||
}))
|
||||
defer externalServer.Close()
|
||||
|
||||
server := NewServer()
|
||||
server.syncProviders([]syncedProvider{{
|
||||
ProviderID: "codex",
|
||||
Label: "Codex",
|
||||
Endpoint: externalServer.URL,
|
||||
Enabled: true,
|
||||
}})
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
request := httptest.NewRequest(
|
||||
http.MethodPost,
|
||||
"http://127.0.0.1/acp/rpc",
|
||||
strings.NewReader(`{"jsonrpc":"2.0","id":"run-auth","method":"session.start","params":{"sessionId":"s1","threadId":"t1","taskPrompt":"hello","workingDirectory":"`+t.TempDir()+`","routing":{"routingMode":"explicit","explicitExecutionTarget":"singleAgent","explicitProviderId":"codex"}}}`),
|
||||
)
|
||||
request.Header.Set("Content-Type", "application/json")
|
||||
request.Header.Set("Authorization", "Bearer bridge-token")
|
||||
|
||||
server.HandleRPC(recorder, request)
|
||||
|
||||
if recorder.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d", recorder.Code)
|
||||
}
|
||||
if !strings.Contains(recorder.Body.String(), "forwarded-auth-ok") {
|
||||
t.Fatalf("expected forwarded provider response, got %q", recorder.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
@ -297,7 +297,16 @@ func TestExecuteSessionTaskExplicitRoutingDoesNotRecordProjectMemory(t *testing.
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteSessionTaskExplicitProviderRequiresSyncedCatalog(t *testing.T) {
|
||||
func TestExecuteSessionTaskExplicitProviderUsesAutodetectedLocalProvider(t *testing.T) {
|
||||
fakeClaude := filepath.Join(t.TempDir(), "claude")
|
||||
if err := os.WriteFile(fakeClaude, []byte("#!/bin/sh\nprintf 'autodetected-provider-ok\\n'"), 0o755); err != nil {
|
||||
t.Fatalf("write fake claude: %v", err)
|
||||
}
|
||||
t.Setenv("ACP_CLAUDE_BIN", fakeClaude)
|
||||
t.Setenv("ACP_CODEX_BIN", "")
|
||||
t.Setenv("ACP_GEMINI_BIN", "")
|
||||
t.Setenv("ACP_OPENCODE_BIN", "")
|
||||
|
||||
server := NewServer()
|
||||
response, rpcErr := server.executeSessionTask(task{
|
||||
req: shared.RPCRequest{
|
||||
@ -315,13 +324,16 @@ func TestExecuteSessionTaskExplicitProviderRequiresSyncedCatalog(t *testing.T) {
|
||||
},
|
||||
})
|
||||
if rpcErr != nil {
|
||||
t.Fatalf("expected structured unavailable response, got rpc error: %v", rpcErr)
|
||||
t.Fatalf("expected structured response, got rpc error: %v", rpcErr)
|
||||
}
|
||||
if got := response["unavailable"]; got != true {
|
||||
t.Fatalf("expected unavailable response, got %#v", response)
|
||||
if success, _ := response["success"].(bool); !success {
|
||||
t.Fatalf("expected success response, got %#v", response)
|
||||
}
|
||||
if got := response["unavailableCode"]; got != "PROVIDER_UNAVAILABLE" {
|
||||
t.Fatalf("expected PROVIDER_UNAVAILABLE, got %#v", response)
|
||||
if got := response["provider"]; got != "claude" {
|
||||
t.Fatalf("expected claude provider, got %#v", response)
|
||||
}
|
||||
if got := response["output"]; got != "autodetected-provider-ok" {
|
||||
t.Fatalf("expected autodetected provider output, got %#v", response)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -217,6 +217,10 @@ func (s *Server) HandleRPC(w http.ResponseWriter, r *http.Request) {
|
||||
s.writeJSONError(w, nil, http.StatusBadRequest, -32700, err.Error())
|
||||
return
|
||||
}
|
||||
request.Params = injectInboundAuthorizationHeader(
|
||||
request.Params,
|
||||
r.Header.Get("Authorization"),
|
||||
)
|
||||
|
||||
accept := strings.ToLower(r.Header.Get("Accept"))
|
||||
stream := strings.Contains(accept, "text/event-stream")
|
||||
@ -269,6 +273,13 @@ func (s *Server) HandleRPC(w http.ResponseWriter, r *http.Request) {
|
||||
_ = json.NewEncoder(w).Encode(shared.ResultEnvelope(request.ID, response))
|
||||
}
|
||||
|
||||
func (s *Server) authorized(r *http.Request) bool {
|
||||
if s == nil || s.authService == nil {
|
||||
return false
|
||||
}
|
||||
return s.authService.ValidateAuthorizationHeader(r.Header.Get("Authorization"))
|
||||
}
|
||||
|
||||
func (s *Server) handleRequest(
|
||||
request shared.RPCRequest,
|
||||
notify func(map[string]any),
|
||||
@ -689,6 +700,10 @@ func (s *Server) runSingleAgent(
|
||||
workingDirectory := strings.TrimSpace(
|
||||
shared.StringArg(params, "workingDirectory", ""),
|
||||
)
|
||||
workingDirectory, effectiveWorkingDirectory := shared.NormalizeProviderWorkingDirectory(
|
||||
provider,
|
||||
workingDirectory,
|
||||
)
|
||||
model := strings.TrimSpace(shared.StringArg(params, "model", ""))
|
||||
prompt := strings.TrimSpace(shared.StringArg(params, "taskPrompt", ""))
|
||||
prompt = shared.AugmentPromptWithAttachments(prompt, params)
|
||||
@ -715,6 +730,9 @@ func (s *Server) runSingleAgent(
|
||||
if _, exists := result["turnId"]; !exists {
|
||||
result["turnId"] = turnID
|
||||
}
|
||||
if _, exists := result["effectiveWorkingDirectory"]; !exists && effectiveWorkingDirectory != "" {
|
||||
result["effectiveWorkingDirectory"] = effectiveWorkingDirectory
|
||||
}
|
||||
return taskResult{response: result}
|
||||
}
|
||||
s.emitSessionUpdate(session, notify, turnID, map[string]any{
|
||||
@ -757,6 +775,9 @@ func (s *Server) runSingleAgent(
|
||||
if _, exists := result["turnId"]; !exists {
|
||||
result["turnId"] = turnID
|
||||
}
|
||||
if _, exists := result["effectiveWorkingDirectory"]; !exists && effectiveWorkingDirectory != "" {
|
||||
result["effectiveWorkingDirectory"] = effectiveWorkingDirectory
|
||||
}
|
||||
return taskResult{response: result}
|
||||
}
|
||||
s.emitSessionUpdate(session, notify, turnID, map[string]any{
|
||||
@ -827,11 +848,12 @@ func (s *Server) runSingleAgent(
|
||||
|
||||
return taskResult{
|
||||
response: map[string]any{
|
||||
"success": true,
|
||||
"output": output,
|
||||
"turnId": turnID,
|
||||
"mode": "single-agent",
|
||||
"provider": provider,
|
||||
"success": true,
|
||||
"output": output,
|
||||
"turnId": turnID,
|
||||
"mode": "single-agent",
|
||||
"provider": provider,
|
||||
"effectiveWorkingDirectory": effectiveWorkingDirectory,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@ -8,6 +8,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"sort"
|
||||
"strings"
|
||||
@ -84,6 +85,40 @@ func RunProviderCommand(
|
||||
return output, nil
|
||||
}
|
||||
|
||||
func NormalizeProviderWorkingDirectory(provider, requested string) (string, string) {
|
||||
requested = strings.TrimSpace(requested)
|
||||
if requested == "" {
|
||||
return "", ""
|
||||
}
|
||||
switch strings.TrimSpace(strings.ToLower(provider)) {
|
||||
case "codex", "opencode":
|
||||
default:
|
||||
return requested, requested
|
||||
}
|
||||
if canAccessWorkingDirectory(requested) {
|
||||
return requested, requested
|
||||
}
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return requested, requested
|
||||
}
|
||||
home = strings.TrimSpace(home)
|
||||
if home == "" {
|
||||
return requested, requested
|
||||
}
|
||||
return home, home
|
||||
}
|
||||
|
||||
func canAccessWorkingDirectory(path string) bool {
|
||||
info, err := os.Stat(path)
|
||||
if err != nil || !info.IsDir() {
|
||||
return false
|
||||
}
|
||||
cmd := exec.Command("pwd")
|
||||
cmd.Dir = path
|
||||
return cmd.Run() == nil
|
||||
}
|
||||
|
||||
func ResolveProviderCommand(
|
||||
provider,
|
||||
model,
|
||||
|
||||
38
internal/shared/tools_test.go
Normal file
38
internal/shared/tools_test.go
Normal file
@ -0,0 +1,38 @@
|
||||
package shared
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNormalizeProviderWorkingDirectoryPreservesAccessibleDir(t *testing.T) {
|
||||
accessible := t.TempDir()
|
||||
|
||||
got, effective := NormalizeProviderWorkingDirectory("opencode", accessible)
|
||||
|
||||
if got != accessible || effective != accessible {
|
||||
t.Fatalf("expected accessible dir preserved, got %q %q", got, effective)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeProviderWorkingDirectoryFallsBackToHome(t *testing.T) {
|
||||
home := t.TempDir()
|
||||
t.Setenv("HOME", home)
|
||||
missing := filepath.Join(t.TempDir(), "missing")
|
||||
|
||||
got, effective := NormalizeProviderWorkingDirectory("codex", missing)
|
||||
|
||||
if got != home || effective != home {
|
||||
t.Fatalf("expected fallback to home %q, got %q %q", home, got, effective)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeProviderWorkingDirectorySkipsUnknownProvider(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
got, effective := NormalizeProviderWorkingDirectory("claude", dir)
|
||||
|
||||
if got != dir || effective != dir {
|
||||
t.Fatalf("expected unknown provider to keep dir, got %q %q", got, effective)
|
||||
}
|
||||
}
|
||||
Loading…
Reference in New Issue
Block a user