diff --git a/internal/acp/providers_sync.go b/internal/acp/providers_sync.go index 4d1daad..c2919f1 100644 --- a/internal/acp/providers_sync.go +++ b/internal/acp/providers_sync.go @@ -3,6 +3,8 @@ package acp import ( "sort" "strings" + + "xworkmate-bridge/internal/shared" ) type syncedProvider struct { @@ -72,6 +74,9 @@ func (s *Server) availableProviders() []string { providers[provider.ProviderID] = struct{}{} } s.mu.Unlock() + for _, providerID := range shared.DetectACPProviders() { + providers[providerID] = struct{}{} + } ordered := make([]string, 0, len(providers)) for providerID := range providers { ordered = append(ordered, providerID) diff --git a/internal/acp/providers_sync_test.go b/internal/acp/providers_sync_test.go index 10f80e9..4cd7be1 100644 --- a/internal/acp/providers_sync_test.go +++ b/internal/acp/providers_sync_test.go @@ -6,6 +6,8 @@ import ( "net/http" "net/http/httptest" "os" + "path/filepath" + "strings" "testing" "xworkmate-bridge/internal/shared" @@ -17,6 +19,9 @@ func TestCapabilitiesIgnoreLocalProviderAutodetectUntilSync(t *testing.T) { t.Fatalf("write fake provider: %v", err) } t.Setenv("ACP_CLAUDE_BIN", fakeProvider) + t.Setenv("ACP_CODEX_BIN", "") + t.Setenv("ACP_GEMINI_BIN", "") + t.Setenv("ACP_OPENCODE_BIN", "") server := NewServer() result, rpcErr := server.handleRequest(shared.RPCRequest{ @@ -28,8 +33,15 @@ func TestCapabilitiesIgnoreLocalProviderAutodetectUntilSync(t *testing.T) { } providers, _ := result["providers"].([]string) - if len(providers) != 0 { - t.Fatalf("expected no providers before sync, got %#v", providers) + found := false + for _, provider := range providers { + if provider == "claude" { + found = true + break + } + } + if !found { + t.Fatalf("expected autodetected local provider before sync, got %#v", providers) } } @@ -207,3 +219,80 @@ func TestRunSingleAgentUsesFrozenExternalProviderParams(t *testing.T) { t.Fatalf("expected frozen provider output, got %#v", result.response) } } + +func TestRunSingleAgentFallsBackWorkingDirectoryToHome(t *testing.T) { + home := t.TempDir() + t.Setenv("HOME", home) + fakeOpencode := filepath.Join(t.TempDir(), "opencode") + if err := os.WriteFile(fakeOpencode, []byte("#!/bin/sh\necho local-ok\n"), 0o755); err != nil { + t.Fatalf("write fake opencode: %v", err) + } + t.Setenv("ACP_OPENCODE_BIN", fakeOpencode) + + server := NewServer() + session := server.getOrCreateSession("session-local", "thread-local") + result := server.runSingleAgent( + context.Background(), + "session.start", + session, + map[string]any{ + "provider": "opencode", + "taskPrompt": "hello", + "workingDirectory": filepath.Join(t.TempDir(), "missing"), + }, + "turn-local", + func(map[string]any) {}, + ) + if result.err != nil { + t.Fatalf("expected success, got rpc error: %v", result.err) + } + if got := result.response["output"]; got != "local-ok" { + t.Fatalf("expected local provider output, got %#v", result.response) + } + if got := result.response["effectiveWorkingDirectory"]; got != home { + t.Fatalf("expected effectiveWorkingDirectory %q, got %#v", home, got) + } +} + +func TestHandleRPCForwardsInboundBearerToExternalProvider(t *testing.T) { + externalServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if got := r.Header.Get("Authorization"); got != "Bearer bridge-token" { + t.Fatalf("expected forwarded bearer header, got %q", got) + } + _ = json.NewEncoder(w).Encode(map[string]any{ + "jsonrpc": "2.0", + "id": "run-auth", + "result": map[string]any{ + "success": true, + "output": "forwarded-auth-ok", + }, + }) + })) + defer externalServer.Close() + + server := NewServer() + server.syncProviders([]syncedProvider{{ + ProviderID: "codex", + Label: "Codex", + Endpoint: externalServer.URL, + Enabled: true, + }}) + + recorder := httptest.NewRecorder() + request := httptest.NewRequest( + http.MethodPost, + "http://127.0.0.1/acp/rpc", + strings.NewReader(`{"jsonrpc":"2.0","id":"run-auth","method":"session.start","params":{"sessionId":"s1","threadId":"t1","taskPrompt":"hello","workingDirectory":"`+t.TempDir()+`","routing":{"routingMode":"explicit","explicitExecutionTarget":"singleAgent","explicitProviderId":"codex"}}}`), + ) + request.Header.Set("Content-Type", "application/json") + request.Header.Set("Authorization", "Bearer bridge-token") + + server.HandleRPC(recorder, request) + + if recorder.Code != http.StatusOK { + t.Fatalf("expected 200, got %d", recorder.Code) + } + if !strings.Contains(recorder.Body.String(), "forwarded-auth-ok") { + t.Fatalf("expected forwarded provider response, got %q", recorder.Body.String()) + } +} diff --git a/internal/acp/routing_test.go b/internal/acp/routing_test.go index 350cb0c..46a29d1 100644 --- a/internal/acp/routing_test.go +++ b/internal/acp/routing_test.go @@ -297,7 +297,16 @@ func TestExecuteSessionTaskExplicitRoutingDoesNotRecordProjectMemory(t *testing. } } -func TestExecuteSessionTaskExplicitProviderRequiresSyncedCatalog(t *testing.T) { +func TestExecuteSessionTaskExplicitProviderUsesAutodetectedLocalProvider(t *testing.T) { + fakeClaude := filepath.Join(t.TempDir(), "claude") + if err := os.WriteFile(fakeClaude, []byte("#!/bin/sh\nprintf 'autodetected-provider-ok\\n'"), 0o755); err != nil { + t.Fatalf("write fake claude: %v", err) + } + t.Setenv("ACP_CLAUDE_BIN", fakeClaude) + t.Setenv("ACP_CODEX_BIN", "") + t.Setenv("ACP_GEMINI_BIN", "") + t.Setenv("ACP_OPENCODE_BIN", "") + server := NewServer() response, rpcErr := server.executeSessionTask(task{ req: shared.RPCRequest{ @@ -315,13 +324,16 @@ func TestExecuteSessionTaskExplicitProviderRequiresSyncedCatalog(t *testing.T) { }, }) if rpcErr != nil { - t.Fatalf("expected structured unavailable response, got rpc error: %v", rpcErr) + t.Fatalf("expected structured response, got rpc error: %v", rpcErr) } - if got := response["unavailable"]; got != true { - t.Fatalf("expected unavailable response, got %#v", response) + if success, _ := response["success"].(bool); !success { + t.Fatalf("expected success response, got %#v", response) } - if got := response["unavailableCode"]; got != "PROVIDER_UNAVAILABLE" { - t.Fatalf("expected PROVIDER_UNAVAILABLE, got %#v", response) + if got := response["provider"]; got != "claude" { + t.Fatalf("expected claude provider, got %#v", response) + } + if got := response["output"]; got != "autodetected-provider-ok" { + t.Fatalf("expected autodetected provider output, got %#v", response) } } diff --git a/internal/acp/server.go b/internal/acp/server.go index c15bdb3..557ebb6 100644 --- a/internal/acp/server.go +++ b/internal/acp/server.go @@ -217,6 +217,10 @@ func (s *Server) HandleRPC(w http.ResponseWriter, r *http.Request) { s.writeJSONError(w, nil, http.StatusBadRequest, -32700, err.Error()) return } + request.Params = injectInboundAuthorizationHeader( + request.Params, + r.Header.Get("Authorization"), + ) accept := strings.ToLower(r.Header.Get("Accept")) stream := strings.Contains(accept, "text/event-stream") @@ -269,6 +273,13 @@ func (s *Server) HandleRPC(w http.ResponseWriter, r *http.Request) { _ = json.NewEncoder(w).Encode(shared.ResultEnvelope(request.ID, response)) } +func (s *Server) authorized(r *http.Request) bool { + if s == nil || s.authService == nil { + return false + } + return s.authService.ValidateAuthorizationHeader(r.Header.Get("Authorization")) +} + func (s *Server) handleRequest( request shared.RPCRequest, notify func(map[string]any), @@ -689,6 +700,10 @@ func (s *Server) runSingleAgent( workingDirectory := strings.TrimSpace( shared.StringArg(params, "workingDirectory", ""), ) + workingDirectory, effectiveWorkingDirectory := shared.NormalizeProviderWorkingDirectory( + provider, + workingDirectory, + ) model := strings.TrimSpace(shared.StringArg(params, "model", "")) prompt := strings.TrimSpace(shared.StringArg(params, "taskPrompt", "")) prompt = shared.AugmentPromptWithAttachments(prompt, params) @@ -715,6 +730,9 @@ func (s *Server) runSingleAgent( if _, exists := result["turnId"]; !exists { result["turnId"] = turnID } + if _, exists := result["effectiveWorkingDirectory"]; !exists && effectiveWorkingDirectory != "" { + result["effectiveWorkingDirectory"] = effectiveWorkingDirectory + } return taskResult{response: result} } s.emitSessionUpdate(session, notify, turnID, map[string]any{ @@ -757,6 +775,9 @@ func (s *Server) runSingleAgent( if _, exists := result["turnId"]; !exists { result["turnId"] = turnID } + if _, exists := result["effectiveWorkingDirectory"]; !exists && effectiveWorkingDirectory != "" { + result["effectiveWorkingDirectory"] = effectiveWorkingDirectory + } return taskResult{response: result} } s.emitSessionUpdate(session, notify, turnID, map[string]any{ @@ -827,11 +848,12 @@ func (s *Server) runSingleAgent( return taskResult{ response: map[string]any{ - "success": true, - "output": output, - "turnId": turnID, - "mode": "single-agent", - "provider": provider, + "success": true, + "output": output, + "turnId": turnID, + "mode": "single-agent", + "provider": provider, + "effectiveWorkingDirectory": effectiveWorkingDirectory, }, } } diff --git a/internal/shared/tools.go b/internal/shared/tools.go index af954e1..5596ebb 100644 --- a/internal/shared/tools.go +++ b/internal/shared/tools.go @@ -8,6 +8,7 @@ import ( "fmt" "io" "net/http" + "os" "os/exec" "sort" "strings" @@ -84,6 +85,40 @@ func RunProviderCommand( return output, nil } +func NormalizeProviderWorkingDirectory(provider, requested string) (string, string) { + requested = strings.TrimSpace(requested) + if requested == "" { + return "", "" + } + switch strings.TrimSpace(strings.ToLower(provider)) { + case "codex", "opencode": + default: + return requested, requested + } + if canAccessWorkingDirectory(requested) { + return requested, requested + } + home, err := os.UserHomeDir() + if err != nil { + return requested, requested + } + home = strings.TrimSpace(home) + if home == "" { + return requested, requested + } + return home, home +} + +func canAccessWorkingDirectory(path string) bool { + info, err := os.Stat(path) + if err != nil || !info.IsDir() { + return false + } + cmd := exec.Command("pwd") + cmd.Dir = path + return cmd.Run() == nil +} + func ResolveProviderCommand( provider, model, diff --git a/internal/shared/tools_test.go b/internal/shared/tools_test.go new file mode 100644 index 0000000..11e4a77 --- /dev/null +++ b/internal/shared/tools_test.go @@ -0,0 +1,38 @@ +package shared + +import ( + "path/filepath" + "testing" +) + +func TestNormalizeProviderWorkingDirectoryPreservesAccessibleDir(t *testing.T) { + accessible := t.TempDir() + + got, effective := NormalizeProviderWorkingDirectory("opencode", accessible) + + if got != accessible || effective != accessible { + t.Fatalf("expected accessible dir preserved, got %q %q", got, effective) + } +} + +func TestNormalizeProviderWorkingDirectoryFallsBackToHome(t *testing.T) { + home := t.TempDir() + t.Setenv("HOME", home) + missing := filepath.Join(t.TempDir(), "missing") + + got, effective := NormalizeProviderWorkingDirectory("codex", missing) + + if got != home || effective != home { + t.Fatalf("expected fallback to home %q, got %q %q", home, got, effective) + } +} + +func TestNormalizeProviderWorkingDirectorySkipsUnknownProvider(t *testing.T) { + dir := t.TempDir() + + got, effective := NormalizeProviderWorkingDirectory("claude", dir) + + if got != dir || effective != dir { + t.Fatalf("expected unknown provider to keep dir, got %q %q", got, effective) + } +}