Makes generate content error result into Error

This commit is contained in:
2024-05-01 17:14:32 +01:00
parent 9754618153
commit a640d97efb
8 changed files with 108 additions and 58 deletions

View File

@@ -8,7 +8,8 @@ use crate::dialogue::{Message, Role};
use crate::error::{Error, Result};
use crate::prelude::{
Candidate, Content, CountTokensRequest, CountTokensResponse, GenerateContentRequest,
GenerateContentResponse, GenerationConfig, TextEmbeddingRequest, TextEmbeddingResponse,
GenerateContentResponse, GenerateContentResponseResult, GenerationConfig, TextEmbeddingRequest,
TextEmbeddingResponse,
};
use crate::{prelude::Part, token_provider::TokenProvider};
@@ -46,12 +47,18 @@ impl<T: TokenProvider + Clone> GeminiClient<T> {
&self,
request: &GenerateContentRequest,
model: &str,
) -> Arc<Queue<Option<GenerateContentResponse>>> {
let queue = Arc::new(Queue::<Option<GenerateContentResponse>>::new());
) -> Arc<Queue<Option<Result<GenerateContentResponseResult>>>> {
let queue = Arc::new(Queue::<Option<Result<GenerateContentResponseResult>>>::new());
let access_token = match self.token_provider.get_token(AUTH_SCOPE).await {
Ok(access_token) => access_token,
Err(e) => {
queue.push(Some(Err(e.into())));
return queue;
}
};
// Clone the queue and other necessary data to move into the async block.
let cloned_queue = queue.clone();
let access_token = self.token_provider.get_token(AUTH_SCOPE).await.unwrap();
let endpoint_url: String = format!(
"https://{}/v1beta1/projects/{}/locations/{}/publishers/google/models/{}:streamGenerateContent?alt=sse", self.api_endpoint, self.project_id, self.location_id, model.to_string(),
);
@@ -64,7 +71,14 @@ impl<T: TokenProvider + Clone> GeminiClient<T> {
.post(&endpoint_url)
.bearer_auth(access_token)
.json(&request);
let mut event_source = EventSource::new(req).unwrap();
let mut event_source = match EventSource::new(req) {
Ok(event_source) => event_source,
Err(e) => {
cloned_queue.push(Some(Err(e.into())));
return;
}
};
while let Some(event) = event_source.next().await {
match event {
Ok(event) => {
@@ -72,7 +86,7 @@ impl<T: TokenProvider + Clone> GeminiClient<T> {
let response: serde_json::error::Result<GenerateContentResponse> =
serde_json::from_str(&event.data);
if let Ok(response) = response {
cloned_queue.push(Some(response));
cloned_queue.push(Some(response.into_result()));
} else {
tracing::error!("Error parsing message: {}", event.data);
};
@@ -95,7 +109,7 @@ impl<T: TokenProvider + Clone> GeminiClient<T> {
&self,
request: &GenerateContentRequest,
model: &str,
) -> Result<GenerateContentResponse> {
) -> Result<GenerateContentResponseResult> {
let access_token = self.token_provider.get_token(AUTH_SCOPE).await?;
let endpoint_url: String = format!(
"https://{}/v1beta1/projects/{}/locations/{}/publishers/google/models/{}:generateContent", self.api_endpoint, self.project_id, self.location_id, model.to_string(),
@@ -137,7 +151,7 @@ impl<T: TokenProvider + Clone> GeminiClient<T> {
let response = self.generate_content(&request, model).await?;
// Check for errors in the response.
let mut candidates = GeminiClient::<T>::collect_text_from_response(&response)?;
let mut candidates = GeminiClient::<T>::collect_text_from_response(&response);
match candidates.pop() {
Some(text) => Ok(Message::new(Role::Model, &text)),
@@ -163,7 +177,7 @@ impl<T: TokenProvider + Clone> GeminiClient<T> {
};
let response = self.generate_content(&request, "gemini-pro").await?;
let mut candidates = GeminiClient::<T>::collect_text_from_response(&response)?;
let mut candidates = GeminiClient::<T>::collect_text_from_response(&response);
match candidates.pop() {
Some(candidate) => Ok(candidate),
@@ -171,21 +185,13 @@ impl<T: TokenProvider + Clone> GeminiClient<T> {
}
}
fn collect_text_from_response(response: &GenerateContentResponse) -> Result<Vec<String>> {
match response {
GenerateContentResponse::Ok {
candidates,
usage_metadata: _,
} => Ok(candidates
.iter()
.map(Candidate::get_text)
.flatten()
.collect::<Vec<String>>()),
GenerateContentResponse::Error { error } => {
tracing::error!("Error in response: {:?}", error);
return Err(Error::VertexError(error.clone()));
}
}
fn collect_text_from_response(response: &GenerateContentResponseResult) -> Vec<String> {
response
.candidates
.iter()
.map(Candidate::get_text)
.flatten()
.collect::<Vec<String>>()
}
pub async fn text_embeddings(