Initial standalone ACP bridge repository

This commit is contained in:
Haitao Pan 2026-04-09 09:49:48 +08:00
commit c5157fcb81
47 changed files with 9004 additions and 0 deletions

1
.gitignore vendored Normal file
View File

@ -0,0 +1 @@
build/

22
Makefile Normal file
View File

@ -0,0 +1,22 @@
.DEFAULT_GOAL := help
SHELL := /bin/bash
GO ?= go
OUTPUT_DIR ?= $(CURDIR)/build/bin
OUTPUT_PATH ?= $(OUTPUT_DIR)/xworkmate-go-core
.PHONY: help test build clean
help:
@printf "%-12s %s\n" "test" "Run Go tests"
@printf "%-12s %s\n" "build" "Build xworkmate-go-core helper"
@printf "%-12s %s\n" "clean" "Remove build output"
test:
$(GO) test ./...
build:
bash scripts/build-helper.sh
clean:
rm -rf build

28
README.md Normal file
View File

@ -0,0 +1,28 @@
# XWorkmate Bridge
`xworkmate-bridge` is the standalone repository for the XWorkmate ACP Bridge Server and the embedded Go helper previously stored under `xworkmate-app/go/go_core`.
## What lives here
- ACP Bridge HTTP/WebSocket server
- ACP stdio bridge entrypoint
- Go helper runtime packages used by the ACP bridge
- Unit tests for bridge routing, RPC contracts, mounts, runtime dispatch, and provider sync
## Compatibility
For compatibility with `xworkmate-app`, the built helper binary name remains `xworkmate-go-core`.
## Commands
```bash
make test
make build
./build/bin/xworkmate-go-core serve --listen 127.0.0.1:8787
```
## Environment
- `ACP_LISTEN_ADDR`: listen address for `serve` mode, default `127.0.0.1:8787`
- `OUTPUT_DIR`: optional output directory for `make build`
- `OUTPUT_PATH`: optional explicit build path for `make build`

5
go.mod Normal file
View File

@ -0,0 +1,5 @@
module xworkmate-bridge
go 1.25.0
require github.com/gorilla/websocket v1.5.3

2
go.sum Normal file
View File

@ -0,0 +1,2 @@
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=

393
internal/acp/execution.go Normal file
View File

@ -0,0 +1,393 @@
package acp
import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/url"
"strings"
"time"
"github.com/gorilla/websocket"
"xworkmate-bridge/internal/router"
"xworkmate-bridge/internal/shared"
)
const (
externalProviderEndpointKey = "externalProviderEndpoint"
externalProviderAuthorizationHeaderKey = "externalProviderAuthorizationHeader"
externalProviderLabelKey = "externalProviderLabel"
)
func buildResolvedExecutionParams(
params map[string]any,
resolved router.Result,
) map[string]any {
next := make(map[string]any, len(params)+8)
for key, value := range params {
next[key] = value
}
switch resolved.ResolvedExecutionTarget {
case router.ExecutionTargetGateway:
next["mode"] = router.ExecutionTargetGatewayChat
next["executionTarget"] = resolved.ResolvedEndpointTarget
case router.ExecutionTargetMultiAgent:
next["mode"] = router.ExecutionTargetMultiAgent
default:
next["mode"] = router.ExecutionTargetSingleAgent
}
if strings.TrimSpace(resolved.ResolvedProviderID) != "" {
next["provider"] = strings.TrimSpace(resolved.ResolvedProviderID)
}
if strings.TrimSpace(resolved.ResolvedModel) != "" {
next["model"] = strings.TrimSpace(resolved.ResolvedModel)
}
if len(resolved.ResolvedSkills) > 0 {
next["selectedSkills"] = append([]string(nil), resolved.ResolvedSkills...)
}
next["resolvedExecutionTarget"] = resolved.ResolvedExecutionTarget
next["resolvedEndpointTarget"] = resolved.ResolvedEndpointTarget
next["resolvedProviderId"] = resolved.ResolvedProviderID
next["resolvedModel"] = resolved.ResolvedModel
next["resolvedSkills"] = append([]string(nil), resolved.ResolvedSkills...)
return next
}
func injectResolvedExternalProviderParams(
params map[string]any,
provider syncedProvider,
) map[string]any {
if params == nil {
params = map[string]any{}
}
if endpoint := strings.TrimSpace(provider.Endpoint); endpoint != "" {
params[externalProviderEndpointKey] = endpoint
}
if authorization := strings.TrimSpace(provider.AuthorizationHeader); authorization != "" {
params[externalProviderAuthorizationHeaderKey] = authorization
}
if label := strings.TrimSpace(provider.Label); label != "" {
params[externalProviderLabelKey] = label
}
return params
}
func (s *Server) runGateway(
ctx context.Context,
method string,
session *session,
params map[string]any,
turnID string,
notify func(map[string]any),
) taskResult {
_ = ctx
executionTarget := strings.TrimSpace(shared.StringArg(params, "executionTarget", ""))
if executionTarget == "" {
executionTarget = router.EndpointTargetLocal
}
result := s.gateway.RequestByMode(
executionTarget,
method,
params,
2*time.Minute,
notify,
)
if !result.OK {
errMessage := strings.TrimSpace(shared.StringArg(result.Error, "message", "gateway execution failed"))
s.emitSessionUpdate(session, notify, turnID, map[string]any{
"type": "status",
"event": "completed",
"message": errMessage,
"pending": false,
"error": true,
})
return taskResult{
response: map[string]any{
"success": false,
"error": errMessage,
"turnId": turnID,
"mode": router.ExecutionTargetGatewayChat,
},
}
}
payload := asMap(result.Payload)
if len(payload) == 0 {
payload = map[string]any{
"success": true,
"turnId": turnID,
"mode": router.ExecutionTargetGatewayChat,
}
}
if _, ok := payload["turnId"]; !ok {
payload["turnId"] = turnID
}
if _, ok := payload["mode"]; !ok {
payload["mode"] = router.ExecutionTargetGatewayChat
}
return taskResult{response: payload}
}
func (s *Server) runSingleAgentViaExternalProvider(
ctx context.Context,
provider syncedProvider,
method string,
params map[string]any,
notify func(map[string]any),
) (map[string]any, error) {
endpoint := strings.TrimSpace(provider.Endpoint)
if endpoint == "" {
return nil, fmt.Errorf("external provider endpoint is missing")
}
forwardParams := sanitizeExternalACPParams(method, params)
return requestExternalACP(
ctx,
endpoint,
provider.AuthorizationHeader,
method,
forwardParams,
notify,
)
}
func sanitizeExternalACPParams(method string, params map[string]any) map[string]any {
if len(params) == 0 {
return map[string]any{}
}
next := make(map[string]any, len(params))
for key, value := range params {
next[key] = value
}
// Internal routing/runtime fields must not leak into external provider payloads.
delete(next, "metadata")
delete(next, "resolvedExecutionTarget")
delete(next, "resolvedEndpointTarget")
delete(next, "resolvedProviderId")
delete(next, "resolvedModel")
delete(next, "resolvedSkills")
delete(next, externalProviderEndpointKey)
delete(next, externalProviderAuthorizationHeaderKey)
delete(next, externalProviderLabelKey)
// Gateway-only fields are irrelevant in ACP single-agent forwarding.
normalizedMethod := strings.TrimSpace(method)
if normalizedMethod == "session.start" || normalizedMethod == "session.message" {
delete(next, "executionTarget")
delete(next, "agentId")
}
return next
}
func externalProviderFromParams(params map[string]any) (syncedProvider, bool) {
endpoint := strings.TrimSpace(shared.StringArg(params, externalProviderEndpointKey, ""))
if endpoint == "" {
return syncedProvider{}, false
}
return syncedProvider{
ProviderID: strings.TrimSpace(shared.StringArg(params, "provider", "")),
Label: strings.TrimSpace(shared.StringArg(params, externalProviderLabelKey, "")),
Endpoint: endpoint,
AuthorizationHeader: strings.TrimSpace(shared.StringArg(params, externalProviderAuthorizationHeaderKey, "")),
Enabled: true,
}, true
}
func requestExternalACP(
ctx context.Context,
endpoint,
authorization,
method string,
params map[string]any,
notify func(map[string]any),
) (map[string]any, error) {
parsed, err := httpOrWebsocketEndpoint(endpoint)
if err != nil {
return nil, err
}
switch parsed.Scheme {
case "http", "https":
return requestExternalACPHTTP(ctx, parsed, authorization, method, params)
default:
return requestExternalACPWebSocket(ctx, parsed, authorization, method, params, notify)
}
}
func requestExternalACPHTTP(
ctx context.Context,
endpoint *urlSpec,
authorization,
method string,
params map[string]any,
) (map[string]any, error) {
requestBody, _ := json.Marshal(map[string]any{
"jsonrpc": "2.0",
"id": fmt.Sprintf("req-%d", time.Now().UnixNano()),
"method": method,
"params": params,
})
req, err := http.NewRequestWithContext(
ctx,
http.MethodPost,
endpoint.httpRPCEndpoint(),
strings.NewReader(string(requestBody)),
)
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json; charset=utf-8")
req.Header.Set("Accept", "application/json")
if strings.TrimSpace(authorization) != "" {
req.Header.Set("Authorization", strings.TrimSpace(authorization))
}
response, err := (&http.Client{Timeout: 2 * time.Minute}).Do(req)
if err != nil {
return nil, err
}
defer response.Body.Close()
var decoded map[string]any
if err := json.NewDecoder(response.Body).Decode(&decoded); err != nil {
return nil, err
}
if errPayload := asMap(decoded["error"]); len(errPayload) > 0 {
return nil, fmt.Errorf(
"%s",
strings.TrimSpace(shared.StringArg(errPayload, "message", "external ACP request failed")),
)
}
return decoded, nil
}
func requestExternalACPWebSocket(
ctx context.Context,
endpoint *urlSpec,
authorization,
method string,
params map[string]any,
notify func(map[string]any),
) (map[string]any, error) {
headers := http.Header{}
if strings.TrimSpace(authorization) != "" {
headers.Set("Authorization", strings.TrimSpace(authorization))
}
conn, _, err := websocket.DefaultDialer.DialContext(
ctx,
endpoint.webSocketEndpoint(),
headers,
)
if err != nil {
return nil, err
}
defer conn.Close()
requestID := fmt.Sprintf("req-%d", time.Now().UnixNano())
if err := conn.WriteJSON(map[string]any{
"jsonrpc": "2.0",
"id": requestID,
"method": method,
"params": params,
}); err != nil {
return nil, err
}
for {
if err := conn.SetReadDeadline(time.Now().Add(2 * time.Minute)); err != nil {
return nil, err
}
var payload map[string]any
if err := conn.ReadJSON(&payload); err != nil {
return nil, err
}
if strings.TrimSpace(shared.StringArg(payload, "id", "")) == requestID &&
(payload["result"] != nil || payload["error"] != nil) {
if errPayload := asMap(payload["error"]); len(errPayload) > 0 {
return nil, fmt.Errorf(
"%s",
strings.TrimSpace(shared.StringArg(errPayload, "message", "external ACP request failed")),
)
}
return payload, nil
}
if notify != nil && strings.TrimSpace(shared.StringArg(payload, "method", "")) != "" {
notify(payload)
}
}
}
type urlSpec struct {
Scheme string
Host string
Port string
Path string
}
func httpOrWebsocketEndpoint(raw string) (*urlSpec, error) {
trimmed := strings.TrimSpace(raw)
if trimmed == "" {
return nil, fmt.Errorf("missing external ACP endpoint")
}
parsed, err := url.ParseRequestURI(trimmed)
if err != nil {
return nil, err
}
scheme := strings.ToLower(strings.TrimSpace(parsed.Scheme))
if scheme != "http" && scheme != "https" && scheme != "ws" && scheme != "wss" {
return nil, fmt.Errorf("unsupported external ACP scheme: %s", scheme)
}
return &urlSpec{
Scheme: scheme,
Host: parsed.Host,
Path: strings.TrimRight(parsed.Path, "/"),
}, nil
}
func (u *urlSpec) basePath() string {
path := strings.TrimSpace(u.Path)
if path == "" || path == "/" {
return ""
}
if strings.HasSuffix(path, "/acp/rpc") {
path = strings.TrimSuffix(path, "/acp/rpc")
} else if strings.HasSuffix(path, "/acp") {
path = strings.TrimSuffix(path, "/acp")
}
path = strings.TrimRight(path, "/")
if path == "" || path == "/" {
return ""
}
if !strings.HasPrefix(path, "/") {
return "/" + path
}
return path
}
func (u *urlSpec) httpRPCEndpoint() string {
scheme := u.Scheme
if scheme == "ws" {
scheme = "http"
} else if scheme == "wss" {
scheme = "https"
}
basePath := u.basePath()
if basePath == "" {
basePath = "/acp/rpc"
} else {
basePath += "/acp/rpc"
}
return fmt.Sprintf("%s://%s%s", scheme, u.Host, basePath)
}
func (u *urlSpec) webSocketEndpoint() string {
scheme := u.Scheme
if scheme == "http" {
scheme = "ws"
} else if scheme == "https" {
scheme = "wss"
}
basePath := u.basePath()
if basePath == "" {
basePath = "/acp"
} else {
basePath += "/acp"
}
return fmt.Sprintf("%s://%s%s", scheme, u.Host, basePath)
}

View File

@ -0,0 +1,159 @@
package acp
import (
"strings"
"time"
"xworkmate-bridge/internal/gatewayruntime"
"xworkmate-bridge/internal/shared"
)
func handleGatewayConnect(
server *Server,
params map[string]any,
notify func(map[string]any),
) map[string]any {
request := gatewayruntime.ConnectRequest{
RuntimeID: strings.TrimSpace(shared.StringArg(params, "runtimeId", "")),
Mode: strings.TrimSpace(shared.StringArg(params, "mode", "unconfigured")),
ClientID: strings.TrimSpace(shared.StringArg(params, "clientId", "")),
Locale: strings.TrimSpace(shared.StringArg(params, "locale", "")),
UserAgent: strings.TrimSpace(shared.StringArg(params, "userAgent", "")),
ConnectAuthMode: strings.TrimSpace(shared.StringArg(params, "connectAuthMode", "")),
ConnectAuthFields: parseGatewayRuntimeStringSlice(params["connectAuthFields"]),
ConnectAuthSources: parseGatewayRuntimeStringSlice(params["connectAuthSources"]),
HasSharedAuth: parseBool(params["hasSharedAuth"]),
HasDeviceToken: parseBool(params["hasDeviceToken"]),
Endpoint: gatewayruntime.Endpoint{
Host: strings.TrimSpace(shared.StringArg(asMap(params["endpoint"]), "host", "")),
Port: parsePositiveInt(asMap(params["endpoint"])["port"]),
TLS: parseBool(asMap(params["endpoint"])["tls"]),
},
PackageInfo: gatewayruntime.PackageInfo{
AppName: strings.TrimSpace(shared.StringArg(asMap(params["packageInfo"]), "appName", "")),
PackageName: strings.TrimSpace(shared.StringArg(asMap(params["packageInfo"]), "packageName", "")),
Version: strings.TrimSpace(shared.StringArg(asMap(params["packageInfo"]), "version", "")),
BuildNumber: strings.TrimSpace(shared.StringArg(asMap(params["packageInfo"]), "buildNumber", "")),
},
DeviceInfo: gatewayruntime.DeviceInfo{
Platform: strings.TrimSpace(shared.StringArg(asMap(params["deviceInfo"]), "platform", "")),
PlatformVersion: strings.TrimSpace(shared.StringArg(asMap(params["deviceInfo"]), "platformVersion", "")),
DeviceFamily: strings.TrimSpace(shared.StringArg(asMap(params["deviceInfo"]), "deviceFamily", "")),
ModelIdentifier: strings.TrimSpace(shared.StringArg(asMap(params["deviceInfo"]), "modelIdentifier", "")),
},
Identity: gatewayruntime.DeviceIdentity{
DeviceID: strings.TrimSpace(shared.StringArg(asMap(params["identity"]), "deviceId", "")),
PublicKeyBase64URL: strings.TrimSpace(shared.StringArg(asMap(params["identity"]), "publicKeyBase64Url", "")),
PrivateKeyBase64URL: strings.TrimSpace(shared.StringArg(asMap(params["identity"]), "privateKeyBase64Url", "")),
},
Auth: gatewayruntime.AuthConfig{
Token: strings.TrimSpace(shared.StringArg(asMap(params["auth"]), "token", "")),
DeviceToken: strings.TrimSpace(shared.StringArg(asMap(params["auth"]), "deviceToken", "")),
Password: strings.TrimSpace(shared.StringArg(asMap(params["auth"]), "password", "")),
},
}
result := server.gateway.Connect(request, notify)
return map[string]any{
"ok": result.OK,
"snapshot": result.Snapshot,
"auth": result.Auth,
"returnedDeviceToken": result.ReturnedDeviceToken,
"error": result.Error,
}
}
func handleGatewayRequest(
server *Server,
params map[string]any,
notify func(map[string]any),
) map[string]any {
timeout := time.Duration(parsePositiveInt(params["timeoutMs"])) * time.Millisecond
result := server.gateway.Request(
strings.TrimSpace(shared.StringArg(params, "runtimeId", "")),
strings.TrimSpace(shared.StringArg(params, "method", "")),
asMap(params["params"]),
timeout,
notify,
)
return map[string]any{
"ok": result.OK,
"payload": result.Payload,
"error": result.Error,
}
}
func handleGatewayDisconnect(
server *Server,
params map[string]any,
notify func(map[string]any),
) map[string]any {
server.gateway.Disconnect(
strings.TrimSpace(shared.StringArg(params, "runtimeId", "")),
notify,
)
return map[string]any{"accepted": true}
}
func asMap(value any) map[string]any {
if typed, ok := value.(map[string]any); ok {
return typed
}
if typed, ok := value.(map[string]interface{}); ok {
return typed
}
return map[string]any{}
}
func parseGatewayRuntimeStringSlice(value any) []string {
list, ok := value.([]any)
if !ok {
if typed, ok := value.([]string); ok {
return append([]string(nil), typed...)
}
return nil
}
result := make([]string, 0, len(list))
for _, item := range list {
text := strings.TrimSpace(shared.StringArg(map[string]any{"value": item}, "value", ""))
if text == "" {
continue
}
result = append(result, text)
}
return result
}
func parseBool(value any) bool {
switch typed := value.(type) {
case bool:
return typed
case string:
return shared.BoolArg(typed, false)
case float64:
return typed != 0
case int:
return typed != 0
default:
return false
}
}
func parsePositiveInt(value any) int {
switch typed := value.(type) {
case int:
if typed > 0 {
return typed
}
case int64:
if typed > 0 {
return int(typed)
}
case float64:
if typed > 0 {
return int(typed)
}
case string:
return shared.IntArg(typed, 0)
}
return 0
}

View File

@ -0,0 +1,94 @@
package acp
import (
"sort"
"strings"
)
type syncedProvider struct {
ProviderID string
Label string
Endpoint string
AuthorizationHeader string
Enabled bool
}
func parseSyncedProviders(raw any) []syncedProvider {
list, ok := raw.([]any)
if !ok {
return nil
}
providers := make([]syncedProvider, 0, len(list))
for _, item := range list {
entry := asMap(item)
providerID := strings.TrimSpace(sharedString(entry, "providerId"))
if providerID == "" {
continue
}
providers = append(providers, syncedProvider{
ProviderID: providerID,
Label: strings.TrimSpace(sharedString(entry, "label")),
Endpoint: strings.TrimSpace(sharedString(entry, "endpoint")),
AuthorizationHeader: strings.TrimSpace(sharedString(entry, "authorizationHeader")),
Enabled: parseBool(entry["enabled"]),
})
}
return providers
}
func (s *Server) syncProviders(providers []syncedProvider) map[string]any {
s.mu.Lock()
defer s.mu.Unlock()
s.providerCatalog = make(map[string]syncedProvider, len(providers))
for _, provider := range providers {
if strings.TrimSpace(provider.ProviderID) == "" {
continue
}
s.providerCatalog[provider.ProviderID] = provider
}
return map[string]any{
"ok": true,
"providers": syncedProvidersResult(providers),
}
}
func (s *Server) syncedProviderByID(providerID string) (syncedProvider, bool) {
s.mu.Lock()
defer s.mu.Unlock()
provider, ok := s.providerCatalog[strings.TrimSpace(providerID)]
if !ok || !provider.Enabled || strings.TrimSpace(provider.Endpoint) == "" {
return syncedProvider{}, false
}
return provider, true
}
func (s *Server) availableProviders() []string {
providers := make(map[string]struct{})
s.mu.Lock()
for _, provider := range s.providerCatalog {
if !provider.Enabled || strings.TrimSpace(provider.Endpoint) == "" {
continue
}
providers[provider.ProviderID] = struct{}{}
}
s.mu.Unlock()
ordered := make([]string, 0, len(providers))
for providerID := range providers {
ordered = append(ordered, providerID)
}
sort.Strings(ordered)
return ordered
}
func syncedProvidersResult(providers []syncedProvider) []map[string]any {
result := make([]map[string]any, 0, len(providers))
for _, provider := range providers {
result = append(result, map[string]any{
"providerId": provider.ProviderID,
"label": provider.Label,
"endpoint": provider.Endpoint,
"enabled": provider.Enabled,
})
}
return result
}

View File

