From 93285e53dd3bd57294cc3295a5a5ae9892403018 Mon Sep 17 00:00:00 2001 From: Andre Bandarra Date: Fri, 23 Feb 2024 14:42:51 +0000 Subject: [PATCH] Updates client to use the generateContent endpoint when not streaming --- examples/text-from-text-streaming.rs | 13 +++-- src/client.rs | 76 +++++++++++++--------------- src/error.rs | 4 ++ src/types.rs | 20 +++----- 4 files changed, 54 insertions(+), 59 deletions(-) diff --git a/examples/text-from-text-streaming.rs b/examples/text-from-text-streaming.rs index 2dfd41d..ddd1b0f 100644 --- a/examples/text-from-text-streaming.rs +++ b/examples/text-from-text-streaming.rs @@ -21,13 +21,16 @@ async fn main() -> Result<(), Box> { let prompt = "Tell me the story of the genesis of the universe as a bedtime story."; let request = GenerateContentRequest::from_prompt(prompt, None); let queue = gemini - .streaming_stream_generate_content(&request, Model::GeminiPro) + .stream_generate_content(&request, Model::GeminiPro) .await; - while let Some(chunk) = queue.pop().await { - if let ResponseStreamChunk::Ok(ok_response) = chunk { - let text = ok_response - .candidates + while let Some(response) = queue.pop().await { + if let GenerateContentResponse::Ok { + candidates, + usage_metadata: _, + } = response + { + let text = candidates .iter() .filter_map(|c| c.get_text()) .collect::(); diff --git a/src/client.rs b/src/client.rs index 2214978..6f1037d 100644 --- a/src/client.rs +++ b/src/client.rs @@ -7,7 +7,7 @@ use reqwest_eventsource::{Event, EventSource}; use crate::dialogue::{Message, Role}; use crate::error::{Error, Result}; use crate::prelude::{ - Content, GenerateContentRequest, GenerateContentResponse, GenerationConfig, ResponseStreamChunk, + Candidate, Content, GenerateContentRequest, GenerateContentResponse, GenerationConfig, }; use crate::{prelude::Part, token_provider::TokenProvider}; @@ -53,12 +53,12 @@ impl GeminiClient { } } - pub async fn streaming_stream_generate_content( + pub async fn stream_generate_content( &self, request: &GenerateContentRequest, model: Model, - ) -> Arc>> { - let queue = Arc::new(Queue::>::new()); + ) -> Arc>> { + let queue = Arc::new(Queue::>::new()); // Clone the queue and other necessary data to move into the async block. let cloned_queue = queue.clone(); @@ -78,7 +78,8 @@ impl GeminiClient { let mut event_source = EventSource::new(req).unwrap(); while let Some(Ok(event)) = event_source.next().await { if let Event::Message(event) = event { - let response: ResponseStreamChunk = serde_json::from_str(&event.data).unwrap(); + let response: GenerateContentResponse = + serde_json::from_str(&event.data).unwrap(); cloned_queue.push(Some(response)); } } @@ -89,14 +90,14 @@ impl GeminiClient { queue } - pub async fn stream_generate_content( + pub async fn generate_content( &self, request: &GenerateContentRequest, model: Model, ) -> Result { let access_token = self.token_provider.get_token(AUTH_SCOPE).await?; let endpoint_url: String = format!( - "https://{}/v1beta1/projects/{}/locations/{}/publishers/google/models/{}:streamGenerateContent", self.api_endpoint, self.project_id, self.location_id, model.to_string(), + "https://{}/v1beta1/projects/{}/locations/{}/publishers/google/models/{}:generateContent", self.api_endpoint, self.project_id, self.location_id, model.to_string(), ); let resp = self .client @@ -130,13 +131,15 @@ impl GeminiClient { tools: None, }; - let response = self - .stream_generate_content(&request, Model::GeminiPro) - .await?; + let response = self.generate_content(&request, Model::GeminiPro).await?; // Check for errors in the response. - let text = GeminiClient::::collect_text_from_response(response)?; - Ok(Message::new(Role::Model, &text)) + let mut candidates = GeminiClient::::collect_text_from_response(&response)?; + + match candidates.pop() { + Some(text) => Ok(Message::new(Role::Model, &text)), + None => Err(Error::NoCandidatesError), + } } /// Sends a text prompt to the Vertex API using the Gemini Pro model and extracts the text @@ -155,40 +158,29 @@ impl GeminiClient { tools: None, }; - let response = self - .stream_generate_content(&request, Model::GeminiPro) - .await?; + let response = self.generate_content(&request, Model::GeminiPro).await?; + let mut candidates = GeminiClient::::collect_text_from_response(&response)?; - GeminiClient::::collect_text_from_response(response) + match candidates.pop() { + Some(candidate) => Ok(candidate), + None => Err(Error::NoCandidatesError), + } } - fn collect_text_from_response(response: GenerateContentResponse) -> Result { - let mut text = String::new(); - for chunk in response { - match chunk { - ResponseStreamChunk::Ok(ok_response) => { - ok_response.candidates.iter().for_each(|c| { - if let Some(t) = c.get_text() { - text.push_str(&t); - } - }); - - for candidate in ok_response.candidates { - if let Some(parts) = &candidate.content.parts { - for part in parts { - if let Part::Text(t) = part { - text.push_str(t); - } - } - } - } - } - ResponseStreamChunk::Error(err) => { - tracing::error!("Error in response: {:?}", err); - return Err(Error::VertexError(err.clone())); - } + fn collect_text_from_response(response: &GenerateContentResponse) -> Result> { + match response { + GenerateContentResponse::Ok { + candidates, + usage_metadata: _, + } => Ok(candidates + .iter() + .map(Candidate::get_text) + .flatten() + .collect::>()), + GenerateContentResponse::Error { error } => { + tracing::error!("Error in response: {:?}", error); + return Err(Error::VertexError(error.clone())); } } - Ok(text) } } diff --git a/src/error.rs b/src/error.rs index 3c5708e..5210b88 100644 --- a/src/error.rs +++ b/src/error.rs @@ -11,6 +11,7 @@ pub enum Error { Token(gcp_auth::Error), Serde(serde_json::Error), VertexError(types::Error), + NoCandidatesError, } impl Display for Error { @@ -23,6 +24,9 @@ impl Display for Error { Error::VertexError(e) => { write!(f, "Vertex error: {}", serde_json::to_string(e).unwrap()) } + Error::NoCandidatesError => { + write!(f, "No candidates returned for the prompt") + } } } } diff --git a/src/types.rs b/src/types.rs index e74c892..f40dd7b 100644 --- a/src/types.rs +++ b/src/types.rs @@ -87,21 +87,17 @@ pub enum Part { }, } -pub type GenerateContentResponse = Vec; - #[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] #[serde(untagged)] -pub enum ResponseStreamChunk { - Ok(OkResponse), - Error(Error), -} - -#[derive(Debug, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct OkResponse { - pub candidates: Vec, - pub usage_metadata: Option, +pub enum GenerateContentResponse { + Ok { + candidates: Vec, + usage_metadata: Option, + }, + Error { + error: Error, + }, } #[derive(Debug, Serialize, Deserialize)]