Adds chat and chat example

This commit is contained in:
2025-12-24 16:55:34 +00:00
parent d06d69d132
commit 65ea1dcfec
7 changed files with 220 additions and 4 deletions

56
Cargo.lock generated
View File

@@ -70,6 +70,19 @@ version = "1.0.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9330f8b2ff13f34540b44e946ef35111825727b38d33286ef986142615121801"
[[package]]
name = "console"
version = "0.16.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "03e45a4a8926227e4197636ba97a9fc9b00477e9f4bd711395687c5f0734bec4"
dependencies = [
"encode_unicode",
"libc",
"once_cell",
"unicode-width",
"windows-sys 0.61.2",
]
[[package]]
name = "core-foundation"
version = "0.9.4"
@@ -86,6 +99,18 @@ version = "0.8.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b"
[[package]]
name = "dialoguer"
version = "0.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "25f104b501bf2364e78d0d3974cbc774f738f5865306ed128e1e0d7499c0ad96"
dependencies = [
"console",
"shell-words",
"tempfile",
"zeroize",
]
[[package]]
name = "displaydoc"
version = "0.2.5"
@@ -103,6 +128,12 @@ version = "0.15.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1aaf95b3e5c8f23aa320147307562d361db0ae0d51242340f558153b4eb2439b"
[[package]]
name = "encode_unicode"
version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "34aa73646ffb006b8f5147f3dc182bd4bcb190227ce861fc4a4844bf8e3cb2c0"
[[package]]
name = "encoding_rs"
version = "0.8.35"
@@ -635,6 +666,7 @@ name = "ollama-rs"
version = "0.1.0"
dependencies = [
"async-stream",
"dialoguer",
"dotenvy",
"futures-util",
"reqwest",
@@ -970,15 +1002,15 @@ dependencies = [
[[package]]
name = "serde_json"
version = "1.0.146"
version = "1.0.147"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "217ca874ae0207aac254aa02c957ded05585a90892cc8d87f9e5fa49669dadd8"
checksum = "6af14725505314343e673e9ecb7cd7e8a36aa9791eb936235a3567cc31447ae4"
dependencies = [
"itoa",
"memchr",
"ryu",
"serde",
"serde_core",
"zmij",
]
[[package]]
@@ -1002,6 +1034,12 @@ dependencies = [
"lazy_static",
]
[[package]]
name = "shell-words"
version = "1.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dc6fe69c597f9c37bfeeeeeb33da3530379845f10be461a66d16d03eca2ded77"
[[package]]
name = "shlex"
version = "1.3.0"
@@ -1310,6 +1348,12 @@ version = "1.0.22"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9312f7c4f6ff9069b165498234ce8be658059c6728633667c526e27dc2cf1df5"
[[package]]
name = "unicode-width"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b4ac048d71ede7ee76d585517add45da530660ef4390e49b098733c6e897f254"
[[package]]
name = "untrusted"
version = "0.9.0"
@@ -1736,3 +1780,9 @@ dependencies = [
"quote",
"syn",
]
[[package]]
name = "zmij"
version = "0.1.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9e404bcd8afdaf006e529269d3e85a743f9480c3cef60034d77860d02964f3ba"

View File

@@ -6,13 +6,14 @@ edition = "2024"
[dependencies]
reqwest = { version = "0.12.28", features = ["json", "stream"] }
serde = { version = "1.0.228", features = ["derive"] }
serde_json = "1.0.146"
serde_json = "1.0.147"
tokio-util = "0.7.17"
tracing = "0.1.44"
futures-util = "0.3.31"
async-stream = "0.3.6"
[dev-dependencies]
dialoguer = "0.12.0"
dotenvy = "0.15.7"
tokio = { version = "1.48.0", features = ["full"] }
tracing-subscriber = "0.3.22"

52
examples/chat.rs Normal file
View File

@@ -0,0 +1,52 @@
use std::{env, error::Error, io::Write};
use dialoguer::Input;
use futures_util::StreamExt;
use ollama_rs::{
OllamaClient,
types::chat::{ChatRequest, Message, Role},
};
const MODEL: &str = "dolphin3:8b";
#[tokio::main]
async fn main() -> Result<(), Box<dyn Error>> {
let _ = dotenvy::dotenv();
let server_address = env::var("OLLAMA_SERVER")?;
let ollama_client = OllamaClient::new(server_address);
let mut messages = vec![Message {
content: "You a role play character called Gerald. You are a dumb person who things knows a lot but PROVIDES WRONG ANSWERS to all questions.".to_string(),
role: Role::System,
}];
loop {
let user_input: String = Input::new().with_prompt(">").interact_text()?;
if user_input == "/quit" {
break;
}
let message = Message {
content: user_input,
role: Role::User,
};
messages.push(message);
let request = ChatRequest::builder(MODEL)
.messages(messages.clone())
.build();
let mut stream = ollama_client.chat(request).await;
let mut full_message = String::new();
while let Some(response) = stream.next().await {
let response = response?;
full_message += &response.message.content;
print!("{}", response.message.content);
std::io::stdout().flush()?;
}
println!();
messages.push(Message {
content: full_message,
role: Role::Assistant,
});
}
Ok(())
}

View File

@@ -10,6 +10,7 @@ async fn main() -> Result<(), Box<dyn Error>> {
let ollama_client = OllamaClient::new(server_address);
let request = GenerateRequest::builder("dolphin3:8b")
.system_prompt("You a role play character called Gerald. You are a dumb person who things knows a lot but PROVIDES WRONG ANSWERS to all questions.")
.stream(false)
.prompt("Why is the sky blue?")
.build();

View File

@@ -10,6 +10,7 @@ use tracing::info;
use crate::{
error::{OllamaError, OllamaResult},
types::{
chat::{ChatRequest, ChatResponse},
generate::{GenerateRequest, GenerateResponse},
ps::RunningModel,
tags::Model,
@@ -113,4 +114,44 @@ impl OllamaClient {
}
})
}
/// Generate the next chat message in a conversation between a user and an assistant.
pub async fn chat(
&self,
request: ChatRequest,
) -> impl Stream<Item = OllamaResult<ChatResponse>> {
let request_address = format!("{}/api/chat", self.server_address);
let client = reqwest::Client::new();
// The stream macro creates an asynchronous generator
Box::pin(stream! {
let response = client
.post(request_address)
.json(&request)
.send()
.await
.map_err(|e| OllamaError::from(e))?; // Adjust based on your error type
let bytes_stream = response.bytes_stream();
let body_reader = StreamReader::new(
bytes_stream.map(|res| res.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))),
);
let mut lines_stream = FramedRead::new(body_reader, LinesCodec::new());
while let Some(line_result) = lines_stream.next().await {
match line_result {
Ok(line_content) => {
if let Ok(parsed) = serde_json::from_str::<ChatResponse>(&line_content) {
let done = parsed.done;
yield Ok(parsed);
if done { break; }
}
}
Err(e) => yield Err(OllamaError::from(e)),
}
}
})
}
}