@ -0,0 +1,209 @@
package acp
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"os"
"testing"
"xworkmate-bridge/internal/shared"
)
func TestCapabilitiesIgnoreLocalProviderAutodetectUntilSync(t *testing.T) {
fakeProvider := t.TempDir() + "/fake-claude"
if err := os.WriteFile(fakeProvider, []byte("#!/bin/sh\nexit 0\n"), 0o755); err != nil {
t.Fatalf("write fake provider: %v", err)
}
t.Setenv("ACP_CLAUDE_BIN", fakeProvider)
server := NewServer()
result, rpcErr := server.handleRequest(shared.RPCRequest{
Method: "acp.capabilities",
Params: map[string]any{},
}, func(map[string]any) {})
if rpcErr != nil {
t.Fatalf("expected capabilities success, got %v", rpcErr)
}
providers, _ := result["providers"].([]string)
if len(providers) != 0 {
t.Fatalf("expected no providers before sync, got %#v", providers)
}
}
func TestProvidersSyncUpdatesCapabilities(t *testing.T) {
server := NewServer()
_, rpcErr := server.handleRequest(shared.RPCRequest{
Method: "xworkmate.providers.sync",
Params: map[string]any{
"providers": []any{
map[string]any{
"providerId": "claude",
"label": "Claude",
"endpoint": "http://127.0.0.1:9999",
"authorizationHeader": "Bearer test",
"enabled": true,
},
},
},
}, func(map[string]any) {})
if rpcErr != nil {
t.Fatalf("expected sync success, got %v", rpcErr)
}
result, rpcErr := server.handleRequest(shared.RPCRequest{
Method: "acp.capabilities",
Params: map[string]any{},
}, func(map[string]any) {})
if rpcErr != nil {
t.Fatalf("expected capabilities success, got %v", rpcErr)
}
providers, _ := result["providers"].([]string)
if len(providers) == 0 {
t.Fatalf("expected synced provider in capabilities, got %#v", result)
}
found := false
for _, provider := range providers {
if provider == "claude" {
found = true
break
}
}
if !found {
t.Fatalf("expected claude provider after sync, got %#v", providers)
}
}
func TestExecuteSessionTaskUsesSyncedExternalProvider(t *testing.T) {
var lastForwardedParams map[string]any
externalServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/acp/rpc" {
http.NotFound(w, r)
return
}
defer r.Body.Close()
var request map[string]any
if err := json.NewDecoder(r.Body).Decode(&request); err != nil {
t.Fatalf("decode request: %v", err)
}
lastForwardedParams = asMap(request["params"])
method, _ := request["method"].(string)
switch method {
case "session.start":
_ = json.NewEncoder(w).Encode(map[string]any{
"jsonrpc": "2.0",
"id": request["id"],
"result": map[string]any{
"success": true,
"output": "external-provider-ok",
"turnId": "turn-external",
"provider": "claude",
"mode": "single-agent",
},
})
default:
_ = json.NewEncoder(w).Encode(map[string]any{
"jsonrpc": "2.0",
"id": request["id"],
"result": map[string]any{"ok": true},
})
}
}))
defer externalServer.Close()
server := NewServer()
server.syncProviders([]syncedProvider{
{
ProviderID: "claude",
Label: "Claude",
Endpoint: externalServer.URL,
AuthorizationHeader: "Bearer test",
Enabled: true,
},
})
response, rpcErr := server.executeSessionTask(task{
req: shared.RPCRequest{
Method: "session.start",
Params: map[string]any{
"sessionId": "session-external",
"threadId": "thread-external",
"taskPrompt": "hello from external provider",
"workingDirectory": t.TempDir(),
"routing": map[string]any{
"routingMode": "explicit",
"explicitExecutionTarget": "singleAgent",
"explicitProviderId": "claude",
},
},
},
})
if rpcErr != nil {
t.Fatalf("expected success, got rpc error: %v", rpcErr)
}
if got := response["output"]; got != "external-provider-ok" {
t.Fatalf("expected external provider output, got %#v", response)
}
if got := response["resolvedProviderId"]; got != "claude" {
t.Fatalf("expected resolved provider claude, got %#v", response)
}
if _, exists := lastForwardedParams["metadata"]; exists {
t.Fatalf("expected metadata to be stripped for external provider request, got %#v", lastForwardedParams)
}
if _, exists := lastForwardedParams[externalProviderEndpointKey]; exists {
t.Fatalf("expected internal endpoint key to be stripped, got %#v", lastForwardedParams)
}
}
func TestRunSingleAgentUsesFrozenExternalProviderParams(t *testing.T) {
externalServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/acp/rpc" {
http.NotFound(w, r)
return
}
defer r.Body.Close()
var request map[string]any
if err := json.NewDecoder(r.Body).Decode(&request); err != nil {
t.Fatalf("decode request: %v", err)
}
_ = json.NewEncoder(w).Encode(map[string]any{
"jsonrpc": "2.0",
"id": request["id"],
"result": map[string]any{
"success": true,
"output": "frozen-provider-ok",
"turnId": "turn-frozen",
"provider": "custom-agent-1",
"mode": "single-agent",
},
})
}))
defer externalServer.Close()
server := NewServer()
session := server.getOrCreateSession("session-frozen", "thread-frozen")
result := server.runSingleAgent(
context.Background(),
"session.start",
session,
map[string]any{
"provider": "custom-agent-1",
"taskPrompt": "hello",
"workingDirectory": t.TempDir(),
externalProviderEndpointKey: externalServer.URL,
externalProviderAuthorizationHeaderKey: "Bearer test",
externalProviderLabelKey: "Codex",
},
"turn-frozen",
func(map[string]any) {},
)
if result.err != nil {
t.Fatalf("expected success, got rpc error: %v", result.err)
}
if got := result.response["output"]; got != "frozen-provider-ok" {
t.Fatalf("expected frozen provider output, got %#v", result.response)
}
}

196
internal/acp/routing.go Normal file
View File

@ -0,0 +1,196 @@
package acp
import (
"fmt"
"os"
"strings"
"xworkmate-bridge/internal/memory"
"xworkmate-bridge/internal/router"
"xworkmate-bridge/internal/skills"
)
func handleRoutingResolve(params map[string]any) map[string]any {
result, _ := resolveRoutingMetadataWithProviders(params, nil)
return mergeRoutingResponse(map[string]any{"ok": true}, result)
}
func resolveRoutingMetadata(params map[string]any) (router.Result, bool) {
return resolveRoutingMetadataWithProviders(params, nil)
}
func resolveRoutingMetadataWithProviders(
params map[string]any,
availableProviders []string,
) (router.Result, bool) {
routingParams := asMap(params["routing"])
if len(routingParams) == 0 {
return router.Result{}, false
}
installApproval := asMap(routingParams["installApproval"])
resolver := router.NewResolver()
result := resolver.Resolve(router.Request{
Prompt: strings.TrimSpace(sharedString(params, "taskPrompt")),
WorkingDirectory: strings.TrimSpace(sharedString(params, "workingDirectory")),
RoutingMode: strings.TrimSpace(sharedString(routingParams, "routingMode")),
PreferredGatewayTarget: strings.TrimSpace(sharedString(routingParams, "preferredGatewayTarget")),
ExplicitExecutionTarget: strings.TrimSpace(sharedString(routingParams, "explicitExecutionTarget")),
ExplicitProviderID: strings.TrimSpace(sharedString(routingParams, "explicitProviderId")),
ExplicitModel: strings.TrimSpace(sharedString(routingParams, "explicitModel")),
ExplicitSkills: parseRoutingStringSlice(routingParams["explicitSkills"]),
AllowSkillInstall: parseBool(routingParams["allowSkillInstall"]),
InstallApproval: skills.InstallApproval{
RequestID: strings.TrimSpace(sharedString(installApproval, "requestId")),
ApprovedSkillKeys: parseRoutingStringSlice(installApproval["approvedSkillKeys"]),
},
AvailableSkills: parseRoutingSkillCandidates(routingParams["availableSkills"]),
AvailableProviders: append([]string(nil), availableProviders...),
AIGatewayBaseURL: strings.TrimSpace(sharedString(params, "aiGatewayBaseUrl")),
AIGatewayAPIKey: strings.TrimSpace(sharedString(params, "aiGatewayApiKey")),
})
return result, true
}
func mergeRoutingResponse(response map[string]any, result router.Result) map[string]any {
if response == nil {
response = map[string]any{}
}
response["resolvedExecutionTarget"] = result.ResolvedExecutionTarget
response["resolvedEndpointTarget"] = result.ResolvedEndpointTarget
response["resolvedProviderId"] = result.ResolvedProviderID
response["resolvedModel"] = result.ResolvedModel
response["resolvedSkills"] = append([]string(nil), result.ResolvedSkills...)
response["skillResolutionSource"] = result.SkillResolutionSource
response["needsSkillInstall"] = result.NeedsSkillInstall
response["unavailable"] = result.Unavailable
if strings.TrimSpace(result.UnavailableCode) != "" {
response["unavailableCode"] = result.UnavailableCode
}
if strings.TrimSpace(result.UnavailableMessage) != "" {
response["unavailableMessage"] = result.UnavailableMessage
}
if strings.TrimSpace(result.SkillInstallRequestID) != "" {
response["skillInstallRequestId"] = result.SkillInstallRequestID
}
if len(result.SkillCandidates) > 0 {
response["skillCandidates"] = routingSkillCandidatesMap(result.SkillCandidates)
}
if len(result.MemorySources) > 0 {
response["memorySources"] = routingMemorySourcesMap(result.MemorySources)
}
return response
}
func recordRoutingSuccess(
params map[string]any,
result router.Result,
response map[string]any,
) error {
routingParams := asMap(params["routing"])
if len(routingParams) == 0 {
return nil
}
if strings.EqualFold(
strings.TrimSpace(sharedString(routingParams, "routingMode")),
router.RoutingModeExplicit,
) {
return nil
}
if !parseBool(response["success"]) {
return nil
}
workingDirectory := strings.TrimSpace(sharedString(params, "workingDirectory"))
if workingDirectory == "" {
return nil
}
homeDir, _ := os.UserHomeDir()
service := memory.NewService(homeDir)
return service.RecordSuccess(workingDirectory, memory.SuccessEntry{
ResolvedExecutionTarget: result.ResolvedExecutionTarget,
ResolvedProviderID: result.ResolvedProviderID,
ResolvedModel: result.ResolvedModel,
ResolvedSkills: append([]string(nil), result.ResolvedSkills...),
})
}
func parseRoutingSkillCandidates(raw any) []skills.Candidate {
list, ok := raw.([]any)
if !ok {
return nil
}
candidates := make([]skills.Candidate, 0, len(list))
for _, item := range list {
entry := asMap(item)
candidates = append(candidates, skills.Candidate{
ID: strings.TrimSpace(sharedString(entry, "id")),
Label: strings.TrimSpace(sharedString(entry, "label")),
Description: strings.TrimSpace(sharedString(entry, "description")),
Installed: parseBool(entry["installed"]),
})
}
return candidates
}
func routingSkillCandidatesMap(candidates []skills.Candidate) []map[string]any {
result := make([]map[string]any, 0, len(candidates))
for _, candidate := range candidates {
result = append(result, map[string]any{
"id": candidate.ID,
"label": candidate.Label,
"description": candidate.Description,
"installed": candidate.Installed,
})
}
return result
}
func routingMemorySourcesMap(sources []memory.Source) []map[string]any {
result := make([]map[string]any, 0, len(sources))
for _, source := range sources {
result = append(result, map[string]any{
"path": source.Path,
"scope": source.Scope,
})
}
return result
}
func parseRoutingStringSlice(raw any) []string {
list, ok := raw.([]any)
if !ok {
if typed, ok := raw.([]string); ok {
return append([]string(nil), typed...)
}
return nil
}
values := make([]string, 0, len(list))
for _, item := range list {
value := strings.TrimSpace(sharedStringArg(item))
if value == "" {
continue
}
values = append(values, value)
}
return values
}
func sharedString(params map[string]any, key string) string {
if params == nil {
return ""
}
return strings.TrimSpace(sharedStringArg(params[key]))
}
func sharedStringArg(value any) string {
if value == nil {
return ""
}
switch typed := value.(type) {
case string:
return typed
default:
return fmt.Sprint(value)
}
}

View File

