diff --git a/Cargo.lock b/Cargo.lock index 22bda13..780760b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1002,9 +1002,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.147" +version = "1.0.148" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6af14725505314343e673e9ecb7cd7e8a36aa9791eb936235a3567cc31447ae4" +checksum = "3084b546a1dd6289475996f182a22aba973866ea8e8b02c51d9f46b1336a22da" dependencies = [ "itoa", "memchr", @@ -1783,6 +1783,6 @@ dependencies = [ [[package]] name = "zmij" -version = "0.1.7" +version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9e404bcd8afdaf006e529269d3e85a743f9480c3cef60034d77860d02964f3ba" +checksum = "e6d6085d62852e35540689d1f97ad663e3971fc19cf5eceab364d62c646ea167" diff --git a/Cargo.toml b/Cargo.toml index 0a045bf..e02d9b6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,7 +6,7 @@ edition = "2024" [dependencies] reqwest = { version = "0.12.28", features = ["json", "stream"] } serde = { version = "1.0.228", features = ["derive"] } -serde_json = "1.0.147" +serde_json = "1.0.148" tokio-util = "0.7.17" tracing = "0.1.44" futures-util = "0.3.31" diff --git a/examples/pull.rs b/examples/pull.rs new file mode 100644 index 0000000..c61f8a8 --- /dev/null +++ b/examples/pull.rs @@ -0,0 +1,22 @@ +use std::{env, error::Error, io::Write}; + +use futures_util::StreamExt; +use ollama_rs::{OllamaClient, types::pull::PullRequest}; + +const MODEL: &str = "HammerAI/mythomax-l2"; + +#[tokio::main] +async fn main() -> Result<(), Box> { + let _ = dotenvy::dotenv(); + let server_address = env::var("OLLAMA_SERVER")?; + let ollama_client = OllamaClient::new(server_address); + + let request = PullRequest::builder(MODEL).stream(true).build(); + let mut stream = ollama_client.pull(request).await; + while let Some(response) = stream.next().await { + let response = response?; + println!("{:?}", response); + std::io::stdout().flush()?; + } + Ok(()) +} diff --git a/src/lib.rs b/src/lib.rs index 584ffd7..0951dd6 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -13,6 +13,7 @@ use crate::{ chat::{ChatRequest, ChatResponse}, generate::{GenerateRequest, GenerateResponse}, ps::RunningModel, + pull::{PullRequest, PullResponse}, tags::Model, }, }; @@ -154,4 +155,41 @@ impl OllamaClient { } }) } + + pub async fn pull( + &self, + request: PullRequest, + ) -> impl Stream> { + let request_address = format!("{}/api/pull", self.server_address); + let client = reqwest::Client::new(); + + Box::pin(stream! { + let response = client + .post(request_address) + .json(&request) + .send() + .await + .map_err(|e| OllamaError::from(e))?; // Adjust based on your error type + + let bytes_stream = response.bytes_stream(); + + let body_reader = StreamReader::new( + bytes_stream.map(|res| res.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))), + ); + + let mut lines_stream = FramedRead::new(body_reader, LinesCodec::new()); + + while let Some(line_result) = lines_stream.next().await { + match line_result { + Ok(line_content) => { + println!("{line_content}"); + if let Ok(parsed) = serde_json::from_str::(&line_content) { + yield Ok(parsed); + } + } + Err(e) => yield Err(OllamaError::from(e)), + } + } + }) + } } diff --git a/src/types/mod.rs b/src/types/mod.rs index f20738e..1fcc5d3 100644 --- a/src/types/mod.rs +++ b/src/types/mod.rs @@ -2,4 +2,5 @@ pub mod chat; pub mod common; pub mod generate; pub mod ps; +pub mod pull; pub mod tags; diff --git a/src/types/pull.rs b/src/types/pull.rs new file mode 100644 index 0000000..eda931e --- /dev/null +++ b/src/types/pull.rs @@ -0,0 +1,45 @@ +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Serialize, Deserialize)] +pub struct PullRequest { + pub model: String, + pub insecure: Option, + pub stream: Option, +} + +impl PullRequest { + pub fn builder>(model: M) -> PullRequestBuilder { + PullRequestBuilder { + pull_request: PullRequest { + model: model.into(), + insecure: None, + stream: None, + }, + } + } +} + +pub struct PullRequestBuilder { + pull_request: PullRequest, +} + +impl PullRequestBuilder { + pub fn stream(mut self, stream: bool) -> Self { + self.pull_request.stream = Some(stream); + self + } + + pub fn insecure(mut self, insecure: bool) -> Self { + self.pull_request.insecure = Some(insecure); + self + } + + pub fn build(self) -> PullRequest { + self.pull_request + } +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct PullResponse { + pub status: String, +}