From bbe5433d8355e73c2743301bdbf68d574b86f832 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9=20Cipriani=20Bandarra?= Date: Sat, 5 Oct 2024 17:07:18 +0100 Subject: [PATCH] Adds code for image generation --- .gitignore | 1 + Cargo.toml | 2 + examples/generate_image.rs | 45 ++++++++++++ src/client.rs | 24 +++++++ src/types/mod.rs | 2 + src/types/predict_image.rs | 142 +++++++++++++++++++++++++++++++++++++ 6 files changed, 216 insertions(+) create mode 100644 examples/generate_image.rs create mode 100644 src/types/predict_image.rs diff --git a/.gitignore b/.gitignore index 19002ec..77063c3 100644 --- a/.gitignore +++ b/.gitignore @@ -7,3 +7,4 @@ Cargo.lock .cargo/config.toml .DS_Store /.idea +output.jpg diff --git a/Cargo.toml b/Cargo.toml index 1aee212..d7a9b3e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,12 +13,14 @@ reqwest = { version = "0.12", features = ["json", "gzip"] } reqwest-eventsource = "0.6" serde = { version = "1", features = ["derive"] } serde_json = { version = "1"} +serde_with = { version = "3.9", features = ["base64"]} tracing = "0.1" tokio = { version = "1" } [dev-dependencies] console = "0.15.8" dialoguer = "0.11.0" +image = "0.25.2" indicatif = "0.17.8" tokio = { version = "1.37.0", features = ["full"] } tracing-subscriber = "0.3.18" diff --git a/examples/generate_image.rs b/examples/generate_image.rs new file mode 100644 index 0000000..117c4fd --- /dev/null +++ b/examples/generate_image.rs @@ -0,0 +1,45 @@ +use std::{error::Error, io::Cursor}; + +use gemini_rs::prelude::{ + GeminiClient, PredictImageRequest, PredictImageRequestParameters, PredictImageRequestPrompt, +}; +use image::{ImageFormat, ImageReader}; + +#[tokio::main] +pub async fn main() -> Result<(), Box> { + tracing_subscriber::fmt().init(); + let authentication_manager = gcp_auth::provider().await?; + let api_endpoint = std::env::var("API_ENDPOINT")?; + let project_id = std::env::var("PROJECT_ID")?; + let location_id = std::env::var("LOCATION_ID")?; + + let gemini = GeminiClient::new( + authentication_manager, + api_endpoint, + project_id, + location_id, + ); + + let prompt = " + Create an image of a tuxedo cat riding a rocket to the moon."; + let request = PredictImageRequest { + instances: vec![PredictImageRequestPrompt { + prompt: prompt.to_string(), + }], + parameters: PredictImageRequestParameters { + sample_count: 1, + ..Default::default() + }, + }; + let mut result = gemini + .predict_image(&request, "imagen-3.0-fast-generate-001") + .await?; + + let result = result.predictions.pop().unwrap(); + + let format = ImageFormat::from_mime_type(result.mime_type).unwrap(); + let img = + ImageReader::with_format(Cursor::new(result.bytes_base64_encoded), format).decode()?; + img.save("output.jpg")?; + Ok(()) +} diff --git a/src/client.rs b/src/client.rs index 89a57d5..92bfce7 100644 --- a/src/client.rs +++ b/src/client.rs @@ -11,6 +11,7 @@ use crate::prelude::{ GenerateContentResponse, GenerateContentResponseResult, GenerationConfig, TextEmbeddingRequest, TextEmbeddingResponse, }; +use crate::types::{PredictImageRequest, PredictImageResponse}; use crate::{prelude::Part, token_provider::TokenProvider}; pub static AUTH_SCOPE: &[&str] = &["https://www.googleapis.com/auth/cloud-platform"]; @@ -252,4 +253,27 @@ impl GeminiClient { tracing::debug!("count_tokens response: {:?}", txt_json); Ok(serde_json::from_str(&txt_json)?) } + + pub async fn predict_image( + &self, + request: &PredictImageRequest, + model: &str, + ) -> Result { + let endpoint_url = format!( + "https://{}/v1/projects/{}/locations/{}/publishers/google/models/{}:predict", + self.api_endpoint, self.project_id, self.location_id, model, + ); + + let access_token = self.token_provider.get_token(AUTH_SCOPE).await?; + let resp = self + .client + .post(&endpoint_url) + .bearer_auth(access_token) + .json(&request) + .send() + .await?; + + let txt_json = resp.text().await?; + Ok(serde_json::from_str(&txt_json)?) + } } diff --git a/src/types/mod.rs b/src/types/mod.rs index 42712fd..8fa1ed5 100644 --- a/src/types/mod.rs +++ b/src/types/mod.rs @@ -2,10 +2,12 @@ mod common; mod count_tokens; mod error; mod generate_content; +mod predict_image; mod text_embeddings; pub use common::*; pub use count_tokens::*; pub use error::*; pub use generate_content::*; +pub use predict_image::*; pub use text_embeddings::*; diff --git a/src/types/predict_image.rs b/src/types/predict_image.rs new file mode 100644 index 0000000..139052b --- /dev/null +++ b/src/types/predict_image.rs @@ -0,0 +1,142 @@ +use serde::{Deserialize, Serialize}; +use serde_with::base64::Base64; +use serde_with::serde_as; + +#[derive(Debug, Serialize, Deserialize)] +pub struct PredictImageRequest { + pub instances: Vec, + pub parameters: PredictImageRequestParameters, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct PredictImageRequestPrompt { + /// The text prompt for the image. + /// The following models support different values for this parameter: + /// - `imagen-3.0-generate-001`: up to 480 tokens. + /// - `imagen-3.0-fast-generate-001`: up to 480 tokens. + /// - `imagegeneration@006`: up to 128 tokens. + /// - `imagegeneration@005`: up to 128 tokens. + /// - `imagegeneration@002`: up to 64 tokens. + pub prompt: String, +} + +#[derive(Debug, Default, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PredictImageRequestParameters { + /// The number of images to generate. The default value is 4. + /// The following models support different values for this parameter: + /// - `imagen-3.0-generate-001`: 1 to 4. + /// - `imagen-3.0-fast-generate-001`: 1 to 4. + /// - `imagegeneration@006`: 1 to 4. + /// - `imagegeneration@005`: 1 to 4. + /// - `imagegeneration@002`: 1 to 8. + pub sample_count: i32, + + /// The random seed for image generation. This is not available when addWatermark is set to + /// true. + #[serde(skip_serializing_if = "Option::is_none")] + pub seed: Option, + + /// A description of what to discourage in the generated images. + /// The following models support this parameter: + /// - `imagen-3.0-generate-001`: up to 480 tokens. + /// - `imagen-3.0-fast-generate-001`: up to 480 tokens. + /// - `imagegeneration@006`: up to 128 tokens. + /// - `imagegeneration@005`: up to 128 tokens. + /// - `imagegeneration@002`: up to 64 tokens. + #[serde(skip_serializing_if = "Option::is_none")] + pub negative_prompt: Option, + + /// The aspect ratio for the image. The default value is "1:1". + /// The following models support different values for this parameter: + /// - `imagen-3.0-generate-001`: "1:1", "9:16", "16:9", "3:4", or "4:3". + /// - `imagen-3.0-fast-generate-001`: "1:1", "9:16", "16:9", "3:4", or "4:3". + /// - `imagegeneration@006`: "1:1", "9:16", "16:9", "3:4", or "4:3". + /// - `imagegeneration@005`: "1:1" or "9:16". + /// - `imagegeneration@002`: "1:1". + #[serde(skip_serializing_if = "Option::is_none")] + pub aspect_ratio: Option, + + /// Describes the output image format in an `PredictImageRequestParametersOutputOptions + /// object. + #[serde(skip_serializing_if = "Option::is_none")] + pub output_options: Option, + + /// Describes the style for the generated images. The following values are supported: + /// - "photograph" + /// - "digital_art" + /// - "landscape" + /// - "sketch" + /// - "watercolor" + /// - "cyberpunk" + /// - "pop_art" + #[serde(skip_serializing_if = "Option::is_none")] + pub sample_image_style: Option, + + /// Allow generation of people by the model. The following values are supported: + /// - `"dont_allow"`: Disallow the inclusion of people or faces in images. + /// - `"allow_adult"`: Allow generation of adults only. + /// - `"allow_all"`: Allow generation of people of all ages. + /// + /// The default value is `"allow_adult"`. + /// + /// Supported by the models `imagen-3.0-generate-001`, `imagen-3.0-fast-generate-001`, and + /// `imagegeneration@006` only. + #[serde(skip_serializing_if = "Option::is_none")] + pub person_generation: Option, + + /// Adds a filter level to safety filtering. The following values are supported: + /// - `"block_most"`: Strongest filtering level, most strict blocking. + /// - `"block_some"`: Block some problematic prompts and responses. + /// - `"block_few"`: Reduces the number of requests blocked due to safety filters. May + /// increase objectionable content generated by Imagen. + /// - `"block_fewest"`: Block very few problematic prompts and responses. Access to this + /// feature is restricted. + /// + /// The default value is `"block_some"`. + /// + /// Supported by the models `imagen-3.0-generate-001`, `imagen-3.0-fast-generate-001`, and + /// `imagegeneration@006` only. + #[serde(skip_serializing_if = "Option::is_none")] + pub safety_setting: Option, + + /// Add an invisible watermark to the generated images. The default value is `false` for the + /// `imagegeneration@002` and `imagegeneration@005` models, and `true` for the + /// `imagen-3.0-fast-generate-001`, `imagegeneration@006`, and imagegeneration@006 models. + #[serde(skip_serializing_if = "Option::is_none")] + pub add_watermark: Option, + + /// Cloud Storage URI to store the generated images. + #[serde(skip_serializing_if = "Option::is_none")] + pub storage_uri: Option, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PredictImageRequestParametersOutputOptions { + /// The image format that the output should be saved as. The following values are supported: + /// + /// - "image/png": Save as a PNG image + /// - "image/jpeg": Save as a JPEG image + /// + /// The default value is "image/png".v + pub mime_type: Option, + + /// The level of compression if the output type is "image/jpeg". + /// Accepted values are 0 through 100. The default value is 75. + pub compression_quality: Option, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct PredictImageResponse { + pub predictions: Vec, +} + +#[serde_as] +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PredictImageResponsePrediction { + #[serde_as(as = "Base64")] + pub bytes_base64_encoded: Vec, + pub mime_type: String, +}