Updates client to use the generateContent endpoint when not streaming
This commit is contained in:
@@ -21,13 +21,16 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||||||
let prompt = "Tell me the story of the genesis of the universe as a bedtime story.";
|
let prompt = "Tell me the story of the genesis of the universe as a bedtime story.";
|
||||||
let request = GenerateContentRequest::from_prompt(prompt, None);
|
let request = GenerateContentRequest::from_prompt(prompt, None);
|
||||||
let queue = gemini
|
let queue = gemini
|
||||||
.streaming_stream_generate_content(&request, Model::GeminiPro)
|
.stream_generate_content(&request, Model::GeminiPro)
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
while let Some(chunk) = queue.pop().await {
|
while let Some(response) = queue.pop().await {
|
||||||
if let ResponseStreamChunk::Ok(ok_response) = chunk {
|
if let GenerateContentResponse::Ok {
|
||||||
let text = ok_response
|
candidates,
|
||||||
.candidates
|
usage_metadata: _,
|
||||||
|
} = response
|
||||||
|
{
|
||||||
|
let text = candidates
|
||||||
.iter()
|
.iter()
|
||||||
.filter_map(|c| c.get_text())
|
.filter_map(|c| c.get_text())
|
||||||
.collect::<String>();
|
.collect::<String>();
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ use reqwest_eventsource::{Event, EventSource};
|
|||||||
use crate::dialogue::{Message, Role};
|
use crate::dialogue::{Message, Role};
|
||||||
use crate::error::{Error, Result};
|
use crate::error::{Error, Result};
|
||||||
use crate::prelude::{
|
use crate::prelude::{
|
||||||
Content, GenerateContentRequest, GenerateContentResponse, GenerationConfig, ResponseStreamChunk,
|
Candidate, Content, GenerateContentRequest, GenerateContentResponse, GenerationConfig,
|
||||||
};
|
};
|
||||||
use crate::{prelude::Part, token_provider::TokenProvider};
|
use crate::{prelude::Part, token_provider::TokenProvider};
|
||||||
|
|
||||||
@@ -53,12 +53,12 @@ impl<T: TokenProvider + Clone> GeminiClient<T> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn streaming_stream_generate_content(
|
pub async fn stream_generate_content(
|
||||||
&self,
|
&self,
|
||||||
request: &GenerateContentRequest,
|
request: &GenerateContentRequest,
|
||||||
model: Model,
|
model: Model,
|
||||||
) -> Arc<Queue<Option<ResponseStreamChunk>>> {
|
) -> Arc<Queue<Option<GenerateContentResponse>>> {
|
||||||
let queue = Arc::new(Queue::<Option<ResponseStreamChunk>>::new());
|
let queue = Arc::new(Queue::<Option<GenerateContentResponse>>::new());
|
||||||
|
|
||||||
// Clone the queue and other necessary data to move into the async block.
|
// Clone the queue and other necessary data to move into the async block.
|
||||||
let cloned_queue = queue.clone();
|
let cloned_queue = queue.clone();
|
||||||
@@ -78,7 +78,8 @@ impl<T: TokenProvider + Clone> GeminiClient<T> {
|
|||||||
let mut event_source = EventSource::new(req).unwrap();
|
let mut event_source = EventSource::new(req).unwrap();
|
||||||
while let Some(Ok(event)) = event_source.next().await {
|
while let Some(Ok(event)) = event_source.next().await {
|
||||||
if let Event::Message(event) = event {
|
if let Event::Message(event) = event {
|
||||||
let response: ResponseStreamChunk = serde_json::from_str(&event.data).unwrap();
|
let response: GenerateContentResponse =
|
||||||
|
serde_json::from_str(&event.data).unwrap();
|
||||||
cloned_queue.push(Some(response));
|
cloned_queue.push(Some(response));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -89,14 +90,14 @@ impl<T: TokenProvider + Clone> GeminiClient<T> {
|
|||||||
queue
|
queue
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn stream_generate_content(
|
pub async fn generate_content(
|
||||||
&self,
|
&self,
|
||||||
request: &GenerateContentRequest,
|
request: &GenerateContentRequest,
|
||||||
model: Model,
|
model: Model,
|
||||||
) -> Result<GenerateContentResponse> {
|
) -> Result<GenerateContentResponse> {
|
||||||
let access_token = self.token_provider.get_token(AUTH_SCOPE).await?;
|
let access_token = self.token_provider.get_token(AUTH_SCOPE).await?;
|
||||||
let endpoint_url: String = format!(
|
let endpoint_url: String = format!(
|
||||||
"https://{}/v1beta1/projects/{}/locations/{}/publishers/google/models/{}:streamGenerateContent", self.api_endpoint, self.project_id, self.location_id, model.to_string(),
|
"https://{}/v1beta1/projects/{}/locations/{}/publishers/google/models/{}:generateContent", self.api_endpoint, self.project_id, self.location_id, model.to_string(),
|
||||||
);
|
);
|
||||||
let resp = self
|
let resp = self
|
||||||
.client
|
.client
|
||||||
@@ -130,13 +131,15 @@ impl<T: TokenProvider + Clone> GeminiClient<T> {
|
|||||||
tools: None,
|
tools: None,
|
||||||
};
|
};
|
||||||
|
|
||||||
let response = self
|
let response = self.generate_content(&request, Model::GeminiPro).await?;
|
||||||
.stream_generate_content(&request, Model::GeminiPro)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
// Check for errors in the response.
|
// Check for errors in the response.
|
||||||
let text = GeminiClient::<T>::collect_text_from_response(response)?;
|
let mut candidates = GeminiClient::<T>::collect_text_from_response(&response)?;
|
||||||
Ok(Message::new(Role::Model, &text))
|
|
||||||
|
match candidates.pop() {
|
||||||
|
Some(text) => Ok(Message::new(Role::Model, &text)),
|
||||||
|
None => Err(Error::NoCandidatesError),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Sends a text prompt to the Vertex API using the Gemini Pro model and extracts the text
|
/// Sends a text prompt to the Vertex API using the Gemini Pro model and extracts the text
|
||||||
@@ -155,40 +158,29 @@ impl<T: TokenProvider + Clone> GeminiClient<T> {
|
|||||||
tools: None,
|
tools: None,
|
||||||
};
|
};
|
||||||
|
|
||||||
let response = self
|
let response = self.generate_content(&request, Model::GeminiPro).await?;
|
||||||
.stream_generate_content(&request, Model::GeminiPro)
|
let mut candidates = GeminiClient::<T>::collect_text_from_response(&response)?;
|
||||||
.await?;
|
|
||||||
|
|
||||||
GeminiClient::<T>::collect_text_from_response(response)
|
match candidates.pop() {
|
||||||
|
Some(candidate) => Ok(candidate),
|
||||||
|
None => Err(Error::NoCandidatesError),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn collect_text_from_response(response: GenerateContentResponse) -> Result<String> {
|
fn collect_text_from_response(response: &GenerateContentResponse) -> Result<Vec<String>> {
|
||||||
let mut text = String::new();
|
match response {
|
||||||
for chunk in response {
|
GenerateContentResponse::Ok {
|
||||||
match chunk {
|
candidates,
|
||||||
ResponseStreamChunk::Ok(ok_response) => {
|
usage_metadata: _,
|
||||||
ok_response.candidates.iter().for_each(|c| {
|
} => Ok(candidates
|
||||||
if let Some(t) = c.get_text() {
|
.iter()
|
||||||
text.push_str(&t);
|
.map(Candidate::get_text)
|
||||||
}
|
.flatten()
|
||||||
});
|
.collect::<Vec<String>>()),
|
||||||
|
GenerateContentResponse::Error { error } => {
|
||||||
for candidate in ok_response.candidates {
|
tracing::error!("Error in response: {:?}", error);
|
||||||
if let Some(parts) = &candidate.content.parts {
|
return Err(Error::VertexError(error.clone()));
|
||||||
for part in parts {
|
|
||||||
if let Part::Text(t) = part {
|
|
||||||
text.push_str(t);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
ResponseStreamChunk::Error(err) => {
|
|
||||||
tracing::error!("Error in response: {:?}", err);
|
|
||||||
return Err(Error::VertexError(err.clone()));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Ok(text)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ pub enum Error {
|
|||||||
Token(gcp_auth::Error),
|
Token(gcp_auth::Error),
|
||||||
Serde(serde_json::Error),
|
Serde(serde_json::Error),
|
||||||
VertexError(types::Error),
|
VertexError(types::Error),
|
||||||
|
NoCandidatesError,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Display for Error {
|
impl Display for Error {
|
||||||
@@ -23,6 +24,9 @@ impl Display for Error {
|
|||||||
Error::VertexError(e) => {
|
Error::VertexError(e) => {
|
||||||
write!(f, "Vertex error: {}", serde_json::to_string(e).unwrap())
|
write!(f, "Vertex error: {}", serde_json::to_string(e).unwrap())
|
||||||
}
|
}
|
||||||
|
Error::NoCandidatesError => {
|
||||||
|
write!(f, "No candidates returned for the prompt")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
20
src/types.rs
20
src/types.rs
@@ -87,21 +87,17 @@ pub enum Part {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
pub type GenerateContentResponse = Vec<ResponseStreamChunk>;
|
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
#[serde(rename_all = "camelCase")]
|
#[serde(rename_all = "camelCase")]
|
||||||
#[serde(untagged)]
|
#[serde(untagged)]
|
||||||
pub enum ResponseStreamChunk {
|
pub enum GenerateContentResponse {
|
||||||
Ok(OkResponse),
|
Ok {
|
||||||
Error(Error),
|
candidates: Vec<Candidate>,
|
||||||
}
|
usage_metadata: Option<UsageMetadata>,
|
||||||
|
},
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
Error {
|
||||||
#[serde(rename_all = "camelCase")]
|
error: Error,
|
||||||
pub struct OkResponse {
|
},
|
||||||
pub candidates: Vec<Candidate>,
|
|
||||||
pub usage_metadata: Option<UsageMetadata>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
|||||||
Reference in New Issue
Block a user