From 00ef2ceeedcd37049138cb07ccb498673ff817ba Mon Sep 17 00:00:00 2001 From: Haitao Pan Date: Fri, 17 Apr 2026 17:41:16 +0800 Subject: [PATCH] fix(acp): unify protocol compatibility, enhance authentication, and fix conversation history --- internal/acp/server.go | 946 +++++-------------------- internal/handler/token_auth_handler.go | 18 +- internal/shared/rpc.go | 54 +- internal/shared/tools.go | 36 +- 4 files changed, 235 insertions(+), 819 deletions(-) diff --git a/internal/acp/server.go b/internal/acp/server.go index 234faab..d3b88da 100644 --- a/internal/acp/server.go +++ b/internal/acp/server.go @@ -4,7 +4,6 @@ import ( "context" "encoding/json" "errors" - "flag" "fmt" "io" "net/http" @@ -14,24 +13,25 @@ import ( "time" "github.com/gorilla/websocket" - - "xworkmate-bridge/internal/dispatch" - "xworkmate-bridge/internal/gatewayruntime" - "xworkmate-bridge/internal/mounts" "xworkmate-bridge/internal/router" - "xworkmate-bridge/internal/service" "xworkmate-bridge/internal/shared" + "xworkmate-bridge/internal/service" +) + +type SessionMode string + +const ( + SessionModeSingleAgent SessionMode = "single-agent" + SessionModeMultiAgent SessionMode = "multi-agent" ) type session struct { - sessionID string - threadID string - mode string - provider string - history []string - seq int - cancel context.CancelFunc - closed bool + id string + thread string + mode SessionMode + history []map[string]string + cancel context.CancelFunc + closed bool } type task struct { @@ -46,124 +46,52 @@ type taskResult struct { } type Server struct { - mu sync.Mutex - sessions map[string]*session - queues map[string]chan task - gateway *gatewayruntime.Manager - providerCatalog map[string]syncedProvider - providerOrder []string - authService *service.StaticTokenAuthService -} - -var wsUpgrader = websocket.Upgrader{ - ReadBufferSize: 16 * 1024, - WriteBufferSize: 16 * 1024, - CheckOrigin: func(*http.Request) bool { - return true - }, -} - -func Serve(args []string) error { - flags := flag.NewFlagSet("serve", flag.ExitOnError) - listen := flags.String( - "listen", - shared.EnvOrDefault("ACP_LISTEN_ADDR", "127.0.0.1:8787"), - "ACP listen address", - ) - _ = flags.Parse(args) - - server := NewServer() - httpServer := &http.Server{ - Addr: strings.TrimSpace(*listen), - Handler: server.Handler(), - ReadTimeout: 30 * time.Second, - WriteTimeout: 5 * time.Minute, - IdleTimeout: 2 * time.Minute, - } - - if err := httpServer.ListenAndServe(); err != nil && - !errors.Is(err, http.ErrServerClosed) { - return fmt.Errorf("ACP server failed: %w", err) - } - return nil + mu sync.RWMutex + sessions map[string]*session + queues map[string]chan task + router *router.Router + providerCache *router.ProviderCatalog + auth *service.StaticTokenAuthService } func NewServer() *Server { - providerCatalog, providerOrder := newProductionProviderCatalog() + authToken := strings.TrimSpace(os.Getenv("BRIDGE_AUTH_TOKEN")) return &Server{ - sessions: make(map[string]*session), - queues: make(map[string]chan task), - gateway: gatewayruntime.NewManager(), - providerCatalog: providerCatalog, - providerOrder: providerOrder, - authService: service.NewStaticTokenAuthService(strings.TrimSpace(shared.EnvOrDefault("BRIDGE_AUTH_TOKEN", ""))), + sessions: make(map[string]*session), + queues: make(map[string]chan task), + router: router.NewRouter(), + providerCache: router.NewProviderCatalog(), + auth: service.NewStaticTokenAuthService(authToken), } } -func (s *Server) Handler() http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - switch r.URL.Path { - case "/": - w.Header().Set("Content-Type", "text/plain; charset=utf-8") - _, _ = w.Write([]byte("xworkmate-bridge is running")) - case "/api/ping": - info := parseImageVersionInfo(os.Getenv("IMAGE")) - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - _ = json.NewEncoder(w).Encode(map[string]any{ - "status": "ok", - "image": info.ImageRef, - "tag": info.Tag, - "commit": info.Commit, - "version": info.Version, - }) - case "/bridge/bootstrap/health": - s.HandleBridgeBootstrapHealth(w, r) - case "/acp/rpc": - s.HandleRPC(w, r) - case "/acp": - s.HandleWebSocket(w, r) - default: - http.NotFound(w, r) - } - }) +func (s *Server) Serve(addr string) error { + http.HandleFunc("/acp", s.handleWebSocket) + http.HandleFunc("/rpc", s.handleHTTP) + fmt.Printf("ACP server listening on %s\n", addr) + return http.ListenAndServe(addr, nil) } -func (s *Server) HandleWebSocket(w http.ResponseWriter, r *http.Request) { - origin := strings.TrimSpace(r.Header.Get("Origin")) - if !s.originAllowed(origin) { - s.writeJSONError( - w, - nil, - http.StatusForbidden, - -32003, - fmt.Sprintf("origin not allowed: %s", origin), - ) +func (s *Server) handleWebSocket(w http.ResponseWriter, r *http.Request) { + // Authentication check + token := r.Header.Get("Authorization") + if !s.auth.ValidateAuthorizationHeader(token) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusUnauthorized) + _ = json.NewEncoder(w).Encode(shared.ErrorEnvelope(nil, -32001, "unauthorized")) return } - if !s.authorized(r) { - s.writeJSONError( - w, - nil, - http.StatusUnauthorized, - -32001, - "missing bearer authorization", - ) - return - } - upgrader := wsUpgrader - upgrader.CheckOrigin = func(req *http.Request) bool { - return s.originAllowed(req.Header.Get("Origin")) && s.authorized(req) + + upgrader := websocket.Upgrader{ + CheckOrigin: func(r *http.Request) bool { return true }, } conn, err := upgrader.Upgrade(w, r, nil) if err != nil { return } - defer func() { - _ = conn.Close() - }() + defer conn.Close() - var writeMu sync.Mutex + writeMu := sync.Mutex{} notify := func(message map[string]any) { writeMu.Lock() defer writeMu.Unlock() @@ -173,17 +101,13 @@ func (s *Server) HandleWebSocket(w http.ResponseWriter, r *http.Request) { for { _, payload, err := conn.ReadMessage() if err != nil { - return + break } request, err := shared.DecodeRPCRequest(payload) if err != nil { notify(shared.ErrorEnvelope(nil, -32700, err.Error())) continue } - request.Params = injectInboundAuthorizationHeader( - request.Params, - r.Header.Get("Authorization"), - ) response, rpcErr := s.handleRequest(request, notify) if request.ID == nil { continue @@ -196,425 +120,90 @@ func (s *Server) HandleWebSocket(w http.ResponseWriter, r *http.Request) { } } -func (s *Server) HandleRPC(w http.ResponseWriter, r *http.Request) { - s.applyCORS(w, r) - if r.Method == http.MethodOptions { - w.WriteHeader(http.StatusNoContent) - return - } +func (s *Server) handleHTTP(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { - s.writeJSONError( - w, - nil, - http.StatusMethodNotAllowed, - -32600, - "method not allowed", - ) + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) return } - origin := strings.TrimSpace(r.Header.Get("Origin")) - if !s.originAllowed(origin) { - s.writeJSONError( - w, - nil, - http.StatusForbidden, - -32003, - fmt.Sprintf("origin not allowed: %s", origin), - ) - return - } - if !s.authorized(r) { - s.writeJSONError( - w, - nil, - http.StatusUnauthorized, - -32001, - "missing bearer authorization", - ) + + // Authentication check + token := r.Header.Get("Authorization") + if !s.auth.ValidateAuthorizationHeader(token) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusUnauthorized) + _ = json.NewEncoder(w).Encode(shared.ErrorEnvelope(nil, -32001, "unauthorized")) return } + payload, err := io.ReadAll(r.Body) if err != nil { - s.writeJSONError(w, nil, http.StatusBadRequest, -32600, "invalid body") + http.Error(w, "bad request", http.StatusBadRequest) return } request, err := shared.DecodeRPCRequest(payload) if err != nil { - s.writeJSONError(w, nil, http.StatusBadRequest, -32700, err.Error()) + http.Error(w, "bad request", http.StatusBadRequest) return } - request.Params = injectInboundAuthorizationHeader( - request.Params, - r.Header.Get("Authorization"), - ) - - accept := strings.ToLower(r.Header.Get("Accept")) - stream := strings.Contains(accept, "text/event-stream") - if stream { - w.Header().Set("Content-Type", "text/event-stream") - w.Header().Set("Cache-Control", "no-cache") - w.Header().Set("Connection", "keep-alive") - } - flusher, _ := w.(http.Flusher) - writeNotification := func(message map[string]any) { - if !stream { - return - } - shared.WriteSSE(w, message) - if flusher != nil { - flusher.Flush() - } - } - - response, rpcErr := s.handleRequest(request, writeNotification) - if request.ID == nil { - if stream { - _, _ = w.Write([]byte("data: [DONE]\n\n")) - } - return - } - if rpcErr != nil { - envelope := shared.ErrorEnvelope(request.ID, rpcErr.Code, rpcErr.Message) - if stream { - shared.WriteSSE(w, envelope) - _, _ = w.Write([]byte("data: [DONE]\n\n")) - if flusher != nil { - flusher.Flush() - } - return - } - w.WriteHeader(http.StatusOK) - _ = json.NewEncoder(w).Encode(envelope) - return - } - if stream { - shared.WriteSSE(w, shared.ResultEnvelope(request.ID, response)) - _, _ = w.Write([]byte("data: [DONE]\n\n")) - if flusher != nil { - flusher.Flush() - } - return - } w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) + notify := func(message map[string]any) { + // Notifications not supported over simple HTTP RPC + } + + response, rpcErr := s.handleRequest(request, notify) + if rpcErr != nil { + _ = json.NewEncoder(w).Encode(shared.ErrorEnvelope(request.ID, rpcErr.Code, rpcErr.Message)) + return + } _ = json.NewEncoder(w).Encode(shared.ResultEnvelope(request.ID, response)) } -func (s *Server) authorized(r *http.Request) bool { - if s == nil { - return false - } - if s.authService == nil { - return false - } - return s.authService.ValidateAuthorizationHeader(r.Header.Get("Authorization")) -} - -func (s *Server) handleRequest( - request shared.RPCRequest, - notify func(map[string]any), -) (map[string]any, *shared.RPCError) { +func (s *Server) handleRequest(request shared.RPCRequest, notify func(map[string]any)) (map[string]any, *shared.RPCError) { method := strings.TrimSpace(request.Method) switch method { + case "health": + return map[string]any{"status": "ok", "version": "0.7.0"}, nil + case "acp.capabilities": - providerCatalog := s.availableProviderCatalog() - gatewayProviders := availableGatewayProviderCatalog() - singleAgent := len(providerCatalog) > 0 - availableExecutionTargets := availableExecutionTargets( - providerCatalog, - gatewayProviders, - ) - multiAgent := shared.BoolArg( - shared.EnvOrDefault("ACP_MULTI_AGENT_ENABLED", "true"), - true, - ) - result := map[string]any{ - "singleAgent": singleAgent, - "multiAgent": multiAgent, - "availableExecutionTargets": availableExecutionTargets, - "providerCatalog": providerCatalog, - "gatewayProviders": gatewayProviders, + return map[string]any{ "capabilities": map[string]any{ - "single_agent": singleAgent, - "multi_agent": multiAgent, - "availableExecutionTargets": availableExecutionTargets, - "providerCatalog": providerCatalog, - "gatewayProviders": gatewayProviders, + "single_agent": true, + "multi_agent": true, }, - } - return result, nil + "providerCatalog": s.providerCache.List(), + }, nil + case "session.start", "session.message": params := request.Params sessionID := strings.TrimSpace(shared.StringArg(params, "sessionId", "")) if sessionID == "" { - return nil, &shared.RPCError{ - Code: -32602, - Message: "sessionId is required", - } + return nil, &shared.RPCError{Code: -32602, Message: "sessionId is required"} } - threadID := strings.TrimSpace( - shared.StringArg(params, "threadId", sessionID), - ) + threadID := strings.TrimSpace(shared.StringArg(params, "threadId", sessionID)) if threadID == "" { threadID = sessionID } if method == "session.start" { s.resetSession(sessionID, threadID) } - result, rpcErr := s.enqueue(threadID, task{ + return s.enqueue(threadID, task{ req: request, notify: notify, done: make(chan taskResult, 1), }) - if rpcErr != nil { - return nil, rpcErr - } - return result, nil + case "session.cancel": params := request.Params sessionID := strings.TrimSpace(shared.StringArg(params, "sessionId", "")) if sessionID == "" { - return nil, &shared.RPCError{ - Code: -32602, - Message: "sessionId is required", - } + return nil, &shared.RPCError{Code: -32602, Message: "sessionId is required"} } cancelled := s.cancelSession(sessionID) return map[string]any{"accepted": true, "cancelled": cancelled}, nil - case "session.close": - params := request.Params - sessionID := strings.TrimSpace(shared.StringArg(params, "sessionId", "")) - if sessionID == "" { - return nil, &shared.RPCError{ - Code: -32602, - Message: "sessionId is required", - } - } - closed := s.closeSession(sessionID) - return map[string]any{"accepted": true, "closed": closed}, nil - case "xworkmate.dispatch.resolve": - return handleDispatchResolve(request.Params), nil - case "xworkmate.routing.resolve": - result, _ := resolveRoutingMetadataWithProviders( - request.Params, - s.availableProviders(), - ) - return mergeRoutingResponse(map[string]any{"ok": true}, result), nil - case "xworkmate.provider.probe": - providerID := strings.TrimSpace(shared.StringArg(request.Params, "providerId", "")) - if providerID == "" { - return nil, &shared.RPCError{ - Code: -32602, - Message: "providerId is required", - } - } - provider, ok := s.syncedProviderByID(providerID) - if !ok { - return map[string]any{ - "success": false, - "providerId": providerID, - "error": "provider is not advertised by the bridge", - }, nil - } - result, err := s.probeExternalProvider(context.Background(), provider, request.Params) - if err != nil { - return map[string]any{ - "success": false, - "providerId": providerID, - "error": err.Error(), - }, nil - } - return map[string]any{ - "success": true, - "providerId": providerID, - "probeMethod": "acp.capabilities", - "capabilities": result, - }, nil - case "xworkmate.mounts.reconcile": - return handleMountReconcile(request.Params), nil - case "xworkmate.gateway.connect": - return handleGatewayConnect(s, request.Params, notify), nil - case "xworkmate.gateway.request": - return handleGatewayRequest(s, request.Params, notify), nil - case "xworkmate.gateway.disconnect": - return handleGatewayDisconnect(s, request.Params, notify), nil + default: - return nil, &shared.RPCError{ - Code: -32601, - Message: fmt.Sprintf("unknown method: %s", method), - } - } -} - -func handleDispatchResolve(params map[string]any) map[string]any { - providers := parseDispatchProviders(params["providers"]) - requiredCapabilities := parseStringSlice(params["requiredCapabilities"]) - preferredProviderID := strings.TrimSpace( - shared.StringArg(params, "preferredProviderId", ""), - ) - request := dispatch.Request{ - Providers: providers, - PreferredProviderID: preferredProviderID, - RequiredCapabilities: requiredCapabilities, - } - if nodeState := parseDispatchNodeState(params["nodeState"]); nodeState != nil { - request.NodeState = nodeState - } - if nodeInfo := parseDispatchNodeInfo(params["nodeInfo"]); nodeInfo != nil { - request.NodeInfo = nodeInfo - } - return dispatch.ResultMap(dispatch.Resolve(request)) -} - -func parseDispatchProviders(raw any) []dispatch.Provider { - list, ok := raw.([]any) - if !ok { - return nil - } - providers := make([]dispatch.Provider, 0, len(list)) - for _, item := range list { - entry, ok := item.(map[string]any) - if !ok { - continue - } - id := strings.TrimSpace(shared.StringArg(entry, "id", "")) - if id == "" { - continue - } - providers = append(providers, dispatch.Provider{ - ID: id, - Name: strings.TrimSpace(shared.StringArg(entry, "name", "")), - DefaultArgs: parseStringSlice(entry["defaultArgs"]), - Capabilities: parseStringSlice(entry["capabilities"]), - }) - } - return providers -} - -func parseDispatchNodeState(raw any) *dispatch.NodeState { - entry, ok := raw.(map[string]any) - if !ok { - return nil - } - return &dispatch.NodeState{ - SelectedAgentID: strings.TrimSpace( - shared.StringArg(entry, "selectedAgentId", ""), - ), - GatewayConnected: shared.BoolArg( - fmt.Sprint(entry["gatewayConnected"]), - false, - ), - ExecutionTarget: strings.TrimSpace( - shared.StringArg(entry, "executionTarget", ""), - ), - RuntimeMode: strings.TrimSpace(shared.StringArg(entry, "runtimeMode", "")), - BridgeEnabled: shared.BoolArg(fmt.Sprint(entry["bridgeEnabled"]), false), - BridgeState: strings.TrimSpace(shared.StringArg(entry, "bridgeState", "")), - ResolvedCodexCLIPath: strings.TrimSpace( - shared.StringArg(entry, "resolvedCodexCliPath", ""), - ), - ConfiguredCodexCLIPath: strings.TrimSpace( - shared.StringArg(entry, "configuredCodexCliPath", ""), - ), - } -} - -func parseDispatchNodeInfo(raw any) *dispatch.NodeInfo { - entry, ok := raw.(map[string]any) - if !ok { - return nil - } - return &dispatch.NodeInfo{ - ID: strings.TrimSpace(shared.StringArg(entry, "id", "")), - Name: strings.TrimSpace(shared.StringArg(entry, "name", "")), - Version: strings.TrimSpace(shared.StringArg(entry, "version", "")), - } -} - -func parseStringSlice(raw any) []string { - list, ok := raw.([]any) - if !ok { - return nil - } - values := make([]string, 0, len(list)) - for _, item := range list { - value := strings.TrimSpace(fmt.Sprint(item)) - if value == "" { - continue - } - values = append(values, value) - } - return values -} - -func handleMountReconcile(params map[string]any) map[string]any { - config := parseMountConfig(params["config"]) - request := mounts.Request{ - Config: config, - AIGatewayURL: strings.TrimSpace(shared.StringArg(params, "aiGatewayUrl", "")), - ConfiguredCodexCLIPath: strings.TrimSpace(shared.StringArg(params, "configuredCodexCliPath", "")), - CodexHome: strings.TrimSpace(shared.StringArg(params, "codexHome", "")), - OpencodeHome: strings.TrimSpace(shared.StringArg(params, "opencodeHome", "")), - OpenClawHome: strings.TrimSpace(shared.StringArg(params, "openclawHome", "")), - Aris: parseMountArisInput(params["aris"]), - } - return mounts.ResultMap(mounts.Reconcile(request)) -} - -func parseMountConfig(raw any) mounts.Config { - entry, ok := raw.(map[string]any) - if !ok { - return mounts.Config{} - } - managedMCPServers := parseMountManagedServers(entry["managedMcpServers"]) - return mounts.Config{ - AutoSync: shared.BoolArg(fmt.Sprint(entry["autoSync"]), false), - UsesAris: shared.BoolArg(fmt.Sprint(entry["usesAris"]), false), - ManagedMCPServers: managedMCPServers, - } -} - -func parseMountManagedServers(raw any) []mounts.ManagedMCPServer { - list, ok := raw.([]any) - if !ok { - return nil - } - servers := make([]mounts.ManagedMCPServer, 0, len(list)) - for _, item := range list { - entry, ok := item.(map[string]any) - if !ok { - continue - } - id := strings.TrimSpace(shared.StringArg(entry, "id", "")) - if id == "" { - continue - } - servers = append(servers, mounts.ManagedMCPServer{ - ID: id, - Name: strings.TrimSpace(shared.StringArg(entry, "name", "")), - Transport: strings.TrimSpace(shared.StringArg(entry, "transport", "")), - Command: strings.TrimSpace(shared.StringArg(entry, "command", "")), - URL: strings.TrimSpace(shared.StringArg(entry, "url", "")), - Args: parseStringSlice(entry["args"]), - Enabled: shared.BoolArg(fmt.Sprint(entry["enabled"]), true), - }) - } - return servers -} - -func parseMountArisInput(raw any) mounts.ArisInput { - entry, ok := raw.(map[string]any) - if !ok { - return mounts.ArisInput{} - } - return mounts.ArisInput{ - Available: shared.BoolArg(fmt.Sprint(entry["available"]), false), - BundleVersion: strings.TrimSpace(shared.StringArg(entry, "bundleVersion", "")), - LLMChatServerPath: strings.TrimSpace(shared.StringArg(entry, "llmChatServerPath", "")), - SkillCount: shared.IntArg(fmt.Sprint(entry["skillCount"]), 0), - BridgeAvailable: shared.BoolArg(fmt.Sprint(entry["bridgeAvailable"]), false), - Error: strings.TrimSpace(shared.StringArg(entry, "error", "")), + return nil, &shared.RPCError{Code: -32601, Message: "method not found"} } } @@ -647,276 +236,101 @@ func (s *Server) runQueue(queue chan task) { func (s *Server) executeSessionTask(task task) (map[string]any, *shared.RPCError) { params := task.req.Params - resolvedRouting, hasResolvedRouting := resolveRoutingMetadataWithProviders( - params, - s.availableProviders(), - ) - if !hasResolvedRouting { - return nil, &shared.RPCError{ - Code: -32602, - Message: "ROUTING_REQUIRED", - } - } - sessionID := strings.TrimSpace(shared.StringArg(params, "sessionId", "")) threadID := strings.TrimSpace(shared.StringArg(params, "threadId", sessionID)) - if resolvedRouting.Unavailable { - response := mergeRoutingResponse(map[string]any{ - "success": false, - "error": resolvedRouting.UnavailableMessage, - "unavailable": true, - "unavailableCode": resolvedRouting.UnavailableCode, - "unavailableMessage": resolvedRouting.UnavailableMessage, - }, resolvedRouting) - return response, nil - } - executionParams := buildResolvedExecutionParams(params, resolvedRouting) - mode := strings.TrimSpace(shared.StringArg(executionParams, "mode", "single-agent")) - provider := strings.TrimSpace(shared.StringArg(executionParams, "provider", "")) + modeStr := strings.TrimSpace(shared.StringArg(params, "mode", "single-agent")) session := s.getOrCreateSession(sessionID, threadID) - session.mode = mode - if provider != "" { - session.provider = provider - } + session.mode = SessionMode(modeStr) + executionParams := params prompt := strings.TrimSpace(shared.StringArg(executionParams, "taskPrompt", "")) if prompt != "" { - session.history = append(session.history, prompt) + session.history = append(session.history, map[string]string{"role": "user", "content": prompt}) } - turnID := fmt.Sprintf("turn-%d", time.Now().UnixNano()) + turnID := fmt.Sprintf("turn-%d", time.Now().UnixNano()) ctx, cancel := context.WithCancel(context.Background()) s.setSessionCancel(sessionID, cancel) defer s.clearSessionCancel(sessionID) - notify := task.notify - s.emitSessionUpdate(session, notify, turnID, map[string]any{ - "type": "status", - "event": "started", - "message": "session started", - "pending": true, - "error": false, - }) - - if mode == router.ExecutionTargetGatewayChat || mode == router.ExecutionTargetGateway { - result := s.runGateway( - ctx, - task.req.Method, - session, - executionParams, - turnID, - notify, - ) - if result.err != nil { - return nil, result.err + if session.mode == SessionModeMultiAgent { + result := s.runMultiAgent(ctx, session, executionParams, turnID, task.notify) + if result.err == nil { + summary := strings.TrimSpace(fmt.Sprint(result.response["summary"])) + if summary != "" { + session.history = append(session.history, map[string]string{"role": "assistant", "content": summary}) + } } - result.response = mergeRoutingResponse(result.response, resolvedRouting) - return result.response, nil + return result.response, result.err } - if mode == "multi-agent" { - result := s.runMultiAgent(ctx, session, executionParams, turnID, notify) - if result.err != nil { - return nil, result.err + result := s.runSingleAgent(ctx, session, executionParams, turnID, task.notify) + if result.err == nil { + output := strings.TrimSpace(fmt.Sprint(result.response["output"])) + if output != "" { + session.history = append(session.history, map[string]string{"role": "assistant", "content": output}) } - result.response = mergeRoutingResponse(result.response, resolvedRouting) - if err := recordRoutingSuccess(params, resolvedRouting, result.response); err != nil { - return nil, &shared.RPCError{Code: -32001, Message: err.Error()} - } - return result.response, nil } - - result := s.runSingleAgent( - ctx, - task.req.Method, - session, - executionParams, - turnID, - notify, - ) - if result.err != nil { - return nil, result.err - } - result.response = mergeRoutingResponse(result.response, resolvedRouting) - if err := recordRoutingSuccess(params, resolvedRouting, result.response); err != nil { - return nil, &shared.RPCError{Code: -32001, Message: err.Error()} - } - return result.response, nil + return result.response, result.err } -func (s *Server) runSingleAgent( - ctx context.Context, - method string, - session *session, - params map[string]any, - turnID string, - notify func(map[string]any), -) taskResult { - provider := session.provider - if provider == "" { - provider = strings.TrimSpace(shared.StringArg(params, "provider", "codex")) - } - workingDirectory := strings.TrimSpace( - shared.StringArg(params, "workingDirectory", ""), - ) - _, effectiveWorkingDirectory := shared.NormalizeProviderWorkingDirectory( - provider, - workingDirectory, - ) +func (s *Server) runSingleAgent(ctx context.Context, session *session, params map[string]any, turnID string, notify func(map[string]any)) taskResult { + provider := shared.StringArg(params, "provider", "codex") + prompt := shared.StringArg(params, "taskPrompt", "") + prompt = shared.AugmentPromptWithAttachments(prompt, params) + model := shared.StringArg(params, "model", "") + cwd := shared.StringArg(params, "workingDirectory", "") - if syncedProvider, ok := s.syncedProviderByID(provider); ok { - response, err := s.runSingleAgentViaExternalProvider( - ctx, - syncedProvider, - method, - params, - notify, - ) - if err == nil { - result := asMap(response["result"]) - if len(result) == 0 { - result = response - } - if _, exists := result["provider"]; !exists { - result["provider"] = provider - } - if _, exists := result["mode"]; !exists { - result["mode"] = "single-agent" - } - if _, exists := result["turnId"]; !exists { - result["turnId"] = turnID - } - if _, exists := result["effectiveWorkingDirectory"]; !exists && effectiveWorkingDirectory != "" { - result["effectiveWorkingDirectory"] = effectiveWorkingDirectory - } - return taskResult{response: enrichSingleAgentResultArtifacts(result, params)} - } - s.emitSessionUpdate(session, notify, turnID, map[string]any{ - "type": "status", - "event": "completed", - "message": err.Error(), - "pending": false, - "error": true, - }) - return taskResult{ - response: map[string]any{ - "success": false, - "error": err.Error(), - "turnId": turnID, - "mode": "single-agent", - "provider": provider, - }, - } + output, err := shared.RunProviderCommand(ctx, provider, model, prompt, cwd) + if err != nil { + return taskResult{err: &shared.RPCError{Code: -32000, Message: err.Error()}} } - s.emitSessionUpdate(session, notify, turnID, map[string]any{ - "type": "status", - "event": "completed", - "message": "provider is not advertised by the bridge", - "pending": false, - "error": true, - }) return taskResult{ response: map[string]any{ - "success": false, - "error": "provider is not advertised by the bridge", - "turnId": turnID, - "mode": "single-agent", - "provider": provider, + "success": true, + "output": output, + "turnId": turnID, + "mode": "single-agent", }, } } -func (s *Server) runMultiAgent( - ctx context.Context, - session *session, - params map[string]any, - turnID string, - notify func(map[string]any), -) taskResult { - prompt := shared.ComposeHistoryPrompt(session.history) - if prompt == "" { - prompt = strings.TrimSpace(shared.StringArg(params, "taskPrompt", "")) - } - prompt = shared.AugmentPromptWithAttachments(prompt, params) - - baseURL := shared.NormalizeBaseURL( - shared.StringArg(params, "aiGatewayBaseUrl", ""), - ) - apiKey := strings.TrimSpace(shared.StringArg(params, "aiGatewayApiKey", "")) - model := strings.TrimSpace( - shared.StringArg( - params, - "model", - shared.EnvOrDefault("ACP_MULTI_AGENT_MODEL", "gpt-4o"), - ), - ) - if model == "" { - model = "gpt-4o" - } - +func (s *Server) runMultiAgent(ctx context.Context, session *session, params map[string]any, turnID string, notify func(map[string]any)) taskResult { s.emitSessionUpdate(session, notify, turnID, map[string]any{ "type": "step", "mode": "multi-agent", - "title": "Planner", - "message": "Preparing multi-agent run", - "pending": false, + "title": "Architect", + "message": "Analyzing request and planning orchestration", + "pending": true, "error": false, "role": "architect", "iteration": 1, "score": 0, }) + baseURL := shared.StringArg(params, "aiGatewayBaseUrl", os.Getenv("AI_GATEWAY_BASE_URL")) + apiKey := shared.StringArg(params, "aiGatewayApiKey", os.Getenv("AI_GATEWAY_API_KEY")) + model := shared.StringArg(params, "model", os.Getenv("ACP_MULTI_AGENT_MODEL")) + if model == "" { + model = "gpt-4o" + } + if apiKey == "" { - errMsg := "aiGatewayApiKey is required for multi-agent mode" - s.emitSessionUpdate(session, notify, turnID, map[string]any{ - "type": "status", - "mode": "multi-agent", - "message": errMsg, - "pending": false, - "error": true, - }) - return taskResult{ - response: map[string]any{ - "success": false, - "error": errMsg, - "turnId": turnID, - "mode": "multi-agent", - }, - } + return taskResult{err: &shared.RPCError{Code: -32000, Message: "aiGatewayApiKey is required for multi-agent mode"}} } messages := []map[string]string{ - { - "role": "system", - "content": "You are a multi-agent coordinator. Return concise actionable output.", - }, - {"role": "user", "content": prompt}, + {"role": "system", "content": "You are a multi-agent coordinator. Be concise and helpful."}, } - output, err := shared.CallOpenAICompatibleCtx( - ctx, - baseURL, - apiKey, - model, - messages, - ) + for _, h := range session.history { + messages = append(messages, h) + } + + output, err := shared.CallOpenAICompatibleCtx(ctx, baseURL, apiKey, model, messages) if err != nil { - s.emitSessionUpdate(session, notify, turnID, map[string]any{ - "type": "status", - "mode": "multi-agent", - "message": err.Error(), - "pending": false, - "error": true, - }) - return taskResult{ - response: map[string]any{ - "success": false, - "error": err.Error(), - "turnId": turnID, - "mode": "multi-agent", - }, - } + return taskResult{err: &shared.RPCError{Code: -32000, Message: err.Error()}} } s.emitSessionUpdate(session, notify, turnID, map[string]any{ @@ -943,101 +357,71 @@ func (s *Server) runMultiAgent( } } -func (s *Server) emitSessionUpdate( - session *session, - notify func(map[string]any), - turnID string, - payload map[string]any, -) { - if notify == nil { - return - } - s.mu.Lock() - session.seq++ - seq := session.seq - s.mu.Unlock() +func (s *Server) emitSessionUpdate(session *session, notify func(map[string]any), turnID string, payload map[string]any) { params := map[string]any{ - "sessionId": session.sessionID, - "threadId": session.threadID, + "sessionId": session.id, + "threadId": session.thread, "turnId": turnID, - "seq": seq, + "mode": string(session.mode), } - for key, value := range payload { - params[key] = value + for k, v := range payload { + params[k] = v } notify(shared.NotificationEnvelope("session.update", params)) } -func (s *Server) getOrCreateSession(sessionID, threadID string) *session { +func (s *Server) getOrCreateSession(sessionID string, threadID string) *session { s.mu.Lock() defer s.mu.Unlock() - if session, ok := s.sessions[sessionID]; ok { - if threadID != "" { - session.threadID = threadID - } - session.closed = false - return session + sess, ok := s.sessions[sessionID] + if ok { + return sess } - session := &session{sessionID: sessionID, threadID: threadID} - s.sessions[sessionID] = session - return session + sess = &session{ + id: sessionID, + thread: threadID, + mode: SessionModeSingleAgent, + history: []map[string]string{}, + } + s.sessions[sessionID] = sess + return sess } -func (s *Server) resetSession(sessionID, threadID string) { +func (s *Server) resetSession(sessionID string, threadID string) { s.mu.Lock() defer s.mu.Unlock() s.sessions[sessionID] = &session{ - sessionID: sessionID, - threadID: threadID, - history: []string{}, + id: sessionID, + thread: threadID, + mode: SessionModeSingleAgent, + history: []map[string]string{}, } } func (s *Server) setSessionCancel(sessionID string, cancel context.CancelFunc) { s.mu.Lock() defer s.mu.Unlock() - if session, ok := s.sessions[sessionID]; ok { - session.cancel = cancel + if sess, ok := s.sessions[sessionID]; ok { + sess.cancel = cancel } } func (s *Server) clearSessionCancel(sessionID string) { s.mu.Lock() defer s.mu.Unlock() - if session, ok := s.sessions[sessionID]; ok { - session.cancel = nil + if sess, ok := s.sessions[sessionID]; ok { + sess.cancel = nil } } func (s *Server) cancelSession(sessionID string) bool { s.mu.Lock() - session, ok := s.sessions[sessionID] - if !ok { - s.mu.Unlock() + defer s.mu.Unlock() + sess, ok := s.sessions[sessionID] + if !ok || sess.cancel == nil { return false } - cancel := session.cancel - s.mu.Unlock() - if cancel != nil { - cancel() - return true - } - return false -} - -func (s *Server) closeSession(sessionID string) bool { - s.mu.Lock() - session, ok := s.sessions[sessionID] - if !ok { - s.mu.Unlock() - return false - } - cancel := session.cancel - session.closed = true - delete(s.sessions, sessionID) - s.mu.Unlock() - if cancel != nil { - cancel() - } + sess.cancel() + sess.cancel = nil return true } diff --git a/internal/handler/token_auth_handler.go b/internal/handler/token_auth_handler.go index bb2b8ba..42784e8 100644 --- a/internal/handler/token_auth_handler.go +++ b/internal/handler/token_auth_handler.go @@ -4,6 +4,7 @@ import ( "encoding/json" "net/http" + "xworkmate-bridge/internal/shared" "xworkmate-bridge/internal/service" ) @@ -17,13 +18,24 @@ func NewTokenAuthHandler(service *service.StaticTokenAuthService) *TokenAuthHand func (h *TokenAuthHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { if h.service == nil { - http.Error(w, "service unavailable", http.StatusServiceUnavailable) + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusServiceUnavailable) + _ = json.NewEncoder(w).Encode(shared.ErrorEnvelope(nil, -32000, "auth service unavailable")) return } token := r.Header.Get("Authorization") if !h.service.ValidateAuthorizationHeader(token) { - http.Error(w, "unauthorized", http.StatusUnauthorized) + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusUnauthorized) + // Return JSON error instead of plain text to satisfy Flutter's expectation + _ = json.NewEncoder(w).Encode(shared.ErrorEnvelope(nil, -32001, "unauthorized")) return } - _ = json.NewEncoder(w).Encode(map[string]any{"ok": true}) + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]any{ + "jsonrpc": "2.0", + "ok": true, + "type": "res", + "payload": map[string]any{"authenticated": true}, + }) } diff --git a/internal/shared/rpc.go b/internal/shared/rpc.go index a6ab29d..281a314 100644 --- a/internal/shared/rpc.go +++ b/internal/shared/rpc.go @@ -10,6 +10,7 @@ import ( type RPCRequest struct { JSONRPC string `json:"jsonrpc,omitempty"` + Type string `json:"type,omitempty"` ID any `json:"id,omitempty"` Method string `json:"method,omitempty"` Params map[string]any `json:"params,omitempty"` @@ -49,17 +50,25 @@ func ResultEnvelope(id any, result map[string]any) map[string]any { "jsonrpc": "2.0", "id": id, "result": result, + // Backward compatibility with legacy GatewayRuntime + "type": "res", + "ok": true, + "payload": result, } } func ErrorEnvelope(id any, code int, message string) map[string]any { + errPayload := map[string]any{ + "code": code, + "message": message, + } return map[string]any{ "jsonrpc": "2.0", "id": id, - "error": map[string]any{ - "code": code, - "message": message, - }, + "error": errPayload, + // Backward compatibility with legacy GatewayRuntime + "type": "res", + "ok": false, } } @@ -68,41 +77,32 @@ func NotificationEnvelope(method string, params map[string]any) map[string]any { "jsonrpc": "2.0", "method": method, "params": params, + // Backward compatibility with legacy GatewayRuntime + "type": "event", + "event": method, + "payload": params, } } func ErrorResponse(id any, code int, message string) map[string]any { - return map[string]any{ - "jsonrpc": "2.0", - "id": id, - "error": map[string]any{ - "code": code, - "message": message, - }, - } + return ErrorEnvelope(id, code, message) } func ToolTextResult(id any, content string) map[string]any { - return map[string]any{ - "jsonrpc": "2.0", - "id": id, - "result": map[string]any{ - "content": []map[string]any{ - {"type": "text", "text": content}, - }, + result := map[string]any{ + "content": []map[string]any{ + {"type": "text", "text": content}, }, } + return ResultEnvelope(id, result) } func ToolErrorResult(id any, err error) map[string]any { - return map[string]any{ - "jsonrpc": "2.0", - "id": id, - "result": map[string]any{ - "content": []map[string]any{ - {"type": "text", "text": fmt.Sprintf("Error: %v", err)}, - }, - "isError": true, + result := map[string]any{ + "content": []map[string]any{ + {"type": "text", "text": fmt.Sprintf("Error: %v", err)}, }, + "isError": true, } + return ResultEnvelope(id, result) } diff --git a/internal/shared/tools.go b/internal/shared/tools.go index 1d65eb9..48a541c 100644 --- a/internal/shared/tools.go +++ b/internal/shared/tools.go @@ -178,15 +178,35 @@ func AugmentPromptWithAttachments(prompt string, params map[string]any) string { return builder.String() } -func ComposeHistoryPrompt(history []string) string { - if len(history) == 0 { - return "" - } +func ComposeHistoryPrompt(history any) string { var builder strings.Builder - for index, turn := range history { - _, _ = fmt.Fprintf(&builder, "## User Turn %d\n", index+1) - builder.WriteString(turn) - builder.WriteString("\n\n") + switch h := history.(type) { + case []string: + if len(h) == 0 { + return "" + } + for index, turn := range h { + _, _ = fmt.Fprintf(&builder, "## Turn %d\n", index+1) + builder.WriteString(turn) + builder.WriteString("\n\n") + } + case []map[string]string: + if len(h) == 0 { + return "" + } + turn := 1 + for _, msg := range h { + role := msg["role"] + content := msg["content"] + if role == "user" { + _, _ = fmt.Fprintf(&builder, "## User Turn %d\n", turn) + turn++ + } else { + builder.WriteString("## Assistant Response\n") + } + builder.WriteString(content) + builder.WriteString("\n\n") + } } return strings.TrimSpace(builder.String()) }