Enforce strict Bearer token validation even when the bridge auth token is not explicitly configured in the environment. This ensures unauthenticated requests are rejected with a 401 status code by default. Updated deployment scripts to pass the required auth token and adjusted the test suite to align with the new security requirements.
540 lines
15 KiB
Go
540 lines
15 KiB
Go
package geminiadapter
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"flag"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/gorilla/websocket"
|
|
|
|
"xworkmate-bridge/internal/service"
|
|
"xworkmate-bridge/internal/shared"
|
|
)
|
|
|
|
const (
|
|
defaultListenAddr = "127.0.0.1:8791"
|
|
defaultProviderID = "gemini"
|
|
defaultLabel = "Gemini"
|
|
)
|
|
|
|
type Server struct {
|
|
client rpcClient
|
|
authService *service.StaticTokenAuthService
|
|
providerID string
|
|
providerLabel string
|
|
allowedOrigins []string
|
|
upstreamMethod string
|
|
sessionRunner func(context.Context, string, string, string) (string, error)
|
|
sessionsMu sync.Mutex
|
|
sessions map[string]*adapterSession
|
|
}
|
|
|
|
var adapterWSUpgrader = websocket.Upgrader{
|
|
ReadBufferSize: 16 * 1024,
|
|
WriteBufferSize: 16 * 1024,
|
|
CheckOrigin: func(*http.Request) bool {
|
|
return true
|
|
},
|
|
}
|
|
|
|
type adapterSession struct {
|
|
history []string
|
|
model string
|
|
workingDirectory string
|
|
lastOutput string
|
|
lastUpstreamMethod string
|
|
}
|
|
|
|
func Serve(args []string) error {
|
|
flags := flag.NewFlagSet("gemini-acp-adapter", flag.ExitOnError)
|
|
listen := flags.String(
|
|
"listen",
|
|
strings.TrimSpace(shared.EnvOrDefault("GEMINI_ADAPTER_LISTEN_ADDR", defaultListenAddr)),
|
|
"Gemini ACP adapter listen address",
|
|
)
|
|
binary := flags.String(
|
|
"gemini-bin",
|
|
strings.TrimSpace(shared.EnvOrDefault("GEMINI_ADAPTER_BIN", shared.EnvOrDefault("ACP_GEMINI_BIN", "gemini"))),
|
|
"Gemini CLI binary path",
|
|
)
|
|
rawArgs := flags.String(
|
|
"gemini-args",
|
|
strings.TrimSpace(shared.EnvOrDefault("GEMINI_ADAPTER_ARGS", "--experimental-acp")),
|
|
"Gemini CLI arguments",
|
|
)
|
|
_ = flags.Parse(args)
|
|
|
|
client := newStdioRPCClient(
|
|
*binary,
|
|
strings.Fields(strings.TrimSpace(*rawArgs)),
|
|
nil,
|
|
shared.IntArg(shared.EnvOrDefault("GEMINI_ADAPTER_PROTOCOL_VERSION", "1"), 1),
|
|
)
|
|
defer func() {
|
|
_ = client.Close()
|
|
}()
|
|
|
|
server := NewServer(client)
|
|
httpServer := &http.Server{
|
|
Addr: strings.TrimSpace(*listen),
|
|
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
switch r.URL.Path {
|
|
case "/acp/rpc":
|
|
server.HandleRPC(w, r)
|
|
case "/acp":
|
|
server.HandleWebSocket(w, r)
|
|
default:
|
|
http.NotFound(w, r)
|
|
}
|
|
}),
|
|
ReadTimeout: 30 * time.Second,
|
|
WriteTimeout: 5 * time.Minute,
|
|
IdleTimeout: 2 * time.Minute,
|
|
}
|
|
if err := httpServer.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
|
return fmt.Errorf("gemini adapter failed: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func NewServer(client rpcClient) *Server {
|
|
return &Server{
|
|
client: client,
|
|
authService: service.NewStaticTokenAuthService(strings.TrimSpace(shared.EnvOrDefault("GEMINI_ADAPTER_AUTH_TOKEN", ""))),
|
|
providerID: strings.TrimSpace(shared.EnvOrDefault("GEMINI_ADAPTER_PROVIDER_ID", defaultProviderID)),
|
|
providerLabel: strings.TrimSpace(shared.EnvOrDefault("GEMINI_ADAPTER_PROVIDER_LABEL", defaultLabel)),
|
|
allowedOrigins: parseAllowedOrigins(strings.TrimSpace(shared.EnvOrDefault("GEMINI_ADAPTER_ALLOWED_ORIGINS", "https://xworkmate.svc.plus,http://localhost:*,http://127.0.0.1:*"))),
|
|
upstreamMethod: strings.TrimSpace(shared.EnvOrDefault("GEMINI_ADAPTER_UPSTREAM_METHOD", "")),
|
|
sessionRunner: func(ctx context.Context, model, prompt, workingDirectory string) (string, error) {
|
|
return shared.RunProviderCommand(
|
|
ctx,
|
|
defaultProviderID,
|
|
model,
|
|
prompt,
|
|
workingDirectory,
|
|
)
|
|
},
|
|
sessions: make(map[string]*adapterSession),
|
|
}
|
|
}
|
|
|
|
func (s *Server) HandleWebSocket(w http.ResponseWriter, r *http.Request) {
|
|
if !s.originAllowed(r.Header.Get("Origin")) {
|
|
s.writeJSONError(w, nil, http.StatusForbidden, -32003, fmt.Sprintf("origin not allowed: %s", strings.TrimSpace(r.Header.Get("Origin"))))
|
|
return
|
|
}
|
|
if !s.authorized(r) {
|
|
s.writeJSONError(w, nil, http.StatusUnauthorized, -32001, "missing bearer authorization")
|
|
return
|
|
}
|
|
upgrader := adapterWSUpgrader
|
|
upgrader.CheckOrigin = func(req *http.Request) bool {
|
|
return s.originAllowed(req.Header.Get("Origin")) && s.authorized(req)
|
|
}
|
|
conn, err := upgrader.Upgrade(w, r, nil)
|
|
if err != nil {
|
|
return
|
|
}
|
|
defer func() {
|
|
_ = conn.Close()
|
|
}()
|
|
|
|
var writeMu sync.Mutex
|
|
notify := func(message map[string]any) {
|
|
writeMu.Lock()
|
|
defer writeMu.Unlock()
|
|
_ = conn.WriteJSON(message)
|
|
}
|
|
|
|
for {
|
|
_, payload, err := conn.ReadMessage()
|
|
if err != nil {
|
|
return
|
|
}
|
|
request, err := shared.DecodeRPCRequest(payload)
|
|
if err != nil {
|
|
notify(shared.ErrorEnvelope(nil, -32700, err.Error()))
|
|
continue
|
|
}
|
|
response := s.handleRequest(request)
|
|
if request.ID == nil {
|
|
continue
|
|
}
|
|
notify(shared.ResultEnvelope(request.ID, response))
|
|
}
|
|
}
|
|
|
|
func (s *Server) HandleRPC(w http.ResponseWriter, r *http.Request) {
|
|
s.applyCORS(w, r)
|
|
if r.Method == http.MethodOptions {
|
|
w.WriteHeader(http.StatusNoContent)
|
|
return
|
|
}
|
|
if r.Method != http.MethodPost {
|
|
s.writeJSONError(w, nil, http.StatusMethodNotAllowed, -32600, "method not allowed")
|
|
return
|
|
}
|
|
if !s.originAllowed(r.Header.Get("Origin")) {
|
|
s.writeJSONError(w, nil, http.StatusForbidden, -32003, fmt.Sprintf("origin not allowed: %s", strings.TrimSpace(r.Header.Get("Origin"))))
|
|
return
|
|
}
|
|
if !s.authorized(r) {
|
|
s.writeJSONError(w, nil, http.StatusUnauthorized, -32001, "missing bearer authorization")
|
|
return
|
|
}
|
|
payload, err := io.ReadAll(r.Body)
|
|
if err != nil {
|
|
s.writeJSONError(w, nil, http.StatusBadRequest, -32600, "invalid body")
|
|
return
|
|
}
|
|
request, err := shared.DecodeRPCRequest(payload)
|
|
if err != nil {
|
|
s.writeJSONError(w, nil, http.StatusBadRequest, -32700, err.Error())
|
|
return
|
|
}
|
|
result := s.handleRequest(request)
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(http.StatusOK)
|
|
_ = json.NewEncoder(w).Encode(shared.ResultEnvelope(request.ID, result))
|
|
}
|
|
|
|
func (s *Server) handleRequest(request shared.RPCRequest) map[string]any {
|
|
switch strings.TrimSpace(request.Method) {
|
|
case "acp.capabilities":
|
|
return s.handleCapabilities()
|
|
case "session.start", "session.message":
|
|
return s.handleSessionRequest(request.Method, request.Params)
|
|
case "session.cancel":
|
|
return map[string]any{"accepted": true, "cancelled": false}
|
|
case "session.close":
|
|
sessionID := strings.TrimSpace(shared.StringArg(request.Params, "sessionId", ""))
|
|
return map[string]any{"accepted": true, "closed": s.closeSession(sessionID)}
|
|
case "gemini.initialize":
|
|
return s.handleInitialize()
|
|
case "gemini.raw":
|
|
return s.handleRaw(request.Params)
|
|
default:
|
|
return map[string]any{
|
|
"success": false,
|
|
"error": fmt.Sprintf("unsupported method: %s", strings.TrimSpace(request.Method)),
|
|
}
|
|
}
|
|
}
|
|
|
|
func (s *Server) handleCapabilities() map[string]any {
|
|
result, err := s.client.Initialize()
|
|
if err != nil {
|
|
return map[string]any{
|
|
"singleAgent": false,
|
|
"multiAgent": false,
|
|
"providers": []string{},
|
|
"capabilities": map[string]any{
|
|
"single_agent": false,
|
|
"multi_agent": false,
|
|
"providers": []string{},
|
|
},
|
|
"success": false,
|
|
"error": err.Error(),
|
|
}
|
|
}
|
|
return map[string]any{
|
|
"singleAgent": true,
|
|
"multiAgent": false,
|
|
"providers": []string{s.providerID},
|
|
"capabilities": map[string]any{
|
|
"single_agent": true,
|
|
"multi_agent": false,
|
|
"providers": []string{s.providerID},
|
|
},
|
|
"provider": map[string]any{
|
|
"id": s.providerID,
|
|
"label": s.providerLabel,
|
|
},
|
|
"upstream": map[string]any{
|
|
"protocolVersion": result.ProtocolVersion,
|
|
"authMethods": result.AuthMethods,
|
|
"agentCapabilities": result.AgentCapabilities,
|
|
},
|
|
}
|
|
}
|
|
|
|
func (s *Server) handleInitialize() map[string]any {
|
|
result, err := s.client.Initialize()
|
|
if err != nil {
|
|
return map[string]any{"success": false, "error": err.Error()}
|
|
}
|
|
return map[string]any{
|
|
"success": true,
|
|
"result": result,
|
|
}
|
|
}
|
|
|
|
func (s *Server) handleRaw(params map[string]any) map[string]any {
|
|
method := strings.TrimSpace(shared.StringArg(params, "method", ""))
|
|
upstreamParams, _ := params["params"].(map[string]any)
|
|
if method == "" {
|
|
return map[string]any{"success": false, "error": "method is required"}
|
|
}
|
|
if _, err := s.client.Initialize(); err != nil {
|
|
return map[string]any{"success": false, "error": err.Error()}
|
|
}
|
|
response, err := s.client.Call(method, upstreamParams)
|
|
if err != nil {
|
|
return map[string]any{"success": false, "error": err.Error()}
|
|
}
|
|
return map[string]any{"success": true, "response": response}
|
|
}
|
|
|
|
func (s *Server) handleSessionRequest(method string, params map[string]any) map[string]any {
|
|
if _, err := s.client.Initialize(); err != nil {
|
|
return map[string]any{
|
|
"success": false,
|
|
"provider": s.providerID,
|
|
"mode": "single-agent",
|
|
"error": err.Error(),
|
|
}
|
|
}
|
|
upstreamMethod := s.upstreamMethod
|
|
if upstreamMethod != "" {
|
|
return s.handleConfiguredUpstreamSessionRequest(upstreamMethod, params)
|
|
}
|
|
return s.handleCompatSessionRequest(method, params)
|
|
}
|
|
|
|
func (s *Server) handleConfiguredUpstreamSessionRequest(upstreamMethod string, params map[string]any) map[string]any {
|
|
response, err := s.client.Call(upstreamMethod, params)
|
|
if err != nil {
|
|
return map[string]any{
|
|
"success": false,
|
|
"provider": s.providerID,
|
|
"mode": "single-agent",
|
|
"error": err.Error(),
|
|
"upstreamMethod": upstreamMethod,
|
|
}
|
|
}
|
|
result, _ := response["result"].(map[string]any)
|
|
if len(result) > 0 {
|
|
if _, ok := result["provider"]; !ok {
|
|
result["provider"] = s.providerID
|
|
}
|
|
if _, ok := result["mode"]; !ok {
|
|
result["mode"] = "single-agent"
|
|
}
|
|
return result
|
|
}
|
|
if errPayload, ok := response["error"].(map[string]any); ok && len(errPayload) > 0 {
|
|
return map[string]any{
|
|
"success": false,
|
|
"provider": s.providerID,
|
|
"mode": "single-agent",
|
|
"error": strings.TrimSpace(shared.StringArg(errPayload, "message", "upstream gemini acp error")),
|
|
"upstreamMethod": upstreamMethod,
|
|
"upstreamError": errPayload,
|
|
}
|
|
}
|
|
return map[string]any{
|
|
"success": true,
|
|
"provider": s.providerID,
|
|
"mode": "single-agent",
|
|
"upstreamMethod": upstreamMethod,
|
|
"upstream": response,
|
|
}
|
|
}
|
|
|
|
func (s *Server) handleCompatSessionRequest(method string, params map[string]any) map[string]any {
|
|
if s.sessionRunner == nil {
|
|
return map[string]any{
|
|
"success": false,
|
|
"provider": s.providerID,
|
|
"mode": "single-agent",
|
|
"error": "gemini session runner is not configured",
|
|
}
|
|
}
|
|
sessionID := strings.TrimSpace(shared.StringArg(params, "sessionId", ""))
|
|
if sessionID == "" {
|
|
return map[string]any{
|
|
"success": false,
|
|
"provider": s.providerID,
|
|
"mode": "single-agent",
|
|
"error": "sessionId is required",
|
|
}
|
|
}
|
|
state := s.getOrCreateSession(sessionID)
|
|
if method == "session.start" {
|
|
state = s.resetSession(sessionID)
|
|
}
|
|
taskPrompt := strings.TrimSpace(shared.StringArg(params, "taskPrompt", ""))
|
|
taskPrompt = shared.AugmentPromptWithAttachments(taskPrompt, params)
|
|
if taskPrompt == "" {
|
|
return map[string]any{
|
|
"success": false,
|
|
"provider": s.providerID,
|
|
"mode": "single-agent",
|
|
"error": "taskPrompt is required",
|
|
}
|
|
}
|
|
|
|
model := strings.TrimSpace(shared.StringArg(params, "model", ""))
|
|
if model == "" {
|
|
model = state.model
|
|
}
|
|
workingDirectory := strings.TrimSpace(shared.StringArg(params, "workingDirectory", ""))
|
|
if workingDirectory == "" {
|
|
workingDirectory = state.workingDirectory
|
|
}
|
|
|
|
sessionsHistory := append([]string(nil), state.history...)
|
|
sessionsHistory = append(sessionsHistory, taskPrompt)
|
|
composedPrompt := shared.ComposeHistoryPrompt(sessionsHistory)
|
|
output, err := s.sessionRunner(context.Background(), model, composedPrompt, workingDirectory)
|
|
if err != nil {
|
|
return map[string]any{
|
|
"success": false,
|
|
"provider": s.providerID,
|
|
"mode": "single-agent",
|
|
"error": err.Error(),
|
|
}
|
|
}
|
|
|
|
s.sessionsMu.Lock()
|
|
state = s.sessions[sessionID]
|
|
if state == nil {
|
|
state = &adapterSession{}
|
|
s.sessions[sessionID] = state
|
|
}
|
|
state.history = sessionsHistory
|
|
state.model = model
|
|
state.workingDirectory = workingDirectory
|
|
state.lastOutput = output
|
|
state.lastUpstreamMethod = "prompt"
|
|
s.sessionsMu.Unlock()
|
|
|
|
result := map[string]any{
|
|
"success": true,
|
|
"provider": s.providerID,
|
|
"mode": "single-agent",
|
|
"output": output,
|
|
"sessionId": sessionID,
|
|
"upstreamMethod": "prompt",
|
|
}
|
|
if workingDirectory != "" {
|
|
result["effectiveWorkingDirectory"] = workingDirectory
|
|
}
|
|
if model != "" {
|
|
result["resolvedModel"] = model
|
|
}
|
|
return result
|
|
}
|
|
|
|
func (s *Server) getOrCreateSession(sessionID string) *adapterSession {
|
|
s.sessionsMu.Lock()
|
|
defer s.sessionsMu.Unlock()
|
|
state := s.sessions[sessionID]
|
|
if state == nil {
|
|
state = &adapterSession{}
|
|
s.sessions[sessionID] = state
|
|
}
|
|
return &adapterSession{
|
|
history: append([]string(nil), state.history...),
|
|
model: state.model,
|
|
workingDirectory: state.workingDirectory,
|
|
lastOutput: state.lastOutput,
|
|
lastUpstreamMethod: state.lastUpstreamMethod,
|
|
}
|
|
}
|
|
|
|
func (s *Server) resetSession(sessionID string) *adapterSession {
|
|
s.sessionsMu.Lock()
|
|
defer s.sessionsMu.Unlock()
|
|
state := &adapterSession{}
|
|
s.sessions[sessionID] = state
|
|
return state
|
|
}
|
|
|
|
func (s *Server) closeSession(sessionID string) bool {
|
|
sessionID = strings.TrimSpace(sessionID)
|
|
if sessionID == "" {
|
|
return false
|
|
}
|
|
s.sessionsMu.Lock()
|
|
defer s.sessionsMu.Unlock()
|
|
if _, ok := s.sessions[sessionID]; !ok {
|
|
return false
|
|
}
|
|
delete(s.sessions, sessionID)
|
|
return true
|
|
}
|
|
|
|
func parseAllowedOrigins(raw string) []string {
|
|
if raw == "" {
|
|
return nil
|
|
}
|
|
parts := strings.Split(raw, ",")
|
|
result := make([]string, 0, len(parts))
|
|
for _, part := range parts {
|
|
part = strings.TrimSpace(part)
|
|
if part == "" {
|
|
continue
|
|
}
|
|
result = append(result, part)
|
|
}
|
|
return result
|
|
}
|
|
|
|
func (s *Server) originAllowed(origin string) bool {
|
|
origin = strings.TrimSpace(origin)
|
|
if origin == "" {
|
|
return true
|
|
}
|
|
for _, allowed := range s.allowedOrigins {
|
|
if strings.HasSuffix(allowed, ":*") {
|
|
if strings.HasPrefix(origin, strings.TrimSuffix(allowed, "*")) {
|
|
return true
|
|
}
|
|
continue
|
|
}
|
|
if origin == allowed {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
func (s *Server) applyCORS(w http.ResponseWriter, r *http.Request) {
|
|
origin := strings.TrimSpace(r.Header.Get("Origin"))
|
|
if origin == "" || !s.originAllowed(origin) {
|
|
return
|
|
}
|
|
headers := w.Header()
|
|
headers.Set("Access-Control-Allow-Origin", origin)
|
|
headers.Set("Access-Control-Allow-Methods", "POST, OPTIONS")
|
|
headers.Set("Access-Control-Allow-Headers", "Authorization, Content-Type, Accept")
|
|
headers.Set("Access-Control-Max-Age", "600")
|
|
headers.Add("Vary", "Origin")
|
|
headers.Add("Vary", "Access-Control-Request-Method")
|
|
headers.Add("Vary", "Access-Control-Request-Headers")
|
|
}
|
|
|
|
func (s *Server) authorized(r *http.Request) bool {
|
|
if s == nil {
|
|
return false
|
|
}
|
|
if s.authService == nil {
|
|
return false
|
|
}
|
|
return s.authService.ValidateAuthorizationHeader(r.Header.Get("Authorization"))
|
|
}
|
|
|
|
func (s *Server) writeJSONError(w http.ResponseWriter, requestID any, statusCode int, code int, message string) {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(statusCode)
|
|
_ = json.NewEncoder(w).Encode(shared.ErrorEnvelope(requestID, code, message))
|
|
}
|