use async_stream::stream; use futures_util::{Stream, StreamExt}; use serde::{Serialize, de::DeserializeOwned}; use tokio_util::{ codec::{FramedRead, LinesCodec}, io::StreamReader, }; use tracing::info; use crate::{ error::{OllamaError, OllamaResult}, types::{ chat::{ChatRequest, ChatResponse}, generate::{GenerateRequest, GenerateResponse}, ps::PsResponse, pull::{PullRequest, PullResponse}, tags::TagsResponse, version::VersionResponse, }, }; pub mod error; pub mod types; pub struct OllamaClient { server_address: String, } impl OllamaClient { pub fn new>(server_address: S) -> Self { Self { server_address: server_address.as_ref().to_string(), } } /// Retrieve the version of the Ollama pub async fn version(&self) -> OllamaResult { let request_address = format!("{}/api/version", self.server_address); Ok(reqwest::get(request_address) .await? .error_for_status()? .json() .await?) } /// Fetch a list of models and their details pub async fn tags(&self) -> OllamaResult { let request_address = format!("{}/api/tags", self.server_address); info!("List models: {}", request_address); Ok(reqwest::get(request_address) .await? .error_for_status()? .json() .await?) } /// Retrieve a list of models that are currently running pub async fn ps(&self) -> OllamaResult { let request_address = format!("{}/api/ps", self.server_address); info!("List models: {}", request_address); Ok(reqwest::get(request_address) .await? .error_for_status()? .json() .await?) } async fn stream_response( &self, endpoint: String, request: R, ) -> impl Stream> { let client = reqwest::Client::new(); Box::pin(stream! { let response = client .post(endpoint) .json(&request) .send() .await .map_err(OllamaError::from)?; // Adjust based on your error type let bytes_stream = response.bytes_stream(); let body_reader = StreamReader::new( bytes_stream.map(|res| res.map_err(std::io::Error::other)), ); let mut lines_stream = FramedRead::new(body_reader, LinesCodec::new()); while let Some(line_result) = lines_stream.next().await { match line_result { Ok(line_content) => { if let Ok(parsed) = serde_json::from_str::(&line_content) { yield Ok(parsed); } } Err(e) => yield Err(OllamaError::from(e)), } } }) } /// Generates a response for the provided prompt pub async fn generate( &self, request: GenerateRequest, ) -> impl Stream> { let request_address = format!("{}/api/generate", self.server_address); self.stream_response(request_address, request).await } /// Generate the next chat message in a conversation between a user and an assistant. pub async fn chat( &self, request: ChatRequest, ) -> impl Stream> { let request_address = format!("{}/api/chat", self.server_address); self.stream_response(request_address, request).await } /// Pull a model pub async fn pull( &self, request: PullRequest, ) -> impl Stream> { let request_address = format!("{}/api/pull", self.server_address); self.stream_response(request_address, request).await } }