using System.Collections.Generic; using Unity.Sentis; using UnityEngine; using UnityEngine.UI; using UnityEngine.Video; using System.IO; using FF = Unity.Sentis.Functional; /* * YOLO Inference Script * ======================== * * Place this script on the Main Camera and set the script parameters according to the tooltips. * */ public class RunYOLO : MonoBehaviour { [Tooltip("Drag a YOLO model .onnx file here")] public ModelAsset modelAsset; [Tooltip("Drag the classes.txt here")] public TextAsset classesAsset; [Tooltip("Create a Raw Image in the scene and link it here")] public RawImage displayImage; [Tooltip("Drag a border box texture here")] public Texture2D borderTexture; [Tooltip("Select an appropriate font for the labels")] public Font font; [Tooltip("Change this to the name of the video you put in the Assets/StreamingAssets folder")] public string videoFilename = "giraffes.mp4"; const BackendType backend = BackendType.GPUCompute; private Transform displayLocation; private Worker worker; private string[] labels; private RenderTexture targetRT; private Sprite borderSprite; //Image size for the model private const int imageWidth = 640; private const int imageHeight = 640; private VideoPlayer video; List boxPool = new(); [Tooltip("Intersection over union threshold used for non-maximum suppression")] [SerializeField, Range(0, 1)] float iouThreshold = 0.5f; [Tooltip("Confidence score threshold used for non-maximum suppression")] [SerializeField, Range(0, 1)] float scoreThreshold = 0.5f; Tensor centersToCorners; //bounding box data public struct BoundingBox { public float centerX; public float centerY; public float width; public float height; public string label; } void Start() { Application.targetFrameRate = 60; Screen.orientation = ScreenOrientation.LandscapeLeft; //Parse neural net labels labels = classesAsset.text.Split('\n'); LoadModel(); targetRT = new RenderTexture(imageWidth, imageHeight, 0); //Create image to display video displayLocation = displayImage.transform; SetupInput(); borderSprite = Sprite.Create(borderTexture, new Rect(0, 0, borderTexture.width, borderTexture.height), new Vector2(borderTexture.width / 2, borderTexture.height / 2)); } void LoadModel() { //Load model var model1 = ModelLoader.Load(modelAsset); centersToCorners = new Tensor(new TensorShape(4, 4), new float[] { 1, 0, 1, 0, 0, 1, 0, 1, -0.5f, 0, 0.5f, 0, 0, -0.5f, 0, 0.5f }); //Here we transform the output of the model1 by feeding it through a Non-Max-Suppression layer. var graph = new FunctionalGraph(); var inputs = graph.AddInputs(model1); var modelOutput = FF.Forward(model1, inputs)[0]; //shape=(1,84,8400) var boxCoords = modelOutput[0, 0..4, ..].Transpose(0, 1); //shape=(8400,4) var allScores = modelOutput[0, 4.., ..]; //shape=(80,8400) var scores = FF.ReduceMax(allScores, 0); //shape=(8400) var classIDs = FF.ArgMax(allScores, 0); //shape=(8400) var boxCorners = FF.MatMul(boxCoords, FF.Constant(centersToCorners)); //shape=(8400,4) var indices = FF.NMS(boxCorners, scores, iouThreshold, scoreThreshold); //shape=(N) var coords = FF.IndexSelect(boxCoords, 0, indices); //shape=(N,4) var labelIDs = FF.IndexSelect(classIDs, 0, indices); //shape=(N) //Create worker to run model worker = new Worker(graph.Compile(coords, labelIDs), backend); } void SetupInput() { video = gameObject.AddComponent(); video.renderMode = VideoRenderMode.APIOnly; video.source = VideoSource.Url; video.url = Path.Join(Application.streamingAssetsPath, videoFilename); video.isLooping = true; video.Play(); } private void Update() { ExecuteML(); if (Input.GetKeyDown(KeyCode.Escape)) { Application.Quit(); } } public void ExecuteML() { ClearAnnotations(); if (video && video.texture) { float aspect = video.width * 1f / video.height; Graphics.Blit(video.texture, targetRT, new Vector2(1f / aspect, 1), new Vector2(0, 0)); displayImage.texture = targetRT; } else return; using Tensor inputTensor = new Tensor(new TensorShape(1, 3, imageHeight, imageWidth)); TextureConverter.ToTensor(targetRT, inputTensor, default); worker.Schedule(inputTensor); using var output = (worker.PeekOutput("output_0") as Tensor).ReadbackAndClone(); using var labelIDs = (worker.PeekOutput("output_1") as Tensor).ReadbackAndClone(); float displayWidth = displayImage.rectTransform.rect.width; float displayHeight = displayImage.rectTransform.rect.height; float scaleX = displayWidth / imageWidth; float scaleY = displayHeight / imageHeight; int boxesFound = output.shape[0]; //Draw the bounding boxes for (int n = 0; n < Mathf.Min(boxesFound, 200); n++) { var box = new BoundingBox { centerX = output[n, 0] * scaleX - displayWidth / 2, centerY = output[n, 1] * scaleY - displayHeight / 2, width = output[n, 2] * scaleX, height = output[n, 3] * scaleY, label = labels[labelIDs[n]], }; DrawBox(box, n, displayHeight * 0.05f); } } public void DrawBox(BoundingBox box, int id, float fontSize) { //Create the bounding box graphic or get from pool GameObject panel; if (id < boxPool.Count) { panel = boxPool[id]; panel.SetActive(true); } else { panel = CreateNewBox(Color.yellow); } //Set box position panel.transform.localPosition = new Vector3(box.centerX, -box.centerY); //Set box size RectTransform rt = panel.GetComponent(); rt.sizeDelta = new Vector2(box.width, box.height); //Set label text var label = panel.GetComponentInChildren(); label.text = box.label; label.fontSize = (int)fontSize; } public GameObject CreateNewBox(Color color) { //Create the box and set image var panel = new GameObject("ObjectBox"); panel.AddComponent(); Image img = panel.AddComponent(); img.color = color; img.sprite = borderSprite; img.type = Image.Type.Sliced; panel.transform.SetParent(displayLocation, false); //Create the label var text = new GameObject("ObjectLabel"); text.AddComponent(); text.transform.SetParent(panel.transform, false); Text txt = text.AddComponent(); txt.font = font; txt.color = color; txt.fontSize = 40; txt.horizontalOverflow = HorizontalWrapMode.Overflow; RectTransform rt2 = text.GetComponent(); rt2.offsetMin = new Vector2(20, rt2.offsetMin.y); rt2.offsetMax = new Vector2(0, rt2.offsetMax.y); rt2.offsetMin = new Vector2(rt2.offsetMin.x, 0); rt2.offsetMax = new Vector2(rt2.offsetMax.x, 30); rt2.anchorMin = new Vector2(0, 0); rt2.anchorMax = new Vector2(1, 1); boxPool.Add(panel); return panel; } public void ClearAnnotations() { foreach (var box in boxPool) { box.SetActive(false); } } private void OnDestroy() { centersToCorners?.Dispose(); worker?.Dispose(); } }