-
Notifications
You must be signed in to change notification settings - Fork 3.1k
fix: include all models in provider model listing, not just canonical… #8001
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -456,6 +456,11 @@ pub trait Provider: Send + Sync { | |
| /// Get the name of this provider instance | ||
| fn get_name(&self) -> &str; | ||
|
|
||
| /// Get the provider classification for model listing behavior. | ||
| fn provider_type(&self) -> ProviderType { | ||
| ProviderType::Custom | ||
|
Comment on lines
+460
to
+461
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Useful? React with 👍 / 👎. |
||
| } | ||
|
|
||
| /// Primary streaming method that all providers must implement. | ||
| /// | ||
| /// Note: Do not add `#[instrument]` here — the call sites (`complete` and | ||
|
|
@@ -534,40 +539,77 @@ pub trait Provider: Send + Sync { | |
| Ok(vec![]) | ||
| } | ||
|
|
||
| /// Fetch models filtered by canonical registry and usability | ||
| /// Fetch models sorted by release date when available from canonical registry. | ||
| /// For built-in providers, models must be in the canonical registry and pass | ||
| /// usability checks (text modality, tool support). | ||
| /// For custom providers, all models are included; unknown models sort alphabetically. | ||
| async fn fetch_recommended_models(&self) -> Result<Vec<String>, ProviderError> { | ||
| let all_models = self.fetch_supported_models().await?; | ||
|
|
||
| // Try to load the canonical registry for metadata. | ||
| // If it fails, propagate the error - we don't want to silently return | ||
| // an empty list or all models when we can't properly validate. | ||
| let registry = CanonicalModelRegistry::bundled().map_err(|e| { | ||
| ProviderError::ExecutionError(format!("Failed to load canonical registry: {}", e)) | ||
| })?; | ||
|
|
||
| let provider_name = self.get_name(); | ||
| let provider_type = self.provider_type(); | ||
| let uses_strict_model_filtering = matches!( | ||
| provider_type, | ||
| ProviderType::Builtin | ProviderType::Preferred | ||
| ); | ||
| let allows_unknown_models = matches!( | ||
| provider_type, | ||
| ProviderType::Custom | ProviderType::Declarative | ||
| ); | ||
| let toolshim_enabled = self.get_model_config().toolshim; | ||
|
|
||
| // Get all text-capable models with their release dates | ||
| // Build list of (model_name, release_date) for sorting. | ||
| // For built-in providers, filter out models without canonical metadata | ||
| // or that don't pass usability checks. | ||
| let mut models_with_dates: Vec<(String, Option<String>)> = all_models | ||
| .iter() | ||
| .filter_map(|model| { | ||
| let canonical_id = map_to_canonical_model(provider_name, model, registry)?; | ||
|
|
||
| let (provider, model_name) = canonical_id.split_once('/')?; | ||
| let canonical_model = registry.get(provider, model_name)?; | ||
| let canonical = map_to_canonical_model(provider_name, model, registry).and_then( | ||
| |canonical_id| { | ||
| let (provider, model_name) = canonical_id.split_once('/')?; | ||
| registry.get(provider, model_name) | ||
| }, | ||
| ); | ||
|
|
||
| match canonical { | ||
| Some(cm) => { | ||
| // Model has canonical metadata - apply checks | ||
| // Check text modality | ||
| if !cm | ||
| .modalities | ||
| .input | ||
| .contains(&crate::providers::canonical::Modality::Text) | ||
| { | ||
| return None; | ||
| } | ||
|
|
||
| if !canonical_model | ||
| .modalities | ||
| .input | ||
| .contains(&crate::providers::canonical::Modality::Text) | ||
| { | ||
| return None; | ||
| } | ||
| // Check tool support | ||
| if !cm.tool_call && !toolshim_enabled { | ||
| return None; | ||
| } | ||
|
|
||
| if !canonical_model.tool_call && !self.get_model_config().toolshim { | ||
| return None; | ||
| Some((model.clone(), cm.release_date.clone())) | ||
| } | ||
| None => { | ||
| // Model not in canonical registry | ||
| if uses_strict_model_filtering { | ||
| // Built-in/preferred providers: skip unknown models | ||
| None | ||
octogonz marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| } else if allows_unknown_models { | ||
| // Custom/declarative providers: include unknown models | ||
| Some((model.clone(), None)) | ||
|
Comment on lines
+605
to
+607
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Fresh evidence: although the Useful? React with 👍 / 👎. |
||
| } else { | ||
| None | ||
| } | ||
| } | ||
| } | ||
|
|
||
| let release_date = canonical_model.release_date.clone(); | ||
|
|
||
| Some((model.clone(), release_date)) | ||
| }) | ||
| .collect(); | ||
|
|
||
|
|
@@ -579,16 +621,10 @@ pub trait Provider: Send + Sync { | |
| (None, None) => a.0.cmp(&b.0), | ||
| }); | ||
|
|
||
| let recommended_models: Vec<String> = models_with_dates | ||
| Ok(models_with_dates | ||
| .into_iter() | ||
| .map(|(name, _)| name) | ||
| .collect(); | ||
|
|
||
| if recommended_models.is_empty() { | ||
| Ok(all_models) | ||
| } else { | ||
| Ok(recommended_models) | ||
| } | ||
| .collect()) | ||
| } | ||
|
|
||
| async fn map_to_canonical_model( | ||
|
|
@@ -895,6 +931,45 @@ mod tests { | |
| } | ||
| } | ||
|
|
||
| struct ListingProvider { | ||
| name: String, | ||
| provider_type: ProviderType, | ||
| model_config: ModelConfig, | ||
| supported_models: Vec<String>, | ||
| } | ||
|
|
||
| #[async_trait::async_trait] | ||
| impl Provider for ListingProvider { | ||
| fn get_name(&self) -> &str { | ||
| &self.name | ||
| } | ||
|
|
||
| fn provider_type(&self) -> ProviderType { | ||
| self.provider_type | ||
| } | ||
|
|
||
| fn get_model_config(&self) -> ModelConfig { | ||
| self.model_config.clone() | ||
| } | ||
|
|
||
| async fn fetch_supported_models(&self) -> Result<Vec<String>, ProviderError> { | ||
| Ok(self.supported_models.clone()) | ||
| } | ||
|
|
||
| async fn stream( | ||
| &self, | ||
| _model_config: &ModelConfig, | ||
| _session_id: &str, | ||
| _system: &str, | ||
| _messages: &[Message], | ||
| _tools: &[Tool], | ||
| ) -> Result<MessageStream, ProviderError> { | ||
| Err(ProviderError::ExecutionError( | ||
| "stream not implemented for listing tests".to_string(), | ||
| )) | ||
| } | ||
| } | ||
|
|
||
| fn create_test_stream( | ||
| items: Vec<String>, | ||
| ) -> impl Stream<Item = Result<(Option<Message>, Option<ProviderUsage>), ProviderError>> { | ||
|
|
@@ -1078,4 +1153,30 @@ mod tests { | |
| assert_eq!(info.output_token_cost, Some(0.00001)); | ||
| assert_eq!(info.currency, Some("$".to_string())); | ||
| } | ||
|
|
||
| #[tokio::test] | ||
| async fn test_fetch_recommended_models_includes_unknown_for_custom_provider() { | ||
| let provider = ListingProvider { | ||
| name: "custom-proxy".to_string(), | ||
| provider_type: ProviderType::Custom, | ||
| model_config: ModelConfig::new_or_fail("glm-5"), | ||
| supported_models: vec!["glm-5".to_string()], | ||
| }; | ||
|
|
||
| let recommended = provider.fetch_recommended_models().await.unwrap(); | ||
| assert!(recommended.contains(&"glm-5".to_string())); | ||
| } | ||
|
|
||
| #[tokio::test] | ||
| async fn test_fetch_recommended_models_excludes_unknown_for_builtin_provider() { | ||
| let provider = ListingProvider { | ||
| name: "openai".to_string(), | ||
| provider_type: ProviderType::Builtin, | ||
| model_config: ModelConfig::new_or_fail("gpt-4o"), | ||
| supported_models: vec!["definitely-unknown-model-id".to_string()], | ||
| }; | ||
|
|
||
| let recommended = provider.fetch_recommended_models().await.unwrap(); | ||
| assert!(recommended.is_empty()); | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -65,6 +65,7 @@ pub struct OpenAiProvider { | |
| custom_headers: Option<HashMap<String, String>>, | ||
| supports_streaming: bool, | ||
| name: String, | ||
| is_custom_host: bool, | ||
| } | ||
|
|
||
| impl OpenAiProvider { | ||
|
|
@@ -76,6 +77,8 @@ impl OpenAiProvider { | |
| .get_param("OPENAI_HOST") | ||
| .unwrap_or_else(|_| "https://api.openai.com".to_string()); | ||
|
|
||
| let is_custom_host = host != "https://api.openai.com"; | ||
|
|
||
| let secrets = config | ||
| .get_secrets("OPENAI_API_KEY", &["OPENAI_CUSTOM_HEADERS"]) | ||
| .unwrap_or_default(); | ||
|
|
@@ -126,6 +129,7 @@ impl OpenAiProvider { | |
| custom_headers, | ||
| supports_streaming: true, | ||
| name: OPEN_AI_PROVIDER_NAME.to_string(), | ||
| is_custom_host, | ||
| }) | ||
| } | ||
|
|
||
|
|
@@ -140,6 +144,7 @@ impl OpenAiProvider { | |
| custom_headers: None, | ||
| supports_streaming: true, | ||
| name: OPEN_AI_PROVIDER_NAME.to_string(), | ||
| is_custom_host: false, | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -208,6 +213,7 @@ impl OpenAiProvider { | |
| custom_headers: config.headers, | ||
| supports_streaming: config.supports_streaming.unwrap_or(true), | ||
| name: config.name.clone(), | ||
| is_custom_host: true, | ||
| }) | ||
| } | ||
|
|
||
|
|
@@ -361,6 +367,14 @@ impl Provider for OpenAiProvider { | |
| &self.name | ||
| } | ||
|
|
||
| fn provider_type(&self) -> crate::providers::base::ProviderType { | ||
| if self.name == OPEN_AI_PROVIDER_NAME && !self.is_custom_host { | ||
| crate::providers::base::ProviderType::Builtin | ||
|
Comment on lines
+370
to
+372
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
This keeps the default-host OpenAI provider on the strict canonical-only path, so account-specific ids from Useful? React with 👍 / 👎. |
||
| } else { | ||
| crate::providers::base::ProviderType::Custom | ||
octogonz marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| } | ||
| } | ||
|
|
||
| fn get_model_config(&self) -> ModelConfig { | ||
| self.model.clone() | ||
| } | ||
|
|
@@ -617,6 +631,7 @@ mod tests { | |
| custom_headers: None, | ||
| supports_streaming: true, | ||
| name: name.to_string(), | ||
| is_custom_host: true, | ||
| } | ||
| } | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When
ANTHROPIC_HOSTis set to an equivalent official URL such ashttps://api.anthropic.com/orhttps://api.anthropic.com:443, this raw string comparison flips the provider toCustomeven thoughApiClient::build_url()still resolves requests against the normal Anthropic API (crates/goose/src/providers/api_client.rs). That changesfetch_recommended_models()onto the relaxed custom-provider path inbase.rs, so the CLI/UI model picker stops applying the usual canonical filtering just because the host string was formatted differently.Useful? React with 👍 / 👎.