Adds code for image generation
This commit is contained in:
@@ -11,6 +11,7 @@ use crate::prelude::{
|
||||
GenerateContentResponse, GenerateContentResponseResult, GenerationConfig, TextEmbeddingRequest,
|
||||
TextEmbeddingResponse,
|
||||
};
|
||||
use crate::types::{PredictImageRequest, PredictImageResponse};
|
||||
use crate::{prelude::Part, token_provider::TokenProvider};
|
||||
|
||||
pub static AUTH_SCOPE: &[&str] = &["https://www.googleapis.com/auth/cloud-platform"];
|
||||
@@ -252,4 +253,27 @@ impl<T: TokenProvider + Clone> GeminiClient<T> {
|
||||
tracing::debug!("count_tokens response: {:?}", txt_json);
|
||||
Ok(serde_json::from_str(&txt_json)?)
|
||||
}
|
||||
|
||||
pub async fn predict_image(
|
||||
&self,
|
||||
request: &PredictImageRequest,
|
||||
model: &str,
|
||||
) -> Result<PredictImageResponse> {
|
||||
let endpoint_url = format!(
|
||||
"https://{}/v1/projects/{}/locations/{}/publishers/google/models/{}:predict",
|
||||
self.api_endpoint, self.project_id, self.location_id, model,
|
||||
);
|
||||
|
||||
let access_token = self.token_provider.get_token(AUTH_SCOPE).await?;
|
||||
let resp = self
|
||||
.client
|
||||
.post(&endpoint_url)
|
||||
.bearer_auth(access_token)
|
||||
.json(&request)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
let txt_json = resp.text().await?;
|
||||
Ok(serde_json::from_str(&txt_json)?)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,10 +2,12 @@ mod common;
|
||||
mod count_tokens;
|
||||
mod error;
|
||||
mod generate_content;
|
||||
mod predict_image;
|
||||
mod text_embeddings;
|
||||
|
||||
pub use common::*;
|
||||
pub use count_tokens::*;
|
||||
pub use error::*;
|
||||
pub use generate_content::*;
|
||||
pub use predict_image::*;
|
||||
pub use text_embeddings::*;
|
||||
|
||||
142
src/types/predict_image.rs
Normal file
142
src/types/predict_image.rs
Normal file
@@ -0,0 +1,142 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_with::base64::Base64;
|
||||
use serde_with::serde_as;
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct PredictImageRequest {
|
||||
pub instances: Vec<PredictImageRequestPrompt>,
|
||||
pub parameters: PredictImageRequestParameters,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct PredictImageRequestPrompt {
|
||||
/// The text prompt for the image.
|
||||
/// The following models support different values for this parameter:
|
||||
/// - `imagen-3.0-generate-001`: up to 480 tokens.
|
||||
/// - `imagen-3.0-fast-generate-001`: up to 480 tokens.
|
||||
/// - `imagegeneration@006`: up to 128 tokens.
|
||||
/// - `imagegeneration@005`: up to 128 tokens.
|
||||
/// - `imagegeneration@002`: up to 64 tokens.
|
||||
pub prompt: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct PredictImageRequestParameters {
|
||||
/// The number of images to generate. The default value is 4.
|
||||
/// The following models support different values for this parameter:
|
||||
/// - `imagen-3.0-generate-001`: 1 to 4.
|
||||
/// - `imagen-3.0-fast-generate-001`: 1 to 4.
|
||||
/// - `imagegeneration@006`: 1 to 4.
|
||||
/// - `imagegeneration@005`: 1 to 4.
|
||||
/// - `imagegeneration@002`: 1 to 8.
|
||||
pub sample_count: i32,
|
||||
|
||||
/// The random seed for image generation. This is not available when addWatermark is set to
|
||||
/// true.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub seed: Option<u32>,
|
||||
|
||||
/// A description of what to discourage in the generated images.
|
||||
/// The following models support this parameter:
|
||||
/// - `imagen-3.0-generate-001`: up to 480 tokens.
|
||||
/// - `imagen-3.0-fast-generate-001`: up to 480 tokens.
|
||||
/// - `imagegeneration@006`: up to 128 tokens.
|
||||
/// - `imagegeneration@005`: up to 128 tokens.
|
||||
/// - `imagegeneration@002`: up to 64 tokens.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub negative_prompt: Option<String>,
|
||||
|
||||
/// The aspect ratio for the image. The default value is "1:1".
|
||||
/// The following models support different values for this parameter:
|
||||
/// - `imagen-3.0-generate-001`: "1:1", "9:16", "16:9", "3:4", or "4:3".
|
||||
/// - `imagen-3.0-fast-generate-001`: "1:1", "9:16", "16:9", "3:4", or "4:3".
|
||||
/// - `imagegeneration@006`: "1:1", "9:16", "16:9", "3:4", or "4:3".
|
||||
/// - `imagegeneration@005`: "1:1" or "9:16".
|
||||
/// - `imagegeneration@002`: "1:1".
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub aspect_ratio: Option<String>,
|
||||
|
||||
/// Describes the output image format in an `PredictImageRequestParametersOutputOptions
|
||||
/// object.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub output_options: Option<PredictImageRequestParametersOutputOptions>,
|
||||
|
||||
/// Describes the style for the generated images. The following values are supported:
|
||||
/// - "photograph"
|
||||
/// - "digital_art"
|
||||
/// - "landscape"
|
||||
/// - "sketch"
|
||||
/// - "watercolor"
|
||||
/// - "cyberpunk"
|
||||
/// - "pop_art"
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub sample_image_style: Option<String>,
|
||||
|
||||
/// Allow generation of people by the model. The following values are supported:
|
||||
/// - `"dont_allow"`: Disallow the inclusion of people or faces in images.
|
||||
/// - `"allow_adult"`: Allow generation of adults only.
|
||||
/// - `"allow_all"`: Allow generation of people of all ages.
|
||||
///
|
||||
/// The default value is `"allow_adult"`.
|
||||
///
|
||||
/// Supported by the models `imagen-3.0-generate-001`, `imagen-3.0-fast-generate-001`, and
|
||||
/// `imagegeneration@006` only.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub person_generation: Option<String>,
|
||||
|
||||
/// Adds a filter level to safety filtering. The following values are supported:
|
||||
/// - `"block_most"`: Strongest filtering level, most strict blocking.
|
||||
/// - `"block_some"`: Block some problematic prompts and responses.
|
||||
/// - `"block_few"`: Reduces the number of requests blocked due to safety filters. May
|
||||
/// increase objectionable content generated by Imagen.
|
||||
/// - `"block_fewest"`: Block very few problematic prompts and responses. Access to this
|
||||
/// feature is restricted.
|
||||
///
|
||||
/// The default value is `"block_some"`.
|
||||
///
|
||||
/// Supported by the models `imagen-3.0-generate-001`, `imagen-3.0-fast-generate-001`, and
|
||||
/// `imagegeneration@006` only.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub safety_setting: Option<String>,
|
||||
|
||||
/// Add an invisible watermark to the generated images. The default value is `false` for the
|
||||
/// `imagegeneration@002` and `imagegeneration@005` models, and `true` for the
|
||||
/// `imagen-3.0-fast-generate-001`, `imagegeneration@006`, and imagegeneration@006 models.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub add_watermark: Option<bool>,
|
||||
|
||||
/// Cloud Storage URI to store the generated images.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub storage_uri: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct PredictImageRequestParametersOutputOptions {
|
||||
/// The image format that the output should be saved as. The following values are supported:
|
||||
///
|
||||
/// - "image/png": Save as a PNG image
|
||||
/// - "image/jpeg": Save as a JPEG image
|
||||
///
|
||||
/// The default value is "image/png".v
|
||||
pub mime_type: Option<String>,
|
||||
|
||||
/// The level of compression if the output type is "image/jpeg".
|
||||
/// Accepted values are 0 through 100. The default value is 75.
|
||||
pub compression_quality: Option<i32>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct PredictImageResponse {
|
||||
pub predictions: Vec<PredictImageResponsePrediction>,
|
||||
}
|
||||
|
||||
#[serde_as]
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct PredictImageResponsePrediction {
|
||||
#[serde_as(as = "Base64")]
|
||||
pub bytes_base64_encoded: Vec<u8>,
|
||||
pub mime_type: String,
|
||||
}
|
||||
Reference in New Issue
Block a user