Finishes the generate implementation
This commit is contained in:
23
Cargo.lock
generated
23
Cargo.lock
generated
@@ -2,6 +2,28 @@
|
|||||||
# It is not intended for manual editing.
|
# It is not intended for manual editing.
|
||||||
version = 4
|
version = 4
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "async-stream"
|
||||||
|
version = "0.3.6"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "0b5a71a6f37880a80d1d7f19efd781e4b5de42c88f0722cc13bcb6cc2cfe8476"
|
||||||
|
dependencies = [
|
||||||
|
"async-stream-impl",
|
||||||
|
"futures-core",
|
||||||
|
"pin-project-lite",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "async-stream-impl"
|
||||||
|
version = "0.3.6"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "c7c24de15d275a1ecfd47a380fb4d5ec9bfe0933f309ed5e705b775596a3574d"
|
||||||
|
dependencies = [
|
||||||
|
"proc-macro2",
|
||||||
|
"quote",
|
||||||
|
"syn",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "atomic-waker"
|
name = "atomic-waker"
|
||||||
version = "1.1.2"
|
version = "1.1.2"
|
||||||
@@ -612,6 +634,7 @@ dependencies = [
|
|||||||
name = "ollama-rs"
|
name = "ollama-rs"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
|
"async-stream",
|
||||||
"dotenvy",
|
"dotenvy",
|
||||||
"futures-util",
|
"futures-util",
|
||||||
"reqwest",
|
"reqwest",
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ serde_json = "1.0.146"
|
|||||||
tokio-util = "0.7.17"
|
tokio-util = "0.7.17"
|
||||||
tracing = "0.1.44"
|
tracing = "0.1.44"
|
||||||
futures-util = "0.3.31"
|
futures-util = "0.3.31"
|
||||||
|
async-stream = "0.3.6"
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
dotenvy = "0.15.7"
|
dotenvy = "0.15.7"
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
use std::{env, error::Error};
|
use std::{env, error::Error, io::Write};
|
||||||
|
|
||||||
|
use futures_util::StreamExt;
|
||||||
use ollama_rs::{OllamaClient, types::generate::GenerateRequest};
|
use ollama_rs::{OllamaClient, types::generate::GenerateRequest};
|
||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
@@ -8,9 +9,21 @@ async fn main() -> Result<(), Box<dyn Error>> {
|
|||||||
let server_address = env::var("OLLAMA_SERVER")?;
|
let server_address = env::var("OLLAMA_SERVER")?;
|
||||||
let ollama_client = OllamaClient::new(server_address);
|
let ollama_client = OllamaClient::new(server_address);
|
||||||
let request = GenerateRequest::builder("dolphin3:8b")
|
let request = GenerateRequest::builder("dolphin3:8b")
|
||||||
|
.system_prompt("You a role play character called Gerald. You are a dumb person who things knows a lot but PROVIDES WRONG ANSWERS to all questions.")
|
||||||
.prompt("Why is the sky blue?")
|
.prompt("Why is the sky blue?")
|
||||||
.build();
|
.build();
|
||||||
let response = ollama_client.generate(request).await?;
|
|
||||||
println!("{:?}", response);
|
let mut stream = ollama_client.generate(request).await;
|
||||||
|
|
||||||
|
while let Some(response) = stream.next().await {
|
||||||
|
match response {
|
||||||
|
Ok(token) => {
|
||||||
|
print!("{}", token.response);
|
||||||
|
std::io::stdout().flush()?;
|
||||||
|
}
|
||||||
|
Err(e) => println!("Error: {}", e),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|||||||
10
src/error.rs
10
src/error.rs
@@ -1,11 +1,14 @@
|
|||||||
use std::{error::Error, fmt::Display};
|
use std::{error::Error, fmt::Display};
|
||||||
|
|
||||||
|
use tokio_util::codec::LinesCodecError;
|
||||||
|
|
||||||
pub type OllamaResult<T> = Result<T, OllamaError>;
|
pub type OllamaResult<T> = Result<T, OllamaError>;
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub enum OllamaError {
|
pub enum OllamaError {
|
||||||
NetworkError(reqwest::Error),
|
NetworkError(reqwest::Error),
|
||||||
ResponseParseError(serde_json::Error),
|
ResponseParseError(serde_json::Error),
|
||||||
|
LinesCoderError(LinesCodecError),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Error for OllamaError {}
|
impl Error for OllamaError {}
|
||||||
@@ -15,6 +18,7 @@ impl Display for OllamaError {
|
|||||||
match self {
|
match self {
|
||||||
OllamaError::NetworkError(e) => writeln!(f, "Network error: {}", e),
|
OllamaError::NetworkError(e) => writeln!(f, "Network error: {}", e),
|
||||||
OllamaError::ResponseParseError(e) => writeln!(f, "ResponseParseError error: {}", e),
|
OllamaError::ResponseParseError(e) => writeln!(f, "ResponseParseError error: {}", e),
|
||||||
|
OllamaError::LinesCoderError(e) => writeln!(f, "LinesCoderError error: {}", e),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -30,3 +34,9 @@ impl From<serde_json::Error> for OllamaError {
|
|||||||
Self::ResponseParseError(error)
|
Self::ResponseParseError(error)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl From<LinesCodecError> for OllamaError {
|
||||||
|
fn from(value: LinesCodecError) -> Self {
|
||||||
|
Self::LinesCoderError(value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
68
src/lib.rs
68
src/lib.rs
@@ -1,10 +1,14 @@
|
|||||||
use futures_util::{StreamExt};
|
use async_stream::stream;
|
||||||
|
use futures_util::{Stream, StreamExt};
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
use tokio_util::io::StreamReader;
|
use tokio_util::{
|
||||||
|
codec::{FramedRead, LinesCodec},
|
||||||
|
io::StreamReader,
|
||||||
|
};
|
||||||
use tracing::info;
|
use tracing::info;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
error::OllamaResult,
|
error::{OllamaError, OllamaResult},
|
||||||
types::{
|
types::{
|
||||||
generate::{GenerateRequest, GenerateResponse},
|
generate::{GenerateRequest, GenerateResponse},
|
||||||
ps::RunningModel,
|
ps::RunningModel,
|
||||||
@@ -71,32 +75,42 @@ impl OllamaClient {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Generates a response for the provided prompt
|
/// Generates a response for the provided prompt
|
||||||
pub async fn generate(&self, request: GenerateRequest) -> OllamaResult<()> {
|
pub async fn generate(
|
||||||
|
&self,
|
||||||
|
request: GenerateRequest,
|
||||||
|
) -> impl Stream<Item = OllamaResult<GenerateResponse>> {
|
||||||
let request_address = format!("{}/api/generate", self.server_address);
|
let request_address = format!("{}/api/generate", self.server_address);
|
||||||
let client = reqwest::Client::new();
|
let client = reqwest::Client::new();
|
||||||
let response = client
|
|
||||||
.post(request_address)
|
|
||||||
.json(&request)
|
|
||||||
.send()
|
|
||||||
.await?
|
|
||||||
.error_for_status()?;
|
|
||||||
let stream = response.bytes_stream().;
|
|
||||||
// let reader = BufReader::new(stream);
|
|
||||||
let reader = StreamReader(stream);
|
|
||||||
while reader
|
|
||||||
while let Some(item) = stream.next().await {
|
|
||||||
let item = item?;
|
|
||||||
println!("Chunk: {:?}", item?);
|
|
||||||
}
|
|
||||||
|
|
||||||
// let stream_reader = tokio_util::io::StreamReader::new(stream);
|
// The stream macro creates an asynchronous generator
|
||||||
// let reder = BufReader::new(stream);
|
Box::pin(stream! {
|
||||||
// let full_response = response.text().await?;
|
let response = client
|
||||||
// let parts = full_response
|
.post(request_address)
|
||||||
// .lines()
|
.json(&request)
|
||||||
// .map(|line| serde_json::from_str::<GenerateResponse>(line).unwrap())
|
.send()
|
||||||
// .collect::<Vec<_>>();
|
.await
|
||||||
// println!("{:#?}", parts);
|
.map_err(|e| OllamaError::from(e))?; // Adjust based on your error type
|
||||||
Ok(())
|
|
||||||
|
let bytes_stream = response.bytes_stream();
|
||||||
|
|
||||||
|
let body_reader = StreamReader::new(
|
||||||
|
bytes_stream.map(|res| res.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))),
|
||||||
|
);
|
||||||
|
|
||||||
|
let mut lines_stream = FramedRead::new(body_reader, LinesCodec::new());
|
||||||
|
|
||||||
|
while let Some(line_result) = lines_stream.next().await {
|
||||||
|
match line_result {
|
||||||
|
Ok(line_content) => {
|
||||||
|
if let Ok(parsed) = serde_json::from_str::<GenerateResponse>(&line_content) {
|
||||||
|
let done = parsed.done;
|
||||||
|
yield Ok(parsed);
|
||||||
|
if done { break; }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => yield Err(OllamaError::from(e)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
0
src/types/chat.rs
Normal file
0
src/types/chat.rs
Normal file
@@ -41,6 +41,11 @@ impl GenerateRequestBuilder {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn system_prompt<P: Into<String>>(mut self, system_prompt: P) -> Self {
|
||||||
|
self.generate_request.system = Some(system_prompt.into());
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
pub fn prompt<P: Into<String>>(mut self, prompt: P) -> Self {
|
pub fn prompt<P: Into<String>>(mut self, prompt: P) -> Self {
|
||||||
self.generate_request.prompt = Some(prompt.into());
|
self.generate_request.prompt = Some(prompt.into());
|
||||||
self
|
self
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
pub mod chat;
|
||||||
pub mod common;
|
pub mod common;
|
||||||
pub mod generate;
|
pub mod generate;
|
||||||
pub mod ps;
|
pub mod ps;
|
||||||
|
|||||||
Reference in New Issue
Block a user