Refactors ResponseChunk to use enums
This commit is contained in:
24
src/types.rs
24
src/types.rs
@@ -62,23 +62,19 @@ pub enum Part {
|
||||
|
||||
pub type GenerateContentResponse = Vec<ResponseStreamChunk>;
|
||||
|
||||
// #[derive(Debug, Serialize, Deserialize)]
|
||||
// #[serde(rename_all = "camelCase")]
|
||||
// #[serde(untagged)]
|
||||
// pub enum ResponseStreamChunkType {
|
||||
// Ok {
|
||||
// candidates: Vec<Candidate>,
|
||||
// usage_metadata: UsageMetadata,
|
||||
// },
|
||||
// Error,
|
||||
// }
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
#[serde(untagged)]
|
||||
pub enum ResponseStreamChunk {
|
||||
Ok(OkResponse),
|
||||
Error(Error),
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ResponseStreamChunk {
|
||||
pub candidates: Option<Vec<Candidate>>,
|
||||
pub usage_metadata: Option<UsageMetadata>,
|
||||
pub error: Option<Error>,
|
||||
pub struct OkResponse {
|
||||
pub candidates: Vec<Candidate>,
|
||||
pub usage_metadata: UsageMetadata,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
|
||||
@@ -2,6 +2,7 @@ use crate::conversation::{Message, Role};
|
||||
use crate::error::{Error, Result};
|
||||
use crate::prelude::{
|
||||
Content, Conversation, GenerateContentRequest, GenerateContentResponse, GenerationConfig,
|
||||
ResponseStreamChunk,
|
||||
};
|
||||
use crate::{prelude::Part, token_provider::TokenProvider};
|
||||
|
||||
@@ -94,32 +95,7 @@ impl<T: TokenProvider + Clone> VertexClient<T> {
|
||||
.await?;
|
||||
|
||||
// Check for errors in the response.
|
||||
for chunk in &response {
|
||||
if let Some(error) = &chunk.error {
|
||||
tracing::error!("Error in response: {:?}", error);
|
||||
let cloned = error.clone();
|
||||
return Err(Error::VertexError(cloned));
|
||||
}
|
||||
}
|
||||
|
||||
let text = response
|
||||
.into_iter()
|
||||
.flat_map(|chunk| {
|
||||
chunk.candidates.unwrap().into_iter().flat_map(|candidate| {
|
||||
candidate
|
||||
.content
|
||||
.parts
|
||||
.unwrap()
|
||||
.into_iter()
|
||||
.map(|part| match part {
|
||||
Part::Text(text) => Some(text),
|
||||
_ => None,
|
||||
})
|
||||
.filter(Option::is_some)
|
||||
.flatten()
|
||||
})
|
||||
})
|
||||
.collect::<String>();
|
||||
let text = VertexClient::<T>::collect_text_from_response(response)?;
|
||||
Ok(Message::new(Role::Model, &text))
|
||||
}
|
||||
|
||||
@@ -143,33 +119,30 @@ impl<T: TokenProvider + Clone> VertexClient<T> {
|
||||
.stream_generate_content(&request, Model::GeminiPro)
|
||||
.await?;
|
||||
|
||||
// Check for errors in the response.
|
||||
for chunk in &response {
|
||||
if let Some(error) = &chunk.error {
|
||||
tracing::error!("Error in response: {:?}", error);
|
||||
let cloned = error.clone();
|
||||
return Err(Error::VertexError(cloned));
|
||||
}
|
||||
VertexClient::<T>::collect_text_from_response(response)
|
||||
}
|
||||
|
||||
let text = response
|
||||
.into_iter()
|
||||
.flat_map(|chunk| {
|
||||
chunk.candidates.unwrap().into_iter().flat_map(|candidate| {
|
||||
candidate
|
||||
.content
|
||||
.parts
|
||||
.unwrap()
|
||||
.into_iter()
|
||||
.map(|part| match part {
|
||||
Part::Text(text) => Some(text),
|
||||
_ => None,
|
||||
})
|
||||
.filter(Option::is_some)
|
||||
.flatten()
|
||||
})
|
||||
})
|
||||
.collect::<String>();
|
||||
fn collect_text_from_response(response: GenerateContentResponse) -> Result<String> {
|
||||
let mut text = String::new();
|
||||
for chunk in response {
|
||||
match chunk {
|
||||
ResponseStreamChunk::Ok(ok_response) => {
|
||||
for candidate in ok_response.candidates {
|
||||
if let Some(parts) = &candidate.content.parts {
|
||||
for part in parts {
|
||||
if let Part::Text(t) = part {
|
||||
text.push_str(t);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
ResponseStreamChunk::Error(err) => {
|
||||
tracing::error!("Error in response: {:?}", err);
|
||||
return Err(Error::VertexError(err.clone()));
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(text)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user