Refactors ResponseChunk to use enums

This commit is contained in:
2024-02-07 08:14:43 +00:00
parent 8c42901ddf
commit ee7dcd4df4
2 changed files with 34 additions and 65 deletions

View File

@@ -62,23 +62,19 @@ pub enum Part {
pub type GenerateContentResponse = Vec<ResponseStreamChunk>; pub type GenerateContentResponse = Vec<ResponseStreamChunk>;
// #[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
// #[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
// #[serde(untagged)] #[serde(untagged)]
// pub enum ResponseStreamChunkType { pub enum ResponseStreamChunk {
// Ok { Ok(OkResponse),
// candidates: Vec<Candidate>, Error(Error),
// usage_metadata: UsageMetadata, }
// },
// Error,
// }
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
pub struct ResponseStreamChunk { pub struct OkResponse {
pub candidates: Option<Vec<Candidate>>, pub candidates: Vec<Candidate>,
pub usage_metadata: Option<UsageMetadata>, pub usage_metadata: UsageMetadata,
pub error: Option<Error>,
} }
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]

View File

@@ -2,6 +2,7 @@ use crate::conversation::{Message, Role};
use crate::error::{Error, Result}; use crate::error::{Error, Result};
use crate::prelude::{ use crate::prelude::{
Content, Conversation, GenerateContentRequest, GenerateContentResponse, GenerationConfig, Content, Conversation, GenerateContentRequest, GenerateContentResponse, GenerationConfig,
ResponseStreamChunk,
}; };
use crate::{prelude::Part, token_provider::TokenProvider}; use crate::{prelude::Part, token_provider::TokenProvider};
@@ -94,32 +95,7 @@ impl<T: TokenProvider + Clone> VertexClient<T> {
.await?; .await?;
// Check for errors in the response. // Check for errors in the response.
for chunk in &response { let text = VertexClient::<T>::collect_text_from_response(response)?;
if let Some(error) = &chunk.error {
tracing::error!("Error in response: {:?}", error);
let cloned = error.clone();
return Err(Error::VertexError(cloned));
}
}
let text = response
.into_iter()
.flat_map(|chunk| {
chunk.candidates.unwrap().into_iter().flat_map(|candidate| {
candidate
.content
.parts
.unwrap()
.into_iter()
.map(|part| match part {
Part::Text(text) => Some(text),
_ => None,
})
.filter(Option::is_some)
.flatten()
})
})
.collect::<String>();
Ok(Message::new(Role::Model, &text)) Ok(Message::new(Role::Model, &text))
} }
@@ -143,33 +119,30 @@ impl<T: TokenProvider + Clone> VertexClient<T> {
.stream_generate_content(&request, Model::GeminiPro) .stream_generate_content(&request, Model::GeminiPro)
.await?; .await?;
// Check for errors in the response. VertexClient::<T>::collect_text_from_response(response)
for chunk in &response { }
if let Some(error) = &chunk.error {
tracing::error!("Error in response: {:?}", error); fn collect_text_from_response(response: GenerateContentResponse) -> Result<String> {
let cloned = error.clone(); let mut text = String::new();
return Err(Error::VertexError(cloned)); for chunk in response {
match chunk {
ResponseStreamChunk::Ok(ok_response) => {
for candidate in ok_response.candidates {
if let Some(parts) = &candidate.content.parts {
for part in parts {
if let Part::Text(t) = part {
text.push_str(t);
}
}
}
}
}
ResponseStreamChunk::Error(err) => {
tracing::error!("Error in response: {:?}", err);
return Err(Error::VertexError(err.clone()));
}
} }
} }
let text = response
.into_iter()
.flat_map(|chunk| {
chunk.candidates.unwrap().into_iter().flat_map(|candidate| {
candidate
.content
.parts
.unwrap()
.into_iter()
.map(|part| match part {
Part::Text(text) => Some(text),
_ => None,
})
.filter(Option::is_some)
.flatten()
})
})
.collect::<String>();
Ok(text) Ok(text)
} }
} }