Finishes the generate implementation
This commit is contained in:
23
Cargo.lock
generated
23
Cargo.lock
generated
@@ -2,6 +2,28 @@
|
||||
# It is not intended for manual editing.
|
||||
version = 4
|
||||
|
||||
[[package]]
|
||||
name = "async-stream"
|
||||
version = "0.3.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0b5a71a6f37880a80d1d7f19efd781e4b5de42c88f0722cc13bcb6cc2cfe8476"
|
||||
dependencies = [
|
||||
"async-stream-impl",
|
||||
"futures-core",
|
||||
"pin-project-lite",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "async-stream-impl"
|
||||
version = "0.3.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c7c24de15d275a1ecfd47a380fb4d5ec9bfe0933f309ed5e705b775596a3574d"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "atomic-waker"
|
||||
version = "1.1.2"
|
||||
@@ -612,6 +634,7 @@ dependencies = [
|
||||
name = "ollama-rs"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"async-stream",
|
||||
"dotenvy",
|
||||
"futures-util",
|
||||
"reqwest",
|
||||
|
||||
@@ -10,6 +10,7 @@ serde_json = "1.0.146"
|
||||
tokio-util = "0.7.17"
|
||||
tracing = "0.1.44"
|
||||
futures-util = "0.3.31"
|
||||
async-stream = "0.3.6"
|
||||
|
||||
[dev-dependencies]
|
||||
dotenvy = "0.15.7"
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
use std::{env, error::Error};
|
||||
use std::{env, error::Error, io::Write};
|
||||
|
||||
use futures_util::StreamExt;
|
||||
use ollama_rs::{OllamaClient, types::generate::GenerateRequest};
|
||||
|
||||
#[tokio::main]
|
||||
@@ -8,9 +9,21 @@ async fn main() -> Result<(), Box<dyn Error>> {
|
||||
let server_address = env::var("OLLAMA_SERVER")?;
|
||||
let ollama_client = OllamaClient::new(server_address);
|
||||
let request = GenerateRequest::builder("dolphin3:8b")
|
||||
.system_prompt("You a role play character called Gerald. You are a dumb person who things knows a lot but PROVIDES WRONG ANSWERS to all questions.")
|
||||
.prompt("Why is the sky blue?")
|
||||
.build();
|
||||
let response = ollama_client.generate(request).await?;
|
||||
println!("{:?}", response);
|
||||
|
||||
let mut stream = ollama_client.generate(request).await;
|
||||
|
||||
while let Some(response) = stream.next().await {
|
||||
match response {
|
||||
Ok(token) => {
|
||||
print!("{}", token.response);
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
Err(e) => println!("Error: {}", e),
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
10
src/error.rs
10
src/error.rs
@@ -1,11 +1,14 @@
|
||||
use std::{error::Error, fmt::Display};
|
||||
|
||||
use tokio_util::codec::LinesCodecError;
|
||||
|
||||
pub type OllamaResult<T> = Result<T, OllamaError>;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum OllamaError {
|
||||
NetworkError(reqwest::Error),
|
||||
ResponseParseError(serde_json::Error),
|
||||
LinesCoderError(LinesCodecError),
|
||||
}
|
||||
|
||||
impl Error for OllamaError {}
|
||||
@@ -15,6 +18,7 @@ impl Display for OllamaError {
|
||||
match self {
|
||||
OllamaError::NetworkError(e) => writeln!(f, "Network error: {}", e),
|
||||
OllamaError::ResponseParseError(e) => writeln!(f, "ResponseParseError error: {}", e),
|
||||
OllamaError::LinesCoderError(e) => writeln!(f, "LinesCoderError error: {}", e),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -30,3 +34,9 @@ impl From<serde_json::Error> for OllamaError {
|
||||
Self::ResponseParseError(error)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<LinesCodecError> for OllamaError {
|
||||
fn from(value: LinesCodecError) -> Self {
|
||||
Self::LinesCoderError(value)
|
||||
}
|
||||
}
|
||||
|
||||
60
src/lib.rs
60
src/lib.rs
@@ -1,10 +1,14 @@
|
||||
use futures_util::{StreamExt};
|
||||
use async_stream::stream;
|
||||
use futures_util::{Stream, StreamExt};
|
||||
use serde_json::Value;
|
||||
use tokio_util::io::StreamReader;
|
||||
use tokio_util::{
|
||||
codec::{FramedRead, LinesCodec},
|
||||
io::StreamReader,
|
||||
};
|
||||
use tracing::info;
|
||||
|
||||
use crate::{
|
||||
error::OllamaResult,
|
||||
error::{OllamaError, OllamaResult},
|
||||
types::{
|
||||
generate::{GenerateRequest, GenerateResponse},
|
||||
ps::RunningModel,
|
||||
@@ -71,32 +75,42 @@ impl OllamaClient {
|
||||
}
|
||||
|
||||
/// Generates a response for the provided prompt
|
||||
pub async fn generate(&self, request: GenerateRequest) -> OllamaResult<()> {
|
||||
pub async fn generate(
|
||||
&self,
|
||||
request: GenerateRequest,
|
||||
) -> impl Stream<Item = OllamaResult<GenerateResponse>> {
|
||||
let request_address = format!("{}/api/generate", self.server_address);
|
||||
let client = reqwest::Client::new();
|
||||
|
||||
// The stream macro creates an asynchronous generator
|
||||
Box::pin(stream! {
|
||||
let response = client
|
||||
.post(request_address)
|
||||
.json(&request)
|
||||
.send()
|
||||
.await?
|
||||
.error_for_status()?;
|
||||
let stream = response.bytes_stream().;
|
||||
// let reader = BufReader::new(stream);
|
||||
let reader = StreamReader(stream);
|
||||
while reader
|
||||
while let Some(item) = stream.next().await {
|
||||
let item = item?;
|
||||
println!("Chunk: {:?}", item?);
|
||||
}
|
||||
.await
|
||||
.map_err(|e| OllamaError::from(e))?; // Adjust based on your error type
|
||||
|
||||
// let stream_reader = tokio_util::io::StreamReader::new(stream);
|
||||
// let reder = BufReader::new(stream);
|
||||
// let full_response = response.text().await?;
|
||||
// let parts = full_response
|
||||
// .lines()
|
||||
// .map(|line| serde_json::from_str::<GenerateResponse>(line).unwrap())
|
||||
// .collect::<Vec<_>>();
|
||||
// println!("{:#?}", parts);
|
||||
Ok(())
|
||||
let bytes_stream = response.bytes_stream();
|
||||
|
||||
let body_reader = StreamReader::new(
|
||||
bytes_stream.map(|res| res.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))),
|
||||
);
|
||||
|
||||
let mut lines_stream = FramedRead::new(body_reader, LinesCodec::new());
|
||||
|
||||
while let Some(line_result) = lines_stream.next().await {
|
||||
match line_result {
|
||||
Ok(line_content) => {
|
||||
if let Ok(parsed) = serde_json::from_str::<GenerateResponse>(&line_content) {
|
||||
let done = parsed.done;
|
||||
yield Ok(parsed);
|
||||
if done { break; }
|
||||
}
|
||||
}
|
||||
Err(e) => yield Err(OllamaError::from(e)),
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
0
src/types/chat.rs
Normal file
0
src/types/chat.rs
Normal file
@@ -41,6 +41,11 @@ impl GenerateRequestBuilder {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn system_prompt<P: Into<String>>(mut self, system_prompt: P) -> Self {
|
||||
self.generate_request.system = Some(system_prompt.into());
|
||||
self
|
||||
}
|
||||
|
||||
pub fn prompt<P: Into<String>>(mut self, prompt: P) -> Self {
|
||||
self.generate_request.prompt = Some(prompt.into());
|
||||
self
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
pub mod chat;
|
||||
pub mod common;
|
||||
pub mod generate;
|
||||
pub mod ps;
|
||||
|
||||
Reference in New Issue
Block a user