Compare commits

..

18 Commits

Author SHA1 Message Date
e2c0c67b17 Merge pull request #1 from andreban/dependabot/cargo/bytes-1.11.1
Some checks are pending
CI / Check (push) Waiting to run
CI / Test (push) Waiting to run
CI / Clippy (push) Waiting to run
CI / Format (push) Waiting to run
CI / Documentation (push) Waiting to run
Bump bytes from 1.11.0 to 1.11.1
2026-04-03 17:23:57 +01:00
c4dd835594 Merge pull request #2 from andreban/dependabot/cargo/time-0.3.47
Bump time from 0.3.44 to 0.3.47
2026-04-03 17:23:48 +01:00
95fb24d5f1 Merge pull request #3 from andreban/dependabot/cargo/quinn-proto-0.11.14
Bump quinn-proto from 0.11.13 to 0.11.14
2026-04-03 17:23:31 +01:00
f7dc6ead57 Merge pull request #4 from andreban/dependabot/cargo/rustls-webpki-0.103.10
Bump rustls-webpki from 0.103.8 to 0.103.10
2026-04-03 17:23:18 +01:00
5e0fc06327 Chore: update dependencies
Some checks failed
CI / Check (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Clippy (push) Has been cancelled
CI / Format (push) Has been cancelled
CI / Documentation (push) Has been cancelled
2026-04-03 14:32:12 +01:00
dependabot[bot]
71e2440e0b Bump rustls-webpki from 0.103.8 to 0.103.10
Bumps [rustls-webpki](https://github.com/rustls/webpki) from 0.103.8 to 0.103.10.
- [Release notes](https://github.com/rustls/webpki/releases)
- [Commits](https://github.com/rustls/webpki/compare/v/0.103.8...v/0.103.10)

---
updated-dependencies:
- dependency-name: rustls-webpki
  dependency-version: 0.103.10
  dependency-type: indirect
...

Signed-off-by: dependabot[bot] <support@github.com>
2026-03-21 09:37:08 +00:00
dependabot[bot]
d9667c29f7 Bump quinn-proto from 0.11.13 to 0.11.14
Bumps [quinn-proto](https://github.com/quinn-rs/quinn) from 0.11.13 to 0.11.14.
- [Release notes](https://github.com/quinn-rs/quinn/releases)
- [Commits](https://github.com/quinn-rs/quinn/compare/quinn-proto-0.11.13...quinn-proto-0.11.14)

---
updated-dependencies:
- dependency-name: quinn-proto
  dependency-version: 0.11.14
  dependency-type: indirect
...

Signed-off-by: dependabot[bot] <support@github.com>
2026-03-11 00:22:01 +00:00
dependabot[bot]
7da2d82b49 Bump time from 0.3.44 to 0.3.47
Bumps [time](https://github.com/time-rs/time) from 0.3.44 to 0.3.47.
- [Release notes](https://github.com/time-rs/time/releases)
- [Changelog](https://github.com/time-rs/time/blob/main/CHANGELOG.md)
- [Commits](https://github.com/time-rs/time/compare/v0.3.44...v0.3.47)

---
updated-dependencies:
- dependency-name: time
  dependency-version: 0.3.47
  dependency-type: indirect
...

Signed-off-by: dependabot[bot] <support@github.com>
2026-02-05 18:59:16 +00:00
dependabot[bot]
622156591c Bump bytes from 1.11.0 to 1.11.1
Bumps [bytes](https://github.com/tokio-rs/bytes) from 1.11.0 to 1.11.1.
- [Release notes](https://github.com/tokio-rs/bytes/releases)
- [Changelog](https://github.com/tokio-rs/bytes/blob/master/CHANGELOG.md)
- [Commits](https://github.com/tokio-rs/bytes/compare/v1.11.0...v1.11.1)

---
updated-dependencies:
- dependency-name: bytes
  dependency-version: 1.11.1
  dependency-type: indirect
...

Signed-off-by: dependabot[bot] <support@github.com>
2026-02-03 19:42:26 +00:00
2ed8d881c6 Add GitHub Actions CI workflow
Runs check, test, clippy, format, and documentation jobs on pushes
and pull requests targeting main.
2026-01-31 07:00:43 +00:00
c4acf465ba Add rustdoc comments to all public items
Document all public structs, enums, traits, functions, fields, and
variants across the library. Adds crate-level documentation with a
usage example, module-level docs, and API reference links where
applicable. Also fixes four bare URL warnings in existing doc comments.
2026-01-31 06:59:08 +00:00
f8a4323117 Fix SSE data field to append instead of overwrite
Per the SSE specification, multiple data: lines within a single event
should be concatenated with newline separators. Previously, each data:
line overwrote the previous value.
2026-01-31 06:40:45 +00:00
d1bd00ce95 Adds missing Clone/Debugs 2026-01-30 20:39:35 +00:00
a8fbe658bb Add consistent error handling to text_embeddings and count_tokens
- Check HTTP status before parsing response body, matching the
  pattern used by generate_content and predict_image
- Unwrap TextEmbeddingResponse enum, returning TextEmbeddingResponseOk
- Extract CountTokensResponseResult struct and add into_result(),
  returning the unwrapped result instead of the raw enum
- All endpoints now consistently return the success type directly
  and surface API errors as GeminiError or GenericApiError
2026-01-30 20:32:40 +00:00
eb38c65ac5 Remove dead code and fix typo
- Remove unused AUTH_SCOPE constant from client.rs
- Remove unused ErrorType, ErrorInfoMetadata, FieldViolation,
  and Link types from types/error.rs
- Fix "EventSourrce" typo in error.rs Display impl
2026-01-30 20:29:55 +00:00
4c156fbb33 Remove unnecessary unsafe impl Send/Sync for GeminiClient
Both reqwest::Client and String already implement Send and Sync,
so the manual unsafe impls were redundant.
2026-01-30 20:27:54 +00:00
92030a0dd9 Remove prompt_conversation and dialogue module
Drop prompt_conversation, collect_text_from_response, Message,
and Dialogue to redesign the conversation feature later.
2026-01-30 20:26:31 +00:00
56cf4f280b Remove outdated failing tests from generate_content.rs 2026-01-30 20:23:06 +00:00
16 changed files with 447 additions and 372 deletions

61
.github/workflows/ci.yml vendored Normal file
View File

@@ -0,0 +1,61 @@
name: CI
on:
push:
branches: [main]
pull_request:
branches: [main]
env:
CARGO_TERM_COLOR: always
jobs:
check:
name: Check
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: dtolnay/rust-toolchain@stable
- uses: Swatinem/rust-cache@v2
- run: cargo check
test:
name: Test
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: dtolnay/rust-toolchain@stable
- uses: Swatinem/rust-cache@v2
- run: cargo test
clippy:
name: Clippy
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: dtolnay/rust-toolchain@stable
with:
components: clippy
- uses: Swatinem/rust-cache@v2
- run: cargo clippy -- -D warnings
fmt:
name: Format
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: dtolnay/rust-toolchain@stable
with:
components: rustfmt
- run: cargo fmt --check
doc:
name: Documentation
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: dtolnay/rust-toolchain@stable
- uses: Swatinem/rust-cache@v2
- run: cargo doc --no-deps
env:
RUSTDOCFLAGS: "-D warnings"

30
Cargo.lock generated
View File

@@ -216,9 +216,9 @@ checksum = "8f1fe948ff07f4bd06c30984e69f5b4899c516a3ef74f34df92a2df2ab535495"
[[package]] [[package]]
name = "bytes" name = "bytes"
version = "1.11.0" version = "1.11.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b35204fbdc0b3f4446b89fc1ac2cf84a8a68971995d0bf2e925ec7cd960f9cb3" checksum = "1e748733b7cbc798e1434b6ac524f0c1ff2ab456fe201501e6497c8417a4fc33"
[[package]] [[package]]
name = "cc" name = "cc"
@@ -1348,9 +1348,9 @@ dependencies = [
[[package]] [[package]]
name = "num-conv" name = "num-conv"
version = "0.1.0" version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "51d515d32fb182ee37cda2ccdcb92950d6a3c2893aa280e540671c2cd0f3b1d9" checksum = "cf97ec579c3c42f953ef76dbf8d55ac91fb219dde70e49aa4a6b7d74e9919050"
[[package]] [[package]]
name = "num-derive" name = "num-derive"
@@ -1574,9 +1574,9 @@ dependencies = [
[[package]] [[package]]
name = "quinn-proto" name = "quinn-proto"
version = "0.11.13" version = "0.11.14"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f1906b49b0c3bc04b5fe5d86a77925ae6524a19b816ae38ce1e426255f1d8a31" checksum = "434b42fec591c96ef50e21e886936e66d3cc3f737104fdb9b737c40ffb94c098"
dependencies = [ dependencies = [
"aws-lc-rs", "aws-lc-rs",
"bytes", "bytes",
@@ -1898,9 +1898,9 @@ checksum = "f87165f0995f63a9fbeea62b64d10b4d9d8e78ec6d7d51fb2125fda7bb36788f"
[[package]] [[package]]
name = "rustls-webpki" name = "rustls-webpki"
version = "0.103.8" version = "0.103.10"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2ffdfa2f5286e2247234e03f680868ac2815974dc39e00ea15adc445d0aafe52" checksum = "df33b2b81ac578cabaf06b89b0631153a3f416b0a886e8a7a1707fb51abbd1ef"
dependencies = [ dependencies = [
"aws-lc-rs", "aws-lc-rs",
"ring", "ring",
@@ -2280,30 +2280,30 @@ dependencies = [
[[package]] [[package]]
name = "time" name = "time"
version = "0.3.44" version = "0.3.47"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "91e7d9e3bb61134e77bde20dd4825b97c010155709965fedf0f49bb138e52a9d" checksum = "743bd48c283afc0388f9b8827b976905fb217ad9e647fae3a379a9283c4def2c"
dependencies = [ dependencies = [
"deranged", "deranged",
"itoa", "itoa",
"num-conv", "num-conv",
"powerfmt", "powerfmt",
"serde", "serde_core",
"time-core", "time-core",
"time-macros", "time-macros",
] ]
[[package]] [[package]]
name = "time-core" name = "time-core"
version = "0.1.6" version = "0.1.8"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "40868e7c1d2f0b8d73e4a8c7f0ff63af4f6d19be117e90bd73eb1d62cf831c6b" checksum = "7694e1cfe791f8d31026952abf09c69ca6f6fa4e1a1229e18988f06a04a12dca"
[[package]] [[package]]
name = "time-macros" name = "time-macros"
version = "0.2.24" version = "0.2.27"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "30cfb0125f12d9c277f35663a0a33f8c30190f4e4574868a330595412d34ebf3" checksum = "2e70e4c5a0e0a8a4823ad65dfe1a6930e4f4d756dcd9dd7939022b5e8c501215"
dependencies = [ dependencies = [
"num-conv", "num-conv",
"time-core", "time-core",

View File

@@ -10,17 +10,17 @@ edition = "2024"
reqwest = { version = "0.13", features = ["json", "gzip", "stream"] } reqwest = { version = "0.13", features = ["json", "gzip", "stream"] }
serde = { version = "1", features = ["derive"] } serde = { version = "1", features = ["derive"] }
serde_json = { version = "1" } serde_json = { version = "1" }
serde_with = { version = "3.16", features = ["base64"] } serde_with = { version = "3.18", features = ["base64"] }
tracing = "0.1" tracing = "0.1"
tokio = { version = "1" } tokio = { version = "1" }
tokio-stream = "0.1" tokio-stream = "0.1"
tokio-util = "0.7.18" tokio-util = "0.7.18"
[dev-dependencies] [dev-dependencies]
console = "0.16.2" console = "0.16.3"
dialoguer = "0.12.0" dialoguer = "0.12.0"
dotenvy = "0.15.7" dotenvy = "0.15.7"
image = "0.25.9" image = "0.25.10"
indicatif = "0.18.3" indicatif = "0.18.4"
tokio = { version = "1.49.0", features = ["full"] } tokio = { version = "1.51.0", features = ["full"] }
tracing-subscriber = "0.3.22" tracing-subscriber = "0.3.23"

View File

@@ -6,18 +6,34 @@ use tokio_stream::{Stream, StreamExt};
use tokio_util::codec::LinesCodecError; use tokio_util::codec::LinesCodecError;
use tracing::error; use tracing::error;
pub static AUTH_SCOPE: &[&str] = &["https://www.googleapis.com/auth/cloud-platform"]; /// Async client for the Google Gemini API.
///
/// Provides methods for content generation, streaming, token counting, text embeddings,
/// and image prediction. All requests are authenticated with an API key passed at
/// construction time.
///
/// # Example
///
/// ```no_run
/// use google_genai::prelude::*;
///
/// # async fn run() -> google_genai::error::Result<()> {
/// let client = GeminiClient::new("YOUR_API_KEY".into());
/// let request = GenerateContentRequest::builder()
/// .contents(vec![Content::builder().add_text_part("Hi!").build()])
/// .build();
/// let response = client.generate_content(&request, "gemini-2.0-flash").await?;
/// # Ok(())
/// # }
/// ```
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct GeminiClient { pub struct GeminiClient {
client: reqwest::Client, client: reqwest::Client,
api_key: String, api_key: String,
} }
unsafe impl Send for GeminiClient {}
unsafe impl Sync for GeminiClient {}
impl GeminiClient { impl GeminiClient {
/// Creates a new [`GeminiClient`] with the given API key.
pub fn new(api_key: String) -> Self { pub fn new(api_key: String) -> Self {
GeminiClient { GeminiClient {
client: reqwest::Client::new(), client: reqwest::Client::new(),
@@ -25,6 +41,10 @@ impl GeminiClient {
} }
} }
/// Sends a content generation request and returns a stream of response chunks via SSE.
///
/// Each item in the stream is a [`GenerateContentResponseResult`] containing one or more
/// candidates. Useful for displaying incremental output as it is generated.
pub async fn stream_generate_content( pub async fn stream_generate_content(
&self, &self,
request: &GenerateContentRequest, request: &GenerateContentRequest,
@@ -57,6 +77,9 @@ impl GeminiClient {
) )
} }
/// Sends a content generation request and returns the complete response.
///
/// For streaming responses, use [`stream_generate_content`](Self::stream_generate_content).
pub async fn generate_content( pub async fn generate_content(
&self, &self,
request: &GenerateContentRequest, request: &GenerateContentRequest,
@@ -99,50 +122,12 @@ impl GeminiClient {
} }
} }
/// Prompts a conversation to the model. /// Generates text embeddings for the given input.
pub async fn prompt_conversation(
&self,
messages: &[Message],
model: &str,
) -> GeminiResult<Message> {
let request = GenerateContentRequest {
contents: messages
.iter()
.map(|m| Content {
role: Some(m.role),
parts: Some(vec![Part::from_text(m.text.clone())]),
})
.collect(),
generation_config: None,
tools: None,
system_instruction: None,
safety_settings: None,
};
let response = self.generate_content(&request, model).await?;
// Check for errors in the response.
let mut candidates = GeminiClient::collect_text_from_response(&response);
match candidates.pop() {
Some(text) => Ok(Message::new(Role::Model, &text)),
None => Err(GeminiError::NoCandidatesError),
}
}
fn collect_text_from_response(response: &GenerateContentResponseResult) -> Vec<String> {
response
.candidates
.iter()
.filter_map(Candidate::get_text)
.collect::<Vec<String>>()
}
pub async fn text_embeddings( pub async fn text_embeddings(
&self, &self,
request: &TextEmbeddingRequest, request: &TextEmbeddingRequest,
model: &str, model: &str,
) -> GeminiResult<TextEmbeddingResponse> { ) -> GeminiResult<TextEmbeddingResponseOk> {
let endpoint_url = let endpoint_url =
format!("https://generativelanguage.googleapis.com/v1beta/models/{model}:predict"); format!("https://generativelanguage.googleapis.com/v1beta/models/{model}:predict");
let resp = self let resp = self
@@ -152,16 +137,38 @@ impl GeminiClient {
.json(&request) .json(&request)
.send() .send()
.await?; .await?;
let status = resp.status();
let txt_json = resp.text().await?; let txt_json = resp.text().await?;
tracing::debug!("text_embeddings response: {:?}", txt_json); tracing::debug!("text_embeddings response: {:?}", txt_json);
Ok(serde_json::from_str::<TextEmbeddingResponse>(&txt_json)?)
if !status.is_success() {
if let Ok(gemini_error) =
serde_json::from_str::<crate::types::GeminiApiError>(&txt_json)
{
return Err(GeminiError::GeminiError(gemini_error));
}
return Err(GeminiError::GenericApiError {
status: status.as_u16(),
body: txt_json,
});
} }
match serde_json::from_str::<TextEmbeddingResponse>(&txt_json) {
Ok(response) => Ok(response.into_result()?),
Err(e) => {
error!(response = txt_json, error = ?e, "Failed to parse response");
Err(e.into())
}
}
}
/// Counts the number of tokens in the given content.
pub async fn count_tokens( pub async fn count_tokens(
&self, &self,
request: &CountTokensRequest, request: &CountTokensRequest,
model: &str, model: &str,
) -> GeminiResult<CountTokensResponse> { ) -> GeminiResult<CountTokensResponseResult> {
let endpoint_url = let endpoint_url =
format!("https://generativelanguage.googleapis.com/v1beta/models/{model}:countTokens"); format!("https://generativelanguage.googleapis.com/v1beta/models/{model}:countTokens");
let resp = self let resp = self
@@ -172,11 +179,32 @@ impl GeminiClient {
.send() .send()
.await?; .await?;
let status = resp.status();
let txt_json = resp.text().await?; let txt_json = resp.text().await?;
tracing::debug!("count_tokens response: {:?}", txt_json); tracing::debug!("count_tokens response: {:?}", txt_json);
Ok(serde_json::from_str(&txt_json)?)
if !status.is_success() {
if let Ok(gemini_error) =
serde_json::from_str::<crate::types::GeminiApiError>(&txt_json)
{
return Err(GeminiError::GeminiError(gemini_error));
}
return Err(GeminiError::GenericApiError {
status: status.as_u16(),
body: txt_json,
});
} }
match serde_json::from_str::<CountTokensResponse>(&txt_json) {
Ok(response) => Ok(response.into_result()?),
Err(e) => {
error!(response = txt_json, error = ?e, "Failed to parse response");
Err(e.into())
}
}
}
/// Generates images from a text prompt using an Imagen model.
pub async fn predict_image( pub async fn predict_image(
&self, &self,
request: &PredictImageRequest, request: &PredictImageRequest,

View File

@@ -1,42 +0,0 @@
use serde::{Deserialize, Serialize};
use crate::{error::Result, prelude::*};
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Message {
pub role: Role,
pub text: String,
}
impl Message {
pub fn new(role: Role, text: &str) -> Self {
Message {
role,
text: text.to_string(),
}
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Dialogue {
model: String,
messages: Vec<Message>,
}
impl Dialogue {
pub fn new(model: &str) -> Self {
Dialogue {
model: model.to_string(),
messages: vec![],
}
}
pub async fn do_turn(&mut self, gemini: &GeminiClient, message: &str) -> Result<Message> {
self.messages.push(Message::new(Role::User, message));
let response = gemini
.prompt_conversation(&self.messages, &self.model)
.await?;
self.messages.push(response.clone());
Ok(response)
}
}

View File

@@ -1,21 +1,39 @@
//! Error types for the Google Gemini client.
use std::fmt::Display; use std::fmt::Display;
use tokio_util::codec::LinesCodecError; use tokio_util::codec::LinesCodecError;
use crate::types; use crate::types;
/// A type alias for `Result<T, error::Error>`.
pub type Result<T> = std::result::Result<T, Error>; pub type Result<T> = std::result::Result<T, Error>;
/// Errors that can occur when using the Gemini client.
#[derive(Debug)] #[derive(Debug)]
pub enum Error { pub enum Error {
/// An environment variable required for configuration was missing or invalid.
Env(std::env::VarError), Env(std::env::VarError),
/// An HTTP transport error from the underlying `reqwest` client.
HttpClient(reqwest::Error), HttpClient(reqwest::Error),
/// A JSON serialization or deserialization error.
Serde(serde_json::Error), Serde(serde_json::Error),
/// A structured error returned by the Vertex AI API.
VertexError(types::VertexApiError), VertexError(types::VertexApiError),
/// A structured error returned by the Gemini API.
GeminiError(types::GeminiApiError), GeminiError(types::GeminiApiError),
/// The API response contained no candidate completions.
NoCandidatesError, NoCandidatesError,
/// An error occurred while decoding the SSE event stream.
EventSourceError(LinesCodecError), EventSourceError(LinesCodecError),
/// The SSE event stream closed unexpectedly.
EventSourceClosedError, EventSourceClosedError,
GenericApiError { status: u16, body: String }, /// An API error that could not be parsed into a structured error type.
GenericApiError {
/// The HTTP status code.
status: u16,
/// The raw response body.
body: String,
},
} }
impl Display for Error { impl Display for Error {
@@ -34,7 +52,7 @@ impl Display for Error {
write!(f, "No candidates returned for the prompt") write!(f, "No candidates returned for the prompt")
} }
Error::EventSourceError(e) => { Error::EventSourceError(e) => {
write!(f, "EventSourrce Error: {e}") write!(f, "EventSource Error: {e}")
} }
Error::EventSourceClosedError => { Error::EventSourceClosedError => {
write!(f, "EventSource closed error") write!(f, "EventSource closed error")

View File

@@ -1,11 +1,38 @@
//! Async Rust client for the Google Gemini API.
//!
//! This crate provides a high-level async client for interacting with Google's Gemini
//! generative AI models. It supports content generation (including streaming via SSE),
//! token counting, text embeddings, and image generation.
//!
//! # Usage
//!
//! ```no_run
//! use google_genai::prelude::*;
//!
//! # async fn run() -> google_genai::error::Result<()> {
//! let client = GeminiClient::new("YOUR_API_KEY".into());
//!
//! let request = GenerateContentRequest::builder()
//! .contents(vec![
//! Content::builder().add_text_part("Hello, Gemini!").build()
//! ])
//! .build();
//!
//! let response = client.generate_content(&request, "gemini-2.0-flash").await?;
//! # Ok(())
//! # }
//! ```
mod client; mod client;
mod dialogue;
pub mod error; pub mod error;
pub mod network; pub mod network;
mod types; mod types;
/// Convenience re-exports of the most commonly used types.
///
/// Importing `use google_genai::prelude::*` brings [`GeminiClient`](crate::prelude::GeminiClient)
/// and all request/response types into scope.
pub mod prelude { pub mod prelude {
pub use crate::client::*; pub use crate::client::*;
pub use crate::dialogue::*;
pub use crate::types::*; pub use crate::types::*;
} }

View File

@@ -1,3 +1,11 @@
//! Server-Sent Events (SSE) decoder for streaming HTTP responses.
//!
//! Implements a [`tokio_util::codec::Decoder`] that parses an SSE byte stream into
//! [`ServerSentEvent`] values. Used internally by [`GeminiClient::stream_generate_content`]
//! to process chunked model responses.
//!
//! [`GeminiClient::stream_generate_content`]: crate::prelude::GeminiClient::stream_generate_content
use reqwest::Response; use reqwest::Response;
use std::mem; use std::mem;
use tokio_stream::{Stream, StreamExt}; use tokio_stream::{Stream, StreamExt};
@@ -12,7 +20,9 @@ static DATA: &str = "data: ";
static ID: &str = "id: "; static ID: &str = "id: ";
static RETRY: &str = "retry: "; static RETRY: &str = "retry: ";
/// Extension trait for converting an HTTP response into a stream of [`ServerSentEvent`]s.
pub trait EventSource { pub trait EventSource {
/// Consumes the response and returns a stream of parsed SSE events.
fn event_stream(self) -> impl Stream<Item = Result<ServerSentEvent, LinesCodecError>>; fn event_stream(self) -> impl Stream<Item = Result<ServerSentEvent, LinesCodecError>>;
} }
@@ -22,14 +32,25 @@ impl EventSource for Response {
} }
} }
/// A parsed Server-Sent Event.
///
/// Fields correspond to the standard SSE fields: `event`, `data`, `id`, and `retry`.
/// Multiple `data:` lines within a single event are concatenated with newline separators.
#[derive(Debug, Default, Clone)] #[derive(Debug, Default, Clone)]
pub struct ServerSentEvent { pub struct ServerSentEvent {
/// The event type (from the `event:` field).
pub event: Option<String>, pub event: Option<String>,
/// The event payload (from one or more `data:` fields, joined by `\n`).
pub data: Option<String>, pub data: Option<String>,
/// The event ID (from the `id:` field).
pub id: Option<String>, pub id: Option<String>,
/// The reconnection time in milliseconds (from the `retry:` field).
pub retry: Option<usize>, pub retry: Option<usize>,
} }
/// A [`Decoder`] that parses a byte stream of SSE-formatted data into [`ServerSentEvent`]s.
///
/// Wraps a [`LinesCodec`] and accumulates fields until an empty line signals the end of an event.
pub struct ServerSentEventsCodec { pub struct ServerSentEventsCodec {
lines_code: LinesCodec, lines_code: LinesCodec,
next: ServerSentEvent, next: ServerSentEvent,
@@ -42,6 +63,7 @@ impl Default for ServerSentEventsCodec {
} }
impl ServerSentEventsCodec { impl ServerSentEventsCodec {
/// Creates a new SSE codec.
pub fn new() -> Self { pub fn new() -> Self {
Self { Self {
lines_code: LinesCodec::new(), lines_code: LinesCodec::new(),
@@ -73,7 +95,12 @@ impl Decoder for ServerSentEventsCodec {
self.next.event = Some(line); self.next.event = Some(line);
} else if line.starts_with(DATA) { } else if line.starts_with(DATA) {
line.drain(..DATA.len()); line.drain(..DATA.len());
self.next.data = Some(line) if let Some(ref mut existing) = self.next.data {
existing.push('\n');
existing.push_str(&line);
} else {
self.next.data = Some(line);
}
} else if line.starts_with(ID) { } else if line.starts_with(ID) {
line.drain(..ID.len()); line.drain(..ID.len());
self.next.id = Some(line); self.next.id = Some(line);
@@ -90,6 +117,9 @@ impl Decoder for ServerSentEventsCodec {
} }
} }
/// Converts a [`Response`] into a stream of [`ServerSentEvent`]s.
///
/// The response body is read as a byte stream and decoded using [`ServerSentEventsCodec`].
pub fn stream_response( pub fn stream_response(
response: Response, response: Response,
) -> impl Stream<Item = Result<ServerSentEvent, LinesCodecError>> { ) -> impl Stream<Item = Result<ServerSentEvent, LinesCodecError>> {

View File

@@ -1 +1,3 @@
//! Networking utilities for streaming HTTP responses.
pub mod event_source; pub mod event_source;

View File

@@ -5,13 +5,21 @@ use serde_json::Value;
use crate::types::FunctionResponse; use crate::types::FunctionResponse;
/// A conversation message containing one or more [`Part`]s.
///
/// See <https://ai.google.dev/api/caching#Content>.
#[derive(Clone, Default, Debug, Serialize, Deserialize)] #[derive(Clone, Default, Debug, Serialize, Deserialize)]
pub struct Content { pub struct Content {
/// The role of the message author (`user` or `model`).
pub role: Option<Role>, pub role: Option<Role>,
/// The ordered parts that make up this message.
pub parts: Option<Vec<Part>>, pub parts: Option<Vec<Part>>,
} }
impl Content { impl Content {
/// Concatenates all [`PartData::Text`] parts into a single string.
///
/// Returns `None` if there are no parts.
pub fn get_text(&self) -> Option<String> { pub fn get_text(&self) -> Option<String> {
self.parts.as_ref().map(|parts| { self.parts.as_ref().map(|parts| {
parts parts
@@ -24,25 +32,30 @@ impl Content {
}) })
} }
/// Creates a [`Content`] containing a single text part, suitable for use as a system instruction.
pub fn system_prompt<S: Into<String>>(system_prompt: S) -> Self { pub fn system_prompt<S: Into<String>>(system_prompt: S) -> Self {
Self::builder().add_text_part(system_prompt).build() Self::builder().add_text_part(system_prompt).build()
} }
/// Returns a new [`ContentBuilder`].
pub fn builder() -> ContentBuilder { pub fn builder() -> ContentBuilder {
ContentBuilder::default() ContentBuilder::default()
} }
} }
#[derive(Default)] /// Builder for constructing [`Content`] values incrementally.
#[derive(Clone, Debug, Default)]
pub struct ContentBuilder { pub struct ContentBuilder {
content: Content, content: Content,
} }
impl ContentBuilder { impl ContentBuilder {
/// Appends a text part to this content.
pub fn add_text_part<T: Into<String>>(self, text: T) -> Self { pub fn add_text_part<T: Into<String>>(self, text: T) -> Self {
self.add_part(Part::from_text(text.into())) self.add_part(Part::from_text(text.into()))
} }
/// Appends an arbitrary [`Part`] to this content.
pub fn add_part(mut self, part: Part) -> Self { pub fn add_part(mut self, part: Part) -> Self {
match &mut self.content.parts { match &mut self.content.parts {
Some(parts) => parts.push(part), Some(parts) => parts.push(part),
@@ -51,16 +64,19 @@ impl ContentBuilder {
self self
} }
/// Sets the [`Role`] for this content.
pub fn role(mut self, role: Role) -> Self { pub fn role(mut self, role: Role) -> Self {
self.content.role = Some(role); self.content.role = Some(role);
self self
} }
/// Consumes the builder and returns the constructed [`Content`].
pub fn build(self) -> Content { pub fn build(self) -> Content {
self.content self.content
} }
} }
/// The role of a message author in a conversation.
#[derive(Clone, Copy, Debug, Serialize, Deserialize)] #[derive(Clone, Copy, Debug, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")] #[serde(rename_all = "lowercase")]
pub enum Role { pub enum Role {
@@ -90,7 +106,9 @@ impl FromStr for Role {
} }
} }
/// See https://ai.google.dev/api/caching#Part /// A single unit of content within a [`Content`] message.
///
/// See <https://ai.google.dev/api/caching#Part>.
#[derive(Clone, Debug, Serialize, Deserialize)] #[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
pub struct Part { pub struct Part {
@@ -108,29 +126,42 @@ pub struct Part {
pub data: PartData, // Create enum for data. pub data: PartData, // Create enum for data.
} }
/// The payload of a [`Part`], representing different content types.
///
/// See <https://ai.google.dev/api/caching#Part>.
#[derive(Clone, Debug, Serialize, Deserialize)] #[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
pub enum PartData { pub enum PartData {
/// Plain text content.
Text(String), Text(String),
// https://ai.google.dev/api/caching#Blob /// Binary data encoded inline. See <https://ai.google.dev/api/caching#Blob>.
InlineData { InlineData {
/// The IANA MIME type of the data (e.g. `"image/png"`).
mime_type: String, mime_type: String,
/// Base64-encoded binary data.
data: String, data: String,
}, },
// https://ai.google.dev/api/caching#FunctionCall /// A function call requested by the model. See <https://ai.google.dev/api/caching#FunctionCall>.
FunctionCall { FunctionCall {
/// Optional unique identifier for the function call.
id: Option<String>, id: Option<String>,
/// The name of the function to call.
name: String, name: String,
/// The arguments to pass, as a JSON object.
args: Option<Value>, args: Option<Value>,
}, },
// https://ai.google.dev/api/caching#FunctionResponse /// A response to a function call. See <https://ai.google.dev/api/caching#FunctionResponse>.
FunctionResponse(FunctionResponse), FunctionResponse(FunctionResponse),
/// A reference to a file stored in the API.
FileData(Value), FileData(Value),
/// Code to be executed by the model.
ExecutableCode(Value), ExecutableCode(Value),
/// The result of executing code.
CodeExecutionResult(Value), CodeExecutionResult(Value),
} }
impl Part { impl Part {
/// Creates a [`Part`] containing only text.
pub fn from_text<S: Into<String>>(text: S) -> Self { pub fn from_text<S: Into<String>>(text: S) -> Self {
Self { Self {
thought: None, thought: None,

View File

@@ -1,24 +1,35 @@
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use crate::error::{Error, Result};
use super::Content; use super::Content;
#[derive(Debug, Serialize, Deserialize)] /// Request body for the `countTokens` endpoint.
///
/// Use [`CountTokensRequest::builder`] for ergonomic construction.
///
/// See <https://ai.google.dev/api/tokens#method:-models.counttokens>.
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct CountTokensRequest { pub struct CountTokensRequest {
/// The content to count tokens for.
pub contents: Content, pub contents: Content,
} }
impl CountTokensRequest { impl CountTokensRequest {
/// Returns a new [`CountTokensRequestBuilder`].
pub fn builder() -> CountTokensRequestBuilder { pub fn builder() -> CountTokensRequestBuilder {
CountTokensRequestBuilder::default() CountTokensRequestBuilder::default()
} }
} }
#[derive(Default)] /// Builder for [`CountTokensRequest`].
#[derive(Debug, Default)]
pub struct CountTokensRequestBuilder { pub struct CountTokensRequestBuilder {
contents: Content, contents: Content,
} }
impl CountTokensRequestBuilder { impl CountTokensRequestBuilder {
/// Creates a builder pre-populated with a single text prompt.
pub fn from_prompt(prompt: &str) -> Self { pub fn from_prompt(prompt: &str) -> Self {
CountTokensRequestBuilder { CountTokensRequestBuilder {
contents: Content { contents: Content {
@@ -28,6 +39,7 @@ impl CountTokensRequestBuilder {
} }
} }
/// Consumes the builder and returns the constructed [`CountTokensRequest`].
pub fn build(self) -> CountTokensRequest { pub fn build(self) -> CountTokensRequest {
CountTokensRequest { CountTokensRequest {
contents: self.contents, contents: self.contents,
@@ -35,15 +47,32 @@ impl CountTokensRequestBuilder {
} }
} }
#[derive(Debug, Serialize, Deserialize)] /// The raw response from the `countTokens` endpoint, which may be a success or an error.
///
/// Use [`into_result`](CountTokensResponse::into_result) to convert into a standard `Result`.
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(untagged)] #[serde(untagged)]
pub enum CountTokensResponse { pub enum CountTokensResponse {
#[serde(rename_all = "camelCase")] Ok(CountTokensResponseResult),
Ok { Error { error: super::VertexApiError },
total_tokens: i32, }
total_billable_characters: u32,
}, impl CountTokensResponse {
Error { /// Converts this response into a `Result`, mapping the error variant to [`crate::error::Error`].
error: super::VertexApiError, pub fn into_result(self) -> Result<CountTokensResponseResult> {
}, match self {
CountTokensResponse::Ok(result) => Ok(result),
CountTokensResponse::Error { error } => Err(Error::VertexError(error)),
}
}
}
/// A successful response from the `countTokens` endpoint.
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct CountTokensResponseResult {
/// The total number of tokens in the input.
pub total_tokens: i32,
/// The total number of billable characters in the input.
pub total_billable_characters: u32,
} }

View File

@@ -2,11 +2,16 @@ use std::fmt::Formatter;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
/// A structured error returned by the Vertex AI / Gemini API.
#[derive(Clone, Debug, Serialize, Deserialize)] #[derive(Clone, Debug, Serialize, Deserialize)]
pub struct VertexApiError { pub struct VertexApiError {
/// The HTTP status code.
pub code: i32, pub code: i32,
/// A human-readable error message.
pub message: String, pub message: String,
/// The gRPC status string (e.g. `"INVALID_ARGUMENT"`).
pub status: String, pub status: String,
/// Optional additional error details.
pub details: Option<Vec<serde_json::Value>>, pub details: Option<Vec<serde_json::Value>>,
} }
@@ -19,49 +24,23 @@ impl core::fmt::Display for VertexApiError {
impl std::error::Error for VertexApiError {} impl std::error::Error for VertexApiError {}
/// A wrapper around [`VertexApiError`] matching the Gemini API error response format.
///
/// The Gemini API nests the error details inside an `error` field.
#[derive(Clone, Debug, Serialize, Deserialize)] #[derive(Clone, Debug, Serialize, Deserialize)]
pub struct GeminiApiError { pub struct GeminiApiError {
/// The inner error details.
pub error: VertexApiError, pub error: VertexApiError,
} }
impl core::fmt::Display for GeminiApiError { impl core::fmt::Display for GeminiApiError {
fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), std::fmt::Error> { fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), std::fmt::Error> {
write!(f, "Gemini API Error {} - {}", self.error.code, self.error.message) write!(
f,
"Gemini API Error {} - {}",
self.error.code, self.error.message
)
} }
} }
impl std::error::Error for GeminiApiError {} impl std::error::Error for GeminiApiError {}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Link {
pub description: String,
pub url: String,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(tag = "@type")]
pub enum ErrorType {
#[serde(rename = "type.googleapis.com/google.rpc.ErrorInfo")]
ErrorInfo { metadata: ErrorInfoMetadata },
#[serde(rename = "type.googleapis.com/google.rpc.Help")]
Help { links: Vec<Link> },
#[serde(rename = "type.googleapis.com/google.rpc.BadRequest")]
BadRequest {
#[serde(rename = "fieldViolations")]
field_violations: Vec<FieldViolation>,
},
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ErrorInfoMetadata {
pub service: String,
pub consumer: String,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct FieldViolation {
pub field: String,
pub description: String,
}

View File

@@ -4,7 +4,12 @@ use serde_json::Value;
use super::{Content, VertexApiError}; use super::{Content, VertexApiError};
use crate::error::Result; use crate::error::Result;
#[derive(Clone, Default, Serialize, Deserialize)] /// Request body for the `generateContent` and `streamGenerateContent` endpoints.
///
/// Use [`GenerateContentRequest::builder`] for ergonomic construction.
///
/// See <https://ai.google.dev/api/generate-content#request-body>.
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
pub struct GenerateContentRequest { pub struct GenerateContentRequest {
pub contents: Vec<Content>, pub contents: Vec<Content>,
@@ -19,11 +24,14 @@ pub struct GenerateContentRequest {
} }
impl GenerateContentRequest { impl GenerateContentRequest {
/// Returns a new [`GenerateContentRequestBuilder`].
pub fn builder() -> GenerateContentRequestBuilder { pub fn builder() -> GenerateContentRequestBuilder {
GenerateContentRequestBuilder::new() GenerateContentRequestBuilder::new()
} }
} }
/// Builder for [`GenerateContentRequest`].
#[derive(Debug)]
pub struct GenerateContentRequestBuilder { pub struct GenerateContentRequestBuilder {
request: GenerateContentRequest, request: GenerateContentRequest,
} }
@@ -35,37 +43,46 @@ impl GenerateContentRequestBuilder {
} }
} }
/// Sets the conversation contents.
pub fn contents(mut self, contents: Vec<Content>) -> Self { pub fn contents(mut self, contents: Vec<Content>) -> Self {
self.request.contents = contents; self.request.contents = contents;
self self
} }
/// Sets the generation configuration.
pub fn generation_config(mut self, generation_config: GenerationConfig) -> Self { pub fn generation_config(mut self, generation_config: GenerationConfig) -> Self {
self.request.generation_config = Some(generation_config); self.request.generation_config = Some(generation_config);
self self
} }
/// Sets the tools available to the model (e.g. function calling, Google Search).
pub fn tools(mut self, tools: Vec<Tools>) -> Self { pub fn tools(mut self, tools: Vec<Tools>) -> Self {
self.request.tools = Some(tools); self.request.tools = Some(tools);
self self
} }
/// Sets the safety filter settings.
pub fn safety_settings(mut self, safety_settings: Vec<SafetySetting>) -> Self { pub fn safety_settings(mut self, safety_settings: Vec<SafetySetting>) -> Self {
self.request.safety_settings = Some(safety_settings); self.request.safety_settings = Some(safety_settings);
self self
} }
/// Sets a system instruction to guide the model's behavior.
pub fn system_instruction(mut self, system_instruction: Content) -> Self { pub fn system_instruction(mut self, system_instruction: Content) -> Self {
self.request.system_instruction = Some(system_instruction); self.request.system_instruction = Some(system_instruction);
self self
} }
/// Consumes the builder and returns the constructed [`GenerateContentRequest`].
pub fn build(self) -> GenerateContentRequest { pub fn build(self) -> GenerateContentRequest {
self.request self.request
} }
} }
#[derive(Clone, Default, Serialize, Deserialize)] /// A set of tool declarations the model may use during generation.
///
/// See <https://ai.google.dev/api/caching#Tool>.
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
pub struct Tools { pub struct Tools {
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub function_declarations: Option<Vec<FunctionDeclaration>>, pub function_declarations: Option<Vec<FunctionDeclaration>>,
@@ -78,13 +95,17 @@ pub struct Tools {
pub google_search: Option<GoogleSearch>, pub google_search: Option<GoogleSearch>,
} }
#[derive(Clone, Default, Serialize, Deserialize)] /// Enables the Google Search grounding tool (no configuration required).
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
pub struct GoogleSearch {} pub struct GoogleSearch {}
#[derive(Clone, Serialize, Deserialize)] /// Configuration for dynamic retrieval in Google Search grounding.
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
pub struct DynamicRetrievalConfig { pub struct DynamicRetrievalConfig {
/// The retrieval mode (e.g. `"MODE_DYNAMIC"`).
pub mode: String, pub mode: String,
/// The threshold for triggering retrieval. Defaults to `0.7`.
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub dynamic_threshold: Option<f32>, pub dynamic_threshold: Option<f32>,
} }
@@ -98,12 +119,19 @@ impl Default for DynamicRetrievalConfig {
} }
} }
#[derive(Clone, Default, Serialize, Deserialize)] /// Google Search retrieval tool with dynamic retrieval configuration.
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
pub struct GoogleSearchRetrieval { pub struct GoogleSearchRetrieval {
/// Configuration controlling when retrieval is triggered.
pub dynamic_retrieval_config: DynamicRetrievalConfig, pub dynamic_retrieval_config: DynamicRetrievalConfig,
} }
/// Parameters that control how the model generates content.
///
/// Use [`GenerationConfig::builder`] for ergonomic construction.
///
/// See <https://ai.google.dev/api/generate-content#generationconfig>.
#[derive(Clone, Debug, Serialize, Deserialize, Default)] #[derive(Clone, Debug, Serialize, Deserialize, Default)]
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
pub struct GenerationConfig { pub struct GenerationConfig {
@@ -128,11 +156,14 @@ pub struct GenerationConfig {
} }
impl GenerationConfig { impl GenerationConfig {
/// Returns a new [`GenerationConfigBuilder`].
pub fn builder() -> GenerationConfigBuilder { pub fn builder() -> GenerationConfigBuilder {
GenerationConfigBuilder::new() GenerationConfigBuilder::new()
} }
} }
/// Builder for [`GenerationConfig`].
#[derive(Debug)]
pub struct GenerationConfigBuilder { pub struct GenerationConfigBuilder {
generation_config: GenerationConfig, generation_config: GenerationConfig,
} }
@@ -189,11 +220,13 @@ impl GenerationConfigBuilder {
self self
} }
/// Consumes the builder and returns the constructed [`GenerationConfig`].
pub fn build(self) -> GenerationConfig { pub fn build(self) -> GenerationConfig {
self.generation_config self.generation_config
} }
} }
/// Configuration for the model's "thinking" (chain-of-thought) behavior.
#[derive(Clone, Debug, Default, Serialize, Deserialize)] #[derive(Clone, Debug, Default, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
pub struct ThinkingConfig { pub struct ThinkingConfig {
@@ -204,6 +237,7 @@ pub struct ThinkingConfig {
pub thinking_level: Option<ThinkingLevel>, pub thinking_level: Option<ThinkingLevel>,
} }
/// The level of thinking effort the model should use.
#[derive(Clone, Debug, Serialize, Deserialize)] #[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "SCREAMING_SNAKE_CASE")] #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
pub enum ThinkingLevel { pub enum ThinkingLevel {
@@ -212,6 +246,9 @@ pub enum ThinkingLevel {
High, High,
} }
/// A safety filter configuration that controls blocking thresholds for harmful content.
///
/// See <https://ai.google.dev/api/generate-content#safetysetting>.
#[derive(Clone, Debug, Serialize, Deserialize)] #[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
pub struct SafetySetting { pub struct SafetySetting {
@@ -221,6 +258,9 @@ pub struct SafetySetting {
pub method: Option<HarmBlockMethod>, pub method: Option<HarmBlockMethod>,
} }
/// Categories of potentially harmful content.
///
/// See <https://ai.google.dev/api/generate-content#harmcategory>.
#[derive(Clone, Debug, Serialize, Deserialize)] #[derive(Clone, Debug, Serialize, Deserialize)]
pub enum HarmCategory { pub enum HarmCategory {
#[serde(rename = "HARM_CATEGORY_UNSPECIFIED")] #[serde(rename = "HARM_CATEGORY_UNSPECIFIED")]
@@ -235,6 +275,7 @@ pub enum HarmCategory {
SexuallyExplicit, SexuallyExplicit,
} }
/// The threshold at which harmful content is blocked.
#[derive(Clone, Debug, Serialize, Deserialize)] #[derive(Clone, Debug, Serialize, Deserialize)]
pub enum HarmBlockThreshold { pub enum HarmBlockThreshold {
#[serde(rename = "HARM_BLOCK_THRESHOLD_UNSPECIFIED")] #[serde(rename = "HARM_BLOCK_THRESHOLD_UNSPECIFIED")]
@@ -249,6 +290,7 @@ pub enum HarmBlockThreshold {
BlockNone, BlockNone,
} }
/// The method used to evaluate harm (severity-based or probability-based).
#[derive(Clone, Debug, Serialize, Deserialize)] #[derive(Clone, Debug, Serialize, Deserialize)]
pub enum HarmBlockMethod { pub enum HarmBlockMethod {
#[serde(rename = "HARM_BLOCK_METHOD_UNSPECIFIED")] #[serde(rename = "HARM_BLOCK_METHOD_UNSPECIFIED")]
@@ -259,7 +301,10 @@ pub enum HarmBlockMethod {
Probability, // PROBABILITY Probability, // PROBABILITY
} }
#[derive(Debug, Serialize, Deserialize)] /// A single candidate response generated by the model.
///
/// See <https://ai.google.dev/api/generate-content#candidate>.
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
pub struct Candidate { pub struct Candidate {
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
@@ -274,6 +319,7 @@ pub struct Candidate {
} }
impl Candidate { impl Candidate {
/// Returns the concatenated text from this candidate's content, if any.
pub fn get_text(&self) -> Option<String> { pub fn get_text(&self) -> Option<String> {
match &self.content { match &self.content {
Some(content) => content.get_text(), Some(content) => content.get_text(),
@@ -282,7 +328,8 @@ impl Candidate {
} }
} }
#[derive(Debug, Serialize, Deserialize)] /// A citation to a source used by the model in its response.
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
pub struct Citation { pub struct Citation {
pub start_index: Option<i32>, pub start_index: Option<i32>,
@@ -290,13 +337,15 @@ pub struct Citation {
pub uri: Option<String>, pub uri: Option<String>,
} }
#[derive(Debug, Serialize, Deserialize)] /// Metadata containing citations for a candidate's content.
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct CitationMetadata { pub struct CitationMetadata {
#[serde(alias = "citationSources")] #[serde(alias = "citationSources")]
pub citations: Vec<Citation>, pub citations: Vec<Citation>,
} }
#[derive(Debug, Serialize, Deserialize)] /// A safety rating for a piece of content across a specific harm category.
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
pub struct SafetyRating { pub struct SafetyRating {
pub category: String, pub category: String,
@@ -306,7 +355,8 @@ pub struct SafetyRating {
pub severity_score: Option<f32>, pub severity_score: Option<f32>,
} }
#[derive(Debug, Serialize, Deserialize)] /// Token usage statistics for a generate content request/response.
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
pub struct UsageMetadata { pub struct UsageMetadata {
pub candidates_token_count: Option<u32>, pub candidates_token_count: Option<u32>,
@@ -314,6 +364,9 @@ pub struct UsageMetadata {
pub total_token_count: Option<u32>, pub total_token_count: Option<u32>,
} }
/// A declaration of a function the model may call.
///
/// See <https://ai.google.dev/api/caching#FunctionDeclaration>.
#[derive(Clone, Debug, Serialize, Deserialize)] #[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
pub struct FunctionDeclaration { pub struct FunctionDeclaration {
@@ -330,7 +383,7 @@ pub struct FunctionDeclaration {
pub response_json_schema: Option<Value>, pub response_json_schema: Option<Value>,
} }
/// See https://ai.google.dev/api/caching#FunctionResponse /// See <https://ai.google.dev/api/caching#FunctionResponse>.
#[derive(Clone, Debug, Serialize, Deserialize)] #[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
pub struct FunctionResponse { pub struct FunctionResponse {
@@ -343,14 +396,14 @@ pub struct FunctionResponse {
pub scheduling: Option<Scheduling>, pub scheduling: Option<Scheduling>,
} }
/// See https://ai.google.dev/api/caching#FunctionResponsePart /// See <https://ai.google.dev/api/caching#FunctionResponsePart>.
#[derive(Clone, Debug, Serialize, Deserialize)] #[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
pub enum FunctionResponsePart { pub enum FunctionResponsePart {
InlineData(FunctionResponseBlob), InlineData(FunctionResponseBlob),
} }
/// See https://ai.google.dev/api/caching#FunctionResponseBlob /// See <https://ai.google.dev/api/caching#FunctionResponseBlob>.
#[derive(Clone, Debug, Serialize, Deserialize)] #[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
pub struct FunctionResponseBlob { pub struct FunctionResponseBlob {
@@ -358,7 +411,7 @@ pub struct FunctionResponseBlob {
pub data: String, pub data: String,
} }
/// See https://ai.google.dev/api/caching#Scheduling /// See <https://ai.google.dev/api/caching#Scheduling>.
#[derive(Clone, Debug, Serialize, Deserialize)] #[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "SCREAMING_SNAKE_CASE")] #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
pub enum Scheduling { pub enum Scheduling {
@@ -368,6 +421,7 @@ pub enum Scheduling {
Interrupt, Interrupt,
} }
/// A single property within a function's parameter schema.
#[derive(Clone, Debug, Serialize, Deserialize)] #[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
pub struct FunctionParametersProperty { pub struct FunctionParametersProperty {
@@ -375,7 +429,11 @@ pub struct FunctionParametersProperty {
pub description: String, pub description: String,
} }
#[derive(Debug, Serialize, Deserialize)] /// The raw response from the `generateContent` endpoint, which may be a success or an error.
///
/// Use [`into_result`](GenerateContentResponse::into_result) to convert into a standard
/// `Result<GenerateContentResponseResult>`.
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(untagged)] #[serde(untagged)]
pub enum GenerateContentResponse { pub enum GenerateContentResponse {
Ok(GenerateContentResponseResult), Ok(GenerateContentResponseResult),
@@ -391,19 +449,22 @@ impl From<GenerateContentResponse> for Result<GenerateContentResponseResult> {
} }
} }
#[derive(Debug, Serialize, Deserialize)] /// A successful response from the `generateContent` endpoint.
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
pub struct GenerateContentResponseResult { pub struct GenerateContentResponseResult {
pub candidates: Vec<Candidate>, pub candidates: Vec<Candidate>,
pub usage_metadata: Option<UsageMetadata>, pub usage_metadata: Option<UsageMetadata>,
} }
#[derive(Debug, Serialize, Deserialize)] /// An error response from the `generateContent` endpoint.
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct GenerateContentResponseError { pub struct GenerateContentResponseError {
pub error: VertexApiError, pub error: VertexApiError,
} }
impl GenerateContentResponse { impl GenerateContentResponse {
/// Converts this response into a `Result`, mapping the error variant to [`crate::error::Error`].
pub fn into_result(self) -> Result<GenerateContentResponseResult> { pub fn into_result(self) -> Result<GenerateContentResponseResult> {
match self { match self {
GenerateContentResponse::Ok(result) => Ok(result), GenerateContentResponse::Ok(result) => Ok(result),
@@ -416,13 +477,7 @@ impl GenerateContentResponse {
mod tests { mod tests {
use crate::types::{Candidate, UsageMetadata}; use crate::types::{Candidate, UsageMetadata};
use super::{GenerateContentResponse, GenerateContentResponseResult}; use super::GenerateContentResponseResult;
#[test]
pub fn parses_empty_metadata_response() {
let input = r#"{"candidates": [{"content": {"role": "model","parts": [{"text": "-"}]}}],"usageMetadata": {}}"#;
serde_json::from_str::<GenerateContentResponseResult>(input).unwrap();
}
#[test] #[test]
pub fn parses_usage_metadata() { pub fn parses_usage_metadata() {
@@ -498,177 +553,4 @@ mod tests {
"#; "#;
let _ = serde_json::from_str::<GenerateContentResponseResult>(input).unwrap(); let _ = serde_json::from_str::<GenerateContentResponseResult>(input).unwrap();
} }
#[test]
pub fn parses_max_tokens_response() {
let input = r#"{
"candidates": [
{
"content": {
"role": "model",
"parts": [
{
"text": "Service workers are powerful and absolutely worth learning. They let you deliver an entirely new level of experience to your users. Your site can load instantly . It can work offline . It can be installed as a platform-specific app and feel every bit as polished—but with the reach and freedom of the web."
}
]
},
"finishReason": "MAX_TOKENS",
"safetyRatings": [
{
"category": "HARM_CATEGORY_HATE_SPEECH",
"probability": "NEGLIGIBLE",
"probabilityScore": 0.03882902,
"severity": "HARM_SEVERITY_NEGLIGIBLE",
"severityScore": 0.05781161
},
{
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
"probability": "NEGLIGIBLE",
"probabilityScore": 0.07626997,
"severity": "HARM_SEVERITY_NEGLIGIBLE",
"severityScore": 0.06705628
},
{
"category": "HARM_CATEGORY_HARASSMENT",
"probability": "NEGLIGIBLE",
"probabilityScore": 0.05749328,
"severity": "HARM_SEVERITY_NEGLIGIBLE",
"severityScore": 0.027532939
},
{
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
"probability": "NEGLIGIBLE",
"probabilityScore": 0.12929276,
"severity": "HARM_SEVERITY_NEGLIGIBLE",
"severityScore": 0.17838266
}
],
"citationMetadata": {
"citations": [
{
"endIndex": 151,
"uri": "https://web.dev/service-worker-mindset/"
},
{
"startIndex": 93,
"endIndex": 297,
"uri": "https://web.dev/service-worker-mindset/"
},
{
"endIndex": 297
}
]
}
}
],
"usageMetadata": {
"promptTokenCount": 12069,
"candidatesTokenCount": 61,
"totalTokenCount": 12130
}
}"#;
serde_json::from_str::<GenerateContentResponseResult>(input).unwrap();
}
#[test]
fn parses_candidates_without_content() {
let input = r#"{
"candidates": [
{
"finishReason": "RECITATION",
"safetyRatings": [
{
"category": "HARM_CATEGORY_HATE_SPEECH",
"probability": "NEGLIGIBLE",
"probabilityScore": 0.08021325,
"severity": "HARM_SEVERITY_NEGLIGIBLE",
"severityScore": 0.0721122
},
{
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
"probability": "NEGLIGIBLE",
"probabilityScore": 0.19360436,
"severity": "HARM_SEVERITY_NEGLIGIBLE",
"severityScore": 0.1066906
},
{
"category": "HARM_CATEGORY_HARASSMENT",
"probability": "NEGLIGIBLE",
"probabilityScore": 0.07751766,
"severity": "HARM_SEVERITY_NEGLIGIBLE",
"severityScore": 0.040769264
},
{
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
"probability": "NEGLIGIBLE",
"probabilityScore": 0.030792166,
"severity": "HARM_SEVERITY_NEGLIGIBLE",
"severityScore": 0.04138472
}
],
"citationMetadata": {
"citations": [
{
"startIndex": 1108,
"endIndex": 1250,
"uri": "https://chrome.google.com/webstore/detail/autocontrol-shortcut-mana/lkaihdpfpifdlgoapbfocpmekbokmcfd?hl=zh-TW"
}
]
}
}
],
"usageMetadata": {
"promptTokenCount": 577,
"totalTokenCount": 577
}
}"#;
serde_json::from_str::<GenerateContentResponse>(input).unwrap();
}
#[test]
fn parses_safety_rating_without_scores() {
let input = r#"{
"candidates": [
{
"content": {
"role": "model",
"parts": [
{
"text": "Return text"
}
]
},
"finishReason": "STOP",
"safetyRatings": [
{
"category": "HARM_CATEGORY_HATE_SPEECH",
"probability": "NEGLIGIBLE",
"severity": "HARM_SEVERITY_NEGLIGIBLE"
},
{
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
"probability": "NEGLIGIBLE",
"severity": "HARM_SEVERITY_NEGLIGIBLE"
},
{
"category": "HARM_CATEGORY_HARASSMENT",
"probability": "NEGLIGIBLE",
"severity": "HARM_SEVERITY_NEGLIGIBLE"
},
{
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
"probability": "NEGLIGIBLE",
"severity": "HARM_SEVERITY_NEGLIGIBLE"
}
]
}
],
"usageMetadata": {
"promptTokenCount": 5492,
"candidatesTokenCount": 1256,
"totalTokenCount": 6748
}
}"#;
serde_json::from_str::<GenerateContentResponse>(input).unwrap();
}
} }

View File

@@ -1,3 +1,5 @@
//! Request and response types for the Gemini API.
mod common; mod common;
mod count_tokens; mod count_tokens;
mod error; mod error;

View File

@@ -2,13 +2,15 @@ use serde::{Deserialize, Serialize};
use serde_with::base64::Base64; use serde_with::base64::Base64;
use serde_with::serde_as; use serde_with::serde_as;
#[derive(Debug, Serialize, Deserialize)] /// Request body for the Imagen image generation `predict` endpoint.
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct PredictImageRequest { pub struct PredictImageRequest {
pub instances: Vec<PredictImageRequestPrompt>, pub instances: Vec<PredictImageRequestPrompt>,
pub parameters: PredictImageRequestParameters, pub parameters: PredictImageRequestParameters,
} }
#[derive(Debug, Serialize, Deserialize)] /// A text prompt instance for image generation.
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct PredictImageRequestPrompt { pub struct PredictImageRequestPrompt {
/// The text prompt for the image. /// The text prompt for the image.
/// The following models support different values for this parameter: /// The following models support different values for this parameter:
@@ -20,7 +22,8 @@ pub struct PredictImageRequestPrompt {
pub prompt: String, pub prompt: String,
} }
#[derive(Debug, Default, Serialize, Deserialize)] /// Parameters controlling image generation behavior.
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
pub struct PredictImageRequestParameters { pub struct PredictImageRequestParameters {
/// The number of images to generate. The default value is 4. /// The number of images to generate. The default value is 4.
@@ -139,7 +142,8 @@ pub struct PredictImageRequestParameters {
pub storage_uri: Option<String>, pub storage_uri: Option<String>,
} }
#[derive(Debug, Serialize, Deserialize)] /// Output format options for generated images.
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
pub struct PredictImageRequestParametersOutputOptions { pub struct PredictImageRequestParametersOutputOptions {
/// The image format that the output should be saved as. The following values are supported: /// The image format that the output should be saved as. The following values are supported:
@@ -155,13 +159,15 @@ pub struct PredictImageRequestParametersOutputOptions {
pub compression_quality: Option<i32>, pub compression_quality: Option<i32>,
} }
#[derive(Debug, Serialize, Deserialize)] /// A successful response from the Imagen `predict` endpoint.
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct PredictImageResponse { pub struct PredictImageResponse {
pub predictions: Vec<PredictImageResponsePrediction>, pub predictions: Vec<PredictImageResponsePrediction>,
} }
/// A single generated image from the prediction response.
#[serde_as] #[serde_as]
#[derive(Debug, Serialize, Deserialize)] #[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
pub struct PredictImageResponsePrediction { pub struct PredictImageResponsePrediction {
#[serde_as(as = "Base64")] #[serde_as(as = "Base64")]
@@ -169,7 +175,8 @@ pub struct PredictImageResponsePrediction {
pub mime_type: String, pub mime_type: String,
} }
#[derive(Debug, Serialize, Deserialize)] /// Controls whether generated images may include people.
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]
pub enum PersonGeneration { pub enum PersonGeneration {
DontAllow, DontAllow,
@@ -177,7 +184,8 @@ pub enum PersonGeneration {
AllowAll, AllowAll,
} }
#[derive(Debug, Serialize, Deserialize)] /// Safety filter level for image generation.
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]
pub enum PredictImageSafetySetting { pub enum PredictImageSafetySetting {
BlockLowAndAbove, BlockLowAndAbove,

View File

@@ -3,18 +3,27 @@ use serde::{Deserialize, Serialize};
use crate::error::{Error, Result}; use crate::error::{Error, Result};
use crate::prelude::VertexApiError; use crate::prelude::VertexApiError;
/// Request body for the text embeddings `predict` endpoint.
#[derive(Clone, Debug, Serialize, Deserialize)] #[derive(Clone, Debug, Serialize, Deserialize)]
pub struct TextEmbeddingRequest { pub struct TextEmbeddingRequest {
/// The list of text instances to embed.
pub instances: Vec<TextEmbeddingRequestInstance>, pub instances: Vec<TextEmbeddingRequestInstance>,
} }
/// A single text instance to embed.
#[derive(Clone, Debug, Serialize, Deserialize)] #[derive(Clone, Debug, Serialize, Deserialize)]
pub struct TextEmbeddingRequestInstance { pub struct TextEmbeddingRequestInstance {
/// The text content to generate an embedding for.
pub content: String, pub content: String,
/// The task type for the embedding (e.g. `"RETRIEVAL_DOCUMENT"`, `"RETRIEVAL_QUERY"`).
pub task_type: String, pub task_type: String,
/// An optional title for the content (used with retrieval task types).
pub title: Option<String>, pub title: Option<String>,
} }
/// The raw response from the text embeddings endpoint, which may be a success or an error.
///
/// Use [`into_result`](TextEmbeddingResponse::into_result) to convert into a standard `Result`.
#[derive(Clone, Debug, Serialize, Deserialize)] #[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(untagged)] #[serde(untagged)]
pub enum TextEmbeddingResponse { pub enum TextEmbeddingResponse {
@@ -23,13 +32,16 @@ pub enum TextEmbeddingResponse {
} }
impl TextEmbeddingResponse { impl TextEmbeddingResponse {
/// Converts this response into a `Result`, mapping the error variant to [`crate::error::Error`].
pub fn into_result(self) -> Result<TextEmbeddingResponseOk> { pub fn into_result(self) -> Result<TextEmbeddingResponseOk> {
self.into() self.into()
} }
} }
/// A successful response from the text embeddings endpoint.
#[derive(Clone, Debug, Serialize, Deserialize)] #[derive(Clone, Debug, Serialize, Deserialize)]
pub struct TextEmbeddingResponseOk { pub struct TextEmbeddingResponseOk {
/// The embedding predictions, one per input instance.
pub predictions: Vec<TextEmbeddingPrediction>, pub predictions: Vec<TextEmbeddingPrediction>,
} }
@@ -42,19 +54,27 @@ impl From<TextEmbeddingResponse> for Result<TextEmbeddingResponseOk> {
} }
} }
/// A single embedding prediction.
#[derive(Clone, Debug, Serialize, Deserialize)] #[derive(Clone, Debug, Serialize, Deserialize)]
pub struct TextEmbeddingPrediction { pub struct TextEmbeddingPrediction {
/// The embedding result containing the vector and statistics.
pub embeddings: TextEmbeddingResult, pub embeddings: TextEmbeddingResult,
} }
/// The embedding vector and associated statistics.
#[derive(Clone, Debug, Serialize, Deserialize)] #[derive(Clone, Debug, Serialize, Deserialize)]
pub struct TextEmbeddingResult { pub struct TextEmbeddingResult {
/// Statistics about the embedding computation.
pub statistics: TextEmbeddingStatistics, pub statistics: TextEmbeddingStatistics,
/// The embedding vector.
pub values: Vec<f64>, pub values: Vec<f64>,
} }
/// Statistics about a text embedding computation.
#[derive(Clone, Debug, Serialize, Deserialize)] #[derive(Clone, Debug, Serialize, Deserialize)]
pub struct TextEmbeddingStatistics { pub struct TextEmbeddingStatistics {
/// Whether the input was truncated to fit the model's context window.
pub truncated: bool, pub truncated: bool,
/// The number of tokens in the input.
pub token_count: u32, pub token_count: u32,
} }