diff --git a/go.mod b/go.mod index 526e640..4aec15b 100644 --- a/go.mod +++ b/go.mod @@ -5,3 +5,28 @@ go 1.25.0 require github.com/gorilla/websocket v1.5.3 require gopkg.in/yaml.v3 v3.0.1 + +require ( + github.com/google/uuid v1.6.0 // indirect + github.com/pion/datachannel v1.6.0 // indirect + github.com/pion/dtls/v3 v3.1.3 // indirect + github.com/pion/ice/v4 v4.2.7 // indirect + github.com/pion/interceptor v0.1.45 // indirect + github.com/pion/logging v0.2.4 // indirect + github.com/pion/mdns/v2 v2.1.0 // indirect + github.com/pion/randutil v0.1.0 // indirect + github.com/pion/rtcp v1.2.16 // indirect + github.com/pion/rtp v1.10.2 // indirect + github.com/pion/sctp v1.10.0 // indirect + github.com/pion/sdp/v3 v3.0.18 // indirect + github.com/pion/srtp/v3 v3.0.11 // indirect + github.com/pion/stun/v3 v3.1.4 // indirect + github.com/pion/transport/v4 v4.0.2 // indirect + github.com/pion/turn/v5 v5.0.7 // indirect + github.com/pion/webrtc/v4 v4.2.14 // indirect + github.com/wlynxg/anet v0.0.5 // indirect + golang.org/x/crypto v0.48.0 // indirect + golang.org/x/net v0.50.0 // indirect + golang.org/x/sys v0.41.0 // indirect + golang.org/x/time v0.14.0 // indirect +) diff --git a/go.sum b/go.sum index 6a2c68d..a1e6515 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,49 @@ +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= +github.com/pion/datachannel v1.6.0 h1:XecBlj+cvsxhAMZWFfFcPyUaDZtd7IJvrXqlXD/53i0= +github.com/pion/datachannel v1.6.0/go.mod h1:ur+wzYF8mWdC+Mkis5Thosk+u/VOL287apDNEbFpsIk= +github.com/pion/dtls/v3 v3.1.3 h1:OA6J5UCeA8DvRXD8ofaMnlNPXN3ISBLHHJ9P8SWL09E= +github.com/pion/dtls/v3 v3.1.3/go.mod h1:GEwid4EzCcakfrNvHXM7bs6ci2mASI5Y5Q4tbtLFuWs= +github.com/pion/ice/v4 v4.2.7 h1:zDEbC6MiEdhQpF8TxBOTws+NU6ZgGpveHrQq4Lc1kao= +github.com/pion/ice/v4 v4.2.7/go.mod h1:9SNPaq0c7El/ki8leJzyCkK10zsskprR3zTNbO3monY= +github.com/pion/interceptor v0.1.45 h1:6PUo/5829bIfRFIPPJQzuDn8EjxRTSB/CSD7QVCOaqo= +github.com/pion/interceptor v0.1.45/go.mod h1:gNDYM/uFKcLe/B3gS2/7+aw6z+RDiMy2qKTnF1LO31w= +github.com/pion/logging v0.2.4 h1:tTew+7cmQ+Mc1pTBLKH2puKsOvhm32dROumOZ655zB8= +github.com/pion/logging v0.2.4/go.mod h1:DffhXTKYdNZU+KtJ5pyQDjvOAh/GsNSyv1lbkFbe3so= +github.com/pion/mdns/v2 v2.1.0 h1:3IJ9+Xio6tWYjhN6WwuY142P/1jA0D5ERaIqawg/fOY= +github.com/pion/mdns/v2 v2.1.0/go.mod h1:pcez23GdynwcfRU1977qKU0mDxSeucttSHbCSfFOd9A= +github.com/pion/randutil v0.1.0 h1:CFG1UdESneORglEsnimhUjf33Rwjubwj6xfiOXBa3mA= +github.com/pion/randutil v0.1.0/go.mod h1:XcJrSMMbbMRhASFVOlj/5hQial/Y8oH/HVo7TBZq+j8= +github.com/pion/rtcp v1.2.16 h1:fk1B1dNW4hsI78XUCljZJlC4kZOPk67mNRuQ0fcEkSo= +github.com/pion/rtcp v1.2.16/go.mod h1:/as7VKfYbs5NIb4h6muQ35kQF/J0ZVNz2Z3xKoCBYOo= +github.com/pion/rtp v1.10.2 h1:l+f6tTDcAH6xwepaAoW791ddhuYsJlqRATOzirO04Mo= +github.com/pion/rtp v1.10.2/go.mod h1:Au8fc6cEByy8RLTwKTQTEeQqDB/SJDxwL4mZuxYA5Pk= +github.com/pion/sctp v1.10.0 h1:qeoD6swF/2M5bYRcAGayqSbTKX3m4AW29CiQxG1+Pfg= +github.com/pion/sctp v1.10.0/go.mod h1:N20Dq6LY+JvJDAh9VVh1JELngb2rQ8dPgds5yBWiPgw= +github.com/pion/sdp/v3 v3.0.18 h1:l0bAXazKHpepazVdp+tPYnrsy9dfh7ZbT8DxesH5ZnI= +github.com/pion/sdp/v3 v3.0.18/go.mod h1:ZREGo6A9ZygQ9XkqAj5xYCQtQpif0i6Pa81HOiAdqQ8= +github.com/pion/srtp/v3 v3.0.11 h1:GiESUr54/K4UuPigfq/CvWUed80JenQAHXn0C2MQQIQ= +github.com/pion/srtp/v3 v3.0.11/go.mod h1:EeZOi/sd6glM1EXapg051gdNWO9yWT1YSsgQ4SlJkns= +github.com/pion/stun/v3 v3.1.4 h1:/7ZL0j0dmLroKOq4GfkyKQ6asByYqntwyHSp5sYLcGY= +github.com/pion/stun/v3 v3.1.4/go.mod h1:ET7PFiXo1nrD2ZNVpbEHDuT0kCPVXhKmyWdiePNMw/U= +github.com/pion/transport/v4 v4.0.2 h1:ifYlPqNwsy6aKQ9y8yzxXlHae5431ZrH2avkD/Rn6Tk= +github.com/pion/transport/v4 v4.0.2/go.mod h1:06hFI+jCFcok2X2MekVufNZ/uzNZXivGBPfviSVcjgM= +github.com/pion/turn/v5 v5.0.7 h1:cA4zPYZR/tS1qZqOi5myHSQ+cwPENCvY8T/wMloP8Tg= +github.com/pion/turn/v5 v5.0.7/go.mod h1:1VwvxElZaOdJU0liJ/WUSm/Tsh+n2OxS5ISSDxgOWxU= +github.com/pion/webrtc/v4 v4.2.14 h1:Q6zMs+fSDsYuhZcNlvFGBxCOMHVV9oYcDa6O9/HIGTc= +github.com/pion/webrtc/v4 v4.2.14/go.mod h1:87NVKP86+g4OMrRxWhjWfUjeXP4JrV6RTlUrIW+/Jak= +github.com/wlynxg/anet v0.0.5 h1:J3VJGi1gvo0JwZ/P1/Yc/8p63SoW98B5dHkYDmpgvvU= +github.com/wlynxg/anet v0.0.5/go.mod h1:eay5PRQr7fIVAMbTbchTnO9gG65Hg/uYGdc7mguHxoA= +golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts= +golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos= +golang.org/x/net v0.50.0 h1:ucWh9eiCGyDR3vtzso0WMQinm2Dnt8cFMuQa9K33J60= +golang.org/x/net v0.50.0/go.mod h1:UgoSli3F/pBgdJBHCTc+tp3gmrU4XswgGRgtnwWTfyM= +golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k= +golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI= +golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/internal/acp/config.go b/internal/acp/config.go index d20f72f..c4229e9 100644 --- a/internal/acp/config.go +++ b/internal/acp/config.go @@ -76,9 +76,10 @@ type DistributedRouteConfig struct { } type OpenClawGatewayConfig struct { - MaxActive *int `yaml:"max_active"` - MaxQueued *int `yaml:"max_queued"` - QueueTimeout string `yaml:"queue_timeout"` + MaxActive *int `yaml:"max_active"` + MaxQueued *int `yaml:"max_queued"` + QueueTimeout string `yaml:"queue_timeout"` + MaxAllowedSilentDuration string `yaml:"max_allowed_silent_duration"` } func loadBridgeConfig() *BridgeConfig { diff --git a/internal/acp/execution_test.go b/internal/acp/execution_test.go index 0fee1f9..260989d 100644 --- a/internal/acp/execution_test.go +++ b/internal/acp/execution_test.go @@ -5,10 +5,14 @@ import ( "encoding/json" "net/http" "net/http/httptest" + "os" + "path/filepath" "strings" "testing" + "time" "github.com/gorilla/websocket" + "xworkmate-bridge/internal/shared" ) func TestResolveSingleAgentForwardEndpointFromExampleConfig(t *testing.T) { @@ -557,3 +561,116 @@ func TestExternalACPNotificationCollectorIgnoresCodexCommentaryMessages(t *testi t.Fatalf("expected commentary to be hidden and duplicate final line collapsed, got %#v", result) } } + +func TestProbeOpenClawTaskFailsAfterMaxAllowedSilentDuration(t *testing.T) { + t.Setenv("XWORKMATE_BRIDGE_OPENCLAW_GATEWAY_MAX_SILENT_DURATION", "2s") + t.Setenv("BRIDGE_CONFIG_PATH", filepath.Join(t.TempDir(), "missing-config.yaml")) + + server := NewServer() + orchestrator := server.orchestrator + sess := server.getOrCreateSession("silent-session", "silent-thread") + startedAt := time.Now().Add(-time.Minute) + sess.mu.Lock() + sess.task = QueuedTask{ + SessionID: "silent-session", + ThreadID: "silent-thread", + TurnID: "silent-turn", + RunID: "silent-run", + SessionKey: "silent-session", + GatewayProviderID: "openclaw", + State: TaskStateRunning, + Kind: TaskKindGateway, + RuntimeBudgetMinutes: openClawLongTaskMinutes, + StartedAt: startedAt, + DeadlineAt: time.Now().Add(time.Minute), + } + sess.openClaw = &OpenClawTaskRecord{ + SessionID: "silent-session", + ThreadID: "silent-thread", + TurnID: "silent-turn", + RunID: "silent-run", + SessionKey: "silent-session", + GatewayProviderID: "openclaw", + TaskLoadClass: "long_task", + RuntimeBudgetMinutes: openClawLongTaskMinutes, + StartedAt: startedAt, + DeadlineAt: time.Now().Add(time.Minute), + FirstSilentFailureAt: time.Now().Add(-3 * time.Second), + } + sess.mu.Unlock() + + result := orchestrator.probeOpenClawTask(context.Background(), sess, nil) + + if got := result["status"]; got != string(TaskStateFailed) { + t.Fatalf("expected failed status after silent duration, got %#v", result) + } + if got := result["code"]; got != "OPENCLAW_GATEWAY_LOST" { + t.Fatalf("expected OPENCLAW_GATEWAY_LOST, got %#v", result) + } + sess.mu.Lock() + state := sess.task.State + sess.mu.Unlock() + if state != TaskStateFailed { + t.Fatalf("task state = %s, want %s", state, TaskStateFailed) + } +} + +func TestTerminalOpenClawTaskRemovesInlineAttachmentDirectory(t *testing.T) { + workspace := t.TempDir() + turnID := "turn-inline-gc" + chatParams, rpcErr := openClawChatSendParams(map[string]any{ + "threadId": "thread-inline-gc", + "taskPrompt": "inspect uploaded file", + "workingDirectory": workspace, + "inlineAttachments": []any{ + map[string]any{ + "name": "note.txt", + "mimeType": "text/plain", + "content": "bm90ZQ==", + }, + }, + }, turnID) + if rpcErr != nil { + t.Fatalf("expected chat params, got rpc error: %#v", rpcErr) + } + attachments := shared.ListArg(chatParams, "attachments") + if len(attachments) != 1 { + t.Fatalf("expected materialized attachment, got %#v", attachments) + } + attachmentPath := shared.StringArg(shared.AsMap(attachments[0]), "path", "") + attachmentDirectory := filepath.Dir(attachmentPath) + if _, err := os.Stat(attachmentDirectory); err != nil { + t.Fatalf("expected attachment directory before terminal task state: %v", err) + } + + server := NewServer() + sess := server.getOrCreateSession("gc-session", "gc-thread") + now := time.Now() + sess.mu.Lock() + sess.task = QueuedTask{ + SessionID: "gc-session", + ThreadID: "gc-thread", + TurnID: turnID, + RunID: "gc-run", + State: TaskStateRunning, + Kind: TaskKindGateway, + StartedAt: now, + } + sess.openClaw = &OpenClawTaskRecord{ + SessionID: "gc-session", + ThreadID: "gc-thread", + TurnID: turnID, + RunID: "gc-run", + StartedAt: now, + ChatParams: map[string]any{ + "workingDirectory": workspace, + }, + } + sess.mu.Unlock() + + server.orchestrator.failOpenClawTask(sess, "TEST_FAILED", "terminal") + + if _, err := os.Stat(attachmentDirectory); !os.IsNotExist(err) { + t.Fatalf("expected terminal task to remove attachment directory, stat err=%v", err) + } +} diff --git a/internal/acp/gateway_runtime_test.go b/internal/acp/gateway_runtime_test.go index 500bb71..29615af 100644 --- a/internal/acp/gateway_runtime_test.go +++ b/internal/acp/gateway_runtime_test.go @@ -64,3 +64,56 @@ func TestResolveGatewayReportedRemoteAddressNormalizesExplicitPublicRemoteHost( t.Fatalf("resolveGatewayReportedRemoteAddress() = %q, want %q", got, want) } } + +func TestReassociateOpenClawTaskDerivesRuntimeBudgetWithoutExplicitBudget(t *testing.T) { + t.Parallel() + + cases := []struct { + name string + params map[string]any + want int + }{ + { + name: "short task load class", + params: map[string]any{ + "runId": "run-short", + "artifactScope": "tasks/main/run-short", + "taskLoadClass": "short_task", + }, + want: openClawShortTaskMinutes, + }, + { + name: "required final artifact", + params: map[string]any{ + "runId": "run-pdf", + "artifactScope": "tasks/main/run-pdf", + "requiredArtifactExtensions": []any{"pdf"}, + }, + want: openClawLongTaskMinutes, + }, + } + + for _, tc := range cases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + server := NewServer() + sess := server.reassociateOpenClawTask(tc.params) + if sess == nil { + t.Fatal("expected reassociated session") + } + sess.mu.Lock() + gotTaskBudget := sess.task.RuntimeBudgetMinutes + gotRecordBudget := sess.openClaw.RuntimeBudgetMinutes + sess.mu.Unlock() + + if gotTaskBudget != tc.want { + t.Fatalf("task RuntimeBudgetMinutes = %d, want %d", gotTaskBudget, tc.want) + } + if gotRecordBudget != tc.want { + t.Fatalf("record RuntimeBudgetMinutes = %d, want %d", gotRecordBudget, tc.want) + } + }) + } +} diff --git a/internal/acp/openclaw_async_tasks.go b/internal/acp/openclaw_async_tasks.go index b129ef7..66dba3d 100644 --- a/internal/acp/openclaw_async_tasks.go +++ b/internal/acp/openclaw_async_tasks.go @@ -2,6 +2,8 @@ package acp import ( "context" + "os" + "path/filepath" "strings" "time" @@ -10,12 +12,13 @@ import ( ) const ( - openClawTaskProbeTimeout = 2 * time.Second - openClawTaskProbeTimeoutMs = 1000 - openClawTaskMonitorInterval = time.Second - openClawShortTaskMinutes = 10 - openClawLongTaskMinutes = 30 - openClawComplexTaskMinutes = 60 + openClawTaskProbeTimeout = 2 * time.Second + openClawTaskProbeTimeoutMs = 1000 + openClawTaskMonitorInterval = time.Second + openClawShortTaskMinutes = 10 + openClawLongTaskMinutes = 30 + openClawComplexTaskMinutes = 60 + openClawDefaultMaxAllowedSilentDuration = 10 * time.Minute ) type OpenClawTaskRecord struct { @@ -34,6 +37,7 @@ type OpenClawTaskRecord struct { ProgressStage string ProgressMessage string ProgressTerminal bool + FirstSilentFailureAt time.Time ChatParams map[string]any PreparedArtifact *openClawPreparedArtifactScope ArtifactContract openClawArtifactContract @@ -50,7 +54,22 @@ func openClawTaskRuntimePolicy(params map[string]any, chatParams map[string]any, message = openClawCurrentTurnMessage(params) } lower := strings.ToLower(message) + taskLoadClass := strings.TrimSpace(contract.TaskLoadClass) + if taskLoadClass == "" { + taskLoadClass = strings.TrimSpace(shared.StringArg(params, "taskLoadClass", "")) + } metadataClass := strings.TrimSpace(shared.StringArg(shared.AsMap(params["metadata"]), "taskLoadClass", "")) + if taskLoadClass == "" { + taskLoadClass = metadataClass + } + switch taskLoadClass { + case "short_task": + return "short_task", openClawShortTaskMinutes + case "long_task": + return "long_task", openClawLongTaskMinutes + case "complex_chain_task", "complex_long_chain_task": + return "complex_chain_task", openClawComplexTaskMinutes + } if metadataClass == "complex_long_chain_task" || contract.ComplexLongChain || openClawMessageContainsAny(lower, []string{ "复杂链路", "多章节", "每章", "拆章节", "汇总排版", "gpt images", "images2", "image generation", "视频", "渲染", "hyperframes", "remotion", "ffmpeg", }) { @@ -194,6 +213,7 @@ func (o *SessionOrchestrator) failOpenClawTask(sess *session, code string, messa turnID := sess.task.TurnID runID := sess.task.RunID gatewayProviderID := sess.task.GatewayProviderID + record := sess.openClaw sess.task.State = TaskStateFailed sess.task.UpdatedAt = time.Now() sess.task.ProgressStage = "failed" @@ -221,6 +241,7 @@ func (o *SessionOrchestrator) failOpenClawTask(sess *session, code string, messa sess.lastResult = cloneMap(result) sess.mu.Unlock() o.releaseOpenClawAdmission(sess) + cleanupOpenClawTurnAttachments(record) return result } @@ -295,8 +316,17 @@ func (o *SessionOrchestrator) probeOpenClawTask(ctx context.Context, sess *sessi ) if !waitResult.OK { if openClawProbeStillRunning(waitResult.Error) { + now := time.Now() sess.mu.Lock() if sess.openClaw != nil { + if sess.openClaw.FirstSilentFailureAt.IsZero() { + sess.openClaw.FirstSilentFailureAt = now + } + if openClawSilentFailureExceeded(o.server.config, sess.openClaw.FirstSilentFailureAt, now) { + sess.openClaw.ProbeInFlight = false + sess.mu.Unlock() + return o.failOpenClawTask(sess, "OPENCLAW_GATEWAY_LOST", "OpenClaw gateway stayed unreachable beyond the allowed silent duration") + } sess.openClaw.ProgressStage = "running" sess.openClaw.ProgressMessage = "OpenClaw task is still running" sess.openClaw.ProbeInFlight = false @@ -313,9 +343,34 @@ func (o *SessionOrchestrator) probeOpenClawTask(ctx context.Context, sess *sessi message := strings.TrimSpace(shared.StringArg(waitResult.Error, "message", "openclaw wait failed")) return o.failOpenClawTask(sess, code, message) } + sess.mu.Lock() + if sess.openClaw != nil { + sess.openClaw.FirstSilentFailureAt = time.Time{} + } + sess.mu.Unlock() return o.completeOpenClawTask(sess, shared.AsMap(waitResult.Payload), collector, notify) } +func openClawSilentFailureExceeded(config *BridgeConfig, firstFailureAt time.Time, now time.Time) bool { + if firstFailureAt.IsZero() { + return false + } + return now.Sub(firstFailureAt) >= openClawMaxAllowedSilentDuration(config) +} + +func openClawMaxAllowedSilentDuration(config *BridgeConfig) time.Duration { + raw := strings.TrimSpace(shared.EnvOrDefault("XWORKMATE_BRIDGE_OPENCLAW_GATEWAY_MAX_SILENT_DURATION", "")) + if raw == "" && config != nil { + raw = strings.TrimSpace(config.OpenClawGateway.MaxAllowedSilentDuration) + } + if raw != "" { + if parsed, err := time.ParseDuration(raw); err == nil && parsed > 0 { + return parsed + } + } + return openClawDefaultMaxAllowedSilentDuration +} + func openClawProbeStillRunning(errorPayload map[string]any) bool { code := strings.TrimSpace(strings.ToUpper(shared.StringArg(errorPayload, "code", ""))) if code == "TIMEOUT" || code == "RPC_TIMEOUT" || code == "REQUEST_TIMEOUT" || @@ -445,12 +500,50 @@ func (o *SessionOrchestrator) completeOpenClawTask( sess.lastResult = cloneMap(result) sess.mu.Unlock() o.releaseOpenClawAdmission(sess) + cleanupOpenClawTurnAttachments(record) if notify != nil { notify(shared.NotificationEnvelope("session.update", openClawGatewayCompletedResultUpdate(record.SessionID, record.ThreadID, record.TurnID, result))) } return result } +func cleanupOpenClawTurnAttachments(record *OpenClawTaskRecord) { + if record == nil { + return + } + workingDirectory := strings.TrimSpace(shared.StringArg(record.ChatParams, "workingDirectory", "")) + if workingDirectory == "" { + return + } + attachmentDirectory := filepath.Join( + workingDirectory, + ".xworkmate", + "attachments", + safeOpenClawAttachmentPathSegment(record.TurnID, "turn"), + ) + if !openClawSafeAttachmentCleanupPath(workingDirectory, attachmentDirectory) { + return + } + _ = os.RemoveAll(attachmentDirectory) +} + +func openClawSafeAttachmentCleanupPath(workingDirectory string, attachmentDirectory string) bool { + workingRoot, err := filepath.Abs(strings.TrimSpace(workingDirectory)) + if err != nil || workingRoot == "" { + return false + } + attachmentRoot := filepath.Join(workingRoot, ".xworkmate", "attachments") + target, err := filepath.Abs(strings.TrimSpace(attachmentDirectory)) + if err != nil || target == "" { + return false + } + rel, err := filepath.Rel(attachmentRoot, target) + if err != nil || rel == "." || strings.HasPrefix(rel, "..") || filepath.IsAbs(rel) { + return false + } + return true +} + func openClawSessionSnapshotLocked(sess *session) map[string]any { payload := map[string]any{ "status": string(sess.task.State), diff --git a/internal/acp/rpc_handler.go b/internal/acp/rpc_handler.go index 53f2f92..7e7e683 100644 --- a/internal/acp/rpc_handler.go +++ b/internal/acp/rpc_handler.go @@ -2,10 +2,14 @@ package acp import ( "context" + "encoding/json" "fmt" "strings" "time" + "xworkmate-bridge/internal/desktop" "xworkmate-bridge/internal/shared" + + "github.com/pion/webrtc/v4" ) func (s *Server) handleRequest(request shared.RPCRequest, notify func(map[string]any)) (map[string]any, *shared.RPCError) { @@ -56,6 +60,9 @@ func (s *Server) handleRequest(request shared.RPCRequest, notify func(map[string // Gateway 语义由专门的 Gateway 组件通过 Adapter 处理 return s.handleGatewayMethod(ctx, method, request.Params, notify) + case "xworkmate.desktop.offer", "xworkmate.desktop.ice", "xworkmate.desktop.close": + return s.handleDesktopMethod(ctx, method, request.Params, notify) + case "xworkmate.jobs.submit", "xworkmate.jobs.get", "xworkmate.jobs.list", "xworkmate.jobs.stats": return s.handleJobMethod(ctx, method, request.Params, notify) @@ -175,7 +182,6 @@ func (s *Server) reassociateOpenClawTask(params map[string]any) *session { turnID := strings.TrimSpace(shared.StringArg(params, "turnId", runID)) sessionKey := strings.TrimSpace(shared.StringArg(params, "sessionKey", threadID)) gatewayProvider := strings.TrimSpace(shared.StringArg(params, "gatewayProviderId", "openclaw")) - budget := shared.IntArg(shared.StringArg(params, "runtimeBudgetMinutes", ""), openClawComplexTaskMinutes) now := time.Now() prepared := &openClawPreparedArtifactScope{ ArtifactScope: artifactScope, @@ -190,6 +196,10 @@ func (s *Server) reassociateOpenClawTask(params map[string]any) *session { ExpectedArtifactExtensions: normalizeOpenClawExtensionList(shared.ListArg(params, "expectedArtifactExtensions")), RequiredFinalExtensions: normalizeOpenClawExtensionList(shared.ListArg(params, "requiredArtifactExtensions")), } + taskLoadClass, budget := openClawTaskRuntimePolicy(params, map[string]any{"sessionKey": sessionKey}, contract) + if explicitBudget := shared.IntArg(shared.StringArg(params, "runtimeBudgetMinutes", ""), 0); explicitBudget > 0 { + budget = explicitBudget + } sess := s.getOrCreateSession(sessionID, threadID) sess.mu.Lock() sess.provider = gatewayProvider @@ -206,7 +216,7 @@ func (s *Server) reassociateOpenClawTask(params map[string]any) *session { GatewayProviderID: gatewayProvider, State: TaskStateRunning, Kind: TaskKindGateway, - TaskLoadClass: contract.TaskLoadClass, + TaskLoadClass: taskLoadClass, ArtifactScope: artifactScope, ArtifactDirectory: prepared.ArtifactDirectory, RuntimeBudgetMinutes: budget, @@ -223,7 +233,7 @@ func (s *Server) reassociateOpenClawTask(params map[string]any) *session { RunID: runID, SessionKey: sessionKey, GatewayProviderID: gatewayProvider, - TaskLoadClass: contract.TaskLoadClass, + TaskLoadClass: taskLoadClass, ArtifactSinceUnixMs: 0, RuntimeBudgetMinutes: budget, StartedAt: now, @@ -254,15 +264,99 @@ func (s *Server) cancelSession(ctx context.Context, sessionID string) { func (s *Server) closeSession(ctx context.Context, sessionID string) bool { s.mu.Lock() - sess, ok := s.sessions[sessionID] - delete(s.sessions, sessionID) - s.mu.Unlock() - if ok && sess != nil && sess.compat != nil { - sess.mu.Lock() - sess.task.State = TaskStateCancelled - sess.task.UpdatedAt = time.Now() - sess.mu.Unlock() - _ = sess.compat.CloseSession(ctx, sessionID) + _, existed := s.sessions[sessionID] + if existed { + delete(s.sessions, sessionID) } - return ok + s.mu.Unlock() + return existed } + +func (s *Server) handleDesktopMethod(ctx context.Context, method string, params map[string]any, notify func(map[string]any)) (map[string]any, *shared.RPCError) { + sessionID := strings.TrimSpace(shared.StringArg(params, "sessionId", "")) + if sessionID == "" { + sessionID = "default" + } + + srv := desktop.GetService() + + switch method { + case "xworkmate.desktop.offer": + sdpOffer := strings.TrimSpace(shared.StringArg(params, "sdpOffer", "")) + if sdpOffer == "" { + return nil, &shared.RPCError{Code: -32602, Message: "sdpOffer is required"} + } + + display := strings.TrimSpace(shared.StringArg(params, "display", "")) + width := shared.IntArg(shared.StringArg(params, "width", ""), 1280) + height := shared.IntArg(shared.StringArg(params, "height", ""), 720) + fps := shared.IntArg(shared.StringArg(params, "fps", ""), 30) + bitrate := shared.IntArg(shared.StringArg(params, "bitrate", ""), 2000) + useGPU := shared.BoolArg(shared.StringArg(params, "useGpu", ""), false) + + var iceServers []string + if rawIce, ok := params["iceServers"].([]any); ok { + for _, ice := range rawIce { + if s, ok := ice.(string); ok { + iceServers = append(iceServers, s) + } + } + } + + cfg := desktop.PipelineConfig{ + Display: display, + Port: 5004, + Width: width, + Height: height, + FPS: fps, + Bitrate: bitrate, + UseGPU: useGPU, + ToolType: "auto", + } + + sess, err := srv.StartSession(sessionID, cfg, iceServers) + if err != nil { + return nil, &shared.RPCError{Code: -32001, Message: fmt.Sprintf("failed to start desktop session: %v", err)} + } + + sdpAnswer, err := sess.WebRTC.ProcessOffer(sdpOffer) + if err != nil { + srv.StopSession(sessionID) + return nil, &shared.RPCError{Code: -32002, Message: fmt.Sprintf("failed to process SDP offer: %v", err)} + } + + return map[string]any{ + "sessionId": sessionID, + "sdpAnswer": sdpAnswer, + }, nil + + case "xworkmate.desktop.ice": + candidateData, ok := params["candidate"].(map[string]any) + if !ok { + return nil, &shared.RPCError{Code: -32602, Message: "candidate object is required"} + } + + var candidate webrtc.ICECandidateInit + bytes, err := json.Marshal(candidateData) + if err != nil { + return nil, &shared.RPCError{Code: -32602, Message: fmt.Sprintf("failed to marshal candidate: %v", err)} + } + if err := json.Unmarshal(bytes, &candidate); err != nil { + return nil, &shared.RPCError{Code: -32602, Message: fmt.Sprintf("failed to unmarshal candidate: %v", err)} + } + + if err := srv.AddICECandidate(sessionID, candidate); err != nil { + return nil, &shared.RPCError{Code: -32003, Message: fmt.Sprintf("failed to add ICE candidate: %v", err)} + } + + return map[string]any{"status": "ok"}, nil + + case "xworkmate.desktop.close": + srv.StopSession(sessionID) + return map[string]any{"status": "closed"}, nil + + default: + return nil, &shared.RPCError{Code: -32601, Message: fmt.Sprintf("unknown desktop method: %s", method)} + } +} + diff --git a/internal/desktop/input.go b/internal/desktop/input.go new file mode 100644 index 0000000..b65a832 --- /dev/null +++ b/internal/desktop/input.go @@ -0,0 +1,219 @@ +package desktop + +import ( + "fmt" + "io" + "log" + "os/exec" + "strconv" + "strings" + "sync" +) + +// InputEvent represents a client mouse or keyboard action +type InputEvent struct { + Type string `json:"type"` // "mouse_move", "mouse_down", "mouse_up", "key_down", "key_up", "scroll" + X float64 `json:"x,omitempty"` // normalized x coordinate (0.0 to 1.0) + Y float64 `json:"y,omitempty"` // normalized y coordinate (0.0 to 1.0) + Button int `json:"button,omitempty"` // mouse button: 1=left, 2=middle, 3=right, 4=scroll_up, 5=scroll_down + Key string `json:"key,omitempty"` // key symbol or keycode +} + +// XdotoolInjector injects inputs by writing to a persistent xdotool process stdin +type XdotoolInjector struct { + cmd *exec.Cmd + stdin io.WriteCloser + mu sync.Mutex + display string + width int + height int + isStarted bool +} + +func NewXdotoolInjector(display string) *XdotoolInjector { + if display == "" { + display = ":0.0" + } + return &XdotoolInjector{ + display: display, + width: 1280, // Default fallbacks + height: 720, + } +} + +// Start launches xdotool and queries screen resolution +func (xi *XdotoolInjector) Start() error { + xi.mu.Lock() + defer xi.mu.Unlock() + + if xi.isStarted { + return nil + } + + // 1. Resolve screen resolution + w, h, err := xi.queryDisplayGeometry() + if err == nil { + xi.width = w + xi.height = h + log.Printf("Detected remote display geometry: %dx%d on %s", w, h, xi.display) + } else { + log.Printf("Warning: Failed to query display geometry: %v. Using default: %dx%d", err, xi.width, xi.height) + } + + // 2. Launch persistent xdotool process + cmd := exec.Command("xdotool", "-") + cmd.Env = append(cmd.Env, "DISPLAY="+xi.display) + + stdin, err := cmd.StdinPipe() + if err != nil { + return fmt.Errorf("failed to open xdotool stdin pipe: %w", err) + } + + if err := cmd.Start(); err != nil { + stdin.Close() + return fmt.Errorf("failed to start xdotool process: %w", err) + } + + xi.cmd = cmd + xi.stdin = stdin + xi.isStarted = true + + return nil +} + +// Inject sends a command to the persistent xdotool process +func (xi *XdotoolInjector) Inject(event InputEvent) error { + xi.mu.Lock() + defer xi.mu.Unlock() + + if !xi.isStarted || xi.stdin == nil { + return fmt.Errorf("injector is not running") + } + + var cmdStr string + + switch event.Type { + case "mouse_move": + absX := int(event.X * float64(xi.width)) + absY := int(event.Y * float64(xi.height)) + cmdStr = fmt.Sprintf("mousemove %d %d\n", absX, absY) + + case "mouse_down": + btn := xi.mapButton(event.Button) + cmdStr = fmt.Sprintf("mousedown %d\n", btn) + + case "mouse_up": + btn := xi.mapButton(event.Button) + cmdStr = fmt.Sprintf("mouseup %d\n", btn) + + case "key_down": + key := xi.sanitizeKey(event.Key) + cmdStr = fmt.Sprintf("keydown %s\n", key) + + case "key_up": + key := xi.sanitizeKey(event.Key) + cmdStr = fmt.Sprintf("keyup %s\n", key) + + case "scroll": + // xdotool maps scroll up to button 4 and scroll down to button 5 + if event.Button == 4 || event.Button == 5 { + cmdStr = fmt.Sprintf("click %d\n", event.Button) + } + + default: + return fmt.Errorf("unsupported input type: %s", event.Type) + } + + if cmdStr != "" { + _, err := xi.stdin.Write([]byte(cmdStr)) + if err != nil { + // Try to restart if pipe is broken + log.Printf("xdotool write error: %v. Attempting to restart injector.", err) + xi.isStarted = false + xi.stdin.Close() + if restartErr := xi.Start(); restartErr == nil { + _, _ = xi.stdin.Write([]byte(cmdStr)) + } + return err + } + } + + return nil +} + +// Close terminates the xdotool process +func (xi *XdotoolInjector) Close() error { + xi.mu.Lock() + defer xi.mu.Unlock() + + if !xi.isStarted { + return nil + } + + xi.isStarted = false + if xi.stdin != nil { + _ = xi.stdin.Close() + } + if xi.cmd != nil { + _ = xi.cmd.Process.Kill() + } + + xi.stdin = nil + xi.cmd = nil + return nil +} + +func (xi *XdotoolInjector) queryDisplayGeometry() (int, int, error) { + cmd := exec.Command("xdotool", "getdisplaygeometry") + cmd.Env = append(cmd.Env, "DISPLAY="+xi.display) + out, err := cmd.Output() + if err != nil { + return 0, 0, err + } + + parts := strings.Fields(string(out)) + if len(parts) < 2 { + return 0, 0, fmt.Errorf("invalid geometry output: %s", string(out)) + } + + w, err1 := strconv.Atoi(parts[0]) + h, err2 := strconv.Atoi(parts[1]) + if err1 != nil || err2 != nil { + return 0, 0, fmt.Errorf("failed to parse geometry: %w, %w", err1, err2) + } + + return w, h, nil +} + +func (xi *XdotoolInjector) mapButton(btn int) int { + // Standard mapping: 1=left, 2=middle, 3=right + if btn <= 0 || btn > 3 { + return 1 + } + return btn +} + +func (xi *XdotoolInjector) sanitizeKey(key string) string { + // Clean or map keys if they need special handling in xdotool + // For example, Flutter "Backspace" needs to be "BackSpace" or "Return" to "Return". + key = strings.TrimSpace(key) + switch strings.ToLower(key) { + case "enter": + return "Return" + case "backspace": + return "BackSpace" + case "tab": + return "Tab" + case "escape": + return "Escape" + case "control": + return "ctrl" + case "shift": + return "shift" + case "alt": + return "alt" + case "meta": + return "super" + } + return key +} diff --git a/internal/desktop/pipeline.go b/internal/desktop/pipeline.go new file mode 100644 index 0000000..a6e4079 --- /dev/null +++ b/internal/desktop/pipeline.go @@ -0,0 +1,245 @@ +package desktop + +import ( + "context" + "fmt" + "log" + "os" + "os/exec" + "strings" + "sync" + "time" +) + +// PipelineManager manages the screen capture process lifecycle +type PipelineManager struct { + cmd *exec.Cmd + mu sync.Mutex + isRunning bool + cancel context.CancelFunc +} + +type PipelineConfig struct { + Display string // e.g. ":0" + Port int // e.g. 5004 + Width int // e.g. 1920 + Height int // e.g. 1080 + FPS int // e.g. 30 + Bitrate int // in kbps, e.g. 2000 + UseGPU bool // try using hardware accelerated encoding (nvh264enc / h264_nvenc) + ToolType string // "gstreamer" or "ffmpeg" or "auto" +} + +func NewPipelineManager() *PipelineManager { + return &PipelineManager{} +} + +// Start spawns the GStreamer or FFmpeg capture process in the background +func (pm *PipelineManager) Start(cfg PipelineConfig) error { + pm.mu.Lock() + defer pm.mu.Unlock() + + if pm.isRunning { + return fmt.Errorf("pipeline is already running") + } + + if cfg.Display == "" { + cfg.Display = os.Getenv("DISPLAY") + if cfg.Display == "" { + cfg.Display = ":0.0" + } + } + if cfg.Port <= 0 { + cfg.Port = 5004 + } + if cfg.Width <= 0 { + cfg.Width = 1280 + } + if cfg.Height <= 0 { + cfg.Height = 720 + } + if cfg.FPS <= 0 { + cfg.FPS = 30 + } + if cfg.Bitrate <= 0 { + cfg.Bitrate = 2000 + } + + tool, args, err := pm.resolvePipeline(cfg) + if err != nil { + return fmt.Errorf("failed to resolve pipeline command: %w", err) + } + + ctx, cancel := context.WithCancel(context.Background()) + pm.cancel = cancel + + cmd := exec.CommandContext(ctx, tool, args...) + // Set X11 display environment variable + cmd.Env = append(os.Environ(), "DISPLAY="+cfg.Display) + + // Capture stdout/stderr for logging + cmd.Stderr = os.Stderr + + log.Printf("Starting capture pipeline: %s %s", tool, strings.Join(args, " ")) + if err := cmd.Start(); err != nil { + cancel() + return fmt.Errorf("failed to start pipeline process: %w", err) + } + + pm.cmd = cmd + pm.isRunning = true + + // Monitor process termination asynchronously + go func() { + err := cmd.Wait() + pm.mu.Lock() + pm.isRunning = false + pm.cmd = nil + pm.mu.Unlock() + if err != nil { + log.Printf("Capture pipeline exited with error: %v", err) + } else { + log.Printf("Capture pipeline stopped cleanly") + } + }() + + return nil +} + +// Stop terminates the capture process +func (pm *PipelineManager) Stop() { + pm.mu.Lock() + defer pm.mu.Unlock() + + if !pm.isRunning || pm.cancel == nil { + return + } + + log.Println("Stopping capture pipeline...") + pm.cancel() // cancels the context, sending SIGKILL or SIGTERM + + // Wait up to 2 seconds for clean exit + for i := 0; i < 20; i++ { + if !pm.isRunning { + break + } + time.Sleep(100 * time.Millisecond) + } + pm.isRunning = false + pm.cmd = nil + pm.cancel = nil +} + +func (pm *PipelineManager) IsRunning() bool { + pm.mu.Lock() + defer pm.mu.Unlock() + return pm.isRunning +} + +// resolvePipeline checks for available software and builds command arguments +func (pm *PipelineManager) resolvePipeline(cfg PipelineConfig) (string, []string, error) { + tool := cfg.ToolType + if tool == "auto" || tool == "" { + if pm.hasExecutable("gst-launch-1.0") { + tool = "gstreamer" + } else if pm.hasExecutable("ffmpeg") { + tool = "ffmpeg" + } else { + return "", nil, fmt.Errorf("neither GStreamer (gst-launch-1.0) nor FFmpeg was found in PATH") + } + } + + switch tool { + case "gstreamer": + return pm.buildGStreamer(cfg) + case "ffmpeg": + return pm.buildFFmpeg(cfg) + default: + return "", nil, fmt.Errorf("unsupported capture tool type: %s", tool) + } +} + +func (pm *PipelineManager) hasExecutable(name string) bool { + _, err := exec.LookPath(name) + return err == nil +} + +func (pm *PipelineManager) buildGStreamer(cfg PipelineConfig) (string, []string, error) { + var pipelineParts []string + + // 1. Capture Source (X11) + pipelineParts = append(pipelineParts, fmt.Sprintf("ximagesrc display-name=%s", cfg.Display)) + pipelineParts = append(pipelineParts, "video/x-raw,framerate=30/1") + pipelineParts = append(pipelineParts, "videoconvert") + + // 2. Encoder + encoderStr := "x264enc speed-preset=ultrafast tune=zerolatency bitrate=" + fmt.Sprintf("%d", cfg.Bitrate) + if cfg.UseGPU { + // Detect if nvcodec is present by calling gst-inspect-1.0 or simply attempting it. + // We'll default to nvh264enc. + encoderStr = "nvh264enc bitrate=" + fmt.Sprintf("%d", cfg.Bitrate) + " preset=low-latency gop-size=30" + } + pipelineParts = append(pipelineParts, encoderStr) + + // 3. Payload and Sink + pipelineParts = append(pipelineParts, "rtph264pay config-interval=1 pt=96") + pipelineParts = append(pipelineParts, fmt.Sprintf("udpsink host=127.0.0.1 port=%d sync=false async=false", cfg.Port)) + + // Join GStreamer pipeline with '!' + pipelineStr := strings.Join(pipelineParts, " ! ") + args := []string{"-v"} + args = append(args, strings.Split(pipelineStr, " ")...) + + // Clean up empty parameters in parsed string split + var cleanArgs []string + for _, arg := range args { + trimmed := strings.TrimSpace(arg) + if trimmed != "" && trimmed != "!" { + cleanArgs = append(cleanArgs, trimmed) + } + } + + return "gst-launch-1.0", cleanArgs, nil +} + +func (pm *PipelineManager) buildFFmpeg(cfg PipelineConfig) (string, []string, error) { + args := []string{ + "-f", "x11grab", + "-draw_mouse", "1", + "-framerate", fmt.Sprintf("%d", cfg.FPS), + "-video_size", fmt.Sprintf("%dx%d", cfg.Width, cfg.Height), + "-i", cfg.Display, + } + + // Encoder config + if cfg.UseGPU { + args = append(args, + "-c:v", "h264_nvenc", + "-preset", "llhp", // low latency high quality + "-tune", "zerolatency", + "-g", "30", + ) + } else { + args = append(args, + "-c:v", "libx264", + "-preset", "ultrafast", + "-tune", "zerolatency", + "-g", "30", + ) + } + + // Constant Bitrate + args = append(args, + "-b:v", fmt.Sprintf("%dk", cfg.Bitrate), + "-maxrate", fmt.Sprintf("%dk", cfg.Bitrate), + "-bufsize", fmt.Sprintf("%dk", cfg.Bitrate*2), + ) + + // RTP Stream over UDP + args = append(args, + "-f", "rtp", + fmt.Sprintf("rtp://127.0.0.1:%d", cfg.Port), + ) + + return "ffmpeg", args, nil +} diff --git a/internal/desktop/service.go b/internal/desktop/service.go new file mode 100644 index 0000000..227c4bd --- /dev/null +++ b/internal/desktop/service.go @@ -0,0 +1,132 @@ +package desktop + +import ( + "fmt" + "log" + "sync" + + "github.com/pion/webrtc/v4" +) + +type DesktopSession struct { + SessionID string + Pipeline *PipelineManager + Injector *XdotoolInjector + WebRTC *WebRTCServer +} + +type Service struct { + sessions map[string]*DesktopSession + mu sync.Mutex +} + +var ( + instance *Service + once sync.Once +) + +func GetService() *Service { + once.Do(func() { + instance = &Service{ + sessions: make(map[string]*DesktopSession), + } + }) + return instance +} + +func (s *Service) StartSession(sessionID string, cfg PipelineConfig, iceServers []string) (*DesktopSession, error) { + s.mu.Lock() + defer s.mu.Unlock() + + // Stop old session if exists + if old, exists := s.sessions[sessionID]; exists { + s.stopSessionLocked(old) + } + + log.Printf("Starting Remote Desktop session: %s", sessionID) + + // 1. Initialize input injector + injector := NewXdotoolInjector(cfg.Display) + if err := injector.Start(); err != nil { + return nil, fmt.Errorf("failed to start input injector: %w", err) + } + + // 2. Initialize WebRTC server + webrtcSrv, err := NewWebRTCServer(injector) + if err != nil { + injector.Close() + return nil, fmt.Errorf("failed to create WebRTC server: %w", err) + } + + if err := webrtcSrv.InitPeerConnection(iceServers); err != nil { + injector.Close() + return nil, fmt.Errorf("failed to init peer connection: %w", err) + } + + // Start local UDP listener for GStreamer RTP packets + if err := webrtcSrv.StartRTPReceiver(cfg.Port); err != nil { + webrtcSrv.Close() + injector.Close() + return nil, fmt.Errorf("failed to start RTP receiver: %w", err) + } + + // 3. Initialize screen capture pipeline + pipeline := NewPipelineManager() + if err := pipeline.Start(cfg); err != nil { + webrtcSrv.Close() + injector.Close() + return nil, fmt.Errorf("failed to start capture pipeline: %w", err) + } + + sess := &DesktopSession{ + SessionID: sessionID, + Pipeline: pipeline, + Injector: injector, + WebRTC: webrtcSrv, + } + + s.sessions[sessionID] = sess + return sess, nil +} + +func (s *Service) GetSession(sessionID string) (*DesktopSession, error) { + s.mu.Lock() + defer s.mu.Unlock() + + sess, exists := s.sessions[sessionID] + if !exists { + return nil, fmt.Errorf("session %s not found", sessionID) + } + return sess, nil +} + +func (s *Service) StopSession(sessionID string) { + s.mu.Lock() + defer s.mu.Unlock() + + if sess, exists := s.sessions[sessionID]; exists { + s.stopSessionLocked(sess) + delete(s.sessions, sessionID) + } +} + +func (s *Service) stopSessionLocked(sess *DesktopSession) { + log.Printf("Stopping Remote Desktop session: %s", sess.SessionID) + if sess.Pipeline != nil { + sess.Pipeline.Stop() + } + if sess.WebRTC != nil { + sess.WebRTC.Close() + } + if sess.Injector != nil { + _ = sess.Injector.Close() + } +} + +func (s *Service) AddICECandidate(sessionID string, candidate webrtc.ICECandidateInit) error { + sess, err := s.GetSession(sessionID) + if err != nil { + return err + } + return sess.WebRTC.AddICECandidate(candidate) +} diff --git a/internal/desktop/webrtc.go b/internal/desktop/webrtc.go new file mode 100644 index 0000000..ae8ae95 --- /dev/null +++ b/internal/desktop/webrtc.go @@ -0,0 +1,217 @@ +package desktop + +import ( + "encoding/json" + "fmt" + "log" + "net" + "sync" + + "github.com/pion/webrtc/v4" +) + +type WebRTCServer struct { + peerConnection *webrtc.PeerConnection + videoTrack *webrtc.TrackLocalStaticRTP + udpListener net.PacketConn + inputInjector *XdotoolInjector + mu sync.Mutex + isClosed bool +} + +func NewWebRTCServer(injector *XdotoolInjector) (*WebRTCServer, error) { + return &WebRTCServer{ + inputInjector: injector, + }, nil +} + +// InitPeerConnection sets up the Pion PeerConnection and local H.264 video track +func (w *WebRTCServer) InitPeerConnection(iceServers []string) error { + w.mu.Lock() + defer w.mu.Unlock() + + var webrtcIceServers []webrtc.ICEServer + for _, url := range iceServers { + webrtcIceServers = append(webrtcIceServers, webrtc.ICEServer{ + URLs: []string{url}, + }) + } + if len(webrtcIceServers) == 0 { + // Default STUN server + webrtcIceServers = append(webrtcIceServers, webrtc.ICEServer{ + URLs: []string{"stun:stun.l.google.com:19302"}, + }) + } + + config := webrtc.Configuration{ + ICEServers: webrtcIceServers, + } + + pc, err := webrtc.NewPeerConnection(config) + if err != nil { + return fmt.Errorf("failed to create WebRTC PeerConnection: %w", err) + } + + // Create H.264 video track + videoTrack, err := webrtc.NewTrackLocalStaticRTP( + webrtc.RTPCodecCapability{MimeType: webrtc.MimeTypeH264}, + "video", + "xworkmate-desktop", + ) + if err != nil { + pc.Close() + return fmt.Errorf("failed to create video track: %w", err) + } + + _, err = pc.AddTrack(videoTrack) + if err != nil { + pc.Close() + return fmt.Errorf("failed to add video track: %w", err) + } + + // Handle Data Channel for inputs + pc.OnDataChannel(func(d *webrtc.DataChannel) { + log.Printf("Data channel '%s'-'%d' opened", d.Label(), d.ID()) + if d.Label() == "input" { + d.OnMessage(func(msg webrtc.DataChannelMessage) { + var event InputEvent + if err := json.Unmarshal(msg.Data, &event); err != nil { + log.Printf("Failed to unmarshal input event: %v", err) + return + } + if err := w.inputInjector.Inject(event); err != nil { + log.Printf("Failed to inject input event: %v", err) + } + }) + } + }) + + w.peerConnection = pc + w.videoTrack = videoTrack + + return nil +} + +// StartRTPReceiver listens on local UDP port for GStreamer RTP stream and forwards to WebRTC video track +func (w *WebRTCServer) StartRTPReceiver(port int) error { + addr := fmt.Sprintf("127.0.0.1:%d", port) + conn, err := net.ListenPacket("udp", addr) + if err != nil { + return fmt.Errorf("failed to bind UDP port %s: %w", addr, err) + } + + w.mu.Lock() + w.udpListener = conn + w.mu.Unlock() + + go func() { + buf := make([]byte, 2048) + log.Printf("WebRTC RTP receiver listening on UDP %s", addr) + for { + n, _, err := conn.ReadFrom(buf) + if err != nil { + w.mu.Lock() + closed := w.isClosed + w.mu.Unlock() + if !closed { + log.Printf("UDP RTP read error: %v", err) + } + break + } + + // Forward packet directly to WebRTC track (zero-copy) + w.mu.Lock() + track := w.videoTrack + w.mu.Unlock() + + if track != nil { + if _, err := track.Write(buf[:n]); err != nil { + log.Printf("Failed to write RTP packet to track: %v", err) + } + } + } + }() + + return nil +} + +// ProcessOffer handles SDP offer, sets remote description, generates and returns SDP answer +func (w *WebRTCServer) ProcessOffer(sdpOffer string) (string, error) { + w.mu.Lock() + pc := w.peerConnection + w.mu.Unlock() + + if pc == nil { + return "", fmt.Errorf("peer connection not initialized") + } + + offer := webrtc.SessionDescription{ + Type: webrtc.SDPTypeOffer, + SDP: sdpOffer, + } + + if err := pc.SetRemoteDescription(offer); err != nil { + return "", fmt.Errorf("failed to set remote description: %w", err) + } + + answer, err := pc.CreateAnswer(nil) + if err != nil { + return "", fmt.Errorf("failed to create SDP answer: %w", err) + } + + // Gather ICE candidates + gatherComplete := webrtc.GatheringCompletePromise(pc) + + if err := pc.SetLocalDescription(answer); err != nil { + return "", fmt.Errorf("failed to set local description: %w", err) + } + + <-gatherComplete + + localDesc := pc.LocalDescription() + if localDesc == nil { + return "", fmt.Errorf("local description is nil after gathering") + } + + return localDesc.SDP, nil +} + +// AddICECandidate adds a remote ICE candidate +func (w *WebRTCServer) AddICECandidate(candidate webrtc.ICECandidateInit) error { + w.mu.Lock() + pc := w.peerConnection + w.mu.Unlock() + + if pc == nil { + return fmt.Errorf("peer connection not initialized") + } + + if err := pc.AddICECandidate(candidate); err != nil { + return fmt.Errorf("failed to add remote ICE candidate: %w", err) + } + + return nil +} + +// Close terminates the WebRTC server +func (w *WebRTCServer) Close() { + w.mu.Lock() + if w.isClosed { + w.mu.Unlock() + return + } + w.isClosed = true + pc := w.peerConnection + conn := w.udpListener + w.mu.Unlock() + + log.Println("Closing WebRTC server...") + + if conn != nil { + _ = conn.Close() + } + + if pc != nil { + _ = pc.Close() + } +}