717 lines
21 KiB
Go
717 lines
21 KiB
Go
package acp
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"net/url"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/gorilla/websocket"
|
|
|
|
"xworkmate-bridge/internal/shared"
|
|
)
|
|
|
|
type externalACPCompat struct {
|
|
providerID string
|
|
label string
|
|
endpoint string
|
|
authHeader string
|
|
category string
|
|
client *http.Client
|
|
}
|
|
|
|
type codexCompat struct {
|
|
*externalACPCompat
|
|
mu sync.Mutex
|
|
threads map[string]string
|
|
}
|
|
type opencodeCompat struct{ *externalACPCompat }
|
|
type geminiCompat struct{ *externalACPCompat }
|
|
type hermesCompat struct{ *externalACPCompat }
|
|
|
|
type sessionContinuationUnavailableError struct {
|
|
providerID string
|
|
sessionID string
|
|
threadID string
|
|
reason string
|
|
}
|
|
|
|
func (e sessionContinuationUnavailableError) Error() string {
|
|
if e.reason != "" {
|
|
return "SESSION_CONTINUATION_UNAVAILABLE: " + e.reason
|
|
}
|
|
return "SESSION_CONTINUATION_UNAVAILABLE: provider session state is unavailable"
|
|
}
|
|
|
|
func newSessionContinuationUnavailableError(providerID, sessionID, threadID, reason string) error {
|
|
return sessionContinuationUnavailableError{
|
|
providerID: strings.TrimSpace(providerID),
|
|
sessionID: strings.TrimSpace(sessionID),
|
|
threadID: strings.TrimSpace(threadID),
|
|
reason: strings.TrimSpace(reason),
|
|
}
|
|
}
|
|
|
|
func asSessionContinuationUnavailableError(err error) (sessionContinuationUnavailableError, bool) {
|
|
if err == nil {
|
|
return sessionContinuationUnavailableError{}, false
|
|
}
|
|
if typed, ok := err.(sessionContinuationUnavailableError); ok {
|
|
return typed, true
|
|
}
|
|
if typed, ok := err.(*sessionContinuationUnavailableError); ok && typed != nil {
|
|
return *typed, true
|
|
}
|
|
return sessionContinuationUnavailableError{}, false
|
|
}
|
|
|
|
func newProviderCompat(provider syncedProvider) ProviderCompat {
|
|
base := &externalACPCompat{
|
|
providerID: provider.ProviderID,
|
|
label: provider.Label,
|
|
endpoint: resolveSingleAgentForwardEndpoint(provider),
|
|
authHeader: provider.AuthorizationHeader,
|
|
category: providerCategory(provider.ProviderID),
|
|
client: shared.NewHTTPClient(5 * time.Minute),
|
|
}
|
|
switch provider.ProviderID {
|
|
case "gemini":
|
|
return &geminiCompat{externalACPCompat: base}
|
|
case "opencode":
|
|
return &opencodeCompat{externalACPCompat: base}
|
|
case "hermes":
|
|
return &hermesCompat{externalACPCompat: base}
|
|
default:
|
|
return &codexCompat{
|
|
externalACPCompat: base,
|
|
threads: make(map[string]string),
|
|
}
|
|
}
|
|
}
|
|
|
|
func providerCategory(providerID string) string {
|
|
switch providerID {
|
|
case "gemini", "hermes", "opencode":
|
|
return "protocol-adapter"
|
|
default:
|
|
return "native"
|
|
}
|
|
}
|
|
|
|
func (c *externalACPCompat) ID() string { return c.providerID }
|
|
|
|
func (c *externalACPCompat) Metadata() map[string]any {
|
|
return map[string]any{
|
|
"providerId": c.providerID,
|
|
"label": c.label,
|
|
"category": c.category,
|
|
"transport": c.transport(),
|
|
}
|
|
}
|
|
|
|
func (c *externalACPCompat) Probe(ctx context.Context) ProviderProbeResult {
|
|
_, err := c.rpcCall(ctx, "acp.capabilities", nil, nil)
|
|
if err != nil {
|
|
return ProviderProbeResult{Available: false, Status: err.Error()}
|
|
}
|
|
return ProviderProbeResult{Available: true, Status: "ok"}
|
|
}
|
|
|
|
func (c *codexCompat) Probe(ctx context.Context) ProviderProbeResult {
|
|
_, err := c.codexCall(ctx, "initialize", codexInitializeParams(), nil)
|
|
if err != nil {
|
|
if strings.Contains(strings.ToLower(err.Error()), "already initialized") {
|
|
return ProviderProbeResult{Available: true, Status: "ok"}
|
|
}
|
|
return ProviderProbeResult{Available: false, Status: err.Error()}
|
|
}
|
|
return ProviderProbeResult{Available: true, Status: "ok"}
|
|
}
|
|
|
|
func (c *codexCompat) StartSession(ctx context.Context, sessionID string, threadID string, params map[string]any, sink SessionNotificationSink) (map[string]any, error) {
|
|
if c.transport() == "ws" {
|
|
return c.startSessionWS(ctx, sessionID, threadID, params, sink)
|
|
}
|
|
thread, err := c.codexCall(ctx, "thread/start", codexThreadStartParams(params), nil)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
codexThreadID := codexThreadIDFromResult(thread)
|
|
if codexThreadID == "" {
|
|
return nil, fmt.Errorf("codex thread/start response missing thread id")
|
|
}
|
|
c.rememberThread(sessionID, threadID, codexThreadID)
|
|
return c.startTurn(ctx, codexThreadID, params, sink)
|
|
}
|
|
|
|
func (c *codexCompat) SendMessage(ctx context.Context, sessionID string, threadID string, params map[string]any, sink SessionNotificationSink) (map[string]any, error) {
|
|
if c.transport() == "ws" {
|
|
return c.sendMessageWS(ctx, sessionID, threadID, params, sink)
|
|
}
|
|
codexThreadID := c.lookupThread(sessionID, threadID)
|
|
if codexThreadID == "" {
|
|
return nil, newSessionContinuationUnavailableError(
|
|
c.providerID,
|
|
sessionID,
|
|
threadID,
|
|
"provider thread mapping is missing",
|
|
)
|
|
}
|
|
thread, err := c.codexCall(ctx, "thread/resume", map[string]any{"threadId": codexThreadID}, nil)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if resolved := codexThreadIDFromResult(thread); resolved != "" {
|
|
codexThreadID = resolved
|
|
c.rememberThread(sessionID, threadID, codexThreadID)
|
|
}
|
|
return c.startTurn(ctx, codexThreadID, params, sink)
|
|
}
|
|
|
|
func (c *codexCompat) startSessionWS(ctx context.Context, sessionID string, threadID string, params map[string]any, sink SessionNotificationSink) (map[string]any, error) {
|
|
return c.withInitializedCodexWS(ctx, func(conn *websocket.Conn) (map[string]any, error) {
|
|
thread, err := c.writeAndReadWSRPC(ctx, conn, "thread/start", codexThreadStartParams(params), nil)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
codexThreadID := codexThreadIDFromResult(thread)
|
|
if codexThreadID == "" {
|
|
return nil, fmt.Errorf("codex thread/start response missing thread id")
|
|
}
|
|
c.rememberThread(sessionID, threadID, codexThreadID)
|
|
return c.startTurnOnWS(ctx, conn, codexThreadID, params, sink)
|
|
})
|
|
}
|
|
|
|
func (c *codexCompat) sendMessageWS(ctx context.Context, sessionID string, threadID string, params map[string]any, sink SessionNotificationSink) (map[string]any, error) {
|
|
codexThreadID := c.lookupThread(sessionID, threadID)
|
|
if codexThreadID == "" {
|
|
return nil, newSessionContinuationUnavailableError(
|
|
c.providerID,
|
|
sessionID,
|
|
threadID,
|
|
"provider thread mapping is missing",
|
|
)
|
|
}
|
|
return c.withInitializedCodexWS(ctx, func(conn *websocket.Conn) (map[string]any, error) {
|
|
thread, err := c.writeAndReadWSRPC(ctx, conn, "thread/resume", map[string]any{"threadId": codexThreadID}, nil)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if resolved := codexThreadIDFromResult(thread); resolved != "" {
|
|
codexThreadID = resolved
|
|
c.rememberThread(sessionID, threadID, codexThreadID)
|
|
}
|
|
return c.startTurnOnWS(ctx, conn, codexThreadID, params, sink)
|
|
})
|
|
}
|
|
|
|
func (c *codexCompat) CloseSession(ctx context.Context, sessionID string) error {
|
|
c.mu.Lock()
|
|
delete(c.threads, sessionID)
|
|
c.mu.Unlock()
|
|
return nil
|
|
}
|
|
|
|
func (c *codexCompat) CancelSession(ctx context.Context, sessionID string) error {
|
|
codexThreadID := c.lookupThread(sessionID, "")
|
|
if codexThreadID == "" {
|
|
return nil
|
|
}
|
|
_, err := c.codexCall(ctx, "turn/interrupt", map[string]any{"threadId": codexThreadID}, nil)
|
|
return err
|
|
}
|
|
|
|
func (c *codexCompat) startTurn(ctx context.Context, codexThreadID string, params map[string]any, sink SessionNotificationSink) (map[string]any, error) {
|
|
result, err := c.codexCall(
|
|
ctx,
|
|
"turn/start",
|
|
map[string]any{
|
|
"threadId": codexThreadID,
|
|
"input": codexUserInput(params),
|
|
},
|
|
sink,
|
|
)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return c.finalizeCodexTurnResult(codexThreadID, result), nil
|
|
}
|
|
|
|
func (c *codexCompat) startTurnOnWS(ctx context.Context, conn *websocket.Conn, codexThreadID string, params map[string]any, sink SessionNotificationSink) (map[string]any, error) {
|
|
result, err := c.writeAndReadWSRPC(
|
|
ctx,
|
|
conn,
|
|
"turn/start",
|
|
map[string]any{
|
|
"threadId": codexThreadID,
|
|
"input": codexUserInput(params),
|
|
},
|
|
sink,
|
|
)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return c.finalizeCodexTurnResult(codexThreadID, result), nil
|
|
}
|
|
|
|
func (c *codexCompat) finalizeCodexTurnResult(codexThreadID string, result map[string]any) map[string]any {
|
|
if _, ok := result["output"]; !ok {
|
|
if summary := strings.TrimSpace(shared.StringArg(result, "summary", "")); summary != "" {
|
|
result["output"] = summary
|
|
}
|
|
}
|
|
result["providerThreadId"] = codexThreadID
|
|
if codexDisplayText(result) == "" && !isProviderFailureResult(result) {
|
|
result["success"] = false
|
|
result["error"] = "codex returned no displayable output"
|
|
result["message"] = "codex returned no displayable output"
|
|
}
|
|
return result
|
|
}
|
|
|
|
func codexDisplayText(result map[string]any) string {
|
|
for _, key := range []string{"output", "summary", "message"} {
|
|
if text := strings.TrimSpace(shared.StringArg(result, key, "")); text != "" && !isGenericHermesAckText(text) {
|
|
return text
|
|
}
|
|
}
|
|
return ""
|
|
}
|
|
|
|
func isProviderFailureResult(result map[string]any) bool {
|
|
if result == nil {
|
|
return false
|
|
}
|
|
if value, ok := result["success"]; ok && !parseBool(value) {
|
|
return true
|
|
}
|
|
for _, key := range []string{"error", "errorMessage", "unavailableMessage"} {
|
|
if strings.TrimSpace(shared.StringArg(result, key, "")) != "" {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
func (c *codexCompat) codexCall(ctx context.Context, method string, params map[string]any, sink SessionNotificationSink) (map[string]any, error) {
|
|
if c.transport() != "ws" {
|
|
return c.rpcCall(ctx, method, params, sink)
|
|
}
|
|
return c.callWSRPCWithInitialize(ctx, method, params, sink)
|
|
}
|
|
|
|
func (c *codexCompat) callWSRPCWithInitialize(ctx context.Context, method string, params map[string]any, sink SessionNotificationSink) (map[string]any, error) {
|
|
return c.withInitializedCodexWS(ctx, func(conn *websocket.Conn) (map[string]any, error) {
|
|
return c.writeAndReadWSRPC(ctx, conn, method, params, sink)
|
|
})
|
|
}
|
|
|
|
func (c *codexCompat) withInitializedCodexWS(ctx context.Context, run func(*websocket.Conn) (map[string]any, error)) (map[string]any, error) {
|
|
headers := http.Header{}
|
|
if c.authHeader != "" {
|
|
headers.Set("Authorization", c.authHeader)
|
|
}
|
|
conn, _, err := websocket.DefaultDialer.DialContext(ctx, c.endpoint, headers)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer func() { _ = conn.Close() }()
|
|
|
|
if _, err := c.writeAndReadWSRPC(ctx, conn, "initialize", codexInitializeParams(), nil); err != nil {
|
|
return nil, err
|
|
}
|
|
return run(conn)
|
|
}
|
|
|
|
func (c *codexCompat) writeAndReadWSRPC(ctx context.Context, conn *websocket.Conn, method string, params map[string]any, sink SessionNotificationSink) (map[string]any, error) {
|
|
requestID := fmt.Sprintf("req-%d", time.Now().UnixNano())
|
|
request := map[string]any{
|
|
"jsonrpc": "2.0",
|
|
"id": requestID,
|
|
"method": method,
|
|
"params": params,
|
|
}
|
|
if err := conn.WriteJSON(request); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
collector := &externalACPNotificationCollector{}
|
|
var pendingTurn map[string]any
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return nil, ctx.Err()
|
|
default:
|
|
}
|
|
_ = conn.SetReadDeadline(time.Now().Add(5 * time.Minute))
|
|
_, payload, err := conn.ReadMessage()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
var decoded map[string]any
|
|
if err := json.Unmarshal(payload, &decoded); err != nil {
|
|
return nil, fmt.Errorf("failed to decode websocket rpc response: %w", err)
|
|
}
|
|
|
|
methodName := strings.TrimSpace(shared.StringArg(decoded, "method", ""))
|
|
if methodName != "" {
|
|
if isExternalPermissionRequest(methodName) {
|
|
_ = writeExternalPermissionApproval(conn, decoded)
|
|
continue
|
|
}
|
|
collector.observe(decoded)
|
|
if isExternalSessionUpdateMethod(methodName) && sink != nil {
|
|
update := shared.AsMap(decoded["params"])
|
|
if len(update) > 0 {
|
|
if structured := structuredExternalACPEvent(decoded); len(structured) > 0 {
|
|
update["structuredEvent"] = structured
|
|
update["eventType"] = structured["type"]
|
|
}
|
|
sink(update)
|
|
}
|
|
}
|
|
if pendingTurn != nil && isCodexTurnCompletedNotification(decoded, pendingTurn) {
|
|
return collector.apply(pendingTurn), nil
|
|
}
|
|
continue
|
|
}
|
|
|
|
if fmt.Sprintf("%v", decoded["id"]) != requestID {
|
|
continue
|
|
}
|
|
|
|
result, err := parseExternalRPCResult(decoded)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if method == "turn/start" && isCodexTurnInProgress(result) {
|
|
pendingTurn = collector.apply(result)
|
|
continue
|
|
}
|
|
return collector.apply(result), nil
|
|
}
|
|
}
|
|
|
|
func isCodexTurnInProgress(result map[string]any) bool {
|
|
if result == nil {
|
|
return false
|
|
}
|
|
turn := shared.AsMap(result["turn"])
|
|
if len(turn) == 0 {
|
|
return false
|
|
}
|
|
status := strings.TrimSpace(shared.StringArg(turn, "status", ""))
|
|
return status == "" || strings.EqualFold(status, "inProgress") || strings.EqualFold(status, "running")
|
|
}
|
|
|
|
func isCodexTurnCompletedNotification(notification map[string]any, pendingTurn map[string]any) bool {
|
|
if notification == nil || pendingTurn == nil {
|
|
return false
|
|
}
|
|
if strings.TrimSpace(shared.StringArg(notification, "method", "")) != "turn/completed" {
|
|
return false
|
|
}
|
|
params := shared.AsMap(notification["params"])
|
|
turn := shared.AsMap(params["turn"])
|
|
if len(turn) == 0 {
|
|
return true
|
|
}
|
|
pending := shared.AsMap(pendingTurn["turn"])
|
|
pendingID := strings.TrimSpace(shared.StringArg(pending, "id", ""))
|
|
completedID := strings.TrimSpace(shared.StringArg(turn, "id", ""))
|
|
if pendingID == "" || completedID == "" || pendingID == completedID {
|
|
pendingTurn["turn"] = turn
|
|
return true
|
|
}
|
|
return false
|
|
}
|
|
|
|
func (c *codexCompat) rememberThread(sessionID string, threadID string, codexThreadID string) {
|
|
c.mu.Lock()
|
|
defer c.mu.Unlock()
|
|
if sessionID != "" {
|
|
c.threads[sessionID] = codexThreadID
|
|
}
|
|
if threadID != "" {
|
|
c.threads[threadID] = codexThreadID
|
|
}
|
|
}
|
|
|
|
func (c *codexCompat) lookupThread(sessionID string, threadID string) string {
|
|
c.mu.Lock()
|
|
defer c.mu.Unlock()
|
|
if sessionID != "" {
|
|
if value := strings.TrimSpace(c.threads[sessionID]); value != "" {
|
|
return value
|
|
}
|
|
}
|
|
if threadID != "" {
|
|
return strings.TrimSpace(c.threads[threadID])
|
|
}
|
|
return ""
|
|
}
|
|
|
|
func codexThreadStartParams(params map[string]any) map[string]any {
|
|
result := map[string]any{}
|
|
if cwd := strings.TrimSpace(shared.StringArg(params, "workingDirectory", "")); cwd != "" {
|
|
result["cwd"] = cwd
|
|
}
|
|
return result
|
|
}
|
|
|
|
func codexInitializeParams() map[string]any {
|
|
return map[string]any{
|
|
"clientInfo": map[string]any{
|
|
"name": "xworkmate-bridge",
|
|
"version": "1.1.0",
|
|
},
|
|
}
|
|
}
|
|
|
|
func codexUserInput(params map[string]any) []any {
|
|
input := map[string]any{
|
|
"type": "text",
|
|
"text": shared.StringArg(params, "taskPrompt", ""),
|
|
}
|
|
if attachments := anyList(params["attachments"]); len(attachments) > 0 {
|
|
input["attachments"] = attachments
|
|
}
|
|
return []any{input}
|
|
}
|
|
|
|
func codexThreadIDFromResult(result map[string]any) string {
|
|
for _, key := range []string{"threadId", "id"} {
|
|
if value := strings.TrimSpace(shared.StringArg(result, key, "")); value != "" {
|
|
return value
|
|
}
|
|
}
|
|
thread := shared.AsMap(result["thread"])
|
|
for _, key := range []string{"id", "threadId"} {
|
|
if value := strings.TrimSpace(shared.StringArg(thread, key, "")); value != "" {
|
|
return value
|
|
}
|
|
}
|
|
return ""
|
|
}
|
|
|
|
func anyList(value any) []any {
|
|
switch typed := value.(type) {
|
|
case []any:
|
|
return typed
|
|
case []map[string]any:
|
|
result := make([]any, 0, len(typed))
|
|
for _, item := range typed {
|
|
result = append(result, item)
|
|
}
|
|
return result
|
|
default:
|
|
return nil
|
|
}
|
|
}
|
|
|
|
func (c *externalACPCompat) StartSession(ctx context.Context, sessionID string, threadID string, params map[string]any, sink SessionNotificationSink) (map[string]any, error) {
|
|
return c.rpcCall(ctx, "session.start", params, sink)
|
|
}
|
|
|
|
func (c *externalACPCompat) SendMessage(ctx context.Context, sessionID string, threadID string, params map[string]any, sink SessionNotificationSink) (map[string]any, error) {
|
|
return c.rpcCall(ctx, "session.message", params, sink)
|
|
}
|
|
|
|
func (c *externalACPCompat) CancelSession(ctx context.Context, sessionID string) error {
|
|
_, err := c.rpcCall(ctx, "session.cancel", map[string]any{"sessionId": sessionID}, nil)
|
|
return err
|
|
}
|
|
|
|
func (c *externalACPCompat) CloseSession(ctx context.Context, sessionID string) error {
|
|
_, err := c.rpcCall(ctx, "session.close", map[string]any{"sessionId": sessionID}, nil)
|
|
return err
|
|
}
|
|
|
|
func (c *externalACPCompat) rpcCall(ctx context.Context, method string, params map[string]any, sink SessionNotificationSink) (map[string]any, error) {
|
|
switch c.transport() {
|
|
case "ws":
|
|
return c.callWSRPC(ctx, method, params, sink)
|
|
default:
|
|
return c.callHTTPRPC(ctx, method, params)
|
|
}
|
|
}
|
|
|
|
func (c *externalACPCompat) transport() string {
|
|
parsed, err := url.Parse(strings.TrimSpace(c.endpoint))
|
|
if err != nil {
|
|
return "http"
|
|
}
|
|
switch strings.ToLower(parsed.Scheme) {
|
|
case "ws", "wss":
|
|
return "ws"
|
|
default:
|
|
return "http"
|
|
}
|
|
}
|
|
|
|
func (c *externalACPCompat) callHTTPRPC(ctx context.Context, method string, params map[string]any) (map[string]any, error) {
|
|
requestID := fmt.Sprintf("req-%d", time.Now().UnixNano())
|
|
requestBody, _ := json.Marshal(map[string]any{
|
|
"jsonrpc": "2.0",
|
|
"id": requestID,
|
|
"method": method,
|
|
"params": params,
|
|
})
|
|
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.endpoint, bytes.NewReader(requestBody))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
req.Header.Set("Content-Type", "application/json; charset=utf-8")
|
|
req.Header.Set("Accept", "application/json")
|
|
if c.authHeader != "" {
|
|
req.Header.Set("Authorization", c.authHeader)
|
|
}
|
|
|
|
response, err := c.client.Do(req)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer func() {
|
|
_ = response.Body.Close()
|
|
}()
|
|
|
|
if response.StatusCode != http.StatusOK {
|
|
body, _ := io.ReadAll(io.LimitReader(response.Body, 2048))
|
|
return nil, fmt.Errorf("rpc request failed (%d): %s", response.StatusCode, strings.TrimSpace(string(body)))
|
|
}
|
|
|
|
var decoded map[string]any
|
|
if err := json.NewDecoder(response.Body).Decode(&decoded); err != nil {
|
|
return nil, fmt.Errorf("failed to decode rpc response: %w", err)
|
|
}
|
|
return parseExternalRPCResult(decoded)
|
|
}
|
|
|
|
func (c *externalACPCompat) callWSRPC(ctx context.Context, method string, params map[string]any, sink SessionNotificationSink) (map[string]any, error) {
|
|
headers := http.Header{}
|
|
if c.authHeader != "" {
|
|
headers.Set("Authorization", c.authHeader)
|
|
}
|
|
conn, _, err := websocket.DefaultDialer.DialContext(ctx, c.endpoint, headers)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer func() { _ = conn.Close() }()
|
|
|
|
requestID := fmt.Sprintf("req-%d", time.Now().UnixNano())
|
|
request := map[string]any{
|
|
"jsonrpc": "2.0",
|
|
"id": requestID,
|
|
"method": method,
|
|
"params": params,
|
|
}
|
|
if err := conn.WriteJSON(request); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
collector := &externalACPNotificationCollector{}
|
|
for {
|
|
_, payload, err := conn.ReadMessage()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
var decoded map[string]any
|
|
if err := json.Unmarshal(payload, &decoded); err != nil {
|
|
return nil, fmt.Errorf("failed to decode websocket rpc response: %w", err)
|
|
}
|
|
|
|
methodName := strings.TrimSpace(shared.StringArg(decoded, "method", ""))
|
|
if methodName != "" {
|
|
if isExternalPermissionRequest(methodName) {
|
|
_ = writeExternalPermissionApproval(conn, decoded)
|
|
continue
|
|
}
|
|
collector.observe(decoded)
|
|
if isExternalSessionUpdateMethod(methodName) && sink != nil {
|
|
update := shared.AsMap(decoded["params"])
|
|
if len(update) > 0 {
|
|
if structured := structuredExternalACPEvent(decoded); len(structured) > 0 {
|
|
update["structuredEvent"] = structured
|
|
update["eventType"] = structured["type"]
|
|
}
|
|
sink(update)
|
|
}
|
|
}
|
|
continue
|
|
}
|
|
|
|
if fmt.Sprintf("%v", decoded["id"]) != requestID {
|
|
continue
|
|
}
|
|
|
|
result, err := parseExternalRPCResult(decoded)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return collector.apply(result), nil
|
|
}
|
|
}
|
|
|
|
func isExternalSessionUpdateMethod(method string) bool {
|
|
switch strings.TrimSpace(method) {
|
|
case "session.update", "acp.session.update", "session/update":
|
|
return true
|
|
default:
|
|
return strings.HasPrefix(method, "item/") || strings.HasPrefix(method, "turn/")
|
|
}
|
|
}
|
|
|
|
func isExternalPermissionRequest(method string) bool {
|
|
normalized := strings.TrimSpace(method)
|
|
return normalized == "session/request_permission" || normalized == "session.request_permission" || normalized == "request_permission"
|
|
}
|
|
|
|
func writeExternalPermissionApproval(conn *websocket.Conn, request map[string]any) error {
|
|
if conn == nil || request == nil || request["id"] == nil {
|
|
return nil
|
|
}
|
|
return conn.WriteJSON(map[string]any{
|
|
"jsonrpc": "2.0",
|
|
"id": request["id"],
|
|
"result": map[string]any{
|
|
"approved": true,
|
|
"decision": "approved",
|
|
"behavior": "allow",
|
|
},
|
|
})
|
|
}
|
|
|
|
func parseExternalRPCResult(decoded map[string]any) (map[string]any, error) {
|
|
if decoded == nil {
|
|
return map[string]any{}, nil
|
|
}
|
|
if errPayload := shared.AsMap(decoded["error"]); len(errPayload) > 0 {
|
|
message := strings.TrimSpace(shared.StringArg(errPayload, "message", "upstream rpc error"))
|
|
if message == "" {
|
|
message = "upstream rpc error"
|
|
}
|
|
return nil, fmt.Errorf("%s", message)
|
|
}
|
|
result := shared.AsMap(decoded["result"])
|
|
if len(result) > 0 {
|
|
return result, nil
|
|
}
|
|
if ok, _ := decoded["ok"].(bool); ok {
|
|
payload := shared.AsMap(decoded["payload"])
|
|
if len(payload) > 0 {
|
|
return payload, nil
|
|
}
|
|
}
|
|
return map[string]any{}, nil
|
|
}
|