@ -0,0 +1,472 @@
package acp
import (
"encoding/json"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strings"
"testing"
"xworkmate-bridge/internal/shared"
)
func newExternalSingleAgentProvider(
t *testing.T,
providerID string,
output string,
) *httptest.Server {
t.Helper()
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/acp/rpc" {
http.NotFound(w, r)
return
}
defer r.Body.Close()
var request map[string]any
if err := json.NewDecoder(r.Body).Decode(&request); err != nil {
t.Fatalf("decode request: %v", err)
}
_ = json.NewEncoder(w).Encode(map[string]any{
"jsonrpc": "2.0",
"id": request["id"],
"result": map[string]any{
"success": true,
"output": output,
"turnId": "turn-" + providerID,
"provider": providerID,
"mode": "single-agent",
},
})
}))
}
func TestHandleRoutingResolveCoversNineScenarioBuckets(t *testing.T) {
localAvailableSkills := []map[string]any{
{"id": "pptx", "label": "PPTX", "description": "slides", "installed": true},
{"id": "docx", "label": "DOCX", "description": "docs", "installed": true},
{"id": "xlsx", "label": "XLSX", "description": "sheets", "installed": true},
{"id": "pdf", "label": "PDF", "description": "pdf", "installed": true},
{"id": "image-resizer", "label": "image-resizer", "description": "image resize", "installed": true},
{"id": "browser-automation", "label": "Browser Automation", "description": "browser", "installed": true},
}
cases := []struct {
name string
prompt string
expectedExecutionTarget string
expectedSkillSource string
expectedResolvedSkill string
expectedNeedsSkillInstall bool
}{
{
name: "powerpoint-pptx",
prompt: "create a powerpoint deck for this launch",
expectedExecutionTarget: "single-agent",
expectedSkillSource: "local_match",
expectedResolvedSkill: "PPTX",
},
{
name: "word-docx",
prompt: "draft a word document memo",
expectedExecutionTarget: "single-agent",
expectedSkillSource: "local_match",
expectedResolvedSkill: "DOCX",
},
{
name: "excel-xlsx",
prompt: "build an excel workbook with formulas",
expectedExecutionTarget: "single-agent",
expectedSkillSource: "local_match",
expectedResolvedSkill: "XLSX",
},
{
name: "pdf",
prompt: "merge and fill this pdf form",
expectedExecutionTarget: "single-agent",
expectedSkillSource: "local_match",
expectedResolvedSkill: "PDF",
},
{
name: "image-resizer",
prompt: "batch resize image assets",
expectedExecutionTarget: "single-agent",
expectedSkillSource: "local_match",
expectedResolvedSkill: "image-resizer",
},
{
name: "image-cog",
prompt: "use image-cog to generate consistent characters",
expectedExecutionTarget: "gateway",
expectedSkillSource: "find_skills",
expectedNeedsSkillInstall: true,
},
{
name: "image-video-generation-editting",
prompt: "wan 图生视频并做视频编辑",
expectedExecutionTarget: "gateway",
expectedSkillSource: "find_skills",
expectedNeedsSkillInstall: true,
},
{
name: "video-translator",
prompt: "translate video subtitles and dub the clip",
expectedExecutionTarget: "gateway",
expectedSkillSource: "find_skills",
expectedNeedsSkillInstall: true,
},
{
name: "browser-search-news",
prompt: "跨浏览器执行并搜索最新资讯采集结果",
expectedExecutionTarget: "gateway",
expectedSkillSource: "local_match",
expectedResolvedSkill: "Browser Automation",
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
result := handleRoutingResolve(map[string]any{
"taskPrompt": tc.prompt,
"workingDirectory": "/tmp/workspace",
"routing": map[string]any{
"routingMode": "auto",
"preferredGatewayTarget": "local",
"allowSkillInstall": false,
"availableSkills": func() []any {
values := make([]any, 0, len(localAvailableSkills))
for _, item := range localAvailableSkills {
values = append(values, item)
}
return values
}(),
},
})
if got := result["resolvedExecutionTarget"]; got != tc.expectedExecutionTarget {
t.Fatalf("expected execution target %q, got %#v", tc.expectedExecutionTarget, got)
}
if got := result["skillResolutionSource"]; got != tc.expectedSkillSource {
t.Fatalf("expected skill source %q, got %#v", tc.expectedSkillSource, got)
}
if tc.expectedResolvedSkill != "" {
resolvedSkills, _ := result["resolvedSkills"].([]string)
if len(resolvedSkills) == 0 || resolvedSkills[0] != tc.expectedResolvedSkill {
t.Fatalf("expected resolved skill %q, got %#v", tc.expectedResolvedSkill, result["resolvedSkills"])
}
}
if got := result["needsSkillInstall"]; got != tc.expectedNeedsSkillInstall {
t.Fatalf("expected needsSkillInstall=%v, got %#v", tc.expectedNeedsSkillInstall, got)
}
})
}
}
func TestExecuteSessionTaskAutoRoutingRecordsProjectMemory(t *testing.T) {
homeDir := t.TempDir()
workspaceDir := filepath.Join(t.TempDir(), "workspace")
if err := os.MkdirAll(workspaceDir, 0o755); err != nil {
t.Fatalf("create workspace: %v", err)
}
t.Setenv("HOME", homeDir)
server := NewServer()
providerServer := newExternalSingleAgentProvider(t, "claude", "done")
defer providerServer.Close()
server.syncProviders([]syncedProvider{{
ProviderID: "claude",
Label: "Claude",
Endpoint: providerServer.URL,
Enabled: true,
}})
response, rpcErr := server.executeSessionTask(task{
req: shared.RPCRequest{
Params: map[string]any{
"sessionId": "session-auto",
"threadId": "thread-auto",
"provider": "claude",
"taskPrompt": "create a powerpoint deck for launch",
"workingDirectory": workspaceDir,
"routing": map[string]any{
"routingMode": "auto",
"preferredGatewayTarget": "local",
"availableSkills": []any{
map[string]any{
"id": "pptx",
"label": "PPTX",
"description": "slides",
"installed": true,
},
},
},
},
},
})
if rpcErr != nil {
t.Fatalf("expected success, got rpc error: %v", rpcErr)
}
if success, _ := response["success"].(bool); !success {
t.Fatalf("expected success response, got %#v", response)
}
projectLocalMemory := filepath.Join(workspaceDir, ".xworkmate", "memory.md")
content, err := os.ReadFile(projectLocalMemory)
if err != nil {
t.Fatalf("expected memory file %s: %v", projectLocalMemory, err)
}
text := string(content)
if !strings.Contains(text, "preferred-route: single-agent") {
t.Fatalf("expected preferred route in %s, got %q", projectLocalMemory, text)
}
if !strings.Contains(text, "preferred-skills: PPTX") {
t.Fatalf("expected preferred skills in %s, got %q", projectLocalMemory, text)
}
projectHomeMemory := filepath.Join(
homeDir,
"self-improving",
"projects",
filepath.Base(workspaceDir)+".md",
)
if _, err := os.Stat(projectHomeMemory); !os.IsNotExist(err) {
t.Fatalf("expected auto memory write to stay project-local only, got stat err=%v", err)
}
}
func TestExecuteSessionTaskExplicitRoutingDoesNotRecordProjectMemory(t *testing.T) {
homeDir := t.TempDir()
workspaceDir := filepath.Join(t.TempDir(), "workspace")
if err := os.MkdirAll(workspaceDir, 0o755); err != nil {
t.Fatalf("create workspace: %v", err)
}
t.Setenv("HOME", homeDir)
server := NewServer()
providerServer := newExternalSingleAgentProvider(t, "claude", "done")
defer providerServer.Close()
server.syncProviders([]syncedProvider{{
ProviderID: "claude",
Label: "Claude",
Endpoint: providerServer.URL,
Enabled: true,
}})
response, rpcErr := server.executeSessionTask(task{
req: shared.RPCRequest{
Params: map[string]any{
"sessionId": "session-explicit",
"threadId": "thread-explicit",
"provider": "claude",
"taskPrompt": "create a powerpoint deck for launch",
"workingDirectory": workspaceDir,
"routing": map[string]any{
"routingMode": "explicit",
"explicitExecutionTarget": "singleAgent",
"explicitProviderId": "claude",
"availableSkills": []any{
map[string]any{
"id": "pptx",
"label": "PPTX",
"description": "slides",
"installed": true,
},
},
},
},
},
})
if rpcErr != nil {
t.Fatalf("expected success, got rpc error: %v", rpcErr)
}
if success, _ := response["success"].(bool); !success {
t.Fatalf("expected success response, got %#v", response)
}
projectHomeMemory := filepath.Join(
homeDir,
"self-improving",
"projects",
filepath.Base(workspaceDir)+".md",
)
projectLocalMemory := filepath.Join(workspaceDir, ".xworkmate", "memory.md")
for _, target := range []string{projectHomeMemory, projectLocalMemory} {
if _, err := os.Stat(target); !os.IsNotExist(err) {
t.Fatalf("expected no memory write for explicit routing at %s, err=%v", target, err)
}
}
}
func TestExecuteSessionTaskExplicitProviderRequiresSyncedCatalog(t *testing.T) {
server := NewServer()
response, rpcErr := server.executeSessionTask(task{
req: shared.RPCRequest{
Method: "session.start",
Params: map[string]any{
"sessionId": "session-explicit-provider",
"threadId": "thread-explicit-provider",
"taskPrompt": "create a powerpoint deck for launch",
"routing": map[string]any{
"routingMode": "explicit",
"explicitExecutionTarget": "singleAgent",
"explicitProviderId": "claude",
},
},
},
})
if rpcErr != nil {
t.Fatalf("expected structured unavailable response, got rpc error: %v", rpcErr)
}
if got := response["unavailable"]; got != true {
t.Fatalf("expected unavailable response, got %#v", response)
}
if got := response["unavailableCode"]; got != "PROVIDER_UNAVAILABLE" {
t.Fatalf("expected PROVIDER_UNAVAILABLE, got %#v", response)
}
}
func TestExecuteSessionTaskRequiresRouting(t *testing.T) {
server := NewServer()
_, rpcErr := server.executeSessionTask(task{
req: shared.RPCRequest{
ID: "request-1",
Method: "session.start",
Params: map[string]any{
"sessionId": "session-missing-routing",
"threadId": "thread-missing-routing",
"taskPrompt": "hello",
},
},
})
if rpcErr == nil {
t.Fatalf("expected routing-required error")
}
if rpcErr.Message != "ROUTING_REQUIRED" {
t.Fatalf("expected ROUTING_REQUIRED, got %#v", rpcErr)
}
}
func TestExecuteSessionTaskAutoRoutingPromotesComplexRequestToMultiAgent(t *testing.T) {
workspaceDir := filepath.Join(t.TempDir(), "workspace")
if err := os.MkdirAll(workspaceDir, 0o755); err != nil {
t.Fatalf("create workspace: %v", err)
}
aiGateway := httptest.NewServer(
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"choices":[{"message":{"content":"planner output"}}]}`))
}),
)
defer aiGateway.Close()
server := NewServer()
response, rpcErr := server.executeSessionTask(task{
req: shared.RPCRequest{
Params: map[string]any{
"sessionId": "session-complex",
"threadId": "thread-complex",
"provider": "claude",
"taskPrompt": "collect latest news and summarize it into a report for review",
"workingDirectory": workspaceDir,
"aiGatewayBaseUrl": aiGateway.URL,
"aiGatewayApiKey": "test-key",
"routing": map[string]any{
"routingMode": "auto",
"preferredGatewayTarget": "local",
},
},
},
})
if rpcErr != nil {
t.Fatalf("expected success, got rpc error: %v", rpcErr)
}
if success, _ := response["success"].(bool); !success {
t.Fatalf("expected success response, got %#v", response)
}
if got := response["mode"]; got != "multi-agent" {
t.Fatalf("expected session mode to be promoted to multi-agent, got %#v", got)
}
if got := response["resolvedExecutionTarget"]; got != "multi-agent" {
t.Fatalf("expected resolved execution target multi-agent, got %#v", got)
}
}
func TestHandleRoutingResolveAllowsSkillInstallRetry(t *testing.T) {
tempDir := t.TempDir()
finder := filepath.Join(tempDir, "find-skills.sh")
installer := filepath.Join(tempDir, "install-skills.sh")
if err := os.WriteFile(
finder,
[]byte("#!/bin/sh\nprintf '%s' '{\"candidates\":[{\"id\":\"video-translator\",\"label\":\"video-translator\",\"description\":\"translate video\",\"installed\":false}]}'\n"),
0o755,
); err != nil {
t.Fatalf("write finder: %v", err)
}
if err := os.WriteFile(
installer,
[]byte("#!/bin/sh\nprintf '%s' '{\"candidates\":[{\"id\":\"video-translator\",\"label\":\"video-translator\",\"description\":\"translate video\",\"installed\":true}]}'\n"),
0o755,
); err != nil {
t.Fatalf("write installer: %v", err)
}
t.Setenv("ACP_FIND_SKILLS_BIN", finder)
t.Setenv("ACP_INSTALL_SKILL_BIN", installer)
result := handleRoutingResolve(map[string]any{
"taskPrompt": "translate and dub this video with subtitles",
"workingDirectory": "/tmp/workspace",
"routing": map[string]any{
"routingMode": "auto",
"allowSkillInstall": true,
"availableSkills": []any{
map[string]any{
"id": "docx",
"label": "docx",
"description": "docs",
"installed": true,
},
},
},
})
if got := result["skillResolutionSource"]; got != "find_skills" {
t.Fatalf("expected find_skills source, got %#v", got)
}
if got := result["needsSkillInstall"]; got != true {
t.Fatalf("expected first pass to request install approval, got %#v", got)
}
requestID, _ := result["skillInstallRequestId"].(string)
if strings.TrimSpace(requestID) == "" {
t.Fatalf("expected install request id, got %#v", result)
}
retried := handleRoutingResolve(map[string]any{
"taskPrompt": "translate and dub this video with subtitles",
"workingDirectory": "/tmp/workspace",
"routing": map[string]any{
"routingMode": "auto",
"allowSkillInstall": true,
"installApproval": map[string]any{
"requestId": requestID,
"approvedSkillKeys": []any{"video-translator"},
},
"availableSkills": []any{
map[string]any{
"id": "docx",
"label": "docx",
"description": "docs",
"installed": true,
},
},
},
})
if got := retried["needsSkillInstall"]; got != false {
t.Fatalf("expected install retry to clear needsSkillInstall, got %#v", got)
}
resolvedSkills, _ := retried["resolvedSkills"].([]string)
if len(resolvedSkills) != 1 || resolvedSkills[0] != "video-translator" {
t.Fatalf("expected installed skill to resolve, got %#v", retried["resolvedSkills"])
}
}

1024
internal/acp/server.go Normal file

File diff suppressed because it is too large Load Diff

95
internal/acp/stdio.go Normal file
View File

@ -0,0 +1,95 @@
package acp
import (
"bufio"
"encoding/json"
"errors"
"fmt"
"io"
"strings"
"sync"
"xworkmate-bridge/internal/shared"
)
func RunStdio(input io.Reader, output io.Writer) {
server := NewServer()
reader := bufio.NewReader(input)
var writeMu sync.Mutex
writeMessage := func(message map[string]any) {
payload, _ := jsonMarshal(message)
writeMu.Lock()
defer writeMu.Unlock()
_, _ = output.Write(append(payload, '\n'))
}
for {
payload, err := readStdioMessage(reader)
if err != nil {
if errors.Is(err, io.EOF) {
return
}
writeMessage(shared.ErrorEnvelope(nil, -32700, err.Error()))
continue
}
if len(strings.TrimSpace(string(payload))) == 0 {
continue
}
request, err := shared.DecodeRPCRequest(payload)
if err != nil {
writeMessage(shared.ErrorEnvelope(nil, -32700, err.Error()))
continue
}
response, rpcErr := server.handleRequest(request, writeMessage)
if request.ID == nil {
continue
}
if rpcErr != nil {
writeMessage(
shared.ErrorEnvelope(request.ID, rpcErr.Code, rpcErr.Message),
)
continue
}
writeMessage(shared.ResultEnvelope(request.ID, response))
}
}
func readStdioMessage(reader *bufio.Reader) ([]byte, error) {
line, err := reader.ReadString('\n')
if err != nil {
return nil, err
}
line = strings.TrimSpace(line)
if line == "" {
return nil, nil
}
if strings.HasPrefix(strings.ToLower(line), "content-length:") {
var contentLength int
if _, err := fmt.Sscanf(line, "Content-Length: %d", &contentLength); err != nil {
if _, err2 := fmt.Sscanf(line, "content-length: %d", &contentLength); err2 != nil {
return nil, fmt.Errorf("invalid content-length header")
}
}
for {
headerLine, err := reader.ReadString('\n')
if err != nil {
return nil, err
}
if strings.TrimSpace(headerLine) == "" {
break
}
}
body := make([]byte, contentLength)
if _, err := io.ReadFull(reader, body); err != nil {
return nil, err
}
return body, nil
}
return []byte(line), nil
}
func jsonMarshal(message map[string]any) ([]byte, error) {
return json.Marshal(message)
}

View File

@ -0,0 +1,78 @@
package acp
import (
"encoding/json"
"net/http"
"strings"
"xworkmate-bridge/internal/shared"
)
func (s *Server) allowedOrigins() []string {
raw := strings.TrimSpace(shared.EnvOrDefault(
"ACP_ALLOWED_ORIGINS",
"https://xworkmate.svc.plus,http://localhost:*,http://127.0.0.1:*",
))
if raw == "" {
return nil
}
parts := strings.Split(raw, ",")
origins := make([]string, 0, len(parts))
for _, part := range parts {
candidate := strings.TrimSpace(part)
if candidate == "" {
continue
}
origins = append(origins, candidate)
}
return origins
}
func (s *Server) originAllowed(origin string) bool {
origin = strings.TrimSpace(origin)
if origin == "" {
return true
}
for _, allowed := range s.allowedOrigins() {
if strings.HasSuffix(allowed, ":*") {
if strings.HasPrefix(origin, strings.TrimSuffix(allowed, "*")) {
return true
}
continue
}
if origin == allowed {
return true
}
}
return false
}
func (s *Server) applyCORS(w http.ResponseWriter, r *http.Request) {
origin := strings.TrimSpace(r.Header.Get("Origin"))
if origin == "" || !s.originAllowed(origin) {
return
}
headers := w.Header()
headers.Set("Access-Control-Allow-Origin", origin)
headers.Set("Access-Control-Allow-Methods", "POST, OPTIONS")
headers.Set(
"Access-Control-Allow-Headers",
"Authorization, Content-Type, Accept",
)
headers.Set("Access-Control-Max-Age", "600")
headers.Add("Vary", "Origin")
headers.Add("Vary", "Access-Control-Request-Method")
headers.Add("Vary", "Access-Control-Request-Headers")
}
func (s *Server) writeJSONError(
w http.ResponseWriter,
requestID any,
statusCode int,
code int,
message string,
) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(statusCode)
_ = json.NewEncoder(w).Encode(shared.ErrorEnvelope(requestID, code, message))
}

View File

@ -0,0 +1,111 @@
package acp
import (
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
)
func TestHandleWebSocketRejectsUnknownOrigin(t *testing.T) {
t.Setenv("ACP_ALLOWED_ORIGINS", "https://xworkmate.svc.plus")
server := NewServer()
recorder := httptest.NewRecorder()
request := httptest.NewRequest(http.MethodGet, "http://127.0.0.1/acp", nil)
request.Header.Set("Origin", "https://evil.example.com")
server.HandleWebSocket(recorder, request)
if recorder.Code != http.StatusForbidden {
t.Fatalf("expected 403, got %d", recorder.Code)
}
if got := recorder.Header().Get("Content-Type"); !strings.Contains(got, "application/json") {
t.Fatalf("expected application/json content type, got %q", got)
}
}
func TestHandleRPCAllowsPreflightForConfiguredOrigin(t *testing.T) {
t.Setenv("ACP_ALLOWED_ORIGINS", "https://xworkmate.svc.plus,http://localhost:*")
server := NewServer()
recorder := httptest.NewRecorder()
request := httptest.NewRequest(http.MethodOptions, "http://127.0.0.1/acp/rpc", nil)
request.Header.Set("Origin", "https://xworkmate.svc.plus")
request.Header.Set("Access-Control-Request-Method", "POST")
server.HandleRPC(recorder, request)
if recorder.Code != http.StatusNoContent {
t.Fatalf("expected 204, got %d", recorder.Code)
}
if got := recorder.Header().Get("Access-Control-Allow-Origin"); got != "https://xworkmate.svc.plus" {
t.Fatalf("expected allow origin header, got %q", got)
}
}
func TestHandleRPCRejectsUnknownOrigin(t *testing.T) {
t.Setenv("ACP_ALLOWED_ORIGINS", "https://xworkmate.svc.plus")
server := NewServer()
recorder := httptest.NewRecorder()
request := httptest.NewRequest(
http.MethodPost,
"http://127.0.0.1/acp/rpc",
strings.NewReader(`{"jsonrpc":"2.0","id":1,"method":"acp.capabilities"}`),
)
request.Header.Set("Origin", "https://evil.example.com")
request.Header.Set("Content-Type", "application/json")
server.HandleRPC(recorder, request)
if recorder.Code != http.StatusForbidden {
t.Fatalf("expected 403, got %d", recorder.Code)
}
var envelope map[string]any
if err := json.Unmarshal(recorder.Body.Bytes(), &envelope); err != nil {
t.Fatalf("decode error envelope: %v", err)
}
if _, ok := envelope["error"]; !ok {
t.Fatalf("expected JSON-RPC error envelope, got %v", envelope)
}
}
func TestHandleRPCMethodErrorUsesJSONEnvelope(t *testing.T) {
server := NewServer()
recorder := httptest.NewRecorder()
request := httptest.NewRequest(http.MethodGet, "http://127.0.0.1/acp/rpc", nil)
server.HandleRPC(recorder, request)
if recorder.Code != http.StatusMethodNotAllowed {
t.Fatalf("expected 405, got %d", recorder.Code)
}
if got := recorder.Header().Get("Content-Type"); !strings.Contains(got, "application/json") {
t.Fatalf("expected application/json content type, got %q", got)
}
}
func TestHandleRPCCapabilitiesStillReturnsJSONResult(t *testing.T) {
server := NewServer()
recorder := httptest.NewRecorder()
request := httptest.NewRequest(
http.MethodPost,
"http://127.0.0.1/acp/rpc",
strings.NewReader(`{"jsonrpc":"2.0","id":1,"method":"acp.capabilities"}`),
)
request.Header.Set("Content-Type", "application/json")
server.HandleRPC(recorder, request)
if recorder.Code != http.StatusOK {
t.Fatalf("expected 200, got %d", recorder.Code)
}
if got := recorder.Header().Get("Content-Type"); !strings.Contains(got, "application/json") {
t.Fatalf("expected application/json content type, got %q", got)
}
if !strings.Contains(recorder.Body.String(), `"providers"`) {
t.Fatalf("expected capabilities response, got %q", recorder.Body.String())
}
}

View File

@ -0,0 +1,203 @@
package dispatch
import (
"maps"
"slices"
"strings"
)
type Provider struct {
ID string
Name string
DefaultArgs []string
Capabilities []string
}
type NodeState struct {
SelectedAgentID string
GatewayConnected bool
ExecutionTarget string
RuntimeMode string
BridgeEnabled bool
BridgeState string
ResolvedCodexCLIPath string
ConfiguredCodexCLIPath string
}
type NodeInfo struct {
ID string
Name string
Version string
}
type Request struct {
Providers []Provider
PreferredProviderID string
RequiredCapabilities []string
NodeState *NodeState
NodeInfo *NodeInfo
}
type Result struct {
Provider *Provider
AgentID string
Metadata map[string]any
}
func Resolve(request Request) Result {
provider := selectProvider(
request.Providers,
request.PreferredProviderID,
request.RequiredCapabilities,
)
if request.NodeState == nil {
return Result{Provider: provider, Metadata: map[string]any{}}
}
state := request.NodeState
nodeInfo := request.NodeInfo
nodeID := "xworkmate-app"
nodeName := "XWorkmate"
nodeVersion := ""
if nodeInfo != nil {
if strings.TrimSpace(nodeInfo.ID) != "" {
nodeID = strings.TrimSpace(nodeInfo.ID)
}
if strings.TrimSpace(nodeInfo.Name) != "" {
nodeName = strings.TrimSpace(nodeInfo.Name)
}
nodeVersion = strings.TrimSpace(nodeInfo.Version)
}
configuredPath := strings.TrimSpace(state.ConfiguredCodexCLIPath)
if strings.TrimSpace(state.ResolvedCodexCLIPath) != "" {
configuredPath = strings.TrimSpace(state.ResolvedCodexCLIPath)
}
localTransport := "stdio-jsonrpc"
if strings.TrimSpace(state.RuntimeMode) == "builtIn" {
localTransport = "ffi-runtime"
}
metadata := map[string]any{
"node": map[string]any{
"id": nodeID,
"name": nodeName,
"version": nodeVersion,
"kind": "app-mediated-cooperative-node",
"gatewayTransport": "websocket-rpc",
},
"dispatch": map[string]any{
"mode": dispatchMode(state.BridgeEnabled),
"executionTarget": strings.TrimSpace(state.ExecutionTarget),
},
"bridge": map[string]any{
"enabled": state.BridgeEnabled,
"state": strings.TrimSpace(state.BridgeState),
"gatewayConnected": state.GatewayConnected,
"runtimeMode": strings.TrimSpace(state.RuntimeMode),
"localTransport": localTransport,
},
}
if configuredPath != "" {
bridge := metadata["bridge"].(map[string]any)
bridge["binaryConfigured"] = true
}
if provider != nil {
metadata["provider"] = map[string]any{
"id": provider.ID,
"name": provider.Name,
"defaultArgs": provider.DefaultArgs,
"capabilities": provider.Capabilities,
}
}
return Result{
Provider: provider,
AgentID: strings.TrimSpace(state.SelectedAgentID),
Metadata: metadata,
}
}
func dispatchMode(bridgeEnabled bool) string {
if bridgeEnabled {
return "cooperative"
}
return "gateway-only"
}
func selectProvider(
providers []Provider,
preferredProviderID string,
requiredCapabilities []string,
) *Provider {
required := normalizeCapabilities(requiredCapabilities)
preferredID := strings.TrimSpace(preferredProviderID)
if preferredID != "" {
for _, provider := range providers {
if provider.ID == preferredID && supportsProvider(provider, required) {
candidate := provider
return &candidate
}
}
}
filtered := make([]Provider, 0, len(providers))
for _, provider := range providers {
if supportsProvider(provider, required) {
filtered = append(filtered, provider)
}
}
if len(filtered) == 0 {
return nil
}
slices.SortFunc(filtered, func(a, b Provider) int {
return strings.Compare(a.ID, b.ID)
})
candidate := filtered[0]
return &candidate
}
func supportsProvider(provider Provider, required map[string]struct{}) bool {
if len(required) == 0 {
return true
}
provided := normalizeCapabilities(provider.Capabilities)
for capability := range required {
if _, ok := provided[capability]; !ok {
return false
}
}
return true
}
func normalizeCapabilities(values []string) map[string]struct{} {
normalized := map[string]struct{}{}
for _, value := range values {
item := strings.TrimSpace(strings.ToLower(value))
if item == "" {
continue
}
normalized[item] = struct{}{}
}
return normalized
}
func ResultMap(result Result) map[string]any {
response := map[string]any{
"metadata": result.Metadata,
}
if result.Provider != nil {
provider := *result.Provider
response["providerId"] = provider.ID
response["provider"] = map[string]any{
"id": provider.ID,
"name": provider.Name,
"defaultArgs": slices.Clone(provider.DefaultArgs),
"capabilities": slices.Clone(provider.Capabilities),
}
}
if strings.TrimSpace(result.AgentID) != "" {
response["agentId"] = strings.TrimSpace(result.AgentID)
}
return maps.Clone(response)
}

View File

@ -0,0 +1,96 @@
package dispatch
import "testing"
func TestResolvePrefersRequestedProviderWhenCapabilitiesMatch(t *testing.T) {
result := Resolve(Request{
Providers: []Provider{
{
ID: "codex",
Name: "Codex",
Capabilities: []string{"chat", "gateway-bridge"},
},
{
ID: "qwen",
Name: "Qwen",
Capabilities: []string{"chat"},
},
},
PreferredProviderID: "codex",
RequiredCapabilities: []string{"gateway-bridge"},
})
if result.Provider == nil || result.Provider.ID != "codex" {
t.Fatalf("expected codex provider, got %#v", result.Provider)
}
}
func TestResolveFallsBackDeterministicallyByID(t *testing.T) {
result := Resolve(Request{
Providers: []Provider{
{
ID: "qwen",
Name: "Qwen",
Capabilities: []string{"chat"},
},
{
ID: "codex",
Name: "Codex",
Capabilities: []string{"chat"},
},
},
RequiredCapabilities: []string{"chat"},
})
if result.Provider == nil || result.Provider.ID != "codex" {
t.Fatalf("expected deterministic codex fallback, got %#v", result.Provider)
}
}
func TestResolveBuildsGatewayDispatchMetadata(t *testing.T) {
result := Resolve(Request{
Providers: []Provider{
{
ID: "codex",
Name: "Codex CLI",
DefaultArgs: []string{"app-server"},
Capabilities: []string{"chat", "gateway-bridge"},
},
},
PreferredProviderID: "codex",
RequiredCapabilities: []string{"gateway-bridge"},
NodeState: &NodeState{
SelectedAgentID: "main",
GatewayConnected: true,
ExecutionTarget: "local",
RuntimeMode: "externalCli",
BridgeEnabled: true,
BridgeState: "registered",
ResolvedCodexCLIPath: "/opt/homebrew/bin/codex",
},
NodeInfo: &NodeInfo{
ID: "xworkmate-app",
Name: "XWorkmate",
Version: "1.0.0",
},
})
if result.Provider == nil || result.Provider.ID != "codex" {
t.Fatalf("expected codex provider, got %#v", result.Provider)
}
if result.AgentID != "main" {
t.Fatalf("expected agent id main, got %q", result.AgentID)
}
dispatch, ok := result.Metadata["dispatch"].(map[string]any)
if !ok || dispatch["mode"] != "cooperative" {
t.Fatalf("expected cooperative dispatch, got %#v", result.Metadata["dispatch"])
}
bridge, ok := result.Metadata["bridge"].(map[string]any)
if !ok || bridge["localTransport"] != "stdio-jsonrpc" {
t.Fatalf("expected stdio-jsonrpc bridge transport, got %#v", result.Metadata["bridge"])
}
provider, ok := result.Metadata["provider"].(map[string]any)
if !ok || provider["id"] != "codex" {
t.Fatalf("expected provider metadata for codex, got %#v", result.Metadata["provider"])
}
}

View File

@ -0,0 +1,98 @@
package gatewayruntime
import "strings"
func normalizeChatRunEvent(event string, payload map[string]any) map[string]any {
switch event {
case "chat":
runID := strings.TrimSpace(stringValue(payload["runId"]))
state := strings.TrimSpace(stringValue(payload["state"]))
if runID == "" && state == "" {
return nil
}
message := asMap(payload["message"])
assistantText := ""
if strings.EqualFold(strings.TrimSpace(stringValue(message["role"])), "assistant") {
assistantText = extractMessageText(message)
}
normalized := map[string]any{
"runId": runID,
"sessionKey": strings.TrimSpace(stringValue(payload["sessionKey"])),
"state": state,
"source": "chat",
"terminal": state == "final" || state == "aborted" || state == "error",
}
if assistantText != "" {
normalized["assistantText"] = assistantText
}
if errorMessage := strings.TrimSpace(stringValue(payload["errorMessage"])); errorMessage != "" {
normalized["errorMessage"] = errorMessage
}
return normalized
case "agent":
runID := strings.TrimSpace(stringValue(payload["runId"]))
if runID == "" {
return nil
}
stream := strings.TrimSpace(stringValue(payload["stream"]))
if !strings.EqualFold(stream, "assistant") {
return nil
}
data := asMap(payload["data"])
assistantText := strings.TrimSpace(stringValue(data["text"]))
if assistantText == "" {
assistantText = extractMessageText(data)
}
if assistantText == "" {
return nil
}
sessionKey := strings.TrimSpace(stringValue(payload["sessionKey"]))
if sessionKey == "" {
sessionKey = strings.TrimSpace(stringValue(data["sessionKey"]))
}
return map[string]any{
"runId": runID,
"sessionKey": sessionKey,
"state": "delta",
"source": "agent",
"stream": stream,
"assistantText": assistantText,
"terminal": false,
}
default:
return nil
}
}
func asList(value any) []any {
switch typed := value.(type) {
case []any:
return typed
default:
return nil
}
}
func extractMessageText(message map[string]any) string {
directContent, ok := message["content"].(string)
if ok {
return strings.TrimSpace(directContent)
}
parts := make([]string, 0, 4)
for _, part := range asList(message["content"]) {
segment := asMap(part)
text := strings.TrimSpace(firstNonEmpty(
stringValue(segment["text"]),
stringValue(segment["thinking"]),
))
if text != "" {
parts = append(parts, text)
continue
}
nestedContent := strings.TrimSpace(stringValue(segment["content"]))
if nestedContent != "" {
parts = append(parts, nestedContent)
}
}
return strings.TrimSpace(strings.Join(parts, "\n"))
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,337 @@
package gatewayruntime
import (
"encoding/json"
"net"
"net/http"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/gorilla/websocket"
)
func TestManagerConnectAndRequest(t *testing.T) {
server := newFakeGatewayServer(t)
defer server.Close()
manager := NewManager()
manager.ReconnectDelay = 20 * time.Millisecond
notifications := make([]map[string]any, 0, 8)
var mu sync.Mutex
notify := func(message map[string]any) {
mu.Lock()
defer mu.Unlock()
notifications = append(notifications, message)
}
result := manager.Connect(buildTestConnectRequest(server.Port()), notify)
if !result.OK {
t.Fatalf("expected connect success, got %#v", result.Error)
}
if result.ReturnedDeviceToken != "device-token-1" {
t.Fatalf("expected returned device token, got %#v", result.ReturnedDeviceToken)
}
requestResult := manager.Request(
"runtime-1",
"health",
map[string]any{},
2*time.Second,
notify,
)
if !requestResult.OK {
t.Fatalf("expected health success, got %#v", requestResult.Error)
}
payload, ok := requestResult.Payload.(map[string]any)
if !ok || payload["status"] != "ok" {
t.Fatalf("unexpected health payload %#v", requestResult.Payload)
}
mu.Lock()
defer mu.Unlock()
if len(notifications) == 0 {
t.Fatalf("expected notifications during connect")
}
}
func TestManagerReconnectsAfterSocketClose(t *testing.T) {
server := newFakeGatewayServer(t)
server.closeAfterConnect.Store(true)
defer server.Close()
manager := NewManager()
manager.ReconnectDelay = 25 * time.Millisecond
reconnected := make(chan struct{}, 1)
notify := func(message map[string]any) {
params := asMap(message["params"])
if strings.TrimSpace(stringValue(message["method"])) != "xworkmate.gateway.snapshot" {
return
}
snapshot := asMap(params["snapshot"])
if snapshot["status"] == "connected" && server.ConnectCount() >= 2 {
select {
case reconnected <- struct{}{}:
default:
}
}
}
result := manager.Connect(buildTestConnectRequest(server.Port()), notify)
if !result.OK {
t.Fatalf("expected connect success, got %#v", result.Error)
}
select {
case <-reconnected:
case <-time.After(3 * time.Second):
t.Fatalf("expected reconnect to complete; connect count=%d", server.ConnectCount())
}
}
func TestManagerSuppressesReconnectForPairingRequired(t *testing.T) {
server := newFakeGatewayServer(t)
server.connectErrorCode = "NOT_PAIRED"
server.connectErrorDetailCode = "PAIRING_REQUIRED"
defer server.Close()
manager := NewManager()
manager.ReconnectDelay = 20 * time.Millisecond
result := manager.Connect(buildTestConnectRequest(server.Port()), func(map[string]any) {})
if result.OK {
t.Fatalf("expected connect failure")
}
time.Sleep(120 * time.Millisecond)
if server.ConnectCount() != 1 {
t.Fatalf("expected reconnect suppression, got %d connect attempts", server.ConnectCount())
}
}
func TestSessionEmitsNormalizedChatRunPushEvents(t *testing.T) {
manager := NewManager()
session := newSession(manager, "runtime-1")
notifications := make([]map[string]any, 0, 8)
session.setNotify(func(message map[string]any) {
notifications = append(notifications, message)
})
session.handleEvent(
"chat",
map[string]any{"seq": 7},
map[string]any{
"runId": "run-1",
"sessionKey": "agent:main:main",
"state": "final",
"message": map[string]any{
"role": "assistant",
"content": []any{
map[string]any{"type": "text", "text": "XWORKMATE_OK"},
},
},
},
)
session.handleEvent(
"agent",
map[string]any{"seq": 8},
map[string]any{
"runId": "run-1",
"stream": "assistant",
"data": map[string]any{
"text": "DELTA_TEXT",
},
},
)
normalized := make([]map[string]any, 0, 2)
for _, notification := range notifications {
if strings.TrimSpace(stringValue(notification["method"])) != "xworkmate.gateway.push" {
continue
}
params := asMap(notification["params"])
event := asMap(params["event"])
if strings.TrimSpace(stringValue(event["event"])) != "chat.run" {
continue
}
normalized = append(normalized, asMap(event["payload"]))
}
if len(normalized) != 2 {
t.Fatalf("expected 2 normalized chat.run notifications, got %#v", normalized)
}
if normalized[0]["runId"] != "run-1" || normalized[0]["state"] != "final" {
t.Fatalf("unexpected normalized chat payload %#v", normalized[0])
}
if normalized[0]["assistantText"] != "XWORKMATE_OK" {
t.Fatalf("expected final assistant text, got %#v", normalized[0])
}
if normalized[0]["terminal"] != true {
t.Fatalf("expected terminal final chat.run, got %#v", normalized[0])
}
if normalized[1]["assistantText"] != "DELTA_TEXT" || normalized[1]["state"] != "delta" {
t.Fatalf("unexpected normalized agent payload %#v", normalized[1])
}
}
type fakeGatewayServer struct {
server *http.Server
listener net.Listener
connectCount atomic.Int32
closeAfterConnect atomic.Bool
connectErrorCode string
connectErrorDetailCode string
}
func newFakeGatewayServer(t *testing.T) *fakeGatewayServer {
t.Helper()
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("listen: %v", err)
}
fake := &fakeGatewayServer{listener: listener}
upgrader := websocket.Upgrader{CheckOrigin: func(*http.Request) bool { return true }}
mux := http.NewServeMux()
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
return
}
defer conn.Close()
_ = conn.WriteJSON(map[string]any{
"type": "event",
"event": "connect.challenge",
"payload": map[string]any{
"nonce": "nonce-1",
},
})
for {
_, payload, err := conn.ReadMessage()
if err != nil {
return
}
var frame map[string]any
if err := json.Unmarshal(payload, &frame); err != nil {
continue
}
if frame["type"] != "req" {
continue
}
id := frame["id"]
method := stringValue(frame["method"])
switch method {
case "connect":
fake.connectCount.Add(1)
if fake.connectErrorCode != "" {
_ = conn.WriteJSON(map[string]any{
"type": "res",
"id": id,
"ok": false,
"error": map[string]any{
"code": fake.connectErrorCode,
"message": "connect failed",
"details": map[string]any{
"code": fake.connectErrorDetailCode,
},
},
})
continue
}
_ = conn.WriteJSON(map[string]any{
"type": "res",
"id": id,
"ok": true,
"payload": map[string]any{
"server": map[string]any{"host": "127.0.0.1"},
"snapshot": map[string]any{
"sessionDefaults": map[string]any{"mainSessionKey": "main"},
},
"auth": map[string]any{
"role": "operator",
"scopes": defaultOperatorScopes,
"deviceToken": "device-token-1",
},
},
})
if fake.closeAfterConnect.Load() && fake.connectCount.Load() == 1 {
go func() {
time.Sleep(20 * time.Millisecond)
_ = conn.Close()
}()
}
case "health":
_ = conn.WriteJSON(map[string]any{
"type": "res",
"id": id,
"ok": true,
"payload": map[string]any{
"status": "ok",
},
})
default:
_ = conn.WriteJSON(map[string]any{
"type": "res",
"id": id,
"ok": true,
"payload": map[string]any{},
})
}
}
})
fake.server = &http.Server{Handler: mux}
go func() {
_ = fake.server.Serve(listener)
}()
return fake
}
func (f *fakeGatewayServer) Port() int {
return f.listener.Addr().(*net.TCPAddr).Port
}
func (f *fakeGatewayServer) ConnectCount() int {
return int(f.connectCount.Load())
}
func (f *fakeGatewayServer) Close() {
_ = f.server.Close()
}
func buildTestConnectRequest(port int) ConnectRequest {
return ConnectRequest{
RuntimeID: "runtime-1",
Mode: "remote",
ClientID: "openclaw-macos",
Locale: "en_US",
UserAgent: "XWorkmate/1.0.0",
Endpoint: Endpoint{
Host: "127.0.0.1",
Port: port,
TLS: false,
},
ConnectAuthMode: "shared-token",
ConnectAuthFields: []string{"token"},
ConnectAuthSources: []string{"shared:form"},
HasSharedAuth: true,
HasDeviceToken: false,
PackageInfo: PackageInfo{
AppName: "XWorkmate",
Version: "1.0.0",
},
DeviceInfo: DeviceInfo{
Platform: "macos",
PlatformVersion: "14.0",
DeviceFamily: "Mac",
ModelIdentifier: "Mac14,5",
},
Identity: DeviceIdentity{
DeviceID: "device-1",
PublicKeyBase64URL: "tl4fnKW7VLD0Cl4lQTu2CEgHPs4PWAX7eVgWfWQWk2Q",
PrivateKeyBase64URL: "dr7GfMKoO-lJBtgA0dE5m6f_X4kEFsxChDc7mW8mkXu2Xh-cpbsUsPQKXiVBO7YISAc-zg9YBft5WBZ9ZBaTZA",
},
Auth: AuthConfig{
Token: "shared-token",
},
}
}

View File

@ -0,0 +1,129 @@
package gatewayruntime
import "time"
const (
defaultProtocolVersion = 3
defaultReconnectDelay = 2 * time.Second
defaultConnectTimeout = 10 * time.Second
defaultChallengeWait = 2 * time.Second
defaultRequestTimeout = 15 * time.Second
)
var defaultOperatorScopes = []string{
"operator.admin",
"operator.read",
"operator.write",
"operator.approvals",
"operator.pairing",
}
type Endpoint struct {
Host string
Port int
TLS bool
}
type PackageInfo struct {
AppName string
PackageName string
Version string
BuildNumber string
}
type DeviceInfo struct {
Platform string
PlatformVersion string
DeviceFamily string
ModelIdentifier string
}
func (d DeviceInfo) PlatformLabel() string {
if d.PlatformVersion == "" {
return d.Platform
}
return d.Platform + " " + d.PlatformVersion
}
type DeviceIdentity struct {
DeviceID string
PublicKeyBase64URL string
PrivateKeyBase64URL string
}
type AuthConfig struct {
Token string
DeviceToken string
Password string
}
type ConnectRequest struct {
RuntimeID string
Mode string
ClientID string
Locale string
UserAgent string
Endpoint Endpoint
ConnectAuthMode string
ConnectAuthFields []string
ConnectAuthSources []string
HasSharedAuth bool
HasDeviceToken bool
PackageInfo PackageInfo
DeviceInfo DeviceInfo
Identity DeviceIdentity
Auth AuthConfig
}
type ConnectResult struct {
OK bool
Snapshot map[string]any
Auth map[string]any
ReturnedDeviceToken string
Error map[string]any
}
type RequestResult struct {
OK bool
Payload any
Error map[string]any
}
type GatewayError struct {
Message string
Code string
Details map[string]any
}
func (e *GatewayError) Error() string {
if e == nil {
return ""
}
return e.Message
}
func (e *GatewayError) DetailCode() string {
if e == nil || e.Details == nil {
return ""
}
if value, ok := e.Details["code"].(string); ok {
return value
}
return ""
}
func (e *GatewayError) Map() map[string]any {
if e == nil {
return map[string]any{}
}
payload := map[string]any{
"message": e.Message,
}
if e.Code != "" {
payload["code"] = e.Code
}
if len(e.Details) > 0 {
payload["details"] = e.Details
}
return payload
}

View File

@ -0,0 +1,49 @@
package handler
import (
"encoding/json"
"net/http"
"xworkmate-bridge/internal/service"
)
type Authenticator interface {
Authenticate(username, password string) error
}
type AuthHandler struct {
service Authenticator
}
func NewAuthHandler(svc Authenticator) *AuthHandler {
return &AuthHandler{service: svc}
}
func NewServiceAdapter(svc *service.AuthService) Authenticator {
return authServiceAdapter{service: svc}
}
func (h *AuthHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
var payload struct {
Username string `json:"username"`
Password string `json:"password"`
}
if err := json.NewDecoder(r.Body).Decode(&payload); err != nil {
http.Error(w, "invalid json", http.StatusBadRequest)
return
}
if err := h.service.Authenticate(payload.Username, payload.Password); err != nil {
http.Error(w, err.Error(), http.StatusUnauthorized)
return
}
w.WriteHeader(http.StatusOK)
_ = json.NewEncoder(w).Encode(map[string]string{"status": "ok"})
}
type authServiceAdapter struct {
service *service.AuthService
}
func (a authServiceAdapter) Authenticate(username, password string) error {
return a.service.Authenticate(nil, username, password)
}

View File

@ -0,0 +1,53 @@
package handler
import (
"bytes"
"errors"
"net/http"
"net/http/httptest"
"testing"
)
type fakeAuthenticator struct {
err error
}
func (f fakeAuthenticator) Authenticate(username, password string) error {
return f.err
}
func TestAuthHandlerRejectsInvalidJSON(t *testing.T) {
handler := NewAuthHandler(fakeAuthenticator{})
req := httptest.NewRequest(http.MethodPost, "/auth", bytes.NewBufferString("{"))
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
if rec.Code != http.StatusBadRequest {
t.Fatalf("expected 400, got %d", rec.Code)
}
}
func TestAuthHandlerReturnsUnauthorizedOnServiceFailure(t *testing.T) {
handler := NewAuthHandler(fakeAuthenticator{err: errors.New("invalid credentials")})
req := httptest.NewRequest(http.MethodPost, "/auth", bytes.NewBufferString(`{"username":"alice","password":"secret"}`))
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
if rec.Code != http.StatusUnauthorized {
t.Fatalf("expected 401, got %d", rec.Code)
}
}
func TestAuthHandlerReturnsOKOnSuccess(t *testing.T) {
handler := NewAuthHandler(fakeAuthenticator{})
req := httptest.NewRequest(http.MethodPost, "/auth", bytes.NewBufferString(`{"username":"alice","password":"secret"}`))
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("expected 200, got %d", rec.Code)
}
}

234
internal/memory/provider.go Normal file
View File

@ -0,0 +1,234 @@
package memory
import (
"fmt"
"os"
"path/filepath"
"strings"
"time"
)
type Source struct {
Path string
Scope string
}
type Preferences struct {
PreferredRoute string
PreferredModel string
PreferredSkills []string
Provider string
}
type LoadResult struct {
MergedText string
Sources []Source
Preferences Preferences
ProjectFiles []string
}
type SuccessEntry struct {
ResolvedExecutionTarget string
ResolvedProviderID string
ResolvedModel string
ResolvedSkills []string
Summary string
}
type Service struct {
HomeDir string
}
func NewService(homeDir string) Service {
return Service{HomeDir: strings.TrimSpace(homeDir)}
}
func (s Service) Load(workingDirectory string) LoadResult {
projectName := projectNameFromWorkingDirectory(workingDirectory)
paths := []Source{
{Path: filepath.Join(s.HomeDir, "self-improving", "memory.md"), Scope: "global"},
{Path: filepath.Join(s.HomeDir, "self-improving", "projects", projectName+".md"), Scope: "project-home"},
{Path: filepath.Join(strings.TrimSpace(workingDirectory), ".xworkmate", "memory.md"), Scope: "project-local"},
}
merged := make([]string, 0, len(paths))
sources := make([]Source, 0, len(paths))
prefs := Preferences{}
projectFiles := make([]string, 0, 2)
for _, source := range paths {
if strings.TrimSpace(source.Path) == "" {
continue
}
content, err := os.ReadFile(source.Path)
if err != nil {
continue
}
text := sanitizeMemoryText(string(content))
if strings.TrimSpace(text) == "" {
continue
}
sources = append(sources, source)
merged = append(merged, fmt.Sprintf("## %s\n%s", source.Scope, text))
mergePreferences(&prefs, parsePreferences(text))
if source.Scope != "global" {
projectFiles = append(projectFiles, source.Path)
}
}
return LoadResult{
MergedText: strings.TrimSpace(strings.Join(merged, "\n\n")),
Sources: sources,
Preferences: prefs,
ProjectFiles: projectFiles,
}
}
func (s Service) RecordSuccess(workingDirectory string, entry SuccessEntry) error {
workingDirectory = strings.TrimSpace(workingDirectory)
if workingDirectory == "" {
return nil
}
projectName := projectNameFromWorkingDirectory(workingDirectory)
if projectName == "" {
return nil
}
target := s.projectWriteTarget(workingDirectory, projectName)
if target == "" {
return nil
}
block := formatSuccessEntry(entry)
if err := os.MkdirAll(filepath.Dir(target), 0o755); err != nil {
return err
}
file, err := os.OpenFile(target, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0o644)
if err != nil {
return err
}
if _, err := file.WriteString(block); err != nil {
_ = file.Close()
return err
}
if err := file.Close(); err != nil {
return err
}
return nil
}
func (s Service) projectWriteTarget(
workingDirectory string,
projectName string,
) string {
repoLocalDir := filepath.Join(workingDirectory, ".xworkmate")
if err := os.MkdirAll(repoLocalDir, 0o755); err == nil {
return filepath.Join(repoLocalDir, "memory.md")
}
return filepath.Join(s.HomeDir, "self-improving", "projects", projectName+".md")
}
func formatSuccessEntry(entry SuccessEntry) string {
lines := []string{
"",
fmt.Sprintf("## Auto route %s", time.Now().Format(time.RFC3339)),
fmt.Sprintf("preferred-route: %s", strings.TrimSpace(entry.ResolvedExecutionTarget)),
}
if strings.TrimSpace(entry.ResolvedModel) != "" {
lines = append(lines, fmt.Sprintf("preferred-model: %s", strings.TrimSpace(entry.ResolvedModel)))
}
if len(entry.ResolvedSkills) > 0 {
lines = append(lines, fmt.Sprintf("preferred-skills: %s", strings.Join(entry.ResolvedSkills, ", ")))
}
if strings.TrimSpace(entry.ResolvedProviderID) != "" {
lines = append(lines, fmt.Sprintf("provider: %s", strings.TrimSpace(entry.ResolvedProviderID)))
}
if summary := sanitizeMemoryText(entry.Summary); strings.TrimSpace(summary) != "" {
lines = append(lines, "summary:")
for _, line := range strings.Split(summary, "\n") {
trimmed := strings.TrimSpace(line)
if trimmed == "" {
continue
}
lines = append(lines, fmt.Sprintf("- %s", trimmed))
}
}
return strings.Join(lines, "\n") + "\n"
}
func parsePreferences(text string) Preferences {
prefs := Preferences{}
for _, line := range strings.Split(text, "\n") {
trimmed := strings.TrimSpace(line)
switch {
case strings.HasPrefix(strings.ToLower(trimmed), "preferred-route:"):
prefs.PreferredRoute = normalizePreferredRoute(
strings.TrimSpace(strings.TrimPrefix(trimmed, "preferred-route:")),
)
case strings.HasPrefix(strings.ToLower(trimmed), "preferred-model:"):
prefs.PreferredModel = strings.TrimSpace(strings.TrimPrefix(trimmed, "preferred-model:"))
case strings.HasPrefix(strings.ToLower(trimmed), "preferred-skills:"):
raw := strings.TrimSpace(strings.TrimPrefix(trimmed, "preferred-skills:"))
for _, item := range strings.Split(raw, ",") {
value := strings.TrimSpace(item)
if value != "" {
prefs.PreferredSkills = append(prefs.PreferredSkills, value)
}
}
case strings.HasPrefix(strings.ToLower(trimmed), "provider:"):
prefs.Provider = strings.TrimSpace(strings.TrimPrefix(trimmed, "provider:"))
}
}
return prefs
}
func mergePreferences(dst *Preferences, src Preferences) {
if strings.TrimSpace(src.PreferredRoute) != "" {
dst.PreferredRoute = strings.TrimSpace(src.PreferredRoute)
}
if strings.TrimSpace(src.PreferredModel) != "" {
dst.PreferredModel = strings.TrimSpace(src.PreferredModel)
}
if len(src.PreferredSkills) > 0 {
dst.PreferredSkills = append([]string(nil), src.PreferredSkills...)
}
if strings.TrimSpace(src.Provider) != "" {
dst.Provider = strings.TrimSpace(src.Provider)
}
}
func sanitizeMemoryText(text string) string {
lines := strings.Split(text, "\n")
filtered := make([]string, 0, len(lines))
for _, line := range lines {
normalized := strings.ToLower(strings.TrimSpace(line))
if normalized == "" {
filtered = append(filtered, "")
continue
}
if strings.Contains(normalized, "token") ||
strings.Contains(normalized, "password") ||
strings.Contains(normalized, "secret") ||
strings.Contains(normalized, "api_key") ||
strings.Contains(normalized, "apikey") ||
strings.Contains(normalized, "api key") {
continue
}
filtered = append(filtered, line)
}
return strings.TrimSpace(strings.Join(filtered, "\n"))
}
func projectNameFromWorkingDirectory(workingDirectory string) string {
cleaned := strings.TrimSpace(workingDirectory)
if cleaned == "" {
return ""
}
return strings.TrimSpace(filepath.Base(cleaned))
}
func normalizePreferredRoute(value string) string {
switch strings.ToLower(strings.TrimSpace(value)) {
case "gateway-chat":
return "gateway"
default:
return strings.TrimSpace(value)
}
}

View File

@ -0,0 +1,117 @@
package memory
import (
"os"
"path/filepath"
"strings"
"testing"
)
func TestLoadMergesGlobalAndProjectMemoryAndSanitizesSecrets(t *testing.T) {
tempDir := t.TempDir()
workingDir := filepath.Join(tempDir, "workspace")
homeDir := filepath.Join(tempDir, "home")
if err := os.MkdirAll(filepath.Join(workingDir, ".xworkmate"), 0o755); err != nil {
t.Fatalf("mkdir workspace: %v", err)
}
if err := os.MkdirAll(filepath.Join(homeDir, "self-improving", "projects"), 0o755); err != nil {
t.Fatalf("mkdir home: %v", err)
}
if err := os.WriteFile(filepath.Join(homeDir, "self-improving", "memory.md"), []byte("preferred-route: gateway-chat\napi_key: hidden\n"), 0o644); err != nil {
t.Fatalf("write global memory: %v", err)
}
if err := os.WriteFile(filepath.Join(homeDir, "self-improving", "projects", "workspace.md"), []byte("preferred-model: gpt-5.4\n"), 0o644); err != nil {
t.Fatalf("write project memory: %v", err)
}
if err := os.WriteFile(filepath.Join(workingDir, ".xworkmate", "memory.md"), []byte("preferred-skills: pptx, pdf\npassword: hidden\n"), 0o644); err != nil {
t.Fatalf("write local memory: %v", err)
}
result := NewService(homeDir).Load(workingDir)
if len(result.Sources) != 3 {
t.Fatalf("expected 3 memory sources, got %d", len(result.Sources))
}
if strings.Contains(strings.ToLower(result.MergedText), "api_key") || strings.Contains(strings.ToLower(result.MergedText), "password") {
t.Fatalf("expected sanitized merged text, got %q", result.MergedText)
}
if result.Preferences.PreferredRoute != "gateway" {
t.Fatalf("unexpected preferred route: %#v", result.Preferences)
}
if result.Preferences.PreferredModel != "gpt-5.4" {
t.Fatalf("unexpected preferred model: %#v", result.Preferences)
}
if len(result.Preferences.PreferredSkills) != 2 {
t.Fatalf("unexpected preferred skills: %#v", result.Preferences.PreferredSkills)
}
}
func TestRecordSuccessWritesProjectLevelMemoryFiles(t *testing.T) {
tempDir := t.TempDir()
workingDir := filepath.Join(tempDir, "repo")
homeDir := filepath.Join(tempDir, "home")
if err := os.MkdirAll(workingDir, 0o755); err != nil {
t.Fatalf("mkdir working dir: %v", err)
}
service := NewService(homeDir)
err := service.RecordSuccess(workingDir, SuccessEntry{
ResolvedExecutionTarget: "single-agent",
ResolvedModel: "gpt-5.4",
ResolvedSkills: []string{"pptx", "pdf"},
Summary: "created a clean deck",
})
if err != nil {
t.Fatalf("record success: %v", err)
}
repoLocalTarget := filepath.Join(workingDir, ".xworkmate", "memory.md")
content, err := os.ReadFile(repoLocalTarget)
if err != nil {
t.Fatalf("read target %s: %v", repoLocalTarget, err)
}
text := string(content)
if !strings.Contains(text, "preferred-route: single-agent") {
t.Fatalf("missing preferred route in %s: %q", repoLocalTarget, text)
}
if strings.Contains(strings.ToLower(text), "token") {
t.Fatalf("unexpected sensitive content in %s: %q", repoLocalTarget, text)
}
homeProjectTarget := filepath.Join(homeDir, "self-improving", "projects", "repo.md")
if _, err := os.Stat(homeProjectTarget); !os.IsNotExist(err) {
t.Fatalf("expected single project-level write target, got stat err=%v", err)
}
}
func TestLoadLetsProjectMemoryOverrideGlobalPreferences(t *testing.T) {
tempDir := t.TempDir()
workingDir := filepath.Join(tempDir, "workspace")
homeDir := filepath.Join(tempDir, "home")
if err := os.MkdirAll(filepath.Join(workingDir, ".xworkmate"), 0o755); err != nil {
t.Fatalf("mkdir workspace: %v", err)
}
if err := os.MkdirAll(filepath.Join(homeDir, "self-improving", "projects"), 0o755); err != nil {
t.Fatalf("mkdir home: %v", err)
}
if err := os.WriteFile(filepath.Join(homeDir, "self-improving", "memory.md"), []byte("preferred-route: single-agent\npreferred-model: gpt-4o\npreferred-skills: docx\n"), 0o644); err != nil {
t.Fatalf("write global memory: %v", err)
}
if err := os.WriteFile(filepath.Join(homeDir, "self-improving", "projects", "workspace.md"), []byte("preferred-route: gateway\npreferred-model: gpt-5.4\n"), 0o644); err != nil {
t.Fatalf("write project home memory: %v", err)
}
if err := os.WriteFile(filepath.Join(workingDir, ".xworkmate", "memory.md"), []byte("preferred-route: multi-agent\npreferred-skills: pptx, pdf\n"), 0o644); err != nil {
t.Fatalf("write project local memory: %v", err)
}
result := NewService(homeDir).Load(workingDir)
if result.Preferences.PreferredRoute != "multi-agent" {
t.Fatalf("expected project-local route to win, got %#v", result.Preferences)
}
if result.Preferences.PreferredModel != "gpt-5.4" {
t.Fatalf("expected project-home model to override global, got %#v", result.Preferences)
}
if len(result.Preferences.PreferredSkills) != 2 || result.Preferences.PreferredSkills[0] != "pptx" {
t.Fatalf("expected project-local skills to win, got %#v", result.Preferences.PreferredSkills)
}
}

150
internal/mounts/config.go Normal file
View File

@ -0,0 +1,150 @@
package mounts
import (
"fmt"
"os"
"path/filepath"
"regexp"
"strings"
"time"
)
const (
codexManagedMCPBlockStart = "# BEGIN XWORKMATE MANAGED MCP BLOCK"
codexManagedMCPBlockEnd = "# END XWORKMATE MANAGED MCP BLOCK"
opencodeManagedMCPBlockStart = "# BEGIN XWORKMATE MANAGED MCP BLOCK"
opencodeManagedMCPBlockEnd = "# END XWORKMATE MANAGED MCP BLOCK"
)
var mcpServerSectionPattern = regexp.MustCompile(
`(?m)^\[mcp_servers\.[^\]]+\]`,
)
func countMCPSections(content string) int {
return len(mcpServerSectionPattern.FindAllStringIndex(content, -1))
}
func defaultCodexHome() string {
home, err := os.UserHomeDir()
if err != nil || strings.TrimSpace(home) == "" {
return ""
}
return filepath.Join(home, ".codex")
}
func defaultOpencodeHome() string {
home, err := os.UserHomeDir()
if err != nil || strings.TrimSpace(home) == "" {
return ""
}
return filepath.Join(home, ".opencode")
}
func defaultOpenClawHome() string {
home, err := os.UserHomeDir()
if err != nil || strings.TrimSpace(home) == "" {
return ""
}
return filepath.Join(home, ".openclaw")
}
func stripManagedBlock(content, startMarker, endMarker string) string {
if strings.TrimSpace(content) == "" {
return content
}
remaining := content
for {
start := strings.Index(remaining, startMarker)
if start < 0 {
break
}
end := strings.Index(remaining[start:], endMarker)
if end < 0 {
remaining = remaining[:start]
break
}
end += start
remaining = remaining[:start] + remaining[end+len(endMarker):]
}
return remaining
}
func mergeManagedBlock(content, block, startMarker, endMarker string) string {
preserved := strings.TrimRight(
stripManagedBlock(content, startMarker, endMarker),
"\n",
)
if preserved == "" {
return block + "\n"
}
return preserved + "\n\n" + block + "\n"
}
func buildCodexManagedMCPBlock(servers []ManagedMCPServer) string {
var buffer strings.Builder
buffer.WriteString(codexManagedMCPBlockStart)
buffer.WriteString("\n# Generated by XWorkmate - Managed MCP Server Configuration\n")
buffer.WriteString(
fmt.Sprintf("# Last updated: %s\n\n", time.Now().Format(time.RFC3339Nano)),
)
for _, server := range servers {
buffer.WriteString(fmt.Sprintf("[mcp_servers.%s]\n", server.ID))
buffer.WriteString(fmt.Sprintf("command = %q\n", server.Command))
if len(server.Args) > 0 {
buffer.WriteString(fmt.Sprintf("args = %s\n", formatTOMLArray(server.Args)))
}
buffer.WriteString("\n")
}
buffer.WriteString(codexManagedMCPBlockEnd)
return strings.TrimRight(buffer.String(), "\n")
}
func buildOpencodeManagedMCPBlock(servers []ManagedMCPServer) string {
var buffer strings.Builder
buffer.WriteString(opencodeManagedMCPBlockStart)
buffer.WriteString("\n# Generated by XWorkmate - Managed MCP Server Configuration\n")
buffer.WriteString(
fmt.Sprintf("# Last updated: %s\n\n", time.Now().Format(time.RFC3339Nano)),
)
for _, server := range servers {
buffer.WriteString(fmt.Sprintf("[mcp_servers.%s]\n", server.ID))
if strings.TrimSpace(server.URL) != "" {
buffer.WriteString(fmt.Sprintf("url = %q\n", strings.TrimSpace(server.URL)))
} else {
buffer.WriteString("type = \"stdio\"\n")
buffer.WriteString(fmt.Sprintf("command = %q\n", server.Command))
if len(server.Args) > 0 {
buffer.WriteString(fmt.Sprintf("args = %s\n", formatTOMLArray(server.Args)))
}
}
buffer.WriteString("\n")
}
buffer.WriteString(opencodeManagedMCPBlockEnd)
return strings.TrimRight(buffer.String(), "\n")
}
func formatTOMLArray(items []string) string {
if len(items) == 0 {
return "[]"
}
var quoted []string
for _, item := range items {
quoted = append(quoted, fmt.Sprintf("%q", item))
}
return "[" + strings.Join(quoted, ", ") + "]"
}
func applyManagedBlock(configPath, block, startMarker, endMarker string) error {
configDir := filepath.Dir(configPath)
if err := os.MkdirAll(configDir, 0o755); err != nil {
return err
}
content, err := os.ReadFile(configPath)
if err != nil && !os.IsNotExist(err) {
return err
}
merged := mergeManagedBlock(string(content), block, startMarker, endMarker)
return os.WriteFile(configPath, []byte(merged), 0o644)
}

View File

@ -0,0 +1,437 @@
package mounts
import (
"context"
"encoding/json"
"os"
"os/exec"
"path/filepath"
"sort"
"strconv"
"strings"
"time"
)
type ManagedMCPServer struct {
ID string
Name string
Transport string
Command string
URL string
Args []string
Enabled bool
}
type Config struct {
AutoSync bool
UsesAris bool
ManagedMCPServers []ManagedMCPServer
}
type ArisInput struct {
Available bool
BundleVersion string
LLMChatServerPath string
SkillCount int
BridgeAvailable bool
Error string
}
type Request struct {
Config Config
AIGatewayURL string
ConfiguredCodexCLIPath string
CodexHome string
OpencodeHome string
OpenClawHome string
Aris ArisInput
}
type MountTargetState struct {
TargetID string
Label string
Available bool
SupportsSkills bool
SupportsMCP bool
SupportsAIGatewayInjection bool
DiscoveryState string
SyncState string
DiscoveredSkillCount int
DiscoveredMCPCount int
ManagedMCPCount int
Detail string
}
type Result struct {
MountTargets []MountTargetState
ArisBundleVersion string
ArisCompatStatus string
}
func Reconcile(request Request) Result {
states := []MountTargetState{
reconcileAris(request.Config, request.Aris),
reconcileCodex(
request.Config,
request.AIGatewayURL,
request.ConfiguredCodexCLIPath,
request.CodexHome,
),
reconcileCLIListTarget(
request.Config,
"claude",
"Claude",
[]string{"claude", "mcp", "list"},
),
reconcileCLIListTarget(
request.Config,
"gemini",
"Gemini",
[]string{"gemini", "mcp", "list"},
),
reconcileOpencode(request.Config, request.OpencodeHome),
reconcileOpenClaw(request.Config, request.OpenClawHome),
}
result := Result{
MountTargets: states,
ArisBundleVersion: strings.TrimSpace(request.Aris.BundleVersion),
ArisCompatStatus: "idle",
}
for _, state := range states {
if state.TargetID == "aris" {
result.ArisCompatStatus = state.SyncState
break
}
}
return result
}
func ResultMap(result Result) map[string]any {
rawTargets := make([]map[string]any, 0, len(result.MountTargets))
for _, target := range result.MountTargets {
rawTargets = append(rawTargets, map[string]any{
"targetId": target.TargetID,
"label": target.Label,
"available": target.Available,
"supportsSkills": target.SupportsSkills,
"supportsMcp": target.SupportsMCP,
"supportsAiGatewayInjection": target.SupportsAIGatewayInjection,
"discoveryState": target.DiscoveryState,
"syncState": target.SyncState,
"discoveredSkillCount": target.DiscoveredSkillCount,
"discoveredMcpCount": target.DiscoveredMCPCount,
"managedMcpCount": target.ManagedMCPCount,
"detail": target.Detail,
})
}
return map[string]any{
"mountTargets": rawTargets,
"arisBundleVersion": result.ArisBundleVersion,
"arisCompatStatus": result.ArisCompatStatus,
}
}
func reconcileAris(config Config, input ArisInput) MountTargetState {
state := placeholderState("aris", "ARIS", true, true, false)
if strings.TrimSpace(input.Error) != "" {
state.Available = false
state.DiscoveryState = "error"
state.SyncState = "error"
state.Detail = strings.TrimSpace(input.Error)
return state
}
if !input.Available {
state.DiscoveryState = "missing"
state.SyncState = "missing"
state.Detail = "Embedded ARIS bundle is unavailable."
return state
}
state.Available = true
state.DiscoveryState = "ready"
state.DiscoveredSkillCount = input.SkillCount
llmChatReady := strings.TrimSpace(input.LLMChatServerPath) != ""
if config.UsesAris && llmChatReady && input.BridgeAvailable {
state.SyncState = "ready"
state.DiscoveredMCPCount = 1
state.ManagedMCPCount = 1
state.Detail = "Embedded bundle " +
strings.TrimSpace(input.BundleVersion) +
" ready; XWorkmate Go core manages llm-chat and claude-review."
return state
}
state.SyncState = "embedded"
if llmChatReady {
state.DiscoveredMCPCount = 1
}
if llmChatReady {
state.Detail = "Embedded bundle extracted, but the XWorkmate Go core is not available yet."
} else {
state.Detail = "Embedded bundle extracted, but llm-chat metadata is missing."
}
return state
}
func reconcileCodex(
config Config,
aiGatewayURL string,
configuredCodexCLIPath string,
codexHome string,
) MountTargetState {
state := placeholderState("codex", "Codex", true, true, true)
available := codexAvailable(configuredCodexCLIPath)
configHome := strings.TrimSpace(codexHome)
if configHome == "" {
configHome = defaultCodexHome()
}
configPath := filepath.Join(configHome, "config.toml")
content, _ := os.ReadFile(configPath)
discovered := countMCPSections(string(content))
managedServers := enabledCodexServers(config.ManagedMCPServers)
if available && config.AutoSync && len(managedServers) > 0 {
_ = applyManagedBlock(
configPath,
buildCodexManagedMCPBlock(managedServers),
codexManagedMCPBlockStart,
codexManagedMCPBlockEnd,
)
}
state.Available = available
if available {
state.DiscoveryState = "ready"
} else {
state.DiscoveryState = "missing"
}
switch {
case !available:
state.SyncState = "missing"
case config.AutoSync:
state.SyncState = "ready"
default:
state.SyncState = "disabled"
}
state.DiscoveredMCPCount = discovered
state.ManagedMCPCount = len(managedServers)
if strings.TrimSpace(aiGatewayURL) != "" {
state.Detail = "LLM API uses launch-scoped defaults for collaboration runs."
} else {
state.Detail = "LLM API not configured."
}
return state
}
func reconcileCLIListTarget(
config Config,
targetID string,
label string,
command []string,
) MountTargetState {
state := placeholderState(targetID, label, true, true, true)
available := binaryExists(command[0])
discovered := 0
if available {
discovered = countListedEntries(command)
}
state.Available = available
if available {
state.DiscoveryState = "ready"
} else {
state.DiscoveryState = "missing"
}
if available && config.AutoSync {
state.SyncState = "launch-only"
} else {
state.SyncState = "disabled"
}
state.DiscoveredMCPCount = discovered
state.ManagedMCPCount = len(enabledServers(config.ManagedMCPServers))
state.Detail = "MCP discovery uses `" + strings.Join(command, " ") +
"`; LLM API stays launch-scoped."
return state
}
func reconcileOpencode(config Config, opencodeHome string) MountTargetState {
state := placeholderState("opencode", "OpenCode", true, true, true)
available := binaryExists("opencode")
configHome := strings.TrimSpace(opencodeHome)
if configHome == "" {
configHome = defaultOpencodeHome()
}
configPath := filepath.Join(configHome, "config.toml")
content, _ := os.ReadFile(configPath)
discovered := countMCPSections(string(content))
managedServers := enabledServers(config.ManagedMCPServers)
if available && config.AutoSync && len(managedServers) > 0 {
_ = applyManagedBlock(
configPath,
buildOpencodeManagedMCPBlock(managedServers),
opencodeManagedMCPBlockStart,
opencodeManagedMCPBlockEnd,
)
}
state.Available = available
if available {
state.DiscoveryState = "ready"
} else {
state.DiscoveryState = "missing"
}
switch {
case !available:
state.SyncState = "missing"
case config.AutoSync:
state.SyncState = "ready"
default:
state.SyncState = "disabled"
}
state.DiscoveredMCPCount = discovered
state.ManagedMCPCount = len(managedServers)
state.Detail = "Managed MCP config is preserved in ~/.opencode/config.toml."
return state
}
func reconcileOpenClaw(config Config, openClawHome string) MountTargetState {
state := placeholderState("openclaw", "OpenClaw", true, false, true)
available := binaryExists("openclaw")
state.Available = available
if available {
state.DiscoveryState = "ready"
} else {
state.DiscoveryState = "missing"
}
if available && config.AutoSync {
state.SyncState = "launch-only"
} else {
state.SyncState = "disabled"
}
state.Detail = "OpenClaw acts as the host/control plane mount."
configHome := strings.TrimSpace(openClawHome)
if configHome == "" {
configHome = defaultOpenClawHome()
}
configPath := filepath.Join(configHome, "openclaw.json")
if content, err := os.ReadFile(configPath); err == nil {
var decoded map[string]any
if err := json.Unmarshal(content, &decoded); err == nil {
agents := 0
if rawAgents, ok := decoded["agents"].(map[string]any); ok {
if rawList, ok := rawAgents["list"].([]any); ok {
agents = len(rawList)
}
}
skillsDir := filepath.Join(configHome, "skills")
if entries, err := os.ReadDir(skillsDir); err == nil {
state.DiscoveredSkillCount = len(entries)
}
state.Detail = "agents: " + itoa(agents) + " · skills: " +
itoa(state.DiscoveredSkillCount)
} else {
state.Detail = "OpenClaw config detected but could not be fully parsed."
}
}
return state
}
func placeholderState(
targetID string,
label string,
supportsSkills bool,
supportsMCP bool,
supportsAIGatewayInjection bool,
) MountTargetState {
return MountTargetState{
TargetID: targetID,
Label: label,
SupportsSkills: supportsSkills,
SupportsMCP: supportsMCP,
SupportsAIGatewayInjection: supportsAIGatewayInjection,
DiscoveryState: "idle",
SyncState: "idle",
}
}
func codexAvailable(configuredPath string) bool {
if strings.TrimSpace(configuredPath) != "" {
if _, err := os.Stat(strings.TrimSpace(configuredPath)); err == nil {
return true
}
}
return binaryExists("codex")
}
func binaryExists(command string) bool {
_, err := exec.LookPath(command)
return err == nil
}
func countListedEntries(command []string) int {
output := strings.TrimSpace(runCommand(command))
if output == "" ||
strings.Contains(output, "No MCP servers configured") ||
strings.Contains(output, "No MCP servers configured yet") ||
strings.Contains(output, "No MCP servers configured.") {
return 0
}
lines := strings.Split(output, "\n")
count := 0
for _, line := range lines {
trimmed := strings.TrimSpace(line)
switch {
case trimmed == "":
case strings.HasPrefix(trimmed, "Usage:"):
case strings.HasPrefix(trimmed, "┌"):
case strings.HasPrefix(trimmed, "│"):
case strings.HasPrefix(trimmed, "└"):
default:
count++
}
}
return count
}
func runCommand(command []string) string {
if len(command) == 0 {
return ""
}
ctx, cancel := context.WithTimeout(context.Background(), 4*time.Second)
defer cancel()
cmd := exec.CommandContext(ctx, command[0], command[1:]...)
output, err := cmd.CombinedOutput()
if err != nil && len(output) == 0 {
return ""
}
return string(output)
}
func enabledServers(servers []ManagedMCPServer) []ManagedMCPServer {
filtered := make([]ManagedMCPServer, 0, len(servers))
for _, server := range servers {
if !server.Enabled {
continue
}
filtered = append(filtered, server)
}
sort.SliceStable(filtered, func(i, j int) bool {
return filtered[i].ID < filtered[j].ID
})
return filtered
}
func enabledCodexServers(servers []ManagedMCPServer) []ManagedMCPServer {
filtered := make([]ManagedMCPServer, 0, len(servers))
for _, server := range servers {
if !server.Enabled || strings.TrimSpace(server.Command) == "" {
continue
}
filtered = append(filtered, server)
}
sort.SliceStable(filtered, func(i, j int) bool {
return filtered[i].ID < filtered[j].ID
})
return filtered
}
func itoa(value int) string {
return strconv.Itoa(value)
}

