Adds chat and chat example
This commit is contained in:
56
Cargo.lock
generated
56
Cargo.lock
generated
@@ -70,6 +70,19 @@ version = "1.0.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9330f8b2ff13f34540b44e946ef35111825727b38d33286ef986142615121801"
|
||||
|
||||
[[package]]
|
||||
name = "console"
|
||||
version = "0.16.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "03e45a4a8926227e4197636ba97a9fc9b00477e9f4bd711395687c5f0734bec4"
|
||||
dependencies = [
|
||||
"encode_unicode",
|
||||
"libc",
|
||||
"once_cell",
|
||||
"unicode-width",
|
||||
"windows-sys 0.61.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "core-foundation"
|
||||
version = "0.9.4"
|
||||
@@ -86,6 +99,18 @@ version = "0.8.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b"
|
||||
|
||||
[[package]]
|
||||
name = "dialoguer"
|
||||
version = "0.12.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "25f104b501bf2364e78d0d3974cbc774f738f5865306ed128e1e0d7499c0ad96"
|
||||
dependencies = [
|
||||
"console",
|
||||
"shell-words",
|
||||
"tempfile",
|
||||
"zeroize",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "displaydoc"
|
||||
version = "0.2.5"
|
||||
@@ -103,6 +128,12 @@ version = "0.15.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1aaf95b3e5c8f23aa320147307562d361db0ae0d51242340f558153b4eb2439b"
|
||||
|
||||
[[package]]
|
||||
name = "encode_unicode"
|
||||
version = "1.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "34aa73646ffb006b8f5147f3dc182bd4bcb190227ce861fc4a4844bf8e3cb2c0"
|
||||
|
||||
[[package]]
|
||||
name = "encoding_rs"
|
||||
version = "0.8.35"
|
||||
@@ -635,6 +666,7 @@ name = "ollama-rs"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"async-stream",
|
||||
"dialoguer",
|
||||
"dotenvy",
|
||||
"futures-util",
|
||||
"reqwest",
|
||||
@@ -970,15 +1002,15 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "serde_json"
|
||||
version = "1.0.146"
|
||||
version = "1.0.147"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "217ca874ae0207aac254aa02c957ded05585a90892cc8d87f9e5fa49669dadd8"
|
||||
checksum = "6af14725505314343e673e9ecb7cd7e8a36aa9791eb936235a3567cc31447ae4"
|
||||
dependencies = [
|
||||
"itoa",
|
||||
"memchr",
|
||||
"ryu",
|
||||
"serde",
|
||||
"serde_core",
|
||||
"zmij",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -1002,6 +1034,12 @@ dependencies = [
|
||||
"lazy_static",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "shell-words"
|
||||
version = "1.1.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "dc6fe69c597f9c37bfeeeeeb33da3530379845f10be461a66d16d03eca2ded77"
|
||||
|
||||
[[package]]
|
||||
name = "shlex"
|
||||
version = "1.3.0"
|
||||
@@ -1310,6 +1348,12 @@ version = "1.0.22"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9312f7c4f6ff9069b165498234ce8be658059c6728633667c526e27dc2cf1df5"
|
||||
|
||||
[[package]]
|
||||
name = "unicode-width"
|
||||
version = "0.2.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b4ac048d71ede7ee76d585517add45da530660ef4390e49b098733c6e897f254"
|
||||
|
||||
[[package]]
|
||||
name = "untrusted"
|
||||
version = "0.9.0"
|
||||
@@ -1736,3 +1780,9 @@ dependencies = [
|
||||
"quote",
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "zmij"
|
||||
version = "0.1.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9e404bcd8afdaf006e529269d3e85a743f9480c3cef60034d77860d02964f3ba"
|
||||
|
||||
@@ -6,13 +6,14 @@ edition = "2024"
|
||||
[dependencies]
|
||||
reqwest = { version = "0.12.28", features = ["json", "stream"] }
|
||||
serde = { version = "1.0.228", features = ["derive"] }
|
||||
serde_json = "1.0.146"
|
||||
serde_json = "1.0.147"
|
||||
tokio-util = "0.7.17"
|
||||
tracing = "0.1.44"
|
||||
futures-util = "0.3.31"
|
||||
async-stream = "0.3.6"
|
||||
|
||||
[dev-dependencies]
|
||||
dialoguer = "0.12.0"
|
||||
dotenvy = "0.15.7"
|
||||
tokio = { version = "1.48.0", features = ["full"] }
|
||||
tracing-subscriber = "0.3.22"
|
||||
|
||||
52
examples/chat.rs
Normal file
52
examples/chat.rs
Normal file
@@ -0,0 +1,52 @@
|
||||
use std::{env, error::Error, io::Write};
|
||||
|
||||
use dialoguer::Input;
|
||||
use futures_util::StreamExt;
|
||||
use ollama_rs::{
|
||||
OllamaClient,
|
||||
types::chat::{ChatRequest, Message, Role},
|
||||
};
|
||||
|
||||
const MODEL: &str = "dolphin3:8b";
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), Box<dyn Error>> {
|
||||
let _ = dotenvy::dotenv();
|
||||
let server_address = env::var("OLLAMA_SERVER")?;
|
||||
let ollama_client = OllamaClient::new(server_address);
|
||||
let mut messages = vec![Message {
|
||||
content: "You a role play character called Gerald. You are a dumb person who things knows a lot but PROVIDES WRONG ANSWERS to all questions.".to_string(),
|
||||
role: Role::System,
|
||||
}];
|
||||
|
||||
loop {
|
||||
let user_input: String = Input::new().with_prompt(">").interact_text()?;
|
||||
if user_input == "/quit" {
|
||||
break;
|
||||
}
|
||||
let message = Message {
|
||||
content: user_input,
|
||||
role: Role::User,
|
||||
};
|
||||
messages.push(message);
|
||||
let request = ChatRequest::builder(MODEL)
|
||||
.messages(messages.clone())
|
||||
.build();
|
||||
|
||||
let mut stream = ollama_client.chat(request).await;
|
||||
let mut full_message = String::new();
|
||||
while let Some(response) = stream.next().await {
|
||||
let response = response?;
|
||||
full_message += &response.message.content;
|
||||
print!("{}", response.message.content);
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
println!();
|
||||
|
||||
messages.push(Message {
|
||||
content: full_message,
|
||||
role: Role::Assistant,
|
||||
});
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -10,6 +10,7 @@ async fn main() -> Result<(), Box<dyn Error>> {
|
||||
let ollama_client = OllamaClient::new(server_address);
|
||||
let request = GenerateRequest::builder("dolphin3:8b")
|
||||
.system_prompt("You a role play character called Gerald. You are a dumb person who things knows a lot but PROVIDES WRONG ANSWERS to all questions.")
|
||||
.stream(false)
|
||||
.prompt("Why is the sky blue?")
|
||||
.build();
|
||||
|
||||
|
||||
41
src/lib.rs
41
src/lib.rs
@@ -10,6 +10,7 @@ use tracing::info;
|
||||
use crate::{
|
||||
error::{OllamaError, OllamaResult},
|
||||
types::{
|
||||
chat::{ChatRequest, ChatResponse},
|
||||
generate::{GenerateRequest, GenerateResponse},
|
||||
ps::RunningModel,
|
||||
tags::Model,
|
||||
@@ -113,4 +114,44 @@ impl OllamaClient {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/// Generate the next chat message in a conversation between a user and an assistant.
|
||||
pub async fn chat(
|
||||
&self,
|
||||
request: ChatRequest,
|
||||
) -> impl Stream<Item = OllamaResult<ChatResponse>> {
|
||||
let request_address = format!("{}/api/chat", self.server_address);
|
||||
let client = reqwest::Client::new();
|
||||
|
||||
// The stream macro creates an asynchronous generator
|
||||
Box::pin(stream! {
|
||||
let response = client
|
||||
.post(request_address)
|
||||
.json(&request)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| OllamaError::from(e))?; // Adjust based on your error type
|
||||
|
||||
let bytes_stream = response.bytes_stream();
|
||||
|
||||
let body_reader = StreamReader::new(
|
||||
bytes_stream.map(|res| res.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))),
|
||||
);
|
||||
|
||||
let mut lines_stream = FramedRead::new(body_reader, LinesCodec::new());
|
||||
|
||||
while let Some(line_result) = lines_stream.next().await {
|
||||
match line_result {
|
||||
Ok(line_content) => {
|
||||
if let Ok(parsed) = serde_json::from_str::<ChatResponse>(&line_content) {
|
||||
let done = parsed.done;
|
||||
yield Ok(parsed);
|
||||
if done { break; }
|
||||
}
|
||||
}
|
||||
Err(e) => yield Err(OllamaError::from(e)),
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,61 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum Role {
|
||||
User,
|
||||
System,
|
||||
Assistant,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct Message {
|
||||
pub content: String,
|
||||
pub role: Role,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct ChatRequest {
|
||||
pub model: String,
|
||||
pub messages: Vec<Message>,
|
||||
pub stream: Option<bool>,
|
||||
}
|
||||
|
||||
impl ChatRequest {
|
||||
pub fn builder<M: Into<String>>(model: M) -> ChatRequestBuilder {
|
||||
ChatRequestBuilder::new(model)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct ChatResponse {
|
||||
pub model: String,
|
||||
pub created_at: String,
|
||||
pub message: Message,
|
||||
pub done: bool,
|
||||
}
|
||||
|
||||
pub struct ChatRequestBuilder {
|
||||
chat_request: ChatRequest,
|
||||
}
|
||||
|
||||
impl ChatRequestBuilder {
|
||||
fn new<M: Into<String>>(model: M) -> Self {
|
||||
Self {
|
||||
chat_request: ChatRequest {
|
||||
model: model.into(),
|
||||
messages: vec![],
|
||||
stream: None,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
pub fn messages(mut self, messages: Vec<Message>) -> Self {
|
||||
self.chat_request.messages = messages;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn build(self) -> ChatRequest {
|
||||
self.chat_request
|
||||
}
|
||||
}
|
||||
|
||||
@@ -17,6 +17,10 @@ pub struct GenerateRequest {
|
||||
/// System prompt for the model to generate a response from
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub system: Option<String>,
|
||||
|
||||
/// When true, returns a stream of partial responses
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub stream: Option<bool>,
|
||||
}
|
||||
|
||||
impl GenerateRequest {
|
||||
@@ -37,6 +41,7 @@ impl GenerateRequestBuilder {
|
||||
prompt: None,
|
||||
suffix: None,
|
||||
system: None,
|
||||
stream: None,
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -51,6 +56,11 @@ impl GenerateRequestBuilder {
|
||||
self
|
||||
}
|
||||
|
||||
pub fn stream(mut self, stream: bool) -> Self {
|
||||
self.generate_request.stream = Some(stream);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn build(self) -> GenerateRequest {
|
||||
self.generate_request
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user