Adds tags and ps methods
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -1 +1,2 @@
|
|||||||
/target
|
/target
|
||||||
|
.env
|
||||||
|
|||||||
1715
Cargo.lock
generated
Normal file
1715
Cargo.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
11
Cargo.toml
11
Cargo.toml
@@ -4,3 +4,14 @@ version = "0.1.0"
|
|||||||
edition = "2024"
|
edition = "2024"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
|
reqwest = { version = "0.12.28", features = ["json", "stream"] }
|
||||||
|
serde = { version = "1.0.228", features = ["derive"] }
|
||||||
|
serde_json = "1.0.146"
|
||||||
|
tokio-util = "0.7.17"
|
||||||
|
tracing = "0.1.44"
|
||||||
|
futures-util = "0.3.31"
|
||||||
|
|
||||||
|
[dev-dependencies]
|
||||||
|
dotenvy = "0.15.7"
|
||||||
|
tokio = { version = "1.48.0", features = ["full"] }
|
||||||
|
tracing-subscriber = "0.3.22"
|
||||||
|
|||||||
16
examples/generate.rs
Normal file
16
examples/generate.rs
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
use std::{env, error::Error};
|
||||||
|
|
||||||
|
use ollama_rs::{OllamaClient, types::generate::GenerateRequest};
|
||||||
|
|
||||||
|
#[tokio::main]
|
||||||
|
async fn main() -> Result<(), Box<dyn Error>> {
|
||||||
|
let _ = dotenvy::dotenv();
|
||||||
|
let server_address = env::var("OLLAMA_SERVER")?;
|
||||||
|
let ollama_client = OllamaClient::new(server_address);
|
||||||
|
let request = GenerateRequest::builder("dolphin3:8b")
|
||||||
|
.prompt("Why is the sky blue?")
|
||||||
|
.build();
|
||||||
|
let response = ollama_client.generate(request).await?;
|
||||||
|
println!("{:?}", response);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
16
examples/ps.rs
Normal file
16
examples/ps.rs
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
use std::{env, error::Error};
|
||||||
|
|
||||||
|
use ollama_rs::OllamaClient;
|
||||||
|
|
||||||
|
#[tokio::main]
|
||||||
|
async fn main() -> Result<(), Box<dyn Error>> {
|
||||||
|
tracing_subscriber::fmt().init();
|
||||||
|
let _ = dotenvy::dotenv();
|
||||||
|
let server_address = env::var("OLLAMA_SERVER")?;
|
||||||
|
let ollama_client = OllamaClient::new(server_address);
|
||||||
|
let models = ollama_client.list_runnning_models().await?;
|
||||||
|
for model in models {
|
||||||
|
println!("{:?}", model);
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
16
examples/tags.rs
Normal file
16
examples/tags.rs
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
use std::{env, error::Error};
|
||||||
|
|
||||||
|
use ollama_rs::OllamaClient;
|
||||||
|
|
||||||
|
#[tokio::main]
|
||||||
|
async fn main() -> Result<(), Box<dyn Error>> {
|
||||||
|
tracing_subscriber::fmt().init();
|
||||||
|
let _ = dotenvy::dotenv();
|
||||||
|
let server_address = env::var("OLLAMA_SERVER")?;
|
||||||
|
let ollama_client = OllamaClient::new(server_address);
|
||||||
|
let models = ollama_client.tags().await?;
|
||||||
|
for model in models {
|
||||||
|
println!("{:?}", model);
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
32
src/error.rs
Normal file
32
src/error.rs
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
use std::{error::Error, fmt::Display};
|
||||||
|
|
||||||
|
pub type OllamaResult<T> = Result<T, OllamaError>;
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub enum OllamaError {
|
||||||
|
NetworkError(reqwest::Error),
|
||||||
|
ResponseParseError(serde_json::Error),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Error for OllamaError {}
|
||||||
|
|
||||||
|
impl Display for OllamaError {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
match self {
|
||||||
|
OllamaError::NetworkError(e) => writeln!(f, "Network error: {}", e),
|
||||||
|
OllamaError::ResponseParseError(e) => writeln!(f, "ResponseParseError error: {}", e),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<reqwest::Error> for OllamaError {
|
||||||
|
fn from(error: reqwest::Error) -> Self {
|
||||||
|
Self::NetworkError(error)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<serde_json::Error> for OllamaError {
|
||||||
|
fn from(error: serde_json::Error) -> Self {
|
||||||
|
Self::ResponseParseError(error)
|
||||||
|
}
|
||||||
|
}
|
||||||
102
src/lib.rs
Normal file
102
src/lib.rs
Normal file
@@ -0,0 +1,102 @@
|
|||||||
|
use futures_util::{StreamExt};
|
||||||
|
use serde_json::Value;
|
||||||
|
use tokio_util::io::StreamReader;
|
||||||
|
use tracing::info;
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
error::OllamaResult,
|
||||||
|
types::{
|
||||||
|
generate::{GenerateRequest, GenerateResponse},
|
||||||
|
ps::RunningModel,
|
||||||
|
tags::Model,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
pub mod error;
|
||||||
|
pub mod types;
|
||||||
|
|
||||||
|
pub struct OllamaClient {
|
||||||
|
server_address: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl OllamaClient {
|
||||||
|
pub fn new<S: AsRef<str>>(server_address: S) -> Self {
|
||||||
|
Self {
|
||||||
|
server_address: server_address.as_ref().to_string(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Fetch a list of models and their details
|
||||||
|
pub async fn tags(&self) -> OllamaResult<Vec<Model>> {
|
||||||
|
let request_address = format!("{}/api/tags", self.server_address);
|
||||||
|
info!("List models: {}", request_address);
|
||||||
|
let mut response: Value = reqwest::get(request_address)
|
||||||
|
.await?
|
||||||
|
.error_for_status()?
|
||||||
|
.json()
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let Some(response) = response.as_object_mut() else {
|
||||||
|
return Ok(vec![]);
|
||||||
|
};
|
||||||
|
|
||||||
|
let Some(models) = response.remove("models") else {
|
||||||
|
return Ok(vec![]);
|
||||||
|
};
|
||||||
|
|
||||||
|
let models = serde_json::from_value(models)?;
|
||||||
|
Ok(models)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Retrieve a list of models that are currently running
|
||||||
|
pub async fn list_runnning_models(&self) -> OllamaResult<Vec<RunningModel>> {
|
||||||
|
let request_address = format!("{}/api/ps", self.server_address);
|
||||||
|
info!("List models: {}", request_address);
|
||||||
|
let mut response: Value = reqwest::get(request_address)
|
||||||
|
.await?
|
||||||
|
.error_for_status()?
|
||||||
|
.json()
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let Some(response) = response.as_object_mut() else {
|
||||||
|
return Ok(vec![]);
|
||||||
|
};
|
||||||
|
|
||||||
|
let Some(models) = response.remove("models") else {
|
||||||
|
return Ok(vec![]);
|
||||||
|
};
|
||||||
|
|
||||||
|
let models = serde_json::from_value(models)?;
|
||||||
|
Ok(models)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Generates a response for the provided prompt
|
||||||
|
pub async fn generate(&self, request: GenerateRequest) -> OllamaResult<()> {
|
||||||
|
let request_address = format!("{}/api/generate", self.server_address);
|
||||||
|
let client = reqwest::Client::new();
|
||||||
|
let response = client
|
||||||
|
.post(request_address)
|
||||||
|
.json(&request)
|
||||||
|
.send()
|
||||||
|
.await?
|
||||||
|
.error_for_status()?;
|
||||||
|
let stream = response.bytes_stream().;
|
||||||
|
// let reader = BufReader::new(stream);
|
||||||
|
let reader = StreamReader(stream);
|
||||||
|
while reader
|
||||||
|
while let Some(item) = stream.next().await {
|
||||||
|
let item = item?;
|
||||||
|
println!("Chunk: {:?}", item?);
|
||||||
|
}
|
||||||
|
|
||||||
|
// let stream_reader = tokio_util::io::StreamReader::new(stream);
|
||||||
|
// let reder = BufReader::new(stream);
|
||||||
|
// let full_response = response.text().await?;
|
||||||
|
// let parts = full_response
|
||||||
|
// .lines()
|
||||||
|
// .map(|line| serde_json::from_str::<GenerateResponse>(line).unwrap())
|
||||||
|
// .collect::<Vec<_>>();
|
||||||
|
// println!("{:#?}", parts);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,3 +0,0 @@
|
|||||||
fn main() {
|
|
||||||
println!("Hello, world!");
|
|
||||||
}
|
|
||||||
10
src/types/common.rs
Normal file
10
src/types/common.rs
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
pub struct ModelDetails {
|
||||||
|
pub format: String,
|
||||||
|
pub family: String,
|
||||||
|
pub families: Vec<String>,
|
||||||
|
pub parameter_size: String,
|
||||||
|
pub quantization_level: String,
|
||||||
|
}
|
||||||
91
src/types/generate.rs
Normal file
91
src/types/generate.rs
Normal file
@@ -0,0 +1,91 @@
|
|||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
pub struct GenerateRequest {
|
||||||
|
/// Model name
|
||||||
|
pub model: String,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
|
||||||
|
/// Text for the model to generate a response from
|
||||||
|
pub prompt: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
|
||||||
|
/// Used for fill-in-the-middle models, text that appears after the user prompt and before the
|
||||||
|
/// model response
|
||||||
|
pub suffix: Option<String>,
|
||||||
|
|
||||||
|
/// System prompt for the model to generate a response from
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub system: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl GenerateRequest {
|
||||||
|
pub fn builder<M: Into<String>>(model: M) -> GenerateRequestBuilder {
|
||||||
|
GenerateRequestBuilder::new(model)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct GenerateRequestBuilder {
|
||||||
|
generate_request: GenerateRequest,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl GenerateRequestBuilder {
|
||||||
|
fn new<M: Into<String>>(model: M) -> Self {
|
||||||
|
Self {
|
||||||
|
generate_request: GenerateRequest {
|
||||||
|
model: model.into(),
|
||||||
|
prompt: None,
|
||||||
|
suffix: None,
|
||||||
|
system: None,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn prompt<P: Into<String>>(mut self, prompt: P) -> Self {
|
||||||
|
self.generate_request.prompt = Some(prompt.into());
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn build(self) -> GenerateRequest {
|
||||||
|
self.generate_request
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
pub struct GenerateResponse {
|
||||||
|
/// Model name
|
||||||
|
pub model: String,
|
||||||
|
|
||||||
|
/// ISO 8601 timestamp of response creation
|
||||||
|
pub created_at: String,
|
||||||
|
|
||||||
|
/// The model's generated text response
|
||||||
|
pub response: String,
|
||||||
|
|
||||||
|
/// The model's generated thinking output
|
||||||
|
pub thinking: Option<String>,
|
||||||
|
|
||||||
|
/// Indicates whether generation has finished
|
||||||
|
pub done: bool,
|
||||||
|
|
||||||
|
/// Reason the generation stopped
|
||||||
|
pub done_reason: Option<String>,
|
||||||
|
|
||||||
|
/// Time spent generating the response in nanoseconds
|
||||||
|
pub total_duration: Option<usize>,
|
||||||
|
|
||||||
|
/// Time spent loading the model in nanoseconds
|
||||||
|
pub load_duration: Option<usize>,
|
||||||
|
|
||||||
|
/// Number of input tokens in the prompt
|
||||||
|
pub prompt_eval_count: Option<usize>,
|
||||||
|
|
||||||
|
/// Time spent evaluating the prompt in nanoseconds
|
||||||
|
pub prompt_eval_duration: Option<usize>,
|
||||||
|
|
||||||
|
/// Number of output tokens generated in the response
|
||||||
|
pub eval_count: Option<usize>,
|
||||||
|
|
||||||
|
/// Time spent generating tokens in nanoseconds
|
||||||
|
pub eval_duration: Option<usize>,
|
||||||
|
}
|
||||||
4
src/types/mod.rs
Normal file
4
src/types/mod.rs
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
pub mod common;
|
||||||
|
pub mod generate;
|
||||||
|
pub mod ps;
|
||||||
|
pub mod tags;
|
||||||
15
src/types/ps.rs
Normal file
15
src/types/ps.rs
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
use crate::types::common::ModelDetails;
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
pub struct RunningModel {
|
||||||
|
pub name: String,
|
||||||
|
pub model: String,
|
||||||
|
pub size: usize,
|
||||||
|
pub digest: String,
|
||||||
|
pub details: ModelDetails,
|
||||||
|
pub expires_at: String,
|
||||||
|
pub size_vram: usize,
|
||||||
|
pub context_length: u32,
|
||||||
|
}
|
||||||
13
src/types/tags.rs
Normal file
13
src/types/tags.rs
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
use crate::types::common::ModelDetails;
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
pub struct Model {
|
||||||
|
pub name: String,
|
||||||
|
pub model: String,
|
||||||
|
pub modified_at: String,
|
||||||
|
pub size: usize,
|
||||||
|
pub digest: String,
|
||||||
|
pub details: ModelDetails,
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user