View File

@ -0,0 +1,115 @@
package mounts
import (
"os"
"path/filepath"
"strings"
"testing"
)
func TestReconcileCodexAppliesManagedBlockAndPreservesUserEntries(t *testing.T) {
tempDir := t.TempDir()
configuredBinary := filepath.Join(tempDir, "custom-codex")
if err := os.WriteFile(configuredBinary, []byte("#!/bin/sh\nexit 0\n"), 0o755); err != nil {
t.Fatalf("write configured binary: %v", err)
}
configPath := filepath.Join(tempDir, "config.toml")
if err := os.WriteFile(configPath, []byte(`
[mcp_servers.user_server]
command = "user-mcp"
`), 0o644); err != nil {
t.Fatalf("write config: %v", err)
}
result := Reconcile(Request{
Config: Config{
AutoSync: true,
ManagedMCPServers: []ManagedMCPServer{
{ID: "xworkmate_server", Command: "xworkmate-mcp", Args: []string{"--port", "7777"}, Enabled: true},
},
},
ConfiguredCodexCLIPath: configuredBinary,
CodexHome: tempDir,
})
content, err := os.ReadFile(configPath)
if err != nil {
t.Fatalf("read config: %v", err)
}
if !strings.Contains(string(content), `[mcp_servers.user_server]`) {
t.Fatalf("expected user entry preserved: %s", string(content))
}
if !strings.Contains(string(content), `[mcp_servers.xworkmate_server]`) {
t.Fatalf("expected managed entry written: %s", string(content))
}
if strings.Count(string(content), codexManagedMCPBlockStart) != 1 {
t.Fatalf("expected single managed block: %s", string(content))
}
if result.MountTargets[1].ManagedMCPCount != 1 {
t.Fatalf("expected codex managed count 1, got %d", result.MountTargets[1].ManagedMCPCount)
}
}
func TestReconcileOpencodeAppliesManagedBlockAndPreservesUserEntries(t *testing.T) {
tempDir := t.TempDir()
binDir := t.TempDir()
originalPath := os.Getenv("PATH")
t.Setenv("PATH", binDir+string(os.PathListSeparator)+originalPath)
if err := os.WriteFile(filepath.Join(binDir, "opencode"), []byte("#!/bin/sh\nexit 0\n"), 0o755); err != nil {
t.Fatalf("write opencode binary: %v", err)
}
configPath := filepath.Join(tempDir, "config.toml")
if err := os.WriteFile(configPath, []byte(`
[model]
name = "user-default"
`), 0o644); err != nil {
t.Fatalf("write config: %v", err)
}
result := Reconcile(Request{
Config: Config{
AutoSync: true,
ManagedMCPServers: []ManagedMCPServer{
{ID: "xworkmate_server", Command: "xworkmate-mcp", Args: []string{"--port", "3001"}, Enabled: true},
},
},
OpencodeHome: tempDir,
})
content, err := os.ReadFile(configPath)
if err != nil {
t.Fatalf("read config: %v", err)
}
if !strings.Contains(string(content), `[model]`) {
t.Fatalf("expected user config preserved: %s", string(content))
}
if !strings.Contains(string(content), `[mcp_servers.xworkmate_server]`) {
t.Fatalf("expected managed opencode entry written: %s", string(content))
}
if strings.Count(string(content), opencodeManagedMCPBlockStart) != 1 {
t.Fatalf("expected single opencode managed block: %s", string(content))
}
if result.MountTargets[4].ManagedMCPCount != 1 {
t.Fatalf("expected opencode managed count 1, got %d", result.MountTargets[4].ManagedMCPCount)
}
}
func TestReconcileArisReportsReadyWhenBundleAndBridgeAreAvailable(t *testing.T) {
result := Reconcile(Request{
Config: Config{UsesAris: true},
Aris: ArisInput{
Available: true,
BundleVersion: "test",
LLMChatServerPath: "mcp-server.py",
SkillCount: 2,
BridgeAvailable: true,
},
})
if got := result.MountTargets[0].SyncState; got != "ready" {
t.Fatalf("expected ready aris state, got %q", got)
}
if got := result.ArisBundleVersion; got != "test" {
t.Fatalf("expected bundle version test, got %q", got)
}
}

