xworkmate-bridge/internal/acp/distributed_forwarder.go
2026-05-26 14:11:01 +08:00

149 lines
4.2 KiB
Go

package acp
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"net/url"
"strings"
"xworkmate-bridge/internal/shared"
)
const distributedForwardedHeader = "X-XWorkmate-Bridge-Forwarded"
type distributedTaskForwarderConfig struct {
Endpoint string
Token string
}
type distributedTaskForwarder struct {
endpoint string
token string
httpClient *http.Client
}
func newDistributedTaskForwarder(config distributedTaskForwarderConfig) *distributedTaskForwarder {
endpoint := strings.TrimRight(strings.TrimSpace(config.Endpoint), "/")
if endpoint == "" {
return nil
}
return &distributedTaskForwarder{
endpoint: endpoint,
token: strings.TrimSpace(config.Token),
httpClient: &http.Client{
Timeout: openClawAgentWaitMaxTimeout + openClawAgentWaitHTTPMargin,
},
}
}
func (f *distributedTaskForwarder) shouldForward(r *http.Request, request shared.RPCRequest) bool {
if f == nil || strings.TrimSpace(f.endpoint) == "" || r == nil {
return false
}
if strings.TrimSpace(r.Header.Get(distributedForwardedHeader)) != "" {
return false
}
method := strings.TrimSpace(request.Method)
return method == "session.start" || method == "session.message"
}
func (f *distributedTaskForwarder) forward(ctx context.Context, w http.ResponseWriter, r *http.Request, request shared.RPCRequest) bool {
if !f.shouldForward(r, request) {
return false
}
payload, err := json.Marshal(request)
if err != nil {
shared.WriteJSONError(w, request.ID, http.StatusInternalServerError, -32603, "TASK_FORWARD_ENCODE_FAILED")
return true
}
forwardURL, err := f.forwardURL(r.URL.Path)
if err != nil {
shared.WriteJSONError(w, request.ID, http.StatusBadGateway, -32060, err.Error())
return true
}
outbound, err := http.NewRequestWithContext(ctx, http.MethodPost, forwardURL, bytes.NewReader(payload))
if err != nil {
shared.WriteJSONError(w, request.ID, http.StatusBadGateway, -32060, "TASK_FORWARD_REQUEST_BUILD_FAILED: "+err.Error())
return true
}
outbound.Header.Set("Content-Type", "application/json")
outbound.Header.Set(distributedForwardedHeader, "1")
copyForwardHeader(outbound.Header, r.Header, "Accept")
copyForwardHeader(outbound.Header, r.Header, "Origin")
if f.token != "" {
outbound.Header.Set("Authorization", distributedForwardBearerHeader(f.token))
}
response, err := f.httpClient.Do(outbound)
if err != nil {
shared.WriteJSONError(w, request.ID, http.StatusBadGateway, -32060, "TASK_FORWARD_FAILED: "+err.Error())
return true
}
defer func() { _ = response.Body.Close() }()
copyForwardResponseHeaders(w.Header(), response.Header)
w.WriteHeader(response.StatusCode)
_, _ = io.Copy(w, response.Body)
return true
}
func (f *distributedTaskForwarder) forwardURL(path string) (string, error) {
base, err := url.Parse(f.endpoint)
if err != nil || base.Scheme == "" || base.Host == "" {
return "", fmt.Errorf("TASK_FORWARD_ENDPOINT_INVALID: %s", f.endpoint)
}
if !distributedForwardEndpointEncryptedOrPrivate(base) {
return "", fmt.Errorf("TASK_FORWARD_ENDPOINT_INSECURE: use https or a private VPN endpoint")
}
base.Path = strings.TrimRight(base.Path, "/") + "/" + strings.TrimLeft(path, "/")
base.RawQuery = ""
return base.String(), nil
}
func distributedForwardEndpointEncryptedOrPrivate(endpoint *url.URL) bool {
if endpoint == nil {
return false
}
if strings.EqualFold(endpoint.Scheme, "https") {
return true
}
if !strings.EqualFold(endpoint.Scheme, "http") {
return false
}
host := strings.Trim(endpoint.Hostname(), "[]")
if host == "localhost" {
return true
}
ip := net.ParseIP(host)
if ip == nil {
return false
}
return ip.IsLoopback() || ip.IsPrivate() || ip.IsLinkLocalUnicast()
}
func distributedForwardBearerHeader(token string) string {
token = strings.TrimSpace(token)
if token == "" || strings.HasPrefix(strings.ToLower(token), "bearer ") {
return token
}
return "Bearer " + token
}
func copyForwardHeader(dst http.Header, src http.Header, key string) {
if value := strings.TrimSpace(src.Get(key)); value != "" {
dst.Set(key, value)
}
}
func copyForwardResponseHeaders(dst http.Header, src http.Header) {
for _, key := range []string{"Content-Type", "Cache-Control", "Connection"} {
if value := strings.TrimSpace(src.Get(key)); value != "" {
dst.Set(key, value)
}
}
}