Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 27 additions & 4 deletions cmd/cli/commands/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -809,11 +809,34 @@ func newRunCmd() *cobra.Command {
return nil
}

_, err := desktopClient.Inspect(model, false)
if err != nil {
if !errors.Is(err, desktop.ErrNotFound) {
return handleClientError(err, "Failed to inspect model")
modelInfo, err := desktopClient.Inspect(model, false)
modelFoundLocally := err == nil
if err != nil && !errors.Is(err, desktop.ErrNotFound) {
return handleClientError(err, "Failed to inspect model")
}

if !modelFoundLocally {
remoteInfo, remoteErr := desktopClient.Inspect(model, true)
if remoteErr == nil {
modelInfo = remoteInfo
}
}

backend := ""
if modelInfo.ID != "" {
backend, _ = GetRequiredBackendFromModelInfo(&modelInfo)
}

if backend != "" {
if err := EnsureBackendAvailable(backend, cmd); err != nil {
if errors.Is(err, errBackendInstallationCancelled) {
return nil
}
return err
}
}

if !modelFoundLocally {
cmd.Println("Unable to find model '" + model + "' locally. Pulling from the server.")
if err := pullModel(cmd, desktopClient, model); err != nil {
return err
Expand Down
107 changes: 107 additions & 0 deletions cmd/cli/commands/utils.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
package commands

import (
"bufio"
"bytes"
"encoding/json"
"errors"
"fmt"
"io"
Expand All @@ -11,7 +13,11 @@ import (
"github.com/docker/model-runner/cmd/cli/desktop"
"github.com/docker/model-runner/cmd/cli/pkg/standalone"
"github.com/docker/model-runner/pkg/distribution/oci/reference"
"github.com/docker/model-runner/pkg/distribution/types"
"github.com/docker/model-runner/pkg/inference/backends/diffusers"
"github.com/docker/model-runner/pkg/inference/backends/llamacpp"
"github.com/docker/model-runner/pkg/inference/backends/vllm"
dmrm "github.com/docker/model-runner/pkg/inference/models"
"github.com/moby/term"
"github.com/olekukonko/tablewriter"
"github.com/olekukonko/tablewriter/renderer"
Expand Down Expand Up @@ -42,6 +48,8 @@ func getDefaultRegistry() string {

var errNotRunning = fmt.Errorf("Docker Model Runner is not running. Please start it and try again.\n")

var errBackendInstallationCancelled = errors.New("backend installation cancelled")

func handleClientError(err error, message string) error {
if errors.Is(err, desktop.ErrServiceUnavailable) {
err = errNotRunning
Expand Down Expand Up @@ -270,6 +278,105 @@ func newTable(w io.Writer) *tablewriter.Table {
)
}

func CheckBackendInstalled(backend string) (bool, error) {
status := desktopClient.Status()
if status.Error != nil {
return false, fmt.Errorf("failed to get backend status: %w", status.Error)
}

var backendStatus map[string]string
if err := json.Unmarshal(status.Status, &backendStatus); err != nil {
return false, fmt.Errorf("failed to parse backend status: %w", err)
}

backendState, exists := backendStatus[backend]
if !exists {
return false, nil
}

state := strings.TrimSpace(strings.ToLower(backendState))
if strings.HasPrefix(state, "not ") || strings.HasPrefix(state, "error") {
return false, nil
}

return strings.HasPrefix(state, "installed") || strings.HasPrefix(state, "running"), nil
}

func PromptInstallBackend(backend string, cmd *cobra.Command) (bool, error) {
fmt.Fprintf(cmd.OutOrStdout(), "Backend %q is not installed. Download and install it now? [Y/n]: ", backend)

reader := bufio.NewReader(cmd.InOrStdin())
input, err := reader.ReadString('\n')
if err != nil {
return false, fmt.Errorf("failed to read input: %w", err)
}

input = strings.TrimSpace(strings.ToLower(input))
return input == "" || input == "y" || input == "yes", nil
}

func InstallBackend(backend string) error {
if err := desktopClient.InstallBackend(backend); err != nil {
return fmt.Errorf("failed to install backend %s: %w", backend, err)
}

return nil
}

func EnsureBackendAvailable(backend string, cmd *cobra.Command) error {
installed, err := CheckBackendInstalled(backend)
if err != nil {
return err
}

if installed {
return nil
}

confirm, err := PromptInstallBackend(backend, cmd)
if err != nil {
return err
}

if !confirm {
cmd.Printf("Run 'docker model install-runner --backend %s' to install it manually.\n", backend)
return errBackendInstallationCancelled
}

if err := InstallBackend(backend); err != nil {
return err
}

installed, err = CheckBackendInstalled(backend)
if err != nil {
return err
}
if !installed {
return fmt.Errorf("backend %q is still not installed; run 'docker model install-runner --backend %s'", backend, backend)
}

cmd.Printf("Backend %q installed successfully.\n", backend)
return nil
}

func GetRequiredBackendFromModelInfo(modelInfo *dmrm.Model) (string, error) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please check and use func (s *Scheduler) selectBackendForModel.
The server could expose a simple endpoint like:

GET /models/{name}/backend → {"backend": "vllm", "installed": true}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The backend selection logic in GetRequiredBackendFromModelInfo duplicates what the server already does in Scheduler.selectBackendForModel, and the two already diverge (the server has a platform-aware fallback chain: vLLM → MLX → SGLang, while the CLI always returns vllm for safetensors). Rather than maintaining this mapping in two places, consider calling the server to determine the required backend.

config, ok := modelInfo.Config.(*types.Config)
if !ok {
return llamacpp.Name, nil
}

switch config.Format {
case types.FormatSafetensors:
return vllm.Name, nil
case types.FormatGGUF:
return llamacpp.Name, nil
case types.FormatDiffusers:
return diffusers.Name, nil
default:
return llamacpp.Name, nil
}
}

func printNextSteps(out io.Writer, messages []string) {
if len(messages) == 0 {
return
Expand Down
126 changes: 126 additions & 0 deletions cmd/cli/commands/utils_test.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,26 @@
package commands

import (
"bytes"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strings"
"testing"

"github.com/docker/model-runner/cmd/cli/desktop"
mockdesktop "github.com/docker/model-runner/cmd/cli/mocks"
"github.com/docker/model-runner/pkg/distribution/types"
"github.com/docker/model-runner/pkg/inference"
"github.com/docker/model-runner/pkg/inference/backends/diffusers"
"github.com/docker/model-runner/pkg/inference/backends/llamacpp"
"github.com/docker/model-runner/pkg/inference/backends/vllm"
dmrm "github.com/docker/model-runner/pkg/inference/models"
"github.com/spf13/cobra"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
)

func TestStripDefaultsFromModelName(t *testing.T) {
Expand Down Expand Up @@ -112,3 +129,112 @@ func TestHandleClientErrorFormat(t *testing.T) {
}
})
}

func setupDesktopClientStatusMock(t *testing.T, ctrl *gomock.Controller, backendStatus map[string]string) {
t.Helper()

client := mockdesktop.NewMockDockerHttpClient(ctrl)
modelRunner = desktop.NewContextForMock(client)
desktopClient = desktop.New(modelRunner)

statusJSON, err := json.Marshal(backendStatus)
require.NoError(t, err)

expectedModelsURL := modelRunner.URL(inference.ModelsPrefix)
expectedStatusURL := modelRunner.URL(inference.InferencePrefix + "/status")
expectedUserAgent := "docker-model-cli/" + desktop.Version

client.EXPECT().Do(gomock.Cond(func(req any) bool {
r, ok := req.(*http.Request)
return ok && r.URL.String() == expectedModelsURL && r.Header.Get("User-Agent") == expectedUserAgent
})).Return(&http.Response{StatusCode: http.StatusOK, Body: io.NopCloser(strings.NewReader("{}"))}, nil)

client.EXPECT().Do(gomock.Cond(func(req any) bool {
r, ok := req.(*http.Request)
return ok && r.URL.String() == expectedStatusURL && r.Header.Get("User-Agent") == expectedUserAgent
})).Return(&http.Response{StatusCode: http.StatusOK, Body: io.NopCloser(bytes.NewReader(statusJSON))}, nil)
}

func TestCheckBackendInstalled(t *testing.T) {
t.Run("running status string is treated as installed", func(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()

setupDesktopClientStatusMock(t, ctrl, map[string]string{"vllm": "running vllm latest-cuda"})

installed, err := CheckBackendInstalled(vllm.Name)
require.NoError(t, err)
require.True(t, installed)
})

t.Run("not running status is treated as missing", func(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()

setupDesktopClientStatusMock(t, ctrl, map[string]string{"vllm": "not running"})

installed, err := CheckBackendInstalled(vllm.Name)
require.NoError(t, err)
require.False(t, installed)
})

t.Run("error status is treated as missing", func(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()

setupDesktopClientStatusMock(t, ctrl, map[string]string{"vllm": "error failed to start"})

installed, err := CheckBackendInstalled(vllm.Name)
require.NoError(t, err)
require.False(t, installed)
})
}

func TestPromptInstallBackend(t *testing.T) {
cmd := &cobra.Command{Use: "test"}
cmd.SetIn(strings.NewReader("yes\n"))
out := new(bytes.Buffer)
cmd.SetOut(out)

confirmed, err := PromptInstallBackend(vllm.Name, cmd)
require.NoError(t, err)
require.True(t, confirmed)
require.Contains(t, out.String(), "Backend \"vllm\" is not installed")
}

func TestEnsureBackendAvailableCancelled(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()

setupDesktopClientStatusMock(t, ctrl, map[string]string{"vllm": "not running"})

cmd := &cobra.Command{Use: "test"}
cmd.SetIn(strings.NewReader("n\n"))
out := new(bytes.Buffer)
cmd.SetOut(out)

err := EnsureBackendAvailable(vllm.Name, cmd)
require.Error(t, err)
require.ErrorIs(t, err, errBackendInstallationCancelled)
require.Contains(t, out.String(), "docker model install-runner --backend vllm")
}

func TestGetRequiredBackendFromModelInfo(t *testing.T) {
t.Run("safetensors chooses vllm", func(t *testing.T) {
backend, err := GetRequiredBackendFromModelInfo(&dmrm.Model{Config: &types.Config{Format: types.FormatSafetensors}})
require.NoError(t, err)
require.Equal(t, vllm.Name, backend)
})

t.Run("gguf chooses llamacpp", func(t *testing.T) {
backend, err := GetRequiredBackendFromModelInfo(&dmrm.Model{Config: &types.Config{Format: types.FormatGGUF}})
require.NoError(t, err)
require.Equal(t, llamacpp.Name, backend)
})

t.Run("diffusers chooses diffusers backend", func(t *testing.T) {
backend, err := GetRequiredBackendFromModelInfo(&dmrm.Model{Config: &types.Config{Format: types.FormatDiffusers}})
require.NoError(t, err)
require.Equal(t, diffusers.Name, backend)
})
}