diff --git a/examples/system_instruction.rs b/examples/system_instruction.rs index 8bde470..f98931e 100644 --- a/examples/system_instruction.rs +++ b/examples/system_instruction.rs @@ -38,13 +38,7 @@ async fn main() -> Result<(), Box> { .generate_content(&request, "gemini-1.0-pro-002") .await?; - if let GenerateContentResponse::Ok { - candidates, - usage_metadata: _, - } = result - { - println!("Response: {:?}", candidates[0].get_text().unwrap()); - } + println!("Response: {:?}", result.candidates[0].get_text().unwrap()); Ok(()) } diff --git a/examples/text-from-text-streaming.rs b/examples/text-from-text-streaming.rs index b241bbf..f90af7e 100644 --- a/examples/text-from-text-streaming.rs +++ b/examples/text-from-text-streaming.rs @@ -23,16 +23,18 @@ async fn main() -> Result<(), Box> { let queue = gemini.stream_generate_content(&request, "gemini-pro").await; while let Some(response) = queue.pop().await { - if let GenerateContentResponse::Ok { - candidates, - usage_metadata: _, - } = response - { - let text = candidates - .iter() - .filter_map(|c| c.get_text()) - .collect::(); - print!("{}", text); + match response { + Ok(result) => { + let text = result + .candidates + .iter() + .filter_map(|c| c.get_text()) + .collect::(); + print!("{}", text); + } + Err(error) => { + println!("{error}"); + } } } diff --git a/src/client.rs b/src/client.rs index 122f0df..055df52 100644 --- a/src/client.rs +++ b/src/client.rs @@ -8,7 +8,8 @@ use crate::dialogue::{Message, Role}; use crate::error::{Error, Result}; use crate::prelude::{ Candidate, Content, CountTokensRequest, CountTokensResponse, GenerateContentRequest, - GenerateContentResponse, GenerationConfig, TextEmbeddingRequest, TextEmbeddingResponse, + GenerateContentResponse, GenerateContentResponseResult, GenerationConfig, TextEmbeddingRequest, + TextEmbeddingResponse, }; use crate::{prelude::Part, token_provider::TokenProvider}; @@ -46,12 +47,18 @@ impl GeminiClient { &self, request: &GenerateContentRequest, model: &str, - ) -> Arc>> { - let queue = Arc::new(Queue::>::new()); + ) -> Arc>>> { + let queue = Arc::new(Queue::>>::new()); + let access_token = match self.token_provider.get_token(AUTH_SCOPE).await { + Ok(access_token) => access_token, + Err(e) => { + queue.push(Some(Err(e.into()))); + return queue; + } + }; // Clone the queue and other necessary data to move into the async block. let cloned_queue = queue.clone(); - let access_token = self.token_provider.get_token(AUTH_SCOPE).await.unwrap(); let endpoint_url: String = format!( "https://{}/v1beta1/projects/{}/locations/{}/publishers/google/models/{}:streamGenerateContent?alt=sse", self.api_endpoint, self.project_id, self.location_id, model.to_string(), ); @@ -64,7 +71,14 @@ impl GeminiClient { .post(&endpoint_url) .bearer_auth(access_token) .json(&request); - let mut event_source = EventSource::new(req).unwrap(); + + let mut event_source = match EventSource::new(req) { + Ok(event_source) => event_source, + Err(e) => { + cloned_queue.push(Some(Err(e.into()))); + return; + } + }; while let Some(event) = event_source.next().await { match event { Ok(event) => { @@ -72,7 +86,7 @@ impl GeminiClient { let response: serde_json::error::Result = serde_json::from_str(&event.data); if let Ok(response) = response { - cloned_queue.push(Some(response)); + cloned_queue.push(Some(response.into_result())); } else { tracing::error!("Error parsing message: {}", event.data); }; @@ -95,7 +109,7 @@ impl GeminiClient { &self, request: &GenerateContentRequest, model: &str, - ) -> Result { + ) -> Result { let access_token = self.token_provider.get_token(AUTH_SCOPE).await?; let endpoint_url: String = format!( "https://{}/v1beta1/projects/{}/locations/{}/publishers/google/models/{}:generateContent", self.api_endpoint, self.project_id, self.location_id, model.to_string(), @@ -137,7 +151,7 @@ impl GeminiClient { let response = self.generate_content(&request, model).await?; // Check for errors in the response. - let mut candidates = GeminiClient::::collect_text_from_response(&response)?; + let mut candidates = GeminiClient::::collect_text_from_response(&response); match candidates.pop() { Some(text) => Ok(Message::new(Role::Model, &text)), @@ -163,7 +177,7 @@ impl GeminiClient { }; let response = self.generate_content(&request, "gemini-pro").await?; - let mut candidates = GeminiClient::::collect_text_from_response(&response)?; + let mut candidates = GeminiClient::::collect_text_from_response(&response); match candidates.pop() { Some(candidate) => Ok(candidate), @@ -171,21 +185,13 @@ impl GeminiClient { } } - fn collect_text_from_response(response: &GenerateContentResponse) -> Result> { - match response { - GenerateContentResponse::Ok { - candidates, - usage_metadata: _, - } => Ok(candidates - .iter() - .map(Candidate::get_text) - .flatten() - .collect::>()), - GenerateContentResponse::Error { error } => { - tracing::error!("Error in response: {:?}", error); - return Err(Error::VertexError(error.clone())); - } - } + fn collect_text_from_response(response: &GenerateContentResponseResult) -> Vec { + response + .candidates + .iter() + .map(Candidate::get_text) + .flatten() + .collect::>() } pub async fn text_embeddings( diff --git a/src/error.rs b/src/error.rs index 5210b88..e524c3c 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,5 +1,7 @@ use std::fmt::Display; +use reqwest_eventsource::CannotCloneRequestError; + use crate::types; pub type Result = std::result::Result; @@ -10,8 +12,9 @@ pub enum Error { HttpClient(reqwest::Error), Token(gcp_auth::Error), Serde(serde_json::Error), - VertexError(types::Error), + VertexError(types::VertexApiError), NoCandidatesError, + EventSourceError(CannotCloneRequestError), } impl Display for Error { @@ -22,11 +25,14 @@ impl Display for Error { Error::Token(e) => write!(f, "Token error: {}", e), Error::Serde(e) => write!(f, "Serde error: {}", e), Error::VertexError(e) => { - write!(f, "Vertex error: {}", serde_json::to_string(e).unwrap()) + write!(f, "Vertex error: {}", e.to_string()) } Error::NoCandidatesError => { write!(f, "No candidates returned for the prompt") } + Error::EventSourceError(e) => { + write!(f, "EventSourrce Error: {}", e) + } } } } @@ -57,8 +63,14 @@ impl From for Error { } } -impl From for Error { - fn from(e: types::Error) -> Self { +impl From for Error { + fn from(e: types::VertexApiError) -> Self { Error::VertexError(e) } } + +impl From for Error { + fn from(e: CannotCloneRequestError) -> Self { + Error::EventSourceError(e) + } +} diff --git a/src/types/count_tokens.rs b/src/types/count_tokens.rs index e294ecc..cce0612 100644 --- a/src/types/count_tokens.rs +++ b/src/types/count_tokens.rs @@ -16,6 +16,6 @@ pub enum CountTokensResponse { total_billable_characters: u32, }, Error { - error: super::Error, + error: super::VertexApiError, }, } diff --git a/src/types/error.rs b/src/types/error.rs index 85baef5..22cc876 100644 --- a/src/types/error.rs +++ b/src/types/error.rs @@ -1,13 +1,24 @@ +use std::fmt::Formatter; + use serde::{Deserialize, Serialize}; #[derive(Clone, Debug, Serialize, Deserialize)] -pub struct Error { +pub struct VertexApiError { pub code: i32, pub message: String, pub status: String, pub details: Option>, } +impl core::fmt::Display for VertexApiError { + fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), std::fmt::Error> { + writeln!(f, "Vertex API Error {} - {}", self.code, self.message)?; + Ok(()) + } +} + +impl std::error::Error for VertexApiError {} + #[derive(Clone, Debug, Serialize, Deserialize)] pub struct Link { pub description: String, diff --git a/src/types/generate_content.rs b/src/types/generate_content.rs index 40796c8..a46d9ec 100644 --- a/src/types/generate_content.rs +++ b/src/types/generate_content.rs @@ -2,7 +2,8 @@ use std::collections::HashMap; use serde::{Deserialize, Serialize}; -use super::{Content, Error, Part}; +use super::{Content, Part, VertexApiError}; +use crate::error::Result; #[derive(Clone, Default, Serialize, Deserialize)] pub struct GenerateContentRequest { @@ -116,16 +117,40 @@ pub struct FunctionParametersProperty { } #[derive(Debug, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] #[serde(untagged)] pub enum GenerateContentResponse { - Ok { - candidates: Vec, - usage_metadata: Option, - }, - Error { - error: Error, - }, + Ok(GenerateContentResponseResult), + Error(GenerateContentResponseError), +} + +impl Into> for GenerateContentResponse { + fn into(self) -> Result { + match self { + GenerateContentResponse::Ok(result) => Ok(result), + GenerateContentResponse::Error(error) => Err(error.error.into()), + } + } +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct GenerateContentResponseResult { + pub candidates: Vec, + pub usage_metadata: Option, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct GenerateContentResponseError { + pub error: VertexApiError, +} + +impl GenerateContentResponse { + pub fn into_result(self) -> Result { + match self { + GenerateContentResponse::Ok(result) => Ok(result), + GenerateContentResponse::Error(error) => Err(error.error.into()), + } + } } #[cfg(test)] diff --git a/src/types/text_embeddings.rs b/src/types/text_embeddings.rs index 83f0bd5..09be2fe 100644 --- a/src/types/text_embeddings.rs +++ b/src/types/text_embeddings.rs @@ -1,6 +1,6 @@ use serde::{Deserialize, Serialize}; -use super::Error; +use super::VertexApiError; #[derive(Clone, Debug, Serialize, Deserialize)] pub struct TextEmbeddingRequest { @@ -21,7 +21,7 @@ pub enum TextEmbeddingResponse { predictions: Vec, }, Error { - error: Error, + error: VertexApiError, }, }