Deserialization and error handling improvements

This commit is contained in:
2024-02-07 07:43:11 +00:00
parent 450fe84d8b
commit 8c42901ddf
5 changed files with 73 additions and 20 deletions

View File

@@ -17,3 +17,4 @@ console = "0.15.8"
dialoguer = "0.11.0" dialoguer = "0.11.0"
indicatif = "0.17.7" indicatif = "0.17.7"
tokio = { version = "1.35.1", features = ["full"] } tokio = { version = "1.35.1", features = ["full"] }
tracing-subscriber = "0.3.18"

View File

@@ -14,6 +14,8 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
let project_id = std::env::var("PROJECT_ID")?; let project_id = std::env::var("PROJECT_ID")?;
let location_id = std::env::var("LOCATION_ID")?; let location_id = std::env::var("LOCATION_ID")?;
tracing_subscriber::fmt().init();
let vertex_client = VertexClient::new( let vertex_client = VertexClient::new(
authentication_manager, authentication_manager,
api_endpoint, api_endpoint,
@@ -21,6 +23,8 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
location_id, location_id,
); );
tracing::info!("Starting conversation...");
let mut conversation = Conversation::new(); let mut conversation = Conversation::new();
loop { loop {
let message: String = Input::with_theme(&ColorfulTheme::default()) let message: String = Input::with_theme(&ColorfulTheme::default())

View File

@@ -1,5 +1,7 @@
use std::fmt::Display; use std::fmt::Display;
use crate::types;
pub type Result<T> = std::result::Result<T, Error>; pub type Result<T> = std::result::Result<T, Error>;
#[derive(Debug)] #[derive(Debug)]
@@ -8,6 +10,7 @@ pub enum Error {
HttpClient(reqwest::Error), HttpClient(reqwest::Error),
Token(gcp_auth::Error), Token(gcp_auth::Error),
Serde(serde_json::Error), Serde(serde_json::Error),
VertexError(types::Error),
} }
impl Display for Error { impl Display for Error {
@@ -17,6 +20,9 @@ impl Display for Error {
Error::HttpClient(e) => write!(f, "HTTP Client error: {}", e), Error::HttpClient(e) => write!(f, "HTTP Client error: {}", e),
Error::Token(e) => write!(f, "Token error: {}", e), Error::Token(e) => write!(f, "Token error: {}", e),
Error::Serde(e) => write!(f, "Serde error: {}", e), Error::Serde(e) => write!(f, "Serde error: {}", e),
Error::VertexError(e) => {
write!(f, "Vertex error: {}", serde_json::to_string(e).unwrap())
}
} }
} }
} }
@@ -46,3 +52,9 @@ impl From<serde_json::Error> for Error {
Error::Serde(e) Error::Serde(e)
} }
} }
impl From<types::Error> for Error {
fn from(e: types::Error) -> Self {
Error::VertexError(e)
}
}

View File

