Adds code for image generation

This commit is contained in:
2024-10-05 17:07:18 +01:00
parent cb954ea5db
commit bbe5433d83
6 changed files with 216 additions and 0 deletions

1
.gitignore vendored
View File

@@ -7,3 +7,4 @@ Cargo.lock
.cargo/config.toml .cargo/config.toml
.DS_Store .DS_Store
/.idea /.idea
output.jpg

View File

@@ -13,12 +13,14 @@ reqwest = { version = "0.12", features = ["json", "gzip"] }
reqwest-eventsource = "0.6" reqwest-eventsource = "0.6"
serde = { version = "1", features = ["derive"] } serde = { version = "1", features = ["derive"] }
serde_json = { version = "1"} serde_json = { version = "1"}
serde_with = { version = "3.9", features = ["base64"]}
tracing = "0.1" tracing = "0.1"
tokio = { version = "1" } tokio = { version = "1" }
[dev-dependencies] [dev-dependencies]
console = "0.15.8" console = "0.15.8"
dialoguer = "0.11.0" dialoguer = "0.11.0"
image = "0.25.2"
indicatif = "0.17.8" indicatif = "0.17.8"
tokio = { version = "1.37.0", features = ["full"] } tokio = { version = "1.37.0", features = ["full"] }
tracing-subscriber = "0.3.18" tracing-subscriber = "0.3.18"

View File

@@ -0,0 +1,45 @@
use std::{error::Error, io::Cursor};
use gemini_rs::prelude::{
GeminiClient, PredictImageRequest, PredictImageRequestParameters, PredictImageRequestPrompt,
};
use image::{ImageFormat, ImageReader};
#[tokio::main]
pub async fn main() -> Result<(), Box<dyn Error>> {
tracing_subscriber::fmt().init();
let authentication_manager = gcp_auth::provider().await?;
let api_endpoint = std::env::var("API_ENDPOINT")?;
let project_id = std::env::var("PROJECT_ID")?;
let location_id = std::env::var("LOCATION_ID")?;
let gemini = GeminiClient::new(
authentication_manager,
api_endpoint,
project_id,
location_id,
);
let prompt = "
Create an image of a tuxedo cat riding a rocket to the moon.";
let request = PredictImageRequest {
instances: vec![PredictImageRequestPrompt {
prompt: prompt.to_string(),
}],
parameters: PredictImageRequestParameters {
sample_count: 1,
..Default::default()
},
};
let mut result = gemini
.predict_image(&request, "imagen-3.0-fast-generate-001")
.await?;
let result = result.predictions.pop().unwrap();
let format = ImageFormat::from_mime_type(result.mime_type).unwrap();
let img =
ImageReader::with_format(Cursor::new(result.bytes_base64_encoded), format).decode()?;
img.save("output.jpg")?;
Ok(())
}

View File

@@ -11,6 +11,7 @@ use crate::prelude::{
GenerateContentResponse, GenerateContentResponseResult, GenerationConfig, TextEmbeddingRequest, GenerateContentResponse, GenerateContentResponseResult, GenerationConfig, TextEmbeddingRequest,
TextEmbeddingResponse, TextEmbeddingResponse,
}; };
use crate::types::{PredictImageRequest, PredictImageResponse};
use crate::{prelude::Part, token_provider::TokenProvider}; use crate::{prelude::Part, token_provider::TokenProvider};
pub static AUTH_SCOPE: &[&str] = &["https://www.googleapis.com/auth/cloud-platform"]; pub static AUTH_SCOPE: &[&str] = &["https://www.googleapis.com/auth/cloud-platform"];
@@ -252,4 +253,27 @@ impl<T: TokenProvider + Clone> GeminiClient<T> {
tracing::debug!("count_tokens response: {:?}", txt_json); tracing::debug!("count_tokens response: {:?}", txt_json);
Ok(serde_json::from_str(&txt_json)?) Ok(serde_json::from_str(&txt_json)?)
} }
pub async fn predict_image(
&self,
request: &PredictImageRequest,
model: &str,
) -> Result<PredictImageResponse> {
let endpoint_url = format!(
"https://{}/v1/projects/{}/locations/{}/publishers/google/models/{}:predict",
self.api_endpoint, self.project_id, self.location_id, model,
);
let access_token = self.token_provider.get_token(AUTH_SCOPE).await?;
let resp = self
.client
.post(&endpoint_url)
.bearer_auth(access_token)
.json(&request)
.send()
.await?;
let txt_json = resp.text().await?;
Ok(serde_json::from_str(&txt_json)?)
}
} }

View File

@@ -2,10 +2,12 @@ mod common;
mod count_tokens; mod count_tokens;
mod error; mod error;
mod generate_content; mod generate_content;
mod predict_image;
mod text_embeddings; mod text_embeddings;
pub use common::*; pub use common::*;
pub use count_tokens::*; pub use count_tokens::*;
pub use error::*; pub use error::*;
pub use generate_content::*; pub use generate_content::*;
pub use predict_image::*;
pub use text_embeddings::*; pub use text_embeddings::*;

142
src/types/predict_image.rs Normal file
View File

