Adds function calling
This commit is contained in:
@@ -7,32 +7,28 @@ use ollama_rs::{
|
||||
types::chat::{ChatRequest, Message, Role},
|
||||
};
|
||||
|
||||
const MODEL: &str = "dolphin3:8b";
|
||||
const MODEL: &str = "functiongemma";
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), Box<dyn Error>> {
|
||||
let _ = dotenvy::dotenv();
|
||||
let server_address = env::var("OLLAMA_SERVER")?;
|
||||
let ollama_client = OllamaClient::new(server_address);
|
||||
let mut messages = vec![Message {
|
||||
content: "You a role play character called Gerald. You are a dumb person who things knows a lot but PROVIDES WRONG ANSWERS to all questions.".to_string(),
|
||||
role: Role::System,
|
||||
}];
|
||||
let mut messages = vec![Message::system(
|
||||
"You a role play character called Gerald. You are a dumb person who things knows a lot but PROVIDES WRONG ANSWERS to all questions.",
|
||||
)];
|
||||
|
||||
loop {
|
||||
let user_input: String = Input::new().with_prompt(">").interact_text()?;
|
||||
if user_input == "/quit" {
|
||||
break;
|
||||
}
|
||||
let message = Message {
|
||||
content: user_input,
|
||||
role: Role::User,
|
||||
};
|
||||
let message = Message::user(user_input);
|
||||
messages.push(message);
|
||||
let request = ChatRequest::builder(MODEL)
|
||||
.messages(messages.clone())
|
||||
.build();
|
||||
|
||||
let mut stream = ollama_client.chat(request).await;
|
||||
let mut stream = ollama_client.chat(request);
|
||||
let mut full_message = String::new();
|
||||
while let Some(response) = stream.next().await {
|
||||
let response = response?;
|
||||
@@ -48,6 +44,7 @@ async fn main() -> Result<(), Box<dyn Error>> {
|
||||
messages.push(Message {
|
||||
content: full_message,
|
||||
role: Role::Assistant,
|
||||
tool_calls: vec![],
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
@@ -14,8 +14,7 @@ async fn main() -> Result<(), Box<dyn Error>> {
|
||||
.prompt("Why is the sky blue?")
|
||||
.build();
|
||||
|
||||
let mut stream = ollama_client.generate(request).await;
|
||||
|
||||
let mut stream = ollama_client.generate(request);
|
||||
while let Some(response) = stream.next().await {
|
||||
match response {
|
||||
Ok(token) => {
|
||||
|
||||
@@ -3,7 +3,7 @@ use std::{env, error::Error, io::Write};
|
||||
use futures_util::StreamExt;
|
||||
use ollama_rs::{OllamaClient, types::pull::PullRequest};
|
||||
|
||||
const MODEL: &str = "HammerAI/mythomax-l2";
|
||||
const MODEL: &str = "functiongemma";
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), Box<dyn Error>> {
|
||||
@@ -12,7 +12,7 @@ async fn main() -> Result<(), Box<dyn Error>> {
|
||||
let ollama_client = OllamaClient::new(server_address);
|
||||
|
||||
let request = PullRequest::builder(MODEL).stream(true).build();
|
||||
let mut stream = ollama_client.pull(request).await;
|
||||
let mut stream = ollama_client.pull(request);
|
||||
while let Some(response) = stream.next().await {
|
||||
let response = response?;
|
||||
println!("{:?}", response);
|
||||
|
||||
80
examples/tool_call.rs
Normal file
80
examples/tool_call.rs
Normal file
@@ -0,0 +1,80 @@
|
||||
use std::{env, error::Error, io::Write};
|
||||
|
||||
use futures_util::StreamExt;
|
||||
use ollama_rs::{
|
||||
OllamaClient,
|
||||
types::chat::{ChatRequest, Function, Message, Tool, ToolType},
|
||||
};
|
||||
use serde::Deserialize;
|
||||
use serde_json::{Value, json};
|
||||
|
||||
const MODEL: &str = "functiongemma";
|
||||
|
||||
fn get_weather(city: &str) -> Value {
|
||||
json!({
|
||||
"city": city,
|
||||
"temperature": 22.0,
|
||||
"unit": "celsius",
|
||||
"condition": "sunny",
|
||||
})
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct GetWeatherArgs {
|
||||
city: String,
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), Box<dyn Error>> {
|
||||
let _ = dotenvy::dotenv();
|
||||
let server_address = env::var("OLLAMA_SERVER")?;
|
||||
let ollama_client = OllamaClient::new(server_address);
|
||||
let tools = vec![Tool {
|
||||
tool_type: ToolType::Function,
|
||||
function: Function {
|
||||
name: "get_weather".to_string(),
|
||||
description: "Get the current weather for a city.".to_string(),
|
||||
parameters: json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": { "type": "string", "description": "The name of the city" },
|
||||
},
|
||||
"required": ["city"],
|
||||
}),
|
||||
},
|
||||
}];
|
||||
|
||||
let mut messages = vec![Message::user("What is the weather in Paris?")];
|
||||
|
||||
loop {
|
||||
let request = ChatRequest::builder(MODEL)
|
||||
.messages(messages.clone())
|
||||
.stream(false)
|
||||
.tools(tools.clone())
|
||||
.build();
|
||||
|
||||
let mut stream = ollama_client.chat(request);
|
||||
let mut full_message = String::new();
|
||||
let Some(response) = stream.next().await else {
|
||||
println!("No response from stream.");
|
||||
return Ok(());
|
||||
};
|
||||
|
||||
let response = response?;
|
||||
|
||||
if response.message.tool_calls.is_empty() {
|
||||
full_message += &response.message.content;
|
||||
print!("{}", response.message.content);
|
||||
std::io::stdout().flush()?;
|
||||
break;
|
||||
}
|
||||
|
||||
messages.push(response.message.clone());
|
||||
|
||||
let tool_call = &response.message.tool_calls[0];
|
||||
let arg: GetWeatherArgs = serde_json::from_value(tool_call.function.arguments.clone())?;
|
||||
let result = get_weather(&arg.city);
|
||||
messages.push(Message::tool_response(&result));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
Reference in New Issue
Block a user