Initial EventSource implementationo
This commit is contained in:
40
Cargo.lock
generated
40
Cargo.lock
generated
@@ -86,6 +86,39 @@ dependencies = [
|
||||
"tokio",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "async-stream"
|
||||
version = "0.3.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0b5a71a6f37880a80d1d7f19efd781e4b5de42c88f0722cc13bcb6cc2cfe8476"
|
||||
dependencies = [
|
||||
"async-stream-impl",
|
||||
"futures-core",
|
||||
"pin-project-lite",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "async-stream-impl"
|
||||
version = "0.3.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c7c24de15d275a1ecfd47a380fb4d5ec9bfe0933f309ed5e705b775596a3574d"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "async-trait"
|
||||
version = "0.1.89"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9035ad2d096bed7955a320ee7e2230574d28fd3c3a0f186cbea1ff3c7eed5dbb"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "atomic-waker"
|
||||
version = "1.1.2"
|
||||
@@ -696,6 +729,8 @@ dependencies = [
|
||||
name = "google-genai"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"async-stream",
|
||||
"async-trait",
|
||||
"console",
|
||||
"deadqueue",
|
||||
"dialoguer",
|
||||
@@ -709,6 +744,7 @@ dependencies = [
|
||||
"serde_with",
|
||||
"tokio",
|
||||
"tokio-stream",
|
||||
"tokio-util",
|
||||
"tracing",
|
||||
"tracing-subscriber",
|
||||
]
|
||||
@@ -2334,9 +2370,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "tokio-util"
|
||||
version = "0.7.17"
|
||||
version = "0.7.18"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2efa149fe76073d6e8fd97ef4f4eca7b67f599660115591483572e406e165594"
|
||||
checksum = "9ae9cec805b01e8fc3fd2fe289f89149a9b66dd16786abd8b19cfa7b48cb0098"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"futures-core",
|
||||
|
||||
@@ -18,6 +18,9 @@ serde_with = { version = "3.16", features = ["base64"] }
|
||||
tracing = "0.1"
|
||||
tokio = { version = "1" }
|
||||
tokio-stream = "0.1"
|
||||
async-stream = "0.3.6"
|
||||
tokio-util = "0.7.18"
|
||||
async-trait = "0.1.89"
|
||||
|
||||
[dev-dependencies]
|
||||
console = "0.16.2"
|
||||
|
||||
37
examples/sse.rs
Normal file
37
examples/sse.rs
Normal file
@@ -0,0 +1,37 @@
|
||||
use std::env;
|
||||
|
||||
use google_genai::{
|
||||
network::event_source::EventSource,
|
||||
prelude::{Content, GenerateContentRequest, Role},
|
||||
};
|
||||
use tokio_stream::StreamExt;
|
||||
|
||||
static MODEL: &str = "gemini-2.5-flash";
|
||||
|
||||
#[tokio::main]
|
||||
pub async fn main() {
|
||||
let prompt = vec![
|
||||
Content::builder()
|
||||
.role(Role::User)
|
||||
.add_text_part("What is the airspeed of an unladen swallow?")
|
||||
.build(),
|
||||
];
|
||||
let request = GenerateContentRequest::builder().contents(prompt).build();
|
||||
let _ = dotenvy::dotenv();
|
||||
let api_key = env::var("GEMINI_API_KEY").unwrap();
|
||||
let client = reqwest::Client::new();
|
||||
let endpoint_url = format!(
|
||||
"https://generativelanguage.googleapis.com/v1beta/models/{MODEL}:streamGenerateContent?alt=sse"
|
||||
);
|
||||
let mut event_stream = client
|
||||
.post(&endpoint_url)
|
||||
.header("x-goog-api-key", api_key)
|
||||
.json(&request)
|
||||
.send()
|
||||
.await
|
||||
.unwrap()
|
||||
.event_stream();
|
||||
while let Some(event) = event_stream.next().await {
|
||||
println!("{event:?}")
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
mod client;
|
||||
mod dialogue;
|
||||
pub mod error;
|
||||
pub mod network;
|
||||
mod types;
|
||||
|
||||
pub mod prelude {
|
||||
|
||||
93
src/network/event_source.rs
Normal file
93
src/network/event_source.rs
Normal file
@@ -0,0 +1,93 @@
|
||||
use reqwest::Response;
|
||||
use std::mem;
|
||||
use tokio_stream::{Stream, StreamExt};
|
||||
use tokio_util::{
|
||||
codec::{Decoder, FramedRead, LinesCodec, LinesCodecError},
|
||||
io::StreamReader,
|
||||
};
|
||||
use tracing::warn;
|
||||
|
||||
static EVENT: &str = "event: ";
|
||||
static DATA: &str = "data: ";
|
||||
static ID: &str = "id: ";
|
||||
static RETRY: &str = "retry: ";
|
||||
|
||||
pub trait EventSource {
|
||||
fn event_stream(self) -> impl Stream<Item = Result<ServerSentEvent, LinesCodecError>>;
|
||||
}
|
||||
|
||||
impl EventSource for Response {
|
||||
fn event_stream(self) -> impl Stream<Item = Result<ServerSentEvent, LinesCodecError>> {
|
||||
stream_response(self)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Clone)]
|
||||
pub struct ServerSentEvent {
|
||||
pub event: Option<String>,
|
||||
pub data: Option<String>,
|
||||
pub id: Option<String>,
|
||||
pub retry: Option<usize>,
|
||||
}
|
||||
|
||||
pub struct ServerSentEventsCodec {
|
||||
lines_code: LinesCodec,
|
||||
next: ServerSentEvent,
|
||||
}
|
||||
|
||||
impl ServerSentEventsCodec {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
lines_code: LinesCodec::new(),
|
||||
next: Default::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Decoder for ServerSentEventsCodec {
|
||||
type Item = ServerSentEvent;
|
||||
type Error = LinesCodecError;
|
||||
fn decode(
|
||||
&mut self,
|
||||
src: &mut tokio_util::bytes::BytesMut,
|
||||
) -> Result<Option<Self::Item>, Self::Error> {
|
||||
let res = self.lines_code.decode(src)?;
|
||||
|
||||
let Some(mut line) = res else {
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
if line.is_empty() {
|
||||
let result = mem::take(&mut self.next);
|
||||
return Ok(Some(result));
|
||||
}
|
||||
|
||||
if line.starts_with(EVENT) {
|
||||
line.drain(..EVENT.len());
|
||||
self.next.event = Some(line);
|
||||
} else if line.starts_with(DATA) {
|
||||
line.drain(..DATA.len());
|
||||
self.next.data = Some(line)
|
||||
} else if line.starts_with(ID) {
|
||||
line.drain(..ID.len());
|
||||
self.next.id = Some(line);
|
||||
} else if line.starts_with(RETRY) {
|
||||
line.drain(..RETRY.len());
|
||||
let Ok(retry) = line.parse() else {
|
||||
warn!(line, "Received invalid retry value");
|
||||
return Ok(None);
|
||||
};
|
||||
self.next.retry = Some(retry);
|
||||
}
|
||||
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn stream_response(
|
||||
response: Response,
|
||||
) -> impl Stream<Item = Result<ServerSentEvent, LinesCodecError>> {
|
||||
let bytes_stream = response.bytes_stream();
|
||||
let body_reader = StreamReader::new(bytes_stream.map(|res| res.map_err(std::io::Error::other)));
|
||||
FramedRead::new(body_reader, ServerSentEventsCodec::new())
|
||||
}
|
||||
1
src/network/mod.rs
Normal file
1
src/network/mod.rs
Normal file
@@ -0,0 +1 @@
|
||||
pub mod event_source;
|
||||
Reference in New Issue
Block a user