Finishes the generate implementation

This commit is contained in:
2025-12-24 15:58:23 +00:00
parent 8a0c10c6fc
commit d06d69d132
8 changed files with 97 additions and 30 deletions

23
Cargo.lock generated
View File

@@ -2,6 +2,28 @@
# It is not intended for manual editing. # It is not intended for manual editing.
version = 4 version = 4
[[package]]
name = "async-stream"
version = "0.3.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0b5a71a6f37880a80d1d7f19efd781e4b5de42c88f0722cc13bcb6cc2cfe8476"
dependencies = [
"async-stream-impl",
"futures-core",
"pin-project-lite",
]
[[package]]
name = "async-stream-impl"
version = "0.3.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c7c24de15d275a1ecfd47a380fb4d5ec9bfe0933f309ed5e705b775596a3574d"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]] [[package]]
name = "atomic-waker" name = "atomic-waker"
version = "1.1.2" version = "1.1.2"
@@ -612,6 +634,7 @@ dependencies = [
name = "ollama-rs" name = "ollama-rs"
version = "0.1.0" version = "0.1.0"
dependencies = [ dependencies = [
"async-stream",
"dotenvy", "dotenvy",
"futures-util", "futures-util",
"reqwest", "reqwest",

View File

@@ -10,6 +10,7 @@ serde_json = "1.0.146"
tokio-util = "0.7.17" tokio-util = "0.7.17"
tracing = "0.1.44" tracing = "0.1.44"
futures-util = "0.3.31" futures-util = "0.3.31"
async-stream = "0.3.6"
[dev-dependencies] [dev-dependencies]
dotenvy = "0.15.7" dotenvy = "0.15.7"

View File

@@ -1,5 +1,6 @@
use std::{env, error::Error}; use std::{env, error::Error, io::Write};
use futures_util::StreamExt;
use ollama_rs::{OllamaClient, types::generate::GenerateRequest}; use ollama_rs::{OllamaClient, types::generate::GenerateRequest};
#[tokio::main] #[tokio::main]
@@ -8,9 +9,21 @@ async fn main() -> Result<(), Box<dyn Error>> {
let server_address = env::var("OLLAMA_SERVER")?; let server_address = env::var("OLLAMA_SERVER")?;
let ollama_client = OllamaClient::new(server_address); let ollama_client = OllamaClient::new(server_address);
let request = GenerateRequest::builder("dolphin3:8b") let request = GenerateRequest::builder("dolphin3:8b")
.system_prompt("You a role play character called Gerald. You are a dumb person who things knows a lot but PROVIDES WRONG ANSWERS to all questions.")
.prompt("Why is the sky blue?") .prompt("Why is the sky blue?")
.build(); .build();
let response = ollama_client.generate(request).await?;
println!("{:?}", response); let mut stream = ollama_client.generate(request).await;
while let Some(response) = stream.next().await {
match response {
Ok(token) => {
print!("{}", token.response);
std::io::stdout().flush()?;
}
Err(e) => println!("Error: {}", e),
}
}
Ok(()) Ok(())
} }

View File

@@ -1,11 +1,14 @@
use std::{error::Error, fmt::Display}; use std::{error::Error, fmt::Display};
use tokio_util::codec::LinesCodecError;
pub type OllamaResult<T> = Result<T, OllamaError>; pub type OllamaResult<T> = Result<T, OllamaError>;
#[derive(Debug)] #[derive(Debug)]
pub enum OllamaError { pub enum OllamaError {
NetworkError(reqwest::Error), NetworkError(reqwest::Error),
ResponseParseError(serde_json::Error), ResponseParseError(serde_json::Error),
LinesCoderError(LinesCodecError),
} }
impl Error for OllamaError {} impl Error for OllamaError {}
@@ -15,6 +18,7 @@ impl Display for OllamaError {
match self { match self {
OllamaError::NetworkError(e) => writeln!(f, "Network error: {}", e), OllamaError::NetworkError(e) => writeln!(f, "Network error: {}", e),
OllamaError::ResponseParseError(e) => writeln!(f, "ResponseParseError error: {}", e), OllamaError::ResponseParseError(e) => writeln!(f, "ResponseParseError error: {}", e),
OllamaError::LinesCoderError(e) => writeln!(f, "LinesCoderError error: {}", e),
} }
} }
} }
@@ -30,3 +34,9 @@ impl From<serde_json::Error> for OllamaError {
Self::ResponseParseError(error) Self::ResponseParseError(error)
} }
} }
impl From<LinesCodecError> for OllamaError {
fn from(value: LinesCodecError) -> Self {
Self::LinesCoderError(value)
}
}