View File

@ -0,0 +1,78 @@
package router
import (
"context"
"strings"
"time"
"xworkmate-bridge/internal/shared"
)
type ClassificationRequest struct {
Prompt string
AIGatewayBaseURL string
AIGatewayAPIKey string
}
type Classifier interface {
Classify(req ClassificationRequest) string
}
type LLMClassifier struct{}
func (LLMClassifier) Classify(req ClassificationRequest) string {
baseURL := shared.NormalizeBaseURL(strings.TrimSpace(req.AIGatewayBaseURL))
apiKey := strings.TrimSpace(req.AIGatewayAPIKey)
if baseURL == "" {
baseURL = shared.NormalizeBaseURL(
shared.EnvOrDefault("LLM_BASE_URL", "https://api.openai.com/v1"),
)
}
if apiKey == "" {
apiKey = strings.TrimSpace(shared.EnvOrDefault("LLM_API_KEY", ""))
}
if baseURL == "" || apiKey == "" {
return ""
}
model := strings.TrimSpace(shared.EnvOrDefault("ACP_ROUTING_MODEL", "gpt-4o"))
if model == "" {
model = "gpt-4o"
}
ctx, cancel := context.WithTimeout(context.Background(), 45*time.Second)
defer cancel()
content, err := shared.CallOpenAICompatibleCtx(
ctx,
baseURL,
apiKey,
model,
[]map[string]string{
{
"role": "system",
"content": "Classify the user task into exactly one label: single-agent, multi-agent, or gateway. Return only the label.",
},
{
"role": "user",
"content": strings.TrimSpace(req.Prompt),
},
},
)
if err != nil {
return ""
}
return normalizeClassifierLabel(content)
}
func normalizeClassifierLabel(value string) string {
normalized := strings.ToLower(strings.TrimSpace(value))
switch {
case strings.Contains(normalized, ExecutionTargetSingleAgent):
return ExecutionTargetSingleAgent
case strings.Contains(normalized, ExecutionTargetMultiAgent):
return ExecutionTargetMultiAgent
case strings.Contains(normalized, ExecutionTargetGateway):
return ExecutionTargetGateway
default:
return ""
}
}

365
internal/router/router.go Normal file
View File

