Unifies streaming code
This commit is contained in:
@@ -39,6 +39,9 @@ async fn main() -> Result<(), Box<dyn Error>> {
|
|||||||
full_message += &response.message.content;
|
full_message += &response.message.content;
|
||||||
print!("{}", response.message.content);
|
print!("{}", response.message.content);
|
||||||
std::io::stdout().flush()?;
|
std::io::stdout().flush()?;
|
||||||
|
if response.done {
|
||||||
|
break;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
println!();
|
println!();
|
||||||
|
|
||||||
|
|||||||
@@ -21,6 +21,9 @@ async fn main() -> Result<(), Box<dyn Error>> {
|
|||||||
Ok(token) => {
|
Ok(token) => {
|
||||||
print!("{}", token.response);
|
print!("{}", token.response);
|
||||||
std::io::stdout().flush()?;
|
std::io::stdout().flush()?;
|
||||||
|
if token.done {
|
||||||
|
break;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
Err(e) => println!("Error: {}", e),
|
Err(e) => println!("Error: {}", e),
|
||||||
}
|
}
|
||||||
|
|||||||
95
src/lib.rs
95
src/lib.rs
@@ -1,5 +1,6 @@
|
|||||||
use async_stream::stream;
|
use async_stream::stream;
|
||||||
use futures_util::{Stream, StreamExt};
|
use futures_util::{Stream, StreamExt};
|
||||||
|
use serde::{Serialize, de::DeserializeOwned};
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
use tokio_util::{
|
use tokio_util::{
|
||||||
codec::{FramedRead, LinesCodec},
|
codec::{FramedRead, LinesCodec},
|
||||||
@@ -76,27 +77,24 @@ impl OllamaClient {
|
|||||||
Ok(models)
|
Ok(models)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Generates a response for the provided prompt
|
async fn stream_response<R: Serialize, T: DeserializeOwned>(
|
||||||
pub async fn generate(
|
|
||||||
&self,
|
&self,
|
||||||
request: GenerateRequest,
|
endpoint: String,
|
||||||
) -> impl Stream<Item = OllamaResult<GenerateResponse>> {
|
request: R,
|
||||||
let request_address = format!("{}/api/generate", self.server_address);
|
) -> impl Stream<Item = OllamaResult<T>> {
|
||||||
let client = reqwest::Client::new();
|
let client = reqwest::Client::new();
|
||||||
|
|
||||||
// The stream macro creates an asynchronous generator
|
|
||||||
Box::pin(stream! {
|
Box::pin(stream! {
|
||||||
let response = client
|
let response = client
|
||||||
.post(request_address)
|
.post(endpoint)
|
||||||
.json(&request)
|
.json(&request)
|
||||||
.send()
|
.send()
|
||||||
.await
|
.await
|
||||||
.map_err(|e| OllamaError::from(e))?; // Adjust based on your error type
|
.map_err(OllamaError::from)?; // Adjust based on your error type
|
||||||
|
|
||||||
let bytes_stream = response.bytes_stream();
|
let bytes_stream = response.bytes_stream();
|
||||||
|
|
||||||
let body_reader = StreamReader::new(
|
let body_reader = StreamReader::new(
|
||||||
bytes_stream.map(|res| res.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))),
|
bytes_stream.map(|res| res.map_err(std::io::Error::other)),
|
||||||
);
|
);
|
||||||
|
|
||||||
let mut lines_stream = FramedRead::new(body_reader, LinesCodec::new());
|
let mut lines_stream = FramedRead::new(body_reader, LinesCodec::new());
|
||||||
@@ -104,10 +102,8 @@ impl OllamaClient {
|
|||||||
while let Some(line_result) = lines_stream.next().await {
|
while let Some(line_result) = lines_stream.next().await {
|
||||||
match line_result {
|
match line_result {
|
||||||
Ok(line_content) => {
|
Ok(line_content) => {
|
||||||
if let Ok(parsed) = serde_json::from_str::<GenerateResponse>(&line_content) {
|
if let Ok(parsed) = serde_json::from_str::<T>(&line_content) {
|
||||||
let done = parsed.done;
|
|
||||||
yield Ok(parsed);
|
yield Ok(parsed);
|
||||||
if done { break; }
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Err(e) => yield Err(OllamaError::from(e)),
|
Err(e) => yield Err(OllamaError::from(e)),
|
||||||
@@ -116,44 +112,22 @@ impl OllamaClient {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Generates a response for the provided prompt
|
||||||
|
pub async fn generate(
|
||||||
|
&self,
|
||||||
|
request: GenerateRequest,
|
||||||
|
) -> impl Stream<Item = OllamaResult<GenerateResponse>> {
|
||||||
|
let request_address = format!("{}/api/generate", self.server_address);
|
||||||
|
self.stream_response(request_address, request).await
|
||||||
|
}
|
||||||
|
|
||||||
/// Generate the next chat message in a conversation between a user and an assistant.
|
/// Generate the next chat message in a conversation between a user and an assistant.
|
||||||
pub async fn chat(
|
pub async fn chat(
|
||||||
&self,
|
&self,
|
||||||
request: ChatRequest,
|
request: ChatRequest,
|
||||||
) -> impl Stream<Item = OllamaResult<ChatResponse>> {
|
) -> impl Stream<Item = OllamaResult<ChatResponse>> {
|
||||||
let request_address = format!("{}/api/chat", self.server_address);
|
let request_address = format!("{}/api/chat", self.server_address);
|
||||||
let client = reqwest::Client::new();
|
self.stream_response(request_address, request).await
|
||||||
|
|
||||||
// The stream macro creates an asynchronous generator
|
|
||||||
Box::pin(stream! {
|
|
||||||
let response = client
|
|
||||||
.post(request_address)
|
|
||||||
.json(&request)
|
|
||||||
.send()
|
|
||||||
.await
|
|
||||||
.map_err(|e| OllamaError::from(e))?; // Adjust based on your error type
|
|
||||||
|
|
||||||
let bytes_stream = response.bytes_stream();
|
|
||||||
|
|
||||||
let body_reader = StreamReader::new(
|
|
||||||
bytes_stream.map(|res| res.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))),
|
|
||||||
);
|
|
||||||
|
|
||||||
let mut lines_stream = FramedRead::new(body_reader, LinesCodec::new());
|
|
||||||
|
|
||||||
while let Some(line_result) = lines_stream.next().await {
|
|
||||||
match line_result {
|
|
||||||
Ok(line_content) => {
|
|
||||||
if let Ok(parsed) = serde_json::from_str::<ChatResponse>(&line_content) {
|
|
||||||
let done = parsed.done;
|
|
||||||
yield Ok(parsed);
|
|
||||||
if done { break; }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Err(e) => yield Err(OllamaError::from(e)),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn pull(
|
pub async fn pull(
|
||||||
@@ -161,35 +135,6 @@ impl OllamaClient {
|
|||||||
request: PullRequest,
|
request: PullRequest,
|
||||||
) -> impl Stream<Item = OllamaResult<PullResponse>> {
|
) -> impl Stream<Item = OllamaResult<PullResponse>> {
|
||||||
let request_address = format!("{}/api/pull", self.server_address);
|
let request_address = format!("{}/api/pull", self.server_address);
|
||||||
let client = reqwest::Client::new();
|
self.stream_response(request_address, request).await
|
||||||
|
|
||||||
Box::pin(stream! {
|
|
||||||
let response = client
|
|
||||||
.post(request_address)
|
|
||||||
.json(&request)
|
|
||||||
.send()
|
|
||||||
.await
|
|
||||||
.map_err(|e| OllamaError::from(e))?; // Adjust based on your error type
|
|
||||||
|
|
||||||
let bytes_stream = response.bytes_stream();
|
|
||||||
|
|
||||||
let body_reader = StreamReader::new(
|
|
||||||
bytes_stream.map(|res| res.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))),
|
|
||||||
);
|
|
||||||
|
|
||||||
let mut lines_stream = FramedRead::new(body_reader, LinesCodec::new());
|
|
||||||
|
|
||||||
while let Some(line_result) = lines_stream.next().await {
|
|
||||||
match line_result {
|
|
||||||
Ok(line_content) => {
|
|
||||||
println!("{line_content}");
|
|
||||||
if let Ok(parsed) = serde_json::from_str::<PullResponse>(&line_content) {
|
|
||||||
yield Ok(parsed);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Err(e) => yield Err(OllamaError::from(e)),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user