Split ACP bridge into standalone repository
This commit is contained in:
parent
4d200b6faa
commit
e5b343ba3a
2
Makefile
2
Makefile
@ -85,7 +85,7 @@ build-ios-sim: ## Build the iOS app for the simulator
|
||||
$(FLUTTER) build ios --simulator $(APP_STORE_DART_DEFINE) --build-name=$(APP_VERSION) --build-number=$(APP_BUILD_NUMBER) $(APP_DART_DEFINE_VERSION) $(APP_DART_DEFINE_BUILD)
|
||||
bash scripts/check-apple-export-compliance.sh build/ios/iphonesimulator/Runner.app
|
||||
|
||||
build-go-core: ## Build the Go core helper
|
||||
build-go-core: ## Build the external ACP bridge helper from xworkmate-bridge
|
||||
bash scripts/build-go-core.sh
|
||||
|
||||
package-deb: ## Create the Linux .deb package
|
||||
|
||||
58
docs/architecture/xworkmate-bridge-migration.md
Normal file
58
docs/architecture/xworkmate-bridge-migration.md
Normal file
@ -0,0 +1,58 @@
|
||||
# XWorkmate Bridge Migration
|
||||
|
||||
## Summary
|
||||
|
||||
The ACP Bridge Server implementation was migrated out of `xworkmate-app` into the standalone sibling repository `xworkmate-bridge`.
|
||||
|
||||
This migration separates the embedded Go bridge/server from the Flutter application repository while preserving the existing helper binary contract used by the app.
|
||||
|
||||
## New Repository
|
||||
|
||||
- Repository path: `/Users/shenlan/workspaces/cloud-neutral-toolkit/xworkmate-bridge`
|
||||
- Go module: `xworkmate-bridge`
|
||||
- Helper binary output name: `xworkmate-go-core`
|
||||
|
||||
## What Moved
|
||||
|
||||
The previous `xworkmate-app/go/go_core` implementation was migrated to `xworkmate-bridge`, including:
|
||||
|
||||
- ACP Bridge HTTP/WebSocket server
|
||||
- ACP stdio entrypoint
|
||||
- internal routing, dispatch, mounts, shared RPC helpers, gateway runtime support, memory, skills, and toolbridge packages
|
||||
- Go tests for ACP routing/contracts and bridge helper behavior
|
||||
|
||||
## What Stayed In xworkmate-app
|
||||
|
||||
The following app-side concerns remain in `xworkmate-app`:
|
||||
|
||||
- Flutter UI and settings pages
|
||||
- ACP Bridge client-side configuration and secure-storage handling
|
||||
- Dart runtime launch/locator logic for the helper binary
|
||||
- packaging logic that embeds the helper into the app bundle
|
||||
|
||||
## Build Contract
|
||||
|
||||
`xworkmate-app` still expects a helper named `xworkmate-go-core`.
|
||||
|
||||
To preserve compatibility, `xworkmate-bridge` continues to build the helper using that binary name.
|
||||
|
||||
## App Repository Changes
|
||||
|
||||
In `xworkmate-app`:
|
||||
|
||||
- `go/go_core` was removed
|
||||
- `scripts/build-go-core.sh` now resolves and builds from sibling repo `xworkmate-bridge`
|
||||
- the script supports both normal workspace layout and worktree layout
|
||||
- release notes references were updated to point at the new repository
|
||||
|
||||
## Validation
|
||||
|
||||
Validated during migration:
|
||||
|
||||
- `cd xworkmate-bridge && go test ./...`
|
||||
- `cd xworkmate-bridge && bash scripts/build-helper.sh`
|
||||
- `cd xworkmate-app && bash scripts/build-go-core.sh`
|
||||
|
||||
## Operational Note
|
||||
|
||||
For local development and packaging, `xworkmate-bridge` must exist as a sibling repository next to `xworkmate-app`, unless `XWORKMATE_BRIDGE_DIR` is set explicitly.
|
||||
@ -1,5 +0,0 @@
|
||||
module xworkmate/go_core
|
||||
|
||||
go 1.25.0
|
||||
|
||||
require github.com/gorilla/websocket v1.5.3
|
||||
@ -1,2 +0,0 @@
|
||||
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
|
||||
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||
@ -1,393 +0,0 @@
|
||||
package acp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
|
||||
"xworkmate/go_core/internal/router"
|
||||
"xworkmate/go_core/internal/shared"
|
||||
)
|
||||
|
||||
const (
|
||||
externalProviderEndpointKey = "externalProviderEndpoint"
|
||||
externalProviderAuthorizationHeaderKey = "externalProviderAuthorizationHeader"
|
||||
externalProviderLabelKey = "externalProviderLabel"
|
||||
)
|
||||
|
||||
func buildResolvedExecutionParams(
|
||||
params map[string]any,
|
||||
resolved router.Result,
|
||||
) map[string]any {
|
||||
next := make(map[string]any, len(params)+8)
|
||||
for key, value := range params {
|
||||
next[key] = value
|
||||
}
|
||||
switch resolved.ResolvedExecutionTarget {
|
||||
case router.ExecutionTargetGateway:
|
||||
next["mode"] = router.ExecutionTargetGatewayChat
|
||||
next["executionTarget"] = resolved.ResolvedEndpointTarget
|
||||
case router.ExecutionTargetMultiAgent:
|
||||
next["mode"] = router.ExecutionTargetMultiAgent
|
||||
default:
|
||||
next["mode"] = router.ExecutionTargetSingleAgent
|
||||
}
|
||||
if strings.TrimSpace(resolved.ResolvedProviderID) != "" {
|
||||
next["provider"] = strings.TrimSpace(resolved.ResolvedProviderID)
|
||||
}
|
||||
if strings.TrimSpace(resolved.ResolvedModel) != "" {
|
||||
next["model"] = strings.TrimSpace(resolved.ResolvedModel)
|
||||
}
|
||||
if len(resolved.ResolvedSkills) > 0 {
|
||||
next["selectedSkills"] = append([]string(nil), resolved.ResolvedSkills...)
|
||||
}
|
||||
next["resolvedExecutionTarget"] = resolved.ResolvedExecutionTarget
|
||||
next["resolvedEndpointTarget"] = resolved.ResolvedEndpointTarget
|
||||
next["resolvedProviderId"] = resolved.ResolvedProviderID
|
||||
next["resolvedModel"] = resolved.ResolvedModel
|
||||
next["resolvedSkills"] = append([]string(nil), resolved.ResolvedSkills...)
|
||||
return next
|
||||
}
|
||||
|
||||
func injectResolvedExternalProviderParams(
|
||||
params map[string]any,
|
||||
provider syncedProvider,
|
||||
) map[string]any {
|
||||
if params == nil {
|
||||
params = map[string]any{}
|
||||
}
|
||||
if endpoint := strings.TrimSpace(provider.Endpoint); endpoint != "" {
|
||||
params[externalProviderEndpointKey] = endpoint
|
||||
}
|
||||
if authorization := strings.TrimSpace(provider.AuthorizationHeader); authorization != "" {
|
||||
params[externalProviderAuthorizationHeaderKey] = authorization
|
||||
}
|
||||
if label := strings.TrimSpace(provider.Label); label != "" {
|
||||
params[externalProviderLabelKey] = label
|
||||
}
|
||||
return params
|
||||
}
|
||||
|
||||
func (s *Server) runGateway(
|
||||
ctx context.Context,
|
||||
method string,
|
||||
session *session,
|
||||
params map[string]any,
|
||||
turnID string,
|
||||
notify func(map[string]any),
|
||||
) taskResult {
|
||||
_ = ctx
|
||||
executionTarget := strings.TrimSpace(shared.StringArg(params, "executionTarget", ""))
|
||||
if executionTarget == "" {
|
||||
executionTarget = router.EndpointTargetLocal
|
||||
}
|
||||
result := s.gateway.RequestByMode(
|
||||
executionTarget,
|
||||
method,
|
||||
params,
|
||||
2*time.Minute,
|
||||
notify,
|
||||
)
|
||||
if !result.OK {
|
||||
errMessage := strings.TrimSpace(shared.StringArg(result.Error, "message", "gateway execution failed"))
|
||||
s.emitSessionUpdate(session, notify, turnID, map[string]any{
|
||||
"type": "status",
|
||||
"event": "completed",
|
||||
"message": errMessage,
|
||||
"pending": false,
|
||||
"error": true,
|
||||
})
|
||||
return taskResult{
|
||||
response: map[string]any{
|
||||
"success": false,
|
||||
"error": errMessage,
|
||||
"turnId": turnID,
|
||||
"mode": router.ExecutionTargetGatewayChat,
|
||||
},
|
||||
}
|
||||
}
|
||||
payload := asMap(result.Payload)
|
||||
if len(payload) == 0 {
|
||||
payload = map[string]any{
|
||||
"success": true,
|
||||
"turnId": turnID,
|
||||
"mode": router.ExecutionTargetGatewayChat,
|
||||
}
|
||||
}
|
||||
if _, ok := payload["turnId"]; !ok {
|
||||
payload["turnId"] = turnID
|
||||
}
|
||||
if _, ok := payload["mode"]; !ok {
|
||||
payload["mode"] = router.ExecutionTargetGatewayChat
|
||||
}
|
||||
return taskResult{response: payload}
|
||||
}
|
||||
|
||||
func (s *Server) runSingleAgentViaExternalProvider(
|
||||
ctx context.Context,
|
||||
provider syncedProvider,
|
||||
method string,
|
||||
params map[string]any,
|
||||
notify func(map[string]any),
|
||||
) (map[string]any, error) {
|
||||
endpoint := strings.TrimSpace(provider.Endpoint)
|
||||
if endpoint == "" {
|
||||
return nil, fmt.Errorf("external provider endpoint is missing")
|
||||
}
|
||||
forwardParams := sanitizeExternalACPParams(method, params)
|
||||
return requestExternalACP(
|
||||
ctx,
|
||||
endpoint,
|
||||
provider.AuthorizationHeader,
|
||||
method,
|
||||
forwardParams,
|
||||
notify,
|
||||
)
|
||||
}
|
||||
|
||||
func sanitizeExternalACPParams(method string, params map[string]any) map[string]any {
|
||||
if len(params) == 0 {
|
||||
return map[string]any{}
|
||||
}
|
||||
next := make(map[string]any, len(params))
|
||||
for key, value := range params {
|
||||
next[key] = value
|
||||
}
|
||||
// Internal routing/runtime fields must not leak into external provider payloads.
|
||||
delete(next, "metadata")
|
||||
delete(next, "resolvedExecutionTarget")
|
||||
delete(next, "resolvedEndpointTarget")
|
||||
delete(next, "resolvedProviderId")
|
||||
delete(next, "resolvedModel")
|
||||
delete(next, "resolvedSkills")
|
||||
delete(next, externalProviderEndpointKey)
|
||||
delete(next, externalProviderAuthorizationHeaderKey)
|
||||
delete(next, externalProviderLabelKey)
|
||||
// Gateway-only fields are irrelevant in ACP single-agent forwarding.
|
||||
normalizedMethod := strings.TrimSpace(method)
|
||||
if normalizedMethod == "session.start" || normalizedMethod == "session.message" {
|
||||
delete(next, "executionTarget")
|
||||
delete(next, "agentId")
|
||||
}
|
||||
return next
|
||||
}
|
||||
|
||||
func externalProviderFromParams(params map[string]any) (syncedProvider, bool) {
|
||||
endpoint := strings.TrimSpace(shared.StringArg(params, externalProviderEndpointKey, ""))
|
||||
if endpoint == "" {
|
||||
return syncedProvider{}, false
|
||||
}
|
||||
return syncedProvider{
|
||||
ProviderID: strings.TrimSpace(shared.StringArg(params, "provider", "")),
|
||||
Label: strings.TrimSpace(shared.StringArg(params, externalProviderLabelKey, "")),
|
||||
Endpoint: endpoint,
|
||||
AuthorizationHeader: strings.TrimSpace(shared.StringArg(params, externalProviderAuthorizationHeaderKey, "")),
|
||||
Enabled: true,
|
||||
}, true
|
||||
}
|
||||
|
||||
func requestExternalACP(
|
||||
ctx context.Context,
|
||||
endpoint,
|
||||
authorization,
|
||||
method string,
|
||||
params map[string]any,
|
||||
notify func(map[string]any),
|
||||
) (map[string]any, error) {
|
||||
parsed, err := httpOrWebsocketEndpoint(endpoint)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
switch parsed.Scheme {
|
||||
case "http", "https":
|
||||
return requestExternalACPHTTP(ctx, parsed, authorization, method, params)
|
||||
default:
|
||||
return requestExternalACPWebSocket(ctx, parsed, authorization, method, params, notify)
|
||||
}
|
||||
}
|
||||
|
||||
func requestExternalACPHTTP(
|
||||
ctx context.Context,
|
||||
endpoint *urlSpec,
|
||||
authorization,
|
||||
method string,
|
||||
params map[string]any,
|
||||
) (map[string]any, error) {
|
||||
requestBody, _ := json.Marshal(map[string]any{
|
||||
"jsonrpc": "2.0",
|
||||
"id": fmt.Sprintf("req-%d", time.Now().UnixNano()),
|
||||
"method": method,
|
||||
"params": params,
|
||||
})
|
||||
req, err := http.NewRequestWithContext(
|
||||
ctx,
|
||||
http.MethodPost,
|
||||
endpoint.httpRPCEndpoint(),
|
||||
strings.NewReader(string(requestBody)),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json; charset=utf-8")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
if strings.TrimSpace(authorization) != "" {
|
||||
req.Header.Set("Authorization", strings.TrimSpace(authorization))
|
||||
}
|
||||
response, err := (&http.Client{Timeout: 2 * time.Minute}).Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer response.Body.Close()
|
||||
var decoded map[string]any
|
||||
if err := json.NewDecoder(response.Body).Decode(&decoded); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if errPayload := asMap(decoded["error"]); len(errPayload) > 0 {
|
||||
return nil, fmt.Errorf(
|
||||
"%s",
|
||||
strings.TrimSpace(shared.StringArg(errPayload, "message", "external ACP request failed")),
|
||||
)
|
||||
}
|
||||
return decoded, nil
|
||||
}
|
||||
|
||||
func requestExternalACPWebSocket(
|
||||
ctx context.Context,
|
||||
endpoint *urlSpec,
|
||||
authorization,
|
||||
method string,
|
||||
params map[string]any,
|
||||
notify func(map[string]any),
|
||||
) (map[string]any, error) {
|
||||
headers := http.Header{}
|
||||
if strings.TrimSpace(authorization) != "" {
|
||||
headers.Set("Authorization", strings.TrimSpace(authorization))
|
||||
}
|
||||
conn, _, err := websocket.DefaultDialer.DialContext(
|
||||
ctx,
|
||||
endpoint.webSocketEndpoint(),
|
||||
headers,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
requestID := fmt.Sprintf("req-%d", time.Now().UnixNano())
|
||||
if err := conn.WriteJSON(map[string]any{
|
||||
"jsonrpc": "2.0",
|
||||
"id": requestID,
|
||||
"method": method,
|
||||
"params": params,
|
||||
}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for {
|
||||
if err := conn.SetReadDeadline(time.Now().Add(2 * time.Minute)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var payload map[string]any
|
||||
if err := conn.ReadJSON(&payload); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if strings.TrimSpace(shared.StringArg(payload, "id", "")) == requestID &&
|
||||
(payload["result"] != nil || payload["error"] != nil) {
|
||||
if errPayload := asMap(payload["error"]); len(errPayload) > 0 {
|
||||
return nil, fmt.Errorf(
|
||||
"%s",
|
||||
strings.TrimSpace(shared.StringArg(errPayload, "message", "external ACP request failed")),
|
||||
)
|
||||
}
|
||||
return payload, nil
|
||||
}
|
||||
if notify != nil && strings.TrimSpace(shared.StringArg(payload, "method", "")) != "" {
|
||||
notify(payload)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type urlSpec struct {
|
||||
Scheme string
|
||||
Host string
|
||||
Port string
|
||||
Path string
|
||||
}
|
||||
|
||||
func httpOrWebsocketEndpoint(raw string) (*urlSpec, error) {
|
||||
trimmed := strings.TrimSpace(raw)
|
||||
if trimmed == "" {
|
||||
return nil, fmt.Errorf("missing external ACP endpoint")
|
||||
}
|
||||
parsed, err := url.ParseRequestURI(trimmed)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
scheme := strings.ToLower(strings.TrimSpace(parsed.Scheme))
|
||||
if scheme != "http" && scheme != "https" && scheme != "ws" && scheme != "wss" {
|
||||
return nil, fmt.Errorf("unsupported external ACP scheme: %s", scheme)
|
||||
}
|
||||
return &urlSpec{
|
||||
Scheme: scheme,
|
||||
Host: parsed.Host,
|
||||
Path: strings.TrimRight(parsed.Path, "/"),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (u *urlSpec) basePath() string {
|
||||
path := strings.TrimSpace(u.Path)
|
||||
if path == "" || path == "/" {
|
||||
return ""
|
||||
}
|
||||
if strings.HasSuffix(path, "/acp/rpc") {
|
||||
path = strings.TrimSuffix(path, "/acp/rpc")
|
||||
} else if strings.HasSuffix(path, "/acp") {
|
||||
path = strings.TrimSuffix(path, "/acp")
|
||||
}
|
||||
path = strings.TrimRight(path, "/")
|
||||
if path == "" || path == "/" {
|
||||
return ""
|
||||
}
|
||||
if !strings.HasPrefix(path, "/") {
|
||||
return "/" + path
|
||||
}
|
||||
return path
|
||||
}
|
||||
|
||||
func (u *urlSpec) httpRPCEndpoint() string {
|
||||
scheme := u.Scheme
|
||||
if scheme == "ws" {
|
||||
scheme = "http"
|
||||
} else if scheme == "wss" {
|
||||
scheme = "https"
|
||||
}
|
||||
basePath := u.basePath()
|
||||
if basePath == "" {
|
||||
basePath = "/acp/rpc"
|
||||
} else {
|
||||
basePath += "/acp/rpc"
|
||||
}
|
||||
return fmt.Sprintf("%s://%s%s", scheme, u.Host, basePath)
|
||||
}
|
||||
|
||||
func (u *urlSpec) webSocketEndpoint() string {
|
||||
scheme := u.Scheme
|
||||
if scheme == "http" {
|
||||
scheme = "ws"
|
||||
} else if scheme == "https" {
|
||||
scheme = "wss"
|
||||
}
|
||||
basePath := u.basePath()
|
||||
if basePath == "" {
|
||||
basePath = "/acp"
|
||||
} else {
|
||||
basePath += "/acp"
|
||||
}
|
||||
return fmt.Sprintf("%s://%s%s", scheme, u.Host, basePath)
|
||||
}
|
||||
@ -1,159 +0,0 @@
|
||||
package acp
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"xworkmate/go_core/internal/gatewayruntime"
|
||||
"xworkmate/go_core/internal/shared"
|
||||
)
|
||||
|
||||
func handleGatewayConnect(
|
||||
server *Server,
|
||||
params map[string]any,
|
||||
notify func(map[string]any),
|
||||
) map[string]any {
|
||||
request := gatewayruntime.ConnectRequest{
|
||||
RuntimeID: strings.TrimSpace(shared.StringArg(params, "runtimeId", "")),
|
||||
Mode: strings.TrimSpace(shared.StringArg(params, "mode", "unconfigured")),
|
||||
ClientID: strings.TrimSpace(shared.StringArg(params, "clientId", "")),
|
||||
Locale: strings.TrimSpace(shared.StringArg(params, "locale", "")),
|
||||
UserAgent: strings.TrimSpace(shared.StringArg(params, "userAgent", "")),
|
||||
ConnectAuthMode: strings.TrimSpace(shared.StringArg(params, "connectAuthMode", "")),
|
||||
ConnectAuthFields: parseGatewayRuntimeStringSlice(params["connectAuthFields"]),
|
||||
ConnectAuthSources: parseGatewayRuntimeStringSlice(params["connectAuthSources"]),
|
||||
HasSharedAuth: parseBool(params["hasSharedAuth"]),
|
||||
HasDeviceToken: parseBool(params["hasDeviceToken"]),
|
||||
Endpoint: gatewayruntime.Endpoint{
|
||||
Host: strings.TrimSpace(shared.StringArg(asMap(params["endpoint"]), "host", "")),
|
||||
Port: parsePositiveInt(asMap(params["endpoint"])["port"]),
|
||||
TLS: parseBool(asMap(params["endpoint"])["tls"]),
|
||||
},
|
||||
PackageInfo: gatewayruntime.PackageInfo{
|
||||
AppName: strings.TrimSpace(shared.StringArg(asMap(params["packageInfo"]), "appName", "")),
|
||||
PackageName: strings.TrimSpace(shared.StringArg(asMap(params["packageInfo"]), "packageName", "")),
|
||||
Version: strings.TrimSpace(shared.StringArg(asMap(params["packageInfo"]), "version", "")),
|
||||
BuildNumber: strings.TrimSpace(shared.StringArg(asMap(params["packageInfo"]), "buildNumber", "")),
|
||||
},
|
||||
DeviceInfo: gatewayruntime.DeviceInfo{
|
||||
Platform: strings.TrimSpace(shared.StringArg(asMap(params["deviceInfo"]), "platform", "")),
|
||||
PlatformVersion: strings.TrimSpace(shared.StringArg(asMap(params["deviceInfo"]), "platformVersion", "")),
|
||||
DeviceFamily: strings.TrimSpace(shared.StringArg(asMap(params["deviceInfo"]), "deviceFamily", "")),
|
||||
ModelIdentifier: strings.TrimSpace(shared.StringArg(asMap(params["deviceInfo"]), "modelIdentifier", "")),
|
||||
},
|
||||
Identity: gatewayruntime.DeviceIdentity{
|
||||
DeviceID: strings.TrimSpace(shared.StringArg(asMap(params["identity"]), "deviceId", "")),
|
||||
PublicKeyBase64URL: strings.TrimSpace(shared.StringArg(asMap(params["identity"]), "publicKeyBase64Url", "")),
|
||||
PrivateKeyBase64URL: strings.TrimSpace(shared.StringArg(asMap(params["identity"]), "privateKeyBase64Url", "")),
|
||||
},
|
||||
Auth: gatewayruntime.AuthConfig{
|
||||
Token: strings.TrimSpace(shared.StringArg(asMap(params["auth"]), "token", "")),
|
||||
DeviceToken: strings.TrimSpace(shared.StringArg(asMap(params["auth"]), "deviceToken", "")),
|
||||
Password: strings.TrimSpace(shared.StringArg(asMap(params["auth"]), "password", "")),
|
||||
},
|
||||
}
|
||||
result := server.gateway.Connect(request, notify)
|
||||
return map[string]any{
|
||||
"ok": result.OK,
|
||||
"snapshot": result.Snapshot,
|
||||
"auth": result.Auth,
|
||||
"returnedDeviceToken": result.ReturnedDeviceToken,
|
||||
"error": result.Error,
|
||||
}
|
||||
}
|
||||
|
||||
func handleGatewayRequest(
|
||||
server *Server,
|
||||
params map[string]any,
|
||||
notify func(map[string]any),
|
||||
) map[string]any {
|
||||
timeout := time.Duration(parsePositiveInt(params["timeoutMs"])) * time.Millisecond
|
||||
result := server.gateway.Request(
|
||||
strings.TrimSpace(shared.StringArg(params, "runtimeId", "")),
|
||||
strings.TrimSpace(shared.StringArg(params, "method", "")),
|
||||
asMap(params["params"]),
|
||||
timeout,
|
||||
notify,
|
||||
)
|
||||
return map[string]any{
|
||||
"ok": result.OK,
|
||||
"payload": result.Payload,
|
||||
"error": result.Error,
|
||||
}
|
||||
}
|
||||
|
||||
func handleGatewayDisconnect(
|
||||
server *Server,
|
||||
params map[string]any,
|
||||
notify func(map[string]any),
|
||||
) map[string]any {
|
||||
server.gateway.Disconnect(
|
||||
strings.TrimSpace(shared.StringArg(params, "runtimeId", "")),
|
||||
notify,
|
||||
)
|
||||
return map[string]any{"accepted": true}
|
||||
}
|
||||
|
||||
func asMap(value any) map[string]any {
|
||||
if typed, ok := value.(map[string]any); ok {
|
||||
return typed
|
||||
}
|
||||
if typed, ok := value.(map[string]interface{}); ok {
|
||||
return typed
|
||||
}
|
||||
return map[string]any{}
|
||||
}
|
||||
|
||||
func parseGatewayRuntimeStringSlice(value any) []string {
|
||||
list, ok := value.([]any)
|
||||
if !ok {
|
||||
if typed, ok := value.([]string); ok {
|
||||
return append([]string(nil), typed...)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
result := make([]string, 0, len(list))
|
||||
for _, item := range list {
|
||||
text := strings.TrimSpace(shared.StringArg(map[string]any{"value": item}, "value", ""))
|
||||
if text == "" {
|
||||
continue
|
||||
}
|
||||
result = append(result, text)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func parseBool(value any) bool {
|
||||
switch typed := value.(type) {
|
||||
case bool:
|
||||
return typed
|
||||
case string:
|
||||
return shared.BoolArg(typed, false)
|
||||
case float64:
|
||||
return typed != 0
|
||||
case int:
|
||||
return typed != 0
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func parsePositiveInt(value any) int {
|
||||
switch typed := value.(type) {
|
||||
case int:
|
||||
if typed > 0 {
|
||||
return typed
|
||||
}
|
||||
case int64:
|
||||
if typed > 0 {
|
||||
return int(typed)
|
||||
}
|
||||
case float64:
|
||||
if typed > 0 {
|
||||
return int(typed)
|
||||
}
|
||||
case string:
|
||||
return shared.IntArg(typed, 0)
|
||||
}
|
||||
return 0
|
||||
}
|
||||
@ -1,94 +0,0 @@
|
||||
package acp
|
||||
|
||||
import (
|
||||
"sort"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type syncedProvider struct {
|
||||
ProviderID string
|
||||
Label string
|
||||
Endpoint string
|
||||
AuthorizationHeader string
|
||||
Enabled bool
|
||||
}
|
||||
|
||||
func parseSyncedProviders(raw any) []syncedProvider {
|
||||
list, ok := raw.([]any)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
providers := make([]syncedProvider, 0, len(list))
|
||||
for _, item := range list {
|
||||
entry := asMap(item)
|
||||
providerID := strings.TrimSpace(sharedString(entry, "providerId"))
|
||||
if providerID == "" {
|
||||
continue
|
||||
}
|
||||
providers = append(providers, syncedProvider{
|
||||
ProviderID: providerID,
|
||||
Label: strings.TrimSpace(sharedString(entry, "label")),
|
||||
Endpoint: strings.TrimSpace(sharedString(entry, "endpoint")),
|
||||
AuthorizationHeader: strings.TrimSpace(sharedString(entry, "authorizationHeader")),
|
||||
Enabled: parseBool(entry["enabled"]),
|
||||
})
|
||||
}
|
||||
return providers
|
||||
}
|
||||
|
||||
func (s *Server) syncProviders(providers []syncedProvider) map[string]any {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.providerCatalog = make(map[string]syncedProvider, len(providers))
|
||||
for _, provider := range providers {
|
||||
if strings.TrimSpace(provider.ProviderID) == "" {
|
||||
continue
|
||||
}
|
||||
s.providerCatalog[provider.ProviderID] = provider
|
||||
}
|
||||
return map[string]any{
|
||||
"ok": true,
|
||||
"providers": syncedProvidersResult(providers),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) syncedProviderByID(providerID string) (syncedProvider, bool) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
provider, ok := s.providerCatalog[strings.TrimSpace(providerID)]
|
||||
if !ok || !provider.Enabled || strings.TrimSpace(provider.Endpoint) == "" {
|
||||
return syncedProvider{}, false
|
||||
}
|
||||
return provider, true
|
||||
}
|
||||
|
||||
func (s *Server) availableProviders() []string {
|
||||
providers := make(map[string]struct{})
|
||||
s.mu.Lock()
|
||||
for _, provider := range s.providerCatalog {
|
||||
if !provider.Enabled || strings.TrimSpace(provider.Endpoint) == "" {
|
||||
continue
|
||||
}
|
||||
providers[provider.ProviderID] = struct{}{}
|
||||
}
|
||||
s.mu.Unlock()
|
||||
ordered := make([]string, 0, len(providers))
|
||||
for providerID := range providers {
|
||||
ordered = append(ordered, providerID)
|
||||
}
|
||||
sort.Strings(ordered)
|
||||
return ordered
|
||||
}
|
||||
|
||||
func syncedProvidersResult(providers []syncedProvider) []map[string]any {
|
||||
result := make([]map[string]any, 0, len(providers))
|
||||
for _, provider := range providers {
|
||||
result = append(result, map[string]any{
|
||||
"providerId": provider.ProviderID,
|
||||
"label": provider.Label,
|
||||
"endpoint": provider.Endpoint,
|
||||
"enabled": provider.Enabled,
|
||||
})
|
||||
}
|
||||
return result
|
||||
}
|
||||
@ -1,209 +0,0 @@
|
||||
package acp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"xworkmate/go_core/internal/shared"
|
||||
)
|
||||
|
||||
func TestCapabilitiesIgnoreLocalProviderAutodetectUntilSync(t *testing.T) {
|
||||
fakeProvider := t.TempDir() + "/fake-claude"
|
||||
if err := os.WriteFile(fakeProvider, []byte("#!/bin/sh\nexit 0\n"), 0o755); err != nil {
|
||||
t.Fatalf("write fake provider: %v", err)
|
||||
}
|
||||
t.Setenv("ACP_CLAUDE_BIN", fakeProvider)
|
||||
|
||||
server := NewServer()
|
||||
result, rpcErr := server.handleRequest(shared.RPCRequest{
|
||||
Method: "acp.capabilities",
|
||||
Params: map[string]any{},
|
||||
}, func(map[string]any) {})
|
||||
if rpcErr != nil {
|
||||
t.Fatalf("expected capabilities success, got %v", rpcErr)
|
||||
}
|
||||
|
||||
providers, _ := result["providers"].([]string)
|
||||
if len(providers) != 0 {
|
||||
t.Fatalf("expected no providers before sync, got %#v", providers)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProvidersSyncUpdatesCapabilities(t *testing.T) {
|
||||
server := NewServer()
|
||||
|
||||
_, rpcErr := server.handleRequest(shared.RPCRequest{
|
||||
Method: "xworkmate.providers.sync",
|
||||
Params: map[string]any{
|
||||
"providers": []any{
|
||||
map[string]any{
|
||||
"providerId": "claude",
|
||||
"label": "Claude",
|
||||
"endpoint": "http://127.0.0.1:9999",
|
||||
"authorizationHeader": "Bearer test",
|
||||
"enabled": true,
|
||||
},
|
||||
},
|
||||
},
|
||||
}, func(map[string]any) {})
|
||||
if rpcErr != nil {
|
||||
t.Fatalf("expected sync success, got %v", rpcErr)
|
||||
}
|
||||
|
||||
result, rpcErr := server.handleRequest(shared.RPCRequest{
|
||||
Method: "acp.capabilities",
|
||||
Params: map[string]any{},
|
||||
}, func(map[string]any) {})
|
||||
if rpcErr != nil {
|
||||
t.Fatalf("expected capabilities success, got %v", rpcErr)
|
||||
}
|
||||
providers, _ := result["providers"].([]string)
|
||||
if len(providers) == 0 {
|
||||
t.Fatalf("expected synced provider in capabilities, got %#v", result)
|
||||
}
|
||||
found := false
|
||||
for _, provider := range providers {
|
||||
if provider == "claude" {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Fatalf("expected claude provider after sync, got %#v", providers)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteSessionTaskUsesSyncedExternalProvider(t *testing.T) {
|
||||
var lastForwardedParams map[string]any
|
||||
externalServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/acp/rpc" {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
defer r.Body.Close()
|
||||
var request map[string]any
|
||||
if err := json.NewDecoder(r.Body).Decode(&request); err != nil {
|
||||
t.Fatalf("decode request: %v", err)
|
||||
}
|
||||
lastForwardedParams = asMap(request["params"])
|
||||
method, _ := request["method"].(string)
|
||||
switch method {
|
||||
case "session.start":
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"jsonrpc": "2.0",
|
||||
"id": request["id"],
|
||||
"result": map[string]any{
|
||||
"success": true,
|
||||
"output": "external-provider-ok",
|
||||
"turnId": "turn-external",
|
||||
"provider": "claude",
|
||||
"mode": "single-agent",
|
||||
},
|
||||
})
|
||||
default:
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"jsonrpc": "2.0",
|
||||
"id": request["id"],
|
||||
"result": map[string]any{"ok": true},
|
||||
})
|
||||
}
|
||||
}))
|
||||
defer externalServer.Close()
|
||||
|
||||
server := NewServer()
|
||||
server.syncProviders([]syncedProvider{
|
||||
{
|
||||
ProviderID: "claude",
|
||||
Label: "Claude",
|
||||
Endpoint: externalServer.URL,
|
||||
AuthorizationHeader: "Bearer test",
|
||||
Enabled: true,
|
||||
},
|
||||
})
|
||||
|
||||
response, rpcErr := server.executeSessionTask(task{
|
||||
req: shared.RPCRequest{
|
||||
Method: "session.start",
|
||||
Params: map[string]any{
|
||||
"sessionId": "session-external",
|
||||
"threadId": "thread-external",
|
||||
"taskPrompt": "hello from external provider",
|
||||
"workingDirectory": t.TempDir(),
|
||||
"routing": map[string]any{
|
||||
"routingMode": "explicit",
|
||||
"explicitExecutionTarget": "singleAgent",
|
||||
"explicitProviderId": "claude",
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
if rpcErr != nil {
|
||||
t.Fatalf("expected success, got rpc error: %v", rpcErr)
|
||||
}
|
||||
if got := response["output"]; got != "external-provider-ok" {
|
||||
t.Fatalf("expected external provider output, got %#v", response)
|
||||
}
|
||||
if got := response["resolvedProviderId"]; got != "claude" {
|
||||
t.Fatalf("expected resolved provider claude, got %#v", response)
|
||||
}
|
||||
if _, exists := lastForwardedParams["metadata"]; exists {
|
||||
t.Fatalf("expected metadata to be stripped for external provider request, got %#v", lastForwardedParams)
|
||||
}
|
||||
if _, exists := lastForwardedParams[externalProviderEndpointKey]; exists {
|
||||
t.Fatalf("expected internal endpoint key to be stripped, got %#v", lastForwardedParams)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunSingleAgentUsesFrozenExternalProviderParams(t *testing.T) {
|
||||
externalServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/acp/rpc" {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
defer r.Body.Close()
|
||||
var request map[string]any
|
||||
if err := json.NewDecoder(r.Body).Decode(&request); err != nil {
|
||||
t.Fatalf("decode request: %v", err)
|
||||
}
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"jsonrpc": "2.0",
|
||||
"id": request["id"],
|
||||
"result": map[string]any{
|
||||
"success": true,
|
||||
"output": "frozen-provider-ok",
|
||||
"turnId": "turn-frozen",
|
||||
"provider": "custom-agent-1",
|
||||
"mode": "single-agent",
|
||||
},
|
||||
})
|
||||
}))
|
||||
defer externalServer.Close()
|
||||
|
||||
server := NewServer()
|
||||
session := server.getOrCreateSession("session-frozen", "thread-frozen")
|
||||
result := server.runSingleAgent(
|
||||
context.Background(),
|
||||
"session.start",
|
||||
session,
|
||||
map[string]any{
|
||||
"provider": "custom-agent-1",
|
||||
"taskPrompt": "hello",
|
||||
"workingDirectory": t.TempDir(),
|
||||
externalProviderEndpointKey: externalServer.URL,
|
||||
externalProviderAuthorizationHeaderKey: "Bearer test",
|
||||
externalProviderLabelKey: "Codex",
|
||||
},
|
||||
"turn-frozen",
|
||||
func(map[string]any) {},
|
||||
)
|
||||
if result.err != nil {
|
||||
t.Fatalf("expected success, got rpc error: %v", result.err)
|
||||
}
|
||||
if got := result.response["output"]; got != "frozen-provider-ok" {
|
||||
t.Fatalf("expected frozen provider output, got %#v", result.response)
|
||||
}
|
||||
}
|
||||
@ -1,196 +0,0 @@
|
||||
package acp
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"xworkmate/go_core/internal/memory"
|
||||
"xworkmate/go_core/internal/router"
|
||||
"xworkmate/go_core/internal/skills"
|
||||
)
|
||||
|
||||
func handleRoutingResolve(params map[string]any) map[string]any {
|
||||
result, _ := resolveRoutingMetadataWithProviders(params, nil)
|
||||
return mergeRoutingResponse(map[string]any{"ok": true}, result)
|
||||
}
|
||||
|
||||
func resolveRoutingMetadata(params map[string]any) (router.Result, bool) {
|
||||
return resolveRoutingMetadataWithProviders(params, nil)
|
||||
}
|
||||
|
||||
func resolveRoutingMetadataWithProviders(
|
||||
params map[string]any,
|
||||
availableProviders []string,
|
||||
) (router.Result, bool) {
|
||||
routingParams := asMap(params["routing"])
|
||||
if len(routingParams) == 0 {
|
||||
return router.Result{}, false
|
||||
}
|
||||
installApproval := asMap(routingParams["installApproval"])
|
||||
|
||||
resolver := router.NewResolver()
|
||||
result := resolver.Resolve(router.Request{
|
||||
Prompt: strings.TrimSpace(sharedString(params, "taskPrompt")),
|
||||
WorkingDirectory: strings.TrimSpace(sharedString(params, "workingDirectory")),
|
||||
RoutingMode: strings.TrimSpace(sharedString(routingParams, "routingMode")),
|
||||
PreferredGatewayTarget: strings.TrimSpace(sharedString(routingParams, "preferredGatewayTarget")),
|
||||
ExplicitExecutionTarget: strings.TrimSpace(sharedString(routingParams, "explicitExecutionTarget")),
|
||||
ExplicitProviderID: strings.TrimSpace(sharedString(routingParams, "explicitProviderId")),
|
||||
ExplicitModel: strings.TrimSpace(sharedString(routingParams, "explicitModel")),
|
||||
ExplicitSkills: parseRoutingStringSlice(routingParams["explicitSkills"]),
|
||||
AllowSkillInstall: parseBool(routingParams["allowSkillInstall"]),
|
||||
InstallApproval: skills.InstallApproval{
|
||||
RequestID: strings.TrimSpace(sharedString(installApproval, "requestId")),
|
||||
ApprovedSkillKeys: parseRoutingStringSlice(installApproval["approvedSkillKeys"]),
|
||||
},
|
||||
AvailableSkills: parseRoutingSkillCandidates(routingParams["availableSkills"]),
|
||||
AvailableProviders: append([]string(nil), availableProviders...),
|
||||
AIGatewayBaseURL: strings.TrimSpace(sharedString(params, "aiGatewayBaseUrl")),
|
||||
AIGatewayAPIKey: strings.TrimSpace(sharedString(params, "aiGatewayApiKey")),
|
||||
})
|
||||
return result, true
|
||||
}
|
||||
|
||||
func mergeRoutingResponse(response map[string]any, result router.Result) map[string]any {
|
||||
if response == nil {
|
||||
response = map[string]any{}
|
||||
}
|
||||
response["resolvedExecutionTarget"] = result.ResolvedExecutionTarget
|
||||
response["resolvedEndpointTarget"] = result.ResolvedEndpointTarget
|
||||
response["resolvedProviderId"] = result.ResolvedProviderID
|
||||
response["resolvedModel"] = result.ResolvedModel
|
||||
response["resolvedSkills"] = append([]string(nil), result.ResolvedSkills...)
|
||||
response["skillResolutionSource"] = result.SkillResolutionSource
|
||||
response["needsSkillInstall"] = result.NeedsSkillInstall
|
||||
response["unavailable"] = result.Unavailable
|
||||
if strings.TrimSpace(result.UnavailableCode) != "" {
|
||||
response["unavailableCode"] = result.UnavailableCode
|
||||
}
|
||||
if strings.TrimSpace(result.UnavailableMessage) != "" {
|
||||
response["unavailableMessage"] = result.UnavailableMessage
|
||||
}
|
||||
if strings.TrimSpace(result.SkillInstallRequestID) != "" {
|
||||
response["skillInstallRequestId"] = result.SkillInstallRequestID
|
||||
}
|
||||
if len(result.SkillCandidates) > 0 {
|
||||
response["skillCandidates"] = routingSkillCandidatesMap(result.SkillCandidates)
|
||||
}
|
||||
if len(result.MemorySources) > 0 {
|
||||
response["memorySources"] = routingMemorySourcesMap(result.MemorySources)
|
||||
}
|
||||
return response
|
||||
}
|
||||
|
||||
func recordRoutingSuccess(
|
||||
params map[string]any,
|
||||
result router.Result,
|
||||
response map[string]any,
|
||||
) error {
|
||||
routingParams := asMap(params["routing"])
|
||||
if len(routingParams) == 0 {
|
||||
return nil
|
||||
}
|
||||
if strings.EqualFold(
|
||||
strings.TrimSpace(sharedString(routingParams, "routingMode")),
|
||||
router.RoutingModeExplicit,
|
||||
) {
|
||||
return nil
|
||||
}
|
||||
if !parseBool(response["success"]) {
|
||||
return nil
|
||||
}
|
||||
|
||||
workingDirectory := strings.TrimSpace(sharedString(params, "workingDirectory"))
|
||||
if workingDirectory == "" {
|
||||
return nil
|
||||
}
|
||||
homeDir, _ := os.UserHomeDir()
|
||||
service := memory.NewService(homeDir)
|
||||
return service.RecordSuccess(workingDirectory, memory.SuccessEntry{
|
||||
ResolvedExecutionTarget: result.ResolvedExecutionTarget,
|
||||
ResolvedProviderID: result.ResolvedProviderID,
|
||||
ResolvedModel: result.ResolvedModel,
|
||||
ResolvedSkills: append([]string(nil), result.ResolvedSkills...),
|
||||
})
|
||||
}
|
||||
|
||||
func parseRoutingSkillCandidates(raw any) []skills.Candidate {
|
||||
list, ok := raw.([]any)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
candidates := make([]skills.Candidate, 0, len(list))
|
||||
for _, item := range list {
|
||||
entry := asMap(item)
|
||||
candidates = append(candidates, skills.Candidate{
|
||||
ID: strings.TrimSpace(sharedString(entry, "id")),
|
||||
Label: strings.TrimSpace(sharedString(entry, "label")),
|
||||
Description: strings.TrimSpace(sharedString(entry, "description")),
|
||||
Installed: parseBool(entry["installed"]),
|
||||
})
|
||||
}
|
||||
return candidates
|
||||
}
|
||||
|
||||
func routingSkillCandidatesMap(candidates []skills.Candidate) []map[string]any {
|
||||
result := make([]map[string]any, 0, len(candidates))
|
||||
for _, candidate := range candidates {
|
||||
result = append(result, map[string]any{
|
||||
"id": candidate.ID,
|
||||
"label": candidate.Label,
|
||||
"description": candidate.Description,
|
||||
"installed": candidate.Installed,
|
||||
})
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func routingMemorySourcesMap(sources []memory.Source) []map[string]any {
|
||||
result := make([]map[string]any, 0, len(sources))
|
||||
for _, source := range sources {
|
||||
result = append(result, map[string]any{
|
||||
"path": source.Path,
|
||||
"scope": source.Scope,
|
||||
})
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func parseRoutingStringSlice(raw any) []string {
|
||||
list, ok := raw.([]any)
|
||||
if !ok {
|
||||
if typed, ok := raw.([]string); ok {
|
||||
return append([]string(nil), typed...)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
values := make([]string, 0, len(list))
|
||||
for _, item := range list {
|
||||
value := strings.TrimSpace(sharedStringArg(item))
|
||||
if value == "" {
|
||||
continue
|
||||
}
|
||||
values = append(values, value)
|
||||
}
|
||||
return values
|
||||
}
|
||||
|
||||
func sharedString(params map[string]any, key string) string {
|
||||
if params == nil {
|
||||
return ""
|
||||
}
|
||||
return strings.TrimSpace(sharedStringArg(params[key]))
|
||||
}
|
||||
|
||||
func sharedStringArg(value any) string {
|
||||
if value == nil {
|
||||
return ""
|
||||
}
|
||||
switch typed := value.(type) {
|
||||
case string:
|
||||
return typed
|
||||
default:
|
||||
return fmt.Sprint(value)
|
||||
}
|
||||
}
|
||||
@ -1,472 +0,0 @@
|
||||
package acp
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"xworkmate/go_core/internal/shared"
|
||||
)
|
||||
|
||||
func newExternalSingleAgentProvider(
|
||||
t *testing.T,
|
||||
providerID string,
|
||||
output string,
|
||||
) *httptest.Server {
|
||||
t.Helper()
|
||||
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/acp/rpc" {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
defer r.Body.Close()
|
||||
var request map[string]any
|
||||
if err := json.NewDecoder(r.Body).Decode(&request); err != nil {
|
||||
t.Fatalf("decode request: %v", err)
|
||||
}
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"jsonrpc": "2.0",
|
||||
"id": request["id"],
|
||||
"result": map[string]any{
|
||||
"success": true,
|
||||
"output": output,
|
||||
"turnId": "turn-" + providerID,
|
||||
"provider": providerID,
|
||||
"mode": "single-agent",
|
||||
},
|
||||
})
|
||||
}))
|
||||
}
|
||||
|
||||
func TestHandleRoutingResolveCoversNineScenarioBuckets(t *testing.T) {
|
||||
localAvailableSkills := []map[string]any{
|
||||
{"id": "pptx", "label": "PPTX", "description": "slides", "installed": true},
|
||||
{"id": "docx", "label": "DOCX", "description": "docs", "installed": true},
|
||||
{"id": "xlsx", "label": "XLSX", "description": "sheets", "installed": true},
|
||||
{"id": "pdf", "label": "PDF", "description": "pdf", "installed": true},
|
||||
{"id": "image-resizer", "label": "image-resizer", "description": "image resize", "installed": true},
|
||||
{"id": "browser-automation", "label": "Browser Automation", "description": "browser", "installed": true},
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
prompt string
|
||||
expectedExecutionTarget string
|
||||
expectedSkillSource string
|
||||
expectedResolvedSkill string
|
||||
expectedNeedsSkillInstall bool
|
||||
}{
|
||||
{
|
||||
name: "powerpoint-pptx",
|
||||
prompt: "create a powerpoint deck for this launch",
|
||||
expectedExecutionTarget: "single-agent",
|
||||
expectedSkillSource: "local_match",
|
||||
expectedResolvedSkill: "PPTX",
|
||||
},
|
||||
{
|
||||
name: "word-docx",
|
||||
prompt: "draft a word document memo",
|
||||
expectedExecutionTarget: "single-agent",
|
||||
expectedSkillSource: "local_match",
|
||||
expectedResolvedSkill: "DOCX",
|
||||
},
|
||||
{
|
||||
name: "excel-xlsx",
|
||||
prompt: "build an excel workbook with formulas",
|
||||
expectedExecutionTarget: "single-agent",
|
||||
expectedSkillSource: "local_match",
|
||||
expectedResolvedSkill: "XLSX",
|
||||
},
|
||||
{
|
||||
name: "pdf",
|
||||
prompt: "merge and fill this pdf form",
|
||||
expectedExecutionTarget: "single-agent",
|
||||
expectedSkillSource: "local_match",
|
||||
expectedResolvedSkill: "PDF",
|
||||
},
|
||||
{
|
||||
name: "image-resizer",
|
||||
prompt: "batch resize image assets",
|
||||
expectedExecutionTarget: "single-agent",
|
||||
expectedSkillSource: "local_match",
|
||||
expectedResolvedSkill: "image-resizer",
|
||||
},
|
||||
{
|
||||
name: "image-cog",
|
||||
prompt: "use image-cog to generate consistent characters",
|
||||
expectedExecutionTarget: "gateway",
|
||||
expectedSkillSource: "find_skills",
|
||||
expectedNeedsSkillInstall: true,
|
||||
},
|
||||
{
|
||||
name: "image-video-generation-editting",
|
||||
prompt: "wan 图生视频并做视频编辑",
|
||||
expectedExecutionTarget: "gateway",
|
||||
expectedSkillSource: "find_skills",
|
||||
expectedNeedsSkillInstall: true,
|
||||
},
|
||||
{
|
||||
name: "video-translator",
|
||||
prompt: "translate video subtitles and dub the clip",
|
||||
expectedExecutionTarget: "gateway",
|
||||
expectedSkillSource: "find_skills",
|
||||
expectedNeedsSkillInstall: true,
|
||||
},
|
||||
{
|
||||
name: "browser-search-news",
|
||||
prompt: "跨浏览器执行并搜索最新资讯采集结果",
|
||||
expectedExecutionTarget: "gateway",
|
||||
expectedSkillSource: "local_match",
|
||||
expectedResolvedSkill: "Browser Automation",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
result := handleRoutingResolve(map[string]any{
|
||||
"taskPrompt": tc.prompt,
|
||||
"workingDirectory": "/tmp/workspace",
|
||||
"routing": map[string]any{
|
||||
"routingMode": "auto",
|
||||
"preferredGatewayTarget": "local",
|
||||
"allowSkillInstall": false,
|
||||
"availableSkills": func() []any {
|
||||
values := make([]any, 0, len(localAvailableSkills))
|
||||
for _, item := range localAvailableSkills {
|
||||
values = append(values, item)
|
||||
}
|
||||
return values
|
||||
}(),
|
||||
},
|
||||
})
|
||||
|
||||
if got := result["resolvedExecutionTarget"]; got != tc.expectedExecutionTarget {
|
||||
t.Fatalf("expected execution target %q, got %#v", tc.expectedExecutionTarget, got)
|
||||
}
|
||||
if got := result["skillResolutionSource"]; got != tc.expectedSkillSource {
|
||||
t.Fatalf("expected skill source %q, got %#v", tc.expectedSkillSource, got)
|
||||
}
|
||||
if tc.expectedResolvedSkill != "" {
|
||||
resolvedSkills, _ := result["resolvedSkills"].([]string)
|
||||
if len(resolvedSkills) == 0 || resolvedSkills[0] != tc.expectedResolvedSkill {
|
||||
t.Fatalf("expected resolved skill %q, got %#v", tc.expectedResolvedSkill, result["resolvedSkills"])
|
||||
}
|
||||
}
|
||||
if got := result["needsSkillInstall"]; got != tc.expectedNeedsSkillInstall {
|
||||
t.Fatalf("expected needsSkillInstall=%v, got %#v", tc.expectedNeedsSkillInstall, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteSessionTaskAutoRoutingRecordsProjectMemory(t *testing.T) {
|
||||
homeDir := t.TempDir()
|
||||
workspaceDir := filepath.Join(t.TempDir(), "workspace")
|
||||
if err := os.MkdirAll(workspaceDir, 0o755); err != nil {
|
||||
t.Fatalf("create workspace: %v", err)
|
||||
}
|
||||
|
||||
t.Setenv("HOME", homeDir)
|
||||
|
||||
server := NewServer()
|
||||
providerServer := newExternalSingleAgentProvider(t, "claude", "done")
|
||||
defer providerServer.Close()
|
||||
server.syncProviders([]syncedProvider{{
|
||||
ProviderID: "claude",
|
||||
Label: "Claude",
|
||||
Endpoint: providerServer.URL,
|
||||
Enabled: true,
|
||||
}})
|
||||
response, rpcErr := server.executeSessionTask(task{
|
||||
req: shared.RPCRequest{
|
||||
Params: map[string]any{
|
||||
"sessionId": "session-auto",
|
||||
"threadId": "thread-auto",
|
||||
"provider": "claude",
|
||||
"taskPrompt": "create a powerpoint deck for launch",
|
||||
"workingDirectory": workspaceDir,
|
||||
"routing": map[string]any{
|
||||
"routingMode": "auto",
|
||||
"preferredGatewayTarget": "local",
|
||||
"availableSkills": []any{
|
||||
map[string]any{
|
||||
"id": "pptx",
|
||||
"label": "PPTX",
|
||||
"description": "slides",
|
||||
"installed": true,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
if rpcErr != nil {
|
||||
t.Fatalf("expected success, got rpc error: %v", rpcErr)
|
||||
}
|
||||
if success, _ := response["success"].(bool); !success {
|
||||
t.Fatalf("expected success response, got %#v", response)
|
||||
}
|
||||
|
||||
projectLocalMemory := filepath.Join(workspaceDir, ".xworkmate", "memory.md")
|
||||
content, err := os.ReadFile(projectLocalMemory)
|
||||
if err != nil {
|
||||
t.Fatalf("expected memory file %s: %v", projectLocalMemory, err)
|
||||
}
|
||||
text := string(content)
|
||||
if !strings.Contains(text, "preferred-route: single-agent") {
|
||||
t.Fatalf("expected preferred route in %s, got %q", projectLocalMemory, text)
|
||||
}
|
||||
if !strings.Contains(text, "preferred-skills: PPTX") {
|
||||
t.Fatalf("expected preferred skills in %s, got %q", projectLocalMemory, text)
|
||||
}
|
||||
projectHomeMemory := filepath.Join(
|
||||
homeDir,
|
||||
"self-improving",
|
||||
"projects",
|
||||
filepath.Base(workspaceDir)+".md",
|
||||
)
|
||||
if _, err := os.Stat(projectHomeMemory); !os.IsNotExist(err) {
|
||||
t.Fatalf("expected auto memory write to stay project-local only, got stat err=%v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteSessionTaskExplicitRoutingDoesNotRecordProjectMemory(t *testing.T) {
|
||||
homeDir := t.TempDir()
|
||||
workspaceDir := filepath.Join(t.TempDir(), "workspace")
|
||||
if err := os.MkdirAll(workspaceDir, 0o755); err != nil {
|
||||
t.Fatalf("create workspace: %v", err)
|
||||
}
|
||||
|
||||
t.Setenv("HOME", homeDir)
|
||||
|
||||
server := NewServer()
|
||||
providerServer := newExternalSingleAgentProvider(t, "claude", "done")
|
||||
defer providerServer.Close()
|
||||
server.syncProviders([]syncedProvider{{
|
||||
ProviderID: "claude",
|
||||
Label: "Claude",
|
||||
Endpoint: providerServer.URL,
|
||||
Enabled: true,
|
||||
}})
|
||||
response, rpcErr := server.executeSessionTask(task{
|
||||
req: shared.RPCRequest{
|
||||
Params: map[string]any{
|
||||
"sessionId": "session-explicit",
|
||||
"threadId": "thread-explicit",
|
||||
"provider": "claude",
|
||||
"taskPrompt": "create a powerpoint deck for launch",
|
||||
"workingDirectory": workspaceDir,
|
||||
"routing": map[string]any{
|
||||
"routingMode": "explicit",
|
||||
"explicitExecutionTarget": "singleAgent",
|
||||
"explicitProviderId": "claude",
|
||||
"availableSkills": []any{
|
||||
map[string]any{
|
||||
"id": "pptx",
|
||||
"label": "PPTX",
|
||||
"description": "slides",
|
||||
"installed": true,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
if rpcErr != nil {
|
||||
t.Fatalf("expected success, got rpc error: %v", rpcErr)
|
||||
}
|
||||
if success, _ := response["success"].(bool); !success {
|
||||
t.Fatalf("expected success response, got %#v", response)
|
||||
}
|
||||
|
||||
projectHomeMemory := filepath.Join(
|
||||
homeDir,
|
||||
"self-improving",
|
||||
"projects",
|
||||
filepath.Base(workspaceDir)+".md",
|
||||
)
|
||||
projectLocalMemory := filepath.Join(workspaceDir, ".xworkmate", "memory.md")
|
||||
for _, target := range []string{projectHomeMemory, projectLocalMemory} {
|
||||
if _, err := os.Stat(target); !os.IsNotExist(err) {
|
||||
t.Fatalf("expected no memory write for explicit routing at %s, err=%v", target, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteSessionTaskExplicitProviderRequiresSyncedCatalog(t *testing.T) {
|
||||
server := NewServer()
|
||||
response, rpcErr := server.executeSessionTask(task{
|
||||
req: shared.RPCRequest{
|
||||
Method: "session.start",
|
||||
Params: map[string]any{
|
||||
"sessionId": "session-explicit-provider",
|
||||
"threadId": "thread-explicit-provider",
|
||||
"taskPrompt": "create a powerpoint deck for launch",
|
||||
"routing": map[string]any{
|
||||
"routingMode": "explicit",
|
||||
"explicitExecutionTarget": "singleAgent",
|
||||
"explicitProviderId": "claude",
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
if rpcErr != nil {
|
||||
t.Fatalf("expected structured unavailable response, got rpc error: %v", rpcErr)
|
||||
}
|
||||
if got := response["unavailable"]; got != true {
|
||||
t.Fatalf("expected unavailable response, got %#v", response)
|
||||
}
|
||||
if got := response["unavailableCode"]; got != "PROVIDER_UNAVAILABLE" {
|
||||
t.Fatalf("expected PROVIDER_UNAVAILABLE, got %#v", response)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteSessionTaskRequiresRouting(t *testing.T) {
|
||||
server := NewServer()
|
||||
_, rpcErr := server.executeSessionTask(task{
|
||||
req: shared.RPCRequest{
|
||||
ID: "request-1",
|
||||
Method: "session.start",
|
||||
Params: map[string]any{
|
||||
"sessionId": "session-missing-routing",
|
||||
"threadId": "thread-missing-routing",
|
||||
"taskPrompt": "hello",
|
||||
},
|
||||
},
|
||||
})
|
||||
if rpcErr == nil {
|
||||
t.Fatalf("expected routing-required error")
|
||||
}
|
||||
if rpcErr.Message != "ROUTING_REQUIRED" {
|
||||
t.Fatalf("expected ROUTING_REQUIRED, got %#v", rpcErr)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteSessionTaskAutoRoutingPromotesComplexRequestToMultiAgent(t *testing.T) {
|
||||
workspaceDir := filepath.Join(t.TempDir(), "workspace")
|
||||
if err := os.MkdirAll(workspaceDir, 0o755); err != nil {
|
||||
t.Fatalf("create workspace: %v", err)
|
||||
}
|
||||
|
||||
aiGateway := httptest.NewServer(
|
||||
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"choices":[{"message":{"content":"planner output"}}]}`))
|
||||
}),
|
||||
)
|
||||
defer aiGateway.Close()
|
||||
|
||||
server := NewServer()
|
||||
response, rpcErr := server.executeSessionTask(task{
|
||||
req: shared.RPCRequest{
|
||||
Params: map[string]any{
|
||||
"sessionId": "session-complex",
|
||||
"threadId": "thread-complex",
|
||||
"provider": "claude",
|
||||
"taskPrompt": "collect latest news and summarize it into a report for review",
|
||||
"workingDirectory": workspaceDir,
|
||||
"aiGatewayBaseUrl": aiGateway.URL,
|
||||
"aiGatewayApiKey": "test-key",
|
||||
"routing": map[string]any{
|
||||
"routingMode": "auto",
|
||||
"preferredGatewayTarget": "local",
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
if rpcErr != nil {
|
||||
t.Fatalf("expected success, got rpc error: %v", rpcErr)
|
||||
}
|
||||
if success, _ := response["success"].(bool); !success {
|
||||
t.Fatalf("expected success response, got %#v", response)
|
||||
}
|
||||
if got := response["mode"]; got != "multi-agent" {
|
||||
t.Fatalf("expected session mode to be promoted to multi-agent, got %#v", got)
|
||||
}
|
||||
if got := response["resolvedExecutionTarget"]; got != "multi-agent" {
|
||||
t.Fatalf("expected resolved execution target multi-agent, got %#v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleRoutingResolveAllowsSkillInstallRetry(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
finder := filepath.Join(tempDir, "find-skills.sh")
|
||||
installer := filepath.Join(tempDir, "install-skills.sh")
|
||||
if err := os.WriteFile(
|
||||
finder,
|
||||
[]byte("#!/bin/sh\nprintf '%s' '{\"candidates\":[{\"id\":\"video-translator\",\"label\":\"video-translator\",\"description\":\"translate video\",\"installed\":false}]}'\n"),
|
||||
0o755,
|
||||
); err != nil {
|
||||
t.Fatalf("write finder: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(
|
||||
installer,
|
||||
[]byte("#!/bin/sh\nprintf '%s' '{\"candidates\":[{\"id\":\"video-translator\",\"label\":\"video-translator\",\"description\":\"translate video\",\"installed\":true}]}'\n"),
|
||||
0o755,
|
||||
); err != nil {
|
||||
t.Fatalf("write installer: %v", err)
|
||||
}
|
||||
t.Setenv("ACP_FIND_SKILLS_BIN", finder)
|
||||
t.Setenv("ACP_INSTALL_SKILL_BIN", installer)
|
||||
|
||||
result := handleRoutingResolve(map[string]any{
|
||||
"taskPrompt": "translate and dub this video with subtitles",
|
||||
"workingDirectory": "/tmp/workspace",
|
||||
"routing": map[string]any{
|
||||
"routingMode": "auto",
|
||||
"allowSkillInstall": true,
|
||||
"availableSkills": []any{
|
||||
map[string]any{
|
||||
"id": "docx",
|
||||
"label": "docx",
|
||||
"description": "docs",
|
||||
"installed": true,
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
if got := result["skillResolutionSource"]; got != "find_skills" {
|
||||
t.Fatalf("expected find_skills source, got %#v", got)
|
||||
}
|
||||
if got := result["needsSkillInstall"]; got != true {
|
||||
t.Fatalf("expected first pass to request install approval, got %#v", got)
|
||||
}
|
||||
requestID, _ := result["skillInstallRequestId"].(string)
|
||||
if strings.TrimSpace(requestID) == "" {
|
||||
t.Fatalf("expected install request id, got %#v", result)
|
||||
}
|
||||
|
||||
retried := handleRoutingResolve(map[string]any{
|
||||
"taskPrompt": "translate and dub this video with subtitles",
|
||||
"workingDirectory": "/tmp/workspace",
|
||||
"routing": map[string]any{
|
||||
"routingMode": "auto",
|
||||
"allowSkillInstall": true,
|
||||
"installApproval": map[string]any{
|
||||
"requestId": requestID,
|
||||
"approvedSkillKeys": []any{"video-translator"},
|
||||
},
|
||||
"availableSkills": []any{
|
||||
map[string]any{
|
||||
"id": "docx",
|
||||
"label": "docx",
|
||||
"description": "docs",
|
||||
"installed": true,
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
if got := retried["needsSkillInstall"]; got != false {
|
||||
t.Fatalf("expected install retry to clear needsSkillInstall, got %#v", got)
|
||||
}
|
||||
resolvedSkills, _ := retried["resolvedSkills"].([]string)
|
||||
if len(resolvedSkills) != 1 || resolvedSkills[0] != "video-translator" {
|
||||
t.Fatalf("expected installed skill to resolve, got %#v", retried["resolvedSkills"])
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@ -1,95 +0,0 @@
|
||||
package acp
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"xworkmate/go_core/internal/shared"
|
||||
)
|
||||
|
||||
func RunStdio(input io.Reader, output io.Writer) {
|
||||
server := NewServer()
|
||||
reader := bufio.NewReader(input)
|
||||
var writeMu sync.Mutex
|
||||
|
||||
writeMessage := func(message map[string]any) {
|
||||
payload, _ := jsonMarshal(message)
|
||||
writeMu.Lock()
|
||||
defer writeMu.Unlock()
|
||||
_, _ = output.Write(append(payload, '\n'))
|
||||
}
|
||||
|
||||
for {
|
||||
payload, err := readStdioMessage(reader)
|
||||
if err != nil {
|
||||
if errors.Is(err, io.EOF) {
|
||||
return
|
||||
}
|
||||
writeMessage(shared.ErrorEnvelope(nil, -32700, err.Error()))
|
||||
continue
|
||||
}
|
||||
if len(strings.TrimSpace(string(payload))) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
request, err := shared.DecodeRPCRequest(payload)
|
||||
if err != nil {
|
||||
writeMessage(shared.ErrorEnvelope(nil, -32700, err.Error()))
|
||||
continue
|
||||
}
|
||||
response, rpcErr := server.handleRequest(request, writeMessage)
|
||||
if request.ID == nil {
|
||||
continue
|
||||
}
|
||||
if rpcErr != nil {
|
||||
writeMessage(
|
||||
shared.ErrorEnvelope(request.ID, rpcErr.Code, rpcErr.Message),
|
||||
)
|
||||
continue
|
||||
}
|
||||
writeMessage(shared.ResultEnvelope(request.ID, response))
|
||||
}
|
||||
}
|
||||
|
||||
func readStdioMessage(reader *bufio.Reader) ([]byte, error) {
|
||||
line, err := reader.ReadString('\n')
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
line = strings.TrimSpace(line)
|
||||
if line == "" {
|
||||
return nil, nil
|
||||
}
|
||||
if strings.HasPrefix(strings.ToLower(line), "content-length:") {
|
||||
var contentLength int
|
||||
if _, err := fmt.Sscanf(line, "Content-Length: %d", &contentLength); err != nil {
|
||||
if _, err2 := fmt.Sscanf(line, "content-length: %d", &contentLength); err2 != nil {
|
||||
return nil, fmt.Errorf("invalid content-length header")
|
||||
}
|
||||
}
|
||||
for {
|
||||
headerLine, err := reader.ReadString('\n')
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if strings.TrimSpace(headerLine) == "" {
|
||||
break
|
||||
}
|
||||
}
|
||||
body := make([]byte, contentLength)
|
||||
if _, err := io.ReadFull(reader, body); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return body, nil
|
||||
}
|
||||
return []byte(line), nil
|
||||
}
|
||||
|
||||
func jsonMarshal(message map[string]any) ([]byte, error) {
|
||||
return json.Marshal(message)
|
||||
}
|
||||
@ -1,78 +0,0 @@
|
||||
package acp
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"xworkmate/go_core/internal/shared"
|
||||
)
|
||||
|
||||
func (s *Server) allowedOrigins() []string {
|
||||
raw := strings.TrimSpace(shared.EnvOrDefault(
|
||||
"ACP_ALLOWED_ORIGINS",
|
||||
"https://xworkmate.svc.plus,http://localhost:*,http://127.0.0.1:*",
|
||||
))
|
||||
if raw == "" {
|
||||
return nil
|
||||
}
|
||||
parts := strings.Split(raw, ",")
|
||||
origins := make([]string, 0, len(parts))
|
||||
for _, part := range parts {
|
||||
candidate := strings.TrimSpace(part)
|
||||
if candidate == "" {
|
||||
continue
|
||||
}
|
||||
origins = append(origins, candidate)
|
||||
}
|
||||
return origins
|
||||
}
|
||||
|
||||
func (s *Server) originAllowed(origin string) bool {
|
||||
origin = strings.TrimSpace(origin)
|
||||
if origin == "" {
|
||||
return true
|
||||
}
|
||||
for _, allowed := range s.allowedOrigins() {
|
||||
if strings.HasSuffix(allowed, ":*") {
|
||||
if strings.HasPrefix(origin, strings.TrimSuffix(allowed, "*")) {
|
||||
return true
|
||||
}
|
||||
continue
|
||||
}
|
||||
if origin == allowed {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *Server) applyCORS(w http.ResponseWriter, r *http.Request) {
|
||||
origin := strings.TrimSpace(r.Header.Get("Origin"))
|
||||
if origin == "" || !s.originAllowed(origin) {
|
||||
return
|
||||
}
|
||||
headers := w.Header()
|
||||
headers.Set("Access-Control-Allow-Origin", origin)
|
||||
headers.Set("Access-Control-Allow-Methods", "POST, OPTIONS")
|
||||
headers.Set(
|
||||
"Access-Control-Allow-Headers",
|
||||
"Authorization, Content-Type, Accept",
|
||||
)
|
||||
headers.Set("Access-Control-Max-Age", "600")
|
||||
headers.Add("Vary", "Origin")
|
||||
headers.Add("Vary", "Access-Control-Request-Method")
|
||||
headers.Add("Vary", "Access-Control-Request-Headers")
|
||||
}
|
||||
|
||||
func (s *Server) writeJSONError(
|
||||
w http.ResponseWriter,
|
||||
requestID any,
|
||||
statusCode int,
|
||||
code int,
|
||||
message string,
|
||||
) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(statusCode)
|
||||
_ = json.NewEncoder(w).Encode(shared.ErrorEnvelope(requestID, code, message))
|
||||
}
|
||||
@ -1,111 +0,0 @@
|
||||
package acp
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestHandleWebSocketRejectsUnknownOrigin(t *testing.T) {
|
||||
t.Setenv("ACP_ALLOWED_ORIGINS", "https://xworkmate.svc.plus")
|
||||
|
||||
server := NewServer()
|
||||
recorder := httptest.NewRecorder()
|
||||
request := httptest.NewRequest(http.MethodGet, "http://127.0.0.1/acp", nil)
|
||||
request.Header.Set("Origin", "https://evil.example.com")
|
||||
|
||||
server.HandleWebSocket(recorder, request)
|
||||
|
||||
if recorder.Code != http.StatusForbidden {
|
||||
t.Fatalf("expected 403, got %d", recorder.Code)
|
||||
}
|
||||
if got := recorder.Header().Get("Content-Type"); !strings.Contains(got, "application/json") {
|
||||
t.Fatalf("expected application/json content type, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleRPCAllowsPreflightForConfiguredOrigin(t *testing.T) {
|
||||
t.Setenv("ACP_ALLOWED_ORIGINS", "https://xworkmate.svc.plus,http://localhost:*")
|
||||
|
||||
server := NewServer()
|
||||
recorder := httptest.NewRecorder()
|
||||
request := httptest.NewRequest(http.MethodOptions, "http://127.0.0.1/acp/rpc", nil)
|
||||
request.Header.Set("Origin", "https://xworkmate.svc.plus")
|
||||
request.Header.Set("Access-Control-Request-Method", "POST")
|
||||
|
||||
server.HandleRPC(recorder, request)
|
||||
|
||||
if recorder.Code != http.StatusNoContent {
|
||||
t.Fatalf("expected 204, got %d", recorder.Code)
|
||||
}
|
||||
if got := recorder.Header().Get("Access-Control-Allow-Origin"); got != "https://xworkmate.svc.plus" {
|
||||
t.Fatalf("expected allow origin header, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleRPCRejectsUnknownOrigin(t *testing.T) {
|
||||
t.Setenv("ACP_ALLOWED_ORIGINS", "https://xworkmate.svc.plus")
|
||||
|
||||
server := NewServer()
|
||||
recorder := httptest.NewRecorder()
|
||||
request := httptest.NewRequest(
|
||||
http.MethodPost,
|
||||
"http://127.0.0.1/acp/rpc",
|
||||
strings.NewReader(`{"jsonrpc":"2.0","id":1,"method":"acp.capabilities"}`),
|
||||
)
|
||||
request.Header.Set("Origin", "https://evil.example.com")
|
||||
request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
server.HandleRPC(recorder, request)
|
||||
|
||||
if recorder.Code != http.StatusForbidden {
|
||||
t.Fatalf("expected 403, got %d", recorder.Code)
|
||||
}
|
||||
var envelope map[string]any
|
||||
if err := json.Unmarshal(recorder.Body.Bytes(), &envelope); err != nil {
|
||||
t.Fatalf("decode error envelope: %v", err)
|
||||
}
|
||||
if _, ok := envelope["error"]; !ok {
|
||||
t.Fatalf("expected JSON-RPC error envelope, got %v", envelope)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleRPCMethodErrorUsesJSONEnvelope(t *testing.T) {
|
||||
server := NewServer()
|
||||
recorder := httptest.NewRecorder()
|
||||
request := httptest.NewRequest(http.MethodGet, "http://127.0.0.1/acp/rpc", nil)
|
||||
|
||||
server.HandleRPC(recorder, request)
|
||||
|
||||
if recorder.Code != http.StatusMethodNotAllowed {
|
||||
t.Fatalf("expected 405, got %d", recorder.Code)
|
||||
}
|
||||
if got := recorder.Header().Get("Content-Type"); !strings.Contains(got, "application/json") {
|
||||
t.Fatalf("expected application/json content type, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleRPCCapabilitiesStillReturnsJSONResult(t *testing.T) {
|
||||
server := NewServer()
|
||||
recorder := httptest.NewRecorder()
|
||||
request := httptest.NewRequest(
|
||||
http.MethodPost,
|
||||
"http://127.0.0.1/acp/rpc",
|
||||
strings.NewReader(`{"jsonrpc":"2.0","id":1,"method":"acp.capabilities"}`),
|
||||
)
|
||||
request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
server.HandleRPC(recorder, request)
|
||||
|
||||
if recorder.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d", recorder.Code)
|
||||
}
|
||||
if got := recorder.Header().Get("Content-Type"); !strings.Contains(got, "application/json") {
|
||||
t.Fatalf("expected application/json content type, got %q", got)
|
||||
}
|
||||
if !strings.Contains(recorder.Body.String(), `"providers"`) {
|
||||
t.Fatalf("expected capabilities response, got %q", recorder.Body.String())
|
||||
}
|
||||
}
|
||||
@ -1,203 +0,0 @@
|
||||
package dispatch
|
||||
|
||||
import (
|
||||
"maps"
|
||||
"slices"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type Provider struct {
|
||||
ID string
|
||||
Name string
|
||||
DefaultArgs []string
|
||||
Capabilities []string
|
||||
}
|
||||
|
||||
type NodeState struct {
|
||||
SelectedAgentID string
|
||||
GatewayConnected bool
|
||||
ExecutionTarget string
|
||||
RuntimeMode string
|
||||
BridgeEnabled bool
|
||||
BridgeState string
|
||||
ResolvedCodexCLIPath string
|
||||
ConfiguredCodexCLIPath string
|
||||
}
|
||||
|
||||
type NodeInfo struct {
|
||||
ID string
|
||||
Name string
|
||||
Version string
|
||||
}
|
||||
|
||||
type Request struct {
|
||||
Providers []Provider
|
||||
PreferredProviderID string
|
||||
RequiredCapabilities []string
|
||||
NodeState *NodeState
|
||||
NodeInfo *NodeInfo
|
||||
}
|
||||
|
||||
type Result struct {
|
||||
Provider *Provider
|
||||
AgentID string
|
||||
Metadata map[string]any
|
||||
}
|
||||
|
||||
func Resolve(request Request) Result {
|
||||
provider := selectProvider(
|
||||
request.Providers,
|
||||
request.PreferredProviderID,
|
||||
request.RequiredCapabilities,
|
||||
)
|
||||
if request.NodeState == nil {
|
||||
return Result{Provider: provider, Metadata: map[string]any{}}
|
||||
}
|
||||
|
||||
state := request.NodeState
|
||||
nodeInfo := request.NodeInfo
|
||||
nodeID := "xworkmate-app"
|
||||
nodeName := "XWorkmate"
|
||||
nodeVersion := ""
|
||||
if nodeInfo != nil {
|
||||
if strings.TrimSpace(nodeInfo.ID) != "" {
|
||||
nodeID = strings.TrimSpace(nodeInfo.ID)
|
||||
}
|
||||
if strings.TrimSpace(nodeInfo.Name) != "" {
|
||||
nodeName = strings.TrimSpace(nodeInfo.Name)
|
||||
}
|
||||
nodeVersion = strings.TrimSpace(nodeInfo.Version)
|
||||
}
|
||||
|
||||
configuredPath := strings.TrimSpace(state.ConfiguredCodexCLIPath)
|
||||
if strings.TrimSpace(state.ResolvedCodexCLIPath) != "" {
|
||||
configuredPath = strings.TrimSpace(state.ResolvedCodexCLIPath)
|
||||
}
|
||||
localTransport := "stdio-jsonrpc"
|
||||
if strings.TrimSpace(state.RuntimeMode) == "builtIn" {
|
||||
localTransport = "ffi-runtime"
|
||||
}
|
||||
|
||||
metadata := map[string]any{
|
||||
"node": map[string]any{
|
||||
"id": nodeID,
|
||||
"name": nodeName,
|
||||
"version": nodeVersion,
|
||||
"kind": "app-mediated-cooperative-node",
|
||||
"gatewayTransport": "websocket-rpc",
|
||||
},
|
||||
"dispatch": map[string]any{
|
||||
"mode": dispatchMode(state.BridgeEnabled),
|
||||
"executionTarget": strings.TrimSpace(state.ExecutionTarget),
|
||||
},
|
||||
"bridge": map[string]any{
|
||||
"enabled": state.BridgeEnabled,
|
||||
"state": strings.TrimSpace(state.BridgeState),
|
||||
"gatewayConnected": state.GatewayConnected,
|
||||
"runtimeMode": strings.TrimSpace(state.RuntimeMode),
|
||||
"localTransport": localTransport,
|
||||
},
|
||||
}
|
||||
if configuredPath != "" {
|
||||
bridge := metadata["bridge"].(map[string]any)
|
||||
bridge["binaryConfigured"] = true
|
||||
}
|
||||
if provider != nil {
|
||||
metadata["provider"] = map[string]any{
|
||||
"id": provider.ID,
|
||||
"name": provider.Name,
|
||||
"defaultArgs": provider.DefaultArgs,
|
||||
"capabilities": provider.Capabilities,
|
||||
}
|
||||
}
|
||||
|
||||
return Result{
|
||||
Provider: provider,
|
||||
AgentID: strings.TrimSpace(state.SelectedAgentID),
|
||||
Metadata: metadata,
|
||||
}
|
||||
}
|
||||
|
||||
func dispatchMode(bridgeEnabled bool) string {
|
||||
if bridgeEnabled {
|
||||
return "cooperative"
|
||||
}
|
||||
return "gateway-only"
|
||||
}
|
||||
|
||||
func selectProvider(
|
||||
providers []Provider,
|
||||
preferredProviderID string,
|
||||
requiredCapabilities []string,
|
||||
) *Provider {
|
||||
required := normalizeCapabilities(requiredCapabilities)
|
||||
preferredID := strings.TrimSpace(preferredProviderID)
|
||||
if preferredID != "" {
|
||||
for _, provider := range providers {
|
||||
if provider.ID == preferredID && supportsProvider(provider, required) {
|
||||
candidate := provider
|
||||
return &candidate
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
filtered := make([]Provider, 0, len(providers))
|
||||
for _, provider := range providers {
|
||||
if supportsProvider(provider, required) {
|
||||
filtered = append(filtered, provider)
|
||||
}
|
||||
}
|
||||
if len(filtered) == 0 {
|
||||
return nil
|
||||
}
|
||||
slices.SortFunc(filtered, func(a, b Provider) int {
|
||||
return strings.Compare(a.ID, b.ID)
|
||||
})
|
||||
candidate := filtered[0]
|
||||
return &candidate
|
||||
}
|
||||
|
||||
func supportsProvider(provider Provider, required map[string]struct{}) bool {
|
||||
if len(required) == 0 {
|
||||
return true
|
||||
}
|
||||
provided := normalizeCapabilities(provider.Capabilities)
|
||||
for capability := range required {
|
||||
if _, ok := provided[capability]; !ok {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func normalizeCapabilities(values []string) map[string]struct{} {
|
||||
normalized := map[string]struct{}{}
|
||||
for _, value := range values {
|
||||
item := strings.TrimSpace(strings.ToLower(value))
|
||||
if item == "" {
|
||||
continue
|
||||
}
|
||||
normalized[item] = struct{}{}
|
||||
}
|
||||
return normalized
|
||||
}
|
||||
|
||||
func ResultMap(result Result) map[string]any {
|
||||
response := map[string]any{
|
||||
"metadata": result.Metadata,
|
||||
}
|
||||
if result.Provider != nil {
|
||||
provider := *result.Provider
|
||||
response["providerId"] = provider.ID
|
||||
response["provider"] = map[string]any{
|
||||
"id": provider.ID,
|
||||
"name": provider.Name,
|
||||
"defaultArgs": slices.Clone(provider.DefaultArgs),
|
||||
"capabilities": slices.Clone(provider.Capabilities),
|
||||
}
|
||||
}
|
||||
if strings.TrimSpace(result.AgentID) != "" {
|
||||
response["agentId"] = strings.TrimSpace(result.AgentID)
|
||||
}
|
||||
return maps.Clone(response)
|
||||
}
|
||||
@ -1,96 +0,0 @@
|
||||
package dispatch
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestResolvePrefersRequestedProviderWhenCapabilitiesMatch(t *testing.T) {
|
||||
result := Resolve(Request{
|
||||
Providers: []Provider{
|
||||
{
|
||||
ID: "codex",
|
||||
Name: "Codex",
|
||||
Capabilities: []string{"chat", "gateway-bridge"},
|
||||
},
|
||||
{
|
||||
ID: "qwen",
|
||||
Name: "Qwen",
|
||||
Capabilities: []string{"chat"},
|
||||
},
|
||||
},
|
||||
PreferredProviderID: "codex",
|
||||
RequiredCapabilities: []string{"gateway-bridge"},
|
||||
})
|
||||
|
||||
if result.Provider == nil || result.Provider.ID != "codex" {
|
||||
t.Fatalf("expected codex provider, got %#v", result.Provider)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveFallsBackDeterministicallyByID(t *testing.T) {
|
||||
result := Resolve(Request{
|
||||
Providers: []Provider{
|
||||
{
|
||||
ID: "qwen",
|
||||
Name: "Qwen",
|
||||
Capabilities: []string{"chat"},
|
||||
},
|
||||
{
|
||||
ID: "codex",
|
||||
Name: "Codex",
|
||||
Capabilities: []string{"chat"},
|
||||
},
|
||||
},
|
||||
RequiredCapabilities: []string{"chat"},
|
||||
})
|
||||
|
||||
if result.Provider == nil || result.Provider.ID != "codex" {
|
||||
t.Fatalf("expected deterministic codex fallback, got %#v", result.Provider)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveBuildsGatewayDispatchMetadata(t *testing.T) {
|
||||
result := Resolve(Request{
|
||||
Providers: []Provider{
|
||||
{
|
||||
ID: "codex",
|
||||
Name: "Codex CLI",
|
||||
DefaultArgs: []string{"app-server"},
|
||||
Capabilities: []string{"chat", "gateway-bridge"},
|
||||
},
|
||||
},
|
||||
PreferredProviderID: "codex",
|
||||
RequiredCapabilities: []string{"gateway-bridge"},
|
||||
NodeState: &NodeState{
|
||||
SelectedAgentID: "main",
|
||||
GatewayConnected: true,
|
||||
ExecutionTarget: "local",
|
||||
RuntimeMode: "externalCli",
|
||||
BridgeEnabled: true,
|
||||
BridgeState: "registered",
|
||||
ResolvedCodexCLIPath: "/opt/homebrew/bin/codex",
|
||||
},
|
||||
NodeInfo: &NodeInfo{
|
||||
ID: "xworkmate-app",
|
||||
Name: "XWorkmate",
|
||||
Version: "1.0.0",
|
||||
},
|
||||
})
|
||||
|
||||
if result.Provider == nil || result.Provider.ID != "codex" {
|
||||
t.Fatalf("expected codex provider, got %#v", result.Provider)
|
||||
}
|
||||
if result.AgentID != "main" {
|
||||
t.Fatalf("expected agent id main, got %q", result.AgentID)
|
||||
}
|
||||
dispatch, ok := result.Metadata["dispatch"].(map[string]any)
|
||||
if !ok || dispatch["mode"] != "cooperative" {
|
||||
t.Fatalf("expected cooperative dispatch, got %#v", result.Metadata["dispatch"])
|
||||
}
|
||||
bridge, ok := result.Metadata["bridge"].(map[string]any)
|
||||
if !ok || bridge["localTransport"] != "stdio-jsonrpc" {
|
||||
t.Fatalf("expected stdio-jsonrpc bridge transport, got %#v", result.Metadata["bridge"])
|
||||
}
|
||||
provider, ok := result.Metadata["provider"].(map[string]any)
|
||||
if !ok || provider["id"] != "codex" {
|
||||
t.Fatalf("expected provider metadata for codex, got %#v", result.Metadata["provider"])
|
||||
}
|
||||
}
|
||||
@ -1,98 +0,0 @@
|
||||
package gatewayruntime
|
||||
|
||||
import "strings"
|
||||
|
||||
func normalizeChatRunEvent(event string, payload map[string]any) map[string]any {
|
||||
switch event {
|
||||
case "chat":
|
||||
runID := strings.TrimSpace(stringValue(payload["runId"]))
|
||||
state := strings.TrimSpace(stringValue(payload["state"]))
|
||||
if runID == "" && state == "" {
|
||||
return nil
|
||||
}
|
||||
message := asMap(payload["message"])
|
||||
assistantText := ""
|
||||
if strings.EqualFold(strings.TrimSpace(stringValue(message["role"])), "assistant") {
|
||||
assistantText = extractMessageText(message)
|
||||
}
|
||||
normalized := map[string]any{
|
||||
"runId": runID,
|
||||
"sessionKey": strings.TrimSpace(stringValue(payload["sessionKey"])),
|
||||
"state": state,
|
||||
"source": "chat",
|
||||
"terminal": state == "final" || state == "aborted" || state == "error",
|
||||
}
|
||||
if assistantText != "" {
|
||||
normalized["assistantText"] = assistantText
|
||||
}
|
||||
if errorMessage := strings.TrimSpace(stringValue(payload["errorMessage"])); errorMessage != "" {
|
||||
normalized["errorMessage"] = errorMessage
|
||||
}
|
||||
return normalized
|
||||
case "agent":
|
||||
runID := strings.TrimSpace(stringValue(payload["runId"]))
|
||||
if runID == "" {
|
||||
return nil
|
||||
}
|
||||
stream := strings.TrimSpace(stringValue(payload["stream"]))
|
||||
if !strings.EqualFold(stream, "assistant") {
|
||||
return nil
|
||||
}
|
||||
data := asMap(payload["data"])
|
||||
assistantText := strings.TrimSpace(stringValue(data["text"]))
|
||||
if assistantText == "" {
|
||||
assistantText = extractMessageText(data)
|
||||
}
|
||||
if assistantText == "" {
|
||||
return nil
|
||||
}
|
||||
sessionKey := strings.TrimSpace(stringValue(payload["sessionKey"]))
|
||||
if sessionKey == "" {
|
||||
sessionKey = strings.TrimSpace(stringValue(data["sessionKey"]))
|
||||
}
|
||||
return map[string]any{
|
||||
"runId": runID,
|
||||
"sessionKey": sessionKey,
|
||||
"state": "delta",
|
||||
"source": "agent",
|
||||
"stream": stream,
|
||||
"assistantText": assistantText,
|
||||
"terminal": false,
|
||||
}
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func asList(value any) []any {
|
||||
switch typed := value.(type) {
|
||||
case []any:
|
||||
return typed
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func extractMessageText(message map[string]any) string {
|
||||
directContent, ok := message["content"].(string)
|
||||
if ok {
|
||||
return strings.TrimSpace(directContent)
|
||||
}
|
||||
parts := make([]string, 0, 4)
|
||||
for _, part := range asList(message["content"]) {
|
||||
segment := asMap(part)
|
||||
text := strings.TrimSpace(firstNonEmpty(
|
||||
stringValue(segment["text"]),
|
||||
stringValue(segment["thinking"]),
|
||||
))
|
||||
if text != "" {
|
||||
parts = append(parts, text)
|
||||
continue
|
||||
}
|
||||
nestedContent := strings.TrimSpace(stringValue(segment["content"]))
|
||||
if nestedContent != "" {
|
||||
parts = append(parts, nestedContent)
|
||||
}
|
||||
}
|
||||
return strings.TrimSpace(strings.Join(parts, "\n"))
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@ -1,337 +0,0 @@
|
||||
package gatewayruntime
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
func TestManagerConnectAndRequest(t *testing.T) {
|
||||
server := newFakeGatewayServer(t)
|
||||
defer server.Close()
|
||||
|
||||
manager := NewManager()
|
||||
manager.ReconnectDelay = 20 * time.Millisecond
|
||||
notifications := make([]map[string]any, 0, 8)
|
||||
var mu sync.Mutex
|
||||
notify := func(message map[string]any) {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
notifications = append(notifications, message)
|
||||
}
|
||||
|
||||
result := manager.Connect(buildTestConnectRequest(server.Port()), notify)
|
||||
if !result.OK {
|
||||
t.Fatalf("expected connect success, got %#v", result.Error)
|
||||
}
|
||||
if result.ReturnedDeviceToken != "device-token-1" {
|
||||
t.Fatalf("expected returned device token, got %#v", result.ReturnedDeviceToken)
|
||||
}
|
||||
|
||||
requestResult := manager.Request(
|
||||
"runtime-1",
|
||||
"health",
|
||||
map[string]any{},
|
||||
2*time.Second,
|
||||
notify,
|
||||
)
|
||||
if !requestResult.OK {
|
||||
t.Fatalf("expected health success, got %#v", requestResult.Error)
|
||||
}
|
||||
payload, ok := requestResult.Payload.(map[string]any)
|
||||
if !ok || payload["status"] != "ok" {
|
||||
t.Fatalf("unexpected health payload %#v", requestResult.Payload)
|
||||
}
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
if len(notifications) == 0 {
|
||||
t.Fatalf("expected notifications during connect")
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagerReconnectsAfterSocketClose(t *testing.T) {
|
||||
server := newFakeGatewayServer(t)
|
||||
server.closeAfterConnect.Store(true)
|
||||
defer server.Close()
|
||||
|
||||
manager := NewManager()
|
||||
manager.ReconnectDelay = 25 * time.Millisecond
|
||||
|
||||
reconnected := make(chan struct{}, 1)
|
||||
notify := func(message map[string]any) {
|
||||
params := asMap(message["params"])
|
||||
if strings.TrimSpace(stringValue(message["method"])) != "xworkmate.gateway.snapshot" {
|
||||
return
|
||||
}
|
||||
snapshot := asMap(params["snapshot"])
|
||||
if snapshot["status"] == "connected" && server.ConnectCount() >= 2 {
|
||||
select {
|
||||
case reconnected <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
result := manager.Connect(buildTestConnectRequest(server.Port()), notify)
|
||||
if !result.OK {
|
||||
t.Fatalf("expected connect success, got %#v", result.Error)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-reconnected:
|
||||
case <-time.After(3 * time.Second):
|
||||
t.Fatalf("expected reconnect to complete; connect count=%d", server.ConnectCount())
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagerSuppressesReconnectForPairingRequired(t *testing.T) {
|
||||
server := newFakeGatewayServer(t)
|
||||
server.connectErrorCode = "NOT_PAIRED"
|
||||
server.connectErrorDetailCode = "PAIRING_REQUIRED"
|
||||
defer server.Close()
|
||||
|
||||
manager := NewManager()
|
||||
manager.ReconnectDelay = 20 * time.Millisecond
|
||||
result := manager.Connect(buildTestConnectRequest(server.Port()), func(map[string]any) {})
|
||||
if result.OK {
|
||||
t.Fatalf("expected connect failure")
|
||||
}
|
||||
time.Sleep(120 * time.Millisecond)
|
||||
if server.ConnectCount() != 1 {
|
||||
t.Fatalf("expected reconnect suppression, got %d connect attempts", server.ConnectCount())
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionEmitsNormalizedChatRunPushEvents(t *testing.T) {
|
||||
manager := NewManager()
|
||||
session := newSession(manager, "runtime-1")
|
||||
notifications := make([]map[string]any, 0, 8)
|
||||
session.setNotify(func(message map[string]any) {
|
||||
notifications = append(notifications, message)
|
||||
})
|
||||
|
||||
session.handleEvent(
|
||||
"chat",
|
||||
map[string]any{"seq": 7},
|
||||
map[string]any{
|
||||
"runId": "run-1",
|
||||
"sessionKey": "agent:main:main",
|
||||
"state": "final",
|
||||
"message": map[string]any{
|
||||
"role": "assistant",
|
||||
"content": []any{
|
||||
map[string]any{"type": "text", "text": "XWORKMATE_OK"},
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
session.handleEvent(
|
||||
"agent",
|
||||
map[string]any{"seq": 8},
|
||||
map[string]any{
|
||||
"runId": "run-1",
|
||||
"stream": "assistant",
|
||||
"data": map[string]any{
|
||||
"text": "DELTA_TEXT",
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
normalized := make([]map[string]any, 0, 2)
|
||||
for _, notification := range notifications {
|
||||
if strings.TrimSpace(stringValue(notification["method"])) != "xworkmate.gateway.push" {
|
||||
continue
|
||||
}
|
||||
params := asMap(notification["params"])
|
||||
event := asMap(params["event"])
|
||||
if strings.TrimSpace(stringValue(event["event"])) != "chat.run" {
|
||||
continue
|
||||
}
|
||||
normalized = append(normalized, asMap(event["payload"]))
|
||||
}
|
||||
|
||||
if len(normalized) != 2 {
|
||||
t.Fatalf("expected 2 normalized chat.run notifications, got %#v", normalized)
|
||||
}
|
||||
if normalized[0]["runId"] != "run-1" || normalized[0]["state"] != "final" {
|
||||
t.Fatalf("unexpected normalized chat payload %#v", normalized[0])
|
||||
}
|
||||
if normalized[0]["assistantText"] != "XWORKMATE_OK" {
|
||||
t.Fatalf("expected final assistant text, got %#v", normalized[0])
|
||||
}
|
||||
if normalized[0]["terminal"] != true {
|
||||
t.Fatalf("expected terminal final chat.run, got %#v", normalized[0])
|
||||
}
|
||||
if normalized[1]["assistantText"] != "DELTA_TEXT" || normalized[1]["state"] != "delta" {
|
||||
t.Fatalf("unexpected normalized agent payload %#v", normalized[1])
|
||||
}
|
||||
}
|
||||
|
||||
type fakeGatewayServer struct {
|
||||
server *http.Server
|
||||
listener net.Listener
|
||||
connectCount atomic.Int32
|
||||
closeAfterConnect atomic.Bool
|
||||
connectErrorCode string
|
||||
connectErrorDetailCode string
|
||||
}
|
||||
|
||||
func newFakeGatewayServer(t *testing.T) *fakeGatewayServer {
|
||||
t.Helper()
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("listen: %v", err)
|
||||
}
|
||||
fake := &fakeGatewayServer{listener: listener}
|
||||
upgrader := websocket.Upgrader{CheckOrigin: func(*http.Request) bool { return true }}
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
_ = conn.WriteJSON(map[string]any{
|
||||
"type": "event",
|
||||
"event": "connect.challenge",
|
||||
"payload": map[string]any{
|
||||
"nonce": "nonce-1",
|
||||
},
|
||||
})
|
||||
for {
|
||||
_, payload, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
var frame map[string]any
|
||||
if err := json.Unmarshal(payload, &frame); err != nil {
|
||||
continue
|
||||
}
|
||||
if frame["type"] != "req" {
|
||||
continue
|
||||
}
|
||||
id := frame["id"]
|
||||
method := stringValue(frame["method"])
|
||||
switch method {
|
||||
case "connect":
|
||||
fake.connectCount.Add(1)
|
||||
if fake.connectErrorCode != "" {
|
||||
_ = conn.WriteJSON(map[string]any{
|
||||
"type": "res",
|
||||
"id": id,
|
||||
"ok": false,
|
||||
"error": map[string]any{
|
||||
"code": fake.connectErrorCode,
|
||||
"message": "connect failed",
|
||||
"details": map[string]any{
|
||||
"code": fake.connectErrorDetailCode,
|
||||
},
|
||||
},
|
||||
})
|
||||
continue
|
||||
}
|
||||
_ = conn.WriteJSON(map[string]any{
|
||||
"type": "res",
|
||||
"id": id,
|
||||
"ok": true,
|
||||
"payload": map[string]any{
|
||||
"server": map[string]any{"host": "127.0.0.1"},
|
||||
"snapshot": map[string]any{
|
||||
"sessionDefaults": map[string]any{"mainSessionKey": "main"},
|
||||
},
|
||||
"auth": map[string]any{
|
||||
"role": "operator",
|
||||
"scopes": defaultOperatorScopes,
|
||||
"deviceToken": "device-token-1",
|
||||
},
|
||||
},
|
||||
})
|
||||
if fake.closeAfterConnect.Load() && fake.connectCount.Load() == 1 {
|
||||
go func() {
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
_ = conn.Close()
|
||||
}()
|
||||
}
|
||||
case "health":
|
||||
_ = conn.WriteJSON(map[string]any{
|
||||
"type": "res",
|
||||
"id": id,
|
||||
"ok": true,
|
||||
"payload": map[string]any{
|
||||
"status": "ok",
|
||||
},
|
||||
})
|
||||
default:
|
||||
_ = conn.WriteJSON(map[string]any{
|
||||
"type": "res",
|
||||
"id": id,
|
||||
"ok": true,
|
||||
"payload": map[string]any{},
|
||||
})
|
||||
}
|
||||
}
|
||||
})
|
||||
fake.server = &http.Server{Handler: mux}
|
||||
go func() {
|
||||
_ = fake.server.Serve(listener)
|
||||
}()
|
||||
return fake
|
||||
}
|
||||
|
||||
func (f *fakeGatewayServer) Port() int {
|
||||
return f.listener.Addr().(*net.TCPAddr).Port
|
||||
}
|
||||
|
||||
func (f *fakeGatewayServer) ConnectCount() int {
|
||||
return int(f.connectCount.Load())
|
||||
}
|
||||
|
||||
func (f *fakeGatewayServer) Close() {
|
||||
_ = f.server.Close()
|
||||
}
|
||||
|
||||
func buildTestConnectRequest(port int) ConnectRequest {
|
||||
return ConnectRequest{
|
||||
RuntimeID: "runtime-1",
|
||||
Mode: "remote",
|
||||
ClientID: "openclaw-macos",
|
||||
Locale: "en_US",
|
||||
UserAgent: "XWorkmate/1.0.0",
|
||||
Endpoint: Endpoint{
|
||||
Host: "127.0.0.1",
|
||||
Port: port,
|
||||
TLS: false,
|
||||
},
|
||||
ConnectAuthMode: "shared-token",
|
||||
ConnectAuthFields: []string{"token"},
|
||||
ConnectAuthSources: []string{"shared:form"},
|
||||
HasSharedAuth: true,
|
||||
HasDeviceToken: false,
|
||||
PackageInfo: PackageInfo{
|
||||
AppName: "XWorkmate",
|
||||
Version: "1.0.0",
|
||||
},
|
||||
DeviceInfo: DeviceInfo{
|
||||
Platform: "macos",
|
||||
PlatformVersion: "14.0",
|
||||
DeviceFamily: "Mac",
|
||||
ModelIdentifier: "Mac14,5",
|
||||
},
|
||||
Identity: DeviceIdentity{
|
||||
DeviceID: "device-1",
|
||||
PublicKeyBase64URL: "test-public-key-value",
|
||||
PrivateKeyBase64URL: "test-private-key-value",
|
||||
},
|
||||
Auth: AuthConfig{
|
||||
Token: "shared-token",
|
||||
},
|
||||
}
|
||||
}
|
||||
@ -1,129 +0,0 @@
|
||||
package gatewayruntime
|
||||
|
||||
import "time"
|
||||
|
||||
const (
|
||||
defaultProtocolVersion = 3
|
||||
defaultReconnectDelay = 2 * time.Second
|
||||
defaultConnectTimeout = 10 * time.Second
|
||||
defaultChallengeWait = 2 * time.Second
|
||||
defaultRequestTimeout = 15 * time.Second
|
||||
)
|
||||
|
||||
var defaultOperatorScopes = []string{
|
||||
"operator.admin",
|
||||
"operator.read",
|
||||
"operator.write",
|
||||
"operator.approvals",
|
||||
"operator.pairing",
|
||||
}
|
||||
|
||||
type Endpoint struct {
|
||||
Host string
|
||||
Port int
|
||||
TLS bool
|
||||
}
|
||||
|
||||
type PackageInfo struct {
|
||||
AppName string
|
||||
PackageName string
|
||||
Version string
|
||||
BuildNumber string
|
||||
}
|
||||
|
||||
type DeviceInfo struct {
|
||||
Platform string
|
||||
PlatformVersion string
|
||||
DeviceFamily string
|
||||
ModelIdentifier string
|
||||
}
|
||||
|
||||
func (d DeviceInfo) PlatformLabel() string {
|
||||
if d.PlatformVersion == "" {
|
||||
return d.Platform
|
||||
}
|
||||
return d.Platform + " " + d.PlatformVersion
|
||||
}
|
||||
|
||||
type DeviceIdentity struct {
|
||||
DeviceID string
|
||||
PublicKeyBase64URL string
|
||||
PrivateKeyBase64URL string
|
||||
}
|
||||
|
||||
type AuthConfig struct {
|
||||
Token string
|
||||
DeviceToken string
|
||||
Password string
|
||||
}
|
||||
|
||||
type ConnectRequest struct {
|
||||
RuntimeID string
|
||||
Mode string
|
||||
ClientID string
|
||||
Locale string
|
||||
UserAgent string
|
||||
Endpoint Endpoint
|
||||
ConnectAuthMode string
|
||||
ConnectAuthFields []string
|
||||
ConnectAuthSources []string
|
||||
HasSharedAuth bool
|
||||
HasDeviceToken bool
|
||||
PackageInfo PackageInfo
|
||||
DeviceInfo DeviceInfo
|
||||
Identity DeviceIdentity
|
||||
Auth AuthConfig
|
||||
}
|
||||
|
||||
type ConnectResult struct {
|
||||
OK bool
|
||||
Snapshot map[string]any
|
||||
Auth map[string]any
|
||||
ReturnedDeviceToken string
|
||||
Error map[string]any
|
||||
}
|
||||
|
||||
type RequestResult struct {
|
||||
OK bool
|
||||
Payload any
|
||||
Error map[string]any
|
||||
}
|
||||
|
||||
type GatewayError struct {
|
||||
Message string
|
||||
Code string
|
||||
Details map[string]any
|
||||
}
|
||||
|
||||
func (e *GatewayError) Error() string {
|
||||
if e == nil {
|
||||
return ""
|
||||
}
|
||||
return e.Message
|
||||
}
|
||||
|
||||
func (e *GatewayError) DetailCode() string {
|
||||
if e == nil || e.Details == nil {
|
||||
return ""
|
||||
}
|
||||
if value, ok := e.Details["code"].(string); ok {
|
||||
return value
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (e *GatewayError) Map() map[string]any {
|
||||
if e == nil {
|
||||
return map[string]any{}
|
||||
}
|
||||
payload := map[string]any{
|
||||
"message": e.Message,
|
||||
}
|
||||
if e.Code != "" {
|
||||
payload["code"] = e.Code
|
||||
}
|
||||
if len(e.Details) > 0 {
|
||||
payload["details"] = e.Details
|
||||
}
|
||||
return payload
|
||||
}
|
||||
@ -1,49 +0,0 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
|
||||
"xworkmate/go_core/internal/service"
|
||||
)
|
||||
|
||||
type Authenticator interface {
|
||||
Authenticate(username, password string) error
|
||||
}
|
||||
|
||||
type AuthHandler struct {
|
||||
service Authenticator
|
||||
}
|
||||
|
||||
func NewAuthHandler(svc Authenticator) *AuthHandler {
|
||||
return &AuthHandler{service: svc}
|
||||
}
|
||||
|
||||
func NewServiceAdapter(svc *service.AuthService) Authenticator {
|
||||
return authServiceAdapter{service: svc}
|
||||
}
|
||||
|
||||
func (h *AuthHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
var payload struct {
|
||||
Username string `json:"username"`
|
||||
Password string `json:"password"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&payload); err != nil {
|
||||
http.Error(w, "invalid json", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if err := h.service.Authenticate(payload.Username, payload.Password); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_ = json.NewEncoder(w).Encode(map[string]string{"status": "ok"})
|
||||
}
|
||||
|
||||
type authServiceAdapter struct {
|
||||
service *service.AuthService
|
||||
}
|
||||
|
||||
func (a authServiceAdapter) Authenticate(username, password string) error {
|
||||
return a.service.Authenticate(nil, username, password)
|
||||
}
|
||||
@ -1,53 +0,0 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type fakeAuthenticator struct {
|
||||
err error
|
||||
}
|
||||
|
||||
func (f fakeAuthenticator) Authenticate(username, password string) error {
|
||||
return f.err
|
||||
}
|
||||
|
||||
func TestAuthHandlerRejectsInvalidJSON(t *testing.T) {
|
||||
handler := NewAuthHandler(fakeAuthenticator{})
|
||||
req := httptest.NewRequest(http.MethodPost, "/auth", bytes.NewBufferString("{"))
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusBadRequest {
|
||||
t.Fatalf("expected 400, got %d", rec.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthHandlerReturnsUnauthorizedOnServiceFailure(t *testing.T) {
|
||||
handler := NewAuthHandler(fakeAuthenticator{err: errors.New("invalid credentials")})
|
||||
req := httptest.NewRequest(http.MethodPost, "/auth", bytes.NewBufferString(`{"username":"alice","password":"secret"}`))
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusUnauthorized {
|
||||
t.Fatalf("expected 401, got %d", rec.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthHandlerReturnsOKOnSuccess(t *testing.T) {
|
||||
handler := NewAuthHandler(fakeAuthenticator{})
|
||||
req := httptest.NewRequest(http.MethodPost, "/auth", bytes.NewBufferString(`{"username":"alice","password":"secret"}`))
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d", rec.Code)
|
||||
}
|
||||
}
|
||||
@ -1,234 +0,0 @@
|
||||
package memory
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Source struct {
|
||||
Path string
|
||||
Scope string
|
||||
}
|
||||
|
||||
type Preferences struct {
|
||||
PreferredRoute string
|
||||
PreferredModel string
|
||||
PreferredSkills []string
|
||||
Provider string
|
||||
}
|
||||
|
||||
type LoadResult struct {
|
||||
MergedText string
|
||||
Sources []Source
|
||||
Preferences Preferences
|
||||
ProjectFiles []string
|
||||
}
|
||||
|
||||
type SuccessEntry struct {
|
||||
ResolvedExecutionTarget string
|
||||
ResolvedProviderID string
|
||||
ResolvedModel string
|
||||
ResolvedSkills []string
|
||||
Summary string
|
||||
}
|
||||
|
||||
type Service struct {
|
||||
HomeDir string
|
||||
}
|
||||
|
||||
func NewService(homeDir string) Service {
|
||||
return Service{HomeDir: strings.TrimSpace(homeDir)}
|
||||
}
|
||||
|
||||
func (s Service) Load(workingDirectory string) LoadResult {
|
||||
projectName := projectNameFromWorkingDirectory(workingDirectory)
|
||||
paths := []Source{
|
||||
{Path: filepath.Join(s.HomeDir, "self-improving", "memory.md"), Scope: "global"},
|
||||
{Path: filepath.Join(s.HomeDir, "self-improving", "projects", projectName+".md"), Scope: "project-home"},
|
||||
{Path: filepath.Join(strings.TrimSpace(workingDirectory), ".xworkmate", "memory.md"), Scope: "project-local"},
|
||||
}
|
||||
merged := make([]string, 0, len(paths))
|
||||
sources := make([]Source, 0, len(paths))
|
||||
prefs := Preferences{}
|
||||
projectFiles := make([]string, 0, 2)
|
||||
|
||||
for _, source := range paths {
|
||||
if strings.TrimSpace(source.Path) == "" {
|
||||
continue
|
||||
}
|
||||
content, err := os.ReadFile(source.Path)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
text := sanitizeMemoryText(string(content))
|
||||
if strings.TrimSpace(text) == "" {
|
||||
continue
|
||||
}
|
||||
sources = append(sources, source)
|
||||
merged = append(merged, fmt.Sprintf("## %s\n%s", source.Scope, text))
|
||||
mergePreferences(&prefs, parsePreferences(text))
|
||||
if source.Scope != "global" {
|
||||
projectFiles = append(projectFiles, source.Path)
|
||||
}
|
||||
}
|
||||
|
||||
return LoadResult{
|
||||
MergedText: strings.TrimSpace(strings.Join(merged, "\n\n")),
|
||||
Sources: sources,
|
||||
Preferences: prefs,
|
||||
ProjectFiles: projectFiles,
|
||||
}
|
||||
}
|
||||
|
||||
func (s Service) RecordSuccess(workingDirectory string, entry SuccessEntry) error {
|
||||
workingDirectory = strings.TrimSpace(workingDirectory)
|
||||
if workingDirectory == "" {
|
||||
return nil
|
||||
}
|
||||
projectName := projectNameFromWorkingDirectory(workingDirectory)
|
||||
if projectName == "" {
|
||||
return nil
|
||||
}
|
||||
target := s.projectWriteTarget(workingDirectory, projectName)
|
||||
if target == "" {
|
||||
return nil
|
||||
}
|
||||
block := formatSuccessEntry(entry)
|
||||
if err := os.MkdirAll(filepath.Dir(target), 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
file, err := os.OpenFile(target, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0o644)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := file.WriteString(block); err != nil {
|
||||
_ = file.Close()
|
||||
return err
|
||||
}
|
||||
if err := file.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s Service) projectWriteTarget(
|
||||
workingDirectory string,
|
||||
projectName string,
|
||||
) string {
|
||||
repoLocalDir := filepath.Join(workingDirectory, ".xworkmate")
|
||||
if err := os.MkdirAll(repoLocalDir, 0o755); err == nil {
|
||||
return filepath.Join(repoLocalDir, "memory.md")
|
||||
}
|
||||
return filepath.Join(s.HomeDir, "self-improving", "projects", projectName+".md")
|
||||
}
|
||||
|
||||
func formatSuccessEntry(entry SuccessEntry) string {
|
||||
lines := []string{
|
||||
"",
|
||||
fmt.Sprintf("## Auto route %s", time.Now().Format(time.RFC3339)),
|
||||
fmt.Sprintf("preferred-route: %s", strings.TrimSpace(entry.ResolvedExecutionTarget)),
|
||||
}
|
||||
if strings.TrimSpace(entry.ResolvedModel) != "" {
|
||||
lines = append(lines, fmt.Sprintf("preferred-model: %s", strings.TrimSpace(entry.ResolvedModel)))
|
||||
}
|
||||
if len(entry.ResolvedSkills) > 0 {
|
||||
lines = append(lines, fmt.Sprintf("preferred-skills: %s", strings.Join(entry.ResolvedSkills, ", ")))
|
||||
}
|
||||
if strings.TrimSpace(entry.ResolvedProviderID) != "" {
|
||||
lines = append(lines, fmt.Sprintf("provider: %s", strings.TrimSpace(entry.ResolvedProviderID)))
|
||||
}
|
||||
if summary := sanitizeMemoryText(entry.Summary); strings.TrimSpace(summary) != "" {
|
||||
lines = append(lines, "summary:")
|
||||
for _, line := range strings.Split(summary, "\n") {
|
||||
trimmed := strings.TrimSpace(line)
|
||||
if trimmed == "" {
|
||||
continue
|
||||
}
|
||||
lines = append(lines, fmt.Sprintf("- %s", trimmed))
|
||||
}
|
||||
}
|
||||
return strings.Join(lines, "\n") + "\n"
|
||||
}
|
||||
|
||||
func parsePreferences(text string) Preferences {
|
||||
prefs := Preferences{}
|
||||
for _, line := range strings.Split(text, "\n") {
|
||||
trimmed := strings.TrimSpace(line)
|
||||
switch {
|
||||
case strings.HasPrefix(strings.ToLower(trimmed), "preferred-route:"):
|
||||
prefs.PreferredRoute = normalizePreferredRoute(
|
||||
strings.TrimSpace(strings.TrimPrefix(trimmed, "preferred-route:")),
|
||||
)
|
||||
case strings.HasPrefix(strings.ToLower(trimmed), "preferred-model:"):
|
||||
prefs.PreferredModel = strings.TrimSpace(strings.TrimPrefix(trimmed, "preferred-model:"))
|
||||
case strings.HasPrefix(strings.ToLower(trimmed), "preferred-skills:"):
|
||||
raw := strings.TrimSpace(strings.TrimPrefix(trimmed, "preferred-skills:"))
|
||||
for _, item := range strings.Split(raw, ",") {
|
||||
value := strings.TrimSpace(item)
|
||||
if value != "" {
|
||||
prefs.PreferredSkills = append(prefs.PreferredSkills, value)
|
||||
}
|
||||
}
|
||||
case strings.HasPrefix(strings.ToLower(trimmed), "provider:"):
|
||||
prefs.Provider = strings.TrimSpace(strings.TrimPrefix(trimmed, "provider:"))
|
||||
}
|
||||
}
|
||||
return prefs
|
||||
}
|
||||
|
||||
func mergePreferences(dst *Preferences, src Preferences) {
|
||||
if strings.TrimSpace(src.PreferredRoute) != "" {
|
||||
dst.PreferredRoute = strings.TrimSpace(src.PreferredRoute)
|
||||
}
|
||||
if strings.TrimSpace(src.PreferredModel) != "" {
|
||||
dst.PreferredModel = strings.TrimSpace(src.PreferredModel)
|
||||
}
|
||||
if len(src.PreferredSkills) > 0 {
|
||||
dst.PreferredSkills = append([]string(nil), src.PreferredSkills...)
|
||||
}
|
||||
if strings.TrimSpace(src.Provider) != "" {
|
||||
dst.Provider = strings.TrimSpace(src.Provider)
|
||||
}
|
||||
}
|
||||
|
||||
func sanitizeMemoryText(text string) string {
|
||||
lines := strings.Split(text, "\n")
|
||||
filtered := make([]string, 0, len(lines))
|
||||
for _, line := range lines {
|
||||
normalized := strings.ToLower(strings.TrimSpace(line))
|
||||
if normalized == "" {
|
||||
filtered = append(filtered, "")
|
||||
continue
|
||||
}
|
||||
if strings.Contains(normalized, "token") ||
|
||||
strings.Contains(normalized, "password") ||
|
||||
strings.Contains(normalized, "secret") ||
|
||||
strings.Contains(normalized, "api_key") ||
|
||||
strings.Contains(normalized, "apikey") ||
|
||||
strings.Contains(normalized, "api key") {
|
||||
continue
|
||||
}
|
||||
filtered = append(filtered, line)
|
||||
}
|
||||
return strings.TrimSpace(strings.Join(filtered, "\n"))
|
||||
}
|
||||
|
||||
func projectNameFromWorkingDirectory(workingDirectory string) string {
|
||||
cleaned := strings.TrimSpace(workingDirectory)
|
||||
if cleaned == "" {
|
||||
return ""
|
||||
}
|
||||
return strings.TrimSpace(filepath.Base(cleaned))
|
||||
}
|
||||
|
||||
func normalizePreferredRoute(value string) string {
|
||||
switch strings.ToLower(strings.TrimSpace(value)) {
|
||||
case "gateway-chat":
|
||||
return "gateway"
|
||||
default:
|
||||
return strings.TrimSpace(value)
|
||||
}
|
||||
}
|
||||
@ -1,117 +0,0 @@
|
||||
package memory
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestLoadMergesGlobalAndProjectMemoryAndSanitizesSecrets(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
workingDir := filepath.Join(tempDir, "workspace")
|
||||
homeDir := filepath.Join(tempDir, "home")
|
||||
if err := os.MkdirAll(filepath.Join(workingDir, ".xworkmate"), 0o755); err != nil {
|
||||
t.Fatalf("mkdir workspace: %v", err)
|
||||
}
|
||||
if err := os.MkdirAll(filepath.Join(homeDir, "self-improving", "projects"), 0o755); err != nil {
|
||||
t.Fatalf("mkdir home: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(homeDir, "self-improving", "memory.md"), []byte("preferred-route: gateway-chat\napi_key: hidden\n"), 0o644); err != nil {
|
||||
t.Fatalf("write global memory: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(homeDir, "self-improving", "projects", "workspace.md"), []byte("preferred-model: gpt-5.4\n"), 0o644); err != nil {
|
||||
t.Fatalf("write project memory: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(workingDir, ".xworkmate", "memory.md"), []byte("preferred-skills: pptx, pdf\npassword: hidden\n"), 0o644); err != nil {
|
||||
t.Fatalf("write local memory: %v", err)
|
||||
}
|
||||
|
||||
result := NewService(homeDir).Load(workingDir)
|
||||
|
||||
if len(result.Sources) != 3 {
|
||||
t.Fatalf("expected 3 memory sources, got %d", len(result.Sources))
|
||||
}
|
||||
if strings.Contains(strings.ToLower(result.MergedText), "api_key") || strings.Contains(strings.ToLower(result.MergedText), "password") {
|
||||
t.Fatalf("expected sanitized merged text, got %q", result.MergedText)
|
||||
}
|
||||
if result.Preferences.PreferredRoute != "gateway" {
|
||||
t.Fatalf("unexpected preferred route: %#v", result.Preferences)
|
||||
}
|
||||
if result.Preferences.PreferredModel != "gpt-5.4" {
|
||||
t.Fatalf("unexpected preferred model: %#v", result.Preferences)
|
||||
}
|
||||
if len(result.Preferences.PreferredSkills) != 2 {
|
||||
t.Fatalf("unexpected preferred skills: %#v", result.Preferences.PreferredSkills)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordSuccessWritesProjectLevelMemoryFiles(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
workingDir := filepath.Join(tempDir, "repo")
|
||||
homeDir := filepath.Join(tempDir, "home")
|
||||
if err := os.MkdirAll(workingDir, 0o755); err != nil {
|
||||
t.Fatalf("mkdir working dir: %v", err)
|
||||
}
|
||||
|
||||
service := NewService(homeDir)
|
||||
err := service.RecordSuccess(workingDir, SuccessEntry{
|
||||
ResolvedExecutionTarget: "single-agent",
|
||||
ResolvedModel: "gpt-5.4",
|
||||
ResolvedSkills: []string{"pptx", "pdf"},
|
||||
Summary: "created a clean deck",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("record success: %v", err)
|
||||
}
|
||||
|
||||
repoLocalTarget := filepath.Join(workingDir, ".xworkmate", "memory.md")
|
||||
content, err := os.ReadFile(repoLocalTarget)
|
||||
if err != nil {
|
||||
t.Fatalf("read target %s: %v", repoLocalTarget, err)
|
||||
}
|
||||
text := string(content)
|
||||
if !strings.Contains(text, "preferred-route: single-agent") {
|
||||
t.Fatalf("missing preferred route in %s: %q", repoLocalTarget, text)
|
||||
}
|
||||
if strings.Contains(strings.ToLower(text), "token") {
|
||||
t.Fatalf("unexpected sensitive content in %s: %q", repoLocalTarget, text)
|
||||
}
|
||||
homeProjectTarget := filepath.Join(homeDir, "self-improving", "projects", "repo.md")
|
||||
if _, err := os.Stat(homeProjectTarget); !os.IsNotExist(err) {
|
||||
t.Fatalf("expected single project-level write target, got stat err=%v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadLetsProjectMemoryOverrideGlobalPreferences(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
workingDir := filepath.Join(tempDir, "workspace")
|
||||
homeDir := filepath.Join(tempDir, "home")
|
||||
if err := os.MkdirAll(filepath.Join(workingDir, ".xworkmate"), 0o755); err != nil {
|
||||
t.Fatalf("mkdir workspace: %v", err)
|
||||
}
|
||||
if err := os.MkdirAll(filepath.Join(homeDir, "self-improving", "projects"), 0o755); err != nil {
|
||||
t.Fatalf("mkdir home: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(homeDir, "self-improving", "memory.md"), []byte("preferred-route: single-agent\npreferred-model: gpt-4o\npreferred-skills: docx\n"), 0o644); err != nil {
|
||||
t.Fatalf("write global memory: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(homeDir, "self-improving", "projects", "workspace.md"), []byte("preferred-route: gateway\npreferred-model: gpt-5.4\n"), 0o644); err != nil {
|
||||
t.Fatalf("write project home memory: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(workingDir, ".xworkmate", "memory.md"), []byte("preferred-route: multi-agent\npreferred-skills: pptx, pdf\n"), 0o644); err != nil {
|
||||
t.Fatalf("write project local memory: %v", err)
|
||||
}
|
||||
|
||||
result := NewService(homeDir).Load(workingDir)
|
||||
|
||||
if result.Preferences.PreferredRoute != "multi-agent" {
|
||||
t.Fatalf("expected project-local route to win, got %#v", result.Preferences)
|
||||
}
|
||||
if result.Preferences.PreferredModel != "gpt-5.4" {
|
||||
t.Fatalf("expected project-home model to override global, got %#v", result.Preferences)
|
||||
}
|
||||
if len(result.Preferences.PreferredSkills) != 2 || result.Preferences.PreferredSkills[0] != "pptx" {
|
||||
t.Fatalf("expected project-local skills to win, got %#v", result.Preferences.PreferredSkills)
|
||||
}
|
||||
}
|
||||
@ -1,150 +0,0 @@
|
||||
package mounts
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
codexManagedMCPBlockStart = "# BEGIN XWORKMATE MANAGED MCP BLOCK"
|
||||
codexManagedMCPBlockEnd = "# END XWORKMATE MANAGED MCP BLOCK"
|
||||
opencodeManagedMCPBlockStart = "# BEGIN XWORKMATE MANAGED MCP BLOCK"
|
||||
opencodeManagedMCPBlockEnd = "# END XWORKMATE MANAGED MCP BLOCK"
|
||||
)
|
||||
|
||||
var mcpServerSectionPattern = regexp.MustCompile(
|
||||
`(?m)^\[mcp_servers\.[^\]]+\]`,
|
||||
)
|
||||
|
||||
func countMCPSections(content string) int {
|
||||
return len(mcpServerSectionPattern.FindAllStringIndex(content, -1))
|
||||
}
|
||||
|
||||
func defaultCodexHome() string {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil || strings.TrimSpace(home) == "" {
|
||||
return ""
|
||||
}
|
||||
return filepath.Join(home, ".codex")
|
||||
}
|
||||
|
||||
func defaultOpencodeHome() string {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil || strings.TrimSpace(home) == "" {
|
||||
return ""
|
||||
}
|
||||
return filepath.Join(home, ".opencode")
|
||||
}
|
||||
|
||||
func defaultOpenClawHome() string {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil || strings.TrimSpace(home) == "" {
|
||||
return ""
|
||||
}
|
||||
return filepath.Join(home, ".openclaw")
|
||||
}
|
||||
|
||||
func stripManagedBlock(content, startMarker, endMarker string) string {
|
||||
if strings.TrimSpace(content) == "" {
|
||||
return content
|
||||
}
|
||||
|
||||
remaining := content
|
||||
for {
|
||||
start := strings.Index(remaining, startMarker)
|
||||
if start < 0 {
|
||||
break
|
||||
}
|
||||
end := strings.Index(remaining[start:], endMarker)
|
||||
if end < 0 {
|
||||
remaining = remaining[:start]
|
||||
break
|
||||
}
|
||||
end += start
|
||||
remaining = remaining[:start] + remaining[end+len(endMarker):]
|
||||
}
|
||||
return remaining
|
||||
}
|
||||
|
||||
func mergeManagedBlock(content, block, startMarker, endMarker string) string {
|
||||
preserved := strings.TrimRight(
|
||||
stripManagedBlock(content, startMarker, endMarker),
|
||||
"\n",
|
||||
)
|
||||
if preserved == "" {
|
||||
return block + "\n"
|
||||
}
|
||||
return preserved + "\n\n" + block + "\n"
|
||||
}
|
||||
|
||||
func buildCodexManagedMCPBlock(servers []ManagedMCPServer) string {
|
||||
var buffer strings.Builder
|
||||
buffer.WriteString(codexManagedMCPBlockStart)
|
||||
buffer.WriteString("\n# Generated by XWorkmate - Managed MCP Server Configuration\n")
|
||||
buffer.WriteString(
|
||||
fmt.Sprintf("# Last updated: %s\n\n", time.Now().Format(time.RFC3339Nano)),
|
||||
)
|
||||
for _, server := range servers {
|
||||
buffer.WriteString(fmt.Sprintf("[mcp_servers.%s]\n", server.ID))
|
||||
buffer.WriteString(fmt.Sprintf("command = %q\n", server.Command))
|
||||
if len(server.Args) > 0 {
|
||||
buffer.WriteString(fmt.Sprintf("args = %s\n", formatTOMLArray(server.Args)))
|
||||
}
|
||||
buffer.WriteString("\n")
|
||||
}
|
||||
buffer.WriteString(codexManagedMCPBlockEnd)
|
||||
return strings.TrimRight(buffer.String(), "\n")
|
||||
}
|
||||
|
||||
func buildOpencodeManagedMCPBlock(servers []ManagedMCPServer) string {
|
||||
var buffer strings.Builder
|
||||
buffer.WriteString(opencodeManagedMCPBlockStart)
|
||||
buffer.WriteString("\n# Generated by XWorkmate - Managed MCP Server Configuration\n")
|
||||
buffer.WriteString(
|
||||
fmt.Sprintf("# Last updated: %s\n\n", time.Now().Format(time.RFC3339Nano)),
|
||||
)
|
||||
for _, server := range servers {
|
||||
buffer.WriteString(fmt.Sprintf("[mcp_servers.%s]\n", server.ID))
|
||||
if strings.TrimSpace(server.URL) != "" {
|
||||
buffer.WriteString(fmt.Sprintf("url = %q\n", strings.TrimSpace(server.URL)))
|
||||
} else {
|
||||
buffer.WriteString("type = \"stdio\"\n")
|
||||
buffer.WriteString(fmt.Sprintf("command = %q\n", server.Command))
|
||||
if len(server.Args) > 0 {
|
||||
buffer.WriteString(fmt.Sprintf("args = %s\n", formatTOMLArray(server.Args)))
|
||||
}
|
||||
}
|
||||
buffer.WriteString("\n")
|
||||
}
|
||||
buffer.WriteString(opencodeManagedMCPBlockEnd)
|
||||
return strings.TrimRight(buffer.String(), "\n")
|
||||
}
|
||||
|
||||
func formatTOMLArray(items []string) string {
|
||||
if len(items) == 0 {
|
||||
return "[]"
|
||||
}
|
||||
var quoted []string
|
||||
for _, item := range items {
|
||||
quoted = append(quoted, fmt.Sprintf("%q", item))
|
||||
}
|
||||
return "[" + strings.Join(quoted, ", ") + "]"
|
||||
}
|
||||
|
||||
func applyManagedBlock(configPath, block, startMarker, endMarker string) error {
|
||||
configDir := filepath.Dir(configPath)
|
||||
if err := os.MkdirAll(configDir, 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
content, err := os.ReadFile(configPath)
|
||||
if err != nil && !os.IsNotExist(err) {
|
||||
return err
|
||||
}
|
||||
merged := mergeManagedBlock(string(content), block, startMarker, endMarker)
|
||||
return os.WriteFile(configPath, []byte(merged), 0o644)
|
||||
}
|
||||
@ -1,437 +0,0 @@
|
||||
package mounts
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type ManagedMCPServer struct {
|
||||
ID string
|
||||
Name string
|
||||
Transport string
|
||||
Command string
|
||||
URL string
|
||||
Args []string
|
||||
Enabled bool
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
AutoSync bool
|
||||
UsesAris bool
|
||||
ManagedMCPServers []ManagedMCPServer
|
||||
}
|
||||
|
||||
type ArisInput struct {
|
||||
Available bool
|
||||
BundleVersion string
|
||||
LLMChatServerPath string
|
||||
SkillCount int
|
||||
BridgeAvailable bool
|
||||
Error string
|
||||
}
|
||||
|
||||
type Request struct {
|
||||
Config Config
|
||||
AIGatewayURL string
|
||||
ConfiguredCodexCLIPath string
|
||||
CodexHome string
|
||||
OpencodeHome string
|
||||
OpenClawHome string
|
||||
Aris ArisInput
|
||||
}
|
||||
|
||||
type MountTargetState struct {
|
||||
TargetID string
|
||||
Label string
|
||||
Available bool
|
||||
SupportsSkills bool
|
||||
SupportsMCP bool
|
||||
SupportsAIGatewayInjection bool
|
||||
DiscoveryState string
|
||||
SyncState string
|
||||
DiscoveredSkillCount int
|
||||
DiscoveredMCPCount int
|
||||
ManagedMCPCount int
|
||||
Detail string
|
||||
}
|
||||
|
||||
type Result struct {
|
||||
MountTargets []MountTargetState
|
||||
ArisBundleVersion string
|
||||
ArisCompatStatus string
|
||||
}
|
||||
|
||||
func Reconcile(request Request) Result {
|
||||
states := []MountTargetState{
|
||||
reconcileAris(request.Config, request.Aris),
|
||||
reconcileCodex(
|
||||
request.Config,
|
||||
request.AIGatewayURL,
|
||||
request.ConfiguredCodexCLIPath,
|
||||
request.CodexHome,
|
||||
),
|
||||
reconcileCLIListTarget(
|
||||
request.Config,
|
||||
"claude",
|
||||
"Claude",
|
||||
[]string{"claude", "mcp", "list"},
|
||||
),
|
||||
reconcileCLIListTarget(
|
||||
request.Config,
|
||||
"gemini",
|
||||
"Gemini",
|
||||
[]string{"gemini", "mcp", "list"},
|
||||
),
|
||||
reconcileOpencode(request.Config, request.OpencodeHome),
|
||||
reconcileOpenClaw(request.Config, request.OpenClawHome),
|
||||
}
|
||||
|
||||
result := Result{
|
||||
MountTargets: states,
|
||||
ArisBundleVersion: strings.TrimSpace(request.Aris.BundleVersion),
|
||||
ArisCompatStatus: "idle",
|
||||
}
|
||||
for _, state := range states {
|
||||
if state.TargetID == "aris" {
|
||||
result.ArisCompatStatus = state.SyncState
|
||||
break
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func ResultMap(result Result) map[string]any {
|
||||
rawTargets := make([]map[string]any, 0, len(result.MountTargets))
|
||||
for _, target := range result.MountTargets {
|
||||
rawTargets = append(rawTargets, map[string]any{
|
||||
"targetId": target.TargetID,
|
||||
"label": target.Label,
|
||||
"available": target.Available,
|
||||
"supportsSkills": target.SupportsSkills,
|
||||
"supportsMcp": target.SupportsMCP,
|
||||
"supportsAiGatewayInjection": target.SupportsAIGatewayInjection,
|
||||
"discoveryState": target.DiscoveryState,
|
||||
"syncState": target.SyncState,
|
||||
"discoveredSkillCount": target.DiscoveredSkillCount,
|
||||
"discoveredMcpCount": target.DiscoveredMCPCount,
|
||||
"managedMcpCount": target.ManagedMCPCount,
|
||||
"detail": target.Detail,
|
||||
})
|
||||
}
|
||||
return map[string]any{
|
||||
"mountTargets": rawTargets,
|
||||
"arisBundleVersion": result.ArisBundleVersion,
|
||||
"arisCompatStatus": result.ArisCompatStatus,
|
||||
}
|
||||
}
|
||||
|
||||
func reconcileAris(config Config, input ArisInput) MountTargetState {
|
||||
state := placeholderState("aris", "ARIS", true, true, false)
|
||||
if strings.TrimSpace(input.Error) != "" {
|
||||
state.Available = false
|
||||
state.DiscoveryState = "error"
|
||||
state.SyncState = "error"
|
||||
state.Detail = strings.TrimSpace(input.Error)
|
||||
return state
|
||||
}
|
||||
if !input.Available {
|
||||
state.DiscoveryState = "missing"
|
||||
state.SyncState = "missing"
|
||||
state.Detail = "Embedded ARIS bundle is unavailable."
|
||||
return state
|
||||
}
|
||||
|
||||
state.Available = true
|
||||
state.DiscoveryState = "ready"
|
||||
state.DiscoveredSkillCount = input.SkillCount
|
||||
llmChatReady := strings.TrimSpace(input.LLMChatServerPath) != ""
|
||||
if config.UsesAris && llmChatReady && input.BridgeAvailable {
|
||||
state.SyncState = "ready"
|
||||
state.DiscoveredMCPCount = 1
|
||||
state.ManagedMCPCount = 1
|
||||
state.Detail = "Embedded bundle " +
|
||||
strings.TrimSpace(input.BundleVersion) +
|
||||
" ready; XWorkmate Go core manages llm-chat and claude-review."
|
||||
return state
|
||||
}
|
||||
state.SyncState = "embedded"
|
||||
if llmChatReady {
|
||||
state.DiscoveredMCPCount = 1
|
||||
}
|
||||
if llmChatReady {
|
||||
state.Detail = "Embedded bundle extracted, but the XWorkmate Go core is not available yet."
|
||||
} else {
|
||||
state.Detail = "Embedded bundle extracted, but llm-chat metadata is missing."
|
||||
}
|
||||
return state
|
||||
}
|
||||
|
||||
func reconcileCodex(
|
||||
config Config,
|
||||
aiGatewayURL string,
|
||||
configuredCodexCLIPath string,
|
||||
codexHome string,
|
||||
) MountTargetState {
|
||||
state := placeholderState("codex", "Codex", true, true, true)
|
||||
available := codexAvailable(configuredCodexCLIPath)
|
||||
configHome := strings.TrimSpace(codexHome)
|
||||
if configHome == "" {
|
||||
configHome = defaultCodexHome()
|
||||
}
|
||||
configPath := filepath.Join(configHome, "config.toml")
|
||||
content, _ := os.ReadFile(configPath)
|
||||
discovered := countMCPSections(string(content))
|
||||
managedServers := enabledCodexServers(config.ManagedMCPServers)
|
||||
if available && config.AutoSync && len(managedServers) > 0 {
|
||||
_ = applyManagedBlock(
|
||||
configPath,
|
||||
buildCodexManagedMCPBlock(managedServers),
|
||||
codexManagedMCPBlockStart,
|
||||
codexManagedMCPBlockEnd,
|
||||
)
|
||||
}
|
||||
state.Available = available
|
||||
if available {
|
||||
state.DiscoveryState = "ready"
|
||||
} else {
|
||||
state.DiscoveryState = "missing"
|
||||
}
|
||||
switch {
|
||||
case !available:
|
||||
state.SyncState = "missing"
|
||||
case config.AutoSync:
|
||||
state.SyncState = "ready"
|
||||
default:
|
||||
state.SyncState = "disabled"
|
||||
}
|
||||
state.DiscoveredMCPCount = discovered
|
||||
state.ManagedMCPCount = len(managedServers)
|
||||
if strings.TrimSpace(aiGatewayURL) != "" {
|
||||
state.Detail = "LLM API uses launch-scoped defaults for collaboration runs."
|
||||
} else {
|
||||
state.Detail = "LLM API not configured."
|
||||
}
|
||||
return state
|
||||
}
|
||||
|
||||
func reconcileCLIListTarget(
|
||||
config Config,
|
||||
targetID string,
|
||||
label string,
|
||||
command []string,
|
||||
) MountTargetState {
|
||||
state := placeholderState(targetID, label, true, true, true)
|
||||
available := binaryExists(command[0])
|
||||
discovered := 0
|
||||
if available {
|
||||
discovered = countListedEntries(command)
|
||||
}
|
||||
state.Available = available
|
||||
if available {
|
||||
state.DiscoveryState = "ready"
|
||||
} else {
|
||||
state.DiscoveryState = "missing"
|
||||
}
|
||||
if available && config.AutoSync {
|
||||
state.SyncState = "launch-only"
|
||||
} else {
|
||||
state.SyncState = "disabled"
|
||||
}
|
||||
state.DiscoveredMCPCount = discovered
|
||||
state.ManagedMCPCount = len(enabledServers(config.ManagedMCPServers))
|
||||
state.Detail = "MCP discovery uses `" + strings.Join(command, " ") +
|
||||
"`; LLM API stays launch-scoped."
|
||||
return state
|
||||
}
|
||||
|
||||
func reconcileOpencode(config Config, opencodeHome string) MountTargetState {
|
||||
state := placeholderState("opencode", "OpenCode", true, true, true)
|
||||
available := binaryExists("opencode")
|
||||
configHome := strings.TrimSpace(opencodeHome)
|
||||
if configHome == "" {
|
||||
configHome = defaultOpencodeHome()
|
||||
}
|
||||
configPath := filepath.Join(configHome, "config.toml")
|
||||
content, _ := os.ReadFile(configPath)
|
||||
discovered := countMCPSections(string(content))
|
||||
managedServers := enabledServers(config.ManagedMCPServers)
|
||||
if available && config.AutoSync && len(managedServers) > 0 {
|
||||
_ = applyManagedBlock(
|
||||
configPath,
|
||||
buildOpencodeManagedMCPBlock(managedServers),
|
||||
opencodeManagedMCPBlockStart,
|
||||
opencodeManagedMCPBlockEnd,
|
||||
)
|
||||
}
|
||||
state.Available = available
|
||||
if available {
|
||||
state.DiscoveryState = "ready"
|
||||
} else {
|
||||
state.DiscoveryState = "missing"
|
||||
}
|
||||
switch {
|
||||
case !available:
|
||||
state.SyncState = "missing"
|
||||
case config.AutoSync:
|
||||
state.SyncState = "ready"
|
||||
default:
|
||||
state.SyncState = "disabled"
|
||||
}
|
||||
state.DiscoveredMCPCount = discovered
|
||||
state.ManagedMCPCount = len(managedServers)
|
||||
state.Detail = "Managed MCP config is preserved in ~/.opencode/config.toml."
|
||||
return state
|
||||
}
|
||||
|
||||
func reconcileOpenClaw(config Config, openClawHome string) MountTargetState {
|
||||
state := placeholderState("openclaw", "OpenClaw", true, false, true)
|
||||
available := binaryExists("openclaw")
|
||||
state.Available = available
|
||||
if available {
|
||||
state.DiscoveryState = "ready"
|
||||
} else {
|
||||
state.DiscoveryState = "missing"
|
||||
}
|
||||
if available && config.AutoSync {
|
||||
state.SyncState = "launch-only"
|
||||
} else {
|
||||
state.SyncState = "disabled"
|
||||
}
|
||||
state.Detail = "OpenClaw acts as the host/control plane mount."
|
||||
|
||||
configHome := strings.TrimSpace(openClawHome)
|
||||
if configHome == "" {
|
||||
configHome = defaultOpenClawHome()
|
||||
}
|
||||
configPath := filepath.Join(configHome, "openclaw.json")
|
||||
if content, err := os.ReadFile(configPath); err == nil {
|
||||
var decoded map[string]any
|
||||
if err := json.Unmarshal(content, &decoded); err == nil {
|
||||
agents := 0
|
||||
if rawAgents, ok := decoded["agents"].(map[string]any); ok {
|
||||
if rawList, ok := rawAgents["list"].([]any); ok {
|
||||
agents = len(rawList)
|
||||
}
|
||||
}
|
||||
skillsDir := filepath.Join(configHome, "skills")
|
||||
if entries, err := os.ReadDir(skillsDir); err == nil {
|
||||
state.DiscoveredSkillCount = len(entries)
|
||||
}
|
||||
state.Detail = "agents: " + itoa(agents) + " · skills: " +
|
||||
itoa(state.DiscoveredSkillCount)
|
||||
} else {
|
||||
state.Detail = "OpenClaw config detected but could not be fully parsed."
|
||||
}
|
||||
}
|
||||
return state
|
||||
}
|
||||
|
||||
func placeholderState(
|
||||
targetID string,
|
||||
label string,
|
||||
supportsSkills bool,
|
||||
supportsMCP bool,
|
||||
supportsAIGatewayInjection bool,
|
||||
) MountTargetState {
|
||||
return MountTargetState{
|
||||
TargetID: targetID,
|
||||
Label: label,
|
||||
SupportsSkills: supportsSkills,
|
||||
SupportsMCP: supportsMCP,
|
||||
SupportsAIGatewayInjection: supportsAIGatewayInjection,
|
||||
DiscoveryState: "idle",
|
||||
SyncState: "idle",
|
||||
}
|
||||
}
|
||||
|
||||
func codexAvailable(configuredPath string) bool {
|
||||
if strings.TrimSpace(configuredPath) != "" {
|
||||
if _, err := os.Stat(strings.TrimSpace(configuredPath)); err == nil {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return binaryExists("codex")
|
||||
}
|
||||
|
||||
func binaryExists(command string) bool {
|
||||
_, err := exec.LookPath(command)
|
||||
return err == nil
|
||||
}
|
||||
|
||||
func countListedEntries(command []string) int {
|
||||
output := strings.TrimSpace(runCommand(command))
|
||||
if output == "" ||
|
||||
strings.Contains(output, "No MCP servers configured") ||
|
||||
strings.Contains(output, "No MCP servers configured yet") ||
|
||||
strings.Contains(output, "No MCP servers configured.") {
|
||||
return 0
|
||||
}
|
||||
lines := strings.Split(output, "\n")
|
||||
count := 0
|
||||
for _, line := range lines {
|
||||
trimmed := strings.TrimSpace(line)
|
||||
switch {
|
||||
case trimmed == "":
|
||||
case strings.HasPrefix(trimmed, "Usage:"):
|
||||
case strings.HasPrefix(trimmed, "┌"):
|
||||
case strings.HasPrefix(trimmed, "│"):
|
||||
case strings.HasPrefix(trimmed, "└"):
|
||||
default:
|
||||
count++
|
||||
}
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
func runCommand(command []string) string {
|
||||
if len(command) == 0 {
|
||||
return ""
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 4*time.Second)
|
||||
defer cancel()
|
||||
cmd := exec.CommandContext(ctx, command[0], command[1:]...)
|
||||
output, err := cmd.CombinedOutput()
|
||||
if err != nil && len(output) == 0 {
|
||||
return ""
|
||||
}
|
||||
return string(output)
|
||||
}
|
||||
|
||||
func enabledServers(servers []ManagedMCPServer) []ManagedMCPServer {
|
||||
filtered := make([]ManagedMCPServer, 0, len(servers))
|
||||
for _, server := range servers {
|
||||
if !server.Enabled {
|
||||
continue
|
||||
}
|
||||
filtered = append(filtered, server)
|
||||
}
|
||||
sort.SliceStable(filtered, func(i, j int) bool {
|
||||
return filtered[i].ID < filtered[j].ID
|
||||
})
|
||||
return filtered
|
||||
}
|
||||
|
||||
func enabledCodexServers(servers []ManagedMCPServer) []ManagedMCPServer {
|
||||
filtered := make([]ManagedMCPServer, 0, len(servers))
|
||||
for _, server := range servers {
|
||||
if !server.Enabled || strings.TrimSpace(server.Command) == "" {
|
||||
continue
|
||||
}
|
||||
filtered = append(filtered, server)
|
||||
}
|
||||
sort.SliceStable(filtered, func(i, j int) bool {
|
||||
return filtered[i].ID < filtered[j].ID
|
||||
})
|
||||
return filtered
|
||||
}
|
||||
|
||||
func itoa(value int) string {
|
||||
return strconv.Itoa(value)
|
||||
}
|
||||
@ -1,115 +0,0 @@
|
||||
package mounts
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestReconcileCodexAppliesManagedBlockAndPreservesUserEntries(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
configuredBinary := filepath.Join(tempDir, "custom-codex")
|
||||
if err := os.WriteFile(configuredBinary, []byte("#!/bin/sh\nexit 0\n"), 0o755); err != nil {
|
||||
t.Fatalf("write configured binary: %v", err)
|
||||
}
|
||||
configPath := filepath.Join(tempDir, "config.toml")
|
||||
if err := os.WriteFile(configPath, []byte(`
|
||||
[mcp_servers.user_server]
|
||||
command = "user-mcp"
|
||||
`), 0o644); err != nil {
|
||||
t.Fatalf("write config: %v", err)
|
||||
}
|
||||
|
||||
result := Reconcile(Request{
|
||||
Config: Config{
|
||||
AutoSync: true,
|
||||
ManagedMCPServers: []ManagedMCPServer{
|
||||
{ID: "xworkmate_server", Command: "xworkmate-mcp", Args: []string{"--port", "7777"}, Enabled: true},
|
||||
},
|
||||
},
|
||||
ConfiguredCodexCLIPath: configuredBinary,
|
||||
CodexHome: tempDir,
|
||||
})
|
||||
|
||||
content, err := os.ReadFile(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("read config: %v", err)
|
||||
}
|
||||
if !strings.Contains(string(content), `[mcp_servers.user_server]`) {
|
||||
t.Fatalf("expected user entry preserved: %s", string(content))
|
||||
}
|
||||
if !strings.Contains(string(content), `[mcp_servers.xworkmate_server]`) {
|
||||
t.Fatalf("expected managed entry written: %s", string(content))
|
||||
}
|
||||
if strings.Count(string(content), codexManagedMCPBlockStart) != 1 {
|
||||
t.Fatalf("expected single managed block: %s", string(content))
|
||||
}
|
||||
if result.MountTargets[1].ManagedMCPCount != 1 {
|
||||
t.Fatalf("expected codex managed count 1, got %d", result.MountTargets[1].ManagedMCPCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReconcileOpencodeAppliesManagedBlockAndPreservesUserEntries(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
binDir := t.TempDir()
|
||||
originalPath := os.Getenv("PATH")
|
||||
t.Setenv("PATH", binDir+string(os.PathListSeparator)+originalPath)
|
||||
if err := os.WriteFile(filepath.Join(binDir, "opencode"), []byte("#!/bin/sh\nexit 0\n"), 0o755); err != nil {
|
||||
t.Fatalf("write opencode binary: %v", err)
|
||||
}
|
||||
configPath := filepath.Join(tempDir, "config.toml")
|
||||
if err := os.WriteFile(configPath, []byte(`
|
||||
[model]
|
||||
name = "user-default"
|
||||
`), 0o644); err != nil {
|
||||
t.Fatalf("write config: %v", err)
|
||||
}
|
||||
|
||||
result := Reconcile(Request{
|
||||
Config: Config{
|
||||
AutoSync: true,
|
||||
ManagedMCPServers: []ManagedMCPServer{
|
||||
{ID: "xworkmate_server", Command: "xworkmate-mcp", Args: []string{"--port", "3001"}, Enabled: true},
|
||||
},
|
||||
},
|
||||
OpencodeHome: tempDir,
|
||||
})
|
||||
|
||||
content, err := os.ReadFile(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("read config: %v", err)
|
||||
}
|
||||
if !strings.Contains(string(content), `[model]`) {
|
||||
t.Fatalf("expected user config preserved: %s", string(content))
|
||||
}
|
||||
if !strings.Contains(string(content), `[mcp_servers.xworkmate_server]`) {
|
||||
t.Fatalf("expected managed opencode entry written: %s", string(content))
|
||||
}
|
||||
if strings.Count(string(content), opencodeManagedMCPBlockStart) != 1 {
|
||||
t.Fatalf("expected single opencode managed block: %s", string(content))
|
||||
}
|
||||
if result.MountTargets[4].ManagedMCPCount != 1 {
|
||||
t.Fatalf("expected opencode managed count 1, got %d", result.MountTargets[4].ManagedMCPCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReconcileArisReportsReadyWhenBundleAndBridgeAreAvailable(t *testing.T) {
|
||||
result := Reconcile(Request{
|
||||
Config: Config{UsesAris: true},
|
||||
Aris: ArisInput{
|
||||
Available: true,
|
||||
BundleVersion: "test",
|
||||
LLMChatServerPath: "mcp-server.py",
|
||||
SkillCount: 2,
|
||||
BridgeAvailable: true,
|
||||
},
|
||||
})
|
||||
|
||||
if got := result.MountTargets[0].SyncState; got != "ready" {
|
||||
t.Fatalf("expected ready aris state, got %q", got)
|
||||
}
|
||||
if got := result.ArisBundleVersion; got != "test" {
|
||||
t.Fatalf("expected bundle version test, got %q", got)
|
||||
}
|
||||
}
|
||||
@ -1,78 +0,0 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"xworkmate/go_core/internal/shared"
|
||||
)
|
||||
|
||||
type ClassificationRequest struct {
|
||||
Prompt string
|
||||
AIGatewayBaseURL string
|
||||
AIGatewayAPIKey string
|
||||
}
|
||||
|
||||
type Classifier interface {
|
||||
Classify(req ClassificationRequest) string
|
||||
}
|
||||
|
||||
type LLMClassifier struct{}
|
||||
|
||||
func (LLMClassifier) Classify(req ClassificationRequest) string {
|
||||
baseURL := shared.NormalizeBaseURL(strings.TrimSpace(req.AIGatewayBaseURL))
|
||||
apiKey := strings.TrimSpace(req.AIGatewayAPIKey)
|
||||
if baseURL == "" {
|
||||
baseURL = shared.NormalizeBaseURL(
|
||||
shared.EnvOrDefault("LLM_BASE_URL", "https://api.openai.com/v1"),
|
||||
)
|
||||
}
|
||||
if apiKey == "" {
|
||||
apiKey = strings.TrimSpace(shared.EnvOrDefault("LLM_API_KEY", ""))
|
||||
}
|
||||
if baseURL == "" || apiKey == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
model := strings.TrimSpace(shared.EnvOrDefault("ACP_ROUTING_MODEL", "gpt-4o"))
|
||||
if model == "" {
|
||||
model = "gpt-4o"
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 45*time.Second)
|
||||
defer cancel()
|
||||
content, err := shared.CallOpenAICompatibleCtx(
|
||||
ctx,
|
||||
baseURL,
|
||||
apiKey,
|
||||
model,
|
||||
[]map[string]string{
|
||||
{
|
||||
"role": "system",
|
||||
"content": "Classify the user task into exactly one label: single-agent, multi-agent, or gateway. Return only the label.",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": strings.TrimSpace(req.Prompt),
|
||||
},
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return normalizeClassifierLabel(content)
|
||||
}
|
||||
|
||||
func normalizeClassifierLabel(value string) string {
|
||||
normalized := strings.ToLower(strings.TrimSpace(value))
|
||||
switch {
|
||||
case strings.Contains(normalized, ExecutionTargetSingleAgent):
|
||||
return ExecutionTargetSingleAgent
|
||||
case strings.Contains(normalized, ExecutionTargetMultiAgent):
|
||||
return ExecutionTargetMultiAgent
|
||||
case strings.Contains(normalized, ExecutionTargetGateway):
|
||||
return ExecutionTargetGateway
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
@ -1,365 +0,0 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"os"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"xworkmate/go_core/internal/memory"
|
||||
"xworkmate/go_core/internal/skills"
|
||||
)
|
||||
|
||||
const (
|
||||
RoutingModeAuto = "auto"
|
||||
RoutingModeExplicit = "explicit"
|
||||
|
||||
ExecutionTargetSingleAgent = "single-agent"
|
||||
ExecutionTargetMultiAgent = "multi-agent"
|
||||
ExecutionTargetGateway = "gateway"
|
||||
ExecutionTargetGatewayChat = "gateway-chat"
|
||||
|
||||
EndpointTargetSingleAgent = "singleAgent"
|
||||
EndpointTargetLocal = "local"
|
||||
EndpointTargetRemote = "remote"
|
||||
)
|
||||
|
||||
type Request struct {
|
||||
Prompt string
|
||||
WorkingDirectory string
|
||||
RoutingMode string
|
||||
PreferredGatewayTarget string
|
||||
ExplicitExecutionTarget string
|
||||
ExplicitProviderID string
|
||||
ExplicitModel string
|
||||
ExplicitSkills []string
|
||||
AllowSkillInstall bool
|
||||
InstallApproval skills.InstallApproval
|
||||
AvailableSkills []skills.Candidate
|
||||
AvailableProviders []string
|
||||
AIGatewayBaseURL string
|
||||
AIGatewayAPIKey string
|
||||
}
|
||||
|
||||
type Result struct {
|
||||
ResolvedExecutionTarget string
|
||||
ResolvedEndpointTarget string
|
||||
ResolvedProviderID string
|
||||
ResolvedModel string
|
||||
ResolvedSkills []string
|
||||
SkillResolutionSource string
|
||||
SkillCandidates []skills.Candidate
|
||||
NeedsSkillInstall bool
|
||||
SkillInstallRequestID string
|
||||
MemorySources []memory.Source
|
||||
Unavailable bool
|
||||
UnavailableCode string
|
||||
UnavailableMessage string
|
||||
}
|
||||
|
||||
type Resolver struct {
|
||||
SkillFinder skills.Finder
|
||||
SkillInstaller skills.Installer
|
||||
MemoryService memory.Service
|
||||
Classifier Classifier
|
||||
}
|
||||
|
||||
func NewResolver() Resolver {
|
||||
homeDir, _ := os.UserHomeDir()
|
||||
return Resolver{
|
||||
SkillFinder: skills.NewDefaultFinder(),
|
||||
SkillInstaller: skills.NewDefaultInstaller(),
|
||||
MemoryService: memory.NewService(homeDir),
|
||||
Classifier: LLMClassifier{},
|
||||
}
|
||||
}
|
||||
|
||||
func (r Resolver) Resolve(req Request) Result {
|
||||
mem := r.MemoryService.Load(req.WorkingDirectory)
|
||||
availableProviders := normalizeProviders(req.AvailableProviders)
|
||||
|
||||
result := Result{
|
||||
ResolvedModel: strings.TrimSpace(req.ExplicitModel),
|
||||
MemorySources: mem.Sources,
|
||||
}
|
||||
|
||||
result.ResolvedExecutionTarget, result.ResolvedEndpointTarget = r.resolveExecution(req, mem.Preferences)
|
||||
result.ResolvedProviderID, result.Unavailable, result.UnavailableCode, result.UnavailableMessage = resolveProvider(
|
||||
req,
|
||||
mem.Preferences,
|
||||
availableProviders,
|
||||
result.ResolvedExecutionTarget,
|
||||
)
|
||||
if result.ResolvedModel == "" {
|
||||
result.ResolvedModel = strings.TrimSpace(mem.Preferences.PreferredModel)
|
||||
}
|
||||
|
||||
skillRequest := skills.ResolveRequest{
|
||||
Prompt: req.Prompt,
|
||||
ExplicitSkills: req.ExplicitSkills,
|
||||
AvailableSkills: req.AvailableSkills,
|
||||
AllowSkillInstall: req.AllowSkillInstall,
|
||||
InstallApproval: req.InstallApproval,
|
||||
}
|
||||
skillResult := skills.Resolve(skillRequest, r.SkillFinder, r.SkillInstaller)
|
||||
result.ResolvedSkills = skillResult.ResolvedSkills
|
||||
result.SkillResolutionSource = skillResult.Source
|
||||
result.SkillCandidates = skillResult.Candidates
|
||||
result.NeedsSkillInstall = skillResult.NeedsInstall
|
||||
result.SkillInstallRequestID = skillResult.InstallRequestID
|
||||
|
||||
if len(result.ResolvedSkills) == 0 && len(mem.Preferences.PreferredSkills) > 0 {
|
||||
result.ResolvedSkills = append([]string(nil), mem.Preferences.PreferredSkills...)
|
||||
if result.SkillResolutionSource == "" || result.SkillResolutionSource == "none" {
|
||||
result.SkillResolutionSource = "local_match"
|
||||
}
|
||||
}
|
||||
if result.SkillResolutionSource == "" {
|
||||
result.SkillResolutionSource = "none"
|
||||
}
|
||||
if result.ResolvedExecutionTarget == "" {
|
||||
if len(availableProviders) > 0 {
|
||||
result.ResolvedExecutionTarget = ExecutionTargetSingleAgent
|
||||
} else {
|
||||
result.ResolvedExecutionTarget = ExecutionTargetGateway
|
||||
}
|
||||
}
|
||||
if result.ResolvedEndpointTarget == "" {
|
||||
if result.ResolvedExecutionTarget == ExecutionTargetGateway {
|
||||
result.ResolvedEndpointTarget = normalizeGatewayTarget(req.PreferredGatewayTarget)
|
||||
} else {
|
||||
result.ResolvedEndpointTarget = EndpointTargetSingleAgent
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (r Resolver) resolveExecution(req Request, prefs memory.Preferences) (string, string) {
|
||||
explicit := strings.TrimSpace(req.ExplicitExecutionTarget)
|
||||
if strings.EqualFold(strings.TrimSpace(req.RoutingMode), RoutingModeExplicit) && explicit != "" {
|
||||
return mapExplicitTarget(explicit)
|
||||
}
|
||||
|
||||
prompt := normalize(req.Prompt)
|
||||
|
||||
localTask := looksLocal(prompt)
|
||||
onlineTask := looksOnline(prompt)
|
||||
complexTask := looksComplex(prompt)
|
||||
|
||||
switch {
|
||||
case localTask && complexTask:
|
||||
return ExecutionTargetMultiAgent, EndpointTargetSingleAgent
|
||||
case onlineTask && complexTask:
|
||||
return ExecutionTargetMultiAgent, EndpointTargetSingleAgent
|
||||
case localTask:
|
||||
return ExecutionTargetSingleAgent, EndpointTargetSingleAgent
|
||||
case onlineTask:
|
||||
return ExecutionTargetGateway, normalizeGatewayTarget(req.PreferredGatewayTarget)
|
||||
case complexTask:
|
||||
return ExecutionTargetMultiAgent, EndpointTargetSingleAgent
|
||||
}
|
||||
|
||||
switch normalizeExecutionTarget(r.classify(req)) {
|
||||
case ExecutionTargetGateway:
|
||||
return ExecutionTargetGateway, normalizeGatewayTarget(req.PreferredGatewayTarget)
|
||||
case ExecutionTargetMultiAgent:
|
||||
return ExecutionTargetMultiAgent, EndpointTargetSingleAgent
|
||||
case ExecutionTargetSingleAgent:
|
||||
return ExecutionTargetSingleAgent, EndpointTargetSingleAgent
|
||||
}
|
||||
|
||||
switch normalizeExecutionTarget(strings.TrimSpace(prefs.PreferredRoute)) {
|
||||
case ExecutionTargetGateway:
|
||||
return ExecutionTargetGateway, normalizeGatewayTarget(req.PreferredGatewayTarget)
|
||||
case ExecutionTargetMultiAgent:
|
||||
return ExecutionTargetMultiAgent, EndpointTargetSingleAgent
|
||||
case ExecutionTargetSingleAgent:
|
||||
if len(normalizeProviders(req.AvailableProviders)) > 0 {
|
||||
return ExecutionTargetSingleAgent, EndpointTargetSingleAgent
|
||||
}
|
||||
}
|
||||
if len(normalizeProviders(req.AvailableProviders)) > 0 {
|
||||
return ExecutionTargetSingleAgent, EndpointTargetSingleAgent
|
||||
}
|
||||
return ExecutionTargetGateway, normalizeGatewayTarget(req.PreferredGatewayTarget)
|
||||
}
|
||||
|
||||
func (r Resolver) classify(req Request) string {
|
||||
if r.Classifier == nil {
|
||||
return ""
|
||||
}
|
||||
return normalizeExecutionTarget(r.Classifier.Classify(ClassificationRequest{
|
||||
Prompt: req.Prompt,
|
||||
AIGatewayBaseURL: req.AIGatewayBaseURL,
|
||||
AIGatewayAPIKey: req.AIGatewayAPIKey,
|
||||
}))
|
||||
}
|
||||
|
||||
func mapExplicitTarget(value string) (string, string) {
|
||||
switch strings.TrimSpace(value) {
|
||||
case EndpointTargetLocal:
|
||||
return ExecutionTargetGateway, EndpointTargetLocal
|
||||
case EndpointTargetRemote:
|
||||
return ExecutionTargetGateway, EndpointTargetRemote
|
||||
case "multiAgent", ExecutionTargetMultiAgent:
|
||||
return ExecutionTargetMultiAgent, EndpointTargetSingleAgent
|
||||
case EndpointTargetSingleAgent, ExecutionTargetSingleAgent:
|
||||
return ExecutionTargetSingleAgent, EndpointTargetSingleAgent
|
||||
default:
|
||||
return ExecutionTargetSingleAgent, EndpointTargetSingleAgent
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeGatewayTarget(value string) string {
|
||||
switch strings.TrimSpace(value) {
|
||||
case EndpointTargetLocal, "":
|
||||
return EndpointTargetLocal
|
||||
default:
|
||||
return EndpointTargetRemote
|
||||
}
|
||||
}
|
||||
|
||||
func resolveProvider(
|
||||
req Request,
|
||||
prefs memory.Preferences,
|
||||
availableProviders []string,
|
||||
executionTarget string,
|
||||
) (string, bool, string, string) {
|
||||
explicitProviderID := normalize(strings.TrimSpace(req.ExplicitProviderID))
|
||||
if explicitProviderID != "" {
|
||||
if containsProvider(availableProviders, explicitProviderID) {
|
||||
return explicitProviderID, false, "", ""
|
||||
}
|
||||
return "", true, "PROVIDER_UNAVAILABLE", "explicit provider is unavailable"
|
||||
}
|
||||
|
||||
if executionTarget != ExecutionTargetSingleAgent {
|
||||
preferredProvider := normalize(strings.TrimSpace(prefs.Provider))
|
||||
if containsProvider(availableProviders, preferredProvider) {
|
||||
return preferredProvider, false, "", ""
|
||||
}
|
||||
return "", false, "", ""
|
||||
}
|
||||
|
||||
preferredProvider := normalize(strings.TrimSpace(prefs.Provider))
|
||||
if containsProvider(availableProviders, preferredProvider) {
|
||||
return preferredProvider, false, "", ""
|
||||
}
|
||||
if len(availableProviders) > 0 {
|
||||
return availableProviders[0], false, "", ""
|
||||
}
|
||||
return "", true, "PROVIDER_UNAVAILABLE", "no single-agent provider is available"
|
||||
}
|
||||
|
||||
func normalizeProviders(values []string) []string {
|
||||
if len(values) == 0 {
|
||||
return nil
|
||||
}
|
||||
unique := make(map[string]struct{}, len(values))
|
||||
normalized := make([]string, 0, len(values))
|
||||
for _, value := range values {
|
||||
providerID := normalize(value)
|
||||
if providerID == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := unique[providerID]; ok {
|
||||
continue
|
||||
}
|
||||
unique[providerID] = struct{}{}
|
||||
normalized = append(normalized, providerID)
|
||||
}
|
||||
sort.Strings(normalized)
|
||||
return normalized
|
||||
}
|
||||
|
||||
func containsProvider(values []string, want string) bool {
|
||||
want = normalize(want)
|
||||
if want == "" {
|
||||
return false
|
||||
}
|
||||
for _, value := range values {
|
||||
if normalize(value) == want {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func looksLocal(prompt string) bool {
|
||||
return containsAny(prompt, []string{
|
||||
"ppt", "pptx", "powerpoint", "word", "docx", "excel", "xlsx", "pdf",
|
||||
"image-resizer", "resize image", "compress image", "crop image",
|
||||
})
|
||||
}
|
||||
|
||||
func looksOnline(prompt string) bool {
|
||||
return containsAny(prompt, []string{
|
||||
"image-cog", "wan", "video-translator", "browser", "search", "news",
|
||||
"资讯采集", "跨浏览器", "文生图", "文生视频", "图生视频", "视频翻译",
|
||||
"translate video", "dub video", "subtitles",
|
||||
})
|
||||
}
|
||||
|
||||
func looksComplex(prompt string) bool {
|
||||
strongSignals := containsAny(prompt, []string{
|
||||
"multiple deliverables", "multiple outputs", "多个产物", "多个输出",
|
||||
"审阅", "复核", "汇编", "end-to-end", "end to end",
|
||||
})
|
||||
if strongSignals {
|
||||
return true
|
||||
}
|
||||
|
||||
reviewSignals := containsAny(prompt, []string{
|
||||
"review", "audit", "verify", "summarize", "compare",
|
||||
"审阅", "复核", "汇总", "对比", "整理", "整合", "汇编",
|
||||
})
|
||||
multiStepSignals := containsAny(prompt, []string{
|
||||
"workflow", "pipeline", "step by step", "multi-step", "collect and",
|
||||
"analyze and", "review and", "compare and", "summarize and",
|
||||
"先", "然后", "之后",
|
||||
})
|
||||
structuredOutputSignals := containsAny(prompt, []string{
|
||||
"report", "memo", "table", "spreadsheet", "document", "deck", "slides",
|
||||
"presentation", "报告", "总结", "表格", "文档", "演示",
|
||||
})
|
||||
onlineCollectionSignals := containsAny(prompt, []string{
|
||||
"browser", "search", "news", "research", "crawl", "scrape",
|
||||
"跨浏览器", "搜索", "资讯", "采集", "检索",
|
||||
})
|
||||
|
||||
score := 0
|
||||
if reviewSignals {
|
||||
score++
|
||||
}
|
||||
if multiStepSignals {
|
||||
score++
|
||||
}
|
||||
if structuredOutputSignals {
|
||||
score++
|
||||
}
|
||||
if onlineCollectionSignals && structuredOutputSignals {
|
||||
return true
|
||||
}
|
||||
return score >= 2
|
||||
}
|
||||
|
||||
func containsAny(haystack string, needles []string) bool {
|
||||
for _, needle := range needles {
|
||||
if strings.Contains(haystack, normalize(needle)) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func normalize(value string) string {
|
||||
return strings.ToLower(strings.TrimSpace(value))
|
||||
}
|
||||
|
||||
func normalizeExecutionTarget(value string) string {
|
||||
switch normalize(value) {
|
||||
case ExecutionTargetGatewayChat:
|
||||
return ExecutionTargetGateway
|
||||
default:
|
||||
return normalize(value)
|
||||
}
|
||||
}
|
||||
@ -1,136 +0,0 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"xworkmate/go_core/internal/memory"
|
||||
"xworkmate/go_core/internal/skills"
|
||||
)
|
||||
|
||||
type fakeClassifier string
|
||||
|
||||
func (f fakeClassifier) Classify(req ClassificationRequest) string {
|
||||
return string(f)
|
||||
}
|
||||
|
||||
func TestResolveExplicitTargetOverridesAuto(t *testing.T) {
|
||||
resolver := Resolver{
|
||||
SkillFinder: skills.StaticFinder{},
|
||||
SkillInstaller: nil,
|
||||
MemoryService: memory.Service{},
|
||||
}
|
||||
|
||||
result := resolver.Resolve(Request{
|
||||
Prompt: "search the web and summarize results",
|
||||
RoutingMode: RoutingModeExplicit,
|
||||
ExplicitExecutionTarget: "singleAgent",
|
||||
ExplicitProviderID: "codex",
|
||||
ExplicitModel: "gpt-5.4",
|
||||
AvailableProviders: []string{"codex"},
|
||||
})
|
||||
|
||||
if result.ResolvedExecutionTarget != ExecutionTargetSingleAgent {
|
||||
t.Fatalf("expected explicit single-agent route, got %#v", result)
|
||||
}
|
||||
if result.ResolvedEndpointTarget != EndpointTargetSingleAgent {
|
||||
t.Fatalf("expected singleAgent endpoint target, got %#v", result)
|
||||
}
|
||||
if result.ResolvedProviderID != "codex" || result.ResolvedModel != "gpt-5.4" {
|
||||
t.Fatalf("unexpected explicit provider/model: %#v", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveExplicitProviderRequiresAvailability(t *testing.T) {
|
||||
resolver := Resolver{
|
||||
SkillFinder: skills.StaticFinder{},
|
||||
SkillInstaller: nil,
|
||||
MemoryService: memory.Service{},
|
||||
}
|
||||
|
||||
result := resolver.Resolve(Request{
|
||||
Prompt: "search the web and summarize results",
|
||||
RoutingMode: RoutingModeExplicit,
|
||||
ExplicitExecutionTarget: "singleAgent",
|
||||
ExplicitProviderID: "codex",
|
||||
})
|
||||
|
||||
if !result.Unavailable {
|
||||
t.Fatalf("expected explicit provider to be unavailable without synced catalog, got %#v", result)
|
||||
}
|
||||
if result.UnavailableCode != "PROVIDER_UNAVAILABLE" {
|
||||
t.Fatalf("expected PROVIDER_UNAVAILABLE, got %#v", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveAutoLocalTaskToSingleAgent(t *testing.T) {
|
||||
resolver := Resolver{
|
||||
SkillFinder: skills.StaticFinder{},
|
||||
SkillInstaller: nil,
|
||||
MemoryService: memory.Service{},
|
||||
}
|
||||
|
||||
result := resolver.Resolve(Request{
|
||||
Prompt: "create a PowerPoint deck from this outline",
|
||||
})
|
||||
|
||||
if result.ResolvedExecutionTarget != ExecutionTargetSingleAgent {
|
||||
t.Fatalf("expected single-agent route, got %#v", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveAutoOnlineTaskToGateway(t *testing.T) {
|
||||
resolver := Resolver{
|
||||
SkillFinder: skills.StaticFinder{},
|
||||
SkillInstaller: nil,
|
||||
MemoryService: memory.Service{},
|
||||
}
|
||||
|
||||
result := resolver.Resolve(Request{
|
||||
Prompt: "跨浏览器执行并搜索最新资讯",
|
||||
PreferredGatewayTarget: EndpointTargetLocal,
|
||||
})
|
||||
|
||||
if result.ResolvedExecutionTarget != ExecutionTargetGateway {
|
||||
t.Fatalf("expected gateway route, got %#v", result)
|
||||
}
|
||||
if result.ResolvedEndpointTarget != EndpointTargetLocal {
|
||||
t.Fatalf("expected local gateway target, got %#v", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveComplexTaskUpgradesToMultiAgent(t *testing.T) {
|
||||
resolver := Resolver{
|
||||
SkillFinder: skills.StaticFinder{},
|
||||
SkillInstaller: nil,
|
||||
MemoryService: memory.Service{},
|
||||
}
|
||||
|
||||
result := resolver.Resolve(Request{
|
||||
Prompt: "analyze these files, review the output, and summarize multiple deliverables",
|
||||
})
|
||||
|
||||
if result.ResolvedExecutionTarget != ExecutionTargetMultiAgent {
|
||||
t.Fatalf("expected multi-agent route, got %#v", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveUsesClassifierForBoundarySamples(t *testing.T) {
|
||||
resolver := Resolver{
|
||||
SkillFinder: skills.StaticFinder{},
|
||||
SkillInstaller: nil,
|
||||
MemoryService: memory.Service{},
|
||||
Classifier: fakeClassifier(ExecutionTargetGateway),
|
||||
}
|
||||
|
||||
result := resolver.Resolve(Request{
|
||||
Prompt: "help me handle this ambiguous request",
|
||||
PreferredGatewayTarget: EndpointTargetLocal,
|
||||
})
|
||||
|
||||
if result.ResolvedExecutionTarget != ExecutionTargetGateway {
|
||||
t.Fatalf("expected classifier to resolve gateway route, got %#v", result)
|
||||
}
|
||||
if result.ResolvedEndpointTarget != EndpointTargetLocal {
|
||||
t.Fatalf("expected local endpoint target, got %#v", result)
|
||||
}
|
||||
}
|
||||
@ -1,37 +0,0 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var ErrInvalidCredentials = errors.New("invalid credentials")
|
||||
|
||||
type AuthRepository interface {
|
||||
Verify(ctx context.Context, username, password string) (bool, error)
|
||||
}
|
||||
|
||||
type AuthService struct {
|
||||
repo AuthRepository
|
||||
}
|
||||
|
||||
func NewAuthService(repo AuthRepository) *AuthService {
|
||||
return &AuthService{repo: repo}
|
||||
}
|
||||
|
||||
func (s *AuthService) Authenticate(ctx context.Context, username, password string) error {
|
||||
username = strings.TrimSpace(username)
|
||||
password = strings.TrimSpace(password)
|
||||
if username == "" || password == "" {
|
||||
return ErrInvalidCredentials
|
||||
}
|
||||
ok, err := s.repo.Verify(ctx, username, password)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !ok {
|
||||
return ErrInvalidCredentials
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@ -1,55 +0,0 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type fakeAuthRepo struct {
|
||||
verify func(ctx context.Context, username, password string) (bool, error)
|
||||
}
|
||||
|
||||
func (f fakeAuthRepo) Verify(ctx context.Context, username, password string) (bool, error) {
|
||||
return f.verify(ctx, username, password)
|
||||
}
|
||||
|
||||
func TestAuthenticateRejectsBlankValues(t *testing.T) {
|
||||
svc := NewAuthService(fakeAuthRepo{
|
||||
verify: func(ctx context.Context, username, password string) (bool, error) {
|
||||
return true, nil
|
||||
},
|
||||
})
|
||||
|
||||
if err := svc.Authenticate(context.Background(), " ", "secret"); !errors.Is(err, ErrInvalidCredentials) {
|
||||
t.Fatalf("expected invalid credentials, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthenticateRejectsFailedVerification(t *testing.T) {
|
||||
svc := NewAuthService(fakeAuthRepo{
|
||||
verify: func(ctx context.Context, username, password string) (bool, error) {
|
||||
if username != "alice" || password != "secret" {
|
||||
t.Fatalf("unexpected credentials: %q %q", username, password)
|
||||
}
|
||||
return false, nil
|
||||
},
|
||||
})
|
||||
|
||||
if err := svc.Authenticate(context.Background(), "alice", "secret"); !errors.Is(err, ErrInvalidCredentials) {
|
||||
t.Fatalf("expected invalid credentials, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthenticateReturnsRepoError(t *testing.T) {
|
||||
wanted := errors.New("boom")
|
||||
svc := NewAuthService(fakeAuthRepo{
|
||||
verify: func(ctx context.Context, username, password string) (bool, error) {
|
||||
return false, wanted
|
||||
},
|
||||
})
|
||||
|
||||
if err := svc.Authenticate(context.Background(), "alice", "secret"); !errors.Is(err, wanted) {
|
||||
t.Fatalf("expected repo error, got %v", err)
|
||||
}
|
||||
}
|
||||
@ -1,81 +0,0 @@
|
||||
package shared
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func NormalizeBaseURL(raw string) string {
|
||||
trimmed := strings.TrimSpace(raw)
|
||||
if trimmed == "" {
|
||||
return "https://api.openai.com/v1"
|
||||
}
|
||||
if strings.HasSuffix(trimmed, "/v1") {
|
||||
return trimmed
|
||||
}
|
||||
return strings.TrimRight(trimmed, "/") + "/v1"
|
||||
}
|
||||
|
||||
func EnvOrDefault(key, fallback string) string {
|
||||
value := strings.TrimSpace(os.Getenv(key))
|
||||
if value == "" {
|
||||
return fallback
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
func StringArg(arguments map[string]any, key, fallback string) string {
|
||||
if arguments == nil {
|
||||
return fallback
|
||||
}
|
||||
value, ok := arguments[key]
|
||||
if !ok {
|
||||
return fallback
|
||||
}
|
||||
text := strings.TrimSpace(fmt.Sprint(value))
|
||||
if text == "" || text == "<nil>" {
|
||||
return fallback
|
||||
}
|
||||
return text
|
||||
}
|
||||
|
||||
func ListArg(arguments map[string]any, key string) []any {
|
||||
if arguments == nil {
|
||||
return nil
|
||||
}
|
||||
raw, ok := arguments[key]
|
||||
if !ok || raw == nil {
|
||||
return nil
|
||||
}
|
||||
if values, ok := raw.([]any); ok {
|
||||
return values
|
||||
}
|
||||
if values, ok := raw.([]interface{}); ok {
|
||||
return values
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func IntArg(raw string, fallback int) int {
|
||||
var parsed int
|
||||
if _, err := fmt.Sscanf(raw, "%d", &parsed); err != nil || parsed <= 0 {
|
||||
return fallback
|
||||
}
|
||||
return parsed
|
||||
}
|
||||
|
||||
func BoolArg(raw string, fallback bool) bool {
|
||||
trimmed := strings.TrimSpace(strings.ToLower(raw))
|
||||
if trimmed == "" {
|
||||
return fallback
|
||||
}
|
||||
switch trimmed {
|
||||
case "1", "true", "yes", "on":
|
||||
return true
|
||||
case "0", "false", "no", "off":
|
||||
return false
|
||||
default:
|
||||
return fallback
|
||||
}
|
||||
}
|
||||
@ -1,108 +0,0 @@
|
||||
package shared
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type RPCRequest struct {
|
||||
JSONRPC string `json:"jsonrpc,omitempty"`
|
||||
ID any `json:"id,omitempty"`
|
||||
Method string `json:"method,omitempty"`
|
||||
Params map[string]any `json:"params,omitempty"`
|
||||
}
|
||||
|
||||
type RPCError struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
type ToolCallParams struct {
|
||||
Name string `json:"name"`
|
||||
Arguments map[string]any `json:"arguments"`
|
||||
}
|
||||
|
||||
func DecodeRPCRequest(payload []byte) (RPCRequest, error) {
|
||||
var request RPCRequest
|
||||
if err := json.Unmarshal(payload, &request); err != nil {
|
||||
return RPCRequest{}, fmt.Errorf("invalid json: %w", err)
|
||||
}
|
||||
if strings.TrimSpace(request.Method) == "" {
|
||||
return RPCRequest{}, errors.New("missing method")
|
||||
}
|
||||
if request.Params == nil {
|
||||
request.Params = map[string]any{}
|
||||
}
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func WriteSSE(w http.ResponseWriter, payload map[string]any) {
|
||||
encoded, _ := json.Marshal(payload)
|
||||
_, _ = fmt.Fprintf(w, "data: %s\n\n", encoded)
|
||||
}
|
||||
|
||||
func ResultEnvelope(id any, result map[string]any) map[string]any {
|
||||
return map[string]any{
|
||||
"jsonrpc": "2.0",
|
||||
"id": id,
|
||||
"result": result,
|
||||
}
|
||||
}
|
||||
|
||||
func ErrorEnvelope(id any, code int, message string) map[string]any {
|
||||
return map[string]any{
|
||||
"jsonrpc": "2.0",
|
||||
"id": id,
|
||||
"error": map[string]any{
|
||||
"code": code,
|
||||
"message": message,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func NotificationEnvelope(method string, params map[string]any) map[string]any {
|
||||
return map[string]any{
|
||||
"jsonrpc": "2.0",
|
||||
"method": method,
|
||||
"params": params,
|
||||
}
|
||||
}
|
||||
|
||||
func ErrorResponse(id any, code int, message string) map[string]any {
|
||||
return map[string]any{
|
||||
"jsonrpc": "2.0",
|
||||
"id": id,
|
||||
"error": map[string]any{
|
||||
"code": code,
|
||||
"message": message,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func ToolTextResult(id any, content string) map[string]any {
|
||||
return map[string]any{
|
||||
"jsonrpc": "2.0",
|
||||
"id": id,
|
||||
"result": map[string]any{
|
||||
"content": []map[string]any{
|
||||
{"type": "text", "text": content},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func ToolErrorResult(id any, err error) map[string]any {
|
||||
return map[string]any{
|
||||
"jsonrpc": "2.0",
|
||||
"id": id,
|
||||
"result": map[string]any{
|
||||
"content": []map[string]any{
|
||||
{"type": "text", "text": fmt.Sprintf("Error: %v", err)},
|
||||
},
|
||||
"isError": true,
|
||||
},
|
||||
}
|
||||
}
|
||||
@ -1,397 +0,0 @@
|
||||
package shared
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os/exec"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
func DetectACPProviders() []string {
|
||||
candidates := []struct {
|
||||
provider string
|
||||
envKey string
|
||||
binary string
|
||||
}{
|
||||
{provider: "codex", envKey: "ACP_CODEX_BIN", binary: "codex"},
|
||||
{provider: "opencode", envKey: "ACP_OPENCODE_BIN", binary: "opencode"},
|
||||
{provider: "claude", envKey: "ACP_CLAUDE_BIN", binary: "claude"},
|
||||
{provider: "gemini", envKey: "ACP_GEMINI_BIN", binary: "gemini"},
|
||||
}
|
||||
providers := make([]string, 0, len(candidates))
|
||||
for _, candidate := range candidates {
|
||||
binary := strings.TrimSpace(EnvOrDefault(candidate.envKey, candidate.binary))
|
||||
if binary == "" {
|
||||
continue
|
||||
}
|
||||
if _, err := exec.LookPath(binary); err == nil {
|
||||
providers = append(providers, candidate.provider)
|
||||
}
|
||||
}
|
||||
sort.Strings(providers)
|
||||
return providers
|
||||
}
|
||||
|
||||
func RunProviderCommand(
|
||||
ctx context.Context,
|
||||
provider,
|
||||
model,
|
||||
prompt,
|
||||
workingDirectory string,
|
||||
) (string, error) {
|
||||
command, args := ResolveProviderCommand(
|
||||
provider,
|
||||
model,
|
||||
prompt,
|
||||
workingDirectory,
|
||||
)
|
||||
if command == "" {
|
||||
return "", fmt.Errorf("unsupported provider: %s", provider)
|
||||
}
|
||||
cmd := exec.CommandContext(ctx, command, args...)
|
||||
if strings.TrimSpace(workingDirectory) != "" {
|
||||
cmd.Dir = strings.TrimSpace(workingDirectory)
|
||||
}
|
||||
var stdout bytes.Buffer
|
||||
var stderr bytes.Buffer
|
||||
cmd.Stdout = &stdout
|
||||
cmd.Stderr = &stderr
|
||||
|
||||
if err := cmd.Run(); err != nil {
|
||||
if errors.Is(ctx.Err(), context.Canceled) {
|
||||
return "", errors.New("run canceled")
|
||||
}
|
||||
message := strings.TrimSpace(stderr.String())
|
||||
if message == "" {
|
||||
message = err.Error()
|
||||
}
|
||||
return "", fmt.Errorf("%s run failed: %s", provider, message)
|
||||
}
|
||||
output := strings.TrimSpace(stdout.String())
|
||||
if output == "" {
|
||||
output = strings.TrimSpace(stderr.String())
|
||||
}
|
||||
if output == "" {
|
||||
return "", fmt.Errorf("%s returned empty output", provider)
|
||||
}
|
||||
return output, nil
|
||||
}
|
||||
|
||||
func ResolveProviderCommand(
|
||||
provider,
|
||||
model,
|
||||
prompt,
|
||||
cwd string,
|
||||
) (string, []string) {
|
||||
switch strings.TrimSpace(strings.ToLower(provider)) {
|
||||
case "codex":
|
||||
binary := strings.TrimSpace(EnvOrDefault("ACP_CODEX_BIN", "codex"))
|
||||
args := []string{"exec", "--skip-git-repo-check", "--color", "never"}
|
||||
if strings.TrimSpace(cwd) != "" {
|
||||
args = append(args, "-C", strings.TrimSpace(cwd))
|
||||
}
|
||||
if strings.TrimSpace(model) != "" {
|
||||
args = append(args, "-m", strings.TrimSpace(model))
|
||||
}
|
||||
args = append(args, prompt)
|
||||
return binary, args
|
||||
case "opencode":
|
||||
binary := strings.TrimSpace(EnvOrDefault("ACP_OPENCODE_BIN", "opencode"))
|
||||
args := []string{"run", "--format", "default"}
|
||||
if strings.TrimSpace(cwd) != "" {
|
||||
args = append(args, "--dir", strings.TrimSpace(cwd))
|
||||
}
|
||||
if strings.TrimSpace(model) != "" {
|
||||
args = append(args, "-m", strings.TrimSpace(model))
|
||||
}
|
||||
args = append(args, prompt)
|
||||
return binary, args
|
||||
case "claude":
|
||||
binary := strings.TrimSpace(EnvOrDefault("ACP_CLAUDE_BIN", "claude"))
|
||||
if strings.TrimSpace(model) == "" {
|
||||
return binary, []string{"-p", prompt}
|
||||
}
|
||||
return binary, []string{
|
||||
"--model",
|
||||
strings.TrimSpace(model),
|
||||
"-p",
|
||||
prompt,
|
||||
}
|
||||
case "gemini":
|
||||
binary := strings.TrimSpace(EnvOrDefault("ACP_GEMINI_BIN", "gemini"))
|
||||
if strings.TrimSpace(model) == "" {
|
||||
return binary, []string{"-p", prompt}
|
||||
}
|
||||
return binary, []string{
|
||||
"--model",
|
||||
strings.TrimSpace(model),
|
||||
"-p",
|
||||
prompt,
|
||||
}
|
||||
default:
|
||||
return "", nil
|
||||
}
|
||||
}
|
||||
|
||||
func AugmentPromptWithAttachments(prompt string, params map[string]any) string {
|
||||
attachmentsRaw := ListArg(params, "attachments")
|
||||
if len(attachmentsRaw) == 0 {
|
||||
return prompt
|
||||
}
|
||||
lines := make([]string, 0, len(attachmentsRaw))
|
||||
for _, raw := range attachmentsRaw {
|
||||
entry, ok := raw.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
name := strings.TrimSpace(StringArg(entry, "name", "attachment"))
|
||||
path := strings.TrimSpace(StringArg(entry, "path", ""))
|
||||
if path == "" {
|
||||
continue
|
||||
}
|
||||
lines = append(lines, fmt.Sprintf("- %s: %s", name, path))
|
||||
}
|
||||
if len(lines) == 0 {
|
||||
return prompt
|
||||
}
|
||||
var builder strings.Builder
|
||||
builder.WriteString("User-selected local attachments:\n")
|
||||
builder.WriteString(strings.Join(lines, "\n"))
|
||||
builder.WriteString("\n\n")
|
||||
builder.WriteString(prompt)
|
||||
return builder.String()
|
||||
}
|
||||
|
||||
func ComposeHistoryPrompt(history []string) string {
|
||||
if len(history) == 0 {
|
||||
return ""
|
||||
}
|
||||
var builder strings.Builder
|
||||
for index, turn := range history {
|
||||
builder.WriteString(fmt.Sprintf("## User Turn %d\n", index+1))
|
||||
builder.WriteString(turn)
|
||||
builder.WriteString("\n\n")
|
||||
}
|
||||
return strings.TrimSpace(builder.String())
|
||||
}
|
||||
|
||||
func CallOpenAICompatibleCtx(
|
||||
ctx context.Context,
|
||||
baseURL,
|
||||
apiKey,
|
||||
model string,
|
||||
messages []map[string]string,
|
||||
) (string, error) {
|
||||
payload := map[string]any{
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"max_tokens": 4096,
|
||||
"stream": false,
|
||||
}
|
||||
body, _ := json.Marshal(payload)
|
||||
request, err := http.NewRequestWithContext(
|
||||
ctx,
|
||||
http.MethodPost,
|
||||
strings.TrimRight(baseURL, "/")+"/chat/completions",
|
||||
bytes.NewReader(body),
|
||||
)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
request.Header.Set("Content-Type", "application/json")
|
||||
request.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
|
||||
client := &http.Client{Timeout: 120 * time.Second}
|
||||
response, err := client.Do(request)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer response.Body.Close()
|
||||
responseBody, err := io.ReadAll(response.Body)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if response.StatusCode < 200 || response.StatusCode >= 300 {
|
||||
return "", fmt.Errorf(
|
||||
"api error %d: %s",
|
||||
response.StatusCode,
|
||||
strings.TrimSpace(string(responseBody)),
|
||||
)
|
||||
}
|
||||
|
||||
var decoded map[string]any
|
||||
if err := json.Unmarshal(responseBody, &decoded); err != nil {
|
||||
return "", err
|
||||
}
|
||||
choices, _ := decoded["choices"].([]any)
|
||||
if len(choices) == 0 {
|
||||
return "", errors.New("missing choices in response")
|
||||
}
|
||||
choice, _ := choices[0].(map[string]any)
|
||||
message, _ := choice["message"].(map[string]any)
|
||||
content := strings.TrimSpace(fmt.Sprint(message["content"]))
|
||||
if content == "" || content == "<nil>" {
|
||||
return "", errors.New("empty response content")
|
||||
}
|
||||
return content, nil
|
||||
}
|
||||
|
||||
func HandleChatTool(arguments map[string]any) (string, error) {
|
||||
apiKey := strings.TrimSpace(EnvOrDefault("LLM_API_KEY", ""))
|
||||
if apiKey == "" {
|
||||
return "", errors.New("LLM_API_KEY environment variable not set")
|
||||
}
|
||||
baseURL := NormalizeBaseURL(
|
||||
EnvOrDefault("LLM_BASE_URL", "https://api.openai.com/v1"),
|
||||
)
|
||||
model := StringArg(arguments, "model", EnvOrDefault("LLM_MODEL", "gpt-4o"))
|
||||
prompt := strings.TrimSpace(StringArg(arguments, "prompt", ""))
|
||||
if prompt == "" {
|
||||
return "", errors.New("prompt is required")
|
||||
}
|
||||
system := strings.TrimSpace(StringArg(arguments, "system", ""))
|
||||
|
||||
messages := make([]map[string]string, 0, 2)
|
||||
if system != "" {
|
||||
messages = append(messages, map[string]string{
|
||||
"role": "system",
|
||||
"content": system,
|
||||
})
|
||||
}
|
||||
messages = append(messages, map[string]string{
|
||||
"role": "user",
|
||||
"content": prompt,
|
||||
})
|
||||
return CallOpenAICompatible(baseURL, apiKey, model, messages)
|
||||
}
|
||||
|
||||
func HandleClaudeReviewTool(arguments map[string]any) (string, error) {
|
||||
prompt := strings.TrimSpace(StringArg(arguments, "prompt", ""))
|
||||
if prompt == "" {
|
||||
return "", errors.New("prompt is required")
|
||||
}
|
||||
model := strings.TrimSpace(
|
||||
StringArg(arguments, "model", EnvOrDefault("CLAUDE_REVIEW_MODEL", "")),
|
||||
)
|
||||
system := strings.TrimSpace(
|
||||
StringArg(arguments, "system", EnvOrDefault("CLAUDE_REVIEW_SYSTEM", "")),
|
||||
)
|
||||
tools := strings.TrimSpace(
|
||||
StringArg(arguments, "tools", EnvOrDefault("CLAUDE_REVIEW_TOOLS", "")),
|
||||
)
|
||||
timeout := IntArg(EnvOrDefault("CLAUDE_REVIEW_TIMEOUT_SEC", "600"), 600)
|
||||
return RunClaudeReview(
|
||||
prompt,
|
||||
model,
|
||||
system,
|
||||
tools,
|
||||
time.Duration(timeout)*time.Second,
|
||||
)
|
||||
}
|
||||
|
||||
func CallOpenAICompatible(
|
||||
baseURL,
|
||||
apiKey,
|
||||
model string,
|
||||
messages []map[string]string,
|
||||
) (string, error) {
|
||||
return CallOpenAICompatibleCtx(
|
||||
context.Background(),
|
||||
baseURL,
|
||||
apiKey,
|
||||
model,
|
||||
messages,
|
||||
)
|
||||
}
|
||||
|
||||
func RunClaudeReview(
|
||||
prompt,
|
||||
model,
|
||||
system,
|
||||
tools string,
|
||||
timeout time.Duration,
|
||||
) (string, error) {
|
||||
claudeBin := strings.TrimSpace(EnvOrDefault("CLAUDE_BIN", "claude"))
|
||||
resolved, err := exec.LookPath(claudeBin)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("Claude CLI not found: %s", claudeBin)
|
||||
}
|
||||
|
||||
args := []string{
|
||||
"-p",
|
||||
prompt,
|
||||
"--output-format",
|
||||
"json",
|
||||
"--permission-mode",
|
||||
"plan",
|
||||
}
|
||||
if model != "" {
|
||||
args = append(args, "--model", model)
|
||||
}
|
||||
if system != "" {
|
||||
args = append(args, "--system-prompt", system)
|
||||
}
|
||||
if tools != "" {
|
||||
args = append(args, "--tools", tools)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
defer cancel()
|
||||
|
||||
cmd := exec.CommandContext(ctx, resolved, args...)
|
||||
cmd.Stdin = nil
|
||||
var stdout bytes.Buffer
|
||||
var stderr bytes.Buffer
|
||||
cmd.Stdout = &stdout
|
||||
cmd.Stderr = &stderr
|
||||
|
||||
if err := cmd.Run(); err != nil {
|
||||
if errors.Is(ctx.Err(), context.DeadlineExceeded) {
|
||||
return "", fmt.Errorf("Claude review timed out after %s", timeout)
|
||||
}
|
||||
message := strings.TrimSpace(stderr.String())
|
||||
if message == "" {
|
||||
message = err.Error()
|
||||
}
|
||||
return "", fmt.Errorf("Claude review failed: %s", message)
|
||||
}
|
||||
|
||||
payload, err := ParseClaudeJSON(stdout.String())
|
||||
if err != nil {
|
||||
message := strings.TrimSpace(stderr.String())
|
||||
if message != "" {
|
||||
return "", fmt.Errorf("%v. stderr: %s", err, message)
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
if isError, _ := payload["is_error"].(bool); isError {
|
||||
return "", fmt.Errorf("%v", payload["result"])
|
||||
}
|
||||
response := strings.TrimSpace(fmt.Sprint(payload["result"]))
|
||||
if response == "" || response == "<nil>" {
|
||||
return "", errors.New("Claude review returned empty output")
|
||||
}
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func ParseClaudeJSON(raw string) (map[string]any, error) {
|
||||
lines := strings.Split(raw, "\n")
|
||||
for i := len(lines) - 1; i >= 0; i-- {
|
||||
candidate := strings.TrimSpace(lines[i])
|
||||
if candidate == "" {
|
||||
continue
|
||||
}
|
||||
var payload map[string]any
|
||||
if err := json.Unmarshal([]byte(candidate), &payload); err == nil {
|
||||
return payload, nil
|
||||
}
|
||||
}
|
||||
return nil, errors.New("Claude CLI did not return JSON output")
|
||||
}
|
||||
@ -1,325 +0,0 @@
|
||||
package shared
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type VaultKVResult struct {
|
||||
Operation string `json:"operation"`
|
||||
Mount string `json:"mount"`
|
||||
Path string `json:"path"`
|
||||
Data map[string]any `json:"data,omitempty"`
|
||||
Keys []string `json:"keys,omitempty"`
|
||||
Metadata map[string]any `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
func HandleVaultKVTool(arguments map[string]any) (string, error) {
|
||||
request, err := buildVaultKVRequest(arguments)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
result, err := executeVaultKVRequest(request)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
encoded, err := json.MarshalIndent(result, "", " ")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(encoded), nil
|
||||
}
|
||||
|
||||
type vaultKVRequest struct {
|
||||
baseURL string
|
||||
token string
|
||||
namespace string
|
||||
operation string
|
||||
mount string
|
||||
path string
|
||||
data map[string]any
|
||||
cas int
|
||||
}
|
||||
|
||||
func buildVaultKVRequest(arguments map[string]any) (vaultKVRequest, error) {
|
||||
baseURL := strings.TrimSpace(EnvOrDefault("VAULT_SERVER_URL", ""))
|
||||
if baseURL == "" {
|
||||
return vaultKVRequest{}, errors.New("VAULT_SERVER_URL environment variable not set")
|
||||
}
|
||||
token := strings.TrimSpace(EnvOrDefault("VAULT_SERVER_ROOT_ACCESS_TOKEN", ""))
|
||||
if token == "" {
|
||||
return vaultKVRequest{}, errors.New("VAULT_SERVER_ROOT_ACCESS_TOKEN environment variable not set")
|
||||
}
|
||||
operation := strings.ToLower(strings.TrimSpace(StringArg(arguments, "operation", "")))
|
||||
if operation == "" {
|
||||
return vaultKVRequest{}, errors.New("operation is required")
|
||||
}
|
||||
path := normalizeVaultPath(StringArg(arguments, "path", ""))
|
||||
if path == "" {
|
||||
return vaultKVRequest{}, errors.New("path is required")
|
||||
}
|
||||
data, err := vaultDataArg(arguments["data"])
|
||||
if err != nil {
|
||||
return vaultKVRequest{}, err
|
||||
}
|
||||
return vaultKVRequest{
|
||||
baseURL: strings.TrimRight(baseURL, "/"),
|
||||
token: token,
|
||||
namespace: strings.TrimSpace(EnvOrDefault("VAULT_NAMESPACE", "")),
|
||||
operation: operation,
|
||||
mount: normalizeVaultMount(StringArg(arguments, "mount", "secret")),
|
||||
path: path,
|
||||
data: data,
|
||||
cas: vaultCASArg(arguments["cas"]),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func executeVaultKVRequest(request vaultKVRequest) (VaultKVResult, error) {
|
||||
switch request.operation {
|
||||
case "get", "read":
|
||||
return vaultKVRead(request)
|
||||
case "put", "write":
|
||||
return vaultKVWrite(request)
|
||||
case "list":
|
||||
return vaultKVList(request)
|
||||
case "delete":
|
||||
return vaultKVDelete(request)
|
||||
default:
|
||||
return VaultKVResult{}, fmt.Errorf("unsupported operation: %s", request.operation)
|
||||
}
|
||||
}
|
||||
|
||||
func vaultKVRead(request vaultKVRequest) (VaultKVResult, error) {
|
||||
response, err := doVaultRequest(
|
||||
request,
|
||||
http.MethodGet,
|
||||
vaultDataURL(request.mount, request.path),
|
||||
nil,
|
||||
)
|
||||
if err != nil {
|
||||
return VaultKVResult{}, err
|
||||
}
|
||||
dataBlock := mapArg(response["data"])
|
||||
return VaultKVResult{
|
||||
Operation: "read",
|
||||
Mount: request.mount,
|
||||
Path: request.path,
|
||||
Data: mapArg(dataBlock["data"]),
|
||||
Metadata: mapArg(dataBlock["metadata"]),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func vaultKVWrite(request vaultKVRequest) (VaultKVResult, error) {
|
||||
if len(request.data) == 0 {
|
||||
return VaultKVResult{}, errors.New("data is required for write operations")
|
||||
}
|
||||
payload := map[string]any{"data": request.data}
|
||||
if request.cas > 0 {
|
||||
payload["options"] = map[string]any{"cas": request.cas}
|
||||
}
|
||||
response, err := doVaultRequest(
|
||||
request,
|
||||
http.MethodPost,
|
||||
vaultDataURL(request.mount, request.path),
|
||||
payload,
|
||||
)
|
||||
if err != nil {
|
||||
return VaultKVResult{}, err
|
||||
}
|
||||
return VaultKVResult{
|
||||
Operation: "write",
|
||||
Mount: request.mount,
|
||||
Path: request.path,
|
||||
Data: request.data,
|
||||
Metadata: mapArg(mapArg(response["data"])["metadata"]),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func vaultKVList(request vaultKVRequest) (VaultKVResult, error) {
|
||||
response, err := doVaultRequest(
|
||||
request,
|
||||
"LIST",
|
||||
vaultMetadataURL(request.mount, request.path),
|
||||
nil,
|
||||
)
|
||||
if err != nil {
|
||||
return VaultKVResult{}, err
|
||||
}
|
||||
dataBlock := mapArg(response["data"])
|
||||
return VaultKVResult{
|
||||
Operation: "list",
|
||||
Mount: request.mount,
|
||||
Path: request.path,
|
||||
Keys: stringSliceArg(dataBlock["keys"]),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func vaultKVDelete(request vaultKVRequest) (VaultKVResult, error) {
|
||||
_, err := doVaultRequest(
|
||||
request,
|
||||
http.MethodDelete,
|
||||
vaultDataURL(request.mount, request.path),
|
||||
nil,
|
||||
)
|
||||
if err != nil {
|
||||
return VaultKVResult{}, err
|
||||
}
|
||||
return VaultKVResult{
|
||||
Operation: "delete",
|
||||
Mount: request.mount,
|
||||
Path: request.path,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func doVaultRequest(
|
||||
request vaultKVRequest,
|
||||
method string,
|
||||
target string,
|
||||
payload map[string]any,
|
||||
) (map[string]any, error) {
|
||||
var body io.Reader
|
||||
if payload != nil {
|
||||
encoded, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
body = bytes.NewReader(encoded)
|
||||
}
|
||||
httpRequest, err := http.NewRequest(method, target, body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
httpRequest.Header.Set("X-Vault-Token", request.token)
|
||||
if request.namespace != "" {
|
||||
httpRequest.Header.Set("X-Vault-Namespace", request.namespace)
|
||||
}
|
||||
if payload != nil {
|
||||
httpRequest.Header.Set("Content-Type", "application/json")
|
||||
}
|
||||
client := &http.Client{Timeout: 30 * time.Second}
|
||||
response, err := client.Do(httpRequest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer response.Body.Close()
|
||||
bodyBytes, err := io.ReadAll(response.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if response.StatusCode < 200 || response.StatusCode >= 300 {
|
||||
return nil, fmt.Errorf(
|
||||
"vault api error %d: %s",
|
||||
response.StatusCode,
|
||||
strings.TrimSpace(string(bodyBytes)),
|
||||
)
|
||||
}
|
||||
if len(strings.TrimSpace(string(bodyBytes))) == 0 {
|
||||
return map[string]any{}, nil
|
||||
}
|
||||
var decoded map[string]any
|
||||
if err := json.Unmarshal(bodyBytes, &decoded); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return decoded, nil
|
||||
}
|
||||
|
||||
func vaultDataURL(mount, path string) string {
|
||||
return fmt.Sprintf("%s/data/%s", vaultBasePath(mount), vaultPathSegments(path))
|
||||
}
|
||||
|
||||
func vaultMetadataURL(mount, path string) string {
|
||||
return fmt.Sprintf("%s/metadata/%s", vaultBasePath(mount), vaultPathSegments(path))
|
||||
}
|
||||
|
||||
func vaultBasePath(mount string) string {
|
||||
return fmt.Sprintf("%s/v1/%s", strings.TrimRight(strings.TrimSpace(EnvOrDefault("VAULT_SERVER_URL", "")), "/"), url.PathEscape(normalizeVaultMount(mount)))
|
||||
}
|
||||
|
||||
func vaultPathSegments(path string) string {
|
||||
segments := strings.Split(normalizeVaultPath(path), "/")
|
||||
for index, segment := range segments {
|
||||
segments[index] = url.PathEscape(segment)
|
||||
}
|
||||
return strings.Join(segments, "/")
|
||||
}
|
||||
|
||||
func normalizeVaultMount(raw string) string {
|
||||
trimmed := strings.Trim(strings.TrimSpace(raw), "/")
|
||||
if trimmed == "" {
|
||||
return "secret"
|
||||
}
|
||||
return trimmed
|
||||
}
|
||||
|
||||
func normalizeVaultPath(raw string) string {
|
||||
return strings.Trim(strings.TrimSpace(raw), "/")
|
||||
}
|
||||
|
||||
func vaultDataArg(raw any) (map[string]any, error) {
|
||||
if raw == nil {
|
||||
return nil, nil
|
||||
}
|
||||
switch typed := raw.(type) {
|
||||
case map[string]any:
|
||||
return typed, nil
|
||||
case string:
|
||||
trimmed := strings.TrimSpace(typed)
|
||||
if trimmed == "" {
|
||||
return nil, nil
|
||||
}
|
||||
var decoded map[string]any
|
||||
if err := json.Unmarshal([]byte(trimmed), &decoded); err != nil {
|
||||
return nil, errors.New("data must be a JSON object")
|
||||
}
|
||||
return decoded, nil
|
||||
default:
|
||||
return nil, errors.New("data must be an object")
|
||||
}
|
||||
}
|
||||
|
||||
func vaultCASArg(raw any) int {
|
||||
switch typed := raw.(type) {
|
||||
case int:
|
||||
return typed
|
||||
case int64:
|
||||
return int(typed)
|
||||
case float64:
|
||||
return int(typed)
|
||||
case string:
|
||||
return IntArg(typed, 0)
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
func mapArg(raw any) map[string]any {
|
||||
switch typed := raw.(type) {
|
||||
case map[string]any:
|
||||
return typed
|
||||
default:
|
||||
return map[string]any{}
|
||||
}
|
||||
}
|
||||
|
||||
func stringSliceArg(raw any) []string {
|
||||
values, ok := raw.([]any)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
result := make([]string, 0, len(values))
|
||||
for _, value := range values {
|
||||
text := strings.TrimSpace(fmt.Sprint(value))
|
||||
if text == "" || text == "<nil>" {
|
||||
continue
|
||||
}
|
||||
result = append(result, text)
|
||||
}
|
||||
return result
|
||||
}
|
||||
@ -1,142 +0,0 @@
|
||||
package shared
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestHandleVaultKVToolReadsSecretData(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
t.Fatalf("unexpected method: %s", r.Method)
|
||||
}
|
||||
if got := r.Header.Get("X-Vault-Token"); got != "root-token" {
|
||||
t.Fatalf("unexpected token header: %s", got)
|
||||
}
|
||||
if got := r.Header.Get("X-Vault-Namespace"); got != "platform/team-a" {
|
||||
t.Fatalf("unexpected namespace header: %s", got)
|
||||
}
|
||||
if got := r.URL.Path; got != "/v1/secret/data/apps/demo" {
|
||||
t.Fatalf("unexpected request path: %s", got)
|
||||
}
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"data": map[string]any{
|
||||
"data": map[string]any{
|
||||
"api_key": "demo-key",
|
||||
},
|
||||
"metadata": map[string]any{
|
||||
"version": 3,
|
||||
},
|
||||
},
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
t.Setenv("VAULT_SERVER_URL", server.URL)
|
||||
t.Setenv("VAULT_SERVER_ROOT_ACCESS_TOKEN", "root-token")
|
||||
t.Setenv("VAULT_NAMESPACE", "platform/team-a")
|
||||
|
||||
output, err := HandleVaultKVTool(map[string]any{
|
||||
"operation": "read",
|
||||
"path": "apps/demo",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("HandleVaultKVTool returned error: %v", err)
|
||||
}
|
||||
if !strings.Contains(output, `"api_key": "demo-key"`) {
|
||||
t.Fatalf("expected secret data in output, got %s", output)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleVaultKVToolWritesSecretData(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
t.Fatalf("unexpected method: %s", r.Method)
|
||||
}
|
||||
if got := r.URL.Path; got != "/v1/secret/data/apps/demo" {
|
||||
t.Fatalf("unexpected request path: %s", got)
|
||||
}
|
||||
var payload map[string]any
|
||||
if err := json.NewDecoder(r.Body).Decode(&payload); err != nil {
|
||||
t.Fatalf("decode payload: %v", err)
|
||||
}
|
||||
data := mapArg(payload["data"])
|
||||
if got := data["enabled"]; got != true {
|
||||
t.Fatalf("unexpected data payload: %v", payload)
|
||||
}
|
||||
options := mapArg(payload["options"])
|
||||
if got := options["cas"]; got != float64(2) {
|
||||
t.Fatalf("unexpected cas payload: %v", payload)
|
||||
}
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"data": map[string]any{
|
||||
"metadata": map[string]any{
|
||||
"version": 4,
|
||||
},
|
||||
},
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
t.Setenv("VAULT_SERVER_URL", server.URL)
|
||||
t.Setenv("VAULT_SERVER_ROOT_ACCESS_TOKEN", "root-token")
|
||||
|
||||
output, err := HandleVaultKVTool(map[string]any{
|
||||
"operation": "write",
|
||||
"path": "apps/demo",
|
||||
"data": map[string]any{
|
||||
"enabled": true,
|
||||
},
|
||||
"cas": 2,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("HandleVaultKVTool returned error: %v", err)
|
||||
}
|
||||
if !strings.Contains(output, `"version": 4`) {
|
||||
t.Fatalf("expected metadata in output, got %s", output)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleVaultKVToolListsSecretKeys(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != "LIST" {
|
||||
t.Fatalf("unexpected method: %s", r.Method)
|
||||
}
|
||||
if got := r.URL.Path; got != "/v1/secret/metadata/apps" {
|
||||
t.Fatalf("unexpected request path: %s", got)
|
||||
}
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"data": map[string]any{
|
||||
"keys": []string{"demo", "prod"},
|
||||
},
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
t.Setenv("VAULT_SERVER_URL", server.URL)
|
||||
t.Setenv("VAULT_SERVER_ROOT_ACCESS_TOKEN", "root-token")
|
||||
|
||||
output, err := HandleVaultKVTool(map[string]any{
|
||||
"operation": "list",
|
||||
"path": "apps",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("HandleVaultKVTool returned error: %v", err)
|
||||
}
|
||||
if !strings.Contains(output, `"demo"`) || !strings.Contains(output, `"prod"`) {
|
||||
t.Fatalf("expected listed keys in output, got %s", output)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleVaultKVToolRequiresEnvironment(t *testing.T) {
|
||||
_, err := HandleVaultKVTool(map[string]any{
|
||||
"operation": "read",
|
||||
"path": "apps/demo",
|
||||
})
|
||||
if err == nil || !strings.Contains(err.Error(), "VAULT_SERVER_URL") {
|
||||
t.Fatalf("expected missing environment error, got %v", err)
|
||||
}
|
||||
}
|
||||
@ -1,209 +0,0 @@
|
||||
package skills
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"xworkmate/go_core/internal/shared"
|
||||
)
|
||||
|
||||
type ChainFinder struct {
|
||||
Primary Finder
|
||||
Fallback Finder
|
||||
}
|
||||
|
||||
func (f ChainFinder) Find(prompt string) []Candidate {
|
||||
if f.Primary != nil {
|
||||
if resolved := dedupeCandidates(f.Primary.Find(prompt)); len(resolved) > 0 {
|
||||
return resolved
|
||||
}
|
||||
}
|
||||
if f.Fallback == nil {
|
||||
return nil
|
||||
}
|
||||
return dedupeCandidates(f.Fallback.Find(prompt))
|
||||
}
|
||||
|
||||
type CommandFinder struct {
|
||||
Binary string
|
||||
}
|
||||
|
||||
func (f CommandFinder) Find(prompt string) []Candidate {
|
||||
payload, ok := runSkillCommand(
|
||||
strings.TrimSpace(f.Binary),
|
||||
map[string]any{"prompt": strings.TrimSpace(prompt)},
|
||||
)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
return parseCandidatesPayload(payload)
|
||||
}
|
||||
|
||||
type CommandInstaller struct {
|
||||
Binary string
|
||||
}
|
||||
|
||||
func (i CommandInstaller) Install(candidates []Candidate) ([]Candidate, error) {
|
||||
payload, ok := runSkillCommand(
|
||||
strings.TrimSpace(i.Binary),
|
||||
map[string]any{
|
||||
"candidates": routingCandidatesPayload(candidates),
|
||||
},
|
||||
)
|
||||
if !ok {
|
||||
return nil, nil
|
||||
}
|
||||
return parseCandidatesPayload(payload), nil
|
||||
}
|
||||
|
||||
func NewDefaultFinder() Finder {
|
||||
return ChainFinder{
|
||||
Primary: CommandFinder{
|
||||
Binary: strings.TrimSpace(shared.EnvOrDefault("ACP_FIND_SKILLS_BIN", "")),
|
||||
},
|
||||
Fallback: StaticFinder{},
|
||||
}
|
||||
}
|
||||
|
||||
func NewDefaultInstaller() Installer {
|
||||
return CommandInstaller{
|
||||
Binary: strings.TrimSpace(shared.EnvOrDefault("ACP_INSTALL_SKILL_BIN", "")),
|
||||
}
|
||||
}
|
||||
|
||||
func runSkillCommand(binary string, payload map[string]any) (map[string]any, bool) {
|
||||
if binary == "" {
|
||||
return nil, false
|
||||
}
|
||||
if _, err := exec.LookPath(binary); err != nil {
|
||||
return nil, false
|
||||
}
|
||||
body, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil, false
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 45*time.Second)
|
||||
defer cancel()
|
||||
cmd := exec.CommandContext(ctx, binary)
|
||||
cmd.Stdin = strings.NewReader(string(body))
|
||||
output, err := cmd.Output()
|
||||
if err != nil {
|
||||
return nil, false
|
||||
}
|
||||
var decoded map[string]any
|
||||
if err := json.Unmarshal(output, &decoded); err == nil {
|
||||
return decoded, true
|
||||
}
|
||||
var list []map[string]any
|
||||
if err := json.Unmarshal(output, &list); err == nil {
|
||||
return map[string]any{"candidates": list}, true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func parseCandidatesPayload(payload map[string]any) []Candidate {
|
||||
if len(payload) == 0 {
|
||||
return nil
|
||||
}
|
||||
if raw, ok := payload["candidates"]; ok {
|
||||
return parseCandidates(raw)
|
||||
}
|
||||
if raw, ok := payload["skills"]; ok {
|
||||
return parseCandidates(raw)
|
||||
}
|
||||
return parseCandidates(payload)
|
||||
}
|
||||
|
||||
func parseCandidates(raw any) []Candidate {
|
||||
switch typed := raw.(type) {
|
||||
case []any:
|
||||
result := make([]Candidate, 0, len(typed))
|
||||
for _, item := range typed {
|
||||
entry := toMap(item)
|
||||
if len(entry) == 0 {
|
||||
continue
|
||||
}
|
||||
result = append(result, Candidate{
|
||||
ID: strings.TrimSpace(stringValue(entry["id"])),
|
||||
Label: strings.TrimSpace(stringValue(entry["label"])),
|
||||
Description: strings.TrimSpace(stringValue(entry["description"])),
|
||||
Installed: boolValue(entry["installed"]),
|
||||
})
|
||||
}
|
||||
return dedupeCandidates(result)
|
||||
case []map[string]any:
|
||||
values := make([]any, 0, len(typed))
|
||||
for _, item := range typed {
|
||||
values = append(values, item)
|
||||
}
|
||||
return parseCandidates(values)
|
||||
case map[string]any:
|
||||
entry := Candidate{
|
||||
ID: strings.TrimSpace(stringValue(typed["id"])),
|
||||
Label: strings.TrimSpace(stringValue(typed["label"])),
|
||||
Description: strings.TrimSpace(stringValue(typed["description"])),
|
||||
Installed: boolValue(typed["installed"]),
|
||||
}
|
||||
if entry.ID == "" && entry.Label == "" {
|
||||
return nil
|
||||
}
|
||||
return []Candidate{entry}
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func routingCandidatesPayload(candidates []Candidate) []map[string]any {
|
||||
result := make([]map[string]any, 0, len(candidates))
|
||||
for _, candidate := range candidates {
|
||||
result = append(result, map[string]any{
|
||||
"id": strings.TrimSpace(candidate.ID),
|
||||
"label": strings.TrimSpace(candidate.Label),
|
||||
"description": strings.TrimSpace(candidate.Description),
|
||||
"installed": candidate.Installed,
|
||||
})
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func toMap(value any) map[string]any {
|
||||
if typed, ok := value.(map[string]any); ok {
|
||||
return typed
|
||||
}
|
||||
if typed, ok := value.(map[string]interface{}); ok {
|
||||
return typed
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func stringValue(value any) string {
|
||||
if value == nil {
|
||||
return ""
|
||||
}
|
||||
switch typed := value.(type) {
|
||||
case string:
|
||||
return strings.TrimSpace(typed)
|
||||
default:
|
||||
return strings.TrimSpace(fmt.Sprint(value))
|
||||
}
|
||||
}
|
||||
|
||||
func boolValue(value any) bool {
|
||||
switch typed := value.(type) {
|
||||
case bool:
|
||||
return typed
|
||||
case string:
|
||||
normalized := strings.ToLower(strings.TrimSpace(typed))
|
||||
return normalized == "true" || normalized == "1" || normalized == "yes"
|
||||
case float64:
|
||||
return typed != 0
|
||||
case int:
|
||||
return typed != 0
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
@ -1,353 +0,0 @@
|
||||
package skills
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type Candidate struct {
|
||||
ID string
|
||||
Label string
|
||||
Description string
|
||||
Installed bool
|
||||
}
|
||||
|
||||
type Finder interface {
|
||||
Find(prompt string) []Candidate
|
||||
}
|
||||
|
||||
type Installer interface {
|
||||
Install(candidates []Candidate) ([]Candidate, error)
|
||||
}
|
||||
|
||||
type ResolveRequest struct {
|
||||
Prompt string
|
||||
ExplicitSkills []string
|
||||
AvailableSkills []Candidate
|
||||
AllowSkillInstall bool
|
||||
InstallApproval InstallApproval
|
||||
}
|
||||
|
||||
type InstallApproval struct {
|
||||
RequestID string
|
||||
ApprovedSkillKeys []string
|
||||
}
|
||||
|
||||
type ResolveResult struct {
|
||||
ResolvedSkills []string
|
||||
Candidates []Candidate
|
||||
Source string
|
||||
NeedsInstall bool
|
||||
InstallRequestID string
|
||||
}
|
||||
|
||||
type StaticFinder struct{}
|
||||
|
||||
func (StaticFinder) Find(prompt string) []Candidate {
|
||||
haystack := normalize(prompt)
|
||||
candidates := make([]Candidate, 0, 4)
|
||||
for _, entry := range builtinCatalog {
|
||||
if !containsAny(haystack, entry.keywords) {
|
||||
continue
|
||||
}
|
||||
candidates = append(candidates, Candidate{
|
||||
ID: entry.id,
|
||||
Label: entry.label,
|
||||
Installed: false,
|
||||
})
|
||||
}
|
||||
return dedupeCandidates(candidates)
|
||||
}
|
||||
|
||||
func Resolve(req ResolveRequest, finder Finder, installer Installer) ResolveResult {
|
||||
available := dedupeCandidates(req.AvailableSkills)
|
||||
explicit := normalizeList(req.ExplicitSkills)
|
||||
if len(explicit) > 0 {
|
||||
return ResolveResult{
|
||||
ResolvedSkills: explicit,
|
||||
Source: "local_match",
|
||||
}
|
||||
}
|
||||
|
||||
localMatches := matchLocalSkills(req.Prompt, available)
|
||||
if len(localMatches) > 0 {
|
||||
return ResolveResult{
|
||||
ResolvedSkills: localMatches,
|
||||
Source: "local_match",
|
||||
}
|
||||
}
|
||||
|
||||
if finder == nil {
|
||||
return ResolveResult{Source: "none"}
|
||||
}
|
||||
|
||||
fallback := dedupeCandidates(finder.Find(req.Prompt))
|
||||
if len(fallback) == 0 {
|
||||
return ResolveResult{Source: "none"}
|
||||
}
|
||||
|
||||
installed := make([]string, 0, len(fallback))
|
||||
uninstalled := make([]Candidate, 0, len(fallback))
|
||||
for _, candidate := range fallback {
|
||||
if matched := findInstalledMatch(candidate, available); matched != "" {
|
||||
installed = append(installed, matched)
|
||||
continue
|
||||
}
|
||||
uninstalled = append(uninstalled, candidate)
|
||||
}
|
||||
|
||||
if len(installed) > 0 {
|
||||
return ResolveResult{
|
||||
ResolvedSkills: dedupeStrings(installed),
|
||||
Candidates: fallback,
|
||||
Source: "find_skills",
|
||||
}
|
||||
}
|
||||
|
||||
installRequestID := buildInstallRequestID(uninstalled)
|
||||
if shouldInstallApprovedCandidates(req, installRequestID) &&
|
||||
installer != nil &&
|
||||
len(uninstalled) > 0 {
|
||||
approvedCandidates := filterApprovedCandidates(
|
||||
uninstalled,
|
||||
req.InstallApproval.ApprovedSkillKeys,
|
||||
)
|
||||
if len(approvedCandidates) == 0 {
|
||||
return ResolveResult{
|
||||
Candidates: fallback,
|
||||
Source: "find_skills",
|
||||
NeedsInstall: true,
|
||||
InstallRequestID: installRequestID,
|
||||
}
|
||||
}
|
||||
installedCandidates, err := installer.Install(approvedCandidates)
|
||||
if err == nil && len(installedCandidates) > 0 {
|
||||
mergedAvailable := dedupeCandidates(
|
||||
append(append([]Candidate(nil), available...), installedCandidates...),
|
||||
)
|
||||
if resolved := installedMatches(fallback, mergedAvailable); len(resolved) > 0 {
|
||||
return ResolveResult{
|
||||
ResolvedSkills: resolved,
|
||||
Candidates: dedupeCandidates(
|
||||
append(append([]Candidate(nil), fallback...), installedCandidates...),
|
||||
),
|
||||
Source: "find_skills",
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ResolveResult{
|
||||
Candidates: fallback,
|
||||
Source: "find_skills",
|
||||
NeedsInstall: len(uninstalled) > 0,
|
||||
InstallRequestID: installRequestID,
|
||||
}
|
||||
}
|
||||
|
||||
func shouldInstallApprovedCandidates(
|
||||
req ResolveRequest,
|
||||
expectedRequestID string,
|
||||
) bool {
|
||||
if !req.AllowSkillInstall || expectedRequestID == "" {
|
||||
return false
|
||||
}
|
||||
if strings.TrimSpace(req.InstallApproval.RequestID) != expectedRequestID {
|
||||
return false
|
||||
}
|
||||
return len(dedupeStrings(req.InstallApproval.ApprovedSkillKeys)) > 0
|
||||
}
|
||||
|
||||
func filterApprovedCandidates(
|
||||
candidates []Candidate,
|
||||
approvedSkillKeys []string,
|
||||
) []Candidate {
|
||||
if len(candidates) == 0 {
|
||||
return nil
|
||||
}
|
||||
approved := make(map[string]struct{}, len(approvedSkillKeys))
|
||||
for _, key := range approvedSkillKeys {
|
||||
normalized := normalize(key)
|
||||
if normalized == "" {
|
||||
continue
|
||||
}
|
||||
approved[normalized] = struct{}{}
|
||||
}
|
||||
filtered := make([]Candidate, 0, len(candidates))
|
||||
for _, candidate := range candidates {
|
||||
if _, ok := approved[normalize(candidate.ID)]; ok {
|
||||
filtered = append(filtered, candidate)
|
||||
}
|
||||
}
|
||||
return filtered
|
||||
}
|
||||
|
||||
func buildInstallRequestID(candidates []Candidate) string {
|
||||
if len(candidates) == 0 {
|
||||
return ""
|
||||
}
|
||||
keys := make([]string, 0, len(candidates))
|
||||
for _, candidate := range candidates {
|
||||
key := normalize(candidate.ID)
|
||||
if key == "" {
|
||||
key = normalize(candidate.Label)
|
||||
}
|
||||
if key != "" {
|
||||
keys = append(keys, key)
|
||||
}
|
||||
}
|
||||
if len(keys) == 0 {
|
||||
return ""
|
||||
}
|
||||
sort.Strings(keys)
|
||||
return "skill-install:" + strings.Join(keys, ",")
|
||||
}
|
||||
|
||||
type builtinSkill struct {
|
||||
id string
|
||||
label string
|
||||
keywords []string
|
||||
}
|
||||
|
||||
var builtinCatalog = []builtinSkill{
|
||||
{id: "pptx", label: "pptx", keywords: []string{"ppt", "pptx", "powerpoint", "slides", "幻灯片", "演示文稿"}},
|
||||
{id: "docx", label: "docx", keywords: []string{"docx", "word", "word document", "文档"}},
|
||||
{id: "xlsx", label: "xlsx", keywords: []string{"xlsx", "excel", "spreadsheet", "表格", "工作表"}},
|
||||
{id: "pdf", label: "pdf", keywords: []string{"pdf", "表单", "merge pdf", "split pdf"}},
|
||||
{id: "image-resizer", label: "image-resizer", keywords: []string{"image-resizer", "resize image", "compress image", "crop image", "批量图片"}},
|
||||
{id: "image-cog", label: "image-cog", keywords: []string{"image-cog", "文生图", "图生图", "角色一致性"}},
|
||||
{id: "image-video-generation-editting", label: "image-video-generation-editting", keywords: []string{"wan", "文生视频", "图生视频", "视频生成", "视频编辑"}},
|
||||
{id: "video-translator", label: "video-translator", keywords: []string{"video-translator", "视频翻译", "配音", "字幕翻译", "translate video", "dub video", "subtitles"}},
|
||||
{id: "browser-automation", label: "Browser Automation", keywords: []string{"browser", "跨浏览器", "浏览器", "web scraping", "资讯采集", "search", "搜索", "news", "资讯"}},
|
||||
{id: "find-skills", label: "find_skills", keywords: []string{"find skills", "find_skills", "技能包", "skill package"}},
|
||||
}
|
||||
|
||||
func matchLocalSkills(prompt string, available []Candidate) []string {
|
||||
if len(available) == 0 {
|
||||
return nil
|
||||
}
|
||||
haystack := normalize(prompt)
|
||||
if haystack == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
matches := make([]string, 0, len(available))
|
||||
for _, candidate := range available {
|
||||
keywords := candidateKeywords(candidate)
|
||||
if containsAny(haystack, keywords) {
|
||||
matches = append(matches, candidateLabel(candidate))
|
||||
}
|
||||
}
|
||||
return dedupeStrings(matches)
|
||||
}
|
||||
|
||||
func candidateKeywords(candidate Candidate) []string {
|
||||
base := []string{
|
||||
normalize(candidate.ID),
|
||||
normalize(candidate.Label),
|
||||
}
|
||||
text := normalize(strings.Join([]string{candidate.ID, candidate.Label}, " "))
|
||||
for _, entry := range builtinCatalog {
|
||||
if containsAny(text, []string{normalize(entry.id), normalize(entry.label)}) {
|
||||
base = append(base, entry.keywords...)
|
||||
}
|
||||
}
|
||||
return dedupeStrings(base)
|
||||
}
|
||||
|
||||
func findInstalledMatch(candidate Candidate, available []Candidate) string {
|
||||
want := candidateKeywords(candidate)
|
||||
for _, item := range available {
|
||||
if containsAny(strings.Join(candidateKeywords(item), " "), want) {
|
||||
return candidateLabel(item)
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func installedMatches(candidates []Candidate, available []Candidate) []string {
|
||||
resolved := make([]string, 0, len(candidates))
|
||||
for _, candidate := range candidates {
|
||||
if matched := findInstalledMatch(candidate, available); matched != "" {
|
||||
resolved = append(resolved, matched)
|
||||
}
|
||||
}
|
||||
return dedupeStrings(resolved)
|
||||
}
|
||||
|
||||
func candidateLabel(candidate Candidate) string {
|
||||
if strings.TrimSpace(candidate.Label) != "" {
|
||||
return strings.TrimSpace(candidate.Label)
|
||||
}
|
||||
return strings.TrimSpace(candidate.ID)
|
||||
}
|
||||
|
||||
func containsAny(haystack string, needles []string) bool {
|
||||
for _, needle := range needles {
|
||||
if strings.TrimSpace(needle) == "" {
|
||||
continue
|
||||
}
|
||||
if strings.Contains(haystack, normalize(needle)) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func normalize(value string) string {
|
||||
return strings.ToLower(strings.TrimSpace(value))
|
||||
}
|
||||
|
||||
func normalizeList(values []string) []string {
|
||||
result := make([]string, 0, len(values))
|
||||
for _, value := range values {
|
||||
trimmed := strings.TrimSpace(value)
|
||||
if trimmed == "" {
|
||||
continue
|
||||
}
|
||||
result = append(result, trimmed)
|
||||
}
|
||||
return dedupeStrings(result)
|
||||
}
|
||||
|
||||
func dedupeStrings(values []string) []string {
|
||||
if len(values) == 0 {
|
||||
return nil
|
||||
}
|
||||
seen := make(map[string]string, len(values))
|
||||
ordered := make([]string, 0, len(values))
|
||||
for _, value := range values {
|
||||
trimmed := strings.TrimSpace(value)
|
||||
if trimmed == "" {
|
||||
continue
|
||||
}
|
||||
key := normalize(trimmed)
|
||||
if _, ok := seen[key]; ok {
|
||||
continue
|
||||
}
|
||||
seen[key] = trimmed
|
||||
ordered = append(ordered, trimmed)
|
||||
}
|
||||
return ordered
|
||||
}
|
||||
|
||||
func dedupeCandidates(values []Candidate) []Candidate {
|
||||
if len(values) == 0 {
|
||||
return nil
|
||||
}
|
||||
seen := make(map[string]struct{}, len(values))
|
||||
ordered := make([]Candidate, 0, len(values))
|
||||
for _, candidate := range values {
|
||||
key := normalize(fmt.Sprintf("%s|%s", candidate.ID, candidate.Label))
|
||||
if key == "|" {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[key]; ok {
|
||||
continue
|
||||
}
|
||||
seen[key] = struct{}{}
|
||||
ordered = append(ordered, candidate)
|
||||
}
|
||||
return ordered
|
||||
}
|
||||
@ -1,127 +0,0 @@
|
||||
package skills
|
||||
|
||||
import "testing"
|
||||
|
||||
type fakeFinder []Candidate
|
||||
|
||||
func (f fakeFinder) Find(prompt string) []Candidate {
|
||||
return append([]Candidate(nil), f...)
|
||||
}
|
||||
|
||||
type fakeInstaller struct {
|
||||
installed []Candidate
|
||||
}
|
||||
|
||||
func (f fakeInstaller) Install(candidates []Candidate) ([]Candidate, error) {
|
||||
return append([]Candidate(nil), f.installed...), nil
|
||||
}
|
||||
|
||||
func TestResolvePrefersExplicitSkills(t *testing.T) {
|
||||
result := Resolve(ResolveRequest{
|
||||
Prompt: "make a deck",
|
||||
ExplicitSkills: []string{"pptx"},
|
||||
AvailableSkills: []Candidate{
|
||||
{ID: "pptx", Label: "pptx", Installed: true},
|
||||
},
|
||||
}, StaticFinder{}, nil)
|
||||
|
||||
if result.Source != "local_match" {
|
||||
t.Fatalf("expected local_match source, got %q", result.Source)
|
||||
}
|
||||
if len(result.ResolvedSkills) != 1 || result.ResolvedSkills[0] != "pptx" {
|
||||
t.Fatalf("unexpected resolved skills: %#v", result.ResolvedSkills)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveUsesInstalledLocalMatchesBeforeFallback(t *testing.T) {
|
||||
result := Resolve(ResolveRequest{
|
||||
Prompt: "create a PowerPoint presentation from this brief",
|
||||
AvailableSkills: []Candidate{
|
||||
{ID: "pptx", Label: "PPTX", Installed: true},
|
||||
{ID: "docx", Label: "DOCX", Installed: true},
|
||||
},
|
||||
}, StaticFinder{}, nil)
|
||||
|
||||
if result.Source != "local_match" {
|
||||
t.Fatalf("expected local_match source, got %q", result.Source)
|
||||
}
|
||||
if len(result.ResolvedSkills) != 1 || result.ResolvedSkills[0] != "PPTX" {
|
||||
t.Fatalf("unexpected resolved skills: %#v", result.ResolvedSkills)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveFallsBackToFindSkillsCandidates(t *testing.T) {
|
||||
result := Resolve(ResolveRequest{
|
||||
Prompt: "translate and dub this video with subtitles",
|
||||
AvailableSkills: []Candidate{{ID: "docx", Label: "docx", Installed: true}},
|
||||
AllowSkillInstall: false,
|
||||
}, StaticFinder{}, nil)
|
||||
|
||||
if result.Source != "find_skills" {
|
||||
t.Fatalf("expected find_skills source, got %q", result.Source)
|
||||
}
|
||||
if len(result.ResolvedSkills) != 0 {
|
||||
t.Fatalf("expected no installed resolved skills, got %#v", result.ResolvedSkills)
|
||||
}
|
||||
if !result.NeedsInstall {
|
||||
t.Fatalf("expected install recommendation")
|
||||
}
|
||||
if len(result.Candidates) == 0 || result.Candidates[0].ID != "video-translator" {
|
||||
t.Fatalf("unexpected fallback candidates: %#v", result.Candidates)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveInstallsMissingSkillsWhenAuthorized(t *testing.T) {
|
||||
initial := Resolve(
|
||||
ResolveRequest{
|
||||
Prompt: "translate and dub this video with subtitles",
|
||||
AvailableSkills: []Candidate{{ID: "docx", Label: "docx", Installed: true}},
|
||||
AllowSkillInstall: true,
|
||||
},
|
||||
fakeFinder{
|
||||
{ID: "video-translator", Label: "video-translator", Installed: false},
|
||||
},
|
||||
fakeInstaller{
|
||||
installed: []Candidate{
|
||||
{ID: "video-translator", Label: "video-translator", Installed: true},
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
if !initial.NeedsInstall {
|
||||
t.Fatalf("expected install approval flow to pause first, got %#v", initial)
|
||||
}
|
||||
if initial.InstallRequestID == "" {
|
||||
t.Fatalf("expected install request id, got %#v", initial)
|
||||
}
|
||||
|
||||
result := Resolve(
|
||||
ResolveRequest{
|
||||
Prompt: "translate and dub this video with subtitles",
|
||||
AvailableSkills: []Candidate{{ID: "docx", Label: "docx", Installed: true}},
|
||||
AllowSkillInstall: true,
|
||||
InstallApproval: InstallApproval{
|
||||
RequestID: initial.InstallRequestID,
|
||||
ApprovedSkillKeys: []string{"video-translator"},
|
||||
},
|
||||
},
|
||||
fakeFinder{
|
||||
{ID: "video-translator", Label: "video-translator", Installed: false},
|
||||
},
|
||||
fakeInstaller{
|
||||
installed: []Candidate{
|
||||
{ID: "video-translator", Label: "video-translator", Installed: true},
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
if result.Source != "find_skills" {
|
||||
t.Fatalf("expected find_skills source, got %q", result.Source)
|
||||
}
|
||||
if result.NeedsInstall {
|
||||
t.Fatalf("expected install retry to resolve the skill, got %#v", result)
|
||||
}
|
||||
if len(result.ResolvedSkills) != 1 || result.ResolvedSkills[0] != "video-translator" {
|
||||
t.Fatalf("unexpected resolved skills after install: %#v", result.ResolvedSkills)
|
||||
}
|
||||
}
|
||||
@ -1,194 +0,0 @@
|
||||
package toolbridge
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
|
||||
"xworkmate/go_core/internal/shared"
|
||||
)
|
||||
|
||||
func Run(input io.Reader, output io.Writer) {
|
||||
reader := bufio.NewReader(input)
|
||||
for {
|
||||
payload, err := readMessage(reader)
|
||||
if err != nil {
|
||||
if errors.Is(err, io.EOF) {
|
||||
return
|
||||
}
|
||||
writeError(output, nil, -32700, err.Error())
|
||||
continue
|
||||
}
|
||||
if len(strings.TrimSpace(string(payload))) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
request, err := shared.DecodeRPCRequest(payload)
|
||||
if err != nil {
|
||||
writeError(output, nil, -32700, err.Error())
|
||||
continue
|
||||
}
|
||||
|
||||
response := handleRequest(request)
|
||||
if response != nil {
|
||||
writeMessage(output, response)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func readMessage(reader *bufio.Reader) ([]byte, error) {
|
||||
line, err := reader.ReadString('\n')
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
line = strings.TrimSpace(line)
|
||||
if line == "" {
|
||||
return nil, nil
|
||||
}
|
||||
if strings.HasPrefix(strings.ToLower(line), "content-length:") {
|
||||
var contentLength int
|
||||
if _, err := fmt.Sscanf(line, "Content-Length: %d", &contentLength); err != nil {
|
||||
if _, err2 := fmt.Sscanf(line, "content-length: %d", &contentLength); err2 != nil {
|
||||
return nil, fmt.Errorf("invalid content-length header")
|
||||
}
|
||||
}
|
||||
for {
|
||||
headerLine, err := reader.ReadString('\n')
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if strings.TrimSpace(headerLine) == "" {
|
||||
break
|
||||
}
|
||||
}
|
||||
body := make([]byte, contentLength)
|
||||
if _, err := io.ReadFull(reader, body); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return body, nil
|
||||
}
|
||||
return []byte(line), nil
|
||||
}
|
||||
|
||||
func writeMessage(output io.Writer, message map[string]any) {
|
||||
payload, _ := json.Marshal(message)
|
||||
_, _ = output.Write(append(payload, '\n'))
|
||||
}
|
||||
|
||||
func writeError(output io.Writer, id any, code int, message string) {
|
||||
writeMessage(output, shared.ErrorEnvelope(id, code, message))
|
||||
}
|
||||
|
||||
func handleRequest(request shared.RPCRequest) map[string]any {
|
||||
if request.ID == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch request.Method {
|
||||
case "initialize":
|
||||
return shared.ResultEnvelope(request.ID, map[string]any{
|
||||
"protocolVersion": "2024-11-05",
|
||||
"capabilities": map[string]any{
|
||||
"tools": map[string]any{},
|
||||
},
|
||||
"serverInfo": map[string]any{
|
||||
"name": "xworkmate-go-core",
|
||||
"version": "0.2.0",
|
||||
},
|
||||
})
|
||||
case "ping":
|
||||
return shared.ResultEnvelope(request.ID, map[string]any{})
|
||||
case "tools/list":
|
||||
return shared.ResultEnvelope(request.ID, map[string]any{
|
||||
"tools": []map[string]any{
|
||||
{
|
||||
"name": "chat",
|
||||
"description": "OpenAI-compatible reviewer chat bridge",
|
||||
"inputSchema": map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"prompt": map[string]any{"type": "string"},
|
||||
"model": map[string]any{"type": "string"},
|
||||
"system": map[string]any{"type": "string"},
|
||||
},
|
||||
"required": []string{"prompt"},
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "claude_review",
|
||||
"description": "Review-only bridge over Claude CLI",
|
||||
"inputSchema": map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"prompt": map[string]any{"type": "string"},
|
||||
"model": map[string]any{"type": "string"},
|
||||
"system": map[string]any{"type": "string"},
|
||||
"tools": map[string]any{"type": "string"},
|
||||
},
|
||||
"required": []string{"prompt"},
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "vault_kv",
|
||||
"description": "HashiCorp Vault K/V v2 bridge",
|
||||
"inputSchema": map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"operation": map[string]any{"type": "string"},
|
||||
"mount": map[string]any{"type": "string"},
|
||||
"path": map[string]any{"type": "string"},
|
||||
"data": map[string]any{"type": "object"},
|
||||
"cas": map[string]any{"type": "number"},
|
||||
},
|
||||
"required": []string{"operation", "path"},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
case "tools/call":
|
||||
var params shared.ToolCallParams
|
||||
raw, _ := json.Marshal(request.Params)
|
||||
if err := json.Unmarshal(raw, ¶ms); err != nil {
|
||||
return shared.ErrorResponse(
|
||||
request.ID,
|
||||
-32602,
|
||||
fmt.Sprintf("invalid tool params: %v", err),
|
||||
)
|
||||
}
|
||||
switch params.Name {
|
||||
case "chat":
|
||||
content, err := shared.HandleChatTool(params.Arguments)
|
||||
if err != nil {
|
||||
return shared.ToolErrorResult(request.ID, err)
|
||||
}
|
||||
return shared.ToolTextResult(request.ID, content)
|
||||
case "claude_review":
|
||||
content, err := shared.HandleClaudeReviewTool(params.Arguments)
|
||||
if err != nil {
|
||||
return shared.ToolErrorResult(request.ID, err)
|
||||
}
|
||||
return shared.ToolTextResult(request.ID, content)
|
||||
case "vault_kv":
|
||||
content, err := shared.HandleVaultKVTool(params.Arguments)
|
||||
if err != nil {
|
||||
return shared.ToolErrorResult(request.ID, err)
|
||||
}
|
||||
return shared.ToolTextResult(request.ID, content)
|
||||
default:
|
||||
return shared.ErrorResponse(
|
||||
request.ID,
|
||||
-32601,
|
||||
fmt.Sprintf("unknown tool: %s", params.Name),
|
||||
)
|
||||
}
|
||||
default:
|
||||
return shared.ErrorResponse(
|
||||
request.ID,
|
||||
-32601,
|
||||
fmt.Sprintf("unknown method: %s", request.Method),
|
||||
)
|
||||
}
|
||||
}
|
||||
@ -1,80 +0,0 @@
|
||||
package toolbridge
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"xworkmate/go_core/internal/shared"
|
||||
)
|
||||
|
||||
func TestHandleRequestListsVaultKVTool(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
response := handleRequest(sharedRequest("tools/list", nil))
|
||||
result := mapStringAny(response["result"])
|
||||
tools := result["tools"].([]map[string]any)
|
||||
found := false
|
||||
for _, tool := range tools {
|
||||
if tool["name"] == "vault_kv" {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Fatalf("expected vault_kv tool in %v", tools)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleRequestCallsVaultKVTool(t *testing.T) {
|
||||
var requestPath string
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
requestPath = r.URL.Path
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"data": map[string]any{
|
||||
"data": map[string]any{
|
||||
"demo": "value",
|
||||
},
|
||||
},
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
t.Setenv("VAULT_SERVER_URL", server.URL)
|
||||
t.Setenv("VAULT_SERVER_ROOT_ACCESS_TOKEN", "root-token")
|
||||
|
||||
response := handleRequest(sharedRequest("tools/call", map[string]any{
|
||||
"name": "vault_kv",
|
||||
"arguments": map[string]any{
|
||||
"operation": "read",
|
||||
"path": "apps/demo",
|
||||
},
|
||||
}))
|
||||
result := mapStringAny(response["result"])
|
||||
content := result["content"].([]map[string]any)
|
||||
text := strings.TrimSpace(content[0]["text"].(string))
|
||||
if !strings.Contains(text, `"demo": "value"`) {
|
||||
t.Fatalf("unexpected tool output: %s", text)
|
||||
}
|
||||
if requestPath != "/v1/secret/data/apps/demo" {
|
||||
t.Fatalf("unexpected request path: %s", requestPath)
|
||||
}
|
||||
}
|
||||
|
||||
func sharedRequest(method string, params map[string]any) shared.RPCRequest {
|
||||
return shared.RPCRequest{
|
||||
JSONRPC: "2.0",
|
||||
ID: 1,
|
||||
Method: method,
|
||||
Params: params,
|
||||
}
|
||||
}
|
||||
|
||||
func mapStringAny(raw any) map[string]any {
|
||||
if typed, ok := raw.(map[string]any); ok {
|
||||
return typed
|
||||
}
|
||||
return map[string]any{}
|
||||
}
|
||||
@ -1,25 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"xworkmate/go_core/internal/acp"
|
||||
"xworkmate/go_core/internal/toolbridge"
|
||||
)
|
||||
|
||||
func main() {
|
||||
if len(os.Args) > 1 && os.Args[1] == "serve" {
|
||||
if err := acp.Serve(os.Args[2:]); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "%v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
return
|
||||
}
|
||||
if len(os.Args) > 1 && os.Args[1] == "acp-stdio" {
|
||||
acp.RunStdio(os.Stdin, os.Stdout)
|
||||
return
|
||||
}
|
||||
|
||||
toolbridge.Run(os.Stdin, os.Stdout)
|
||||
}
|
||||
@ -1,131 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestParseClaudeJSON(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
payload, err := parseClaudeJSON("log line\n{\"result\":\"review ok\",\"is_error\":false}\n")
|
||||
if err != nil {
|
||||
t.Fatalf("parseClaudeJSON returned error: %v", err)
|
||||
}
|
||||
if got := payload["result"]; got != "review ok" {
|
||||
t.Fatalf("unexpected result: %v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCallOpenAICompatible(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if got := r.Header.Get("Authorization"); got != "Bearer test-key" {
|
||||
t.Fatalf("unexpected auth header: %s", got)
|
||||
}
|
||||
var body map[string]any
|
||||
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
|
||||
t.Fatalf("decode request body: %v", err)
|
||||
}
|
||||
if got := body["model"]; got != "qwen2.5-coder:latest" {
|
||||
t.Fatalf("unexpected model: %v", got)
|
||||
}
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"choices": []map[string]any{
|
||||
{
|
||||
"message": map[string]any{
|
||||
"content": "review ok",
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
output, err := callOpenAICompatible(
|
||||
server.URL,
|
||||
"test-key",
|
||||
"qwen2.5-coder:latest",
|
||||
[]map[string]string{
|
||||
{"role": "user", "content": "hello"},
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("callOpenAICompatible returned error: %v", err)
|
||||
}
|
||||
if output != "review ok" {
|
||||
t.Fatalf("unexpected output: %s", output)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleChatToolRequiresPrompt(t *testing.T) {
|
||||
t.Setenv("LLM_API_KEY", "test-key")
|
||||
t.Setenv("LLM_BASE_URL", "http://127.0.0.1:11434/v1")
|
||||
|
||||
_, err := handleChatTool(map[string]any{})
|
||||
if err == nil || err.Error() != "prompt is required" {
|
||||
t.Fatalf("expected prompt error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseClaudeJSONReturnsErrorForPlainText(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, err := parseClaudeJSON("plain text only\n")
|
||||
if err == nil {
|
||||
t.Fatal("expected parse error for plain text output")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCallOpenAICompatibleReturnsStatusError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
http.Error(w, "bad gateway", http.StatusBadGateway)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
_, err := callOpenAICompatible(
|
||||
server.URL,
|
||||
"test-key",
|
||||
"qwen2.5-coder:latest",
|
||||
[]map[string]string{{"role": "user", "content": "hello"}},
|
||||
)
|
||||
if err == nil || err.Error() == "" {
|
||||
t.Fatal("expected non-2xx status error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunClaudeReviewSurfacesCliExitFailure(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
cliPath := filepath.Join(tempDir, "claude")
|
||||
if err := os.WriteFile(cliPath, []byte("#!/bin/sh\necho boom >&2\nexit 2\n"), 0o755); err != nil {
|
||||
t.Fatalf("write fake claude script: %v", err)
|
||||
}
|
||||
t.Setenv("CLAUDE_BIN", cliPath)
|
||||
|
||||
_, err := runClaudeReview("review this", "", "", "", 2*time.Second)
|
||||
if err == nil || err.Error() == "" {
|
||||
t.Fatal("expected cli failure")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunClaudeReviewSurfacesNonJSONStdout(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
cliPath := filepath.Join(tempDir, "claude")
|
||||
if err := os.WriteFile(cliPath, []byte("#!/bin/sh\necho plain-text-output\nexit 0\n"), 0o755); err != nil {
|
||||
t.Fatalf("write fake claude script: %v", err)
|
||||
}
|
||||
t.Setenv("CLAUDE_BIN", cliPath)
|
||||
|
||||
_, err := runClaudeReview("review this", "", "", "", 2*time.Second)
|
||||
if err == nil || err.Error() == "" {
|
||||
t.Fatal("expected non-json stdout error")
|
||||
}
|
||||
}
|
||||
@ -1,38 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"xworkmate/go_core/internal/shared"
|
||||
)
|
||||
|
||||
func parseClaudeJSON(raw string) (map[string]any, error) {
|
||||
return shared.ParseClaudeJSON(raw)
|
||||
}
|
||||
|
||||
func callOpenAICompatible(
|
||||
baseURL,
|
||||
apiKey,
|
||||
model string,
|
||||
messages []map[string]string,
|
||||
) (string, error) {
|
||||
return shared.CallOpenAICompatible(baseURL, apiKey, model, messages)
|
||||
}
|
||||
|
||||
func handleChatTool(arguments map[string]any) (string, error) {
|
||||
return shared.HandleChatTool(arguments)
|
||||
}
|
||||
|
||||
func handleVaultKVTool(arguments map[string]any) (string, error) {
|
||||
return shared.HandleVaultKVTool(arguments)
|
||||
}
|
||||
|
||||
func runClaudeReview(
|
||||
prompt,
|
||||
model,
|
||||
system,
|
||||
tools string,
|
||||
timeout time.Duration,
|
||||
) (string, error) {
|
||||
return shared.RunClaudeReview(prompt, model, system, tools, timeout)
|
||||
}
|
||||
@ -24,7 +24,7 @@
|
||||
|
||||
- `flutter analyze`
|
||||
- `flutter test`
|
||||
- `cd go/go_core && go test ./...`
|
||||
- `cd ../xworkmate-bridge && go test ./...`
|
||||
- `flutter test integration_test/desktop_navigation_flow_test.dart -d macos`
|
||||
- `flutter test integration_test/desktop_settings_flow_test.dart -d macos`
|
||||
- `flutter build macos`
|
||||
|
||||
@ -2,7 +2,15 @@
|
||||
set -euo pipefail
|
||||
|
||||
ROOT_DIR="$(cd "$(dirname "$0")/.." && pwd)"
|
||||
BRIDGE_DIR="$ROOT_DIR/go/go_core"
|
||||
DEFAULT_BRIDGE_DIR=""
|
||||
for candidate in "$ROOT_DIR/../xworkmate-bridge" "$ROOT_DIR/../../xworkmate-bridge"
|
||||
do
|
||||
if [[ -f "$candidate/go.mod" ]]; then
|
||||
DEFAULT_BRIDGE_DIR="$candidate"
|
||||
break
|
||||
fi
|
||||
done
|
||||
BRIDGE_DIR="${XWORKMATE_BRIDGE_DIR:-$DEFAULT_BRIDGE_DIR}"
|
||||
OUTPUT_DIR="${OUTPUT_DIR:-$ROOT_DIR/build/bin}"
|
||||
OUTPUT_PATH_BASE="${OUTPUT_DIR}/xworkmate-go-core"
|
||||
|
||||
@ -13,7 +21,7 @@ else
|
||||
fi
|
||||
|
||||
if [[ ! -f "$BRIDGE_DIR/go.mod" ]]; then
|
||||
echo "Missing go.mod in $BRIDGE_DIR" >&2
|
||||
echo "Missing xworkmate-bridge repo or go.mod in $BRIDGE_DIR" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
@ -24,7 +32,7 @@ fi
|
||||
|
||||
mkdir -p "$OUTPUT_DIR"
|
||||
|
||||
echo "Building xworkmate-go-core..."
|
||||
echo "Building xworkmate-go-core from xworkmate-bridge..."
|
||||
(
|
||||
cd "$BRIDGE_DIR"
|
||||
GO111MODULE=on go build -o "$OUTPUT_PATH" .
|
||||
|
||||
Loading…
Reference in New Issue
Block a user