@ -0,0 +1,365 @@
package router
import (
"os"
"sort"
"strings"
"xworkmate-bridge/internal/memory"
"xworkmate-bridge/internal/skills"
)
const (
RoutingModeAuto = "auto"
RoutingModeExplicit = "explicit"
ExecutionTargetSingleAgent = "single-agent"
ExecutionTargetMultiAgent = "multi-agent"
ExecutionTargetGateway = "gateway"
ExecutionTargetGatewayChat = "gateway-chat"
EndpointTargetSingleAgent = "singleAgent"
EndpointTargetLocal = "local"
EndpointTargetRemote = "remote"
)
type Request struct {
Prompt string
WorkingDirectory string
RoutingMode string
PreferredGatewayTarget string
ExplicitExecutionTarget string
ExplicitProviderID string
ExplicitModel string
ExplicitSkills []string
AllowSkillInstall bool
InstallApproval skills.InstallApproval
AvailableSkills []skills.Candidate
AvailableProviders []string
AIGatewayBaseURL string
AIGatewayAPIKey string
}
type Result struct {
ResolvedExecutionTarget string
ResolvedEndpointTarget string
ResolvedProviderID string
ResolvedModel string
ResolvedSkills []string
SkillResolutionSource string
SkillCandidates []skills.Candidate
NeedsSkillInstall bool
SkillInstallRequestID string
MemorySources []memory.Source
Unavailable bool
UnavailableCode string
UnavailableMessage string
}
type Resolver struct {
SkillFinder skills.Finder
SkillInstaller skills.Installer
MemoryService memory.Service
Classifier Classifier
}
func NewResolver() Resolver {
homeDir, _ := os.UserHomeDir()
return Resolver{
SkillFinder: skills.NewDefaultFinder(),
SkillInstaller: skills.NewDefaultInstaller(),
MemoryService: memory.NewService(homeDir),
Classifier: LLMClassifier{},
}
}
func (r Resolver) Resolve(req Request) Result {
mem := r.MemoryService.Load(req.WorkingDirectory)
availableProviders := normalizeProviders(req.AvailableProviders)
result := Result{
ResolvedModel: strings.TrimSpace(req.ExplicitModel),
MemorySources: mem.Sources,
}
result.ResolvedExecutionTarget, result.ResolvedEndpointTarget = r.resolveExecution(req, mem.Preferences)
result.ResolvedProviderID, result.Unavailable, result.UnavailableCode, result.UnavailableMessage = resolveProvider(
req,
mem.Preferences,
availableProviders,
result.ResolvedExecutionTarget,
)
if result.ResolvedModel == "" {
result.ResolvedModel = strings.TrimSpace(mem.Preferences.PreferredModel)
}
skillRequest := skills.ResolveRequest{
Prompt: req.Prompt,
ExplicitSkills: req.ExplicitSkills,
AvailableSkills: req.AvailableSkills,
AllowSkillInstall: req.AllowSkillInstall,
InstallApproval: req.InstallApproval,
}
skillResult := skills.Resolve(skillRequest, r.SkillFinder, r.SkillInstaller)
result.ResolvedSkills = skillResult.ResolvedSkills
result.SkillResolutionSource = skillResult.Source
result.SkillCandidates = skillResult.Candidates
result.NeedsSkillInstall = skillResult.NeedsInstall
result.SkillInstallRequestID = skillResult.InstallRequestID
if len(result.ResolvedSkills) == 0 && len(mem.Preferences.PreferredSkills) > 0 {
result.ResolvedSkills = append([]string(nil), mem.Preferences.PreferredSkills...)
if result.SkillResolutionSource == "" || result.SkillResolutionSource == "none" {
result.SkillResolutionSource = "local_match"
}
}
if result.SkillResolutionSource == "" {
result.SkillResolutionSource = "none"
}
if result.ResolvedExecutionTarget == "" {
if len(availableProviders) > 0 {
result.ResolvedExecutionTarget = ExecutionTargetSingleAgent
} else {
result.ResolvedExecutionTarget = ExecutionTargetGateway
}
}
if result.ResolvedEndpointTarget == "" {
if result.ResolvedExecutionTarget == ExecutionTargetGateway {
result.ResolvedEndpointTarget = normalizeGatewayTarget(req.PreferredGatewayTarget)
} else {
result.ResolvedEndpointTarget = EndpointTargetSingleAgent
}
}
return result
}
func (r Resolver) resolveExecution(req Request, prefs memory.Preferences) (string, string) {
explicit := strings.TrimSpace(req.ExplicitExecutionTarget)
if strings.EqualFold(strings.TrimSpace(req.RoutingMode), RoutingModeExplicit) && explicit != "" {
return mapExplicitTarget(explicit)
}
prompt := normalize(req.Prompt)
localTask := looksLocal(prompt)
onlineTask := looksOnline(prompt)
complexTask := looksComplex(prompt)
switch {
case localTask && complexTask:
return ExecutionTargetMultiAgent, EndpointTargetSingleAgent
case onlineTask && complexTask:
return ExecutionTargetMultiAgent, EndpointTargetSingleAgent
case localTask:
return ExecutionTargetSingleAgent, EndpointTargetSingleAgent
case onlineTask:
return ExecutionTargetGateway, normalizeGatewayTarget(req.PreferredGatewayTarget)
case complexTask:
return ExecutionTargetMultiAgent, EndpointTargetSingleAgent
}
switch normalizeExecutionTarget(r.classify(req)) {
case ExecutionTargetGateway:
return ExecutionTargetGateway, normalizeGatewayTarget(req.PreferredGatewayTarget)
case ExecutionTargetMultiAgent:
return ExecutionTargetMultiAgent, EndpointTargetSingleAgent
case ExecutionTargetSingleAgent:
return ExecutionTargetSingleAgent, EndpointTargetSingleAgent
}
switch normalizeExecutionTarget(strings.TrimSpace(prefs.PreferredRoute)) {
case ExecutionTargetGateway:
return ExecutionTargetGateway, normalizeGatewayTarget(req.PreferredGatewayTarget)
case ExecutionTargetMultiAgent:
return ExecutionTargetMultiAgent, EndpointTargetSingleAgent
case ExecutionTargetSingleAgent:
if len(normalizeProviders(req.AvailableProviders)) > 0 {
return ExecutionTargetSingleAgent, EndpointTargetSingleAgent
}
}
if len(normalizeProviders(req.AvailableProviders)) > 0 {
return ExecutionTargetSingleAgent, EndpointTargetSingleAgent
}
return ExecutionTargetGateway, normalizeGatewayTarget(req.PreferredGatewayTarget)
}
func (r Resolver) classify(req Request) string {
if r.Classifier == nil {
return ""
}
return normalizeExecutionTarget(r.Classifier.Classify(ClassificationRequest{
Prompt: req.Prompt,
AIGatewayBaseURL: req.AIGatewayBaseURL,
AIGatewayAPIKey: req.AIGatewayAPIKey,
}))
}
func mapExplicitTarget(value string) (string, string) {
switch strings.TrimSpace(value) {
case EndpointTargetLocal:
return ExecutionTargetGateway, EndpointTargetLocal
case EndpointTargetRemote:
return ExecutionTargetGateway, EndpointTargetRemote
case "multiAgent", ExecutionTargetMultiAgent:
return ExecutionTargetMultiAgent, EndpointTargetSingleAgent
case EndpointTargetSingleAgent, ExecutionTargetSingleAgent:
return ExecutionTargetSingleAgent, EndpointTargetSingleAgent
default:
return ExecutionTargetSingleAgent, EndpointTargetSingleAgent
}
}
func normalizeGatewayTarget(value string) string {
switch strings.TrimSpace(value) {
case EndpointTargetLocal, "":
return EndpointTargetLocal
default:
return EndpointTargetRemote
}
}
func resolveProvider(
req Request,
prefs memory.Preferences,
availableProviders []string,
executionTarget string,
) (string, bool, string, string) {
explicitProviderID := normalize(strings.TrimSpace(req.ExplicitProviderID))
if explicitProviderID != "" {
if containsProvider(availableProviders, explicitProviderID) {
return explicitProviderID, false, "", ""
}
return "", true, "PROVIDER_UNAVAILABLE", "explicit provider is unavailable"
}
if executionTarget != ExecutionTargetSingleAgent {
preferredProvider := normalize(strings.TrimSpace(prefs.Provider))
if containsProvider(availableProviders, preferredProvider) {
return preferredProvider, false, "", ""
}
return "", false, "", ""
}
preferredProvider := normalize(strings.TrimSpace(prefs.Provider))
if containsProvider(availableProviders, preferredProvider) {
return preferredProvider, false, "", ""
}
if len(availableProviders) > 0 {
return availableProviders[0], false, "", ""
}
return "", true, "PROVIDER_UNAVAILABLE", "no single-agent provider is available"
}
func normalizeProviders(values []string) []string {
if len(values) == 0 {
return nil
}
unique := make(map[string]struct{}, len(values))
normalized := make([]string, 0, len(values))
for _, value := range values {
providerID := normalize(value)
if providerID == "" {
continue
}
if _, ok := unique[providerID]; ok {
continue
}
unique[providerID] = struct{}{}
normalized = append(normalized, providerID)
}
sort.Strings(normalized)
return normalized
}
func containsProvider(values []string, want string) bool {
want = normalize(want)
if want == "" {
return false
}
for _, value := range values {
if normalize(value) == want {
return true
}
}
return false
}
func looksLocal(prompt string) bool {
return containsAny(prompt, []string{
"ppt", "pptx", "powerpoint", "word", "docx", "excel", "xlsx", "pdf",
"image-resizer", "resize image", "compress image", "crop image",
})
}
func looksOnline(prompt string) bool {
return containsAny(prompt, []string{
"image-cog", "wan", "video-translator", "browser", "search", "news",
"资讯采集", "跨浏览器", "文生图", "文生视频", "图生视频", "视频翻译",
"translate video", "dub video", "subtitles",
})
}
func looksComplex(prompt string) bool {
strongSignals := containsAny(prompt, []string{
"multiple deliverables", "multiple outputs", "多个产物", "多个输出",
"审阅", "复核", "汇编", "end-to-end", "end to end",
})
if strongSignals {
return true
}
reviewSignals := containsAny(prompt, []string{
"review", "audit", "verify", "summarize", "compare",
"审阅", "复核", "汇总", "对比", "整理", "整合", "汇编",
})
multiStepSignals := containsAny(prompt, []string{
"workflow", "pipeline", "step by step", "multi-step", "collect and",
"analyze and", "review and", "compare and", "summarize and",
"先", "然后", "之后",
})
structuredOutputSignals := containsAny(prompt, []string{
"report", "memo", "table", "spreadsheet", "document", "deck", "slides",
"presentation", "报告", "总结", "表格", "文档", "演示",
})
onlineCollectionSignals := containsAny(prompt, []string{
"browser", "search", "news", "research", "crawl", "scrape",
"跨浏览器", "搜索", "资讯", "采集", "检索",
})
score := 0
if reviewSignals {
score++
}
if multiStepSignals {
score++
}
if structuredOutputSignals {
score++
}
if onlineCollectionSignals && structuredOutputSignals {
return true
}
return score >= 2
}
func containsAny(haystack string, needles []string) bool {
for _, needle := range needles {
if strings.Contains(haystack, normalize(needle)) {
return true
}
}
return false
}
func normalize(value string) string {
return strings.ToLower(strings.TrimSpace(value))
}
func normalizeExecutionTarget(value string) string {
switch normalize(value) {
case ExecutionTargetGatewayChat:
return ExecutionTargetGateway
default:
return normalize(value)
}
}

View File

@ -0,0 +1,136 @@
package router
import (
"testing"
"xworkmate-bridge/internal/memory"
"xworkmate-bridge/internal/skills"
)
type fakeClassifier string
func (f fakeClassifier) Classify(req ClassificationRequest) string {
return string(f)
}
func TestResolveExplicitTargetOverridesAuto(t *testing.T) {
resolver := Resolver{
SkillFinder: skills.StaticFinder{},
SkillInstaller: nil,
MemoryService: memory.Service{},
}
result := resolver.Resolve(Request{
Prompt: "search the web and summarize results",
RoutingMode: RoutingModeExplicit,
ExplicitExecutionTarget: "singleAgent",
ExplicitProviderID: "codex",
ExplicitModel: "gpt-5.4",
AvailableProviders: []string{"codex"},
})
if result.ResolvedExecutionTarget != ExecutionTargetSingleAgent {
t.Fatalf("expected explicit single-agent route, got %#v", result)
}
if result.ResolvedEndpointTarget != EndpointTargetSingleAgent {
t.Fatalf("expected singleAgent endpoint target, got %#v", result)
}
if result.ResolvedProviderID != "codex" || result.ResolvedModel != "gpt-5.4" {
t.Fatalf("unexpected explicit provider/model: %#v", result)
}
}
func TestResolveExplicitProviderRequiresAvailability(t *testing.T) {
resolver := Resolver{
SkillFinder: skills.StaticFinder{},
SkillInstaller: nil,
MemoryService: memory.Service{},
}
result := resolver.Resolve(Request{
Prompt: "search the web and summarize results",
RoutingMode: RoutingModeExplicit,
ExplicitExecutionTarget: "singleAgent",
ExplicitProviderID: "codex",
})
if !result.Unavailable {
t.Fatalf("expected explicit provider to be unavailable without synced catalog, got %#v", result)
}
if result.UnavailableCode != "PROVIDER_UNAVAILABLE" {
t.Fatalf("expected PROVIDER_UNAVAILABLE, got %#v", result)
}
}
func TestResolveAutoLocalTaskToSingleAgent(t *testing.T) {
resolver := Resolver{
SkillFinder: skills.StaticFinder{},
SkillInstaller: nil,
MemoryService: memory.Service{},
}
result := resolver.Resolve(Request{
Prompt: "create a PowerPoint deck from this outline",
})
if result.ResolvedExecutionTarget != ExecutionTargetSingleAgent {
t.Fatalf("expected single-agent route, got %#v", result)
}
}
func TestResolveAutoOnlineTaskToGateway(t *testing.T) {
resolver := Resolver{
SkillFinder: skills.StaticFinder{},
SkillInstaller: nil,
MemoryService: memory.Service{},
}
result := resolver.Resolve(Request{
Prompt: "跨浏览器执行并搜索最新资讯",
PreferredGatewayTarget: EndpointTargetLocal,
})
if result.ResolvedExecutionTarget != ExecutionTargetGateway {
t.Fatalf("expected gateway route, got %#v", result)
}
if result.ResolvedEndpointTarget != EndpointTargetLocal {
t.Fatalf("expected local gateway target, got %#v", result)
}
}
func TestResolveComplexTaskUpgradesToMultiAgent(t *testing.T) {
resolver := Resolver{
SkillFinder: skills.StaticFinder{},
SkillInstaller: nil,
MemoryService: memory.Service{},
}
result := resolver.Resolve(Request{
Prompt: "analyze these files, review the output, and summarize multiple deliverables",
})
if result.ResolvedExecutionTarget != ExecutionTargetMultiAgent {
t.Fatalf("expected multi-agent route, got %#v", result)
}
}
func TestResolveUsesClassifierForBoundarySamples(t *testing.T) {
resolver := Resolver{
SkillFinder: skills.StaticFinder{},
SkillInstaller: nil,
MemoryService: memory.Service{},
Classifier: fakeClassifier(ExecutionTargetGateway),
}
result := resolver.Resolve(Request{
Prompt: "help me handle this ambiguous request",
PreferredGatewayTarget: EndpointTargetLocal,
})
if result.ResolvedExecutionTarget != ExecutionTargetGateway {
t.Fatalf("expected classifier to resolve gateway route, got %#v", result)
}
if result.ResolvedEndpointTarget != EndpointTargetLocal {
t.Fatalf("expected local endpoint target, got %#v", result)
}
}

View File

@ -0,0 +1,37 @@
package service
import (
"context"
"errors"
"strings"
)
var ErrInvalidCredentials = errors.New("invalid credentials")
type AuthRepository interface {
Verify(ctx context.Context, username, password string) (bool, error)
}
type AuthService struct {
repo AuthRepository
}
func NewAuthService(repo AuthRepository) *AuthService {
return &AuthService{repo: repo}
}
func (s *AuthService) Authenticate(ctx context.Context, username, password string) error {
username = strings.TrimSpace(username)
password = strings.TrimSpace(password)
if username == "" || password == "" {
return ErrInvalidCredentials
}
ok, err := s.repo.Verify(ctx, username, password)
if err != nil {
return err
}
if !ok {
return ErrInvalidCredentials
}
return nil
}

View File

@ -0,0 +1,55 @@
package service
import (
"context"
"errors"
"testing"
)
type fakeAuthRepo struct {
verify func(ctx context.Context, username, password string) (bool, error)
}
func (f fakeAuthRepo) Verify(ctx context.Context, username, password string) (bool, error) {
return f.verify(ctx, username, password)
}
func TestAuthenticateRejectsBlankValues(t *testing.T) {
svc := NewAuthService(fakeAuthRepo{
verify: func(ctx context.Context, username, password string) (bool, error) {
return true, nil
},
})
if err := svc.Authenticate(context.Background(), " ", "secret"); !errors.Is(err, ErrInvalidCredentials) {
t.Fatalf("expected invalid credentials, got %v", err)
}
}
func TestAuthenticateRejectsFailedVerification(t *testing.T) {
svc := NewAuthService(fakeAuthRepo{
verify: func(ctx context.Context, username, password string) (bool, error) {
if username != "alice" || password != "secret" {
t.Fatalf("unexpected credentials: %q %q", username, password)
}
return false, nil
},
})
if err := svc.Authenticate(context.Background(), "alice", "secret"); !errors.Is(err, ErrInvalidCredentials) {
t.Fatalf("expected invalid credentials, got %v", err)
}
}
func TestAuthenticateReturnsRepoError(t *testing.T) {
wanted := errors.New("boom")
svc := NewAuthService(fakeAuthRepo{
verify: func(ctx context.Context, username, password string) (bool, error) {
return false, wanted
},
})
if err := svc.Authenticate(context.Background(), "alice", "secret"); !errors.Is(err, wanted) {
t.Fatalf("expected repo error, got %v", err)
}
}

View File

@ -0,0 +1,81 @@
package shared
import (
"fmt"
"os"
"strings"
)
func NormalizeBaseURL(raw string) string {
trimmed := strings.TrimSpace(raw)
if trimmed == "" {
return "https://api.openai.com/v1"
}
if strings.HasSuffix(trimmed, "/v1") {
return trimmed
}
return strings.TrimRight(trimmed, "/") + "/v1"
}
func EnvOrDefault(key, fallback string) string {
value := strings.TrimSpace(os.Getenv(key))
if value == "" {
return fallback
}
return value
}
func StringArg(arguments map[string]any, key, fallback string) string {
if arguments == nil {
return fallback
}
value, ok := arguments[key]
if !ok {
return fallback
}
text := strings.TrimSpace(fmt.Sprint(value))
if text == "" || text == "<nil>" {
return fallback
}
return text
}
func ListArg(arguments map[string]any, key string) []any {
if arguments == nil {
return nil
}
raw, ok := arguments[key]
if !ok || raw == nil {
return nil
}
if values, ok := raw.([]any); ok {
return values
}
if values, ok := raw.([]interface{}); ok {
return values
}
return nil
}
func IntArg(raw string, fallback int) int {
var parsed int
if _, err := fmt.Sscanf(raw, "%d", &parsed); err != nil || parsed <= 0 {
return fallback
}
return parsed
}
func BoolArg(raw string, fallback bool) bool {
trimmed := strings.TrimSpace(strings.ToLower(raw))
if trimmed == "" {
return fallback
}
switch trimmed {
case "1", "true", "yes", "on":
return true
case "0", "false", "no", "off":
return false
default:
return fallback
}
}

108
internal/shared/rpc.go Normal file
View File

@ -0,0 +1,108 @@
package shared
import (
"encoding/json"
"errors"
"fmt"
"net/http"
"strings"
)
type RPCRequest struct {
JSONRPC string `json:"jsonrpc,omitempty"`
ID any `json:"id,omitempty"`
Method string `json:"method,omitempty"`
Params map[string]any `json:"params,omitempty"`
}
type RPCError struct {
Code int `json:"code"`
Message string `json:"message"`
}
type ToolCallParams struct {
Name string `json:"name"`
Arguments map[string]any `json:"arguments"`
}
func DecodeRPCRequest(payload []byte) (RPCRequest, error) {
var request RPCRequest
if err := json.Unmarshal(payload, &request); err != nil {
return RPCRequest{}, fmt.Errorf("invalid json: %w", err)
}
if strings.TrimSpace(request.Method) == "" {
return RPCRequest{}, errors.New("missing method")
}
if request.Params == nil {
request.Params = map[string]any{}
}
return request, nil
}
func WriteSSE(w http.ResponseWriter, payload map[string]any) {
encoded, _ := json.Marshal(payload)
_, _ = fmt.Fprintf(w, "data: %s\n\n", encoded)
}
func ResultEnvelope(id any, result map[string]any) map[string]any {
return map[string]any{
"jsonrpc": "2.0",
"id": id,
"result": result,
}
}
func ErrorEnvelope(id any, code int, message string) map[string]any {
return map[string]any{
"jsonrpc": "2.0",
"id": id,
"error": map[string]any{
"code": code,
"message": message,
},
}
}
func NotificationEnvelope(method string, params map[string]any) map[string]any {
return map[string]any{
"jsonrpc": "2.0",
"method": method,
"params": params,
}
}
func ErrorResponse(id any, code int, message string) map[string]any {
return map[string]any{
"jsonrpc": "2.0",
"id": id,
"error": map[string]any{
"code": code,
"message": message,
},
}
}
func ToolTextResult(id any, content string) map[string]any {
return map[string]any{
"jsonrpc": "2.0",
"id": id,
"result": map[string]any{
"content": []map[string]any{
{"type": "text", "text": content},
},
},
}
}
func ToolErrorResult(id any, err error) map[string]any {
return map[string]any{
"jsonrpc": "2.0",
"id": id,
"result": map[string]any{
"content": []map[string]any{
{"type": "text", "text": fmt.Sprintf("Error: %v", err)},
},
"isError": true,
},
}
}

397
internal/shared/tools.go Normal file
View File

@ -0,0 +1,397 @@
package shared
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"os/exec"
"sort"
"strings"
"time"
)
func DetectACPProviders() []string {
candidates := []struct {
provider string
envKey string
binary string
}{
{provider: "codex", envKey: "ACP_CODEX_BIN", binary: "codex"},
{provider: "opencode", envKey: "ACP_OPENCODE_BIN", binary: "opencode"},
{provider: "claude", envKey: "ACP_CLAUDE_BIN", binary: "claude"},
{provider: "gemini", envKey: "ACP_GEMINI_BIN", binary: "gemini"},
}
providers := make([]string, 0, len(candidates))
for _, candidate := range candidates {
binary := strings.TrimSpace(EnvOrDefault(candidate.envKey, candidate.binary))
if binary == "" {
continue
}
if _, err := exec.LookPath(binary); err == nil {
providers = append(providers, candidate.provider)
}
}
sort.Strings(providers)
return providers
}
func RunProviderCommand(
ctx context.Context,
provider,
model,
prompt,
workingDirectory string,
) (string, error) {
command, args := ResolveProviderCommand(
provider,
model,
prompt,
workingDirectory,
)
if command == "" {
return "", fmt.Errorf("unsupported provider: %s", provider)
}
cmd := exec.CommandContext(ctx, command, args...)
if strings.TrimSpace(workingDirectory) != "" {
cmd.Dir = strings.TrimSpace(workingDirectory)
}
var stdout bytes.Buffer
var stderr bytes.Buffer
cmd.Stdout = &stdout
cmd.Stderr = &stderr
if err := cmd.Run(); err != nil {
if errors.Is(ctx.Err(), context.Canceled) {
return "", errors.New("run canceled")
}
message := strings.TrimSpace(stderr.String())
if message == "" {
message = err.Error()
}
return "", fmt.Errorf("%s run failed: %s", provider, message)
}
output := strings.TrimSpace(stdout.String())
if output == "" {
output = strings.TrimSpace(stderr.String())
}
if output == "" {
return "", fmt.Errorf("%s returned empty output", provider)
}
return output, nil
}
func ResolveProviderCommand(
provider,
model,
prompt,
cwd string,
) (string, []string) {
switch strings.TrimSpace(strings.ToLower(provider)) {
case "codex":
binary := strings.TrimSpace(EnvOrDefault("ACP_CODEX_BIN", "codex"))
args := []string{"exec", "--skip-git-repo-check", "--color", "never"}
if strings.TrimSpace(cwd) != "" {
args = append(args, "-C", strings.TrimSpace(cwd))
}
if strings.TrimSpace(model) != "" {
args = append(args, "-m", strings.TrimSpace(model))
}
args = append(args, prompt)
return binary, args
case "opencode":
binary := strings.TrimSpace(EnvOrDefault("ACP_OPENCODE_BIN", "opencode"))
args := []string{"run", "--format", "default"}
if strings.TrimSpace(cwd) != "" {
args = append(args, "--dir", strings.TrimSpace(cwd))
}
if strings.TrimSpace(model) != "" {
args = append(args, "-m", strings.TrimSpace(model))
}
args = append(args, prompt)
return binary, args
case "claude":
binary := strings.TrimSpace(EnvOrDefault("ACP_CLAUDE_BIN", "claude"))
if strings.TrimSpace(model) == "" {
return binary, []string{"-p", prompt}
}
return binary, []string{
"--model",
strings.TrimSpace(model),
"-p",
prompt,
}
case "gemini":
binary := strings.TrimSpace(EnvOrDefault("ACP_GEMINI_BIN", "gemini"))
if strings.TrimSpace(model) == "" {
return binary, []string{"-p", prompt}
}
return binary, []string{
"--model",
strings.TrimSpace(model),
"-p",
prompt,
}
default:
return "", nil
}
}
func AugmentPromptWithAttachments(prompt string, params map[string]any) string {
attachmentsRaw := ListArg(params, "attachments")
if len(attachmentsRaw) == 0 {
return prompt
}
lines := make([]string, 0, len(attachmentsRaw))
for _, raw := range attachmentsRaw {
entry, ok := raw.(map[string]any)
if !ok {
continue
}
name := strings.TrimSpace(StringArg(entry, "name", "attachment"))
path := strings.TrimSpace(StringArg(entry, "path", ""))
if path == "" {
continue
}
lines = append(lines, fmt.Sprintf("- %s: %s", name, path))
}
if len(lines) == 0 {
return prompt
}
var builder strings.Builder
builder.WriteString("User-selected local attachments:\n")
builder.WriteString(strings.Join(lines, "\n"))
builder.WriteString("\n\n")
builder.WriteString(prompt)
return builder.String()
}
func ComposeHistoryPrompt(history []string) string {
if len(history) == 0 {
return ""
}
var builder strings.Builder
for index, turn := range history {
builder.WriteString(fmt.Sprintf("## User Turn %d\n", index+1))
builder.WriteString(turn)
builder.WriteString("\n\n")
}
return strings.TrimSpace(builder.String())
}
func CallOpenAICompatibleCtx(
ctx context.Context,
baseURL,
apiKey,
model string,
messages []map[string]string,
) (string, error) {
payload := map[string]any{
"model": model,
"messages": messages,
"max_tokens": 4096,
"stream": false,
}
body, _ := json.Marshal(payload)
request, err := http.NewRequestWithContext(
ctx,
http.MethodPost,
strings.TrimRight(baseURL, "/")+"/chat/completions",
bytes.NewReader(body),
)
if err != nil {
return "", err
}
request.Header.Set("Content-Type", "application/json")
request.Header.Set("Authorization", "Bearer "+apiKey)
client := &http.Client{Timeout: 120 * time.Second}
response, err := client.Do(request)
if err != nil {
return "", err
}
defer response.Body.Close()
responseBody, err := io.ReadAll(response.Body)
if err != nil {
return "", err
}
if response.StatusCode < 200 || response.StatusCode >= 300 {
return "", fmt.Errorf(
"api error %d: %s",
response.StatusCode,
strings.TrimSpace(string(responseBody)),
)
}
var decoded map[string]any
if err := json.Unmarshal(responseBody, &decoded); err != nil {
return "", err
}
choices, _ := decoded["choices"].([]any)
if len(choices) == 0 {
return "", errors.New("missing choices in response")
}
choice, _ := choices[0].(map[string]any)
message, _ := choice["message"].(map[string]any)
content := strings.TrimSpace(fmt.Sprint(message["content"]))
if content == "" || content == "<nil>" {
return "", errors.New("empty response content")
}
return content, nil
}
func HandleChatTool(arguments map[string]any) (string, error) {
apiKey := strings.TrimSpace(EnvOrDefault("LLM_API_KEY", ""))
if apiKey == "" {
return "", errors.New("LLM_API_KEY environment variable not set")
}
baseURL := NormalizeBaseURL(
EnvOrDefault("LLM_BASE_URL", "https://api.openai.com/v1"),
)
model := StringArg(arguments, "model", EnvOrDefault("LLM_MODEL", "gpt-4o"))
prompt := strings.TrimSpace(StringArg(arguments, "prompt", ""))
if prompt == "" {
return "", errors.New("prompt is required")
}
system := strings.TrimSpace(StringArg(arguments, "system", ""))
messages := make([]map[string]string, 0, 2)
if system != "" {
messages = append(messages, map[string]string{
"role": "system",
"content": system,
})
}
messages = append(messages, map[string]string{
"role": "user",
"content": prompt,
})
return CallOpenAICompatible(baseURL, apiKey, model, messages)
}
func HandleClaudeReviewTool(arguments map[string]any) (string, error) {
prompt := strings.TrimSpace(StringArg(arguments, "prompt", ""))
if prompt == "" {
return "", errors.New("prompt is required")
}
model := strings.TrimSpace(
StringArg(arguments, "model", EnvOrDefault("CLAUDE_REVIEW_MODEL", "")),
)
system := strings.TrimSpace(
StringArg(arguments, "system", EnvOrDefault("CLAUDE_REVIEW_SYSTEM", "")),
)
tools := strings.TrimSpace(
StringArg(arguments, "tools", EnvOrDefault("CLAUDE_REVIEW_TOOLS", "")),
)
timeout := IntArg(EnvOrDefault("CLAUDE_REVIEW_TIMEOUT_SEC", "600"), 600)
return RunClaudeReview(
prompt,
model,
system,
tools,
time.Duration(timeout)*time.Second,
)
}
func CallOpenAICompatible(
baseURL,
apiKey,
model string,
messages []map[string]string,
) (string, error) {
return CallOpenAICompatibleCtx(
context.Background(),
baseURL,
apiKey,
model,
messages,
)
}
func RunClaudeReview(
prompt,
model,
system,
tools string,
timeout time.Duration,
) (string, error) {
claudeBin := strings.TrimSpace(EnvOrDefault("CLAUDE_BIN", "claude"))
resolved, err := exec.LookPath(claudeBin)
if err != nil {
return "", fmt.Errorf("Claude CLI not found: %s", claudeBin)
}
args := []string{
"-p",
prompt,
"--output-format",
"json",
"--permission-mode",
"plan",
}
if model != "" {
args = append(args, "--model", model)
}
if system != "" {
args = append(args, "--system-prompt", system)
}
if tools != "" {
args = append(args, "--tools", tools)
}
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
cmd := exec.CommandContext(ctx, resolved, args...)
cmd.Stdin = nil
var stdout bytes.Buffer
var stderr bytes.Buffer
cmd.Stdout = &stdout
cmd.Stderr = &stderr
if err := cmd.Run(); err != nil {
if errors.Is(ctx.Err(), context.DeadlineExceeded) {
return "", fmt.Errorf("Claude review timed out after %s", timeout)
}
message := strings.TrimSpace(stderr.String())
if message == "" {
message = err.Error()
}
return "", fmt.Errorf("Claude review failed: %s", message)
}
payload, err := ParseClaudeJSON(stdout.String())
if err != nil {
message := strings.TrimSpace(stderr.String())
if message != "" {
return "", fmt.Errorf("%v. stderr: %s", err, message)
}
return "", err
}
if isError, _ := payload["is_error"].(bool); isError {
return "", fmt.Errorf("%v", payload["result"])
}
response := strings.TrimSpace(fmt.Sprint(payload["result"]))
if response == "" || response == "<nil>" {
return "", errors.New("Claude review returned empty output")
}
return response, nil
}
func ParseClaudeJSON(raw string) (map[string]any, error) {
lines := strings.Split(raw, "\n")
for i := len(lines) - 1; i >= 0; i-- {
candidate := strings.TrimSpace(lines[i])
if candidate == "" {
continue
}
var payload map[string]any
if err := json.Unmarshal([]byte(candidate), &payload); err == nil {
return payload, nil
}
}
return nil, errors.New("Claude CLI did not return JSON output")
}

