Adds function calling

This commit is contained in:
2026-01-06 21:15:20 +00:00
parent ecedb1c054
commit 53353eabe0
6 changed files with 174 additions and 27 deletions

View File

@@ -7,32 +7,28 @@ use ollama_rs::{
types::chat::{ChatRequest, Message, Role}, types::chat::{ChatRequest, Message, Role},
}; };
const MODEL: &str = "dolphin3:8b"; const MODEL: &str = "functiongemma";
#[tokio::main] #[tokio::main]
async fn main() -> Result<(), Box<dyn Error>> { async fn main() -> Result<(), Box<dyn Error>> {
let _ = dotenvy::dotenv(); let _ = dotenvy::dotenv();
let server_address = env::var("OLLAMA_SERVER")?; let server_address = env::var("OLLAMA_SERVER")?;
let ollama_client = OllamaClient::new(server_address); let ollama_client = OllamaClient::new(server_address);
let mut messages = vec![Message { let mut messages = vec![Message::system(
content: "You a role play character called Gerald. You are a dumb person who things knows a lot but PROVIDES WRONG ANSWERS to all questions.".to_string(), "You a role play character called Gerald. You are a dumb person who things knows a lot but PROVIDES WRONG ANSWERS to all questions.",
role: Role::System, )];
}];
loop { loop {
let user_input: String = Input::new().with_prompt(">").interact_text()?; let user_input: String = Input::new().with_prompt(">").interact_text()?;
if user_input == "/quit" { if user_input == "/quit" {
break; break;
} }
let message = Message { let message = Message::user(user_input);
content: user_input,
role: Role::User,
};
messages.push(message); messages.push(message);
let request = ChatRequest::builder(MODEL) let request = ChatRequest::builder(MODEL)
.messages(messages.clone()) .messages(messages.clone())
.build(); .build();
let mut stream = ollama_client.chat(request).await; let mut stream = ollama_client.chat(request);
let mut full_message = String::new(); let mut full_message = String::new();
while let Some(response) = stream.next().await { while let Some(response) = stream.next().await {
let response = response?; let response = response?;
@@ -48,6 +44,7 @@ async fn main() -> Result<(), Box<dyn Error>> {
messages.push(Message { messages.push(Message {
content: full_message, content: full_message,
role: Role::Assistant, role: Role::Assistant,
tool_calls: vec![],
}); });
} }

View File

@@ -14,8 +14,7 @@ async fn main() -> Result<(), Box<dyn Error>> {
.prompt("Why is the sky blue?") .prompt("Why is the sky blue?")
.build(); .build();
let mut stream = ollama_client.generate(request).await; let mut stream = ollama_client.generate(request);
while let Some(response) = stream.next().await { while let Some(response) = stream.next().await {
match response { match response {
Ok(token) => { Ok(token) => {

View File

@@ -3,7 +3,7 @@ use std::{env, error::Error, io::Write};
use futures_util::StreamExt; use futures_util::StreamExt;
use ollama_rs::{OllamaClient, types::pull::PullRequest}; use ollama_rs::{OllamaClient, types::pull::PullRequest};
const MODEL: &str = "HammerAI/mythomax-l2"; const MODEL: &str = "functiongemma";
#[tokio::main] #[tokio::main]
async fn main() -> Result<(), Box<dyn Error>> { async fn main() -> Result<(), Box<dyn Error>> {
@@ -12,7 +12,7 @@ async fn main() -> Result<(), Box<dyn Error>> {
let ollama_client = OllamaClient::new(server_address); let ollama_client = OllamaClient::new(server_address);
let request = PullRequest::builder(MODEL).stream(true).build(); let request = PullRequest::builder(MODEL).stream(true).build();
let mut stream = ollama_client.pull(request).await; let mut stream = ollama_client.pull(request);
while let Some(response) = stream.next().await { while let Some(response) = stream.next().await {
let response = response?; let response = response?;
println!("{:?}", response); println!("{:?}", response);

80
examples/tool_call.rs Normal file
View File

@@ -0,0 +1,80 @@
use std::{env, error::Error, io::Write};
use futures_util::StreamExt;
use ollama_rs::{
OllamaClient,
types::chat::{ChatRequest, Function, Message, Tool, ToolType},
};
use serde::Deserialize;
use serde_json::{Value, json};
const MODEL: &str = "functiongemma";
fn get_weather(city: &str) -> Value {
json!({
"city": city,
"temperature": 22.0,
"unit": "celsius",
"condition": "sunny",
})
}
#[derive(Deserialize)]
struct GetWeatherArgs {
city: String,
}
#[tokio::main]
async fn main() -> Result<(), Box<dyn Error>> {
let _ = dotenvy::dotenv();
let server_address = env::var("OLLAMA_SERVER")?;
let ollama_client = OllamaClient::new(server_address);
let tools = vec![Tool {
tool_type: ToolType::Function,
function: Function {
name: "get_weather".to_string(),
description: "Get the current weather for a city.".to_string(),
parameters: json!({
"type": "object",
"properties": {
"city": { "type": "string", "description": "The name of the city" },
},
"required": ["city"],
}),
},
}];
let mut messages = vec![Message::user("What is the weather in Paris?")];
loop {
let request = ChatRequest::builder(MODEL)
.messages(messages.clone())
.stream(false)
.tools(tools.clone())
.build();
let mut stream = ollama_client.chat(request);
let mut full_message = String::new();
let Some(response) = stream.next().await else {
println!("No response from stream.");
return Ok(());
};
let response = response?;
if response.message.tool_calls.is_empty() {
full_message += &response.message.content;
print!("{}", response.message.content);
std::io::stdout().flush()?;
break;
}
messages.push(response.message.clone());
let tool_call = &response.message.tool_calls[0];
let arg: GetWeatherArgs = serde_json::from_value(tool_call.function.arguments.clone())?;
let result = get_weather(&arg.city);
messages.push(Message::tool_response(&result));
}
Ok(())
}

View File

@@ -65,7 +65,7 @@ impl OllamaClient {
.await?) .await?)
} }
async fn stream_response<R: Serialize, T: DeserializeOwned>( fn stream_response<R: Serialize, T: DeserializeOwned>(
&self, &self,
endpoint: String, endpoint: String,
request: R, request: R,
@@ -101,29 +101,23 @@ impl OllamaClient {
} }
/// Generates a response for the provided prompt /// Generates a response for the provided prompt
pub async fn generate( pub fn generate(
&self, &self,
request: GenerateRequest, request: GenerateRequest,
) -> impl Stream<Item = OllamaResult<GenerateResponse>> { ) -> impl Stream<Item = OllamaResult<GenerateResponse>> {
let request_address = format!("{}/api/generate", self.server_address); let request_address = format!("{}/api/generate", self.server_address);
self.stream_response(request_address, request).await self.stream_response(request_address, request)
} }
/// Generate the next chat message in a conversation between a user and an assistant. /// Generate the next chat message in a conversation between a user and an assistant.
pub async fn chat( pub fn chat(&self, request: ChatRequest) -> impl Stream<Item = OllamaResult<ChatResponse>> {
&self,
request: ChatRequest,
) -> impl Stream<Item = OllamaResult<ChatResponse>> {
let request_address = format!("{}/api/chat", self.server_address); let request_address = format!("{}/api/chat", self.server_address);
self.stream_response(request_address, request).await self.stream_response(request_address, request)
} }
/// Pull a model /// Pull a model
pub async fn pull( pub fn pull(&self, request: PullRequest) -> impl Stream<Item = OllamaResult<PullResponse>> {
&self,
request: PullRequest,
) -> impl Stream<Item = OllamaResult<PullResponse>> {
let request_address = format!("{}/api/pull", self.server_address); let request_address = format!("{}/api/pull", self.server_address);
self.stream_response(request_address, request).await self.stream_response(request_address, request)
} }
} }

View File

@@ -1,4 +1,5 @@
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::Value;
use crate::types::common::Options; use crate::types::common::Options;
@@ -8,12 +9,42 @@ pub enum Role {
User, User,
System, System,
Assistant, Assistant,
Tool,
} }
#[derive(Clone, Debug, Serialize, Deserialize)] #[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Message { pub struct Message {
pub content: String, pub content: String,
pub role: Role, pub role: Role,
#[serde(skip_serializing_if = "Vec::is_empty")]
#[serde(default)]
pub tool_calls: Vec<ToolCall>,
}
impl Message {
pub fn system<T: Into<String>>(content: T) -> Self {
Self {
content: content.into(),
role: Role::System,
tool_calls: vec![],
}
}
pub fn user<T: Into<String>>(content: T) -> Self {
Self {
content: content.into(),
role: Role::User,
tool_calls: vec![],
}
}
pub fn tool_response(content: &Value) -> Self {
Message {
content: serde_json::to_string(content).unwrap(),
role: Role::Tool,
tool_calls: vec![],
}
}
} }
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
@@ -29,6 +60,9 @@ pub struct ChatRequest {
/// Runtime options that control text generation /// Runtime options that control text generation
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub options: Option<Options>, pub options: Option<Options>,
#[serde(skip_serializing_if = "Vec::is_empty")]
pub tools: Vec<Tool>,
} }
impl ChatRequest { impl ChatRequest {
@@ -57,6 +91,7 @@ impl ChatRequestBuilder {
messages: vec![], messages: vec![],
stream: None, stream: None,
options: None, options: None,
tools: vec![],
}, },
} }
} }
@@ -71,7 +106,49 @@ impl ChatRequestBuilder {
self self
} }
pub fn tools(mut self, tools: Vec<Tool>) -> Self {
self.chat_request.tools = tools;
self
}
pub fn stream(mut self, stream: bool) -> Self {
self.chat_request.stream = Some(stream);
self
}
pub fn build(self) -> ChatRequest { pub fn build(self) -> ChatRequest {
self.chat_request self.chat_request
} }
} }
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Tool {
#[serde(rename = "type")]
pub tool_type: ToolType,
pub function: Function,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum ToolType {
Function,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Function {
pub name: String,
pub parameters: Value,
pub description: String,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ToolCall {
pub function: ToolCallFunction,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ToolCallFunction {
pub name: String,
pub arguments: Value,
pub index: usize,
}