From a8fbe658bbde4d6cb6d44de03ad9bcec826bf93a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9=20Cipriani=20Bandarra?= Date: Fri, 30 Jan 2026 20:32:40 +0000 Subject: [PATCH] Add consistent error handling to text_embeddings and count_tokens - Check HTTP status before parsing response body, matching the pattern used by generate_content and predict_image - Unwrap TextEmbeddingResponse enum, returning TextEmbeddingResponseOk - Extract CountTokensResponseResult struct and add into_result(), returning the unwrapped result instead of the raw enum - All endpoints now consistently return the success type directly and surface API errors as GeminiError or GenericApiError --- src/client.rs | 49 +++++++++++++++++++++++++++++++++++---- src/types/count_tokens.rs | 28 +++++++++++++++------- 2 files changed, 65 insertions(+), 12 deletions(-) diff --git a/src/client.rs b/src/client.rs index f110338..14226da 100644 --- a/src/client.rs +++ b/src/client.rs @@ -98,7 +98,7 @@ impl GeminiClient { &self, request: &TextEmbeddingRequest, model: &str, - ) -> GeminiResult { + ) -> GeminiResult { let endpoint_url = format!("https://generativelanguage.googleapis.com/v1beta/models/{model}:predict"); let resp = self @@ -108,16 +108,37 @@ impl GeminiClient { .json(&request) .send() .await?; + + let status = resp.status(); let txt_json = resp.text().await?; tracing::debug!("text_embeddings response: {:?}", txt_json); - Ok(serde_json::from_str::(&txt_json)?) + + if !status.is_success() { + if let Ok(gemini_error) = + serde_json::from_str::(&txt_json) + { + return Err(GeminiError::GeminiError(gemini_error)); + } + return Err(GeminiError::GenericApiError { + status: status.as_u16(), + body: txt_json, + }); + } + + match serde_json::from_str::(&txt_json) { + Ok(response) => Ok(response.into_result()?), + Err(e) => { + error!(response = txt_json, error = ?e, "Failed to parse response"); + Err(e.into()) + } + } } pub async fn count_tokens( &self, request: &CountTokensRequest, model: &str, - ) -> GeminiResult { + ) -> GeminiResult { let endpoint_url = format!("https://generativelanguage.googleapis.com/v1beta/models/{model}:countTokens"); let resp = self @@ -128,9 +149,29 @@ impl GeminiClient { .send() .await?; + let status = resp.status(); let txt_json = resp.text().await?; tracing::debug!("count_tokens response: {:?}", txt_json); - Ok(serde_json::from_str(&txt_json)?) + + if !status.is_success() { + if let Ok(gemini_error) = + serde_json::from_str::(&txt_json) + { + return Err(GeminiError::GeminiError(gemini_error)); + } + return Err(GeminiError::GenericApiError { + status: status.as_u16(), + body: txt_json, + }); + } + + match serde_json::from_str::(&txt_json) { + Ok(response) => Ok(response.into_result()?), + Err(e) => { + error!(response = txt_json, error = ?e, "Failed to parse response"); + Err(e.into()) + } + } } pub async fn predict_image( diff --git a/src/types/count_tokens.rs b/src/types/count_tokens.rs index 79af363..d764bc4 100644 --- a/src/types/count_tokens.rs +++ b/src/types/count_tokens.rs @@ -1,5 +1,7 @@ use serde::{Deserialize, Serialize}; +use crate::error::{Error, Result}; + use super::Content; #[derive(Debug, Serialize, Deserialize)] @@ -38,12 +40,22 @@ impl CountTokensRequestBuilder { #[derive(Debug, Serialize, Deserialize)] #[serde(untagged)] pub enum CountTokensResponse { - #[serde(rename_all = "camelCase")] - Ok { - total_tokens: i32, - total_billable_characters: u32, - }, - Error { - error: super::VertexApiError, - }, + Ok(CountTokensResponseResult), + Error { error: super::VertexApiError }, +} + +impl CountTokensResponse { + pub fn into_result(self) -> Result { + match self { + CountTokensResponse::Ok(result) => Ok(result), + CountTokensResponse::Error { error } => Err(Error::VertexError(error)), + } + } +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct CountTokensResponseResult { + pub total_tokens: i32, + pub total_billable_characters: u32, }