Adds chat and chat example
This commit is contained in:
56
Cargo.lock
generated
56
Cargo.lock
generated
@@ -70,6 +70,19 @@ version = "1.0.4"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "9330f8b2ff13f34540b44e946ef35111825727b38d33286ef986142615121801"
|
checksum = "9330f8b2ff13f34540b44e946ef35111825727b38d33286ef986142615121801"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "console"
|
||||||
|
version = "0.16.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "03e45a4a8926227e4197636ba97a9fc9b00477e9f4bd711395687c5f0734bec4"
|
||||||
|
dependencies = [
|
||||||
|
"encode_unicode",
|
||||||
|
"libc",
|
||||||
|
"once_cell",
|
||||||
|
"unicode-width",
|
||||||
|
"windows-sys 0.61.2",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "core-foundation"
|
name = "core-foundation"
|
||||||
version = "0.9.4"
|
version = "0.9.4"
|
||||||
@@ -86,6 +99,18 @@ version = "0.8.7"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b"
|
checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "dialoguer"
|
||||||
|
version = "0.12.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "25f104b501bf2364e78d0d3974cbc774f738f5865306ed128e1e0d7499c0ad96"
|
||||||
|
dependencies = [
|
||||||
|
"console",
|
||||||
|
"shell-words",
|
||||||
|
"tempfile",
|
||||||
|
"zeroize",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "displaydoc"
|
name = "displaydoc"
|
||||||
version = "0.2.5"
|
version = "0.2.5"
|
||||||
@@ -103,6 +128,12 @@ version = "0.15.7"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "1aaf95b3e5c8f23aa320147307562d361db0ae0d51242340f558153b4eb2439b"
|
checksum = "1aaf95b3e5c8f23aa320147307562d361db0ae0d51242340f558153b4eb2439b"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "encode_unicode"
|
||||||
|
version = "1.0.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "34aa73646ffb006b8f5147f3dc182bd4bcb190227ce861fc4a4844bf8e3cb2c0"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "encoding_rs"
|
name = "encoding_rs"
|
||||||
version = "0.8.35"
|
version = "0.8.35"
|
||||||
@@ -635,6 +666,7 @@ name = "ollama-rs"
|
|||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"async-stream",
|
"async-stream",
|
||||||
|
"dialoguer",
|
||||||
"dotenvy",
|
"dotenvy",
|
||||||
"futures-util",
|
"futures-util",
|
||||||
"reqwest",
|
"reqwest",
|
||||||
@@ -970,15 +1002,15 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "serde_json"
|
name = "serde_json"
|
||||||
version = "1.0.146"
|
version = "1.0.147"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "217ca874ae0207aac254aa02c957ded05585a90892cc8d87f9e5fa49669dadd8"
|
checksum = "6af14725505314343e673e9ecb7cd7e8a36aa9791eb936235a3567cc31447ae4"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"itoa",
|
"itoa",
|
||||||
"memchr",
|
"memchr",
|
||||||
"ryu",
|
|
||||||
"serde",
|
"serde",
|
||||||
"serde_core",
|
"serde_core",
|
||||||
|
"zmij",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -1002,6 +1034,12 @@ dependencies = [
|
|||||||
"lazy_static",
|
"lazy_static",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "shell-words"
|
||||||
|
version = "1.1.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "dc6fe69c597f9c37bfeeeeeb33da3530379845f10be461a66d16d03eca2ded77"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "shlex"
|
name = "shlex"
|
||||||
version = "1.3.0"
|
version = "1.3.0"
|
||||||
@@ -1310,6 +1348,12 @@ version = "1.0.22"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "9312f7c4f6ff9069b165498234ce8be658059c6728633667c526e27dc2cf1df5"
|
checksum = "9312f7c4f6ff9069b165498234ce8be658059c6728633667c526e27dc2cf1df5"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "unicode-width"
|
||||||
|
version = "0.2.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "b4ac048d71ede7ee76d585517add45da530660ef4390e49b098733c6e897f254"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "untrusted"
|
name = "untrusted"
|
||||||
version = "0.9.0"
|
version = "0.9.0"
|
||||||
@@ -1736,3 +1780,9 @@ dependencies = [
|
|||||||
"quote",
|
"quote",
|
||||||
"syn",
|
"syn",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "zmij"
|
||||||
|
version = "0.1.7"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "9e404bcd8afdaf006e529269d3e85a743f9480c3cef60034d77860d02964f3ba"
|
||||||
|
|||||||
@@ -6,13 +6,14 @@ edition = "2024"
|
|||||||
[dependencies]
|
[dependencies]
|
||||||
reqwest = { version = "0.12.28", features = ["json", "stream"] }
|
reqwest = { version = "0.12.28", features = ["json", "stream"] }
|
||||||
serde = { version = "1.0.228", features = ["derive"] }
|
serde = { version = "1.0.228", features = ["derive"] }
|
||||||
serde_json = "1.0.146"
|
serde_json = "1.0.147"
|
||||||
tokio-util = "0.7.17"
|
tokio-util = "0.7.17"
|
||||||
tracing = "0.1.44"
|
tracing = "0.1.44"
|
||||||
futures-util = "0.3.31"
|
futures-util = "0.3.31"
|
||||||
async-stream = "0.3.6"
|
async-stream = "0.3.6"
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
|
dialoguer = "0.12.0"
|
||||||
dotenvy = "0.15.7"
|
dotenvy = "0.15.7"
|
||||||
tokio = { version = "1.48.0", features = ["full"] }
|
tokio = { version = "1.48.0", features = ["full"] }
|
||||||
tracing-subscriber = "0.3.22"
|
tracing-subscriber = "0.3.22"
|
||||||
|
|||||||
52
examples/chat.rs
Normal file
52
examples/chat.rs
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
use std::{env, error::Error, io::Write};
|
||||||
|
|
||||||
|
use dialoguer::Input;
|
||||||
|
use futures_util::StreamExt;
|
||||||
|
use ollama_rs::{
|
||||||
|
OllamaClient,
|
||||||
|
types::chat::{ChatRequest, Message, Role},
|
||||||
|
};
|
||||||
|
|
||||||
|
const MODEL: &str = "dolphin3:8b";
|
||||||
|
#[tokio::main]
|
||||||
|
async fn main() -> Result<(), Box<dyn Error>> {
|
||||||
|
let _ = dotenvy::dotenv();
|
||||||
|
let server_address = env::var("OLLAMA_SERVER")?;
|
||||||
|
let ollama_client = OllamaClient::new(server_address);
|
||||||
|
let mut messages = vec![Message {
|
||||||
|
content: "You a role play character called Gerald. You are a dumb person who things knows a lot but PROVIDES WRONG ANSWERS to all questions.".to_string(),
|
||||||
|
role: Role::System,
|
||||||
|
}];
|
||||||
|
|
||||||
|
loop {
|
||||||
|
let user_input: String = Input::new().with_prompt(">").interact_text()?;
|
||||||
|
if user_input == "/quit" {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
let message = Message {
|
||||||
|
content: user_input,
|
||||||
|
role: Role::User,
|
||||||
|
};
|
||||||
|
messages.push(message);
|
||||||
|
let request = ChatRequest::builder(MODEL)
|
||||||
|
.messages(messages.clone())
|
||||||
|
.build();
|
||||||
|
|
||||||
|
let mut stream = ollama_client.chat(request).await;
|
||||||
|
let mut full_message = String::new();
|
||||||
|
while let Some(response) = stream.next().await {
|
||||||
|
let response = response?;
|
||||||
|
full_message += &response.message.content;
|
||||||
|
print!("{}", response.message.content);
|
||||||
|
std::io::stdout().flush()?;
|
||||||
|
}
|
||||||
|
println!();
|
||||||
|
|
||||||
|
messages.push(Message {
|
||||||
|
content: full_message,
|
||||||
|
role: Role::Assistant,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
@@ -10,6 +10,7 @@ async fn main() -> Result<(), Box<dyn Error>> {
|
|||||||
let ollama_client = OllamaClient::new(server_address);
|
let ollama_client = OllamaClient::new(server_address);
|
||||||
let request = GenerateRequest::builder("dolphin3:8b")
|
let request = GenerateRequest::builder("dolphin3:8b")
|
||||||
.system_prompt("You a role play character called Gerald. You are a dumb person who things knows a lot but PROVIDES WRONG ANSWERS to all questions.")
|
.system_prompt("You a role play character called Gerald. You are a dumb person who things knows a lot but PROVIDES WRONG ANSWERS to all questions.")
|
||||||
|
.stream(false)
|
||||||
.prompt("Why is the sky blue?")
|
.prompt("Why is the sky blue?")
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
|
|||||||
41
src/lib.rs
41
src/lib.rs
@@ -10,6 +10,7 @@ use tracing::info;
|
|||||||
use crate::{
|
use crate::{
|
||||||
error::{OllamaError, OllamaResult},
|
error::{OllamaError, OllamaResult},
|
||||||
types::{
|
types::{
|
||||||
|
chat::{ChatRequest, ChatResponse},
|
||||||
generate::{GenerateRequest, GenerateResponse},
|
generate::{GenerateRequest, GenerateResponse},
|
||||||
ps::RunningModel,
|
ps::RunningModel,
|
||||||
tags::Model,
|
tags::Model,
|
||||||
@@ -113,4 +114,44 @@ impl OllamaClient {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Generate the next chat message in a conversation between a user and an assistant.
|
||||||
|
pub async fn chat(
|
||||||
|
&self,
|
||||||
|
request: ChatRequest,
|
||||||
|
) -> impl Stream<Item = OllamaResult<ChatResponse>> {
|
||||||
|
let request_address = format!("{}/api/chat", self.server_address);
|
||||||
|
let client = reqwest::Client::new();
|
||||||
|
|
||||||
|
// The stream macro creates an asynchronous generator
|
||||||
|
Box::pin(stream! {
|
||||||
|
let response = client
|
||||||
|
.post(request_address)
|
||||||
|
.json(&request)
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.map_err(|e| OllamaError::from(e))?; // Adjust based on your error type
|
||||||
|
|
||||||
|
let bytes_stream = response.bytes_stream();
|
||||||
|
|
||||||
|
let body_reader = StreamReader::new(
|
||||||
|
bytes_stream.map(|res| res.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))),
|
||||||
|
);
|
||||||
|
|
||||||
|
let mut lines_stream = FramedRead::new(body_reader, LinesCodec::new());
|
||||||
|
|
||||||
|
while let Some(line_result) = lines_stream.next().await {
|
||||||
|
match line_result {
|
||||||
|
Ok(line_content) => {
|
||||||
|
if let Ok(parsed) = serde_json::from_str::<ChatResponse>(&line_content) {
|
||||||
|
let done = parsed.done;
|
||||||
|
yield Ok(parsed);
|
||||||
|
if done { break; }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => yield Err(OllamaError::from(e)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,61 @@
|
|||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||||
|
#[serde(rename_all = "lowercase")]
|
||||||
|
pub enum Role {
|
||||||
|
User,
|
||||||
|
System,
|
||||||
|
Assistant,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||||
|
pub struct Message {
|
||||||
|
pub content: String,
|
||||||
|
pub role: Role,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
pub struct ChatRequest {
|
||||||
|
pub model: String,
|
||||||
|
pub messages: Vec<Message>,
|
||||||
|
pub stream: Option<bool>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ChatRequest {
|
||||||
|
pub fn builder<M: Into<String>>(model: M) -> ChatRequestBuilder {
|
||||||
|
ChatRequestBuilder::new(model)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
pub struct ChatResponse {
|
||||||
|
pub model: String,
|
||||||
|
pub created_at: String,
|
||||||
|
pub message: Message,
|
||||||
|
pub done: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct ChatRequestBuilder {
|
||||||
|
chat_request: ChatRequest,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ChatRequestBuilder {
|
||||||
|
fn new<M: Into<String>>(model: M) -> Self {
|
||||||
|
Self {
|
||||||
|
chat_request: ChatRequest {
|
||||||
|
model: model.into(),
|
||||||
|
messages: vec![],
|
||||||
|
stream: None,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn messages(mut self, messages: Vec<Message>) -> Self {
|
||||||
|
self.chat_request.messages = messages;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn build(self) -> ChatRequest {
|
||||||
|
self.chat_request
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -17,6 +17,10 @@ pub struct GenerateRequest {
|
|||||||
/// System prompt for the model to generate a response from
|
/// System prompt for the model to generate a response from
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub system: Option<String>,
|
pub system: Option<String>,
|
||||||
|
|
||||||
|
/// When true, returns a stream of partial responses
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub stream: Option<bool>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl GenerateRequest {
|
impl GenerateRequest {
|
||||||
@@ -37,6 +41,7 @@ impl GenerateRequestBuilder {
|
|||||||
prompt: None,
|
prompt: None,
|
||||||
suffix: None,
|
suffix: None,
|
||||||
system: None,
|
system: None,
|
||||||
|
stream: None,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -51,6 +56,11 @@ impl GenerateRequestBuilder {
|
|||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn stream(mut self, stream: bool) -> Self {
|
||||||
|
self.generate_request.stream = Some(stream);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
pub fn build(self) -> GenerateRequest {
|
pub fn build(self) -> GenerateRequest {
|
||||||
self.generate_request
|
self.generate_request
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user