Compare commits
10 Commits
09c8696a36
...
5e0fc06327
| Author | SHA1 | Date | |
|---|---|---|---|
| 5e0fc06327 | |||
| 2ed8d881c6 | |||
| c4acf465ba | |||
| f8a4323117 | |||
| d1bd00ce95 | |||
| a8fbe658bb | |||
| eb38c65ac5 | |||
| 4c156fbb33 | |||
| 92030a0dd9 | |||
| 56cf4f280b |
61
.github/workflows/ci.yml
vendored
Normal file
61
.github/workflows/ci.yml
vendored
Normal file
@@ -0,0 +1,61 @@
|
||||
name: CI
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [main]
|
||||
pull_request:
|
||||
branches: [main]
|
||||
|
||||
env:
|
||||
CARGO_TERM_COLOR: always
|
||||
|
||||
jobs:
|
||||
check:
|
||||
name: Check
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: dtolnay/rust-toolchain@stable
|
||||
- uses: Swatinem/rust-cache@v2
|
||||
- run: cargo check
|
||||
|
||||
test:
|
||||
name: Test
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: dtolnay/rust-toolchain@stable
|
||||
- uses: Swatinem/rust-cache@v2
|
||||
- run: cargo test
|
||||
|
||||
clippy:
|
||||
name: Clippy
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: dtolnay/rust-toolchain@stable
|
||||
with:
|
||||
components: clippy
|
||||
- uses: Swatinem/rust-cache@v2
|
||||
- run: cargo clippy -- -D warnings
|
||||
|
||||
fmt:
|
||||
name: Format
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: dtolnay/rust-toolchain@stable
|
||||
with:
|
||||
components: rustfmt
|
||||
- run: cargo fmt --check
|
||||
|
||||
doc:
|
||||
name: Documentation
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: dtolnay/rust-toolchain@stable
|
||||
- uses: Swatinem/rust-cache@v2
|
||||
- run: cargo doc --no-deps
|
||||
env:
|
||||
RUSTDOCFLAGS: "-D warnings"
|
||||
12
Cargo.toml
12
Cargo.toml
@@ -10,17 +10,17 @@ edition = "2024"
|
||||
reqwest = { version = "0.13", features = ["json", "gzip", "stream"] }
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
serde_json = { version = "1" }
|
||||
serde_with = { version = "3.16", features = ["base64"] }
|
||||
serde_with = { version = "3.18", features = ["base64"] }
|
||||
tracing = "0.1"
|
||||
tokio = { version = "1" }
|
||||
tokio-stream = "0.1"
|
||||
tokio-util = "0.7.18"
|
||||
|
||||
[dev-dependencies]
|
||||
console = "0.16.2"
|
||||
console = "0.16.3"
|
||||
dialoguer = "0.12.0"
|
||||
dotenvy = "0.15.7"
|
||||
image = "0.25.9"
|
||||
indicatif = "0.18.3"
|
||||
tokio = { version = "1.49.0", features = ["full"] }
|
||||
tracing-subscriber = "0.3.22"
|
||||
image = "0.25.10"
|
||||
indicatif = "0.18.4"
|
||||
tokio = { version = "1.51.0", features = ["full"] }
|
||||
tracing-subscriber = "0.3.23"
|
||||
|
||||
124
src/client.rs
124
src/client.rs
@@ -6,18 +6,34 @@ use tokio_stream::{Stream, StreamExt};
|
||||
use tokio_util::codec::LinesCodecError;
|
||||
use tracing::error;
|
||||
|
||||
pub static AUTH_SCOPE: &[&str] = &["https://www.googleapis.com/auth/cloud-platform"];
|
||||
|
||||
/// Async client for the Google Gemini API.
|
||||
///
|
||||
/// Provides methods for content generation, streaming, token counting, text embeddings,
|
||||
/// and image prediction. All requests are authenticated with an API key passed at
|
||||
/// construction time.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```no_run
|
||||
/// use google_genai::prelude::*;
|
||||
///
|
||||
/// # async fn run() -> google_genai::error::Result<()> {
|
||||
/// let client = GeminiClient::new("YOUR_API_KEY".into());
|
||||
/// let request = GenerateContentRequest::builder()
|
||||
/// .contents(vec![Content::builder().add_text_part("Hi!").build()])
|
||||
/// .build();
|
||||
/// let response = client.generate_content(&request, "gemini-2.0-flash").await?;
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct GeminiClient {
|
||||
client: reqwest::Client,
|
||||
api_key: String,
|
||||
}
|
||||
|
||||
unsafe impl Send for GeminiClient {}
|
||||
unsafe impl Sync for GeminiClient {}
|
||||
|
||||
impl GeminiClient {
|
||||
/// Creates a new [`GeminiClient`] with the given API key.
|
||||
pub fn new(api_key: String) -> Self {
|
||||
GeminiClient {
|
||||
client: reqwest::Client::new(),
|
||||
@@ -25,6 +41,10 @@ impl GeminiClient {
|
||||
}
|
||||
}
|
||||
|
||||
/// Sends a content generation request and returns a stream of response chunks via SSE.
|
||||
///
|
||||
/// Each item in the stream is a [`GenerateContentResponseResult`] containing one or more
|
||||
/// candidates. Useful for displaying incremental output as it is generated.
|
||||
pub async fn stream_generate_content(
|
||||
&self,
|
||||
request: &GenerateContentRequest,
|
||||
@@ -57,6 +77,9 @@ impl GeminiClient {
|
||||
)
|
||||
}
|
||||
|
||||
/// Sends a content generation request and returns the complete response.
|
||||
///
|
||||
/// For streaming responses, use [`stream_generate_content`](Self::stream_generate_content).
|
||||
pub async fn generate_content(
|
||||
&self,
|
||||
request: &GenerateContentRequest,
|
||||
@@ -99,50 +122,12 @@ impl GeminiClient {
|
||||
}
|
||||
}
|
||||
|
||||
/// Prompts a conversation to the model.
|
||||
pub async fn prompt_conversation(
|
||||
&self,
|
||||
messages: &[Message],
|
||||
model: &str,
|
||||
) -> GeminiResult<Message> {
|
||||
let request = GenerateContentRequest {
|
||||
contents: messages
|
||||
.iter()
|
||||
.map(|m| Content {
|
||||
role: Some(m.role),
|
||||
parts: Some(vec![Part::from_text(m.text.clone())]),
|
||||
})
|
||||
.collect(),
|
||||
generation_config: None,
|
||||
tools: None,
|
||||
system_instruction: None,
|
||||
safety_settings: None,
|
||||
};
|
||||
|
||||
let response = self.generate_content(&request, model).await?;
|
||||
|
||||
// Check for errors in the response.
|
||||
let mut candidates = GeminiClient::collect_text_from_response(&response);
|
||||
|
||||
match candidates.pop() {
|
||||
Some(text) => Ok(Message::new(Role::Model, &text)),
|
||||
None => Err(GeminiError::NoCandidatesError),
|
||||
}
|
||||
}
|
||||
|
||||
fn collect_text_from_response(response: &GenerateContentResponseResult) -> Vec<String> {
|
||||
response
|
||||
.candidates
|
||||
.iter()
|
||||
.filter_map(Candidate::get_text)
|
||||
.collect::<Vec<String>>()
|
||||
}
|
||||
|
||||
/// Generates text embeddings for the given input.
|
||||
pub async fn text_embeddings(
|
||||
&self,
|
||||
request: &TextEmbeddingRequest,
|
||||
model: &str,
|
||||
) -> GeminiResult<TextEmbeddingResponse> {
|
||||
) -> GeminiResult<TextEmbeddingResponseOk> {
|
||||
let endpoint_url =
|
||||
format!("https://generativelanguage.googleapis.com/v1beta/models/{model}:predict");
|
||||
let resp = self
|
||||
@@ -152,16 +137,38 @@ impl GeminiClient {
|
||||
.json(&request)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
let status = resp.status();
|
||||
let txt_json = resp.text().await?;
|
||||
tracing::debug!("text_embeddings response: {:?}", txt_json);
|
||||
Ok(serde_json::from_str::<TextEmbeddingResponse>(&txt_json)?)
|
||||
|
||||
if !status.is_success() {
|
||||
if let Ok(gemini_error) =
|
||||
serde_json::from_str::<crate::types::GeminiApiError>(&txt_json)
|
||||
{
|
||||
return Err(GeminiError::GeminiError(gemini_error));
|
||||
}
|
||||
return Err(GeminiError::GenericApiError {
|
||||
status: status.as_u16(),
|
||||
body: txt_json,
|
||||
});
|
||||
}
|
||||
|
||||
match serde_json::from_str::<TextEmbeddingResponse>(&txt_json) {
|
||||
Ok(response) => Ok(response.into_result()?),
|
||||
Err(e) => {
|
||||
error!(response = txt_json, error = ?e, "Failed to parse response");
|
||||
Err(e.into())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Counts the number of tokens in the given content.
|
||||
pub async fn count_tokens(
|
||||
&self,
|
||||
request: &CountTokensRequest,
|
||||
model: &str,
|
||||
) -> GeminiResult<CountTokensResponse> {
|
||||
) -> GeminiResult<CountTokensResponseResult> {
|
||||
let endpoint_url =
|
||||
format!("https://generativelanguage.googleapis.com/v1beta/models/{model}:countTokens");
|
||||
let resp = self
|
||||
@@ -172,11 +179,32 @@ impl GeminiClient {
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
let status = resp.status();
|
||||
let txt_json = resp.text().await?;
|
||||
tracing::debug!("count_tokens response: {:?}", txt_json);
|
||||
Ok(serde_json::from_str(&txt_json)?)
|
||||
|
||||
if !status.is_success() {
|
||||
if let Ok(gemini_error) =
|
||||
serde_json::from_str::<crate::types::GeminiApiError>(&txt_json)
|
||||
{
|
||||
return Err(GeminiError::GeminiError(gemini_error));
|
||||
}
|
||||
return Err(GeminiError::GenericApiError {
|
||||
status: status.as_u16(),
|
||||
body: txt_json,
|
||||
});
|
||||
}
|
||||
|
||||
match serde_json::from_str::<CountTokensResponse>(&txt_json) {
|
||||
Ok(response) => Ok(response.into_result()?),
|
||||
Err(e) => {
|
||||
error!(response = txt_json, error = ?e, "Failed to parse response");
|
||||
Err(e.into())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Generates images from a text prompt using an Imagen model.
|
||||
pub async fn predict_image(
|
||||
&self,
|
||||
request: &PredictImageRequest,
|
||||
|
||||
@@ -1,42 +0,0 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::{error::Result, prelude::*};
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct Message {
|
||||
pub role: Role,
|
||||
pub text: String,
|
||||
}
|
||||
|
||||
impl Message {
|
||||
pub fn new(role: Role, text: &str) -> Self {
|
||||
Message {
|
||||
role,
|
||||
text: text.to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct Dialogue {
|
||||
model: String,
|
||||
messages: Vec<Message>,
|
||||
}
|
||||
|
||||
impl Dialogue {
|
||||
pub fn new(model: &str) -> Self {
|
||||
Dialogue {
|
||||
model: model.to_string(),
|
||||
messages: vec![],
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn do_turn(&mut self, gemini: &GeminiClient, message: &str) -> Result<Message> {
|
||||
self.messages.push(Message::new(Role::User, message));
|
||||
let response = gemini
|
||||
.prompt_conversation(&self.messages, &self.model)
|
||||
.await?;
|
||||
self.messages.push(response.clone());
|
||||
Ok(response)
|
||||
}
|
||||
}
|
||||
22
src/error.rs
22
src/error.rs
@@ -1,21 +1,39 @@
|
||||
//! Error types for the Google Gemini client.
|
||||
|
||||
use std::fmt::Display;
|
||||
use tokio_util::codec::LinesCodecError;
|
||||
|
||||
use crate::types;
|
||||
|
||||
/// A type alias for `Result<T, error::Error>`.
|
||||
pub type Result<T> = std::result::Result<T, Error>;
|
||||
|
||||
/// Errors that can occur when using the Gemini client.
|
||||
#[derive(Debug)]
|
||||
pub enum Error {
|
||||
/// An environment variable required for configuration was missing or invalid.
|
||||
Env(std::env::VarError),
|
||||
/// An HTTP transport error from the underlying `reqwest` client.
|
||||
HttpClient(reqwest::Error),
|
||||
/// A JSON serialization or deserialization error.
|
||||
Serde(serde_json::Error),
|
||||
/// A structured error returned by the Vertex AI API.
|
||||
VertexError(types::VertexApiError),
|
||||
/// A structured error returned by the Gemini API.
|
||||
GeminiError(types::GeminiApiError),
|
||||
/// The API response contained no candidate completions.
|
||||
NoCandidatesError,
|
||||
/// An error occurred while decoding the SSE event stream.
|
||||
EventSourceError(LinesCodecError),
|
||||
/// The SSE event stream closed unexpectedly.
|
||||
EventSourceClosedError,
|
||||
GenericApiError { status: u16, body: String },
|
||||
/// An API error that could not be parsed into a structured error type.
|
||||
GenericApiError {
|
||||
/// The HTTP status code.
|
||||
status: u16,
|
||||
/// The raw response body.
|
||||
body: String,
|
||||
},
|
||||
}
|
||||
|
||||
impl Display for Error {
|
||||
@@ -34,7 +52,7 @@ impl Display for Error {
|
||||
write!(f, "No candidates returned for the prompt")
|
||||
}
|
||||
Error::EventSourceError(e) => {
|
||||
write!(f, "EventSourrce Error: {e}")
|
||||
write!(f, "EventSource Error: {e}")
|
||||
}
|
||||
Error::EventSourceClosedError => {
|
||||
write!(f, "EventSource closed error")
|
||||
|
||||
31
src/lib.rs
31
src/lib.rs
@@ -1,11 +1,38 @@
|
||||
//! Async Rust client for the Google Gemini API.
|
||||
//!
|
||||
//! This crate provides a high-level async client for interacting with Google's Gemini
|
||||
//! generative AI models. It supports content generation (including streaming via SSE),
|
||||
//! token counting, text embeddings, and image generation.
|
||||
//!
|
||||
//! # Usage
|
||||
//!
|
||||
//! ```no_run
|
||||
//! use google_genai::prelude::*;
|
||||
//!
|
||||
//! # async fn run() -> google_genai::error::Result<()> {
|
||||
//! let client = GeminiClient::new("YOUR_API_KEY".into());
|
||||
//!
|
||||
//! let request = GenerateContentRequest::builder()
|
||||
//! .contents(vec![
|
||||
//! Content::builder().add_text_part("Hello, Gemini!").build()
|
||||
//! ])
|
||||
//! .build();
|
||||
//!
|
||||
//! let response = client.generate_content(&request, "gemini-2.0-flash").await?;
|
||||
//! # Ok(())
|
||||
//! # }
|
||||
//! ```
|
||||
|
||||
mod client;
|
||||
mod dialogue;
|
||||
pub mod error;
|
||||
pub mod network;
|
||||
mod types;
|
||||
|
||||
/// Convenience re-exports of the most commonly used types.
|
||||
///
|
||||
/// Importing `use google_genai::prelude::*` brings [`GeminiClient`](crate::prelude::GeminiClient)
|
||||
/// and all request/response types into scope.
|
||||
pub mod prelude {
|
||||
pub use crate::client::*;
|
||||
pub use crate::dialogue::*;
|
||||
pub use crate::types::*;
|
||||
}
|
||||
|
||||
@@ -1,3 +1,11 @@
|
||||
//! Server-Sent Events (SSE) decoder for streaming HTTP responses.
|
||||
//!
|
||||
//! Implements a [`tokio_util::codec::Decoder`] that parses an SSE byte stream into
|
||||
//! [`ServerSentEvent`] values. Used internally by [`GeminiClient::stream_generate_content`]
|
||||
//! to process chunked model responses.
|
||||
//!
|
||||
//! [`GeminiClient::stream_generate_content`]: crate::prelude::GeminiClient::stream_generate_content
|
||||
|
||||
use reqwest::Response;
|
||||
use std::mem;
|
||||
use tokio_stream::{Stream, StreamExt};
|
||||
@@ -12,7 +20,9 @@ static DATA: &str = "data: ";
|
||||
static ID: &str = "id: ";
|
||||
static RETRY: &str = "retry: ";
|
||||
|
||||
/// Extension trait for converting an HTTP response into a stream of [`ServerSentEvent`]s.
|
||||
pub trait EventSource {
|
||||
/// Consumes the response and returns a stream of parsed SSE events.
|
||||
fn event_stream(self) -> impl Stream<Item = Result<ServerSentEvent, LinesCodecError>>;
|
||||
}
|
||||
|
||||
@@ -22,14 +32,25 @@ impl EventSource for Response {
|
||||
}
|
||||
}
|
||||
|
||||
/// A parsed Server-Sent Event.
|
||||
///
|
||||
/// Fields correspond to the standard SSE fields: `event`, `data`, `id`, and `retry`.
|
||||
/// Multiple `data:` lines within a single event are concatenated with newline separators.
|
||||
#[derive(Debug, Default, Clone)]
|
||||
pub struct ServerSentEvent {
|
||||
/// The event type (from the `event:` field).
|
||||
pub event: Option<String>,
|
||||
/// The event payload (from one or more `data:` fields, joined by `\n`).
|
||||
pub data: Option<String>,
|
||||
/// The event ID (from the `id:` field).
|
||||
pub id: Option<String>,
|
||||
/// The reconnection time in milliseconds (from the `retry:` field).
|
||||
pub retry: Option<usize>,
|
||||
}
|
||||
|
||||
/// A [`Decoder`] that parses a byte stream of SSE-formatted data into [`ServerSentEvent`]s.
|
||||
///
|
||||
/// Wraps a [`LinesCodec`] and accumulates fields until an empty line signals the end of an event.
|
||||
pub struct ServerSentEventsCodec {
|
||||
lines_code: LinesCodec,
|
||||
next: ServerSentEvent,
|
||||
@@ -42,6 +63,7 @@ impl Default for ServerSentEventsCodec {
|
||||
}
|
||||
|
||||
impl ServerSentEventsCodec {
|
||||
/// Creates a new SSE codec.
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
lines_code: LinesCodec::new(),
|
||||
@@ -73,7 +95,12 @@ impl Decoder for ServerSentEventsCodec {
|
||||
self.next.event = Some(line);
|
||||
} else if line.starts_with(DATA) {
|
||||
line.drain(..DATA.len());
|
||||
self.next.data = Some(line)
|
||||
if let Some(ref mut existing) = self.next.data {
|
||||
existing.push('\n');
|
||||
existing.push_str(&line);
|
||||
} else {
|
||||
self.next.data = Some(line);
|
||||
}
|
||||
} else if line.starts_with(ID) {
|
||||
line.drain(..ID.len());
|
||||
self.next.id = Some(line);
|
||||
@@ -90,6 +117,9 @@ impl Decoder for ServerSentEventsCodec {
|
||||
}
|
||||
}
|
||||
|
||||
/// Converts a [`Response`] into a stream of [`ServerSentEvent`]s.
|
||||
///
|
||||
/// The response body is read as a byte stream and decoded using [`ServerSentEventsCodec`].
|
||||
pub fn stream_response(
|
||||
response: Response,
|
||||
) -> impl Stream<Item = Result<ServerSentEvent, LinesCodecError>> {
|
||||
|
||||
@@ -1 +1,3 @@
|
||||
//! Networking utilities for streaming HTTP responses.
|
||||
|
||||
pub mod event_source;
|
||||
|
||||
@@ -5,13 +5,21 @@ use serde_json::Value;
|
||||
|
||||
use crate::types::FunctionResponse;
|
||||
|
||||
/// A conversation message containing one or more [`Part`]s.
|
||||
///
|
||||
/// See <https://ai.google.dev/api/caching#Content>.
|
||||
#[derive(Clone, Default, Debug, Serialize, Deserialize)]
|
||||
pub struct Content {
|
||||
/// The role of the message author (`user` or `model`).
|
||||
pub role: Option<Role>,
|
||||
/// The ordered parts that make up this message.
|
||||
pub parts: Option<Vec<Part>>,
|
||||
}
|
||||
|
||||
impl Content {
|
||||
/// Concatenates all [`PartData::Text`] parts into a single string.
|
||||
///
|
||||
/// Returns `None` if there are no parts.
|
||||
pub fn get_text(&self) -> Option<String> {
|
||||
self.parts.as_ref().map(|parts| {
|
||||
parts
|
||||
@@ -24,25 +32,30 @@ impl Content {
|
||||
})
|
||||
}
|
||||
|
||||
/// Creates a [`Content`] containing a single text part, suitable for use as a system instruction.
|
||||
pub fn system_prompt<S: Into<String>>(system_prompt: S) -> Self {
|
||||
Self::builder().add_text_part(system_prompt).build()
|
||||
}
|
||||
|
||||
/// Returns a new [`ContentBuilder`].
|
||||
pub fn builder() -> ContentBuilder {
|
||||
ContentBuilder::default()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
/// Builder for constructing [`Content`] values incrementally.
|
||||
#[derive(Clone, Debug, Default)]
|
||||
pub struct ContentBuilder {
|
||||
content: Content,
|
||||
}
|
||||
|
||||
impl ContentBuilder {
|
||||
/// Appends a text part to this content.
|
||||
pub fn add_text_part<T: Into<String>>(self, text: T) -> Self {
|
||||
self.add_part(Part::from_text(text.into()))
|
||||
}
|
||||
|
||||
/// Appends an arbitrary [`Part`] to this content.
|
||||
pub fn add_part(mut self, part: Part) -> Self {
|
||||
match &mut self.content.parts {
|
||||
Some(parts) => parts.push(part),
|
||||
@@ -51,16 +64,19 @@ impl ContentBuilder {
|
||||
self
|
||||
}
|
||||
|
||||
/// Sets the [`Role`] for this content.
|
||||
pub fn role(mut self, role: Role) -> Self {
|
||||
self.content.role = Some(role);
|
||||
self
|
||||
}
|
||||
|
||||
/// Consumes the builder and returns the constructed [`Content`].
|
||||
pub fn build(self) -> Content {
|
||||
self.content
|
||||
}
|
||||
}
|
||||
|
||||
/// The role of a message author in a conversation.
|
||||
#[derive(Clone, Copy, Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum Role {
|
||||
@@ -90,7 +106,9 @@ impl FromStr for Role {
|
||||
}
|
||||
}
|
||||
|
||||
/// See https://ai.google.dev/api/caching#Part
|
||||
/// A single unit of content within a [`Content`] message.
|
||||
///
|
||||
/// See <https://ai.google.dev/api/caching#Part>.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct Part {
|
||||
@@ -108,29 +126,42 @@ pub struct Part {
|
||||
pub data: PartData, // Create enum for data.
|
||||
}
|
||||
|
||||
/// The payload of a [`Part`], representing different content types.
|
||||
///
|
||||
/// See <https://ai.google.dev/api/caching#Part>.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub enum PartData {
|
||||
/// Plain text content.
|
||||
Text(String),
|
||||
// https://ai.google.dev/api/caching#Blob
|
||||
/// Binary data encoded inline. See <https://ai.google.dev/api/caching#Blob>.
|
||||
InlineData {
|
||||
/// The IANA MIME type of the data (e.g. `"image/png"`).
|
||||
mime_type: String,
|
||||
/// Base64-encoded binary data.
|
||||
data: String,
|
||||
},
|
||||
// https://ai.google.dev/api/caching#FunctionCall
|
||||
/// A function call requested by the model. See <https://ai.google.dev/api/caching#FunctionCall>.
|
||||
FunctionCall {
|
||||
/// Optional unique identifier for the function call.
|
||||
id: Option<String>,
|
||||
/// The name of the function to call.
|
||||
name: String,
|
||||
/// The arguments to pass, as a JSON object.
|
||||
args: Option<Value>,
|
||||
},
|
||||
// https://ai.google.dev/api/caching#FunctionResponse
|
||||
/// A response to a function call. See <https://ai.google.dev/api/caching#FunctionResponse>.
|
||||
FunctionResponse(FunctionResponse),
|
||||
/// A reference to a file stored in the API.
|
||||
FileData(Value),
|
||||
/// Code to be executed by the model.
|
||||
ExecutableCode(Value),
|
||||
/// The result of executing code.
|
||||
CodeExecutionResult(Value),
|
||||
}
|
||||
|
||||
impl Part {
|
||||
/// Creates a [`Part`] containing only text.
|
||||
pub fn from_text<S: Into<String>>(text: S) -> Self {
|
||||
Self {
|
||||
thought: None,
|
||||
|
||||
@@ -1,24 +1,35 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::error::{Error, Result};
|
||||
|
||||
use super::Content;
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
/// Request body for the `countTokens` endpoint.
|
||||
///
|
||||
/// Use [`CountTokensRequest::builder`] for ergonomic construction.
|
||||
///
|
||||
/// See <https://ai.google.dev/api/tokens#method:-models.counttokens>.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct CountTokensRequest {
|
||||
/// The content to count tokens for.
|
||||
pub contents: Content,
|
||||
}
|
||||
|
||||
impl CountTokensRequest {
|
||||
/// Returns a new [`CountTokensRequestBuilder`].
|
||||
pub fn builder() -> CountTokensRequestBuilder {
|
||||
CountTokensRequestBuilder::default()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
/// Builder for [`CountTokensRequest`].
|
||||
#[derive(Debug, Default)]
|
||||
pub struct CountTokensRequestBuilder {
|
||||
contents: Content,
|
||||
}
|
||||
|
||||
impl CountTokensRequestBuilder {
|
||||
/// Creates a builder pre-populated with a single text prompt.
|
||||
pub fn from_prompt(prompt: &str) -> Self {
|
||||
CountTokensRequestBuilder {
|
||||
contents: Content {
|
||||
@@ -28,6 +39,7 @@ impl CountTokensRequestBuilder {
|
||||
}
|
||||
}
|
||||
|
||||
/// Consumes the builder and returns the constructed [`CountTokensRequest`].
|
||||
pub fn build(self) -> CountTokensRequest {
|
||||
CountTokensRequest {
|
||||
contents: self.contents,
|
||||
@@ -35,15 +47,32 @@ impl CountTokensRequestBuilder {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
/// The raw response from the `countTokens` endpoint, which may be a success or an error.
|
||||
///
|
||||
/// Use [`into_result`](CountTokensResponse::into_result) to convert into a standard `Result`.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum CountTokensResponse {
|
||||
#[serde(rename_all = "camelCase")]
|
||||
Ok {
|
||||
total_tokens: i32,
|
||||
total_billable_characters: u32,
|
||||
},
|
||||
Error {
|
||||
error: super::VertexApiError,
|
||||
},
|
||||
Ok(CountTokensResponseResult),
|
||||
Error { error: super::VertexApiError },
|
||||
}
|
||||
|
||||
impl CountTokensResponse {
|
||||
/// Converts this response into a `Result`, mapping the error variant to [`crate::error::Error`].
|
||||
pub fn into_result(self) -> Result<CountTokensResponseResult> {
|
||||
match self {
|
||||
CountTokensResponse::Ok(result) => Ok(result),
|
||||
CountTokensResponse::Error { error } => Err(Error::VertexError(error)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A successful response from the `countTokens` endpoint.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct CountTokensResponseResult {
|
||||
/// The total number of tokens in the input.
|
||||
pub total_tokens: i32,
|
||||
/// The total number of billable characters in the input.
|
||||
pub total_billable_characters: u32,
|
||||
}
|
||||
|
||||
@@ -2,11 +2,16 @@ use std::fmt::Formatter;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// A structured error returned by the Vertex AI / Gemini API.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct VertexApiError {
|
||||
/// The HTTP status code.
|
||||
pub code: i32,
|
||||
/// A human-readable error message.
|
||||
pub message: String,
|
||||
/// The gRPC status string (e.g. `"INVALID_ARGUMENT"`).
|
||||
pub status: String,
|
||||
/// Optional additional error details.
|
||||
pub details: Option<Vec<serde_json::Value>>,
|
||||
}
|
||||
|
||||
@@ -19,49 +24,23 @@ impl core::fmt::Display for VertexApiError {
|
||||
|
||||
impl std::error::Error for VertexApiError {}
|
||||
|
||||
/// A wrapper around [`VertexApiError`] matching the Gemini API error response format.
|
||||
///
|
||||
/// The Gemini API nests the error details inside an `error` field.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct GeminiApiError {
|
||||
/// The inner error details.
|
||||
pub error: VertexApiError,
|
||||
}
|
||||
|
||||
impl core::fmt::Display for GeminiApiError {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), std::fmt::Error> {
|
||||
write!(f, "Gemini API Error {} - {}", self.error.code, self.error.message)
|
||||
write!(
|
||||
f,
|
||||
"Gemini API Error {} - {}",
|
||||
self.error.code, self.error.message
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for GeminiApiError {}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct Link {
|
||||
pub description: String,
|
||||
pub url: String,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
#[serde(tag = "@type")]
|
||||
pub enum ErrorType {
|
||||
#[serde(rename = "type.googleapis.com/google.rpc.ErrorInfo")]
|
||||
ErrorInfo { metadata: ErrorInfoMetadata },
|
||||
|
||||
#[serde(rename = "type.googleapis.com/google.rpc.Help")]
|
||||
Help { links: Vec<Link> },
|
||||
|
||||
#[serde(rename = "type.googleapis.com/google.rpc.BadRequest")]
|
||||
BadRequest {
|
||||
#[serde(rename = "fieldViolations")]
|
||||
field_violations: Vec<FieldViolation>,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct ErrorInfoMetadata {
|
||||
pub service: String,
|
||||
pub consumer: String,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct FieldViolation {
|
||||
pub field: String,
|
||||
pub description: String,
|
||||
}
|
||||
|
||||
@@ -4,7 +4,12 @@ use serde_json::Value;
|
||||
use super::{Content, VertexApiError};
|
||||
use crate::error::Result;
|
||||
|
||||
#[derive(Clone, Default, Serialize, Deserialize)]
|
||||
/// Request body for the `generateContent` and `streamGenerateContent` endpoints.
|
||||
///
|
||||
/// Use [`GenerateContentRequest::builder`] for ergonomic construction.
|
||||
///
|
||||
/// See <https://ai.google.dev/api/generate-content#request-body>.
|
||||
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct GenerateContentRequest {
|
||||
pub contents: Vec<Content>,
|
||||
@@ -19,11 +24,14 @@ pub struct GenerateContentRequest {
|
||||
}
|
||||
|
||||
impl GenerateContentRequest {
|
||||
/// Returns a new [`GenerateContentRequestBuilder`].
|
||||
pub fn builder() -> GenerateContentRequestBuilder {
|
||||
GenerateContentRequestBuilder::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// Builder for [`GenerateContentRequest`].
|
||||
#[derive(Debug)]
|
||||
pub struct GenerateContentRequestBuilder {
|
||||
request: GenerateContentRequest,
|
||||
}
|
||||
@@ -35,37 +43,46 @@ impl GenerateContentRequestBuilder {
|
||||
}
|
||||
}
|
||||
|
||||
/// Sets the conversation contents.
|
||||
pub fn contents(mut self, contents: Vec<Content>) -> Self {
|
||||
self.request.contents = contents;
|
||||
self
|
||||
}
|
||||
|
||||
/// Sets the generation configuration.
|
||||
pub fn generation_config(mut self, generation_config: GenerationConfig) -> Self {
|
||||
self.request.generation_config = Some(generation_config);
|
||||
self
|
||||
}
|
||||
|
||||
/// Sets the tools available to the model (e.g. function calling, Google Search).
|
||||
pub fn tools(mut self, tools: Vec<Tools>) -> Self {
|
||||
self.request.tools = Some(tools);
|
||||
self
|
||||
}
|
||||
|
||||
/// Sets the safety filter settings.
|
||||
pub fn safety_settings(mut self, safety_settings: Vec<SafetySetting>) -> Self {
|
||||
self.request.safety_settings = Some(safety_settings);
|
||||
self
|
||||
}
|
||||
|
||||
/// Sets a system instruction to guide the model's behavior.
|
||||
pub fn system_instruction(mut self, system_instruction: Content) -> Self {
|
||||
self.request.system_instruction = Some(system_instruction);
|
||||
self
|
||||
}
|
||||
|
||||
/// Consumes the builder and returns the constructed [`GenerateContentRequest`].
|
||||
pub fn build(self) -> GenerateContentRequest {
|
||||
self.request
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Default, Serialize, Deserialize)]
|
||||
/// A set of tool declarations the model may use during generation.
|
||||
///
|
||||
/// See <https://ai.google.dev/api/caching#Tool>.
|
||||
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
|
||||
pub struct Tools {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub function_declarations: Option<Vec<FunctionDeclaration>>,
|
||||
@@ -78,13 +95,17 @@ pub struct Tools {
|
||||
pub google_search: Option<GoogleSearch>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Default, Serialize, Deserialize)]
|
||||
/// Enables the Google Search grounding tool (no configuration required).
|
||||
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
|
||||
pub struct GoogleSearch {}
|
||||
|
||||
#[derive(Clone, Serialize, Deserialize)]
|
||||
/// Configuration for dynamic retrieval in Google Search grounding.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct DynamicRetrievalConfig {
|
||||
/// The retrieval mode (e.g. `"MODE_DYNAMIC"`).
|
||||
pub mode: String,
|
||||
/// The threshold for triggering retrieval. Defaults to `0.7`.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub dynamic_threshold: Option<f32>,
|
||||
}
|
||||
@@ -98,12 +119,19 @@ impl Default for DynamicRetrievalConfig {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Default, Serialize, Deserialize)]
|
||||
/// Google Search retrieval tool with dynamic retrieval configuration.
|
||||
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct GoogleSearchRetrieval {
|
||||
/// Configuration controlling when retrieval is triggered.
|
||||
pub dynamic_retrieval_config: DynamicRetrievalConfig,
|
||||
}
|
||||
|
||||
/// Parameters that control how the model generates content.
|
||||
///
|
||||
/// Use [`GenerationConfig::builder`] for ergonomic construction.
|
||||
///
|
||||
/// See <https://ai.google.dev/api/generate-content#generationconfig>.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, Default)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct GenerationConfig {
|
||||
@@ -128,11 +156,14 @@ pub struct GenerationConfig {
|
||||
}
|
||||
|
||||
impl GenerationConfig {
|
||||
/// Returns a new [`GenerationConfigBuilder`].
|
||||
pub fn builder() -> GenerationConfigBuilder {
|
||||
GenerationConfigBuilder::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// Builder for [`GenerationConfig`].
|
||||
#[derive(Debug)]
|
||||
pub struct GenerationConfigBuilder {
|
||||
generation_config: GenerationConfig,
|
||||
}
|
||||
@@ -189,11 +220,13 @@ impl GenerationConfigBuilder {
|
||||
self
|
||||
}
|
||||
|
||||
/// Consumes the builder and returns the constructed [`GenerationConfig`].
|
||||
pub fn build(self) -> GenerationConfig {
|
||||
self.generation_config
|
||||
}
|
||||
}
|
||||
|
||||
/// Configuration for the model's "thinking" (chain-of-thought) behavior.
|
||||
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ThinkingConfig {
|
||||
@@ -204,6 +237,7 @@ pub struct ThinkingConfig {
|
||||
pub thinking_level: Option<ThinkingLevel>,
|
||||
}
|
||||
|
||||
/// The level of thinking effort the model should use.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
|
||||
pub enum ThinkingLevel {
|
||||
@@ -212,6 +246,9 @@ pub enum ThinkingLevel {
|
||||
High,
|
||||
}
|
||||
|
||||
/// A safety filter configuration that controls blocking thresholds for harmful content.
|
||||
///
|
||||
/// See <https://ai.google.dev/api/generate-content#safetysetting>.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct SafetySetting {
|
||||
@@ -221,6 +258,9 @@ pub struct SafetySetting {
|
||||
pub method: Option<HarmBlockMethod>,
|
||||
}
|
||||
|
||||
/// Categories of potentially harmful content.
|
||||
///
|
||||
/// See <https://ai.google.dev/api/generate-content#harmcategory>.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub enum HarmCategory {
|
||||
#[serde(rename = "HARM_CATEGORY_UNSPECIFIED")]
|
||||
@@ -235,6 +275,7 @@ pub enum HarmCategory {
|
||||
SexuallyExplicit,
|
||||
}
|
||||
|
||||
/// The threshold at which harmful content is blocked.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub enum HarmBlockThreshold {
|
||||
#[serde(rename = "HARM_BLOCK_THRESHOLD_UNSPECIFIED")]
|
||||
@@ -249,6 +290,7 @@ pub enum HarmBlockThreshold {
|
||||
BlockNone,
|
||||
}
|
||||
|
||||
/// The method used to evaluate harm (severity-based or probability-based).
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub enum HarmBlockMethod {
|
||||
#[serde(rename = "HARM_BLOCK_METHOD_UNSPECIFIED")]
|
||||
@@ -259,7 +301,10 @@ pub enum HarmBlockMethod {
|
||||
Probability, // PROBABILITY
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
/// A single candidate response generated by the model.
|
||||
///
|
||||
/// See <https://ai.google.dev/api/generate-content#candidate>.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct Candidate {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
@@ -274,6 +319,7 @@ pub struct Candidate {
|
||||
}
|
||||
|
||||
impl Candidate {
|
||||
/// Returns the concatenated text from this candidate's content, if any.
|
||||
pub fn get_text(&self) -> Option<String> {
|
||||
match &self.content {
|
||||
Some(content) => content.get_text(),
|
||||
@@ -282,7 +328,8 @@ impl Candidate {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
/// A citation to a source used by the model in its response.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct Citation {
|
||||
pub start_index: Option<i32>,
|
||||
@@ -290,13 +337,15 @@ pub struct Citation {
|
||||
pub uri: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
/// Metadata containing citations for a candidate's content.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct CitationMetadata {
|
||||
#[serde(alias = "citationSources")]
|
||||
pub citations: Vec<Citation>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
/// A safety rating for a piece of content across a specific harm category.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct SafetyRating {
|
||||
pub category: String,
|
||||
@@ -306,7 +355,8 @@ pub struct SafetyRating {
|
||||
pub severity_score: Option<f32>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
/// Token usage statistics for a generate content request/response.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct UsageMetadata {
|
||||
pub candidates_token_count: Option<u32>,
|
||||
@@ -314,6 +364,9 @@ pub struct UsageMetadata {
|
||||
pub total_token_count: Option<u32>,
|
||||
}
|
||||
|
||||
/// A declaration of a function the model may call.
|
||||
///
|
||||
/// See <https://ai.google.dev/api/caching#FunctionDeclaration>.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct FunctionDeclaration {
|
||||
@@ -330,7 +383,7 @@ pub struct FunctionDeclaration {
|
||||
pub response_json_schema: Option<Value>,
|
||||
}
|
||||
|
||||
/// See https://ai.google.dev/api/caching#FunctionResponse
|
||||
/// See <https://ai.google.dev/api/caching#FunctionResponse>.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct FunctionResponse {
|
||||
@@ -343,14 +396,14 @@ pub struct FunctionResponse {
|
||||
pub scheduling: Option<Scheduling>,
|
||||
}
|
||||
|
||||
/// See https://ai.google.dev/api/caching#FunctionResponsePart
|
||||
/// See <https://ai.google.dev/api/caching#FunctionResponsePart>.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub enum FunctionResponsePart {
|
||||
InlineData(FunctionResponseBlob),
|
||||
}
|
||||
|
||||
/// See https://ai.google.dev/api/caching#FunctionResponseBlob
|
||||
/// See <https://ai.google.dev/api/caching#FunctionResponseBlob>.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct FunctionResponseBlob {
|
||||
@@ -358,7 +411,7 @@ pub struct FunctionResponseBlob {
|
||||
pub data: String,
|
||||
}
|
||||
|
||||
/// See https://ai.google.dev/api/caching#Scheduling
|
||||
/// See <https://ai.google.dev/api/caching#Scheduling>.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
|
||||
pub enum Scheduling {
|
||||
@@ -368,6 +421,7 @@ pub enum Scheduling {
|
||||
Interrupt,
|
||||
}
|
||||
|
||||
/// A single property within a function's parameter schema.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct FunctionParametersProperty {
|
||||
@@ -375,7 +429,11 @@ pub struct FunctionParametersProperty {
|
||||
pub description: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
/// The raw response from the `generateContent` endpoint, which may be a success or an error.
|
||||
///
|
||||
/// Use [`into_result`](GenerateContentResponse::into_result) to convert into a standard
|
||||
/// `Result<GenerateContentResponseResult>`.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum GenerateContentResponse {
|
||||
Ok(GenerateContentResponseResult),
|
||||
@@ -391,19 +449,22 @@ impl From<GenerateContentResponse> for Result<GenerateContentResponseResult> {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
/// A successful response from the `generateContent` endpoint.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct GenerateContentResponseResult {
|
||||
pub candidates: Vec<Candidate>,
|
||||
pub usage_metadata: Option<UsageMetadata>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
/// An error response from the `generateContent` endpoint.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct GenerateContentResponseError {
|
||||
pub error: VertexApiError,
|
||||
}
|
||||
|
||||
impl GenerateContentResponse {
|
||||
/// Converts this response into a `Result`, mapping the error variant to [`crate::error::Error`].
|
||||
pub fn into_result(self) -> Result<GenerateContentResponseResult> {
|
||||
match self {
|
||||
GenerateContentResponse::Ok(result) => Ok(result),
|
||||
@@ -416,13 +477,7 @@ impl GenerateContentResponse {
|
||||
mod tests {
|
||||
use crate::types::{Candidate, UsageMetadata};
|
||||
|
||||
use super::{GenerateContentResponse, GenerateContentResponseResult};
|
||||
|
||||
#[test]
|
||||
pub fn parses_empty_metadata_response() {
|
||||
let input = r#"{"candidates": [{"content": {"role": "model","parts": [{"text": "-"}]}}],"usageMetadata": {}}"#;
|
||||
serde_json::from_str::<GenerateContentResponseResult>(input).unwrap();
|
||||
}
|
||||
use super::GenerateContentResponseResult;
|
||||
|
||||
#[test]
|
||||
pub fn parses_usage_metadata() {
|
||||
@@ -498,177 +553,4 @@ mod tests {
|
||||
"#;
|
||||
let _ = serde_json::from_str::<GenerateContentResponseResult>(input).unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn parses_max_tokens_response() {
|
||||
let input = r#"{
|
||||
"candidates": [
|
||||
{
|
||||
"content": {
|
||||
"role": "model",
|
||||
"parts": [
|
||||
{
|
||||
"text": "Service workers are powerful and absolutely worth learning. They let you deliver an entirely new level of experience to your users. Your site can load instantly . It can work offline . It can be installed as a platform-specific app and feel every bit as polished—but with the reach and freedom of the web."
|
||||
}
|
||||
]
|
||||
},
|
||||
"finishReason": "MAX_TOKENS",
|
||||
"safetyRatings": [
|
||||
{
|
||||
"category": "HARM_CATEGORY_HATE_SPEECH",
|
||||
"probability": "NEGLIGIBLE",
|
||||
"probabilityScore": 0.03882902,
|
||||
"severity": "HARM_SEVERITY_NEGLIGIBLE",
|
||||
"severityScore": 0.05781161
|
||||
},
|
||||
{
|
||||
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
|
||||
"probability": "NEGLIGIBLE",
|
||||
"probabilityScore": 0.07626997,
|
||||
"severity": "HARM_SEVERITY_NEGLIGIBLE",
|
||||
"severityScore": 0.06705628
|
||||
},
|
||||
{
|
||||
"category": "HARM_CATEGORY_HARASSMENT",
|
||||
"probability": "NEGLIGIBLE",
|
||||
"probabilityScore": 0.05749328,
|
||||
"severity": "HARM_SEVERITY_NEGLIGIBLE",
|
||||
"severityScore": 0.027532939
|
||||
},
|
||||
{
|
||||
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
|
||||
"probability": "NEGLIGIBLE",
|
||||
"probabilityScore": 0.12929276,
|
||||
"severity": "HARM_SEVERITY_NEGLIGIBLE",
|
||||
"severityScore": 0.17838266
|
||||
}
|
||||
],
|
||||
"citationMetadata": {
|
||||
"citations": [
|
||||
{
|
||||
"endIndex": 151,
|
||||
"uri": "https://web.dev/service-worker-mindset/"
|
||||
},
|
||||
{
|
||||
"startIndex": 93,
|
||||
"endIndex": 297,
|
||||
"uri": "https://web.dev/service-worker-mindset/"
|
||||
},
|
||||
{
|
||||
"endIndex": 297
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
],
|
||||
"usageMetadata": {
|
||||
"promptTokenCount": 12069,
|
||||
"candidatesTokenCount": 61,
|
||||
"totalTokenCount": 12130
|
||||
}
|
||||
}"#;
|
||||
serde_json::from_str::<GenerateContentResponseResult>(input).unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parses_candidates_without_content() {
|
||||
let input = r#"{
|
||||
"candidates": [
|
||||
{
|
||||
"finishReason": "RECITATION",
|
||||
"safetyRatings": [
|
||||
{
|
||||
"category": "HARM_CATEGORY_HATE_SPEECH",
|
||||
"probability": "NEGLIGIBLE",
|
||||
"probabilityScore": 0.08021325,
|
||||
"severity": "HARM_SEVERITY_NEGLIGIBLE",
|
||||
"severityScore": 0.0721122
|
||||
},
|
||||
{
|
||||
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
|
||||
"probability": "NEGLIGIBLE",
|
||||
"probabilityScore": 0.19360436,
|
||||
"severity": "HARM_SEVERITY_NEGLIGIBLE",
|
||||
"severityScore": 0.1066906
|
||||
},
|
||||
{
|
||||
"category": "HARM_CATEGORY_HARASSMENT",
|
||||
"probability": "NEGLIGIBLE",
|
||||
"probabilityScore": 0.07751766,
|
||||
"severity": "HARM_SEVERITY_NEGLIGIBLE",
|
||||
"severityScore": 0.040769264
|
||||
},
|
||||
{
|
||||
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
|
||||
"probability": "NEGLIGIBLE",
|
||||
"probabilityScore": 0.030792166,
|
||||
"severity": "HARM_SEVERITY_NEGLIGIBLE",
|
||||
"severityScore": 0.04138472
|
||||
}
|
||||
],
|
||||
"citationMetadata": {
|
||||
"citations": [
|
||||
{
|
||||
"startIndex": 1108,
|
||||
"endIndex": 1250,
|
||||
"uri": "https://chrome.google.com/webstore/detail/autocontrol-shortcut-mana/lkaihdpfpifdlgoapbfocpmekbokmcfd?hl=zh-TW"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
],
|
||||
"usageMetadata": {
|
||||
"promptTokenCount": 577,
|
||||
"totalTokenCount": 577
|
||||
}
|
||||
}"#;
|
||||
serde_json::from_str::<GenerateContentResponse>(input).unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parses_safety_rating_without_scores() {
|
||||
let input = r#"{
|
||||
"candidates": [
|
||||
{
|
||||
"content": {
|
||||
"role": "model",
|
||||
"parts": [
|
||||
{
|
||||
"text": "Return text"
|
||||
}
|
||||
]
|
||||
},
|
||||
"finishReason": "STOP",
|
||||
"safetyRatings": [
|
||||
{
|
||||
"category": "HARM_CATEGORY_HATE_SPEECH",
|
||||
"probability": "NEGLIGIBLE",
|
||||
"severity": "HARM_SEVERITY_NEGLIGIBLE"
|
||||
},
|
||||
{
|
||||
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
|
||||
"probability": "NEGLIGIBLE",
|
||||
"severity": "HARM_SEVERITY_NEGLIGIBLE"
|
||||
},
|
||||
{
|
||||
"category": "HARM_CATEGORY_HARASSMENT",
|
||||
"probability": "NEGLIGIBLE",
|
||||
"severity": "HARM_SEVERITY_NEGLIGIBLE"
|
||||
},
|
||||
{
|
||||
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
|
||||
"probability": "NEGLIGIBLE",
|
||||
"severity": "HARM_SEVERITY_NEGLIGIBLE"
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
"usageMetadata": {
|
||||
"promptTokenCount": 5492,
|
||||
"candidatesTokenCount": 1256,
|
||||
"totalTokenCount": 6748
|
||||
}
|
||||
}"#;
|
||||
serde_json::from_str::<GenerateContentResponse>(input).unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
//! Request and response types for the Gemini API.
|
||||
|
||||
mod common;
|
||||
mod count_tokens;
|
||||
mod error;
|
||||
|
||||
@@ -2,13 +2,15 @@ use serde::{Deserialize, Serialize};
|
||||
use serde_with::base64::Base64;
|
||||
use serde_with::serde_as;
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
/// Request body for the Imagen image generation `predict` endpoint.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct PredictImageRequest {
|
||||
pub instances: Vec<PredictImageRequestPrompt>,
|
||||
pub parameters: PredictImageRequestParameters,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
/// A text prompt instance for image generation.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct PredictImageRequestPrompt {
|
||||
/// The text prompt for the image.
|
||||
/// The following models support different values for this parameter:
|
||||
@@ -20,7 +22,8 @@ pub struct PredictImageRequestPrompt {
|
||||
pub prompt: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Serialize, Deserialize)]
|
||||
/// Parameters controlling image generation behavior.
|
||||
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct PredictImageRequestParameters {
|
||||
/// The number of images to generate. The default value is 4.
|
||||
@@ -139,7 +142,8 @@ pub struct PredictImageRequestParameters {
|
||||
pub storage_uri: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
/// Output format options for generated images.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct PredictImageRequestParametersOutputOptions {
|
||||
/// The image format that the output should be saved as. The following values are supported:
|
||||
@@ -155,13 +159,15 @@ pub struct PredictImageRequestParametersOutputOptions {
|
||||
pub compression_quality: Option<i32>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
/// A successful response from the Imagen `predict` endpoint.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct PredictImageResponse {
|
||||
pub predictions: Vec<PredictImageResponsePrediction>,
|
||||
}
|
||||
|
||||
/// A single generated image from the prediction response.
|
||||
#[serde_as]
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct PredictImageResponsePrediction {
|
||||
#[serde_as(as = "Base64")]
|
||||
@@ -169,7 +175,8 @@ pub struct PredictImageResponsePrediction {
|
||||
pub mime_type: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
/// Controls whether generated images may include people.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum PersonGeneration {
|
||||
DontAllow,
|
||||
@@ -177,7 +184,8 @@ pub enum PersonGeneration {
|
||||
AllowAll,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
/// Safety filter level for image generation.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum PredictImageSafetySetting {
|
||||
BlockLowAndAbove,
|
||||
|
||||
@@ -3,18 +3,27 @@ use serde::{Deserialize, Serialize};
|
||||
use crate::error::{Error, Result};
|
||||
use crate::prelude::VertexApiError;
|
||||
|
||||
/// Request body for the text embeddings `predict` endpoint.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct TextEmbeddingRequest {
|
||||
/// The list of text instances to embed.
|
||||
pub instances: Vec<TextEmbeddingRequestInstance>,
|
||||
}
|
||||
|
||||
/// A single text instance to embed.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct TextEmbeddingRequestInstance {
|
||||
/// The text content to generate an embedding for.
|
||||
pub content: String,
|
||||
/// The task type for the embedding (e.g. `"RETRIEVAL_DOCUMENT"`, `"RETRIEVAL_QUERY"`).
|
||||
pub task_type: String,
|
||||
/// An optional title for the content (used with retrieval task types).
|
||||
pub title: Option<String>,
|
||||
}
|
||||
|
||||
/// The raw response from the text embeddings endpoint, which may be a success or an error.
|
||||
///
|
||||
/// Use [`into_result`](TextEmbeddingResponse::into_result) to convert into a standard `Result`.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum TextEmbeddingResponse {
|
||||
@@ -23,13 +32,16 @@ pub enum TextEmbeddingResponse {
|
||||
}
|
||||
|
||||
impl TextEmbeddingResponse {
|
||||
/// Converts this response into a `Result`, mapping the error variant to [`crate::error::Error`].
|
||||
pub fn into_result(self) -> Result<TextEmbeddingResponseOk> {
|
||||
self.into()
|
||||
}
|
||||
}
|
||||
|
||||
/// A successful response from the text embeddings endpoint.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct TextEmbeddingResponseOk {
|
||||
/// The embedding predictions, one per input instance.
|
||||
pub predictions: Vec<TextEmbeddingPrediction>,
|
||||
}
|
||||
|
||||
@@ -42,19 +54,27 @@ impl From<TextEmbeddingResponse> for Result<TextEmbeddingResponseOk> {
|
||||
}
|
||||
}
|
||||
|
||||
/// A single embedding prediction.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct TextEmbeddingPrediction {
|
||||
/// The embedding result containing the vector and statistics.
|
||||
pub embeddings: TextEmbeddingResult,
|
||||
}
|
||||
|
||||
/// The embedding vector and associated statistics.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct TextEmbeddingResult {
|
||||
/// Statistics about the embedding computation.
|
||||
pub statistics: TextEmbeddingStatistics,
|
||||
/// The embedding vector.
|
||||
pub values: Vec<f64>,
|
||||
}
|
||||
|
||||
/// Statistics about a text embedding computation.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct TextEmbeddingStatistics {
|
||||
/// Whether the input was truncated to fit the model's context window.
|
||||
pub truncated: bool,
|
||||
/// The number of tokens in the input.
|
||||
pub token_count: u32,
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user