Makes generate content error result into Error
This commit is contained in:
@@ -38,13 +38,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
.generate_content(&request, "gemini-1.0-pro-002")
|
||||
.await?;
|
||||
|
||||
if let GenerateContentResponse::Ok {
|
||||
candidates,
|
||||
usage_metadata: _,
|
||||
} = result
|
||||
{
|
||||
println!("Response: {:?}", candidates[0].get_text().unwrap());
|
||||
}
|
||||
println!("Response: {:?}", result.candidates[0].get_text().unwrap());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -23,16 +23,18 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let queue = gemini.stream_generate_content(&request, "gemini-pro").await;
|
||||
|
||||
while let Some(response) = queue.pop().await {
|
||||
if let GenerateContentResponse::Ok {
|
||||
candidates,
|
||||
usage_metadata: _,
|
||||
} = response
|
||||
{
|
||||
let text = candidates
|
||||
.iter()
|
||||
.filter_map(|c| c.get_text())
|
||||
.collect::<String>();
|
||||
print!("{}", text);
|
||||
match response {
|
||||
Ok(result) => {
|
||||
let text = result
|
||||
.candidates
|
||||
.iter()
|
||||
.filter_map(|c| c.get_text())
|
||||
.collect::<String>();
|
||||
print!("{}", text);
|
||||
}
|
||||
Err(error) => {
|
||||
println!("{error}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -8,7 +8,8 @@ use crate::dialogue::{Message, Role};
|
||||
use crate::error::{Error, Result};
|
||||
use crate::prelude::{
|
||||
Candidate, Content, CountTokensRequest, CountTokensResponse, GenerateContentRequest,
|
||||
GenerateContentResponse, GenerationConfig, TextEmbeddingRequest, TextEmbeddingResponse,
|
||||
GenerateContentResponse, GenerateContentResponseResult, GenerationConfig, TextEmbeddingRequest,
|
||||
TextEmbeddingResponse,
|
||||
};
|
||||
use crate::{prelude::Part, token_provider::TokenProvider};
|
||||
|
||||
@@ -46,12 +47,18 @@ impl<T: TokenProvider + Clone> GeminiClient<T> {
|
||||
&self,
|
||||
request: &GenerateContentRequest,
|
||||
model: &str,
|
||||
) -> Arc<Queue<Option<GenerateContentResponse>>> {
|
||||
let queue = Arc::new(Queue::<Option<GenerateContentResponse>>::new());
|
||||
) -> Arc<Queue<Option<Result<GenerateContentResponseResult>>>> {
|
||||
let queue = Arc::new(Queue::<Option<Result<GenerateContentResponseResult>>>::new());
|
||||
let access_token = match self.token_provider.get_token(AUTH_SCOPE).await {
|
||||
Ok(access_token) => access_token,
|
||||
Err(e) => {
|
||||
queue.push(Some(Err(e.into())));
|
||||
return queue;
|
||||
}
|
||||
};
|
||||
|
||||
// Clone the queue and other necessary data to move into the async block.
|
||||
let cloned_queue = queue.clone();
|
||||
let access_token = self.token_provider.get_token(AUTH_SCOPE).await.unwrap();
|
||||
let endpoint_url: String = format!(
|
||||
"https://{}/v1beta1/projects/{}/locations/{}/publishers/google/models/{}:streamGenerateContent?alt=sse", self.api_endpoint, self.project_id, self.location_id, model.to_string(),
|
||||
);
|
||||
@@ -64,7 +71,14 @@ impl<T: TokenProvider + Clone> GeminiClient<T> {
|
||||
.post(&endpoint_url)
|
||||
.bearer_auth(access_token)
|
||||
.json(&request);
|
||||
let mut event_source = EventSource::new(req).unwrap();
|
||||
|
||||
let mut event_source = match EventSource::new(req) {
|
||||
Ok(event_source) => event_source,
|
||||
Err(e) => {
|
||||
cloned_queue.push(Some(Err(e.into())));
|
||||
return;
|
||||
}
|
||||
};
|
||||
while let Some(event) = event_source.next().await {
|
||||
match event {
|
||||
Ok(event) => {
|
||||
@@ -72,7 +86,7 @@ impl<T: TokenProvider + Clone> GeminiClient<T> {
|
||||
let response: serde_json::error::Result<GenerateContentResponse> =
|
||||
serde_json::from_str(&event.data);
|
||||
if let Ok(response) = response {
|
||||
cloned_queue.push(Some(response));
|
||||
cloned_queue.push(Some(response.into_result()));
|
||||
} else {
|
||||
tracing::error!("Error parsing message: {}", event.data);
|
||||
};
|
||||
@@ -95,7 +109,7 @@ impl<T: TokenProvider + Clone> GeminiClient<T> {
|
||||
&self,
|
||||
request: &GenerateContentRequest,
|
||||
model: &str,
|
||||
) -> Result<GenerateContentResponse> {
|
||||
) -> Result<GenerateContentResponseResult> {
|
||||
let access_token = self.token_provider.get_token(AUTH_SCOPE).await?;
|
||||
let endpoint_url: String = format!(
|
||||
"https://{}/v1beta1/projects/{}/locations/{}/publishers/google/models/{}:generateContent", self.api_endpoint, self.project_id, self.location_id, model.to_string(),
|
||||
@@ -137,7 +151,7 @@ impl<T: TokenProvider + Clone> GeminiClient<T> {
|
||||
let response = self.generate_content(&request, model).await?;
|
||||
|
||||
// Check for errors in the response.
|
||||
let mut candidates = GeminiClient::<T>::collect_text_from_response(&response)?;
|
||||
let mut candidates = GeminiClient::<T>::collect_text_from_response(&response);
|
||||
|
||||
match candidates.pop() {
|
||||
Some(text) => Ok(Message::new(Role::Model, &text)),
|
||||
@@ -163,7 +177,7 @@ impl<T: TokenProvider + Clone> GeminiClient<T> {
|
||||
};
|
||||
|
||||
let response = self.generate_content(&request, "gemini-pro").await?;
|
||||
let mut candidates = GeminiClient::<T>::collect_text_from_response(&response)?;
|
||||
let mut candidates = GeminiClient::<T>::collect_text_from_response(&response);
|
||||
|
||||
match candidates.pop() {
|
||||
Some(candidate) => Ok(candidate),
|
||||
@@ -171,21 +185,13 @@ impl<T: TokenProvider + Clone> GeminiClient<T> {
|
||||
}
|
||||
}
|
||||
|
||||
fn collect_text_from_response(response: &GenerateContentResponse) -> Result<Vec<String>> {
|
||||
match response {
|
||||
GenerateContentResponse::Ok {
|
||||
candidates,
|
||||
usage_metadata: _,
|
||||
} => Ok(candidates
|
||||
.iter()
|
||||
.map(Candidate::get_text)
|
||||
.flatten()
|
||||
.collect::<Vec<String>>()),
|
||||
GenerateContentResponse::Error { error } => {
|
||||
tracing::error!("Error in response: {:?}", error);
|
||||
return Err(Error::VertexError(error.clone()));
|
||||
}
|
||||
}
|
||||
fn collect_text_from_response(response: &GenerateContentResponseResult) -> Vec<String> {
|
||||
response
|
||||
.candidates
|
||||
.iter()
|
||||
.map(Candidate::get_text)
|
||||
.flatten()
|
||||
.collect::<Vec<String>>()
|
||||
}
|
||||
|
||||
pub async fn text_embeddings(
|
||||
|
||||
20
src/error.rs
20
src/error.rs
@@ -1,5 +1,7 @@
|
||||
use std::fmt::Display;
|
||||
|
||||
use reqwest_eventsource::CannotCloneRequestError;
|
||||
|
||||
use crate::types;
|
||||
|
||||
pub type Result<T> = std::result::Result<T, Error>;
|
||||
@@ -10,8 +12,9 @@ pub enum Error {
|
||||
HttpClient(reqwest::Error),
|
||||
Token(gcp_auth::Error),
|
||||
Serde(serde_json::Error),
|
||||
VertexError(types::Error),
|
||||
VertexError(types::VertexApiError),
|
||||
NoCandidatesError,
|
||||
EventSourceError(CannotCloneRequestError),
|
||||
}
|
||||
|
||||
impl Display for Error {
|
||||
@@ -22,11 +25,14 @@ impl Display for Error {
|
||||
Error::Token(e) => write!(f, "Token error: {}", e),
|
||||
Error::Serde(e) => write!(f, "Serde error: {}", e),
|
||||
Error::VertexError(e) => {
|
||||
write!(f, "Vertex error: {}", serde_json::to_string(e).unwrap())
|
||||
write!(f, "Vertex error: {}", e.to_string())
|
||||
}
|
||||
Error::NoCandidatesError => {
|
||||
write!(f, "No candidates returned for the prompt")
|
||||
}
|
||||
Error::EventSourceError(e) => {
|
||||
write!(f, "EventSourrce Error: {}", e)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -57,8 +63,14 @@ impl From<serde_json::Error> for Error {
|
||||
}
|
||||
}
|
||||
|
||||
impl From<types::Error> for Error {
|
||||
fn from(e: types::Error) -> Self {
|
||||
impl From<types::VertexApiError> for Error {
|
||||
fn from(e: types::VertexApiError) -> Self {
|
||||
Error::VertexError(e)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<CannotCloneRequestError> for Error {
|
||||
fn from(e: CannotCloneRequestError) -> Self {
|
||||
Error::EventSourceError(e)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -16,6 +16,6 @@ pub enum CountTokensResponse {
|
||||
total_billable_characters: u32,
|
||||
},
|
||||
Error {
|
||||
error: super::Error,
|
||||
error: super::VertexApiError,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -1,13 +1,24 @@
|
||||
use std::fmt::Formatter;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct Error {
|
||||
pub struct VertexApiError {
|
||||
pub code: i32,
|
||||
pub message: String,
|
||||
pub status: String,
|
||||
pub details: Option<Vec<ErrorType>>,
|
||||
}
|
||||
|
||||
impl core::fmt::Display for VertexApiError {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), std::fmt::Error> {
|
||||
writeln!(f, "Vertex API Error {} - {}", self.code, self.message)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for VertexApiError {}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct Link {
|
||||
pub description: String,
|
||||
|
||||
@@ -2,7 +2,8 @@ use std::collections::HashMap;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use super::{Content, Error, Part};
|
||||
use super::{Content, Part, VertexApiError};
|
||||
use crate::error::Result;
|
||||
|
||||
#[derive(Clone, Default, Serialize, Deserialize)]
|
||||
pub struct GenerateContentRequest {
|
||||
@@ -116,16 +117,40 @@ pub struct FunctionParametersProperty {
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
#[serde(untagged)]
|
||||
pub enum GenerateContentResponse {
|
||||
Ok {
|
||||
candidates: Vec<Candidate>,
|
||||
usage_metadata: Option<UsageMetadata>,
|
||||
},
|
||||
Error {
|
||||
error: Error,
|
||||
},
|
||||
Ok(GenerateContentResponseResult),
|
||||
Error(GenerateContentResponseError),
|
||||
}
|
||||
|
||||
impl Into<Result<GenerateContentResponseResult>> for GenerateContentResponse {
|
||||
fn into(self) -> Result<GenerateContentResponseResult> {
|
||||
match self {
|
||||
GenerateContentResponse::Ok(result) => Ok(result),
|
||||
GenerateContentResponse::Error(error) => Err(error.error.into()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct GenerateContentResponseResult {
|
||||
pub candidates: Vec<Candidate>,
|
||||
pub usage_metadata: Option<UsageMetadata>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct GenerateContentResponseError {
|
||||
pub error: VertexApiError,
|
||||
}
|
||||
|
||||
impl GenerateContentResponse {
|
||||
pub fn into_result(self) -> Result<GenerateContentResponseResult> {
|
||||
match self {
|
||||
GenerateContentResponse::Ok(result) => Ok(result),
|
||||
GenerateContentResponse::Error(error) => Err(error.error.into()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use super::Error;
|
||||
use super::VertexApiError;
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct TextEmbeddingRequest {
|
||||
@@ -21,7 +21,7 @@ pub enum TextEmbeddingResponse {
|
||||
predictions: Vec<TextEmbeddingPrediction>,
|
||||
},
|
||||
Error {
|
||||
error: Error,
|
||||
error: VertexApiError,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user