Adds streaming generation that returns a streaming
- Introducues generate_content_stream that returns a Tokio Steam instead of a Queue. This allows using the standard stream APIs from tokio-streams. - Replace future-utils with tokio-streams, mainly due to better ergonomics for using the filter_map stream combinator.
This commit is contained in:
@@ -1,7 +1,9 @@
|
||||
use crate::error::Result as GeminiResult;
|
||||
use std::sync::Arc;
|
||||
use std::vec;
|
||||
use tokio_stream::{Stream, StreamExt};
|
||||
|
||||
use deadqueue::unlimited::Queue;
|
||||
use futures_util::stream::StreamExt;
|
||||
use reqwest_eventsource::{Event, EventSource};
|
||||
use tracing::error;
|
||||
|
||||
@@ -9,7 +11,7 @@ use crate::dialogue::Message;
|
||||
use crate::error::{Error, Result};
|
||||
use crate::prelude::{
|
||||
Candidate, Content, CountTokensRequest, CountTokensResponse, GenerateContentRequest,
|
||||
GenerateContentResponse, GenerateContentResponseResult, GenerationConfig, TextEmbeddingRequest,
|
||||
GenerateContentResponse, GenerateContentResponseResult, TextEmbeddingRequest,
|
||||
TextEmbeddingResponse,
|
||||
};
|
||||
use crate::types::{PredictImageRequest, PredictImageResponse, Role};
|
||||
@@ -45,6 +47,50 @@ impl<T: TokenProvider + Clone> GeminiClient<T> {
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn generate_content_stream(
|
||||
&self,
|
||||
request: &GenerateContentRequest,
|
||||
model: &str,
|
||||
) -> Result<impl Stream<Item = GeminiResult<GenerateContentResponseResult>>> {
|
||||
let access_token = self.token_provider.get_token(AUTH_SCOPE).await.unwrap();
|
||||
let endpoint_url = format!(
|
||||
"https://{}/v1beta1/projects/{}/locations/{}/publishers/google/models/{}:streamGenerateContent?alt=sse", self.api_endpoint, self.project_id, self.location_id, model,
|
||||
);
|
||||
let client = self.client.clone();
|
||||
let request = request.clone();
|
||||
let req = client
|
||||
.post(&endpoint_url)
|
||||
.bearer_auth(access_token)
|
||||
.json(&request);
|
||||
|
||||
let event_source = EventSource::new(req).unwrap();
|
||||
|
||||
let mapped = event_source.filter_map(|event| {
|
||||
let event = match event {
|
||||
Ok(event) => event,
|
||||
Err(e) => return Some(Err(e.into())),
|
||||
};
|
||||
|
||||
let Event::Message(event_message) = event else {
|
||||
return None;
|
||||
};
|
||||
|
||||
let gemini_response: GenerateContentResponse =
|
||||
match serde_json::from_str(&event_message.data) {
|
||||
Ok(gemini_response) => gemini_response,
|
||||
Err(e) => return Some(Err(e.into())),
|
||||
};
|
||||
|
||||
let gemini_response = match gemini_response.into_result() {
|
||||
Ok(gemini_response) => gemini_response,
|
||||
Err(e) => return Some(Err(e)),
|
||||
};
|
||||
|
||||
Some(Ok(gemini_response))
|
||||
});
|
||||
Ok(mapped)
|
||||
}
|
||||
|
||||
pub async fn stream_generate_content(
|
||||
&self,
|
||||
request: &GenerateContentRequest,
|
||||
@@ -175,34 +221,6 @@ impl<T: TokenProvider + Clone> GeminiClient<T> {
|
||||
}
|
||||
}
|
||||
|
||||
/// Sends a text prompt to the Vertex API using the Gemini Pro model and extracts the text
|
||||
/// from the response.
|
||||
#[deprecated(note = "Use `generate_content` instead")]
|
||||
pub async fn prompt_text(
|
||||
&self,
|
||||
prompt: &str,
|
||||
generation_config: Option<&GenerationConfig>,
|
||||
) -> Result<String> {
|
||||
let request = GenerateContentRequest {
|
||||
contents: vec![Content {
|
||||
role: Some(Role::User),
|
||||
parts: Some(vec![Part::Text(prompt.to_string())]),
|
||||
}],
|
||||
generation_config: generation_config.cloned(),
|
||||
tools: None,
|
||||
system_instruction: None,
|
||||
safety_settings: None,
|
||||
};
|
||||
|
||||
let response = self.generate_content(&request, "gemini-pro").await?;
|
||||
let mut candidates = GeminiClient::<T>::collect_text_from_response(&response);
|
||||
|
||||
match candidates.pop() {
|
||||
Some(candidate) => Ok(candidate),
|
||||
None => Err(Error::NoCandidatesError),
|
||||
}
|
||||
}
|
||||
|
||||
fn collect_text_from_response(response: &GenerateContentResponseResult) -> Vec<String> {
|
||||
response
|
||||
.candidates
|
||||
@@ -278,10 +296,10 @@ impl<T: TokenProvider + Clone> GeminiClient<T> {
|
||||
let txt_json = resp.text().await?;
|
||||
|
||||
match serde_json::from_str::<PredictImageResponse>(&txt_json) {
|
||||
Ok(response) => return Ok(response),
|
||||
Ok(response) => Ok(response),
|
||||
Err(e) => {
|
||||
error!(response = txt_json, error = ?e, "Failed to parse response");
|
||||
return Err(e.into());
|
||||
Err(e.into())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user