diff --git a/examples/chat.rs b/examples/chat.rs index e7efaf1..c1bab56 100644 --- a/examples/chat.rs +++ b/examples/chat.rs @@ -7,32 +7,28 @@ use ollama_rs::{ types::chat::{ChatRequest, Message, Role}, }; -const MODEL: &str = "dolphin3:8b"; +const MODEL: &str = "functiongemma"; #[tokio::main] async fn main() -> Result<(), Box> { let _ = dotenvy::dotenv(); let server_address = env::var("OLLAMA_SERVER")?; let ollama_client = OllamaClient::new(server_address); - let mut messages = vec![Message { - content: "You a role play character called Gerald. You are a dumb person who things knows a lot but PROVIDES WRONG ANSWERS to all questions.".to_string(), - role: Role::System, - }]; + let mut messages = vec![Message::system( + "You a role play character called Gerald. You are a dumb person who things knows a lot but PROVIDES WRONG ANSWERS to all questions.", + )]; loop { let user_input: String = Input::new().with_prompt(">").interact_text()?; if user_input == "/quit" { break; } - let message = Message { - content: user_input, - role: Role::User, - }; + let message = Message::user(user_input); messages.push(message); let request = ChatRequest::builder(MODEL) .messages(messages.clone()) .build(); - let mut stream = ollama_client.chat(request).await; + let mut stream = ollama_client.chat(request); let mut full_message = String::new(); while let Some(response) = stream.next().await { let response = response?; @@ -48,6 +44,7 @@ async fn main() -> Result<(), Box> { messages.push(Message { content: full_message, role: Role::Assistant, + tool_calls: vec![], }); } diff --git a/examples/generate.rs b/examples/generate.rs index 64ee242..970faee 100644 --- a/examples/generate.rs +++ b/examples/generate.rs @@ -14,8 +14,7 @@ async fn main() -> Result<(), Box> { .prompt("Why is the sky blue?") .build(); - let mut stream = ollama_client.generate(request).await; - + let mut stream = ollama_client.generate(request); while let Some(response) = stream.next().await { match response { Ok(token) => { diff --git a/examples/pull.rs b/examples/pull.rs index c61f8a8..89a9c1e 100644 --- a/examples/pull.rs +++ b/examples/pull.rs @@ -3,7 +3,7 @@ use std::{env, error::Error, io::Write}; use futures_util::StreamExt; use ollama_rs::{OllamaClient, types::pull::PullRequest}; -const MODEL: &str = "HammerAI/mythomax-l2"; +const MODEL: &str = "functiongemma"; #[tokio::main] async fn main() -> Result<(), Box> { @@ -12,7 +12,7 @@ async fn main() -> Result<(), Box> { let ollama_client = OllamaClient::new(server_address); let request = PullRequest::builder(MODEL).stream(true).build(); - let mut stream = ollama_client.pull(request).await; + let mut stream = ollama_client.pull(request); while let Some(response) = stream.next().await { let response = response?; println!("{:?}", response); diff --git a/examples/tool_call.rs b/examples/tool_call.rs new file mode 100644 index 0000000..ffe757e --- /dev/null +++ b/examples/tool_call.rs @@ -0,0 +1,80 @@ +use std::{env, error::Error, io::Write}; + +use futures_util::StreamExt; +use ollama_rs::{ + OllamaClient, + types::chat::{ChatRequest, Function, Message, Tool, ToolType}, +}; +use serde::Deserialize; +use serde_json::{Value, json}; + +const MODEL: &str = "functiongemma"; + +fn get_weather(city: &str) -> Value { + json!({ + "city": city, + "temperature": 22.0, + "unit": "celsius", + "condition": "sunny", + }) +} + +#[derive(Deserialize)] +struct GetWeatherArgs { + city: String, +} + +#[tokio::main] +async fn main() -> Result<(), Box> { + let _ = dotenvy::dotenv(); + let server_address = env::var("OLLAMA_SERVER")?; + let ollama_client = OllamaClient::new(server_address); + let tools = vec![Tool { + tool_type: ToolType::Function, + function: Function { + name: "get_weather".to_string(), + description: "Get the current weather for a city.".to_string(), + parameters: json!({ + "type": "object", + "properties": { + "city": { "type": "string", "description": "The name of the city" }, + }, + "required": ["city"], + }), + }, + }]; + + let mut messages = vec![Message::user("What is the weather in Paris?")]; + + loop { + let request = ChatRequest::builder(MODEL) + .messages(messages.clone()) + .stream(false) + .tools(tools.clone()) + .build(); + + let mut stream = ollama_client.chat(request); + let mut full_message = String::new(); + let Some(response) = stream.next().await else { + println!("No response from stream."); + return Ok(()); + }; + + let response = response?; + + if response.message.tool_calls.is_empty() { + full_message += &response.message.content; + print!("{}", response.message.content); + std::io::stdout().flush()?; + break; + } + + messages.push(response.message.clone()); + + let tool_call = &response.message.tool_calls[0]; + let arg: GetWeatherArgs = serde_json::from_value(tool_call.function.arguments.clone())?; + let result = get_weather(&arg.city); + messages.push(Message::tool_response(&result)); + } + Ok(()) +} diff --git a/src/lib.rs b/src/lib.rs index a4a3630..1ec799f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -65,7 +65,7 @@ impl OllamaClient { .await?) } - async fn stream_response( + fn stream_response( &self, endpoint: String, request: R, @@ -101,29 +101,23 @@ impl OllamaClient { } /// Generates a response for the provided prompt - pub async fn generate( + pub fn generate( &self, request: GenerateRequest, ) -> impl Stream> { let request_address = format!("{}/api/generate", self.server_address); - self.stream_response(request_address, request).await + self.stream_response(request_address, request) } /// Generate the next chat message in a conversation between a user and an assistant. - pub async fn chat( - &self, - request: ChatRequest, - ) -> impl Stream> { + pub fn chat(&self, request: ChatRequest) -> impl Stream> { let request_address = format!("{}/api/chat", self.server_address); - self.stream_response(request_address, request).await + self.stream_response(request_address, request) } /// Pull a model - pub async fn pull( - &self, - request: PullRequest, - ) -> impl Stream> { + pub fn pull(&self, request: PullRequest) -> impl Stream> { let request_address = format!("{}/api/pull", self.server_address); - self.stream_response(request_address, request).await + self.stream_response(request_address, request) } } diff --git a/src/types/chat.rs b/src/types/chat.rs index 76806e9..24e90ae 100644 --- a/src/types/chat.rs +++ b/src/types/chat.rs @@ -1,4 +1,5 @@ use serde::{Deserialize, Serialize}; +use serde_json::Value; use crate::types::common::Options; @@ -8,12 +9,42 @@ pub enum Role { User, System, Assistant, + Tool, } #[derive(Clone, Debug, Serialize, Deserialize)] pub struct Message { pub content: String, pub role: Role, + #[serde(skip_serializing_if = "Vec::is_empty")] + #[serde(default)] + pub tool_calls: Vec, +} + +impl Message { + pub fn system>(content: T) -> Self { + Self { + content: content.into(), + role: Role::System, + tool_calls: vec![], + } + } + + pub fn user>(content: T) -> Self { + Self { + content: content.into(), + role: Role::User, + tool_calls: vec![], + } + } + + pub fn tool_response(content: &Value) -> Self { + Message { + content: serde_json::to_string(content).unwrap(), + role: Role::Tool, + tool_calls: vec![], + } + } } #[derive(Debug, Serialize, Deserialize)] @@ -29,6 +60,9 @@ pub struct ChatRequest { /// Runtime options that control text generation #[serde(skip_serializing_if = "Option::is_none")] pub options: Option, + + #[serde(skip_serializing_if = "Vec::is_empty")] + pub tools: Vec, } impl ChatRequest { @@ -57,6 +91,7 @@ impl ChatRequestBuilder { messages: vec![], stream: None, options: None, + tools: vec![], }, } } @@ -71,7 +106,49 @@ impl ChatRequestBuilder { self } + pub fn tools(mut self, tools: Vec) -> Self { + self.chat_request.tools = tools; + self + } + + pub fn stream(mut self, stream: bool) -> Self { + self.chat_request.stream = Some(stream); + self + } + pub fn build(self) -> ChatRequest { self.chat_request } } + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct Tool { + #[serde(rename = "type")] + pub tool_type: ToolType, + pub function: Function, +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum ToolType { + Function, +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct Function { + pub name: String, + pub parameters: Value, + pub description: String, +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct ToolCall { + pub function: ToolCallFunction, +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct ToolCallFunction { + pub name: String, + pub arguments: Value, + pub index: usize, +}