From 65ea1dcfec1b354f2b94fd272adf88bf30bccd24 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9=20Cipriani=20Bandarra?= Date: Wed, 24 Dec 2025 16:55:34 +0000 Subject: [PATCH] Adds chat and chat example --- Cargo.lock | 56 ++++++++++++++++++++++++++++++++++++--- Cargo.toml | 3 ++- examples/chat.rs | 52 ++++++++++++++++++++++++++++++++++++ examples/generate.rs | 1 + src/lib.rs | 41 +++++++++++++++++++++++++++++ src/types/chat.rs | 61 +++++++++++++++++++++++++++++++++++++++++++ src/types/generate.rs | 10 +++++++ 7 files changed, 220 insertions(+), 4 deletions(-) create mode 100644 examples/chat.rs diff --git a/Cargo.lock b/Cargo.lock index 647337c..22bda13 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -70,6 +70,19 @@ version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9330f8b2ff13f34540b44e946ef35111825727b38d33286ef986142615121801" +[[package]] +name = "console" +version = "0.16.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "03e45a4a8926227e4197636ba97a9fc9b00477e9f4bd711395687c5f0734bec4" +dependencies = [ + "encode_unicode", + "libc", + "once_cell", + "unicode-width", + "windows-sys 0.61.2", +] + [[package]] name = "core-foundation" version = "0.9.4" @@ -86,6 +99,18 @@ version = "0.8.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" +[[package]] +name = "dialoguer" +version = "0.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "25f104b501bf2364e78d0d3974cbc774f738f5865306ed128e1e0d7499c0ad96" +dependencies = [ + "console", + "shell-words", + "tempfile", + "zeroize", +] + [[package]] name = "displaydoc" version = "0.2.5" @@ -103,6 +128,12 @@ version = "0.15.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1aaf95b3e5c8f23aa320147307562d361db0ae0d51242340f558153b4eb2439b" +[[package]] +name = "encode_unicode" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34aa73646ffb006b8f5147f3dc182bd4bcb190227ce861fc4a4844bf8e3cb2c0" + [[package]] name = "encoding_rs" version = "0.8.35" @@ -635,6 +666,7 @@ name = "ollama-rs" version = "0.1.0" dependencies = [ "async-stream", + "dialoguer", "dotenvy", "futures-util", "reqwest", @@ -970,15 +1002,15 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.146" +version = "1.0.147" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "217ca874ae0207aac254aa02c957ded05585a90892cc8d87f9e5fa49669dadd8" +checksum = "6af14725505314343e673e9ecb7cd7e8a36aa9791eb936235a3567cc31447ae4" dependencies = [ "itoa", "memchr", - "ryu", "serde", "serde_core", + "zmij", ] [[package]] @@ -1002,6 +1034,12 @@ dependencies = [ "lazy_static", ] +[[package]] +name = "shell-words" +version = "1.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc6fe69c597f9c37bfeeeeeb33da3530379845f10be461a66d16d03eca2ded77" + [[package]] name = "shlex" version = "1.3.0" @@ -1310,6 +1348,12 @@ version = "1.0.22" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9312f7c4f6ff9069b165498234ce8be658059c6728633667c526e27dc2cf1df5" +[[package]] +name = "unicode-width" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b4ac048d71ede7ee76d585517add45da530660ef4390e49b098733c6e897f254" + [[package]] name = "untrusted" version = "0.9.0" @@ -1736,3 +1780,9 @@ dependencies = [ "quote", "syn", ] + +[[package]] +name = "zmij" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9e404bcd8afdaf006e529269d3e85a743f9480c3cef60034d77860d02964f3ba" diff --git a/Cargo.toml b/Cargo.toml index 6a8c530..0a045bf 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,13 +6,14 @@ edition = "2024" [dependencies] reqwest = { version = "0.12.28", features = ["json", "stream"] } serde = { version = "1.0.228", features = ["derive"] } -serde_json = "1.0.146" +serde_json = "1.0.147" tokio-util = "0.7.17" tracing = "0.1.44" futures-util = "0.3.31" async-stream = "0.3.6" [dev-dependencies] +dialoguer = "0.12.0" dotenvy = "0.15.7" tokio = { version = "1.48.0", features = ["full"] } tracing-subscriber = "0.3.22" diff --git a/examples/chat.rs b/examples/chat.rs new file mode 100644 index 0000000..9d2a154 --- /dev/null +++ b/examples/chat.rs @@ -0,0 +1,52 @@ +use std::{env, error::Error, io::Write}; + +use dialoguer::Input; +use futures_util::StreamExt; +use ollama_rs::{ + OllamaClient, + types::chat::{ChatRequest, Message, Role}, +}; + +const MODEL: &str = "dolphin3:8b"; +#[tokio::main] +async fn main() -> Result<(), Box> { + let _ = dotenvy::dotenv(); + let server_address = env::var("OLLAMA_SERVER")?; + let ollama_client = OllamaClient::new(server_address); + let mut messages = vec![Message { + content: "You a role play character called Gerald. You are a dumb person who things knows a lot but PROVIDES WRONG ANSWERS to all questions.".to_string(), + role: Role::System, + }]; + + loop { + let user_input: String = Input::new().with_prompt(">").interact_text()?; + if user_input == "/quit" { + break; + } + let message = Message { + content: user_input, + role: Role::User, + }; + messages.push(message); + let request = ChatRequest::builder(MODEL) + .messages(messages.clone()) + .build(); + + let mut stream = ollama_client.chat(request).await; + let mut full_message = String::new(); + while let Some(response) = stream.next().await { + let response = response?; + full_message += &response.message.content; + print!("{}", response.message.content); + std::io::stdout().flush()?; + } + println!(); + + messages.push(Message { + content: full_message, + role: Role::Assistant, + }); + } + + Ok(()) +} diff --git a/examples/generate.rs b/examples/generate.rs index f95669a..5fa9820 100644 --- a/examples/generate.rs +++ b/examples/generate.rs @@ -10,6 +10,7 @@ async fn main() -> Result<(), Box> { let ollama_client = OllamaClient::new(server_address); let request = GenerateRequest::builder("dolphin3:8b") .system_prompt("You a role play character called Gerald. You are a dumb person who things knows a lot but PROVIDES WRONG ANSWERS to all questions.") + .stream(false) .prompt("Why is the sky blue?") .build(); diff --git a/src/lib.rs b/src/lib.rs index 213b4e1..584ffd7 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -10,6 +10,7 @@ use tracing::info; use crate::{ error::{OllamaError, OllamaResult}, types::{ + chat::{ChatRequest, ChatResponse}, generate::{GenerateRequest, GenerateResponse}, ps::RunningModel, tags::Model, @@ -113,4 +114,44 @@ impl OllamaClient { } }) } + + /// Generate the next chat message in a conversation between a user and an assistant. + pub async fn chat( + &self, + request: ChatRequest, + ) -> impl Stream> { + let request_address = format!("{}/api/chat", self.server_address); + let client = reqwest::Client::new(); + + // The stream macro creates an asynchronous generator + Box::pin(stream! { + let response = client + .post(request_address) + .json(&request) + .send() + .await + .map_err(|e| OllamaError::from(e))?; // Adjust based on your error type + + let bytes_stream = response.bytes_stream(); + + let body_reader = StreamReader::new( + bytes_stream.map(|res| res.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))), + ); + + let mut lines_stream = FramedRead::new(body_reader, LinesCodec::new()); + + while let Some(line_result) = lines_stream.next().await { + match line_result { + Ok(line_content) => { + if let Ok(parsed) = serde_json::from_str::(&line_content) { + let done = parsed.done; + yield Ok(parsed); + if done { break; } + } + } + Err(e) => yield Err(OllamaError::from(e)), + } + } + }) + } } diff --git a/src/types/chat.rs b/src/types/chat.rs index e69de29..5af6f1f 100644 --- a/src/types/chat.rs +++ b/src/types/chat.rs @@ -0,0 +1,61 @@ +use serde::{Deserialize, Serialize}; + +#[derive(Clone, Debug, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum Role { + User, + System, + Assistant, +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct Message { + pub content: String, + pub role: Role, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct ChatRequest { + pub model: String, + pub messages: Vec, + pub stream: Option, +} + +impl ChatRequest { + pub fn builder>(model: M) -> ChatRequestBuilder { + ChatRequestBuilder::new(model) + } +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct ChatResponse { + pub model: String, + pub created_at: String, + pub message: Message, + pub done: bool, +} + +pub struct ChatRequestBuilder { + chat_request: ChatRequest, +} + +impl ChatRequestBuilder { + fn new>(model: M) -> Self { + Self { + chat_request: ChatRequest { + model: model.into(), + messages: vec![], + stream: None, + }, + } + } + + pub fn messages(mut self, messages: Vec) -> Self { + self.chat_request.messages = messages; + self + } + + pub fn build(self) -> ChatRequest { + self.chat_request + } +} diff --git a/src/types/generate.rs b/src/types/generate.rs index 7d218c6..c781e59 100644 --- a/src/types/generate.rs +++ b/src/types/generate.rs @@ -17,6 +17,10 @@ pub struct GenerateRequest { /// System prompt for the model to generate a response from #[serde(skip_serializing_if = "Option::is_none")] pub system: Option, + + /// When true, returns a stream of partial responses + #[serde(skip_serializing_if = "Option::is_none")] + pub stream: Option, } impl GenerateRequest { @@ -37,6 +41,7 @@ impl GenerateRequestBuilder { prompt: None, suffix: None, system: None, + stream: None, }, } } @@ -51,6 +56,11 @@ impl GenerateRequestBuilder { self } + pub fn stream(mut self, stream: bool) -> Self { + self.generate_request.stream = Some(stream); + self + } + pub fn build(self) -> GenerateRequest { self.generate_request }