fix hermes acp prompt forwarding
This commit is contained in:
parent
849e45bf5b
commit
b697e915eb
@ -15,6 +15,7 @@ import (
|
||||
type rpcClient interface {
|
||||
Initialize() (initializeResult, error)
|
||||
Call(method string, params map[string]any) (map[string]any, error)
|
||||
SetNotificationHandler(func(map[string]any))
|
||||
Close() error
|
||||
}
|
||||
|
||||
@ -25,18 +26,19 @@ type initializeResult struct {
|
||||
}
|
||||
|
||||
type stdioRPCClient struct {
|
||||
mu sync.Mutex
|
||||
command string
|
||||
args []string
|
||||
env []string
|
||||
protocolVersion int
|
||||
cmd *exec.Cmd
|
||||
stdin io.WriteCloser
|
||||
stdout *bufio.Reader
|
||||
stderr io.ReadCloser
|
||||
nextID atomic.Int64
|
||||
initialized bool
|
||||
initResult initializeResult
|
||||
mu sync.Mutex
|
||||
command string
|
||||
args []string
|
||||
env []string
|
||||
protocolVersion int
|
||||
cmd *exec.Cmd
|
||||
stdin io.WriteCloser
|
||||
stdout *bufio.Reader
|
||||
stderr io.ReadCloser
|
||||
nextID atomic.Int64
|
||||
initialized bool
|
||||
initResult initializeResult
|
||||
notificationHandler func(map[string]any)
|
||||
}
|
||||
|
||||
func newStdioRPCClient(command string, args []string, env []string, protocolVersion int) *stdioRPCClient {
|
||||
@ -93,6 +95,12 @@ func (c *stdioRPCClient) Close() error {
|
||||
return c.closeLocked()
|
||||
}
|
||||
|
||||
func (c *stdioRPCClient) SetNotificationHandler(handler func(map[string]any)) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.notificationHandler = handler
|
||||
}
|
||||
|
||||
func (c *stdioRPCClient) ensureStartedLocked() error {
|
||||
if c.cmd != nil {
|
||||
return nil
|
||||
@ -161,19 +169,29 @@ func (c *stdioRPCClient) callLocked(method string, params map[string]any) (map[s
|
||||
if _, err := c.stdin.Write(append(encoded, '\n')); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
line, err := c.stdout.ReadBytes('\n')
|
||||
if err != nil {
|
||||
if stderr, stderrErr := io.ReadAll(c.stderr); stderrErr == nil {
|
||||
trimmed := strings.TrimSpace(string(stderr))
|
||||
if trimmed != "" {
|
||||
return nil, fmt.Errorf("hermes acp read failed: %s", trimmed)
|
||||
for {
|
||||
line, err := c.stdout.ReadBytes('\n')
|
||||
if err != nil {
|
||||
if stderr, stderrErr := io.ReadAll(c.stderr); stderrErr == nil {
|
||||
trimmed := strings.TrimSpace(string(stderr))
|
||||
if trimmed != "" {
|
||||
return nil, fmt.Errorf("hermes acp read failed: %s", trimmed)
|
||||
}
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
var response map[string]any
|
||||
if err := json.Unmarshal(line, &response); err != nil {
|
||||
return nil, fmt.Errorf("decode hermes acp response: %w", err)
|
||||
}
|
||||
if responseID, _ := response["id"].(string); responseID != "" {
|
||||
if responseID == requestID {
|
||||
return response, nil
|
||||
}
|
||||
continue
|
||||
}
|
||||
if handler := c.notificationHandler; handler != nil {
|
||||
handler(response)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
var response map[string]any
|
||||
if err := json.Unmarshal(line, &response); err != nil {
|
||||
return nil, fmt.Errorf("decode hermes acp response: %w", err)
|
||||
}
|
||||
return response, nil
|
||||
}
|
||||
|
||||
@ -48,6 +48,7 @@ type adapterSession struct {
|
||||
history []string
|
||||
model string
|
||||
workingDirectory string
|
||||
upstreamSessionID string
|
||||
lastOutput string
|
||||
lastUpstreamMethod string
|
||||
}
|
||||
@ -111,7 +112,7 @@ func NewServer(client rpcClient) *Server {
|
||||
providerID: strings.TrimSpace(shared.EnvOrDefault("HERMES_ADAPTER_PROVIDER_ID", defaultProviderID)),
|
||||
providerLabel: strings.TrimSpace(shared.EnvOrDefault("HERMES_ADAPTER_PROVIDER_LABEL", defaultLabel)),
|
||||
allowedOrigins: parseAllowedOrigins(strings.TrimSpace(shared.EnvOrDefault("HERMES_ADAPTER_ALLOWED_ORIGINS", "https://xworkmate.svc.plus,http://localhost:*,http://127.0.0.1:*"))),
|
||||
upstreamMethod: strings.TrimSpace(shared.EnvOrDefault("HERMES_ADAPTER_UPSTREAM_METHOD", "session.start")),
|
||||
upstreamMethod: strings.TrimSpace(shared.EnvOrDefault("HERMES_ADAPTER_UPSTREAM_METHOD", "prompt")),
|
||||
sessionRunner: func(ctx context.Context, model, prompt, workingDirectory string) (string, error) {
|
||||
return shared.RunProviderCommand(
|
||||
ctx,
|
||||
@ -272,12 +273,15 @@ func (s *Server) handleSessionRequest(method string, params map[string]any) map[
|
||||
}
|
||||
upstreamMethod := s.upstreamMethod
|
||||
if upstreamMethod != "" {
|
||||
return s.handleConfiguredUpstreamSessionRequest(upstreamMethod, params)
|
||||
return s.handleConfiguredUpstreamSessionRequest(method, upstreamMethod, params)
|
||||
}
|
||||
return s.handleCompatSessionRequest(method, params)
|
||||
}
|
||||
|
||||
func (s *Server) handleConfiguredUpstreamSessionRequest(upstreamMethod string, params map[string]any) map[string]any {
|
||||
func (s *Server) handleConfiguredUpstreamSessionRequest(method, upstreamMethod string, params map[string]any) map[string]any {
|
||||
if strings.TrimSpace(strings.ToLower(upstreamMethod)) == "prompt" {
|
||||
return s.handleHermesPromptUpstreamSessionRequest(method, params)
|
||||
}
|
||||
response, err := s.client.Call(upstreamMethod, params)
|
||||
if err != nil {
|
||||
return map[string]any{
|
||||
@ -317,6 +321,193 @@ func (s *Server) handleConfiguredUpstreamSessionRequest(upstreamMethod string, p
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handleHermesPromptUpstreamSessionRequest(method string, params map[string]any) map[string]any {
|
||||
sessionID := strings.TrimSpace(shared.StringArg(params, "sessionId", ""))
|
||||
if sessionID == "" {
|
||||
return map[string]any{
|
||||
"success": false,
|
||||
"provider": s.providerID,
|
||||
"mode": "single-agent",
|
||||
"error": "sessionId is required",
|
||||
}
|
||||
}
|
||||
|
||||
state := s.getOrCreateSession(sessionID)
|
||||
if method == "session.start" {
|
||||
state = s.resetSession(sessionID)
|
||||
}
|
||||
taskPrompt := strings.TrimSpace(shared.StringArg(params, "taskPrompt", ""))
|
||||
taskPrompt = shared.AugmentPromptWithAttachments(taskPrompt, params)
|
||||
if taskPrompt == "" {
|
||||
return map[string]any{
|
||||
"success": false,
|
||||
"provider": s.providerID,
|
||||
"mode": "single-agent",
|
||||
"error": "taskPrompt is required",
|
||||
}
|
||||
}
|
||||
|
||||
workingDirectory := strings.TrimSpace(shared.StringArg(params, "workingDirectory", ""))
|
||||
if workingDirectory == "" {
|
||||
workingDirectory = state.workingDirectory
|
||||
}
|
||||
if workingDirectory == "" {
|
||||
workingDirectory = "."
|
||||
}
|
||||
|
||||
if state.upstreamSessionID == "" || method == "session.start" {
|
||||
newSessionResp, err := s.client.Call("new_session", map[string]any{
|
||||
"cwd": workingDirectory,
|
||||
})
|
||||
if err != nil {
|
||||
return map[string]any{
|
||||
"success": false,
|
||||
"provider": s.providerID,
|
||||
"mode": "single-agent",
|
||||
"error": err.Error(),
|
||||
}
|
||||
}
|
||||
state.upstreamSessionID = extractHermesUpstreamSessionID(newSessionResp)
|
||||
if state.upstreamSessionID == "" {
|
||||
return map[string]any{
|
||||
"success": false,
|
||||
"provider": s.providerID,
|
||||
"mode": "single-agent",
|
||||
"error": "hermes upstream did not return a session id",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
s.sessionsMu.Lock()
|
||||
current := s.sessions[sessionID]
|
||||
if current == nil {
|
||||
current = &adapterSession{}
|
||||
s.sessions[sessionID] = current
|
||||
}
|
||||
current.upstreamSessionID = state.upstreamSessionID
|
||||
current.workingDirectory = workingDirectory
|
||||
current.model = strings.TrimSpace(shared.StringArg(params, "model", current.model))
|
||||
s.sessionsMu.Unlock()
|
||||
|
||||
var outputParts []string
|
||||
notificationHandler := func(notification map[string]any) {
|
||||
text := extractHermesSessionUpdateText(notification)
|
||||
if text != "" {
|
||||
outputParts = append(outputParts, text)
|
||||
}
|
||||
}
|
||||
s.client.SetNotificationHandler(notificationHandler)
|
||||
defer s.client.SetNotificationHandler(nil)
|
||||
|
||||
promptPayload := []map[string]any{
|
||||
{
|
||||
"type": "text",
|
||||
"text": taskPrompt,
|
||||
},
|
||||
}
|
||||
response, err := s.client.Call("prompt", map[string]any{
|
||||
"sessionId": state.upstreamSessionID,
|
||||
"prompt": promptPayload,
|
||||
})
|
||||
if err != nil {
|
||||
return map[string]any{
|
||||
"success": false,
|
||||
"provider": s.providerID,
|
||||
"mode": "single-agent",
|
||||
"error": err.Error(),
|
||||
}
|
||||
}
|
||||
|
||||
output := strings.TrimSpace(strings.Join(outputParts, ""))
|
||||
if output == "" {
|
||||
if resultMap, ok := response["result"].(map[string]any); ok {
|
||||
for _, key := range []string{"output", "finalResponse", "final_response", "text", "message"} {
|
||||
if candidate := strings.TrimSpace(shared.StringArg(resultMap, key, "")); candidate != "" {
|
||||
output = candidate
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if output == "" {
|
||||
output = "ok"
|
||||
}
|
||||
|
||||
s.sessionsMu.Lock()
|
||||
current = s.sessions[sessionID]
|
||||
if current == nil {
|
||||
current = &adapterSession{}
|
||||
s.sessions[sessionID] = current
|
||||
}
|
||||
current.history = append(current.history, "USER: "+taskPrompt, "ASSISTANT: "+output)
|
||||
current.lastOutput = output
|
||||
current.lastUpstreamMethod = "prompt"
|
||||
s.sessionsMu.Unlock()
|
||||
|
||||
result := map[string]any{
|
||||
"success": true,
|
||||
"provider": s.providerID,
|
||||
"mode": "single-agent",
|
||||
"output": output,
|
||||
"sessionId": sessionID,
|
||||
"upstreamMethod": "prompt",
|
||||
}
|
||||
if workingDirectory != "" {
|
||||
result["effectiveWorkingDirectory"] = workingDirectory
|
||||
}
|
||||
if state.upstreamSessionID != "" {
|
||||
result["upstreamSessionId"] = state.upstreamSessionID
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func extractHermesUpstreamSessionID(response map[string]any) string {
|
||||
for _, key := range []string{"sessionId", "session_id", "id"} {
|
||||
if value := strings.TrimSpace(shared.StringArg(asMap(response["result"]), key, "")); value != "" {
|
||||
return value
|
||||
}
|
||||
if value := strings.TrimSpace(shared.StringArg(response, key, "")); value != "" {
|
||||
return value
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func extractHermesSessionUpdateText(notification map[string]any) string {
|
||||
if notification == nil {
|
||||
return ""
|
||||
}
|
||||
payload := asMap(notification["params"])
|
||||
if len(payload) == 0 {
|
||||
payload = notification
|
||||
}
|
||||
update := asMap(payload["update"])
|
||||
if len(update) == 0 {
|
||||
update = payload
|
||||
}
|
||||
for _, key := range []string{"text", "message", "content", "delta"} {
|
||||
if text := strings.TrimSpace(shared.StringArg(update, key, "")); text != "" {
|
||||
if updateKind := strings.TrimSpace(shared.StringArg(update, "sessionUpdate", "")); updateKind == "" || updateKind == "agent_message_chunk" || updateKind == "agent_message_text" {
|
||||
return text
|
||||
}
|
||||
}
|
||||
}
|
||||
if text := strings.TrimSpace(shared.StringArg(payload, "text", "")); text != "" {
|
||||
return text
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func asMap(value any) map[string]any {
|
||||
if value == nil {
|
||||
return nil
|
||||
}
|
||||
if result, ok := value.(map[string]any); ok {
|
||||
return result
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) handleCompatSessionRequest(method string, params map[string]any) map[string]any {
|
||||
if s.sessionRunner == nil {
|
||||
return map[string]any{
|
||||
@ -414,6 +605,7 @@ func (s *Server) getOrCreateSession(sessionID string) *adapterSession {
|
||||
history: append([]string(nil), state.history...),
|
||||
model: state.model,
|
||||
workingDirectory: state.workingDirectory,
|
||||
upstreamSessionID: state.upstreamSessionID,
|
||||
lastOutput: state.lastOutput,
|
||||
lastUpstreamMethod: state.lastUpstreamMethod,
|
||||
}
|
||||
|
||||
@ -14,12 +14,15 @@ import (
|
||||
)
|
||||
|
||||
type stubClient struct {
|
||||
initResult initializeResult
|
||||
initErr error
|
||||
callResult map[string]any
|
||||
callErr error
|
||||
lastMethod string
|
||||
lastParams map[string]any
|
||||
initResult initializeResult
|
||||
initErr error
|
||||
callResult map[string]any
|
||||
callErr error
|
||||
callFn func(method string, params map[string]any) (map[string]any, error)
|
||||
lastMethod string
|
||||
lastParams map[string]any
|
||||
methods []string
|
||||
notificationHandler func(map[string]any)
|
||||
}
|
||||
|
||||
func (s *stubClient) Initialize() (initializeResult, error) {
|
||||
@ -29,9 +32,17 @@ func (s *stubClient) Initialize() (initializeResult, error) {
|
||||
func (s *stubClient) Call(method string, params map[string]any) (map[string]any, error) {
|
||||
s.lastMethod = method
|
||||
s.lastParams = params
|
||||
s.methods = append(s.methods, method)
|
||||
if s.callFn != nil {
|
||||
return s.callFn(method, params)
|
||||
}
|
||||
return s.callResult, s.callErr
|
||||
}
|
||||
|
||||
func (s *stubClient) SetNotificationHandler(handler func(map[string]any)) {
|
||||
s.notificationHandler = handler
|
||||
}
|
||||
|
||||
func (s *stubClient) Close() error { return nil }
|
||||
|
||||
func TestHandleCapabilitiesSynthesizesProviderResponse(t *testing.T) {
|
||||
@ -49,17 +60,38 @@ func TestHandleCapabilitiesSynthesizesProviderResponse(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestHandleRPCSessionStartReturnsUpstreamResult(t *testing.T) {
|
||||
stub := &stubClient{
|
||||
initResult: initializeResult{ProtocolVersion: 1},
|
||||
callResult: map[string]any{
|
||||
"result": map[string]any{
|
||||
"success": true,
|
||||
"output": "hello",
|
||||
},
|
||||
},
|
||||
var stub *stubClient
|
||||
stub = &stubClient{initResult: initializeResult{ProtocolVersion: 1}}
|
||||
stub.callFn = func(method string, params map[string]any) (map[string]any, error) {
|
||||
switch method {
|
||||
case "new_session":
|
||||
return map[string]any{
|
||||
"result": map[string]any{
|
||||
"sessionId": "upstream-session-1",
|
||||
},
|
||||
}, nil
|
||||
case "prompt":
|
||||
if stub.notificationHandler != nil {
|
||||
stub.notificationHandler(map[string]any{
|
||||
"params": map[string]any{
|
||||
"update": map[string]any{
|
||||
"sessionUpdate": "agent_message_chunk",
|
||||
"text": "hello",
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
return map[string]any{
|
||||
"result": map[string]any{
|
||||
"stopReason": "end_turn",
|
||||
},
|
||||
}, nil
|
||||
default:
|
||||
return map[string]any{"result": map[string]any{}}, nil
|
||||
}
|
||||
}
|
||||
server := NewServer(stub)
|
||||
server.upstreamMethod = "session.start"
|
||||
server.upstreamMethod = "prompt"
|
||||
|
||||
body, _ := json.Marshal(shared.RPCRequest{
|
||||
JSONRPC: "2.0",
|
||||
@ -87,8 +119,8 @@ func TestHandleRPCSessionStartReturnsUpstreamResult(t *testing.T) {
|
||||
if got := result["output"]; got != "hello" {
|
||||
t.Fatalf("expected output hello, got %#v", result)
|
||||
}
|
||||
if stub.lastMethod != "session.start" {
|
||||
t.Fatalf("expected upstream method session.start, got %q", stub.lastMethod)
|
||||
if len(stub.methods) != 2 || stub.methods[0] != "new_session" || stub.methods[1] != "prompt" {
|
||||
t.Fatalf("expected new_session then prompt, got %#v", stub.methods)
|
||||
}
|
||||
}
|
||||
|
||||
@ -120,10 +152,10 @@ func TestHandleSessionStartFallsBackToPromptRunner(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewServerDefaultsHermesToUpstreamSessionStart(t *testing.T) {
|
||||
func TestNewServerDefaultsHermesToUpstreamPrompt(t *testing.T) {
|
||||
server := NewServer(&stubClient{})
|
||||
if got := server.upstreamMethod; got != "session.start" {
|
||||
t.Fatalf("expected default upstream method session.start, got %q", got)
|
||||
if got := server.upstreamMethod; got != "prompt" {
|
||||
t.Fatalf("expected default upstream method prompt, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user