Adds function calling
This commit is contained in:
20
src/lib.rs
20
src/lib.rs
@@ -65,7 +65,7 @@ impl OllamaClient {
|
||||
.await?)
|
||||
}
|
||||
|
||||
async fn stream_response<R: Serialize, T: DeserializeOwned>(
|
||||
fn stream_response<R: Serialize, T: DeserializeOwned>(
|
||||
&self,
|
||||
endpoint: String,
|
||||
request: R,
|
||||
@@ -101,29 +101,23 @@ impl OllamaClient {
|
||||
}
|
||||
|
||||
/// Generates a response for the provided prompt
|
||||
pub async fn generate(
|
||||
pub fn generate(
|
||||
&self,
|
||||
request: GenerateRequest,
|
||||
) -> impl Stream<Item = OllamaResult<GenerateResponse>> {
|
||||
let request_address = format!("{}/api/generate", self.server_address);
|
||||
self.stream_response(request_address, request).await
|
||||
self.stream_response(request_address, request)
|
||||
}
|
||||
|
||||
/// Generate the next chat message in a conversation between a user and an assistant.
|
||||
pub async fn chat(
|
||||
&self,
|
||||
request: ChatRequest,
|
||||
) -> impl Stream<Item = OllamaResult<ChatResponse>> {
|
||||
pub fn chat(&self, request: ChatRequest) -> impl Stream<Item = OllamaResult<ChatResponse>> {
|
||||
let request_address = format!("{}/api/chat", self.server_address);
|
||||
self.stream_response(request_address, request).await
|
||||
self.stream_response(request_address, request)
|
||||
}
|
||||
|
||||
/// Pull a model
|
||||
pub async fn pull(
|
||||
&self,
|
||||
request: PullRequest,
|
||||
) -> impl Stream<Item = OllamaResult<PullResponse>> {
|
||||
pub fn pull(&self, request: PullRequest) -> impl Stream<Item = OllamaResult<PullResponse>> {
|
||||
let request_address = format!("{}/api/pull", self.server_address);
|
||||
self.stream_response(request_address, request).await
|
||||
self.stream_response(request_address, request)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
|
||||
use crate::types::common::Options;
|
||||
|
||||
@@ -8,12 +9,42 @@ pub enum Role {
|
||||
User,
|
||||
System,
|
||||
Assistant,
|
||||
Tool,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct Message {
|
||||
pub content: String,
|
||||
pub role: Role,
|
||||
#[serde(skip_serializing_if = "Vec::is_empty")]
|
||||
#[serde(default)]
|
||||
pub tool_calls: Vec<ToolCall>,
|
||||
}
|
||||
|
||||
impl Message {
|
||||
pub fn system<T: Into<String>>(content: T) -> Self {
|
||||
Self {
|
||||
content: content.into(),
|
||||
role: Role::System,
|
||||
tool_calls: vec![],
|
||||
}
|
||||
}
|
||||
|
||||
pub fn user<T: Into<String>>(content: T) -> Self {
|
||||
Self {
|
||||
content: content.into(),
|
||||
role: Role::User,
|
||||
tool_calls: vec![],
|
||||
}
|
||||
}
|
||||
|
||||
pub fn tool_response(content: &Value) -> Self {
|
||||
Message {
|
||||
content: serde_json::to_string(content).unwrap(),
|
||||
role: Role::Tool,
|
||||
tool_calls: vec![],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
@@ -29,6 +60,9 @@ pub struct ChatRequest {
|
||||
/// Runtime options that control text generation
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub options: Option<Options>,
|
||||
|
||||
#[serde(skip_serializing_if = "Vec::is_empty")]
|
||||
pub tools: Vec<Tool>,
|
||||
}
|
||||
|
||||
impl ChatRequest {
|
||||
@@ -57,6 +91,7 @@ impl ChatRequestBuilder {
|
||||
messages: vec![],
|
||||
stream: None,
|
||||
options: None,
|
||||
tools: vec![],
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -71,7 +106,49 @@ impl ChatRequestBuilder {
|
||||
self
|
||||
}
|
||||
|
||||
pub fn tools(mut self, tools: Vec<Tool>) -> Self {
|
||||
self.chat_request.tools = tools;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn stream(mut self, stream: bool) -> Self {
|
||||
self.chat_request.stream = Some(stream);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn build(self) -> ChatRequest {
|
||||
self.chat_request
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct Tool {
|
||||
#[serde(rename = "type")]
|
||||
pub tool_type: ToolType,
|
||||
pub function: Function,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum ToolType {
|
||||
Function,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct Function {
|
||||
pub name: String,
|
||||
pub parameters: Value,
|
||||
pub description: String,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct ToolCall {
|
||||
pub function: ToolCallFunction,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct ToolCallFunction {
|
||||
pub name: String,
|
||||
pub arguments: Value,
|
||||
pub index: usize,
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user