Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 27 additions & 4 deletions cmd/cli/commands/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -809,11 +809,34 @@ func newRunCmd() *cobra.Command {
return nil
}

_, err := desktopClient.Inspect(model, false)
if err != nil {
if !errors.Is(err, desktop.ErrNotFound) {
return handleClientError(err, "Failed to inspect model")
modelInfo, err := desktopClient.Inspect(model, false)
modelFoundLocally := err == nil
if err != nil && !errors.Is(err, desktop.ErrNotFound) {
return handleClientError(err, "Failed to inspect model")
}

if !modelFoundLocally {
remoteInfo, remoteErr := desktopClient.Inspect(model, true)
if remoteErr == nil {
modelInfo = remoteInfo
}
}

backend := ""
if modelInfo.ID != "" {
backend, _ = GetRequiredBackendFromModelInfo(&modelInfo)
}

if backend != "" {
if err := EnsureBackendAvailable(backend, cmd); err != nil {
if err.Error() == "backend installation cancelled" {
return nil
}
return err
}
}

if !modelFoundLocally {
cmd.Println("Unable to find model '" + model + "' locally. Pulling from the server.")
if err := pullModel(cmd, desktopClient, model); err != nil {
return err
Expand Down
116 changes: 116 additions & 0 deletions cmd/cli/commands/utils.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
package commands

import (
"bufio"
"bytes"
"encoding/json"
"errors"
"fmt"
"io"
Expand All @@ -12,7 +14,10 @@ import (
"github.com/docker/model-runner/cmd/cli/desktop"
"github.com/docker/model-runner/cmd/cli/pkg/standalone"
"github.com/docker/model-runner/pkg/distribution/oci/reference"
"github.com/docker/model-runner/pkg/distribution/types"
"github.com/docker/model-runner/pkg/inference/backends/llamacpp"
"github.com/docker/model-runner/pkg/inference/backends/vllm"
dmrm "github.com/docker/model-runner/pkg/inference/models"
"github.com/moby/term"
"github.com/olekukonko/tablewriter"
"github.com/olekukonko/tablewriter/renderer"
Expand Down Expand Up @@ -270,3 +275,114 @@ func newTable(w io.Writer) *tablewriter.Table {
}),
)
}

func CheckBackendInstalled(backend string) (bool, error) {
status := desktopClient.Status()
if status.Error != nil {
return false, fmt.Errorf("failed to get backend status: %w", status.Error)
}

var backendStatus map[string]string
if err := json.Unmarshal(status.Status, &backendStatus); err != nil {
return false, fmt.Errorf("failed to parse backend status: %w", err)
}

backendState, exists := backendStatus[backend]
if !exists {
return false, nil
}

state := strings.TrimSpace(strings.ToLower(backendState))
if strings.HasPrefix(state, "not ") || strings.HasPrefix(state, "error") {
return false, nil
}

return strings.HasPrefix(state, "installed") || strings.HasPrefix(state, "running"), nil
}

func PromptInstallBackend(backend string, cmd *cobra.Command) (bool, error) {
fmt.Fprintf(cmd.OutOrStdout(), "Backend %q is not installed. Download and install it now? [Y/n]: ", backend)

reader := bufio.NewReader(os.Stdin)
input, err := reader.ReadString('\n')
if err != nil {
return false, fmt.Errorf("failed to read input: %w", err)
}

input = strings.TrimSpace(strings.ToLower(input))
return input == "" || input == "y" || input == "yes", nil
}

func InstallBackend(backend string, cmd *cobra.Command) error {
installCmd := newInstallRunner()
installCmd.SetArgs([]string{"--backend", backend})

if err := installCmd.Execute(); err != nil {
return fmt.Errorf("failed to install backend %s: %w", backend, err)
}

return nil
}

func EnsureBackendAvailable(backend string, cmd *cobra.Command) error {
installed, err := CheckBackendInstalled(backend)
if err != nil {
return err
}

if installed {
return nil
}

confirm, err := PromptInstallBackend(backend, cmd)
if err != nil {
return err
}

if !confirm {
cmd.Printf("Run 'docker model install-runner --backend %s' to install it manually.\n", backend)
return fmt.Errorf("backend installation cancelled")
}

if err := InstallBackend(backend, cmd); err != nil {
return err
}

installed, err = CheckBackendInstalled(backend)
if err != nil {
return err
}
if !installed {
return fmt.Errorf("backend %q is still not installed; run 'docker model install-runner --backend %s'", backend, backend)
}

cmd.Printf("Backend %q installed successfully.\n", backend)
return nil
}

func GetRequiredBackend(model string) (string, error) {
modelInfo, err := desktopClient.Inspect(model, false)
if err != nil {
return "", err
}

return GetRequiredBackendFromModelInfo(&modelInfo)
}

func GetRequiredBackendFromModelInfo(modelInfo *dmrm.Model) (string, error) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please check and use func (s *Scheduler) selectBackendForModel.
The server could expose a simple endpoint like:

GET /models/{name}/backend → {"backend": "vllm", "installed": true}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The backend selection logic in GetRequiredBackendFromModelInfo duplicates what the server already does in Scheduler.selectBackendForModel, and the two already diverge (the server has a platform-aware fallback chain: vLLM → MLX → SGLang, while the CLI always returns vllm for safetensors). Rather than maintaining this mapping in two places, consider calling the server to determine the required backend.

config, ok := modelInfo.Config.(*types.Config)
if !ok {
return llamacpp.Name, nil
}

switch config.Format {
case types.FormatSafetensors:
return vllm.Name, nil
case types.FormatGGUF:
return llamacpp.Name, nil
case types.FormatDiffusers:
return "diffusers", nil
default:
return llamacpp.Name, nil
}
}