Adds pull method and example
This commit is contained in:
8
Cargo.lock
generated
8
Cargo.lock
generated
@@ -1002,9 +1002,9 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "serde_json"
|
name = "serde_json"
|
||||||
version = "1.0.147"
|
version = "1.0.148"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "6af14725505314343e673e9ecb7cd7e8a36aa9791eb936235a3567cc31447ae4"
|
checksum = "3084b546a1dd6289475996f182a22aba973866ea8e8b02c51d9f46b1336a22da"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"itoa",
|
"itoa",
|
||||||
"memchr",
|
"memchr",
|
||||||
@@ -1783,6 +1783,6 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "zmij"
|
name = "zmij"
|
||||||
version = "0.1.7"
|
version = "1.0.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "9e404bcd8afdaf006e529269d3e85a743f9480c3cef60034d77860d02964f3ba"
|
checksum = "e6d6085d62852e35540689d1f97ad663e3971fc19cf5eceab364d62c646ea167"
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ edition = "2024"
|
|||||||
[dependencies]
|
[dependencies]
|
||||||
reqwest = { version = "0.12.28", features = ["json", "stream"] }
|
reqwest = { version = "0.12.28", features = ["json", "stream"] }
|
||||||
serde = { version = "1.0.228", features = ["derive"] }
|
serde = { version = "1.0.228", features = ["derive"] }
|
||||||
serde_json = "1.0.147"
|
serde_json = "1.0.148"
|
||||||
tokio-util = "0.7.17"
|
tokio-util = "0.7.17"
|
||||||
tracing = "0.1.44"
|
tracing = "0.1.44"
|
||||||
futures-util = "0.3.31"
|
futures-util = "0.3.31"
|
||||||
|
|||||||
22
examples/pull.rs
Normal file
22
examples/pull.rs
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
use std::{env, error::Error, io::Write};
|
||||||
|
|
||||||
|
use futures_util::StreamExt;
|
||||||
|
use ollama_rs::{OllamaClient, types::pull::PullRequest};
|
||||||
|
|
||||||
|
const MODEL: &str = "HammerAI/mythomax-l2";
|
||||||
|
|
||||||
|
#[tokio::main]
|
||||||
|
async fn main() -> Result<(), Box<dyn Error>> {
|
||||||
|
let _ = dotenvy::dotenv();
|
||||||
|
let server_address = env::var("OLLAMA_SERVER")?;
|
||||||
|
let ollama_client = OllamaClient::new(server_address);
|
||||||
|
|
||||||
|
let request = PullRequest::builder(MODEL).stream(true).build();
|
||||||
|
let mut stream = ollama_client.pull(request).await;
|
||||||
|
while let Some(response) = stream.next().await {
|
||||||
|
let response = response?;
|
||||||
|
println!("{:?}", response);
|
||||||
|
std::io::stdout().flush()?;
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
38
src/lib.rs
38
src/lib.rs
@@ -13,6 +13,7 @@ use crate::{
|
|||||||
chat::{ChatRequest, ChatResponse},
|
chat::{ChatRequest, ChatResponse},
|
||||||
generate::{GenerateRequest, GenerateResponse},
|
generate::{GenerateRequest, GenerateResponse},
|
||||||
ps::RunningModel,
|
ps::RunningModel,
|
||||||
|
pull::{PullRequest, PullResponse},
|
||||||
tags::Model,
|
tags::Model,
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
@@ -154,4 +155,41 @@ impl OllamaClient {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub async fn pull(
|
||||||
|
&self,
|
||||||
|
request: PullRequest,
|
||||||
|
) -> impl Stream<Item = OllamaResult<PullResponse>> {
|
||||||
|
let request_address = format!("{}/api/pull", self.server_address);
|
||||||
|
let client = reqwest::Client::new();
|
||||||
|
|
||||||
|
Box::pin(stream! {
|
||||||
|
let response = client
|
||||||
|
.post(request_address)
|
||||||
|
.json(&request)
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.map_err(|e| OllamaError::from(e))?; // Adjust based on your error type
|
||||||
|
|
||||||
|
let bytes_stream = response.bytes_stream();
|
||||||
|
|
||||||
|
let body_reader = StreamReader::new(
|
||||||
|
bytes_stream.map(|res| res.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))),
|
||||||
|
);
|
||||||
|
|
||||||
|
let mut lines_stream = FramedRead::new(body_reader, LinesCodec::new());
|
||||||
|
|
||||||
|
while let Some(line_result) = lines_stream.next().await {
|
||||||
|
match line_result {
|
||||||
|
Ok(line_content) => {
|
||||||
|
println!("{line_content}");
|
||||||
|
if let Ok(parsed) = serde_json::from_str::<PullResponse>(&line_content) {
|
||||||
|
yield Ok(parsed);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => yield Err(OllamaError::from(e)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,4 +2,5 @@ pub mod chat;
|
|||||||
pub mod common;
|
pub mod common;
|
||||||
pub mod generate;
|
pub mod generate;
|
||||||
pub mod ps;
|
pub mod ps;
|
||||||
|
pub mod pull;
|
||||||
pub mod tags;
|
pub mod tags;
|
||||||
|
|||||||
45
src/types/pull.rs
Normal file
45
src/types/pull.rs
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
pub struct PullRequest {
|
||||||
|
pub model: String,
|
||||||
|
pub insecure: Option<bool>,
|
||||||
|
pub stream: Option<bool>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PullRequest {
|
||||||
|
pub fn builder<M: Into<String>>(model: M) -> PullRequestBuilder {
|
||||||
|
PullRequestBuilder {
|
||||||
|
pull_request: PullRequest {
|
||||||
|
model: model.into(),
|
||||||
|
insecure: None,
|
||||||
|
stream: None,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct PullRequestBuilder {
|
||||||
|
pull_request: PullRequest,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PullRequestBuilder {
|
||||||
|
pub fn stream(mut self, stream: bool) -> Self {
|
||||||
|
self.pull_request.stream = Some(stream);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn insecure(mut self, insecure: bool) -> Self {
|
||||||
|
self.pull_request.insecure = Some(insecure);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn build(self) -> PullRequest {
|
||||||
|
self.pull_request
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
pub struct PullResponse {
|
||||||
|
pub status: String,
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user