Add consistent error handling to text_embeddings and count_tokens
- Check HTTP status before parsing response body, matching the pattern used by generate_content and predict_image - Unwrap TextEmbeddingResponse enum, returning TextEmbeddingResponseOk - Extract CountTokensResponseResult struct and add into_result(), returning the unwrapped result instead of the raw enum - All endpoints now consistently return the success type directly and surface API errors as GeminiError or GenericApiError
This commit is contained in:
@@ -98,7 +98,7 @@ impl GeminiClient {
|
|||||||
&self,
|
&self,
|
||||||
request: &TextEmbeddingRequest,
|
request: &TextEmbeddingRequest,
|
||||||
model: &str,
|
model: &str,
|
||||||
) -> GeminiResult<TextEmbeddingResponse> {
|
) -> GeminiResult<TextEmbeddingResponseOk> {
|
||||||
let endpoint_url =
|
let endpoint_url =
|
||||||
format!("https://generativelanguage.googleapis.com/v1beta/models/{model}:predict");
|
format!("https://generativelanguage.googleapis.com/v1beta/models/{model}:predict");
|
||||||
let resp = self
|
let resp = self
|
||||||
@@ -108,16 +108,37 @@ impl GeminiClient {
|
|||||||
.json(&request)
|
.json(&request)
|
||||||
.send()
|
.send()
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
|
let status = resp.status();
|
||||||
let txt_json = resp.text().await?;
|
let txt_json = resp.text().await?;
|
||||||
tracing::debug!("text_embeddings response: {:?}", txt_json);
|
tracing::debug!("text_embeddings response: {:?}", txt_json);
|
||||||
Ok(serde_json::from_str::<TextEmbeddingResponse>(&txt_json)?)
|
|
||||||
|
if !status.is_success() {
|
||||||
|
if let Ok(gemini_error) =
|
||||||
|
serde_json::from_str::<crate::types::GeminiApiError>(&txt_json)
|
||||||
|
{
|
||||||
|
return Err(GeminiError::GeminiError(gemini_error));
|
||||||
|
}
|
||||||
|
return Err(GeminiError::GenericApiError {
|
||||||
|
status: status.as_u16(),
|
||||||
|
body: txt_json,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
match serde_json::from_str::<TextEmbeddingResponse>(&txt_json) {
|
||||||
|
Ok(response) => Ok(response.into_result()?),
|
||||||
|
Err(e) => {
|
||||||
|
error!(response = txt_json, error = ?e, "Failed to parse response");
|
||||||
|
Err(e.into())
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn count_tokens(
|
pub async fn count_tokens(
|
||||||
&self,
|
&self,
|
||||||
request: &CountTokensRequest,
|
request: &CountTokensRequest,
|
||||||
model: &str,
|
model: &str,
|
||||||
) -> GeminiResult<CountTokensResponse> {
|
) -> GeminiResult<CountTokensResponseResult> {
|
||||||
let endpoint_url =
|
let endpoint_url =
|
||||||
format!("https://generativelanguage.googleapis.com/v1beta/models/{model}:countTokens");
|
format!("https://generativelanguage.googleapis.com/v1beta/models/{model}:countTokens");
|
||||||
let resp = self
|
let resp = self
|
||||||
@@ -128,9 +149,29 @@ impl GeminiClient {
|
|||||||
.send()
|
.send()
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
|
let status = resp.status();
|
||||||
let txt_json = resp.text().await?;
|
let txt_json = resp.text().await?;
|
||||||
tracing::debug!("count_tokens response: {:?}", txt_json);
|
tracing::debug!("count_tokens response: {:?}", txt_json);
|
||||||
Ok(serde_json::from_str(&txt_json)?)
|
|
||||||
|
if !status.is_success() {
|
||||||
|
if let Ok(gemini_error) =
|
||||||
|
serde_json::from_str::<crate::types::GeminiApiError>(&txt_json)
|
||||||
|
{
|
||||||
|
return Err(GeminiError::GeminiError(gemini_error));
|
||||||
|
}
|
||||||
|
return Err(GeminiError::GenericApiError {
|
||||||
|
status: status.as_u16(),
|
||||||
|
body: txt_json,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
match serde_json::from_str::<CountTokensResponse>(&txt_json) {
|
||||||
|
Ok(response) => Ok(response.into_result()?),
|
||||||
|
Err(e) => {
|
||||||
|
error!(response = txt_json, error = ?e, "Failed to parse response");
|
||||||
|
Err(e.into())
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn predict_image(
|
pub async fn predict_image(
|
||||||
|
|||||||
@@ -1,5 +1,7 @@
|
|||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
use crate::error::{Error, Result};
|
||||||
|
|
||||||
use super::Content;
|
use super::Content;
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
@@ -38,12 +40,22 @@ impl CountTokensRequestBuilder {
|
|||||||
#[derive(Debug, Serialize, Deserialize)]
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
#[serde(untagged)]
|
#[serde(untagged)]
|
||||||
pub enum CountTokensResponse {
|
pub enum CountTokensResponse {
|
||||||
#[serde(rename_all = "camelCase")]
|
Ok(CountTokensResponseResult),
|
||||||
Ok {
|
Error { error: super::VertexApiError },
|
||||||
total_tokens: i32,
|
}
|
||||||
total_billable_characters: u32,
|
|
||||||
},
|
impl CountTokensResponse {
|
||||||
Error {
|
pub fn into_result(self) -> Result<CountTokensResponseResult> {
|
||||||
error: super::VertexApiError,
|
match self {
|
||||||
},
|
CountTokensResponse::Ok(result) => Ok(result),
|
||||||
|
CountTokensResponse::Error { error } => Err(Error::VertexError(error)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
#[serde(rename_all = "camelCase")]
|
||||||
|
pub struct CountTokensResponseResult {
|
||||||
|
pub total_tokens: i32,
|
||||||
|
pub total_billable_characters: u32,
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user