diff --git a/src/error.rs b/src/error.rs index ec30308..72a454e 100644 --- a/src/error.rs +++ b/src/error.rs @@ -7,6 +7,7 @@ pub enum Error { Env(std::env::VarError), HttpClient(reqwest::Error), Token(gcp_auth::Error), + Serde(serde_json::Error), } impl Display for Error { @@ -15,6 +16,7 @@ impl Display for Error { Error::Env(e) => write!(f, "Environment variable error: {}", e), Error::HttpClient(e) => write!(f, "HTTP Client error: {}", e), Error::Token(e) => write!(f, "Token error: {}", e), + Error::Serde(e) => write!(f, "Serde error: {}", e), } } } @@ -38,3 +40,9 @@ impl From for Error { Error::Token(e) } } + +impl From for Error { + fn from(e: serde_json::Error) -> Self { + Error::Serde(e) + } +} diff --git a/src/types.rs b/src/types.rs index 0550cfa..b7d1a92 100644 --- a/src/types.rs +++ b/src/types.rs @@ -66,8 +66,9 @@ pub struct GenerateContentResponse(pub Vec); #[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct ResponseStreamChunk { - pub candidates: Vec, + pub candidates: Option>, pub usage_metadata: Option, + pub error: Option, } #[derive(Debug, Serialize, Deserialize)] @@ -128,3 +129,24 @@ pub struct FunctionParametersProperty { pub r#type: String, pub description: String, } + +#[derive(Debug, Serialize, Deserialize)] +pub struct Error { + pub code: i32, + pub message: String, + pub status: String, + pub details: Vec, +} + +// TODO: Make ErrorDetail an enum and map to the different types of errors. +#[derive(Debug, Serialize, Deserialize)] +pub struct ErrorDetail { + #[serde(rename = "@type")] + pub r#type: String, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct Link { + pub description: String, + pub url: String, +} diff --git a/src/vertex_client.rs b/src/vertex_client.rs index b01cda4..24c31ac 100644 --- a/src/vertex_client.rs +++ b/src/vertex_client.rs @@ -66,7 +66,13 @@ impl VertexClient { let txt_json = resp.text().await?; tracing::debug!("Vertex API Response: {}", txt_json); - Ok(serde_json::from_str(&txt_json).unwrap()) + match serde_json::from_str(&txt_json) { + Ok(response) => Ok(response), + Err(e) => { + eprintln!("Failed to parse response: {} / {}", txt_json, e); + Err(e.into()) + } + } } /// Prompts a conversation to the model. @@ -92,7 +98,7 @@ impl VertexClient { .0 .into_iter() .flat_map(|chunk| { - chunk.candidates.into_iter().flat_map(|candidate| { + chunk.candidates.unwrap().into_iter().flat_map(|candidate| { candidate .content .parts @@ -133,7 +139,7 @@ impl VertexClient { .0 .into_iter() .flat_map(|chunk| { - chunk.candidates.into_iter().flat_map(|candidate| { + chunk.candidates.unwrap().into_iter().flat_map(|candidate| { candidate .content .parts