Refactors ResponseChunk to use enums

This commit is contained in:
2024-02-07 08:14:43 +00:00
parent 8c42901ddf
commit ee7dcd4df4
2 changed files with 34 additions and 65 deletions

View File

@@ -62,23 +62,19 @@ pub enum Part {
pub type GenerateContentResponse = Vec<ResponseStreamChunk>;
// #[derive(Debug, Serialize, Deserialize)]
// #[serde(rename_all = "camelCase")]
// #[serde(untagged)]
// pub enum ResponseStreamChunkType {
// Ok {
// candidates: Vec<Candidate>,
// usage_metadata: UsageMetadata,
// },
// Error,
// }
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
#[serde(untagged)]
pub enum ResponseStreamChunk {
Ok(OkResponse),
Error(Error),
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ResponseStreamChunk {
pub candidates: Option<Vec<Candidate>>,
pub usage_metadata: Option<UsageMetadata>,
pub error: Option<Error>,
pub struct OkResponse {
pub candidates: Vec<Candidate>,
pub usage_metadata: UsageMetadata,
}
#[derive(Debug, Serialize, Deserialize)]

View File

@@ -2,6 +2,7 @@ use crate::conversation::{Message, Role};
use crate::error::{Error, Result};
use crate::prelude::{
Content, Conversation, GenerateContentRequest, GenerateContentResponse, GenerationConfig,
ResponseStreamChunk,
};
use crate::{prelude::Part, token_provider::TokenProvider};
@@ -94,32 +95,7 @@ impl<T: TokenProvider + Clone> VertexClient<T> {
.await?;
// Check for errors in the response.
for chunk in &response {
if let Some(error) = &chunk.error {
tracing::error!("Error in response: {:?}", error);
let cloned = error.clone();
return Err(Error::VertexError(cloned));
}
}
let text = response
.into_iter()
.flat_map(|chunk| {
chunk.candidates.unwrap().into_iter().flat_map(|candidate| {
candidate
.content
.parts
.unwrap()
.into_iter()
.map(|part| match part {
Part::Text(text) => Some(text),
_ => None,
})
.filter(Option::is_some)
.flatten()
})
})
.collect::<String>();
let text = VertexClient::<T>::collect_text_from_response(response)?;
Ok(Message::new(Role::Model, &text))
}
@@ -143,33 +119,30 @@ impl<T: TokenProvider + Clone> VertexClient<T> {
.stream_generate_content(&request, Model::GeminiPro)
.await?;
// Check for errors in the response.
for chunk in &response {
if let Some(error) = &chunk.error {
tracing::error!("Error in response: {:?}", error);
let cloned = error.clone();
return Err(Error::VertexError(cloned));
VertexClient::<T>::collect_text_from_response(response)
}
fn collect_text_from_response(response: GenerateContentResponse) -> Result<String> {
let mut text = String::new();
for chunk in response {
match chunk {
ResponseStreamChunk::Ok(ok_response) => {
for candidate in ok_response.candidates {
if let Some(parts) = &candidate.content.parts {
for part in parts {
if let Part::Text(t) = part {
text.push_str(t);
}
}
}
}
}
ResponseStreamChunk::Error(err) => {
tracing::error!("Error in response: {:?}", err);
return Err(Error::VertexError(err.clone()));
}
}
}
let text = response
.into_iter()
.flat_map(|chunk| {
chunk.candidates.unwrap().into_iter().flat_map(|candidate| {
candidate
.content
.parts
.unwrap()
.into_iter()
.map(|part| match part {
Part::Text(text) => Some(text),
_ => None,
})
.filter(Option::is_some)
.flatten()
})
})
.collect::<String>();
Ok(text)
}
}