366 lines
11 KiB
Go
366 lines
11 KiB
Go
package router
|
|
|
|
import (
|
|
"os"
|
|
"sort"
|
|
"strings"
|
|
|
|
"xworkmate-bridge/internal/memory"
|
|
"xworkmate-bridge/internal/skills"
|
|
)
|
|
|
|
const (
|
|
RoutingModeAuto = "auto"
|
|
RoutingModeExplicit = "explicit"
|
|
|
|
ExecutionTargetSingleAgent = "single-agent"
|
|
ExecutionTargetMultiAgent = "multi-agent"
|
|
ExecutionTargetGateway = "gateway"
|
|
ExecutionTargetGatewayChat = "gateway-chat"
|
|
|
|
EndpointTargetSingleAgent = "singleAgent"
|
|
EndpointTargetLocal = "local"
|
|
EndpointTargetRemote = "remote"
|
|
)
|
|
|
|
type Request struct {
|
|
Prompt string
|
|
WorkingDirectory string
|
|
RoutingMode string
|
|
PreferredGatewayTarget string
|
|
ExplicitExecutionTarget string
|
|
ExplicitProviderID string
|
|
ExplicitModel string
|
|
ExplicitSkills []string
|
|
AllowSkillInstall bool
|
|
InstallApproval skills.InstallApproval
|
|
AvailableSkills []skills.Candidate
|
|
AvailableProviders []string
|
|
AIGatewayBaseURL string
|
|
AIGatewayAPIKey string
|
|
}
|
|
|
|
type Result struct {
|
|
ResolvedExecutionTarget string
|
|
ResolvedEndpointTarget string
|
|
ResolvedProviderID string
|
|
ResolvedModel string
|
|
ResolvedSkills []string
|
|
SkillResolutionSource string
|
|
SkillCandidates []skills.Candidate
|
|
NeedsSkillInstall bool
|
|
SkillInstallRequestID string
|
|
MemorySources []memory.Source
|
|
Unavailable bool
|
|
UnavailableCode string
|
|
UnavailableMessage string
|
|
}
|
|
|
|
type Resolver struct {
|
|
SkillFinder skills.Finder
|
|
SkillInstaller skills.Installer
|
|
MemoryService memory.Service
|
|
Classifier Classifier
|
|
}
|
|
|
|
func NewResolver() Resolver {
|
|
homeDir, _ := os.UserHomeDir()
|
|
return Resolver{
|
|
SkillFinder: skills.NewDefaultFinder(),
|
|
SkillInstaller: skills.NewDefaultInstaller(),
|
|
MemoryService: memory.NewService(homeDir),
|
|
Classifier: LLMClassifier{},
|
|
}
|
|
}
|
|
|
|
func (r Resolver) Resolve(req Request) Result {
|
|
mem := r.MemoryService.Load(req.WorkingDirectory)
|
|
availableProviders := normalizeProviders(req.AvailableProviders)
|
|
|
|
result := Result{
|
|
ResolvedModel: strings.TrimSpace(req.ExplicitModel),
|
|
MemorySources: mem.Sources,
|
|
}
|
|
|
|
result.ResolvedExecutionTarget, result.ResolvedEndpointTarget = r.resolveExecution(req, mem.Preferences)
|
|
result.ResolvedProviderID, result.Unavailable, result.UnavailableCode, result.UnavailableMessage = resolveProvider(
|
|
req,
|
|
mem.Preferences,
|
|
availableProviders,
|
|
result.ResolvedExecutionTarget,
|
|
)
|
|
if result.ResolvedModel == "" {
|
|
result.ResolvedModel = strings.TrimSpace(mem.Preferences.PreferredModel)
|
|
}
|
|
|
|
skillRequest := skills.ResolveRequest{
|
|
Prompt: req.Prompt,
|
|
ExplicitSkills: req.ExplicitSkills,
|
|
AvailableSkills: req.AvailableSkills,
|
|
AllowSkillInstall: req.AllowSkillInstall,
|
|
InstallApproval: req.InstallApproval,
|
|
}
|
|
skillResult := skills.Resolve(skillRequest, r.SkillFinder, r.SkillInstaller)
|
|
result.ResolvedSkills = skillResult.ResolvedSkills
|
|
result.SkillResolutionSource = skillResult.Source
|
|
result.SkillCandidates = skillResult.Candidates
|
|
result.NeedsSkillInstall = skillResult.NeedsInstall
|
|
result.SkillInstallRequestID = skillResult.InstallRequestID
|
|
|
|
if len(result.ResolvedSkills) == 0 && len(mem.Preferences.PreferredSkills) > 0 {
|
|
result.ResolvedSkills = append([]string(nil), mem.Preferences.PreferredSkills...)
|
|
if result.SkillResolutionSource == "" || result.SkillResolutionSource == "none" {
|
|
result.SkillResolutionSource = "local_match"
|
|
}
|
|
}
|
|
if result.SkillResolutionSource == "" {
|
|
result.SkillResolutionSource = "none"
|
|
}
|
|
if result.ResolvedExecutionTarget == "" {
|
|
if len(availableProviders) > 0 {
|
|
result.ResolvedExecutionTarget = ExecutionTargetSingleAgent
|
|
} else {
|
|
result.ResolvedExecutionTarget = ExecutionTargetGateway
|
|
}
|
|
}
|
|
if result.ResolvedEndpointTarget == "" {
|
|
if result.ResolvedExecutionTarget == ExecutionTargetGateway {
|
|
result.ResolvedEndpointTarget = normalizeGatewayTarget(req.PreferredGatewayTarget)
|
|
} else {
|
|
result.ResolvedEndpointTarget = EndpointTargetSingleAgent
|
|
}
|
|
}
|
|
return result
|
|
}
|
|
|
|
func (r Resolver) resolveExecution(req Request, prefs memory.Preferences) (string, string) {
|
|
explicit := strings.TrimSpace(req.ExplicitExecutionTarget)
|
|
if strings.EqualFold(strings.TrimSpace(req.RoutingMode), RoutingModeExplicit) && explicit != "" {
|
|
return mapExplicitTarget(explicit)
|
|
}
|
|
|
|
prompt := normalize(req.Prompt)
|
|
|
|
localTask := looksLocal(prompt)
|
|
onlineTask := looksOnline(prompt)
|
|
complexTask := looksComplex(prompt)
|
|
|
|
switch {
|
|
case localTask && complexTask:
|
|
return ExecutionTargetMultiAgent, EndpointTargetSingleAgent
|
|
case onlineTask && complexTask:
|
|
return ExecutionTargetMultiAgent, EndpointTargetSingleAgent
|
|
case localTask:
|
|
return ExecutionTargetSingleAgent, EndpointTargetSingleAgent
|
|
case onlineTask:
|
|
return ExecutionTargetGateway, normalizeGatewayTarget(req.PreferredGatewayTarget)
|
|
case complexTask:
|
|
return ExecutionTargetMultiAgent, EndpointTargetSingleAgent
|
|
}
|
|
|
|
switch normalizeExecutionTarget(r.classify(req)) {
|
|
case ExecutionTargetGateway:
|
|
return ExecutionTargetGateway, normalizeGatewayTarget(req.PreferredGatewayTarget)
|
|
case ExecutionTargetMultiAgent:
|
|
return ExecutionTargetMultiAgent, EndpointTargetSingleAgent
|
|
case ExecutionTargetSingleAgent:
|
|
return ExecutionTargetSingleAgent, EndpointTargetSingleAgent
|
|
}
|
|
|
|
switch normalizeExecutionTarget(strings.TrimSpace(prefs.PreferredRoute)) {
|
|
case ExecutionTargetGateway:
|
|
return ExecutionTargetGateway, normalizeGatewayTarget(req.PreferredGatewayTarget)
|
|
case ExecutionTargetMultiAgent:
|
|
return ExecutionTargetMultiAgent, EndpointTargetSingleAgent
|
|
case ExecutionTargetSingleAgent:
|
|
if len(normalizeProviders(req.AvailableProviders)) > 0 {
|
|
return ExecutionTargetSingleAgent, EndpointTargetSingleAgent
|
|
}
|
|
}
|
|
if len(normalizeProviders(req.AvailableProviders)) > 0 {
|
|
return ExecutionTargetSingleAgent, EndpointTargetSingleAgent
|
|
}
|
|
return ExecutionTargetGateway, normalizeGatewayTarget(req.PreferredGatewayTarget)
|
|
}
|
|
|
|
func (r Resolver) classify(req Request) string {
|
|
if r.Classifier == nil {
|
|
return ""
|
|
}
|
|
return normalizeExecutionTarget(r.Classifier.Classify(ClassificationRequest{
|
|
Prompt: req.Prompt,
|
|
AIGatewayBaseURL: req.AIGatewayBaseURL,
|
|
AIGatewayAPIKey: req.AIGatewayAPIKey,
|
|
}))
|
|
}
|
|
|
|
func mapExplicitTarget(value string) (string, string) {
|
|
switch strings.TrimSpace(value) {
|
|
case EndpointTargetLocal:
|
|
return ExecutionTargetGateway, EndpointTargetLocal
|
|
case EndpointTargetRemote:
|
|
return ExecutionTargetGateway, EndpointTargetRemote
|
|
case "multiAgent", ExecutionTargetMultiAgent:
|
|
return ExecutionTargetMultiAgent, EndpointTargetSingleAgent
|
|
case EndpointTargetSingleAgent, ExecutionTargetSingleAgent:
|
|
return ExecutionTargetSingleAgent, EndpointTargetSingleAgent
|
|
default:
|
|
return ExecutionTargetSingleAgent, EndpointTargetSingleAgent
|
|
}
|
|
}
|
|
|
|
func normalizeGatewayTarget(value string) string {
|
|
switch strings.TrimSpace(value) {
|
|
case EndpointTargetLocal, "":
|
|
return EndpointTargetLocal
|
|
default:
|
|
return EndpointTargetRemote
|
|
}
|
|
}
|
|
|
|
func resolveProvider(
|
|
req Request,
|
|
prefs memory.Preferences,
|
|
availableProviders []string,
|
|
executionTarget string,
|
|
) (string, bool, string, string) {
|
|
explicitProviderID := normalize(strings.TrimSpace(req.ExplicitProviderID))
|
|
if explicitProviderID != "" {
|
|
if containsProvider(availableProviders, explicitProviderID) {
|
|
return explicitProviderID, false, "", ""
|
|
}
|
|
return "", true, "PROVIDER_UNAVAILABLE", "explicit provider is unavailable"
|
|
}
|
|
|
|
if executionTarget != ExecutionTargetSingleAgent {
|
|
preferredProvider := normalize(strings.TrimSpace(prefs.Provider))
|
|
if containsProvider(availableProviders, preferredProvider) {
|
|
return preferredProvider, false, "", ""
|
|
}
|
|
return "", false, "", ""
|
|
}
|
|
|
|
preferredProvider := normalize(strings.TrimSpace(prefs.Provider))
|
|
if containsProvider(availableProviders, preferredProvider) {
|
|
return preferredProvider, false, "", ""
|
|
}
|
|
if len(availableProviders) > 0 {
|
|
return availableProviders[0], false, "", ""
|
|
}
|
|
return "", true, "PROVIDER_UNAVAILABLE", "no single-agent provider is available"
|
|
}
|
|
|
|
func normalizeProviders(values []string) []string {
|
|
if len(values) == 0 {
|
|
return nil
|
|
}
|
|
unique := make(map[string]struct{}, len(values))
|
|
normalized := make([]string, 0, len(values))
|
|
for _, value := range values {
|
|
providerID := normalize(value)
|
|
if providerID == "" {
|
|
continue
|
|
}
|
|
if _, ok := unique[providerID]; ok {
|
|
continue
|
|
}
|
|
unique[providerID] = struct{}{}
|
|
normalized = append(normalized, providerID)
|
|
}
|
|
sort.Strings(normalized)
|
|
return normalized
|
|
}
|
|
|
|
func containsProvider(values []string, want string) bool {
|
|
want = normalize(want)
|
|
if want == "" {
|
|
return false
|
|
}
|
|
for _, value := range values {
|
|
if normalize(value) == want {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
func looksLocal(prompt string) bool {
|
|
return containsAny(prompt, []string{
|
|
"ppt", "pptx", "powerpoint", "word", "docx", "excel", "xlsx", "pdf",
|
|
"image-resizer", "resize image", "compress image", "crop image",
|
|
})
|
|
}
|
|
|
|
func looksOnline(prompt string) bool {
|
|
return containsAny(prompt, []string{
|
|
"image-cog", "wan", "video-translator", "browser", "search", "news",
|
|
"资讯采集", "跨浏览器", "文生图", "文生视频", "图生视频", "视频翻译",
|
|
"translate video", "dub video", "subtitles",
|
|
})
|
|
}
|
|
|
|
func looksComplex(prompt string) bool {
|
|
strongSignals := containsAny(prompt, []string{
|
|
"multiple deliverables", "multiple outputs", "多个产物", "多个输出",
|
|
"审阅", "复核", "汇编", "end-to-end", "end to end",
|
|
})
|
|
if strongSignals {
|
|
return true
|
|
}
|
|
|
|
reviewSignals := containsAny(prompt, []string{
|
|
"review", "audit", "verify", "summarize", "compare",
|
|
"审阅", "复核", "汇总", "对比", "整理", "整合", "汇编",
|
|
})
|
|
multiStepSignals := containsAny(prompt, []string{
|
|
"workflow", "pipeline", "step by step", "multi-step", "collect and",
|
|
"analyze and", "review and", "compare and", "summarize and",
|
|
"先", "然后", "之后",
|
|
})
|
|
structuredOutputSignals := containsAny(prompt, []string{
|
|
"report", "memo", "table", "spreadsheet", "document", "deck", "slides",
|
|
"presentation", "报告", "总结", "表格", "文档", "演示",
|
|
})
|
|
onlineCollectionSignals := containsAny(prompt, []string{
|
|
"browser", "search", "news", "research", "crawl", "scrape",
|
|
"跨浏览器", "搜索", "资讯", "采集", "检索",
|
|
})
|
|
|
|
score := 0
|
|
if reviewSignals {
|
|
score++
|
|
}
|
|
if multiStepSignals {
|
|
score++
|
|
}
|
|
if structuredOutputSignals {
|
|
score++
|
|
}
|
|
if onlineCollectionSignals && structuredOutputSignals {
|
|
return true
|
|
}
|
|
return score >= 2
|
|
}
|
|
|
|
func containsAny(haystack string, needles []string) bool {
|
|
for _, needle := range needles {
|
|
if strings.Contains(haystack, normalize(needle)) {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
func normalize(value string) string {
|
|
return strings.ToLower(strings.TrimSpace(value))
|
|
}
|
|
|
|
func normalizeExecutionTarget(value string) string {
|
|
switch normalize(value) {
|
|
case ExecutionTargetGatewayChat:
|
|
return ExecutionTargetGateway
|
|
default:
|
|
return normalize(value)
|
|
}
|
|
}
|