Add embed endpoint (POST /api/embed)
Implement the Ollama POST /api/embed endpoint for generating vector embeddings from text input. - Add EmbedInput, EmbedRequest, EmbedResponse types in src/types/embed.rs - Add OllamaClient::embed() async method in src/lib.rs - Register embed module in src/types/mod.rs - Add usage example in examples/embed.rs - Update README with embed endpoint documentation
This commit is contained in:
@@ -9,6 +9,7 @@ An async Rust client library for the [Ollama](https://ollama.com/) API. Provides
|
||||
- Structured JSON output with schema validation
|
||||
- Tool calling / function calling support
|
||||
- Model management (list, pull, delete, inspect running models)
|
||||
- Text embeddings generation
|
||||
- Builder pattern for constructing requests
|
||||
- Configurable generation parameters (temperature, top-k, top-p, and more)
|
||||
- Thinking / reasoning mode support
|
||||
@@ -161,6 +162,7 @@ When the model decides to call a tool, the response `message.tool_calls` field w
|
||||
| `chat(request)` | Chat conversation (streaming) |
|
||||
| `pull(request)` | Pull/download a model (streaming) |
|
||||
| `delete(request)` | Delete a model from the server |
|
||||
| `embed(request)` | Generate vector embeddings |
|
||||
|
||||
**`OllamaClient::builder(server_address)`** -- `.connection_timeout(Duration)`, `.build()`
|
||||
|
||||
@@ -181,6 +183,8 @@ let client = OllamaClient::builder("http://localhost:11434")
|
||||
|
||||
**`PullRequest::builder(model)`** -- `.stream()`
|
||||
|
||||
**`EmbedRequest::builder(model)`** -- `.input()`, `.inputs()`, `.truncate()`, `.dimensions()`, `.keep_alive()`, `.options()`
|
||||
|
||||
### Generation Options
|
||||
|
||||
Configure sampling parameters via `Options::builder()`:
|
||||
@@ -208,6 +212,7 @@ The `examples/` directory contains runnable programs:
|
||||
| `tool_call` | Function calling / tool use |
|
||||
| `pull` | Download a model |
|
||||
| `delete` | Delete a model |
|
||||
| `embed` | Generate text embeddings |
|
||||
| `tags` | List available models |
|
||||
| `ps` | List running models |
|
||||
| `version` | Query server version |
|
||||
|
||||
24
examples/embed.rs
Normal file
24
examples/embed.rs
Normal file
@@ -0,0 +1,24 @@
|
||||
use std::{env, error::Error};
|
||||
|
||||
use ollama_rs::OllamaClient;
|
||||
use ollama_rs::types::embed::EmbedRequest;
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), Box<dyn Error>> {
|
||||
tracing_subscriber::fmt().init();
|
||||
let _ = dotenvy::dotenv();
|
||||
let server_address = env::var("OLLAMA_SERVER")?;
|
||||
let ollama_client = OllamaClient::new(server_address);
|
||||
let request = EmbedRequest::builder("embeddinggemma")
|
||||
.input("Generate embeddings for this text")
|
||||
.build();
|
||||
|
||||
let response = ollama_client.embed(request).await?;
|
||||
for (i, embedding) in response.embeddings.iter().enumerate() {
|
||||
println!("Embedding {}: {} dimensions", i, embedding.len());
|
||||
if embedding.len() >= 3 {
|
||||
println!(" First 3 values: {:?}", &embedding[..3]);
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
42
src/lib.rs
42
src/lib.rs
@@ -94,6 +94,7 @@ use crate::{
|
||||
types::{
|
||||
chat::{ChatRequest, ChatResponse},
|
||||
delete::DeleteRequest,
|
||||
embed::{EmbedRequest, EmbedResponse},
|
||||
generate::{GenerateRequest, GenerateResponse},
|
||||
ps::PsResponse,
|
||||
pull::{PullRequest, PullResponse},
|
||||
@@ -282,6 +283,47 @@ impl OllamaClient {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Generates vector embeddings for the given input text(s).
|
||||
///
|
||||
/// Calls `POST /api/embed`.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns [`OllamaError::NetworkError`] if the server is unreachable or returns
|
||||
/// a non-success status code.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```no_run
|
||||
/// # use ollama_rs::OllamaClient;
|
||||
/// # use ollama_rs::types::embed::EmbedRequest;
|
||||
/// # async fn run() -> ollama_rs::error::OllamaResult<()> {
|
||||
/// let client = OllamaClient::default();
|
||||
/// let request = EmbedRequest::builder("embeddinggemma")
|
||||
/// .input("Generate embeddings for this text")
|
||||
/// .build();
|
||||
///
|
||||
/// let response = client.embed(request).await?;
|
||||
/// for embedding in &response.embeddings {
|
||||
/// println!("Dimension count: {}", embedding.len());
|
||||
/// }
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
pub async fn embed(&self, request: EmbedRequest) -> OllamaResult<EmbedResponse> {
|
||||
let request_address = format!("{}/api/embed", self.server_address);
|
||||
info!("Generate embeddings: {}", request.model);
|
||||
Ok(self
|
||||
.client
|
||||
.post(request_address)
|
||||
.json(&request)
|
||||
.send()
|
||||
.await?
|
||||
.error_for_status()?
|
||||
.json()
|
||||
.await?)
|
||||
}
|
||||
|
||||
fn stream_response<R: Serialize, T: DeserializeOwned>(
|
||||
&self,
|
||||
endpoint: String,
|
||||
|
||||
306
src/types/embed.rs
Normal file
306
src/types/embed.rs
Normal file
@@ -0,0 +1,306 @@
|
||||
//! Types for the embedding endpoint (`POST /api/embed`).
|
||||
//!
|
||||
//! Use [`EmbedRequest::builder()`] to construct a request and pass it to
|
||||
//! [`OllamaClient::embed()`](crate::OllamaClient::embed).
|
||||
//!
|
||||
//! # Examples
|
||||
//!
|
||||
//! ```no_run
|
||||
//! # use ollama_rs::OllamaClient;
|
||||
//! # use ollama_rs::types::embed::EmbedRequest;
|
||||
//! # async fn run() -> ollama_rs::error::OllamaResult<()> {
|
||||
//! let client = OllamaClient::default();
|
||||
//! let request = EmbedRequest::builder("embeddinggemma")
|
||||
//! .input("Generate embeddings for this text")
|
||||
//! .build();
|
||||
//!
|
||||
//! let response = client.embed(request).await?;
|
||||
//! println!("Embeddings: {:?}", response.embeddings);
|
||||
//! # Ok(())
|
||||
//! # }
|
||||
//! ```
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::types::common::Options;
|
||||
|
||||
/// The input text(s) to generate embeddings for.
|
||||
///
|
||||
/// Accepts either a single string or an array of strings. Serialized as an
|
||||
/// untagged enum so both `"hello"` and `["hello", "world"]` are valid JSON
|
||||
/// representations.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```
|
||||
/// use ollama_rs::types::embed::EmbedInput;
|
||||
///
|
||||
/// let single = EmbedInput::Single("hello".to_string());
|
||||
/// let multiple = EmbedInput::Multiple(vec!["hello".to_string(), "world".to_string()]);
|
||||
/// ```
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum EmbedInput {
|
||||
/// A single text string.
|
||||
Single(String),
|
||||
/// Multiple text strings.
|
||||
Multiple(Vec<String>),
|
||||
}
|
||||
|
||||
/// A request to generate embeddings (`POST /api/embed`).
|
||||
///
|
||||
/// Construct via [`EmbedRequest::builder()`].
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```
|
||||
/// use ollama_rs::types::embed::EmbedRequest;
|
||||
///
|
||||
/// let request = EmbedRequest::builder("embeddinggemma")
|
||||
/// .input("Generate embeddings for this text")
|
||||
/// .build();
|
||||
/// ```
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct EmbedRequest {
|
||||
/// The model name to use for generating embeddings.
|
||||
pub model: String,
|
||||
|
||||
/// The text or array of texts to generate embeddings for.
|
||||
pub input: EmbedInput,
|
||||
|
||||
/// If `true`, truncate inputs that exceed the context window. If `false`,
|
||||
/// returns an error. Defaults to `true` on the server.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub truncate: Option<bool>,
|
||||
|
||||
/// Number of dimensions to generate embeddings for.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub dimensions: Option<u32>,
|
||||
|
||||
/// How long the model stays loaded in memory (e.g., `"5m"`, `"1h"`).
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub keep_alive: Option<String>,
|
||||
|
||||
/// Runtime options for the embedding model.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub options: Option<Options>,
|
||||
}
|
||||
|
||||
impl EmbedRequest {
|
||||
/// Returns an [`EmbedRequestBuilder`] for the given model name.
|
||||
pub fn builder<M: Into<String>>(model: M) -> EmbedRequestBuilder {
|
||||
EmbedRequestBuilder {
|
||||
embed_request: EmbedRequest {
|
||||
model: model.into(),
|
||||
input: EmbedInput::Single(String::new()),
|
||||
truncate: None,
|
||||
dimensions: None,
|
||||
keep_alive: None,
|
||||
options: None,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A builder for constructing an [`EmbedRequest`].
|
||||
///
|
||||
/// Obtain a builder via [`EmbedRequest::builder()`].
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```
|
||||
/// use ollama_rs::types::embed::EmbedRequest;
|
||||
///
|
||||
/// let request = EmbedRequest::builder("embeddinggemma")
|
||||
/// .input("hello world")
|
||||
/// .truncate(true)
|
||||
/// .build();
|
||||
/// ```
|
||||
pub struct EmbedRequestBuilder {
|
||||
embed_request: EmbedRequest,
|
||||
}
|
||||
|
||||
impl EmbedRequestBuilder {
|
||||
/// Sets a single text string as the input.
|
||||
pub fn input<S: Into<String>>(mut self, input: S) -> Self {
|
||||
self.embed_request.input = EmbedInput::Single(input.into());
|
||||
self
|
||||
}
|
||||
|
||||
/// Sets multiple text strings as the input.
|
||||
pub fn inputs<S: Into<String>>(mut self, inputs: Vec<S>) -> Self {
|
||||
self.embed_request.input =
|
||||
EmbedInput::Multiple(inputs.into_iter().map(|s| s.into()).collect());
|
||||
self
|
||||
}
|
||||
|
||||
/// Sets whether to truncate inputs that exceed the context window.
|
||||
pub fn truncate(mut self, truncate: bool) -> Self {
|
||||
self.embed_request.truncate = Some(truncate);
|
||||
self
|
||||
}
|
||||
|
||||
/// Sets the number of dimensions for the embeddings.
|
||||
pub fn dimensions(mut self, dimensions: u32) -> Self {
|
||||
self.embed_request.dimensions = Some(dimensions);
|
||||
self
|
||||
}
|
||||
|
||||
/// Sets how long the model stays loaded in memory (e.g., `"5m"`).
|
||||
pub fn keep_alive<S: Into<String>>(mut self, keep_alive: S) -> Self {
|
||||
self.embed_request.keep_alive = Some(keep_alive.into());
|
||||
self
|
||||
}
|
||||
|
||||
/// Sets runtime options for the embedding model.
|
||||
pub fn options(mut self, options: Options) -> Self {
|
||||
self.embed_request.options = Some(options);
|
||||
self
|
||||
}
|
||||
|
||||
/// Consumes the builder and returns the configured [`EmbedRequest`].
|
||||
pub fn build(self) -> EmbedRequest {
|
||||
self.embed_request
|
||||
}
|
||||
}
|
||||
|
||||
/// The response from the embedding endpoint.
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct EmbedResponse {
|
||||
/// The model that produced the embeddings.
|
||||
pub model: String,
|
||||
|
||||
/// The generated vector embeddings. Each inner `Vec<f64>` corresponds to
|
||||
/// one input text, in the same order as the request.
|
||||
pub embeddings: Vec<Vec<f64>>,
|
||||
|
||||
/// Total time spent generating embeddings, in nanoseconds.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub total_duration: Option<u64>,
|
||||
|
||||
/// Time spent loading the model, in nanoseconds.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub load_duration: Option<u64>,
|
||||
|
||||
/// Number of input tokens processed to generate the embeddings.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub prompt_eval_count: Option<u64>,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use serde_json::json;
|
||||
|
||||
#[test]
|
||||
fn builder_minimal() {
|
||||
let request = EmbedRequest::builder("embeddinggemma")
|
||||
.input("hello")
|
||||
.build();
|
||||
assert_eq!(request.model, "embeddinggemma");
|
||||
assert!(matches!(request.input, EmbedInput::Single(ref s) if s == "hello"));
|
||||
assert!(request.truncate.is_none());
|
||||
assert!(request.dimensions.is_none());
|
||||
assert!(request.keep_alive.is_none());
|
||||
assert!(request.options.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn builder_with_all_fields() {
|
||||
let request = EmbedRequest::builder("embeddinggemma")
|
||||
.inputs(vec!["hello", "world"])
|
||||
.truncate(false)
|
||||
.dimensions(256)
|
||||
.keep_alive("10m")
|
||||
.options(Options::builder().seed(42).build())
|
||||
.build();
|
||||
|
||||
assert!(matches!(request.input, EmbedInput::Multiple(ref v) if v.len() == 2));
|
||||
assert_eq!(request.truncate, Some(false));
|
||||
assert_eq!(request.dimensions, Some(256));
|
||||
assert_eq!(request.keep_alive, Some("10m".to_string()));
|
||||
assert!(request.options.is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn request_skips_none_fields() {
|
||||
let request = EmbedRequest::builder("embeddinggemma")
|
||||
.input("hello")
|
||||
.build();
|
||||
let json = serde_json::to_value(&request).unwrap();
|
||||
let obj = json.as_object().unwrap();
|
||||
assert!(obj.contains_key("model"));
|
||||
assert!(obj.contains_key("input"));
|
||||
assert!(!obj.contains_key("truncate"));
|
||||
assert!(!obj.contains_key("dimensions"));
|
||||
assert!(!obj.contains_key("keep_alive"));
|
||||
assert!(!obj.contains_key("options"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn request_serializes_single_input() {
|
||||
let request = EmbedRequest::builder("embeddinggemma")
|
||||
.input("hello")
|
||||
.build();
|
||||
let json = serde_json::to_value(&request).unwrap();
|
||||
assert_eq!(json["input"], json!("hello"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn request_serializes_multiple_inputs() {
|
||||
let request = EmbedRequest::builder("embeddinggemma")
|
||||
.inputs(vec!["hello", "world"])
|
||||
.build();
|
||||
let json = serde_json::to_value(&request).unwrap();
|
||||
assert_eq!(json["input"], json!(["hello", "world"]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn embed_input_single_round_trip() {
|
||||
let input = EmbedInput::Single("test".to_string());
|
||||
let json = serde_json::to_value(&input).unwrap();
|
||||
assert_eq!(json, json!("test"));
|
||||
let deserialized: EmbedInput = serde_json::from_value(json).unwrap();
|
||||
assert!(matches!(deserialized, EmbedInput::Single(s) if s == "test"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn embed_input_multiple_round_trip() {
|
||||
let input = EmbedInput::Multiple(vec!["a".to_string(), "b".to_string()]);
|
||||
let json = serde_json::to_value(&input).unwrap();
|
||||
assert_eq!(json, json!(["a", "b"]));
|
||||
let deserialized: EmbedInput = serde_json::from_value(json).unwrap();
|
||||
assert!(matches!(deserialized, EmbedInput::Multiple(v) if v == vec!["a", "b"]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn response_deserialize() {
|
||||
let json = json!({
|
||||
"model": "embeddinggemma",
|
||||
"embeddings": [[0.010071029, -0.0017594862, 0.05007221]],
|
||||
"total_duration": 14143917,
|
||||
"load_duration": 1019500,
|
||||
"prompt_eval_count": 8
|
||||
});
|
||||
let response: EmbedResponse = serde_json::from_value(json).unwrap();
|
||||
assert_eq!(response.model, "embeddinggemma");
|
||||
assert_eq!(response.embeddings.len(), 1);
|
||||
assert_eq!(response.embeddings[0].len(), 3);
|
||||
assert_eq!(response.total_duration, Some(14143917));
|
||||
assert_eq!(response.load_duration, Some(1019500));
|
||||
assert_eq!(response.prompt_eval_count, Some(8));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn response_deserialize_minimal() {
|
||||
let json = json!({
|
||||
"model": "embeddinggemma",
|
||||
"embeddings": [[1.0, 2.0], [3.0, 4.0]]
|
||||
});
|
||||
let response: EmbedResponse = serde_json::from_value(json).unwrap();
|
||||
assert_eq!(response.embeddings.len(), 2);
|
||||
assert!(response.total_duration.is_none());
|
||||
assert!(response.load_duration.is_none());
|
||||
assert!(response.prompt_eval_count.is_none());
|
||||
}
|
||||
}
|
||||
@@ -6,6 +6,7 @@
|
||||
//! |--------------|-----------------------|------------------------------------------|
|
||||
//! | [`chat`] | `POST /api/chat` | Multi-turn chat conversations |
|
||||
//! | [`delete`] | `DELETE /api/delete` | Delete a model from the server |
|
||||
//! | [`embed`] | `POST /api/embed` | Generate vector embeddings |
|
||||
//! | [`generate`] | `POST /api/generate` | Single-prompt text generation |
|
||||
//! | [`pull`] | `POST /api/pull` | Download models from the registry |
|
||||
//! | [`tags`] | `GET /api/tags` | List available models |
|
||||
@@ -19,6 +20,7 @@
|
||||
pub mod chat;
|
||||
pub mod common;
|
||||
pub mod delete;
|
||||
pub mod embed;
|
||||
pub mod generate;
|
||||
pub mod ps;
|
||||
pub mod pull;
|
||||
|
||||
Reference in New Issue
Block a user