diff --git a/README.md b/README.md index d7d965f..57b3233 100644 --- a/README.md +++ b/README.md @@ -143,7 +143,7 @@ let request = ChatRequest::builder("llama3:8b") .build(); ``` -When the model decides to call a tool, the response `message.tool_calls` field will contain the tool name and arguments. You can then execute the function and send the result back via `Message::tool_response(...)`. +When the model decides to call a tool, the response `message.tool_calls` field will contain the tool name and arguments. You can then execute the function and send the result back via `Message::tool_response(...)` which returns an `OllamaResult`. ## API Reference @@ -151,7 +151,9 @@ When the model decides to call a tool, the response `message.tool_calls` field w | Method | Description | |--------|-------------| -| `new(server_address)` | Create a new client pointing at an Ollama server | +| `new(server_address)` | Create a new client with a 30-second connection timeout | +| `default()` | Create a client connecting to `http://localhost:11434` | +| `builder(server_address)` | Create a client with custom configuration (see below) | | `version()` | Get the Ollama server version | | `tags()` | List all available models | | `ps()` | List currently running/loaded models | @@ -159,11 +161,22 @@ When the model decides to call a tool, the response `message.tool_calls` field w | `chat(request)` | Chat conversation (streaming) | | `pull(request)` | Pull/download a model (streaming) | +**`OllamaClient::builder(server_address)`** -- `.connection_timeout(Duration)`, `.build()` + +```rust +use std::time::Duration; +use ollama_rs::OllamaClient; + +let client = OllamaClient::builder("http://localhost:11434") + .connection_timeout(Duration::from_secs(60)) + .build(); +``` + ### Request Builders **`GenerateRequest::builder(model)`** -- `.prompt()`, `.system_prompt()`, `.format()`, `.options()`, `.stream()`, `.think()`, `.images()`, `.suffix()` -**`ChatRequest::builder(model)`** -- `.messages()`, `.tools()`, `.format()`, `.options()`, `.stream()` +**`ChatRequest::builder(model)`** -- `.messages()`, `.tools()`, `.format()`, `.options()`, `.stream()`, `.think()` **`PullRequest::builder(model)`** -- `.stream()` diff --git a/src/lib.rs b/src/lib.rs index 391cce0..69a73b5 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,3 +1,5 @@ +use std::time::Duration; + use async_stream::stream; use futures_util::{Stream, StreamExt}; use serde::{Serialize, de::DeserializeOwned}; @@ -22,6 +24,8 @@ use crate::{ pub mod error; pub mod types; +const DEFAULT_CONNECTION_TIMEOUT: Duration = Duration::from_secs(30); + #[derive(Clone)] pub struct OllamaClient { server_address: String, @@ -38,7 +42,17 @@ impl OllamaClient { pub fn new>(server_address: S) -> Self { Self { server_address: server_address.as_ref().to_string(), - client: reqwest::Client::new(), + client: reqwest::Client::builder() + .connect_timeout(DEFAULT_CONNECTION_TIMEOUT) + .build() + .expect("failed to build reqwest client"), + } + } + + pub fn builder>(server_address: S) -> OllamaClientBuilder { + OllamaClientBuilder { + server_address: server_address.as_ref().to_string(), + connection_timeout: DEFAULT_CONNECTION_TIMEOUT, } } @@ -142,3 +156,56 @@ impl OllamaClient { self.stream_response(request_address, request) } } + +pub struct OllamaClientBuilder { + server_address: String, + connection_timeout: Duration, +} + +impl OllamaClientBuilder { + pub fn connection_timeout(mut self, timeout: Duration) -> Self { + self.connection_timeout = timeout; + self + } + + pub fn build(self) -> OllamaClient { + OllamaClient { + server_address: self.server_address, + client: reqwest::Client::builder() + .connect_timeout(self.connection_timeout) + .build() + .expect("failed to build reqwest client"), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn new_creates_client() { + let client = OllamaClient::new("http://localhost:11434"); + assert_eq!(client.server_address, "http://localhost:11434"); + } + + #[test] + fn default_creates_localhost_client() { + let client = OllamaClient::default(); + assert_eq!(client.server_address, "http://localhost:11434"); + } + + #[test] + fn builder_creates_client() { + let client = OllamaClient::builder("http://myserver:11434").build(); + assert_eq!(client.server_address, "http://myserver:11434"); + } + + #[test] + fn builder_with_custom_timeout() { + let client = OllamaClient::builder("http://localhost:11434") + .connection_timeout(Duration::from_secs(60)) + .build(); + assert_eq!(client.server_address, "http://localhost:11434"); + } +}