diff --git a/Cargo.toml b/Cargo.toml index 616ed9c..d0014ef 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,11 +6,15 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] +deadqueue = "0.2.4" +futures-util = "0.3.30" gcp_auth = "0.10.0" reqwest = { version = "0.11.9", features = ["json", "gzip"] } +reqwest-eventsource = "0.5.0" serde = { version = "*", features = ["derive"] } serde_json = { version = "*"} tracing = "0.1.40" +tokio = { version = "1.36.0" } [dev-dependencies] console = "0.15.8" diff --git a/examples/text-from-text-streaming.rs b/examples/text-from-text-streaming.rs new file mode 100644 index 0000000..2dfd41d --- /dev/null +++ b/examples/text-from-text-streaming.rs @@ -0,0 +1,39 @@ +use std::sync::Arc; + +use gemini_rs::prelude::*; + +use gcp_auth::AuthenticationManager; + +#[tokio::main] +async fn main() -> Result<(), Box> { + let authentication_manager = Arc::new(AuthenticationManager::new().await?); + let api_endpoint = std::env::var("API_ENDPOINT")?; + let project_id = std::env::var("PROJECT_ID")?; + let location_id = std::env::var("LOCATION_ID")?; + + let gemini = GeminiClient::new( + authentication_manager, + api_endpoint, + project_id, + location_id, + ); + + let prompt = "Tell me the story of the genesis of the universe as a bedtime story."; + let request = GenerateContentRequest::from_prompt(prompt, None); + let queue = gemini + .streaming_stream_generate_content(&request, Model::GeminiPro) + .await; + + while let Some(chunk) = queue.pop().await { + if let ResponseStreamChunk::Ok(ok_response) = chunk { + let text = ok_response + .candidates + .iter() + .filter_map(|c| c.get_text()) + .collect::(); + print!("{}", text); + } + } + + Ok(()) +} diff --git a/src/client.rs b/src/client.rs index 683e83d..2214978 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1,3 +1,9 @@ +use std::sync::Arc; + +use deadqueue::unlimited::Queue; +use futures_util::stream::StreamExt; +use reqwest_eventsource::{Event, EventSource}; + use crate::dialogue::{Message, Role}; use crate::error::{Error, Result}; use crate::prelude::{ @@ -47,6 +53,42 @@ impl GeminiClient { } } + pub async fn streaming_stream_generate_content( + &self, + request: &GenerateContentRequest, + model: Model, + ) -> Arc>> { + let queue = Arc::new(Queue::>::new()); + + // Clone the queue and other necessary data to move into the async block. + let cloned_queue = queue.clone(); + let access_token = self.token_provider.get_token(AUTH_SCOPE).await.unwrap(); + let endpoint_url: String = format!( + "https://{}/v1beta1/projects/{}/locations/{}/publishers/google/models/{}:streamGenerateContent?alt=sse", self.api_endpoint, self.project_id, self.location_id, model.to_string(), + ); + let client = self.client.clone(); + let request = request.clone(); + + // Start a thread to run the request in the background. + tokio::spawn(async move { + let req = client + .post(&endpoint_url) + .bearer_auth(access_token) + .json(&request); + let mut event_source = EventSource::new(req).unwrap(); + while let Some(Ok(event)) = event_source.next().await { + if let Event::Message(event) = event { + let response: ResponseStreamChunk = serde_json::from_str(&event.data).unwrap(); + cloned_queue.push(Some(response)); + } + } + cloned_queue.push(None); + }); + + // Return the queue that will receive the responses. + queue + } + pub async fn stream_generate_content( &self, request: &GenerateContentRequest, @@ -125,6 +167,12 @@ impl GeminiClient { for chunk in response { match chunk { ResponseStreamChunk::Ok(ok_response) => { + ok_response.candidates.iter().for_each(|c| { + if let Some(t) = c.get_text() { + text.push_str(&t); + } + }); + for candidate in ok_response.candidates { if let Some(parts) = &candidate.content.parts { for part in parts { diff --git a/src/types.rs b/src/types.rs index 8de4693..e74c892 100644 --- a/src/types.rs +++ b/src/types.rs @@ -13,24 +13,51 @@ pub struct CountTokensResponse { pub total_tokens: i32, } -#[derive(Serialize, Deserialize)] +#[derive(Clone, Serialize, Deserialize)] pub struct GenerateContentRequest { pub contents: Vec, pub generation_config: Option, pub tools: Option>, } -#[derive(Serialize, Deserialize)] +impl GenerateContentRequest { + pub fn from_prompt(prompt: &str, generation_config: Option) -> Self { + GenerateContentRequest { + contents: vec![Content { + role: "user".to_string(), + parts: Some(vec![Part::Text(prompt.to_string())]), + }], + generation_config, + tools: None, + } + } +} + +#[derive(Clone, Serialize, Deserialize)] pub struct Tools { pub function_declarations: Option>, } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Clone, Debug, Serialize, Deserialize)] pub struct Content { pub role: String, pub parts: Option>, } +impl Content { + pub fn get_text(&self) -> Option { + self.parts.as_ref().map(|parts| { + parts + .iter() + .filter_map(|part| match part { + Part::Text(text) => Some(text.clone()), + _ => None, + }) + .collect::() + }) + } +} + #[derive(Clone, Debug, Serialize, Deserialize, Default)] #[serde(rename_all = "camelCase")] pub struct GenerationConfig { @@ -42,7 +69,7 @@ pub struct GenerationConfig { pub candidate_count: Option, } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Clone, Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub enum Part { Text(String), @@ -86,6 +113,12 @@ pub struct Candidate { pub finish_reason: Option, } +impl Candidate { + pub fn get_text(&self) -> Option { + self.content.get_text() + } +} + #[derive(Debug, Serialize, Deserialize)] pub struct SafetyRating { pub category: String, @@ -113,7 +146,7 @@ pub struct UsageMetadata { pub total_token_count: i32, } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Clone, Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct FunctionDeclaration { pub name: String, @@ -121,7 +154,7 @@ pub struct FunctionDeclaration { pub parameters: FunctionParameters, } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Clone, Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct FunctionParameters { pub r#type: String, @@ -129,7 +162,7 @@ pub struct FunctionParameters { pub required: Vec, } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Clone, Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct FunctionParametersProperty { pub r#type: String,