diff --git a/pkg/compose/model.go b/pkg/compose/model.go index 12b7c73e4..0b225609c 100644 --- a/pkg/compose/model.go +++ b/pkg/compose/model.go @@ -169,7 +169,8 @@ func (m *modelAPI) ConfigureModel(ctx context.Context, config types.ModelConfig, args = append(args, "--context-size", strconv.Itoa(config.ContextSize)) } args = append(args, config.Model) - if len(config.RuntimeFlags) != 0 { + // Only append RuntimeFlags if docker model CLI version is >= v1.0.6 + if len(config.RuntimeFlags) != 0 && m.supportsRuntimeFlags(ctx) { args = append(args, "--") args = append(args, config.RuntimeFlags...) } @@ -277,3 +278,114 @@ func (m *modelAPI) ListModels(ctx context.Context) ([]string, error) { } return availableModels, nil } + +// getModelVersion retrieves the docker model CLI version +func (m *modelAPI) getModelVersion(ctx context.Context) (string, error) { + cmd := exec.CommandContext(ctx, m.path, "version") + err := m.prepare(ctx, cmd) + if err != nil { + return "", err + } + + output, err := cmd.CombinedOutput() + if err != nil { + return "", fmt.Errorf("error getting docker model version: %w", err) + } + + // Parse output like: "Docker Model Runner version v1.0.4" + // We need to extract the version string (e.g., "v1.0.4") + lines := strings.Split(strings.TrimSpace(string(output)), "\n") + for _, line := range lines { + if strings.Contains(line, "version") { + parts := strings.Fields(line) + for i, part := range parts { + if part == "version" && i+1 < len(parts) { + return parts[i+1], nil + } + } + } + } + + return "", fmt.Errorf("could not parse docker model version from output: %s", string(output)) +} + +// supportsRuntimeFlags checks if the docker model version supports runtime flags +// Runtime flags are supported in version >= v1.0.6 +func (m *modelAPI) supportsRuntimeFlags(ctx context.Context) bool { + versionStr, err := m.getModelVersion(ctx) + if err != nil { + // If we can't determine the version, don't append runtime flags to be safe + return false + } + + // Parse version strings + currentVersion, err := parseVersion(versionStr) + if err != nil { + return false + } + + minVersion, err := parseVersion("1.0.6") + if err != nil { + return false + } + + return !currentVersion.LessThan(minVersion) +} + +// parseVersion parses a semantic version string +// Strips build metadata and prerelease suffixes (e.g., "1.0.6-dirty" or "1.0.6+build") +func parseVersion(versionStr string) (*semVersion, error) { + // Remove 'v' prefix if present + versionStr = strings.TrimPrefix(versionStr, "v") + + // Strip build metadata or prerelease suffix after "-" or "+" + // Examples: "1.0.6-dirty" -> "1.0.6", "1.0.6+build" -> "1.0.6" + if idx := strings.IndexAny(versionStr, "-+"); idx != -1 { + versionStr = versionStr[:idx] + } + + parts := strings.Split(versionStr, ".") + if len(parts) < 2 { + return nil, fmt.Errorf("invalid version format: %s", versionStr) + } + + var v semVersion + var err error + + v.major, err = strconv.Atoi(parts[0]) + if err != nil { + return nil, fmt.Errorf("invalid major version: %s", parts[0]) + } + + v.minor, err = strconv.Atoi(parts[1]) + if err != nil { + return nil, fmt.Errorf("invalid minor version: %s", parts[1]) + } + + if len(parts) > 2 { + v.patch, err = strconv.Atoi(parts[2]) + if err != nil { + return nil, fmt.Errorf("invalid patch version: %s", parts[2]) + } + } + + return &v, nil +} + +// semVersion represents a semantic version +type semVersion struct { + major int + minor int + patch int +} + +// LessThan compares two semantic versions +func (v *semVersion) LessThan(other *semVersion) bool { + if v.major != other.major { + return v.major < other.major + } + if v.minor != other.minor { + return v.minor < other.minor + } + return v.patch < other.patch +}