From 1c431715c666fea4bd4efed79c2ea7cec3777bc3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9=20Cipriani=20Bandarra?= Date: Sun, 28 Dec 2025 12:50:47 +0000 Subject: [PATCH] Unifies streaming code --- examples/chat.rs | 3 ++ examples/generate.rs | 3 ++ src/lib.rs | 95 ++++++++++---------------------------------- 3 files changed, 26 insertions(+), 75 deletions(-) diff --git a/examples/chat.rs b/examples/chat.rs index 9d2a154..e7efaf1 100644 --- a/examples/chat.rs +++ b/examples/chat.rs @@ -39,6 +39,9 @@ async fn main() -> Result<(), Box> { full_message += &response.message.content; print!("{}", response.message.content); std::io::stdout().flush()?; + if response.done { + break; + } } println!(); diff --git a/examples/generate.rs b/examples/generate.rs index 5fa9820..64ee242 100644 --- a/examples/generate.rs +++ b/examples/generate.rs @@ -21,6 +21,9 @@ async fn main() -> Result<(), Box> { Ok(token) => { print!("{}", token.response); std::io::stdout().flush()?; + if token.done { + break; + } } Err(e) => println!("Error: {}", e), } diff --git a/src/lib.rs b/src/lib.rs index 0951dd6..30641bb 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,5 +1,6 @@ use async_stream::stream; use futures_util::{Stream, StreamExt}; +use serde::{Serialize, de::DeserializeOwned}; use serde_json::Value; use tokio_util::{ codec::{FramedRead, LinesCodec}, @@ -76,27 +77,24 @@ impl OllamaClient { Ok(models) } - /// Generates a response for the provided prompt - pub async fn generate( + async fn stream_response( &self, - request: GenerateRequest, - ) -> impl Stream> { - let request_address = format!("{}/api/generate", self.server_address); + endpoint: String, + request: R, + ) -> impl Stream> { let client = reqwest::Client::new(); - - // The stream macro creates an asynchronous generator Box::pin(stream! { let response = client - .post(request_address) + .post(endpoint) .json(&request) .send() .await - .map_err(|e| OllamaError::from(e))?; // Adjust based on your error type + .map_err(OllamaError::from)?; // Adjust based on your error type let bytes_stream = response.bytes_stream(); let body_reader = StreamReader::new( - bytes_stream.map(|res| res.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))), + bytes_stream.map(|res| res.map_err(std::io::Error::other)), ); let mut lines_stream = FramedRead::new(body_reader, LinesCodec::new()); @@ -104,10 +102,8 @@ impl OllamaClient { while let Some(line_result) = lines_stream.next().await { match line_result { Ok(line_content) => { - if let Ok(parsed) = serde_json::from_str::(&line_content) { - let done = parsed.done; + if let Ok(parsed) = serde_json::from_str::(&line_content) { yield Ok(parsed); - if done { break; } } } Err(e) => yield Err(OllamaError::from(e)), @@ -116,44 +112,22 @@ impl OllamaClient { }) } + /// Generates a response for the provided prompt + pub async fn generate( + &self, + request: GenerateRequest, + ) -> impl Stream> { + let request_address = format!("{}/api/generate", self.server_address); + self.stream_response(request_address, request).await + } + /// Generate the next chat message in a conversation between a user and an assistant. pub async fn chat( &self, request: ChatRequest, ) -> impl Stream> { let request_address = format!("{}/api/chat", self.server_address); - let client = reqwest::Client::new(); - - // The stream macro creates an asynchronous generator - Box::pin(stream! { - let response = client - .post(request_address) - .json(&request) - .send() - .await - .map_err(|e| OllamaError::from(e))?; // Adjust based on your error type - - let bytes_stream = response.bytes_stream(); - - let body_reader = StreamReader::new( - bytes_stream.map(|res| res.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))), - ); - - let mut lines_stream = FramedRead::new(body_reader, LinesCodec::new()); - - while let Some(line_result) = lines_stream.next().await { - match line_result { - Ok(line_content) => { - if let Ok(parsed) = serde_json::from_str::(&line_content) { - let done = parsed.done; - yield Ok(parsed); - if done { break; } - } - } - Err(e) => yield Err(OllamaError::from(e)), - } - } - }) + self.stream_response(request_address, request).await } pub async fn pull( @@ -161,35 +135,6 @@ impl OllamaClient { request: PullRequest, ) -> impl Stream> { let request_address = format!("{}/api/pull", self.server_address); - let client = reqwest::Client::new(); - - Box::pin(stream! { - let response = client - .post(request_address) - .json(&request) - .send() - .await - .map_err(|e| OllamaError::from(e))?; // Adjust based on your error type - - let bytes_stream = response.bytes_stream(); - - let body_reader = StreamReader::new( - bytes_stream.map(|res| res.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))), - ); - - let mut lines_stream = FramedRead::new(body_reader, LinesCodec::new()); - - while let Some(line_result) = lines_stream.next().await { - match line_result { - Ok(line_content) => { - println!("{line_content}"); - if let Ok(parsed) = serde_json::from_str::(&line_content) { - yield Ok(parsed); - } - } - Err(e) => yield Err(OllamaError::from(e)), - } - } - }) + self.stream_response(request_address, request).await } }