541 lines
15 KiB
Go
541 lines
15 KiB
Go
package geminiadapter
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"flag"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/gorilla/websocket"
|
|
|
|
"xworkmate-bridge/internal/service"
|
|
"xworkmate-bridge/internal/shared"
|
|
)
|
|
|
|
const (
|
|
defaultListenAddr = "127.0.0.1:8791"
|
|
defaultProviderID = "gemini"
|
|
defaultLabel = "Gemini"
|
|
)
|
|
|
|
type Server struct {
|
|
client rpcClient
|
|
authService *service.StaticTokenAuthService
|
|
providerID string
|
|
providerLabel string
|
|
allowedOrigins []string
|
|
upstreamMethod string
|
|
sessionRunner func(context.Context, string, string, string) (string, error)
|
|
sessionsMu sync.Mutex
|
|
sessions map[string]*adapterSession
|
|
}
|
|
|
|
var adapterWSUpgrader = websocket.Upgrader{
|
|
ReadBufferSize: 16 * 1024,
|
|
WriteBufferSize: 16 * 1024,
|
|
CheckOrigin: func(*http.Request) bool {
|
|
return true
|
|
},
|
|
}
|
|
|
|
type adapterSession struct {
|
|
history []string
|
|
model string
|
|
workingDirectory string
|
|
lastOutput string
|
|
lastUpstreamMethod string
|
|
}
|
|
|
|
func Serve(args []string) error {
|
|
flags := flag.NewFlagSet("gemini-acp-adapter", flag.ExitOnError)
|
|
listen := flags.String(
|
|
"listen",
|
|
strings.TrimSpace(shared.EnvOrDefault("GEMINI_ADAPTER_LISTEN_ADDR", defaultListenAddr)),
|
|
"Gemini ACP adapter listen address",
|
|
)
|
|
binary := flags.String(
|
|
"gemini-bin",
|
|
strings.TrimSpace(shared.EnvOrDefault("GEMINI_ADAPTER_BIN", shared.EnvOrDefault("ACP_GEMINI_BIN", "gemini"))),
|
|
"Gemini CLI binary path",
|
|
)
|
|
rawArgs := flags.String(
|
|
"gemini-args",
|
|
strings.TrimSpace(shared.EnvOrDefault("GEMINI_ADAPTER_ARGS", "--experimental-acp")),
|
|
"Gemini CLI arguments",
|
|
)
|
|
_ = flags.Parse(args)
|
|
|
|
client := newStdioRPCClient(
|
|
*binary,
|
|
strings.Fields(strings.TrimSpace(*rawArgs)),
|
|
nil,
|
|
shared.IntArg(shared.EnvOrDefault("GEMINI_ADAPTER_PROTOCOL_VERSION", "1"), 1),
|
|
)
|
|
defer func() {
|
|
_ = client.Close()
|
|
}()
|
|
|
|
server := NewServer(client)
|
|
httpServer := &http.Server{
|
|
Addr: strings.TrimSpace(*listen),
|
|
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
switch r.URL.Path {
|
|
case "/acp/rpc":
|
|
server.HandleRPC(w, r)
|
|
case "/acp":
|
|
server.HandleWebSocket(w, r)
|
|
default:
|
|
http.NotFound(w, r)
|
|
}
|
|
}),
|
|
ReadTimeout: 30 * time.Second,
|
|
WriteTimeout: 5 * time.Minute,
|
|
IdleTimeout: 2 * time.Minute,
|
|
}
|
|
if err := httpServer.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
|
return fmt.Errorf("gemini adapter failed: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func NewServer(client rpcClient) *Server {
|
|
return &Server{
|
|
client: client,
|
|
authService: service.NewStaticTokenAuthService(strings.TrimSpace(shared.EnvOrDefault("GEMINI_ADAPTER_AUTH_TOKEN", ""))),
|
|
providerID: strings.TrimSpace(shared.EnvOrDefault("GEMINI_ADAPTER_PROVIDER_ID", defaultProviderID)),
|
|
providerLabel: strings.TrimSpace(shared.EnvOrDefault("GEMINI_ADAPTER_PROVIDER_LABEL", defaultLabel)),
|
|
allowedOrigins: parseAllowedOrigins(strings.TrimSpace(shared.EnvOrDefault("GEMINI_ADAPTER_ALLOWED_ORIGINS", "https://xworkmate.svc.plus,http://localhost:*,http://127.0.0.1:*"))),
|
|
upstreamMethod: strings.TrimSpace(shared.EnvOrDefault("GEMINI_ADAPTER_UPSTREAM_METHOD", "")),
|
|
sessionRunner: func(ctx context.Context, model, prompt, workingDirectory string) (string, error) {
|
|
return shared.RunProviderCommand(
|
|
ctx,
|
|
defaultProviderID,
|
|
model,
|
|
prompt,
|
|
workingDirectory,
|
|
)
|
|
},
|
|
sessions: make(map[string]*adapterSession),
|
|
}
|
|
}
|
|
|
|
func (s *Server) HandleWebSocket(w http.ResponseWriter, r *http.Request) {
|
|
if !s.originAllowed(r.Header.Get("Origin")) {
|
|
s.writeJSONError(w, nil, http.StatusForbidden, -32003, fmt.Sprintf("origin not allowed: %s", strings.TrimSpace(r.Header.Get("Origin"))))
|
|
return
|
|
}
|
|
if !s.authorized(r) {
|
|
s.writeJSONError(w, nil, http.StatusUnauthorized, -32001, "missing bearer authorization")
|
|
return
|
|
}
|
|
upgrader := adapterWSUpgrader
|
|
upgrader.CheckOrigin = func(req *http.Request) bool {
|
|
return s.originAllowed(req.Header.Get("Origin")) && s.authorized(req)
|
|
}
|
|
conn, err := upgrader.Upgrade(w, r, nil)
|
|
if err != nil {
|
|
return
|
|
}
|
|
defer func() {
|
|
_ = conn.Close()
|
|
}()
|
|
|
|
var writeMu sync.Mutex
|
|
notify := func(message map[string]any) {
|
|
writeMu.Lock()
|
|
defer writeMu.Unlock()
|
|
_ = conn.WriteJSON(message)
|
|
}
|
|
|
|
for {
|
|
_, payload, err := conn.ReadMessage()
|
|
if err != nil {
|
|
return
|
|
}
|
|
request, err := shared.DecodeRPCRequest(payload)
|
|
if err != nil {
|
|
notify(shared.ErrorEnvelope(nil, -32700, err.Error()))
|
|
continue
|
|
}
|
|
response := s.handleRequest(request)
|
|
if request.ID == nil {
|
|
continue
|
|
}
|
|
notify(shared.ResultEnvelope(request.ID, response))
|
|
}
|
|
}
|
|
|
|
func (s *Server) HandleRPC(w http.ResponseWriter, r *http.Request) {
|
|
s.applyCORS(w, r)
|
|
if r.Method == http.MethodOptions {
|
|
w.WriteHeader(http.StatusNoContent)
|
|
return
|
|
}
|
|
if r.Method != http.MethodPost {
|
|
s.writeJSONError(w, nil, http.StatusMethodNotAllowed, -32600, "method not allowed")
|
|
return
|
|
}
|
|
if !s.originAllowed(r.Header.Get("Origin")) {
|
|
s.writeJSONError(w, nil, http.StatusForbidden, -32003, fmt.Sprintf("origin not allowed: %s", strings.TrimSpace(r.Header.Get("Origin"))))
|
|
return
|
|
}
|
|
if !s.authorized(r) {
|
|
s.writeJSONError(w, nil, http.StatusUnauthorized, -32001, "missing bearer authorization")
|
|
return
|
|
}
|
|
payload, err := io.ReadAll(r.Body)
|
|
if err != nil {
|
|
s.writeJSONError(w, nil, http.StatusBadRequest, -32600, "invalid body")
|
|
return
|
|
}
|
|
request, err := shared.DecodeRPCRequest(payload)
|
|
if err != nil {
|
|
s.writeJSONError(w, nil, http.StatusBadRequest, -32700, err.Error())
|
|
return
|
|
}
|
|
result := s.handleRequest(request)
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(http.StatusOK)
|
|
_ = json.NewEncoder(w).Encode(shared.ResultEnvelope(request.ID, result))
|
|
}
|
|
|
|
func (s *Server) handleRequest(request shared.RPCRequest) map[string]any {
|
|
switch strings.TrimSpace(request.Method) {
|
|
case "acp.capabilities":
|
|
return s.handleCapabilities()
|
|
case "session.start", "session.message":
|
|
return s.handleSessionRequest(request.Method, request.Params)
|
|
case "session.cancel":
|
|
return map[string]any{"accepted": true, "cancelled": false}
|
|
case "session.close":
|
|
sessionID := strings.TrimSpace(shared.StringArg(request.Params, "sessionId", ""))
|
|
return map[string]any{"accepted": true, "closed": s.closeSession(sessionID)}
|
|
case "gemini.initialize":
|
|
return s.handleInitialize()
|
|
case "gemini.raw":
|
|
return s.handleRaw(request.Params)
|
|
default:
|
|
return map[string]any{
|
|
"success": false,
|
|
"error": fmt.Sprintf("unsupported method: %s", strings.TrimSpace(request.Method)),
|
|
}
|
|
}
|
|
}
|
|
|
|
func (s *Server) handleCapabilities() map[string]any {
|
|
result, err := s.client.Initialize()
|
|
if err != nil {
|
|
return map[string]any{
|
|
"singleAgent": false,
|
|
"multiAgent": false,
|
|
"providers": []string{},
|
|
"capabilities": map[string]any{
|
|
"single_agent": false,
|
|
"multi_agent": false,
|
|
"providers": []string{},
|
|
},
|
|
"success": false,
|
|
"error": err.Error(),
|
|
}
|
|
}
|
|
return map[string]any{
|
|
"singleAgent": true,
|
|
"multiAgent": false,
|
|
"providers": []string{s.providerID},
|
|
"capabilities": map[string]any{
|
|
"single_agent": true,
|
|
"multi_agent": false,
|
|
"providers": []string{s.providerID},
|
|
},
|
|
"provider": map[string]any{
|
|
"id": s.providerID,
|
|
"label": s.providerLabel,
|
|
},
|
|
"upstream": map[string]any{
|
|
"protocolVersion": result.ProtocolVersion,
|
|
"authMethods": result.AuthMethods,
|
|
"agentCapabilities": result.AgentCapabilities,
|
|
},
|
|
}
|
|
}
|
|
|
|
func (s *Server) handleInitialize() map[string]any {
|
|
result, err := s.client.Initialize()
|
|
if err != nil {
|
|
return map[string]any{"success": false, "error": err.Error()}
|
|
}
|
|
return map[string]any{
|
|
"success": true,
|
|
"result": result,
|
|
}
|
|
}
|
|
|
|
func (s *Server) handleRaw(params map[string]any) map[string]any {
|
|
method := strings.TrimSpace(shared.StringArg(params, "method", ""))
|
|
upstreamParams, _ := params["params"].(map[string]any)
|
|
if method == "" {
|
|
return map[string]any{"success": false, "error": "method is required"}
|
|
}
|
|
if _, err := s.client.Initialize(); err != nil {
|
|
return map[string]any{"success": false, "error": err.Error()}
|
|
}
|
|
response, err := s.client.Call(method, upstreamParams)
|
|
if err != nil {
|
|
return map[string]any{"success": false, "error": err.Error()}
|
|
}
|
|
return map[string]any{"success": true, "response": response}
|
|
}
|
|
|
|
func (s *Server) handleSessionRequest(method string, params map[string]any) map[string]any {
|
|
if _, err := s.client.Initialize(); err != nil {
|
|
return map[string]any{
|
|
"success": false,
|
|
"provider": s.providerID,
|
|
"mode": "single-agent",
|
|
"error": err.Error(),
|
|
}
|
|
}
|
|
upstreamMethod := s.upstreamMethod
|
|
if upstreamMethod != "" {
|
|
return s.handleConfiguredUpstreamSessionRequest(upstreamMethod, params)
|
|
}
|
|
return s.handleCompatSessionRequest(method, params)
|
|
}
|
|
|
|
func (s *Server) handleConfiguredUpstreamSessionRequest(upstreamMethod string, params map[string]any) map[string]any {
|
|
response, err := s.client.Call(upstreamMethod, params)
|
|
if err != nil {
|
|
return map[string]any{
|
|
"success": false,
|
|
"provider": s.providerID,
|
|
"mode": "single-agent",
|
|
"error": err.Error(),
|
|
"upstreamMethod": upstreamMethod,
|
|
}
|
|
}
|
|
result, _ := response["result"].(map[string]any)
|
|
if len(result) > 0 {
|
|
if _, ok := result["provider"]; !ok {
|
|
result["provider"] = s.providerID
|
|
}
|
|
if _, ok := result["mode"]; !ok {
|
|
result["mode"] = "single-agent"
|
|
}
|
|
return result
|
|
}
|
|
if errPayload, ok := response["error"].(map[string]any); ok && len(errPayload) > 0 {
|
|
return map[string]any{
|
|
"success": false,
|
|
"provider": s.providerID,
|
|
"mode": "single-agent",
|
|
"error": strings.TrimSpace(shared.StringArg(errPayload, "message", "upstream gemini acp error")),
|
|
"upstreamMethod": upstreamMethod,
|
|
"upstreamError": errPayload,
|
|
}
|
|
}
|
|
return map[string]any{
|
|
"success": true,
|
|
"provider": s.providerID,
|
|
"mode": "single-agent",
|
|
"upstreamMethod": upstreamMethod,
|
|
"upstream": response,
|
|
}
|
|
}
|
|
|
|
func (s *Server) handleCompatSessionRequest(method string, params map[string]any) map[string]any {
|
|
if s.sessionRunner == nil {
|
|
return map[string]any{
|
|
"success": false,
|
|
"provider": s.providerID,
|
|
"mode": "single-agent",
|
|
"error": "gemini session runner is not configured",
|
|
}
|
|
}
|
|
sessionID := strings.TrimSpace(shared.StringArg(params, "sessionId", ""))
|
|
if sessionID == "" {
|
|
return map[string]any{
|
|
"success": false,
|
|
"provider": s.providerID,
|
|
"mode": "single-agent",
|
|
"error": "sessionId is required",
|
|
}
|
|
}
|
|
state := s.getOrCreateSession(sessionID)
|
|
if method == "session.start" {
|
|
state = s.resetSession(sessionID)
|
|
}
|
|
taskPrompt := strings.TrimSpace(shared.StringArg(params, "taskPrompt", ""))
|
|
taskPrompt = shared.AugmentPromptWithAttachments(taskPrompt, params)
|
|
if taskPrompt == "" {
|
|
return map[string]any{
|
|
"success": false,
|
|
"provider": s.providerID,
|
|
"mode": "single-agent",
|
|
"error": "taskPrompt is required",
|
|
}
|
|
}
|
|
|
|
model := strings.TrimSpace(shared.StringArg(params, "model", ""))
|
|
if model == "" {
|
|
model = state.model
|
|
}
|
|
workingDirectory := strings.TrimSpace(shared.StringArg(params, "workingDirectory", ""))
|
|
if workingDirectory == "" {
|
|
workingDirectory = state.workingDirectory
|
|
}
|
|
|
|
sessionsHistory := append([]string(nil), state.history...)
|
|
sessionsHistory = append(sessionsHistory, taskPrompt)
|
|
composedPrompt := shared.ComposeHistoryPrompt(sessionsHistory)
|
|
output, err := s.sessionRunner(context.Background(), model, composedPrompt, workingDirectory)
|
|
if err != nil {
|
|
return map[string]any{
|
|
"success": false,
|
|
"provider": s.providerID,
|
|
"mode": "single-agent",
|
|
"error": err.Error(),
|
|
}
|
|
}
|
|
|
|
s.sessionsMu.Lock()
|
|
state = s.sessions[sessionID]
|
|
if state == nil {
|
|
state = &adapterSession{}
|
|
s.sessions[sessionID] = state
|
|
}
|
|
state.history = sessionsHistory
|
|
state.model = model
|
|
state.workingDirectory = workingDirectory
|
|
state.lastOutput = output
|
|
state.lastUpstreamMethod = "prompt"
|
|
s.sessionsMu.Unlock()
|
|
|
|
result := map[string]any{
|
|
"success": true,
|
|
"provider": s.providerID,
|
|
"mode": "single-agent",
|
|
"output": output,
|
|
"sessionId": sessionID,
|
|
"upstreamMethod": "prompt",
|
|
}
|
|
if workingDirectory != "" {
|
|
result["effectiveWorkingDirectory"] = workingDirectory
|
|
}
|
|
if model != "" {
|
|
result["resolvedModel"] = model
|
|
}
|
|
return result
|
|
}
|
|
|
|
func (s *Server) getOrCreateSession(sessionID string) *adapterSession {
|
|
s.sessionsMu.Lock()
|
|
defer s.sessionsMu.Unlock()
|
|
state := s.sessions[sessionID]
|
|
if state == nil {
|
|
state = &adapterSession{}
|
|
s.sessions[sessionID] = state
|
|
}
|
|
return &adapterSession{
|
|
history: append([]string(nil), state.history...),
|
|
model: state.model,
|
|
workingDirectory: state.workingDirectory,
|
|
lastOutput: state.lastOutput,
|
|
lastUpstreamMethod: state.lastUpstreamMethod,
|
|
}
|
|
}
|
|
|
|
func (s *Server) resetSession(sessionID string) *adapterSession {
|
|
s.sessionsMu.Lock()
|
|
defer s.sessionsMu.Unlock()
|
|
state := &adapterSession{}
|
|
s.sessions[sessionID] = state
|
|
return state
|
|
}
|
|
|
|
func (s *Server) closeSession(sessionID string) bool {
|
|
sessionID = strings.TrimSpace(sessionID)
|
|
if sessionID == "" {
|
|
return false
|
|
}
|
|
s.sessionsMu.Lock()
|
|
defer s.sessionsMu.Unlock()
|
|
if _, ok := s.sessions[sessionID]; !ok {
|
|
return false
|
|
}
|
|
delete(s.sessions, sessionID)
|
|
return true
|
|
}
|
|
|
|
func parseAllowedOrigins(raw string) []string {
|
|
if raw == "" {
|
|
return nil
|
|
}
|
|
parts := strings.Split(raw, ",")
|
|
result := make([]string, 0, len(parts))
|
|
for _, part := range parts {
|
|
part = strings.TrimSpace(part)
|
|
if part == "" {
|
|
continue
|
|
}
|
|
result = append(result, part)
|
|
}
|
|
return result
|
|
}
|
|
|
|
func (s *Server) originAllowed(origin string) bool {
|
|
origin = strings.TrimSpace(origin)
|
|
if origin == "" {
|
|
return true
|
|
}
|
|
for _, allowed := range s.allowedOrigins {
|
|
if strings.HasSuffix(allowed, ":*") {
|
|
if strings.HasPrefix(origin, strings.TrimSuffix(allowed, "*")) {
|
|
return true
|
|
}
|
|
continue
|
|
}
|
|
if origin == allowed {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
func (s *Server) applyCORS(w http.ResponseWriter, r *http.Request) {
|
|
origin := strings.TrimSpace(r.Header.Get("Origin"))
|
|
if origin == "" || !s.originAllowed(origin) {
|
|
return
|
|
}
|
|
headers := w.Header()
|
|
headers.Set("Access-Control-Allow-Origin", origin)
|
|
headers.Set("Access-Control-Allow-Methods", "POST, OPTIONS")
|
|
headers.Set("Access-Control-Allow-Headers", "Authorization, Content-Type, Accept")
|
|
headers.Set("Access-Control-Max-Age", "600")
|
|
headers.Add("Vary", "Origin")
|
|
headers.Add("Vary", "Access-Control-Request-Method")
|
|
headers.Add("Vary", "Access-Control-Request-Headers")
|
|
}
|
|
|
|
func (s *Server) authorized(r *http.Request) bool {
|
|
if s == nil || s.authService == nil {
|
|
return true
|
|
}
|
|
expected := strings.TrimSpace(shared.EnvOrDefault("GEMINI_ADAPTER_AUTH_TOKEN", ""))
|
|
if expected == "" {
|
|
return true
|
|
}
|
|
return s.authService.ValidateAuthorizationHeader(r.Header.Get("Authorization"))
|
|
}
|
|
|
|
func (s *Server) writeJSONError(w http.ResponseWriter, requestID any, statusCode int, code int, message string) {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(statusCode)
|
|
_ = json.NewEncoder(w).Encode(shared.ErrorEnvelope(requestID, code, message))
|
|
}
|