diff --git a/crates/goose/src/providers/anthropic.rs b/crates/goose/src/providers/anthropic.rs index 21338a732b18..23c889d0dd96 100644 --- a/crates/goose/src/providers/anthropic.rs +++ b/crates/goose/src/providers/anthropic.rs @@ -55,6 +55,7 @@ pub struct AnthropicProvider { model: ModelConfig, supports_streaming: bool, name: String, + is_custom_host: bool, } impl AnthropicProvider { @@ -67,6 +68,8 @@ impl AnthropicProvider { .get_param("ANTHROPIC_HOST") .unwrap_or_else(|_| "https://api.anthropic.com".to_string()); + let is_custom_host = host != "https://api.anthropic.com"; + let auth = AuthMethod::ApiKey { header_name: "x-api-key".to_string(), key: api_key, @@ -80,6 +83,7 @@ impl AnthropicProvider { model, supports_streaming: true, name: ANTHROPIC_PROVIDER_NAME.to_string(), + is_custom_host, }) } @@ -124,6 +128,7 @@ impl AnthropicProvider { model, supports_streaming, name: config.name.clone(), + is_custom_host: true, }) } @@ -189,6 +194,14 @@ impl Provider for AnthropicProvider { &self.name } + fn provider_type(&self) -> crate::providers::base::ProviderType { + if self.name == ANTHROPIC_PROVIDER_NAME && !self.is_custom_host { + crate::providers::base::ProviderType::Builtin + } else { + crate::providers::base::ProviderType::Custom + } + } + fn get_model_config(&self) -> ModelConfig { self.model.clone() } diff --git a/crates/goose/src/providers/base.rs b/crates/goose/src/providers/base.rs index fae0dd8b7ae1..ddb6931b1b90 100644 --- a/crates/goose/src/providers/base.rs +++ b/crates/goose/src/providers/base.rs @@ -456,6 +456,11 @@ pub trait Provider: Send + Sync { /// Get the name of this provider instance fn get_name(&self) -> &str; + /// Get the provider classification for model listing behavior. + fn provider_type(&self) -> ProviderType { + ProviderType::Custom + } + /// Primary streaming method that all providers must implement. /// /// Note: Do not add `#[instrument]` here — the call sites (`complete` and @@ -534,40 +539,77 @@ pub trait Provider: Send + Sync { Ok(vec![]) } - /// Fetch models filtered by canonical registry and usability + /// Fetch models sorted by release date when available from canonical registry. + /// For built-in providers, models must be in the canonical registry and pass + /// usability checks (text modality, tool support). + /// For custom providers, all models are included; unknown models sort alphabetically. async fn fetch_recommended_models(&self) -> Result, ProviderError> { let all_models = self.fetch_supported_models().await?; + // Try to load the canonical registry for metadata. + // If it fails, propagate the error - we don't want to silently return + // an empty list or all models when we can't properly validate. let registry = CanonicalModelRegistry::bundled().map_err(|e| { ProviderError::ExecutionError(format!("Failed to load canonical registry: {}", e)) })?; let provider_name = self.get_name(); + let provider_type = self.provider_type(); + let uses_strict_model_filtering = matches!( + provider_type, + ProviderType::Builtin | ProviderType::Preferred + ); + let allows_unknown_models = matches!( + provider_type, + ProviderType::Custom | ProviderType::Declarative + ); + let toolshim_enabled = self.get_model_config().toolshim; - // Get all text-capable models with their release dates + // Build list of (model_name, release_date) for sorting. + // For built-in providers, filter out models without canonical metadata + // or that don't pass usability checks. let mut models_with_dates: Vec<(String, Option)> = all_models .iter() .filter_map(|model| { - let canonical_id = map_to_canonical_model(provider_name, model, registry)?; - - let (provider, model_name) = canonical_id.split_once('/')?; - let canonical_model = registry.get(provider, model_name)?; + let canonical = map_to_canonical_model(provider_name, model, registry).and_then( + |canonical_id| { + let (provider, model_name) = canonical_id.split_once('/')?; + registry.get(provider, model_name) + }, + ); + + match canonical { + Some(cm) => { + // Model has canonical metadata - apply checks + // Check text modality + if !cm + .modalities + .input + .contains(&crate::providers::canonical::Modality::Text) + { + return None; + } - if !canonical_model - .modalities - .input - .contains(&crate::providers::canonical::Modality::Text) - { - return None; - } + // Check tool support + if !cm.tool_call && !toolshim_enabled { + return None; + } - if !canonical_model.tool_call && !self.get_model_config().toolshim { - return None; + Some((model.clone(), cm.release_date.clone())) + } + None => { + // Model not in canonical registry + if uses_strict_model_filtering { + // Built-in/preferred providers: skip unknown models + None + } else if allows_unknown_models { + // Custom/declarative providers: include unknown models + Some((model.clone(), None)) + } else { + None + } + } } - - let release_date = canonical_model.release_date.clone(); - - Some((model.clone(), release_date)) }) .collect(); @@ -579,16 +621,10 @@ pub trait Provider: Send + Sync { (None, None) => a.0.cmp(&b.0), }); - let recommended_models: Vec = models_with_dates + Ok(models_with_dates .into_iter() .map(|(name, _)| name) - .collect(); - - if recommended_models.is_empty() { - Ok(all_models) - } else { - Ok(recommended_models) - } + .collect()) } async fn map_to_canonical_model( @@ -895,6 +931,45 @@ mod tests { } } + struct ListingProvider { + name: String, + provider_type: ProviderType, + model_config: ModelConfig, + supported_models: Vec, + } + + #[async_trait::async_trait] + impl Provider for ListingProvider { + fn get_name(&self) -> &str { + &self.name + } + + fn provider_type(&self) -> ProviderType { + self.provider_type + } + + fn get_model_config(&self) -> ModelConfig { + self.model_config.clone() + } + + async fn fetch_supported_models(&self) -> Result, ProviderError> { + Ok(self.supported_models.clone()) + } + + async fn stream( + &self, + _model_config: &ModelConfig, + _session_id: &str, + _system: &str, + _messages: &[Message], + _tools: &[Tool], + ) -> Result { + Err(ProviderError::ExecutionError( + "stream not implemented for listing tests".to_string(), + )) + } + } + fn create_test_stream( items: Vec, ) -> impl Stream, Option), ProviderError>> { @@ -1078,4 +1153,30 @@ mod tests { assert_eq!(info.output_token_cost, Some(0.00001)); assert_eq!(info.currency, Some("$".to_string())); } + + #[tokio::test] + async fn test_fetch_recommended_models_includes_unknown_for_custom_provider() { + let provider = ListingProvider { + name: "custom-proxy".to_string(), + provider_type: ProviderType::Custom, + model_config: ModelConfig::new_or_fail("glm-5"), + supported_models: vec!["glm-5".to_string()], + }; + + let recommended = provider.fetch_recommended_models().await.unwrap(); + assert!(recommended.contains(&"glm-5".to_string())); + } + + #[tokio::test] + async fn test_fetch_recommended_models_excludes_unknown_for_builtin_provider() { + let provider = ListingProvider { + name: "openai".to_string(), + provider_type: ProviderType::Builtin, + model_config: ModelConfig::new_or_fail("gpt-4o"), + supported_models: vec!["definitely-unknown-model-id".to_string()], + }; + + let recommended = provider.fetch_recommended_models().await.unwrap(); + assert!(recommended.is_empty()); + } } diff --git a/crates/goose/src/providers/openai.rs b/crates/goose/src/providers/openai.rs index 3e032c40a3d6..68e0670da3a6 100644 --- a/crates/goose/src/providers/openai.rs +++ b/crates/goose/src/providers/openai.rs @@ -65,6 +65,7 @@ pub struct OpenAiProvider { custom_headers: Option>, supports_streaming: bool, name: String, + is_custom_host: bool, } impl OpenAiProvider { @@ -76,6 +77,8 @@ impl OpenAiProvider { .get_param("OPENAI_HOST") .unwrap_or_else(|_| "https://api.openai.com".to_string()); + let is_custom_host = host != "https://api.openai.com"; + let secrets = config .get_secrets("OPENAI_API_KEY", &["OPENAI_CUSTOM_HEADERS"]) .unwrap_or_default(); @@ -126,6 +129,7 @@ impl OpenAiProvider { custom_headers, supports_streaming: true, name: OPEN_AI_PROVIDER_NAME.to_string(), + is_custom_host, }) } @@ -140,6 +144,7 @@ impl OpenAiProvider { custom_headers: None, supports_streaming: true, name: OPEN_AI_PROVIDER_NAME.to_string(), + is_custom_host: false, } } @@ -208,6 +213,7 @@ impl OpenAiProvider { custom_headers: config.headers, supports_streaming: config.supports_streaming.unwrap_or(true), name: config.name.clone(), + is_custom_host: true, }) } @@ -361,6 +367,14 @@ impl Provider for OpenAiProvider { &self.name } + fn provider_type(&self) -> crate::providers::base::ProviderType { + if self.name == OPEN_AI_PROVIDER_NAME && !self.is_custom_host { + crate::providers::base::ProviderType::Builtin + } else { + crate::providers::base::ProviderType::Custom + } + } + fn get_model_config(&self) -> ModelConfig { self.model.clone() } @@ -617,6 +631,7 @@ mod tests { custom_headers: None, supports_streaming: true, name: name.to_string(), + is_custom_host: true, } }