@@ -0,0 +1,142 @@
use serde::{Deserialize, Serialize};
use serde_with::base64::Base64;
use serde_with::serde_as;
#[derive(Debug, Serialize, Deserialize)]
pub struct PredictImageRequest {
pub instances: Vec<PredictImageRequestPrompt>,
pub parameters: PredictImageRequestParameters,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct PredictImageRequestPrompt {
/// The text prompt for the image.
/// The following models support different values for this parameter:
/// - `imagen-3.0-generate-001`: up to 480 tokens.
/// - `imagen-3.0-fast-generate-001`: up to 480 tokens.
/// - `imagegeneration@006`: up to 128 tokens.
/// - `imagegeneration@005`: up to 128 tokens.
/// - `imagegeneration@002`: up to 64 tokens.
pub prompt: String,
}
#[derive(Debug, Default, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct PredictImageRequestParameters {
/// The number of images to generate. The default value is 4.
/// The following models support different values for this parameter:
/// - `imagen-3.0-generate-001`: 1 to 4.
/// - `imagen-3.0-fast-generate-001`: 1 to 4.
/// - `imagegeneration@006`: 1 to 4.
/// - `imagegeneration@005`: 1 to 4.
/// - `imagegeneration@002`: 1 to 8.
pub sample_count: i32,
/// The random seed for image generation. This is not available when addWatermark is set to
/// true.
#[serde(skip_serializing_if = "Option::is_none")]
pub seed: Option<u32>,
/// A description of what to discourage in the generated images.
/// The following models support this parameter:
/// - `imagen-3.0-generate-001`: up to 480 tokens.
/// - `imagen-3.0-fast-generate-001`: up to 480 tokens.
/// - `imagegeneration@006`: up to 128 tokens.
/// - `imagegeneration@005`: up to 128 tokens.
/// - `imagegeneration@002`: up to 64 tokens.
#[serde(skip_serializing_if = "Option::is_none")]
pub negative_prompt: Option<String>,
/// The aspect ratio for the image. The default value is "1:1".
/// The following models support different values for this parameter:
/// - `imagen-3.0-generate-001`: "1:1", "9:16", "16:9", "3:4", or "4:3".
/// - `imagen-3.0-fast-generate-001`: "1:1", "9:16", "16:9", "3:4", or "4:3".
/// - `imagegeneration@006`: "1:1", "9:16", "16:9", "3:4", or "4:3".
/// - `imagegeneration@005`: "1:1" or "9:16".
/// - `imagegeneration@002`: "1:1".
#[serde(skip_serializing_if = "Option::is_none")]
pub aspect_ratio: Option<String>,
/// Describes the output image format in an `PredictImageRequestParametersOutputOptions
/// object.
#[serde(skip_serializing_if = "Option::is_none")]
pub output_options: Option<PredictImageRequestParametersOutputOptions>,
/// Describes the style for the generated images. The following values are supported:
/// - "photograph"
/// - "digital_art"
/// - "landscape"
/// - "sketch"
/// - "watercolor"
/// - "cyberpunk"
/// - "pop_art"
#[serde(skip_serializing_if = "Option::is_none")]
pub sample_image_style: Option<String>,
/// Allow generation of people by the model. The following values are supported:
/// - `"dont_allow"`: Disallow the inclusion of people or faces in images.
/// - `"allow_adult"`: Allow generation of adults only.
/// - `"allow_all"`: Allow generation of people of all ages.
///
/// The default value is `"allow_adult"`.
///
/// Supported by the models `imagen-3.0-generate-001`, `imagen-3.0-fast-generate-001`, and
/// `imagegeneration@006` only.
#[serde(skip_serializing_if = "Option::is_none")]
pub person_generation: Option<String>,
/// Adds a filter level to safety filtering. The following values are supported:
/// - `"block_most"`: Strongest filtering level, most strict blocking.
/// - `"block_some"`: Block some problematic prompts and responses.
/// - `"block_few"`: Reduces the number of requests blocked due to safety filters. May
/// increase objectionable content generated by Imagen.
/// - `"block_fewest"`: Block very few problematic prompts and responses. Access to this
/// feature is restricted.
///
/// The default value is `"block_some"`.
///
/// Supported by the models `imagen-3.0-generate-001`, `imagen-3.0-fast-generate-001`, and
/// `imagegeneration@006` only.
#[serde(skip_serializing_if = "Option::is_none")]
pub safety_setting: Option<String>,
/// Add an invisible watermark to the generated images. The default value is `false` for the
/// `imagegeneration@002` and `imagegeneration@005` models, and `true` for the
/// `imagen-3.0-fast-generate-001`, `imagegeneration@006`, and imagegeneration@006 models.
#[serde(skip_serializing_if = "Option::is_none")]
pub add_watermark: Option<bool>,
/// Cloud Storage URI to store the generated images.
#[serde(skip_serializing_if = "Option::is_none")]
pub storage_uri: Option<String>,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct PredictImageRequestParametersOutputOptions {
/// The image format that the output should be saved as. The following values are supported:
///
/// - "image/png": Save as a PNG image
/// - "image/jpeg": Save as a JPEG image
///
/// The default value is "image/png".v
pub mime_type: Option<String>,
/// The level of compression if the output type is "image/jpeg".
/// Accepted values are 0 through 100. The default value is 75.
pub compression_quality: Option<i32>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct PredictImageResponse {
pub predictions: Vec<PredictImageResponsePrediction>,
}
#[serde_as]
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct PredictImageResponsePrediction {
#[serde_as(as = "Base64")]
pub bytes_base64_encoded: Vec<u8>,
pub mime_type: String,
}