325
internal/shared/vault.go Normal file
View File

@ -0,0 +1,325 @@
package shared
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"
)
type VaultKVResult struct {
Operation string `json:"operation"`
Mount string `json:"mount"`
Path string `json:"path"`
Data map[string]any `json:"data,omitempty"`
Keys []string `json:"keys,omitempty"`
Metadata map[string]any `json:"metadata,omitempty"`
}
func HandleVaultKVTool(arguments map[string]any) (string, error) {
request, err := buildVaultKVRequest(arguments)
if err != nil {
return "", err
}
result, err := executeVaultKVRequest(request)
if err != nil {
return "", err
}
encoded, err := json.MarshalIndent(result, "", " ")
if err != nil {
return "", err
}
return string(encoded), nil
}
type vaultKVRequest struct {
baseURL string
token string
namespace string
operation string
mount string
path string
data map[string]any
cas int
}
func buildVaultKVRequest(arguments map[string]any) (vaultKVRequest, error) {
baseURL := strings.TrimSpace(EnvOrDefault("VAULT_SERVER_URL", ""))
if baseURL == "" {
return vaultKVRequest{}, errors.New("VAULT_SERVER_URL environment variable not set")
}
token := strings.TrimSpace(EnvOrDefault("VAULT_SERVER_ROOT_ACCESS_TOKEN", ""))
if token == "" {
return vaultKVRequest{}, errors.New("VAULT_SERVER_ROOT_ACCESS_TOKEN environment variable not set")
}
operation := strings.ToLower(strings.TrimSpace(StringArg(arguments, "operation", "")))
if operation == "" {
return vaultKVRequest{}, errors.New("operation is required")
}
path := normalizeVaultPath(StringArg(arguments, "path", ""))
if path == "" {
return vaultKVRequest{}, errors.New("path is required")
}
data, err := vaultDataArg(arguments["data"])
if err != nil {
return vaultKVRequest{}, err
}
return vaultKVRequest{
baseURL: strings.TrimRight(baseURL, "/"),
token: token,
namespace: strings.TrimSpace(EnvOrDefault("VAULT_NAMESPACE", "")),
operation: operation,
mount: normalizeVaultMount(StringArg(arguments, "mount", "secret")),
path: path,
data: data,
cas: vaultCASArg(arguments["cas"]),
}, nil
}
func executeVaultKVRequest(request vaultKVRequest) (VaultKVResult, error) {
switch request.operation {
case "get", "read":
return vaultKVRead(request)
case "put", "write":
return vaultKVWrite(request)
case "list":
return vaultKVList(request)
case "delete":
return vaultKVDelete(request)
default:
return VaultKVResult{}, fmt.Errorf("unsupported operation: %s", request.operation)
}
}
func vaultKVRead(request vaultKVRequest) (VaultKVResult, error) {
response, err := doVaultRequest(
request,
http.MethodGet,
vaultDataURL(request.mount, request.path),
nil,
)
if err != nil {
return VaultKVResult{}, err
}
dataBlock := mapArg(response["data"])
return VaultKVResult{
Operation: "read",
Mount: request.mount,
Path: request.path,
Data: mapArg(dataBlock["data"]),
Metadata: mapArg(dataBlock["metadata"]),
}, nil
}
func vaultKVWrite(request vaultKVRequest) (VaultKVResult, error) {
if len(request.data) == 0 {
return VaultKVResult{}, errors.New("data is required for write operations")
}
payload := map[string]any{"data": request.data}
if request.cas > 0 {
payload["options"] = map[string]any{"cas": request.cas}
}
response, err := doVaultRequest(
request,
http.MethodPost,
vaultDataURL(request.mount, request.path),
payload,
)
if err != nil {
return VaultKVResult{}, err
}
return VaultKVResult{
Operation: "write",
Mount: request.mount,
Path: request.path,
Data: request.data,
Metadata: mapArg(mapArg(response["data"])["metadata"]),
}, nil
}
func vaultKVList(request vaultKVRequest) (VaultKVResult, error) {
response, err := doVaultRequest(
request,
"LIST",
vaultMetadataURL(request.mount, request.path),
nil,
)
if err != nil {
return VaultKVResult{}, err
}
dataBlock := mapArg(response["data"])
return VaultKVResult{
Operation: "list",
Mount: request.mount,
Path: request.path,
Keys: stringSliceArg(dataBlock["keys"]),
}, nil
}
func vaultKVDelete(request vaultKVRequest) (VaultKVResult, error) {
_, err := doVaultRequest(
request,
http.MethodDelete,
vaultDataURL(request.mount, request.path),
nil,
)
if err != nil {
return VaultKVResult{}, err
}
return VaultKVResult{
Operation: "delete",
Mount: request.mount,
Path: request.path,
}, nil
}
func doVaultRequest(
request vaultKVRequest,
method string,
target string,
payload map[string]any,
) (map[string]any, error) {
var body io.Reader
if payload != nil {
encoded, err := json.Marshal(payload)
if err != nil {
return nil, err
}
body = bytes.NewReader(encoded)
}
httpRequest, err := http.NewRequest(method, target, body)
if err != nil {
return nil, err
}
httpRequest.Header.Set("X-Vault-Token", request.token)
if request.namespace != "" {
httpRequest.Header.Set("X-Vault-Namespace", request.namespace)
}
if payload != nil {
httpRequest.Header.Set("Content-Type", "application/json")
}
client := &http.Client{Timeout: 30 * time.Second}
response, err := client.Do(httpRequest)
if err != nil {
return nil, err
}
defer response.Body.Close()
bodyBytes, err := io.ReadAll(response.Body)
if err != nil {
return nil, err
}
if response.StatusCode < 200 || response.StatusCode >= 300 {
return nil, fmt.Errorf(
"vault api error %d: %s",
response.StatusCode,
strings.TrimSpace(string(bodyBytes)),
)
}
if len(strings.TrimSpace(string(bodyBytes))) == 0 {
return map[string]any{}, nil
}
var decoded map[string]any
if err := json.Unmarshal(bodyBytes, &decoded); err != nil {
return nil, err
}
return decoded, nil
}
func vaultDataURL(mount, path string) string {
return fmt.Sprintf("%s/data/%s", vaultBasePath(mount), vaultPathSegments(path))
}
func vaultMetadataURL(mount, path string) string {
return fmt.Sprintf("%s/metadata/%s", vaultBasePath(mount), vaultPathSegments(path))
}
func vaultBasePath(mount string) string {
return fmt.Sprintf("%s/v1/%s", strings.TrimRight(strings.TrimSpace(EnvOrDefault("VAULT_SERVER_URL", "")), "/"), url.PathEscape(normalizeVaultMount(mount)))
}
func vaultPathSegments(path string) string {
segments := strings.Split(normalizeVaultPath(path), "/")
for index, segment := range segments {
segments[index] = url.PathEscape(segment)
}
return strings.Join(segments, "/")
}
func normalizeVaultMount(raw string) string {
trimmed := strings.Trim(strings.TrimSpace(raw), "/")
if trimmed == "" {
return "secret"
}
return trimmed
}
func normalizeVaultPath(raw string) string {
return strings.Trim(strings.TrimSpace(raw), "/")
}
func vaultDataArg(raw any) (map[string]any, error) {
if raw == nil {
return nil, nil
}
switch typed := raw.(type) {
case map[string]any:
return typed, nil
case string:
trimmed := strings.TrimSpace(typed)
if trimmed == "" {
return nil, nil
}
var decoded map[string]any
if err := json.Unmarshal([]byte(trimmed), &decoded); err != nil {
return nil, errors.New("data must be a JSON object")
}
return decoded, nil
default:
return nil, errors.New("data must be an object")
}
}
func vaultCASArg(raw any) int {
switch typed := raw.(type) {
case int:
return typed
case int64:
return int(typed)
case float64:
return int(typed)
case string:
return IntArg(typed, 0)
default:
return 0
}
}
func mapArg(raw any) map[string]any {
switch typed := raw.(type) {
case map[string]any:
return typed
default:
return map[string]any{}
}
}
func stringSliceArg(raw any) []string {
values, ok := raw.([]any)
if !ok {
return nil
}
result := make([]string, 0, len(values))
for _, value := range values {
text := strings.TrimSpace(fmt.Sprint(value))
if text == "" || text == "<nil>" {
continue
}
result = append(result, text)
}
return result
}

View File

@ -0,0 +1,142 @@
package shared
import (
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
)
func TestHandleVaultKVToolReadsSecretData(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
t.Fatalf("unexpected method: %s", r.Method)
}
if got := r.Header.Get("X-Vault-Token"); got != "root-token" {
t.Fatalf("unexpected token header: %s", got)
}
if got := r.Header.Get("X-Vault-Namespace"); got != "platform/team-a" {
t.Fatalf("unexpected namespace header: %s", got)
}
if got := r.URL.Path; got != "/v1/secret/data/apps/demo" {
t.Fatalf("unexpected request path: %s", got)
}
_ = json.NewEncoder(w).Encode(map[string]any{
"data": map[string]any{
"data": map[string]any{
"api_key": "demo-key",
},
"metadata": map[string]any{
"version": 3,
},
},
})
}))
defer server.Close()
t.Setenv("VAULT_SERVER_URL", server.URL)
t.Setenv("VAULT_SERVER_ROOT_ACCESS_TOKEN", "root-token")
t.Setenv("VAULT_NAMESPACE", "platform/team-a")
output, err := HandleVaultKVTool(map[string]any{
"operation": "read",
"path": "apps/demo",
})
if err != nil {
t.Fatalf("HandleVaultKVTool returned error: %v", err)
}
if !strings.Contains(output, `"api_key": "demo-key"`) {
t.Fatalf("expected secret data in output, got %s", output)
}
}
func TestHandleVaultKVToolWritesSecretData(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
t.Fatalf("unexpected method: %s", r.Method)
}
if got := r.URL.Path; got != "/v1/secret/data/apps/demo" {
t.Fatalf("unexpected request path: %s", got)
}
var payload map[string]any
if err := json.NewDecoder(r.Body).Decode(&payload); err != nil {
t.Fatalf("decode payload: %v", err)
}
data := mapArg(payload["data"])
if got := data["enabled"]; got != true {
t.Fatalf("unexpected data payload: %v", payload)
}
options := mapArg(payload["options"])
if got := options["cas"]; got != float64(2) {
t.Fatalf("unexpected cas payload: %v", payload)
}
_ = json.NewEncoder(w).Encode(map[string]any{
"data": map[string]any{
"metadata": map[string]any{
"version": 4,
},
},
})
}))
defer server.Close()
t.Setenv("VAULT_SERVER_URL", server.URL)
t.Setenv("VAULT_SERVER_ROOT_ACCESS_TOKEN", "root-token")
output, err := HandleVaultKVTool(map[string]any{
"operation": "write",
"path": "apps/demo",
"data": map[string]any{
"enabled": true,
},
"cas": 2,
})
if err != nil {
t.Fatalf("HandleVaultKVTool returned error: %v", err)
}
if !strings.Contains(output, `"version": 4`) {
t.Fatalf("expected metadata in output, got %s", output)
}
}
func TestHandleVaultKVToolListsSecretKeys(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != "LIST" {
t.Fatalf("unexpected method: %s", r.Method)
}
if got := r.URL.Path; got != "/v1/secret/metadata/apps" {
t.Fatalf("unexpected request path: %s", got)
}
_ = json.NewEncoder(w).Encode(map[string]any{
"data": map[string]any{
"keys": []string{"demo", "prod"},
},
})
}))
defer server.Close()
t.Setenv("VAULT_SERVER_URL", server.URL)
t.Setenv("VAULT_SERVER_ROOT_ACCESS_TOKEN", "root-token")
output, err := HandleVaultKVTool(map[string]any{
"operation": "list",
"path": "apps",
})
if err != nil {
t.Fatalf("HandleVaultKVTool returned error: %v", err)
}
if !strings.Contains(output, `"demo"`) || !strings.Contains(output, `"prod"`) {
t.Fatalf("expected listed keys in output, got %s", output)
}
}
func TestHandleVaultKVToolRequiresEnvironment(t *testing.T) {
_, err := HandleVaultKVTool(map[string]any{
"operation": "read",
"path": "apps/demo",
})
if err == nil || !strings.Contains(err.Error(), "VAULT_SERVER_URL") {
t.Fatalf("expected missing environment error, got %v", err)
}
}

View File

@ -0,0 +1,209 @@
package skills
import (
"context"
"encoding/json"
"fmt"
"os/exec"
"strings"
"time"
"xworkmate-bridge/internal/shared"
)
type ChainFinder struct {
Primary Finder
Fallback Finder
}
func (f ChainFinder) Find(prompt string) []Candidate {
if f.Primary != nil {
if resolved := dedupeCandidates(f.Primary.Find(prompt)); len(resolved) > 0 {
return resolved
}
}
if f.Fallback == nil {
return nil
}
return dedupeCandidates(f.Fallback.Find(prompt))
}
type CommandFinder struct {
Binary string
}
func (f CommandFinder) Find(prompt string) []Candidate {
payload, ok := runSkillCommand(
strings.TrimSpace(f.Binary),
map[string]any{"prompt": strings.TrimSpace(prompt)},
)
if !ok {
return nil
}
return parseCandidatesPayload(payload)
}
type CommandInstaller struct {
Binary string
}
func (i CommandInstaller) Install(candidates []Candidate) ([]Candidate, error) {
payload, ok := runSkillCommand(
strings.TrimSpace(i.Binary),
map[string]any{
"candidates": routingCandidatesPayload(candidates),
},
)
if !ok {
return nil, nil
}
return parseCandidatesPayload(payload), nil
}
func NewDefaultFinder() Finder {
return ChainFinder{
Primary: CommandFinder{
Binary: strings.TrimSpace(shared.EnvOrDefault("ACP_FIND_SKILLS_BIN", "")),
},
Fallback: StaticFinder{},
}
}
func NewDefaultInstaller() Installer {
return CommandInstaller{
Binary: strings.TrimSpace(shared.EnvOrDefault("ACP_INSTALL_SKILL_BIN", "")),
}
}
func runSkillCommand(binary string, payload map[string]any) (map[string]any, bool) {
if binary == "" {
return nil, false
}
if _, err := exec.LookPath(binary); err != nil {
return nil, false
}
body, err := json.Marshal(payload)
if err != nil {
return nil, false
}
ctx, cancel := context.WithTimeout(context.Background(), 45*time.Second)
defer cancel()
cmd := exec.CommandContext(ctx, binary)
cmd.Stdin = strings.NewReader(string(body))
output, err := cmd.Output()
if err != nil {
return nil, false
}
var decoded map[string]any
if err := json.Unmarshal(output, &decoded); err == nil {
return decoded, true
}
var list []map[string]any
if err := json.Unmarshal(output, &list); err == nil {
return map[string]any{"candidates": list}, true
}
return nil, false
}
func parseCandidatesPayload(payload map[string]any) []Candidate {
if len(payload) == 0 {
return nil
}
if raw, ok := payload["candidates"]; ok {
return parseCandidates(raw)
}
if raw, ok := payload["skills"]; ok {
return parseCandidates(raw)
}
return parseCandidates(payload)
}
func parseCandidates(raw any) []Candidate {
switch typed := raw.(type) {
case []any:
result := make([]Candidate, 0, len(typed))
for _, item := range typed {
entry := toMap(item)
if len(entry) == 0 {
continue
}
result = append(result, Candidate{
ID: strings.TrimSpace(stringValue(entry["id"])),
Label: strings.TrimSpace(stringValue(entry["label"])),
Description: strings.TrimSpace(stringValue(entry["description"])),
Installed: boolValue(entry["installed"]),
})
}
return dedupeCandidates(result)
case []map[string]any:
values := make([]any, 0, len(typed))
for _, item := range typed {
values = append(values, item)
}
return parseCandidates(values)
case map[string]any:
entry := Candidate{
ID: strings.TrimSpace(stringValue(typed["id"])),
Label: strings.TrimSpace(stringValue(typed["label"])),
Description: strings.TrimSpace(stringValue(typed["description"])),
Installed: boolValue(typed["installed"]),
}
if entry.ID == "" && entry.Label == "" {
return nil
}
return []Candidate{entry}
default:
return nil
}
}
func routingCandidatesPayload(candidates []Candidate) []map[string]any {
result := make([]map[string]any, 0, len(candidates))
for _, candidate := range candidates {
result = append(result, map[string]any{
"id": strings.TrimSpace(candidate.ID),
"label": strings.TrimSpace(candidate.Label),
"description": strings.TrimSpace(candidate.Description),
"installed": candidate.Installed,
})
}
return result
}
func toMap(value any) map[string]any {
if typed, ok := value.(map[string]any); ok {
return typed
}
if typed, ok := value.(map[string]interface{}); ok {
return typed
}
return nil
}
func stringValue(value any) string {
if value == nil {
return ""
}
switch typed := value.(type) {
case string:
return strings.TrimSpace(typed)
default:
return strings.TrimSpace(fmt.Sprint(value))
}
}
func boolValue(value any) bool {
switch typed := value.(type) {
case bool:
return typed
case string:
normalized := strings.ToLower(strings.TrimSpace(typed))
return normalized == "true" || normalized == "1" || normalized == "yes"
case float64:
return typed != 0
case int:
return typed != 0
default:
return false
}
}

353
internal/skills/resolver.go Normal file
View File

