Drops the serde_eventsource dependency
serde_eventsource is taking a while to update to reqwest 13.*. This PR implements handling SSE in the library code.
This commit is contained in:
167
src/client.rs
167
src/client.rs
@@ -1,22 +1,14 @@
|
||||
use crate::error::Result as GeminiResult;
|
||||
use std::sync::Arc;
|
||||
use crate::dialogue::Message;
|
||||
use crate::error::{Error as GeminiError, Result as GeminiResult};
|
||||
use crate::network::event_source::{EventSource, ServerSentEvent};
|
||||
use crate::prelude::*;
|
||||
use crate::types::{PredictImageRequest, PredictImageResponse, Role};
|
||||
|
||||
use std::vec;
|
||||
use tokio_stream::{Stream, StreamExt};
|
||||
|
||||
use deadqueue::unlimited::Queue;
|
||||
use reqwest_eventsource::{Event, EventSource};
|
||||
use tokio_util::codec::LinesCodecError;
|
||||
use tracing::error;
|
||||
|
||||
use crate::dialogue::Message;
|
||||
use crate::error::{Error, Result};
|
||||
use crate::prelude::Part;
|
||||
use crate::prelude::{
|
||||
Candidate, Content, CountTokensRequest, CountTokensResponse, GenerateContentRequest,
|
||||
GenerateContentResponse, GenerateContentResponseResult, TextEmbeddingRequest,
|
||||
TextEmbeddingResponse,
|
||||
};
|
||||
use crate::types::{PredictImageRequest, PredictImageResponse, Role};
|
||||
|
||||
pub static AUTH_SCOPE: &[&str] = &["https://www.googleapis.com/auth/cloud-platform"];
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
@@ -36,126 +28,43 @@ impl GeminiClient {
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn generate_content_stream(
|
||||
&self,
|
||||
request: &GenerateContentRequest,
|
||||
model: &str,
|
||||
) -> Result<impl Stream<Item = GeminiResult<GenerateContentResponseResult>>> {
|
||||
let endpoint_url = format!(
|
||||
"https://generativelanguage.googleapis.com/v1beta/models/{model}:streamGenerateContent?alt=sse"
|
||||
);
|
||||
let client = self.client.clone();
|
||||
let request = request.clone();
|
||||
let req = client
|
||||
.post(&endpoint_url)
|
||||
.header("x-goog-api-key", &self.api_key)
|
||||
.json(&request);
|
||||
|
||||
let event_source = EventSource::new(req).unwrap();
|
||||
|
||||
let mapped = event_source.filter_map(|event| {
|
||||
let event = match event {
|
||||
Ok(event) => event,
|
||||
Err(reqwest_eventsource::Error::StreamEnded) => {
|
||||
return Some(Err(Error::EventSourceClosedError));
|
||||
}
|
||||
Err(e) => return Some(Err(e.into())),
|
||||
};
|
||||
|
||||
let Event::Message(event_message) = event else {
|
||||
return None;
|
||||
};
|
||||
|
||||
let gemini_response: GenerateContentResponse =
|
||||
match serde_json::from_str(&event_message.data) {
|
||||
Ok(gemini_response) => gemini_response,
|
||||
Err(e) => return Some(Err(e.into())),
|
||||
};
|
||||
|
||||
let gemini_response = match gemini_response.into_result() {
|
||||
Ok(gemini_response) => gemini_response,
|
||||
Err(e) => return Some(Err(e)),
|
||||
};
|
||||
|
||||
Some(Ok(gemini_response))
|
||||
});
|
||||
Ok(mapped)
|
||||
}
|
||||
|
||||
pub async fn stream_generate_content(
|
||||
&self,
|
||||
request: &GenerateContentRequest,
|
||||
model: &str,
|
||||
) -> Arc<Queue<Option<Result<GenerateContentResponseResult>>>> {
|
||||
let queue = Arc::new(Queue::<Option<Result<GenerateContentResponseResult>>>::new());
|
||||
|
||||
// Clone the queue and other necessary data to move into the async block.
|
||||
let cloned_queue = queue.clone();
|
||||
) -> GeminiResult<impl Stream<Item = GeminiResult<GenerateContentResponseResult>>> {
|
||||
let endpoint_url = format!(
|
||||
"https://generativelanguage.googleapis.com/v1beta/models/{model}:streamGenerateContent?alt=sse"
|
||||
);
|
||||
let client = self.client.clone();
|
||||
let request = request.clone();
|
||||
Ok(client
|
||||
.post(&endpoint_url)
|
||||
.header("x-goog-api-key", &self.api_key)
|
||||
.json(&request)
|
||||
.send()
|
||||
.await?
|
||||
.event_stream()
|
||||
.filter_map(Self::parse_event))
|
||||
}
|
||||
|
||||
let api_key = self.api_key.clone();
|
||||
// Start a thread to run the request in the background.
|
||||
tokio::spawn(async move {
|
||||
let req = client
|
||||
.post(&endpoint_url)
|
||||
.header("x-goog-api-key", api_key)
|
||||
.json(&request);
|
||||
fn parse_event(
|
||||
event_result: std::result::Result<ServerSentEvent, LinesCodecError>,
|
||||
) -> Option<GeminiResult<GenerateContentResponseResult>> {
|
||||
let data = event_result.map_err(Into::<GeminiError>::into).ok()?.data?;
|
||||
|
||||
let mut event_source = match EventSource::new(req) {
|
||||
Ok(event_source) => event_source,
|
||||
Err(e) => {
|
||||
cloned_queue.push(Some(Err(e.into())));
|
||||
return;
|
||||
}
|
||||
};
|
||||
while let Some(event) = event_source.next().await {
|
||||
match event {
|
||||
Ok(event) => {
|
||||
if let Event::Message(event) = event {
|
||||
let response: serde_json::error::Result<GenerateContentResponse> =
|
||||
serde_json::from_str(&event.data);
|
||||
|
||||
match response {
|
||||
Ok(response) => {
|
||||
let result = response.into_result();
|
||||
let finished = match &result {
|
||||
Ok(result) => result.candidates[0].finish_reason.is_some(),
|
||||
Err(_) => true,
|
||||
};
|
||||
cloned_queue.push(Some(result));
|
||||
if finished {
|
||||
break;
|
||||
}
|
||||
}
|
||||
Err(_) => {
|
||||
tracing::error!("Error parsing message: {}", event.data);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("Error in event source: {:?}", e);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
cloned_queue.push(None);
|
||||
});
|
||||
|
||||
// Return the queue that will receive the responses.
|
||||
queue
|
||||
Some(
|
||||
serde_json::from_str::<GenerateContentResponse>(&data)
|
||||
.map_err(Into::into)
|
||||
.and_then(|resp| resp.into_result()),
|
||||
)
|
||||
}
|
||||
|
||||
pub async fn generate_content(
|
||||
&self,
|
||||
request: &GenerateContentRequest,
|
||||
model: &str,
|
||||
) -> Result<GenerateContentResponseResult> {
|
||||
) -> GeminiResult<GenerateContentResponseResult> {
|
||||
let endpoint_url: String = format!(
|
||||
"https://generativelanguage.googleapis.com/v1beta/models/{model}:generateContent",
|
||||
);
|
||||
@@ -175,10 +84,10 @@ impl GeminiClient {
|
||||
if let Ok(gemini_error) =
|
||||
serde_json::from_str::<crate::types::GeminiApiError>(&txt_json)
|
||||
{
|
||||
return Err(Error::GeminiError(gemini_error));
|
||||
return Err(GeminiError::GeminiError(gemini_error));
|
||||
}
|
||||
// Fallback if parsing fails, though it should ideally match GeminiApiError
|
||||
return Err(Error::GenericApiError {
|
||||
return Err(GeminiError::GenericApiError {
|
||||
status: status.as_u16(),
|
||||
body: txt_json,
|
||||
});
|
||||
@@ -194,7 +103,11 @@ impl GeminiClient {
|
||||
}
|
||||
|
||||
/// Prompts a conversation to the model.
|
||||
pub async fn prompt_conversation(&self, messages: &[Message], model: &str) -> Result<Message> {
|
||||
pub async fn prompt_conversation(
|
||||
&self,
|
||||
messages: &[Message],
|
||||
model: &str,
|
||||
) -> GeminiResult<Message> {
|
||||
let request = GenerateContentRequest {
|
||||
contents: messages
|
||||
.iter()
|
||||
@@ -216,7 +129,7 @@ impl GeminiClient {
|
||||
|
||||
match candidates.pop() {
|
||||
Some(text) => Ok(Message::new(Role::Model, &text)),
|
||||
None => Err(Error::NoCandidatesError),
|
||||
None => Err(GeminiError::NoCandidatesError),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -232,7 +145,7 @@ impl GeminiClient {
|
||||
&self,
|
||||
request: &TextEmbeddingRequest,
|
||||
model: &str,
|
||||
) -> Result<TextEmbeddingResponse> {
|
||||
) -> GeminiResult<TextEmbeddingResponse> {
|
||||
let endpoint_url =
|
||||
format!("https://generativelanguage.googleapis.com/v1beta/models/{model}:predict");
|
||||
let resp = self
|
||||
@@ -251,7 +164,7 @@ impl GeminiClient {
|
||||
&self,
|
||||
request: &CountTokensRequest,
|
||||
model: &str,
|
||||
) -> Result<CountTokensResponse> {
|
||||
) -> GeminiResult<CountTokensResponse> {
|
||||
let endpoint_url =
|
||||
format!("https://generativelanguage.googleapis.com/v1beta/models/{model}:countTokens");
|
||||
let resp = self
|
||||
@@ -271,7 +184,7 @@ impl GeminiClient {
|
||||
&self,
|
||||
request: &PredictImageRequest,
|
||||
model: &str,
|
||||
) -> Result<PredictImageResponse> {
|
||||
) -> GeminiResult<PredictImageResponse> {
|
||||
let endpoint_url =
|
||||
format!("https://generativelanguage.googleapis.com/v1beta/models/{model}:predict");
|
||||
|
||||
@@ -290,9 +203,9 @@ impl GeminiClient {
|
||||
if let Ok(gemini_error) =
|
||||
serde_json::from_str::<crate::types::GeminiApiError>(&txt_json)
|
||||
{
|
||||
return Err(Error::GeminiError(gemini_error));
|
||||
return Err(GeminiError::GeminiError(gemini_error));
|
||||
}
|
||||
return Err(Error::GenericApiError {
|
||||
return Err(GeminiError::GenericApiError {
|
||||
status: status.as_u16(),
|
||||
body: txt_json,
|
||||
});
|
||||
|
||||
20
src/error.rs
20
src/error.rs
@@ -1,6 +1,6 @@
|
||||
use std::fmt::Display;
|
||||
|
||||
use reqwest_eventsource::CannotCloneRequestError;
|
||||
use tokio_util::codec::LinesCodecError;
|
||||
|
||||
use crate::types;
|
||||
|
||||
@@ -14,8 +14,7 @@ pub enum Error {
|
||||
VertexError(types::VertexApiError),
|
||||
GeminiError(types::GeminiApiError),
|
||||
NoCandidatesError,
|
||||
CannotCloneRequestError(CannotCloneRequestError),
|
||||
EventSourceError(Box<reqwest_eventsource::Error>),
|
||||
EventSourceError(LinesCodecError),
|
||||
EventSourceClosedError,
|
||||
GenericApiError { status: u16, body: String },
|
||||
}
|
||||
@@ -35,9 +34,6 @@ impl Display for Error {
|
||||
Error::NoCandidatesError => {
|
||||
write!(f, "No candidates returned for the prompt")
|
||||
}
|
||||
Error::CannotCloneRequestError(e) => {
|
||||
write!(f, "Cannot clone request: {e}")
|
||||
}
|
||||
Error::EventSourceError(e) => {
|
||||
write!(f, "EventSourrce Error: {e}")
|
||||
}
|
||||
@@ -83,14 +79,8 @@ impl From<types::GeminiApiError> for Error {
|
||||
}
|
||||
}
|
||||
|
||||
impl From<CannotCloneRequestError> for Error {
|
||||
fn from(e: CannotCloneRequestError) -> Self {
|
||||
Error::CannotCloneRequestError(e)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<reqwest_eventsource::Error> for Error {
|
||||
fn from(e: reqwest_eventsource::Error) -> Self {
|
||||
Error::EventSourceError(Box::new(e))
|
||||
impl From<LinesCodecError> for Error {
|
||||
fn from(e: LinesCodecError) -> Self {
|
||||
Error::EventSourceError(e)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -35,6 +35,12 @@ pub struct ServerSentEventsCodec {
|
||||
next: ServerSentEvent,
|
||||
}
|
||||
|
||||
impl Default for ServerSentEventsCodec {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl ServerSentEventsCodec {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
|
||||
Reference in New Issue
Block a user