@@ -28,7 +28,7 @@ pub struct Tools {
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
pub struct Content { pub struct Content {
pub role: String, pub role: String,
pub parts: Vec<Part>, pub parts: Option<Vec<Part>>,
} }
#[derive(Clone, Debug, Serialize, Deserialize, Default)] #[derive(Clone, Debug, Serialize, Deserialize, Default)]
@@ -60,8 +60,18 @@ pub enum Part {
}, },
} }
#[derive(Debug, Serialize, Deserialize)] pub type GenerateContentResponse = Vec<ResponseStreamChunk>;
pub struct GenerateContentResponse(pub Vec<ResponseStreamChunk>);
// #[derive(Debug, Serialize, Deserialize)]
// #[serde(rename_all = "camelCase")]
// #[serde(untagged)]
// pub enum ResponseStreamChunkType {
// Ok {
// candidates: Vec<Candidate>,
// usage_metadata: UsageMetadata,
// },
// Error,
// }
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
@@ -130,23 +140,32 @@ pub struct FunctionParametersProperty {
pub description: String, pub description: String,
} }
#[derive(Debug, Serialize, Deserialize)] #[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Error { pub struct Error {
pub code: i32, pub code: i32,
pub message: String, pub message: String,
pub status: String, pub status: String,
pub details: Vec<ErrorDetail>, pub details: Vec<ErrorType>,
} }
// TODO: Make ErrorDetail an enum and map to the different types of errors. #[derive(Clone, Debug, Serialize, Deserialize)]
#[derive(Debug, Serialize, Deserialize)]
pub struct ErrorDetail {
#[serde(rename = "@type")]
pub r#type: String,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct Link { pub struct Link {
pub description: String, pub description: String,
pub url: String, pub url: String,
} }
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(tag = "@type")]
pub enum ErrorType {
#[serde(rename = "type.googleapis.com/google.rpc.ErrorInfo")]
ErrorInfo { metadata: ErrorInfoMetadata },
#[serde(rename = "type.googleapis.com/google.rpc.Help")]
Help { links: Vec<Link> },
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ErrorInfoMetadata {
service: String,
consumer: String,
}

View File

@@ -1,5 +1,5 @@
use crate::conversation::{Message, Role}; use crate::conversation::{Message, Role};
use crate::error::Result; use crate::error::{Error, Result};
use crate::prelude::{ use crate::prelude::{
Content, Conversation, GenerateContentRequest, GenerateContentResponse, GenerationConfig, Content, Conversation, GenerateContentRequest, GenerateContentResponse, GenerationConfig,
}; };
@@ -65,11 +65,10 @@ impl<T: TokenProvider + Clone> VertexClient<T> {
.await?; .await?;
let txt_json = resp.text().await?; let txt_json = resp.text().await?;
tracing::debug!("Vertex API Response: {}", txt_json);
match serde_json::from_str(&txt_json) { match serde_json::from_str(&txt_json) {
Ok(response) => Ok(response), Ok(response) => Ok(response),
Err(e) => { Err(e) => {
eprintln!("Failed to parse response: {} / {}", txt_json, e); tracing::error!("Failed to parse response: {} with error {}", txt_json, e);
Err(e.into()) Err(e.into())
} }
} }
@@ -83,7 +82,7 @@ impl<T: TokenProvider + Clone> VertexClient<T> {
.iter() .iter()
.map(|m| Content { .map(|m| Content {
role: m.role.to_string(), role: m.role.to_string(),
parts: vec![Part::Text(m.text.clone())], parts: Some(vec![Part::Text(m.text.clone())]),
}) })
.collect(), .collect(),
generation_config: None, generation_config: None,
@@ -94,14 +93,23 @@ impl<T: TokenProvider + Clone> VertexClient<T> {
.stream_generate_content(&request, Model::GeminiPro) .stream_generate_content(&request, Model::GeminiPro)
.await?; .await?;
// Check for errors in the response.
for chunk in &response {
if let Some(error) = &chunk.error {
tracing::error!("Error in response: {:?}", error);
let cloned = error.clone();
return Err(Error::VertexError(cloned));
}
}
let text = response let text = response
.0
.into_iter() .into_iter()
.flat_map(|chunk| { .flat_map(|chunk| {
chunk.candidates.unwrap().into_iter().flat_map(|candidate| { chunk.candidates.unwrap().into_iter().flat_map(|candidate| {
candidate candidate
.content .content
.parts .parts
.unwrap()
.into_iter() .into_iter()
.map(|part| match part { .map(|part| match part {
Part::Text(text) => Some(text), Part::Text(text) => Some(text),
@@ -125,7 +133,7 @@ impl<T: TokenProvider + Clone> VertexClient<T> {
let request = GenerateContentRequest { let request = GenerateContentRequest {
contents: vec![Content { contents: vec![Content {
role: "user".to_string(), role: "user".to_string(),
parts: vec![Part::Text(prompt.to_string())], parts: Some(vec![Part::Text(prompt.to_string())]),
}], }],
generation_config: generation_config.cloned(), generation_config: generation_config.cloned(),
tools: None, tools: None,
@@ -135,14 +143,23 @@ impl<T: TokenProvider + Clone> VertexClient<T> {
.stream_generate_content(&request, Model::GeminiPro) .stream_generate_content(&request, Model::GeminiPro)
.await?; .await?;
// Check for errors in the response.
for chunk in &response {
if let Some(error) = &chunk.error {
tracing::error!("Error in response: {:?}", error);
let cloned = error.clone();
return Err(Error::VertexError(cloned));
}
}
let text = response let text = response
.0
.into_iter() .into_iter()
.flat_map(|chunk| { .flat_map(|chunk| {
chunk.candidates.unwrap().into_iter().flat_map(|candidate| { chunk.candidates.unwrap().into_iter().flat_map(|candidate| {
candidate candidate
.content .content
.parts .parts
.unwrap()
.into_iter() .into_iter()
.map(|part| match part { .map(|part| match part {
Part::Text(text) => Some(text), Part::Text(text) => Some(text),