Deserialization and error handling improvements
This commit is contained in:
@@ -17,3 +17,4 @@ console = "0.15.8"
|
|||||||
dialoguer = "0.11.0"
|
dialoguer = "0.11.0"
|
||||||
indicatif = "0.17.7"
|
indicatif = "0.17.7"
|
||||||
tokio = { version = "1.35.1", features = ["full"] }
|
tokio = { version = "1.35.1", features = ["full"] }
|
||||||
|
tracing-subscriber = "0.3.18"
|
||||||
|
|||||||
@@ -14,6 +14,8 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||||||
let project_id = std::env::var("PROJECT_ID")?;
|
let project_id = std::env::var("PROJECT_ID")?;
|
||||||
let location_id = std::env::var("LOCATION_ID")?;
|
let location_id = std::env::var("LOCATION_ID")?;
|
||||||
|
|
||||||
|
tracing_subscriber::fmt().init();
|
||||||
|
|
||||||
let vertex_client = VertexClient::new(
|
let vertex_client = VertexClient::new(
|
||||||
authentication_manager,
|
authentication_manager,
|
||||||
api_endpoint,
|
api_endpoint,
|
||||||
@@ -21,6 +23,8 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||||||
location_id,
|
location_id,
|
||||||
);
|
);
|
||||||
|
|
||||||
|
tracing::info!("Starting conversation...");
|
||||||
|
|
||||||
let mut conversation = Conversation::new();
|
let mut conversation = Conversation::new();
|
||||||
loop {
|
loop {
|
||||||
let message: String = Input::with_theme(&ColorfulTheme::default())
|
let message: String = Input::with_theme(&ColorfulTheme::default())
|
||||||
|
|||||||
12
src/error.rs
12
src/error.rs
@@ -1,5 +1,7 @@
|
|||||||
use std::fmt::Display;
|
use std::fmt::Display;
|
||||||
|
|
||||||
|
use crate::types;
|
||||||
|
|
||||||
pub type Result<T> = std::result::Result<T, Error>;
|
pub type Result<T> = std::result::Result<T, Error>;
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
@@ -8,6 +10,7 @@ pub enum Error {
|
|||||||
HttpClient(reqwest::Error),
|
HttpClient(reqwest::Error),
|
||||||
Token(gcp_auth::Error),
|
Token(gcp_auth::Error),
|
||||||
Serde(serde_json::Error),
|
Serde(serde_json::Error),
|
||||||
|
VertexError(types::Error),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Display for Error {
|
impl Display for Error {
|
||||||
@@ -17,6 +20,9 @@ impl Display for Error {
|
|||||||
Error::HttpClient(e) => write!(f, "HTTP Client error: {}", e),
|
Error::HttpClient(e) => write!(f, "HTTP Client error: {}", e),
|
||||||
Error::Token(e) => write!(f, "Token error: {}", e),
|
Error::Token(e) => write!(f, "Token error: {}", e),
|
||||||
Error::Serde(e) => write!(f, "Serde error: {}", e),
|
Error::Serde(e) => write!(f, "Serde error: {}", e),
|
||||||
|
Error::VertexError(e) => {
|
||||||
|
write!(f, "Vertex error: {}", serde_json::to_string(e).unwrap())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -46,3 +52,9 @@ impl From<serde_json::Error> for Error {
|
|||||||
Error::Serde(e)
|
Error::Serde(e)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl From<types::Error> for Error {
|
||||||
|
fn from(e: types::Error) -> Self {
|
||||||
|
Error::VertexError(e)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
45
src/types.rs
45
src/types.rs
@@ -28,7 +28,7 @@ pub struct Tools {
|
|||||||
#[derive(Debug, Serialize, Deserialize)]
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
pub struct Content {
|
pub struct Content {
|
||||||
pub role: String,
|
pub role: String,
|
||||||
pub parts: Vec<Part>,
|
pub parts: Option<Vec<Part>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, Serialize, Deserialize, Default)]
|
#[derive(Clone, Debug, Serialize, Deserialize, Default)]
|
||||||
@@ -60,8 +60,18 @@ pub enum Part {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
pub type GenerateContentResponse = Vec<ResponseStreamChunk>;
|
||||||
pub struct GenerateContentResponse(pub Vec<ResponseStreamChunk>);
|
|
||||||
|
// #[derive(Debug, Serialize, Deserialize)]
|
||||||
|
// #[serde(rename_all = "camelCase")]
|
||||||
|
// #[serde(untagged)]
|
||||||
|
// pub enum ResponseStreamChunkType {
|
||||||
|
// Ok {
|
||||||
|
// candidates: Vec<Candidate>,
|
||||||
|
// usage_metadata: UsageMetadata,
|
||||||
|
// },
|
||||||
|
// Error,
|
||||||
|
// }
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
#[serde(rename_all = "camelCase")]
|
#[serde(rename_all = "camelCase")]
|
||||||
@@ -130,23 +140,32 @@ pub struct FunctionParametersProperty {
|
|||||||
pub description: String,
|
pub description: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||||
pub struct Error {
|
pub struct Error {
|
||||||
pub code: i32,
|
pub code: i32,
|
||||||
pub message: String,
|
pub message: String,
|
||||||
pub status: String,
|
pub status: String,
|
||||||
pub details: Vec<ErrorDetail>,
|
pub details: Vec<ErrorType>,
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: Make ErrorDetail an enum and map to the different types of errors.
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
|
||||||
pub struct ErrorDetail {
|
|
||||||
#[serde(rename = "@type")]
|
|
||||||
pub r#type: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
|
||||||
pub struct Link {
|
pub struct Link {
|
||||||
pub description: String,
|
pub description: String,
|
||||||
pub url: String,
|
pub url: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||||
|
#[serde(tag = "@type")]
|
||||||
|
pub enum ErrorType {
|
||||||
|
#[serde(rename = "type.googleapis.com/google.rpc.ErrorInfo")]
|
||||||
|
ErrorInfo { metadata: ErrorInfoMetadata },
|
||||||
|
|
||||||
|
#[serde(rename = "type.googleapis.com/google.rpc.Help")]
|
||||||
|
Help { links: Vec<Link> },
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||||
|
pub struct ErrorInfoMetadata {
|
||||||
|
service: String,
|
||||||
|
consumer: String,
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
use crate::conversation::{Message, Role};
|
use crate::conversation::{Message, Role};
|
||||||
use crate::error::Result;
|
use crate::error::{Error, Result};
|
||||||
use crate::prelude::{
|
use crate::prelude::{
|
||||||
Content, Conversation, GenerateContentRequest, GenerateContentResponse, GenerationConfig,
|
Content, Conversation, GenerateContentRequest, GenerateContentResponse, GenerationConfig,
|
||||||
};
|
};
|
||||||
@@ -65,11 +65,10 @@ impl<T: TokenProvider + Clone> VertexClient<T> {
|
|||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
let txt_json = resp.text().await?;
|
let txt_json = resp.text().await?;
|
||||||
tracing::debug!("Vertex API Response: {}", txt_json);
|
|
||||||
match serde_json::from_str(&txt_json) {
|
match serde_json::from_str(&txt_json) {
|
||||||
Ok(response) => Ok(response),
|
Ok(response) => Ok(response),
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
eprintln!("Failed to parse response: {} / {}", txt_json, e);
|
tracing::error!("Failed to parse response: {} with error {}", txt_json, e);
|
||||||
Err(e.into())
|
Err(e.into())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -83,7 +82,7 @@ impl<T: TokenProvider + Clone> VertexClient<T> {
|
|||||||
.iter()
|
.iter()
|
||||||
.map(|m| Content {
|
.map(|m| Content {
|
||||||
role: m.role.to_string(),
|
role: m.role.to_string(),
|
||||||
parts: vec![Part::Text(m.text.clone())],
|
parts: Some(vec![Part::Text(m.text.clone())]),
|
||||||
})
|
})
|
||||||
.collect(),
|
.collect(),
|
||||||
generation_config: None,
|
generation_config: None,
|
||||||
@@ -94,14 +93,23 @@ impl<T: TokenProvider + Clone> VertexClient<T> {
|
|||||||
.stream_generate_content(&request, Model::GeminiPro)
|
.stream_generate_content(&request, Model::GeminiPro)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
|
// Check for errors in the response.
|
||||||
|
for chunk in &response {
|
||||||
|
if let Some(error) = &chunk.error {
|
||||||
|
tracing::error!("Error in response: {:?}", error);
|
||||||
|
let cloned = error.clone();
|
||||||
|
return Err(Error::VertexError(cloned));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
let text = response
|
let text = response
|
||||||
.0
|
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.flat_map(|chunk| {
|
.flat_map(|chunk| {
|
||||||
chunk.candidates.unwrap().into_iter().flat_map(|candidate| {
|
chunk.candidates.unwrap().into_iter().flat_map(|candidate| {
|
||||||
candidate
|
candidate
|
||||||
.content
|
.content
|
||||||
.parts
|
.parts
|
||||||
|
.unwrap()
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|part| match part {
|
.map(|part| match part {
|
||||||
Part::Text(text) => Some(text),
|
Part::Text(text) => Some(text),
|
||||||
@@ -125,7 +133,7 @@ impl<T: TokenProvider + Clone> VertexClient<T> {
|
|||||||
let request = GenerateContentRequest {
|
let request = GenerateContentRequest {
|
||||||
contents: vec![Content {
|
contents: vec![Content {
|
||||||
role: "user".to_string(),
|
role: "user".to_string(),
|
||||||
parts: vec![Part::Text(prompt.to_string())],
|
parts: Some(vec![Part::Text(prompt.to_string())]),
|
||||||
}],
|
}],
|
||||||
generation_config: generation_config.cloned(),
|
generation_config: generation_config.cloned(),
|
||||||
tools: None,
|
tools: None,
|
||||||
@@ -135,14 +143,23 @@ impl<T: TokenProvider + Clone> VertexClient<T> {
|
|||||||
.stream_generate_content(&request, Model::GeminiPro)
|
.stream_generate_content(&request, Model::GeminiPro)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
|
// Check for errors in the response.
|
||||||
|
for chunk in &response {
|
||||||
|
if let Some(error) = &chunk.error {
|
||||||
|
tracing::error!("Error in response: {:?}", error);
|
||||||
|
let cloned = error.clone();
|
||||||
|
return Err(Error::VertexError(cloned));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
let text = response
|
let text = response
|
||||||
.0
|
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.flat_map(|chunk| {
|
.flat_map(|chunk| {
|
||||||
chunk.candidates.unwrap().into_iter().flat_map(|candidate| {
|
chunk.candidates.unwrap().into_iter().flat_map(|candidate| {
|
||||||
candidate
|
candidate
|
||||||
.content
|
.content
|
||||||
.parts
|
.parts
|
||||||
|
.unwrap()
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|part| match part {
|
.map(|part| match part {
|
||||||
Part::Text(text) => Some(text),
|
Part::Text(text) => Some(text),
|
||||||
|
|||||||
Reference in New Issue
Block a user