Refactors ResponseChunk to use enums
This commit is contained in:
24
src/types.rs
24
src/types.rs
@@ -62,23 +62,19 @@ pub enum Part {
|
|||||||
|
|
||||||
pub type GenerateContentResponse = Vec<ResponseStreamChunk>;
|
pub type GenerateContentResponse = Vec<ResponseStreamChunk>;
|
||||||
|
|
||||||
// #[derive(Debug, Serialize, Deserialize)]
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
// #[serde(rename_all = "camelCase")]
|
#[serde(rename_all = "camelCase")]
|
||||||
// #[serde(untagged)]
|
#[serde(untagged)]
|
||||||
// pub enum ResponseStreamChunkType {
|
pub enum ResponseStreamChunk {
|
||||||
// Ok {
|
Ok(OkResponse),
|
||||||
// candidates: Vec<Candidate>,
|
Error(Error),
|
||||||
// usage_metadata: UsageMetadata,
|
}
|
||||||
// },
|
|
||||||
// Error,
|
|
||||||
// }
|
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
#[serde(rename_all = "camelCase")]
|
#[serde(rename_all = "camelCase")]
|
||||||
pub struct ResponseStreamChunk {
|
pub struct OkResponse {
|
||||||
pub candidates: Option<Vec<Candidate>>,
|
pub candidates: Vec<Candidate>,
|
||||||
pub usage_metadata: Option<UsageMetadata>,
|
pub usage_metadata: UsageMetadata,
|
||||||
pub error: Option<Error>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ use crate::conversation::{Message, Role};
|
|||||||
use crate::error::{Error, Result};
|
use crate::error::{Error, Result};
|
||||||
use crate::prelude::{
|
use crate::prelude::{
|
||||||
Content, Conversation, GenerateContentRequest, GenerateContentResponse, GenerationConfig,
|
Content, Conversation, GenerateContentRequest, GenerateContentResponse, GenerationConfig,
|
||||||
|
ResponseStreamChunk,
|
||||||
};
|
};
|
||||||
use crate::{prelude::Part, token_provider::TokenProvider};
|
use crate::{prelude::Part, token_provider::TokenProvider};
|
||||||
|
|
||||||
@@ -94,32 +95,7 @@ impl<T: TokenProvider + Clone> VertexClient<T> {
|
|||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
// Check for errors in the response.
|
// Check for errors in the response.
|
||||||
for chunk in &response {
|
let text = VertexClient::<T>::collect_text_from_response(response)?;
|
||||||
if let Some(error) = &chunk.error {
|
|
||||||
tracing::error!("Error in response: {:?}", error);
|
|
||||||
let cloned = error.clone();
|
|
||||||
return Err(Error::VertexError(cloned));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let text = response
|
|
||||||
.into_iter()
|
|
||||||
.flat_map(|chunk| {
|
|
||||||
chunk.candidates.unwrap().into_iter().flat_map(|candidate| {
|
|
||||||
candidate
|
|
||||||
.content
|
|
||||||
.parts
|
|
||||||
.unwrap()
|
|
||||||
.into_iter()
|
|
||||||
.map(|part| match part {
|
|
||||||
Part::Text(text) => Some(text),
|
|
||||||
_ => None,
|
|
||||||
})
|
|
||||||
.filter(Option::is_some)
|
|
||||||
.flatten()
|
|
||||||
})
|
|
||||||
})
|
|
||||||
.collect::<String>();
|
|
||||||
Ok(Message::new(Role::Model, &text))
|
Ok(Message::new(Role::Model, &text))
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -143,33 +119,30 @@ impl<T: TokenProvider + Clone> VertexClient<T> {
|
|||||||
.stream_generate_content(&request, Model::GeminiPro)
|
.stream_generate_content(&request, Model::GeminiPro)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
// Check for errors in the response.
|
VertexClient::<T>::collect_text_from_response(response)
|
||||||
for chunk in &response {
|
|
||||||
if let Some(error) = &chunk.error {
|
|
||||||
tracing::error!("Error in response: {:?}", error);
|
|
||||||
let cloned = error.clone();
|
|
||||||
return Err(Error::VertexError(cloned));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
let text = response
|
fn collect_text_from_response(response: GenerateContentResponse) -> Result<String> {
|
||||||
.into_iter()
|
let mut text = String::new();
|
||||||
.flat_map(|chunk| {
|
for chunk in response {
|
||||||
chunk.candidates.unwrap().into_iter().flat_map(|candidate| {
|
match chunk {
|
||||||
candidate
|
ResponseStreamChunk::Ok(ok_response) => {
|
||||||
.content
|
for candidate in ok_response.candidates {
|
||||||
.parts
|
if let Some(parts) = &candidate.content.parts {
|
||||||
.unwrap()
|
for part in parts {
|
||||||
.into_iter()
|
if let Part::Text(t) = part {
|
||||||
.map(|part| match part {
|
text.push_str(t);
|
||||||
Part::Text(text) => Some(text),
|
}
|
||||||
_ => None,
|
}
|
||||||
})
|
}
|
||||||
.filter(Option::is_some)
|
}
|
||||||
.flatten()
|
}
|
||||||
})
|
ResponseStreamChunk::Error(err) => {
|
||||||
})
|
tracing::error!("Error in response: {:?}", err);
|
||||||
.collect::<String>();
|
return Err(Error::VertexError(err.clone()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
Ok(text)
|
Ok(text)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user