112 lines
3.5 KiB
Go
112 lines
3.5 KiB
Go
package acp
|
|
|
|
import (
|
|
"encoding/json"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"strings"
|
|
"testing"
|
|
)
|
|
|
|
func TestHandleWebSocketRejectsUnknownOrigin(t *testing.T) {
|
|
t.Setenv("ACP_ALLOWED_ORIGINS", "https://xworkmate.svc.plus")
|
|
|
|
server := NewServer()
|
|
recorder := httptest.NewRecorder()
|
|
request := httptest.NewRequest(http.MethodGet, "http://127.0.0.1/acp", nil)
|
|
request.Header.Set("Origin", "https://evil.example.com")
|
|
|
|
server.HandleWebSocket(recorder, request)
|
|
|
|
if recorder.Code != http.StatusForbidden {
|
|
t.Fatalf("expected 403, got %d", recorder.Code)
|
|
}
|
|
if got := recorder.Header().Get("Content-Type"); !strings.Contains(got, "application/json") {
|
|
t.Fatalf("expected application/json content type, got %q", got)
|
|
}
|
|
}
|
|
|
|
func TestHandleRPCAllowsPreflightForConfiguredOrigin(t *testing.T) {
|
|
t.Setenv("ACP_ALLOWED_ORIGINS", "https://xworkmate.svc.plus,http://localhost:*")
|
|
|
|
server := NewServer()
|
|
recorder := httptest.NewRecorder()
|
|
request := httptest.NewRequest(http.MethodOptions, "http://127.0.0.1/acp/rpc", nil)
|
|
request.Header.Set("Origin", "https://xworkmate.svc.plus")
|
|
request.Header.Set("Access-Control-Request-Method", "POST")
|
|
|
|
server.HandleRPC(recorder, request)
|
|
|
|
if recorder.Code != http.StatusNoContent {
|
|
t.Fatalf("expected 204, got %d", recorder.Code)
|
|
}
|
|
if got := recorder.Header().Get("Access-Control-Allow-Origin"); got != "https://xworkmate.svc.plus" {
|
|
t.Fatalf("expected allow origin header, got %q", got)
|
|
}
|
|
}
|
|
|
|
func TestHandleRPCRejectsUnknownOrigin(t *testing.T) {
|
|
t.Setenv("ACP_ALLOWED_ORIGINS", "https://xworkmate.svc.plus")
|
|
|
|
server := NewServer()
|
|
recorder := httptest.NewRecorder()
|
|
request := httptest.NewRequest(
|
|
http.MethodPost,
|
|
"http://127.0.0.1/acp/rpc",
|
|
strings.NewReader(`{"jsonrpc":"2.0","id":1,"method":"acp.capabilities"}`),
|
|
)
|
|
request.Header.Set("Origin", "https://evil.example.com")
|
|
request.Header.Set("Content-Type", "application/json")
|
|
|
|
server.HandleRPC(recorder, request)
|
|
|
|
if recorder.Code != http.StatusForbidden {
|
|
t.Fatalf("expected 403, got %d", recorder.Code)
|
|
}
|
|
var envelope map[string]any
|
|
if err := json.Unmarshal(recorder.Body.Bytes(), &envelope); err != nil {
|
|
t.Fatalf("decode error envelope: %v", err)
|
|
}
|
|
if _, ok := envelope["error"]; !ok {
|
|
t.Fatalf("expected JSON-RPC error envelope, got %v", envelope)
|
|
}
|
|
}
|
|
|
|
func TestHandleRPCMethodErrorUsesJSONEnvelope(t *testing.T) {
|
|
server := NewServer()
|
|
recorder := httptest.NewRecorder()
|
|
request := httptest.NewRequest(http.MethodGet, "http://127.0.0.1/acp/rpc", nil)
|
|
|
|
server.HandleRPC(recorder, request)
|
|
|
|
if recorder.Code != http.StatusMethodNotAllowed {
|
|
t.Fatalf("expected 405, got %d", recorder.Code)
|
|
}
|
|
if got := recorder.Header().Get("Content-Type"); !strings.Contains(got, "application/json") {
|
|
t.Fatalf("expected application/json content type, got %q", got)
|
|
}
|
|
}
|
|
|
|
func TestHandleRPCCapabilitiesStillReturnsJSONResult(t *testing.T) {
|
|
server := NewServer()
|
|
recorder := httptest.NewRecorder()
|
|
request := httptest.NewRequest(
|
|
http.MethodPost,
|
|
"http://127.0.0.1/acp/rpc",
|
|
strings.NewReader(`{"jsonrpc":"2.0","id":1,"method":"acp.capabilities"}`),
|
|
)
|
|
request.Header.Set("Content-Type", "application/json")
|
|
|
|
server.HandleRPC(recorder, request)
|
|
|
|
if recorder.Code != http.StatusOK {
|
|
t.Fatalf("expected 200, got %d", recorder.Code)
|
|
}
|
|
if got := recorder.Header().Get("Content-Type"); !strings.Contains(got, "application/json") {
|
|
t.Fatalf("expected application/json content type, got %q", got)
|
|
}
|
|
if !strings.Contains(recorder.Body.String(), `"providers"`) {
|
|
t.Fatalf("expected capabilities response, got %q", recorder.Body.String())
|
|
}
|
|
}
|