diff --git a/src/client.rs b/src/client.rs index 14226da..d1969ff 100644 --- a/src/client.rs +++ b/src/client.rs @@ -6,6 +6,26 @@ use tokio_stream::{Stream, StreamExt}; use tokio_util::codec::LinesCodecError; use tracing::error; +/// Async client for the Google Gemini API. +/// +/// Provides methods for content generation, streaming, token counting, text embeddings, +/// and image prediction. All requests are authenticated with an API key passed at +/// construction time. +/// +/// # Example +/// +/// ```no_run +/// use google_genai::prelude::*; +/// +/// # async fn run() -> google_genai::error::Result<()> { +/// let client = GeminiClient::new("YOUR_API_KEY".into()); +/// let request = GenerateContentRequest::builder() +/// .contents(vec![Content::builder().add_text_part("Hi!").build()]) +/// .build(); +/// let response = client.generate_content(&request, "gemini-2.0-flash").await?; +/// # Ok(()) +/// # } +/// ``` #[derive(Clone, Debug)] pub struct GeminiClient { client: reqwest::Client, @@ -13,6 +33,7 @@ pub struct GeminiClient { } impl GeminiClient { + /// Creates a new [`GeminiClient`] with the given API key. pub fn new(api_key: String) -> Self { GeminiClient { client: reqwest::Client::new(), @@ -20,6 +41,10 @@ impl GeminiClient { } } + /// Sends a content generation request and returns a stream of response chunks via SSE. + /// + /// Each item in the stream is a [`GenerateContentResponseResult`] containing one or more + /// candidates. Useful for displaying incremental output as it is generated. pub async fn stream_generate_content( &self, request: &GenerateContentRequest, @@ -52,6 +77,9 @@ impl GeminiClient { ) } + /// Sends a content generation request and returns the complete response. + /// + /// For streaming responses, use [`stream_generate_content`](Self::stream_generate_content). pub async fn generate_content( &self, request: &GenerateContentRequest, @@ -94,6 +122,7 @@ impl GeminiClient { } } + /// Generates text embeddings for the given input. pub async fn text_embeddings( &self, request: &TextEmbeddingRequest, @@ -134,6 +163,7 @@ impl GeminiClient { } } + /// Counts the number of tokens in the given content. pub async fn count_tokens( &self, request: &CountTokensRequest, @@ -174,6 +204,7 @@ impl GeminiClient { } } + /// Generates images from a text prompt using an Imagen model. pub async fn predict_image( &self, request: &PredictImageRequest, diff --git a/src/error.rs b/src/error.rs index 77b1692..30f1e15 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,21 +1,39 @@ +//! Error types for the Google Gemini client. + use std::fmt::Display; use tokio_util::codec::LinesCodecError; use crate::types; +/// A type alias for `Result`. pub type Result = std::result::Result; +/// Errors that can occur when using the Gemini client. #[derive(Debug)] pub enum Error { + /// An environment variable required for configuration was missing or invalid. Env(std::env::VarError), + /// An HTTP transport error from the underlying `reqwest` client. HttpClient(reqwest::Error), + /// A JSON serialization or deserialization error. Serde(serde_json::Error), + /// A structured error returned by the Vertex AI API. VertexError(types::VertexApiError), + /// A structured error returned by the Gemini API. GeminiError(types::GeminiApiError), + /// The API response contained no candidate completions. NoCandidatesError, + /// An error occurred while decoding the SSE event stream. EventSourceError(LinesCodecError), + /// The SSE event stream closed unexpectedly. EventSourceClosedError, - GenericApiError { status: u16, body: String }, + /// An API error that could not be parsed into a structured error type. + GenericApiError { + /// The HTTP status code. + status: u16, + /// The raw response body. + body: String, + }, } impl Display for Error { diff --git a/src/lib.rs b/src/lib.rs index f26b63e..9f41712 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,8 +1,37 @@ +//! Async Rust client for the Google Gemini API. +//! +//! This crate provides a high-level async client for interacting with Google's Gemini +//! generative AI models. It supports content generation (including streaming via SSE), +//! token counting, text embeddings, and image generation. +//! +//! # Usage +//! +//! ```no_run +//! use google_genai::prelude::*; +//! +//! # async fn run() -> google_genai::error::Result<()> { +//! let client = GeminiClient::new("YOUR_API_KEY".into()); +//! +//! let request = GenerateContentRequest::builder() +//! .contents(vec![ +//! Content::builder().add_text_part("Hello, Gemini!").build() +//! ]) +//! .build(); +//! +//! let response = client.generate_content(&request, "gemini-2.0-flash").await?; +//! # Ok(()) +//! # } +//! ``` + mod client; pub mod error; pub mod network; mod types; +/// Convenience re-exports of the most commonly used types. +/// +/// Importing `use google_genai::prelude::*` brings [`GeminiClient`](crate::prelude::GeminiClient) +/// and all request/response types into scope. pub mod prelude { pub use crate::client::*; pub use crate::types::*; diff --git a/src/network/event_source.rs b/src/network/event_source.rs index 53eaff9..4e74f78 100644 --- a/src/network/event_source.rs +++ b/src/network/event_source.rs @@ -1,3 +1,11 @@ +//! Server-Sent Events (SSE) decoder for streaming HTTP responses. +//! +//! Implements a [`tokio_util::codec::Decoder`] that parses an SSE byte stream into +//! [`ServerSentEvent`] values. Used internally by [`GeminiClient::stream_generate_content`] +//! to process chunked model responses. +//! +//! [`GeminiClient::stream_generate_content`]: crate::prelude::GeminiClient::stream_generate_content + use reqwest::Response; use std::mem; use tokio_stream::{Stream, StreamExt}; @@ -12,7 +20,9 @@ static DATA: &str = "data: "; static ID: &str = "id: "; static RETRY: &str = "retry: "; +/// Extension trait for converting an HTTP response into a stream of [`ServerSentEvent`]s. pub trait EventSource { + /// Consumes the response and returns a stream of parsed SSE events. fn event_stream(self) -> impl Stream>; } @@ -22,14 +32,25 @@ impl EventSource for Response { } } +/// A parsed Server-Sent Event. +/// +/// Fields correspond to the standard SSE fields: `event`, `data`, `id`, and `retry`. +/// Multiple `data:` lines within a single event are concatenated with newline separators. #[derive(Debug, Default, Clone)] pub struct ServerSentEvent { + /// The event type (from the `event:` field). pub event: Option, + /// The event payload (from one or more `data:` fields, joined by `\n`). pub data: Option, + /// The event ID (from the `id:` field). pub id: Option, + /// The reconnection time in milliseconds (from the `retry:` field). pub retry: Option, } +/// A [`Decoder`] that parses a byte stream of SSE-formatted data into [`ServerSentEvent`]s. +/// +/// Wraps a [`LinesCodec`] and accumulates fields until an empty line signals the end of an event. pub struct ServerSentEventsCodec { lines_code: LinesCodec, next: ServerSentEvent, @@ -42,6 +63,7 @@ impl Default for ServerSentEventsCodec { } impl ServerSentEventsCodec { + /// Creates a new SSE codec. pub fn new() -> Self { Self { lines_code: LinesCodec::new(), @@ -95,6 +117,9 @@ impl Decoder for ServerSentEventsCodec { } } +/// Converts a [`Response`] into a stream of [`ServerSentEvent`]s. +/// +/// The response body is read as a byte stream and decoded using [`ServerSentEventsCodec`]. pub fn stream_response( response: Response, ) -> impl Stream> { diff --git a/src/network/mod.rs b/src/network/mod.rs index aafb61f..5ac91d2 100644 --- a/src/network/mod.rs +++ b/src/network/mod.rs @@ -1 +1,3 @@ +//! Networking utilities for streaming HTTP responses. + pub mod event_source; diff --git a/src/types/common.rs b/src/types/common.rs index 74470c9..ed3f1b8 100644 --- a/src/types/common.rs +++ b/src/types/common.rs @@ -5,13 +5,21 @@ use serde_json::Value; use crate::types::FunctionResponse; +/// A conversation message containing one or more [`Part`]s. +/// +/// See . #[derive(Clone, Default, Debug, Serialize, Deserialize)] pub struct Content { + /// The role of the message author (`user` or `model`). pub role: Option, + /// The ordered parts that make up this message. pub parts: Option>, } impl Content { + /// Concatenates all [`PartData::Text`] parts into a single string. + /// + /// Returns `None` if there are no parts. pub fn get_text(&self) -> Option { self.parts.as_ref().map(|parts| { parts @@ -24,25 +32,30 @@ impl Content { }) } + /// Creates a [`Content`] containing a single text part, suitable for use as a system instruction. pub fn system_prompt>(system_prompt: S) -> Self { Self::builder().add_text_part(system_prompt).build() } + /// Returns a new [`ContentBuilder`]. pub fn builder() -> ContentBuilder { ContentBuilder::default() } } +/// Builder for constructing [`Content`] values incrementally. #[derive(Clone, Debug, Default)] pub struct ContentBuilder { content: Content, } impl ContentBuilder { + /// Appends a text part to this content. pub fn add_text_part>(self, text: T) -> Self { self.add_part(Part::from_text(text.into())) } + /// Appends an arbitrary [`Part`] to this content. pub fn add_part(mut self, part: Part) -> Self { match &mut self.content.parts { Some(parts) => parts.push(part), @@ -51,16 +64,19 @@ impl ContentBuilder { self } + /// Sets the [`Role`] for this content. pub fn role(mut self, role: Role) -> Self { self.content.role = Some(role); self } + /// Consumes the builder and returns the constructed [`Content`]. pub fn build(self) -> Content { self.content } } +/// The role of a message author in a conversation. #[derive(Clone, Copy, Debug, Serialize, Deserialize)] #[serde(rename_all = "lowercase")] pub enum Role { @@ -90,7 +106,9 @@ impl FromStr for Role { } } -/// See https://ai.google.dev/api/caching#Part +/// A single unit of content within a [`Content`] message. +/// +/// See . #[derive(Clone, Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct Part { @@ -108,29 +126,42 @@ pub struct Part { pub data: PartData, // Create enum for data. } +/// The payload of a [`Part`], representing different content types. +/// +/// See . #[derive(Clone, Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub enum PartData { + /// Plain text content. Text(String), - // https://ai.google.dev/api/caching#Blob + /// Binary data encoded inline. See . InlineData { + /// The IANA MIME type of the data (e.g. `"image/png"`). mime_type: String, + /// Base64-encoded binary data. data: String, }, - // https://ai.google.dev/api/caching#FunctionCall + /// A function call requested by the model. See . FunctionCall { + /// Optional unique identifier for the function call. id: Option, + /// The name of the function to call. name: String, + /// The arguments to pass, as a JSON object. args: Option, }, - // https://ai.google.dev/api/caching#FunctionResponse + /// A response to a function call. See . FunctionResponse(FunctionResponse), + /// A reference to a file stored in the API. FileData(Value), + /// Code to be executed by the model. ExecutableCode(Value), + /// The result of executing code. CodeExecutionResult(Value), } impl Part { + /// Creates a [`Part`] containing only text. pub fn from_text>(text: S) -> Self { Self { thought: None, diff --git a/src/types/count_tokens.rs b/src/types/count_tokens.rs index 2a94893..31cb4a0 100644 --- a/src/types/count_tokens.rs +++ b/src/types/count_tokens.rs @@ -4,23 +4,32 @@ use crate::error::{Error, Result}; use super::Content; +/// Request body for the `countTokens` endpoint. +/// +/// Use [`CountTokensRequest::builder`] for ergonomic construction. +/// +/// See . #[derive(Clone, Debug, Serialize, Deserialize)] pub struct CountTokensRequest { + /// The content to count tokens for. pub contents: Content, } impl CountTokensRequest { + /// Returns a new [`CountTokensRequestBuilder`]. pub fn builder() -> CountTokensRequestBuilder { CountTokensRequestBuilder::default() } } +/// Builder for [`CountTokensRequest`]. #[derive(Debug, Default)] pub struct CountTokensRequestBuilder { contents: Content, } impl CountTokensRequestBuilder { + /// Creates a builder pre-populated with a single text prompt. pub fn from_prompt(prompt: &str) -> Self { CountTokensRequestBuilder { contents: Content { @@ -30,6 +39,7 @@ impl CountTokensRequestBuilder { } } + /// Consumes the builder and returns the constructed [`CountTokensRequest`]. pub fn build(self) -> CountTokensRequest { CountTokensRequest { contents: self.contents, @@ -37,6 +47,9 @@ impl CountTokensRequestBuilder { } } +/// The raw response from the `countTokens` endpoint, which may be a success or an error. +/// +/// Use [`into_result`](CountTokensResponse::into_result) to convert into a standard `Result`. #[derive(Clone, Debug, Serialize, Deserialize)] #[serde(untagged)] pub enum CountTokensResponse { @@ -45,6 +58,7 @@ pub enum CountTokensResponse { } impl CountTokensResponse { + /// Converts this response into a `Result`, mapping the error variant to [`crate::error::Error`]. pub fn into_result(self) -> Result { match self { CountTokensResponse::Ok(result) => Ok(result), @@ -53,9 +67,12 @@ impl CountTokensResponse { } } +/// A successful response from the `countTokens` endpoint. #[derive(Clone, Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct CountTokensResponseResult { + /// The total number of tokens in the input. pub total_tokens: i32, + /// The total number of billable characters in the input. pub total_billable_characters: u32, } diff --git a/src/types/error.rs b/src/types/error.rs index 8c0e6ae..9c30c77 100644 --- a/src/types/error.rs +++ b/src/types/error.rs @@ -2,11 +2,16 @@ use std::fmt::Formatter; use serde::{Deserialize, Serialize}; +/// A structured error returned by the Vertex AI / Gemini API. #[derive(Clone, Debug, Serialize, Deserialize)] pub struct VertexApiError { + /// The HTTP status code. pub code: i32, + /// A human-readable error message. pub message: String, + /// The gRPC status string (e.g. `"INVALID_ARGUMENT"`). pub status: String, + /// Optional additional error details. pub details: Option>, } @@ -19,8 +24,12 @@ impl core::fmt::Display for VertexApiError { impl std::error::Error for VertexApiError {} +/// A wrapper around [`VertexApiError`] matching the Gemini API error response format. +/// +/// The Gemini API nests the error details inside an `error` field. #[derive(Clone, Debug, Serialize, Deserialize)] pub struct GeminiApiError { + /// The inner error details. pub error: VertexApiError, } diff --git a/src/types/generate_content.rs b/src/types/generate_content.rs index 094e541..9d733c9 100644 --- a/src/types/generate_content.rs +++ b/src/types/generate_content.rs @@ -4,6 +4,11 @@ use serde_json::Value; use super::{Content, VertexApiError}; use crate::error::Result; +/// Request body for the `generateContent` and `streamGenerateContent` endpoints. +/// +/// Use [`GenerateContentRequest::builder`] for ergonomic construction. +/// +/// See . #[derive(Clone, Debug, Default, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct GenerateContentRequest { @@ -19,11 +24,13 @@ pub struct GenerateContentRequest { } impl GenerateContentRequest { + /// Returns a new [`GenerateContentRequestBuilder`]. pub fn builder() -> GenerateContentRequestBuilder { GenerateContentRequestBuilder::new() } } +/// Builder for [`GenerateContentRequest`]. #[derive(Debug)] pub struct GenerateContentRequestBuilder { request: GenerateContentRequest, @@ -36,36 +43,45 @@ impl GenerateContentRequestBuilder { } } + /// Sets the conversation contents. pub fn contents(mut self, contents: Vec) -> Self { self.request.contents = contents; self } + /// Sets the generation configuration. pub fn generation_config(mut self, generation_config: GenerationConfig) -> Self { self.request.generation_config = Some(generation_config); self } + /// Sets the tools available to the model (e.g. function calling, Google Search). pub fn tools(mut self, tools: Vec) -> Self { self.request.tools = Some(tools); self } + /// Sets the safety filter settings. pub fn safety_settings(mut self, safety_settings: Vec) -> Self { self.request.safety_settings = Some(safety_settings); self } + /// Sets a system instruction to guide the model's behavior. pub fn system_instruction(mut self, system_instruction: Content) -> Self { self.request.system_instruction = Some(system_instruction); self } + /// Consumes the builder and returns the constructed [`GenerateContentRequest`]. pub fn build(self) -> GenerateContentRequest { self.request } } +/// A set of tool declarations the model may use during generation. +/// +/// See . #[derive(Clone, Debug, Default, Serialize, Deserialize)] pub struct Tools { #[serde(skip_serializing_if = "Option::is_none")] @@ -79,13 +95,17 @@ pub struct Tools { pub google_search: Option, } +/// Enables the Google Search grounding tool (no configuration required). #[derive(Clone, Debug, Default, Serialize, Deserialize)] pub struct GoogleSearch {} +/// Configuration for dynamic retrieval in Google Search grounding. #[derive(Clone, Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct DynamicRetrievalConfig { + /// The retrieval mode (e.g. `"MODE_DYNAMIC"`). pub mode: String, + /// The threshold for triggering retrieval. Defaults to `0.7`. #[serde(skip_serializing_if = "Option::is_none")] pub dynamic_threshold: Option, } @@ -99,12 +119,19 @@ impl Default for DynamicRetrievalConfig { } } +/// Google Search retrieval tool with dynamic retrieval configuration. #[derive(Clone, Debug, Default, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct GoogleSearchRetrieval { + /// Configuration controlling when retrieval is triggered. pub dynamic_retrieval_config: DynamicRetrievalConfig, } +/// Parameters that control how the model generates content. +/// +/// Use [`GenerationConfig::builder`] for ergonomic construction. +/// +/// See . #[derive(Clone, Debug, Serialize, Deserialize, Default)] #[serde(rename_all = "camelCase")] pub struct GenerationConfig { @@ -129,11 +156,13 @@ pub struct GenerationConfig { } impl GenerationConfig { + /// Returns a new [`GenerationConfigBuilder`]. pub fn builder() -> GenerationConfigBuilder { GenerationConfigBuilder::new() } } +/// Builder for [`GenerationConfig`]. #[derive(Debug)] pub struct GenerationConfigBuilder { generation_config: GenerationConfig, @@ -191,11 +220,13 @@ impl GenerationConfigBuilder { self } + /// Consumes the builder and returns the constructed [`GenerationConfig`]. pub fn build(self) -> GenerationConfig { self.generation_config } } +/// Configuration for the model's "thinking" (chain-of-thought) behavior. #[derive(Clone, Debug, Default, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct ThinkingConfig { @@ -206,6 +237,7 @@ pub struct ThinkingConfig { pub thinking_level: Option, } +/// The level of thinking effort the model should use. #[derive(Clone, Debug, Serialize, Deserialize)] #[serde(rename_all = "SCREAMING_SNAKE_CASE")] pub enum ThinkingLevel { @@ -214,6 +246,9 @@ pub enum ThinkingLevel { High, } +/// A safety filter configuration that controls blocking thresholds for harmful content. +/// +/// See . #[derive(Clone, Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct SafetySetting { @@ -223,6 +258,9 @@ pub struct SafetySetting { pub method: Option, } +/// Categories of potentially harmful content. +/// +/// See . #[derive(Clone, Debug, Serialize, Deserialize)] pub enum HarmCategory { #[serde(rename = "HARM_CATEGORY_UNSPECIFIED")] @@ -237,6 +275,7 @@ pub enum HarmCategory { SexuallyExplicit, } +/// The threshold at which harmful content is blocked. #[derive(Clone, Debug, Serialize, Deserialize)] pub enum HarmBlockThreshold { #[serde(rename = "HARM_BLOCK_THRESHOLD_UNSPECIFIED")] @@ -251,6 +290,7 @@ pub enum HarmBlockThreshold { BlockNone, } +/// The method used to evaluate harm (severity-based or probability-based). #[derive(Clone, Debug, Serialize, Deserialize)] pub enum HarmBlockMethod { #[serde(rename = "HARM_BLOCK_METHOD_UNSPECIFIED")] @@ -261,6 +301,9 @@ pub enum HarmBlockMethod { Probability, // PROBABILITY } +/// A single candidate response generated by the model. +/// +/// See . #[derive(Clone, Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct Candidate { @@ -276,6 +319,7 @@ pub struct Candidate { } impl Candidate { + /// Returns the concatenated text from this candidate's content, if any. pub fn get_text(&self) -> Option { match &self.content { Some(content) => content.get_text(), @@ -284,6 +328,7 @@ impl Candidate { } } +/// A citation to a source used by the model in its response. #[derive(Clone, Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct Citation { @@ -292,12 +337,14 @@ pub struct Citation { pub uri: Option, } +/// Metadata containing citations for a candidate's content. #[derive(Clone, Debug, Serialize, Deserialize)] pub struct CitationMetadata { #[serde(alias = "citationSources")] pub citations: Vec, } +/// A safety rating for a piece of content across a specific harm category. #[derive(Clone, Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct SafetyRating { @@ -308,6 +355,7 @@ pub struct SafetyRating { pub severity_score: Option, } +/// Token usage statistics for a generate content request/response. #[derive(Clone, Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct UsageMetadata { @@ -316,6 +364,9 @@ pub struct UsageMetadata { pub total_token_count: Option, } +/// A declaration of a function the model may call. +/// +/// See . #[derive(Clone, Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct FunctionDeclaration { @@ -332,7 +383,7 @@ pub struct FunctionDeclaration { pub response_json_schema: Option, } -/// See https://ai.google.dev/api/caching#FunctionResponse +/// See . #[derive(Clone, Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct FunctionResponse { @@ -345,14 +396,14 @@ pub struct FunctionResponse { pub scheduling: Option, } -/// See https://ai.google.dev/api/caching#FunctionResponsePart +/// See . #[derive(Clone, Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub enum FunctionResponsePart { InlineData(FunctionResponseBlob), } -/// See https://ai.google.dev/api/caching#FunctionResponseBlob +/// See . #[derive(Clone, Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct FunctionResponseBlob { @@ -360,7 +411,7 @@ pub struct FunctionResponseBlob { pub data: String, } -/// See https://ai.google.dev/api/caching#Scheduling +/// See . #[derive(Clone, Debug, Serialize, Deserialize)] #[serde(rename_all = "SCREAMING_SNAKE_CASE")] pub enum Scheduling { @@ -370,6 +421,7 @@ pub enum Scheduling { Interrupt, } +/// A single property within a function's parameter schema. #[derive(Clone, Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct FunctionParametersProperty { @@ -377,6 +429,10 @@ pub struct FunctionParametersProperty { pub description: String, } +/// The raw response from the `generateContent` endpoint, which may be a success or an error. +/// +/// Use [`into_result`](GenerateContentResponse::into_result) to convert into a standard +/// `Result`. #[derive(Clone, Debug, Serialize, Deserialize)] #[serde(untagged)] pub enum GenerateContentResponse { @@ -393,6 +449,7 @@ impl From for Result { } } +/// A successful response from the `generateContent` endpoint. #[derive(Clone, Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct GenerateContentResponseResult { @@ -400,12 +457,14 @@ pub struct GenerateContentResponseResult { pub usage_metadata: Option, } +/// An error response from the `generateContent` endpoint. #[derive(Clone, Debug, Serialize, Deserialize)] pub struct GenerateContentResponseError { pub error: VertexApiError, } impl GenerateContentResponse { + /// Converts this response into a `Result`, mapping the error variant to [`crate::error::Error`]. pub fn into_result(self) -> Result { match self { GenerateContentResponse::Ok(result) => Ok(result), diff --git a/src/types/mod.rs b/src/types/mod.rs index 8fa1ed5..24ab35f 100644 --- a/src/types/mod.rs +++ b/src/types/mod.rs @@ -1,3 +1,5 @@ +//! Request and response types for the Gemini API. + mod common; mod count_tokens; mod error; diff --git a/src/types/predict_image.rs b/src/types/predict_image.rs index 31c751e..4e15087 100644 --- a/src/types/predict_image.rs +++ b/src/types/predict_image.rs @@ -2,12 +2,14 @@ use serde::{Deserialize, Serialize}; use serde_with::base64::Base64; use serde_with::serde_as; +/// Request body for the Imagen image generation `predict` endpoint. #[derive(Clone, Debug, Serialize, Deserialize)] pub struct PredictImageRequest { pub instances: Vec, pub parameters: PredictImageRequestParameters, } +/// A text prompt instance for image generation. #[derive(Clone, Debug, Serialize, Deserialize)] pub struct PredictImageRequestPrompt { /// The text prompt for the image. @@ -20,6 +22,7 @@ pub struct PredictImageRequestPrompt { pub prompt: String, } +/// Parameters controlling image generation behavior. #[derive(Clone, Debug, Default, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct PredictImageRequestParameters { @@ -139,6 +142,7 @@ pub struct PredictImageRequestParameters { pub storage_uri: Option, } +/// Output format options for generated images. #[derive(Clone, Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct PredictImageRequestParametersOutputOptions { @@ -155,11 +159,13 @@ pub struct PredictImageRequestParametersOutputOptions { pub compression_quality: Option, } +/// A successful response from the Imagen `predict` endpoint. #[derive(Clone, Debug, Serialize, Deserialize)] pub struct PredictImageResponse { pub predictions: Vec, } +/// A single generated image from the prediction response. #[serde_as] #[derive(Clone, Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] @@ -169,6 +175,7 @@ pub struct PredictImageResponsePrediction { pub mime_type: String, } +/// Controls whether generated images may include people. #[derive(Clone, Debug, Serialize, Deserialize)] #[serde(rename_all = "snake_case")] pub enum PersonGeneration { @@ -177,6 +184,7 @@ pub enum PersonGeneration { AllowAll, } +/// Safety filter level for image generation. #[derive(Clone, Debug, Serialize, Deserialize)] #[serde(rename_all = "snake_case")] pub enum PredictImageSafetySetting { diff --git a/src/types/text_embeddings.rs b/src/types/text_embeddings.rs index 3548307..a4cf9db 100644 --- a/src/types/text_embeddings.rs +++ b/src/types/text_embeddings.rs @@ -3,18 +3,27 @@ use serde::{Deserialize, Serialize}; use crate::error::{Error, Result}; use crate::prelude::VertexApiError; +/// Request body for the text embeddings `predict` endpoint. #[derive(Clone, Debug, Serialize, Deserialize)] pub struct TextEmbeddingRequest { + /// The list of text instances to embed. pub instances: Vec, } +/// A single text instance to embed. #[derive(Clone, Debug, Serialize, Deserialize)] pub struct TextEmbeddingRequestInstance { + /// The text content to generate an embedding for. pub content: String, + /// The task type for the embedding (e.g. `"RETRIEVAL_DOCUMENT"`, `"RETRIEVAL_QUERY"`). pub task_type: String, + /// An optional title for the content (used with retrieval task types). pub title: Option, } +/// The raw response from the text embeddings endpoint, which may be a success or an error. +/// +/// Use [`into_result`](TextEmbeddingResponse::into_result) to convert into a standard `Result`. #[derive(Clone, Debug, Serialize, Deserialize)] #[serde(untagged)] pub enum TextEmbeddingResponse { @@ -23,13 +32,16 @@ pub enum TextEmbeddingResponse { } impl TextEmbeddingResponse { + /// Converts this response into a `Result`, mapping the error variant to [`crate::error::Error`]. pub fn into_result(self) -> Result { self.into() } } +/// A successful response from the text embeddings endpoint. #[derive(Clone, Debug, Serialize, Deserialize)] pub struct TextEmbeddingResponseOk { + /// The embedding predictions, one per input instance. pub predictions: Vec, } @@ -42,19 +54,27 @@ impl From for Result { } } +/// A single embedding prediction. #[derive(Clone, Debug, Serialize, Deserialize)] pub struct TextEmbeddingPrediction { + /// The embedding result containing the vector and statistics. pub embeddings: TextEmbeddingResult, } +/// The embedding vector and associated statistics. #[derive(Clone, Debug, Serialize, Deserialize)] pub struct TextEmbeddingResult { + /// Statistics about the embedding computation. pub statistics: TextEmbeddingStatistics, + /// The embedding vector. pub values: Vec, } +/// Statistics about a text embedding computation. #[derive(Clone, Debug, Serialize, Deserialize)] pub struct TextEmbeddingStatistics { + /// Whether the input was truncated to fit the model's context window. pub truncated: bool, + /// The number of tokens in the input. pub token_count: u32, }