From 0f796f1a2f54b1e3df410774deda4f7414abc14b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9=20Cipriani=20Bandarra?= Date: Sun, 1 Feb 2026 21:26:44 +0000 Subject: [PATCH] Add embed endpoint (POST /api/embed) Implement the Ollama POST /api/embed endpoint for generating vector embeddings from text input. - Add EmbedInput, EmbedRequest, EmbedResponse types in src/types/embed.rs - Add OllamaClient::embed() async method in src/lib.rs - Register embed module in src/types/mod.rs - Add usage example in examples/embed.rs - Update README with embed endpoint documentation --- README.md | 5 + examples/embed.rs | 24 ++++ src/lib.rs | 42 +++++++ src/types/embed.rs | 306 +++++++++++++++++++++++++++++++++++++++++++++ src/types/mod.rs | 2 + 5 files changed, 379 insertions(+) create mode 100644 examples/embed.rs create mode 100644 src/types/embed.rs diff --git a/README.md b/README.md index f1b8fdf..ba1f5c5 100644 --- a/README.md +++ b/README.md @@ -9,6 +9,7 @@ An async Rust client library for the [Ollama](https://ollama.com/) API. Provides - Structured JSON output with schema validation - Tool calling / function calling support - Model management (list, pull, delete, inspect running models) +- Text embeddings generation - Builder pattern for constructing requests - Configurable generation parameters (temperature, top-k, top-p, and more) - Thinking / reasoning mode support @@ -161,6 +162,7 @@ When the model decides to call a tool, the response `message.tool_calls` field w | `chat(request)` | Chat conversation (streaming) | | `pull(request)` | Pull/download a model (streaming) | | `delete(request)` | Delete a model from the server | +| `embed(request)` | Generate vector embeddings | **`OllamaClient::builder(server_address)`** -- `.connection_timeout(Duration)`, `.build()` @@ -181,6 +183,8 @@ let client = OllamaClient::builder("http://localhost:11434") **`PullRequest::builder(model)`** -- `.stream()` +**`EmbedRequest::builder(model)`** -- `.input()`, `.inputs()`, `.truncate()`, `.dimensions()`, `.keep_alive()`, `.options()` + ### Generation Options Configure sampling parameters via `Options::builder()`: @@ -208,6 +212,7 @@ The `examples/` directory contains runnable programs: | `tool_call` | Function calling / tool use | | `pull` | Download a model | | `delete` | Delete a model | +| `embed` | Generate text embeddings | | `tags` | List available models | | `ps` | List running models | | `version` | Query server version | diff --git a/examples/embed.rs b/examples/embed.rs new file mode 100644 index 0000000..1d7b8cc --- /dev/null +++ b/examples/embed.rs @@ -0,0 +1,24 @@ +use std::{env, error::Error}; + +use ollama_rs::OllamaClient; +use ollama_rs::types::embed::EmbedRequest; + +#[tokio::main] +async fn main() -> Result<(), Box> { + tracing_subscriber::fmt().init(); + let _ = dotenvy::dotenv(); + let server_address = env::var("OLLAMA_SERVER")?; + let ollama_client = OllamaClient::new(server_address); + let request = EmbedRequest::builder("embeddinggemma") + .input("Generate embeddings for this text") + .build(); + + let response = ollama_client.embed(request).await?; + for (i, embedding) in response.embeddings.iter().enumerate() { + println!("Embedding {}: {} dimensions", i, embedding.len()); + if embedding.len() >= 3 { + println!(" First 3 values: {:?}", &embedding[..3]); + } + } + Ok(()) +} diff --git a/src/lib.rs b/src/lib.rs index 349aad9..bcbe680 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -94,6 +94,7 @@ use crate::{ types::{ chat::{ChatRequest, ChatResponse}, delete::DeleteRequest, + embed::{EmbedRequest, EmbedResponse}, generate::{GenerateRequest, GenerateResponse}, ps::PsResponse, pull::{PullRequest, PullResponse}, @@ -282,6 +283,47 @@ impl OllamaClient { Ok(()) } + /// Generates vector embeddings for the given input text(s). + /// + /// Calls `POST /api/embed`. + /// + /// # Errors + /// + /// Returns [`OllamaError::NetworkError`] if the server is unreachable or returns + /// a non-success status code. + /// + /// # Examples + /// + /// ```no_run + /// # use ollama_rs::OllamaClient; + /// # use ollama_rs::types::embed::EmbedRequest; + /// # async fn run() -> ollama_rs::error::OllamaResult<()> { + /// let client = OllamaClient::default(); + /// let request = EmbedRequest::builder("embeddinggemma") + /// .input("Generate embeddings for this text") + /// .build(); + /// + /// let response = client.embed(request).await?; + /// for embedding in &response.embeddings { + /// println!("Dimension count: {}", embedding.len()); + /// } + /// # Ok(()) + /// # } + /// ``` + pub async fn embed(&self, request: EmbedRequest) -> OllamaResult { + let request_address = format!("{}/api/embed", self.server_address); + info!("Generate embeddings: {}", request.model); + Ok(self + .client + .post(request_address) + .json(&request) + .send() + .await? + .error_for_status()? + .json() + .await?) + } + fn stream_response( &self, endpoint: String, diff --git a/src/types/embed.rs b/src/types/embed.rs new file mode 100644 index 0000000..e6bdc4d --- /dev/null +++ b/src/types/embed.rs @@ -0,0 +1,306 @@ +//! Types for the embedding endpoint (`POST /api/embed`). +//! +//! Use [`EmbedRequest::builder()`] to construct a request and pass it to +//! [`OllamaClient::embed()`](crate::OllamaClient::embed). +//! +//! # Examples +//! +//! ```no_run +//! # use ollama_rs::OllamaClient; +//! # use ollama_rs::types::embed::EmbedRequest; +//! # async fn run() -> ollama_rs::error::OllamaResult<()> { +//! let client = OllamaClient::default(); +//! let request = EmbedRequest::builder("embeddinggemma") +//! .input("Generate embeddings for this text") +//! .build(); +//! +//! let response = client.embed(request).await?; +//! println!("Embeddings: {:?}", response.embeddings); +//! # Ok(()) +//! # } +//! ``` + +use serde::{Deserialize, Serialize}; + +use crate::types::common::Options; + +/// The input text(s) to generate embeddings for. +/// +/// Accepts either a single string or an array of strings. Serialized as an +/// untagged enum so both `"hello"` and `["hello", "world"]` are valid JSON +/// representations. +/// +/// # Examples +/// +/// ``` +/// use ollama_rs::types::embed::EmbedInput; +/// +/// let single = EmbedInput::Single("hello".to_string()); +/// let multiple = EmbedInput::Multiple(vec!["hello".to_string(), "world".to_string()]); +/// ``` +#[derive(Debug, Serialize, Deserialize)] +#[serde(untagged)] +pub enum EmbedInput { + /// A single text string. + Single(String), + /// Multiple text strings. + Multiple(Vec), +} + +/// A request to generate embeddings (`POST /api/embed`). +/// +/// Construct via [`EmbedRequest::builder()`]. +/// +/// # Examples +/// +/// ``` +/// use ollama_rs::types::embed::EmbedRequest; +/// +/// let request = EmbedRequest::builder("embeddinggemma") +/// .input("Generate embeddings for this text") +/// .build(); +/// ``` +#[derive(Debug, Serialize, Deserialize)] +pub struct EmbedRequest { + /// The model name to use for generating embeddings. + pub model: String, + + /// The text or array of texts to generate embeddings for. + pub input: EmbedInput, + + /// If `true`, truncate inputs that exceed the context window. If `false`, + /// returns an error. Defaults to `true` on the server. + #[serde(skip_serializing_if = "Option::is_none")] + pub truncate: Option, + + /// Number of dimensions to generate embeddings for. + #[serde(skip_serializing_if = "Option::is_none")] + pub dimensions: Option, + + /// How long the model stays loaded in memory (e.g., `"5m"`, `"1h"`). + #[serde(skip_serializing_if = "Option::is_none")] + pub keep_alive: Option, + + /// Runtime options for the embedding model. + #[serde(skip_serializing_if = "Option::is_none")] + pub options: Option, +} + +impl EmbedRequest { + /// Returns an [`EmbedRequestBuilder`] for the given model name. + pub fn builder>(model: M) -> EmbedRequestBuilder { + EmbedRequestBuilder { + embed_request: EmbedRequest { + model: model.into(), + input: EmbedInput::Single(String::new()), + truncate: None, + dimensions: None, + keep_alive: None, + options: None, + }, + } + } +} + +/// A builder for constructing an [`EmbedRequest`]. +/// +/// Obtain a builder via [`EmbedRequest::builder()`]. +/// +/// # Examples +/// +/// ``` +/// use ollama_rs::types::embed::EmbedRequest; +/// +/// let request = EmbedRequest::builder("embeddinggemma") +/// .input("hello world") +/// .truncate(true) +/// .build(); +/// ``` +pub struct EmbedRequestBuilder { + embed_request: EmbedRequest, +} + +impl EmbedRequestBuilder { + /// Sets a single text string as the input. + pub fn input>(mut self, input: S) -> Self { + self.embed_request.input = EmbedInput::Single(input.into()); + self + } + + /// Sets multiple text strings as the input. + pub fn inputs>(mut self, inputs: Vec) -> Self { + self.embed_request.input = + EmbedInput::Multiple(inputs.into_iter().map(|s| s.into()).collect()); + self + } + + /// Sets whether to truncate inputs that exceed the context window. + pub fn truncate(mut self, truncate: bool) -> Self { + self.embed_request.truncate = Some(truncate); + self + } + + /// Sets the number of dimensions for the embeddings. + pub fn dimensions(mut self, dimensions: u32) -> Self { + self.embed_request.dimensions = Some(dimensions); + self + } + + /// Sets how long the model stays loaded in memory (e.g., `"5m"`). + pub fn keep_alive>(mut self, keep_alive: S) -> Self { + self.embed_request.keep_alive = Some(keep_alive.into()); + self + } + + /// Sets runtime options for the embedding model. + pub fn options(mut self, options: Options) -> Self { + self.embed_request.options = Some(options); + self + } + + /// Consumes the builder and returns the configured [`EmbedRequest`]. + pub fn build(self) -> EmbedRequest { + self.embed_request + } +} + +/// The response from the embedding endpoint. +#[derive(Debug, Serialize, Deserialize)] +pub struct EmbedResponse { + /// The model that produced the embeddings. + pub model: String, + + /// The generated vector embeddings. Each inner `Vec` corresponds to + /// one input text, in the same order as the request. + pub embeddings: Vec>, + + /// Total time spent generating embeddings, in nanoseconds. + #[serde(skip_serializing_if = "Option::is_none")] + pub total_duration: Option, + + /// Time spent loading the model, in nanoseconds. + #[serde(skip_serializing_if = "Option::is_none")] + pub load_duration: Option, + + /// Number of input tokens processed to generate the embeddings. + #[serde(skip_serializing_if = "Option::is_none")] + pub prompt_eval_count: Option, +} + +#[cfg(test)] +mod tests { + use super::*; + use serde_json::json; + + #[test] + fn builder_minimal() { + let request = EmbedRequest::builder("embeddinggemma") + .input("hello") + .build(); + assert_eq!(request.model, "embeddinggemma"); + assert!(matches!(request.input, EmbedInput::Single(ref s) if s == "hello")); + assert!(request.truncate.is_none()); + assert!(request.dimensions.is_none()); + assert!(request.keep_alive.is_none()); + assert!(request.options.is_none()); + } + + #[test] + fn builder_with_all_fields() { + let request = EmbedRequest::builder("embeddinggemma") + .inputs(vec!["hello", "world"]) + .truncate(false) + .dimensions(256) + .keep_alive("10m") + .options(Options::builder().seed(42).build()) + .build(); + + assert!(matches!(request.input, EmbedInput::Multiple(ref v) if v.len() == 2)); + assert_eq!(request.truncate, Some(false)); + assert_eq!(request.dimensions, Some(256)); + assert_eq!(request.keep_alive, Some("10m".to_string())); + assert!(request.options.is_some()); + } + + #[test] + fn request_skips_none_fields() { + let request = EmbedRequest::builder("embeddinggemma") + .input("hello") + .build(); + let json = serde_json::to_value(&request).unwrap(); + let obj = json.as_object().unwrap(); + assert!(obj.contains_key("model")); + assert!(obj.contains_key("input")); + assert!(!obj.contains_key("truncate")); + assert!(!obj.contains_key("dimensions")); + assert!(!obj.contains_key("keep_alive")); + assert!(!obj.contains_key("options")); + } + + #[test] + fn request_serializes_single_input() { + let request = EmbedRequest::builder("embeddinggemma") + .input("hello") + .build(); + let json = serde_json::to_value(&request).unwrap(); + assert_eq!(json["input"], json!("hello")); + } + + #[test] + fn request_serializes_multiple_inputs() { + let request = EmbedRequest::builder("embeddinggemma") + .inputs(vec!["hello", "world"]) + .build(); + let json = serde_json::to_value(&request).unwrap(); + assert_eq!(json["input"], json!(["hello", "world"])); + } + + #[test] + fn embed_input_single_round_trip() { + let input = EmbedInput::Single("test".to_string()); + let json = serde_json::to_value(&input).unwrap(); + assert_eq!(json, json!("test")); + let deserialized: EmbedInput = serde_json::from_value(json).unwrap(); + assert!(matches!(deserialized, EmbedInput::Single(s) if s == "test")); + } + + #[test] + fn embed_input_multiple_round_trip() { + let input = EmbedInput::Multiple(vec!["a".to_string(), "b".to_string()]); + let json = serde_json::to_value(&input).unwrap(); + assert_eq!(json, json!(["a", "b"])); + let deserialized: EmbedInput = serde_json::from_value(json).unwrap(); + assert!(matches!(deserialized, EmbedInput::Multiple(v) if v == vec!["a", "b"])); + } + + #[test] + fn response_deserialize() { + let json = json!({ + "model": "embeddinggemma", + "embeddings": [[0.010071029, -0.0017594862, 0.05007221]], + "total_duration": 14143917, + "load_duration": 1019500, + "prompt_eval_count": 8 + }); + let response: EmbedResponse = serde_json::from_value(json).unwrap(); + assert_eq!(response.model, "embeddinggemma"); + assert_eq!(response.embeddings.len(), 1); + assert_eq!(response.embeddings[0].len(), 3); + assert_eq!(response.total_duration, Some(14143917)); + assert_eq!(response.load_duration, Some(1019500)); + assert_eq!(response.prompt_eval_count, Some(8)); + } + + #[test] + fn response_deserialize_minimal() { + let json = json!({ + "model": "embeddinggemma", + "embeddings": [[1.0, 2.0], [3.0, 4.0]] + }); + let response: EmbedResponse = serde_json::from_value(json).unwrap(); + assert_eq!(response.embeddings.len(), 2); + assert!(response.total_duration.is_none()); + assert!(response.load_duration.is_none()); + assert!(response.prompt_eval_count.is_none()); + } +} diff --git a/src/types/mod.rs b/src/types/mod.rs index 5b42851..7916e4d 100644 --- a/src/types/mod.rs +++ b/src/types/mod.rs @@ -6,6 +6,7 @@ //! |--------------|-----------------------|------------------------------------------| //! | [`chat`] | `POST /api/chat` | Multi-turn chat conversations | //! | [`delete`] | `DELETE /api/delete` | Delete a model from the server | +//! | [`embed`] | `POST /api/embed` | Generate vector embeddings | //! | [`generate`] | `POST /api/generate` | Single-prompt text generation | //! | [`pull`] | `POST /api/pull` | Download models from the registry | //! | [`tags`] | `GET /api/tags` | List available models | @@ -19,6 +20,7 @@ pub mod chat; pub mod common; pub mod delete; +pub mod embed; pub mod generate; pub mod ps; pub mod pull;