@ -0,0 +1,353 @@
package skills
import (
"fmt"
"sort"
"strings"
)
type Candidate struct {
ID string
Label string
Description string
Installed bool
}
type Finder interface {
Find(prompt string) []Candidate
}
type Installer interface {
Install(candidates []Candidate) ([]Candidate, error)
}
type ResolveRequest struct {
Prompt string
ExplicitSkills []string
AvailableSkills []Candidate
AllowSkillInstall bool
InstallApproval InstallApproval
}
type InstallApproval struct {
RequestID string
ApprovedSkillKeys []string
}
type ResolveResult struct {
ResolvedSkills []string
Candidates []Candidate
Source string
NeedsInstall bool
InstallRequestID string
}
type StaticFinder struct{}
func (StaticFinder) Find(prompt string) []Candidate {
haystack := normalize(prompt)
candidates := make([]Candidate, 0, 4)
for _, entry := range builtinCatalog {
if !containsAny(haystack, entry.keywords) {
continue
}
candidates = append(candidates, Candidate{
ID: entry.id,
Label: entry.label,
Installed: false,
})
}
return dedupeCandidates(candidates)
}
func Resolve(req ResolveRequest, finder Finder, installer Installer) ResolveResult {
available := dedupeCandidates(req.AvailableSkills)
explicit := normalizeList(req.ExplicitSkills)
if len(explicit) > 0 {
return ResolveResult{
ResolvedSkills: explicit,
Source: "local_match",
}
}
localMatches := matchLocalSkills(req.Prompt, available)
if len(localMatches) > 0 {
return ResolveResult{
ResolvedSkills: localMatches,
Source: "local_match",
}
}
if finder == nil {
return ResolveResult{Source: "none"}
}
fallback := dedupeCandidates(finder.Find(req.Prompt))
if len(fallback) == 0 {
return ResolveResult{Source: "none"}
}
installed := make([]string, 0, len(fallback))
uninstalled := make([]Candidate, 0, len(fallback))
for _, candidate := range fallback {
if matched := findInstalledMatch(candidate, available); matched != "" {
installed = append(installed, matched)
continue
}
uninstalled = append(uninstalled, candidate)
}
if len(installed) > 0 {
return ResolveResult{
ResolvedSkills: dedupeStrings(installed),
Candidates: fallback,
Source: "find_skills",
}
}
installRequestID := buildInstallRequestID(uninstalled)
if shouldInstallApprovedCandidates(req, installRequestID) &&
installer != nil &&
len(uninstalled) > 0 {
approvedCandidates := filterApprovedCandidates(
uninstalled,
req.InstallApproval.ApprovedSkillKeys,
)
if len(approvedCandidates) == 0 {
return ResolveResult{
Candidates: fallback,
Source: "find_skills",
NeedsInstall: true,
InstallRequestID: installRequestID,
}
}
installedCandidates, err := installer.Install(approvedCandidates)
if err == nil && len(installedCandidates) > 0 {
mergedAvailable := dedupeCandidates(
append(append([]Candidate(nil), available...), installedCandidates...),
)
if resolved := installedMatches(fallback, mergedAvailable); len(resolved) > 0 {
return ResolveResult{
ResolvedSkills: resolved,
Candidates: dedupeCandidates(
append(append([]Candidate(nil), fallback...), installedCandidates...),
),
Source: "find_skills",
}
}
}
}
return ResolveResult{
Candidates: fallback,
Source: "find_skills",
NeedsInstall: len(uninstalled) > 0,
InstallRequestID: installRequestID,
}
}
func shouldInstallApprovedCandidates(
req ResolveRequest,
expectedRequestID string,
) bool {
if !req.AllowSkillInstall || expectedRequestID == "" {
return false
}
if strings.TrimSpace(req.InstallApproval.RequestID) != expectedRequestID {
return false
}
return len(dedupeStrings(req.InstallApproval.ApprovedSkillKeys)) > 0
}
func filterApprovedCandidates(
candidates []Candidate,
approvedSkillKeys []string,
) []Candidate {
if len(candidates) == 0 {
return nil
}
approved := make(map[string]struct{}, len(approvedSkillKeys))
for _, key := range approvedSkillKeys {
normalized := normalize(key)
if normalized == "" {
continue
}
approved[normalized] = struct{}{}
}
filtered := make([]Candidate, 0, len(candidates))
for _, candidate := range candidates {
if _, ok := approved[normalize(candidate.ID)]; ok {
filtered = append(filtered, candidate)
}
}
return filtered
}
func buildInstallRequestID(candidates []Candidate) string {
if len(candidates) == 0 {
return ""
}
keys := make([]string, 0, len(candidates))
for _, candidate := range candidates {
key := normalize(candidate.ID)
if key == "" {
key = normalize(candidate.Label)
}
if key != "" {
keys = append(keys, key)
}
}
if len(keys) == 0 {
return ""
}
sort.Strings(keys)
return "skill-install:" + strings.Join(keys, ",")
}
type builtinSkill struct {
id string
label string
keywords []string
}
var builtinCatalog = []builtinSkill{
{id: "pptx", label: "pptx", keywords: []string{"ppt", "pptx", "powerpoint", "slides", "幻灯片", "演示文稿"}},
{id: "docx", label: "docx", keywords: []string{"docx", "word", "word document", "文档"}},
{id: "xlsx", label: "xlsx", keywords: []string{"xlsx", "excel", "spreadsheet", "表格", "工作表"}},
{id: "pdf", label: "pdf", keywords: []string{"pdf", "表单", "merge pdf", "split pdf"}},
{id: "image-resizer", label: "image-resizer", keywords: []string{"image-resizer", "resize image", "compress image", "crop image", "批量图片"}},
{id: "image-cog", label: "image-cog", keywords: []string{"image-cog", "文生图", "图生图", "角色一致性"}},
{id: "image-video-generation-editting", label: "image-video-generation-editting", keywords: []string{"wan", "文生视频", "图生视频", "视频生成", "视频编辑"}},
{id: "video-translator", label: "video-translator", keywords: []string{"video-translator", "视频翻译", "配音", "字幕翻译", "translate video", "dub video", "subtitles"}},
{id: "browser-automation", label: "Browser Automation", keywords: []string{"browser", "跨浏览器", "浏览器", "web scraping", "资讯采集", "search", "搜索", "news", "资讯"}},
{id: "find-skills", label: "find_skills", keywords: []string{"find skills", "find_skills", "技能包", "skill package"}},
}
func matchLocalSkills(prompt string, available []Candidate) []string {
if len(available) == 0 {
return nil
}
haystack := normalize(prompt)
if haystack == "" {
return nil
}
matches := make([]string, 0, len(available))
for _, candidate := range available {
keywords := candidateKeywords(candidate)
if containsAny(haystack, keywords) {
matches = append(matches, candidateLabel(candidate))
}
}
return dedupeStrings(matches)
}
func candidateKeywords(candidate Candidate) []string {
base := []string{
normalize(candidate.ID),
normalize(candidate.Label),
}
text := normalize(strings.Join([]string{candidate.ID, candidate.Label}, " "))
for _, entry := range builtinCatalog {
if containsAny(text, []string{normalize(entry.id), normalize(entry.label)}) {
base = append(base, entry.keywords...)
}
}
return dedupeStrings(base)
}
func findInstalledMatch(candidate Candidate, available []Candidate) string {
want := candidateKeywords(candidate)
for _, item := range available {
if containsAny(strings.Join(candidateKeywords(item), " "), want) {
return candidateLabel(item)
}
}
return ""
}
func installedMatches(candidates []Candidate, available []Candidate) []string {
resolved := make([]string, 0, len(candidates))
for _, candidate := range candidates {
if matched := findInstalledMatch(candidate, available); matched != "" {
resolved = append(resolved, matched)
}
}
return dedupeStrings(resolved)
}
func candidateLabel(candidate Candidate) string {
if strings.TrimSpace(candidate.Label) != "" {
return strings.TrimSpace(candidate.Label)
}
return strings.TrimSpace(candidate.ID)
}
func containsAny(haystack string, needles []string) bool {
for _, needle := range needles {
if strings.TrimSpace(needle) == "" {
continue
}
if strings.Contains(haystack, normalize(needle)) {
return true
}
}
return false
}
func normalize(value string) string {
return strings.ToLower(strings.TrimSpace(value))
}
func normalizeList(values []string) []string {
result := make([]string, 0, len(values))
for _, value := range values {
trimmed := strings.TrimSpace(value)
if trimmed == "" {
continue
}
result = append(result, trimmed)
}
return dedupeStrings(result)
}
func dedupeStrings(values []string) []string {
if len(values) == 0 {
return nil
}
seen := make(map[string]string, len(values))
ordered := make([]string, 0, len(values))
for _, value := range values {
trimmed := strings.TrimSpace(value)
if trimmed == "" {
continue
}
key := normalize(trimmed)
if _, ok := seen[key]; ok {
continue
}
seen[key] = trimmed
ordered = append(ordered, trimmed)
}
return ordered
}
func dedupeCandidates(values []Candidate) []Candidate {
if len(values) == 0 {
return nil
}
seen := make(map[string]struct{}, len(values))
ordered := make([]Candidate, 0, len(values))
for _, candidate := range values {
key := normalize(fmt.Sprintf("%s|%s", candidate.ID, candidate.Label))
if key == "|" {
continue
}
if _, ok := seen[key]; ok {
continue
}
seen[key] = struct{}{}
ordered = append(ordered, candidate)
}
return ordered
}

View File

@ -0,0 +1,127 @@
package skills
import "testing"
type fakeFinder []Candidate
func (f fakeFinder) Find(prompt string) []Candidate {
return append([]Candidate(nil), f...)
}
type fakeInstaller struct {
installed []Candidate
}
func (f fakeInstaller) Install(candidates []Candidate) ([]Candidate, error) {
return append([]Candidate(nil), f.installed...), nil
}
func TestResolvePrefersExplicitSkills(t *testing.T) {
result := Resolve(ResolveRequest{
Prompt: "make a deck",
ExplicitSkills: []string{"pptx"},
AvailableSkills: []Candidate{
{ID: "pptx", Label: "pptx", Installed: true},
},
}, StaticFinder{}, nil)
if result.Source != "local_match" {
t.Fatalf("expected local_match source, got %q", result.Source)
}
if len(result.ResolvedSkills) != 1 || result.ResolvedSkills[0] != "pptx" {
t.Fatalf("unexpected resolved skills: %#v", result.ResolvedSkills)
}
}
func TestResolveUsesInstalledLocalMatchesBeforeFallback(t *testing.T) {
result := Resolve(ResolveRequest{
Prompt: "create a PowerPoint presentation from this brief",
AvailableSkills: []Candidate{
{ID: "pptx", Label: "PPTX", Installed: true},
{ID: "docx", Label: "DOCX", Installed: true},
},
}, StaticFinder{}, nil)
if result.Source != "local_match" {
t.Fatalf("expected local_match source, got %q", result.Source)
}
if len(result.ResolvedSkills) != 1 || result.ResolvedSkills[0] != "PPTX" {
t.Fatalf("unexpected resolved skills: %#v", result.ResolvedSkills)
}
}
func TestResolveFallsBackToFindSkillsCandidates(t *testing.T) {
result := Resolve(ResolveRequest{
Prompt: "translate and dub this video with subtitles",
AvailableSkills: []Candidate{{ID: "docx", Label: "docx", Installed: true}},
AllowSkillInstall: false,
}, StaticFinder{}, nil)
if result.Source != "find_skills" {
t.Fatalf("expected find_skills source, got %q", result.Source)
}
if len(result.ResolvedSkills) != 0 {
t.Fatalf("expected no installed resolved skills, got %#v", result.ResolvedSkills)
}
if !result.NeedsInstall {
t.Fatalf("expected install recommendation")
}
if len(result.Candidates) == 0 || result.Candidates[0].ID != "video-translator" {
t.Fatalf("unexpected fallback candidates: %#v", result.Candidates)
}
}
func TestResolveInstallsMissingSkillsWhenAuthorized(t *testing.T) {
initial := Resolve(
ResolveRequest{
Prompt: "translate and dub this video with subtitles",
AvailableSkills: []Candidate{{ID: "docx", Label: "docx", Installed: true}},
AllowSkillInstall: true,
},
fakeFinder{
{ID: "video-translator", Label: "video-translator", Installed: false},
},
fakeInstaller{
installed: []Candidate{
{ID: "video-translator", Label: "video-translator", Installed: true},
},
},
)
if !initial.NeedsInstall {
t.Fatalf("expected install approval flow to pause first, got %#v", initial)
}
if initial.InstallRequestID == "" {
t.Fatalf("expected install request id, got %#v", initial)
}
result := Resolve(
ResolveRequest{
Prompt: "translate and dub this video with subtitles",
AvailableSkills: []Candidate{{ID: "docx", Label: "docx", Installed: true}},
AllowSkillInstall: true,
InstallApproval: InstallApproval{
RequestID: initial.InstallRequestID,
ApprovedSkillKeys: []string{"video-translator"},
},
},
fakeFinder{
{ID: "video-translator", Label: "video-translator", Installed: false},
},
fakeInstaller{
installed: []Candidate{
{ID: "video-translator", Label: "video-translator", Installed: true},
},
},
)
if result.Source != "find_skills" {
t.Fatalf("expected find_skills source, got %q", result.Source)
}
if result.NeedsInstall {
t.Fatalf("expected install retry to resolve the skill, got %#v", result)
}
if len(result.ResolvedSkills) != 1 || result.ResolvedSkills[0] != "video-translator" {
t.Fatalf("unexpected resolved skills after install: %#v", result.ResolvedSkills)
}
}

View File

@ -0,0 +1,194 @@
package toolbridge
import (
"bufio"
"encoding/json"
"errors"
"fmt"
"io"
"strings"
"xworkmate-bridge/internal/shared"
)
func Run(input io.Reader, output io.Writer) {
reader := bufio.NewReader(input)
for {
payload, err := readMessage(reader)
if err != nil {
if errors.Is(err, io.EOF) {
return
}
writeError(output, nil, -32700, err.Error())
continue
}
if len(strings.TrimSpace(string(payload))) == 0 {
continue
}
request, err := shared.DecodeRPCRequest(payload)
if err != nil {
writeError(output, nil, -32700, err.Error())
continue
}
response := handleRequest(request)
if response != nil {
writeMessage(output, response)
}
}
}
func readMessage(reader *bufio.Reader) ([]byte, error) {
line, err := reader.ReadString('\n')
if err != nil {
return nil, err
}
line = strings.TrimSpace(line)
if line == "" {
return nil, nil
}
if strings.HasPrefix(strings.ToLower(line), "content-length:") {
var contentLength int
if _, err := fmt.Sscanf(line, "Content-Length: %d", &contentLength); err != nil {
if _, err2 := fmt.Sscanf(line, "content-length: %d", &contentLength); err2 != nil {
return nil, fmt.Errorf("invalid content-length header")
}
}
for {
headerLine, err := reader.ReadString('\n')
if err != nil {
return nil, err
}
if strings.TrimSpace(headerLine) == "" {
break
}
}
body := make([]byte, contentLength)
if _, err := io.ReadFull(reader, body); err != nil {
return nil, err
}
return body, nil
}
return []byte(line), nil
}
func writeMessage(output io.Writer, message map[string]any) {
payload, _ := json.Marshal(message)
_, _ = output.Write(append(payload, '\n'))
}
func writeError(output io.Writer, id any, code int, message string) {
writeMessage(output, shared.ErrorEnvelope(id, code, message))
}
func handleRequest(request shared.RPCRequest) map[string]any {
if request.ID == nil {
return nil
}
switch request.Method {
case "initialize":
return shared.ResultEnvelope(request.ID, map[string]any{
"protocolVersion": "2024-11-05",
"capabilities": map[string]any{
"tools": map[string]any{},
},
"serverInfo": map[string]any{
"name": "xworkmate-go-core",
"version": "0.2.0",
},
})
case "ping":
return shared.ResultEnvelope(request.ID, map[string]any{})
case "tools/list":
return shared.ResultEnvelope(request.ID, map[string]any{
"tools": []map[string]any{
{
"name": "chat",
"description": "OpenAI-compatible reviewer chat bridge",
"inputSchema": map[string]any{
"type": "object",
"properties": map[string]any{
"prompt": map[string]any{"type": "string"},
"model": map[string]any{"type": "string"},
"system": map[string]any{"type": "string"},
},
"required": []string{"prompt"},
},
},
{
"name": "claude_review",
"description": "Review-only bridge over Claude CLI",
"inputSchema": map[string]any{
"type": "object",
"properties": map[string]any{
"prompt": map[string]any{"type": "string"},
"model": map[string]any{"type": "string"},
"system": map[string]any{"type": "string"},
"tools": map[string]any{"type": "string"},
},
"required": []string{"prompt"},
},
},
{
"name": "vault_kv",
"description": "HashiCorp Vault K/V v2 bridge",
"inputSchema": map[string]any{
"type": "object",
"properties": map[string]any{
"operation": map[string]any{"type": "string"},
"mount": map[string]any{"type": "string"},
"path": map[string]any{"type": "string"},
"data": map[string]any{"type": "object"},
"cas": map[string]any{"type": "number"},
},
"required": []string{"operation", "path"},
},
},
},
})
case "tools/call":
var params shared.ToolCallParams
raw, _ := json.Marshal(request.Params)
if err := json.Unmarshal(raw, &params); err != nil {
return shared.ErrorResponse(
request.ID,
-32602,
fmt.Sprintf("invalid tool params: %v", err),
)
}
switch params.Name {
case "chat":
content, err := shared.HandleChatTool(params.Arguments)
if err != nil {
return shared.ToolErrorResult(request.ID, err)
}
return shared.ToolTextResult(request.ID, content)
case "claude_review":
content, err := shared.HandleClaudeReviewTool(params.Arguments)
if err != nil {
return shared.ToolErrorResult(request.ID, err)
}
return shared.ToolTextResult(request.ID, content)
case "vault_kv":
content, err := shared.HandleVaultKVTool(params.Arguments)
if err != nil {
return shared.ToolErrorResult(request.ID, err)
}
return shared.ToolTextResult(request.ID, content)
default:
return shared.ErrorResponse(
request.ID,
-32601,
fmt.Sprintf("unknown tool: %s", params.Name),
)
}
default:
return shared.ErrorResponse(
request.ID,
-32601,
fmt.Sprintf("unknown method: %s", request.Method),
)
}
}

View File

@ -0,0 +1,80 @@
package toolbridge
import (
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
"xworkmate-bridge/internal/shared"
)
func TestHandleRequestListsVaultKVTool(t *testing.T) {
t.Parallel()
response := handleRequest(sharedRequest("tools/list", nil))
result := mapStringAny(response["result"])
tools := result["tools"].([]map[string]any)
found := false
for _, tool := range tools {
if tool["name"] == "vault_kv" {
found = true
break
}
}
if !found {
t.Fatalf("expected vault_kv tool in %v", tools)
}
}
func TestHandleRequestCallsVaultKVTool(t *testing.T) {
var requestPath string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
requestPath = r.URL.Path
_ = json.NewEncoder(w).Encode(map[string]any{
"data": map[string]any{
"data": map[string]any{
"demo": "value",
},
},
})
}))
defer server.Close()
t.Setenv("VAULT_SERVER_URL", server.URL)
t.Setenv("VAULT_SERVER_ROOT_ACCESS_TOKEN", "root-token")
response := handleRequest(sharedRequest("tools/call", map[string]any{
"name": "vault_kv",
"arguments": map[string]any{
"operation": "read",
"path": "apps/demo",
},
}))
result := mapStringAny(response["result"])
content := result["content"].([]map[string]any)
text := strings.TrimSpace(content[0]["text"].(string))
if !strings.Contains(text, `"demo": "value"`) {
t.Fatalf("unexpected tool output: %s", text)
}
if requestPath != "/v1/secret/data/apps/demo" {
t.Fatalf("unexpected request path: %s", requestPath)
}
}
func sharedRequest(method string, params map[string]any) shared.RPCRequest {
return shared.RPCRequest{
JSONRPC: "2.0",
ID: 1,
Method: method,
Params: params,
}
}
func mapStringAny(raw any) map[string]any {
if typed, ok := raw.(map[string]any); ok {
return typed
}
return map[string]any{}
}

25
main.go Normal file
View File

@ -0,0 +1,25 @@
package main
import (
"fmt"
"os"
"xworkmate-bridge/internal/acp"
"xworkmate-bridge/internal/toolbridge"
)
func main() {
if len(os.Args) > 1 && os.Args[1] == "serve" {
if err := acp.Serve(os.Args[2:]); err != nil {
fmt.Fprintf(os.Stderr, "%v\n", err)
os.Exit(1)
}
return
}
if len(os.Args) > 1 && os.Args[1] == "acp-stdio" {
acp.RunStdio(os.Stdin, os.Stdout)
return
}
toolbridge.Run(os.Stdin, os.Stdout)
}

131
main_test.go Normal file
View File

@ -0,0 +1,131 @@
package main
import (
"encoding/json"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"testing"
"time"
)
func TestParseClaudeJSON(t *testing.T) {
t.Parallel()
payload, err := parseClaudeJSON("log line\n{\"result\":\"review ok\",\"is_error\":false}\n")
if err != nil {
t.Fatalf("parseClaudeJSON returned error: %v", err)
}
if got := payload["result"]; got != "review ok" {
t.Fatalf("unexpected result: %v", got)
}
}
func TestCallOpenAICompatible(t *testing.T) {
t.Parallel()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if got := r.Header.Get("Authorization"); got != "Bearer test-key" {
t.Fatalf("unexpected auth header: %s", got)
}
var body map[string]any
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
t.Fatalf("decode request body: %v", err)
}
if got := body["model"]; got != "qwen2.5-coder:latest" {
t.Fatalf("unexpected model: %v", got)
}
_ = json.NewEncoder(w).Encode(map[string]any{
"choices": []map[string]any{
{
"message": map[string]any{
"content": "review ok",
},
},
},
})
}))
defer server.Close()
output, err := callOpenAICompatible(
server.URL,
"test-key",
"qwen2.5-coder:latest",
[]map[string]string{
{"role": "user", "content": "hello"},
},
)
if err != nil {
t.Fatalf("callOpenAICompatible returned error: %v", err)
}
if output != "review ok" {
t.Fatalf("unexpected output: %s", output)
}
}
func TestHandleChatToolRequiresPrompt(t *testing.T) {
t.Setenv("LLM_API_KEY", "test-key")
t.Setenv("LLM_BASE_URL", "http://127.0.0.1:11434/v1")
_, err := handleChatTool(map[string]any{})
if err == nil || err.Error() != "prompt is required" {
t.Fatalf("expected prompt error, got %v", err)
}
}
func TestParseClaudeJSONReturnsErrorForPlainText(t *testing.T) {
t.Parallel()
_, err := parseClaudeJSON("plain text only\n")
if err == nil {
t.Fatal("expected parse error for plain text output")
}
}
func TestCallOpenAICompatibleReturnsStatusError(t *testing.T) {
t.Parallel()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Error(w, "bad gateway", http.StatusBadGateway)
}))
defer server.Close()
_, err := callOpenAICompatible(
server.URL,
"test-key",
"qwen2.5-coder:latest",
[]map[string]string{{"role": "user", "content": "hello"}},
)
if err == nil || err.Error() == "" {
t.Fatal("expected non-2xx status error")
}
}
func TestRunClaudeReviewSurfacesCliExitFailure(t *testing.T) {
tempDir := t.TempDir()
cliPath := filepath.Join(tempDir, "claude")
if err := os.WriteFile(cliPath, []byte("#!/bin/sh\necho boom >&2\nexit 2\n"), 0o755); err != nil {
t.Fatalf("write fake claude script: %v", err)
}
t.Setenv("CLAUDE_BIN", cliPath)
_, err := runClaudeReview("review this", "", "", "", 2*time.Second)
if err == nil || err.Error() == "" {
t.Fatal("expected cli failure")
}
}
func TestRunClaudeReviewSurfacesNonJSONStdout(t *testing.T) {
tempDir := t.TempDir()
cliPath := filepath.Join(tempDir, "claude")
if err := os.WriteFile(cliPath, []byte("#!/bin/sh\necho plain-text-output\nexit 0\n"), 0o755); err != nil {
t.Fatalf("write fake claude script: %v", err)
}
t.Setenv("CLAUDE_BIN", cliPath)
_, err := runClaudeReview("review this", "", "", "", 2*time.Second)
if err == nil || err.Error() == "" {
t.Fatal("expected non-json stdout error")
}
}

38
main_tools.go Normal file
View File

@ -0,0 +1,38 @@
package main
import (
"time"
"xworkmate-bridge/internal/shared"
)
func parseClaudeJSON(raw string) (map[string]any, error) {
return shared.ParseClaudeJSON(raw)
}
func callOpenAICompatible(
baseURL,
apiKey,
model string,
messages []map[string]string,
) (string, error) {
return shared.CallOpenAICompatible(baseURL, apiKey, model, messages)
}
func handleChatTool(arguments map[string]any) (string, error) {
return shared.HandleChatTool(arguments)
}
func handleVaultKVTool(arguments map[string]any) (string, error) {
return shared.HandleVaultKVTool(arguments)
}
func runClaudeReview(
prompt,
model,
system,
tools string,
timeout time.Duration,
) (string, error) {
return shared.RunClaudeReview(prompt, model, system, tools, timeout)
}

33
scripts/build-helper.sh Executable file
View File

@ -0,0 +1,33 @@
#!/usr/bin/env bash
set -euo pipefail
ROOT_DIR="$(cd "$(dirname "$0")/.." && pwd)"
OUTPUT_DIR="${OUTPUT_DIR:-$ROOT_DIR/build/bin}"
OUTPUT_PATH_BASE="${OUTPUT_DIR}/xworkmate-go-core"
if [[ "$(uname -s)" == *MINGW* || "$(uname -s)" == *MSYS* || "$(uname -s)" == *CYGWIN* ]]; then
OUTPUT_PATH="${OUTPUT_PATH:-${OUTPUT_PATH_BASE}.exe}"
else
OUTPUT_PATH="${OUTPUT_PATH:-${OUTPUT_PATH_BASE}}"
fi
if [[ ! -f "$ROOT_DIR/go.mod" ]]; then
echo "Missing go.mod in $ROOT_DIR" >&2
exit 1
fi
if ! command -v go >/dev/null 2>&1; then
echo "Go toolchain is required to build xworkmate-go-core" >&2
exit 1
fi
mkdir -p "$OUTPUT_DIR"
echo "Building xworkmate-go-core from xworkmate-bridge..."
(
cd "$ROOT_DIR"
GO111MODULE=on go build -o "$OUTPUT_PATH" .
)
chmod +x "$OUTPUT_PATH"
echo "Built: $OUTPUT_PATH"