Adds streaming generation that returns a streaming
- Introducues generate_content_stream that returns a Tokio Steam instead of a Queue. This allows using the standard stream APIs from tokio-streams. - Replace future-utils with tokio-streams, mainly due to better ergonomics for using the filter_map stream combinator.
This commit is contained in:
@@ -7,7 +7,6 @@ edition = "2021"
|
|||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
deadqueue = "0.2"
|
deadqueue = "0.2"
|
||||||
futures-util = "0.3"
|
|
||||||
gcp_auth = "0.12"
|
gcp_auth = "0.12"
|
||||||
reqwest = { version = "0.12", features = ["json", "gzip"] }
|
reqwest = { version = "0.12", features = ["json", "gzip"] }
|
||||||
reqwest-eventsource = "0.6"
|
reqwest-eventsource = "0.6"
|
||||||
@@ -16,6 +15,7 @@ serde_json = { version = "1"}
|
|||||||
serde_with = { version = "3.9", features = ["base64"]}
|
serde_with = { version = "3.9", features = ["base64"]}
|
||||||
tracing = "0.1"
|
tracing = "0.1"
|
||||||
tokio = { version = "1" }
|
tokio = { version = "1" }
|
||||||
|
tokio-stream = "0.1.17"
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
console = "0.15.8"
|
console = "0.15.8"
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
use gemini_rs::prelude::*;
|
use gemini_rs::prelude::*;
|
||||||
|
use tokio_stream::StreamExt;
|
||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
@@ -21,22 +22,12 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||||||
|
|
||||||
let request = GenerateContentRequest::builder().contents(prompt).build();
|
let request = GenerateContentRequest::builder().contents(prompt).build();
|
||||||
|
|
||||||
let queue = gemini.stream_generate_content(&request, "gemini-pro").await;
|
let mut queue = gemini
|
||||||
|
.generate_content_stream(&request, "gemini-2.0-flash-001")
|
||||||
|
.await?;
|
||||||
|
|
||||||
while let Some(response) = queue.pop().await {
|
while let Some(Ok(response)) = queue.next().await {
|
||||||
match response {
|
println!("Response: {:?}", response);
|
||||||
Ok(result) => {
|
|
||||||
let text = result
|
|
||||||
.candidates
|
|
||||||
.iter()
|
|
||||||
.filter_map(|c| c.get_text())
|
|
||||||
.collect::<String>();
|
|
||||||
print!("{}", text);
|
|
||||||
}
|
|
||||||
Err(error) => {
|
|
||||||
println!("{error}");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
|
|||||||
@@ -1,7 +1,9 @@
|
|||||||
|
use crate::error::Result as GeminiResult;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
use std::vec;
|
||||||
|
use tokio_stream::{Stream, StreamExt};
|
||||||
|
|
||||||
use deadqueue::unlimited::Queue;
|
use deadqueue::unlimited::Queue;
|
||||||
use futures_util::stream::StreamExt;
|
|
||||||
use reqwest_eventsource::{Event, EventSource};
|
use reqwest_eventsource::{Event, EventSource};
|
||||||
use tracing::error;
|
use tracing::error;
|
||||||
|
|
||||||
@@ -9,7 +11,7 @@ use crate::dialogue::Message;
|
|||||||
use crate::error::{Error, Result};
|
use crate::error::{Error, Result};
|
||||||
use crate::prelude::{
|
use crate::prelude::{
|
||||||
Candidate, Content, CountTokensRequest, CountTokensResponse, GenerateContentRequest,
|
Candidate, Content, CountTokensRequest, CountTokensResponse, GenerateContentRequest,
|
||||||
GenerateContentResponse, GenerateContentResponseResult, GenerationConfig, TextEmbeddingRequest,
|
GenerateContentResponse, GenerateContentResponseResult, TextEmbeddingRequest,
|
||||||
TextEmbeddingResponse,
|
TextEmbeddingResponse,
|
||||||
};
|
};
|
||||||
use crate::types::{PredictImageRequest, PredictImageResponse, Role};
|
use crate::types::{PredictImageRequest, PredictImageResponse, Role};
|
||||||
@@ -45,6 +47,50 @@ impl<T: TokenProvider + Clone> GeminiClient<T> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub async fn generate_content_stream(
|
||||||
|
&self,
|
||||||
|
request: &GenerateContentRequest,
|
||||||
|
model: &str,
|
||||||
|
) -> Result<impl Stream<Item = GeminiResult<GenerateContentResponseResult>>> {
|
||||||
|
let access_token = self.token_provider.get_token(AUTH_SCOPE).await.unwrap();
|
||||||
|
let endpoint_url = format!(
|
||||||
|
"https://{}/v1beta1/projects/{}/locations/{}/publishers/google/models/{}:streamGenerateContent?alt=sse", self.api_endpoint, self.project_id, self.location_id, model,
|
||||||
|
);
|
||||||
|
let client = self.client.clone();
|
||||||
|
let request = request.clone();
|
||||||
|
let req = client
|
||||||
|
.post(&endpoint_url)
|
||||||
|
.bearer_auth(access_token)
|
||||||
|
.json(&request);
|
||||||
|
|
||||||
|
let event_source = EventSource::new(req).unwrap();
|
||||||
|
|
||||||
|
let mapped = event_source.filter_map(|event| {
|
||||||
|
let event = match event {
|
||||||
|
Ok(event) => event,
|
||||||
|
Err(e) => return Some(Err(e.into())),
|
||||||
|
};
|
||||||
|
|
||||||
|
let Event::Message(event_message) = event else {
|
||||||
|
return None;
|
||||||
|
};
|
||||||
|
|
||||||
|
let gemini_response: GenerateContentResponse =
|
||||||
|
match serde_json::from_str(&event_message.data) {
|
||||||
|
Ok(gemini_response) => gemini_response,
|
||||||
|
Err(e) => return Some(Err(e.into())),
|
||||||
|
};
|
||||||
|
|
||||||
|
let gemini_response = match gemini_response.into_result() {
|
||||||
|
Ok(gemini_response) => gemini_response,
|
||||||
|
Err(e) => return Some(Err(e)),
|
||||||
|
};
|
||||||
|
|
||||||
|
Some(Ok(gemini_response))
|
||||||
|
});
|
||||||
|
Ok(mapped)
|
||||||
|
}
|
||||||
|
|
||||||
pub async fn stream_generate_content(
|
pub async fn stream_generate_content(
|
||||||
&self,
|
&self,
|
||||||
request: &GenerateContentRequest,
|
request: &GenerateContentRequest,
|
||||||
@@ -175,34 +221,6 @@ impl<T: TokenProvider + Clone> GeminiClient<T> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Sends a text prompt to the Vertex API using the Gemini Pro model and extracts the text
|
|
||||||
/// from the response.
|
|
||||||
#[deprecated(note = "Use `generate_content` instead")]
|
|
||||||
pub async fn prompt_text(
|
|
||||||
&self,
|
|
||||||
prompt: &str,
|
|
||||||
generation_config: Option<&GenerationConfig>,
|
|
||||||
) -> Result<String> {
|
|
||||||
let request = GenerateContentRequest {
|
|
||||||
contents: vec![Content {
|
|
||||||
role: Some(Role::User),
|
|
||||||
parts: Some(vec![Part::Text(prompt.to_string())]),
|
|
||||||
}],
|
|
||||||
generation_config: generation_config.cloned(),
|
|
||||||
tools: None,
|
|
||||||
system_instruction: None,
|
|
||||||
safety_settings: None,
|
|
||||||
};
|
|
||||||
|
|
||||||
let response = self.generate_content(&request, "gemini-pro").await?;
|
|
||||||
let mut candidates = GeminiClient::<T>::collect_text_from_response(&response);
|
|
||||||
|
|
||||||
match candidates.pop() {
|
|
||||||
Some(candidate) => Ok(candidate),
|
|
||||||
None => Err(Error::NoCandidatesError),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn collect_text_from_response(response: &GenerateContentResponseResult) -> Vec<String> {
|
fn collect_text_from_response(response: &GenerateContentResponseResult) -> Vec<String> {
|
||||||
response
|
response
|
||||||
.candidates
|
.candidates
|
||||||
@@ -278,10 +296,10 @@ impl<T: TokenProvider + Clone> GeminiClient<T> {
|
|||||||
let txt_json = resp.text().await?;
|
let txt_json = resp.text().await?;
|
||||||
|
|
||||||
match serde_json::from_str::<PredictImageResponse>(&txt_json) {
|
match serde_json::from_str::<PredictImageResponse>(&txt_json) {
|
||||||
Ok(response) => return Ok(response),
|
Ok(response) => Ok(response),
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
error!(response = txt_json, error = ?e, "Failed to parse response");
|
error!(response = txt_json, error = ?e, "Failed to parse response");
|
||||||
return Err(e.into());
|
Err(e.into())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
12
src/error.rs
12
src/error.rs
@@ -14,7 +14,8 @@ pub enum Error {
|
|||||||
Serde(serde_json::Error),
|
Serde(serde_json::Error),
|
||||||
VertexError(types::VertexApiError),
|
VertexError(types::VertexApiError),
|
||||||
NoCandidatesError,
|
NoCandidatesError,
|
||||||
EventSourceError(CannotCloneRequestError),
|
CannotCloneRequestError(CannotCloneRequestError),
|
||||||
|
EventSourceError(reqwest_eventsource::Error),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Display for Error {
|
impl Display for Error {
|
||||||
@@ -30,6 +31,9 @@ impl Display for Error {
|
|||||||
Error::NoCandidatesError => {
|
Error::NoCandidatesError => {
|
||||||
write!(f, "No candidates returned for the prompt")
|
write!(f, "No candidates returned for the prompt")
|
||||||
}
|
}
|
||||||
|
Error::CannotCloneRequestError(e) => {
|
||||||
|
write!(f, "Cannot clone request: {}", e)
|
||||||
|
}
|
||||||
Error::EventSourceError(e) => {
|
Error::EventSourceError(e) => {
|
||||||
write!(f, "EventSourrce Error: {}", e)
|
write!(f, "EventSourrce Error: {}", e)
|
||||||
}
|
}
|
||||||
@@ -71,6 +75,12 @@ impl From<types::VertexApiError> for Error {
|
|||||||
|
|
||||||
impl From<CannotCloneRequestError> for Error {
|
impl From<CannotCloneRequestError> for Error {
|
||||||
fn from(e: CannotCloneRequestError) -> Self {
|
fn from(e: CannotCloneRequestError) -> Self {
|
||||||
|
Error::CannotCloneRequestError(e)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<reqwest_eventsource::Error> for Error {
|
||||||
|
fn from(e: reqwest_eventsource::Error) -> Self {
|
||||||
Error::EventSourceError(e)
|
Error::EventSourceError(e)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ pub trait TokenProvider {
|
|||||||
-> impl std::future::Future<Output = Result<String>> + Send;
|
-> impl std::future::Future<Output = Result<String>> + Send;
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a> TokenProvider for Arc<dyn gcp_auth::TokenProvider + 'a> {
|
impl TokenProvider for Arc<dyn gcp_auth::TokenProvider> {
|
||||||
async fn get_token(&self, scope: &[&str]) -> Result<String> {
|
async fn get_token(&self, scope: &[&str]) -> Result<String> {
|
||||||
let token = self.token(scope).await;
|
let token = self.token(scope).await;
|
||||||
match token {
|
match token {
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
use std::{collections::HashMap, str::FromStr, vec};
|
use std::{collections::HashMap, fmt::Display, str::FromStr, vec};
|
||||||
|
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
@@ -22,21 +22,16 @@ impl Content {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn builder() -> ContentBuilder {
|
pub fn builder() -> ContentBuilder {
|
||||||
ContentBuilder::new()
|
ContentBuilder::default()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Default)]
|
||||||
pub struct ContentBuilder {
|
pub struct ContentBuilder {
|
||||||
content: Content,
|
content: Content,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ContentBuilder {
|
impl ContentBuilder {
|
||||||
pub fn new() -> Self {
|
|
||||||
Self {
|
|
||||||
content: Default::default(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn add_text_part<T: Into<String>>(self, text: T) -> Self {
|
pub fn add_text_part<T: Into<String>>(self, text: T) -> Self {
|
||||||
self.add_part(Part::Text(text.into()))
|
self.add_part(Part::Text(text.into()))
|
||||||
}
|
}
|
||||||
@@ -66,12 +61,13 @@ pub enum Role {
|
|||||||
Model,
|
Model,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ToString for Role {
|
impl Display for Role {
|
||||||
fn to_string(&self) -> String {
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
match self {
|
let role_str = match self {
|
||||||
Role::User => "user".to_string(),
|
Role::User => "user",
|
||||||
Role::Model => "model".to_string(),
|
Role::Model => "model",
|
||||||
}
|
};
|
||||||
|
f.write_str(role_str)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -9,21 +9,16 @@ pub struct CountTokensRequest {
|
|||||||
|
|
||||||
impl CountTokensRequest {
|
impl CountTokensRequest {
|
||||||
pub fn builder() -> CountTokensRequestBuilder {
|
pub fn builder() -> CountTokensRequestBuilder {
|
||||||
CountTokensRequestBuilder::new()
|
CountTokensRequestBuilder::default()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Default)]
|
||||||
pub struct CountTokensRequestBuilder {
|
pub struct CountTokensRequestBuilder {
|
||||||
contents: Content,
|
contents: Content,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl CountTokensRequestBuilder {
|
impl CountTokensRequestBuilder {
|
||||||
pub fn new() -> Self {
|
|
||||||
CountTokensRequestBuilder {
|
|
||||||
contents: Content::default(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn from_prompt(prompt: &str) -> Self {
|
pub fn from_prompt(prompt: &str) -> Self {
|
||||||
CountTokensRequestBuilder {
|
CountTokensRequestBuilder {
|
||||||
contents: Content {
|
contents: Content {
|
||||||
|
|||||||
@@ -70,6 +70,7 @@ pub struct PredictImageRequestParameters {
|
|||||||
/// - "watercolor"
|
/// - "watercolor"
|
||||||
/// - "cyberpunk"
|
/// - "cyberpunk"
|
||||||
/// - "pop_art"
|
/// - "pop_art"
|
||||||
|
///
|
||||||
/// Pre-defined styles is only supported for model imagegeneration@002
|
/// Pre-defined styles is only supported for model imagegeneration@002
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub sample_image_style: Option<String>,
|
pub sample_image_style: Option<String>,
|
||||||
@@ -90,9 +91,9 @@ pub struct PredictImageRequestParameters {
|
|||||||
/// - `"block_most"`: Strongest filtering level, most strict blocking.
|
/// - `"block_most"`: Strongest filtering level, most strict blocking.
|
||||||
/// - `"block_some"`: Block some problematic prompts and responses.
|
/// - `"block_some"`: Block some problematic prompts and responses.
|
||||||
/// - `"block_few"`: Reduces the number of requests blocked due to safety filters. May
|
/// - `"block_few"`: Reduces the number of requests blocked due to safety filters. May
|
||||||
/// increase objectionable content generated by Imagen.
|
/// increase objectionable content generated by Imagen.
|
||||||
/// - `"block_fewest"`: Block very few problematic prompts and responses. Access to this
|
/// - `"block_fewest"`: Block very few problematic prompts and responses. Access to this
|
||||||
/// feature is restricted.
|
/// feature is restricted.
|
||||||
///
|
///
|
||||||
/// The default value is `"block_some"`.
|
/// The default value is `"block_some"`.
|
||||||
///
|
///
|
||||||
|
|||||||
Reference in New Issue
Block a user