diff --git a/Cargo.toml b/Cargo.toml index 2e09bbe..616ed9c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,3 +17,4 @@ console = "0.15.8" dialoguer = "0.11.0" indicatif = "0.17.7" tokio = { version = "1.35.1", features = ["full"] } +tracing-subscriber = "0.3.18" diff --git a/examples/conversation.rs b/examples/conversation.rs index fa45e0f..3b4af20 100644 --- a/examples/conversation.rs +++ b/examples/conversation.rs @@ -14,6 +14,8 @@ async fn main() -> Result<(), Box> { let project_id = std::env::var("PROJECT_ID")?; let location_id = std::env::var("LOCATION_ID")?; + tracing_subscriber::fmt().init(); + let vertex_client = VertexClient::new( authentication_manager, api_endpoint, @@ -21,6 +23,8 @@ async fn main() -> Result<(), Box> { location_id, ); + tracing::info!("Starting conversation..."); + let mut conversation = Conversation::new(); loop { let message: String = Input::with_theme(&ColorfulTheme::default()) diff --git a/src/error.rs b/src/error.rs index 72a454e..3c5708e 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,5 +1,7 @@ use std::fmt::Display; +use crate::types; + pub type Result = std::result::Result; #[derive(Debug)] @@ -8,6 +10,7 @@ pub enum Error { HttpClient(reqwest::Error), Token(gcp_auth::Error), Serde(serde_json::Error), + VertexError(types::Error), } impl Display for Error { @@ -17,6 +20,9 @@ impl Display for Error { Error::HttpClient(e) => write!(f, "HTTP Client error: {}", e), Error::Token(e) => write!(f, "Token error: {}", e), Error::Serde(e) => write!(f, "Serde error: {}", e), + Error::VertexError(e) => { + write!(f, "Vertex error: {}", serde_json::to_string(e).unwrap()) + } } } } @@ -46,3 +52,9 @@ impl From for Error { Error::Serde(e) } } + +impl From for Error { + fn from(e: types::Error) -> Self { + Error::VertexError(e) + } +} diff --git a/src/types.rs b/src/types.rs index b7d1a92..8fa5cd1 100644 --- a/src/types.rs +++ b/src/types.rs @@ -28,7 +28,7 @@ pub struct Tools { #[derive(Debug, Serialize, Deserialize)] pub struct Content { pub role: String, - pub parts: Vec, + pub parts: Option>, } #[derive(Clone, Debug, Serialize, Deserialize, Default)] @@ -60,8 +60,18 @@ pub enum Part { }, } -#[derive(Debug, Serialize, Deserialize)] -pub struct GenerateContentResponse(pub Vec); +pub type GenerateContentResponse = Vec; + +// #[derive(Debug, Serialize, Deserialize)] +// #[serde(rename_all = "camelCase")] +// #[serde(untagged)] +// pub enum ResponseStreamChunkType { +// Ok { +// candidates: Vec, +// usage_metadata: UsageMetadata, +// }, +// Error, +// } #[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] @@ -130,23 +140,32 @@ pub struct FunctionParametersProperty { pub description: String, } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Clone, Debug, Serialize, Deserialize)] pub struct Error { pub code: i32, pub message: String, pub status: String, - pub details: Vec, + pub details: Vec, } -// TODO: Make ErrorDetail an enum and map to the different types of errors. -#[derive(Debug, Serialize, Deserialize)] -pub struct ErrorDetail { - #[serde(rename = "@type")] - pub r#type: String, -} - -#[derive(Debug, Serialize, Deserialize)] +#[derive(Clone, Debug, Serialize, Deserialize)] pub struct Link { pub description: String, pub url: String, } + +#[derive(Clone, Debug, Serialize, Deserialize)] +#[serde(tag = "@type")] +pub enum ErrorType { + #[serde(rename = "type.googleapis.com/google.rpc.ErrorInfo")] + ErrorInfo { metadata: ErrorInfoMetadata }, + + #[serde(rename = "type.googleapis.com/google.rpc.Help")] + Help { links: Vec }, +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct ErrorInfoMetadata { + service: String, + consumer: String, +} diff --git a/src/vertex_client.rs b/src/vertex_client.rs index 24c31ac..27b169d 100644 --- a/src/vertex_client.rs +++ b/src/vertex_client.rs @@ -1,5 +1,5 @@ use crate::conversation::{Message, Role}; -use crate::error::Result; +use crate::error::{Error, Result}; use crate::prelude::{ Content, Conversation, GenerateContentRequest, GenerateContentResponse, GenerationConfig, }; @@ -65,11 +65,10 @@ impl VertexClient { .await?; let txt_json = resp.text().await?; - tracing::debug!("Vertex API Response: {}", txt_json); match serde_json::from_str(&txt_json) { Ok(response) => Ok(response), Err(e) => { - eprintln!("Failed to parse response: {} / {}", txt_json, e); + tracing::error!("Failed to parse response: {} with error {}", txt_json, e); Err(e.into()) } } @@ -83,7 +82,7 @@ impl VertexClient { .iter() .map(|m| Content { role: m.role.to_string(), - parts: vec![Part::Text(m.text.clone())], + parts: Some(vec![Part::Text(m.text.clone())]), }) .collect(), generation_config: None, @@ -94,14 +93,23 @@ impl VertexClient { .stream_generate_content(&request, Model::GeminiPro) .await?; + // Check for errors in the response. + for chunk in &response { + if let Some(error) = &chunk.error { + tracing::error!("Error in response: {:?}", error); + let cloned = error.clone(); + return Err(Error::VertexError(cloned)); + } + } + let text = response - .0 .into_iter() .flat_map(|chunk| { chunk.candidates.unwrap().into_iter().flat_map(|candidate| { candidate .content .parts + .unwrap() .into_iter() .map(|part| match part { Part::Text(text) => Some(text), @@ -125,7 +133,7 @@ impl VertexClient { let request = GenerateContentRequest { contents: vec![Content { role: "user".to_string(), - parts: vec![Part::Text(prompt.to_string())], + parts: Some(vec![Part::Text(prompt.to_string())]), }], generation_config: generation_config.cloned(), tools: None, @@ -135,14 +143,23 @@ impl VertexClient { .stream_generate_content(&request, Model::GeminiPro) .await?; + // Check for errors in the response. + for chunk in &response { + if let Some(error) = &chunk.error { + tracing::error!("Error in response: {:?}", error); + let cloned = error.clone(); + return Err(Error::VertexError(cloned)); + } + } + let text = response - .0 .into_iter() .flat_map(|chunk| { chunk.candidates.unwrap().into_iter().flat_map(|candidate| { candidate .content .parts + .unwrap() .into_iter() .map(|part| match part { Part::Text(text) => Some(text),