diff --git a/Cargo.toml b/Cargo.toml index d7a9b3e..a2dcfc0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,7 +7,6 @@ edition = "2021" [dependencies] deadqueue = "0.2" -futures-util = "0.3" gcp_auth = "0.12" reqwest = { version = "0.12", features = ["json", "gzip"] } reqwest-eventsource = "0.6" @@ -16,6 +15,7 @@ serde_json = { version = "1"} serde_with = { version = "3.9", features = ["base64"]} tracing = "0.1" tokio = { version = "1" } +tokio-stream = "0.1.17" [dev-dependencies] console = "0.15.8" diff --git a/examples/text-from-text-streaming.rs b/examples/text-from-text-streaming.rs index 06de5f4..9f25647 100644 --- a/examples/text-from-text-streaming.rs +++ b/examples/text-from-text-streaming.rs @@ -1,4 +1,5 @@ use gemini_rs::prelude::*; +use tokio_stream::StreamExt; #[tokio::main] async fn main() -> Result<(), Box> { @@ -21,22 +22,12 @@ async fn main() -> Result<(), Box> { let request = GenerateContentRequest::builder().contents(prompt).build(); - let queue = gemini.stream_generate_content(&request, "gemini-pro").await; + let mut queue = gemini + .generate_content_stream(&request, "gemini-2.0-flash-001") + .await?; - while let Some(response) = queue.pop().await { - match response { - Ok(result) => { - let text = result - .candidates - .iter() - .filter_map(|c| c.get_text()) - .collect::(); - print!("{}", text); - } - Err(error) => { - println!("{error}"); - } - } + while let Some(Ok(response)) = queue.next().await { + println!("Response: {:?}", response); } Ok(()) diff --git a/src/client.rs b/src/client.rs index c54ded3..711db78 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1,7 +1,9 @@ +use crate::error::Result as GeminiResult; use std::sync::Arc; +use std::vec; +use tokio_stream::{Stream, StreamExt}; use deadqueue::unlimited::Queue; -use futures_util::stream::StreamExt; use reqwest_eventsource::{Event, EventSource}; use tracing::error; @@ -9,7 +11,7 @@ use crate::dialogue::Message; use crate::error::{Error, Result}; use crate::prelude::{ Candidate, Content, CountTokensRequest, CountTokensResponse, GenerateContentRequest, - GenerateContentResponse, GenerateContentResponseResult, GenerationConfig, TextEmbeddingRequest, + GenerateContentResponse, GenerateContentResponseResult, TextEmbeddingRequest, TextEmbeddingResponse, }; use crate::types::{PredictImageRequest, PredictImageResponse, Role}; @@ -45,6 +47,50 @@ impl GeminiClient { } } + pub async fn generate_content_stream( + &self, + request: &GenerateContentRequest, + model: &str, + ) -> Result>> { + let access_token = self.token_provider.get_token(AUTH_SCOPE).await.unwrap(); + let endpoint_url = format!( + "https://{}/v1beta1/projects/{}/locations/{}/publishers/google/models/{}:streamGenerateContent?alt=sse", self.api_endpoint, self.project_id, self.location_id, model, + ); + let client = self.client.clone(); + let request = request.clone(); + let req = client + .post(&endpoint_url) + .bearer_auth(access_token) + .json(&request); + + let event_source = EventSource::new(req).unwrap(); + + let mapped = event_source.filter_map(|event| { + let event = match event { + Ok(event) => event, + Err(e) => return Some(Err(e.into())), + }; + + let Event::Message(event_message) = event else { + return None; + }; + + let gemini_response: GenerateContentResponse = + match serde_json::from_str(&event_message.data) { + Ok(gemini_response) => gemini_response, + Err(e) => return Some(Err(e.into())), + }; + + let gemini_response = match gemini_response.into_result() { + Ok(gemini_response) => gemini_response, + Err(e) => return Some(Err(e)), + }; + + Some(Ok(gemini_response)) + }); + Ok(mapped) + } + pub async fn stream_generate_content( &self, request: &GenerateContentRequest, @@ -175,34 +221,6 @@ impl GeminiClient { } } - /// Sends a text prompt to the Vertex API using the Gemini Pro model and extracts the text - /// from the response. - #[deprecated(note = "Use `generate_content` instead")] - pub async fn prompt_text( - &self, - prompt: &str, - generation_config: Option<&GenerationConfig>, - ) -> Result { - let request = GenerateContentRequest { - contents: vec![Content { - role: Some(Role::User), - parts: Some(vec![Part::Text(prompt.to_string())]), - }], - generation_config: generation_config.cloned(), - tools: None, - system_instruction: None, - safety_settings: None, - }; - - let response = self.generate_content(&request, "gemini-pro").await?; - let mut candidates = GeminiClient::::collect_text_from_response(&response); - - match candidates.pop() { - Some(candidate) => Ok(candidate), - None => Err(Error::NoCandidatesError), - } - } - fn collect_text_from_response(response: &GenerateContentResponseResult) -> Vec { response .candidates @@ -278,10 +296,10 @@ impl GeminiClient { let txt_json = resp.text().await?; match serde_json::from_str::(&txt_json) { - Ok(response) => return Ok(response), + Ok(response) => Ok(response), Err(e) => { error!(response = txt_json, error = ?e, "Failed to parse response"); - return Err(e.into()); + Err(e.into()) } } } diff --git a/src/error.rs b/src/error.rs index 2a60e87..ad43c29 100644 --- a/src/error.rs +++ b/src/error.rs @@ -14,7 +14,8 @@ pub enum Error { Serde(serde_json::Error), VertexError(types::VertexApiError), NoCandidatesError, - EventSourceError(CannotCloneRequestError), + CannotCloneRequestError(CannotCloneRequestError), + EventSourceError(reqwest_eventsource::Error), } impl Display for Error { @@ -30,6 +31,9 @@ impl Display for Error { Error::NoCandidatesError => { write!(f, "No candidates returned for the prompt") } + Error::CannotCloneRequestError(e) => { + write!(f, "Cannot clone request: {}", e) + } Error::EventSourceError(e) => { write!(f, "EventSourrce Error: {}", e) } @@ -71,6 +75,12 @@ impl From for Error { impl From for Error { fn from(e: CannotCloneRequestError) -> Self { + Error::CannotCloneRequestError(e) + } +} + +impl From for Error { + fn from(e: reqwest_eventsource::Error) -> Self { Error::EventSourceError(e) } } diff --git a/src/token_provider.rs b/src/token_provider.rs index dd80900..f79284f 100644 --- a/src/token_provider.rs +++ b/src/token_provider.rs @@ -7,7 +7,7 @@ pub trait TokenProvider { -> impl std::future::Future> + Send; } -impl<'a> TokenProvider for Arc { +impl TokenProvider for Arc { async fn get_token(&self, scope: &[&str]) -> Result { let token = self.token(scope).await; match token { diff --git a/src/types/common.rs b/src/types/common.rs index 8bbdeea..f17d522 100644 --- a/src/types/common.rs +++ b/src/types/common.rs @@ -1,4 +1,4 @@ -use std::{collections::HashMap, str::FromStr, vec}; +use std::{collections::HashMap, fmt::Display, str::FromStr, vec}; use serde::{Deserialize, Serialize}; @@ -22,21 +22,16 @@ impl Content { } pub fn builder() -> ContentBuilder { - ContentBuilder::new() + ContentBuilder::default() } } +#[derive(Default)] pub struct ContentBuilder { content: Content, } impl ContentBuilder { - pub fn new() -> Self { - Self { - content: Default::default(), - } - } - pub fn add_text_part>(self, text: T) -> Self { self.add_part(Part::Text(text.into())) } @@ -66,12 +61,13 @@ pub enum Role { Model, } -impl ToString for Role { - fn to_string(&self) -> String { - match self { - Role::User => "user".to_string(), - Role::Model => "model".to_string(), - } +impl Display for Role { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let role_str = match self { + Role::User => "user", + Role::Model => "model", + }; + f.write_str(role_str) } } diff --git a/src/types/count_tokens.rs b/src/types/count_tokens.rs index a152055..b586339 100644 --- a/src/types/count_tokens.rs +++ b/src/types/count_tokens.rs @@ -9,21 +9,16 @@ pub struct CountTokensRequest { impl CountTokensRequest { pub fn builder() -> CountTokensRequestBuilder { - CountTokensRequestBuilder::new() + CountTokensRequestBuilder::default() } } +#[derive(Default)] pub struct CountTokensRequestBuilder { contents: Content, } impl CountTokensRequestBuilder { - pub fn new() -> Self { - CountTokensRequestBuilder { - contents: Content::default(), - } - } - pub fn from_prompt(prompt: &str) -> Self { CountTokensRequestBuilder { contents: Content { diff --git a/src/types/predict_image.rs b/src/types/predict_image.rs index 9f08ed8..b342485 100644 --- a/src/types/predict_image.rs +++ b/src/types/predict_image.rs @@ -70,6 +70,7 @@ pub struct PredictImageRequestParameters { /// - "watercolor" /// - "cyberpunk" /// - "pop_art" + /// /// Pre-defined styles is only supported for model imagegeneration@002 #[serde(skip_serializing_if = "Option::is_none")] pub sample_image_style: Option, @@ -90,9 +91,9 @@ pub struct PredictImageRequestParameters { /// - `"block_most"`: Strongest filtering level, most strict blocking. /// - `"block_some"`: Block some problematic prompts and responses. /// - `"block_few"`: Reduces the number of requests blocked due to safety filters. May - /// increase objectionable content generated by Imagen. + /// increase objectionable content generated by Imagen. /// - `"block_fewest"`: Block very few problematic prompts and responses. Access to this - /// feature is restricted. + /// feature is restricted. /// /// The default value is `"block_some"`. ///