Adds function calling
This commit is contained in:
@@ -7,32 +7,28 @@ use ollama_rs::{
|
|||||||
types::chat::{ChatRequest, Message, Role},
|
types::chat::{ChatRequest, Message, Role},
|
||||||
};
|
};
|
||||||
|
|
||||||
const MODEL: &str = "dolphin3:8b";
|
const MODEL: &str = "functiongemma";
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
async fn main() -> Result<(), Box<dyn Error>> {
|
async fn main() -> Result<(), Box<dyn Error>> {
|
||||||
let _ = dotenvy::dotenv();
|
let _ = dotenvy::dotenv();
|
||||||
let server_address = env::var("OLLAMA_SERVER")?;
|
let server_address = env::var("OLLAMA_SERVER")?;
|
||||||
let ollama_client = OllamaClient::new(server_address);
|
let ollama_client = OllamaClient::new(server_address);
|
||||||
let mut messages = vec![Message {
|
let mut messages = vec![Message::system(
|
||||||
content: "You a role play character called Gerald. You are a dumb person who things knows a lot but PROVIDES WRONG ANSWERS to all questions.".to_string(),
|
"You a role play character called Gerald. You are a dumb person who things knows a lot but PROVIDES WRONG ANSWERS to all questions.",
|
||||||
role: Role::System,
|
)];
|
||||||
}];
|
|
||||||
|
|
||||||
loop {
|
loop {
|
||||||
let user_input: String = Input::new().with_prompt(">").interact_text()?;
|
let user_input: String = Input::new().with_prompt(">").interact_text()?;
|
||||||
if user_input == "/quit" {
|
if user_input == "/quit" {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
let message = Message {
|
let message = Message::user(user_input);
|
||||||
content: user_input,
|
|
||||||
role: Role::User,
|
|
||||||
};
|
|
||||||
messages.push(message);
|
messages.push(message);
|
||||||
let request = ChatRequest::builder(MODEL)
|
let request = ChatRequest::builder(MODEL)
|
||||||
.messages(messages.clone())
|
.messages(messages.clone())
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
let mut stream = ollama_client.chat(request).await;
|
let mut stream = ollama_client.chat(request);
|
||||||
let mut full_message = String::new();
|
let mut full_message = String::new();
|
||||||
while let Some(response) = stream.next().await {
|
while let Some(response) = stream.next().await {
|
||||||
let response = response?;
|
let response = response?;
|
||||||
@@ -48,6 +44,7 @@ async fn main() -> Result<(), Box<dyn Error>> {
|
|||||||
messages.push(Message {
|
messages.push(Message {
|
||||||
content: full_message,
|
content: full_message,
|
||||||
role: Role::Assistant,
|
role: Role::Assistant,
|
||||||
|
tool_calls: vec![],
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -14,8 +14,7 @@ async fn main() -> Result<(), Box<dyn Error>> {
|
|||||||
.prompt("Why is the sky blue?")
|
.prompt("Why is the sky blue?")
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
let mut stream = ollama_client.generate(request).await;
|
let mut stream = ollama_client.generate(request);
|
||||||
|
|
||||||
while let Some(response) = stream.next().await {
|
while let Some(response) = stream.next().await {
|
||||||
match response {
|
match response {
|
||||||
Ok(token) => {
|
Ok(token) => {
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ use std::{env, error::Error, io::Write};
|
|||||||
use futures_util::StreamExt;
|
use futures_util::StreamExt;
|
||||||
use ollama_rs::{OllamaClient, types::pull::PullRequest};
|
use ollama_rs::{OllamaClient, types::pull::PullRequest};
|
||||||
|
|
||||||
const MODEL: &str = "HammerAI/mythomax-l2";
|
const MODEL: &str = "functiongemma";
|
||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
async fn main() -> Result<(), Box<dyn Error>> {
|
async fn main() -> Result<(), Box<dyn Error>> {
|
||||||
@@ -12,7 +12,7 @@ async fn main() -> Result<(), Box<dyn Error>> {
|
|||||||
let ollama_client = OllamaClient::new(server_address);
|
let ollama_client = OllamaClient::new(server_address);
|
||||||
|
|
||||||
let request = PullRequest::builder(MODEL).stream(true).build();
|
let request = PullRequest::builder(MODEL).stream(true).build();
|
||||||
let mut stream = ollama_client.pull(request).await;
|
let mut stream = ollama_client.pull(request);
|
||||||
while let Some(response) = stream.next().await {
|
while let Some(response) = stream.next().await {
|
||||||
let response = response?;
|
let response = response?;
|
||||||
println!("{:?}", response);
|
println!("{:?}", response);
|
||||||
|
|||||||
80
examples/tool_call.rs
Normal file
80
examples/tool_call.rs
Normal file
@@ -0,0 +1,80 @@
|
|||||||
|
use std::{env, error::Error, io::Write};
|
||||||
|
|
||||||
|
use futures_util::StreamExt;
|
||||||
|
use ollama_rs::{
|
||||||
|
OllamaClient,
|
||||||
|
types::chat::{ChatRequest, Function, Message, Tool, ToolType},
|
||||||
|
};
|
||||||
|
use serde::Deserialize;
|
||||||
|
use serde_json::{Value, json};
|
||||||
|
|
||||||
|
const MODEL: &str = "functiongemma";
|
||||||
|
|
||||||
|
fn get_weather(city: &str) -> Value {
|
||||||
|
json!({
|
||||||
|
"city": city,
|
||||||
|
"temperature": 22.0,
|
||||||
|
"unit": "celsius",
|
||||||
|
"condition": "sunny",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
struct GetWeatherArgs {
|
||||||
|
city: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::main]
|
||||||
|
async fn main() -> Result<(), Box<dyn Error>> {
|
||||||
|
let _ = dotenvy::dotenv();
|
||||||
|
let server_address = env::var("OLLAMA_SERVER")?;
|
||||||
|
let ollama_client = OllamaClient::new(server_address);
|
||||||
|
let tools = vec![Tool {
|
||||||
|
tool_type: ToolType::Function,
|
||||||
|
function: Function {
|
||||||
|
name: "get_weather".to_string(),
|
||||||
|
description: "Get the current weather for a city.".to_string(),
|
||||||
|
parameters: json!({
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"city": { "type": "string", "description": "The name of the city" },
|
||||||
|
},
|
||||||
|
"required": ["city"],
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
}];
|
||||||
|
|
||||||
|
let mut messages = vec![Message::user("What is the weather in Paris?")];
|
||||||
|
|
||||||
|
loop {
|
||||||
|
let request = ChatRequest::builder(MODEL)
|
||||||
|
.messages(messages.clone())
|
||||||
|
.stream(false)
|
||||||
|
.tools(tools.clone())
|
||||||
|
.build();
|
||||||
|
|
||||||
|
let mut stream = ollama_client.chat(request);
|
||||||
|
let mut full_message = String::new();
|
||||||
|
let Some(response) = stream.next().await else {
|
||||||
|
println!("No response from stream.");
|
||||||
|
return Ok(());
|
||||||
|
};
|
||||||
|
|
||||||
|
let response = response?;
|
||||||
|
|
||||||
|
if response.message.tool_calls.is_empty() {
|
||||||
|
full_message += &response.message.content;
|
||||||
|
print!("{}", response.message.content);
|
||||||
|
std::io::stdout().flush()?;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
messages.push(response.message.clone());
|
||||||
|
|
||||||
|
let tool_call = &response.message.tool_calls[0];
|
||||||
|
let arg: GetWeatherArgs = serde_json::from_value(tool_call.function.arguments.clone())?;
|
||||||
|
let result = get_weather(&arg.city);
|
||||||
|
messages.push(Message::tool_response(&result));
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
20
src/lib.rs
20
src/lib.rs
@@ -65,7 +65,7 @@ impl OllamaClient {
|
|||||||
.await?)
|
.await?)
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn stream_response<R: Serialize, T: DeserializeOwned>(
|
fn stream_response<R: Serialize, T: DeserializeOwned>(
|
||||||
&self,
|
&self,
|
||||||
endpoint: String,
|
endpoint: String,
|
||||||
request: R,
|
request: R,
|
||||||
@@ -101,29 +101,23 @@ impl OllamaClient {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Generates a response for the provided prompt
|
/// Generates a response for the provided prompt
|
||||||
pub async fn generate(
|
pub fn generate(
|
||||||
&self,
|
&self,
|
||||||
request: GenerateRequest,
|
request: GenerateRequest,
|
||||||
) -> impl Stream<Item = OllamaResult<GenerateResponse>> {
|
) -> impl Stream<Item = OllamaResult<GenerateResponse>> {
|
||||||
let request_address = format!("{}/api/generate", self.server_address);
|
let request_address = format!("{}/api/generate", self.server_address);
|
||||||
self.stream_response(request_address, request).await
|
self.stream_response(request_address, request)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Generate the next chat message in a conversation between a user and an assistant.
|
/// Generate the next chat message in a conversation between a user and an assistant.
|
||||||
pub async fn chat(
|
pub fn chat(&self, request: ChatRequest) -> impl Stream<Item = OllamaResult<ChatResponse>> {
|
||||||
&self,
|
|
||||||
request: ChatRequest,
|
|
||||||
) -> impl Stream<Item = OllamaResult<ChatResponse>> {
|
|
||||||
let request_address = format!("{}/api/chat", self.server_address);
|
let request_address = format!("{}/api/chat", self.server_address);
|
||||||
self.stream_response(request_address, request).await
|
self.stream_response(request_address, request)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Pull a model
|
/// Pull a model
|
||||||
pub async fn pull(
|
pub fn pull(&self, request: PullRequest) -> impl Stream<Item = OllamaResult<PullResponse>> {
|
||||||
&self,
|
|
||||||
request: PullRequest,
|
|
||||||
) -> impl Stream<Item = OllamaResult<PullResponse>> {
|
|
||||||
let request_address = format!("{}/api/pull", self.server_address);
|
let request_address = format!("{}/api/pull", self.server_address);
|
||||||
self.stream_response(request_address, request).await
|
self.stream_response(request_address, request)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
use serde_json::Value;
|
||||||
|
|
||||||
use crate::types::common::Options;
|
use crate::types::common::Options;
|
||||||
|
|
||||||
@@ -8,12 +9,42 @@ pub enum Role {
|
|||||||
User,
|
User,
|
||||||
System,
|
System,
|
||||||
Assistant,
|
Assistant,
|
||||||
|
Tool,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||||
pub struct Message {
|
pub struct Message {
|
||||||
pub content: String,
|
pub content: String,
|
||||||
pub role: Role,
|
pub role: Role,
|
||||||
|
#[serde(skip_serializing_if = "Vec::is_empty")]
|
||||||
|
#[serde(default)]
|
||||||
|
pub tool_calls: Vec<ToolCall>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Message {
|
||||||
|
pub fn system<T: Into<String>>(content: T) -> Self {
|
||||||
|
Self {
|
||||||
|
content: content.into(),
|
||||||
|
role: Role::System,
|
||||||
|
tool_calls: vec![],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn user<T: Into<String>>(content: T) -> Self {
|
||||||
|
Self {
|
||||||
|
content: content.into(),
|
||||||
|
role: Role::User,
|
||||||
|
tool_calls: vec![],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn tool_response(content: &Value) -> Self {
|
||||||
|
Message {
|
||||||
|
content: serde_json::to_string(content).unwrap(),
|
||||||
|
role: Role::Tool,
|
||||||
|
tool_calls: vec![],
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
@@ -29,6 +60,9 @@ pub struct ChatRequest {
|
|||||||
/// Runtime options that control text generation
|
/// Runtime options that control text generation
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub options: Option<Options>,
|
pub options: Option<Options>,
|
||||||
|
|
||||||
|
#[serde(skip_serializing_if = "Vec::is_empty")]
|
||||||
|
pub tools: Vec<Tool>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ChatRequest {
|
impl ChatRequest {
|
||||||
@@ -57,6 +91,7 @@ impl ChatRequestBuilder {
|
|||||||
messages: vec![],
|
messages: vec![],
|
||||||
stream: None,
|
stream: None,
|
||||||
options: None,
|
options: None,
|
||||||
|
tools: vec![],
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -71,7 +106,49 @@ impl ChatRequestBuilder {
|
|||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn tools(mut self, tools: Vec<Tool>) -> Self {
|
||||||
|
self.chat_request.tools = tools;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn stream(mut self, stream: bool) -> Self {
|
||||||
|
self.chat_request.stream = Some(stream);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
pub fn build(self) -> ChatRequest {
|
pub fn build(self) -> ChatRequest {
|
||||||
self.chat_request
|
self.chat_request
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||||
|
pub struct Tool {
|
||||||
|
#[serde(rename = "type")]
|
||||||
|
pub tool_type: ToolType,
|
||||||
|
pub function: Function,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||||
|
#[serde(rename_all = "lowercase")]
|
||||||
|
pub enum ToolType {
|
||||||
|
Function,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||||
|
pub struct Function {
|
||||||
|
pub name: String,
|
||||||
|
pub parameters: Value,
|
||||||
|
pub description: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||||
|
pub struct ToolCall {
|
||||||
|
pub function: ToolCallFunction,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||||
|
pub struct ToolCallFunction {
|
||||||
|
pub name: String,
|
||||||
|
pub arguments: Value,
|
||||||
|
pub index: usize,
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user