From d06d69d132f774be1748e2042345b9aa6efa430f Mon Sep 17 00:00:00 2001 From: Andre Cipriani Bandarra Date: Wed, 24 Dec 2025 15:58:23 +0000 Subject: [PATCH] Finishes the generate implementation --- Cargo.lock | 23 +++++++++++++++ Cargo.toml | 1 + examples/generate.rs | 19 ++++++++++-- src/error.rs | 10 +++++++ src/lib.rs | 68 ++++++++++++++++++++++++++----------------- src/types/chat.rs | 0 src/types/generate.rs | 5 ++++ src/types/mod.rs | 1 + 8 files changed, 97 insertions(+), 30 deletions(-) create mode 100644 src/types/chat.rs diff --git a/Cargo.lock b/Cargo.lock index 911b5db..647337c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2,6 +2,28 @@ # It is not intended for manual editing. version = 4 +[[package]] +name = "async-stream" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b5a71a6f37880a80d1d7f19efd781e4b5de42c88f0722cc13bcb6cc2cfe8476" +dependencies = [ + "async-stream-impl", + "futures-core", + "pin-project-lite", +] + +[[package]] +name = "async-stream-impl" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c7c24de15d275a1ecfd47a380fb4d5ec9bfe0933f309ed5e705b775596a3574d" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "atomic-waker" version = "1.1.2" @@ -612,6 +634,7 @@ dependencies = [ name = "ollama-rs" version = "0.1.0" dependencies = [ + "async-stream", "dotenvy", "futures-util", "reqwest", diff --git a/Cargo.toml b/Cargo.toml index ad494e4..6a8c530 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,6 +10,7 @@ serde_json = "1.0.146" tokio-util = "0.7.17" tracing = "0.1.44" futures-util = "0.3.31" +async-stream = "0.3.6" [dev-dependencies] dotenvy = "0.15.7" diff --git a/examples/generate.rs b/examples/generate.rs index 1b7316a..f95669a 100644 --- a/examples/generate.rs +++ b/examples/generate.rs @@ -1,5 +1,6 @@ -use std::{env, error::Error}; +use std::{env, error::Error, io::Write}; +use futures_util::StreamExt; use ollama_rs::{OllamaClient, types::generate::GenerateRequest}; #[tokio::main] @@ -8,9 +9,21 @@ async fn main() -> Result<(), Box> { let server_address = env::var("OLLAMA_SERVER")?; let ollama_client = OllamaClient::new(server_address); let request = GenerateRequest::builder("dolphin3:8b") + .system_prompt("You a role play character called Gerald. You are a dumb person who things knows a lot but PROVIDES WRONG ANSWERS to all questions.") .prompt("Why is the sky blue?") .build(); - let response = ollama_client.generate(request).await?; - println!("{:?}", response); + + let mut stream = ollama_client.generate(request).await; + + while let Some(response) = stream.next().await { + match response { + Ok(token) => { + print!("{}", token.response); + std::io::stdout().flush()?; + } + Err(e) => println!("Error: {}", e), + } + } + Ok(()) } diff --git a/src/error.rs b/src/error.rs index 35c1eb2..07d60cb 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,11 +1,14 @@ use std::{error::Error, fmt::Display}; +use tokio_util::codec::LinesCodecError; + pub type OllamaResult = Result; #[derive(Debug)] pub enum OllamaError { NetworkError(reqwest::Error), ResponseParseError(serde_json::Error), + LinesCoderError(LinesCodecError), } impl Error for OllamaError {} @@ -15,6 +18,7 @@ impl Display for OllamaError { match self { OllamaError::NetworkError(e) => writeln!(f, "Network error: {}", e), OllamaError::ResponseParseError(e) => writeln!(f, "ResponseParseError error: {}", e), + OllamaError::LinesCoderError(e) => writeln!(f, "LinesCoderError error: {}", e), } } } @@ -30,3 +34,9 @@ impl From for OllamaError { Self::ResponseParseError(error) } } + +impl From for OllamaError { + fn from(value: LinesCodecError) -> Self { + Self::LinesCoderError(value) + } +} diff --git a/src/lib.rs b/src/lib.rs index 9c6dc5c..213b4e1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,10 +1,14 @@ -use futures_util::{StreamExt}; +use async_stream::stream; +use futures_util::{Stream, StreamExt}; use serde_json::Value; -use tokio_util::io::StreamReader; +use tokio_util::{ + codec::{FramedRead, LinesCodec}, + io::StreamReader, +}; use tracing::info; use crate::{ - error::OllamaResult, + error::{OllamaError, OllamaResult}, types::{ generate::{GenerateRequest, GenerateResponse}, ps::RunningModel, @@ -71,32 +75,42 @@ impl OllamaClient { } /// Generates a response for the provided prompt - pub async fn generate(&self, request: GenerateRequest) -> OllamaResult<()> { + pub async fn generate( + &self, + request: GenerateRequest, + ) -> impl Stream> { let request_address = format!("{}/api/generate", self.server_address); let client = reqwest::Client::new(); - let response = client - .post(request_address) - .json(&request) - .send() - .await? - .error_for_status()?; - let stream = response.bytes_stream().; - // let reader = BufReader::new(stream); - let reader = StreamReader(stream); - while reader - while let Some(item) = stream.next().await { - let item = item?; - println!("Chunk: {:?}", item?); - } - // let stream_reader = tokio_util::io::StreamReader::new(stream); - // let reder = BufReader::new(stream); - // let full_response = response.text().await?; - // let parts = full_response - // .lines() - // .map(|line| serde_json::from_str::(line).unwrap()) - // .collect::>(); - // println!("{:#?}", parts); - Ok(()) + // The stream macro creates an asynchronous generator + Box::pin(stream! { + let response = client + .post(request_address) + .json(&request) + .send() + .await + .map_err(|e| OllamaError::from(e))?; // Adjust based on your error type + + let bytes_stream = response.bytes_stream(); + + let body_reader = StreamReader::new( + bytes_stream.map(|res| res.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))), + ); + + let mut lines_stream = FramedRead::new(body_reader, LinesCodec::new()); + + while let Some(line_result) = lines_stream.next().await { + match line_result { + Ok(line_content) => { + if let Ok(parsed) = serde_json::from_str::(&line_content) { + let done = parsed.done; + yield Ok(parsed); + if done { break; } + } + } + Err(e) => yield Err(OllamaError::from(e)), + } + } + }) } } diff --git a/src/types/chat.rs b/src/types/chat.rs new file mode 100644 index 0000000..e69de29 diff --git a/src/types/generate.rs b/src/types/generate.rs index a5d472c..7d218c6 100644 --- a/src/types/generate.rs +++ b/src/types/generate.rs @@ -41,6 +41,11 @@ impl GenerateRequestBuilder { } } + pub fn system_prompt>(mut self, system_prompt: P) -> Self { + self.generate_request.system = Some(system_prompt.into()); + self + } + pub fn prompt>(mut self, prompt: P) -> Self { self.generate_request.prompt = Some(prompt.into()); self diff --git a/src/types/mod.rs b/src/types/mod.rs index 054f2cd..f20738e 100644 --- a/src/types/mod.rs +++ b/src/types/mod.rs @@ -1,3 +1,4 @@ +pub mod chat; pub mod common; pub mod generate; pub mod ps;