Deserialization and error handling improvements

This commit is contained in:
2024-02-07 07:43:11 +00:00
parent 450fe84d8b
commit 8c42901ddf
5 changed files with 73 additions and 20 deletions

View File

@@ -1,5 +1,7 @@
use std::fmt::Display;
use crate::types;
pub type Result<T> = std::result::Result<T, Error>;
#[derive(Debug)]
@@ -8,6 +10,7 @@ pub enum Error {
HttpClient(reqwest::Error),
Token(gcp_auth::Error),
Serde(serde_json::Error),
VertexError(types::Error),
}
impl Display for Error {
@@ -17,6 +20,9 @@ impl Display for Error {
Error::HttpClient(e) => write!(f, "HTTP Client error: {}", e),
Error::Token(e) => write!(f, "Token error: {}", e),
Error::Serde(e) => write!(f, "Serde error: {}", e),
Error::VertexError(e) => {
write!(f, "Vertex error: {}", serde_json::to_string(e).unwrap())
}
}
}
}
@@ -46,3 +52,9 @@ impl From<serde_json::Error> for Error {
Error::Serde(e)
}
}
impl From<types::Error> for Error {
fn from(e: types::Error) -> Self {
Error::VertexError(e)
}
}

View File

@@ -28,7 +28,7 @@ pub struct Tools {
#[derive(Debug, Serialize, Deserialize)]
pub struct Content {
pub role: String,
pub parts: Vec<Part>,
pub parts: Option<Vec<Part>>,
}
#[derive(Clone, Debug, Serialize, Deserialize, Default)]
@@ -60,8 +60,18 @@ pub enum Part {
},
}
#[derive(Debug, Serialize, Deserialize)]
pub struct GenerateContentResponse(pub Vec<ResponseStreamChunk>);
pub type GenerateContentResponse = Vec<ResponseStreamChunk>;
// #[derive(Debug, Serialize, Deserialize)]
// #[serde(rename_all = "camelCase")]
// #[serde(untagged)]
// pub enum ResponseStreamChunkType {
// Ok {
// candidates: Vec<Candidate>,
// usage_metadata: UsageMetadata,
// },
// Error,
// }
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
@@ -130,23 +140,32 @@ pub struct FunctionParametersProperty {
pub description: String,
}
#[derive(Debug, Serialize, Deserialize)]
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Error {
pub code: i32,
pub message: String,
pub status: String,
pub details: Vec<ErrorDetail>,
pub details: Vec<ErrorType>,
}
// TODO: Make ErrorDetail an enum and map to the different types of errors.
#[derive(Debug, Serialize, Deserialize)]
pub struct ErrorDetail {
#[serde(rename = "@type")]
pub r#type: String,
}
#[derive(Debug, Serialize, Deserialize)]
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Link {
pub description: String,
pub url: String,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(tag = "@type")]
pub enum ErrorType {
#[serde(rename = "type.googleapis.com/google.rpc.ErrorInfo")]
ErrorInfo { metadata: ErrorInfoMetadata },
#[serde(rename = "type.googleapis.com/google.rpc.Help")]
Help { links: Vec<Link> },
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ErrorInfoMetadata {
service: String,
consumer: String,
}

View File

@@ -1,5 +1,5 @@
use crate::conversation::{Message, Role};
use crate::error::Result;
use crate::error::{Error, Result};
use crate::prelude::{
Content, Conversation, GenerateContentRequest, GenerateContentResponse, GenerationConfig,
};
@@ -65,11 +65,10 @@ impl<T: TokenProvider + Clone> VertexClient<T> {
.await?;
let txt_json = resp.text().await?;
tracing::debug!("Vertex API Response: {}", txt_json);
match serde_json::from_str(&txt_json) {
Ok(response) => Ok(response),
Err(e) => {
eprintln!("Failed to parse response: {} / {}", txt_json, e);
tracing::error!("Failed to parse response: {} with error {}", txt_json, e);
Err(e.into())
}
}
@@ -83,7 +82,7 @@ impl<T: TokenProvider + Clone> VertexClient<T> {
.iter()
.map(|m| Content {
role: m.role.to_string(),
parts: vec![Part::Text(m.text.clone())],
parts: Some(vec![Part::Text(m.text.clone())]),
})
.collect(),
generation_config: None,
@@ -94,14 +93,23 @@ impl<T: TokenProvider + Clone> VertexClient<T> {
.stream_generate_content(&request, Model::GeminiPro)
.await?;
// Check for errors in the response.
for chunk in &response {
if let Some(error) = &chunk.error {
tracing::error!("Error in response: {:?}", error);
let cloned = error.clone();
return Err(Error::VertexError(cloned));
}
}
let text = response
.0
.into_iter()
.flat_map(|chunk| {
chunk.candidates.unwrap().into_iter().flat_map(|candidate| {
candidate
.content
.parts
.unwrap()
.into_iter()
.map(|part| match part {
Part::Text(text) => Some(text),
@@ -125,7 +133,7 @@ impl<T: TokenProvider + Clone> VertexClient<T> {
let request = GenerateContentRequest {
contents: vec![Content {
role: "user".to_string(),
parts: vec![Part::Text(prompt.to_string())],
parts: Some(vec![Part::Text(prompt.to_string())]),
}],
generation_config: generation_config.cloned(),
tools: None,
@@ -135,14 +143,23 @@ impl<T: TokenProvider + Clone> VertexClient<T> {
.stream_generate_content(&request, Model::GeminiPro)
.await?;
// Check for errors in the response.
for chunk in &response {
if let Some(error) = &chunk.error {
tracing::error!("Error in response: {:?}", error);
let cloned = error.clone();
return Err(Error::VertexError(cloned));
}
}
let text = response
.0
.into_iter()
.flat_map(|chunk| {
chunk.candidates.unwrap().into_iter().flat_map(|candidate| {
candidate
.content
.parts
.unwrap()
.into_iter()
.map(|part| match part {
Part::Text(text) => Some(text),