View File

@@ -0,0 +1,61 @@
use serde::{Deserialize, Serialize};
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum Role {
User,
System,
Assistant,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Message {
pub content: String,
pub role: Role,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct ChatRequest {
pub model: String,
pub messages: Vec<Message>,
pub stream: Option<bool>,
}
impl ChatRequest {
pub fn builder<M: Into<String>>(model: M) -> ChatRequestBuilder {
ChatRequestBuilder::new(model)
}
}
#[derive(Debug, Serialize, Deserialize)]
pub struct ChatResponse {
pub model: String,
pub created_at: String,
pub message: Message,
pub done: bool,
}
pub struct ChatRequestBuilder {
chat_request: ChatRequest,
}
impl ChatRequestBuilder {
fn new<M: Into<String>>(model: M) -> Self {
Self {
chat_request: ChatRequest {
model: model.into(),
messages: vec![],
stream: None,
},
}
}
pub fn messages(mut self, messages: Vec<Message>) -> Self {
self.chat_request.messages = messages;
self
}
pub fn build(self) -> ChatRequest {
self.chat_request
}
}

View File

@@ -17,6 +17,10 @@ pub struct GenerateRequest {
/// System prompt for the model to generate a response from
#[serde(skip_serializing_if = "Option::is_none")]
pub system: Option<String>,
/// When true, returns a stream of partial responses
#[serde(skip_serializing_if = "Option::is_none")]
pub stream: Option<bool>,
}
impl GenerateRequest {
@@ -37,6 +41,7 @@ impl GenerateRequestBuilder {
prompt: None,
suffix: None,
system: None,
stream: None,
},
}
}
@@ -51,6 +56,11 @@ impl GenerateRequestBuilder {
self
}
pub fn stream(mut self, stream: bool) -> Self {
self.generate_request.stream = Some(stream);
self
}
pub fn build(self) -> GenerateRequest {
self.generate_request
}