From 941614e550cbbd4b972fb095a42b1b8e68f8cd46 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9=20Cipriani=20Bandarra?= Date: Wed, 28 Jan 2026 21:34:20 +0000 Subject: [PATCH] Initial EventSource implementationo --- Cargo.lock | 40 +++++++++++++++- Cargo.toml | 3 ++ examples/sse.rs | 37 +++++++++++++++ src/lib.rs | 1 + src/network/event_source.rs | 93 +++++++++++++++++++++++++++++++++++++ src/network/mod.rs | 1 + 6 files changed, 173 insertions(+), 2 deletions(-) create mode 100644 examples/sse.rs create mode 100644 src/network/event_source.rs create mode 100644 src/network/mod.rs diff --git a/Cargo.lock b/Cargo.lock index 80086e8..d485ad8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -86,6 +86,39 @@ dependencies = [ "tokio", ] +[[package]] +name = "async-stream" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b5a71a6f37880a80d1d7f19efd781e4b5de42c88f0722cc13bcb6cc2cfe8476" +dependencies = [ + "async-stream-impl", + "futures-core", + "pin-project-lite", +] + +[[package]] +name = "async-stream-impl" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c7c24de15d275a1ecfd47a380fb4d5ec9bfe0933f309ed5e705b775596a3574d" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "async-trait" +version = "0.1.89" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9035ad2d096bed7955a320ee7e2230574d28fd3c3a0f186cbea1ff3c7eed5dbb" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "atomic-waker" version = "1.1.2" @@ -696,6 +729,8 @@ dependencies = [ name = "google-genai" version = "0.1.0" dependencies = [ + "async-stream", + "async-trait", "console", "deadqueue", "dialoguer", @@ -709,6 +744,7 @@ dependencies = [ "serde_with", "tokio", "tokio-stream", + "tokio-util", "tracing", "tracing-subscriber", ] @@ -2334,9 +2370,9 @@ dependencies = [ [[package]] name = "tokio-util" -version = "0.7.17" +version = "0.7.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2efa149fe76073d6e8fd97ef4f4eca7b67f599660115591483572e406e165594" +checksum = "9ae9cec805b01e8fc3fd2fe289f89149a9b66dd16786abd8b19cfa7b48cb0098" dependencies = [ "bytes", "futures-core", diff --git a/Cargo.toml b/Cargo.toml index 0781dc2..908a62b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,6 +18,9 @@ serde_with = { version = "3.16", features = ["base64"] } tracing = "0.1" tokio = { version = "1" } tokio-stream = "0.1" +async-stream = "0.3.6" +tokio-util = "0.7.18" +async-trait = "0.1.89" [dev-dependencies] console = "0.16.2" diff --git a/examples/sse.rs b/examples/sse.rs new file mode 100644 index 0000000..a2b7a11 --- /dev/null +++ b/examples/sse.rs @@ -0,0 +1,37 @@ +use std::env; + +use google_genai::{ + network::event_source::EventSource, + prelude::{Content, GenerateContentRequest, Role}, +}; +use tokio_stream::StreamExt; + +static MODEL: &str = "gemini-2.5-flash"; + +#[tokio::main] +pub async fn main() { + let prompt = vec![ + Content::builder() + .role(Role::User) + .add_text_part("What is the airspeed of an unladen swallow?") + .build(), + ]; + let request = GenerateContentRequest::builder().contents(prompt).build(); + let _ = dotenvy::dotenv(); + let api_key = env::var("GEMINI_API_KEY").unwrap(); + let client = reqwest::Client::new(); + let endpoint_url = format!( + "https://generativelanguage.googleapis.com/v1beta/models/{MODEL}:streamGenerateContent?alt=sse" + ); + let mut event_stream = client + .post(&endpoint_url) + .header("x-goog-api-key", api_key) + .json(&request) + .send() + .await + .unwrap() + .event_stream(); + while let Some(event) = event_stream.next().await { + println!("{event:?}") + } +} diff --git a/src/lib.rs b/src/lib.rs index b97fb81..7ca3236 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,6 +1,7 @@ mod client; mod dialogue; pub mod error; +pub mod network; mod types; pub mod prelude { diff --git a/src/network/event_source.rs b/src/network/event_source.rs new file mode 100644 index 0000000..c835f36 --- /dev/null +++ b/src/network/event_source.rs @@ -0,0 +1,93 @@ +use reqwest::Response; +use std::mem; +use tokio_stream::{Stream, StreamExt}; +use tokio_util::{ + codec::{Decoder, FramedRead, LinesCodec, LinesCodecError}, + io::StreamReader, +}; +use tracing::warn; + +static EVENT: &str = "event: "; +static DATA: &str = "data: "; +static ID: &str = "id: "; +static RETRY: &str = "retry: "; + +pub trait EventSource { + fn event_stream(self) -> impl Stream>; +} + +impl EventSource for Response { + fn event_stream(self) -> impl Stream> { + stream_response(self) + } +} + +#[derive(Debug, Default, Clone)] +pub struct ServerSentEvent { + pub event: Option, + pub data: Option, + pub id: Option, + pub retry: Option, +} + +pub struct ServerSentEventsCodec { + lines_code: LinesCodec, + next: ServerSentEvent, +} + +impl ServerSentEventsCodec { + pub fn new() -> Self { + Self { + lines_code: LinesCodec::new(), + next: Default::default(), + } + } +} + +impl Decoder for ServerSentEventsCodec { + type Item = ServerSentEvent; + type Error = LinesCodecError; + fn decode( + &mut self, + src: &mut tokio_util::bytes::BytesMut, + ) -> Result, Self::Error> { + let res = self.lines_code.decode(src)?; + + let Some(mut line) = res else { + return Ok(None); + }; + + if line.is_empty() { + let result = mem::take(&mut self.next); + return Ok(Some(result)); + } + + if line.starts_with(EVENT) { + line.drain(..EVENT.len()); + self.next.event = Some(line); + } else if line.starts_with(DATA) { + line.drain(..DATA.len()); + self.next.data = Some(line) + } else if line.starts_with(ID) { + line.drain(..ID.len()); + self.next.id = Some(line); + } else if line.starts_with(RETRY) { + line.drain(..RETRY.len()); + let Ok(retry) = line.parse() else { + warn!(line, "Received invalid retry value"); + return Ok(None); + }; + self.next.retry = Some(retry); + } + + Ok(None) + } +} + +pub fn stream_response( + response: Response, +) -> impl Stream> { + let bytes_stream = response.bytes_stream(); + let body_reader = StreamReader::new(bytes_stream.map(|res| res.map_err(std::io::Error::other))); + FramedRead::new(body_reader, ServerSentEventsCodec::new()) +} diff --git a/src/network/mod.rs b/src/network/mod.rs new file mode 100644 index 0000000..aafb61f --- /dev/null +++ b/src/network/mod.rs @@ -0,0 +1 @@ +pub mod event_source;