View File

@@ -1,10 +1,14 @@
use futures_util::{StreamExt}; use async_stream::stream;
use futures_util::{Stream, StreamExt};
use serde_json::Value; use serde_json::Value;
use tokio_util::io::StreamReader; use tokio_util::{
codec::{FramedRead, LinesCodec},
io::StreamReader,
};
use tracing::info; use tracing::info;
use crate::{ use crate::{
error::OllamaResult, error::{OllamaError, OllamaResult},
types::{ types::{
generate::{GenerateRequest, GenerateResponse}, generate::{GenerateRequest, GenerateResponse},
ps::RunningModel, ps::RunningModel,
@@ -71,32 +75,42 @@ impl OllamaClient {
} }
/// Generates a response for the provided prompt /// Generates a response for the provided prompt
pub async fn generate(&self, request: GenerateRequest) -> OllamaResult<()> { pub async fn generate(
&self,
request: GenerateRequest,
) -> impl Stream<Item = OllamaResult<GenerateResponse>> {
let request_address = format!("{}/api/generate", self.server_address); let request_address = format!("{}/api/generate", self.server_address);
let client = reqwest::Client::new(); let client = reqwest::Client::new();
let response = client
.post(request_address)
.json(&request)
.send()
.await?
.error_for_status()?;
let stream = response.bytes_stream().;
// let reader = BufReader::new(stream);
let reader = StreamReader(stream);
while reader
while let Some(item) = stream.next().await {
let item = item?;
println!("Chunk: {:?}", item?);
}
// let stream_reader = tokio_util::io::StreamReader::new(stream); // The stream macro creates an asynchronous generator
// let reder = BufReader::new(stream); Box::pin(stream! {
// let full_response = response.text().await?; let response = client
// let parts = full_response .post(request_address)
// .lines() .json(&request)
// .map(|line| serde_json::from_str::<GenerateResponse>(line).unwrap()) .send()
// .collect::<Vec<_>>(); .await
// println!("{:#?}", parts); .map_err(|e| OllamaError::from(e))?; // Adjust based on your error type
Ok(())
let bytes_stream = response.bytes_stream();
let body_reader = StreamReader::new(
bytes_stream.map(|res| res.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))),
);
let mut lines_stream = FramedRead::new(body_reader, LinesCodec::new());
while let Some(line_result) = lines_stream.next().await {
match line_result {
Ok(line_content) => {
if let Ok(parsed) = serde_json::from_str::<GenerateResponse>(&line_content) {
let done = parsed.done;
yield Ok(parsed);
if done { break; }
}
}
Err(e) => yield Err(OllamaError::from(e)),
}
}
})
} }
} }

0
src/types/chat.rs Normal file
View File

View File

@@ -41,6 +41,11 @@ impl GenerateRequestBuilder {
} }
} }
pub fn system_prompt<P: Into<String>>(mut self, system_prompt: P) -> Self {
self.generate_request.system = Some(system_prompt.into());
self
}
pub fn prompt<P: Into<String>>(mut self, prompt: P) -> Self { pub fn prompt<P: Into<String>>(mut self, prompt: P) -> Self {
self.generate_request.prompt = Some(prompt.into()); self.generate_request.prompt = Some(prompt.into());
self self

View File

@@ -1,3 +1,4 @@
pub mod chat;
pub mod common; pub mod common;
pub mod generate; pub mod generate;
pub mod ps; pub mod ps;