xworkmate-bridge/internal/acp/web_contract_test.go

147 lines
4.6 KiB
Go

package acp
import (
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
)
func TestHandleWebSocketRejectsUnknownOrigin(t *testing.T) {
t.Setenv("ACP_ALLOWED_ORIGINS", "https://xworkmate.svc.plus")
server := NewServer()
recorder := httptest.NewRecorder()
request := httptest.NewRequest(http.MethodGet, "http://127.0.0.1/acp", nil)
request.Header.Set("Origin", "https://evil.example.com")
server.HandleWebSocket(recorder, request)
if recorder.Code != http.StatusForbidden {
t.Fatalf("expected 403, got %d", recorder.Code)
}
if got := recorder.Header().Get("Content-Type"); !strings.Contains(got, "application/json") {
t.Fatalf("expected application/json content type, got %q", got)
}
}
func TestHandleRPCAllowsPreflightForConfiguredOrigin(t *testing.T) {
t.Setenv("ACP_ALLOWED_ORIGINS", "https://xworkmate.svc.plus,http://localhost:*")
server := NewServer()
recorder := httptest.NewRecorder()
request := httptest.NewRequest(http.MethodOptions, "http://127.0.0.1/acp/rpc", nil)
request.Header.Set("Origin", "https://xworkmate.svc.plus")
request.Header.Set("Access-Control-Request-Method", "POST")
server.HandleRPC(recorder, request)
if recorder.Code != http.StatusNoContent {
t.Fatalf("expected 204, got %d", recorder.Code)
}
if got := recorder.Header().Get("Access-Control-Allow-Origin"); got != "https://xworkmate.svc.plus" {
t.Fatalf("expected allow origin header, got %q", got)
}
}
func TestHandleRPCRequiresBearerAuthorization(t *testing.T) {
server := NewServer()
recorder := httptest.NewRecorder()
request := httptest.NewRequest(
http.MethodPost,
"http://127.0.0.1/acp/rpc",
strings.NewReader(`{"jsonrpc":"2.0","id":1,"method":"acp.capabilities"}`),
)
request.Header.Set("Content-Type", "application/json")
server.HandleRPC(recorder, request)
if recorder.Code != http.StatusUnauthorized {
t.Fatalf("expected 401, got %d", recorder.Code)
}
}
func TestHandleRPCRejectsUnknownOrigin(t *testing.T) {
t.Setenv("ACP_ALLOWED_ORIGINS", "https://xworkmate.svc.plus")
server := NewServer()
recorder := httptest.NewRecorder()
request := httptest.NewRequest(
http.MethodPost,
"http://127.0.0.1/acp/rpc",
strings.NewReader(`{"jsonrpc":"2.0","id":1,"method":"acp.capabilities"}`),
)
request.Header.Set("Origin", "https://evil.example.com")
request.Header.Set("Content-Type", "application/json")
request.Header.Set("Authorization", "Bearer test")
server.HandleRPC(recorder, request)
if recorder.Code != http.StatusForbidden {
t.Fatalf("expected 403, got %d", recorder.Code)
}
var envelope map[string]any
if err := json.Unmarshal(recorder.Body.Bytes(), &envelope); err != nil {
t.Fatalf("decode error envelope: %v", err)
}
if _, ok := envelope["error"]; !ok {
t.Fatalf("expected JSON-RPC error envelope, got %v", envelope)
}
}
func TestHandleRPCMethodErrorUsesJSONEnvelope(t *testing.T) {
server := NewServer()
recorder := httptest.NewRecorder()
request := httptest.NewRequest(http.MethodGet, "http://127.0.0.1/acp/rpc", nil)
request.Header.Set("Authorization", "Bearer test")
server.HandleRPC(recorder, request)
if recorder.Code != http.StatusMethodNotAllowed {
t.Fatalf("expected 405, got %d", recorder.Code)
}
if got := recorder.Header().Get("Content-Type"); !strings.Contains(got, "application/json") {
t.Fatalf("expected application/json content type, got %q", got)
}
}
func TestHandleRPCCapabilitiesStillReturnsJSONResult(t *testing.T) {
server := NewServer()
recorder := httptest.NewRecorder()
request := httptest.NewRequest(
http.MethodPost,
"http://127.0.0.1/acp/rpc",
strings.NewReader(`{"jsonrpc":"2.0","id":1,"method":"acp.capabilities"}`),
)
request.Header.Set("Content-Type", "application/json")
request.Header.Set("Authorization", "Bearer test")
server.HandleRPC(recorder, request)
if recorder.Code != http.StatusOK {
t.Fatalf("expected 200, got %d", recorder.Code)
}
if got := recorder.Header().Get("Content-Type"); !strings.Contains(got, "application/json") {
t.Fatalf("expected application/json content type, got %q", got)
}
if !strings.Contains(recorder.Body.String(), `"providerCatalog"`) {
t.Fatalf("expected capabilities response, got %q", recorder.Body.String())
}
}
func TestHandleWebSocketRequiresBearerAuthorization(t *testing.T) {
t.Setenv("ACP_ALLOWED_ORIGINS", "https://xworkmate.svc.plus")
server := NewServer()
recorder := httptest.NewRecorder()
request := httptest.NewRequest(http.MethodGet, "http://127.0.0.1/acp", nil)
request.Header.Set("Origin", "https://xworkmate.svc.plus")
server.HandleWebSocket(recorder, request)
if recorder.Code != http.StatusUnauthorized {
t.Fatalf("expected 401, got %d", recorder.Code)
}
}