Updates client to use the generateContent endpoint when not streaming

This commit is contained in:
2024-02-23 14:42:51 +00:00
parent 83335e65dc
commit 93285e53dd
4 changed files with 54 additions and 59 deletions

View File

@@ -21,13 +21,16 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
let prompt = "Tell me the story of the genesis of the universe as a bedtime story."; let prompt = "Tell me the story of the genesis of the universe as a bedtime story.";
let request = GenerateContentRequest::from_prompt(prompt, None); let request = GenerateContentRequest::from_prompt(prompt, None);
let queue = gemini let queue = gemini
.streaming_stream_generate_content(&request, Model::GeminiPro) .stream_generate_content(&request, Model::GeminiPro)
.await; .await;
while let Some(chunk) = queue.pop().await { while let Some(response) = queue.pop().await {
if let ResponseStreamChunk::Ok(ok_response) = chunk { if let GenerateContentResponse::Ok {
let text = ok_response candidates,
.candidates usage_metadata: _,
} = response
{
let text = candidates
.iter() .iter()
.filter_map(|c| c.get_text()) .filter_map(|c| c.get_text())
.collect::<String>(); .collect::<String>();

View File

@@ -7,7 +7,7 @@ use reqwest_eventsource::{Event, EventSource};
use crate::dialogue::{Message, Role}; use crate::dialogue::{Message, Role};
use crate::error::{Error, Result}; use crate::error::{Error, Result};
use crate::prelude::{ use crate::prelude::{
Content, GenerateContentRequest, GenerateContentResponse, GenerationConfig, ResponseStreamChunk, Candidate, Content, GenerateContentRequest, GenerateContentResponse, GenerationConfig,
}; };
use crate::{prelude::Part, token_provider::TokenProvider}; use crate::{prelude::Part, token_provider::TokenProvider};
@@ -53,12 +53,12 @@ impl<T: TokenProvider + Clone> GeminiClient<T> {
} }
} }
pub async fn streaming_stream_generate_content( pub async fn stream_generate_content(
&self, &self,
request: &GenerateContentRequest, request: &GenerateContentRequest,
model: Model, model: Model,
) -> Arc<Queue<Option<ResponseStreamChunk>>> { ) -> Arc<Queue<Option<GenerateContentResponse>>> {
let queue = Arc::new(Queue::<Option<ResponseStreamChunk>>::new()); let queue = Arc::new(Queue::<Option<GenerateContentResponse>>::new());
// Clone the queue and other necessary data to move into the async block. // Clone the queue and other necessary data to move into the async block.
let cloned_queue = queue.clone(); let cloned_queue = queue.clone();
@@ -78,7 +78,8 @@ impl<T: TokenProvider + Clone> GeminiClient<T> {
let mut event_source = EventSource::new(req).unwrap(); let mut event_source = EventSource::new(req).unwrap();
while let Some(Ok(event)) = event_source.next().await { while let Some(Ok(event)) = event_source.next().await {
if let Event::Message(event) = event { if let Event::Message(event) = event {
let response: ResponseStreamChunk = serde_json::from_str(&event.data).unwrap(); let response: GenerateContentResponse =
serde_json::from_str(&event.data).unwrap();
cloned_queue.push(Some(response)); cloned_queue.push(Some(response));
} }
} }
@@ -89,14 +90,14 @@ impl<T: TokenProvider + Clone> GeminiClient<T> {
queue queue
} }
pub async fn stream_generate_content( pub async fn generate_content(
&self, &self,
request: &GenerateContentRequest, request: &GenerateContentRequest,
model: Model, model: Model,
) -> Result<GenerateContentResponse> { ) -> Result<GenerateContentResponse> {
let access_token = self.token_provider.get_token(AUTH_SCOPE).await?; let access_token = self.token_provider.get_token(AUTH_SCOPE).await?;
let endpoint_url: String = format!( let endpoint_url: String = format!(
"https://{}/v1beta1/projects/{}/locations/{}/publishers/google/models/{}:streamGenerateContent", self.api_endpoint, self.project_id, self.location_id, model.to_string(), "https://{}/v1beta1/projects/{}/locations/{}/publishers/google/models/{}:generateContent", self.api_endpoint, self.project_id, self.location_id, model.to_string(),
); );
let resp = self let resp = self
.client .client
@@ -130,13 +131,15 @@ impl<T: TokenProvider + Clone> GeminiClient<T> {
tools: None, tools: None,
}; };
let response = self let response = self.generate_content(&request, Model::GeminiPro).await?;
.stream_generate_content(&request, Model::GeminiPro)
.await?;
// Check for errors in the response. // Check for errors in the response.
let text = GeminiClient::<T>::collect_text_from_response(response)?; let mut candidates = GeminiClient::<T>::collect_text_from_response(&response)?;
Ok(Message::new(Role::Model, &text))
match candidates.pop() {
Some(text) => Ok(Message::new(Role::Model, &text)),
None => Err(Error::NoCandidatesError),
}
} }
/// Sends a text prompt to the Vertex API using the Gemini Pro model and extracts the text /// Sends a text prompt to the Vertex API using the Gemini Pro model and extracts the text
@@ -155,40 +158,29 @@ impl<T: TokenProvider + Clone> GeminiClient<T> {
tools: None, tools: None,
}; };
let response = self let response = self.generate_content(&request, Model::GeminiPro).await?;
.stream_generate_content(&request, Model::GeminiPro) let mut candidates = GeminiClient::<T>::collect_text_from_response(&response)?;
.await?;
GeminiClient::<T>::collect_text_from_response(response) match candidates.pop() {
Some(candidate) => Ok(candidate),
None => Err(Error::NoCandidatesError),
}
} }
fn collect_text_from_response(response: GenerateContentResponse) -> Result<String> { fn collect_text_from_response(response: &GenerateContentResponse) -> Result<Vec<String>> {
let mut text = String::new(); match response {
for chunk in response { GenerateContentResponse::Ok {
match chunk { candidates,
ResponseStreamChunk::Ok(ok_response) => { usage_metadata: _,
ok_response.candidates.iter().for_each(|c| { } => Ok(candidates
if let Some(t) = c.get_text() { .iter()
text.push_str(&t); .map(Candidate::get_text)
} .flatten()
}); .collect::<Vec<String>>()),
GenerateContentResponse::Error { error } => {
for candidate in ok_response.candidates { tracing::error!("Error in response: {:?}", error);
if let Some(parts) = &candidate.content.parts { return Err(Error::VertexError(error.clone()));
for part in parts {
if let Part::Text(t) = part {
text.push_str(t);
}
}
}
}
}
ResponseStreamChunk::Error(err) => {
tracing::error!("Error in response: {:?}", err);
return Err(Error::VertexError(err.clone()));
}
} }
} }
Ok(text)
} }
} }

