diff --git a/crates/goose/src/providers/databricks.rs b/crates/goose/src/providers/databricks.rs index 9428be4a8bf5..5b68accde867 100644 --- a/crates/goose/src/providers/databricks.rs +++ b/crates/goose/src/providers/databricks.rs @@ -412,16 +412,40 @@ impl Provider for DatabricksProvider { let status = resp.status(); let error_text = resp.text().await.unwrap_or_default(); - // Parse as JSON if possible to pass to map_http_error_to_provider_error let json_payload = serde_json::from_str::(&error_text).ok(); return Err(map_http_error_to_provider_error(status, json_payload)); } Ok(resp) }) - .await - .inspect_err(|e| { - let _ = log.error(e); - })?; + .await; + + let response = match response { + Err(e) if e.to_string().contains("stream_options") => { + payload.as_object_mut().unwrap().remove("stream_options"); + self.with_retry(|| async { + let resp = self + .api_client + .response_post(Some(session_id), &path, &payload) + .await?; + if !resp.status().is_success() { + let status = resp.status(); + let error_text = resp.text().await.unwrap_or_default(); + let json_payload = serde_json::from_str::(&error_text).ok(); + return Err(map_http_error_to_provider_error(status, json_payload)); + } + Ok(resp) + }) + .await + .inspect_err(|e| { + let _ = log.error(e); + })? + } + Err(e) => { + let _ = log.error(&e); + return Err(e); + } + Ok(resp) => resp, + }; stream_openai_compat(response, log) } diff --git a/crates/goose/src/providers/formats/openai.rs b/crates/goose/src/providers/formats/openai.rs index 3d84ce3272a8..db733fc4dfc3 100644 --- a/crates/goose/src/providers/formats/openai.rs +++ b/crates/goose/src/providers/formats/openai.rs @@ -48,9 +48,25 @@ struct DeltaToolCall { extra: Option>, } +#[derive(Serialize, Deserialize, Debug)] +#[serde(untagged)] +enum DeltaContent { + String(String), + Array(Vec), +} + +#[derive(Serialize, Deserialize, Debug)] +struct ContentPart { + r#type: String, + text: String, + #[serde(rename = "thoughtSignature")] + thought_signature: Option, +} + #[derive(Serialize, Deserialize, Debug)] struct Delta { - content: Option, + #[serde(default)] + content: Option, role: Option, tool_calls: Option>, reasoning_details: Option>, @@ -74,6 +90,32 @@ struct StreamingChunk { model: Option, } +fn extract_content_and_signature( + delta_content: Option<&DeltaContent>, +) -> (Option, Option) { + match delta_content { + Some(DeltaContent::String(s)) => (Some(s.clone()), None), + Some(DeltaContent::Array(parts)) => { + let text_parts: Vec<_> = parts.iter().filter(|p| p.r#type == "text").collect(); + + let text = text_parts + .iter() + .map(|p| p.text.as_str()) + .collect::(); + + let signature = text_parts + .iter() + .find_map(|p| p.thought_signature.as_ref()) + .cloned(); + + let text = if text.is_empty() { None } else { Some(text) }; + + (text, signature) + } + None => (None, None), + } +} + pub fn format_messages(messages: &[Message], image_format: &ImageFormat) -> Vec { let mut messages_spec = Vec::new(); for message in messages { @@ -564,6 +606,7 @@ where let mut accumulated_reasoning: Vec = Vec::new(); let mut accumulated_reasoning_content = String::new(); + let mut last_signature: Option = None; 'outer: while let Some(response) = stream.next().await { let response_str = response?; @@ -685,14 +728,23 @@ where serde_json::from_str::(arguments) }; - let metadata = extra_fields.as_ref().filter(|m| !m.is_empty()); + let metadata = if let Some(sig) = &last_signature { + let mut combined = extra_fields.clone().unwrap_or_default(); + combined.insert( + crate::providers::formats::google::THOUGHT_SIGNATURE_KEY.to_string(), + json!(sig) + ); + Some(combined) + } else { + extra_fields.as_ref().filter(|m| !m.is_empty()).cloned() + }; let content = match parsed { Ok(params) => { MessageContent::tool_request_with_metadata( id.clone(), Ok(CallToolRequestParams::new(function_name.clone()).with_arguments(object(params))), - metadata, + metadata.as_ref(), ) }, Err(e) => { @@ -704,7 +756,7 @@ where )), data: None, }; - MessageContent::tool_request_with_metadata(id.clone(), Err(error), metadata) + MessageContent::tool_request_with_metadata(id.clone(), Err(error), metadata.as_ref()) } }; contents.push(content); @@ -731,13 +783,20 @@ where if let Some(reasoning) = &chunk.choices[0].delta.reasoning_content { if !reasoning.is_empty() { - content.push(MessageContent::thinking(reasoning, "")); + let signature = last_signature.as_deref().unwrap_or(""); + content.push(MessageContent::thinking(reasoning, signature)); } } - if let Some(text) = &chunk.choices[0].delta.content { + let (text_content, thought_signature) = extract_content_and_signature(chunk.choices[0].delta.content.as_ref()); + + if let Some(sig) = thought_signature { + last_signature = Some(sig); + } + + if let Some(text) = text_content { if !text.is_empty() { - content.push(MessageContent::text(text)); + content.push(MessageContent::text(&text)); } } @@ -748,7 +807,6 @@ where content, ); - // Add ID if present if let Some(id) = chunk.id { msg = msg.with_id(id); }