-
Notifications
You must be signed in to change notification settings - Fork 3.1k
fix: gemini models via databricks #8042
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
49f99c3
852929d
71a7300
e5fbf1f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -412,16 +412,40 @@ impl Provider for DatabricksProvider { | |
| let status = resp.status(); | ||
| let error_text = resp.text().await.unwrap_or_default(); | ||
|
|
||
| // Parse as JSON if possible to pass to map_http_error_to_provider_error | ||
| let json_payload = serde_json::from_str::<Value>(&error_text).ok(); | ||
| return Err(map_http_error_to_provider_error(status, json_payload)); | ||
| } | ||
| Ok(resp) | ||
| }) | ||
| .await | ||
| .inspect_err(|e| { | ||
| let _ = log.error(e); | ||
| })?; | ||
| .await; | ||
|
|
||
| let response = match response { | ||
| Err(e) if e.to_string().contains("stream_options") => { | ||
| payload.as_object_mut().unwrap().remove("stream_options"); | ||
| self.with_retry(|| async { | ||
|
Comment on lines
+423
to
+425
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
When a Databricks endpoint rejects Useful? React with 👍 / 👎. |
||
| let resp = self | ||
| .api_client | ||
| .response_post(Some(session_id), &path, &payload) | ||
| .await?; | ||
| if !resp.status().is_success() { | ||
| let status = resp.status(); | ||
| let error_text = resp.text().await.unwrap_or_default(); | ||
| let json_payload = serde_json::from_str::<Value>(&error_text).ok(); | ||
| return Err(map_http_error_to_provider_error(status, json_payload)); | ||
| } | ||
| Ok(resp) | ||
| }) | ||
| .await | ||
| .inspect_err(|e| { | ||
| let _ = log.error(e); | ||
| })? | ||
| } | ||
| Err(e) => { | ||
| let _ = log.error(&e); | ||
| return Err(e); | ||
| } | ||
| Ok(resp) => resp, | ||
| }; | ||
|
|
||
| stream_openai_compat(response, log) | ||
| } | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -48,9 +48,25 @@ struct DeltaToolCall { | |
| extra: Option<serde_json::Map<String, Value>>, | ||
| } | ||
|
|
||
| #[derive(Serialize, Deserialize, Debug)] | ||
| #[serde(untagged)] | ||
| enum DeltaContent { | ||
| String(String), | ||
| Array(Vec<ContentPart>), | ||
| } | ||
|
|
||
| #[derive(Serialize, Deserialize, Debug)] | ||
| struct ContentPart { | ||
| r#type: String, | ||
| text: String, | ||
| #[serde(rename = "thoughtSignature")] | ||
| thought_signature: Option<String>, | ||
| } | ||
|
|
||
| #[derive(Serialize, Deserialize, Debug)] | ||
| struct Delta { | ||
| content: Option<String>, | ||
| #[serde(default)] | ||
| content: Option<DeltaContent>, | ||
| role: Option<String>, | ||
| tool_calls: Option<Vec<DeltaToolCall>>, | ||
| reasoning_details: Option<Vec<Value>>, | ||
|
|
@@ -74,6 +90,32 @@ struct StreamingChunk { | |
| model: Option<String>, | ||
| } | ||
|
|
||
| fn extract_content_and_signature( | ||
| delta_content: Option<&DeltaContent>, | ||
| ) -> (Option<String>, Option<String>) { | ||
| match delta_content { | ||
| Some(DeltaContent::String(s)) => (Some(s.clone()), None), | ||
| Some(DeltaContent::Array(parts)) => { | ||
| let text_parts: Vec<_> = parts.iter().filter(|p| p.r#type == "text").collect(); | ||
|
|
||
| let text = text_parts | ||
| .iter() | ||
| .map(|p| p.text.as_str()) | ||
| .collect::<String>(); | ||
|
Comment on lines
+99
to
+104
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Gemini's OpenAI-compatible chat completions support Useful? React with 👍 / 👎. |
||
|
|
||
| let signature = text_parts | ||
| .iter() | ||
| .find_map(|p| p.thought_signature.as_ref()) | ||
|
Comment on lines
+106
to
+108
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Gemini 3 places the continuation signature on the last part of a non-function response, but Useful? React with 👍 / 👎. |
||
| .cloned(); | ||
|
|
||
| let text = if text.is_empty() { None } else { Some(text) }; | ||
|
|
||
| (text, signature) | ||
| } | ||
| None => (None, None), | ||
| } | ||
| } | ||
|
|
||
| pub fn format_messages(messages: &[Message], image_format: &ImageFormat) -> Vec<Value> { | ||
| let mut messages_spec = Vec::new(); | ||
| for message in messages { | ||
|
|
@@ -564,6 +606,7 @@ where | |
|
|
||
| let mut accumulated_reasoning: Vec<Value> = Vec::new(); | ||
| let mut accumulated_reasoning_content = String::new(); | ||
| let mut last_signature: Option<String> = None; | ||
|
|
||
| 'outer: while let Some(response) = stream.next().await { | ||
| let response_str = response?; | ||
|
|
@@ -685,14 +728,23 @@ where | |
| serde_json::from_str::<Value>(arguments) | ||
| }; | ||
|
|
||
| let metadata = extra_fields.as_ref().filter(|m| !m.is_empty()); | ||
| let metadata = if let Some(sig) = &last_signature { | ||
| let mut combined = extra_fields.clone().unwrap_or_default(); | ||
| combined.insert( | ||
| crate::providers::formats::google::THOUGHT_SIGNATURE_KEY.to_string(), | ||
| json!(sig) | ||
| ); | ||
| Some(combined) | ||
| } else { | ||
| extra_fields.as_ref().filter(|m| !m.is_empty()).cloned() | ||
| }; | ||
|
|
||
| let content = match parsed { | ||
| Ok(params) => { | ||
| MessageContent::tool_request_with_metadata( | ||
| id.clone(), | ||
| Ok(CallToolRequestParams::new(function_name.clone()).with_arguments(object(params))), | ||
| metadata, | ||
| metadata.as_ref(), | ||
| ) | ||
| }, | ||
| Err(e) => { | ||
|
|
@@ -704,7 +756,7 @@ where | |
| )), | ||
| data: None, | ||
| }; | ||
| MessageContent::tool_request_with_metadata(id.clone(), Err(error), metadata) | ||
| MessageContent::tool_request_with_metadata(id.clone(), Err(error), metadata.as_ref()) | ||
| } | ||
| }; | ||
| contents.push(content); | ||
|
|
@@ -731,13 +783,20 @@ where | |
|
|
||
| if let Some(reasoning) = &chunk.choices[0].delta.reasoning_content { | ||
| if !reasoning.is_empty() { | ||
| content.push(MessageContent::thinking(reasoning, "")); | ||
| let signature = last_signature.as_deref().unwrap_or(""); | ||
| content.push(MessageContent::thinking(reasoning, signature)); | ||
| } | ||
| } | ||
|
|
||
| if let Some(text) = &chunk.choices[0].delta.content { | ||
| let (text_content, thought_signature) = extract_content_and_signature(chunk.choices[0].delta.content.as_ref()); | ||
|
|
||
| if let Some(sig) = thought_signature { | ||
| last_signature = Some(sig); | ||
| } | ||
|
|
||
| if let Some(text) = text_content { | ||
| if !text.is_empty() { | ||
| content.push(MessageContent::text(text)); | ||
| content.push(MessageContent::text(&text)); | ||
|
Comment on lines
+791
to
+799
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Databricks/Gemini chunks on this new Useful? React with 👍 / 👎. |
||
| } | ||
|
Comment on lines
+797
to
800
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
This branch updates Useful? React with 👍 / 👎. |
||
| } | ||
|
|
||
|
|
@@ -748,7 +807,6 @@ where | |
| content, | ||
| ); | ||
|
|
||
| // Add ID if present | ||
| if let Some(id) = chunk.id { | ||
| msg = msg.with_id(id); | ||
| } | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
stream_optionsOn Gemini-backed Databricks endpoints this 400 is deterministic, but the request goes through
with_retry()before thestream_optionsfallback runs. I checkedcrates/goose/src/providers/retry.rs, andshould_retry()treats everyProviderError::RequestFailedas retryable, so affected requests will send the same bad payload up to four times with backoff before line 423 finally stripsstream_options. That adds several seconds of latency and extra billable calls to every streaming turn on the models this patch is trying to unblock.Useful? React with 👍 / 👎.