fix hermes acp prompt forwarding

This commit is contained in:
Haitao Pan 2026-04-22 14:22:09 +08:00
parent 849e45bf5b
commit b697e915eb
3 changed files with 289 additions and 47 deletions

View File

@ -15,6 +15,7 @@ import (
type rpcClient interface {
Initialize() (initializeResult, error)
Call(method string, params map[string]any) (map[string]any, error)
SetNotificationHandler(func(map[string]any))
Close() error
}
@ -25,18 +26,19 @@ type initializeResult struct {
}
type stdioRPCClient struct {
mu sync.Mutex
command string
args []string
env []string
protocolVersion int
cmd *exec.Cmd
stdin io.WriteCloser
stdout *bufio.Reader
stderr io.ReadCloser
nextID atomic.Int64
initialized bool
initResult initializeResult
mu sync.Mutex
command string
args []string
env []string
protocolVersion int
cmd *exec.Cmd
stdin io.WriteCloser
stdout *bufio.Reader
stderr io.ReadCloser
nextID atomic.Int64
initialized bool
initResult initializeResult
notificationHandler func(map[string]any)
}
func newStdioRPCClient(command string, args []string, env []string, protocolVersion int) *stdioRPCClient {
@ -93,6 +95,12 @@ func (c *stdioRPCClient) Close() error {
return c.closeLocked()
}
func (c *stdioRPCClient) SetNotificationHandler(handler func(map[string]any)) {
c.mu.Lock()
defer c.mu.Unlock()
c.notificationHandler = handler
}
func (c *stdioRPCClient) ensureStartedLocked() error {
if c.cmd != nil {
return nil
@ -161,19 +169,29 @@ func (c *stdioRPCClient) callLocked(method string, params map[string]any) (map[s
if _, err := c.stdin.Write(append(encoded, '\n')); err != nil {
return nil, err
}
line, err := c.stdout.ReadBytes('\n')
if err != nil {
if stderr, stderrErr := io.ReadAll(c.stderr); stderrErr == nil {
trimmed := strings.TrimSpace(string(stderr))
if trimmed != "" {
return nil, fmt.Errorf("hermes acp read failed: %s", trimmed)
for {
line, err := c.stdout.ReadBytes('\n')
if err != nil {
if stderr, stderrErr := io.ReadAll(c.stderr); stderrErr == nil {
trimmed := strings.TrimSpace(string(stderr))
if trimmed != "" {
return nil, fmt.Errorf("hermes acp read failed: %s", trimmed)
}
}
return nil, err
}
var response map[string]any
if err := json.Unmarshal(line, &response); err != nil {
return nil, fmt.Errorf("decode hermes acp response: %w", err)
}
if responseID, _ := response["id"].(string); responseID != "" {
if responseID == requestID {
return response, nil
}
continue
}
if handler := c.notificationHandler; handler != nil {
handler(response)
}
return nil, err
}
var response map[string]any
if err := json.Unmarshal(line, &response); err != nil {
return nil, fmt.Errorf("decode hermes acp response: %w", err)
}
return response, nil
}

View File

@ -48,6 +48,7 @@ type adapterSession struct {
history []string
model string
workingDirectory string
upstreamSessionID string
lastOutput string
lastUpstreamMethod string
}
@ -111,7 +112,7 @@ func NewServer(client rpcClient) *Server {
providerID: strings.TrimSpace(shared.EnvOrDefault("HERMES_ADAPTER_PROVIDER_ID", defaultProviderID)),
providerLabel: strings.TrimSpace(shared.EnvOrDefault("HERMES_ADAPTER_PROVIDER_LABEL", defaultLabel)),
allowedOrigins: parseAllowedOrigins(strings.TrimSpace(shared.EnvOrDefault("HERMES_ADAPTER_ALLOWED_ORIGINS", "https://xworkmate.svc.plus,http://localhost:*,http://127.0.0.1:*"))),
upstreamMethod: strings.TrimSpace(shared.EnvOrDefault("HERMES_ADAPTER_UPSTREAM_METHOD", "session.start")),
upstreamMethod: strings.TrimSpace(shared.EnvOrDefault("HERMES_ADAPTER_UPSTREAM_METHOD", "prompt")),
sessionRunner: func(ctx context.Context, model, prompt, workingDirectory string) (string, error) {
return shared.RunProviderCommand(
ctx,
@ -272,12 +273,15 @@ func (s *Server) handleSessionRequest(method string, params map[string]any) map[
}
upstreamMethod := s.upstreamMethod
if upstreamMethod != "" {
return s.handleConfiguredUpstreamSessionRequest(upstreamMethod, params)
return s.handleConfiguredUpstreamSessionRequest(method, upstreamMethod, params)
}
return s.handleCompatSessionRequest(method, params)
}
func (s *Server) handleConfiguredUpstreamSessionRequest(upstreamMethod string, params map[string]any) map[string]any {
func (s *Server) handleConfiguredUpstreamSessionRequest(method, upstreamMethod string, params map[string]any) map[string]any {
if strings.TrimSpace(strings.ToLower(upstreamMethod)) == "prompt" {
return s.handleHermesPromptUpstreamSessionRequest(method, params)
}
response, err := s.client.Call(upstreamMethod, params)
if err != nil {
return map[string]any{
@ -317,6 +321,193 @@ func (s *Server) handleConfiguredUpstreamSessionRequest(upstreamMethod string, p
}
}
func (s *Server) handleHermesPromptUpstreamSessionRequest(method string, params map[string]any) map[string]any {
sessionID := strings.TrimSpace(shared.StringArg(params, "sessionId", ""))
if sessionID == "" {
return map[string]any{
"success": false,
"provider": s.providerID,
"mode": "single-agent",
"error": "sessionId is required",
}
}
state := s.getOrCreateSession(sessionID)
if method == "session.start" {
state = s.resetSession(sessionID)
}
taskPrompt := strings.TrimSpace(shared.StringArg(params, "taskPrompt", ""))
taskPrompt = shared.AugmentPromptWithAttachments(taskPrompt, params)
if taskPrompt == "" {
return map[string]any{
"success": false,
"provider": s.providerID,
"mode": "single-agent",
"error": "taskPrompt is required",
}
}
workingDirectory := strings.TrimSpace(shared.StringArg(params, "workingDirectory", ""))
if workingDirectory == "" {
workingDirectory = state.workingDirectory
}
if workingDirectory == "" {
workingDirectory = "."
}
if state.upstreamSessionID == "" || method == "session.start" {
newSessionResp, err := s.client.Call("new_session", map[string]any{
"cwd": workingDirectory,
})
if err != nil {
return map[string]any{
"success": false,
"provider": s.providerID,
"mode": "single-agent",
"error": err.Error(),
}
}
state.upstreamSessionID = extractHermesUpstreamSessionID(newSessionResp)
if state.upstreamSessionID == "" {
return map[string]any{
"success": false,
"provider": s.providerID,
"mode": "single-agent",
"error": "hermes upstream did not return a session id",
}
}
}
s.sessionsMu.Lock()
current := s.sessions[sessionID]
if current == nil {
current = &adapterSession{}
s.sessions[sessionID] = current
}
current.upstreamSessionID = state.upstreamSessionID
current.workingDirectory = workingDirectory
current.model = strings.TrimSpace(shared.StringArg(params, "model", current.model))
s.sessionsMu.Unlock()
var outputParts []string
notificationHandler := func(notification map[string]any) {
text := extractHermesSessionUpdateText(notification)
if text != "" {
outputParts = append(outputParts, text)
}
}
s.client.SetNotificationHandler(notificationHandler)
defer s.client.SetNotificationHandler(nil)
promptPayload := []map[string]any{
{
"type": "text",
"text": taskPrompt,
},
}
response, err := s.client.Call("prompt", map[string]any{
"sessionId": state.upstreamSessionID,
"prompt": promptPayload,
})
if err != nil {
return map[string]any{
"success": false,
"provider": s.providerID,
"mode": "single-agent",
"error": err.Error(),
}
}
output := strings.TrimSpace(strings.Join(outputParts, ""))
if output == "" {
if resultMap, ok := response["result"].(map[string]any); ok {
for _, key := range []string{"output", "finalResponse", "final_response", "text", "message"} {
if candidate := strings.TrimSpace(shared.StringArg(resultMap, key, "")); candidate != "" {
output = candidate
break
}
}
}
}
if output == "" {
output = "ok"
}
s.sessionsMu.Lock()
current = s.sessions[sessionID]
if current == nil {
current = &adapterSession{}
s.sessions[sessionID] = current
}
current.history = append(current.history, "USER: "+taskPrompt, "ASSISTANT: "+output)
current.lastOutput = output
current.lastUpstreamMethod = "prompt"
s.sessionsMu.Unlock()
result := map[string]any{
"success": true,
"provider": s.providerID,
"mode": "single-agent",
"output": output,
"sessionId": sessionID,
"upstreamMethod": "prompt",
}
if workingDirectory != "" {
result["effectiveWorkingDirectory"] = workingDirectory
}
if state.upstreamSessionID != "" {
result["upstreamSessionId"] = state.upstreamSessionID
}
return result
}
func extractHermesUpstreamSessionID(response map[string]any) string {
for _, key := range []string{"sessionId", "session_id", "id"} {
if value := strings.TrimSpace(shared.StringArg(asMap(response["result"]), key, "")); value != "" {
return value
}
if value := strings.TrimSpace(shared.StringArg(response, key, "")); value != "" {
return value
}
}
return ""
}
func extractHermesSessionUpdateText(notification map[string]any) string {
if notification == nil {
return ""
}
payload := asMap(notification["params"])
if len(payload) == 0 {
payload = notification
}
update := asMap(payload["update"])
if len(update) == 0 {
update = payload
}
for _, key := range []string{"text", "message", "content", "delta"} {
if text := strings.TrimSpace(shared.StringArg(update, key, "")); text != "" {
if updateKind := strings.TrimSpace(shared.StringArg(update, "sessionUpdate", "")); updateKind == "" || updateKind == "agent_message_chunk" || updateKind == "agent_message_text" {
return text
}
}
}
if text := strings.TrimSpace(shared.StringArg(payload, "text", "")); text != "" {
return text
}
return ""
}
func asMap(value any) map[string]any {
if value == nil {
return nil
}
if result, ok := value.(map[string]any); ok {
return result
}
return nil
}
func (s *Server) handleCompatSessionRequest(method string, params map[string]any) map[string]any {
if s.sessionRunner == nil {
return map[string]any{
@ -414,6 +605,7 @@ func (s *Server) getOrCreateSession(sessionID string) *adapterSession {
history: append([]string(nil), state.history...),
model: state.model,
workingDirectory: state.workingDirectory,
upstreamSessionID: state.upstreamSessionID,
lastOutput: state.lastOutput,
lastUpstreamMethod: state.lastUpstreamMethod,
}

View File

@ -14,12 +14,15 @@ import (
)
type stubClient struct {
initResult initializeResult
initErr error
callResult map[string]any
callErr error
lastMethod string
lastParams map[string]any
initResult initializeResult
initErr error
callResult map[string]any
callErr error
callFn func(method string, params map[string]any) (map[string]any, error)
lastMethod string
lastParams map[string]any
methods []string
notificationHandler func(map[string]any)
}
func (s *stubClient) Initialize() (initializeResult, error) {
@ -29,9 +32,17 @@ func (s *stubClient) Initialize() (initializeResult, error) {
func (s *stubClient) Call(method string, params map[string]any) (map[string]any, error) {
s.lastMethod = method
s.lastParams = params
s.methods = append(s.methods, method)
if s.callFn != nil {
return s.callFn(method, params)
}
return s.callResult, s.callErr
}
func (s *stubClient) SetNotificationHandler(handler func(map[string]any)) {
s.notificationHandler = handler
}
func (s *stubClient) Close() error { return nil }
func TestHandleCapabilitiesSynthesizesProviderResponse(t *testing.T) {
@ -49,17 +60,38 @@ func TestHandleCapabilitiesSynthesizesProviderResponse(t *testing.T) {
}
func TestHandleRPCSessionStartReturnsUpstreamResult(t *testing.T) {
stub := &stubClient{
initResult: initializeResult{ProtocolVersion: 1},
callResult: map[string]any{
"result": map[string]any{
"success": true,
"output": "hello",
},
},
var stub *stubClient
stub = &stubClient{initResult: initializeResult{ProtocolVersion: 1}}
stub.callFn = func(method string, params map[string]any) (map[string]any, error) {
switch method {
case "new_session":
return map[string]any{
"result": map[string]any{
"sessionId": "upstream-session-1",
},
}, nil
case "prompt":
if stub.notificationHandler != nil {
stub.notificationHandler(map[string]any{
"params": map[string]any{
"update": map[string]any{
"sessionUpdate": "agent_message_chunk",
"text": "hello",
},
},
})
}
return map[string]any{
"result": map[string]any{
"stopReason": "end_turn",
},
}, nil
default:
return map[string]any{"result": map[string]any{}}, nil
}
}
server := NewServer(stub)
server.upstreamMethod = "session.start"
server.upstreamMethod = "prompt"
body, _ := json.Marshal(shared.RPCRequest{
JSONRPC: "2.0",
@ -87,8 +119,8 @@ func TestHandleRPCSessionStartReturnsUpstreamResult(t *testing.T) {
if got := result["output"]; got != "hello" {
t.Fatalf("expected output hello, got %#v", result)
}
if stub.lastMethod != "session.start" {
t.Fatalf("expected upstream method session.start, got %q", stub.lastMethod)
if len(stub.methods) != 2 || stub.methods[0] != "new_session" || stub.methods[1] != "prompt" {
t.Fatalf("expected new_session then prompt, got %#v", stub.methods)
}
}
@ -120,10 +152,10 @@ func TestHandleSessionStartFallsBackToPromptRunner(t *testing.T) {
}
}
func TestNewServerDefaultsHermesToUpstreamSessionStart(t *testing.T) {
func TestNewServerDefaultsHermesToUpstreamPrompt(t *testing.T) {
server := NewServer(&stubClient{})
if got := server.upstreamMethod; got != "session.start" {
t.Fatalf("expected default upstream method session.start, got %q", got)
if got := server.upstreamMethod; got != "prompt" {
t.Fatalf("expected default upstream method prompt, got %q", got)
}
}