use github.com/docker/docker/api/types/versions for comparing versions and store plugin version obtained by pluginManager in newModelAPI

Signed-off-by: Ignacio López Luna <ignasi.lopez.luna@gmail.com>
This commit is contained in:
Ignacio López Luna
2025-12-17 17:38:07 +01:00
committed by Nicolas De loof
parent 58403169f3
commit 29d6c918c4

View File

@@ -29,6 +29,7 @@ import (
"github.com/compose-spec/compose-go/v2/types"
"github.com/containerd/errdefs"
"github.com/docker/cli/cli-plugins/manager"
"github.com/docker/docker/api/types/versions"
"github.com/spf13/cobra"
"golang.org/x/sync/errgroup"
@@ -71,6 +72,7 @@ func (s *composeService) ensureModels(ctx context.Context, project *types.Projec
type modelAPI struct {
path string
version string // cached plugin version
env []string
prepare func(ctx context.Context, cmd *exec.Cmd) error
cleanup func()
@@ -170,7 +172,7 @@ func (m *modelAPI) ConfigureModel(ctx context.Context, config types.ModelConfig,
}
args = append(args, config.Model)
// Only append RuntimeFlags if docker model CLI version is >= v1.0.6
if len(config.RuntimeFlags) != 0 && m.supportsRuntimeFlags(ctx) {
if len(config.RuntimeFlags) != 0 && m.supportsRuntimeFlags() {
args = append(args, "--")
args = append(args, config.RuntimeFlags...)
}
@@ -279,113 +281,23 @@ func (m *modelAPI) ListModels(ctx context.Context) ([]string, error) {
return availableModels, nil
}
// getModelVersion retrieves the docker model CLI version
func (m *modelAPI) getModelVersion(ctx context.Context) (string, error) {
cmd := exec.CommandContext(ctx, m.path, "version")
err := m.prepare(ctx, cmd)
if err != nil {
return "", err
}
output, err := cmd.CombinedOutput()
if err != nil {
return "", fmt.Errorf("error getting docker model version: %w", err)
}
// Parse output like: "Docker Model Runner version v1.0.4"
// We need to extract the version string (e.g., "v1.0.4")
lines := strings.Split(strings.TrimSpace(string(output)), "\n")
for _, line := range lines {
if strings.Contains(line, "version") {
parts := strings.Fields(line)
for i, part := range parts {
if part == "version" && i+1 < len(parts) {
return parts[i+1], nil
}
}
}
}
return "", fmt.Errorf("could not parse docker model version from output: %s", string(output))
}
// supportsRuntimeFlags checks if the docker model version supports runtime flags
// Runtime flags are supported in version >= v1.0.6
func (m *modelAPI) supportsRuntimeFlags(ctx context.Context) bool {
versionStr, err := m.getModelVersion(ctx)
if err != nil {
// If we can't determine the version, don't append runtime flags to be safe
func (m *modelAPI) supportsRuntimeFlags() bool {
// If version is not cached, don't append runtime flags to be safe
if m.version == "" {
return false
}
// Parse version strings
currentVersion, err := parseVersion(versionStr)
if err != nil {
return false
}
minVersion, err := parseVersion("1.0.6")
if err != nil {
return false
}
return !currentVersion.LessThan(minVersion)
}
// parseVersion parses a semantic version string
// Strips build metadata and prerelease suffixes (e.g., "1.0.6-dirty" or "1.0.6+build")
func parseVersion(versionStr string) (*semVersion, error) {
// Remove 'v' prefix if present
versionStr = strings.TrimPrefix(versionStr, "v")
// Strip 'v' prefix if present (e.g., "v1.0.6" -> "1.0.6")
versionStr := strings.TrimPrefix(m.version, "v")
// Strip build metadata or prerelease suffix after "-" or "+"
// Examples: "1.0.6-dirty" -> "1.0.6", "1.0.6+build" -> "1.0.6"
// This is necessary because versions.LessThan treats "1.0.6-dirty" < "1.0.6" per semver rules
// but we want to compare the base version numbers only
if idx := strings.IndexAny(versionStr, "-+"); idx != -1 {
versionStr = versionStr[:idx]
}
parts := strings.Split(versionStr, ".")
if len(parts) < 2 {
return nil, fmt.Errorf("invalid version format: %s", versionStr)
}
var v semVersion
var err error
v.major, err = strconv.Atoi(parts[0])
if err != nil {
return nil, fmt.Errorf("invalid major version: %s", parts[0])
}
v.minor, err = strconv.Atoi(parts[1])
if err != nil {
return nil, fmt.Errorf("invalid minor version: %s", parts[1])
}
if len(parts) > 2 {
v.patch, err = strconv.Atoi(parts[2])
if err != nil {
return nil, fmt.Errorf("invalid patch version: %s", parts[2])
}
}
return &v, nil
}
// semVersion represents a semantic version
type semVersion struct {
major int
minor int
patch int
}
// LessThan compares two semantic versions
func (v *semVersion) LessThan(other *semVersion) bool {
if v.major != other.major {
return v.major < other.major
}
if v.minor != other.minor {
return v.minor < other.minor
}
return v.patch < other.patch
return !versions.LessThan(versionStr, "1.0.6")
}