View File

@@ -11,6 +11,7 @@ pub enum Error {
Token(gcp_auth::Error), Token(gcp_auth::Error),
Serde(serde_json::Error), Serde(serde_json::Error),
VertexError(types::Error), VertexError(types::Error),
NoCandidatesError,
} }
impl Display for Error { impl Display for Error {
@@ -23,6 +24,9 @@ impl Display for Error {
Error::VertexError(e) => { Error::VertexError(e) => {
write!(f, "Vertex error: {}", serde_json::to_string(e).unwrap()) write!(f, "Vertex error: {}", serde_json::to_string(e).unwrap())
} }
Error::NoCandidatesError => {
write!(f, "No candidates returned for the prompt")
}
} }
} }
} }

View File

@@ -87,21 +87,17 @@ pub enum Part {
}, },
} }
pub type GenerateContentResponse = Vec<ResponseStreamChunk>;
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
#[serde(untagged)] #[serde(untagged)]
pub enum ResponseStreamChunk { pub enum GenerateContentResponse {
Ok(OkResponse), Ok {
Error(Error), candidates: Vec<Candidate>,
} usage_metadata: Option<UsageMetadata>,
},
#[derive(Debug, Serialize, Deserialize)] Error {
#[serde(rename_all = "camelCase")] error: Error,
pub struct OkResponse { },
pub candidates: Vec<Candidate>,
pub usage_metadata: Option<UsageMetadata>,
} }
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]