diff --git a/src/types.rs b/src/types.rs index 8fa5cd1..05f3bd5 100644 --- a/src/types.rs +++ b/src/types.rs @@ -62,23 +62,19 @@ pub enum Part { pub type GenerateContentResponse = Vec; -// #[derive(Debug, Serialize, Deserialize)] -// #[serde(rename_all = "camelCase")] -// #[serde(untagged)] -// pub enum ResponseStreamChunkType { -// Ok { -// candidates: Vec, -// usage_metadata: UsageMetadata, -// }, -// Error, -// } +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +#[serde(untagged)] +pub enum ResponseStreamChunk { + Ok(OkResponse), + Error(Error), +} #[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] -pub struct ResponseStreamChunk { - pub candidates: Option>, - pub usage_metadata: Option, - pub error: Option, +pub struct OkResponse { + pub candidates: Vec, + pub usage_metadata: UsageMetadata, } #[derive(Debug, Serialize, Deserialize)] diff --git a/src/vertex_client.rs b/src/vertex_client.rs index 27b169d..dc54405 100644 --- a/src/vertex_client.rs +++ b/src/vertex_client.rs @@ -2,6 +2,7 @@ use crate::conversation::{Message, Role}; use crate::error::{Error, Result}; use crate::prelude::{ Content, Conversation, GenerateContentRequest, GenerateContentResponse, GenerationConfig, + ResponseStreamChunk, }; use crate::{prelude::Part, token_provider::TokenProvider}; @@ -94,32 +95,7 @@ impl VertexClient { .await?; // Check for errors in the response. - for chunk in &response { - if let Some(error) = &chunk.error { - tracing::error!("Error in response: {:?}", error); - let cloned = error.clone(); - return Err(Error::VertexError(cloned)); - } - } - - let text = response - .into_iter() - .flat_map(|chunk| { - chunk.candidates.unwrap().into_iter().flat_map(|candidate| { - candidate - .content - .parts - .unwrap() - .into_iter() - .map(|part| match part { - Part::Text(text) => Some(text), - _ => None, - }) - .filter(Option::is_some) - .flatten() - }) - }) - .collect::(); + let text = VertexClient::::collect_text_from_response(response)?; Ok(Message::new(Role::Model, &text)) } @@ -143,33 +119,30 @@ impl VertexClient { .stream_generate_content(&request, Model::GeminiPro) .await?; - // Check for errors in the response. - for chunk in &response { - if let Some(error) = &chunk.error { - tracing::error!("Error in response: {:?}", error); - let cloned = error.clone(); - return Err(Error::VertexError(cloned)); + VertexClient::::collect_text_from_response(response) + } + + fn collect_text_from_response(response: GenerateContentResponse) -> Result { + let mut text = String::new(); + for chunk in response { + match chunk { + ResponseStreamChunk::Ok(ok_response) => { + for candidate in ok_response.candidates { + if let Some(parts) = &candidate.content.parts { + for part in parts { + if let Part::Text(t) = part { + text.push_str(t); + } + } + } + } + } + ResponseStreamChunk::Error(err) => { + tracing::error!("Error in response: {:?}", err); + return Err(Error::VertexError(err.clone())); + } } } - - let text = response - .into_iter() - .flat_map(|chunk| { - chunk.candidates.unwrap().into_iter().flat_map(|candidate| { - candidate - .content - .parts - .unwrap() - .into_iter() - .map(|part| match part { - Part::Text(text) => Some(text), - _ => None, - }) - .filter(Option::is_some) - .flatten() - }) - }) - .collect::(); Ok(text) } }