Unifies streaming code

This commit is contained in:
2025-12-28 12:50:47 +00:00
parent ac480881e4
commit 1c431715c6
3 changed files with 26 additions and 75 deletions

View File

@@ -39,6 +39,9 @@ async fn main() -> Result<(), Box<dyn Error>> {
full_message += &response.message.content; full_message += &response.message.content;
print!("{}", response.message.content); print!("{}", response.message.content);
std::io::stdout().flush()?; std::io::stdout().flush()?;
if response.done {
break;
}
} }
println!(); println!();

View File

@@ -21,6 +21,9 @@ async fn main() -> Result<(), Box<dyn Error>> {
Ok(token) => { Ok(token) => {
print!("{}", token.response); print!("{}", token.response);
std::io::stdout().flush()?; std::io::stdout().flush()?;
if token.done {
break;
}
} }
Err(e) => println!("Error: {}", e), Err(e) => println!("Error: {}", e),
} }

View File

@@ -1,5 +1,6 @@
use async_stream::stream; use async_stream::stream;
use futures_util::{Stream, StreamExt}; use futures_util::{Stream, StreamExt};
use serde::{Serialize, de::DeserializeOwned};
use serde_json::Value; use serde_json::Value;
use tokio_util::{ use tokio_util::{
codec::{FramedRead, LinesCodec}, codec::{FramedRead, LinesCodec},
@@ -76,27 +77,24 @@ impl OllamaClient {
Ok(models) Ok(models)
} }
/// Generates a response for the provided prompt async fn stream_response<R: Serialize, T: DeserializeOwned>(
pub async fn generate(
&self, &self,
request: GenerateRequest, endpoint: String,
) -> impl Stream<Item = OllamaResult<GenerateResponse>> { request: R,
let request_address = format!("{}/api/generate", self.server_address); ) -> impl Stream<Item = OllamaResult<T>> {
let client = reqwest::Client::new(); let client = reqwest::Client::new();
// The stream macro creates an asynchronous generator
Box::pin(stream! { Box::pin(stream! {
let response = client let response = client
.post(request_address) .post(endpoint)
.json(&request) .json(&request)
.send() .send()
.await .await
.map_err(|e| OllamaError::from(e))?; // Adjust based on your error type .map_err(OllamaError::from)?; // Adjust based on your error type
let bytes_stream = response.bytes_stream(); let bytes_stream = response.bytes_stream();
let body_reader = StreamReader::new( let body_reader = StreamReader::new(
bytes_stream.map(|res| res.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))), bytes_stream.map(|res| res.map_err(std::io::Error::other)),
); );
let mut lines_stream = FramedRead::new(body_reader, LinesCodec::new()); let mut lines_stream = FramedRead::new(body_reader, LinesCodec::new());
@@ -104,10 +102,8 @@ impl OllamaClient {
while let Some(line_result) = lines_stream.next().await { while let Some(line_result) = lines_stream.next().await {
match line_result { match line_result {
Ok(line_content) => { Ok(line_content) => {
if let Ok(parsed) = serde_json::from_str::<GenerateResponse>(&line_content) { if let Ok(parsed) = serde_json::from_str::<T>(&line_content) {
let done = parsed.done;
yield Ok(parsed); yield Ok(parsed);
if done { break; }
} }
} }
Err(e) => yield Err(OllamaError::from(e)), Err(e) => yield Err(OllamaError::from(e)),
@@ -116,44 +112,22 @@ impl OllamaClient {
}) })
} }
/// Generates a response for the provided prompt
pub async fn generate(
&self,
request: GenerateRequest,
) -> impl Stream<Item = OllamaResult<GenerateResponse>> {
let request_address = format!("{}/api/generate", self.server_address);
self.stream_response(request_address, request).await
}
/// Generate the next chat message in a conversation between a user and an assistant. /// Generate the next chat message in a conversation between a user and an assistant.
pub async fn chat( pub async fn chat(
&self, &self,
request: ChatRequest, request: ChatRequest,
) -> impl Stream<Item = OllamaResult<ChatResponse>> { ) -> impl Stream<Item = OllamaResult<ChatResponse>> {
let request_address = format!("{}/api/chat", self.server_address); let request_address = format!("{}/api/chat", self.server_address);
let client = reqwest::Client::new(); self.stream_response(request_address, request).await
// The stream macro creates an asynchronous generator
Box::pin(stream! {
let response = client
.post(request_address)
.json(&request)
.send()
.await
.map_err(|e| OllamaError::from(e))?; // Adjust based on your error type
let bytes_stream = response.bytes_stream();
let body_reader = StreamReader::new(
bytes_stream.map(|res| res.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))),
);
let mut lines_stream = FramedRead::new(body_reader, LinesCodec::new());
while let Some(line_result) = lines_stream.next().await {
match line_result {
Ok(line_content) => {
if let Ok(parsed) = serde_json::from_str::<ChatResponse>(&line_content) {
let done = parsed.done;
yield Ok(parsed);
if done { break; }
}
}
Err(e) => yield Err(OllamaError::from(e)),
}
}
})
} }
pub async fn pull( pub async fn pull(
@@ -161,35 +135,6 @@ impl OllamaClient {
request: PullRequest, request: PullRequest,
) -> impl Stream<Item = OllamaResult<PullResponse>> { ) -> impl Stream<Item = OllamaResult<PullResponse>> {
let request_address = format!("{}/api/pull", self.server_address); let request_address = format!("{}/api/pull", self.server_address);
let client = reqwest::Client::new(); self.stream_response(request_address, request).await
Box::pin(stream! {
let response = client
.post(request_address)
.json(&request)
.send()
.await
.map_err(|e| OllamaError::from(e))?; // Adjust based on your error type
let bytes_stream = response.bytes_stream();
let body_reader = StreamReader::new(
bytes_stream.map(|res| res.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))),
);
let mut lines_stream = FramedRead::new(body_reader, LinesCodec::new());
while let Some(line_result) = lines_stream.next().await {
match line_result {
Ok(line_content) => {
println!("{line_content}");
if let Ok(parsed) = serde_json::from_str::<PullResponse>(&line_content) {
yield Ok(parsed);
}
}
Err(e) => yield Err(OllamaError::from(e)),
}
}
})
} }
} }